JohannesGaessler commited on
Commit
6d14124
·
1 Parent(s): 2f95156

CUDA: MMQ code deduplication + iquant support (llama/8495)

Browse files

* CUDA: MMQ code deduplication + iquant support

* 1 less parallel job for CI build

ggml/src/ggml-cuda/mmq.cu CHANGED
@@ -59,6 +59,24 @@ void ggml_cuda_op_mul_mat_q(
59
  case GGML_TYPE_Q6_K:
60
  mul_mat_q_case<GGML_TYPE_Q6_K>(ctx, args, stream);
61
  break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  case GGML_TYPE_IQ4_XS:
63
  mul_mat_q_case<GGML_TYPE_IQ4_XS>(ctx, args, stream);
64
  break;
@@ -93,6 +111,12 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
93
  case GGML_TYPE_Q4_K:
94
  case GGML_TYPE_Q5_K:
95
  case GGML_TYPE_Q6_K:
 
 
 
 
 
 
96
  case GGML_TYPE_IQ4_XS:
97
  case GGML_TYPE_IQ4_NL:
98
  mmq_supported = true;
 
59
  case GGML_TYPE_Q6_K:
60
  mul_mat_q_case<GGML_TYPE_Q6_K>(ctx, args, stream);
61
  break;
62
+ case GGML_TYPE_IQ2_XXS:
63
+ mul_mat_q_case<GGML_TYPE_IQ2_XXS>(ctx, args, stream);
64
+ break;
65
+ case GGML_TYPE_IQ2_XS:
66
+ mul_mat_q_case<GGML_TYPE_IQ2_XS>(ctx, args, stream);
67
+ break;
68
+ case GGML_TYPE_IQ2_S:
69
+ mul_mat_q_case<GGML_TYPE_IQ2_S>(ctx, args, stream);
70
+ break;
71
+ case GGML_TYPE_IQ3_XXS:
72
+ mul_mat_q_case<GGML_TYPE_IQ3_XXS>(ctx, args, stream);
73
+ break;
74
+ case GGML_TYPE_IQ3_S:
75
+ mul_mat_q_case<GGML_TYPE_IQ3_S>(ctx, args, stream);
76
+ break;
77
+ case GGML_TYPE_IQ1_S:
78
+ mul_mat_q_case<GGML_TYPE_IQ1_S>(ctx, args, stream);
79
+ break;
80
  case GGML_TYPE_IQ4_XS:
81
  mul_mat_q_case<GGML_TYPE_IQ4_XS>(ctx, args, stream);
82
  break;
 
111
  case GGML_TYPE_Q4_K:
112
  case GGML_TYPE_Q5_K:
113
  case GGML_TYPE_Q6_K:
114
+ case GGML_TYPE_IQ2_XXS:
115
+ case GGML_TYPE_IQ2_XS:
116
+ case GGML_TYPE_IQ2_S:
117
+ case GGML_TYPE_IQ3_XXS:
118
+ case GGML_TYPE_IQ3_S:
119
+ case GGML_TYPE_IQ1_S:
120
  case GGML_TYPE_IQ4_XS:
121
  case GGML_TYPE_IQ4_NL:
122
  mmq_supported = true;
ggml/src/ggml-cuda/mmq.cuh CHANGED
@@ -63,6 +63,14 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
63
  case GGML_TYPE_Q5_K:
64
  return MMQ_Q8_1_DS_LAYOUT_DS4;
65
  case GGML_TYPE_Q6_K:
 
 
 
 
 
 
 
 
66
  case GGML_TYPE_IQ4_XS:
67
  case GGML_TYPE_IQ4_NL:
68
  return MMQ_Q8_1_DS_LAYOUT_D4;
@@ -131,15 +139,16 @@ static constexpr __device__ int get_mmq_y_device() {
131
  #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
132
  }
133
 
134
- #define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0}
135
- #define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0}
136
- #define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_0 + mmq_y/(QI8_0/2), 0}
137
- #define MMQ_DP4A_TXS_Q8_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_1 + mmq_y/(QI8_1/2), 0}
138
- #define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE + mmq_y, 0}
139
- #define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y, mmq_y*WARP_SIZE/8 + mmq_y/8}
140
- #define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_K + mmq_y/QI4_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
141
- #define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
142
- #define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
 
143
 
144
  static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
145
  return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 :
@@ -152,42 +161,46 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
152
  type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K :
153
  type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K :
154
  type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K :
 
 
 
 
 
 
155
  type == GGML_TYPE_IQ4_XS ? MMQ_DP4A_TXS_Q8_0 :
156
  type == GGML_TYPE_IQ4_NL ? MMQ_DP4A_TXS_Q8_0 :
157
  tile_x_sizes{0, 0, 0};
158
  }
159
 
160
- #define MMQ_MMA_TILE_X_K_Q4_0 (1*WARP_SIZE + WARP_SIZE/QI4_0 + 4)
161
- #define MMQ_MMA_TILE_X_K_Q4_1 (1*WARP_SIZE + WARP_SIZE/QI4_1 + 4)
162
  #define MMQ_MMA_TILE_X_K_Q8_0 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4)
163
  #define MMQ_MMA_TILE_X_K_Q8_1 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4)
164
  #define MMQ_MMA_TILE_X_K_Q2_K (2*WARP_SIZE + WARP_SIZE + 4)
165
- #define MMQ_MMA_TILE_X_K_Q3_K (2*WARP_SIZE + WARP_SIZE/(2*QI3_K) + WARP_SIZE/8 + 7)
166
- #define MMQ_MMA_TILE_X_K_Q4_K (1*WARP_SIZE + WARP_SIZE/QI4_K + WARP_SIZE/8 + 7)
167
- #define MMQ_MMA_TILE_X_K_Q5_K (2*WARP_SIZE + WARP_SIZE/QI5_K + WARP_SIZE/8 + 7)
168
  #define MMQ_MMA_TILE_X_K_Q6_K (2*WARP_SIZE + WARP_SIZE/QI6_K + WARP_SIZE/8 + 7)
169
 
170
- static_assert(MMQ_MMA_TILE_X_K_Q4_0 % 8 == 4, "Wrong padding.");
171
- static_assert(MMQ_MMA_TILE_X_K_Q4_1 % 8 == 4, "Wrong padding.");
172
  static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding.");
173
  static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding.");
174
  static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding.");
175
  static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
176
- static_assert(MMQ_MMA_TILE_X_K_Q4_K % 8 == 4, "Wrong padding.");
177
- static_assert(MMQ_MMA_TILE_X_K_Q5_K % 8 == 4, "Wrong padding.");
178
  static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
179
 
180
  static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
181
- return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q4_0 :
182
- type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q4_1 :
183
  type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
184
  type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q8_1 :
185
  type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
186
  type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K :
187
  type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K :
188
- type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q4_K :
189
- type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q5_K :
190
  type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K :
 
 
 
 
 
 
191
  type == GGML_TYPE_IQ4_XS ? MMQ_MMA_TILE_X_K_Q8_0 :
192
  type == GGML_TYPE_IQ4_NL ? MMQ_MMA_TILE_X_K_Q8_0 :
193
  0;
@@ -216,7 +229,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
216
 
217
  #ifdef INT8_MMA_AVAILABLE
218
  int * x_qs = (int *) x_tile;
219
- float * x_df = (float *) (x_qs + WARP_SIZE);
220
  #else
221
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
222
  int * x_qs = (int *) x_tile;
@@ -235,11 +248,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
235
  }
236
 
237
  const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
 
238
 
239
  #ifdef INT8_MMA_AVAILABLE
240
- x_qs[i*MMQ_MMA_TILE_X_K_Q4_0 + threadIdx.x] = get_int_b2(bxi->qs, kqsx);
 
241
  #else
242
- x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_b2(bxi->qs, kqsx);
243
  #endif // INT8_MMA_AVAILABLE
244
  }
245
 
@@ -257,7 +272,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
257
  const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;
258
 
259
  #ifdef INT8_MMA_AVAILABLE
260
- x_df[i*MMQ_MMA_TILE_X_K_Q4_0 + kbxd] = bxi->d;
261
  #else
262
  x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + kbxd] = bxi->d;
263
  #endif // INT8_MMA_AVAILABLE
@@ -304,95 +319,12 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
304
  }
305
  }
306
 
307
- template <int mmq_x, int mmq_y, int nwarps>
308
- static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma(
309
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
310
- #ifdef INT8_MMA_AVAILABLE
311
-
312
- typedef mma_int_A_I16K8 mma_A;
313
- typedef mma_int_B_J8K8 mma_B;
314
- typedef mma_int_C_I16J8 mma_C;
315
-
316
- constexpr int granularity = mmq_get_granularity_device(mmq_x);
317
- constexpr int rows_per_warp = 2 * granularity;
318
- constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
319
-
320
- y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
321
-
322
- const int * x_qs = (const int *) x;
323
- const float * x_df = (const float *) x_qs + WARP_SIZE;
324
- const int * y_qs = (const int *) y + 4;
325
- const half2 * y_ds = (const half2 *) y;
326
-
327
- mma_A A[ntx][4];
328
- float dA[ntx][mma_C::ne/2][4];
329
-
330
- const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
331
-
332
- #pragma unroll
333
- for (int n = 0; n < ntx; ++n) {
334
- #pragma unroll
335
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_0*QI4_0) {
336
- const int k0 = k00 + k01;
337
-
338
- #pragma unroll
339
- for (int l = 0; l < mma_A::ne; ++l) {
340
- const int i = i0 + n*mma_A::I + mma_A::get_i(l);
341
- const int k = k0/QR4_0 + mma_A::get_k(l) % QI4_0;
342
- const int shift = 4*(mma_A::get_k(l) / QI4_0);
343
-
344
- A[n][k01/(QR4_0*QI4_0)].x[l] = __vsubss4((x_qs[i*MMQ_MMA_TILE_X_K_Q4_0 + k] >> shift) & 0x0F0F0F0F, 0x08080808);
345
- }
346
-
347
- #pragma unroll
348
- for (int l = 0; l < mma_C::ne/2; ++l) {
349
- const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
350
-
351
- dA[n][l][k01/(QR4_0*QI4_0)] = x_df[i*MMQ_MMA_TILE_X_K_Q4_0 + k0/(QR4_0*QI4_0)];
352
- }
353
- }
354
- }
355
-
356
- #pragma unroll
357
- for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
358
- #pragma unroll
359
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_0*QI4_0) {
360
- mma_B B;
361
- float dB[mma_C::ne/2];
362
-
363
- B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
364
-
365
- #pragma unroll
366
- for (int l = 0; l < mma_C::ne/2; ++l) {
367
- const int j = j0 + mma_C::get_j(l);
368
-
369
- dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
370
- }
371
-
372
- #pragma unroll
373
- for (int n = 0; n < ntx; ++n) {
374
- mma_C C;
375
- C.mma_K8(A[n][k01/(QR4_0*QI4_0)], B);
376
-
377
- #pragma unroll
378
- for (int l = 0; l < mma_C::ne; ++l) {
379
- sum[(j0/mma_C::J + n)*mma_C::ne + l] += dA[n][l/2][k01/(QR4_0*QI4_0)]*dB[l%2]*C.x[l];
380
- }
381
- }
382
- }
383
- }
384
- #else
385
- GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
386
- NO_DEVICE_CODE;
387
- #endif // INT8_MMA_AVAILABLE
388
- }
389
-
390
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
391
  const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
