JohannesGaessler commited on
Commit
8411e3c
·
1 Parent(s): fcd0c52

CUDA: MMQ support for iq4_nl, iq4_xs (llama/8278)

Browse files
ggml/src/ggml-cuda/fattn-common.cuh CHANGED
@@ -68,7 +68,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
68
  const int iqs4 = k_KQ % QI4_0;
69
  const int shift = k_KQ & (QI8_1/2);
70
 
71
- const int v = (get_int_from_uint8(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
72
  const int u = Q_q8[k_KQ_0/WARP_SIZE];
73
 
74
  const int sumi = ggml_cuda_dp4a(v, u, 0);
@@ -108,7 +108,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
108
  const int iqs4 = k_KQ % QI4_1;
109
  const int shift = k_KQ & (QI8_1/2);
110
 
111
- const int v = (get_int_from_uint8_aligned(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
112
  const int u = Q_q8[k_KQ_0/WARP_SIZE];
113
 
114
  const int sumi = ggml_cuda_dp4a(v, u, 0);
@@ -153,8 +153,8 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
153
  const int iqs8 = k_KQ % QI8_1;
154
  const int shift = k_KQ & (QI8_1/2);
155
 
156
- int v = (get_int_from_uint8(K_q5_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
157
- const int vh = get_int_from_uint8(K_q5_0[ib].qh, 0) >> (iqs8 * QI5_0);
158
  v |= (vh << 4) & 0x00000010; // 0 -> 4
159
  v |= (vh << 11) & 0x00001000; // 1 -> 12
160
  v |= (vh << 18) & 0x00100000; // 2 -> 20
@@ -200,8 +200,8 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
200
  const int iqs8 = k_KQ % QI8_1;
201
  const int shift = k_KQ & (QI8_1/2);
202
 
203
- int v = (get_int_from_uint8(K_q5_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
204
- const int vh = get_int_from_uint8(K_q5_1[ib].qh, 0) >> (iqs8 * QI5_1);
205
  v |= (vh << 4) & 0x00000010; // 0 -> 4
206
  v |= (vh << 11) & 0x00001000; // 1 -> 12
207
  v |= (vh << 18) & 0x00100000; // 2 -> 20
@@ -249,7 +249,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
249
  const int ib = k_KQ / QI8_0;
250
  const int iqs = k_KQ % QI8_0;
251
 
252
- const int v = get_int_from_int8(K_q8_0[ib].qs, iqs);
253
 
254
  T Q_d;
255
  if (std::is_same<T, half>::value) {
@@ -408,7 +408,7 @@ static __device__ __forceinline__ T dequantize_1_q5_0(const void * __restrict__
408
 
409
  const T d = x[ib].d;
410
  const int ql0 = x[ib].qs[iqs];
411
- const int qh0 = get_int_from_uint8(x[ib].qh, 0);
412
  const int ql = ((ql0 >> (4*shift)) & 0x0F);
413
  const int qh = ((qh0 >> idq) << 4) & 0x10;
414
  const int q = (ql | qh) - 16;
@@ -433,7 +433,7 @@ static __device__ __forceinline__ T dequantize_1_q5_1(const void * __restrict__
433
 
434
  const half2 dm = x[ib].dm;
435
  const int ql0 = x[ib].qs[iqs];
436
- const int qh0 = get_int_from_uint8_aligned(x[ib].qh, 0);
437
  const int ql = ((ql0 >> (4*shift)) & 0x0F);
438
  const int qh = ((qh0 >> idq) << 4) & 0x10;
439
  const int q = (ql | qh);
 
68
  const int iqs4 = k_KQ % QI4_0;
69
  const int shift = k_KQ & (QI8_1/2);
70
 
71
+ const int v = (get_int_b2(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
72
  const int u = Q_q8[k_KQ_0/WARP_SIZE];
73
 
74
  const int sumi = ggml_cuda_dp4a(v, u, 0);
 
108
  const int iqs4 = k_KQ % QI4_1;
109
  const int shift = k_KQ & (QI8_1/2);
110
 
111
+ const int v = (get_int_b4(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
112
  const int u = Q_q8[k_KQ_0/WARP_SIZE];
113
 
114
  const int sumi = ggml_cuda_dp4a(v, u, 0);
 
153
  const int iqs8 = k_KQ % QI8_1;
154
  const int shift = k_KQ & (QI8_1/2);
155
 
156
+ int v = (get_int_b2(K_q5_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
157
+ const int vh = get_int_b2(K_q5_0[ib].qh, 0) >> (iqs8 * QI5_0);
158
  v |= (vh << 4) & 0x00000010; // 0 -> 4
159
  v |= (vh << 11) & 0x00001000; // 1 -> 12
160
  v |= (vh << 18) & 0x00100000; // 2 -> 20
 
200
  const int iqs8 = k_KQ % QI8_1;
201
  const int shift = k_KQ & (QI8_1/2);
202
 
203
+ int v = (get_int_b2(K_q5_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
204
+ const int vh = get_int_b2(K_q5_1[ib].qh, 0) >> (iqs8 * QI5_1);
205
  v |= (vh << 4) & 0x00000010; // 0 -> 4
206
  v |= (vh << 11) & 0x00001000; // 1 -> 12
207
  v |= (vh << 18) & 0x00100000; // 2 -> 20
 
249
  const int ib = k_KQ / QI8_0;
250
  const int iqs = k_KQ % QI8_0;
251
 
252
+ const int v = get_int_b2(K_q8_0[ib].qs, iqs);
253
 
254
  T Q_d;
255
  if (std::is_same<T, half>::value) {
 
408
 
409
  const T d = x[ib].d;
410
  const int ql0 = x[ib].qs[iqs];
411
+ const int qh0 = get_int_b2(x[ib].qh, 0);
412
  const int ql = ((ql0 >> (4*shift)) & 0x0F);
413
  const int qh = ((qh0 >> idq) << 4) & 0x10;
414
  const int q = (ql | qh) - 16;
 
433
 
434
  const half2 dm = x[ib].dm;
435
  const int ql0 = x[ib].qs[iqs];
436
+ const int qh0 = get_int_b4(x[ib].qh, 0);
437
  const int ql = ((ql0 >> (4*shift)) & 0x0F);
438
  const int qh = ((qh0 >> idq) << 4) & 0x10;
439
  const int q = (ql | qh);
ggml/src/ggml-cuda/mmq.cu CHANGED
@@ -59,6 +59,12 @@ void ggml_cuda_op_mul_mat_q(
59
  case GGML_TYPE_Q6_K:
60
  mul_mat_q_case<GGML_TYPE_Q6_K>(ctx, args, stream);
61
  break;
 
 
 
 
 
 
62
  default:
63
  GGML_ASSERT(false);
64
  break;
@@ -87,6 +93,8 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
87
  case GGML_TYPE_Q4_K:
88
  case GGML_TYPE_Q5_K:
89
  case GGML_TYPE_Q6_K:
 
 
90
  mmq_supported = true;
91
  break;
92
  default:
 
59
  case GGML_TYPE_Q6_K:
60
  mul_mat_q_case<GGML_TYPE_Q6_K>(ctx, args, stream);
61
  break;
62
+ case GGML_TYPE_IQ4_XS:
63
+ mul_mat_q_case<GGML_TYPE_IQ4_XS>(ctx, args, stream);
64
+ break;
65
+ case GGML_TYPE_IQ4_NL:
66
+ mul_mat_q_case<GGML_TYPE_IQ4_NL>(ctx, args, stream);
67
+ break;
68
  default:
69
  GGML_ASSERT(false);
70
  break;
 
93
  case GGML_TYPE_Q4_K:
94
  case GGML_TYPE_Q5_K:
95
  case GGML_TYPE_Q6_K:
96
+ case GGML_TYPE_IQ4_XS:
97
+ case GGML_TYPE_IQ4_NL:
98
  mmq_supported = true;
99
  break;
100
  default:
ggml/src/ggml-cuda/mmq.cuh CHANGED
@@ -92,15 +92,17 @@ static constexpr __device__ int get_mmq_y_device() {
92
 
93
  static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
94
  return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 :
95
- type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 :
96
- type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q5_0 :
97
- type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q5_1 :
98
- type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 :
99
- type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K :
100
- type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K :
101
- type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K :
102
- type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K :
103
- type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K :
 
 
104
  tile_x_sizes{0, 0, 0};
105
  }
106
 
@@ -128,15 +130,17 @@ static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
128
 
129
  static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
130
  return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q4_0 :
131
- type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q4_1 :
132
- type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q5_0 :
133
- type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q5_1 :
134
- type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
135
- type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K :
136
- type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K :
137
- type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q4_K :
138
- type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q5_K :
139
- type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K :
 
 
140
  0;
141
  }
142
 
@@ -185,9 +189,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
185
  const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
186
 
187
  #ifdef INT8_MMA_AVAILABLE
188
- x_qs[i*MMQ_MMA_TILE_X_K_Q4_0 + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx);
189
  #else
190
- x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx);
191
  #endif // INT8_MMA_AVAILABLE
192
  }
193
 
@@ -348,9 +352,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
348
  const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
349
 
350
  #ifdef INT8_MMA_AVAILABLE
351
- x_qs[i*MMQ_MMA_TILE_X_K_Q4_1 + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
352
  #else
353
- x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
354
  #endif // INT8_MMA_AVAILABLE
355
  }
356
 
@@ -509,8 +513,8 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
509
 
510
  const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbx;
511
 
512
- const int ql = get_int_from_uint8(bxi->qs, kqsx);
513
- const int qh = get_int_from_uint8(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_0));
514
 
515
  int qs0 = (ql >> 0) & 0x0F0F0F0F;
516
  qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
@@ -674,8 +678,8 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
674
 
675
  const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbx;
676
 
677
- const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
678
- const int qh = get_int_from_uint8_aligned(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_1));
679
 
680
  int qs0 = (ql >> 0) & 0x0F0F0F0F;
681
  qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
@@ -839,9 +843,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
839
  const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
840
 
841
  #ifdef INT8_MMA_AVAILABLE
842
- x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x] = get_int_from_int8(bxi->qs, kqsx);
843
  #else
844
- x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_from_int8(bxi->qs, kqsx);
845
  #endif // INT8_MMA_AVAILABLE
846
  }
847
 
@@ -984,7 +988,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
984
 
985
  const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + kbx;
986
 
987
- const int x_ql_0 = get_int_from_uint8(bxi->qs, kqsx);
988
 
989
  #pragma unroll
990
  for (int l = 0; l < QR2_K; ++l) {
@@ -1166,8 +1170,8 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1166
 
1167
  const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbx;
1168
 
1169
- const int x_ql_0 = get_int_from_uint8(bxi->qs, kqsx);
1170
- const int x_qh_0 = get_int_from_uint8(bxi->hmask, kqsx % (QI3_K/2)) >> (4 * (kqsx / (QI3_K/2)));
1171
 
1172
  #pragma unroll
1173
  for (int l = 0; l < QR3_K; ++l) {
@@ -1225,11 +1229,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1225
 
1226
  const int ksc_low = ksc % (QI3_K/8);
1227
  const int shift_low = 4 * (ksc / (QI3_K/8));
1228
- const int sc_low = (get_int_from_uint8(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F;
1229
 
1230
  const int ksc_high = QI3_K/8;
1231
  const int shift_high = 2 * ksc;
1232
- const int sc_high = ((get_int_from_uint8(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030;
1233
 
1234
  const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
1235
 
@@ -1393,9 +1397,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1393
  const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbx;
1394
 
1395
  #ifdef INT8_MMA_AVAILABLE
1396
- x_qs[i*MMQ_MMA_TILE_X_K_Q4_K + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
1397
  #else
1398
- x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
1399
  #endif // INT8_MMA_AVAILABLE
1400
  }
1401
 
@@ -1610,11 +1614,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1610
  const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + kbx;
1611
  const int ky = QR5_K*kqsx;
1612
 
1613
- const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
1614
  const int ql0 = (ql >> 0) & 0x0F0F0F0F;
1615
  const int ql1 = (ql >> 4) & 0x0F0F0F0F;
1616
 
1617
- const int qh = get_int_from_uint8_aligned(bxi->qh, kqsx % (QI5_K/4));
1618
  const int qh0 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 0)) << 4) & 0x10101010;
1619
  const int qh1 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 1)) << 4) & 0x10101010;
1620
 
@@ -1832,11 +1836,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1832
  const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbx;
1833
  const int ky = QR6_K*kqsx;
1834
 
1835
- const int ql = get_int_from_uint8(bxi->ql, kqsx);
1836
  const int ql0 = (ql >> 0) & 0x0F0F0F0F;
1837
  const int ql1 = (ql >> 4) & 0x0F0F0F0F;
1838
 
1839
- const int qh = get_int_from_uint8(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4));
1840
  const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) << 4) & 0x30303030;
1841
  const int qh1 = (qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) & 0x30303030;
1842
 
@@ -1883,9 +1887,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1883
  const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / 4;
1884
 
1885
  #ifdef INT8_MMA_AVAILABLE
1886
- x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, threadIdx.x % (QI6_K/8));
1887
  #else
1888
- x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, threadIdx.x % (QI6_K/8));
1889
  #endif // INT8_MMA_AVAILABLE
1890
  }
1891
  }
@@ -2018,6 +2022,124 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
2018
  #endif // INT8_MMA_AVAILABLE
2019
  }
2020
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2021
  template<int mmq_x, int mmq_y, int nwarps, bool need_check>
2022
  static __device__ __forceinline__ void mmq_write_back_dp4a(
2023
  const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
@@ -2167,6 +2289,22 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
2167
  static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2168
  };
2169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2170
  static bool mmq_need_sum(const ggml_type type_x) {
2171
  switch (type_x) {
2172
  case GGML_TYPE_Q4_0:
@@ -2184,6 +2322,8 @@ static bool mmq_need_sum(const ggml_type type_x) {
2184
  case GGML_TYPE_Q5_K:
2185
  return true;
2186
  case GGML_TYPE_Q6_K:
 
 
2187
  return false;
2188
  default:
2189
  GGML_ASSERT(false);
@@ -2608,6 +2748,8 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
2608
  extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
2609
  extern DECL_MMQ_CASE(GGML_TYPE_Q5_K);
2610
  extern DECL_MMQ_CASE(GGML_TYPE_Q6_K);
 
 
2611
 
2612
  // -------------------------------------------------------------------------------------------------------------------------
2613
 
 
92
 
93
  static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
94
  return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 :
95
+ type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 :
96
+ type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q5_0 :
97
+ type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q5_1 :
98
+ type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 :
99
+ type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K :
100
+ type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K :
101
+ type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K :
102
+ type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K :
103
+ type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K :
104
+ type == GGML_TYPE_IQ4_XS ? MMQ_DP4A_TXS_Q5_0 :
105
+ type == GGML_TYPE_IQ4_NL ? MMQ_DP4A_TXS_Q5_0 :
106
  tile_x_sizes{0, 0, 0};
107
  }
108
 
 
130
 
131
  static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
132
  return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q4_0 :
133
+ type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q4_1 :
134
+ type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q5_0 :
135
+ type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q5_1 :
136
+ type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
137
+ type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K :
138
+ type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K :
139
+ type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q4_K :
140
+ type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q5_K :
141
+ type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K :
142
+ type == GGML_TYPE_IQ4_XS ? MMQ_MMA_TILE_X_K_Q5_0 :
143
+ type == GGML_TYPE_IQ4_NL ? MMQ_MMA_TILE_X_K_Q5_0 :
144
  0;
145
  }
146
 
 
189
  const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
190
 
191
  #ifdef INT8_MMA_AVAILABLE
192
+ x_qs[i*MMQ_MMA_TILE_X_K_Q4_0 + threadIdx.x] = get_int_b2(bxi->qs, kqsx);
193
  #else
194
+ x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_b2(bxi->qs, kqsx);
195
  #endif // INT8_MMA_AVAILABLE
196
  }
197
 
 
352
  const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
353
 
354
  #ifdef INT8_MMA_AVAILABLE
355
+ x_qs[i*MMQ_MMA_TILE_X_K_Q4_1 + threadIdx.x] = get_int_b4(bxi->qs, kqsx);
356
  #else
357
+ x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_b4(bxi->qs, kqsx);
358
  #endif // INT8_MMA_AVAILABLE
359
  }
360
 
 
513
 
514
  const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbx;
515
 
516
+ const int ql = get_int_b2(bxi->qs, kqsx);
517
+ const int qh = get_int_b2(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_0));
518
 
519
  int qs0 = (ql >> 0) & 0x0F0F0F0F;
520
  qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
 
678
 
679
  const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbx;
680
 
681
+ const int ql = get_int_b4(bxi->qs, kqsx);
682
+ const int qh = get_int_b4(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_1));
683
 
684
  int qs0 = (ql >> 0) & 0x0F0F0F0F;
685
  qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
 
843
  const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
844
 
845
  #ifdef INT8_MMA_AVAILABLE
846
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x] = get_int_b2(bxi->qs, kqsx);
847
  #else
848
+ x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_b2(bxi->qs, kqsx);
849
  #endif // INT8_MMA_AVAILABLE
850
  }
851
 
 
988
 
989
  const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + kbx;
990
 
991
+ const int x_ql_0 = get_int_b2(bxi->qs, kqsx);
992
 
993
  #pragma unroll
994
  for (int l = 0; l < QR2_K; ++l) {
 
1170
 
1171
  const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbx;
1172
 
1173
+ const int x_ql_0 = get_int_b2(bxi->qs, kqsx);
1174
+ const int x_qh_0 = get_int_b2(bxi->hmask, kqsx % (QI3_K/2)) >> (4 * (kqsx / (QI3_K/2)));
1175
 
1176
  #pragma unroll
1177
  for (int l = 0; l < QR3_K; ++l) {
 
1229
 
1230
  const int ksc_low = ksc % (QI3_K/8);
1231
  const int shift_low = 4 * (ksc / (QI3_K/8));
1232
+ const int sc_low = (get_int_b2(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F;
1233
 
1234
  const int ksc_high = QI3_K/8;
1235
  const int shift_high = 2 * ksc;
1236
+ const int sc_high = ((get_int_b2(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030;
1237
 
1238
  const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
1239
 
 
1397
  const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbx;
1398
 
1399
  #ifdef INT8_MMA_AVAILABLE
1400
+ x_qs[i*MMQ_MMA_TILE_X_K_Q4_K + threadIdx.x] = get_int_b4(bxi->qs, kqsx);
1401
  #else
1402
+ x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_b4(bxi->qs, kqsx);
1403
  #endif // INT8_MMA_AVAILABLE
1404
  }
1405
 
 
1614
  const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + kbx;
1615
  const int ky = QR5_K*kqsx;
1616
 
1617
+ const int ql = get_int_b4(bxi->qs, kqsx);
1618
  const int ql0 = (ql >> 0) & 0x0F0F0F0F;
1619
  const int ql1 = (ql >> 4) & 0x0F0F0F0F;
1620
 
1621
+ const int qh = get_int_b4(bxi->qh, kqsx % (QI5_K/4));
1622
  const int qh0 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 0)) << 4) & 0x10101010;
1623
  const int qh1 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 1)) << 4) & 0x10101010;
1624
 
 
1836
  const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbx;
1837
  const int ky = QR6_K*kqsx;
1838
 
1839
+ const int ql = get_int_b2(bxi->ql, kqsx);
1840
  const int ql0 = (ql >> 0) & 0x0F0F0F0F;
1841
  const int ql1 = (ql >> 4) & 0x0F0F0F0F;
1842
 
1843
+ const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4));
1844
  const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) << 4) & 0x30303030;
