JohannesGaessler commited on
Commit
5931562
·
1 Parent(s): d4b3604

CUDA: faster q2_K, q3_K MMQ + int8 tensor cores (llama/7921)

Browse files

* CUDA: faster q2_K, q3_K MMQ + int8 tensor cores

* try CI fix

* try CI fix

* try CI fix

* fix data race

* rever q2_K precision related changes

ggml-cuda.cu CHANGED
@@ -188,13 +188,15 @@ static ggml_cuda_device_info ggml_cuda_init() {
188
  info.default_tensor_split[id] = total_vram;
189
  total_vram += prop.totalGlobalMem;
190
 
 
 
191
  #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
 
192
  info.devices[id].cc = 100*prop.major + 10*prop.minor + CC_OFFSET_AMD;
193
  #else
 
194
  info.devices[id].cc = 100*prop.major + 10*prop.minor;
195
  #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
196
- info.devices[id].smpb = prop.sharedMemPerBlock;
197
- info.devices[id].nsm = prop.multiProcessorCount;
198
  }
199
 
200
  for (int id = 0; id < info.device_count; ++id) {
 
188
  info.default_tensor_split[id] = total_vram;
189
  total_vram += prop.totalGlobalMem;
190
 
191
+ info.devices[id].nsm = prop.multiProcessorCount;
192
+ info.devices[id].smpb = prop.sharedMemPerBlock;
193
  #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
194
+ info.devices[id].smpbo = prop.sharedMemPerBlock;
195
  info.devices[id].cc = 100*prop.major + 10*prop.minor + CC_OFFSET_AMD;
196
  #else
197
+ info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
198
  info.devices[id].cc = 100*prop.major + 10*prop.minor;
199
  #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
 
 
200
  }
201
 
202
  for (int id = 0; id < info.device_count; ++id) {
ggml-cuda/argsort.cu CHANGED
@@ -73,6 +73,7 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co
73
  const dim3 block_nums(1, nrows, 1);
74
  const size_t shared_mem = ncols_pad * sizeof(int);
75
 
 
76
  GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
77
 
78
  if (order == GGML_SORT_ORDER_ASC) {
 
73
  const dim3 block_nums(1, nrows, 1);
74
  const size_t shared_mem = ncols_pad * sizeof(int);
75
 
76
+ // FIXME: this limit could be raised by ~2-4x on Ampere or newer
77
  GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
78
 
79
  if (order == GGML_SORT_ORDER_ASC) {
ggml-cuda/common.cuh CHANGED
@@ -331,6 +331,10 @@ static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int
331
  #define FP16_AVAILABLE
332
  #endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
333
 
 
 
 
 
334
  #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
335
  #define FP16_MMA_AVAILABLE
336
  #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
@@ -661,6 +665,7 @@ struct ggml_cuda_device_info {
661
  int cc; // compute capability
662
  int nsm; // number of streaming multiprocessors
663
  size_t smpb; // max. shared memory per block
 
664
  bool vmm; // virtual memory support
665
  size_t vmm_granularity; // granularity of virtual memory
666
  size_t total_vram;
 
331
  #define FP16_AVAILABLE
332
  #endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
333
 
334
+ #if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
335
+ #define FAST_FP16_AVAILABLE
336
+ #endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
337
+
338
  #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
339
  #define FP16_MMA_AVAILABLE
340
  #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
 
665
  int cc; // compute capability
666
  int nsm; // number of streaming multiprocessors
667
  size_t smpb; // max. shared memory per block
668
+ size_t smpbo; // max. shared memory per block (with opt-in)
669
  bool vmm; // virtual memory support
670
  size_t vmm_granularity; // granularity of virtual memory
671
  size_t total_vram;
ggml-cuda/mmq.cuh CHANGED
@@ -10,10 +10,10 @@
10
  #define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
11
 
12
  typedef void (*load_tiles_mmq_t)(
13
- const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
14
  int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride);
15
  typedef void (*vec_dot_mmq_t)(
16
- const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
17
  const int * __restrict__ y, float * __restrict__ sum, const int & k0);
18
  typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1);
19
 
@@ -25,9 +25,8 @@ static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected b
25
  static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size");
26
 
27
  struct tile_x_sizes {
28
- int ql;
29
  int dm;
30
- int qh;
31
  int sc;
32
  };
33
 
@@ -67,16 +66,16 @@ static constexpr __device__ int get_mmq_y_device(int /*mmq_x*/) {
67
  #endif // __CUDA_ARCH__ >= CC_VOLTA
68
  #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
69
 
70
- #define TILE_X_SIZES_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0, 0}
71
- #define TILE_X_SIZES_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0, 0}
72
- #define TILE_X_SIZES_Q5_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_0 + mmq_y/QI5_0, 0, 0}
73
- #define TILE_X_SIZES_Q5_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_1 + mmq_y/QI5_1, 0, 0}
74
- #define TILE_X_SIZES_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI8_0 + mmq_y/QI8_0, 0, 0}
75
- #define TILE_X_SIZES_Q2_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI2_K + mmq_y/QI2_K, 0, mmq_y*WARP_SIZE/4 + mmq_y/4}
76
- #define TILE_X_SIZES_Q3_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI3_K + mmq_y/QI3_K, mmq_y*WARP_SIZE/2 + mmq_y/2, mmq_y*WARP_SIZE/4 + mmq_y/4}
77
- #define TILE_X_SIZES_Q4_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_K + mmq_y/QI4_K, 0, mmq_y*WARP_SIZE/8 + mmq_y/8}
78
- #define TILE_X_SIZES_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, 0, mmq_y*WARP_SIZE/8 + mmq_y/8}
79
- #define TILE_X_SIZES_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, 0, mmq_y*WARP_SIZE/8 + mmq_y/8}
80
 
81
  #define GET_TILE_X_SIZES_BODY \
82
  return type == GGML_TYPE_Q4_0 ? TILE_X_SIZES_Q4_0 : \
@@ -89,7 +88,7 @@ static constexpr __device__ int get_mmq_y_device(int /*mmq_x*/) {
89
  type == GGML_TYPE_Q4_K ? TILE_X_SIZES_Q4_K : \
90
  type == GGML_TYPE_Q5_K ? TILE_X_SIZES_Q5_K : \
91
  type == GGML_TYPE_Q6_K ? TILE_X_SIZES_Q6_K : \
92
- tile_x_sizes{0, 0, 0, 0}
93
 
94
  static tile_x_sizes get_tile_x_sizes_host(const ggml_type type, const int mmq_y) {
95
  GET_TILE_X_SIZES_BODY;
@@ -103,9 +102,9 @@ static constexpr __device__ tile_x_sizes get_tile_x_sizes_device(ggml_type type)
103
  // ------------------------------------------------------------
104
 
105
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
106
- const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
107
  int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
108
- GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
109
 
110
  const int kbx = threadIdx.x / QI4_0;
111
  const int kqsx = threadIdx.x % QI4_0;
@@ -122,7 +121,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
122
 
123
  const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
124
 
125
- x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx);
126
  }
127
 
128
  const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
@@ -144,10 +143,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
144
 
145
  template <int mmq_x, int mmq_y, int nwarps>
146
  static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
147
- const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
148
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
149
-
150
- GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
151
 
152
  const float * x_df = (const float *) x_dm;
153
  const int * y_qs = (const int *) y + 4;
@@ -172,7 +170,7 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
172
  }
173
 
174
  sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
175
- (&x_ql[i*(WARP_SIZE + 1) + k0], u, x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0],
176
  y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
177
  }
178
  }
@@ -180,10 +178,10 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
180
 
181
  template <int mmq_x, int mmq_y, int nwarps>
182
  static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma(
183
- const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
184
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
185
-
186
- GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
187
 
188
  typedef mma_int_A_I16K8 mma_A;
189
  typedef mma_int_B_J8K8 mma_B;
@@ -205,7 +203,7 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma(
205
  const int k = k0 + mma_A::get_k(l) % QI4_0;
206
  const int shift = 4*(mma_A::get_k(l) / QI4_0);
207
 
208
- A.x[l] = __vsubss4((x_ql[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F, 0x08080808);
209
  }
210
  #pragma unroll
211
  for (int l = 0; l < mma_C::ne/2; ++l) {
@@ -240,12 +238,16 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma(
240
  sum[(j0/B.J)*C.ne + l] += dA[l/2]*__low2float(dsB[l%2])*C.x[l];
241
  }
242
  }
 
 
 
 
243
  }
244
 
245
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
246
- const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
247
  int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
248
- GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
249
 
250
  const int kbx = threadIdx.x / QI4_1;
251
  const int kqsx = threadIdx.x % QI4_1;
@@ -260,7 +262,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
260
 
261
  const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
262
 
263
- x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
264
  }
265
 
266
  const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
@@ -282,10 +284,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
282
 
283
  template <int mmq_x, int mmq_y, int nwarps>
284
  static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
285
- const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
286
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
287
-
288
- GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
289
 
290
  const int * y_qs = (const int *) y + 4;
291
  const half2 * y_ds = (const half2 *) y;
@@ -309,7 +310,7 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
309
  }
310
 
311
  sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
312
- (&x_ql[i*(WARP_SIZE + 1) + k0], u, x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + k0/QI4_1],
313
  y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
314
  }
315
  }
@@ -317,10 +318,10 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
317
 
318
  template <int mmq_x, int mmq_y, int nwarps>
319
  static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma(
320
- const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
321
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
322
-
323
- GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
324
 
325
  typedef mma_int_A_I16K8 mma_A;
326
  typedef mma_int_B_J8K8 mma_B;
@@ -341,7 +342,7 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma(
341
  const int k = k0 + mma_A::get_k(l) % QI4_0;
342
  const int shift = 4*(mma_A::get_k(l) / QI4_0);
343
 
344
- A.x[l] = (x_ql[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F;
345
  }
346
  #pragma unroll
347
  for (int l = 0; l < mma_C::ne/2; ++l) {
@@ -377,12 +378,16 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma(
377
  sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
378
  }
379
  }
 
 
 
 
380
  }
381
 
382
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
383
- const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
384
  int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
385
- GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
386
 
387
  const int kbx = threadIdx.x / QI5_0;
388
  const int kqsx = threadIdx.x % QI5_0;
@@ -407,7 +412,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
407
  qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
408
  qs0 = __vsubss4(qs0, 0x10101010); // subtract 16
409
 
410
- x_ql[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+0] = qs0;
411
 
412
  int qs1 = (ql >> 4) & 0x0F0F0F0F;
413
  qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
@@ -416,7 +421,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
416
  qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
417
  qs1 = __vsubss4(qs1, 0x10101010); // subtract 16
418
 
419
- x_ql[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+1] = qs1;
420
  }
421
 
422
  const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
@@ -439,10 +444,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
439
 
440
  template <int mmq_x, int mmq_y, int nwarps>
441
  static __device__ __forceinline__ void vec_dot_q5_0_q8_1_dp4a(
442
- const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
443
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
444
-
445
- GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
446
 
447
  const float * x_dmf = (const float *) x_dm;
448
  const int * y_qs = (const int *) y + 4;
@@ -468,17 +472,17 @@ static __device__ __forceinline__ void vec_dot_q5_0_q8_1_dp4a(
468
  }
469
 
470
  sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, QR5_0*VDR_Q5_0_Q8_1_MMQ>
471
- (&x_ql[i*(2*WARP_SIZE + 1) + 2*k0], u, x_dmf[index_bx], y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
472
  }
473
  }
474
  }