392
 
393
  #ifdef INT8_MMA_AVAILABLE
394
  int * x_qs = (int *) x_tile;
395
- half2 * x_dm = (half2 *) (x_qs + WARP_SIZE);
396
  #else
397
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
398
  int * x_qs = (int *) x_tile;
@@ -411,11 +343,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
411
  }
412
 
413
  const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
 
414
 
415
  #ifdef INT8_MMA_AVAILABLE
416
- x_qs[i*MMQ_MMA_TILE_X_K_Q4_1 + threadIdx.x] = get_int_b4(bxi->qs, kqsx);
 
417
  #else
418
- x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_b4(bxi->qs, kqsx);
419
  #endif // INT8_MMA_AVAILABLE
420
  }
421
 
@@ -433,7 +367,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
433
  const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;
434
 
435
  #ifdef INT8_MMA_AVAILABLE
436
- x_dm[i*MMQ_MMA_TILE_X_K_Q4_1 + kbxd] = bxi->dm;
437
  #else
438
  x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + kbxd] = bxi->dm;
439
  #endif // INT8_MMA_AVAILABLE
@@ -480,88 +414,6 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
480
  }
481
  }
482
 
483
- template <int mmq_x, int mmq_y, int nwarps>
484
- static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma(
485
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
486
- #ifdef INT8_MMA_AVAILABLE
487
-
488
- typedef mma_int_A_I16K8 mma_A;
489
- typedef mma_int_A_I16K4 mma_A_K4;
490
- typedef mma_int_B_J8K8 mma_B;
491
- typedef mma_int_C_I16J8 mma_C;
492
-
493
- constexpr int granularity = mmq_get_granularity_device(mmq_x);
494
- constexpr int rows_per_warp = 2 * granularity;
495
- constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
496
-
497
- y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
498
-
499
- const int * x_qs = (const int *) x;
500
- const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE;
501
- const int * y_qs = (const int *) y + 4;
502
- const half2 * y_ds = (const half2 *) y;
503
-
504
- mma_A A[ntx][4];
505
- half2 dmA[ntx][mma_C::ne/2][4];
506
-
507
- const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
508
-
509
- #pragma unroll
510
- for (int n = 0; n < ntx; ++n) {
511
- #pragma unroll
512
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_1*QI4_1) {
513
- const int k0 = k00 + k01;
514
-
515
- A[n][k01/(QR4_1*QI4_1)].load_low(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q4_1 + k0/QR4_1, MMQ_MMA_TILE_X_K_Q4_1);
516
- A[n][k01/(QR4_1*QI4_1)].x[2] = (A[n][k01/(QR4_1*QI4_1)].x[0] >> 4) & 0x0F0F0F0F;
517
- A[n][k01/(QR4_1*QI4_1)].x[3] = (A[n][k01/(QR4_1*QI4_1)].x[1] >> 4) & 0x0F0F0F0F;
518
- A[n][k01/(QR4_1*QI4_1)].x[0] &= 0x0F0F0F0F;
519
- A[n][k01/(QR4_1*QI4_1)].x[1] &= 0x0F0F0F0F;
520
-
521
- #pragma unroll
522
- for (int l = 0; l < mma_C::ne/2; ++l) {
523
- const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
524
-
525
- dmA[n][l][k01/(QR4_1*QI4_1)] = x_dm[i*MMQ_MMA_TILE_X_K_Q4_1 + k0/(QR4_1*QI4_1)];
526
- }
527
- }
528
- }
529
-
530
- #pragma unroll
531
- for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
532
- #pragma unroll
533
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_1*QI4_1) {
534
- mma_B B;
535
- half2 dsB[mma_C::ne/2];
536
-
537
- B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
538
-
539
- #pragma unroll
540
- for (int l = 0; l < mma_C::ne/2; ++l) {
541
- const int j = j0 + mma_C::get_j(l);
542
-
543
- dsB[l] = y_ds[j*MMQ_TILE_Y_K + k01/QI8_1];
544
- }
545
-
546
- #pragma unroll
547
- for (int n = 0; n < ntx; ++n) {
548
- mma_C C;
549
- C.mma_K8(A[n][k01/(QR4_1*QI4_1)], B);
550
-
551
- #pragma unroll
552
- for (int l = 0; l < mma_C::ne; ++l) {
553
- const half2 dmA_dsB = dmA[n][l/2][k01/(QR4_1*QI4_1)]*dsB[l%2];
554
- sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
555
- }
556
- }
557
- }
558
- }
559
- #else
560
- GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
561
- NO_DEVICE_CODE;
562
- #endif // INT8_MMA_AVAILABLE
563
- }
564
-
565
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
566
  const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
567
 
@@ -789,10 +641,9 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
789
  }
790
  }
791
 
792
- template <int mmq_x, int mmq_y, int nwarps>
793
  static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
794
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
795
- #ifdef INT8_MMA_AVAILABLE
796
 
797
  typedef mma_int_A_I16K8 mma_A;
798
  typedef mma_int_B_J8K8 mma_B;
@@ -808,6 +659,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
808
  const float * x_df = (const float *) x_qs + 2*WARP_SIZE;
809
  const int * y_qs = (const int *) y + 4;
810
  const float * y_df = (const float *) y;
 
811
 
812
  mma_A A[ntx][WARP_SIZE/QI8_0];
813
  float dA[ntx][mma_C::ne/2][WARP_SIZE/QI8_0];
@@ -840,18 +692,20 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
840
  for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
841
  #pragma unroll
842
  for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
843
- const int k0 = k00 + k01;
844
-
845
- mma_B B;
846
  float dB[mma_C::ne/2];
847
 
848
- B.load(y_qs + j0*MMQ_TILE_Y_K + k0 % WARP_SIZE, MMQ_TILE_Y_K);
849
 
850
  #pragma unroll
851
  for (int l = 0; l < mma_C::ne/2; ++l) {
852
  const int j = j0 + mma_C::get_j(l);
853
 
854
- dB[l] = y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)];
 
 
 
 
855
  }
856
 
857
  #pragma unroll
@@ -866,10 +720,6 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
866
  }
867
  }
868
  }
869
- #else
870
- GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
871
- NO_DEVICE_CODE;
872
- #endif // INT8_MMA_AVAILABLE
873
  }
874
 
875
  template <int mmq_x, int mmq_y, int nwarps>
@@ -905,7 +755,6 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
905
  template <int mmq_x, int mmq_y, int nwarps>
906
  static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
907
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
908
- #ifdef INT8_MMA_AVAILABLE
909
 
910
  typedef mma_int_A_I16K8 mma_A;
911
  typedef mma_int_B_J8K8 mma_B;
@@ -922,8 +771,8 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
922
  const int * y_qs = (const int *) y + 4;
923
  const half2 * y_dm = (const half2 *) y;
924
 
925
- mma_A A[ntx][WARP_SIZE/QI8_1];
926
- half2 dmA[ntx][mma_C::ne/2][WARP_SIZE/QI8_1];
927
 
928
  const int i0 = (threadIdx.y/ntx)*rows_per_warp;
929
 
@@ -944,7 +793,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
944
  for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
945
  const int k0 = k00 + k01;
946
 
947
- dmA[n][l][k01/QI8_1] = x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1];
948
  }
949
  }
950
  }
@@ -953,18 +802,16 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
953
  for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
954
  #pragma unroll
955
  for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
956
- const int k0 = k00 + k01;
957
-
958
- mma_B B;
959
- half2 dsB[mma_C::ne/2];
960
 
961
- B.load(y_qs + j0*MMQ_TILE_Y_K + k0 % WARP_SIZE, MMQ_TILE_Y_K);
962
 
963
  #pragma unroll
964
  for (int l = 0; l < mma_C::ne/2; ++l) {
965
  const int j = j0 + mma_C::get_j(l);
966
 
967
- dsB[l] = y_dm[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)];
968
  }
969
 
970
  #pragma unroll
@@ -974,8 +821,120 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
974
 
975
  #pragma unroll
976
  for (int l = 0; l < mma_C::ne; ++l) {
977
- const half2 dmA_dsB = dmA[n][l/2][k01/QI8_1]*dsB[l%2];
978
- sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
979
  }
980
  }
981
  }
@@ -1222,7 +1181,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1222
  #ifdef INT8_MMA_AVAILABLE
1223
  int * x_qs = (int *) x_tile;
1224
  float * x_df = (float *) (x_qs + WARP_SIZE*2);
1225
- int * x_sc = (int *) (x_df + 1);
1226
  #else
1227
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
1228
  int * x_qs = (int *) x_tile;
@@ -1262,23 +1220,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1262
  }
1263
  }
1264
 
1265
- #pragma unroll
1266
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps*WARP_SIZE) {
1267
- int i = (i0 + threadIdx.y*WARP_SIZE + threadIdx.x) % mmq_y;
1268
-
1269
- if (need_check) {
1270
- i = min(i, i_max);
1271
- }
1272
-
1273
- const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
1274
-
1275
- #ifdef INT8_MMA_AVAILABLE
1276
- x_df[i*MMQ_MMA_TILE_X_K_Q3_K] = bxi->d;
1277
- #else
1278
- x_df[i] = bxi->d;
1279
- #endif // INT8_MMA_AVAILABLE
1280
- }
1281
-
1282
  #pragma unroll
1283
  for (int i0 = 0; i0 < mmq_y; i0 += nwarps*8) {
1284
  int i = i0 + threadIdx.y*8 + threadIdx.x/(WARP_SIZE/8);
@@ -1302,11 +1243,32 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1302
  const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
1303
 
1304
  #ifdef INT8_MMA_AVAILABLE
1305
- x_sc[i*MMQ_MMA_TILE_X_K_Q3_K + threadIdx.x % (WARP_SIZE/8)] = sc;
 
 
 
 
 
 
1306
  #else
1307
- x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = sc;
1308
  #endif // INT8_MMA_AVAILABLE
1309
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1310
  }
1311
 
1312
  template <int mmq_x, int mmq_y, int nwarps>
@@ -1342,99 +1304,14 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
1342
  }
1343
  }
1344
 
