Gaurav Garg JohannesGaessler commited on
Commit
3a7ca19
·
1 Parent(s): 53dd8ad

CUDA: Improve flash decoding kernel GPU occupancy for BS=1 case (llama/12183)

Browse files

- Find out active blocks per SM using cudaOccupancyMaxActiveBlocksPerMultiprocessor API. Use this value to determine the optimal parallel_blocks value.
- Prefer vector flash attention kernels over MMA kernel for BS=1

Fixes Issue: #12182
---------

Co-authored-by: Johannes Gäßler <[email protected]>

ggml/src/ggml-cuda/fattn-common.cuh CHANGED
@@ -606,48 +606,47 @@ static __global__ void flash_attn_stream_k_fixup(
606
  *dst = dst_val / rowsum;
607
  }
608
 
609
- template<int D, int parallel_blocks> // D == head size
610
  #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
611
  __launch_bounds__(D, 1)
612
  #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
613
  static __global__ void flash_attn_combine_results(
614
  const float * __restrict__ VKQ_parts,
615
  const float2 * __restrict__ VKQ_meta,
616
- float * __restrict__ dst) {
617
- VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x;
618
- VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x;
619
- dst += D * gridDim.y*blockIdx.x;
 
620
 
621
  const int tid = threadIdx.x;
622
  __builtin_assume(tid < D);
623
 
624
- __shared__ float2 meta[parallel_blocks];
625
  if (tid < 2*parallel_blocks) {
626
- ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid];
627
  }
628
 
629
  __syncthreads();
630
 
631
  float kqmax = meta[0].x;
632
- #pragma unroll
633
  for (int l = 1; l < parallel_blocks; ++l) {
634
  kqmax = max(kqmax, meta[l].x);
635
  }
636
 
637
  float VKQ_numerator = 0.0f;
638
  float VKQ_denominator = 0.0f;
639
- #pragma unroll
640
  for (int l = 0; l < parallel_blocks; ++l) {
641
  const float diff = meta[l].x - kqmax;
642
  const float KQ_max_scale = expf(diff);
643
  const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
644
  *((uint32_t *) &KQ_max_scale) &= ftz_mask;
645
 
646
- VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid];
647
  VKQ_denominator += KQ_max_scale * meta[l].y;
648
  }
649
 
650
- dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
651
  }
652
 
653
  static void on_no_fattn_vec_case(const int D) {
@@ -671,12 +670,10 @@ static void on_no_fattn_vec_case(const int D) {
671
  }
672
  }
673
 
674
- // parallel_blocks == 0 is stream-k decomposition
675
- template <int D, int ncols1, int ncols2, int parallel_blocks, int KQ_stride>
676
  void launch_fattn(
677
- ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
678
- const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V,
679
- const int warp_size = WARP_SIZE
680
  ) {
681
  constexpr int ncols = ncols1 * ncols2;
682
 
@@ -748,12 +745,14 @@ void launch_fattn(
748
  nb23 = nb23*bs*sizeof(half)/ts;
749
  }
750
 
 
 
751
  const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
752
  const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
753
 
754
  const dim3 block_dim(warp_size, nwarps, 1);
755
  dim3 blocks_num;
756
- if (parallel_blocks == 0) {
757
  // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
758
  const int max_blocks = 2*nsm;
759
  const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
@@ -769,9 +768,43 @@ void launch_fattn(
769
 
770
  dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float));
771
  } else {
772
- blocks_num.x = parallel_blocks*ntiles_x;
773
- blocks_num.y = Q->ne[2];
774
- blocks_num.z = Q->ne[3];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
775
 
776
  if (parallel_blocks > 1) {
777
  dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
@@ -803,7 +836,7 @@ void launch_fattn(
803
  K_data,
804
  V_data,
805
  mask ? ((const char *) mask->data) : nullptr,
806
- (parallel_blocks) > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
807
  scale, max_bias, m0, m1, n_head_log2, logit_softcap,
808
  Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
809
  K->ne[0], K->ne[1], K->ne[2], K->ne[3],
@@ -815,7 +848,7 @@ void launch_fattn(
815
  );
816
  CUDA_CHECK(cudaGetLastError());
817
 
818
- if constexpr (parallel_blocks == 0) {
819
  if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
820
  const dim3 block_dim_combine(D, 1, 1);
821
  const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
@@ -824,13 +857,14 @@ void launch_fattn(
824
  <<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
825
  ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
826
  }
827
- } else if constexpr (parallel_blocks > 1) {
828
  const dim3 block_dim_combine(D, 1, 1);
829
- const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
 
830
 
831
- flash_attn_combine_results<D, parallel_blocks>
832
- <<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
833
- (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
834
  }
835
  CUDA_CHECK(cudaGetLastError());
836
  }
 
606
  *dst = dst_val / rowsum;
607
  }
608
 
609
+ template<int D> // D == head size
610
  #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
611
  __launch_bounds__(D, 1)
612
  #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
613
  static __global__ void flash_attn_combine_results(
614
  const float * __restrict__ VKQ_parts,
615
  const float2 * __restrict__ VKQ_meta,
616
+ float * __restrict__ dst,
617
+ const int parallel_blocks) {
618
+ VKQ_parts += parallel_blocks*D * gridDim.z*blockIdx.x;
619
+ VKQ_meta += parallel_blocks * gridDim.z*blockIdx.x;
620
+ dst += D * gridDim.z*blockIdx.x;
621
 
622
  const int tid = threadIdx.x;
623
  __builtin_assume(tid < D);
624
 
625
+ extern __shared__ float2 meta[];
626
  if (tid < 2*parallel_blocks) {
627
+ ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + tid];
628
  }
629
 
630
  __syncthreads();
631
 
632
  float kqmax = meta[0].x;
 
633
  for (int l = 1; l < parallel_blocks; ++l) {
634
  kqmax = max(kqmax, meta[l].x);
635
  }
636
 
637
  float VKQ_numerator = 0.0f;
638
  float VKQ_denominator = 0.0f;
 
639
  for (int l = 0; l < parallel_blocks; ++l) {
640
  const float diff = meta[l].x - kqmax;
641
  const float KQ_max_scale = expf(diff);
642
  const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
643
  *((uint32_t *) &KQ_max_scale) &= ftz_mask;
644
 
645
+ VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.z*D + blockIdx.z*D + tid];
646
  VKQ_denominator += KQ_max_scale * meta[l].y;
647
  }
648
 
649
+ dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator;
650
  }