475
 
476
  template <int mmq_x, int mmq_y, int nwarps>
477
  static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma(
478
- const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
479
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
480
-
481
- GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
482
 
483
  typedef mma_int_A_I16K8 mma_A;
484
  typedef mma_int_B_J8K8 mma_B;
@@ -499,7 +503,7 @@ static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma(
499
  const int i = i0 + mma_A::get_i(l);
500
  const int k = 2*(k0 + mma_A::get_k(l) % QI5_0) + mma_A::get_k(l) / QI5_0;
501
 
502
- A.x[l] = x_ql[i*(2*WARP_SIZE + 1) + k];
503
  }
504
  #pragma unroll
505
  for (int l = 0; l < mma_C::ne/2; ++l) {
@@ -534,12 +538,16 @@ static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma(
534
  sum[(j0/B.J)*C.ne + l] += dA[l/2]*dB[l%2]*C.x[l];
535
  }
536
  }
 
 
 
 
537
  }
538
 
539
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
540
- const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
541
  int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
542
- GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
543
 
544
  const int kbx = threadIdx.x / QI5_1;
545
  const int kqsx = threadIdx.x % QI5_1;
@@ -563,7 +571,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
563
  qs0 |= (qh << 18) & 0x00100000; // 2 -> 20
564
  qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
565
 
566
- x_ql[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+0] = qs0;
567
 
568
  int qs1 = (ql >> 4) & 0x0F0F0F0F;
569
  qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
@@ -571,7 +579,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
571
  qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
572
  qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
573
 
574
- x_ql[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+1] = qs1;
575
  }
576
 
577
  const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
@@ -593,10 +601,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
593
 
594
  template <int mmq_x, int mmq_y, int nwarps>
595
  static __device__ __forceinline__ void vec_dot_q5_1_q8_1_dp4a(
596
- const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
597
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
598
-
599
- GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
600
 
601
  const int * y_qs = (const int *) y + 4;
602
  const half2 * y_ds = (const half2 *) y;
@@ -621,17 +628,17 @@ static __device__ __forceinline__ void vec_dot_q5_1_q8_1_dp4a(
621
  }
622
 
623
  sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
624
- (&x_ql[i*(2*WARP_SIZE + 1) + 2*k0], u, x_dm[index_bx], y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
625
  }
626
  }
627
  }
628
 
629
  template <int mmq_x, int mmq_y, int nwarps>
630
  static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma(
631
- const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
632
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
633
-
634
- GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
635
 
636
  typedef mma_int_A_I16K8 mma_A;
637
  typedef mma_int_B_J8K8 mma_B;
@@ -651,7 +658,7 @@ static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma(
651
  const int i = i0 + mma_A::get_i(l);
652
  const int k = 2*(k0 + mma_A::get_k(l) % QI5_1) + mma_A::get_k(l) / QI5_1;
653
 
654
- A.x[l] = x_ql[i*(2*WARP_SIZE + 1) + k];
655
  }
656
  #pragma unroll
657
  for (int l = 0; l < mma_C::ne/2; ++l) {
@@ -687,13 +694,16 @@ static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma(
687
  sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
688
  }
689
  }
 
 
 
 
690
  }
691
 
692
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
693
- const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
694
  int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
695
-
696
- GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
697
 
698
  const int kbx = threadIdx.x / QI8_0;
699
  const int kqsx = threadIdx.x % QI8_0;
@@ -709,7 +719,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
709
 
710
  const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
711
 
712
- x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_int8(bxi->qs, kqsx);
713
  }
714
 
715
  const int blocks_per_tile_x_row = WARP_SIZE / QI8_0;
@@ -731,10 +741,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
731
 
732
  template <int mmq_x, int mmq_y, int nwarps>
733
  static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
734
- const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
735
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
736
-
737
- GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
738
 
739
  const float * x_dmf = (const float *) x_dm;
740
  const int * y_qs = (const int *) y + 4;
@@ -749,7 +758,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
749
  const int i = i0 + threadIdx.x;
750
 
751
  sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
752
- (&x_ql[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0], x_dmf[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0],
753
  y_df[j*MMQ_TILE_Y_K + k0/QI8_1]);
754
  }
755
  }
@@ -757,10 +766,10 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
757
 
758
  template <int mmq_x, int mmq_y, int nwarps>
759
  static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
760
- const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
761
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
762
-
763
- GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
764
 
765
  typedef mma_int_A_I16K8 mma_A;
766
  typedef mma_int_B_J8K8 mma_B;
@@ -781,7 +790,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
781
  const int i = i0 + mma_A::get_i(l);
782
  const int k = k0 + mma_A::get_k(l);
783
 
784
- A.x[l] = x_ql[i*(WARP_SIZE + 1) + k];
785
  }
786
  #pragma unroll
787
  for (int l = 0; l < mma_C::ne/2; ++l) {
@@ -816,12 +825,15 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
816
  sum[(j0/B.J)*C.ne + l] += C.x[l]*dA[l/2]*dB[l%2];
817
  }
818
  }
 
 
 
 
819
  }
820
 
821
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
822
- const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
823
  int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
824
- GGML_UNUSED(x_qh);
825
 
826
  const int kbx = threadIdx.x / QI2_K;
827
  const int kqsx = threadIdx.x % QI2_K;
@@ -836,48 +848,42 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
836
 
837
  const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + kbx;
838
 
839
- x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
840
- }
841
-
842
- const int blocks_per_tile_x_row = WARP_SIZE / QI2_K;
843
- const int kbxd = threadIdx.x % blocks_per_tile_x_row;
844
 
845
  #pragma unroll
846
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI2_K) {
847
- int i = (i0 + threadIdx.y * QI2_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
848
 
849
- if (need_check) {
850
- i = min(i, i_max);
851
- }
852
-
853
- const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + kbxd;
854
-
855
- x_dm[i * (WARP_SIZE/QI2_K) + i / QI2_K + kbxd] = bxi->dm;
856
- }
857
 
858
- #pragma unroll
859
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
860
- int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
861
 
862
- if (need_check) {
863
- i = min(i, i_max);
864
  }
865
 
866
- const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/4)) / (QI2_K/4);
 
 
 
 
 
 
867
 
868
- x_sc[i * (WARP_SIZE/4) + i / 4 + threadIdx.x % (WARP_SIZE/4)] = get_int_from_uint8_aligned(bxi->scales, threadIdx.x % (QI2_K/4));
869
  }
870
  }
871
 
872
  template <int mmq_x, int mmq_y, int nwarps>
873
- static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mul_mat(
874
- const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
875
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
876
 
877
- GGML_UNUSED(x_qh);
878
-
879
- const int * y_qs = (const int *) y + 4;
880
- const float * y_df = (const float *) y;
881
 
882
  #pragma unroll
883
  for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
@@ -887,30 +893,99 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mul_mat(
887
  for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
888
  const int i = i0 + threadIdx.x;
889
 
890
- const int kbx = k0 / QI2_K;
891
- const int ky = (k0 % QI2_K) * QR2_K;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
892
 
893
- int v[QR2_K*VDR_Q2_K_Q8_1_MMQ];
 
894
 
895
- const int kqsx = i*(WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2);
896
- const int shift = 2 * ((ky % (2*QI2_K)) / (QI2_K/2));
 
897
 
898
  #pragma unroll
899
- for (int l = 0; l < QR2_K*VDR_Q2_K_Q8_1_MMQ; ++l) {
900
- v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303;
901
- }
902
 
903
- const uint8_t * scales = ((const uint8_t *) &x_sc[i*(WARP_SIZE/4) + i/4 + kbx*4]) + ky/4;
 
 
904
 
905
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq(
906
- v, &y_qs[j*MMQ_TILE_Y_K + (QR2_K*k0) % WARP_SIZE], scales,
907
- x_dm[i*(WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[j*MMQ_TILE_Y_K + ((QR2_K*k0) % WARP_SIZE)/QI8_1]);
 
 
 
 
 
 
 
908
  }
909
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
910
  }
911
 
912
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
913
- const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
914
  int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
915
 
916
  const int kbx = threadIdx.x / QI3_K;
@@ -926,7 +1001,25 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
926
 
927
  const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbx;
928
 
929
- x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
930
  }
931
 
932
  const int blocks_per_tile_x_row = WARP_SIZE / QI3_K;
@@ -946,20 +1039,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
946
  x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = bxi->d;
947
  }
948
 
949
- #pragma unroll
950
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 2) {
951
- int i = i0 + threadIdx.y * 2 + threadIdx.x / (WARP_SIZE/2);
952
-
953
- if (need_check) {
954
- i = min(i, i_max);
955
- }
956
-
957
- const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/2)) / (QI3_K/2);
958
-
959
- // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
960
- x_qh[i * (WARP_SIZE/2) + i / 2 + threadIdx.x % (WARP_SIZE/2)] = ~get_int_from_uint8(bxi->hmask, threadIdx.x % (QI3_K/2));
961
- }
962
-
963
  #pragma unroll
964
  for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
965
  int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
@@ -987,13 +1066,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
987
  }
988
 
989
  template <int mmq_x, int mmq_y, int nwarps>