1345
- template <int mmq_x, int mmq_y, int nwarps>
1346
- static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mma(
1347
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
1348
- #ifdef INT8_MMA_AVAILABLE
1349
-
1350
- typedef mma_int_A_I16K4 mma_A;
1351
- typedef mma_int_A_I16K8 mma_A_K8;
1352
- typedef mma_int_B_J8K4 mma_B;
1353
- typedef mma_int_C_I16J8 mma_C;
1354
-
1355
- constexpr int granularity = mmq_get_granularity_device(mmq_x);
1356
- constexpr int rows_per_warp = 2 * granularity;
1357
- constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
1358
-
1359
- y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
1360
-
1361
- const int * x_qs = (const int *) x;
1362
- const float * x_df = (const float *) x_qs + WARP_SIZE*2;
1363
- const int * x_sc = (const int *) x_df + 1;
1364
- const int * y_qs = (const int *) y + 4;
1365
- const float * y_df = (const float *) y;
1366
-
1367
- const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
1368
-
1369
- mma_A A[ntx][8];
1370
- int scA[ntx][mma_C::ne/2][8];
1371
- float dA[ntx][mma_C::ne/2];
1372
-
1373
- #pragma unroll
1374
- for (int n = 0; n < ntx; ++n) {
1375
- #pragma unroll
1376
- for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
1377
- const int k0 = k00 + k01;
1378
-
1379
- ((mma_A_K8 *) A[n])[k01/8].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
1380
- }
1381
-
1382
- #pragma unroll
1383
- for (int l = 0; l < mma_C::ne/2; ++l) {
1384
- const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
1385
-
1386
- #pragma unroll
1387
- for (int k01 = 0; k01 < WARP_SIZE; k01 += 16) {
1388
- const int k0 = k00 + k01;
1389
-
1390
- const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q3_K + k0/16];
1391
- const int8_t * sc = (const int8_t *) &sc_packed;
1392
-
1393
- #pragma unroll
1394
- for (int ksc = 0; ksc < sizeof(int); ++ksc) {
1395
- scA[n][l][k01/4 + ksc] = sc[ksc];
1396
- }
1397
- }
1398
-
1399
- dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K];
1400
- }
1401
- }
1402
-
1403
- #pragma unroll
1404
- for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
1405
- #pragma unroll
1406
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
1407
- mma_B B[2];
1408
- float dB[mma_C::ne/2];
1409
-
1410
- B[0].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
1411
- B[1].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
1412
-
1413
- #pragma unroll
1414
- for (int l = 0; l < mma_C::ne/2; ++l) {
1415
- const int j = j0 + mma_C::get_j(l);
1416
-
1417
- dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
1418
- }
1419
-
1420
- #pragma unroll
1421
- for (int n = 0; n < ntx; ++n) {
1422
- mma_C C[2];
1423
- C[0].mma_K4(A[n][k01/4 + 0], B[0]);
1424
- C[1].mma_K4(A[n][k01/4 + 1], B[1]);
1425
-
1426
- #pragma unroll
1427
- for (int l = 0; l < mma_C::ne; ++l) {
1428
- sum[(j0/mma_C::J + n)*mma_C::ne + l] += dA[n][l/2]*dB[l%2]*
1429
- (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1]);
1430
- }
1431
- }
1432
- }
1433
- }
1434
- #else
1435
- GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
1436
- NO_DEVICE_CODE;
1437
- #endif // INT8_MMA_AVAILABLE
1438
  }
1439
 
1440
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
@@ -1442,8 +1319,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1442
 
1443
  #ifdef INT8_MMA_AVAILABLE
1444
  int * x_qs = (int *) x_tile;
1445
- half2 * x_dm = (half2 *) (x_qs + WARP_SIZE);
1446
- int * x_sc = (int *) (x_dm + WARP_SIZE/QI4_K);
1447
  #else
1448
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
1449
  int * x_qs = (int *) x_tile;
@@ -1451,9 +1327,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1451
  int * x_sc = (int *) (x_dm + txs.dm);
1452
  #endif // INT8_MMA_AVAILABLE
1453
 
1454
- const int kbx = 0; // threadIdx.x / QI4_K
1455
- const int kqsx = threadIdx.x; // threadIdx.x % QI4_K
1456
-
1457
  #pragma unroll
1458
  for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
1459
  int i = i0 + threadIdx.y;
@@ -1462,33 +1335,59 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1462
  i = min(i, i_max);
1463
  }
1464
 
1465
- const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbx;
 
1466
 
1467
  #ifdef INT8_MMA_AVAILABLE
1468
- x_qs[i*MMQ_MMA_TILE_X_K_Q4_K + threadIdx.x] = get_int_b4(bxi->qs, kqsx);
 
1469
  #else
1470
- x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_b4(bxi->qs, kqsx);
1471
  #endif // INT8_MMA_AVAILABLE
1472
  }
1473
 
1474
- const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256
1475
- const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256
1476
 
1477
  #pragma unroll
1478
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_K) {
1479
- int i = (i0 + threadIdx.y * QI4_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
1480
 
1481
  if (need_check) {
1482
  i = min(i, i_max);
1483
  }
1484
 
1485
- const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbxd;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1486
 
1487
- #ifdef INT8_MMA_AVAILABLE
1488
- x_dm[i*MMQ_MMA_TILE_X_K_Q4_K + kbxd] = bxi->dm;
1489
  #else
1490
- x_dm[i*(WARP_SIZE/QI4_K) + i/QI4_K + kbxd] = bxi->dm;
1491
- #endif // INT8_MMA_AVAILABLE
 
 
 
 
 
 
 
 
 
 
1492
  }
1493
 
1494
  #pragma unroll
@@ -1504,17 +1403,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1504
  const int * scales = (const int *) bxi->scales;
1505
 
1506
  const int ksc = threadIdx.x % (WARP_SIZE/8);
 
1507
 
1508
- // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
1509
- int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
1510
- scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits
1511
-
1512
- #ifdef INT8_MMA_AVAILABLE
1513
- x_sc[i*MMQ_MMA_TILE_X_K_Q4_K + ksc] = scales8;
1514
- #else
1515
- x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8;
1516
- #endif // INT8_MMA_AVAILABLE
1517
  }
 
1518
  }
1519
 
1520
  template <int mmq_x, int mmq_y, int nwarps>
@@ -1544,124 +1437,10 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
1544
 
1545
  sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq(
1546
  &x_qs[i*(WARP_SIZE + 1) + k0/2], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
1547
- x_dm[i*(WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
1548
- }
1549
- }
1550
- }
1551
- }
1552
-
1553
- template <int mmq_x, int mmq_y, int nwarps>
1554
- static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma(
1555
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
1556
- #ifdef INT8_MMA_AVAILABLE
1557
-
1558
- typedef mma_int_A_I16K8 mma_A;
1559
- typedef mma_int_B_J8K8 mma_B;
1560
- typedef mma_int_C_I16J8 mma_C;
1561
-
1562
- constexpr int granularity = mmq_get_granularity_device(mmq_x);
1563
- constexpr int rows_per_warp = 2 * granularity;
1564
- constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
1565
-
1566
- y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
1567
-
1568
- const int * x_qs = (const int *) x;
1569
- const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE;
1570
- const int * x_sc = (const int *) x_dm + WARP_SIZE/QI4_K;
1571
- const int * y_qs = (const int *) y + 4;
1572
- const half2 * y_ds = (const half2 *) y;
1573
-
1574
- const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
1575
-
1576
- mma_A A[ntx][4];
1577
- int scA[ntx][mma_C::ne/2][4];
1578
- int mA[ntx][mma_C::ne/2][4];
1579
- half2 dmA[ntx][mma_C::ne/2];
1580
-
1581
- #pragma unroll
1582
- for (int n = 0; n < ntx; ++n) {
1583
- #pragma unroll
1584
- for (int k01 = 0; k01 < WARP_SIZE; k01 += 16) {
1585
- const int k0 = k00 + k01;
1586
-
1587
- A[n][k01/8 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q4_K + k0/QR4_K, MMQ_MMA_TILE_X_K_Q4_K);
1588
-
1589
- #pragma unroll
1590
- for (int l = 0; l < mma_A::ne; ++l) {
1591
- A[n][k01/8 + 1].x[l] = (A[n][k01/8 + 0].x[l] >> 4) & 0x0F0F0F0F;
1592
- A[n][k01/8 + 0].x[l] &= 0x0F0F0F0F;
1593
- }
1594
- }
1595
-
1596
- #pragma unroll
1597
- for (int l = 0; l < mma_C::ne/2; ++l) {
1598
- const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
1599
-
1600
- const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q4_K + (k00/32 + 0)];
1601
- const int m_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q4_K + (k00/32 + 2)];
1602
-
1603
- const uint8_t * sc = (const uint8_t *) &sc_packed;
1604
- const uint8_t * m = (const uint8_t *) &m_packed;
1605
-
1606
- #pragma unroll
1607
- for (int ksc = 0; ksc < sizeof(int); ++ksc) {
1608
- scA[n][l][ksc] = sc[ksc];
1609
- mA[n][l][ksc] = m[ksc];
1610
  }
1611
  }
1612
-
1613
- #pragma unroll
1614
- for (int l = 0; l < mma_C::ne/2; ++l) {
1615
- const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
1616
-
1617
- dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q4_K];
1618
- }
1619
  }
1620
-
1621
- #pragma unroll
1622
- for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
1623
- float tmpd[ntx][mma_C::ne] = {{0.0f}};
1624
- float tmpm[ntx][mma_C::ne] = {{0.0f}};
1625
-
1626
- #pragma unroll
1627
- for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
1628
- mma_B B;
1629
- half2 dsB[mma_C::ne/2];
1630
-
1631
- B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
1632
-
1633
- #pragma unroll
1634
- for (int l = 0; l < mma_C::ne/2; ++l) {
1635
- const int j = j0 + mma_C::get_j(l);
1636
-
1637
- dsB[l] = y_ds[j*MMQ_TILE_Y_K + k01/QI8_1];
1638
- }
1639
-
1640
- #pragma unroll
1641
- for (int n = 0; n < ntx; ++n) {
1642
- mma_C C;
1643
- C.mma_K8(A[n][k01/8], B);
1644
-
1645
- #pragma unroll
1646
- for (int l = 0; l < mma_C::ne; ++l) {
1647
- tmpd[n][l] += (C.x[l]*scA[n][l/2][k01/8]) * __low2float(dsB[l%2]);
1648
- tmpm[n][l] += mA[n][l/2][k01/8] * __high2float(dsB[l%2]);
1649
- }
1650
- }
1651
- }
1652
-
1653
- #pragma unroll
1654
- for (int n = 0; n < ntx; ++n) {
1655
- #pragma unroll
1656
- for (int l = 0; l < mma_C::ne; ++l) {
1657
- sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA[n][l/2])*tmpd[n][l] - __high2float(dmA[n][l/2])*tmpm[n][l];
1658
- }
1659
- }
1660
- }
1661
- #else
1662
- GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
1663
- NO_DEVICE_CODE;
1664
- #endif // INT8_MMA_AVAILABLE
1665
  }
1666
 
1667
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
@@ -1670,7 +1449,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1670
  #ifdef INT8_MMA_AVAILABLE
1671
  int * x_qs = (int *) x_tile;
1672
  half2 * x_dm = (half2 *) (x_qs + WARP_SIZE*2);
1673
- int * x_sc = (int *) (x_dm + WARP_SIZE/QI5_K);
1674
  #else
1675
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
1676
  int * x_qs = (int *) x_tile;
@@ -1678,9 +1456,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1678
  int * x_sc = (int *) (x_dm + txs.dm);
1679
  #endif // INT8_MMA_AVAILABLE
1680
 
1681
- const int kbx = 0; // threadIdx.x / QI5_K
1682
- const int kqsx = threadIdx.x; // threadIdx.x % QI5_K
1683
-
1684
  #pragma unroll
1685
  for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
1686
  int i = i0 + threadIdx.y;
@@ -1689,73 +1464,91 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1689
  i = min(i, i_max);
1690
  }
1691
 
1692
- const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + kbx;
1693
- const int ky = QR5_K*kqsx;
1694
 
1695
- const int ql = get_int_b4(bxi->qs, kqsx);
1696
  const int ql0 = (ql >> 0) & 0x0F0F0F0F;
