ggerganov commited on
Commit
b86860f
·
1 Parent(s): bc1415b

metal : use F32 accumulators in FA kernels (llama/13975)

Browse files
ggml/src/ggml-metal/ggml-metal.m CHANGED
@@ -4766,6 +4766,8 @@ static bool ggml_metal_encode_node(
4766
  GGML_ASSERT(nqptg % 8 == 0);
4767
  GGML_ASSERT(ncpsg % 32 == 0);
4768
 
 
 
4769
  // 2*(2*ncpsg + nqptg)*(nsg)
4770
  // ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float)
4771
  //
@@ -4773,7 +4775,7 @@ static bool ggml_metal_encode_node(
4773
  // the shared memory needed for the simdgroups to load the KV cache
4774
  // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
4775
  //
4776
- #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
4777
 
4778
  int64_t nsgmax = 2;
4779
 
@@ -4810,9 +4812,9 @@ static bool ggml_metal_encode_node(
4810
  // and store the soft_max values and the mask
4811
  //
4812
  // ne00*(nsg)
4813
- // each simdgroup has a full f16 head vector in shared mem to accumulate results
4814
  //
4815
- #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + ne20*(nsg))*(sizeof(float)/2), 16))
4816
 
4817
  int64_t nsgmax = 2;
4818
  while (true) {
 
4766
  GGML_ASSERT(nqptg % 8 == 0);
4767
  GGML_ASSERT(ncpsg % 32 == 0);
4768
 
4769
+ const int is_q = ggml_is_quantized(src1->type) ? 1 : 0;
4770
+
4771
  // 2*(2*ncpsg + nqptg)*(nsg)
4772
  // ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float)
4773
  //
 
4775
  // the shared memory needed for the simdgroups to load the KV cache
4776
  // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
4777
  //
4778
+ #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(2*ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + is_q*(16*32*(nsg)))*(sizeof(float)/2), 16))
4779
 
4780
  int64_t nsgmax = 2;
4781
 
 
4812
  // and store the soft_max values and the mask
4813
  //
4814
  // ne00*(nsg)
4815
+ // each simdgroup has a full f32 head vector in shared mem to accumulate results
4816
  //
4817
+ #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*ne20*(nsg))*(sizeof(float)/2), 16))
4818
 
4819
  int64_t nsgmax = 2;
4820
  while (true) {
ggml/src/ggml-metal/ggml-metal.metal CHANGED
@@ -3328,14 +3328,14 @@ kernel void kernel_flash_attn_ext(
3328
  constexpr short NW = N_SIMDWIDTH;
3329
  constexpr short SH = (2*C + Q); // shared memory per simdgroup (s_t == float)
3330
 
3331
- const short TS = nsg*SH; // shared memory size per query in (s_t == float)
3332
- const short T = DK + 2*TS; // shared memory size per query in (half)
3333
 
3334
- threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
3335
- threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
3336
- threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*DK); // reuse query data for accumulation
3337
- threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*DK); // same as above but in o4_t
3338
- threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + Q*DK); // scratch buffer for attention, mask and diagonal matrix
3339
 
3340
  threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
3341
  threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
@@ -3354,7 +3354,7 @@ kernel void kernel_flash_attn_ext(
3354
  if (iq1 + j < args.ne01) {
3355
  sq4[j*DK4 + i] = (q4_t) q4[i];
3356
  } else {
3357
- sq4[j*DK4 + i] = (q4_t) 0.0f;
3358
  }
3359
  }
3360
  }
@@ -3634,9 +3634,6 @@ kernel void kernel_flash_attn_ext(
3634
 
3635
  // reduce the warps sequentially
3636
  for (ushort sg = 1; sg < nsg; ++sg) {
3637
- float S = { 0.0f };
3638
- float M = { -__FLT_MAX__/2 };
3639
-
3640
  threadgroup_barrier(mem_flags::mem_threadgroup);
3641
 
3642
  // each simdgroup stores its output to shared memory, reusing sq
@@ -3657,12 +3654,12 @@ kernel void kernel_flash_attn_ext(
3657
  const float M0 = ss[j*TS + 1];
3658
  const float M1 = ss[j*TS + sg*SH + 1];
3659
 
3660
- M = max(M0, M1);
3661
 
3662
  const float ms0 = exp(M0 - M);
3663
  const float ms1 = exp(M1 - M);
3664
 
3665
- S = S0*ms0 + S1*ms1;
3666
 
3667
  if (tiisg == 0) {
3668
  ss[j*TS + 0] = S;
@@ -3701,16 +3698,18 @@ kernel void kernel_flash_attn_ext(
3701
  }
3702
  }
3703
 
3704
- device float4 * dst4 = (device float4 *) dst;
 
 
3705
 
3706
  // final rescale with 1/S and store to global memory
3707
- if (sgitg == 0) {
3708
- for (short j = 0; j < Q && iq1 + j < args.ne01; ++j) {
3709
- const float S = ss[j*TS + 0];
3710
 
3711
- for (short i = tiisg; i < DV4; i += NW) {
3712
- dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4 + i] = (float4) so4[j*DV4 + i]/S;
3713
- }
 
3714
  }
3715
  }
3716
  }
@@ -3719,12 +3718,22 @@ kernel void kernel_flash_attn_ext(
3719
  // template to be able to explore different combinations
3720
  //
3721
  #define FA_TYPES \
3722
- half, half4, simdgroup_half8x8, \
3723
- half, half4x4, simdgroup_half8x8, \
3724
- half, half4x4, simdgroup_half8x8, \
3725
- float, simdgroup_float8x8, \
3726
- float, simdgroup_float8x8, \
3727
- half, half4, simdgroup_half8x8
 
 
 
 
 
 
 
 
 
 
3728
 
3729
  typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
3730
 
@@ -3739,15 +3748,15 @@ template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_at
3739
  template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
3740
 
3741
  #if defined(GGML_METAL_USE_BF16)
3742
- template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
3743
- template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
3744
- template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96, 96>;
3745
- template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112, 112>;
3746
- template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128, 128>;
3747
- template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 192>;
3748
- template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 128>;
3749
- template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256, 256>;
3750
- template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
3751
  #endif
3752
 
3753
  template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
@@ -3801,6 +3810,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_at
3801
  template [[host_name("kernel_flash_attn_ext_q8_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 576, 512>;
3802
 
3803
  #undef FA_TYPES
 
3804
 
3805
  template<
3806
  typename q4_t, // query types in shared memory
@@ -3847,12 +3857,12 @@ kernel void kernel_flash_attn_ext_vec(
3847
 
3848
  const short T = DK + nsg*SH; // shared memory size per query in (half)
3849
 
3850
- //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
3851
- threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
3852
- threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
3853
- threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
3854
- threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask
3855
- threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results
3856
 
3857
  // store the result for all queries in local memory (the O matrix from the paper)
3858
  o4_t lo[DV4/NL];
@@ -4157,7 +4167,7 @@ kernel void kernel_flash_attn_ext_vec(
4157
  half4, \
4158
  float, \
4159
  float, float4, \
4160
- half4
4161
 
4162
  typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
4163
 
 
3328
  constexpr short NW = N_SIMDWIDTH;
3329
  constexpr short SH = (2*C + Q); // shared memory per simdgroup (s_t == float)
3330
 
3331
+ const short TS = nsg*SH; // shared memory size per query in (s_t == float)
3332
+ const short T = 2*DK + 2*TS; // shared memory size per query in (half)
3333
 
3334
+ threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
3335
+ threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
3336
+ threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*DK); // reuse query data for accumulation
3337
+ threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*DK); // same as above but in o4_t
3338
+ threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + 2*Q*DK); // scratch buffer for attention, mask and diagonal matrix
3339
 
3340
  threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
3341
  threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
 
3354
  if (iq1 + j < args.ne01) {
3355
  sq4[j*DK4 + i] = (q4_t) q4[i];
3356
  } else {
3357
+ sq4[j*DK4 + i] = 0;
3358
  }
3359
  }
3360
  }
 
3634
 
3635
  // reduce the warps sequentially
3636
  for (ushort sg = 1; sg < nsg; ++sg) {
 
 
 
3637
  threadgroup_barrier(mem_flags::mem_threadgroup);
3638
 
3639
  // each simdgroup stores its output to shared memory, reusing sq
 
3654
  const float M0 = ss[j*TS + 1];
3655
  const float M1 = ss[j*TS + sg*SH + 1];
3656
 
3657
+ const float M = max(M0, M1);
3658
 
3659
  const float ms0 = exp(M0 - M);
3660
  const float ms1 = exp(M1 - M);
3661
 
3662
+ const float S = S0*ms0 + S1*ms1;
3663
 
3664
  if (tiisg == 0) {
3665
  ss[j*TS + 0] = S;
 
3698
  }
3699
  }
3700
 
3701
+ threadgroup_barrier(mem_flags::mem_threadgroup);
3702
+
3703
+ threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2*Q*DK);
3704
 
3705
  // final rescale with 1/S and store to global memory
3706
+ for (short j = sgitg; j < Q && iq1 + j < args.ne01; j += nsg) {
3707
+ const float S = 1.0f/sf[j*TS + 0];
 
3708
 
3709
+ device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4;
3710
+
3711
+ for (short i = tiisg; i < DV4; i += NW) {
3712
+ dst4[i] = (float4) so4[j*DV4 + i]*S;
3713
  }
3714
  }
3715
  }
 
3718
  // template to be able to explore different combinations
3719
  //
3720
  #define FA_TYPES \
3721
+ float, float4, simdgroup_float8x8, \
3722
+ half, half4x4, simdgroup_half8x8, \
3723
+ half, half4x4, simdgroup_half8x8, \
3724
+ float, simdgroup_float8x8, \
3725
+ float, simdgroup_float8x8, \
3726
+ float, float4, simdgroup_float8x8
3727
+ //half, half4, simdgroup_half8x8
3728
+
3729
+ #define FA_TYPES_BF \
3730
+ bfloat, bfloat4, simdgroup_bfloat8x8, \
3731
+ bfloat, bfloat4x4, simdgroup_bfloat8x8, \
3732
+ bfloat, bfloat4x4, simdgroup_bfloat8x8, \
3733
+ float, simdgroup_float8x8, \
3734
+ float, simdgroup_float8x8, \
3735
+ float, float4, simdgroup_float8x8
3736
+ //half, half4, simdgroup_half8x8
3737
 
3738
  typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
3739
 
 
3748
  template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
3749
 
3750
  #if defined(GGML_METAL_USE_BF16)
3751
+ template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
3752
+ template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
3753
+ template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96, 96>;
3754
+ template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112, 112>;
3755
+ template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128, 128>;
3756
+ template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 192>;
3757
+ template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 128>;
3758
+ template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256, 256>;
3759
+ template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
3760
  #endif
3761
 
3762
  template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
 
3810
  template [[host_name("kernel_flash_attn_ext_q8_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 576, 512>;
3811
 
3812
  #undef FA_TYPES
3813
+ #undef FA_TYPES_BF
3814
 
3815
  template<
3816
  typename q4_t, // query types in shared memory
 
3857
 
3858
  const short T = DK + nsg*SH; // shared memory size per query in (half)
3859
 
3860
+ //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
3861
+ threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
3862
+ threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
3863
+ threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
3864
+ threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask
3865
+ threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*DV + Q*T); // scratch buffer for the results
3866
 
3867
  // store the result for all queries in local memory (the O matrix from the paper)
3868
  o4_t lo[DV4/NL];
 
4167
  half4, \
4168
  float, \
4169
  float, float4, \
4170
+ float4
4171
 
4172
  typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
4173