651
 
652
  static void on_no_fattn_vec_case(const int D) {
 
670
  }
671
  }
672
 
673
+ template <int D, int ncols1, int ncols2, int KQ_stride>
 
674
  void launch_fattn(
675
+ ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
676
+ const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
 
677
  ) {
678
  constexpr int ncols = ncols1 * ncols2;
679
 
 
745
  nb23 = nb23*bs*sizeof(half)/ts;
746
  }
747
 
748
+ int parallel_blocks = 1;
749
+
750
  const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
751
  const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
752
 
753
  const dim3 block_dim(warp_size, nwarps, 1);
754
  dim3 blocks_num;
755
+ if (stream_k) {
756
  // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
757
  const int max_blocks = 2*nsm;
758
  const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
 
768
 
769
  dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float));
770
  } else {
771
+ GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0);
772
+ const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
773
+
774
+ int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
775
+ CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
776
+
777
+ // parallel_blocks should be at least large enough to achieve max. occupancy for a single wave:
778
+ parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1);
779
+
780
+ // parallel_blocks must not be larger than what the tensor size allows:
781
+ parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
782
+
783
+ // If ntiles_total % blocks_per_wave != 0 then some efficiency is lost due to tail effects.
784
+ // Test whether parallel_blocks can be set to a higher value for better efficiency.
785
+ const int blocks_per_wave = nsm * max_blocks_per_sm;
786
+ int nwaves_best = 0;
787
+ int efficiency_percent_best = 0;
788
+ for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KQ; ++parallel_blocks_test) {
789
+ const int nblocks_total = ntiles_total * parallel_blocks_test;
790
+ const int nwaves = (nblocks_total + blocks_per_wave - 1) / blocks_per_wave;
791
+ const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave);
792
+
793
+ // Stop trying configurations with more waves if we already have good efficiency to avoid excessive overhead.
794
+ if (efficiency_percent_best >= 90 && nwaves > nwaves_best) {
795
+ break;
796
+ }
797
+
798
+ if (efficiency_percent > efficiency_percent_best) {
799
+ nwaves_best = nwaves;
800
+ efficiency_percent_best = efficiency_percent;
801
+ parallel_blocks = parallel_blocks_test;
802
+ }
803
+ }
804
+
805
+ blocks_num.x = ntiles_x;
806
+ blocks_num.y = parallel_blocks;
807
+ blocks_num.z = Q->ne[2]*Q->ne[3];
808
 
809
  if (parallel_blocks > 1) {
810
  dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
 
836
  K_data,
837
  V_data,
838
  mask ? ((const char *) mask->data) : nullptr,
839
+ !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
840
  scale, max_bias, m0, m1, n_head_log2, logit_softcap,
841
  Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
842
  K->ne[0], K->ne[1], K->ne[2], K->ne[3],
 
848
  );
849
  CUDA_CHECK(cudaGetLastError());
850
 
851
+ if (stream_k) {
852
  if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
853
  const dim3 block_dim_combine(D, 1, 1);
854
  const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
 
857
  <<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
858
  ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
859
  }
860
+ } else if (parallel_blocks > 1) {
861
  const dim3 block_dim_combine(D, 1, 1);
862
+ const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z);
863
+ const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
864
 
865
+ flash_attn_combine_results<D>
866
+ <<<blocks_num_combine, block_dim_combine, nbytes_shared_combine, main_stream>>>
867
+ (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks);
868
  }
869
  CUDA_CHECK(cudaGetLastError());
870
  }
ggml/src/ggml-cuda/fattn-mma-f16.cuh CHANGED
@@ -970,7 +970,8 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
970
  fattn_kernel = flash_attn_ext_f16<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap>;
971
  }
972
 
973
- launch_fattn<D, ncols1, ncols2, 0, KQ_per_iter>(ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, true, true);
 
974
  }
975
 
976
 
 
970
  fattn_kernel = flash_attn_ext_f16<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap>;
971
  }
972
 
973
+ launch_fattn<D, ncols1, ncols2, KQ_per_iter>
974
+ (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, FATTN_KQ_STRIDE, true, true, true);
975
  }
976
 
977
 
ggml/src/ggml-cuda/fattn-tile-f16.cu CHANGED
@@ -4,7 +4,7 @@
4
 
5
  #define FATTN_KQ_STRIDE_TILE_F16 64
6
 
7
- template<int D, int ncols, int nwarps, int parallel_blocks, bool use_logit_softcap> // D == head size
8
  #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
9
  __launch_bounds__(nwarps*WARP_SIZE, 1)
10
  #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
@@ -58,18 +58,17 @@ static __global__ void flash_attn_tile_ext_f16(
58
 
59
  //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
60
 
61
- const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
62
- const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
63
 
64
  const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
65
- const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0);
66
- const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
67
- const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
68
  const half * maskh = (const half *) mask + ne11*ic0;
69
 
70
  const int stride_KV2 = nb11 / sizeof(half2);
71
 
72
- const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
73
  const half slopeh = __float2half(slopef);
74
 
75
  static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
@@ -105,8 +104,7 @@ static __global__ void flash_attn_tile_ext_f16(
105
 
106
  __syncthreads();
107
 
108
- const int k_start = parallel_blocks == 1 ? 0 : ip*FATTN_KQ_STRIDE_TILE_F16;
109
- for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE_TILE_F16) {
110
  // Calculate KQ tile and keep track of new maximum KQ values:
111
 
112
  half kqmax_new[ncols/nwarps];
@@ -271,16 +269,16 @@ static __global__ void flash_attn_tile_ext_f16(
271
  const int i0 = i00 + 2*threadIdx.x;
272
 
273
  half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
274
- if (parallel_blocks == 1) {
275
  dst_val /= __half2half2(kqsum_j);
276
  }
277
- const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
278
- dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = __low2float(dst_val);
279
- dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = __high2float(dst_val);
280
  }
281
 
282
- if (parallel_blocks != 1 && threadIdx.x == 0) {
283
- dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
284
  }
285
  }
286
  #else
@@ -288,7 +286,7 @@ static __global__ void flash_attn_tile_ext_f16(
288
  #endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
289
  }
290
 
291
- template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
292
  void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
293
  const ggml_tensor * Q = dst->src[0];
294
  switch (Q->ne[0]) {
@@ -296,15 +294,17 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
296
  constexpr int D = 64;
297
  constexpr int nwarps = 8;
298
  constexpr size_t nbytes_shared = 0;
299
- fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
300
- launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
 
301
  } break;
