Spaces:
Running
Running
Commit
·
5931562
1
Parent(s):
d4b3604
CUDA: faster q2_K, q3_K MMQ + int8 tensor cores (llama/7921)
Browse files* CUDA: faster q2_K, q3_K MMQ + int8 tensor cores
* try CI fix
* try CI fix
* try CI fix
* fix data race
* rever q2_K precision related changes
- ggml-cuda.cu +4 -2
- ggml-cuda/argsort.cu +1 -0
- ggml-cuda/common.cuh +5 -0
- ggml-cuda/mmq.cuh +430 -298
- ggml-cuda/softmax.cu +1 -0
- ggml-cuda/vecdotq.cuh +16 -19
ggml-cuda.cu
CHANGED
|
@@ -188,13 +188,15 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
|
| 188 |
info.default_tensor_split[id] = total_vram;
|
| 189 |
total_vram += prop.totalGlobalMem;
|
| 190 |
|
|
|
|
|
|
|
| 191 |
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
|
|
|
| 192 |
info.devices[id].cc = 100*prop.major + 10*prop.minor + CC_OFFSET_AMD;
|
| 193 |
#else
|
|
|
|
| 194 |
info.devices[id].cc = 100*prop.major + 10*prop.minor;
|
| 195 |
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
| 196 |
-
info.devices[id].smpb = prop.sharedMemPerBlock;
|
| 197 |
-
info.devices[id].nsm = prop.multiProcessorCount;
|
| 198 |
}
|
| 199 |
|
| 200 |
for (int id = 0; id < info.device_count; ++id) {
|
|
|
|
| 188 |
info.default_tensor_split[id] = total_vram;
|
| 189 |
total_vram += prop.totalGlobalMem;
|
| 190 |
|
| 191 |
+
info.devices[id].nsm = prop.multiProcessorCount;
|
| 192 |
+
info.devices[id].smpb = prop.sharedMemPerBlock;
|
| 193 |
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
| 194 |
+
info.devices[id].smpbo = prop.sharedMemPerBlock;
|
| 195 |
info.devices[id].cc = 100*prop.major + 10*prop.minor + CC_OFFSET_AMD;
|
| 196 |
#else
|
| 197 |
+
info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
|
| 198 |
info.devices[id].cc = 100*prop.major + 10*prop.minor;
|
| 199 |
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
|
|
|
|
|
|
| 200 |
}
|
| 201 |
|
| 202 |
for (int id = 0; id < info.device_count; ++id) {
|
ggml-cuda/argsort.cu
CHANGED
|
@@ -73,6 +73,7 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co
|
|
| 73 |
const dim3 block_nums(1, nrows, 1);
|
| 74 |
const size_t shared_mem = ncols_pad * sizeof(int);
|
| 75 |
|
|
|
|
| 76 |
GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
|
| 77 |
|
| 78 |
if (order == GGML_SORT_ORDER_ASC) {
|
|
|
|
| 73 |
const dim3 block_nums(1, nrows, 1);
|
| 74 |
const size_t shared_mem = ncols_pad * sizeof(int);
|
| 75 |
|
| 76 |
+
// FIXME: this limit could be raised by ~2-4x on Ampere or newer
|
| 77 |
GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
|
| 78 |
|
| 79 |
if (order == GGML_SORT_ORDER_ASC) {
|
ggml-cuda/common.cuh
CHANGED
|
@@ -331,6 +331,10 @@ static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int
|
|
| 331 |
#define FP16_AVAILABLE
|
| 332 |
#endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
|
| 333 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
|
| 335 |
#define FP16_MMA_AVAILABLE
|
| 336 |
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
|
|
@@ -661,6 +665,7 @@ struct ggml_cuda_device_info {
|
|
| 661 |
int cc; // compute capability
|
| 662 |
int nsm; // number of streaming multiprocessors
|
| 663 |
size_t smpb; // max. shared memory per block
|
|
|
|
| 664 |
bool vmm; // virtual memory support
|
| 665 |
size_t vmm_granularity; // granularity of virtual memory
|
| 666 |
size_t total_vram;
|
|
|
|
| 331 |
#define FP16_AVAILABLE
|
| 332 |
#endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
|
| 333 |
|
| 334 |
+
#if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
|
| 335 |
+
#define FAST_FP16_AVAILABLE
|
| 336 |
+
#endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
|
| 337 |
+
|
| 338 |
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
|
| 339 |
#define FP16_MMA_AVAILABLE
|
| 340 |
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
|
|
|
|
| 665 |
int cc; // compute capability
|
| 666 |
int nsm; // number of streaming multiprocessors
|
| 667 |
size_t smpb; // max. shared memory per block
|
| 668 |
+
size_t smpbo; // max. shared memory per block (with opt-in)
|
| 669 |
bool vmm; // virtual memory support
|
| 670 |
size_t vmm_granularity; // granularity of virtual memory
|
| 671 |
size_t total_vram;
|
ggml-cuda/mmq.cuh
CHANGED
|
@@ -10,10 +10,10 @@
|
|
| 10 |
#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
|
| 11 |
|
| 12 |
typedef void (*load_tiles_mmq_t)(
|
| 13 |
-
const char * __restrict__ x, int * __restrict__
|
| 14 |
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride);
|
| 15 |
typedef void (*vec_dot_mmq_t)(
|
| 16 |
-
const int * __restrict__
|
| 17 |
const int * __restrict__ y, float * __restrict__ sum, const int & k0);
|
| 18 |
typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1);
|
| 19 |
|
|
@@ -25,9 +25,8 @@ static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected b
|
|
| 25 |
static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size");
|
| 26 |
|
| 27 |
struct tile_x_sizes {
|
| 28 |
-
int
|
| 29 |
int dm;
|
| 30 |
-
int qh;
|
| 31 |
int sc;
|
| 32 |
};
|
| 33 |
|
|
@@ -67,16 +66,16 @@ static constexpr __device__ int get_mmq_y_device(int /*mmq_x*/) {
|
|
| 67 |
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
| 68 |
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
| 69 |
|
| 70 |
-
#define TILE_X_SIZES_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0
|
| 71 |
-
#define TILE_X_SIZES_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0
|
| 72 |
-
#define TILE_X_SIZES_Q5_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_0 + mmq_y/QI5_0, 0
|
| 73 |
-
#define TILE_X_SIZES_Q5_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_1 + mmq_y/QI5_1, 0
|
| 74 |
-
#define TILE_X_SIZES_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI8_0 + mmq_y/QI8_0, 0
|
| 75 |
-
#define TILE_X_SIZES_Q2_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE
|
| 76 |
-
#define TILE_X_SIZES_Q3_K tile_x_sizes{mmq_y*WARP_SIZE
|
| 77 |
-
#define TILE_X_SIZES_Q4_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_K + mmq_y/QI4_K,
|
| 78 |
-
#define TILE_X_SIZES_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K,
|
| 79 |
-
#define TILE_X_SIZES_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K,
|
| 80 |
|
| 81 |
#define GET_TILE_X_SIZES_BODY \
|
| 82 |
return type == GGML_TYPE_Q4_0 ? TILE_X_SIZES_Q4_0 : \
|
|
@@ -89,7 +88,7 @@ static constexpr __device__ int get_mmq_y_device(int /*mmq_x*/) {
|
|
| 89 |
type == GGML_TYPE_Q4_K ? TILE_X_SIZES_Q4_K : \
|
| 90 |
type == GGML_TYPE_Q5_K ? TILE_X_SIZES_Q5_K : \
|
| 91 |
type == GGML_TYPE_Q6_K ? TILE_X_SIZES_Q6_K : \
|
| 92 |
-
tile_x_sizes{0, 0, 0
|
| 93 |
|
| 94 |
static tile_x_sizes get_tile_x_sizes_host(const ggml_type type, const int mmq_y) {
|
| 95 |
GET_TILE_X_SIZES_BODY;
|
|
@@ -103,9 +102,9 @@ static constexpr __device__ tile_x_sizes get_tile_x_sizes_device(ggml_type type)
|
|
| 103 |
// ------------------------------------------------------------
|
| 104 |
|
| 105 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
|
| 106 |
-
const char * __restrict__ x, int * __restrict__
|
| 107 |
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
|
| 108 |
-
GGML_UNUSED(
|
| 109 |
|
| 110 |
const int kbx = threadIdx.x / QI4_0;
|
| 111 |
const int kqsx = threadIdx.x % QI4_0;
|
|
@@ -122,7 +121,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 122 |
|
| 123 |
const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
|
| 124 |
|
| 125 |
-
|
| 126 |
}
|
| 127 |
|
| 128 |
const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
|
|
@@ -144,10 +143,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 144 |
|
| 145 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 146 |
static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
|
| 147 |
-
const int * __restrict__
|
| 148 |
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
| 149 |
-
|
| 150 |
-
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
|
| 151 |
|
| 152 |
const float * x_df = (const float *) x_dm;
|
| 153 |
const int * y_qs = (const int *) y + 4;
|
|
@@ -172,7 +170,7 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
|
|
| 172 |
}
|
| 173 |
|
| 174 |
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
|
| 175 |
-
(&
|
| 176 |
y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
|
| 177 |
}
|
| 178 |
}
|
|
@@ -180,10 +178,10 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
|
|
| 180 |
|
| 181 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 182 |
static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma(
|
| 183 |
-
const int * __restrict__
|
| 184 |
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
| 185 |
-
|
| 186 |
-
GGML_UNUSED(
|
| 187 |
|
| 188 |
typedef mma_int_A_I16K8 mma_A;
|
| 189 |
typedef mma_int_B_J8K8 mma_B;
|
|
@@ -205,7 +203,7 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma(
|
|
| 205 |
const int k = k0 + mma_A::get_k(l) % QI4_0;
|
| 206 |
const int shift = 4*(mma_A::get_k(l) / QI4_0);
|
| 207 |
|
| 208 |
-
A.x[l] = __vsubss4((
|
| 209 |
}
|
| 210 |
#pragma unroll
|
| 211 |
for (int l = 0; l < mma_C::ne/2; ++l) {
|
|
@@ -240,12 +238,16 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma(
|
|
| 240 |
sum[(j0/B.J)*C.ne + l] += dA[l/2]*__low2float(dsB[l%2])*C.x[l];
|
| 241 |
}
|
| 242 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
}
|
| 244 |
|
| 245 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
|
| 246 |
-
const char * __restrict__ x, int * __restrict__
|
| 247 |
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
|
| 248 |
-
GGML_UNUSED(
|
| 249 |
|
| 250 |
const int kbx = threadIdx.x / QI4_1;
|
| 251 |
const int kqsx = threadIdx.x % QI4_1;
|
|
@@ -260,7 +262,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 260 |
|
| 261 |
const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
|
| 262 |
|
| 263 |
-
|
| 264 |
}
|
| 265 |
|
| 266 |
const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
|
|
@@ -282,10 +284,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 282 |
|
| 283 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 284 |
static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
|
| 285 |
-
const int * __restrict__
|
| 286 |
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
| 287 |
-
|
| 288 |
-
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
|
| 289 |
|
| 290 |
const int * y_qs = (const int *) y + 4;
|
| 291 |
const half2 * y_ds = (const half2 *) y;
|
|
@@ -309,7 +310,7 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
|
|
| 309 |
}
|
| 310 |
|
| 311 |
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
|
| 312 |
-
(&
|
| 313 |
y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
|
| 314 |
}
|
| 315 |
}
|
|
@@ -317,10 +318,10 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
|
|
| 317 |
|
| 318 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 319 |
static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma(
|
| 320 |
-
const int * __restrict__
|
| 321 |
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
| 322 |
-
|
| 323 |
-
GGML_UNUSED(
|
| 324 |
|
| 325 |
typedef mma_int_A_I16K8 mma_A;
|
| 326 |
typedef mma_int_B_J8K8 mma_B;
|
|
@@ -341,7 +342,7 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma(
|
|
| 341 |
const int k = k0 + mma_A::get_k(l) % QI4_0;
|
| 342 |
const int shift = 4*(mma_A::get_k(l) / QI4_0);
|
| 343 |
|
| 344 |
-
A.x[l] = (
|
| 345 |
}
|
| 346 |
#pragma unroll
|
| 347 |
for (int l = 0; l < mma_C::ne/2; ++l) {
|
|
@@ -377,12 +378,16 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma(
|
|
| 377 |
sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
|
| 378 |
}
|
| 379 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 380 |
}
|
| 381 |
|
| 382 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
|
| 383 |
-
const char * __restrict__ x, int * __restrict__
|
| 384 |
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
|
| 385 |
-
GGML_UNUSED(
|
| 386 |
|
| 387 |
const int kbx = threadIdx.x / QI5_0;
|
| 388 |
const int kqsx = threadIdx.x % QI5_0;
|
|
@@ -407,7 +412,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 407 |
qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
|
| 408 |
qs0 = __vsubss4(qs0, 0x10101010); // subtract 16
|
| 409 |
|
| 410 |
-
|
| 411 |
|
| 412 |
int qs1 = (ql >> 4) & 0x0F0F0F0F;
|
| 413 |
qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
|
|
@@ -416,7 +421,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 416 |
qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
|
| 417 |
qs1 = __vsubss4(qs1, 0x10101010); // subtract 16
|
| 418 |
|
| 419 |
-
|
| 420 |
}
|
| 421 |
|
| 422 |
const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
|
|
@@ -439,10 +444,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 439 |
|
| 440 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 441 |
static __device__ __forceinline__ void vec_dot_q5_0_q8_1_dp4a(
|
| 442 |
-
const int * __restrict__
|
| 443 |
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
| 444 |
-
|
| 445 |
-
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
|
| 446 |
|
| 447 |
const float * x_dmf = (const float *) x_dm;
|
| 448 |
const int * y_qs = (const int *) y + 4;
|
|
@@ -468,17 +472,17 @@ static __device__ __forceinline__ void vec_dot_q5_0_q8_1_dp4a(
|
|
| 468 |
}
|
| 469 |
|
| 470 |
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, QR5_0*VDR_Q5_0_Q8_1_MMQ>
|
| 471 |
-
(&
|
| 472 |
}
|
| 473 |
}
|
| 474 |
}
|
| 475 |
|
| 476 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 477 |
static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma(
|
| 478 |
-
const int * __restrict__
|
| 479 |
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
| 480 |
-
|
| 481 |
-
GGML_UNUSED(
|
| 482 |
|
| 483 |
typedef mma_int_A_I16K8 mma_A;
|
| 484 |
typedef mma_int_B_J8K8 mma_B;
|
|
@@ -499,7 +503,7 @@ static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma(
|
|
| 499 |
const int i = i0 + mma_A::get_i(l);
|
| 500 |
const int k = 2*(k0 + mma_A::get_k(l) % QI5_0) + mma_A::get_k(l) / QI5_0;
|
| 501 |
|
| 502 |
-
A.x[l] =
|
| 503 |
}
|
| 504 |
#pragma unroll
|
| 505 |
for (int l = 0; l < mma_C::ne/2; ++l) {
|
|
@@ -534,12 +538,16 @@ static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma(
|
|
| 534 |
sum[(j0/B.J)*C.ne + l] += dA[l/2]*dB[l%2]*C.x[l];
|
| 535 |
}
|
| 536 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 537 |
}
|
| 538 |
|
| 539 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
|
| 540 |
-
const char * __restrict__ x, int * __restrict__
|
| 541 |
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
|
| 542 |
-
GGML_UNUSED(
|
| 543 |
|
| 544 |
const int kbx = threadIdx.x / QI5_1;
|
| 545 |
const int kqsx = threadIdx.x % QI5_1;
|
|
@@ -563,7 +571,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 563 |
qs0 |= (qh << 18) & 0x00100000; // 2 -> 20
|
| 564 |
qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
|
| 565 |
|
| 566 |
-
|
| 567 |
|
| 568 |
int qs1 = (ql >> 4) & 0x0F0F0F0F;
|
| 569 |
qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
|
|
@@ -571,7 +579,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 571 |
qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
|
| 572 |
qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
|
| 573 |
|
| 574 |
-
|
| 575 |
}
|
| 576 |
|
| 577 |
const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
|
|
@@ -593,10 +601,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 593 |
|
| 594 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 595 |
static __device__ __forceinline__ void vec_dot_q5_1_q8_1_dp4a(
|
| 596 |
-
const int * __restrict__
|
| 597 |
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
| 598 |
-
|
| 599 |
-
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
|
| 600 |
|
| 601 |
const int * y_qs = (const int *) y + 4;
|
| 602 |
const half2 * y_ds = (const half2 *) y;
|
|
@@ -621,17 +628,17 @@ static __device__ __forceinline__ void vec_dot_q5_1_q8_1_dp4a(
|
|
| 621 |
}
|
| 622 |
|
| 623 |
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
|
| 624 |
-
(&
|
| 625 |
}
|
| 626 |
}
|
| 627 |
}
|
| 628 |
|
| 629 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 630 |
static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma(
|
| 631 |
-
const int * __restrict__
|
| 632 |
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
| 633 |
-
|
| 634 |
-
GGML_UNUSED(
|
| 635 |
|
| 636 |
typedef mma_int_A_I16K8 mma_A;
|
| 637 |
typedef mma_int_B_J8K8 mma_B;
|
|
@@ -651,7 +658,7 @@ static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma(
|
|
| 651 |
const int i = i0 + mma_A::get_i(l);
|
| 652 |
const int k = 2*(k0 + mma_A::get_k(l) % QI5_1) + mma_A::get_k(l) / QI5_1;
|
| 653 |
|
| 654 |
-
A.x[l] =
|
| 655 |
}
|
| 656 |
#pragma unroll
|
| 657 |
for (int l = 0; l < mma_C::ne/2; ++l) {
|
|
@@ -687,13 +694,16 @@ static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma(
|
|
| 687 |
sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
|
| 688 |
}
|
| 689 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 690 |
}
|
| 691 |
|
| 692 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
|
| 693 |
-
const char * __restrict__ x, int * __restrict__
|
| 694 |
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
|
| 695 |
-
|
| 696 |
-
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
|
| 697 |
|
| 698 |
const int kbx = threadIdx.x / QI8_0;
|
| 699 |
const int kqsx = threadIdx.x % QI8_0;
|
|
@@ -709,7 +719,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 709 |
|
| 710 |
const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
|
| 711 |
|
| 712 |
-
|
| 713 |
}
|
| 714 |
|
| 715 |
const int blocks_per_tile_x_row = WARP_SIZE / QI8_0;
|
|
@@ -731,10 +741,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 731 |
|
| 732 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 733 |
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
|
| 734 |
-
const int * __restrict__
|
| 735 |
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
| 736 |
-
|
| 737 |
-
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
|
| 738 |
|
| 739 |
const float * x_dmf = (const float *) x_dm;
|
| 740 |
const int * y_qs = (const int *) y + 4;
|
|
@@ -749,7 +758,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
|
|
| 749 |
const int i = i0 + threadIdx.x;
|
| 750 |
|
| 751 |
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
|
| 752 |
-
(&
|
| 753 |
y_df[j*MMQ_TILE_Y_K + k0/QI8_1]);
|
| 754 |
}
|
| 755 |
}
|
|
@@ -757,10 +766,10 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
|
|
| 757 |
|
| 758 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 759 |
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
| 760 |
-
const int * __restrict__
|
| 761 |
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
| 762 |
-
|
| 763 |
-
GGML_UNUSED(
|
| 764 |
|
| 765 |
typedef mma_int_A_I16K8 mma_A;
|
| 766 |
typedef mma_int_B_J8K8 mma_B;
|
|
@@ -781,7 +790,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
|
| 781 |
const int i = i0 + mma_A::get_i(l);
|
| 782 |
const int k = k0 + mma_A::get_k(l);
|
| 783 |
|
| 784 |
-
A.x[l] =
|
| 785 |
}
|
| 786 |
#pragma unroll
|
| 787 |
for (int l = 0; l < mma_C::ne/2; ++l) {
|
|
@@ -816,12 +825,15 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
|
| 816 |
sum[(j0/B.J)*C.ne + l] += C.x[l]*dA[l/2]*dB[l%2];
|
| 817 |
}
|
| 818 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 819 |
}
|
| 820 |
|
| 821 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
|
| 822 |
-
const char * __restrict__ x, int * __restrict__
|
| 823 |
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
|
| 824 |
-
GGML_UNUSED(x_qh);
|
| 825 |
|
| 826 |
const int kbx = threadIdx.x / QI2_K;
|
| 827 |
const int kqsx = threadIdx.x % QI2_K;
|
|
@@ -836,48 +848,42 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 836 |
|
| 837 |
const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + kbx;
|
| 838 |
|
| 839 |
-
|
| 840 |
-
}
|
| 841 |
-
|
| 842 |
-
const int blocks_per_tile_x_row = WARP_SIZE / QI2_K;
|
| 843 |
-
const int kbxd = threadIdx.x % blocks_per_tile_x_row;
|
| 844 |
|
| 845 |
#pragma unroll
|
| 846 |
-
|
| 847 |
-
|
| 848 |
|
| 849 |
-
|
| 850 |
-
|
| 851 |
-
|
| 852 |
-
|
| 853 |
-
const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + kbxd;
|
| 854 |
-
|
| 855 |
-
x_dm[i * (WARP_SIZE/QI2_K) + i / QI2_K + kbxd] = bxi->dm;
|
| 856 |
-
}
|
| 857 |
|
| 858 |
-
|
| 859 |
-
|
| 860 |
-
|
| 861 |
|
| 862 |
-
|
| 863 |
-
i = min(i, i_max);
|
| 864 |
}
|
| 865 |
|
| 866 |
-
const
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 867 |
|
| 868 |
-
|
| 869 |
}
|
| 870 |
}
|
| 871 |
|
| 872 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 873 |
-
static __device__ __forceinline__ void
|
| 874 |
-
const int * __restrict__
|
| 875 |
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
| 876 |
|
| 877 |
-
|
| 878 |
-
|
| 879 |
-
const int * y_qs = (const int *) y + 4;
|
| 880 |
-
const float * y_df = (const float *) y;
|
| 881 |
|
| 882 |
#pragma unroll
|
| 883 |
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
|
@@ -887,30 +893,99 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mul_mat(
|
|
| 887 |
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
|
| 888 |
const int i = i0 + threadIdx.x;
|
| 889 |
|
| 890 |
-
|
| 891 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 892 |
|
| 893 |
-
|
|
|
|
| 894 |
|
| 895 |
-
|
| 896 |
-
|
|
|
|
| 897 |
|
| 898 |
#pragma unroll
|
| 899 |
-
|
| 900 |
-
|
| 901 |
-
|
| 902 |
|
| 903 |
-
|
|
|
|
|
|
|
| 904 |
|
| 905 |
-
|
| 906 |
-
|
| 907 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 908 |
}
|
| 909 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 910 |
}
|
| 911 |
|
| 912 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
|
| 913 |
-
const char * __restrict__ x, int * __restrict__
|
| 914 |
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
|
| 915 |
|
| 916 |
const int kbx = threadIdx.x / QI3_K;
|
|
@@ -926,7 +1001,25 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 926 |
|
| 927 |
const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbx;
|
| 928 |
|
| 929 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 930 |
}
|
| 931 |
|
| 932 |
const int blocks_per_tile_x_row = WARP_SIZE / QI3_K;
|
|
@@ -946,20 +1039,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 946 |
x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = bxi->d;
|
| 947 |
}
|
| 948 |
|
| 949 |
-
#pragma unroll
|
| 950 |
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 2) {
|
| 951 |
-
int i = i0 + threadIdx.y * 2 + threadIdx.x / (WARP_SIZE/2);
|
| 952 |
-
|
| 953 |
-
if (need_check) {
|
| 954 |
-
i = min(i, i_max);
|
| 955 |
-
}
|
| 956 |
-
|
| 957 |
-
const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/2)) / (QI3_K/2);
|
| 958 |
-
|
| 959 |
-
// invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
|
| 960 |
-
x_qh[i * (WARP_SIZE/2) + i / 2 + threadIdx.x % (WARP_SIZE/2)] = ~get_int_from_uint8(bxi->hmask, threadIdx.x % (QI3_K/2));
|
| 961 |
-
}
|
| 962 |
-
|
| 963 |
#pragma unroll
|
| 964 |
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
|
| 965 |
int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
|
|
@@ -987,13 +1066,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 987 |
}
|
| 988 |
|
| 989 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 990 |
-
static __device__ __forceinline__ void
|
| 991 |
-
const int * __restrict__
|
| 992 |
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
| 993 |
|
| 994 |
-
const float *
|
| 995 |
-
const int * y_qs
|
| 996 |
-
const float * y_df
|
| 997 |
|
| 998 |
#pragma unroll
|
| 999 |
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
|
@@ -1008,31 +1087,102 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mul_mat(
|
|
| 1008 |
|
| 1009 |
const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
|
| 1010 |
|
| 1011 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1012 |
|
| 1013 |
#pragma unroll
|
| 1014 |
-
|
| 1015 |
-
|
| 1016 |
-
|
| 1017 |
-
const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303;
|
| 1018 |
|
| 1019 |
-
|
| 1020 |
-
|
|
|
|
|
|
|
|
|
|
| 1021 |
|
| 1022 |
-
|
| 1023 |
-
|
|
|
|
| 1024 |
|
| 1025 |
-
|
| 1026 |
-
|
| 1027 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1028 |
}
|
| 1029 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1030 |
}
|
| 1031 |
|
| 1032 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
|
| 1033 |
-
const char * __restrict__ x, int * __restrict__
|
| 1034 |
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
|
| 1035 |
-
GGML_UNUSED(x_qh);
|
| 1036 |
|
| 1037 |
const int kbx = 0; // threadIdx.x / QI4_K
|
| 1038 |
const int kqsx = threadIdx.x; // threadIdx.x % QI4_K
|
|
@@ -1047,7 +1197,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 1047 |
|
| 1048 |
const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbx;
|
| 1049 |
|
| 1050 |
-
|
| 1051 |
}
|
| 1052 |
|
| 1053 |
const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256
|
|
@@ -1090,11 +1240,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 1090 |
|
| 1091 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1092 |
static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
|
| 1093 |
-
const int * __restrict__
|
| 1094 |
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
| 1095 |
|
| 1096 |
-
GGML_UNUSED(x_qh);
|
| 1097 |
-
|
| 1098 |
const int * y_qs = (const int *) y + 4;
|
| 1099 |
const half2 * y_ds = (const half2 *) y;
|
| 1100 |
|
|
@@ -1109,7 +1257,7 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
|
|
| 1109 |
const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2*((k0 % 16) / 8);
|
| 1110 |
|
| 1111 |
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq(
|
| 1112 |
-
&
|
| 1113 |
x_dm[i*(WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[j*MMQ_TILE_Y_K + ((QR4_K*k0) % WARP_SIZE)/QI8_1]);
|
| 1114 |
}
|
| 1115 |
}
|
|
@@ -1117,10 +1265,9 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
|
|
| 1117 |
|
| 1118 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1119 |
static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma(
|
| 1120 |
-
const int * __restrict__
|
| 1121 |
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
| 1122 |
-
|
| 1123 |
-
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
|
| 1124 |
|
| 1125 |
typedef mma_int_A_I16K8 mma_A;
|
| 1126 |
typedef mma_int_B_J8K8 mma_B;
|
|
@@ -1143,7 +1290,7 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma(
|
|
| 1143 |
const int i = i0 + mma_A::get_i(l);
|
| 1144 |
const int k = k0 + mma_A::get_k(l);
|
| 1145 |
|
| 1146 |
-
A[kvdr/4].x[l] = (
|
| 1147 |
}
|
| 1148 |
|
| 1149 |
#pragma unroll
|
|
@@ -1204,12 +1351,15 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma(
|
|
| 1204 |
sum[(j0/mma_B::J)*mma_C::ne + l] += __low2float(dmA[l/2])*tmpd[l] - __high2float(dmA[l/2])*tmpm[l];
|
| 1205 |
}
|
| 1206 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1207 |
}
|
| 1208 |
|
| 1209 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
|
| 1210 |
-
const char * __restrict__ x, int * __restrict__
|
| 1211 |
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
|
| 1212 |
-
GGML_UNUSED(x_qh);
|
| 1213 |
|
| 1214 |
const int kbx = 0; // threadIdx.x / QI5_K
|
| 1215 |
const int kqsx = threadIdx.x; // threadIdx.x % QI5_K
|
|
@@ -1236,8 +1386,8 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 1236 |
const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0;
|
| 1237 |
const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + (QI5_K/4);
|
| 1238 |
|
| 1239 |
-
|
| 1240 |
-
|
| 1241 |
}
|
| 1242 |
|
| 1243 |
const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256
|
|
@@ -1280,11 +1430,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 1280 |
|
| 1281 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1282 |
static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
|
| 1283 |
-
const int * __restrict__
|
| 1284 |
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
| 1285 |
|
| 1286 |
-
GGML_UNUSED(x_qh);
|
| 1287 |
-
|
| 1288 |
const int * y_qs = (const int *) y + 4;
|
| 1289 |
const half2 * y_ds = (const half2 *) y;
|
| 1290 |
|
|
@@ -1299,7 +1447,7 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
|
|
| 1299 |
const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8);
|
| 1300 |
|
| 1301 |
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq(
|
| 1302 |
-
&
|
| 1303 |
x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[j*MMQ_TILE_Y_K + ((QR5_K*k0) % WARP_SIZE)/QI8_1]);
|
| 1304 |
}
|
| 1305 |
}
|
|
@@ -1307,10 +1455,9 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
|
|
| 1307 |
|
| 1308 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1309 |
static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma(
|
| 1310 |
-
const int * __restrict__
|
| 1311 |
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
| 1312 |
-
|
| 1313 |
-
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
|
| 1314 |
|
| 1315 |
typedef mma_int_A_I16K8 mma_A;
|
| 1316 |
typedef mma_int_B_J8K8 mma_B;
|
|
@@ -1333,7 +1480,7 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma(
|
|
| 1333 |
const int i = i0 + mma_A::get_i(l);
|
| 1334 |
const int k = QR5_K*k0 + QR5_K*kvdr + mma_A::get_k(l);
|
| 1335 |
|
| 1336 |
-
A[kvdr/4].x[l] =
|
| 1337 |
}
|
| 1338 |
|
| 1339 |
#pragma unroll
|
|
@@ -1394,12 +1541,15 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma(
|
|
| 1394 |
sum[(j0/mma_B::J)*mma_C::ne + l] += __low2float(dmA[l/2])*tmpd[l] - __high2float(dmA[l/2])*tmpm[l];
|
| 1395 |
}
|
| 1396 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1397 |
}
|
| 1398 |
|
| 1399 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
|
| 1400 |
-
const char * __restrict__ x, int * __restrict__
|
| 1401 |
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
|
| 1402 |
-
GGML_UNUSED(x_qh);
|
| 1403 |
|
| 1404 |
const int kbx = 0; // threadIdx.x / QI6_K
|
| 1405 |
const int kqsx = threadIdx.x; // threadIdx.x % QI6_K
|
|
@@ -1426,8 +1576,8 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 1426 |
const int kq0 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + 0;
|
| 1427 |
const int kq1 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + (QI6_K/2);
|
| 1428 |
|
| 1429 |
-
|
| 1430 |
-
|
| 1431 |
}
|
| 1432 |
|
| 1433 |
const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256
|
|
@@ -1463,11 +1613,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 1463 |
|
| 1464 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1465 |
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
|
| 1466 |
-
const int * __restrict__
|
| 1467 |
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
| 1468 |
|
| 1469 |
-
GGML_UNUSED(x_qh);
|
| 1470 |
-
|
| 1471 |
const float * x_dmf = (const float *) x_dm;
|
| 1472 |
const int * y_qs = (const int *) y + 4;
|
| 1473 |
const float * y_df = (const float *) y;
|
|
@@ -1483,7 +1631,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
|
|
| 1483 |
const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/8]);
|
| 1484 |
|
| 1485 |
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq(
|
| 1486 |
-
&
|
| 1487 |
x_dmf[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + ((QR6_K*k0) % WARP_SIZE)/QI8_1]);
|
| 1488 |
}
|
| 1489 |
}
|
|
@@ -1491,10 +1639,9 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
|
|
| 1491 |
|
| 1492 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1493 |
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
| 1494 |
-
const int * __restrict__
|
| 1495 |
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
| 1496 |
-
|
| 1497 |
-
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
|
| 1498 |
|
| 1499 |
typedef mma_int_A_I16K4 mma_A;
|
| 1500 |
typedef mma_int_B_J8K4 mma_B;
|
|
@@ -1505,7 +1652,9 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|
| 1505 |
const float * y_df = (const float *) y;
|
| 1506 |
|
| 1507 |
const int i0 = threadIdx.y*mma_A::I;
|
|
|
|
| 1508 |
static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
|
|
|
|
| 1509 |
|
| 1510 |
mma_A A[4];
|
| 1511 |
int scA[mma_C::ne/2][4];
|
|
@@ -1517,8 +1666,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|
| 1517 |
const int i = i0 + mma_A::get_i(l);
|
| 1518 |
const int k = QR6_K*k0 + QR6_K*kvdr + mma_A::get_k(l);
|
| 1519 |
|
| 1520 |
-
A[kvdr/2 + 0].x[l] =
|
| 1521 |
-
A[kvdr/2 + 1].x[l] =
|
| 1522 |
}
|
| 1523 |
|
| 1524 |
#pragma unroll
|
|
@@ -1578,6 +1727,10 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|
| 1578 |
sum[(j0/mma_B::J)*mma_C::ne + l] += tmp[l]*dA[l/2];
|
| 1579 |
}
|
| 1580 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1581 |
}
|
| 1582 |
|
| 1583 |
template<int mmq_x, int mmq_y, int nwarps, bool need_check>
|
|
@@ -1608,7 +1761,9 @@ static __device__ __forceinline__ void mmq_write_back_mma(const float * __restri
|
|
| 1608 |
typedef mma_int_C_I16J8 mma_C;
|
| 1609 |
|
| 1610 |
const int i0 = threadIdx.y*mma_C::I;
|
|
|
|
| 1611 |
static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y");
|
|
|
|
| 1612 |
|
| 1613 |
#pragma unroll
|
| 1614 |
for (int j0 = 0; j0 < mmq_x; j0 += mma_C::J) {
|
|
@@ -1638,125 +1793,85 @@ struct mmq_type_traits;
|
|
| 1638 |
|
| 1639 |
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
| 1640 |
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_0> {
|
| 1641 |
-
static constexpr int vdr
|
| 1642 |
-
static constexpr load_tiles_mmq_t load_tiles
|
| 1643 |
-
|
| 1644 |
-
static constexpr vec_dot_mmq_t
|
| 1645 |
-
static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
|
| 1646 |
-
#else
|
| 1647 |
-
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
| 1648 |
-
static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
|
| 1649 |
-
#endif // INT8_MMA_AVAILABLE
|
| 1650 |
};
|
| 1651 |
|
| 1652 |
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
| 1653 |
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_1> {
|
| 1654 |
-
static constexpr int vdr
|
| 1655 |
-
static constexpr load_tiles_mmq_t load_tiles
|
| 1656 |
-
|
| 1657 |
-
static constexpr vec_dot_mmq_t
|
| 1658 |
-
static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
|
| 1659 |
-
#else
|
| 1660 |
-
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
| 1661 |
-
static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
|
| 1662 |
-
#endif // INT8_MMA_AVAILABLE
|
| 1663 |
};
|
| 1664 |
|
| 1665 |
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
| 1666 |
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
|
| 1667 |
-
static constexpr int vdr
|
| 1668 |
-
static constexpr load_tiles_mmq_t load_tiles
|
| 1669 |
-
|
| 1670 |
-
static constexpr vec_dot_mmq_t
|
| 1671 |
-
static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
|
| 1672 |
-
#else
|
| 1673 |
-
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
| 1674 |
-
static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
|
| 1675 |
-
#endif // INT8_MMA_AVAILABLE
|
| 1676 |
};
|
| 1677 |
|
| 1678 |
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
| 1679 |
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
|
| 1680 |
-
static constexpr int vdr
|
| 1681 |
-
static constexpr load_tiles_mmq_t load_tiles
|
| 1682 |
-
|
| 1683 |
-
static constexpr vec_dot_mmq_t
|
| 1684 |
-
static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
|
| 1685 |
-
#else
|
| 1686 |
-
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
| 1687 |
-
static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
|
| 1688 |
-
#endif // INT8_MMA_AVAILABLE
|
| 1689 |
};
|
| 1690 |
|
| 1691 |
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
| 1692 |
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> {
|
| 1693 |
-
static constexpr int vdr
|
| 1694 |
-
static constexpr load_tiles_mmq_t load_tiles
|
| 1695 |
-
|
| 1696 |
-
static constexpr vec_dot_mmq_t
|
| 1697 |
-
static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
|
| 1698 |
-
#else
|
| 1699 |
-
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
| 1700 |
-
static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
|
| 1701 |
-
#endif // INT8_MMA_AVAILABLE
|
| 1702 |
};
|
| 1703 |
|
| 1704 |
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
| 1705 |
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q2_K> {
|
| 1706 |
-
static constexpr int vdr
|
| 1707 |
-
static constexpr load_tiles_mmq_t load_tiles
|
| 1708 |
-
static constexpr vec_dot_mmq_t
|
| 1709 |
-
static constexpr
|
| 1710 |
};
|
| 1711 |
|
| 1712 |
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
| 1713 |
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q3_K> {
|
| 1714 |
-
static constexpr int vdr
|
| 1715 |
-
static constexpr load_tiles_mmq_t load_tiles
|
| 1716 |
-
static constexpr vec_dot_mmq_t
|
| 1717 |
-
static constexpr
|
| 1718 |
};
|
| 1719 |
|
| 1720 |
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
| 1721 |
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> {
|
| 1722 |
-
static constexpr int vdr
|
| 1723 |
-
static constexpr load_tiles_mmq_t load_tiles
|
| 1724 |
-
|
| 1725 |
-
static constexpr vec_dot_mmq_t
|
| 1726 |
-
static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
|
| 1727 |
-
#else
|
| 1728 |
-
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
| 1729 |
-
static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
|
| 1730 |
-
#endif // INT8_MMA_AVAILABLE
|
| 1731 |
};
|
| 1732 |
|
| 1733 |
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
| 1734 |
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> {
|
| 1735 |
-
static constexpr int vdr
|
| 1736 |
-
static constexpr load_tiles_mmq_t load_tiles
|
| 1737 |
-
|
| 1738 |
-
static constexpr vec_dot_mmq_t
|
| 1739 |
-
static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
|
| 1740 |
-
#else
|
| 1741 |
-
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
| 1742 |
-
static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
|
| 1743 |
-
#endif // INT8_MMA_AVAILABLE
|
| 1744 |
};
|
| 1745 |
|
| 1746 |
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
| 1747 |
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
|
| 1748 |
-
static constexpr int vdr
|
| 1749 |
-
static constexpr load_tiles_mmq_t load_tiles
|
| 1750 |
-
|
| 1751 |
-
static constexpr vec_dot_mmq_t
|
| 1752 |
-
static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
|
| 1753 |
-
#else
|
| 1754 |
-
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
| 1755 |
-
static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
|
| 1756 |
-
#endif // INT8_MMA_AVAILABLE
|
| 1757 |
};
|
| 1758 |
|
| 1759 |
-
static
|
| 1760 |
switch (type_x) {
|
| 1761 |
case GGML_TYPE_Q4_0:
|
| 1762 |
case GGML_TYPE_Q4_1:
|
|
@@ -1790,7 +1905,7 @@ template <ggml_type type, int mmq_x, int nwarps, bool need_check>
|
|
| 1790 |
#if __CUDA_ARCH__ >= CC_VOLTA
|
| 1791 |
__launch_bounds__(WARP_SIZE*nwarps, 1)
|
| 1792 |
#else
|
| 1793 |
-
__launch_bounds__(WARP_SIZE*nwarps,
|
| 1794 |
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
| 1795 |
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
| 1796 |
static __global__ void mul_mat_q(
|
|
@@ -1809,16 +1924,21 @@ static __global__ void mul_mat_q(
|
|
| 1809 |
constexpr int mmq_y = get_mmq_y_device(mmq_x);
|
| 1810 |
constexpr int vdr = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vdr;
|
| 1811 |
constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
|
| 1812 |
-
|
| 1813 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1814 |
|
| 1815 |
constexpr tile_x_sizes txs = get_tile_x_sizes_device<mmq_y>(type);
|
| 1816 |
|
| 1817 |
extern __shared__ char data_mul_mat_q[];
|
| 1818 |
-
int *
|
| 1819 |
-
half2 * tile_x_dm = (half2 *) (
|
| 1820 |
-
int *
|
| 1821 |
-
int * tile_x_sc = (int *) (tile_x_qh + txs.qh);
|
| 1822 |
int * tile_y = (int *) (tile_x_sc + txs.sc); // [mmq_x * (WARP_SIZE + WARP_SIZE/QI8_1)]
|
| 1823 |
|
| 1824 |
const int blocks_per_row_x = ne00 / qk;
|
|
@@ -1834,7 +1954,7 @@ static __global__ void mul_mat_q(
|
|
| 1834 |
|
| 1835 |
for (int kb0 = 0; kb0 < blocks_per_row_x; kb0 += blocks_per_warp) {
|
| 1836 |
|
| 1837 |
-
load_tiles(x,
|
| 1838 |
|
| 1839 |
#pragma unroll
|
| 1840 |
for (int kr = 0; kr < qr; ++kr) {
|
|
@@ -1850,7 +1970,7 @@ static __global__ void mul_mat_q(
|
|
| 1850 |
|
| 1851 |
// #pragma unroll // unrolling this loop causes too much register pressure
|
| 1852 |
for (int k0 = kr*WARP_SIZE/qr; k0 < (kr+1)*WARP_SIZE/qr; k0 += vdr) {
|
| 1853 |
-
vec_dot(
|
| 1854 |
}
|
| 1855 |
|
| 1856 |
__syncthreads();
|
|
@@ -1867,6 +1987,19 @@ struct mmq_args {
|
|
| 1867 |
int64_t ne0;
|
| 1868 |
};
|
| 1869 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1870 |
template <ggml_type type, int mmq_x, int nwarps>
|
| 1871 |
static void launch_mul_mat_q(const mmq_args & args, cudaStream_t stream) {
|
| 1872 |
const int id = ggml_cuda_get_device();
|
|
@@ -1878,10 +2011,7 @@ static void launch_mul_mat_q(const mmq_args & args, cudaStream_t stream) {
|
|
| 1878 |
const dim3 block_nums(block_num_x, block_num_y, 1);
|
| 1879 |
const dim3 block_dims(WARP_SIZE, nwarps, 1);
|
| 1880 |
|
| 1881 |
-
const
|
| 1882 |
-
const int shmem_x = txs.ql*sizeof(int) + txs.dm*sizeof(half2) + txs.qh*sizeof(int) + txs.sc*sizeof(int);
|
| 1883 |
-
const int shmem_y = mmq_x*WARP_SIZE*sizeof(int) + mmq_x*(WARP_SIZE/QI8_1)*sizeof(half2);
|
| 1884 |
-
const int shmem = shmem_x + GGML_PAD(shmem_y, nwarps*WARP_SIZE*sizeof(int));
|
| 1885 |
|
| 1886 |
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
| 1887 |
static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
|
@@ -1905,9 +2035,10 @@ static void launch_mul_mat_q(const mmq_args & args, cudaStream_t stream) {
|
|
| 1905 |
|
| 1906 |
template <ggml_type type>
|
| 1907 |
void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) {
|
| 1908 |
-
const int id
|
| 1909 |
-
const int nsm
|
| 1910 |
-
const int cc
|
|
|
|
| 1911 |
|
| 1912 |
const int mmq_x_max = get_mmq_x_max_host(cc);
|
| 1913 |
const int mmq_y = get_mmq_y_host(cc, mmq_x_max);
|
|
@@ -1920,7 +2051,7 @@ void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) {
|
|
| 1920 |
const int block_num_x = (args.ne11 + mmq_x - 1) / mmq_x;
|
| 1921 |
const int nwaves = (block_num_x*block_num_y + nsm - 1) / nsm;
|
| 1922 |
|
| 1923 |
-
if (nwaves < nwaves_best) {
|
| 1924 |
mmq_x_best = mmq_x;
|
| 1925 |
nwaves_best = nwaves;
|
| 1926 |
}
|
|
@@ -1928,54 +2059,55 @@ void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) {
|
|
| 1928 |
|
| 1929 |
switch (mmq_x_best) {
|
| 1930 |
case 8:
|
| 1931 |
-
launch_mul_mat_q<type, 8,
|
| 1932 |
break;
|
| 1933 |
case 16:
|
| 1934 |
-
launch_mul_mat_q<type, 16,
|
| 1935 |
break;
|
| 1936 |
case 24:
|
| 1937 |
-
launch_mul_mat_q<type, 24,
|
| 1938 |
break;
|
| 1939 |
case 32:
|
| 1940 |
-
launch_mul_mat_q<type, 32,
|
| 1941 |
break;
|
| 1942 |
case 40:
|
| 1943 |
-
launch_mul_mat_q<type, 40,
|
| 1944 |
break;
|
| 1945 |
case 48:
|
| 1946 |
-
launch_mul_mat_q<type, 48,
|
| 1947 |
break;
|
| 1948 |
case 56:
|
| 1949 |
-
launch_mul_mat_q<type, 56,
|
| 1950 |
break;
|
| 1951 |
case 64:
|
| 1952 |
-
launch_mul_mat_q<type, 64,
|
| 1953 |
break;
|
| 1954 |
case 72:
|
| 1955 |
-
launch_mul_mat_q<type, 72,
|
| 1956 |
break;
|
| 1957 |
case 80:
|
| 1958 |
-
launch_mul_mat_q<type, 80,
|
| 1959 |
break;
|
| 1960 |
case 88:
|
| 1961 |
-
launch_mul_mat_q<type, 88,
|
| 1962 |
break;
|
| 1963 |
case 96:
|
| 1964 |
-
launch_mul_mat_q<type, 96,
|
| 1965 |
break;
|
| 1966 |
case 104:
|
| 1967 |
-
launch_mul_mat_q<type, 104,
|
| 1968 |
break;
|
| 1969 |
case 112:
|
| 1970 |
-
launch_mul_mat_q<type, 112,
|
| 1971 |
break;
|
| 1972 |
case 120:
|
| 1973 |
-
launch_mul_mat_q<type, 120,
|
| 1974 |
break;
|
| 1975 |
case 128:
|
| 1976 |
-
launch_mul_mat_q<type, 128,
|
| 1977 |
break;
|
| 1978 |
default:
|
|
|
|
| 1979 |
GGML_ASSERT(false);
|
| 1980 |
break;
|
| 1981 |
}
|
|
|
|
| 10 |
#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
|
| 11 |
|
| 12 |
typedef void (*load_tiles_mmq_t)(
|
| 13 |
+
const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
|
| 14 |
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride);
|
| 15 |
typedef void (*vec_dot_mmq_t)(
|
| 16 |
+
const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
|
| 17 |
const int * __restrict__ y, float * __restrict__ sum, const int & k0);
|
| 18 |
typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1);
|
| 19 |
|
|
|
|
| 25 |
static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size");
|
| 26 |
|
| 27 |
struct tile_x_sizes {
|
| 28 |
+
int qs;
|
| 29 |
int dm;
|
|
|
|
| 30 |
int sc;
|
| 31 |
};
|
| 32 |
|
|
|
|
| 66 |
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
| 67 |
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
| 68 |
|
| 69 |
+
#define TILE_X_SIZES_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0}
|
| 70 |
+
#define TILE_X_SIZES_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0}
|
| 71 |
+
#define TILE_X_SIZES_Q5_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_0 + mmq_y/QI5_0, 0}
|
| 72 |
+
#define TILE_X_SIZES_Q5_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_1 + mmq_y/QI5_1, 0}
|
| 73 |
+
#define TILE_X_SIZES_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI8_0 + mmq_y/QI8_0, 0}
|
| 74 |
+
#define TILE_X_SIZES_Q2_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE + mmq_y, 0}
|
| 75 |
+
#define TILE_X_SIZES_Q3_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI3_K + mmq_y/QI3_K, mmq_y*WARP_SIZE/4 + mmq_y/4}
|
| 76 |
+
#define TILE_X_SIZES_Q4_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_K + mmq_y/QI4_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
|
| 77 |
+
#define TILE_X_SIZES_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
|
| 78 |
+
#define TILE_X_SIZES_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
|
| 79 |
|
| 80 |
#define GET_TILE_X_SIZES_BODY \
|
| 81 |
return type == GGML_TYPE_Q4_0 ? TILE_X_SIZES_Q4_0 : \
|
|
|
|
| 88 |
type == GGML_TYPE_Q4_K ? TILE_X_SIZES_Q4_K : \
|
| 89 |
type == GGML_TYPE_Q5_K ? TILE_X_SIZES_Q5_K : \
|
| 90 |
type == GGML_TYPE_Q6_K ? TILE_X_SIZES_Q6_K : \
|
| 91 |
+
tile_x_sizes{0, 0, 0}
|
| 92 |
|
| 93 |
static tile_x_sizes get_tile_x_sizes_host(const ggml_type type, const int mmq_y) {
|
| 94 |
GET_TILE_X_SIZES_BODY;
|
|
|
|
| 102 |
// ------------------------------------------------------------
|
| 103 |
|
| 104 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
|
| 105 |
+
const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
|
| 106 |
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
|
| 107 |
+
GGML_UNUSED(x_sc);
|
| 108 |
|
| 109 |
const int kbx = threadIdx.x / QI4_0;
|
| 110 |
const int kqsx = threadIdx.x % QI4_0;
|
|
|
|
| 121 |
|
| 122 |
const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
|
| 123 |
|
| 124 |
+
x_qs[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx);
|
| 125 |
}
|
| 126 |
|
| 127 |
const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
|
|
|
|
| 143 |
|
| 144 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 145 |
static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
|
| 146 |
+
const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
|
| 147 |
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
| 148 |
+
GGML_UNUSED(x_sc);
|
|
|
|
| 149 |
|
| 150 |
const float * x_df = (const float *) x_dm;
|
| 151 |
const int * y_qs = (const int *) y + 4;
|
|
|
|
| 170 |
}
|
| 171 |
|
| 172 |
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
|
| 173 |
+
(&x_qs[i*(WARP_SIZE + 1) + k0], u, x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0],
|
| 174 |
y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
|
| 175 |
}
|
| 176 |
}
|
|
|
|
| 178 |
|
| 179 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 180 |
static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma(
|
| 181 |
+
const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
|
| 182 |
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
| 183 |
+
#ifdef INT8_MMA_AVAILABLE
|
| 184 |
+
GGML_UNUSED(x_sc);
|
| 185 |
|
| 186 |
typedef mma_int_A_I16K8 mma_A;
|
| 187 |
typedef mma_int_B_J8K8 mma_B;
|
|
|
|
| 203 |
const int k = k0 + mma_A::get_k(l) % QI4_0;
|
| 204 |
const int shift = 4*(mma_A::get_k(l) / QI4_0);
|
| 205 |
|
| 206 |
+
A.x[l] = __vsubss4((x_qs[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F, 0x08080808);
|
| 207 |
}
|
| 208 |
#pragma unroll
|
| 209 |
for (int l = 0; l < mma_C::ne/2; ++l) {
|
|
|
|
| 238 |
sum[(j0/B.J)*C.ne + l] += dA[l/2]*__low2float(dsB[l%2])*C.x[l];
|
| 239 |
}
|
| 240 |
}
|
| 241 |
+
#else
|
| 242 |
+
GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
|
| 243 |
+
NO_DEVICE_CODE;
|
| 244 |
+
#endif // INT8_MMA_AVAILABLE
|
| 245 |
}
|
| 246 |
|
| 247 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
|
| 248 |
+
const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
|
| 249 |
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
|
| 250 |
+
GGML_UNUSED(x_sc);
|
| 251 |
|
| 252 |
const int kbx = threadIdx.x / QI4_1;
|
| 253 |
const int kqsx = threadIdx.x % QI4_1;
|
|
|
|
| 262 |
|
| 263 |
const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
|
| 264 |
|
| 265 |
+
x_qs[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
|
| 266 |
}
|
| 267 |
|
| 268 |
const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
|
|
|
|
| 284 |
|
| 285 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 286 |
static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
|
| 287 |
+
const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
|
| 288 |
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
| 289 |
+
GGML_UNUSED(x_sc);
|
|
|
|
| 290 |
|
| 291 |
const int * y_qs = (const int *) y + 4;
|
| 292 |
const half2 * y_ds = (const half2 *) y;
|
|
|
|
| 310 |
}
|
| 311 |
|
| 312 |
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
|
| 313 |
+
(&x_qs[i*(WARP_SIZE + 1) + k0], u, x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + k0/QI4_1],
|
| 314 |
y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
|
| 315 |
}
|
| 316 |
}
|
|
|
|
| 318 |
|
| 319 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 320 |
static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma(
|
| 321 |
+
const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
|
| 322 |
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
| 323 |
+
#ifdef INT8_MMA_AVAILABLE
|
| 324 |
+
GGML_UNUSED(x_sc);
|
| 325 |
|
| 326 |
typedef mma_int_A_I16K8 mma_A;
|
| 327 |
typedef mma_int_B_J8K8 mma_B;
|
|
|
|
| 342 |
const int k = k0 + mma_A::get_k(l) % QI4_0;
|
| 343 |
const int shift = 4*(mma_A::get_k(l) / QI4_0);
|
| 344 |
|
| 345 |
+
A.x[l] = (x_qs[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F;
|
| 346 |
}
|
| 347 |
#pragma unroll
|
| 348 |
for (int l = 0; l < mma_C::ne/2; ++l) {
|
|
|
|
| 378 |
sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
|
| 379 |
}
|
| 380 |
}
|
| 381 |
+
#else
|
| 382 |
+
GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
|
| 383 |
+
NO_DEVICE_CODE;
|
| 384 |
+
#endif // INT8_MMA_AVAILABLE
|
| 385 |
}
|
| 386 |
|
| 387 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
|
| 388 |
+
const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
|
| 389 |
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
|
| 390 |
+
GGML_UNUSED(x_sc);
|
| 391 |
|
| 392 |
const int kbx = threadIdx.x / QI5_0;
|
| 393 |
const int kqsx = threadIdx.x % QI5_0;
|
|
|
|
| 412 |
qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
|
| 413 |
qs0 = __vsubss4(qs0, 0x10101010); // subtract 16
|
| 414 |
|
| 415 |
+
x_qs[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+0] = qs0;
|
| 416 |
|
| 417 |
int qs1 = (ql >> 4) & 0x0F0F0F0F;
|
| 418 |
qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
|
|
|
|
| 421 |
qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
|
| 422 |
qs1 = __vsubss4(qs1, 0x10101010); // subtract 16
|
| 423 |
|
| 424 |
+
x_qs[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+1] = qs1;
|
| 425 |
}
|
| 426 |
|
| 427 |
const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
|
|
|
|
| 444 |
|
| 445 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 446 |
static __device__ __forceinline__ void vec_dot_q5_0_q8_1_dp4a(
|
| 447 |
+
const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
|
| 448 |
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
| 449 |
+
GGML_UNUSED(x_sc);
|
|
|
|
| 450 |
|
| 451 |
const float * x_dmf = (const float *) x_dm;
|
| 452 |
const int * y_qs = (const int *) y + 4;
|
|
|
|
| 472 |
}
|
| 473 |
|
| 474 |
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, QR5_0*VDR_Q5_0_Q8_1_MMQ>
|
| 475 |
+
(&x_qs[i*(2*WARP_SIZE + 1) + 2*k0], u, x_dmf[index_bx], y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
|
| 476 |
}
|
| 477 |
}
|
| 478 |
}
|
| 479 |
|
| 480 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 481 |
static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma(
|
| 482 |
+
const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
|
| 483 |
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
| 484 |
+
#ifdef INT8_MMA_AVAILABLE
|
| 485 |
+
GGML_UNUSED(x_sc);
|
| 486 |
|
| 487 |
typedef mma_int_A_I16K8 mma_A;
|
| 488 |
typedef mma_int_B_J8K8 mma_B;
|
|
|
|
| 503 |
const int i = i0 + mma_A::get_i(l);
|
| 504 |
const int k = 2*(k0 + mma_A::get_k(l) % QI5_0) + mma_A::get_k(l) / QI5_0;
|
| 505 |
|
| 506 |
+
A.x[l] = x_qs[i*(2*WARP_SIZE + 1) + k];
|
| 507 |
}
|
| 508 |
#pragma unroll
|
| 509 |
for (int l = 0; l < mma_C::ne/2; ++l) {
|
|
|
|
| 538 |
sum[(j0/B.J)*C.ne + l] += dA[l/2]*dB[l%2]*C.x[l];
|
| 539 |
}
|
| 540 |
}
|
| 541 |
+
#else
|
| 542 |
+
GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
|
| 543 |
+
NO_DEVICE_CODE;
|
| 544 |
+
#endif // INT8_MMA_AVAILABLE
|
| 545 |
}
|
| 546 |
|
| 547 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
|
| 548 |
+
const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
|
| 549 |
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
|
| 550 |
+
GGML_UNUSED(x_sc);
|
| 551 |
|
| 552 |
const int kbx = threadIdx.x / QI5_1;
|
| 553 |
const int kqsx = threadIdx.x % QI5_1;
|
|
|
|
| 571 |
qs0 |= (qh << 18) & 0x00100000; // 2 -> 20
|
| 572 |
qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
|
| 573 |
|
| 574 |
+
x_qs[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+0] = qs0;
|
| 575 |
|
| 576 |
int qs1 = (ql >> 4) & 0x0F0F0F0F;
|
| 577 |
qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
|
|
|
|
| 579 |
qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
|
| 580 |
qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
|
| 581 |
|
| 582 |
+
x_qs[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+1] = qs1;
|
| 583 |
}
|
| 584 |
|
| 585 |
const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
|
|
|
|
| 601 |
|
| 602 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 603 |
static __device__ __forceinline__ void vec_dot_q5_1_q8_1_dp4a(
|
| 604 |
+
const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
|
| 605 |
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
| 606 |
+
GGML_UNUSED(x_sc);
|
|
|
|
| 607 |
|
| 608 |
const int * y_qs = (const int *) y + 4;
|
| 609 |
const half2 * y_ds = (const half2 *) y;
|
|
|
|
| 628 |
}
|
| 629 |
|
| 630 |
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
|
| 631 |
+
(&x_qs[i*(2*WARP_SIZE + 1) + 2*k0], u, x_dm[index_bx], y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
|
| 632 |
}
|
| 633 |
}
|
| 634 |
}
|
| 635 |
|
| 636 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 637 |
static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma(
|
| 638 |
+
const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
|
| 639 |
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
| 640 |
+
#ifdef INT8_MMA_AVAILABLE
|
| 641 |
+
GGML_UNUSED(x_sc);
|
| 642 |
|
| 643 |
typedef mma_int_A_I16K8 mma_A;
|
| 644 |
typedef mma_int_B_J8K8 mma_B;
|
|
|
|
| 658 |
const int i = i0 + mma_A::get_i(l);
|
| 659 |
const int k = 2*(k0 + mma_A::get_k(l) % QI5_1) + mma_A::get_k(l) / QI5_1;
|
| 660 |
|
| 661 |
+
A.x[l] = x_qs[i*(2*WARP_SIZE + 1) + k];
|
| 662 |
}
|
| 663 |
#pragma unroll
|
| 664 |
for (int l = 0; l < mma_C::ne/2; ++l) {
|
|
|
|
| 694 |
sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
|
| 695 |
}
|
| 696 |
}
|
| 697 |
+
#else
|
| 698 |
+
GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
|
| 699 |
+
NO_DEVICE_CODE;
|
| 700 |
+
#endif // INT8_MMA_AVAILABLE
|
| 701 |
}
|
| 702 |
|
| 703 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
|
| 704 |
+
const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
|
| 705 |
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
|
| 706 |
+
GGML_UNUSED(x_sc);
|
|
|
|
| 707 |
|
| 708 |
const int kbx = threadIdx.x / QI8_0;
|
| 709 |
const int kqsx = threadIdx.x % QI8_0;
|
|
|
|
| 719 |
|
| 720 |
const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
|
| 721 |
|
| 722 |
+
x_qs[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_int8(bxi->qs, kqsx);
|
| 723 |
}
|
| 724 |
|
| 725 |
const int blocks_per_tile_x_row = WARP_SIZE / QI8_0;
|
|
|
|
| 741 |
|
| 742 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 743 |
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
|
| 744 |
+
const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
|
| 745 |
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
| 746 |
+
GGML_UNUSED(x_sc);
|
|
|
|
| 747 |
|
| 748 |
const float * x_dmf = (const float *) x_dm;
|
| 749 |
const int * y_qs = (const int *) y + 4;
|
|
|
|
| 758 |
const int i = i0 + threadIdx.x;
|
| 759 |
|
| 760 |
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
|
| 761 |
+
(&x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0], x_dmf[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0],
|
| 762 |
y_df[j*MMQ_TILE_Y_K + k0/QI8_1]);
|
| 763 |
}
|
| 764 |
}
|
|
|
|
| 766 |
|
| 767 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 768 |
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
| 769 |
+
const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
|
| 770 |
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
| 771 |
+
#ifdef INT8_MMA_AVAILABLE
|
| 772 |
+
GGML_UNUSED(x_sc);
|
| 773 |
|
| 774 |
typedef mma_int_A_I16K8 mma_A;
|
| 775 |
typedef mma_int_B_J8K8 mma_B;
|
|
|
|
| 790 |
const int i = i0 + mma_A::get_i(l);
|
| 791 |
const int k = k0 + mma_A::get_k(l);
|
| 792 |
|
| 793 |
+
A.x[l] = x_qs[i*(WARP_SIZE + 1) + k];
|
| 794 |
}
|
| 795 |
#pragma unroll
|
| 796 |
for (int l = 0; l < mma_C::ne/2; ++l) {
|
|
|
|
| 825 |
sum[(j0/B.J)*C.ne + l] += C.x[l]*dA[l/2]*dB[l%2];
|
| 826 |
}
|
| 827 |
}
|
| 828 |
+
#else
|
| 829 |
+
GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
|
| 830 |
+
NO_DEVICE_CODE;
|
| 831 |
+
#endif // INT8_MMA_AVAILABLE
|
| 832 |
}
|
| 833 |
|
| 834 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
|
| 835 |
+
const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
|
| 836 |
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
|
|
|
|
| 837 |
|
| 838 |
const int kbx = threadIdx.x / QI2_K;
|
| 839 |
const int kqsx = threadIdx.x % QI2_K;
|
|
|
|
| 848 |
|
| 849 |
const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + kbx;
|
| 850 |
|
| 851 |
+
const int x_ql_0 = get_int_from_uint8(bxi->qs, kqsx);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 852 |
|
| 853 |
#pragma unroll
|
| 854 |
+
for (int l = 0; l < QR2_K; ++l) {
|
| 855 |
+
const int k = kbx*QI2_K + (kqsx/8)*8 + l*2 + (kqsx % 8)/4;
|
| 856 |
|
| 857 |
+
int x_qs_k = ((x_ql_0 >> (2*l)) & 0x03030303) << (2*(kqsx % 4));
|
| 858 |
+
x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 1, WARP_SIZE);
|
| 859 |
+
x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 2, WARP_SIZE);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 860 |
|
| 861 |
+
if (kqsx % QR2_K != 0) {
|
| 862 |
+
continue;
|
| 863 |
+
}
|
| 864 |
|
| 865 |
+
x_qs[i*(WARP_SIZE + 1) + k] = x_qs_k;
|
|
|
|
| 866 |
}
|
| 867 |
|
| 868 |
+
const int sc_m = bxi->scales[kqsx];
|
| 869 |
+
#ifdef FAST_FP16_AVAILABLE
|
| 870 |
+
const half2 x_dm_ik = __hmul2(bxi->dm, make_half2(sc_m & 0x0F, sc_m >> 4));
|
| 871 |
+
#else
|
| 872 |
+
const float2 bxi_dmf = __half22float2(bxi->dm);
|
| 873 |
+
const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4));
|
| 874 |
+
#endif // FAST_FP16_AVAILABLE
|
| 875 |
|
| 876 |
+
x_dm[i*(WARP_SIZE + 1) + threadIdx.x] = x_dm_ik;
|
| 877 |
}
|
| 878 |
}
|
| 879 |
|
| 880 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 881 |
+
static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
|
| 882 |
+
const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
|
| 883 |
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
| 884 |
|
| 885 |
+
const int * y_qs = (const int *) y + 4;
|
| 886 |
+
const float * y_df = (const float *) y;
|
|
|
|
|
|
|
| 887 |
|
| 888 |
#pragma unroll
|
| 889 |
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
|
|
|
| 893 |
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
|
| 894 |
const int i = i0 + threadIdx.x;
|
| 895 |
|
| 896 |
+
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq(
|
| 897 |
+
&x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + (QR2_K*k0) % WARP_SIZE],
|
| 898 |
+
&x_dm[i*(WARP_SIZE + 1) + k0], y_df[j*MMQ_TILE_Y_K + ((QR2_K*k0) % WARP_SIZE)/QI8_1]);
|
| 899 |
+
}
|
| 900 |
+
}
|
| 901 |
+
}
|
| 902 |
+
|
| 903 |
+
template <int mmq_x, int mmq_y, int nwarps>
|
| 904 |
+
static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
| 905 |
+
const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
|
| 906 |
+
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
| 907 |
+
#ifdef INT8_MMA_AVAILABLE
|
| 908 |
+
|
| 909 |
+
typedef mma_int_A_I16K4 mma_A;
|
| 910 |
+
typedef mma_int_B_J8K4 mma_B;
|
| 911 |
+
typedef mma_int_C_I16J8 mma_C;
|
| 912 |
+
|
| 913 |
+
const int * y_qs = (const int *) y + 4;
|
| 914 |
+
const float * y_df = (const float *) y;
|
| 915 |
|
| 916 |
+
const int i0 = threadIdx.y*mma_A::I;
|
| 917 |
+
static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
|
| 918 |
|
| 919 |
+
mma_A A[2];
|
| 920 |
+
float dA[mma_C::ne/2][2];
|
| 921 |
+
float mA[mma_C::ne/2][2];
|
| 922 |
|
| 923 |
#pragma unroll
|
| 924 |
+
for (int l = 0; l < mma_A::ne; ++l) {
|
| 925 |
+
const int i = i0 + mma_A::get_i(l);
|
| 926 |
+
const int shift = 2*mma_A::get_k(l);
|
| 927 |
|
| 928 |
+
A[0].x[l] = (x_qs[i*(WARP_SIZE + 1) + k0 + 0] >> shift) & 0x03030303;
|
| 929 |
+
A[1].x[l] = (x_qs[i*(WARP_SIZE + 1) + k0 + 1] >> shift) & 0x03030303;
|
| 930 |
+
}
|
| 931 |
|
| 932 |
+
#pragma unroll
|
| 933 |
+
for (int l = 0; l < mma_C::ne/2; ++l) {
|
| 934 |
+
const int i = i0 + mma_C::get_i(2*l);
|
| 935 |
+
|
| 936 |
+
#pragma unroll
|
| 937 |
+
for (int kk = 0; kk < 2; ++kk) {
|
| 938 |
+
const float2 dm = __half22float2(x_dm[i*(WARP_SIZE + 1) + k0 + kk]);
|
| 939 |
+
|
| 940 |
+
dA[l][kk] = dm.x;
|
| 941 |
+
mA[l][kk] = dm.y;
|
| 942 |
}
|
| 943 |
}
|
| 944 |
+
|
| 945 |
+
#pragma unroll
|
| 946 |
+
for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
|
| 947 |
+
mma_C Cd[2];
|
| 948 |
+
mma_C Cm[2];
|
| 949 |
+
mma_B B[2];
|
| 950 |
+
float dB[mma_C::ne/2];
|
| 951 |
+
|
| 952 |
+
#pragma unroll
|
| 953 |
+
for (int l = 0; l < mma_B::ne; ++l) {
|
| 954 |
+
const int j = j0 + mma_B::get_j(l);
|
| 955 |
+
const int k = (4*k0 + mma_B::get_k(l)) % WARP_SIZE;
|
| 956 |
+
|
| 957 |
+
B[0].x[l] = y_qs[j*MMQ_TILE_Y_K + k + 0];
|
| 958 |
+
B[1].x[l] = y_qs[j*MMQ_TILE_Y_K + k + mma_B::K];
|
| 959 |
+
}
|
| 960 |
+
#pragma unroll
|
| 961 |
+
for (int l = 0; l < mma_C::ne/2; ++l) {
|
| 962 |
+
const int j = j0 + mma_C::get_j(l);
|
| 963 |
+
|
| 964 |
+
dB[l] = y_df[j*MMQ_TILE_Y_K + ((4*k0)/QI8_1) % (WARP_SIZE/QI8_1)];
|
| 965 |
+
}
|
| 966 |
+
|
| 967 |
+
Cd[0].mma_K4(A[0], B[0]);
|
| 968 |
+
Cd[1].mma_K4(A[1], B[1]);
|
| 969 |
+
|
| 970 |
+
mma_A A1;
|
| 971 |
+
A1.x[0] = 0x01010101;
|
| 972 |
+
A1.x[1] = 0x01010101;
|
| 973 |
+
Cm[0].mma_K4(A1, B[0]);
|
| 974 |
+
Cm[1].mma_K4(A1, B[1]);
|
| 975 |
+
|
| 976 |
+
#pragma unroll
|
| 977 |
+
for (int l = 0; l < mma_C::ne; ++l) {
|
| 978 |
+
sum[(j0/mma_B::J)*mma_C::ne + l] += (Cd[0].x[l]*dA[l/2][0] + Cd[1].x[l]*dA[l/2][1] - Cm[0].x[l]*mA[l/2][0] - Cm[1].x[l]*mA[l/2][1])*dB[l%2];
|
| 979 |
+
}
|
| 980 |
+
}
|
| 981 |
+
#else
|
| 982 |
+
GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
|
| 983 |
+
NO_DEVICE_CODE;
|
| 984 |
+
#endif // INT8_MMA_AVAILABLE
|
| 985 |
}
|
| 986 |
|
| 987 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
|
| 988 |
+
const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
|
| 989 |
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
|
| 990 |
|
| 991 |
const int kbx = threadIdx.x / QI3_K;
|
|
|
|
| 1001 |
|
| 1002 |
const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbx;
|
| 1003 |
|
| 1004 |
+
const int x_ql_0 = get_int_from_uint8(bxi->qs, kqsx);
|
| 1005 |
+
const int x_qh_0 = get_int_from_uint8(bxi->hmask, kqsx % (QI3_K/2)) >> (4 * (kqsx / (QI3_K/2)));
|
| 1006 |
+
|
| 1007 |
+
#pragma unroll
|
| 1008 |
+
for (int l = 0; l < QR3_K; ++l) {
|
| 1009 |
+
const int k = kbx*(QR3_K*QI3_K) + (kqsx/8)*32 + l*8 + kqsx % 8;
|
| 1010 |
+
|
| 1011 |
+
const int x_ql_k = (x_ql_0 >> (2*l)) & 0x03030303;
|
| 1012 |
+
const int x_qh_k = ((x_qh_0 >> l) << 2) & 0x04040404;
|
| 1013 |
+
|
| 1014 |
+
int x_qs_k = (x_ql_k | x_qh_k) << (4*(k%2));
|
| 1015 |
+
x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 1, WARP_SIZE);
|
| 1016 |
+
|
| 1017 |
+
if (kqsx % 2 != 0) {
|
| 1018 |
+
continue;
|
| 1019 |
+
}
|
| 1020 |
+
|
| 1021 |
+
x_qs[i*(2*WARP_SIZE + 1) + k/2] = x_qs_k;
|
| 1022 |
+
}
|
| 1023 |
}
|
| 1024 |
|
| 1025 |
const int blocks_per_tile_x_row = WARP_SIZE / QI3_K;
|
|
|
|
| 1039 |
x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = bxi->d;
|
| 1040 |
}
|
| 1041 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1042 |
#pragma unroll
|
| 1043 |
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
|
| 1044 |
int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
|
|
|
|
| 1066 |
}
|
| 1067 |
|
| 1068 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1069 |
+
static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
|
| 1070 |
+
const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
|
| 1071 |
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
| 1072 |
|
| 1073 |
+
const float * x_df = (const float *) x_dm;
|
| 1074 |
+
const int * y_qs = (const int *) y + 4;
|
| 1075 |
+
const float * y_df = (const float *) y;
|
| 1076 |
|
| 1077 |
#pragma unroll
|
| 1078 |
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
|
|
|
| 1087 |
|
| 1088 |
const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
|
| 1089 |
|
| 1090 |
+
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq(
|
| 1091 |
+
&x_qs[i*(2*WARP_SIZE + 1) + 2*k0], &y_qs[j*MMQ_TILE_Y_K + (k0*QR3_K) % WARP_SIZE], scales,
|
| 1092 |
+
x_df[i*(WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[j*MMQ_TILE_Y_K + ((k0*QR3_K) % WARP_SIZE)/QI8_1]);
|
| 1093 |
+
}
|
| 1094 |
+
}
|
| 1095 |
+
}
|
| 1096 |
+
|
| 1097 |
+
template <int mmq_x, int mmq_y, int nwarps>
|
| 1098 |
+
static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mma(
|
| 1099 |
+
const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
|
| 1100 |
+
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
| 1101 |
+
#ifdef INT8_MMA_AVAILABLE
|
| 1102 |
+
|
| 1103 |
+
typedef mma_int_A_I16K4 mma_A;
|
| 1104 |
+
typedef mma_int_B_J8K4 mma_B;
|
| 1105 |
+
typedef mma_int_C_I16J8 mma_C;
|
| 1106 |
+
|
| 1107 |
+
const float * x_df = (const float *) x_dm;
|
| 1108 |
+
const int * y_qs = (const int *) y + 4;
|
| 1109 |
+
const float * y_df = (const float *) y;
|
| 1110 |
+
|
| 1111 |
+
const int i0 = threadIdx.y*mma_A::I;
|
| 1112 |
+
static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
|
| 1113 |
+
|
| 1114 |
+
mma_A A[2];
|
| 1115 |
+
int scA[mma_C::ne/2][2];
|
| 1116 |
+
float dA[mma_C::ne/2];
|
| 1117 |
|
| 1118 |
#pragma unroll
|
| 1119 |
+
for (int l = 0; l < mma_A::ne; ++l) {
|
| 1120 |
+
const int i = i0 + mma_A::get_i(l);
|
| 1121 |
+
const int k = QR3_K*k0 + mma_A::get_k(l);
|
|
|
|
| 1122 |
|
| 1123 |
+
A[0].x[l] = (x_qs[i*(2*WARP_SIZE + 1) + k/2 + 0] >> (4*(k%2))) & 0x0F0F0F0F;
|
| 1124 |
+
A[1].x[l] = (x_qs[i*(2*WARP_SIZE + 1) + k/2 + mma_A::K/2] >> (4*(k%2))) & 0x0F0F0F0F;
|
| 1125 |
+
A[0].x[l] = __vsubss4(A[0].x[l], 0x04040404);
|
| 1126 |
+
A[1].x[l] = __vsubss4(A[1].x[l], 0x04040404);
|
| 1127 |
+
}
|
| 1128 |
|
| 1129 |
+
#pragma unroll
|
| 1130 |
+
for (int l = 0; l < mma_C::ne/2; ++l) {
|
| 1131 |
+
const int i = i0 + mma_C::get_i(2*l);
|
| 1132 |
|
| 1133 |
+
const int kbx = k0 / QI3_K;
|
| 1134 |
+
const int ky = (k0 % QI3_K) * QR3_K;
|
| 1135 |
+
const int8_t * sc = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
|
| 1136 |
+
|
| 1137 |
+
scA[l][0] = sc[0];
|
| 1138 |
+
scA[l][1] = sc[1];
|
| 1139 |
+
}
|
| 1140 |
+
|
| 1141 |
+
#pragma unroll
|
| 1142 |
+
for (int l = 0; l < mma_C::ne/2; ++l) {
|
| 1143 |
+
const int i = i0 + mma_C::get_i(2*l);
|
| 1144 |
+
|
| 1145 |
+
dA[l] = x_df[i*(WARP_SIZE/QI3_K) + i/QI3_K + k0/QI3_K];
|
| 1146 |
+
}
|
| 1147 |
+
|
| 1148 |
+
#pragma unroll
|
| 1149 |
+
for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
|
| 1150 |
+
mma_C C[2];
|
| 1151 |
+
mma_B B[2];
|
| 1152 |
+
float dB[mma_C::ne/2];
|
| 1153 |
+
|
| 1154 |
+
#pragma unroll
|
| 1155 |
+
for (int l = 0; l < mma_B::ne; ++l) {
|
| 1156 |
+
const int j = j0 + mma_B::get_j(l);
|
| 1157 |
+
const int k = (4*k0 + mma_B::get_k(l)) % WARP_SIZE;
|
| 1158 |
+
|
| 1159 |
+
B[0].x[l] = y_qs[j*MMQ_TILE_Y_K + k + 0];
|
| 1160 |
+
B[1].x[l] = y_qs[j*MMQ_TILE_Y_K + k + mma_B::K];
|
| 1161 |
+
}
|
| 1162 |
+
#pragma unroll
|
| 1163 |
+
for (int l = 0; l < mma_C::ne/2; ++l) {
|
| 1164 |
+
const int j = j0 + mma_C::get_j(l);
|
| 1165 |
+
|
| 1166 |
+
dB[l] = y_df[j*MMQ_TILE_Y_K + ((4*k0)/QI8_1) % (WARP_SIZE/QI8_1)];
|
| 1167 |
+
}
|
| 1168 |
+
|
| 1169 |
+
C[0].mma_K4(A[0], B[0]);
|
| 1170 |
+
C[1].mma_K4(A[1], B[1]);
|
| 1171 |
+
|
| 1172 |
+
#pragma unroll
|
| 1173 |
+
for (int l = 0; l < mma_C::ne; ++l) {
|
| 1174 |
+
sum[(j0/mma_B::J)*mma_C::ne + l] += (C[0].x[l]*scA[l/2][0] + C[1].x[l]*scA[l/2][1])*dA[l/2]*dB[l%2];
|
| 1175 |
}
|
| 1176 |
}
|
| 1177 |
+
#else
|
| 1178 |
+
GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
|
| 1179 |
+
NO_DEVICE_CODE;
|
| 1180 |
+
#endif // INT8_MMA_AVAILABLE
|
| 1181 |
}
|
| 1182 |
|
| 1183 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
|
| 1184 |
+
const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
|
| 1185 |
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
|
|
|
|
| 1186 |
|
| 1187 |
const int kbx = 0; // threadIdx.x / QI4_K
|
| 1188 |
const int kqsx = threadIdx.x; // threadIdx.x % QI4_K
|
|
|
|
| 1197 |
|
| 1198 |
const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbx;
|
| 1199 |
|
| 1200 |
+
x_qs[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
|
| 1201 |
}
|
| 1202 |
|
| 1203 |
const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256
|
|
|
|
| 1240 |
|
| 1241 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1242 |
static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
|
| 1243 |
+
const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
|
| 1244 |
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
| 1245 |
|
|
|
|
|
|
|
| 1246 |
const int * y_qs = (const int *) y + 4;
|
| 1247 |
const half2 * y_ds = (const half2 *) y;
|
| 1248 |
|
|
|
|
| 1257 |
const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2*((k0 % 16) / 8);
|
| 1258 |
|
| 1259 |
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq(
|
| 1260 |
+
&x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + (QR4_K*k0) % WARP_SIZE], sc, sc+8,
|
| 1261 |
x_dm[i*(WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[j*MMQ_TILE_Y_K + ((QR4_K*k0) % WARP_SIZE)/QI8_1]);
|
| 1262 |
}
|
| 1263 |
}
|
|
|
|
| 1265 |
|
| 1266 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1267 |
static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma(
|
| 1268 |
+
const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
|
| 1269 |
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
| 1270 |
+
#ifdef INT8_MMA_AVAILABLE
|
|
|
|
| 1271 |
|
| 1272 |
typedef mma_int_A_I16K8 mma_A;
|
| 1273 |
typedef mma_int_B_J8K8 mma_B;
|
|
|
|
| 1290 |
const int i = i0 + mma_A::get_i(l);
|
| 1291 |
const int k = k0 + mma_A::get_k(l);
|
| 1292 |
|
| 1293 |
+
A[kvdr/4].x[l] = (x_qs[i*(WARP_SIZE + 1) + k] >> kvdr) & 0x0F0F0F0F;
|
| 1294 |
}
|
| 1295 |
|
| 1296 |
#pragma unroll
|
|
|
|
| 1351 |
sum[(j0/mma_B::J)*mma_C::ne + l] += __low2float(dmA[l/2])*tmpd[l] - __high2float(dmA[l/2])*tmpm[l];
|
| 1352 |
}
|
| 1353 |
}
|
| 1354 |
+
#else
|
| 1355 |
+
GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
|
| 1356 |
+
NO_DEVICE_CODE;
|
| 1357 |
+
#endif // INT8_MMA_AVAILABLE
|
| 1358 |
}
|
| 1359 |
|
| 1360 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
|
| 1361 |
+
const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
|
| 1362 |
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
|
|
|
|
| 1363 |
|
| 1364 |
const int kbx = 0; // threadIdx.x / QI5_K
|
| 1365 |
const int kqsx = threadIdx.x; // threadIdx.x % QI5_K
|
|
|
|
| 1386 |
const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0;
|
| 1387 |
const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + (QI5_K/4);
|
| 1388 |
|
| 1389 |
+
x_qs[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0;
|
| 1390 |
+
x_qs[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
|
| 1391 |
}
|
| 1392 |
|
| 1393 |
const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256
|
|
|
|
| 1430 |
|
| 1431 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1432 |
static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
|
| 1433 |
+
const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
|
| 1434 |
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
| 1435 |
|
|
|
|
|
|
|
| 1436 |
const int * y_qs = (const int *) y + 4;
|
| 1437 |
const half2 * y_ds = (const half2 *) y;
|
| 1438 |
|
|
|
|
| 1447 |
const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8);
|
| 1448 |
|
| 1449 |
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq(
|
| 1450 |
+
&x_qs[i*(QR5_K*WARP_SIZE + 1) + QR5_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR5_K*k0) % WARP_SIZE], sc, sc+8,
|
| 1451 |
x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[j*MMQ_TILE_Y_K + ((QR5_K*k0) % WARP_SIZE)/QI8_1]);
|
| 1452 |
}
|
| 1453 |
}
|
|
|
|
| 1455 |
|
| 1456 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1457 |
static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma(
|
| 1458 |
+
const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
|
| 1459 |
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
| 1460 |
+
#ifdef INT8_MMA_AVAILABLE
|
|
|
|
| 1461 |
|
| 1462 |
typedef mma_int_A_I16K8 mma_A;
|
| 1463 |
typedef mma_int_B_J8K8 mma_B;
|
|
|
|
| 1480 |
const int i = i0 + mma_A::get_i(l);
|
| 1481 |
const int k = QR5_K*k0 + QR5_K*kvdr + mma_A::get_k(l);
|
| 1482 |
|
| 1483 |
+
A[kvdr/4].x[l] = x_qs[i*(QR5_K*WARP_SIZE + 1) + k];
|
| 1484 |
}
|
| 1485 |
|
| 1486 |
#pragma unroll
|
|
|
|
| 1541 |
sum[(j0/mma_B::J)*mma_C::ne + l] += __low2float(dmA[l/2])*tmpd[l] - __high2float(dmA[l/2])*tmpm[l];
|
| 1542 |
}
|
| 1543 |
}
|
| 1544 |
+
#else
|
| 1545 |
+
GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
|
| 1546 |
+
NO_DEVICE_CODE;
|
| 1547 |
+
#endif // INT8_MMA_AVAILABLE
|
| 1548 |
}
|
| 1549 |
|
| 1550 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
|
| 1551 |
+
const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
|
| 1552 |
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
|
|
|
|
| 1553 |
|
| 1554 |
const int kbx = 0; // threadIdx.x / QI6_K
|
| 1555 |
const int kqsx = threadIdx.x; // threadIdx.x % QI6_K
|
|
|
|
| 1576 |
const int kq0 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + 0;
|
| 1577 |
const int kq1 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + (QI6_K/2);
|
| 1578 |
|
| 1579 |
+
x_qs[i * (2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
|
| 1580 |
+
x_qs[i * (2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
|
| 1581 |
}
|
| 1582 |
|
| 1583 |
const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256
|
|
|
|
| 1613 |
|
| 1614 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1615 |
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
|
| 1616 |
+
const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
|
| 1617 |
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
| 1618 |
|
|
|
|
|
|
|
| 1619 |
const float * x_dmf = (const float *) x_dm;
|
| 1620 |
const int * y_qs = (const int *) y + 4;
|
| 1621 |
const float * y_df = (const float *) y;
|
|
|
|
| 1631 |
const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/8]);
|
| 1632 |
|
| 1633 |
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq(
|
| 1634 |
+
&x_qs[i*(QR6_K*WARP_SIZE + 1) + QR6_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR6_K*k0) % WARP_SIZE], sc,
|
| 1635 |
x_dmf[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + ((QR6_K*k0) % WARP_SIZE)/QI8_1]);
|
| 1636 |
}
|
| 1637 |
}
|
|
|
|
| 1639 |
|
| 1640 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1641 |
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
| 1642 |
+
const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
|
| 1643 |
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
| 1644 |
+
#ifdef INT8_MMA_AVAILABLE
|
|
|
|
| 1645 |
|
| 1646 |
typedef mma_int_A_I16K4 mma_A;
|
| 1647 |
typedef mma_int_B_J8K4 mma_B;
|
|
|
|
| 1652 |
const float * y_df = (const float *) y;
|
| 1653 |
|
| 1654 |
const int i0 = threadIdx.y*mma_A::I;
|
| 1655 |
+
#ifdef INT8_MMA_AVAILABLE
|
| 1656 |
static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
|
| 1657 |
+
#endif // INT8_MMA_AVAILABLE
|
| 1658 |
|
| 1659 |
mma_A A[4];
|
| 1660 |
int scA[mma_C::ne/2][4];
|
|
|
|
| 1666 |
const int i = i0 + mma_A::get_i(l);
|
| 1667 |
const int k = QR6_K*k0 + QR6_K*kvdr + mma_A::get_k(l);
|
| 1668 |
|
| 1669 |
+
A[kvdr/2 + 0].x[l] = x_qs[i*(QR6_K*WARP_SIZE + 1) + k + 0];
|
| 1670 |
+
A[kvdr/2 + 1].x[l] = x_qs[i*(QR6_K*WARP_SIZE + 1) + k + mma_A::K];
|
| 1671 |
}
|
| 1672 |
|
| 1673 |
#pragma unroll
|
|
|
|
| 1727 |
sum[(j0/mma_B::J)*mma_C::ne + l] += tmp[l]*dA[l/2];
|
| 1728 |
}
|
| 1729 |
}
|
| 1730 |
+
#else
|
| 1731 |
+
GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
|
| 1732 |
+
NO_DEVICE_CODE;
|
| 1733 |
+
#endif // INT8_MMA_AVAILABLE
|
| 1734 |
}
|
| 1735 |
|
| 1736 |
template<int mmq_x, int mmq_y, int nwarps, bool need_check>
|
|
|
|
| 1761 |
typedef mma_int_C_I16J8 mma_C;
|
| 1762 |
|
| 1763 |
const int i0 = threadIdx.y*mma_C::I;
|
| 1764 |
+
#ifdef INT8_MMA_AVAILABLE
|
| 1765 |
static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y");
|
| 1766 |
+
#endif // INT8_MMA_AVAILABLE
|
| 1767 |
|
| 1768 |
#pragma unroll
|
| 1769 |
for (int j0 = 0; j0 < mmq_x; j0 += mma_C::J) {
|
|
|
|
| 1793 |
|
| 1794 |
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
| 1795 |
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_0> {
|
| 1796 |
+
static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
|
| 1797 |
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, nwarps, need_check>;
|
| 1798 |
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q4_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
|
| 1799 |
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1800 |
};
|
| 1801 |
|
| 1802 |
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
| 1803 |
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_1> {
|
| 1804 |
+
static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ;
|
| 1805 |
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, nwarps, need_check>;
|
| 1806 |
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q4_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
|
| 1807 |
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1808 |
};
|
| 1809 |
|
| 1810 |
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
| 1811 |
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
|
| 1812 |
+
static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ;
|
| 1813 |
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, nwarps, need_check>;
|
| 1814 |
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q5_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
|
| 1815 |
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1816 |
};
|
| 1817 |
|
| 1818 |
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
| 1819 |
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
|
| 1820 |
+
static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ;
|
| 1821 |
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, nwarps, need_check>;
|
| 1822 |
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q5_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
|
| 1823 |
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1824 |
};
|
| 1825 |
|
| 1826 |
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
| 1827 |
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> {
|
| 1828 |
+
static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
|
| 1829 |
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, nwarps, need_check>;
|
| 1830 |
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
|
| 1831 |
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1832 |
};
|
| 1833 |
|
| 1834 |
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
| 1835 |
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q2_K> {
|
| 1836 |
+
static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
|
| 1837 |
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, nwarps, need_check>;
|
| 1838 |
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q2_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
|
| 1839 |
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
| 1840 |
};
|
| 1841 |
|
| 1842 |
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
| 1843 |
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q3_K> {
|
| 1844 |
+
static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ;
|
| 1845 |
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, nwarps, need_check>;
|
| 1846 |
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q3_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
|
| 1847 |
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
| 1848 |
};
|
| 1849 |
|
| 1850 |
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
| 1851 |
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> {
|
| 1852 |
+
static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ;
|
| 1853 |
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, nwarps, need_check>;
|
| 1854 |
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q4_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
|
| 1855 |
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1856 |
};
|
| 1857 |
|
| 1858 |
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
| 1859 |
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> {
|
| 1860 |
+
static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ;
|
| 1861 |
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, nwarps, need_check>;
|
| 1862 |
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q5_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
|
| 1863 |
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1864 |
};
|
| 1865 |
|
| 1866 |
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
| 1867 |
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
|
| 1868 |
+
static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ;
|
| 1869 |
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, nwarps, need_check>;
|
| 1870 |
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q6_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
|
| 1871 |
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1872 |
};
|
| 1873 |
|
| 1874 |
+
static bool mmq_need_sum(const ggml_type type_x) {
|
| 1875 |
switch (type_x) {
|
| 1876 |
case GGML_TYPE_Q4_0:
|
| 1877 |
case GGML_TYPE_Q4_1:
|
|
|
|
| 1905 |
#if __CUDA_ARCH__ >= CC_VOLTA
|
| 1906 |
__launch_bounds__(WARP_SIZE*nwarps, 1)
|
| 1907 |
#else
|
| 1908 |
+
__launch_bounds__(WARP_SIZE*nwarps, 2)
|
| 1909 |
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
| 1910 |
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
| 1911 |
static __global__ void mul_mat_q(
|
|
|
|
| 1924 |
constexpr int mmq_y = get_mmq_y_device(mmq_x);
|
| 1925 |
constexpr int vdr = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vdr;
|
| 1926 |
constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
|
| 1927 |
+
|
| 1928 |
+
#ifdef INT8_MMA_AVAILABLE
|
| 1929 |
+
constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot_mma;
|
| 1930 |
+
constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
|
| 1931 |
+
#else
|
| 1932 |
+
constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot_dp4a;
|
| 1933 |
+
constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
|
| 1934 |
+
#endif // INT8_MMA_AVAILABLE
|
| 1935 |
|
| 1936 |
constexpr tile_x_sizes txs = get_tile_x_sizes_device<mmq_y>(type);
|
| 1937 |
|
| 1938 |
extern __shared__ char data_mul_mat_q[];
|
| 1939 |
+
int * tile_x_qs = (int *) data_mul_mat_q;
|
| 1940 |
+
half2 * tile_x_dm = (half2 *) (tile_x_qs + txs.qs);
|
| 1941 |
+
int * tile_x_sc = (int *) (tile_x_dm + txs.dm);
|
|
|
|
| 1942 |
int * tile_y = (int *) (tile_x_sc + txs.sc); // [mmq_x * (WARP_SIZE + WARP_SIZE/QI8_1)]
|
| 1943 |
|
| 1944 |
const int blocks_per_row_x = ne00 / qk;
|
|
|
|
| 1954 |
|
| 1955 |
for (int kb0 = 0; kb0 < blocks_per_row_x; kb0 += blocks_per_warp) {
|
| 1956 |
|
| 1957 |
+
load_tiles(x, tile_x_qs, tile_x_dm, tile_x_sc, stride01*blockIdx.x*mmq_y + kb0, tile_x_max_i, stride01);
|
| 1958 |
|
| 1959 |
#pragma unroll
|
| 1960 |
for (int kr = 0; kr < qr; ++kr) {
|
|
|
|
| 1970 |
|
| 1971 |
// #pragma unroll // unrolling this loop causes too much register pressure
|
| 1972 |
for (int k0 = kr*WARP_SIZE/qr; k0 < (kr+1)*WARP_SIZE/qr; k0 += vdr) {
|
| 1973 |
+
vec_dot(tile_x_qs, tile_x_dm, tile_x_sc, tile_y, sum, k0);
|
| 1974 |
}
|
| 1975 |
|
| 1976 |
__syncthreads();
|
|
|
|
| 1987 |
int64_t ne0;
|
| 1988 |
};
|
| 1989 |
|
| 1990 |
+
constexpr int mmq_get_nwarps(int mmq_x) {
|
| 1991 |
+
return mmq_x >= 32 ? 8 : 4;
|
| 1992 |
+
}
|
| 1993 |
+
|
| 1994 |
+
static int mmq_get_shmem(const ggml_type type, const int mmq_x, const int mmq_y) {
|
| 1995 |
+
const tile_x_sizes txs = get_tile_x_sizes_host(type, mmq_y);
|
| 1996 |
+
const int nwarps = mmq_get_nwarps(mmq_x);
|
| 1997 |
+
|
| 1998 |
+
const int shmem_x = txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
|
| 1999 |
+
const int shmem_y = mmq_x*WARP_SIZE*sizeof(int) + mmq_x*(WARP_SIZE/QI8_1)*sizeof(half2);
|
| 2000 |
+
return shmem_x + GGML_PAD(shmem_y, nwarps*WARP_SIZE*sizeof(int));
|
| 2001 |
+
}
|
| 2002 |
+
|
| 2003 |
template <ggml_type type, int mmq_x, int nwarps>
|
| 2004 |
static void launch_mul_mat_q(const mmq_args & args, cudaStream_t stream) {
|
| 2005 |
const int id = ggml_cuda_get_device();
|
|
|
|
| 2011 |
const dim3 block_nums(block_num_x, block_num_y, 1);
|
| 2012 |
const dim3 block_dims(WARP_SIZE, nwarps, 1);
|
| 2013 |
|
| 2014 |
+
const int shmem = mmq_get_shmem(type, mmq_x, mmq_y);
|
|
|
|
|
|
|
|
|
|
| 2015 |
|
| 2016 |
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
| 2017 |
static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
|
|
|
| 2035 |
|
| 2036 |
template <ggml_type type>
|
| 2037 |
void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) {
|
| 2038 |
+
const int id = ggml_cuda_get_device();
|
| 2039 |
+
const int nsm = ggml_cuda_info().devices[id].nsm;
|
| 2040 |
+
const int cc = ggml_cuda_info().devices[id].cc;
|
| 2041 |
+
const int smpbo = ggml_cuda_info().devices[id].smpbo;
|
| 2042 |
|
| 2043 |
const int mmq_x_max = get_mmq_x_max_host(cc);
|
| 2044 |
const int mmq_y = get_mmq_y_host(cc, mmq_x_max);
|
|
|
|
| 2051 |
const int block_num_x = (args.ne11 + mmq_x - 1) / mmq_x;
|
| 2052 |
const int nwaves = (block_num_x*block_num_y + nsm - 1) / nsm;
|
| 2053 |
|
| 2054 |
+
if (nwaves < nwaves_best && mmq_get_shmem(type, mmq_x, mmq_y) <= smpbo) {
|
| 2055 |
mmq_x_best = mmq_x;
|
| 2056 |
nwaves_best = nwaves;
|
| 2057 |
}
|
|
|
|
| 2059 |
|
| 2060 |
switch (mmq_x_best) {
|
| 2061 |
case 8:
|
| 2062 |
+
launch_mul_mat_q<type, 8, mmq_get_nwarps( 8)>(args, stream);
|
| 2063 |
break;
|
| 2064 |
case 16:
|
| 2065 |
+
launch_mul_mat_q<type, 16, mmq_get_nwarps( 16)>(args, stream);
|
| 2066 |
break;
|
| 2067 |
case 24:
|
| 2068 |
+
launch_mul_mat_q<type, 24, mmq_get_nwarps( 24)>(args, stream);
|
| 2069 |
break;
|
| 2070 |
case 32:
|
| 2071 |
+
launch_mul_mat_q<type, 32, mmq_get_nwarps( 32)>(args, stream);
|
| 2072 |
break;
|
| 2073 |
case 40:
|
| 2074 |
+
launch_mul_mat_q<type, 40, mmq_get_nwarps( 40)>(args, stream);
|
| 2075 |
break;
|
| 2076 |
case 48:
|
| 2077 |
+
launch_mul_mat_q<type, 48, mmq_get_nwarps( 48)>(args, stream);
|
| 2078 |
break;
|
| 2079 |
case 56:
|
| 2080 |
+
launch_mul_mat_q<type, 56, mmq_get_nwarps( 56)>(args, stream);
|
| 2081 |
break;
|
| 2082 |
case 64:
|
| 2083 |
+
launch_mul_mat_q<type, 64, mmq_get_nwarps( 64)>(args, stream);
|
| 2084 |
break;
|
| 2085 |
case 72:
|
| 2086 |
+
launch_mul_mat_q<type, 72, mmq_get_nwarps( 72)>(args, stream);
|
| 2087 |
break;
|
| 2088 |
case 80:
|
| 2089 |
+
launch_mul_mat_q<type, 80, mmq_get_nwarps( 80)>(args, stream);
|
| 2090 |
break;
|
| 2091 |
case 88:
|
| 2092 |
+
launch_mul_mat_q<type, 88, mmq_get_nwarps( 88)>(args, stream);
|
| 2093 |
break;
|
| 2094 |
case 96:
|
| 2095 |
+
launch_mul_mat_q<type, 96, mmq_get_nwarps( 96)>(args, stream);
|
| 2096 |
break;
|
| 2097 |
case 104:
|
| 2098 |
+
launch_mul_mat_q<type, 104, mmq_get_nwarps(104)>(args, stream);
|
| 2099 |
break;
|
| 2100 |
case 112:
|
| 2101 |
+
launch_mul_mat_q<type, 112, mmq_get_nwarps(112)>(args, stream);
|
| 2102 |
break;
|
| 2103 |
case 120:
|
| 2104 |
+
launch_mul_mat_q<type, 120, mmq_get_nwarps(120)>(args, stream);
|
| 2105 |
break;
|
| 2106 |
case 128:
|
| 2107 |
+
launch_mul_mat_q<type, 128, mmq_get_nwarps(128)>(args, stream);
|
| 2108 |
break;
|
| 2109 |
default:
|
| 2110 |
+
fprintf(stderr, "mmq_x_best=%d\n", mmq_x_best);
|
| 2111 |
GGML_ASSERT(false);
|
| 2112 |
break;
|
| 2113 |
}
|
ggml-cuda/softmax.cu
CHANGED
|
@@ -130,6 +130,7 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
|
|
| 130 |
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
| 131 |
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
| 132 |
|
|
|
|
| 133 |
if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
|
| 134 |
switch (ncols_x) {
|
| 135 |
case 32:
|
|
|
|
| 130 |
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
| 131 |
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
| 132 |
|
| 133 |
+
// FIXME: this limit could be raised by ~2-4x on Ampere or newer
|
| 134 |
if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
|
| 135 |
switch (ncols_x) {
|
| 136 |
case 32:
|
ggml-cuda/vecdotq.cuh
CHANGED
|
@@ -265,36 +265,31 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
|
|
| 265 |
|
| 266 |
// contiguous u/y values
|
| 267 |
static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq(
|
| 268 |
-
const int * __restrict__ v, const int * __restrict__ u, const
|
| 269 |
-
const half2 & dm2, const float & d8) {
|
| 270 |
|
| 271 |
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
|
| 272 |
-
|
| 273 |
-
|
| 274 |
|
| 275 |
#pragma unroll
|
| 276 |
for (int i0 = 0; i0 < QI8_1; i0 += QI8_1/2) {
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
// fill int with 4x m
|
| 282 |
-
int m = sc >> 4;
|
| 283 |
-
m |= m << 8;
|
| 284 |
-
m |= m << 16;
|
| 285 |
|
|
|
|
| 286 |
#pragma unroll
|
| 287 |
for (int i = i0; i < i0 + QI8_1/2; ++i) {
|
| 288 |
-
|
| 289 |
-
|
|
|
|
| 290 |
}
|
| 291 |
|
| 292 |
-
|
|
|
|
| 293 |
}
|
| 294 |
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
return d8 * (dm2f.x*sumi_d - dm2f.y*sumi_m);
|
| 298 |
#else
|
| 299 |
NO_DEVICE_CODE;
|
| 300 |
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
|
|
@@ -352,8 +347,10 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
|
|
| 352 |
for (int i0 = 0; i0 < QR3_K*VDR_Q3_K_Q8_1_MMQ; i0 += QI8_1/2) {
|
| 353 |
int sumi_sc = 0;
|
| 354 |
|
|
|
|
| 355 |
for (int i = i0; i < i0 + QI8_1/2; ++i) {
|
| 356 |
-
|
|
|
|
| 357 |
}
|
| 358 |
|
| 359 |
sumi += sumi_sc * scales[i0 / (QI8_1/2)];
|
|
|
|
| 265 |
|
| 266 |
// contiguous u/y values
|
| 267 |
static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq(
|
| 268 |
+
const int * __restrict__ v, const int * __restrict__ u, const half2 * dm2, const float & d8) {
|
|
|
|
| 269 |
|
| 270 |
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
|
| 271 |
+
float sumf_d = 0.0f;
|
| 272 |
+
float sumf_m = 0.0f;
|
| 273 |
|
| 274 |
#pragma unroll
|
| 275 |
for (int i0 = 0; i0 < QI8_1; i0 += QI8_1/2) {
|
| 276 |
+
const float2 dm2f = __half22float2(dm2[i0/(QI8_1/2)]);
|
| 277 |
+
int sumi_d = 0;
|
| 278 |
+
int sumi_m = 0;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
|
| 280 |
+
const int vi0 = v[i0/(QI8_1/2)];
|
| 281 |
#pragma unroll
|
| 282 |
for (int i = i0; i < i0 + QI8_1/2; ++i) {
|
| 283 |
+
const int vi = (vi0 >> (2*(i % (QI8_1/2)))) & 0x03030303;
|
| 284 |
+
sumi_d = __dp4a(vi, u[i], sumi_d); // SIMD dot product
|
| 285 |
+
sumi_m = __dp4a(0x01010101, u[i], sumi_m);
|
| 286 |
}
|
| 287 |
|
| 288 |
+
sumf_d += dm2f.x * sumi_d;
|
| 289 |
+
sumf_m += dm2f.y * sumi_m;
|
| 290 |
}
|
| 291 |
|
| 292 |
+
return d8*(sumf_d - sumf_m);
|
|
|
|
|
|
|
| 293 |
#else
|
| 294 |
NO_DEVICE_CODE;
|
| 295 |
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
|
|
|
|
| 347 |
for (int i0 = 0; i0 < QR3_K*VDR_Q3_K_Q8_1_MMQ; i0 += QI8_1/2) {
|
| 348 |
int sumi_sc = 0;
|
| 349 |
|
| 350 |
+
#pragma unroll
|
| 351 |
for (int i = i0; i < i0 + QI8_1/2; ++i) {
|
| 352 |
+
const int vi = __vsubss4((v[i/2] >> (4*(i%2))) & 0x0F0F0F0F, 0x04040404);
|
| 353 |
+
sumi_sc = __dp4a(vi, u[i], sumi_sc); // SIMD dot product
|
| 354 |
}
|
| 355 |
|
| 356 |
sumi += sumi_sc * scales[i0 / (QI8_1/2)];
|