1697
  const int ql1 = (ql >> 4) & 0x0F0F0F0F;
1698
 
1699
- const int qh = get_int_b4(bxi->qh, kqsx % (QI5_K/4));
1700
- const int qh0 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 0)) << 4) & 0x10101010;
1701
- const int qh1 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 1)) << 4) & 0x10101010;
1702
 
1703
  const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0;
1704
- const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + (QI5_K/4);
1705
 
1706
  #ifdef INT8_MMA_AVAILABLE
1707
- x_qs[i*MMQ_MMA_TILE_X_K_Q5_K + kq0] = ql0 | qh0;
1708
- x_qs[i*MMQ_MMA_TILE_X_K_Q5_K + kq1] = ql1 | qh1;
1709
  #else
1710
  x_qs[i*(2*WARP_SIZE + 1) + kq0] = ql0 | qh0;
1711
  x_qs[i*(2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
1712
  #endif // INT8_MMA_AVAILABLE
1713
  }
1714
 
1715
- const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256
1716
- const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256
1717
 
1718
  #pragma unroll
1719
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_K) {
1720
- int i = (i0 + threadIdx.y * QI5_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
1721
 
1722
  if (need_check) {
1723
  i = min(i, i_max);
1724
  }
1725
 
1726
- const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + kbxd;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1727
 
1728
- #ifdef INT8_MMA_AVAILABLE
1729
- x_dm[i*MMQ_MMA_TILE_X_K_Q5_K + kbxd] = bxi->dm;
1730
  #else
1731
- x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K + kbxd] = bxi->dm;
1732
- #endif // INT8_MMA_AVAILABLE
 
 
 
 
 
 
 
 
 
 
1733
  }
1734
 
1735
  #pragma unroll
1736
- for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
1737
- int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y;
1738
 
1739
  if (need_check) {
1740
  i = min(i, i_max);
1741
  }
1742
 
1743
- const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / (QI5_K/8);
1744
 
1745
  const int * scales = (const int *) bxi->scales;
1746
 
1747
  const int ksc = threadIdx.x % (WARP_SIZE/8);
 
1748
 
1749
- // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
1750
- int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
1751
- scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits
1752
-
1753
- #ifdef INT8_MMA_AVAILABLE
1754
- x_sc[i*MMQ_MMA_TILE_X_K_Q5_K + ksc] = scales8;
1755
- #else
1756
- x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8;
1757
- #endif // INT8_MMA_AVAILABLE
1758
  }
 
1759
  }
1760
 
1761
  template <int mmq_x, int mmq_y, int nwarps>
@@ -1771,134 +1564,24 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
1771
 
1772
  // #pragma unroll
1773
  for (int k01 = 0; k01 < WARP_SIZE; k01 += QR5_K*VDR_Q5_K_Q8_1_MMQ) {
1774
- const int k0 = k00 + k01;
1775
-
1776
- #pragma unroll
1777
- for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
1778
- const int j = j0 + threadIdx.y;
1779
-
1780
- #pragma unroll
1781
- for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
1782
- const int i = i0 + threadIdx.x;
1783
-
1784
- const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k00/32]) + 2*(k01/16);
1785
-
1786
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq(
1787
- &x_qs[i*(QR5_K*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
1788
- x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
1789
- }
1790
- }
1791
- }
1792
- }
1793
-
1794
- template <int mmq_x, int mmq_y, int nwarps>
1795
- static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma(
1796
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
1797
- #ifdef INT8_MMA_AVAILABLE
1798
-
1799
- typedef mma_int_A_I16K8 mma_A;
1800
- typedef mma_int_B_J8K8 mma_B;
1801
- typedef mma_int_C_I16J8 mma_C;
1802
-
1803
- constexpr int granularity = mmq_get_granularity_device(mmq_x);
1804
- constexpr int rows_per_warp = 2 * granularity;
1805
- constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
1806
-
1807
- y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
1808
-
1809
- const int * x_qs = (const int *) x;
1810
- const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE*2;
1811
- const int * x_sc = (const int *) x_dm + WARP_SIZE/QI5_K;
1812
- const int * y_qs = (const int *) y + 4;
1813
- const half2 * y_ds = (const half2 *) y;
1814
-
1815
- const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
1816
-
1817
- mma_A A[ntx][4];
1818
- int scA[ntx][mma_C::ne/2][4];
1819
- int mA[ntx][mma_C::ne/2][4];
1820
- half2 dmA[ntx][mma_C::ne/2];
1821
-
1822
- #pragma unroll
1823
- for (int n = 0; n < ntx; ++n) {
1824
- #pragma unroll
1825
- for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
1826
- const int k0 = k00 + k01;
1827
-
1828
- A[n][k01/8].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q5_K + k0, MMQ_MMA_TILE_X_K_Q5_K);
1829
- }
1830
-
1831
- #pragma unroll
1832
- for (int l = 0; l < mma_C::ne/2; ++l) {
1833
- const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
1834
-
1835
- const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q5_K + (k00/32 + 0)];
1836
- const int m_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q5_K + (k00/32 + 2)];
1837
-
1838
- const uint8_t * sc = (const uint8_t *) &sc_packed;
1839
- const uint8_t * m = (const uint8_t *) &m_packed;
1840
-
1841
- #pragma unroll
1842
- for (int ksc = 0; ksc < sizeof(int); ++ksc) {
1843
- scA[n][l][ksc] = sc[ksc];
1844
- mA[n][l][ksc] = m[ksc];
1845
- }
1846
- }
1847
-
1848
- #pragma unroll
1849
- for (int l = 0; l < mma_C::ne/2; ++l) {
1850
- const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
1851
-
1852
- dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q5_K];
1853
- }
1854
- }
1855
-
1856
- #pragma unroll
1857
- for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
1858
- float tmpd[ntx][mma_C::ne] = {{0.0f}};
1859
- float tmpm[ntx][mma_C::ne] = {{0.0f}};
1860
-
1861
- #pragma unroll
1862
- for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
1863
- const int k0 = k00 + k01;
1864
-
1865
- mma_B B;
1866
- half2 dsB[mma_C::ne/2];
1867
-
1868
- B.load(y_qs + j0*MMQ_TILE_Y_K + k0 % WARP_SIZE, MMQ_TILE_Y_K);
1869
-
1870
- #pragma unroll
1871
- for (int l = 0; l < mma_C::ne/2; ++l) {
1872
- const int j = j0 + mma_C::get_j(l);
1873
-
1874
- dsB[l] = y_ds[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)];
1875
- }
1876
 
1877
  #pragma unroll
1878
- for (int n = 0; n < ntx; ++n) {
1879
- mma_C C;
1880
- C.mma_K8(A[n][k01/8], B);
1881
 
1882
  #pragma unroll
1883
- for (int l = 0; l < mma_C::ne; ++l) {
1884
- tmpd[n][l] += (C.x[l]*scA[n][l/2][k01/8]) * __low2float(dsB[l%2]);
1885
- tmpm[n][l] += mA[n][l/2][k01/8] * __high2float(dsB[l%2]);
1886
- }
1887
- }
1888
- }
1889
 
1890
- #pragma unroll
1891
- for (int n = 0; n < ntx; ++n) {
1892
- #pragma unroll
1893
- for (int l = 0; l < mma_C::ne; ++l) {
1894
- sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA[n][l/2])*tmpd[n][l] - __high2float(dmA[n][l/2])*tmpm[n][l];
1895
  }
1896
  }
1897
  }
1898
- #else
1899
- GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
1900
- NO_DEVICE_CODE;
1901
- #endif // INT8_MMA_AVAILABLE
1902
  }
1903
 
1904
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
@@ -1915,9 +1598,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1915
  int * x_sc = (int *) (x_df + txs.dm);
1916
  #endif // INT8_MMA_AVAILABLE
1917
 
1918
- const int kbx = 0; // threadIdx.x / QI6_K
1919
- const int kqsx = threadIdx.x; // threadIdx.x % QI6_K
1920
-
1921
  #pragma unroll
1922
  for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
1923
  int i = i0 + threadIdx.y;
@@ -1926,19 +1606,18 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1926
  i = min(i, i_max);
1927
  }
1928
 
1929
- const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbx;
1930
- const int ky = QR6_K*kqsx;
1931
 
1932
- const int ql = get_int_b2(bxi->ql, kqsx);
1933
  const int ql0 = (ql >> 0) & 0x0F0F0F0F;
1934
  const int ql1 = (ql >> 4) & 0x0F0F0F0F;
1935
 
1936
- const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4));
1937
- const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) << 4) & 0x30303030;
1938
- const int qh1 = (qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) & 0x30303030;
1939
 
1940
- const int kq0 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + 0;
1941
- const int kq1 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + (QI6_K/2);
1942
 
1943
  #ifdef INT8_MMA_AVAILABLE
1944
  x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
@@ -2187,6 +1866,358 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
2187
  }
2188
  }
2189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2190
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_xs(
2191
  const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
2192
 
@@ -2320,7 +2351,7 @@ template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2320
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_0> {
2321
  static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
2322
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, nwarps, need_check>;
2323
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q4_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
2324
  static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2325
  };
2326
 
@@ -2328,7 +2359,7 @@ template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2328
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_1> {
2329
  static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ;
2330
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, nwarps, need_check>;
2331
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q4_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
2332
  static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2333
  };
2334
 
@@ -2336,7 +2367,7 @@ template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2336
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
2337
  static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ;
2338
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, nwarps, need_check>;
2339
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
2340
  static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2341
  };
2342
 
@@ -2352,7 +2383,7 @@ template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2352
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> {
2353
  static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
2354
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, nwarps, need_check>;
2355
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
2356
  static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2357
  };
2358
 
@@ -2368,7 +2399,7 @@ template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2368
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q3_K> {
2369
  static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ;
2370
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, nwarps, need_check>;
2371
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q3_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
2372
  static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2373
  };
2374
 
@@ -2376,7 +2407,7 @@ template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2376
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> {
2377
  static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ;
2378
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, nwarps, need_check>;
2379
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q4_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
2380
  static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2381
  };
2382
 
@@ -2384,7 +2415,7 @@ template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2384
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> {
2385
  static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ;
2386
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, nwarps, need_check>;
2387
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q5_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
2388
  static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2389
  };
2390
 
@@ -2396,11 +2427,59 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
2396
  static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2397
  };
2398
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2399
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2400
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_NL> {
2401
  static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ;
2402
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl<mmq_y, nwarps, need_check>;
2403
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
2404
  static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2405
  };
2406
 
@@ -2408,7 +2487,7 @@ template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2408
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_XS> {
2409
  static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ;
2410
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs<mmq_y, nwarps, need_check>;
2411
- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
2412
  static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2413
  };
2414
 
@@ -2837,6 +2916,12 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
2837
  extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
2838
  extern DECL_MMQ_CASE(GGML_TYPE_Q5_K);
2839
  extern DECL_MMQ_CASE(GGML_TYPE_Q6_K);
 
 
 
 
 
 
2840
  extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL);
2841
  extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS);
2842
 
 
63
  case GGML_TYPE_Q5_K:
64
  return MMQ_Q8_1_DS_LAYOUT_DS4;
65
  case GGML_TYPE_Q6_K:
66
+ case GGML_TYPE_IQ2_XXS:
67
+ case GGML_TYPE_IQ2_XS:
68
+ case GGML_TYPE_IQ2_S:
69
+ case GGML_TYPE_IQ3_XXS:
70
+ case GGML_TYPE_IQ3_S:
71
+ return MMQ_Q8_1_DS_LAYOUT_D4;
72
+ case GGML_TYPE_IQ1_S:
73
+ return MMQ_Q8_1_DS_LAYOUT_DS4;
74
  case GGML_TYPE_IQ4_XS:
75
  case GGML_TYPE_IQ4_NL:
76
  return MMQ_Q8_1_DS_LAYOUT_D4;
 
139
  #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
140
  }
141
 
142
+ #define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0}
143
+ #define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0}
144
+ #define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_0 + mmq_y/(QI8_0/2), 0}
145
+ #define MMQ_DP4A_TXS_Q8_0_16 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*4/QI8_0 + mmq_y/(QI8_0/4), 0}
146
+ #define MMQ_DP4A_TXS_Q8_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_1 + mmq_y/(QI8_1/2), 0}
147
+ #define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE + mmq_y, 0}
148
+ #define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y, mmq_y*WARP_SIZE/8 + mmq_y/8}
149
+ #define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
150
+ #define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
151
+ #define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
152
 
153
  static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
154
  return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 :
 
161
  type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K :
162
  type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K :
163
  type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K :
164
+ type == GGML_TYPE_IQ2_XXS ? MMQ_DP4A_TXS_Q8_0 :
165
+ type == GGML_TYPE_IQ2_XS ? MMQ_DP4A_TXS_Q8_0_16 :
166
+ type == GGML_TYPE_IQ2_S ? MMQ_DP4A_TXS_Q8_0_16 :
167
+ type == GGML_TYPE_IQ3_XXS ? MMQ_DP4A_TXS_Q8_0 :
168
+ type == GGML_TYPE_IQ3_S ? MMQ_DP4A_TXS_Q8_0 :
169
+ type == GGML_TYPE_IQ1_S ? MMQ_DP4A_TXS_Q8_0 :
170
  type == GGML_TYPE_IQ4_XS ? MMQ_DP4A_TXS_Q8_0 :
171
  type == GGML_TYPE_IQ4_NL ? MMQ_DP4A_TXS_Q8_0 :
172
  tile_x_sizes{0, 0, 0};
173
  }
174
 
 
 
175
  #define MMQ_MMA_TILE_X_K_Q8_0 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4)
176
  #define MMQ_MMA_TILE_X_K_Q8_1 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4)
177
  #define MMQ_MMA_TILE_X_K_Q2_K (2*WARP_SIZE + WARP_SIZE + 4)
178
+ #define MMQ_MMA_TILE_X_K_Q3_K (2*WARP_SIZE + WARP_SIZE/2 + 4)
 
 
179
  #define MMQ_MMA_TILE_X_K_Q6_K (2*WARP_SIZE + WARP_SIZE/QI6_K + WARP_SIZE/8 + 7)
180
 
 
 
181
  static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding.");
182
  static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding.");
183
  static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding.");
184
  static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
 
 
185
  static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
186
 
187
  static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
188
+ return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
189
+ type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q8_1 :
190
  type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
191
  type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q8_1 :
192
  type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
193
  type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K :
194
  type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K :
195
+ type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q8_1 :
196
+ type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q8_1 :
197
  type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K :
198
+ type == GGML_TYPE_IQ2_XXS ? MMQ_MMA_TILE_X_K_Q8_0 :
199
+ type == GGML_TYPE_IQ2_XS ? MMQ_MMA_TILE_X_K_Q3_K :
200
+ type == GGML_TYPE_IQ2_S ? MMQ_MMA_TILE_X_K_Q3_K :
201
+ type == GGML_TYPE_IQ3_XXS ? MMQ_MMA_TILE_X_K_Q8_0 :
202
+ type == GGML_TYPE_IQ3_S ? MMQ_MMA_TILE_X_K_Q8_0 :
203
+ type == GGML_TYPE_IQ1_S ? MMQ_MMA_TILE_X_K_Q8_0 :
204
  type == GGML_TYPE_IQ4_XS ? MMQ_MMA_TILE_X_K_Q8_0 :
205
  type == GGML_TYPE_IQ4_NL ? MMQ_MMA_TILE_X_K_Q8_0 :
206
  0;
 
229
 
230
  #ifdef INT8_MMA_AVAILABLE
231
  int * x_qs = (int *) x_tile;
232
+ float * x_df = (float *) (x_qs + 2*WARP_SIZE);
233
  #else
234
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
235
  int * x_qs = (int *) x_tile;
 
248
  }
249
 
250
  const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
251
+ const int qs0 = get_int_b2(bxi->qs, kqsx);
252
 
253
  #ifdef INT8_MMA_AVAILABLE
254
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0] = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808);
255
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808);
256
  #else
257
+ x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
258
  #endif // INT8_MMA_AVAILABLE
259
  }
260
 
 
272
  const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;
273
 
274
  #ifdef INT8_MMA_AVAILABLE
275
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
276
  #else
277
  x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + kbxd] = bxi->d;
278
  #endif // INT8_MMA_AVAILABLE
 
319
  }
320
  }
321
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
323
  const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
324
 
325
  #ifdef INT8_MMA_AVAILABLE
326
  int * x_qs = (int *) x_tile;
327
+ half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
328
  #else
329
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
330
  int * x_qs = (int *) x_tile;
 
343
  }
344
 
345
  const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
346
+ const int qs0 = get_int_b4(bxi->qs, kqsx);
347
 
348
  #ifdef INT8_MMA_AVAILABLE
349
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0] = (qs0 >> 0) & 0x0F0F0F0F;
350
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F;
351
  #else
352
+ x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
353
  #endif // INT8_MMA_AVAILABLE
354
  }
355
 
 
367
  const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;
368
 
369
  #ifdef INT8_MMA_AVAILABLE
370
+ x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
371
  #else
372
  x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + kbxd] = bxi->dm;
373
  #endif // INT8_MMA_AVAILABLE
 
414
  }
415
  }
416
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
417
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
418
  const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
419
 
 
641
  }
642
  }
643
 
644
+ template <int mmq_x, int mmq_y, int nwarps, mmq_q8_1_ds_layout ds_layout>
645
  static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
646
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
 
647
 
648
  typedef mma_int_A_I16K8 mma_A;
649
  typedef mma_int_B_J8K8 mma_B;
 
659
  const float * x_df = (const float *) x_qs + 2*WARP_SIZE;
660
  const int * y_qs = (const int *) y + 4;
661
  const float * y_df = (const float *) y;
662
+ const half2 * y_ds = (const half2 *) y;
663
 
664
  mma_A A[ntx][WARP_SIZE/QI8_0];
665
  float dA[ntx][mma_C::ne/2][WARP_SIZE/QI8_0];
 
692
  for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
693
  #pragma unroll
694
  for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
695
+ mma_B B;
 
 
696
  float dB[mma_C::ne/2];
697
 
698
+ B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
699
 
700
  #pragma unroll
701
  for (int l = 0; l < mma_C::ne/2; ++l) {
702
  const int j = j0 + mma_C::get_j(l);
703
 
704
+ if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
705
+ dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
706
+ } else {
707
+ dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
708
+ }
709
  }
710
 
711
  #pragma unroll
 
720
  }
721
  }
722
  }
 
 
 
 
723
  }
724
 
725
  template <int mmq_x, int mmq_y, int nwarps>
 
755
  template <int mmq_x, int mmq_y, int nwarps>
756
  static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
757
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
 
758
 
759
  typedef mma_int_A_I16K8 mma_A;
760
  typedef mma_int_B_J8K8 mma_B;
 
771
  const int * y_qs = (const int *) y + 4;
772
  const half2 * y_dm = (const half2 *) y;
773
 
774
+ mma_A A[ntx][WARP_SIZE/QI8_1];
775
+ float2 dmA[ntx][mma_C::ne/2][WARP_SIZE/QI8_1];
776
 
777
  const int i0 = (threadIdx.y/ntx)*rows_per_warp;
778
 
 
793
  for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
794
  const int k0 = k00 + k01;
795
 
796
+ dmA[n][l][k01/QI8_1] = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]);
797
  }
798
  }
799
  }
 
802
  for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
803
  #pragma unroll
804
  for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
805
+ mma_B B;
806
+ float2 dsB[mma_C::ne/2];
 
 
807
 
808
+ B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
809
 
810
  #pragma unroll
811
  for (int l = 0; l < mma_C::ne/2; ++l) {
812
  const int j = j0 + mma_C::get_j(l);
813
 
814
+ dsB[l] = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
815
  }
816
 
817
  #pragma unroll
 
821
 
822
  #pragma unroll
823
  for (int l = 0; l < mma_C::ne; ++l) {
824
+ sum[(j0/mma_C::J + n)*mma_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l];
825
+ sum[(j0/mma_C::J + n)*mma_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y;
826
+ }
827
+ }
828
+ }
829
+ }
830
+ }
831
+
832
+ template <int mmq_x, int mmq_y, int nwarps>
833
+ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
834
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
835
+
836
+ constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
837
+ const int * x_qs = (const int *) x;
838
+ const float * x_df = (const float *) x_qs + txs.qs;
839
+ const int * y_qs = (const int *) y + 4;
840
+ const float * y_df = (const float *) y;
841
+
842
+ // #pragma unroll
843
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
844
+ const int k0 = k00 + k01;
845
+
846
+ #pragma unroll
847
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
848
+ const int j = j0 + threadIdx.y;
849
+
850
+ #pragma unroll
851
+ for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
852
+ const int i = i0 + threadIdx.x;
853
+
854
+ sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_16_q8_1_impl<QI8_0>(
855
+ &x_qs[i*(2*WARP_SIZE + 1) + k0],
856
+ &y_qs[j*MMQ_TILE_Y_K + k01],
857
+ &x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + k0/(QI8_0/2)],
858
+ y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
859
+ }
860
+ }
861
+ }
862
+ }
863
+
864
+ template <int mmq_x, int mmq_y, int nwarps>
865
+ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
866
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
867
+ #ifdef INT8_MMA_AVAILABLE
868
+
869
+ typedef mma_int_A_I16K4 mma_A;
870
+ typedef mma_int_A_I16K8 mma_A_K8;
871
+ typedef mma_int_B_J8K4 mma_B;
872
+ typedef mma_int_C_I16J8 mma_C;
873
+
874
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
875
+ constexpr int rows_per_warp = 2 * granularity;
876
+ constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
877
+
878
+ y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
879
+
880
+ const int * x_qs = (const int *) x;
881
+ const float * x_df = (const float *) x_qs + WARP_SIZE*2;
882
+ const int * y_qs = (const int *) y + 4;
883
+ const float * y_df = (const float *) y;
884
+
885
+ const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
886
+
887
+ mma_A A[ntx][8];
888
+ float dA[ntx][mma_C::ne/2][8];
889
+
890
+ #pragma unroll
891
+ for (int n = 0; n < ntx; ++n) {
892
+ #pragma unroll
893
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
894
+ const int k0 = k00 + k01;
895
+
896
+ ((mma_A_K8 *) A[n])[k01/8].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
897
+ }
898
+
899
+ #pragma unroll
900
+ for (int l = 0; l < mma_C::ne/2; ++l) {
901
+ const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
902
+
903
+ #pragma unroll
904
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += 4) {
905
+ const int k0 = k00 + k01;
906
+
907
+ dA[n][l][k01/4] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4];
908
+ }
909
+ }
910
+ }
911
+
912
+ #pragma unroll
913
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
914
+ #pragma unroll
915
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
916
+ mma_B B[2];
917
+ float dB[mma_C::ne/2];
918
+
919
+ B[0].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
920
+ B[1].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
921
+
922
+ #pragma unroll
923
+ for (int l = 0; l < mma_C::ne/2; ++l) {
924
+ const int j = j0 + mma_C::get_j(l);
925
+
926
+ dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
927
+ }
928
+
929
+ #pragma unroll
930
+ for (int n = 0; n < ntx; ++n) {
931
+ mma_C C[2];
932
+ C[0].mma_K4(A[n][k01/4 + 0], B[0]);
933
+ C[1].mma_K4(A[n][k01/4 + 1], B[1]);
934
+
935
+ #pragma unroll
936
+ for (int l = 0; l < mma_C::ne; ++l) {
937
+ sum[(j0/mma_C::J + n)*mma_C::ne + l] += dB[l%2]*(C[0].x[l]*dA[n][l/2][k01/4 + 0] + C[1].x[l]*dA[n][l/2][k01/4 + 1]);
938
  }
939
  }