302
  case 128: {
303
  constexpr int D = 128;
304
  constexpr int nwarps = 8;
305
  constexpr size_t nbytes_shared = 0;
306
- fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
307
- launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
 
308
  } break;
309
  default: {
310
  GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
@@ -324,37 +324,22 @@ void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_ten
324
 
325
  if (Q->ne[1] <= 16) {
326
  constexpr int cols_per_block = 16;
327
- constexpr int parallel_blocks = 4;
328
  if (logit_softcap == 0.0f) {
329
  constexpr bool use_logit_softcap = false;
330
- launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
331
  } else {
332
  constexpr bool use_logit_softcap = true;
333
- launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
334
- }
335
- return;
336
- }
337
-
338
- if (Q->ne[1] <= 32) {
339
- constexpr int cols_per_block = 32;
340
- constexpr int parallel_blocks = 4;
341
- if (logit_softcap == 0.0f) {
342
- constexpr bool use_logit_softcap = false;
343
- launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
344
- } else {
345
- constexpr bool use_logit_softcap = true;
346
- launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
347
  }
348
  return;
349
  }
350
 
351
  constexpr int cols_per_block = 32;
352
- constexpr int parallel_blocks = 1;
353
  if (logit_softcap == 0.0f) {
354
  constexpr bool use_logit_softcap = false;
355
- launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
356
  } else {
357
  constexpr bool use_logit_softcap = true;
358
- launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
359
  }
360
  }
 
4
 
5
  #define FATTN_KQ_STRIDE_TILE_F16 64
6
 
7
+ template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
8
  #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
9
  __launch_bounds__(nwarps*WARP_SIZE, 1)
10
  #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 
58
 
59
  //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
60
 
61
+ const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
 
62
 
63
  const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
64
+ const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
65
+ const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
66
+ const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
67
  const half * maskh = (const half *) mask + ne11*ic0;
68
 
69
  const int stride_KV2 = nb11 / sizeof(half2);
70
 
71
+ const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
72
  const half slopeh = __float2half(slopef);
73
 
74
  static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
 
104
 
105
  __syncthreads();
106
 
107
+ for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F16; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F16) {
 
108
  // Calculate KQ tile and keep track of new maximum KQ values:
109
 
110
  half kqmax_new[ncols/nwarps];
 
269
  const int i0 = i00 + 2*threadIdx.x;
270
 
271
  half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
272
+ if (gridDim.y == 1) {
273
  dst_val /= __half2half2(kqsum_j);
274
  }
275
+ const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
276
+ dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = __low2float(dst_val);
277
+ dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = __high2float(dst_val);
278
  }
279
 
280
+ if (gridDim.y != 1 && threadIdx.x == 0) {
281
+ dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
282
  }
283
  }
284
  #else
 
286
  #endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
287
  }
288
 
289
+ template <int cols_per_block, bool use_logit_softcap>
290
  void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
291
  const ggml_tensor * Q = dst->src[0];
292
  switch (Q->ne[0]) {
 
294
  constexpr int D = 64;
295
  constexpr int nwarps = 8;
296
  constexpr size_t nbytes_shared = 0;
297
+ fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, use_logit_softcap>;
298
+ launch_fattn<D, cols_per_block, 1, -1>
299
+ (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false);
300
  } break;
301
  case 128: {
302
  constexpr int D = 128;
303
  constexpr int nwarps = 8;
304
  constexpr size_t nbytes_shared = 0;
305
+ fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, use_logit_softcap>;
306
+ launch_fattn<D, cols_per_block, 1, -1>
307
+ (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false);
308
  } break;
309
  default: {
310
  GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
 
324
 
325
  if (Q->ne[1] <= 16) {
326
  constexpr int cols_per_block = 16;
 
327
  if (logit_softcap == 0.0f) {
328
  constexpr bool use_logit_softcap = false;
329
+ launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
330
  } else {
331
  constexpr bool use_logit_softcap = true;
332
+ launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
 
 
 
 
 
 
 
 
 
 
 
 
 
333
  }
334
  return;
335
  }
336
 
337
  constexpr int cols_per_block = 32;
 
338
  if (logit_softcap == 0.0f) {
339
  constexpr bool use_logit_softcap = false;
340
+ launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
341
  } else {
342
  constexpr bool use_logit_softcap = true;
343
+ launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
344
  }
345
  }
ggml/src/ggml-cuda/fattn-tile-f32.cu CHANGED
@@ -4,7 +4,7 @@
4
 
5
  #define FATTN_KQ_STRIDE_TILE_F32 32
6
 
7
- template<int D, int ncols, int nwarps, int parallel_blocks, bool use_logit_softcap> // D == head size
8
  #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
9
  __launch_bounds__(nwarps*WARP_SIZE, 1)
10
  #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
@@ -58,18 +58,17 @@ static __global__ void flash_attn_tile_ext_f32(
58
 
59
  // In this kernel Q, K, V are matrices while i, j, k are matrix indices.
60
 
61
- const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
62
- const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
63
 
64
  const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
65
- const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0);
66
- const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
67
- const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
68
  const half * maskh = (const half *) mask + ne11*ic0;
69
 
70
  const int stride_KV2 = nb11 / sizeof(half2);
71
 
72
- const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
73
 
74
  static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
75
 
@@ -103,8 +102,7 @@ static __global__ void flash_attn_tile_ext_f32(
103
 
104
  __syncthreads();
105
 
106
- const int k_start = parallel_blocks == 1 ? 0 : ip*FATTN_KQ_STRIDE_TILE_F32;
107
- for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE_TILE_F32) {
108
  // Calculate KQ tile and keep track of new maximum KQ values:
109
 
110
  float kqmax_new[ncols/nwarps];
@@ -269,17 +267,17 @@ static __global__ void flash_attn_tile_ext_f32(
269
  const int i0 = i00 + 2*threadIdx.x;
270
 
271
  float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
272
- if (parallel_blocks == 1) {
273
  dst_val.x /= kqsum_j;
274
  dst_val.y /= kqsum_j;
275
  }
276
- const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
277
- dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = dst_val.x;
278
- dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = dst_val.y;
279
  }
280
 
281
- if (parallel_blocks != 1 && threadIdx.x == 0) {
282
- dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
283
  }
284
  }
285
  #else