990
- static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mul_mat(
991
- const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
992
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
993
 
994
- const float * x_dmf = (const float *) x_dm;
995
- const int * y_qs = (const int *) y + 4;
996
- const float * y_df = (const float *) y;
997
 
998
  #pragma unroll
999
  for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
@@ -1008,31 +1087,102 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mul_mat(
1008
 
1009
  const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
1010
 
1011
- int v[QR3_K*VDR_Q3_K_Q8_1_MMQ];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1012
 
1013
  #pragma unroll
1014
- for (int l = 0; l < QR3_K*VDR_Q3_K_Q8_1_MMQ; ++l) {
1015
- const int kqsx = i*(WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2);
1016
- const int shift = 2 * ((ky % 32) / 8);
1017
- const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303;
1018
 
1019
- const int vh = x_qh[i*(WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8);
1020
- const int vlh = (vh << 2) & 0x04040404;
 
 
 
1021
 
1022
- v[l] = __vsubss4(vll, vlh);
1023
- }
 
1024
 
1025
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq(
1026
- v, &y_qs[j*MMQ_TILE_Y_K + (k0*QR3_K) % WARP_SIZE], scales,
1027
- x_dmf[i*(WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[j*MMQ_TILE_Y_K + ((k0*QR3_K) % WARP_SIZE)/QI8_1]);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1028
  }
1029
  }
 
 
 
 
1030
  }
1031
 
1032
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
1033
- const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
1034
  int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
1035
- GGML_UNUSED(x_qh);
1036
 
1037
  const int kbx = 0; // threadIdx.x / QI4_K
1038
  const int kqsx = threadIdx.x; // threadIdx.x % QI4_K
@@ -1047,7 +1197,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1047
 
1048
  const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbx;
1049
 
1050
- x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
1051
  }
1052
 
1053
  const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256
@@ -1090,11 +1240,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1090
 
1091
  template <int mmq_x, int mmq_y, int nwarps>
1092
  static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
1093
- const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
1094
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
1095
 
1096
- GGML_UNUSED(x_qh);
1097
-
1098
  const int * y_qs = (const int *) y + 4;
1099
  const half2 * y_ds = (const half2 *) y;
1100
 
@@ -1109,7 +1257,7 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
1109
  const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2*((k0 % 16) / 8);
1110
 
1111
  sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq(
1112
- &x_ql[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + (QR4_K*k0) % WARP_SIZE], sc, sc+8,
1113
  x_dm[i*(WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[j*MMQ_TILE_Y_K + ((QR4_K*k0) % WARP_SIZE)/QI8_1]);
1114
  }
1115
  }
@@ -1117,10 +1265,9 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
1117
 
1118
  template <int mmq_x, int mmq_y, int nwarps>
1119
  static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma(
1120
- const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
1121
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
1122
-
1123
- GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
1124
 
1125
  typedef mma_int_A_I16K8 mma_A;
1126
  typedef mma_int_B_J8K8 mma_B;
@@ -1143,7 +1290,7 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma(
1143
  const int i = i0 + mma_A::get_i(l);
1144
  const int k = k0 + mma_A::get_k(l);
1145
 
1146
- A[kvdr/4].x[l] = (x_ql[i*(WARP_SIZE + 1) + k] >> kvdr) & 0x0F0F0F0F;
1147
  }
1148
 
1149
  #pragma unroll
@@ -1204,12 +1351,15 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma(
1204
  sum[(j0/mma_B::J)*mma_C::ne + l] += __low2float(dmA[l/2])*tmpd[l] - __high2float(dmA[l/2])*tmpm[l];
1205
  }
1206
  }
 
 
 
 
1207
  }
1208
 
1209
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
1210
- const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
1211
  int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
1212
- GGML_UNUSED(x_qh);
1213
 
1214
  const int kbx = 0; // threadIdx.x / QI5_K
1215
  const int kqsx = threadIdx.x; // threadIdx.x % QI5_K
@@ -1236,8 +1386,8 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1236
  const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0;
1237
  const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + (QI5_K/4);
1238
 
1239
- x_ql[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0;
1240
- x_ql[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
1241
  }
1242
 
1243
  const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256
@@ -1280,11 +1430,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1280
 
1281
  template <int mmq_x, int mmq_y, int nwarps>
1282
  static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
1283
- const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
1284
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
1285
 
1286
- GGML_UNUSED(x_qh);
1287
-
1288
  const int * y_qs = (const int *) y + 4;
1289
  const half2 * y_ds = (const half2 *) y;
1290
 
@@ -1299,7 +1447,7 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
1299
  const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8);
1300
 
1301
  sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq(
1302
- &x_ql[i*(QR5_K*WARP_SIZE + 1) + QR5_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR5_K*k0) % WARP_SIZE], sc, sc+8,
1303
  x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[j*MMQ_TILE_Y_K + ((QR5_K*k0) % WARP_SIZE)/QI8_1]);
1304
  }
1305
  }
@@ -1307,10 +1455,9 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
1307
 
1308
  template <int mmq_x, int mmq_y, int nwarps>
1309
  static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma(
1310
- const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
1311
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
1312
-
1313
- GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
1314
 
1315
  typedef mma_int_A_I16K8 mma_A;
1316
  typedef mma_int_B_J8K8 mma_B;
@@ -1333,7 +1480,7 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma(
1333
  const int i = i0 + mma_A::get_i(l);
1334
  const int k = QR5_K*k0 + QR5_K*kvdr + mma_A::get_k(l);
1335
 
1336
- A[kvdr/4].x[l] = x_ql[i*(QR5_K*WARP_SIZE + 1) + k];
1337
  }
1338
 
1339
  #pragma unroll
@@ -1394,12 +1541,15 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma(
1394
  sum[(j0/mma_B::J)*mma_C::ne + l] += __low2float(dmA[l/2])*tmpd[l] - __high2float(dmA[l/2])*tmpm[l];
1395
  }
1396
  }
 
 
 
 
1397
  }
1398
 
1399
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
1400
- const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
1401
  int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
1402
- GGML_UNUSED(x_qh);
1403
 
1404
  const int kbx = 0; // threadIdx.x / QI6_K
1405
  const int kqsx = threadIdx.x; // threadIdx.x % QI6_K
@@ -1426,8 +1576,8 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1426
  const int kq0 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + 0;
1427
  const int kq1 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + (QI6_K/2);
1428
 
1429
- x_ql[i * (2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
1430
- x_ql[i * (2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
1431
  }
1432
 
1433
  const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256
@@ -1463,11 +1613,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1463
 
1464
  template <int mmq_x, int mmq_y, int nwarps>
1465
  static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
1466
- const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
1467
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
1468
 
1469
- GGML_UNUSED(x_qh);
1470
-
1471
  const float * x_dmf = (const float *) x_dm;
1472
  const int * y_qs = (const int *) y + 4;
1473
  const float * y_df = (const float *) y;
@@ -1483,7 +1631,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
1483
  const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/8]);
1484
 
1485
  sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq(
1486
- &x_ql[i*(QR6_K*WARP_SIZE + 1) + QR6_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR6_K*k0) % WARP_SIZE], sc,
1487
  x_dmf[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + ((QR6_K*k0) % WARP_SIZE)/QI8_1]);
1488
  }
1489
  }
@@ -1491,10 +1639,9 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
1491
 
1492
  template <int mmq_x, int mmq_y, int nwarps>
1493
  static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
1494
- const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
1495
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
1496
-
1497
- GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
1498
 
1499
  typedef mma_int_A_I16K4 mma_A;
1500
  typedef mma_int_B_J8K4 mma_B;
@@ -1505,7 +1652,9 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
1505
  const float * y_df = (const float *) y;
1506
 
1507
  const int i0 = threadIdx.y*mma_A::I;
 
1508
  static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
 
1509
 
1510
  mma_A A[4];
1511
  int scA[mma_C::ne/2][4];
@@ -1517,8 +1666,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
1517
  const int i = i0 + mma_A::get_i(l);
1518
  const int k = QR6_K*k0 + QR6_K*kvdr + mma_A::get_k(l);
1519
 
1520
- A[kvdr/2 + 0].x[l] = x_ql[i*(QR6_K*WARP_SIZE + 1) + k + 0];
1521
- A[kvdr/2 + 1].x[l] = x_ql[i*(QR6_K*WARP_SIZE + 1) + k + mma_A::K];
1522
  }
1523
 
1524
  #pragma unroll
@@ -1578,6 +1727,10 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
1578
  sum[(j0/mma_B::J)*mma_C::ne + l] += tmp[l]*dA[l/2];
1579
  }
1580
  }
 
 
 
 
1581
  }
1582
 
1583
  template<int mmq_x, int mmq_y, int nwarps, bool need_check>
@@ -1608,7 +1761,9 @@ static __device__ __forceinline__ void mmq_write_back_mma(const float * __restri
1608
  typedef mma_int_C_I16J8 mma_C;
1609
 
1610
  const int i0 = threadIdx.y*mma_C::I;
 
1611
  static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y");
 
1612
 
1613
  #pragma unroll
1614
  for (int j0 = 0; j0 < mmq_x; j0 += mma_C::J) {
@@ -1638,125 +1793,85 @@ struct mmq_type_traits;
1638
 
1639
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
1640
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_0> {
1641
- static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
1642
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, nwarps, need_check>;
1643
- #ifdef INT8_MMA_AVAILABLE
1644
- static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
1645
- static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
1646
- #else
1647
- static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
1648
- static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
1649
- #endif // INT8_MMA_AVAILABLE
1650
  };
1651
 
1652
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
1653
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_1> {
1654
- static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ;
1655
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, nwarps, need_check>;
1656
- #ifdef INT8_MMA_AVAILABLE
1657
- static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
1658
- static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
1659
- #else
1660
- static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
1661
- static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
1662
- #endif // INT8_MMA_AVAILABLE
1663
  };
1664
 
1665
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
1666
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
1667
- static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ;
1668
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, nwarps, need_check>;
1669
- #ifdef INT8_MMA_AVAILABLE
1670
- static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
1671
- static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
1672
- #else
1673
- static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
1674
- static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
1675
- #endif // INT8_MMA_AVAILABLE
1676
  };
1677
 
1678
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
1679
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
1680
- static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ;
1681
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, nwarps, need_check>;
1682
- #ifdef INT8_MMA_AVAILABLE
1683
- static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
1684
- static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
1685
- #else
1686
- static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
1687
- static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
1688
- #endif // INT8_MMA_AVAILABLE
1689
  };
1690
 
1691
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
1692
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> {
1693
- static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
1694
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, nwarps, need_check>;
1695
- #ifdef INT8_MMA_AVAILABLE
1696
- static constexpr vec_dot_mmq_t vec_dot = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
1697
- static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
1698
- #else
1699
- static constexpr vec_dot_mmq_t vec_dot = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
1700
- static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
1701
- #endif // INT8_MMA_AVAILABLE
1702
  };
1703
 
1704
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
1705
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q2_K> {
1706
- static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
1707
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, nwarps, need_check>;
1708
- static constexpr vec_dot_mmq_t vec_dot = vec_dot_q2_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
1709
- static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
1710
  };
1711
 
1712
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
1713
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q3_K> {
1714
- static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ;
1715
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, nwarps, need_check>;
1716
- static constexpr vec_dot_mmq_t vec_dot = vec_dot_q3_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
1717
- static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
1718
  };
1719
 
1720
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
1721
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> {
1722
- static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ;
1723
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, nwarps, need_check>;
1724
- #ifdef INT8_MMA_AVAILABLE
1725
- static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
1726
- static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
1727
- #else
1728
- static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
1729
- static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
1730
- #endif // INT8_MMA_AVAILABLE
1731
  };