1845
  const int qh1 = (qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) & 0x30303030;
1846
 
 
1887
  const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / 4;
1888
 
1889
  #ifdef INT8_MMA_AVAILABLE
1890
+ x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8));
1891
  #else
1892
+ x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8));
1893
  #endif // INT8_MMA_AVAILABLE
1894
  }
1895
  }
 
2022
  #endif // INT8_MMA_AVAILABLE
2023
  }
2024
 
2025
+ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl(
2026
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
2027
+
2028
+ #ifdef INT8_MMA_AVAILABLE
2029
+ int * x_qs = (int *) x_tile;
2030
+ float * x_df = (float *) (x_qs + WARP_SIZE*2);
2031
+ #else
2032
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y);
2033
+ int * x_qs = (int *) x_tile;
2034
+ float * x_df = (float *) (x_qs + txs.qs);
2035
+ #endif // INT8_MMA_AVAILABLE
2036
+
2037
+ const int kbx = threadIdx.x / QI4_NL;
2038
+ const int kqsx = threadIdx.x % QI4_NL;
2039
+
2040
+ #pragma unroll
2041
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
2042
+ int i = i0 + threadIdx.y;
2043
+
2044
+ if (need_check) {
2045
+ i = min(i, i_max);
2046
+ }
2047
+
2048
+ const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbx;
2049
+
2050
+ const int aux_q4 = get_int_b2(bxi->qs, kqsx);
2051
+ const int2 v = get_int_from_table_16(aux_q4);
2052
+ const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
2053
+ #ifdef INT8_MMA_AVAILABLE
2054
+ x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + k0 + 0] = v.x;
2055
+ x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + k0 + 4] = v.y;
2056
+ #else
2057
+ x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
2058
+ x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
2059
+ #endif // INT8_MMA_AVAILABLE
2060
+ }
2061
+
2062
+ const int blocks_per_tile_x_row = WARP_SIZE / QI4_NL;
2063
+ const int kbxd = threadIdx.x % blocks_per_tile_x_row;
2064
+
2065
+ #pragma unroll
2066
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_NL) {
2067
+ int i = i0 + threadIdx.y * QI4_NL + threadIdx.x / blocks_per_tile_x_row;
2068
+
2069
+ if (need_check) {
2070
+ i = min(i, i_max);
2071
+ }
2072
+
2073
+ const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd;
2074
+
2075
+ #ifdef INT8_MMA_AVAILABLE
2076
+ x_df[i*MMQ_MMA_TILE_X_K_Q5_0 + kbxd] = __half2float(bxi->d);
2077
+ #else
2078
+ x_df[i*(WARP_SIZE/4) + i/4 + kbxd] = __half2float(bxi->d);
2079
+ #endif // INT8_MMA_AVAILABLE
2080
+ }
2081
+ }
2082
+
2083
+ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_xs(
2084
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
2085
+
2086
+ #ifdef INT8_MMA_AVAILABLE
2087
+ int * x_qs = (int *) x_tile;
2088
+ float * x_df = (float *) (x_qs + WARP_SIZE*2);
2089
+ #else
2090
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
2091
+ int * x_qs = (int *) x_tile;
2092
+ float * x_df = (float *) (x_qs + txs.qs);
2093
+ #endif // INT8_MMA_AVAILABLE
2094
+
2095
+ const int kbx = 0; // threadIdx.x / QI4_XS
2096
+ const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS
2097
+
2098
+ #pragma unroll
2099
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
2100
+ int i = i0 + threadIdx.y;
2101
+
2102
+ if (need_check) {
2103
+ i = min(i, i_max);
2104
+ }
2105
+
2106
+ const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride + kbx;
2107
+
2108
+ const int aux_q4 = get_int_b4(bxi->qs, kqsx);
2109
+ const int2 v = get_int_from_table_16(aux_q4);
2110
+ const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
2111
+ #ifdef INT8_MMA_AVAILABLE
2112
+ x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + k0 + 0] = v.x;
2113
+ x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + k0 + 4] = v.y;
2114
+ #else
2115
+ x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
2116
+ x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
2117
+ #endif // INT8_MMA_AVAILABLE
2118
+ }
2119
+
2120
+ #pragma unroll
2121
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
2122
+ int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
2123
+
2124
+ if (need_check) {
2125
+ i = min(i, i_max);
2126
+ }
2127
+
2128
+ const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride;
2129
+
2130
+ const float d = __half2float(bxi->d);
2131
+
2132
+ const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F)
2133
+ | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4);
2134
+
2135
+ #ifdef INT8_MMA_AVAILABLE
2136
+ x_df[i*MMQ_MMA_TILE_X_K_Q5_0 + threadIdx.x % 8] = d * (ls - 32);
2137
+ #else
2138
+ x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = d * (ls - 32);
2139
+ #endif // INT8_MMA_AVAILABLE
2140
+ }
2141
+ }
2142
+
2143
  template<int mmq_x, int mmq_y, int nwarps, bool need_check>