@@ -287,7 +285,7 @@ static __global__ void flash_attn_tile_ext_f32(
287
  #endif // FLASH_ATTN_AVAILABLE
288
  }
289
 
290
- template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
291
  void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
292
  const ggml_tensor * Q = dst->src[0];
293
  switch (Q->ne[0]) {
@@ -295,15 +293,17 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
295
  constexpr int D = 64;
296
  constexpr int nwarps = 8;
297
  constexpr size_t nbytes_shared = 0;
298
- fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
299
- launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
 
300
  } break;
301
  case 128: {
302
  constexpr int D = 128;
303
  constexpr int nwarps = 8;
304
  constexpr size_t nbytes_shared = 0;
305
- fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
306
- launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
 
307
  } break;
308
  default: {
309
  GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
@@ -320,37 +320,22 @@ void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_ten
320
 
321
  if (Q->ne[1] <= 16) {
322
  constexpr int cols_per_block = 16;
323
- constexpr int parallel_blocks = 4;
324
  if (logit_softcap == 0.0f) {
325
  constexpr bool use_logit_softcap = false;
326
- launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
327
  } else {
328
  constexpr bool use_logit_softcap = true;
329
- launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
330
- }
331
- return;
332
- }
333
-
334
- if (Q->ne[1] <= 32) {
335
- constexpr int cols_per_block = 32;
336
- constexpr int parallel_blocks = 4;
337
- if (logit_softcap == 0.0f) {
338
- constexpr bool use_logit_softcap = false;
339
- launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
340
- } else {
341
- constexpr bool use_logit_softcap = true;
342
- launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
343
  }
344
  return;
345
  }
346
 
347
  constexpr int cols_per_block = 32;
348
- constexpr int parallel_blocks = 1;
349
  if (logit_softcap == 0.0f) {
350
  constexpr bool use_logit_softcap = false;
351
- launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
352
  } else {
353
  constexpr bool use_logit_softcap = true;
354
- launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
355
  }
356
  }
 
4
 
5
  #define FATTN_KQ_STRIDE_TILE_F32 32
6
 
7
+ template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
8
  #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
9
  __launch_bounds__(nwarps*WARP_SIZE, 1)
10
  #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 
58
 
59
  // In this kernel Q, K, V are matrices while i, j, k are matrix indices.
60
 
61
+ const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
 
62
 
63
  const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
64
+ const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
65
+ const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
66
+ const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
67
  const half * maskh = (const half *) mask + ne11*ic0;
68
 
69
  const int stride_KV2 = nb11 / sizeof(half2);
70
 
71
+ const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
72
 
73
  static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
74
 
 
102
 
103
  __syncthreads();
104
 
105
+ for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F32; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F32) {
 
106
  // Calculate KQ tile and keep track of new maximum KQ values:
107
 
108
  float kqmax_new[ncols/nwarps];
 
267
  const int i0 = i00 + 2*threadIdx.x;
268
 
269
  float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
270
+ if (gridDim.y == 1) {
271
  dst_val.x /= kqsum_j;
272
  dst_val.y /= kqsum_j;
273
  }
274
+ const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
275
+ dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = dst_val.x;
276
+ dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = dst_val.y;
277
  }
278
 
279
+ if (gridDim.y != 1 && threadIdx.x == 0) {
280
+ dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
281
  }
282
  }
283
  #else
 
285
  #endif // FLASH_ATTN_AVAILABLE
286
  }
287
 
288
+ template <int cols_per_block, bool use_logit_softcap>
289
  void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
290
  const ggml_tensor * Q = dst->src[0];
291
  switch (Q->ne[0]) {
 
293
  constexpr int D = 64;
294
  constexpr int nwarps = 8;
295
  constexpr size_t nbytes_shared = 0;
296
+ fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, use_logit_softcap>;
297
+ launch_fattn<D, cols_per_block, 1, -1>
298
+ (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false);
299
  } break;
300
  case 128: {
301
  constexpr int D = 128;
302
  constexpr int nwarps = 8;
303
  constexpr size_t nbytes_shared = 0;
304
+ fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, use_logit_softcap>;
305
+ launch_fattn<D, cols_per_block, 1, -1>
306
+ (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false);
307
  } break;
308
  default: {
309
  GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
 
320
 
321
  if (Q->ne[1] <= 16) {
322
  constexpr int cols_per_block = 16;
 
323
  if (logit_softcap == 0.0f) {
324
  constexpr bool use_logit_softcap = false;
325
+ launch_fattn_tile_f32_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
326
  } else {
327
  constexpr bool use_logit_softcap = true;
328
+ launch_fattn_tile_f32_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
 
 
 
 
 
 
 
 
 
 
 
 
 
329
  }
330
  return;
331
  }
332
 
333
  constexpr int cols_per_block = 32;
 
334
  if (logit_softcap == 0.0f) {
335
  constexpr bool use_logit_softcap = false;
336
+ launch_fattn_tile_f32_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
337
  } else {
338
  constexpr bool use_logit_softcap = true;
339
+ launch_fattn_tile_f32_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
340
  }
341
  }
ggml/src/ggml-cuda/fattn-vec-f16.cuh CHANGED
@@ -1,7 +1,7 @@
1
  #include "common.cuh"
2
  #include "fattn-common.cuh"
3
 
4
- template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
5
  #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
6
  __launch_bounds__(D, 1)
7
  #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
@@ -55,17 +55,16 @@ static __global__ void flash_attn_vec_ext_f16(
55
  constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
56
  constexpr dequantize_1_f16_t dequantize_1_v = get_dequantize_1_f16(type_V);
57
 
58
- const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
59
- const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
60
 
61
  const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
62
- Q += nb02* blockIdx.y + nb01*ic0;
63
- K += nb12*(blockIdx.y / gqa_ratio);
64
- V += nb22*(blockIdx.y / gqa_ratio);
65
 
66
  const half * maskh = (const half *) mask + ne11*ic0;
67
 
68
- const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
69
  const half slopeh = __float2half(slopef);
70
 
71
  static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
@@ -172,8 +171,7 @@ static __global__ void flash_attn_vec_ext_f16(
172
 
173
  half2 VKQ[ncols] = {{0.0f, 0.0f}};
174
 
175
- const int k_start = parallel_blocks == 1 ? 0 : ip*D;
176
- for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
177
  // Calculate KQ tile and keep track of new maximum KQ values:
178
 
179
  // For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
@@ -283,29 +281,29 @@ static __global__ void flash_attn_vec_ext_f16(
283
  kqsum[j_VKQ] = warp_reduce_sum((float)kqsum[j_VKQ]);
284
 
285
  half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ]));
286
- if (parallel_blocks == 1) {
287
  dst_val /= kqsum[j_VKQ];
288
  }
289
- const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
290
- dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
291
  }
292
 
293
- if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
294
- dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
295
  }
296
  #else
297
  NO_DEVICE_CODE;
298
  #endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
299
  }
300
 
301
- template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
302
  void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
303
  constexpr int nwarps = D/WARP_SIZE;
304
- fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>;
305
  constexpr bool need_f16_K = D != 128;
306
  constexpr bool need_f16_V = D != 128 && D != 64;
307
  constexpr size_t nbytes_shared = 0;
308
- launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
309
  }
310
 
311
  template <int D, ggml_type type_K, ggml_type type_V>
@@ -325,65 +323,48 @@ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml
325
  memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
326
 
327
  if (Q->ne[1] == 1) {
328
- constexpr int cols_per_block = 1;
329
- constexpr int parallel_blocks = 4;
330
  if (logit_softcap == 0.0f) {
331
  constexpr bool use_logit_softcap = false;
332
- ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
333
  } else {
334
  constexpr bool use_logit_softcap = true;
335
- ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
336
  }
337
  return;
338
  }
339
 
340
  if (Q->ne[1] == 2) {
341
- constexpr int cols_per_block = 2;
342
- constexpr int parallel_blocks = 4;
343
  if (logit_softcap == 0.0f) {
344
  constexpr bool use_logit_softcap = false;
345
- ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
346
  } else {
347
  constexpr bool use_logit_softcap = true;
348
- ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
349
  }
350
  return;
351
  }
352
 
353
  if (Q->ne[1] <= 4) {
354
- constexpr int cols_per_block = 4;
355
- constexpr int parallel_blocks = 4;
356
  if (logit_softcap == 0.0f) {
357
  constexpr bool use_logit_softcap = false;
358
- ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
359
  } else {
360
  constexpr bool use_logit_softcap = true;
361
- ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
362
  }
363
  return;
364
  }
365
 
366
- if (Q->ne[1] <= 8) {
367
- constexpr int cols_per_block = 8;
368
- constexpr int parallel_blocks = 4;
369
- if (logit_softcap == 0.0f) {
370
- constexpr bool use_logit_softcap = false;
371
- ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
372
- } else {
373
- constexpr bool use_logit_softcap = true;
374
- ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
375
- }
376
- return;
377
- }
378
-
379
- constexpr int cols_per_block = 8;
380
- constexpr int parallel_blocks = 1;
381
  if (logit_softcap == 0.0f) {
382
  constexpr bool use_logit_softcap = false;
383
- ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
384
  } else {
385
  constexpr bool use_logit_softcap = true;
386
- ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
387
  }
388
  }
389
 
 
1
  #include "common.cuh"
2
  #include "fattn-common.cuh"
3
 
4
+ template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
5
  #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
6
  __launch_bounds__(D, 1)
7
  #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 
55
  constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
56
  constexpr dequantize_1_f16_t dequantize_1_v = get_dequantize_1_f16(type_V);
57
 
58
+ const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
 
59
 
60
  const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
61
+ Q += nb02* blockIdx.z + nb01*ic0;
62
+ K += nb12*(blockIdx.z / gqa_ratio);
63
+ V += nb22*(blockIdx.z / gqa_ratio);
64
 
65
  const half * maskh = (const half *) mask + ne11*ic0;
66
 
67
+ const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
68
  const half slopeh = __float2half(slopef);
69
 
70
  static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
 
171
 
172
  half2 VKQ[ncols] = {{0.0f, 0.0f}};
173
 
174
+ for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
 
175
  // Calculate KQ tile and keep track of new maximum KQ values:
176
 
177
  // For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
 
281
  kqsum[j_VKQ] = warp_reduce_sum((float)kqsum[j_VKQ]);
282
 
283
  half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ]));
284
+ if (gridDim.y == 1) {
285
  dst_val /= kqsum[j_VKQ];
286
  }
287
+ const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
288
+ dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val;
289
  }
290
 
291
+ if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
292
+ dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
293
  }
294
  #else
295
  NO_DEVICE_CODE;
296
  #endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
297
  }
298
 
299
+ template <int D, int cols_per_block, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
300
  void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
301
  constexpr int nwarps = D/WARP_SIZE;
302
+ fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, type_K, type_V, use_logit_softcap>;
303
  constexpr bool need_f16_K = D != 128;
304
  constexpr bool need_f16_V = D != 128 && D != 64;
305
  constexpr size_t nbytes_shared = 0;
306
+ launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
307
  }
308
 
309
  template <int D, ggml_type type_K, ggml_type type_V>
 
323
  memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
324
 
325
  if (Q->ne[1] == 1) {
326
+ constexpr int cols_per_block = 1;
 
327
  if (logit_softcap == 0.0f) {
328
  constexpr bool use_logit_softcap = false;
329
+ ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
330
  } else {
331
  constexpr bool use_logit_softcap = true;
332
+ ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
333
  }
334
  return;
335
  }
336
 
337
  if (Q->ne[1] == 2) {
338
+ constexpr int cols_per_block = 2;
 
339
  if (logit_softcap == 0.0f) {
340
  constexpr bool use_logit_softcap = false;
341
+ ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
342
  } else {
343
  constexpr bool use_logit_softcap = true;
344
+ ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
345
  }
346
  return;
347
  }
348
 
349
  if (Q->ne[1] <= 4) {
350
+ constexpr int cols_per_block = 4;
 
351
  if (logit_softcap == 0.0f) {
352
  constexpr bool use_logit_softcap = false;
353
+ ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
354
  } else {
355
  constexpr bool use_logit_softcap = true;
356
+ ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
357
  }
358
  return;
359
  }
360
 
361
+ constexpr int cols_per_block = 8;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362
  if (logit_softcap == 0.0f) {
363
  constexpr bool use_logit_softcap = false;
364
+ ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
365
  } else {
366
  constexpr bool use_logit_softcap = true;
367
+ ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
368
  }
369
  }
370
 
ggml/src/ggml-cuda/fattn-vec-f32.cuh CHANGED
@@ -1,7 +1,7 @@
1
  #include "common.cuh"
2
  #include "fattn-common.cuh"
3
 
4
- template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
5
  #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
6
  __launch_bounds__(D, 1)
7
  #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
@@ -55,16 +55,15 @@ static __global__ void flash_attn_vec_ext_f32(
55
  constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
56
  constexpr dequantize_1_f32_t dequantize_1_v = get_dequantize_1_f32(type_V);
57
 
58
- const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
59
- const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
60
 
61
  const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
62
- Q += nb02* blockIdx.y + nb01*ic0;
63
- K += nb12*(blockIdx.y / gqa_ratio);
64
- V += nb22*(blockIdx.y / gqa_ratio); // K and V have same shape
65
  const half * maskh = (const half *) mask + ne11*ic0;
66
 
67
- const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
68
 
69
  static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
70
  constexpr int nwarps = D / WARP_SIZE;
@@ -167,8 +166,7 @@ static __global__ void flash_attn_vec_ext_f32(
167
 
168
  float VKQ[ncols] = {0.0f};
169
 
170
- const int k_start = parallel_blocks == 1 ? 0 : ip*D;
171
- for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
172
  // Calculate KQ tile and keep track of new maximum KQ values:
173
 
174
  float kqmax_new_arr[ncols];
@@ -268,29 +266,29 @@ static __global__ void flash_attn_vec_ext_f32(
268
  kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
269
 
270
  float dst_val = VKQ[j_VKQ];
271
- if (parallel_blocks == 1) {
272
  dst_val /= kqsum[j_VKQ];
273
  }
274
- const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
275
- dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
276
  }
277
 
278
- if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
279
- dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
280
  }
281
  #else
282
  NO_DEVICE_CODE;
283
  #endif // FLASH_ATTN_AVAILABLE
284
  }
285
 
286
- template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
287
  void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
288
  constexpr int nwarps = D/WARP_SIZE;
289
- fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>;
290
  constexpr bool need_f16_K = D != 128;
291
  constexpr bool need_f16_V = D != 128 && D != 64;
292
  constexpr size_t nbytes_shared = 0;
293
- launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
294
  }
295
 
296
  template <int D, ggml_type type_K, ggml_type type_V>
@@ -307,65 +305,48 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
307
  memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
308
 
309
  if (Q->ne[1] == 1) {
310
- constexpr int cols_per_block = 1;
311
- constexpr int parallel_blocks = 4;
312
  if (logit_softcap == 0.0f) {
313
  constexpr bool use_logit_softcap = false;
314
- ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
315
  } else {
316
  constexpr bool use_logit_softcap = true;
317
- ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
318
  }
319
  return;
320
  }
321
 
322
  if (Q->ne[1] == 2) {
323
- constexpr int cols_per_block = 2;
324
- constexpr int parallel_blocks = 4;
325
  if (logit_softcap == 0.0f) {
326
  constexpr bool use_logit_softcap = false;
327
- ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
328
  } else {
329
  constexpr bool use_logit_softcap = true;
330
- ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
331
  }
332
  return;
333
  }
334
 
335
  if (Q->ne[1] <= 4) {
336
- constexpr int cols_per_block = 4;
337
- constexpr int parallel_blocks = 4;
338
  if (logit_softcap == 0.0f) {
339
  constexpr bool use_logit_softcap = false;
340
- ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
341
  } else {
342
  constexpr bool use_logit_softcap = true;
343
- ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
344
  }
345
  return;
346
  }
347
 
348
- if (Q->ne[1] <= 8) {
349
- constexpr int cols_per_block = 8;
350
- constexpr int parallel_blocks = 4;
351
- if (logit_softcap == 0.0f) {
352
- constexpr bool use_logit_softcap = false;
353
- ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
354
- } else {
355
- constexpr bool use_logit_softcap = true;
356
- ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
357
- }
358
- return;
359
- }
360
-
361
- constexpr int cols_per_block = 8;
362
- constexpr int parallel_blocks = 1;
363
  if (logit_softcap == 0.0f) {
364
  constexpr bool use_logit_softcap = false;
365
- ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
366
  } else {
367
  constexpr bool use_logit_softcap = true;
368
- ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
369
  }
370
  }
371
 
 
1
  #include "common.cuh"
2
  #include "fattn-common.cuh"
3
 
4
+ template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
5
  #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
6
  __launch_bounds__(D, 1)
7
  #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 
55
  constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
56
  constexpr dequantize_1_f32_t dequantize_1_v = get_dequantize_1_f32(type_V);
57
 
58
+ const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
 
59
 
60
  const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
61
+ Q += nb02* blockIdx.z + nb01*ic0;
62
+ K += nb12*(blockIdx.z / gqa_ratio);
63
+ V += nb22*(blockIdx.z / gqa_ratio); // K and V have same shape
64
  const half * maskh = (const half *) mask + ne11*ic0;
65
 
66
+ const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
67
 
68
  static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
69
  constexpr int nwarps = D / WARP_SIZE;
 
166
 
167
  float VKQ[ncols] = {0.0f};
168
 
169
+ for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
 
170
  // Calculate KQ tile and keep track of new maximum KQ values:
171
 
172
  float kqmax_new_arr[ncols];
 
266
  kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
267
 
268
  float dst_val = VKQ[j_VKQ];
269
+ if (gridDim.y == 1) {
270
  dst_val /= kqsum[j_VKQ];
271
  }
272
+ const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
273
+ dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val;
274
  }
275
 
276
+ if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
277
+ dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
278
  }
279
  #else
280
  NO_DEVICE_CODE;
281
  #endif // FLASH_ATTN_AVAILABLE
282
  }
283
 
284
+ template <int D, int cols_per_block, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
285
  void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
286
  constexpr int nwarps = D/WARP_SIZE;
287
+ fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, type_K, type_V, use_logit_softcap>;
288
  constexpr bool need_f16_K = D != 128;
289
  constexpr bool need_f16_V = D != 128 && D != 64;
290
  constexpr size_t nbytes_shared = 0;
291
+ launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
292
  }
293
 
294
  template <int D, ggml_type type_K, ggml_type type_V>
 
305
  memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
306
 
307
  if (Q->ne[1] == 1) {
308
+ constexpr int cols_per_block = 1;
 
309
  if (logit_softcap == 0.0f) {
310
  constexpr bool use_logit_softcap = false;
311
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
312
  } else {
313
  constexpr bool use_logit_softcap = true;
314
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
315
  }
316
  return;
317
  }
318
 