1732
 
1733
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
1734
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> {
1735
- static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ;
1736
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, nwarps, need_check>;
1737
- #ifdef INT8_MMA_AVAILABLE
1738
- static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
1739
- static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
1740
- #else
1741
- static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
1742
- static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
1743
- #endif // INT8_MMA_AVAILABLE
1744
  };
1745
 
1746
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
1747
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
1748
- static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ;
1749
- static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, nwarps, need_check>;
1750
- #ifdef INT8_MMA_AVAILABLE
1751
- static constexpr vec_dot_mmq_t vec_dot = vec_dot_q6_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
1752
- static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
1753
- #else
1754
- static constexpr vec_dot_mmq_t vec_dot = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
1755
- static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
1756
- #endif // INT8_MMA_AVAILABLE
1757
  };
1758
 
1759
- static int mmq_need_sum(const ggml_type type_x) {
1760
  switch (type_x) {
1761
  case GGML_TYPE_Q4_0:
1762
  case GGML_TYPE_Q4_1:
@@ -1790,7 +1905,7 @@ template <ggml_type type, int mmq_x, int nwarps, bool need_check>
1790
  #if __CUDA_ARCH__ >= CC_VOLTA
1791
  __launch_bounds__(WARP_SIZE*nwarps, 1)
1792
  #else
1793
- __launch_bounds__(WARP_SIZE*nwarps, type == GGML_TYPE_Q2_K ? 1 : 2)
1794
  #endif // __CUDA_ARCH__ >= CC_VOLTA
1795
  #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
1796
  static __global__ void mul_mat_q(
@@ -1809,16 +1924,21 @@ static __global__ void mul_mat_q(
1809
  constexpr int mmq_y = get_mmq_y_device(mmq_x);
1810
  constexpr int vdr = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vdr;
1811
  constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
1812
- constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot;
1813
- constexpr mmq_write_back_t write_back = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::write_back;
 
 
 
 
 
 
1814
 
1815
  constexpr tile_x_sizes txs = get_tile_x_sizes_device<mmq_y>(type);
1816
 
1817
  extern __shared__ char data_mul_mat_q[];
1818
- int * tile_x_ql = (int *) data_mul_mat_q;
1819
- half2 * tile_x_dm = (half2 *) (tile_x_ql + txs.ql);
1820
- int * tile_x_qh = (int *) (tile_x_dm + txs.dm);
1821
- int * tile_x_sc = (int *) (tile_x_qh + txs.qh);
1822
  int * tile_y = (int *) (tile_x_sc + txs.sc); // [mmq_x * (WARP_SIZE + WARP_SIZE/QI8_1)]
1823
 
1824
  const int blocks_per_row_x = ne00 / qk;
@@ -1834,7 +1954,7 @@ static __global__ void mul_mat_q(
1834
 
1835
  for (int kb0 = 0; kb0 < blocks_per_row_x; kb0 += blocks_per_warp) {
1836
 
1837
- load_tiles(x, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, stride01*blockIdx.x*mmq_y + kb0, tile_x_max_i, stride01);
1838
 
1839
  #pragma unroll
1840
  for (int kr = 0; kr < qr; ++kr) {
@@ -1850,7 +1970,7 @@ static __global__ void mul_mat_q(
1850
 
1851
  // #pragma unroll // unrolling this loop causes too much register pressure
1852
  for (int k0 = kr*WARP_SIZE/qr; k0 < (kr+1)*WARP_SIZE/qr; k0 += vdr) {
1853
- vec_dot(tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y, sum, k0);
1854
  }
1855
 
1856
  __syncthreads();
@@ -1867,6 +1987,19 @@ struct mmq_args {
1867
  int64_t ne0;
1868
  };
1869
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1870
  template <ggml_type type, int mmq_x, int nwarps>
1871
  static void launch_mul_mat_q(const mmq_args & args, cudaStream_t stream) {
1872
  const int id = ggml_cuda_get_device();
@@ -1878,10 +2011,7 @@ static void launch_mul_mat_q(const mmq_args & args, cudaStream_t stream) {
1878
  const dim3 block_nums(block_num_x, block_num_y, 1);
1879
  const dim3 block_dims(WARP_SIZE, nwarps, 1);
1880
 
1881
- const tile_x_sizes txs = get_tile_x_sizes_host(type, mmq_y);
1882
- const int shmem_x = txs.ql*sizeof(int) + txs.dm*sizeof(half2) + txs.qh*sizeof(int) + txs.sc*sizeof(int);
1883
- const int shmem_y = mmq_x*WARP_SIZE*sizeof(int) + mmq_x*(WARP_SIZE/QI8_1)*sizeof(half2);
1884
- const int shmem = shmem_x + GGML_PAD(shmem_y, nwarps*WARP_SIZE*sizeof(int));
1885
 
1886
  #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
1887
  static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
@@ -1905,9 +2035,10 @@ static void launch_mul_mat_q(const mmq_args & args, cudaStream_t stream) {
1905
 
1906
  template <ggml_type type>
1907
  void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) {
1908
- const int id = ggml_cuda_get_device();
1909
- const int nsm = ggml_cuda_info().devices[id].nsm;
1910
- const int cc = ggml_cuda_info().devices[id].cc;
 
1911
 
1912
  const int mmq_x_max = get_mmq_x_max_host(cc);
1913
  const int mmq_y = get_mmq_y_host(cc, mmq_x_max);
@@ -1920,7 +2051,7 @@ void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) {
1920
  const int block_num_x = (args.ne11 + mmq_x - 1) / mmq_x;
1921
  const int nwaves = (block_num_x*block_num_y + nsm - 1) / nsm;
1922
 
1923
- if (nwaves < nwaves_best) {
1924
  mmq_x_best = mmq_x;
1925
  nwaves_best = nwaves;
1926
  }
@@ -1928,54 +2059,55 @@ void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) {
1928
 
1929
  switch (mmq_x_best) {
1930
  case 8:
1931
- launch_mul_mat_q<type, 8, 4>(args, stream);
1932
  break;
1933
  case 16:
1934
- launch_mul_mat_q<type, 16, 4>(args, stream);
1935
  break;
1936
  case 24:
1937
- launch_mul_mat_q<type, 24, 4>(args, stream);
1938
  break;
1939
  case 32:
1940
- launch_mul_mat_q<type, 32, 8>(args, stream);
1941
  break;
1942
  case 40:
1943
- launch_mul_mat_q<type, 40, 8>(args, stream);
1944
  break;
1945
  case 48:
1946
- launch_mul_mat_q<type, 48, 8>(args, stream);
1947
  break;
1948
  case 56:
1949
- launch_mul_mat_q<type, 56, 8>(args, stream);
1950
  break;
1951
  case 64:
1952
- launch_mul_mat_q<type, 64, 8>(args, stream);
1953
  break;
1954
  case 72:
1955
- launch_mul_mat_q<type, 72, 8>(args, stream);
1956
  break;
1957
  case 80:
1958
- launch_mul_mat_q<type, 80, 8>(args, stream);
1959
  break;
1960
  case 88:
1961
- launch_mul_mat_q<type, 88, 8>(args, stream);
1962
  break;
1963
  case 96:
1964
- launch_mul_mat_q<type, 96, 8>(args, stream);
1965
  break;
1966
  case 104:
1967
- launch_mul_mat_q<type, 104, 8>(args, stream);
1968
  break;
1969
  case 112:
1970
- launch_mul_mat_q<type, 112, 8>(args, stream);
1971
  break;
1972
  case 120:
1973
- launch_mul_mat_q<type, 120, 8>(args, stream);
1974
  break;
1975
  case 128:
1976
- launch_mul_mat_q<type, 128, 8>(args, stream);
1977
  break;
1978
  default:
 
1979
  GGML_ASSERT(false);
1980
  break;
1981
  }
 
10
  #define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
11
 
12
  typedef void (*load_tiles_mmq_t)(
13
+ const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
14
  int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride);
15
  typedef void (*vec_dot_mmq_t)(
16
+ const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
17
  const int * __restrict__ y, float * __restrict__ sum, const int & k0);
18
  typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1);
19
 
 
25
  static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size");
26
 
27
  struct tile_x_sizes {
28
+ int qs;
29
  int dm;
 
30
  int sc;
31
  };
32
 
 
66
  #endif // __CUDA_ARCH__ >= CC_VOLTA
67
  #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
68
 
69
+ #define TILE_X_SIZES_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0}
70
+ #define TILE_X_SIZES_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0}
71
+ #define TILE_X_SIZES_Q5_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_0 + mmq_y/QI5_0, 0}
72
+ #define TILE_X_SIZES_Q5_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_1 + mmq_y/QI5_1, 0}
73
+ #define TILE_X_SIZES_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI8_0 + mmq_y/QI8_0, 0}
74
+ #define TILE_X_SIZES_Q2_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE + mmq_y, 0}
75
+ #define TILE_X_SIZES_Q3_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI3_K + mmq_y/QI3_K, mmq_y*WARP_SIZE/4 + mmq_y/4}
76
+ #define TILE_X_SIZES_Q4_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_K + mmq_y/QI4_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
77
+ #define TILE_X_SIZES_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
78
+ #define TILE_X_SIZES_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
79
 
80
  #define GET_TILE_X_SIZES_BODY \
81
  return type == GGML_TYPE_Q4_0 ? TILE_X_SIZES_Q4_0 : \
 
88
  type == GGML_TYPE_Q4_K ? TILE_X_SIZES_Q4_K : \
89
  type == GGML_TYPE_Q5_K ? TILE_X_SIZES_Q5_K : \
90
  type == GGML_TYPE_Q6_K ? TILE_X_SIZES_Q6_K : \
91
+ tile_x_sizes{0, 0, 0}
92
 
93
  static tile_x_sizes get_tile_x_sizes_host(const ggml_type type, const int mmq_y) {
94
  GET_TILE_X_SIZES_BODY;
 
102
  // ------------------------------------------------------------
103
 
104
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
105
+ const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
106
  int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
107
+ GGML_UNUSED(x_sc);
108
 
109
  const int kbx = threadIdx.x / QI4_0;
110
  const int kqsx = threadIdx.x % QI4_0;
 
121
 
122
  const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
123
 
124
+ x_qs[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx);
125
  }
126
 
127
  const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
 
143
 
144
  template <int mmq_x, int mmq_y, int nwarps>
145
  static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
146
+ const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
147
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
148
+ GGML_UNUSED(x_sc);
 
149
 
150
  const float * x_df = (const float *) x_dm;
151
  const int * y_qs = (const int *) y + 4;
 
170
  }
171
 
172
  sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
173
+ (&x_qs[i*(WARP_SIZE + 1) + k0], u, x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0],
174
  y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
175
  }
176
  }
 
178
 
179
  template <int mmq_x, int mmq_y, int nwarps>
180
  static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma(
181
+ const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
182
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
183
+ #ifdef INT8_MMA_AVAILABLE
184
+ GGML_UNUSED(x_sc);
185
 
186
  typedef mma_int_A_I16K8 mma_A;
187
  typedef mma_int_B_J8K8 mma_B;
 
203
  const int k = k0 + mma_A::get_k(l) % QI4_0;
204
  const int shift = 4*(mma_A::get_k(l) / QI4_0);
205
 
206
+ A.x[l] = __vsubss4((x_qs[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F, 0x08080808);
207
  }
208
  #pragma unroll
209
  for (int l = 0; l < mma_C::ne/2; ++l) {
 
238
  sum[(j0/B.J)*C.ne + l] += dA[l/2]*__low2float(dsB[l%2])*C.x[l];
239
  }
240
  }
241
+ #else
242
+ GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
243
+ NO_DEVICE_CODE;
244
+ #endif // INT8_MMA_AVAILABLE
245
  }
246
 
247
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
248
+ const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
249
  int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
250
+ GGML_UNUSED(x_sc);
251
 
252
  const int kbx = threadIdx.x / QI4_1;
253
  const int kqsx = threadIdx.x % QI4_1;
 
262
 
263
  const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
264
 
265
+ x_qs[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
266
  }
267
 
268
  const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
 
284
 
285
  template <int mmq_x, int mmq_y, int nwarps>
286
  static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
287
+ const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
288
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
289
+ GGML_UNUSED(x_sc);
 
290
 
291
  const int * y_qs = (const int *) y + 4;
292
  const half2 * y_ds = (const half2 *) y;
 
310
  }
311
 
312
  sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
313
+ (&x_qs[i*(WARP_SIZE + 1) + k0], u, x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + k0/QI4_1],
314
  y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
315
  }
316
  }
 
318
 
319
  template <int mmq_x, int mmq_y, int nwarps>
320
  static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma(
321
+ const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
322
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
323
+ #ifdef INT8_MMA_AVAILABLE
324
+ GGML_UNUSED(x_sc);
325
 
326
  typedef mma_int_A_I16K8 mma_A;
327
  typedef mma_int_B_J8K8 mma_B;
 
342
  const int k = k0 + mma_A::get_k(l) % QI4_0;
343
  const int shift = 4*(mma_A::get_k(l) / QI4_0);
344
 
345
+ A.x[l] = (x_qs[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F;
346
  }
347
  #pragma unroll
348
  for (int l = 0; l < mma_C::ne/2; ++l) {
 
378
  sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
379
  }
380
  }
381
+ #else
382
+ GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
383
+ NO_DEVICE_CODE;
384
+ #endif // INT8_MMA_AVAILABLE
385
  }
386
 
387
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
388
+ const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
389
  int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
390
+ GGML_UNUSED(x_sc);
391
 
392
  const int kbx = threadIdx.x / QI5_0;
393
  const int kqsx = threadIdx.x % QI5_0;
 
412
  qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
413
  qs0 = __vsubss4(qs0, 0x10101010); // subtract 16
414
 
415
+ x_qs[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+0] = qs0;
416
 
417
  int qs1 = (ql >> 4) & 0x0F0F0F0F;
418
  qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
 
421
  qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
422
  qs1 = __vsubss4(qs1, 0x10101010); // subtract 16
423
 
424
+ x_qs[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+1] = qs1;
425
  }
426
 
427
  const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
 
444
 
445
  template <int mmq_x, int mmq_y, int nwarps>
446
  static __device__ __forceinline__ void vec_dot_q5_0_q8_1_dp4a(
447
+ const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
448
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
449
+ GGML_UNUSED(x_sc);
 
450
 
451
  const float * x_dmf = (const float *) x_dm;
452
  const int * y_qs = (const int *) y + 4;
 
472
  }
473
 
474
  sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, QR5_0*VDR_Q5_0_Q8_1_MMQ>
475
+ (&x_qs[i*(2*WARP_SIZE + 1) + 2*k0], u, x_dmf[index_bx], y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
476
  }
477
  }
478
  }
479
 
480
  template <int mmq_x, int mmq_y, int nwarps>
481
  static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma(
482
+ const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
483
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
484
+ #ifdef INT8_MMA_AVAILABLE
485
+ GGML_UNUSED(x_sc);
486
 
487
  typedef mma_int_A_I16K8 mma_A;
488
  typedef mma_int_B_J8K8 mma_B;
 
503
  const int i = i0 + mma_A::get_i(l);
504
  const int k = 2*(k0 + mma_A::get_k(l) % QI5_0) + mma_A::get_k(l) / QI5_0;
505
 
506
+ A.x[l] = x_qs[i*(2*WARP_SIZE + 1) + k];
507
  }
508
  #pragma unroll
509
  for (int l = 0; l < mma_C::ne/2; ++l) {
 
538
  sum[(j0/B.J)*C.ne + l] += dA[l/2]*dB[l%2]*C.x[l];
539
  }
540
  }