2144
  static __device__ __forceinline__ void mmq_write_back_dp4a(
2145
  const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
 
2289
  static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2290
  };
2291
 
2292
+ template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2293
+ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_NL> {
2294
+ static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ;
2295
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl<mmq_y, nwarps, need_check>;
2296
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q5_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
2297
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2298
+ };
2299
+
2300
+ template <int mmq_x, int mmq_y, int nwarps, bool need_check>
2301
+ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_XS> {
2302
+ static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ;
2303
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs<mmq_y, nwarps, need_check>;
2304
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q5_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
2305
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
2306
+ };
2307
+
2308
  static bool mmq_need_sum(const ggml_type type_x) {
2309
  switch (type_x) {
2310
  case GGML_TYPE_Q4_0:
 
2322
  case GGML_TYPE_Q5_K:
2323
  return true;
2324
  case GGML_TYPE_Q6_K:
2325
+ case GGML_TYPE_IQ4_XS:
2326
+ case GGML_TYPE_IQ4_NL:
2327
  return false;
2328
  default:
2329
  GGML_ASSERT(false);
 
2748
  extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
2749
  extern DECL_MMQ_CASE(GGML_TYPE_Q5_K);
2750
  extern DECL_MMQ_CASE(GGML_TYPE_Q6_K);
2751
+ extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL);
2752
+ extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS);
2753
 