319
  if (Q->ne[1] == 2) {
320
+ constexpr int cols_per_block = 2;
 
321
  if (logit_softcap == 0.0f) {
322
  constexpr bool use_logit_softcap = false;
323
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
324
  } else {
325
  constexpr bool use_logit_softcap = true;
326
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
327
  }
328
  return;
329
  }
330
 
331
  if (Q->ne[1] <= 4) {
332
+ constexpr int cols_per_block = 4;
 
333
  if (logit_softcap == 0.0f) {
334
  constexpr bool use_logit_softcap = false;
335
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
336
  } else {
337
  constexpr bool use_logit_softcap = true;
338
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
339
  }
340
  return;
341
  }
342
 
343
+ constexpr int cols_per_block = 8;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
  if (logit_softcap == 0.0f) {
345
  constexpr bool use_logit_softcap = false;
346
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
347
  } else {
348
  constexpr bool use_logit_softcap = true;
349
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
350
  }
351
  }
352
 
ggml/src/ggml-cuda/fattn-wmma-f16.cu CHANGED
@@ -18,7 +18,7 @@ namespace wmma = rocwmma;
18
  #endif // FP16_MMA_AVAILABLE
19
 
20
  // D == head size, VKQ_stride == num VKQ rows calculated in parallel:
21
- template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_logit_softcap>
22
  __launch_bounds__(nwarps*ggml_cuda_get_physical_warp_size(), 1)
23
  static __global__ void flash_attn_ext_f16(
24
  const char * __restrict__ Q,
@@ -67,8 +67,7 @@ static __global__ void flash_attn_ext_f16(
67
 
68
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
69
 
70
- const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
71
- const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
72
 
73
  static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
74
  static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
@@ -91,16 +90,16 @@ static __global__ void flash_attn_ext_f16(
91
  constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
92
 
93
  const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
94
- const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0);
95
- const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio));
96
- const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
97
  const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0;
98
  const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2);
99
 
100
  const int stride_Q = nb01 / sizeof(float);
101
  const int stride_KV = nb11 / sizeof(half);
102
 
103
- const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
104
  const half slopeh = __float2half(slopef);
105
  const half2 slope2 = make_half2(slopef, slopef);
106
 
@@ -176,7 +175,7 @@ static __global__ void flash_attn_ext_f16(
176
  __syncthreads();
177
 
178
  // Iterate over ne11 == previous tokens:
179
- for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) {
180
  // Calculate tile of KQ:
181
  #pragma unroll
182
  for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
@@ -395,7 +394,7 @@ static __global__ void flash_attn_ext_f16(
395
  if (ic0 + j_VKQ >= ne01) {
396
  return;
397
  }
398
- const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
399
 
400
  float KQ_rowsum_j;
401
  if (std::is_same<KQ_acc_t, float>::value) {
@@ -411,13 +410,13 @@ static __global__ void flash_attn_ext_f16(
411
  break;
412
  }
413
  float dst_val = VKQ[j_VKQ*D_padded + i];
414
- if (parallel_blocks == 1) {
415
  dst_val /= KQ_rowsum_j;
416
  }
417
- dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val;
418
  }
419
 
420
- if (parallel_blocks == 1 || threadIdx.x != 0) {
421
  continue;
422
  }
423
 
@@ -428,7 +427,7 @@ static __global__ void flash_attn_ext_f16(
428
  dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
429
  }
430
  dst_meta_val.y = KQ_rowsum_j;
431
- dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val;
432
  }
433
  #else
434
  NO_DEVICE_CODE;
@@ -462,60 +461,26 @@ static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
462
  template <int D, int cols_per_block, typename KQ_acc_t>
463
  void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
464
  const ggml_tensor * KQV = dst;
465
- const ggml_tensor * Q = dst->src[0];
466
 
467
  constexpr int nwarps = 4;
468
 
469
  constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
470
- const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
471
- const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
472
  const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
473
 
474
  float logit_softcap;
475
  memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
476
 
477
- if (4*blocks_num_pb1 < 2*nsm) {
478
- constexpr int parallel_blocks = 4;
479
- fattn_kernel_t fattn_kernel;
480
- if (logit_softcap == 0.0f) {
481
- constexpr bool use_logit_softcap = false;
482
- fattn_kernel = flash_attn_ext_f16<
483
- D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
484
- } else {
485
- constexpr bool use_logit_softcap = true;
486
- fattn_kernel = flash_attn_ext_f16<
487
- D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
488
- }
489
- launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true, warp_size);
490
- return;
491
- }
492
- if (2*blocks_num_pb1 < 2*nsm) {
493
- constexpr int parallel_blocks = 2;
494
- fattn_kernel_t fattn_kernel;
495
- if (logit_softcap == 0.0f) {
496
- constexpr bool use_logit_softcap = false;
497
- fattn_kernel = flash_attn_ext_f16<
498
- D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
499
- } else {
500
- constexpr bool use_logit_softcap = true;
501
- fattn_kernel = flash_attn_ext_f16<
502
- D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
503
- }
504
- launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true, warp_size);
505
- return;
506
- }
507
- constexpr int parallel_blocks = 1;
508
  fattn_kernel_t fattn_kernel;
509
  if (logit_softcap == 0.0f) {
510
  constexpr bool use_logit_softcap = false;
511
  fattn_kernel = flash_attn_ext_f16<
512
- D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
513
  } else {
514
  constexpr bool use_logit_softcap = true;
515
  fattn_kernel = flash_attn_ext_f16<
516
- D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
517
  }
518
- launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true, warp_size);
519
  }
520
 
521
  void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 
18
  #endif // FP16_MMA_AVAILABLE
19
 
20
  // D == head size, VKQ_stride == num VKQ rows calculated in parallel:
21
+ template<int D, int ncols, int nwarps, int VKQ_stride, typename KQ_acc_t, bool use_logit_softcap>
22
  __launch_bounds__(nwarps*ggml_cuda_get_physical_warp_size(), 1)