541
+ #else
542
+ GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
543
+ NO_DEVICE_CODE;
544
+ #endif // INT8_MMA_AVAILABLE
545
  }
546
 
547
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
548
+ const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
549
  int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
550
+ GGML_UNUSED(x_sc);
551
 
552
  const int kbx = threadIdx.x / QI5_1;
553
  const int kqsx = threadIdx.x % QI5_1;
 
571
  qs0 |= (qh << 18) & 0x00100000; // 2 -> 20
572
  qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
573
 
574
+ x_qs[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+0] = qs0;
575
 
576
  int qs1 = (ql >> 4) & 0x0F0F0F0F;
577
  qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
 
579
  qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
580
  qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
581
 
582
+ x_qs[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+1] = qs1;
583
  }
584
 
585
  const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
 
601
 
602
  template <int mmq_x, int mmq_y, int nwarps>
603
  static __device__ __forceinline__ void vec_dot_q5_1_q8_1_dp4a(
604
+ const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
605
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
606
+ GGML_UNUSED(x_sc);
 
607
 
608
  const int * y_qs = (const int *) y + 4;
609
  const half2 * y_ds = (const half2 *) y;
 
628
  }
629
 
630
  sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
631
+ (&x_qs[i*(2*WARP_SIZE + 1) + 2*k0], u, x_dm[index_bx], y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
632
  }
633
  }
634
  }
635
 
636
  template <int mmq_x, int mmq_y, int nwarps>
637
  static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma(
638
+ const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
639
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
640
+ #ifdef INT8_MMA_AVAILABLE
641
+ GGML_UNUSED(x_sc);
642
 
643
  typedef mma_int_A_I16K8 mma_A;
644
  typedef mma_int_B_J8K8 mma_B;
 
658
  const int i = i0 + mma_A::get_i(l);
659
  const int k = 2*(k0 + mma_A::get_k(l) % QI5_1) + mma_A::get_k(l) / QI5_1;
660
 
661
+ A.x[l] = x_qs[i*(2*WARP_SIZE + 1) + k];
662
  }
663
  #pragma unroll
664
  for (int l = 0; l < mma_C::ne/2; ++l) {
 
694
  sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
695
  }
696
  }
697
+ #else
698
+ GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
699
+ NO_DEVICE_CODE;
700
+ #endif // INT8_MMA_AVAILABLE
701
  }
702
 
703
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
704
+ const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
705
  int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
706
+ GGML_UNUSED(x_sc);
 
707
 
708
  const int kbx = threadIdx.x / QI8_0;
709
  const int kqsx = threadIdx.x % QI8_0;
 
719
 
720
  const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
721
 
722
+ x_qs[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_int8(bxi->qs, kqsx);
723
  }
724
 
725
  const int blocks_per_tile_x_row = WARP_SIZE / QI8_0;
 
741
 
742
  template <int mmq_x, int mmq_y, int nwarps>
743
  static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
744
+ const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
745
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
746
+ GGML_UNUSED(x_sc);
 
747
 
748
  const float * x_dmf = (const float *) x_dm;
749
  const int * y_qs = (const int *) y + 4;
 
758
  const int i = i0 + threadIdx.x;
759
 
760
  sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
761
+ (&x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0], x_dmf[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0],
762
  y_df[j*MMQ_TILE_Y_K + k0/QI8_1]);
763
  }
764
  }
 
766
 
767
  template <int mmq_x, int mmq_y, int nwarps>
768
  static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
769
+ const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
770
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
771
+ #ifdef INT8_MMA_AVAILABLE
772
+ GGML_UNUSED(x_sc);
773
 
774
  typedef mma_int_A_I16K8 mma_A;
775
  typedef mma_int_B_J8K8 mma_B;
 
790
  const int i = i0 + mma_A::get_i(l);
791
  const int k = k0 + mma_A::get_k(l);
792
 
793
+ A.x[l] = x_qs[i*(WARP_SIZE + 1) + k];
794
  }
795
  #pragma unroll
796
  for (int l = 0; l < mma_C::ne/2; ++l) {
 
825
  sum[(j0/B.J)*C.ne + l] += C.x[l]*dA[l/2]*dB[l%2];
826
  }
827
  }
828
+ #else
829
+ GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
830
+ NO_DEVICE_CODE;
831
+ #endif // INT8_MMA_AVAILABLE
832
  }
833
 
834
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
835
+ const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
836
  int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
 
837
 
838
  const int kbx = threadIdx.x / QI2_K;
839
  const int kqsx = threadIdx.x % QI2_K;
 
848
 
849
  const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + kbx;
850
 
851
+ const int x_ql_0 = get_int_from_uint8(bxi->qs, kqsx);
 
 
 
 
852
 
853
  #pragma unroll
854
+ for (int l = 0; l < QR2_K; ++l) {
855
+ const int k = kbx*QI2_K + (kqsx/8)*8 + l*2 + (kqsx % 8)/4;
856
 
857
+ int x_qs_k = ((x_ql_0 >> (2*l)) & 0x03030303) << (2*(kqsx % 4));
858
+ x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 1, WARP_SIZE);
859
+ x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 2, WARP_SIZE);
 
 
 
 
 
860
 
861
+ if (kqsx % QR2_K != 0) {
862
+ continue;
863
+ }
864
 
865
+ x_qs[i*(WARP_SIZE + 1) + k] = x_qs_k;
 
866
  }
867
 
868
+ const int sc_m = bxi->scales[kqsx];
869
+ #ifdef FAST_FP16_AVAILABLE
870
+ const half2 x_dm_ik = __hmul2(bxi->dm, make_half2(sc_m & 0x0F, sc_m >> 4));
871
+ #else
872
+ const float2 bxi_dmf = __half22float2(bxi->dm);
873
+ const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4));
874
+ #endif // FAST_FP16_AVAILABLE
875
 