2754
  // -------------------------------------------------------------------------------------------------------------------------
2755
 
ggml/src/ggml-cuda/template-instances/generate_cu_files.py CHANGED
@@ -22,7 +22,8 @@ SOURCE_FATTN_WMMA_CASE = "DECL_FATTN_WMMA_F16_CASE({head_size}, {cols_per_block}
22
 
23
  TYPES_MMQ = [
24
  "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
25
- "GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K"
 
26
  ]
27
 
28
  SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
 
22
 
23
  TYPES_MMQ = [
24
  "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
25
+ "GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K",
26
+ "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS"
27
  ]
28
 
29
  SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
+
3
+ #include "../mmq.cuh"
4
+
5
+ DECL_MMQ_CASE(GGML_TYPE_IQ4_NL);
ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_xs.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
+
3
+ #include "../mmq.cuh"
4
+
5
+ DECL_MMQ_CASE(GGML_TYPE_IQ4_XS);
ggml/src/ggml-cuda/vecdotq.cuh CHANGED
@@ -1,36 +1,8 @@
1
  #include "common.cuh"
2
  #include <cstdint>
3
 
4
- static __device__ __forceinline__ int get_int_from_int8(const int8_t * x8, const int & i32) {
5
- const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment
6
-
7
- int x32 = 0;
8
- x32 |= x16[0] << 0;
9
- x32 |= x16[1] << 16;
10
-
11
- return x32;
12
- }
13
-
14
- static __device__ __forceinline__ int get_int_from_uint8(const uint8_t * x8, const int & i32) {
15
- const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment
16
-
17
- int x32 = 0;
18
- x32 |= x16[0] << 0;
19
- x32 |= x16[1] << 16;
20
-
21
- return x32;
22
- }
23
-
24
- static __device__ __forceinline__ int get_int_from_int8_aligned(const int8_t * x8, const int & i32) {
25
- return *((const int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
26
- }
27
-
28
- static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t * x8, const int & i32) {
29
- return *((const int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
30
- }
31
-
32
  static __device__ __forceinline__ int get_int_b2(const void * x, const int & i32) {
33
- const uint16_t * x16 = (const uint16_t *) x;
34
 
35
  int x32 = x16[2*i32 + 0] << 0;
36
  x32 |= x16[2*i32 + 1] << 16;
@@ -768,6 +740,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
768
  }
769
 
770
  #define VDR_IQ2_XXS_Q8_1_MMVQ 2
 
771
 
772
  static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
773
  const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
@@ -802,6 +775,7 @@ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
802
  }
803
 
804
  #define VDR_IQ2_XS_Q8_1_MMVQ 2
 
805
 
806
  static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
807
  const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
@@ -840,6 +814,7 @@ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
840
  }
841
 
842
  #define VDR_IQ2_S_Q8_1_MMVQ 2
 
843
 
844
  static __device__ __forceinline__ float vec_dot_iq2_s_q8_1(
845
  const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
@@ -887,6 +862,7 @@ static __device__ __forceinline__ float vec_dot_iq2_s_q8_1(
887
  }
888
 
889
  #define VDR_IQ3_XXS_Q8_1_MMVQ 2
 
890
 
891
  static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
892
  const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
@@ -921,6 +897,7 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
921
  }
922
 
923
  #define VDR_IQ3_S_Q8_1_MMVQ 2
 
924
 
925
  // TODO: don't use lookup table for signs
926
  static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
@@ -962,6 +939,9 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
962
  return d * sumi;
963
  }