23
  static __global__ void flash_attn_ext_f16(
24
  const char * __restrict__ Q,
 
67
 
68
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
69
 
70
+ const int ic0 = ncols*blockIdx.x; // Index of the first Q/QKV column to work on.
 
71
 
72
  static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
73
  static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
 
90
  constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
91
 
92
  const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
93
+ const float * Q_f = (const float *) (Q + nb02* blockIdx.z + nb01*ic0);
94
+ const half * K_h = (const half *) (K + nb12*(blockIdx.z / gqa_ratio));
95
+ const half * V_h = (const half *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
96
  const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0;
97
  const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2);
98
 
99
  const int stride_Q = nb01 / sizeof(float);
100
  const int stride_KV = nb11 / sizeof(half);
101
 
102
+ const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
103
  const half slopeh = __float2half(slopef);
104
  const half2 slope2 = make_half2(slopef, slopef);
105
 
 
175
  __syncthreads();
176
 
177
  // Iterate over ne11 == previous tokens:
178
+ for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE) {
179
  // Calculate tile of KQ:
180
  #pragma unroll
181
  for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
 
394
  if (ic0 + j_VKQ >= ne01) {
395
  return;
396
  }
397
+ const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
398
 
399
  float KQ_rowsum_j;
400
  if (std::is_same<KQ_acc_t, float>::value) {
 
410
  break;
411
  }
412
  float dst_val = VKQ[j_VKQ*D_padded + i];
413
+ if (gridDim.y == 1) {
414
  dst_val /= KQ_rowsum_j;
415
  }
416
+ dst[j_dst*gridDim.z*D + blockIdx.z*D + i] = dst_val;
417
  }
418
 
419
+ if (gridDim.y == 1 || threadIdx.x != 0) {
420
  continue;
421
  }
422
 
 
427
  dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
428
  }
429
  dst_meta_val.y = KQ_rowsum_j;
430
+ dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = dst_meta_val;
431
  }
432
  #else
433
  NO_DEVICE_CODE;
 
461
  template <int D, int cols_per_block, typename KQ_acc_t>
462
  void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
463
  const ggml_tensor * KQV = dst;
 
464
 
465
  constexpr int nwarps = 4;
466
 
467
  constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
 
 
468
  const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
469
 
470
  float logit_softcap;
471
  memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
472
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
473
  fattn_kernel_t fattn_kernel;
474
  if (logit_softcap == 0.0f) {
475
  constexpr bool use_logit_softcap = false;
476
  fattn_kernel = flash_attn_ext_f16<
477
+ D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>;
478
  } else {
479
  constexpr bool use_logit_softcap = true;
480
  fattn_kernel = flash_attn_ext_f16<
481
+ D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>;
482
  }
483
+ launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, 0, FATTN_KQ_STRIDE, true, true, false, warp_size);
484
  }
485
 
486
  void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml/src/ggml-cuda/fattn.cu CHANGED
@@ -281,13 +281,13 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
281
 
282
  if (!fp16_mma_available(cc)) {
283
  if (prec == GGML_PREC_DEFAULT) {
284
- if (Q->ne[1] <= 8) {
285
  ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
286
  } else {
287
  ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
288
  }
289
  } else {
290
- if (Q->ne[1] <= 8) {
291
  ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
292
  } else {
293
  ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
@@ -296,17 +296,17 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
296
  return;
297
  }
298
 
299
- const int gqa_ratio = Q->ne[2] / K->ne[2];
300
- const bool mma_fast_for_bs1 = fp16_mma_available(cc) && gqa_ratio % 2 == 0 &&
301
- K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && mask;
302
- if (Q->ne[1] == 1 && Q->ne[0] % (2*warp_size) == 0 && !mma_fast_for_bs1) {
 
303
  if (prec == GGML_PREC_DEFAULT) {
304
  ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
305
- return;
306
- } else if(Q->ne[0] <= 128) {
307
  ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
308
- return;
309
  }
 
310
  }
311
 
312
  // The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
 
281
 
282
  if (!fp16_mma_available(cc)) {
283
  if (prec == GGML_PREC_DEFAULT) {
284
+ if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
285
  ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
286
  } else {
287
  ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
288
  }
289
  } else {
290
+ if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
291
  ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
292
  } else {
293
  ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
 
296
  return;
297
  }
298
 
299
+ const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
300
+ const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
301
+ const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < GGML_CUDA_CC_ADA_LOVELACE && !mma_needs_data_conversion;
302
+ const bool can_use_vector_kernel = (Q->ne[0] % (2*warp_size) == 0) && (prec == GGML_PREC_DEFAULT || Q->ne[0] <= 128);
303
+ if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
304
  if (prec == GGML_PREC_DEFAULT) {
305
  ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
306
+ } else {
 
307
  ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
 
308
  }
309
+ return;
310
  }
311
 
312
  // The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
ggml/src/ggml-cuda/ggml-cuda.cu CHANGED
@@ -3230,6 +3230,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3230
  #ifndef FLASH_ATTN_AVAILABLE
3231
  return false;
3232
  #endif // FLASH_ATTN_AVAILABLE
 
 
 
3233
  if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
3234
  return false;
3235
  }
 
3230
  #ifndef FLASH_ATTN_AVAILABLE
3231
  return false;
3232
  #endif // FLASH_ATTN_AVAILABLE
3233
+ if (op->src[0]->ne[3] != 1) {
3234
+ return false;
3235
+ }
3236
  if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
3237
  return false;
3238
  }
ggml/src/ggml-cuda/vendors/hip.h CHANGED
@@ -129,6 +129,7 @@
129
  #define cudaGraph_t hipGraph_t
130
  #define cudaStream_t hipStream_t
131
  #define cudaSuccess hipSuccess
 
132
  #define __trap() do { abort(); __builtin_unreachable(); } while(0)
133
  #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
134
  #define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
 
129
  #define cudaGraph_t hipGraph_t
130
  #define cudaStream_t hipStream_t
131
  #define cudaSuccess hipSuccess
132
+ #define cudaOccupancyMaxActiveBlocksPerMultiprocessor hipOccupancyMaxActiveBlocksPerMultiprocessor
133
  #define __trap() do { abort(); __builtin_unreachable(); } while(0)
134
  #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
135
  #define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
ggml/src/ggml-cuda/vendors/musa.h CHANGED
@@ -134,5 +134,6 @@
134
  #define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
135
  #define cudaStreamBeginCapture musaStreamBeginCapture
136
  #define cudaStreamEndCapture musaStreamEndCapture
 
137
 
138
  typedef mt_bfloat16 nv_bfloat16;
 
134
  #define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
135
  #define cudaStreamBeginCapture musaStreamBeginCapture
136
  #define cudaStreamEndCapture musaStreamEndCapture
137
+ #define cudaOccupancyMaxActiveBlocksPerMultiprocessor musaOccupancyMaxActiveBlocksPerMultiprocessor
138
 
139
  typedef mt_bfloat16 nv_bfloat16;