876
+ x_dm[i*(WARP_SIZE + 1) + threadIdx.x] = x_dm_ik;
877
  }
878
  }
879
 
880
  template <int mmq_x, int mmq_y, int nwarps>
881
+ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
882
+ const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
883
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
884
 
885
+ const int * y_qs = (const int *) y + 4;
886
+ const float * y_df = (const float *) y;
 
 
887
 
888
  #pragma unroll
889
  for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
 
893
  for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
894
  const int i = i0 + threadIdx.x;
895
 
896
+ sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq(
897
+ &x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + (QR2_K*k0) % WARP_SIZE],
898
+ &x_dm[i*(WARP_SIZE + 1) + k0], y_df[j*MMQ_TILE_Y_K + ((QR2_K*k0) % WARP_SIZE)/QI8_1]);
899
+ }
900
+ }
901
+ }
902
+
903
+ template <int mmq_x, int mmq_y, int nwarps>
904
+ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
905
+ const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
906
+ const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
907
+ #ifdef INT8_MMA_AVAILABLE
908
+
909
+ typedef mma_int_A_I16K4 mma_A;
910
+ typedef mma_int_B_J8K4 mma_B;
911
+ typedef mma_int_C_I16J8 mma_C;
912
+
913
+ const int * y_qs = (const int *) y + 4;
914
+ const float * y_df = (const float *) y;
915
 
916
+ const int i0 = threadIdx.y*mma_A::I;
917
+ static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
918
 
919
+ mma_A A[2];
920
+ float dA[mma_C::ne/2][2];
921
+ float mA[mma_C::ne/2][2];
922
 
923
  #pragma unroll
924
+ for (int l = 0; l < mma_A::ne; ++l) {
925
+ const int i = i0 + mma_A::get_i(l);
926
+ const int shift = 2*mma_A::get_k(l);
927
 
928
+ A[0].x[l] = (x_qs[i*(WARP_SIZE + 1) + k0 + 0] >> shift) & 0x03030303;
929
+ A[1].x[l] = (x_qs[i*(WARP_SIZE + 1) + k0 + 1] >> shift) & 0x03030303;
930
+ }
931
 
932
+ #pragma unroll
933
+ for (int l = 0; l < mma_C::ne/2; ++l) {
934
+ const int i = i0 + mma_C::get_i(2*l);
935
+
936
+ #pragma unroll
937
+ for (int kk = 0; kk < 2; ++kk) {
938
+ const float2 dm = __half22float2(x_dm[i*(WARP_SIZE + 1) + k0 + kk]);
939
+
940
+ dA[l][kk] = dm.x;
941
+ mA[l][kk] = dm.y;
942
  }
943
  }
944
+
945
+ #pragma unroll
946
+ for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
947
+ mma_C Cd[2];
948
+ mma_C Cm[2];
949
+ mma_B B[2];
950
+ float dB[mma_C::ne/2];
951
+
952
+ #pragma unroll
953
+ for (int l = 0; l < mma_B::ne; ++l) {
954
+ const int j = j0 + mma_B::get_j(l);
955
+ const int k = (4*k0 + mma_B::get_k(l)) % WARP_SIZE;
956
+
957
+ B[0].x[l] = y_qs[j*MMQ_TILE_Y_K + k + 0];
958
+ B[1].x[l] = y_qs[j*MMQ_TILE_Y_K + k + mma_B::K];
959
+ }
960
+ #pragma unroll
961
+ for (int l = 0; l < mma_C::ne/2; ++l) {
962
+ const int j = j0 + mma_C::get_j(l);
963
+
964
+ dB[l] = y_df[j*MMQ_TILE_Y_K + ((4*k0)/QI8_1) % (WARP_SIZE/QI8_1)];
965
+ }
966
+
967
+ Cd[0].mma_K4(A[0], B[0]);
968
+ Cd[1].mma_K4(A[1], B[1]);
969
+
970
+ mma_A A1;
971
+ A1.x[0] = 0x01010101;
972
+ A1.x[1] = 0x01010101;
973
+ Cm[0].mma_K4(A1, B[0]);
974
+ Cm[1].mma_K4(A1, B[1]);
975
+
976
+ #pragma unroll
977
+ for (int l = 0; l < mma_C::ne; ++l) {
978
+ sum[(j0/mma_B::J)*mma_C::ne + l] += (Cd[0].x[l]*dA[l/2][0] + Cd[1].x[l]*dA[l/2][1] - Cm[0].x[l]*mA[l/2][0] - Cm[1].x[l]*mA[l/2][1])*dB[l%2];
979
+ }
980
+ }
981
+ #else
982
+ GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
983
+ NO_DEVICE_CODE;
984
+ #endif // INT8_MMA_AVAILABLE
985
  }
986
 
987
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
988
+ const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
989
  int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
990
 
991
  const int kbx = threadIdx.x / QI3_K;
 
1001
 
1002
  const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbx;
1003
 
1004
+ const int x_ql_0 = get_int_from_uint8(bxi->qs, kqsx);
1005
+ const int x_qh_0 = get_int_from_uint8(bxi->hmask, kqsx % (QI3_K/2)) >> (4 * (kqsx / (QI3_K/2)));
1006
+
1007
+ #pragma unroll
1008
+ for (int l = 0; l < QR3_K; ++l) {
1009
+ const int k = kbx*(QR3_K*QI3_K) + (kqsx/8)*32 + l*8 + kqsx % 8;
1010
+
1011
+ const int x_ql_k = (x_ql_0 >> (2*l)) & 0x03030303;
1012
+ const int x_qh_k = ((x_qh_0 >> l) << 2) & 0x04040404;
1013
+
1014
+ int x_qs_k = (x_ql_k | x_qh_k) << (4*(k%2));
1015
+ x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 1, WARP_SIZE);
1016
+
1017
+ if (kqsx % 2 != 0) {
1018
+ continue;
1019
+ }
1020
+
1021
+ x_qs[i*(2*WARP_SIZE + 1) + k/2] = x_qs_k;
1022
+ }
1023
  }
1024
 
1025
  const int blocks_per_tile_x_row = WARP_SIZE / QI3_K;
 
1039
  x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = bxi->d;
1040
  }
1041
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1042
  #pragma unroll
1043
  for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
1044
  int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
 
1066
  }
1067
 
1068
  template <int mmq_x, int mmq_y, int nwarps>
1069
+ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
1070
+ const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
1071
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
1072
 
1073
+ const float * x_df = (const float *) x_dm;
1074
+ const int * y_qs = (const int *) y + 4;
1075
+ const float * y_df = (const float *) y;
1076
 
1077
  #pragma unroll
1078
  for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
 
1087
 
1088
  const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
1089
 
1090
+ sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq(
1091
+ &x_qs[i*(2*WARP_SIZE + 1) + 2*k0], &y_qs[j*MMQ_TILE_Y_K + (k0*QR3_K) % WARP_SIZE], scales,
1092
+ x_df[i*(WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[j*MMQ_TILE_Y_K + ((k0*QR3_K) % WARP_SIZE)/QI8_1]);
1093
+ }
1094
+ }
1095
+ }
1096
+
1097
+ template <int mmq_x, int mmq_y, int nwarps>
1098
+ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mma(
1099
+ const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
1100
+ const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
1101
+ #ifdef INT8_MMA_AVAILABLE
1102
+
1103
+ typedef mma_int_A_I16K4 mma_A;
1104
+ typedef mma_int_B_J8K4 mma_B;
1105
+ typedef mma_int_C_I16J8 mma_C;
1106
+
1107
+ const float * x_df = (const float *) x_dm;
1108
+ const int * y_qs = (const int *) y + 4;
1109
+ const float * y_df = (const float *) y;
1110
+
1111
+ const int i0 = threadIdx.y*mma_A::I;
1112
+ static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
1113
+
1114
+ mma_A A[2];
1115
+ int scA[mma_C::ne/2][2];
1116
+ float dA[mma_C::ne/2];
1117
 
1118
  #pragma unroll
1119
+ for (int l = 0; l < mma_A::ne; ++l) {
1120
+ const int i = i0 + mma_A::get_i(l);
1121
+ const int k = QR3_K*k0 + mma_A::get_k(l);
 
1122
 
1123
+ A[0].x[l] = (x_qs[i*(2*WARP_SIZE + 1) + k/2 + 0] >> (4*(k%2))) & 0x0F0F0F0F;
1124
+ A[1].x[l] = (x_qs[i*(2*WARP_SIZE + 1) + k/2 + mma_A::K/2] >> (4*(k%2))) & 0x0F0F0F0F;
1125
+ A[0].x[l] = __vsubss4(A[0].x[l], 0x04040404);
1126
+ A[1].x[l] = __vsubss4(A[1].x[l], 0x04040404);
1127
+ }
1128
 
1129
+ #pragma unroll
1130
+ for (int l = 0; l < mma_C::ne/2; ++l) {
1131
+ const int i = i0 + mma_C::get_i(2*l);
1132
 
1133
+ const int kbx = k0 / QI3_K;
1134
+ const int ky = (k0 % QI3_K) * QR3_K;
1135
+ const int8_t * sc = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
1136
+
1137
+ scA[l][0] = sc[0];
1138
+ scA[l][1] = sc[1];
1139
+ }
1140
+
1141
+ #pragma unroll
1142
+ for (int l = 0; l < mma_C::ne/2; ++l) {
1143
+ const int i = i0 + mma_C::get_i(2*l);
1144
+
1145
+ dA[l] = x_df[i*(WARP_SIZE/QI3_K) + i/QI3_K + k0/QI3_K];
1146
+ }
1147
+
1148
+ #pragma unroll
1149
+ for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
1150
+ mma_C C[2];
1151
+ mma_B B[2];
1152
+ float dB[mma_C::ne/2];
1153
+
1154
+ #pragma unroll
1155
+ for (int l = 0; l < mma_B::ne; ++l) {
1156
+ const int j = j0 + mma_B::get_j(l);
1157
+ const int k = (4*k0 + mma_B::get_k(l)) % WARP_SIZE;
1158
+
1159
+ B[0].x[l] = y_qs[j*MMQ_TILE_Y_K + k + 0];
1160
+ B[1].x[l] = y_qs[j*MMQ_TILE_Y_K + k + mma_B::K];
1161
+ }
1162
+ #pragma unroll
1163
+ for (int l = 0; l < mma_C::ne/2; ++l) {
1164
+ const int j = j0 + mma_C::get_j(l);
1165
+
1166
+ dB[l] = y_df[j*MMQ_TILE_Y_K + ((4*k0)/QI8_1) % (WARP_SIZE/QI8_1)];
1167
+ }
1168
+
1169
+ C[0].mma_K4(A[0], B[0]);
1170
+ C[1].mma_K4(A[1], B[1]);
1171
+
1172
+ #pragma unroll
1173
+ for (int l = 0; l < mma_C::ne; ++l) {
1174
+ sum[(j0/mma_B::J)*mma_C::ne + l] += (C[0].x[l]*scA[l/2][0] + C[1].x[l]*scA[l/2][1])*dA[l/2]*dB[l%2];
1175
  }
1176
  }
