Spaces:
Running
Running
Commit
·
8411e3c
1
Parent(s):
fcd0c52
CUDA: MMQ support for iq4_nl, iq4_xs (llama/8278)
Browse files- ggml/src/ggml-cuda/fattn-common.cuh +9 -9
- ggml/src/ggml-cuda/mmq.cu +8 -0
- ggml/src/ggml-cuda/mmq.cuh +183 -41
- ggml/src/ggml-cuda/template-instances/generate_cu_files.py +2 -1
- ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu +5 -0
- ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_xs.cu +5 -0
- ggml/src/ggml-cuda/vecdotq.cuh +14 -29
ggml/src/ggml-cuda/fattn-common.cuh
CHANGED
|
@@ -68,7 +68,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
|
|
| 68 |
const int iqs4 = k_KQ % QI4_0;
|
| 69 |
const int shift = k_KQ & (QI8_1/2);
|
| 70 |
|
| 71 |
-
const int v = (
|
| 72 |
const int u = Q_q8[k_KQ_0/WARP_SIZE];
|
| 73 |
|
| 74 |
const int sumi = ggml_cuda_dp4a(v, u, 0);
|
|
@@ -108,7 +108,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
|
|
| 108 |
const int iqs4 = k_KQ % QI4_1;
|
| 109 |
const int shift = k_KQ & (QI8_1/2);
|
| 110 |
|
| 111 |
-
const int v = (
|
| 112 |
const int u = Q_q8[k_KQ_0/WARP_SIZE];
|
| 113 |
|
| 114 |
const int sumi = ggml_cuda_dp4a(v, u, 0);
|
|
@@ -153,8 +153,8 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
|
|
| 153 |
const int iqs8 = k_KQ % QI8_1;
|
| 154 |
const int shift = k_KQ & (QI8_1/2);
|
| 155 |
|
| 156 |
-
int v = (
|
| 157 |
-
const int vh =
|
| 158 |
v |= (vh << 4) & 0x00000010; // 0 -> 4
|
| 159 |
v |= (vh << 11) & 0x00001000; // 1 -> 12
|
| 160 |
v |= (vh << 18) & 0x00100000; // 2 -> 20
|
|
@@ -200,8 +200,8 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
|
|
| 200 |
const int iqs8 = k_KQ % QI8_1;
|
| 201 |
const int shift = k_KQ & (QI8_1/2);
|
| 202 |
|
| 203 |
-
int v = (
|
| 204 |
-
const int vh =
|
| 205 |
v |= (vh << 4) & 0x00000010; // 0 -> 4
|
| 206 |
v |= (vh << 11) & 0x00001000; // 1 -> 12
|
| 207 |
v |= (vh << 18) & 0x00100000; // 2 -> 20
|
|
@@ -249,7 +249,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
|
|
| 249 |
const int ib = k_KQ / QI8_0;
|
| 250 |
const int iqs = k_KQ % QI8_0;
|
| 251 |
|
| 252 |
-
const int v =
|
| 253 |
|
| 254 |
T Q_d;
|
| 255 |
if (std::is_same<T, half>::value) {
|
|
@@ -408,7 +408,7 @@ static __device__ __forceinline__ T dequantize_1_q5_0(const void * __restrict__
|
|
| 408 |
|
| 409 |
const T d = x[ib].d;
|
| 410 |
const int ql0 = x[ib].qs[iqs];
|
| 411 |
-
const int qh0 =
|
| 412 |
const int ql = ((ql0 >> (4*shift)) & 0x0F);
|
| 413 |
const int qh = ((qh0 >> idq) << 4) & 0x10;
|
| 414 |
const int q = (ql | qh) - 16;
|
|
@@ -433,7 +433,7 @@ static __device__ __forceinline__ T dequantize_1_q5_1(const void * __restrict__
|
|
| 433 |
|
| 434 |
const half2 dm = x[ib].dm;
|
| 435 |
const int ql0 = x[ib].qs[iqs];
|
| 436 |
-
const int qh0 =
|
| 437 |
const int ql = ((ql0 >> (4*shift)) & 0x0F);
|
| 438 |
const int qh = ((qh0 >> idq) << 4) & 0x10;
|
| 439 |
const int q = (ql | qh);
|
|
|
|
| 68 |
const int iqs4 = k_KQ % QI4_0;
|
| 69 |
const int shift = k_KQ & (QI8_1/2);
|
| 70 |
|
| 71 |
+
const int v = (get_int_b2(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
|
| 72 |
const int u = Q_q8[k_KQ_0/WARP_SIZE];
|
| 73 |
|
| 74 |
const int sumi = ggml_cuda_dp4a(v, u, 0);
|
|
|
|
| 108 |
const int iqs4 = k_KQ % QI4_1;
|
| 109 |
const int shift = k_KQ & (QI8_1/2);
|
| 110 |
|
| 111 |
+
const int v = (get_int_b4(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
|
| 112 |
const int u = Q_q8[k_KQ_0/WARP_SIZE];
|
| 113 |
|
| 114 |
const int sumi = ggml_cuda_dp4a(v, u, 0);
|
|
|
|
| 153 |
const int iqs8 = k_KQ % QI8_1;
|
| 154 |
const int shift = k_KQ & (QI8_1/2);
|
| 155 |
|
| 156 |
+
int v = (get_int_b2(K_q5_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
|
| 157 |
+
const int vh = get_int_b2(K_q5_0[ib].qh, 0) >> (iqs8 * QI5_0);
|
| 158 |
v |= (vh << 4) & 0x00000010; // 0 -> 4
|
| 159 |
v |= (vh << 11) & 0x00001000; // 1 -> 12
|
| 160 |
v |= (vh << 18) & 0x00100000; // 2 -> 20
|
|
|
|
| 200 |
const int iqs8 = k_KQ % QI8_1;
|
| 201 |
const int shift = k_KQ & (QI8_1/2);
|
| 202 |
|
| 203 |
+
int v = (get_int_b2(K_q5_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
|
| 204 |
+
const int vh = get_int_b2(K_q5_1[ib].qh, 0) >> (iqs8 * QI5_1);
|
| 205 |
v |= (vh << 4) & 0x00000010; // 0 -> 4
|
| 206 |
v |= (vh << 11) & 0x00001000; // 1 -> 12
|
| 207 |
v |= (vh << 18) & 0x00100000; // 2 -> 20
|
|
|
|
| 249 |
const int ib = k_KQ / QI8_0;
|
| 250 |
const int iqs = k_KQ % QI8_0;
|
| 251 |
|
| 252 |
+
const int v = get_int_b2(K_q8_0[ib].qs, iqs);
|
| 253 |
|
| 254 |
T Q_d;
|
| 255 |
if (std::is_same<T, half>::value) {
|
|
|
|
| 408 |
|
| 409 |
const T d = x[ib].d;
|
| 410 |
const int ql0 = x[ib].qs[iqs];
|
| 411 |
+
const int qh0 = get_int_b2(x[ib].qh, 0);
|
| 412 |
const int ql = ((ql0 >> (4*shift)) & 0x0F);
|
| 413 |
const int qh = ((qh0 >> idq) << 4) & 0x10;
|
| 414 |
const int q = (ql | qh) - 16;
|
|
|
|
| 433 |
|
| 434 |
const half2 dm = x[ib].dm;
|
| 435 |
const int ql0 = x[ib].qs[iqs];
|
| 436 |
+
const int qh0 = get_int_b4(x[ib].qh, 0);
|
| 437 |
const int ql = ((ql0 >> (4*shift)) & 0x0F);
|
| 438 |
const int qh = ((qh0 >> idq) << 4) & 0x10;
|
| 439 |
const int q = (ql | qh);
|
ggml/src/ggml-cuda/mmq.cu
CHANGED
|
@@ -59,6 +59,12 @@ void ggml_cuda_op_mul_mat_q(
|
|
| 59 |
case GGML_TYPE_Q6_K:
|
| 60 |
mul_mat_q_case<GGML_TYPE_Q6_K>(ctx, args, stream);
|
| 61 |
break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
default:
|
| 63 |
GGML_ASSERT(false);
|
| 64 |
break;
|
|
@@ -87,6 +93,8 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
|
|
| 87 |
case GGML_TYPE_Q4_K:
|
| 88 |
case GGML_TYPE_Q5_K:
|
| 89 |
case GGML_TYPE_Q6_K:
|
|
|
|
|
|
|
| 90 |
mmq_supported = true;
|
| 91 |
break;
|
| 92 |
default:
|
|
|
|
| 59 |
case GGML_TYPE_Q6_K:
|
| 60 |
mul_mat_q_case<GGML_TYPE_Q6_K>(ctx, args, stream);
|
| 61 |
break;
|
| 62 |
+
case GGML_TYPE_IQ4_XS:
|
| 63 |
+
mul_mat_q_case<GGML_TYPE_IQ4_XS>(ctx, args, stream);
|
| 64 |
+
break;
|
| 65 |
+
case GGML_TYPE_IQ4_NL:
|
| 66 |
+
mul_mat_q_case<GGML_TYPE_IQ4_NL>(ctx, args, stream);
|
| 67 |
+
break;
|
| 68 |
default:
|
| 69 |
GGML_ASSERT(false);
|
| 70 |
break;
|
|
|
|
| 93 |
case GGML_TYPE_Q4_K:
|
| 94 |
case GGML_TYPE_Q5_K:
|
| 95 |
case GGML_TYPE_Q6_K:
|
| 96 |
+
case GGML_TYPE_IQ4_XS:
|
| 97 |
+
case GGML_TYPE_IQ4_NL:
|
| 98 |
mmq_supported = true;
|
| 99 |
break;
|
| 100 |
default:
|
ggml/src/ggml-cuda/mmq.cuh
CHANGED
|
@@ -92,15 +92,17 @@ static constexpr __device__ int get_mmq_y_device() {
|
|
| 92 |
|
| 93 |
static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
|
| 94 |
return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 :
|
| 95 |
-
type == GGML_TYPE_Q4_1
|
| 96 |
-
type == GGML_TYPE_Q5_0
|
| 97 |
-
type == GGML_TYPE_Q5_1
|
| 98 |
-
type == GGML_TYPE_Q8_0
|
| 99 |
-
type == GGML_TYPE_Q2_K
|
| 100 |
-
type == GGML_TYPE_Q3_K
|
| 101 |
-
type == GGML_TYPE_Q4_K
|
| 102 |
-
type == GGML_TYPE_Q5_K
|
| 103 |
-
type == GGML_TYPE_Q6_K
|
|
|
|
|
|
|
| 104 |
tile_x_sizes{0, 0, 0};
|
| 105 |
}
|
| 106 |
|
|
@@ -128,15 +130,17 @@ static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
|
|
| 128 |
|
| 129 |
static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
| 130 |
return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q4_0 :
|
| 131 |
-
type == GGML_TYPE_Q4_1
|
| 132 |
-
type == GGML_TYPE_Q5_0
|
| 133 |
-
type == GGML_TYPE_Q5_1
|
| 134 |
-
type == GGML_TYPE_Q8_0
|
| 135 |
-
type == GGML_TYPE_Q2_K
|
| 136 |
-
type == GGML_TYPE_Q3_K
|
| 137 |
-
type == GGML_TYPE_Q4_K
|
| 138 |
-
type == GGML_TYPE_Q5_K
|
| 139 |
-
type == GGML_TYPE_Q6_K
|
|
|
|
|
|
|
| 140 |
0;
|
| 141 |
}
|
| 142 |
|
|
@@ -185,9 +189,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 185 |
const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
|
| 186 |
|
| 187 |
#ifdef INT8_MMA_AVAILABLE
|
| 188 |
-
x_qs[i*MMQ_MMA_TILE_X_K_Q4_0 + threadIdx.x] =
|
| 189 |
#else
|
| 190 |
-
x_qs[i*(WARP_SIZE + 1) + threadIdx.x] =
|
| 191 |
#endif // INT8_MMA_AVAILABLE
|
| 192 |
}
|
| 193 |
|
|
@@ -348,9 +352,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 348 |
const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
|
| 349 |
|
| 350 |
#ifdef INT8_MMA_AVAILABLE
|
| 351 |
-
x_qs[i*MMQ_MMA_TILE_X_K_Q4_1 + threadIdx.x] =
|
| 352 |
#else
|
| 353 |
-
x_qs[i*(WARP_SIZE + 1) + threadIdx.x] =
|
| 354 |
#endif // INT8_MMA_AVAILABLE
|
| 355 |
}
|
| 356 |
|
|
@@ -509,8 +513,8 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 509 |
|
| 510 |
const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbx;
|
| 511 |
|
| 512 |
-
const int ql =
|
| 513 |
-
const int qh =
|
| 514 |
|
| 515 |
int qs0 = (ql >> 0) & 0x0F0F0F0F;
|
| 516 |
qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
|
|
@@ -674,8 +678,8 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 674 |
|
| 675 |
const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbx;
|
| 676 |
|
| 677 |
-
const int ql =
|
| 678 |
-
const int qh =
|
| 679 |
|
| 680 |
int qs0 = (ql >> 0) & 0x0F0F0F0F;
|
| 681 |
qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
|
|
@@ -839,9 +843,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 839 |
const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
|
| 840 |
|
| 841 |
#ifdef INT8_MMA_AVAILABLE
|
| 842 |
-
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x] =
|
| 843 |
#else
|
| 844 |
-
x_qs[i*(WARP_SIZE + 1) + threadIdx.x] =
|
| 845 |
#endif // INT8_MMA_AVAILABLE
|
| 846 |
}
|
| 847 |
|
|
@@ -984,7 +988,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 984 |
|
| 985 |
const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + kbx;
|
| 986 |
|
| 987 |
-
const int x_ql_0 =
|
| 988 |
|
| 989 |
#pragma unroll
|
| 990 |
for (int l = 0; l < QR2_K; ++l) {
|
|
@@ -1166,8 +1170,8 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 1166 |
|
| 1167 |
const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbx;
|
| 1168 |
|
| 1169 |
-
const int x_ql_0 =
|
| 1170 |
-
const int x_qh_0 =
|
| 1171 |
|
| 1172 |
#pragma unroll
|
| 1173 |
for (int l = 0; l < QR3_K; ++l) {
|
|
@@ -1225,11 +1229,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 1225 |
|
| 1226 |
const int ksc_low = ksc % (QI3_K/8);
|
| 1227 |
const int shift_low = 4 * (ksc / (QI3_K/8));
|
| 1228 |
-
const int sc_low = (
|
| 1229 |
|
| 1230 |
const int ksc_high = QI3_K/8;
|
| 1231 |
const int shift_high = 2 * ksc;
|
| 1232 |
-
const int sc_high = ((
|
| 1233 |
|
| 1234 |
const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
|
| 1235 |
|
|
@@ -1393,9 +1397,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 1393 |
const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbx;
|
| 1394 |
|
| 1395 |
#ifdef INT8_MMA_AVAILABLE
|
| 1396 |
-
x_qs[i*MMQ_MMA_TILE_X_K_Q4_K + threadIdx.x] =
|
| 1397 |
#else
|
| 1398 |
-
x_qs[i*(WARP_SIZE + 1) + threadIdx.x] =
|
| 1399 |
#endif // INT8_MMA_AVAILABLE
|
| 1400 |
}
|
| 1401 |
|
|
@@ -1610,11 +1614,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 1610 |
const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + kbx;
|
| 1611 |
const int ky = QR5_K*kqsx;
|
| 1612 |
|
| 1613 |
-
const int ql =
|
| 1614 |
const int ql0 = (ql >> 0) & 0x0F0F0F0F;
|
| 1615 |
const int ql1 = (ql >> 4) & 0x0F0F0F0F;
|
| 1616 |
|
| 1617 |
-
const int qh =
|
| 1618 |
const int qh0 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 0)) << 4) & 0x10101010;
|
| 1619 |
const int qh1 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 1)) << 4) & 0x10101010;
|
| 1620 |
|
|
@@ -1832,11 +1836,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 1832 |
const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbx;
|
| 1833 |
const int ky = QR6_K*kqsx;
|
| 1834 |
|
| 1835 |
-
const int ql =
|
| 1836 |
const int ql0 = (ql >> 0) & 0x0F0F0F0F;
|
| 1837 |
const int ql1 = (ql >> 4) & 0x0F0F0F0F;
|
| 1838 |
|
| 1839 |
-
const int qh =
|
| 1840 |
const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) << 4) & 0x30303030;
|
| 1841 |
const int qh1 = (qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) & 0x30303030;
|
| 1842 |
|
|
@@ -1883,9 +1887,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 1883 |
const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / 4;
|
| 1884 |
|
| 1885 |
#ifdef INT8_MMA_AVAILABLE
|
| 1886 |
-
x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x % (WARP_SIZE/8)] =
|
| 1887 |
#else
|
| 1888 |
-
x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] =
|
| 1889 |
#endif // INT8_MMA_AVAILABLE
|
| 1890 |
}
|
| 1891 |
}
|
|
@@ -2018,6 +2022,124 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|
| 2018 |
#endif // INT8_MMA_AVAILABLE
|
| 2019 |
}
|
| 2020 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2021 |
template<int mmq_x, int mmq_y, int nwarps, bool need_check>
|
| 2022 |
static __device__ __forceinline__ void mmq_write_back_dp4a(
|
| 2023 |
const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
|
|
@@ -2167,6 +2289,22 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
|
|
| 2167 |
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
| 2168 |
};
|
| 2169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2170 |
static bool mmq_need_sum(const ggml_type type_x) {
|
| 2171 |
switch (type_x) {
|
| 2172 |
case GGML_TYPE_Q4_0:
|
|
@@ -2184,6 +2322,8 @@ static bool mmq_need_sum(const ggml_type type_x) {
|
|
| 2184 |
case GGML_TYPE_Q5_K:
|
| 2185 |
return true;
|
| 2186 |
case GGML_TYPE_Q6_K:
|
|
|
|
|
|
|
| 2187 |
return false;
|
| 2188 |
default:
|
| 2189 |
GGML_ASSERT(false);
|
|
@@ -2608,6 +2748,8 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
|
|
| 2608 |
extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
|
| 2609 |
extern DECL_MMQ_CASE(GGML_TYPE_Q5_K);
|
| 2610 |
extern DECL_MMQ_CASE(GGML_TYPE_Q6_K);
|
|
|
|
|
|
|
| 2611 |
|
| 2612 |
// -------------------------------------------------------------------------------------------------------------------------
|
| 2613 |
|
|
|
|
| 92 |
|
| 93 |
static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
|
| 94 |
return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 :
|
| 95 |
+
type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 :
|
| 96 |
+
type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q5_0 :
|
| 97 |
+
type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q5_1 :
|
| 98 |
+
type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 :
|
| 99 |
+
type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K :
|
| 100 |
+
type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K :
|
| 101 |
+
type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K :
|
| 102 |
+
type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K :
|
| 103 |
+
type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K :
|
| 104 |
+
type == GGML_TYPE_IQ4_XS ? MMQ_DP4A_TXS_Q5_0 :
|
| 105 |
+
type == GGML_TYPE_IQ4_NL ? MMQ_DP4A_TXS_Q5_0 :
|
| 106 |
tile_x_sizes{0, 0, 0};
|
| 107 |
}
|
| 108 |
|
|
|
|
| 130 |
|
| 131 |
static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
| 132 |
return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q4_0 :
|
| 133 |
+
type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q4_1 :
|
| 134 |
+
type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q5_0 :
|
| 135 |
+
type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q5_1 :
|
| 136 |
+
type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
|
| 137 |
+
type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K :
|
| 138 |
+
type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K :
|
| 139 |
+
type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q4_K :
|
| 140 |
+
type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q5_K :
|
| 141 |
+
type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K :
|
| 142 |
+
type == GGML_TYPE_IQ4_XS ? MMQ_MMA_TILE_X_K_Q5_0 :
|
| 143 |
+
type == GGML_TYPE_IQ4_NL ? MMQ_MMA_TILE_X_K_Q5_0 :
|
| 144 |
0;
|
| 145 |
}
|
| 146 |
|
|
|
|
| 189 |
const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
|
| 190 |
|
| 191 |
#ifdef INT8_MMA_AVAILABLE
|
| 192 |
+
x_qs[i*MMQ_MMA_TILE_X_K_Q4_0 + threadIdx.x] = get_int_b2(bxi->qs, kqsx);
|
| 193 |
#else
|
| 194 |
+
x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_b2(bxi->qs, kqsx);
|
| 195 |
#endif // INT8_MMA_AVAILABLE
|
| 196 |
}
|
| 197 |
|
|
|
|
| 352 |
const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
|
| 353 |
|
| 354 |
#ifdef INT8_MMA_AVAILABLE
|
| 355 |
+
x_qs[i*MMQ_MMA_TILE_X_K_Q4_1 + threadIdx.x] = get_int_b4(bxi->qs, kqsx);
|
| 356 |
#else
|
| 357 |
+
x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_b4(bxi->qs, kqsx);
|
| 358 |
#endif // INT8_MMA_AVAILABLE
|
| 359 |
}
|
| 360 |
|
|
|
|
| 513 |
|
| 514 |
const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbx;
|
| 515 |
|
| 516 |
+
const int ql = get_int_b2(bxi->qs, kqsx);
|
| 517 |
+
const int qh = get_int_b2(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_0));
|
| 518 |
|
| 519 |
int qs0 = (ql >> 0) & 0x0F0F0F0F;
|
| 520 |
qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
|
|
|
|
| 678 |
|
| 679 |
const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbx;
|
| 680 |
|
| 681 |
+
const int ql = get_int_b4(bxi->qs, kqsx);
|
| 682 |
+
const int qh = get_int_b4(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_1));
|
| 683 |
|
| 684 |
int qs0 = (ql >> 0) & 0x0F0F0F0F;
|
| 685 |
qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
|
|
|
|
| 843 |
const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
|
| 844 |
|
| 845 |
#ifdef INT8_MMA_AVAILABLE
|
| 846 |
+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x] = get_int_b2(bxi->qs, kqsx);
|
| 847 |
#else
|
| 848 |
+
x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_b2(bxi->qs, kqsx);
|
| 849 |
#endif // INT8_MMA_AVAILABLE
|
| 850 |
}
|
| 851 |
|
|
|
|
| 988 |
|
| 989 |
const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + kbx;
|
| 990 |
|
| 991 |
+
const int x_ql_0 = get_int_b2(bxi->qs, kqsx);
|
| 992 |
|
| 993 |
#pragma unroll
|
| 994 |
for (int l = 0; l < QR2_K; ++l) {
|
|
|
|
| 1170 |
|
| 1171 |
const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbx;
|
| 1172 |
|
| 1173 |
+
const int x_ql_0 = get_int_b2(bxi->qs, kqsx);
|
| 1174 |
+
const int x_qh_0 = get_int_b2(bxi->hmask, kqsx % (QI3_K/2)) >> (4 * (kqsx / (QI3_K/2)));
|
| 1175 |
|
| 1176 |
#pragma unroll
|
| 1177 |
for (int l = 0; l < QR3_K; ++l) {
|
|
|
|
| 1229 |
|
| 1230 |
const int ksc_low = ksc % (QI3_K/8);
|
| 1231 |
const int shift_low = 4 * (ksc / (QI3_K/8));
|
| 1232 |
+
const int sc_low = (get_int_b2(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F;
|
| 1233 |
|
| 1234 |
const int ksc_high = QI3_K/8;
|
| 1235 |
const int shift_high = 2 * ksc;
|
| 1236 |
+
const int sc_high = ((get_int_b2(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030;
|
| 1237 |
|
| 1238 |
const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
|
| 1239 |
|
|
|
|
| 1397 |
const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbx;
|
| 1398 |
|
| 1399 |
#ifdef INT8_MMA_AVAILABLE
|
| 1400 |
+
x_qs[i*MMQ_MMA_TILE_X_K_Q4_K + threadIdx.x] = get_int_b4(bxi->qs, kqsx);
|
| 1401 |
#else
|
| 1402 |
+
x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_b4(bxi->qs, kqsx);
|
| 1403 |
#endif // INT8_MMA_AVAILABLE
|
| 1404 |
}
|
| 1405 |
|
|
|
|
| 1614 |
const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + kbx;
|
| 1615 |
const int ky = QR5_K*kqsx;
|
| 1616 |
|
| 1617 |
+
const int ql = get_int_b4(bxi->qs, kqsx);
|
| 1618 |
const int ql0 = (ql >> 0) & 0x0F0F0F0F;
|
| 1619 |
const int ql1 = (ql >> 4) & 0x0F0F0F0F;
|
| 1620 |
|
| 1621 |
+
const int qh = get_int_b4(bxi->qh, kqsx % (QI5_K/4));
|
| 1622 |
const int qh0 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 0)) << 4) & 0x10101010;
|
| 1623 |
const int qh1 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 1)) << 4) & 0x10101010;
|
| 1624 |
|
|
|
|
| 1836 |
const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbx;
|
| 1837 |
const int ky = QR6_K*kqsx;
|
| 1838 |
|
| 1839 |
+
const int ql = get_int_b2(bxi->ql, kqsx);
|
| 1840 |
const int ql0 = (ql >> 0) & 0x0F0F0F0F;
|
| 1841 |
const int ql1 = (ql >> 4) & 0x0F0F0F0F;
|
| 1842 |
|
| 1843 |
+
const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4));
|
| 1844 |
const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) << 4) & 0x30303030;
|
| 1845 |
const int qh1 = (qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) & 0x30303030;
|
| 1846 |
|
|
|
|
| 1887 |
const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / 4;
|
| 1888 |
|
| 1889 |
#ifdef INT8_MMA_AVAILABLE
|
| 1890 |
+
x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8));
|
| 1891 |
#else
|
| 1892 |
+
x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8));
|
| 1893 |
#endif // INT8_MMA_AVAILABLE
|
| 1894 |
}
|
| 1895 |
}
|
|
|
|
| 2022 |
#endif // INT8_MMA_AVAILABLE
|
| 2023 |
}
|
| 2024 |
|
| 2025 |
+
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl(
|
| 2026 |
+
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
| 2027 |
+
|
| 2028 |
+
#ifdef INT8_MMA_AVAILABLE
|
| 2029 |
+
int * x_qs = (int *) x_tile;
|
| 2030 |
+
float * x_df = (float *) (x_qs + WARP_SIZE*2);
|
| 2031 |
+
#else
|
| 2032 |
+
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y);
|
| 2033 |
+
int * x_qs = (int *) x_tile;
|
| 2034 |
+
float * x_df = (float *) (x_qs + txs.qs);
|
| 2035 |
+
#endif // INT8_MMA_AVAILABLE
|
| 2036 |
+
|
| 2037 |
+
const int kbx = threadIdx.x / QI4_NL;
|
| 2038 |
+
const int kqsx = threadIdx.x % QI4_NL;
|
| 2039 |
+
|
| 2040 |
+
#pragma unroll
|
| 2041 |
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
| 2042 |
+
int i = i0 + threadIdx.y;
|
| 2043 |
+
|
| 2044 |
+
if (need_check) {
|
| 2045 |
+
i = min(i, i_max);
|
| 2046 |
+
}
|
| 2047 |
+
|
| 2048 |
+
const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbx;
|
| 2049 |
+
|
| 2050 |
+
const int aux_q4 = get_int_b2(bxi->qs, kqsx);
|
| 2051 |
+
const int2 v = get_int_from_table_16(aux_q4);
|
| 2052 |
+
const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
|
| 2053 |
+
#ifdef INT8_MMA_AVAILABLE
|
| 2054 |
+
x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + k0 + 0] = v.x;
|
| 2055 |
+
x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + k0 + 4] = v.y;
|
| 2056 |
+
#else
|
| 2057 |
+
x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
|
| 2058 |
+
x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
|
| 2059 |
+
#endif // INT8_MMA_AVAILABLE
|
| 2060 |
+
}
|
| 2061 |
+
|
| 2062 |
+
const int blocks_per_tile_x_row = WARP_SIZE / QI4_NL;
|
| 2063 |
+
const int kbxd = threadIdx.x % blocks_per_tile_x_row;
|
| 2064 |
+
|
| 2065 |
+
#pragma unroll
|
| 2066 |
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_NL) {
|
| 2067 |
+
int i = i0 + threadIdx.y * QI4_NL + threadIdx.x / blocks_per_tile_x_row;
|
| 2068 |
+
|
| 2069 |
+
if (need_check) {
|
| 2070 |
+
i = min(i, i_max);
|
| 2071 |
+
}
|
| 2072 |
+
|
| 2073 |
+
const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd;
|
| 2074 |
+
|
| 2075 |
+
#ifdef INT8_MMA_AVAILABLE
|
| 2076 |
+
x_df[i*MMQ_MMA_TILE_X_K_Q5_0 + kbxd] = __half2float(bxi->d);
|
| 2077 |
+
#else
|
| 2078 |
+
x_df[i*(WARP_SIZE/4) + i/4 + kbxd] = __half2float(bxi->d);
|
| 2079 |
+
#endif // INT8_MMA_AVAILABLE
|
| 2080 |
+
}
|
| 2081 |
+
}
|
| 2082 |
+
|
| 2083 |
+
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_xs(
|
| 2084 |
+
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
| 2085 |
+
|
| 2086 |
+
#ifdef INT8_MMA_AVAILABLE
|
| 2087 |
+
int * x_qs = (int *) x_tile;
|
| 2088 |
+
float * x_df = (float *) (x_qs + WARP_SIZE*2);
|
| 2089 |
+
#else
|
| 2090 |
+
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
|
| 2091 |
+
int * x_qs = (int *) x_tile;
|
| 2092 |
+
float * x_df = (float *) (x_qs + txs.qs);
|
| 2093 |
+
#endif // INT8_MMA_AVAILABLE
|
| 2094 |
+
|
| 2095 |
+
const int kbx = 0; // threadIdx.x / QI4_XS
|
| 2096 |
+
const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS
|
| 2097 |
+
|
| 2098 |
+
#pragma unroll
|
| 2099 |
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
| 2100 |
+
int i = i0 + threadIdx.y;
|
| 2101 |
+
|
| 2102 |
+
if (need_check) {
|
| 2103 |
+
i = min(i, i_max);
|
| 2104 |
+
}
|
| 2105 |
+
|
| 2106 |
+
const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride + kbx;
|
| 2107 |
+
|
| 2108 |
+
const int aux_q4 = get_int_b4(bxi->qs, kqsx);
|
| 2109 |
+
const int2 v = get_int_from_table_16(aux_q4);
|
| 2110 |
+
const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
|
| 2111 |
+
#ifdef INT8_MMA_AVAILABLE
|
| 2112 |
+
x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + k0 + 0] = v.x;
|
| 2113 |
+
x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + k0 + 4] = v.y;
|
| 2114 |
+
#else
|
| 2115 |
+
x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
|
| 2116 |
+
x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
|
| 2117 |
+
#endif // INT8_MMA_AVAILABLE
|
| 2118 |
+
}
|
| 2119 |
+
|
| 2120 |
+
#pragma unroll
|
| 2121 |
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
|
| 2122 |
+
int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
|
| 2123 |
+
|
| 2124 |
+
if (need_check) {
|
| 2125 |
+
i = min(i, i_max);
|
| 2126 |
+
}
|
| 2127 |
+
|
| 2128 |
+
const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride;
|
| 2129 |
+
|
| 2130 |
+
const float d = __half2float(bxi->d);
|
| 2131 |
+
|
| 2132 |
+
const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F)
|
| 2133 |
+
| (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4);
|
| 2134 |
+
|
| 2135 |
+
#ifdef INT8_MMA_AVAILABLE
|
| 2136 |
+
x_df[i*MMQ_MMA_TILE_X_K_Q5_0 + threadIdx.x % 8] = d * (ls - 32);
|
| 2137 |
+
#else
|
| 2138 |
+
x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = d * (ls - 32);
|
| 2139 |
+
#endif // INT8_MMA_AVAILABLE
|
| 2140 |
+
}
|
| 2141 |
+
}
|
| 2142 |
+
|
| 2143 |
template<int mmq_x, int mmq_y, int nwarps, bool need_check>
|
| 2144 |
static __device__ __forceinline__ void mmq_write_back_dp4a(
|
| 2145 |
const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
|
|
|
|
| 2289 |
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
| 2290 |
};
|
| 2291 |
|
| 2292 |
+
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
| 2293 |
+
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_NL> {
|
| 2294 |
+
static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ;
|
| 2295 |
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl<mmq_y, nwarps, need_check>;
|
| 2296 |
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q5_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
|
| 2297 |
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
| 2298 |
+
};
|
| 2299 |
+
|
| 2300 |
+
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
| 2301 |
+
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_XS> {
|
| 2302 |
+
static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ;
|
| 2303 |
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs<mmq_y, nwarps, need_check>;
|
| 2304 |
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q5_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
|
| 2305 |
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
| 2306 |
+
};
|
| 2307 |
+
|
| 2308 |
static bool mmq_need_sum(const ggml_type type_x) {
|
| 2309 |
switch (type_x) {
|
| 2310 |
case GGML_TYPE_Q4_0:
|
|
|
|
| 2322 |
case GGML_TYPE_Q5_K:
|
| 2323 |
return true;
|
| 2324 |
case GGML_TYPE_Q6_K:
|
| 2325 |
+
case GGML_TYPE_IQ4_XS:
|
| 2326 |
+
case GGML_TYPE_IQ4_NL:
|
| 2327 |
return false;
|
| 2328 |
default:
|
| 2329 |
GGML_ASSERT(false);
|
|
|
|
| 2748 |
extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
|
| 2749 |
extern DECL_MMQ_CASE(GGML_TYPE_Q5_K);
|
| 2750 |
extern DECL_MMQ_CASE(GGML_TYPE_Q6_K);
|
| 2751 |
+
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL);
|
| 2752 |
+
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS);
|
| 2753 |
|
| 2754 |
// -------------------------------------------------------------------------------------------------------------------------
|
| 2755 |
|
ggml/src/ggml-cuda/template-instances/generate_cu_files.py
CHANGED
|
@@ -22,7 +22,8 @@ SOURCE_FATTN_WMMA_CASE = "DECL_FATTN_WMMA_F16_CASE({head_size}, {cols_per_block}
|
|
| 22 |
|
| 23 |
TYPES_MMQ = [
|
| 24 |
"GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
|
| 25 |
-
"GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K"
|
|
|
|
| 26 |
]
|
| 27 |
|
| 28 |
SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
|
|
|
| 22 |
|
| 23 |
TYPES_MMQ = [
|
| 24 |
"GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
|
| 25 |
+
"GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K",
|
| 26 |
+
"GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS"
|
| 27 |
]
|
| 28 |
|
| 29 |
SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
| 2 |
+
|
| 3 |
+
#include "../mmq.cuh"
|
| 4 |
+
|
| 5 |
+
DECL_MMQ_CASE(GGML_TYPE_IQ4_NL);
|
ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_xs.cu
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
| 2 |
+
|
| 3 |
+
#include "../mmq.cuh"
|
| 4 |
+
|
| 5 |
+
DECL_MMQ_CASE(GGML_TYPE_IQ4_XS);
|
ggml/src/ggml-cuda/vecdotq.cuh
CHANGED
|
@@ -1,36 +1,8 @@
|
|
| 1 |
#include "common.cuh"
|
| 2 |
#include <cstdint>
|
| 3 |
|
| 4 |
-
static __device__ __forceinline__ int get_int_from_int8(const int8_t * x8, const int & i32) {
|
| 5 |
-
const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment
|
| 6 |
-
|
| 7 |
-
int x32 = 0;
|
| 8 |
-
x32 |= x16[0] << 0;
|
| 9 |
-
x32 |= x16[1] << 16;
|
| 10 |
-
|
| 11 |
-
return x32;
|
| 12 |
-
}
|
| 13 |
-
|
| 14 |
-
static __device__ __forceinline__ int get_int_from_uint8(const uint8_t * x8, const int & i32) {
|
| 15 |
-
const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment
|
| 16 |
-
|
| 17 |
-
int x32 = 0;
|
| 18 |
-
x32 |= x16[0] << 0;
|
| 19 |
-
x32 |= x16[1] << 16;
|
| 20 |
-
|
| 21 |
-
return x32;
|
| 22 |
-
}
|
| 23 |
-
|
| 24 |
-
static __device__ __forceinline__ int get_int_from_int8_aligned(const int8_t * x8, const int & i32) {
|
| 25 |
-
return *((const int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
|
| 26 |
-
}
|
| 27 |
-
|
| 28 |
-
static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t * x8, const int & i32) {
|
| 29 |
-
return *((const int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
|
| 30 |
-
}
|
| 31 |
-
|
| 32 |
static __device__ __forceinline__ int get_int_b2(const void * x, const int & i32) {
|
| 33 |
-
const uint16_t * x16 = (const uint16_t *) x;
|
| 34 |
|
| 35 |
int x32 = x16[2*i32 + 0] << 0;
|
| 36 |
x32 |= x16[2*i32 + 1] << 16;
|
|
@@ -768,6 +740,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
|
|
| 768 |
}
|
| 769 |
|
| 770 |
#define VDR_IQ2_XXS_Q8_1_MMVQ 2
|
|
|
|
| 771 |
|
| 772 |
static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
|
| 773 |
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
|
@@ -802,6 +775,7 @@ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
|
|
| 802 |
}
|
| 803 |
|
| 804 |
#define VDR_IQ2_XS_Q8_1_MMVQ 2
|
|
|
|
| 805 |
|
| 806 |
static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
|
| 807 |
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
|
@@ -840,6 +814,7 @@ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
|
|
| 840 |
}
|
| 841 |
|
| 842 |
#define VDR_IQ2_S_Q8_1_MMVQ 2
|
|
|
|
| 843 |
|
| 844 |
static __device__ __forceinline__ float vec_dot_iq2_s_q8_1(
|
| 845 |
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
|
@@ -887,6 +862,7 @@ static __device__ __forceinline__ float vec_dot_iq2_s_q8_1(
|
|
| 887 |
}
|
| 888 |
|
| 889 |
#define VDR_IQ3_XXS_Q8_1_MMVQ 2
|
|
|
|
| 890 |
|
| 891 |
static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
|
| 892 |
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
|
@@ -921,6 +897,7 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
|
|
| 921 |
}
|
| 922 |
|
| 923 |
#define VDR_IQ3_S_Q8_1_MMVQ 2
|
|
|
|
| 924 |
|
| 925 |
// TODO: don't use lookup table for signs
|
| 926 |
static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
|
|
@@ -962,6 +939,9 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
|
|
| 962 |
return d * sumi;
|
| 963 |
}
|
| 964 |
|
|
|
|
|
|
|
|
|
|
| 965 |
static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
|
| 966 |
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
| 967 |
const block_iq1_s * bq1 = (const block_iq1_s *) vbq + kbx;
|
|
@@ -992,6 +972,9 @@ static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
|
|
| 992 |
return d1q * (ds.x*sumi + ds.y*delta);
|
| 993 |
}
|
| 994 |
|
|
|
|
|
|
|
|
|
|
| 995 |
static __device__ __forceinline__ float vec_dot_iq1_m_q8_1(
|
| 996 |
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
| 997 |
|
|
@@ -1051,6 +1034,7 @@ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4) {
|
|
| 1051 |
}
|
| 1052 |
|
| 1053 |
#define VDR_IQ4_NL_Q8_1_MMVQ 2
|
|
|
|
| 1054 |
|
| 1055 |
static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(
|
| 1056 |
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
|
@@ -1074,6 +1058,7 @@ static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(
|
|
| 1074 |
}
|
| 1075 |
|
| 1076 |
#define VDR_IQ4_XS_Q8_1_MMVQ 4
|
|
|
|
| 1077 |
|
| 1078 |
static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1(
|
| 1079 |
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
|
|
|
| 1 |
#include "common.cuh"
|
| 2 |
#include <cstdint>
|
| 3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
static __device__ __forceinline__ int get_int_b2(const void * x, const int & i32) {
|
| 5 |
+
const uint16_t * x16 = (const uint16_t *) x; // assume at least 2 byte alignment
|
| 6 |
|
| 7 |
int x32 = x16[2*i32 + 0] << 0;
|
| 8 |
x32 |= x16[2*i32 + 1] << 16;
|
|
|
|
| 740 |
}
|
| 741 |
|
| 742 |
#define VDR_IQ2_XXS_Q8_1_MMVQ 2
|
| 743 |
+
#define VDR_IQ2_XXS_Q8_1_MMQ 2
|
| 744 |
|
| 745 |
static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
|
| 746 |
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
|
|
|
| 775 |
}
|
| 776 |
|
| 777 |
#define VDR_IQ2_XS_Q8_1_MMVQ 2
|
| 778 |
+
#define VDR_IQ2_XS_Q8_1_MMQ 2
|
| 779 |
|
| 780 |
static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
|
| 781 |
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
|
|
|
| 814 |
}
|
| 815 |
|
| 816 |
#define VDR_IQ2_S_Q8_1_MMVQ 2
|
| 817 |
+
#define VDR_IQ2_S_Q8_1_MMQ 2
|
| 818 |
|
| 819 |
static __device__ __forceinline__ float vec_dot_iq2_s_q8_1(
|
| 820 |
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
|
|
|
| 862 |
}
|
| 863 |
|
| 864 |
#define VDR_IQ3_XXS_Q8_1_MMVQ 2
|
| 865 |
+
#define VDR_IQ3_XXS_Q8_1_MMQ 2
|
| 866 |
|
| 867 |
static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
|
| 868 |
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
|
|
|
| 897 |
}
|
| 898 |
|
| 899 |
#define VDR_IQ3_S_Q8_1_MMVQ 2
|
| 900 |
+
#define VDR_IQ3_S_Q8_1_MMQ 2
|
| 901 |
|
| 902 |
// TODO: don't use lookup table for signs
|
| 903 |
static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
|
|
|
|
| 939 |
return d * sumi;
|
| 940 |
}
|
| 941 |
|
| 942 |
+
#define VDR_IQ1_S_Q8_1_MMVQ 1
|
| 943 |
+
#define VDR_IQ1_S_Q8_1_MMQ 1
|
| 944 |
+
|
| 945 |
static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
|
| 946 |
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
| 947 |
const block_iq1_s * bq1 = (const block_iq1_s *) vbq + kbx;
|
|
|
|
| 972 |
return d1q * (ds.x*sumi + ds.y*delta);
|
| 973 |
}
|
| 974 |
|
| 975 |
+
#define VDR_IQ1_M_Q8_1_MMVQ 1
|
| 976 |
+
#define VDR_IQ1_M_Q8_1_MMQ 1
|
| 977 |
+
|
| 978 |
static __device__ __forceinline__ float vec_dot_iq1_m_q8_1(
|
| 979 |
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
| 980 |
|
|
|
|
| 1034 |
}
|
| 1035 |
|
| 1036 |
#define VDR_IQ4_NL_Q8_1_MMVQ 2
|
| 1037 |
+
#define VDR_IQ4_NL_Q8_1_MMQ 4
|
| 1038 |
|
| 1039 |
static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(
|
| 1040 |
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
|
|
|
| 1058 |
}
|
| 1059 |
|
| 1060 |
#define VDR_IQ4_XS_Q8_1_MMVQ 4
|
| 1061 |
+
#define VDR_IQ4_XS_Q8_1_MMQ 4
|
| 1062 |
|
| 1063 |
static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1(
|
| 1064 |
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|