964
 
 
 
 
965
  static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
966
  const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
967
  const block_iq1_s * bq1 = (const block_iq1_s *) vbq + kbx;
@@ -992,6 +972,9 @@ static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
992
  return d1q * (ds.x*sumi + ds.y*delta);
993
  }
994
 
 
 
 
995
  static __device__ __forceinline__ float vec_dot_iq1_m_q8_1(
996
  const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
997
 
@@ -1051,6 +1034,7 @@ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4) {
1051
  }
1052
 
1053
  #define VDR_IQ4_NL_Q8_1_MMVQ 2
 
1054
 
1055
  static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(
1056
  const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
@@ -1074,6 +1058,7 @@ static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(
1074
  }
1075
 
1076
  #define VDR_IQ4_XS_Q8_1_MMVQ 4
 
1077
 
1078
  static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1(
1079
  const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
 
1
  #include "common.cuh"
2
  #include <cstdint>
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  static __device__ __forceinline__ int get_int_b2(const void * x, const int & i32) {
5
+ const uint16_t * x16 = (const uint16_t *) x; // assume at least 2 byte alignment
6
 
7
  int x32 = x16[2*i32 + 0] << 0;
8
  x32 |= x16[2*i32 + 1] << 16;
 
740
  }
741
 
742
  #define VDR_IQ2_XXS_Q8_1_MMVQ 2
743
+ #define VDR_IQ2_XXS_Q8_1_MMQ 2
744
 
745
  static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
746
  const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
 
775
  }
776
 
777
  #define VDR_IQ2_XS_Q8_1_MMVQ 2
778
+ #define VDR_IQ2_XS_Q8_1_MMQ 2
779
 
780
  static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
781
  const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
 
814
  }