940
  }
 
1181
  #ifdef INT8_MMA_AVAILABLE
1182
  int * x_qs = (int *) x_tile;
1183
  float * x_df = (float *) (x_qs + WARP_SIZE*2);
 
1184
  #else
1185
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
1186
  int * x_qs = (int *) x_tile;
 
1220
  }
1221
  }
1222
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1223
  #pragma unroll
1224
  for (int i0 = 0; i0 < mmq_y; i0 += nwarps*8) {
1225
  int i = i0 + threadIdx.y*8 + threadIdx.x/(WARP_SIZE/8);
 
1243
  const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
1244
 
1245
  #ifdef INT8_MMA_AVAILABLE
1246
+ const int8_t * sc8 = (const int8_t *) &sc;
1247
+ const float d = bxi->d;
1248
+
1249
+ #pragma unroll
1250
+ for (int l = 0; l < sizeof(int); ++l) {
1251
+ x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*(threadIdx.x % (WARP_SIZE/8)) + l] = d*sc8[l];
1252
+ }
1253
  #else
1254
+ x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = sc;
1255
  #endif // INT8_MMA_AVAILABLE
1256
  }
1257
+
1258
+ #ifndef INT8_MMA_AVAILABLE
1259
+ #pragma unroll
1260
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*WARP_SIZE) {
1261
+ int i = (i0 + threadIdx.y*WARP_SIZE + threadIdx.x) % mmq_y;
1262
+
1263
+ if (need_check) {
1264
+ i = min(i, i_max);
1265
+ }
1266
+
1267
+ const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
1268
+
1269
+ x_df[i] = bxi->d;
1270
+ }
1271
+ #endif // INT8_MMA_AVAILABLE
1272
  }
1273
 
1274
  template <int mmq_x, int mmq_y, int nwarps>
 
1304
  }
1305
  }
1306
 
1307
+ static __device__ __forceinline__ int unpack_scales_q45_K(const int * scales, const int ksc) {
1308
+ // scale arrangement after the following two lines:
1309
+ // - ksc == 0: sc0, sc1, sc2, sc3
1310
+ // - ksc == 1: sc4, sc5, sc6, sc7
1311
+ // - ksc == 2: m0, m1, m2, m3
1312
+ // - ksc == 3: m4, m5, m6, m7
1313
+ return ((scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F) | // lower 4 bits
1314
+ ((scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030); // upper 2 bits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1315
  }
1316
 
1317
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
 
1319
 
1320
  #ifdef INT8_MMA_AVAILABLE
1321
  int * x_qs = (int *) x_tile;
1322
+ half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
 
1323
  #else
1324
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
1325
  int * x_qs = (int *) x_tile;
 
1327
  int * x_sc = (int *) (x_dm + txs.dm);
1328
  #endif // INT8_MMA_AVAILABLE
1329
 
 
 
 
1330
  #pragma unroll
1331
  for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
1332
  int i = i0 + threadIdx.y;
 
1335
  i = min(i, i_max);
1336
  }
1337
 
1338
+ const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
1339
+ const int qs0 = get_int_b4(bxi->qs, threadIdx.x);
1340
 
1341
  #ifdef INT8_MMA_AVAILABLE
1342
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F;
1343
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F;
1344
  #else
1345
+ x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
1346
  #endif // INT8_MMA_AVAILABLE
1347
  }
1348
 
1349
+ #ifdef INT8_MMA_AVAILABLE
 
1350
 
1351
  #pragma unroll
1352
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) {
1353
+ int i = (i0 + threadIdx.y*16 + threadIdx.x/(WARP_SIZE/16)) % mmq_y;
1354
 
1355
  if (need_check) {
1356
  i = min(i, i_max);
1357
  }
1358
 
1359
+ const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
1360
+
1361
+ const int * scales = (const int *) bxi->scales;
1362
+ const int ksc = threadIdx.x % (WARP_SIZE/16);
1363
+
1364
+ const int sc32 = unpack_scales_q45_K(scales, ksc + 0);
1365
+ const int m32 = unpack_scales_q45_K(scales, ksc + 2);
1366
+
1367
+ const uint8_t * sc8 = (const uint8_t *) &sc32;
1368
+ const uint8_t * m8 = (const uint8_t *) &m32;
1369
+
1370
+ const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
1371
+
1372
+ #pragma unroll
1373
+ for (int l = 0; l < sizeof(int); ++l) {
1374
+ x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
1375
+ }
1376
+ }
1377
 
 
 
1378
  #else
1379
+
1380
+ #pragma unroll
1381
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*QI4_K) {
1382
+ int i = (i0 + threadIdx.y*QI4_K + threadIdx.x) % mmq_y;
1383
+
1384
+ if (need_check) {
1385
+ i = min(i, i_max);
1386
+ }
1387
+
1388
+ const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
1389
+
1390
+ x_dm[i] = bxi->dm;
1391
  }
1392
 
1393
  #pragma unroll
 
1403
  const int * scales = (const int *) bxi->scales;
1404
 
1405
  const int ksc = threadIdx.x % (WARP_SIZE/8);
1406
+ const int scales8 = unpack_scales_q45_K(scales, ksc);
1407
 
1408
+ x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8;
 
 
 
 
 
 
 
 
1409
  }
1410
+ #endif // INT8_MMA_AVAILABLE
1411
  }
1412
 
1413
  template <int mmq_x, int mmq_y, int nwarps>
 
1437
 
1438
  sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq(
1439
  &x_qs[i*(WARP_SIZE + 1) + k0/2], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
1440
+ x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1441
  }
1442
  }
 
 
 
 
 
 
 
1443
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1444
  }
1445
 
1446
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
 
1449
  #ifdef INT8_MMA_AVAILABLE
1450
  int * x_qs = (int *) x_tile;
1451
  half2 * x_dm = (half2 *) (x_qs + WARP_SIZE*2);
 
1452
  #else
1453
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
1454
  int * x_qs = (int *) x_tile;
 
1456
  int * x_sc = (int *) (x_dm + txs.dm);
1457
  #endif // INT8_MMA_AVAILABLE
1458
 
 
 
 
1459
  #pragma unroll
1460
  for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
1461
  int i = i0 + threadIdx.y;
 
1464
  i = min(i, i_max);
1465
  }
1466
 
1467
+ const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
1468
+ const int ky = QR5_K*threadIdx.x;
1469
 
1470
+ const int ql = get_int_b4(bxi->qs, threadIdx.x);
1471
  const int ql0 = (ql >> 0) & 0x0F0F0F0F;
1472
  const int ql1 = (ql >> 4) & 0x0F0F0F0F;
1473
 
1474
+ const int qh = get_int_b4(bxi->qh, threadIdx.x % (QI5_K/4));
1475
+ const int qh0 = ((qh >> (2 * (threadIdx.x / (QI5_K/4)) + 0)) << 4) & 0x10101010;
1476
+ const int qh1 = ((qh >> (2 * (threadIdx.x / (QI5_K/4)) + 1)) << 4) & 0x10101010;
1477
 
1478
  const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0;
1479
+ const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + QI5_K/4;
1480
 
1481
  #ifdef INT8_MMA_AVAILABLE
1482
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0;
1483
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1;
1484
  #else
1485
  x_qs[i*(2*WARP_SIZE + 1) + kq0] = ql0 | qh0;
1486
  x_qs[i*(2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
1487
  #endif // INT8_MMA_AVAILABLE
1488
  }
1489
 
1490
+ #ifdef INT8_MMA_AVAILABLE
 
1491
 
1492
  #pragma unroll
1493
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) {
1494
+ int i = (i0 + threadIdx.y*16 + threadIdx.x/(WARP_SIZE/16)) % mmq_y;
1495
 
1496
  if (need_check) {
1497
  i = min(i, i_max);
1498
  }
1499
 
1500
+ const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
1501
+
1502
+ const int * scales = (const int *) bxi->scales;
1503
+ const int ksc = threadIdx.x % (WARP_SIZE/16);
1504
+
1505
+ const int sc32 = unpack_scales_q45_K(scales, ksc + 0);
1506
+ const int m32 = unpack_scales_q45_K(scales, ksc + 2);
1507
+
1508
+ const uint8_t * sc8 = (const uint8_t *) &sc32;
1509
+ const uint8_t * m8 = (const uint8_t *) &m32;
1510
+
1511
+ const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
1512
+
1513
+ #pragma unroll
1514
+ for (int l = 0; l < sizeof(int); ++l) {
1515
+ x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
1516
+ }
1517
+ }
1518
 
 
 
1519
  #else
1520
+
1521
+ #pragma unroll
1522
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*QI5_K) {
1523
+ int i = (i0 + threadIdx.y*QI5_K + threadIdx.x) % mmq_y;
1524
+
1525
+ if (need_check) {
1526
+ i = min(i, i_max);
1527
+ }
1528
+
1529
+ const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
1530
+
1531
+ x_dm[i] = bxi->dm;
1532
  }
1533
 
1534
  #pragma unroll
1535
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*8) {
1536
+ int i = (i0 + threadIdx.y*8 + threadIdx.x/(WARP_SIZE/8)) % mmq_y;
1537
 
1538
  if (need_check) {
1539
  i = min(i, i_max);
1540
  }
1541
 
1542
+ const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
1543
 
1544
  const int * scales = (const int *) bxi->scales;
1545
 
1546
  const int ksc = threadIdx.x % (WARP_SIZE/8);
1547
+ const int scales8 = unpack_scales_q45_K(scales, ksc);
1548
 
1549
+ x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8;
 
 
 
 
 
 
 
 
1550
  }
1551
+ #endif // INT8_MMA_AVAILABLE
1552
  }
1553
 
1554
  template <int mmq_x, int mmq_y, int nwarps>
 
1564
 
1565
  // #pragma unroll
