Spaces:
Running
Running
metal : use F32 accumulators in FA kernels (llama/13975)
Browse files
ggml/src/ggml-metal/ggml-metal.m
CHANGED
|
@@ -4766,6 +4766,8 @@ static bool ggml_metal_encode_node(
|
|
| 4766 |
GGML_ASSERT(nqptg % 8 == 0);
|
| 4767 |
GGML_ASSERT(ncpsg % 32 == 0);
|
| 4768 |
|
|
|
|
|
|
|
| 4769 |
// 2*(2*ncpsg + nqptg)*(nsg)
|
| 4770 |
// ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float)
|
| 4771 |
//
|
|
@@ -4773,7 +4775,7 @@ static bool ggml_metal_encode_node(
|
|
| 4773 |
// the shared memory needed for the simdgroups to load the KV cache
|
| 4774 |
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
|
| 4775 |
//
|
| 4776 |
-
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
|
| 4777 |
|
| 4778 |
int64_t nsgmax = 2;
|
| 4779 |
|
|
@@ -4810,9 +4812,9 @@ static bool ggml_metal_encode_node(
|
|
| 4810 |
// and store the soft_max values and the mask
|
| 4811 |
//
|
| 4812 |
// ne00*(nsg)
|
| 4813 |
-
// each simdgroup has a full
|
| 4814 |
//
|
| 4815 |
-
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + ne20*(nsg))*(sizeof(float)/2), 16))
|
| 4816 |
|
| 4817 |
int64_t nsgmax = 2;
|
| 4818 |
while (true) {
|
|
|
|
| 4766 |
GGML_ASSERT(nqptg % 8 == 0);
|
| 4767 |
GGML_ASSERT(ncpsg % 32 == 0);
|
| 4768 |
|
| 4769 |
+
const int is_q = ggml_is_quantized(src1->type) ? 1 : 0;
|
| 4770 |
+
|
| 4771 |
// 2*(2*ncpsg + nqptg)*(nsg)
|
| 4772 |
// ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float)
|
| 4773 |
//
|
|
|
|
| 4775 |
// the shared memory needed for the simdgroups to load the KV cache
|
| 4776 |
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
|
| 4777 |
//
|
| 4778 |
+
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(2*ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + is_q*(16*32*(nsg)))*(sizeof(float)/2), 16))
|
| 4779 |
|
| 4780 |
int64_t nsgmax = 2;
|
| 4781 |
|
|
|
|
| 4812 |
// and store the soft_max values and the mask
|
| 4813 |
//
|
| 4814 |
// ne00*(nsg)
|
| 4815 |
+
// each simdgroup has a full f32 head vector in shared mem to accumulate results
|
| 4816 |
//
|
| 4817 |
+
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*ne20*(nsg))*(sizeof(float)/2), 16))
|
| 4818 |
|
| 4819 |
int64_t nsgmax = 2;
|
| 4820 |
while (true) {
|
ggml/src/ggml-metal/ggml-metal.metal
CHANGED
|
@@ -3328,14 +3328,14 @@ kernel void kernel_flash_attn_ext(
|
|
| 3328 |
constexpr short NW = N_SIMDWIDTH;
|
| 3329 |
constexpr short SH = (2*C + Q); // shared memory per simdgroup (s_t == float)
|
| 3330 |
|
| 3331 |
-
const short TS = nsg*SH;
|
| 3332 |
-
const short T = DK + 2*TS; // shared memory size per query in (half)
|
| 3333 |
|
| 3334 |
-
threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 +
|
| 3335 |
-
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 +
|
| 3336 |
-
threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 +
|
| 3337 |
-
threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 +
|
| 3338 |
-
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + Q*DK); // scratch buffer for attention, mask and diagonal matrix
|
| 3339 |
|
| 3340 |
threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
|
| 3341 |
threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
|
|
@@ -3354,7 +3354,7 @@ kernel void kernel_flash_attn_ext(
|
|
| 3354 |
if (iq1 + j < args.ne01) {
|
| 3355 |
sq4[j*DK4 + i] = (q4_t) q4[i];
|
| 3356 |
} else {
|
| 3357 |
-
sq4[j*DK4 + i] =
|
| 3358 |
}
|
| 3359 |
}
|
| 3360 |
}
|
|
@@ -3634,9 +3634,6 @@ kernel void kernel_flash_attn_ext(
|
|
| 3634 |
|
| 3635 |
// reduce the warps sequentially
|
| 3636 |
for (ushort sg = 1; sg < nsg; ++sg) {
|
| 3637 |
-
float S = { 0.0f };
|
| 3638 |
-
float M = { -__FLT_MAX__/2 };
|
| 3639 |
-
|
| 3640 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 3641 |
|
| 3642 |
// each simdgroup stores its output to shared memory, reusing sq
|
|
@@ -3657,12 +3654,12 @@ kernel void kernel_flash_attn_ext(
|
|
| 3657 |
const float M0 = ss[j*TS + 1];
|
| 3658 |
const float M1 = ss[j*TS + sg*SH + 1];
|
| 3659 |
|
| 3660 |
-
M = max(M0, M1);
|
| 3661 |
|
| 3662 |
const float ms0 = exp(M0 - M);
|
| 3663 |
const float ms1 = exp(M1 - M);
|
| 3664 |
|
| 3665 |
-
S = S0*ms0 + S1*ms1;
|
| 3666 |
|
| 3667 |
if (tiisg == 0) {
|
| 3668 |
ss[j*TS + 0] = S;
|
|
@@ -3701,16 +3698,18 @@ kernel void kernel_flash_attn_ext(
|
|
| 3701 |
}
|
| 3702 |
}
|
| 3703 |
|
| 3704 |
-
|
|
|
|
|
|
|
| 3705 |
|
| 3706 |
// final rescale with 1/S and store to global memory
|
| 3707 |
-
|
| 3708 |
-
|
| 3709 |
-
const float S = ss[j*TS + 0];
|
| 3710 |
|
| 3711 |
-
|
| 3712 |
-
|
| 3713 |
-
|
|
|
|
| 3714 |
}
|
| 3715 |
}
|
| 3716 |
}
|
|
@@ -3719,12 +3718,22 @@ kernel void kernel_flash_attn_ext(
|
|
| 3719 |
// template to be able to explore different combinations
|
| 3720 |
//
|
| 3721 |
#define FA_TYPES \
|
| 3722 |
-
|
| 3723 |
-
half,
|
| 3724 |
-
half,
|
| 3725 |
-
float,
|
| 3726 |
-
float,
|
| 3727 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3728 |
|
| 3729 |
typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
|
| 3730 |
|
|
@@ -3739,15 +3748,15 @@ template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_at
|
|
| 3739 |
template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
|
| 3740 |
|
| 3741 |
#if defined(GGML_METAL_USE_BF16)
|
| 3742 |
-
template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<
|
| 3743 |
-
template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<
|
| 3744 |
-
template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<
|
| 3745 |
-
template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<
|
| 3746 |
-
template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<
|
| 3747 |
-
template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<
|
| 3748 |
-
template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<
|
| 3749 |
-
template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<
|
| 3750 |
-
template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<
|
| 3751 |
#endif
|
| 3752 |
|
| 3753 |
template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
|
|
@@ -3801,6 +3810,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_at
|
|
| 3801 |
template [[host_name("kernel_flash_attn_ext_q8_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 576, 512>;
|
| 3802 |
|
| 3803 |
#undef FA_TYPES
|
|
|
|
| 3804 |
|
| 3805 |
template<
|
| 3806 |
typename q4_t, // query types in shared memory
|
|
@@ -3847,12 +3857,12 @@ kernel void kernel_flash_attn_ext_vec(
|
|
| 3847 |
|
| 3848 |
const short T = DK + nsg*SH; // shared memory size per query in (half)
|
| 3849 |
|
| 3850 |
-
//threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 +
|
| 3851 |
-
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 +
|
| 3852 |
-
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 +
|
| 3853 |
-
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 +
|
| 3854 |
-
threadgroup float * sm = (threadgroup float *) (shmem_f16 +
|
| 3855 |
-
threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results
|
| 3856 |
|
| 3857 |
// store the result for all queries in local memory (the O matrix from the paper)
|
| 3858 |
o4_t lo[DV4/NL];
|
|
@@ -4157,7 +4167,7 @@ kernel void kernel_flash_attn_ext_vec(
|
|
| 4157 |
half4, \
|
| 4158 |
float, \
|
| 4159 |
float, float4, \
|
| 4160 |
-
|
| 4161 |
|
| 4162 |
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
|
| 4163 |
|
|
|
|
| 3328 |
constexpr short NW = N_SIMDWIDTH;
|
| 3329 |
constexpr short SH = (2*C + Q); // shared memory per simdgroup (s_t == float)
|
| 3330 |
|
| 3331 |
+
const short TS = nsg*SH; // shared memory size per query in (s_t == float)
|
| 3332 |
+
const short T = 2*DK + 2*TS; // shared memory size per query in (half)
|
| 3333 |
|
| 3334 |
+
threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
|
| 3335 |
+
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
|
| 3336 |
+
threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*DK); // reuse query data for accumulation
|
| 3337 |
+
threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*DK); // same as above but in o4_t
|
| 3338 |
+
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + 2*Q*DK); // scratch buffer for attention, mask and diagonal matrix
|
| 3339 |
|
| 3340 |
threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
|
| 3341 |
threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
|
|
|
|
| 3354 |
if (iq1 + j < args.ne01) {
|
| 3355 |
sq4[j*DK4 + i] = (q4_t) q4[i];
|
| 3356 |
} else {
|
| 3357 |
+
sq4[j*DK4 + i] = 0;
|
| 3358 |
}
|
| 3359 |
}
|
| 3360 |
}
|
|
|
|
| 3634 |
|
| 3635 |
// reduce the warps sequentially
|
| 3636 |
for (ushort sg = 1; sg < nsg; ++sg) {
|
|
|
|
|
|
|
|
|
|
| 3637 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 3638 |
|
| 3639 |
// each simdgroup stores its output to shared memory, reusing sq
|
|
|
|
| 3654 |
const float M0 = ss[j*TS + 1];
|
| 3655 |
const float M1 = ss[j*TS + sg*SH + 1];
|
| 3656 |
|
| 3657 |
+
const float M = max(M0, M1);
|
| 3658 |
|
| 3659 |
const float ms0 = exp(M0 - M);
|
| 3660 |
const float ms1 = exp(M1 - M);
|
| 3661 |
|
| 3662 |
+
const float S = S0*ms0 + S1*ms1;
|
| 3663 |
|
| 3664 |
if (tiisg == 0) {
|
| 3665 |
ss[j*TS + 0] = S;
|
|
|
|
| 3698 |
}
|
| 3699 |
}
|
| 3700 |
|
| 3701 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 3702 |
+
|
| 3703 |
+
threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2*Q*DK);
|
| 3704 |
|
| 3705 |
// final rescale with 1/S and store to global memory
|
| 3706 |
+
for (short j = sgitg; j < Q && iq1 + j < args.ne01; j += nsg) {
|
| 3707 |
+
const float S = 1.0f/sf[j*TS + 0];
|
|
|
|
| 3708 |
|
| 3709 |
+
device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4;
|
| 3710 |
+
|
| 3711 |
+
for (short i = tiisg; i < DV4; i += NW) {
|
| 3712 |
+
dst4[i] = (float4) so4[j*DV4 + i]*S;
|
| 3713 |
}
|
| 3714 |
}
|
| 3715 |
}
|
|
|
|
| 3718 |
// template to be able to explore different combinations
|
| 3719 |
//
|
| 3720 |
#define FA_TYPES \
|
| 3721 |
+
float, float4, simdgroup_float8x8, \
|
| 3722 |
+
half, half4x4, simdgroup_half8x8, \
|
| 3723 |
+
half, half4x4, simdgroup_half8x8, \
|
| 3724 |
+
float, simdgroup_float8x8, \
|
| 3725 |
+
float, simdgroup_float8x8, \
|
| 3726 |
+
float, float4, simdgroup_float8x8
|
| 3727 |
+
//half, half4, simdgroup_half8x8
|
| 3728 |
+
|
| 3729 |
+
#define FA_TYPES_BF \
|
| 3730 |
+
bfloat, bfloat4, simdgroup_bfloat8x8, \
|
| 3731 |
+
bfloat, bfloat4x4, simdgroup_bfloat8x8, \
|
| 3732 |
+
bfloat, bfloat4x4, simdgroup_bfloat8x8, \
|
| 3733 |
+
float, simdgroup_float8x8, \
|
| 3734 |
+
float, simdgroup_float8x8, \
|
| 3735 |
+
float, float4, simdgroup_float8x8
|
| 3736 |
+
//half, half4, simdgroup_half8x8
|
| 3737 |
|
| 3738 |
typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
|
| 3739 |
|
|
|
|
| 3748 |
template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
|
| 3749 |
|
| 3750 |
#if defined(GGML_METAL_USE_BF16)
|
| 3751 |
+
template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
|
| 3752 |
+
template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
|
| 3753 |
+
template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96, 96>;
|
| 3754 |
+
template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112, 112>;
|
| 3755 |
+
template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128, 128>;
|
| 3756 |
+
template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 192>;
|
| 3757 |
+
template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 128>;
|
| 3758 |
+
template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256, 256>;
|
| 3759 |
+
template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
|
| 3760 |
#endif
|
| 3761 |
|
| 3762 |
template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
|
|
|
|
| 3810 |
template [[host_name("kernel_flash_attn_ext_q8_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 576, 512>;
|
| 3811 |
|
| 3812 |
#undef FA_TYPES
|
| 3813 |
+
#undef FA_TYPES_BF
|
| 3814 |
|
| 3815 |
template<
|
| 3816 |
typename q4_t, // query types in shared memory
|
|
|
|
| 3857 |
|
| 3858 |
const short T = DK + nsg*SH; // shared memory size per query in (half)
|
| 3859 |
|
| 3860 |
+
//threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
|
| 3861 |
+
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
|
| 3862 |
+
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
|
| 3863 |
+
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
|
| 3864 |
+
threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask
|
| 3865 |
+
threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*DV + Q*T); // scratch buffer for the results
|
| 3866 |
|
| 3867 |
// store the result for all queries in local memory (the O matrix from the paper)
|
| 3868 |
o4_t lo[DV4/NL];
|
|
|
|
| 4167 |
half4, \
|
| 4168 |
float, \
|
| 4169 |
float, float4, \
|
| 4170 |
+
float4
|
| 4171 |
|
| 4172 |
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
|
| 4173 |
|