1177
+ #else
1178
+ GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
1179
+ NO_DEVICE_CODE;
1180
+ #endif // INT8_MMA_AVAILABLE
1181
  }
1182
 
1183
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
1184
+ const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
1185
  int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
 
1186
 
1187
  const int kbx = 0; // threadIdx.x / QI4_K
1188
  const int kqsx = threadIdx.x; // threadIdx.x % QI4_K
 
1197
 
1198
  const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbx;
1199
 
1200
+ x_qs[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
1201
  }
1202
 
1203
  const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256
 
1240
 
1241
  template <int mmq_x, int mmq_y, int nwarps>
1242
  static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
1243
+ const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
1244
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
1245
 
 
 
1246
  const int * y_qs = (const int *) y + 4;
1247
  const half2 * y_ds = (const half2 *) y;
1248
 
 
1257
  const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2*((k0 % 16) / 8);
1258
 
1259
  sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq(
1260
+ &x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + (QR4_K*k0) % WARP_SIZE], sc, sc+8,
1261
  x_dm[i*(WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[j*MMQ_TILE_Y_K + ((QR4_K*k0) % WARP_SIZE)/QI8_1]);
1262
  }
1263
  }
 
1265
 
1266
  template <int mmq_x, int mmq_y, int nwarps>
1267
  static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma(
1268
+ const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
1269
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
1270
+ #ifdef INT8_MMA_AVAILABLE
 
1271
 
1272
  typedef mma_int_A_I16K8 mma_A;
1273
  typedef mma_int_B_J8K8 mma_B;
 
1290
  const int i = i0 + mma_A::get_i(l);
1291
  const int k = k0 + mma_A::get_k(l);
1292
 
1293
+ A[kvdr/4].x[l] = (x_qs[i*(WARP_SIZE + 1) + k] >> kvdr) & 0x0F0F0F0F;
1294
  }
1295
 
1296
  #pragma unroll
 
1351
  sum[(j0/mma_B::J)*mma_C::ne + l] += __low2float(dmA[l/2])*tmpd[l] - __high2float(dmA[l/2])*tmpm[l];
1352
  }
1353
  }
1354
+ #else
1355
+ GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
1356
+ NO_DEVICE_CODE;
1357
+ #endif // INT8_MMA_AVAILABLE
1358
  }
1359
 
1360
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
1361
+ const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
1362
  int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
 
1363
 
1364
  const int kbx = 0; // threadIdx.x / QI5_K
1365
  const int kqsx = threadIdx.x; // threadIdx.x % QI5_K
 
1386
  const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0;
1387
  const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + (QI5_K/4);
1388
 
1389
+ x_qs[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0;
1390
+ x_qs[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
1391
  }
1392
 
1393
  const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256
 
1430
 
1431
  template <int mmq_x, int mmq_y, int nwarps>
1432
  static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
1433
+ const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
1434
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
1435
 
 
 
1436
  const int * y_qs = (const int *) y + 4;
1437
  const half2 * y_ds = (const half2 *) y;
1438
 
 
1447
  const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8);
1448
 
1449
  sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq(
1450
+ &x_qs[i*(QR5_K*WARP_SIZE + 1) + QR5_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR5_K*k0) % WARP_SIZE], sc, sc+8,
1451
  x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[j*MMQ_TILE_Y_K + ((QR5_K*k0) % WARP_SIZE)/QI8_1]);
1452
  }
1453
  }
 
1455
 
1456
  template <int mmq_x, int mmq_y, int nwarps>
1457
  static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma(
1458
+ const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
1459
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
1460
+ #ifdef INT8_MMA_AVAILABLE
 
1461
 
1462
  typedef mma_int_A_I16K8 mma_A;
1463
  typedef mma_int_B_J8K8 mma_B;
 
1480
  const int i = i0 + mma_A::get_i(l);
1481
  const int k = QR5_K*k0 + QR5_K*kvdr + mma_A::get_k(l);
1482
 
1483
+ A[kvdr/4].x[l] = x_qs[i*(QR5_K*WARP_SIZE + 1) + k];
1484
  }
1485
 
1486
  #pragma unroll
 
1541
  sum[(j0/mma_B::J)*mma_C::ne + l] += __low2float(dmA[l/2])*tmpd[l] - __high2float(dmA[l/2])*tmpm[l];
1542
  }
1543
  }
1544
+ #else
1545
+ GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
1546
+ NO_DEVICE_CODE;
1547
+ #endif // INT8_MMA_AVAILABLE
1548
  }
1549
 
1550
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
1551
+ const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
1552
  int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
 
1553
 
1554
  const int kbx = 0; // threadIdx.x / QI6_K
1555
  const int kqsx = threadIdx.x; // threadIdx.x % QI6_K
 
1576
  const int kq0 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + 0;
1577
  const int kq1 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + (QI6_K/2);
1578
 
1579
+ x_qs[i * (2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
1580
+ x_qs[i * (2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
1581
  }
1582
 
1583
  const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256
 
1613
 
1614
  template <int mmq_x, int mmq_y, int nwarps>
1615
  static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
1616
+ const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
1617
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
1618
 
 
 
1619
  const float * x_dmf = (const float *) x_dm;
1620
  const int * y_qs = (const int *) y + 4;
1621
  const float * y_df = (const float *) y;
 
1631
  const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/8]);
1632
 
1633
  sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq(
1634
+ &x_qs[i*(QR6_K*WARP_SIZE + 1) + QR6_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR6_K*k0) % WARP_SIZE], sc,
1635
  x_dmf[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + ((QR6_K*k0) % WARP_SIZE)/QI8_1]);
1636
  }
1637
  }
 
1639
 
1640
  template <int mmq_x, int mmq_y, int nwarps>
1641
  static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
1642
+ const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
1643
  const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
1644
+ #ifdef INT8_MMA_AVAILABLE
 
1645
 
1646
  typedef mma_int_A_I16K4 mma_A;
1647
  typedef mma_int_B_J8K4 mma_B;
 
1652
  const float * y_df = (const float *) y;
1653
 
1654
  const int i0 = threadIdx.y*mma_A::I;
1655
+ #ifdef INT8_MMA_AVAILABLE
1656
  static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
1657
+ #endif // INT8_MMA_AVAILABLE
1658
 
1659
  mma_A A[4];
1660
  int scA[mma_C::ne/2][4];
 
1666
  const int i = i0 + mma_A::get_i(l);
1667
  const int k = QR6_K*k0 + QR6_K*kvdr + mma_A::get_k(l);
1668
 
1669
+ A[kvdr/2 + 0].x[l] = x_qs[i*(QR6_K*WARP_SIZE + 1) + k + 0];
1670
+ A[kvdr/2 + 1].x[l] = x_qs[i*(QR6_K*WARP_SIZE + 1) + k + mma_A::K];
1671
  }
1672
 
1673
  #pragma unroll
 
1727
  sum[(j0/mma_B::J)*mma_C::ne + l] += tmp[l]*dA[l/2];
1728
  }
1729
  }
1730
+ #else
1731
+ GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
1732
+ NO_DEVICE_CODE;
1733
+ #endif // INT8_MMA_AVAILABLE
1734
  }
1735
 
1736
  template<int mmq_x, int mmq_y, int nwarps, bool need_check>
 
1761
  typedef mma_int_C_I16J8 mma_C;
1762
 
1763
  const int i0 = threadIdx.y*mma_C::I;
1764
+ #ifdef INT8_MMA_AVAILABLE
1765
  static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y");
1766
+ #endif // INT8_MMA_AVAILABLE
1767
 
1768
  #pragma unroll
1769
  for (int j0 = 0; j0 < mmq_x; j0 += mma_C::J) {
 
1793
 
1794
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
1795
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_0> {
1796
+ static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
1797
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, nwarps, need_check>;
1798
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q4_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
1799
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
 
 
 
 
 
1800
  };
1801
 
1802
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
1803
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_1> {
1804
+ static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ;
1805
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, nwarps, need_check>;
1806
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q4_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
1807
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
 
 
 
 
 
1808
  };
1809
 
1810
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
1811
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
1812
+ static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ;
1813
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, nwarps, need_check>;
1814
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q5_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
1815
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
 
 
 
 
 
1816
  };
1817
 
1818
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
1819
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
1820
+ static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ;
1821
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, nwarps, need_check>;
1822
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q5_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
1823
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
 
 
 
 
 
1824
  };
1825
 
1826
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
1827
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> {
1828
+ static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
1829
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, nwarps, need_check>;
1830
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
1831
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
 
 
 
 
 
1832
  };
1833
 
1834
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
1835
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q2_K> {
1836
+ static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
1837
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, nwarps, need_check>;
1838
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q2_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
1839
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
1840
  };
1841
 
1842
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
1843
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q3_K> {
1844
+ static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ;
1845
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, nwarps, need_check>;
1846
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q3_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
1847
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
1848
  };
1849
 
1850
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
1851
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> {
1852
+ static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ;
1853
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, nwarps, need_check>;
1854
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q4_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
1855
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
 
 
 
 
 
1856
  };
1857
 
1858
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
1859
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> {
1860
+ static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ;
1861
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, nwarps, need_check>;
1862
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q5_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
1863
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
 
 
 
 
 
1864
  };
1865
 
1866
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
1867
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
1868
+ static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ;
1869
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, nwarps, need_check>;
1870
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q6_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
1871
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
 
 
 
 
 
1872
  };
1873
 
1874
+ static bool mmq_need_sum(const ggml_type type_x) {
1875
  switch (type_x) {
1876
  case GGML_TYPE_Q4_0:
1877
  case GGML_TYPE_Q4_1:
 
1905
  #if __CUDA_ARCH__ >= CC_VOLTA
1906
  __launch_bounds__(WARP_SIZE*nwarps, 1)
1907
  #else
1908
+ __launch_bounds__(WARP_SIZE*nwarps, 2)
1909
  #endif // __CUDA_ARCH__ >= CC_VOLTA
1910
  #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
1911
  static __global__ void mul_mat_q(
 
1924
  constexpr int mmq_y = get_mmq_y_device(mmq_x);
1925
  constexpr int vdr = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vdr;
1926
  constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
1927
+
1928
+ #ifdef INT8_MMA_AVAILABLE
1929
+ constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot_mma;
1930
+ constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
1931
+ #else
1932
+ constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot_dp4a;
1933
+ constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
1934
+ #endif // INT8_MMA_AVAILABLE
1935
 
1936
  constexpr tile_x_sizes txs = get_tile_x_sizes_device<mmq_y>(type);
1937
 
1938
  extern __shared__ char data_mul_mat_q[];
1939
+ int * tile_x_qs = (int *) data_mul_mat_q;
1940
+ half2 * tile_x_dm = (half2 *) (tile_x_qs + txs.qs);
1941
+ int * tile_x_sc = (int *) (tile_x_dm + txs.dm);
 
1942
  int * tile_y = (int *) (tile_x_sc + txs.sc); // [mmq_x * (WARP_SIZE + WARP_SIZE/QI8_1)]
1943
 
1944
  const int blocks_per_row_x = ne00 / qk;
 
1954
 
1955
  for (int kb0 = 0; kb0 < blocks_per_row_x; kb0 += blocks_per_warp) {
1956
 
1957
+ load_tiles(x, tile_x_qs, tile_x_dm, tile_x_sc, stride01*blockIdx.x*mmq_y + kb0, tile_x_max_i, stride01);
1958
 
1959
  #pragma unroll
1960
  for (int kr = 0; kr < qr; ++kr) {
 
1970
 
1971
  // #pragma unroll // unrolling this loop causes too much register pressure
1972
  for (int k0 = kr*WARP_SIZE/qr; k0 < (kr+1)*WARP_SIZE/qr; k0 += vdr) {
1973
+ vec_dot(tile_x_qs, tile_x_dm, tile_x_sc, tile_y, sum, k0);
1974
  }
1975
 
1976
  __syncthreads();
 
1987
  int64_t ne0;
1988
  };
1989
 
1990
+ constexpr int mmq_get_nwarps(int mmq_x) {
1991
+ return mmq_x >= 32 ? 8 : 4;
1992
+ }
1993
+
1994
+ static int mmq_get_shmem(const ggml_type type, const int mmq_x, const int mmq_y) {
1995
+ const tile_x_sizes txs = get_tile_x_sizes_host(type, mmq_y);
1996
+ const int nwarps = mmq_get_nwarps(mmq_x);
1997
+
1998
+ const int shmem_x = txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
1999
+ const int shmem_y = mmq_x*WARP_SIZE*sizeof(int) + mmq_x*(WARP_SIZE/QI8_1)*sizeof(half2);
2000
+ return shmem_x + GGML_PAD(shmem_y, nwarps*WARP_SIZE*sizeof(int));
2001
+ }
2002
+
2003
  template <ggml_type type, int mmq_x, int nwarps>
2004
  static void launch_mul_mat_q(const mmq_args & args, cudaStream_t stream) {
2005
  const int id = ggml_cuda_get_device();
 
2011
  const dim3 block_nums(block_num_x, block_num_y, 1);
2012
  const dim3 block_dims(WARP_SIZE, nwarps, 1);
2013
 
2014
+ const int shmem = mmq_get_shmem(type, mmq_x, mmq_y);
 
 
 
2015
 
2016
  #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
2017
  static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
 
2035
 
2036
  template <ggml_type type>
2037
  void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) {
2038
+ const int id = ggml_cuda_get_device();
2039
+ const int nsm = ggml_cuda_info().devices[id].nsm;
2040
+ const int cc = ggml_cuda_info().devices[id].cc;
2041
+ const int smpbo = ggml_cuda_info().devices[id].smpbo;
2042
 
2043
  const int mmq_x_max = get_mmq_x_max_host(cc);
2044
  const int mmq_y = get_mmq_y_host(cc, mmq_x_max);
 
2051
  const int block_num_x = (args.ne11 + mmq_x - 1) / mmq_x;
2052
  const int nwaves = (block_num_x*block_num_y + nsm - 1) / nsm;
2053
 
2054
+ if (nwaves < nwaves_best && mmq_get_shmem(type, mmq_x, mmq_y) <= smpbo) {
2055
  mmq_x_best = mmq_x;
2056
  nwaves_best = nwaves;
2057
  }
 
2059
 
2060
  switch (mmq_x_best) {
2061
  case 8:
2062
+ launch_mul_mat_q<type, 8, mmq_get_nwarps( 8)>(args, stream);
2063
  break;
2064
  case 16:
2065
+ launch_mul_mat_q<type, 16, mmq_get_nwarps( 16)>(args, stream);
2066
  break;
2067
  case 24:
2068
+ launch_mul_mat_q<type, 24, mmq_get_nwarps( 24)>(args, stream);
2069
  break;
2070
  case 32:
2071
+ launch_mul_mat_q<type, 32, mmq_get_nwarps( 32)>(args, stream);
2072
  break;
2073
  case 40:
2074
+ launch_mul_mat_q<type, 40, mmq_get_nwarps( 40)>(args, stream);
2075
  break;
2076
  case 48:
2077
+ launch_mul_mat_q<type, 48, mmq_get_nwarps( 48)>(args, stream);
2078
  break;
2079
  case 56:
2080
+ launch_mul_mat_q<type, 56, mmq_get_nwarps( 56)>(args, stream);
2081
  break;
2082
  case 64:
2083
+ launch_mul_mat_q<type, 64, mmq_get_nwarps( 64)>(args, stream);
2084
  break;
2085
  case 72:
2086
+ launch_mul_mat_q<type, 72, mmq_get_nwarps( 72)>(args, stream);
2087
  break;
2088
  case 80:
2089
+ launch_mul_mat_q<type, 80, mmq_get_nwarps( 80)>(args, stream);
2090
  break;
2091
  case 88:
2092
+ launch_mul_mat_q<type, 88, mmq_get_nwarps( 88)>(args, stream);
2093
  break;
2094
  case 96:
2095
+ launch_mul_mat_q<type, 96, mmq_get_nwarps( 96)>(args, stream);
2096
  break;
2097
  case 104:
2098
+ launch_mul_mat_q<type, 104, mmq_get_nwarps(104)>(args, stream);
2099
  break;
2100
  case 112:
2101
+ launch_mul_mat_q<type, 112, mmq_get_nwarps(112)>(args, stream);
2102
  break;
2103
  case 120:
2104
+ launch_mul_mat_q<type, 120, mmq_get_nwarps(120)>(args, stream);
2105
  break;
2106
  case 128:
2107
+ launch_mul_mat_q<type, 128, mmq_get_nwarps(128)>(args, stream);
2108
  break;
2109
  default:
2110
+ fprintf(stderr, "mmq_x_best=%d\n", mmq_x_best);
2111
  GGML_ASSERT(false);
2112
  break;
2113
  }
ggml-cuda/softmax.cu CHANGED
@@ -130,6 +130,7 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
130
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
131
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
132
 
 
133
  if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
134
  switch (ncols_x) {
135
  case 32:
 
130
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
131
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
132
 
133
+ // FIXME: this limit could be raised by ~2-4x on Ampere or newer
134
  if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
135
  switch (ncols_x) {
136
  case 32:
ggml-cuda/vecdotq.cuh CHANGED
@@ -265,36 +265,31 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
265
 
266
  // contiguous u/y values
267
  static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq(
268
- const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ scales,
269
- const half2 & dm2, const float & d8) {
270
 
271
  #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
272
- int sumi_d = 0;
273
- int sumi_m = 0;
274
 
275
  #pragma unroll
276
  for (int i0 = 0; i0 < QI8_1; i0 += QI8_1/2) {
277
- int sumi_d_sc = 0;
278
-
279
- const int sc = scales[i0 / (QI8_1/2)];
280
-
281
- // fill int with 4x m
282
- int m = sc >> 4;
283
- m |= m << 8;
284
- m |= m << 16;
285
 
 
286
  #pragma unroll
287
  for (int i = i0; i < i0 + QI8_1/2; ++i) {
288
- sumi_d_sc = __dp4a(v[i], u[i], sumi_d_sc); // SIMD dot product
289
- sumi_m = __dp4a(m, u[i], sumi_m); // multiply sum of q8_1 values with m
 
290
  }
291
 
292
- sumi_d += sumi_d_sc * (sc & 0xF);
 
293
  }
294
 
295
- const float2 dm2f = __half22float2(dm2);
296
-
297
- return d8 * (dm2f.x*sumi_d - dm2f.y*sumi_m);
298
  #else
299
  NO_DEVICE_CODE;
300
  #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
@@ -352,8 +347,10 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
352
  for (int i0 = 0; i0 < QR3_K*VDR_Q3_K_Q8_1_MMQ; i0 += QI8_1/2) {
353
  int sumi_sc = 0;
354
 
 
355
  for (int i = i0; i < i0 + QI8_1/2; ++i) {
356
- sumi_sc = __dp4a(v[i], u[i], sumi_sc); // SIMD dot product
 
357
  }
358
 
359
  sumi += sumi_sc * scales[i0 / (QI8_1/2)];
 
265
 
266
  // contiguous u/y values
267
  static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq(
268
+ const int * __restrict__ v, const int * __restrict__ u, const half2 * dm2, const float & d8) {
 
269
 
270
  #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
271
+ float sumf_d = 0.0f;
272
+ float sumf_m = 0.0f;
273
 
274
  #pragma unroll
275
  for (int i0 = 0; i0 < QI8_1; i0 += QI8_1/2) {
276
+ const float2 dm2f = __half22float2(dm2[i0/(QI8_1/2)]);
277
+ int sumi_d = 0;
278
+ int sumi_m = 0;
 
 
 
 
 
279
 
280
+ const int vi0 = v[i0/(QI8_1/2)];
281
  #pragma unroll
282
  for (int i = i0; i < i0 + QI8_1/2; ++i) {
283
+ const int vi = (vi0 >> (2*(i % (QI8_1/2)))) & 0x03030303;
284
+ sumi_d = __dp4a(vi, u[i], sumi_d); // SIMD dot product
285
+ sumi_m = __dp4a(0x01010101, u[i], sumi_m);
286
  }
287
 
288
+ sumf_d += dm2f.x * sumi_d;
289
+ sumf_m += dm2f.y * sumi_m;
290
  }
291
 
292
+ return d8*(sumf_d - sumf_m);
 
 
293
  #else
294
  NO_DEVICE_CODE;
295
  #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 
347
  for (int i0 = 0; i0 < QR3_K*VDR_Q3_K_Q8_1_MMQ; i0 += QI8_1/2) {
348
  int sumi_sc = 0;
349
 
350
+ #pragma unroll
351
  for (int i = i0; i < i0 + QI8_1/2; ++i) {
352
+ const int vi = __vsubss4((v[i/2] >> (4*(i%2))) & 0x0F0F0F0F, 0x04040404);
353
+ sumi_sc = __dp4a(vi, u[i], sumi_sc); // SIMD dot product
354
  }
355
 
356
  sumi += sumi_sc * scales[i0 / (QI8_1/2)];