1566
  for (int k01 = 0; k01 < WARP_SIZE; k01 += QR5_K*VDR_Q5_K_Q8_1_MMQ) {
1567
+ const int k0 = k00 + k01;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1568
 
1569
  #pragma unroll
1570
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
1571
+ const int j = j0 + threadIdx.y;
 
1572
 
1573
  #pragma unroll
1574
+ for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
1575
+ const int i = i0 + threadIdx.x;
 
 
 
 
1576
 
1577
+ const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k00/32]) + 2*(k01/16);
1578
+
1579
+ sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq(
1580
+ &x_qs[i*(QR5_K*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
1581
+ x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
1582
  }
1583
  }
1584
  }
 
 
 
 
1585
  }
1586
 
1587
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
 
1598
  int * x_sc = (int *) (x_df + txs.dm);
1599
  #endif // INT8_MMA_AVAILABLE
1600
 
 
 
 
1601
  #pragma unroll
1602
  for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
1603
  int i = i0 + threadIdx.y;
 
1606
  i = min(i, i_max);
1607
  }
1608
 
1609
+ const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride;
 
1610
 
1611
+ const int ql = get_int_b2(bxi->ql, threadIdx.x);
1612
  const int ql0 = (ql >> 0) & 0x0F0F0F0F;
1613
  const int ql1 = (ql >> 4) & 0x0F0F0F0F;
1614
 
1615
+ const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (threadIdx.x / (QI6_K/2)) + threadIdx.x % (QI6_K/4));
1616
+ const int qh0 = ((qh >> ((threadIdx.x & 0x08) >> 2)) << 4) & 0x30303030;
1617
+ const int qh1 = (qh >> ((threadIdx.x & 0x08) >> 2)) & 0x30303030;
1618
 
1619
+ const int kq0 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + 0;
1620
+ const int kq1 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + QI6_K/2;
1621
 
1622
  #ifdef INT8_MMA_AVAILABLE
1623
  x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
 
1866
  }
1867
  }
1868
 
1869
+ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xxs(
1870
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
1871
+
1872
+ #ifdef INT8_MMA_AVAILABLE
1873
+ int * x_qs = (int *) x_tile;
1874
+ float * x_df = (float *) (x_qs + WARP_SIZE*2);
1875
+ #else
1876
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y);
1877
+ int * x_qs = (int *) x_tile;
1878
+ float * x_df = (float *) (x_qs + txs.qs);
1879
+ #endif // INT8_MMA_AVAILABLE
1880
+
1881
+ const int kqsx = threadIdx.x % (QI2_XXS/2);
1882
+
1883
+ #pragma unroll
1884
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_XXS/2)) {
1885
+ int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_XXS) + threadIdx.x/(QI2_XXS/2);
1886
+
1887
+ if (need_check) {
1888
+ i = min(i, i_max);
1889
+ }
1890
+
1891
+ const block_iq2_xxs * bxi = (const block_iq2_xxs *) x + kbx0 + i*stride;
1892
+
1893
+ const int q2 = get_int_b2(bxi->qs, 2*kqsx+0);
1894
+ const uint8_t * aux8 = (const uint8_t *) &q2;
1895
+ const uint32_t aux32 = get_int_b2(bxi->qs, 2*kqsx+1);
1896
+
1897
+ #pragma unroll
1898
+ for (int l = 0; l < QR2_XXS; ++l) {
1899
+ const int * grid_pos = (const int *) (iq2xxs_grid + aux8[l]);
1900
+ const int signs_packed = ksigns_iq2xs[(aux32 >> (7*l)) & 0x7F];
1901
+
1902
+ const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000);
1903
+ const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0);
1904
+
1905
+ const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
1906
+ const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
1907
+
1908
+ #ifdef INT8_MMA_AVAILABLE
1909
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0;
1910
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1;
1911
+ #else
1912
+ x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid0;
1913
+ x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid1;
1914
+ #endif // INT8_MMA_AVAILABLE
1915
+ }
1916
+
1917
+ const int ls = aux32 >> 28;
1918
+ const float d = bxi->d;
1919
+ #ifdef INT8_MMA_AVAILABLE
1920
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4;
1921
+ #else
1922
+ x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = (ls*d + d/2)/4;
1923
+ #endif // INT8_MMA_AVAILABLE
1924
+ }
1925
+ }
1926
+
1927
+ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xs(
1928
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
1929
+
1930
+ #ifdef INT8_MMA_AVAILABLE
1931
+ int * x_qs = (int *) x_tile;
1932
+ float * x_df = (float *) (x_qs + WARP_SIZE*2);
1933
+ #else
1934
+ constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
1935
+ int * x_qs = (int *) x_tile;
1936
+ float * x_df = (float *) (x_qs + txs.qs);
1937
+ #endif // INT8_MMA_AVAILABLE
1938
+
1939
+ const int kqsx = threadIdx.x % (QI2_XS/2);
1940
+
1941
+ #pragma unroll
1942
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_XS/2)) {
1943
+ int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_XS) + threadIdx.x/(QI2_XS/2);
1944
+
1945
+ if (need_check) {
1946
+ i = min(i, i_max);
1947
+ }
1948
+
1949
+ const block_iq2_xs * bxi = (const block_iq2_xs *) x + kbx0 + i*stride;
1950
+
1951
+ const int2 q2_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1));
1952
+ const uint16_t * q2 = (const uint16_t *) &q2_packed;
1953
+
1954
+ #pragma unroll
1955
+ for (int l = 0; l < QR2_XS; ++l) {
1956
+ const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l] & 0x000001FF));
1957
+ const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9));
1958
+
1959
+ const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
1960
+ const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
1961
+
1962
+ #ifdef INT8_MMA_AVAILABLE
1963
+ x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
1964
+ x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
1965
+ #else
1966
+ x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l;
1967
+ x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h;
1968
+ #endif // INT8_MMA_AVAILABLE
1969
+ }
1970
+
1971
+ const int ls = bxi->scales[kqsx];
1972
+ const float d = bxi->d;
1973
+ #ifdef INT8_MMA_AVAILABLE
1974
+ x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
1975
+ x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
1976
+ #else
1977
+ x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
1978
+ x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
1979
+ #endif // INT8_MMA_AVAILABLE
1980
+ }
1981
+ }
1982
+
1983
+ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_s(
1984
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
1985
+
1986
+ #ifdef INT8_MMA_AVAILABLE
1987
+ int * x_qs = (int *) x_tile;
1988
+ float * x_df = (float *) (x_qs + WARP_SIZE*2);
1989
+ #else
1990
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y);
1991
+ int * x_qs = (int *) x_tile;
1992
+ float * x_df = (float *) (x_qs + txs.qs);
1993
+ #endif // INT8_MMA_AVAILABLE
1994
+
1995
+ const int kqsx = threadIdx.x % (QI2_S/2);
1996
+
1997
+ #pragma unroll
1998
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_S/2)) {
1999
+ int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_S) + threadIdx.x/(QI2_S/2);
2000
+
2001
+ if (need_check) {
2002
+ i = min(i, i_max);
2003
+ }
2004
+
2005
+ const block_iq2_s * bxi = (const block_iq2_s *) x + kbx0 + i*stride;
2006
+
2007
+ const int qs_packed = get_int_b2(bxi->qs, kqsx);
2008
+ const uint8_t * qs = (const uint8_t *) &qs_packed;
2009
+
2010
+ const int qh = bxi->qh[kqsx];
2011
+
2012
+ const int signs_packed_32 = get_int_b2(bxi->qs, QK_K/32 + kqsx);
2013
+ const uint8_t * signs_packed_8 = (const uint8_t *) &signs_packed_32;
2014
+
2015
+ #pragma unroll
2016
+ for (int l = 0; l < QR2_S; ++l) {
2017
+ const int * grid_pos = (const int *)(iq2s_grid + (qs[l] | ((qh << (8-2*l)) & 0x300)));
2018
+
2019
+ const int signs0 = __vcmpne4(((signs_packed_8[l] & 0x03) << 7) | ((signs_packed_8[l] & 0x0C) << 21), 0x00000000);
2020
+ const int signs1 = __vcmpne4(((signs_packed_8[l] & 0x30) << 3) | ((signs_packed_8[l] & 0xC0) << 17), 0x00000000);
2021
+
2022
+ const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0);
2023
+ const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1);
2024
+
2025
+ #ifdef INT8_MMA_AVAILABLE
2026
+ x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
2027
+ x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
2028
+ #else
2029
+ x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l;
2030
+ x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h;
2031
+ #endif // INT8_MMA_AVAILABLE
2032
+ }
2033
+
2034
+ const int ls = bxi->scales[kqsx];
2035
+ const float d = bxi->d;
2036
+ #ifdef INT8_MMA_AVAILABLE
2037
+ x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2038
+ x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2039
+ #else
2040
+ x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2041
+ x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2042
+ #endif // INT8_MMA_AVAILABLE
2043
+ }
2044
+ }
2045
+
2046
+ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_xxs(
2047
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
2048
+
2049
+ #ifdef INT8_MMA_AVAILABLE
2050
+ int * x_qs = (int *) x_tile;
2051
+ float * x_df = (float *) (x_qs + WARP_SIZE*2);
2052
+ #else
2053
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y);
2054
+ int * x_qs = (int *) x_tile;
2055
+ float * x_df = (float *) (x_qs + txs.qs);
2056
+ #endif // INT8_MMA_AVAILABLE
2057
+
2058
+ const int kqsx = threadIdx.x % (QI3_XXS/2);
2059
+
2060
+ #pragma unroll
2061
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI3_XXS/2)) {
2062
+ int i = i0 + threadIdx.y*(2*WARP_SIZE/QI3_XXS) + threadIdx.x/(QI3_XXS/2);
2063
+
2064
+ if (need_check) {
2065
+ i = min(i, i_max);
2066
+ }
2067
+
2068
+ const block_iq3_xxs * bxi = (const block_iq3_xxs *) x + kbx0 + i*stride;
2069
+
2070
+ const int2 q3_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1));
2071
+ const uint8_t * q3 = (const uint8_t *) &q3_packed;
2072
+ const uint32_t aux32 = get_int_b2(bxi->qs, QK_K/16 + kqsx);
2073
+
2074
+ #pragma unroll
2075
+ for (int l = 0; l < QR3_XXS; ++l) {
2076
+ const int2 grid_pos = make_int2(iq3xxs_grid[q3[2*l+0]], iq3xxs_grid[q3[2*l+1]]);
2077
+
2078
+ const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l)) & 0x7F));
2079
+
2080
+ const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
2081
+ const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
2082
+
2083
+ #ifdef INT8_MMA_AVAILABLE
2084
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l;
2085
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h;
2086
+ #else
2087
+ x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l;
2088
+ x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h;
2089
+ #endif // INT8_MMA_AVAILABLE
2090
+ }
2091
+
2092
+ const int ls = aux32 >> 28;
2093
+ const float d = bxi->d;
2094
+ #ifdef INT8_MMA_AVAILABLE
2095
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2;
2096
+ #else
2097
+ x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = (ls*d + d/2)/2;
2098
+ #endif // INT8_MMA_AVAILABLE
2099
+ }
2100
+ }
2101
+
2102
+ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_s(
2103
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
2104
+
2105
+ #ifdef INT8_MMA_AVAILABLE
2106
+ int * x_qs = (int *) x_tile;
2107
+ float * x_df = (float *) (x_qs + WARP_SIZE*2);
2108
+ #else
2109
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
2110
+ int * x_qs = (int *) x_tile;
2111
+ float * x_df = (float *) (x_qs + txs.qs);
2112
+ #endif // INT8_MMA_AVAILABLE
2113
+
2114
+ const int kqsx = threadIdx.x % (QI3_S/2);
2115
+
2116
+ #pragma unroll
2117
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI3_S/2)) {
2118
+ int i = i0 + threadIdx.y*(2*WARP_SIZE/QI3_S) + threadIdx.x/(QI3_S/2);
2119
+
2120
+ if (need_check) {
2121
+ i = min(i, i_max);
2122
+ }
2123
+
2124
+ const block_iq3_s * bxi = (const block_iq3_s *) x + kbx0 + i*stride;
2125
+
2126
+ const int2 qs_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1));
2127
+ const uint8_t * qs = (const uint8_t *) &qs_packed;
2128
+
2129
+ const int qh = bxi->qh[kqsx];
2130
+
2131
+ const int signs_packed_32 = get_int_b2(bxi->signs, kqsx);
2132
+ const uint8_t * signs_packed_8 = (const uint8_t *) &signs_packed_32;
2133
+
2134
+ #pragma unroll
2135
+ for (int l = 0; l < QR3_S; ++l) {
2136
+ const int2 grid_pos = make_int2(
2137
+ iq3s_grid[qs[2*l+0] | ((qh << (8 - 2*l)) & 0x100)],
2138
+ iq3s_grid[qs[2*l+1] | ((qh << (7 - 2*l)) & 0x100)]);
2139
+
2140
+ const int signs0 = __vcmpne4(((signs_packed_8[l] & 0x03) << 7) | ((signs_packed_8[l] & 0x0C) << 21), 0x00000000);
2141
+ const int signs1 = __vcmpne4(((signs_packed_8[l] & 0x30) << 3) | ((signs_packed_8[l] & 0xC0) << 17), 0x00000000);
2142
+
2143
+ const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
2144
+ const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
2145
+
2146
+ #ifdef INT8_MMA_AVAILABLE
2147
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l;
2148
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h;
2149
+ #else
2150
+ x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+0)] = grid_l;
2151
+ x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+1)] = grid_h;
2152
+ #endif // INT8_MMA_AVAILABLE
2153
+ }
2154
+
2155
+ const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F);
2156
+ const float d = bxi->d;
2157
+ #ifdef INT8_MMA_AVAILABLE
2158
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d;
2159
+ #else
2160
+ x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = ls*d;
2161
+ #endif // INT8_MMA_AVAILABLE
2162
+ }
2163
+ }
2164
+
2165
+ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq1_s(
2166
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
2167
+
2168
+ #ifdef INT8_MMA_AVAILABLE
2169
+ int * x_qs = (int *) x_tile;
2170
+ half2 * x_ds = (half2 *) (x_qs + WARP_SIZE*2);
2171
+ #else
2172
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
2173
+ int * x_qs = (int *) x_tile;
2174
+ half2 * x_ds = (half2 *) (x_qs + txs.qs);
2175
+ #endif // INT8_MMA_AVAILABLE
2176
+
2177
+ const int kqsx = threadIdx.x % QI1_S;
2178
+
2179
+ #pragma unroll
2180
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI1_S) {
2181
+ int i = i0 + threadIdx.y*(WARP_SIZE/QI1_S) + threadIdx.x/QI1_S;
2182
+
2183
+ if (need_check) {
2184
+ i = min(i, i_max);
2185
+ }
2186
+
2187
+ const block_iq1_s * bxi = (const block_iq1_s *) x + kbx0 + i*stride;
2188
+
2189
+ const int qs_packed = get_int_b2(bxi->qs, kqsx);
2190
+ const uint8_t * qs = (const uint8_t *) &qs_packed;
2191
+
2192
+ const int qh = bxi->qh[kqsx];
2193
+
2194
+ #pragma unroll
2195
+ for (int l = 0; l < QR1_S/2; ++l) {
2196
+ const int grid = iq1s_grid_gpu[qs[l] | (((qh >> (3*l)) & 0x07) << 8)];
2197
+
2198
+ const int grid0 = (grid >> 0) & 0x0F0F0F0F;
2199
+ const int grid1 = (grid >> 4) & 0x0F0F0F0F;
2200
+
2201
+ #ifdef INT8_MMA_AVAILABLE
2202
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0;
2203
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1;
2204
+ #else
2205
+ x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+0)] = grid0;
2206
+ x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+1)] = grid1;
2207
+ #endif // INT8_MMA_AVAILABLE
2208
+ }
2209
+
2210
+ const float d1q = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1);
2211
+ const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000);
2212
+
2213
+ #ifdef INT8_MMA_AVAILABLE
2214
+ x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta);
2215
+ #else
2216
+ x_ds[i*(WARP_SIZE/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta);
2217
+ #endif // INT8_MMA_AVAILABLE
2218
+ }
2219
+ }
2220
+
2221
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_xs(
2222
  const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
2223
 
 
2351
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_0> {
2352
  static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
2353
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, nwarps, need_check>;
2354
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_DS4>;
2355
  static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2356
  };
2357
 
 
2359
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_1> {
2360
  static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ;
2361
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, nwarps, need_check>;
2362
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
2363
  static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2364
  };
2365
 
 
2367
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
2368
  static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ;
2369
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, nwarps, need_check>;
2370
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
2371
  static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2372
  };
2373
 
 
2383
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> {
2384
  static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
2385
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, nwarps, need_check>;
2386
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
2387
  static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2388
  };
2389
 
 
2399
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q3_K> {
2400
  static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ;
2401
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, nwarps, need_check>;
2402
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>;
2403
  static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2404
  };
2405
 
 
2407
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> {
2408
  static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ;
2409
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, nwarps, need_check>;
2410
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
2411
  static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2412
  };
2413
 
 
2415
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> {
2416
  static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ;
2417
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, nwarps, need_check>;
2418
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
2419
  static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2420
  };
2421
 
 
2427
  static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2428
  };
2429
 
2430
+ template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2431
+ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_XXS> {
2432
+ static constexpr int vdr = VDR_IQ2_XXS_Q8_1_MMQ;
2433
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xxs<mmq_y, nwarps, need_check>;
2434
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
2435
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2436
+ };
2437
+
2438
+ template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2439
+ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_XS> {
2440
+ static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ;
2441
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xs<mmq_y, nwarps, need_check>;
2442
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>;
2443
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2444
+ };
2445
+
2446
+ template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2447
+ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_S> {
2448
+ static constexpr int vdr = VDR_IQ2_S_Q8_1_MMQ;
2449
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_s<mmq_y, nwarps, need_check>;
2450
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>;
2451
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2452
+ };
2453
+
2454
+ template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2455
+ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ3_XXS> {
2456
+ static constexpr int vdr = VDR_IQ3_XXS_Q8_1_MMQ;
2457
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_xxs<mmq_y, nwarps, need_check>;
2458
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
2459
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2460
+ };
2461
+
2462
+ template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2463
+ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ3_S> {
2464
+ static constexpr int vdr = VDR_IQ3_S_Q8_1_MMQ;
2465
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_s<mmq_y, nwarps, need_check>;
2466
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
2467
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2468
+ };
2469
+
2470
+ template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2471
+ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ1_S> {
2472
+ static constexpr int vdr = VDR_IQ1_S_Q8_1_MMQ;
2473
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s<mmq_y, nwarps, need_check>;
2474
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
2475
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2476
+ };
2477
+
2478
  template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2479
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_NL> {
2480
  static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ;
2481
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl<mmq_y, nwarps, need_check>;
2482
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
2483
  static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2484
  };
2485
 
 
2487
  struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_XS> {
2488
  static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ;
2489
  static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs<mmq_y, nwarps, need_check>;
2490
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
2491
  static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2492
  };
2493
 
 
2916
  extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
2917
  extern DECL_MMQ_CASE(GGML_TYPE_Q5_K);
2918
  extern DECL_MMQ_CASE(GGML_TYPE_Q6_K);
2919
+ extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XXS);
2920
+ extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XS);
2921
+ extern DECL_MMQ_CASE(GGML_TYPE_IQ2_S);
2922
+ extern DECL_MMQ_CASE(GGML_TYPE_IQ3_XXS);
2923
+ extern DECL_MMQ_CASE(GGML_TYPE_IQ3_S);
2924
+ extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S);
2925
  extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL);
2926
  extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS);
2927
 
ggml/src/ggml-cuda/template-instances/generate_cu_files.py CHANGED
@@ -23,7 +23,8 @@ SOURCE_FATTN_WMMA_CASE = "DECL_FATTN_WMMA_F16_CASE({head_size}, {cols_per_block}
23
  TYPES_MMQ = [
24
  "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
25
  "GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K",
26
- "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS"
 
27
  ]
28
 
29
  SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
23
  TYPES_MMQ = [
24
  "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
25
  "GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K",
26
+ "GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S",
27
+ "GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS"
28
  ]
29
 
30
  SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
+
3
+ #include "../mmq.cuh"
4
+
5
+ DECL_MMQ_CASE(GGML_TYPE_IQ1_S);
ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
+
3
+ #include "../mmq.cuh"
4
+
5
+ DECL_MMQ_CASE(GGML_TYPE_IQ2_S);
ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
+
3
+ #include "../mmq.cuh"
4
+
5
+ DECL_MMQ_CASE(GGML_TYPE_IQ2_XS);
ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
+
3
+ #include "../mmq.cuh"
4
+
5
+ DECL_MMQ_CASE(GGML_TYPE_IQ2_XXS);
ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
+
3
+ #include "../mmq.cuh"
4
+
5
+ DECL_MMQ_CASE(GGML_TYPE_IQ3_S);
ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
+
3
+ #include "../mmq.cuh"
4
+
5
+ DECL_MMQ_CASE(GGML_TYPE_IQ3_XXS);
ggml/src/ggml-cuda/vecdotq.cuh CHANGED
@@ -188,6 +188,27 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp
188
  return sumi*d8d8 + m8s8 / (QI8_1 / vdr);
189
  }
190
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  #define VDR_Q2_K_Q8_1_MMVQ 1
192
  #define VDR_Q2_K_Q8_1_MMQ 4
193
 
 
188
  return sumi*d8d8 + m8s8 / (QI8_1 / vdr);
189
  }
190
 
191
+ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_0_16_q8_1_impl(
192
+ const int * v, const int * u, const float * d8_0, const float & d8_1) {
193
+
194
+ float sumf = 0.0f;
195
+
196
+ #pragma unroll
197
+ for (int i0 = 0; i0 < vdr; i0 += QI8_0/2) {
198
+ int sumi = 0;
199
+
200
+ #pragma unroll
201
+ for (int i = i0; i < i0 + QI8_0/2; ++i) {
202
+ // SIMD dot product of quantized values
203
+ sumi = ggml_cuda_dp4a(v[i], u[i], sumi);
204
+ }
205
+
206
+ sumf += d8_0[i0/(QI8_0/2)]*sumi;
207
+ }
208
+
209
+ return d8_1*sumf;
210
+ }
211
+
212
  #define VDR_Q2_K_Q8_1_MMVQ 1
213
  #define VDR_Q2_K_Q8_1_MMQ 4
214