815
 
816
  #define VDR_IQ2_S_Q8_1_MMVQ 2
817
+ #define VDR_IQ2_S_Q8_1_MMQ 2
818
 
819
  static __device__ __forceinline__ float vec_dot_iq2_s_q8_1(
820
  const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
 
862
  }
863
 
864
  #define VDR_IQ3_XXS_Q8_1_MMVQ 2
865
+ #define VDR_IQ3_XXS_Q8_1_MMQ 2
866
 
867
  static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
868
  const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
 
897
  }
898
 
899
  #define VDR_IQ3_S_Q8_1_MMVQ 2
900
+ #define VDR_IQ3_S_Q8_1_MMQ 2
901
 
902
  // TODO: don't use lookup table for signs
903
  static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
 
939
  return d * sumi;
940
  }
941
 
942
+ #define VDR_IQ1_S_Q8_1_MMVQ 1
943
+ #define VDR_IQ1_S_Q8_1_MMQ 1
944
+
945
  static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
946
  const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
947
  const block_iq1_s * bq1 = (const block_iq1_s *) vbq + kbx;
 
972
  return d1q * (ds.x*sumi + ds.y*delta);
973
  }
974
 
975
+ #define VDR_IQ1_M_Q8_1_MMVQ 1
976
+ #define VDR_IQ1_M_Q8_1_MMQ 1
977
+
978
  static __device__ __forceinline__ float vec_dot_iq1_m_q8_1(
979
  const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
980
 
 
1034
  }
1035
 
1036
  #define VDR_IQ4_NL_Q8_1_MMVQ 2
1037
+ #define VDR_IQ4_NL_Q8_1_MMQ 4
1038
 
1039
  static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(
1040
  const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
 
1058
  }
1059
 
1060
  #define VDR_IQ4_XS_Q8_1_MMVQ 4
1061
+ #define VDR_IQ4_XS_Q8_1_MMQ 4
1062
 
1063
  static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1(
1064
  const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {