Spaces:
Running
Running
Commit
·
3a7ca19
1
Parent(s):
53dd8ad
CUDA: Improve flash decoding kernel GPU occupancy for BS=1 case (llama/12183)
Browse files- Find out active blocks per SM using cudaOccupancyMaxActiveBlocksPerMultiprocessor API. Use this value to determine the optimal parallel_blocks value.
- Prefer vector flash attention kernels over MMA kernel for BS=1
Fixes Issue: #12182
---------
Co-authored-by: Johannes Gäßler <[email protected]>
- ggml/src/ggml-cuda/fattn-common.cuh +61 -27
- ggml/src/ggml-cuda/fattn-mma-f16.cuh +2 -1
- ggml/src/ggml-cuda/fattn-tile-f16.cu +24 -39
- ggml/src/ggml-cuda/fattn-tile-f32.cu +24 -39
- ggml/src/ggml-cuda/fattn-vec-f16.cuh +27 -46
- ggml/src/ggml-cuda/fattn-vec-f32.cuh +27 -46
- ggml/src/ggml-cuda/fattn-wmma-f16.cu +15 -50
- ggml/src/ggml-cuda/fattn.cu +9 -9
- ggml/src/ggml-cuda/ggml-cuda.cu +3 -0
- ggml/src/ggml-cuda/vendors/hip.h +1 -0
- ggml/src/ggml-cuda/vendors/musa.h +1 -0
ggml/src/ggml-cuda/fattn-common.cuh
CHANGED
|
@@ -606,48 +606,47 @@ static __global__ void flash_attn_stream_k_fixup(
|
|
| 606 |
*dst = dst_val / rowsum;
|
| 607 |
}
|
| 608 |
|
| 609 |
-
template<int D
|
| 610 |
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 611 |
__launch_bounds__(D, 1)
|
| 612 |
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 613 |
static __global__ void flash_attn_combine_results(
|
| 614 |
const float * __restrict__ VKQ_parts,
|
| 615 |
const float2 * __restrict__ VKQ_meta,
|
| 616 |
-
float * __restrict__ dst
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
|
|
|
|
| 620 |
|
| 621 |
const int tid = threadIdx.x;
|
| 622 |
__builtin_assume(tid < D);
|
| 623 |
|
| 624 |
-
__shared__ float2 meta[
|
| 625 |
if (tid < 2*parallel_blocks) {
|
| 626 |
-
((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.
|
| 627 |
}
|
| 628 |
|
| 629 |
__syncthreads();
|
| 630 |
|
| 631 |
float kqmax = meta[0].x;
|
| 632 |
-
#pragma unroll
|
| 633 |
for (int l = 1; l < parallel_blocks; ++l) {
|
| 634 |
kqmax = max(kqmax, meta[l].x);
|
| 635 |
}
|
| 636 |
|
| 637 |
float VKQ_numerator = 0.0f;
|
| 638 |
float VKQ_denominator = 0.0f;
|
| 639 |
-
#pragma unroll
|
| 640 |
for (int l = 0; l < parallel_blocks; ++l) {
|
| 641 |
const float diff = meta[l].x - kqmax;
|
| 642 |
const float KQ_max_scale = expf(diff);
|
| 643 |
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
|
| 644 |
*((uint32_t *) &KQ_max_scale) &= ftz_mask;
|
| 645 |
|
| 646 |
-
VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.
|
| 647 |
VKQ_denominator += KQ_max_scale * meta[l].y;
|
| 648 |
}
|
| 649 |
|
| 650 |
-
dst[blockIdx.
|
| 651 |
}
|
| 652 |
|
| 653 |
static void on_no_fattn_vec_case(const int D) {
|
|
@@ -671,12 +670,10 @@ static void on_no_fattn_vec_case(const int D) {
|
|
| 671 |
}
|
| 672 |
}
|
| 673 |
|
| 674 |
-
|
| 675 |
-
template <int D, int ncols1, int ncols2, int parallel_blocks, int KQ_stride>
|
| 676 |
void launch_fattn(
|
| 677 |
-
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
|
| 678 |
-
const int
|
| 679 |
-
const int warp_size = WARP_SIZE
|
| 680 |
) {
|
| 681 |
constexpr int ncols = ncols1 * ncols2;
|
| 682 |
|
|
@@ -748,12 +745,14 @@ void launch_fattn(
|
|
| 748 |
nb23 = nb23*bs*sizeof(half)/ts;
|
| 749 |
}
|
| 750 |
|
|
|
|
|
|
|
| 751 |
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
|
| 752 |
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
|
| 753 |
|
| 754 |
const dim3 block_dim(warp_size, nwarps, 1);
|
| 755 |
dim3 blocks_num;
|
| 756 |
-
if (
|
| 757 |
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
|
| 758 |
const int max_blocks = 2*nsm;
|
| 759 |
const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
|
|
@@ -769,9 +768,43 @@ void launch_fattn(
|
|
| 769 |
|
| 770 |
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float));
|
| 771 |
} else {
|
| 772 |
-
|
| 773 |
-
|
| 774 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 775 |
|
| 776 |
if (parallel_blocks > 1) {
|
| 777 |
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
|
|
@@ -803,7 +836,7 @@ void launch_fattn(
|
|
| 803 |
K_data,
|
| 804 |
V_data,
|
| 805 |
mask ? ((const char *) mask->data) : nullptr,
|
| 806 |
-
|
| 807 |
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
| 808 |
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
| 809 |
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
|
@@ -815,7 +848,7 @@ void launch_fattn(
|
|
| 815 |
);
|
| 816 |
CUDA_CHECK(cudaGetLastError());
|
| 817 |
|
| 818 |
-
if
|
| 819 |
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
|
| 820 |
const dim3 block_dim_combine(D, 1, 1);
|
| 821 |
const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
|
|
@@ -824,13 +857,14 @@ void launch_fattn(
|
|
| 824 |
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
| 825 |
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
|
| 826 |
}
|
| 827 |
-
} else if
|
| 828 |
const dim3 block_dim_combine(D, 1, 1);
|
| 829 |
-
const dim3 blocks_num_combine(Q->ne[1],
|
|
|
|
| 830 |
|
| 831 |
-
flash_attn_combine_results<D
|
| 832 |
-
<<<blocks_num_combine, block_dim_combine,
|
| 833 |
-
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
|
| 834 |
}
|
| 835 |
CUDA_CHECK(cudaGetLastError());
|
| 836 |
}
|
|
|
|
| 606 |
*dst = dst_val / rowsum;
|
| 607 |
}
|
| 608 |
|
| 609 |
+
template<int D> // D == head size
|
| 610 |
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 611 |
__launch_bounds__(D, 1)
|
| 612 |
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 613 |
static __global__ void flash_attn_combine_results(
|
| 614 |
const float * __restrict__ VKQ_parts,
|
| 615 |
const float2 * __restrict__ VKQ_meta,
|
| 616 |
+
float * __restrict__ dst,
|
| 617 |
+
const int parallel_blocks) {
|
| 618 |
+
VKQ_parts += parallel_blocks*D * gridDim.z*blockIdx.x;
|
| 619 |
+
VKQ_meta += parallel_blocks * gridDim.z*blockIdx.x;
|
| 620 |
+
dst += D * gridDim.z*blockIdx.x;
|
| 621 |
|
| 622 |
const int tid = threadIdx.x;
|
| 623 |
__builtin_assume(tid < D);
|
| 624 |
|
| 625 |
+
extern __shared__ float2 meta[];
|
| 626 |
if (tid < 2*parallel_blocks) {
|
| 627 |
+
((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + tid];
|
| 628 |
}
|
| 629 |
|
| 630 |
__syncthreads();
|
| 631 |
|
| 632 |
float kqmax = meta[0].x;
|
|
|
|
| 633 |
for (int l = 1; l < parallel_blocks; ++l) {
|
| 634 |
kqmax = max(kqmax, meta[l].x);
|
| 635 |
}
|
| 636 |
|
| 637 |
float VKQ_numerator = 0.0f;
|
| 638 |
float VKQ_denominator = 0.0f;
|
|
|
|
| 639 |
for (int l = 0; l < parallel_blocks; ++l) {
|
| 640 |
const float diff = meta[l].x - kqmax;
|
| 641 |
const float KQ_max_scale = expf(diff);
|
| 642 |
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
|
| 643 |
*((uint32_t *) &KQ_max_scale) &= ftz_mask;
|
| 644 |
|
| 645 |
+
VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.z*D + blockIdx.z*D + tid];
|
| 646 |
VKQ_denominator += KQ_max_scale * meta[l].y;
|
| 647 |
}
|
| 648 |
|
| 649 |
+
dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator;
|
| 650 |
}
|
| 651 |
|
| 652 |
static void on_no_fattn_vec_case(const int D) {
|
|
|
|
| 670 |
}
|
| 671 |
}
|
| 672 |
|
| 673 |
+
template <int D, int ncols1, int ncols2, int KQ_stride>
|
|
|
|
| 674 |
void launch_fattn(
|
| 675 |
+
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
|
| 676 |
+
const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
|
|
|
|
| 677 |
) {
|
| 678 |
constexpr int ncols = ncols1 * ncols2;
|
| 679 |
|
|
|
|
| 745 |
nb23 = nb23*bs*sizeof(half)/ts;
|
| 746 |
}
|
| 747 |
|
| 748 |
+
int parallel_blocks = 1;
|
| 749 |
+
|
| 750 |
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
|
| 751 |
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
|
| 752 |
|
| 753 |
const dim3 block_dim(warp_size, nwarps, 1);
|
| 754 |
dim3 blocks_num;
|
| 755 |
+
if (stream_k) {
|
| 756 |
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
|
| 757 |
const int max_blocks = 2*nsm;
|
| 758 |
const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
|
|
|
|
| 768 |
|
| 769 |
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float));
|
| 770 |
} else {
|
| 771 |
+
GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0);
|
| 772 |
+
const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
|
| 773 |
+
|
| 774 |
+
int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
|
| 775 |
+
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
|
| 776 |
+
|
| 777 |
+
// parallel_blocks should be at least large enough to achieve max. occupancy for a single wave:
|
| 778 |
+
parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1);
|
| 779 |
+
|
| 780 |
+
// parallel_blocks must not be larger than what the tensor size allows:
|
| 781 |
+
parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
|
| 782 |
+
|
| 783 |
+
// If ntiles_total % blocks_per_wave != 0 then some efficiency is lost due to tail effects.
|
| 784 |
+
// Test whether parallel_blocks can be set to a higher value for better efficiency.
|
| 785 |
+
const int blocks_per_wave = nsm * max_blocks_per_sm;
|
| 786 |
+
int nwaves_best = 0;
|
| 787 |
+
int efficiency_percent_best = 0;
|
| 788 |
+
for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KQ; ++parallel_blocks_test) {
|
| 789 |
+
const int nblocks_total = ntiles_total * parallel_blocks_test;
|
| 790 |
+
const int nwaves = (nblocks_total + blocks_per_wave - 1) / blocks_per_wave;
|
| 791 |
+
const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave);
|
| 792 |
+
|
| 793 |
+
// Stop trying configurations with more waves if we already have good efficiency to avoid excessive overhead.
|
| 794 |
+
if (efficiency_percent_best >= 90 && nwaves > nwaves_best) {
|
| 795 |
+
break;
|
| 796 |
+
}
|
| 797 |
+
|
| 798 |
+
if (efficiency_percent > efficiency_percent_best) {
|
| 799 |
+
nwaves_best = nwaves;
|
| 800 |
+
efficiency_percent_best = efficiency_percent;
|
| 801 |
+
parallel_blocks = parallel_blocks_test;
|
| 802 |
+
}
|
| 803 |
+
}
|
| 804 |
+
|
| 805 |
+
blocks_num.x = ntiles_x;
|
| 806 |
+
blocks_num.y = parallel_blocks;
|
| 807 |
+
blocks_num.z = Q->ne[2]*Q->ne[3];
|
| 808 |
|
| 809 |
if (parallel_blocks > 1) {
|
| 810 |
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
|
|
|
|
| 836 |
K_data,
|
| 837 |
V_data,
|
| 838 |
mask ? ((const char *) mask->data) : nullptr,
|
| 839 |
+
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
|
| 840 |
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
| 841 |
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
| 842 |
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
|
|
|
| 848 |
);
|
| 849 |
CUDA_CHECK(cudaGetLastError());
|
| 850 |
|
| 851 |
+
if (stream_k) {
|
| 852 |
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
|
| 853 |
const dim3 block_dim_combine(D, 1, 1);
|
| 854 |
const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
|
|
|
|
| 857 |
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
| 858 |
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
|
| 859 |
}
|
| 860 |
+
} else if (parallel_blocks > 1) {
|
| 861 |
const dim3 block_dim_combine(D, 1, 1);
|
| 862 |
+
const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z);
|
| 863 |
+
const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
|
| 864 |
|
| 865 |
+
flash_attn_combine_results<D>
|
| 866 |
+
<<<blocks_num_combine, block_dim_combine, nbytes_shared_combine, main_stream>>>
|
| 867 |
+
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks);
|
| 868 |
}
|
| 869 |
CUDA_CHECK(cudaGetLastError());
|
| 870 |
}
|
ggml/src/ggml-cuda/fattn-mma-f16.cuh
CHANGED
|
@@ -970,7 +970,8 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
|
| 970 |
fattn_kernel = flash_attn_ext_f16<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap>;
|
| 971 |
}
|
| 972 |
|
| 973 |
-
launch_fattn<D, ncols1, ncols2,
|
|
|
|
| 974 |
}
|
| 975 |
|
| 976 |
|
|
|
|
| 970 |
fattn_kernel = flash_attn_ext_f16<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap>;
|
| 971 |
}
|
| 972 |
|
| 973 |
+
launch_fattn<D, ncols1, ncols2, KQ_per_iter>
|
| 974 |
+
(ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, FATTN_KQ_STRIDE, true, true, true);
|
| 975 |
}
|
| 976 |
|
| 977 |
|
ggml/src/ggml-cuda/fattn-tile-f16.cu
CHANGED
|
@@ -4,7 +4,7 @@
|
|
| 4 |
|
| 5 |
#define FATTN_KQ_STRIDE_TILE_F16 64
|
| 6 |
|
| 7 |
-
template<int D, int ncols, int nwarps,
|
| 8 |
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 9 |
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
| 10 |
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
|
@@ -58,18 +58,17 @@ static __global__ void flash_attn_tile_ext_f16(
|
|
| 58 |
|
| 59 |
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
| 60 |
|
| 61 |
-
const int ic0 =
|
| 62 |
-
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
| 63 |
|
| 64 |
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
| 65 |
-
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.
|
| 66 |
-
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.
|
| 67 |
-
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.
|
| 68 |
const half * maskh = (const half *) mask + ne11*ic0;
|
| 69 |
|
| 70 |
const int stride_KV2 = nb11 / sizeof(half2);
|
| 71 |
|
| 72 |
-
const float slopef = get_alibi_slope(max_bias, blockIdx.
|
| 73 |
const half slopeh = __float2half(slopef);
|
| 74 |
|
| 75 |
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
|
@@ -105,8 +104,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
|
| 105 |
|
| 106 |
__syncthreads();
|
| 107 |
|
| 108 |
-
|
| 109 |
-
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE_TILE_F16) {
|
| 110 |
// Calculate KQ tile and keep track of new maximum KQ values:
|
| 111 |
|
| 112 |
half kqmax_new[ncols/nwarps];
|
|
@@ -271,16 +269,16 @@ static __global__ void flash_attn_tile_ext_f16(
|
|
| 271 |
const int i0 = i00 + 2*threadIdx.x;
|
| 272 |
|
| 273 |
half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
|
| 274 |
-
if (
|
| 275 |
dst_val /= __half2half2(kqsum_j);
|
| 276 |
}
|
| 277 |
-
const int j_dst = (ic0 + j_VKQ)*
|
| 278 |
-
dst[j_dst*D*gridDim.
|
| 279 |
-
dst[j_dst*D*gridDim.
|
| 280 |
}
|
| 281 |
|
| 282 |
-
if (
|
| 283 |
-
dst_meta[(ic0 + j_VKQ)*gridDim.
|
| 284 |
}
|
| 285 |
}
|
| 286 |
#else
|
|
@@ -288,7 +286,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
|
| 288 |
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
| 289 |
}
|
| 290 |
|
| 291 |
-
template <int cols_per_block,
|
| 292 |
void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 293 |
const ggml_tensor * Q = dst->src[0];
|
| 294 |
switch (Q->ne[0]) {
|
|
@@ -296,15 +294,17 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|
| 296 |
constexpr int D = 64;
|
| 297 |
constexpr int nwarps = 8;
|
| 298 |
constexpr size_t nbytes_shared = 0;
|
| 299 |
-
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps,
|
| 300 |
-
launch_fattn<D, cols_per_block, 1,
|
|
|
|
| 301 |
} break;
|
| 302 |
case 128: {
|
| 303 |
constexpr int D = 128;
|
| 304 |
constexpr int nwarps = 8;
|
| 305 |
constexpr size_t nbytes_shared = 0;
|
| 306 |
-
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps,
|
| 307 |
-
launch_fattn<D, cols_per_block, 1,
|
|
|
|
| 308 |
} break;
|
| 309 |
default: {
|
| 310 |
GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
|
|
@@ -324,37 +324,22 @@ void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_ten
|
|
| 324 |
|
| 325 |
if (Q->ne[1] <= 16) {
|
| 326 |
constexpr int cols_per_block = 16;
|
| 327 |
-
constexpr int parallel_blocks = 4;
|
| 328 |
if (logit_softcap == 0.0f) {
|
| 329 |
constexpr bool use_logit_softcap = false;
|
| 330 |
-
launch_fattn_tile_f16_64_128<cols_per_block,
|
| 331 |
} else {
|
| 332 |
constexpr bool use_logit_softcap = true;
|
| 333 |
-
launch_fattn_tile_f16_64_128<cols_per_block,
|
| 334 |
-
}
|
| 335 |
-
return;
|
| 336 |
-
}
|
| 337 |
-
|
| 338 |
-
if (Q->ne[1] <= 32) {
|
| 339 |
-
constexpr int cols_per_block = 32;
|
| 340 |
-
constexpr int parallel_blocks = 4;
|
| 341 |
-
if (logit_softcap == 0.0f) {
|
| 342 |
-
constexpr bool use_logit_softcap = false;
|
| 343 |
-
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
| 344 |
-
} else {
|
| 345 |
-
constexpr bool use_logit_softcap = true;
|
| 346 |
-
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
| 347 |
}
|
| 348 |
return;
|
| 349 |
}
|
| 350 |
|
| 351 |
constexpr int cols_per_block = 32;
|
| 352 |
-
constexpr int parallel_blocks = 1;
|
| 353 |
if (logit_softcap == 0.0f) {
|
| 354 |
constexpr bool use_logit_softcap = false;
|
| 355 |
-
launch_fattn_tile_f16_64_128<cols_per_block,
|
| 356 |
} else {
|
| 357 |
constexpr bool use_logit_softcap = true;
|
| 358 |
-
launch_fattn_tile_f16_64_128<cols_per_block,
|
| 359 |
}
|
| 360 |
}
|
|
|
|
| 4 |
|
| 5 |
#define FATTN_KQ_STRIDE_TILE_F16 64
|
| 6 |
|
| 7 |
+
template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
|
| 8 |
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 9 |
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
| 10 |
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
|
|
|
| 58 |
|
| 59 |
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
| 60 |
|
| 61 |
+
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
|
|
|
| 62 |
|
| 63 |
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
| 64 |
+
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
|
| 65 |
+
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
|
| 66 |
+
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
|
| 67 |
const half * maskh = (const half *) mask + ne11*ic0;
|
| 68 |
|
| 69 |
const int stride_KV2 = nb11 / sizeof(half2);
|
| 70 |
|
| 71 |
+
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
|
| 72 |
const half slopeh = __float2half(slopef);
|
| 73 |
|
| 74 |
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
|
|
|
| 104 |
|
| 105 |
__syncthreads();
|
| 106 |
|
| 107 |
+
for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F16; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F16) {
|
|
|
|
| 108 |
// Calculate KQ tile and keep track of new maximum KQ values:
|
| 109 |
|
| 110 |
half kqmax_new[ncols/nwarps];
|
|
|
|
| 269 |
const int i0 = i00 + 2*threadIdx.x;
|
| 270 |
|
| 271 |
half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
|
| 272 |
+
if (gridDim.y == 1) {
|
| 273 |
dst_val /= __half2half2(kqsum_j);
|
| 274 |
}
|
| 275 |
+
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
|
| 276 |
+
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = __low2float(dst_val);
|
| 277 |
+
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = __high2float(dst_val);
|
| 278 |
}
|
| 279 |
|
| 280 |
+
if (gridDim.y != 1 && threadIdx.x == 0) {
|
| 281 |
+
dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
|
| 282 |
}
|
| 283 |
}
|
| 284 |
#else
|
|
|
|
| 286 |
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
| 287 |
}
|
| 288 |
|
| 289 |
+
template <int cols_per_block, bool use_logit_softcap>
|
| 290 |
void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 291 |
const ggml_tensor * Q = dst->src[0];
|
| 292 |
switch (Q->ne[0]) {
|
|
|
|
| 294 |
constexpr int D = 64;
|
| 295 |
constexpr int nwarps = 8;
|
| 296 |
constexpr size_t nbytes_shared = 0;
|
| 297 |
+
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, use_logit_softcap>;
|
| 298 |
+
launch_fattn<D, cols_per_block, 1, -1>
|
| 299 |
+
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false);
|
| 300 |
} break;
|
| 301 |
case 128: {
|
| 302 |
constexpr int D = 128;
|
| 303 |
constexpr int nwarps = 8;
|
| 304 |
constexpr size_t nbytes_shared = 0;
|
| 305 |
+
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, use_logit_softcap>;
|
| 306 |
+
launch_fattn<D, cols_per_block, 1, -1>
|
| 307 |
+
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false);
|
| 308 |
} break;
|
| 309 |
default: {
|
| 310 |
GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
|
|
|
|
| 324 |
|
| 325 |
if (Q->ne[1] <= 16) {
|
| 326 |
constexpr int cols_per_block = 16;
|
|
|
|
| 327 |
if (logit_softcap == 0.0f) {
|
| 328 |
constexpr bool use_logit_softcap = false;
|
| 329 |
+
launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
|
| 330 |
} else {
|
| 331 |
constexpr bool use_logit_softcap = true;
|
| 332 |
+
launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
}
|
| 334 |
return;
|
| 335 |
}
|
| 336 |
|
| 337 |
constexpr int cols_per_block = 32;
|
|
|
|
| 338 |
if (logit_softcap == 0.0f) {
|
| 339 |
constexpr bool use_logit_softcap = false;
|
| 340 |
+
launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
|
| 341 |
} else {
|
| 342 |
constexpr bool use_logit_softcap = true;
|
| 343 |
+
launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
|
| 344 |
}
|
| 345 |
}
|
ggml/src/ggml-cuda/fattn-tile-f32.cu
CHANGED
|
@@ -4,7 +4,7 @@
|
|
| 4 |
|
| 5 |
#define FATTN_KQ_STRIDE_TILE_F32 32
|
| 6 |
|
| 7 |
-
template<int D, int ncols, int nwarps,
|
| 8 |
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 9 |
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
| 10 |
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
|
@@ -58,18 +58,17 @@ static __global__ void flash_attn_tile_ext_f32(
|
|
| 58 |
|
| 59 |
// In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
| 60 |
|
| 61 |
-
const int ic0 =
|
| 62 |
-
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
| 63 |
|
| 64 |
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
| 65 |
-
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.
|
| 66 |
-
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.
|
| 67 |
-
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.
|
| 68 |
const half * maskh = (const half *) mask + ne11*ic0;
|
| 69 |
|
| 70 |
const int stride_KV2 = nb11 / sizeof(half2);
|
| 71 |
|
| 72 |
-
const float slope = get_alibi_slope(max_bias, blockIdx.
|
| 73 |
|
| 74 |
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
| 75 |
|
|
@@ -103,8 +102,7 @@ static __global__ void flash_attn_tile_ext_f32(
|
|
| 103 |
|
| 104 |
__syncthreads();
|
| 105 |
|
| 106 |
-
|
| 107 |
-
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE_TILE_F32) {
|
| 108 |
// Calculate KQ tile and keep track of new maximum KQ values:
|
| 109 |
|
| 110 |
float kqmax_new[ncols/nwarps];
|
|
@@ -269,17 +267,17 @@ static __global__ void flash_attn_tile_ext_f32(
|
|
| 269 |
const int i0 = i00 + 2*threadIdx.x;
|
| 270 |
|
| 271 |
float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
|
| 272 |
-
if (
|
| 273 |
dst_val.x /= kqsum_j;
|
| 274 |
dst_val.y /= kqsum_j;
|
| 275 |
}
|
| 276 |
-
const int j_dst = (ic0 + j_VKQ)*
|
| 277 |
-
dst[j_dst*D*gridDim.
|
| 278 |
-
dst[j_dst*D*gridDim.
|
| 279 |
}
|
| 280 |
|
| 281 |
-
if (
|
| 282 |
-
dst_meta[(ic0 + j_VKQ)*gridDim.
|
| 283 |
}
|
| 284 |
}
|
| 285 |
#else
|
|
@@ -287,7 +285,7 @@ static __global__ void flash_attn_tile_ext_f32(
|
|
| 287 |
#endif // FLASH_ATTN_AVAILABLE
|
| 288 |
}
|
| 289 |
|
| 290 |
-
template <int cols_per_block,
|
| 291 |
void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 292 |
const ggml_tensor * Q = dst->src[0];
|
| 293 |
switch (Q->ne[0]) {
|
|
@@ -295,15 +293,17 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|
| 295 |
constexpr int D = 64;
|
| 296 |
constexpr int nwarps = 8;
|
| 297 |
constexpr size_t nbytes_shared = 0;
|
| 298 |
-
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps,
|
| 299 |
-
launch_fattn<D, cols_per_block, 1,
|
|
|
|
| 300 |
} break;
|
| 301 |
case 128: {
|
| 302 |
constexpr int D = 128;
|
| 303 |
constexpr int nwarps = 8;
|
| 304 |
constexpr size_t nbytes_shared = 0;
|
| 305 |
-
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps,
|
| 306 |
-
launch_fattn<D, cols_per_block, 1,
|
|
|
|
| 307 |
} break;
|
| 308 |
default: {
|
| 309 |
GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
|
|
@@ -320,37 +320,22 @@ void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_ten
|
|
| 320 |
|
| 321 |
if (Q->ne[1] <= 16) {
|
| 322 |
constexpr int cols_per_block = 16;
|
| 323 |
-
constexpr int parallel_blocks = 4;
|
| 324 |
if (logit_softcap == 0.0f) {
|
| 325 |
constexpr bool use_logit_softcap = false;
|
| 326 |
-
launch_fattn_tile_f32_64_128<cols_per_block,
|
| 327 |
} else {
|
| 328 |
constexpr bool use_logit_softcap = true;
|
| 329 |
-
launch_fattn_tile_f32_64_128<cols_per_block,
|
| 330 |
-
}
|
| 331 |
-
return;
|
| 332 |
-
}
|
| 333 |
-
|
| 334 |
-
if (Q->ne[1] <= 32) {
|
| 335 |
-
constexpr int cols_per_block = 32;
|
| 336 |
-
constexpr int parallel_blocks = 4;
|
| 337 |
-
if (logit_softcap == 0.0f) {
|
| 338 |
-
constexpr bool use_logit_softcap = false;
|
| 339 |
-
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
| 340 |
-
} else {
|
| 341 |
-
constexpr bool use_logit_softcap = true;
|
| 342 |
-
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
| 343 |
}
|
| 344 |
return;
|
| 345 |
}
|
| 346 |
|
| 347 |
constexpr int cols_per_block = 32;
|
| 348 |
-
constexpr int parallel_blocks = 1;
|
| 349 |
if (logit_softcap == 0.0f) {
|
| 350 |
constexpr bool use_logit_softcap = false;
|
| 351 |
-
launch_fattn_tile_f32_64_128<cols_per_block,
|
| 352 |
} else {
|
| 353 |
constexpr bool use_logit_softcap = true;
|
| 354 |
-
launch_fattn_tile_f32_64_128<cols_per_block,
|
| 355 |
}
|
| 356 |
}
|
|
|
|
| 4 |
|
| 5 |
#define FATTN_KQ_STRIDE_TILE_F32 32
|
| 6 |
|
| 7 |
+
template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
|
| 8 |
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 9 |
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
| 10 |
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
|
|
|
| 58 |
|
| 59 |
// In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
| 60 |
|
| 61 |
+
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
|
|
|
| 62 |
|
| 63 |
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
| 64 |
+
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
|
| 65 |
+
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
|
| 66 |
+
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
|
| 67 |
const half * maskh = (const half *) mask + ne11*ic0;
|
| 68 |
|
| 69 |
const int stride_KV2 = nb11 / sizeof(half2);
|
| 70 |
|
| 71 |
+
const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
|
| 72 |
|
| 73 |
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
| 74 |
|
|
|
|
| 102 |
|
| 103 |
__syncthreads();
|
| 104 |
|
| 105 |
+
for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F32; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F32) {
|
|
|
|
| 106 |
// Calculate KQ tile and keep track of new maximum KQ values:
|
| 107 |
|
| 108 |
float kqmax_new[ncols/nwarps];
|
|
|
|
| 267 |
const int i0 = i00 + 2*threadIdx.x;
|
| 268 |
|
| 269 |
float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
|
| 270 |
+
if (gridDim.y == 1) {
|
| 271 |
dst_val.x /= kqsum_j;
|
| 272 |
dst_val.y /= kqsum_j;
|
| 273 |
}
|
| 274 |
+
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
|
| 275 |
+
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = dst_val.x;
|
| 276 |
+
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = dst_val.y;
|
| 277 |
}
|
| 278 |
|
| 279 |
+
if (gridDim.y != 1 && threadIdx.x == 0) {
|
| 280 |
+
dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
|
| 281 |
}
|
| 282 |
}
|
| 283 |
#else
|
|
|
|
| 285 |
#endif // FLASH_ATTN_AVAILABLE
|
| 286 |
}
|
| 287 |
|
| 288 |
+
template <int cols_per_block, bool use_logit_softcap>
|
| 289 |
void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 290 |
const ggml_tensor * Q = dst->src[0];
|
| 291 |
switch (Q->ne[0]) {
|
|
|
|
| 293 |
constexpr int D = 64;
|
| 294 |
constexpr int nwarps = 8;
|
| 295 |
constexpr size_t nbytes_shared = 0;
|
| 296 |
+
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, use_logit_softcap>;
|
| 297 |
+
launch_fattn<D, cols_per_block, 1, -1>
|
| 298 |
+
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false);
|
| 299 |
} break;
|
| 300 |
case 128: {
|
| 301 |
constexpr int D = 128;
|
| 302 |
constexpr int nwarps = 8;
|
| 303 |
constexpr size_t nbytes_shared = 0;
|
| 304 |
+
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, use_logit_softcap>;
|
| 305 |
+
launch_fattn<D, cols_per_block, 1, -1>
|
| 306 |
+
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false);
|
| 307 |
} break;
|
| 308 |
default: {
|
| 309 |
GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
|
|
|
|
| 320 |
|
| 321 |
if (Q->ne[1] <= 16) {
|
| 322 |
constexpr int cols_per_block = 16;
|
|
|
|
| 323 |
if (logit_softcap == 0.0f) {
|
| 324 |
constexpr bool use_logit_softcap = false;
|
| 325 |
+
launch_fattn_tile_f32_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
|
| 326 |
} else {
|
| 327 |
constexpr bool use_logit_softcap = true;
|
| 328 |
+
launch_fattn_tile_f32_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 329 |
}
|
| 330 |
return;
|
| 331 |
}
|
| 332 |
|
| 333 |
constexpr int cols_per_block = 32;
|
|
|
|
| 334 |
if (logit_softcap == 0.0f) {
|
| 335 |
constexpr bool use_logit_softcap = false;
|
| 336 |
+
launch_fattn_tile_f32_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
|
| 337 |
} else {
|
| 338 |
constexpr bool use_logit_softcap = true;
|
| 339 |
+
launch_fattn_tile_f32_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
|
| 340 |
}
|
| 341 |
}
|
ggml/src/ggml-cuda/fattn-vec-f16.cuh
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
#include "common.cuh"
|
| 2 |
#include "fattn-common.cuh"
|
| 3 |
|
| 4 |
-
template<int D, int ncols,
|
| 5 |
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 6 |
__launch_bounds__(D, 1)
|
| 7 |
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
|
@@ -55,17 +55,16 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
| 55 |
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
|
| 56 |
constexpr dequantize_1_f16_t dequantize_1_v = get_dequantize_1_f16(type_V);
|
| 57 |
|
| 58 |
-
const int ic0 =
|
| 59 |
-
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
| 60 |
|
| 61 |
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
| 62 |
-
Q += nb02* blockIdx.
|
| 63 |
-
K += nb12*(blockIdx.
|
| 64 |
-
V += nb22*(blockIdx.
|
| 65 |
|
| 66 |
const half * maskh = (const half *) mask + ne11*ic0;
|
| 67 |
|
| 68 |
-
const float slopef = get_alibi_slope(max_bias, blockIdx.
|
| 69 |
const half slopeh = __float2half(slopef);
|
| 70 |
|
| 71 |
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
|
@@ -172,8 +171,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
| 172 |
|
| 173 |
half2 VKQ[ncols] = {{0.0f, 0.0f}};
|
| 174 |
|
| 175 |
-
|
| 176 |
-
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
|
| 177 |
// Calculate KQ tile and keep track of new maximum KQ values:
|
| 178 |
|
| 179 |
// For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
|
|
@@ -283,29 +281,29 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
| 283 |
kqsum[j_VKQ] = warp_reduce_sum((float)kqsum[j_VKQ]);
|
| 284 |
|
| 285 |
half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ]));
|
| 286 |
-
if (
|
| 287 |
dst_val /= kqsum[j_VKQ];
|
| 288 |
}
|
| 289 |
-
const int j_dst = (ic0 + j_VKQ)*
|
| 290 |
-
dst[j_dst*D*gridDim.
|
| 291 |
}
|
| 292 |
|
| 293 |
-
if (
|
| 294 |
-
dst_meta[(ic0 + tid)*gridDim.
|
| 295 |
}
|
| 296 |
#else
|
| 297 |
NO_DEVICE_CODE;
|
| 298 |
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
| 299 |
}
|
| 300 |
|
| 301 |
-
template <int D, int cols_per_block,
|
| 302 |
void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 303 |
constexpr int nwarps = D/WARP_SIZE;
|
| 304 |
-
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block,
|
| 305 |
constexpr bool need_f16_K = D != 128;
|
| 306 |
constexpr bool need_f16_V = D != 128 && D != 64;
|
| 307 |
constexpr size_t nbytes_shared = 0;
|
| 308 |
-
launch_fattn<D, cols_per_block, 1,
|
| 309 |
}
|
| 310 |
|
| 311 |
template <int D, ggml_type type_K, ggml_type type_V>
|
|
@@ -325,65 +323,48 @@ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml
|
|
| 325 |
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
| 326 |
|
| 327 |
if (Q->ne[1] == 1) {
|
| 328 |
-
constexpr int cols_per_block
|
| 329 |
-
constexpr int parallel_blocks = 4;
|
| 330 |
if (logit_softcap == 0.0f) {
|
| 331 |
constexpr bool use_logit_softcap = false;
|
| 332 |
-
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block,
|
| 333 |
} else {
|
| 334 |
constexpr bool use_logit_softcap = true;
|
| 335 |
-
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block,
|
| 336 |
}
|
| 337 |
return;
|
| 338 |
}
|
| 339 |
|
| 340 |
if (Q->ne[1] == 2) {
|
| 341 |
-
constexpr int cols_per_block
|
| 342 |
-
constexpr int parallel_blocks = 4;
|
| 343 |
if (logit_softcap == 0.0f) {
|
| 344 |
constexpr bool use_logit_softcap = false;
|
| 345 |
-
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block,
|
| 346 |
} else {
|
| 347 |
constexpr bool use_logit_softcap = true;
|
| 348 |
-
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block,
|
| 349 |
}
|
| 350 |
return;
|
| 351 |
}
|
| 352 |
|
| 353 |
if (Q->ne[1] <= 4) {
|
| 354 |
-
constexpr int cols_per_block
|
| 355 |
-
constexpr int parallel_blocks = 4;
|
| 356 |
if (logit_softcap == 0.0f) {
|
| 357 |
constexpr bool use_logit_softcap = false;
|
| 358 |
-
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block,
|
| 359 |
} else {
|
| 360 |
constexpr bool use_logit_softcap = true;
|
| 361 |
-
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block,
|
| 362 |
}
|
| 363 |
return;
|
| 364 |
}
|
| 365 |
|
| 366 |
-
|
| 367 |
-
constexpr int cols_per_block = 8;
|
| 368 |
-
constexpr int parallel_blocks = 4;
|
| 369 |
-
if (logit_softcap == 0.0f) {
|
| 370 |
-
constexpr bool use_logit_softcap = false;
|
| 371 |
-
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
| 372 |
-
} else {
|
| 373 |
-
constexpr bool use_logit_softcap = true;
|
| 374 |
-
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
| 375 |
-
}
|
| 376 |
-
return;
|
| 377 |
-
}
|
| 378 |
-
|
| 379 |
-
constexpr int cols_per_block = 8;
|
| 380 |
-
constexpr int parallel_blocks = 1;
|
| 381 |
if (logit_softcap == 0.0f) {
|
| 382 |
constexpr bool use_logit_softcap = false;
|
| 383 |
-
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block,
|
| 384 |
} else {
|
| 385 |
constexpr bool use_logit_softcap = true;
|
| 386 |
-
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block,
|
| 387 |
}
|
| 388 |
}
|
| 389 |
|
|
|
|
| 1 |
#include "common.cuh"
|
| 2 |
#include "fattn-common.cuh"
|
| 3 |
|
| 4 |
+
template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
|
| 5 |
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 6 |
__launch_bounds__(D, 1)
|
| 7 |
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
|
|
|
| 55 |
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
|
| 56 |
constexpr dequantize_1_f16_t dequantize_1_v = get_dequantize_1_f16(type_V);
|
| 57 |
|
| 58 |
+
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
|
|
|
| 59 |
|
| 60 |
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
| 61 |
+
Q += nb02* blockIdx.z + nb01*ic0;
|
| 62 |
+
K += nb12*(blockIdx.z / gqa_ratio);
|
| 63 |
+
V += nb22*(blockIdx.z / gqa_ratio);
|
| 64 |
|
| 65 |
const half * maskh = (const half *) mask + ne11*ic0;
|
| 66 |
|
| 67 |
+
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
|
| 68 |
const half slopeh = __float2half(slopef);
|
| 69 |
|
| 70 |
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
|
|
|
| 171 |
|
| 172 |
half2 VKQ[ncols] = {{0.0f, 0.0f}};
|
| 173 |
|
| 174 |
+
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
|
|
|
|
| 175 |
// Calculate KQ tile and keep track of new maximum KQ values:
|
| 176 |
|
| 177 |
// For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
|
|
|
|
| 281 |
kqsum[j_VKQ] = warp_reduce_sum((float)kqsum[j_VKQ]);
|
| 282 |
|
| 283 |
half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ]));
|
| 284 |
+
if (gridDim.y == 1) {
|
| 285 |
dst_val /= kqsum[j_VKQ];
|
| 286 |
}
|
| 287 |
+
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
|
| 288 |
+
dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val;
|
| 289 |
}
|
| 290 |
|
| 291 |
+
if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
|
| 292 |
+
dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
|
| 293 |
}
|
| 294 |
#else
|
| 295 |
NO_DEVICE_CODE;
|
| 296 |
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
| 297 |
}
|
| 298 |
|
| 299 |
+
template <int D, int cols_per_block, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
|
| 300 |
void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 301 |
constexpr int nwarps = D/WARP_SIZE;
|
| 302 |
+
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, type_K, type_V, use_logit_softcap>;
|
| 303 |
constexpr bool need_f16_K = D != 128;
|
| 304 |
constexpr bool need_f16_V = D != 128 && D != 64;
|
| 305 |
constexpr size_t nbytes_shared = 0;
|
| 306 |
+
launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
|
| 307 |
}
|
| 308 |
|
| 309 |
template <int D, ggml_type type_K, ggml_type type_V>
|
|
|
|
| 323 |
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
| 324 |
|
| 325 |
if (Q->ne[1] == 1) {
|
| 326 |
+
constexpr int cols_per_block = 1;
|
|
|
|
| 327 |
if (logit_softcap == 0.0f) {
|
| 328 |
constexpr bool use_logit_softcap = false;
|
| 329 |
+
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
| 330 |
} else {
|
| 331 |
constexpr bool use_logit_softcap = true;
|
| 332 |
+
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
| 333 |
}
|
| 334 |
return;
|
| 335 |
}
|
| 336 |
|
| 337 |
if (Q->ne[1] == 2) {
|
| 338 |
+
constexpr int cols_per_block = 2;
|
|
|
|
| 339 |
if (logit_softcap == 0.0f) {
|
| 340 |
constexpr bool use_logit_softcap = false;
|
| 341 |
+
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
| 342 |
} else {
|
| 343 |
constexpr bool use_logit_softcap = true;
|
| 344 |
+
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
| 345 |
}
|
| 346 |
return;
|
| 347 |
}
|
| 348 |
|
| 349 |
if (Q->ne[1] <= 4) {
|
| 350 |
+
constexpr int cols_per_block = 4;
|
|
|
|
| 351 |
if (logit_softcap == 0.0f) {
|
| 352 |
constexpr bool use_logit_softcap = false;
|
| 353 |
+
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
| 354 |
} else {
|
| 355 |
constexpr bool use_logit_softcap = true;
|
| 356 |
+
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
| 357 |
}
|
| 358 |
return;
|
| 359 |
}
|
| 360 |
|
| 361 |
+
constexpr int cols_per_block = 8;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 362 |
if (logit_softcap == 0.0f) {
|
| 363 |
constexpr bool use_logit_softcap = false;
|
| 364 |
+
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
| 365 |
} else {
|
| 366 |
constexpr bool use_logit_softcap = true;
|
| 367 |
+
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
| 368 |
}
|
| 369 |
}
|
| 370 |
|
ggml/src/ggml-cuda/fattn-vec-f32.cuh
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
#include "common.cuh"
|
| 2 |
#include "fattn-common.cuh"
|
| 3 |
|
| 4 |
-
template<int D, int ncols,
|
| 5 |
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 6 |
__launch_bounds__(D, 1)
|
| 7 |
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
|
@@ -55,16 +55,15 @@ static __global__ void flash_attn_vec_ext_f32(
|
|
| 55 |
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
|
| 56 |
constexpr dequantize_1_f32_t dequantize_1_v = get_dequantize_1_f32(type_V);
|
| 57 |
|
| 58 |
-
const int ic0 =
|
| 59 |
-
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
| 60 |
|
| 61 |
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
| 62 |
-
Q += nb02* blockIdx.
|
| 63 |
-
K += nb12*(blockIdx.
|
| 64 |
-
V += nb22*(blockIdx.
|
| 65 |
const half * maskh = (const half *) mask + ne11*ic0;
|
| 66 |
|
| 67 |
-
const float slope = get_alibi_slope(max_bias, blockIdx.
|
| 68 |
|
| 69 |
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
| 70 |
constexpr int nwarps = D / WARP_SIZE;
|
|
@@ -167,8 +166,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
|
| 167 |
|
| 168 |
float VKQ[ncols] = {0.0f};
|
| 169 |
|
| 170 |
-
|
| 171 |
-
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
|
| 172 |
// Calculate KQ tile and keep track of new maximum KQ values:
|
| 173 |
|
| 174 |
float kqmax_new_arr[ncols];
|
|
@@ -268,29 +266,29 @@ static __global__ void flash_attn_vec_ext_f32(
|
|
| 268 |
kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
|
| 269 |
|
| 270 |
float dst_val = VKQ[j_VKQ];
|
| 271 |
-
if (
|
| 272 |
dst_val /= kqsum[j_VKQ];
|
| 273 |
}
|
| 274 |
-
const int j_dst = (ic0 + j_VKQ)*
|
| 275 |
-
dst[j_dst*D*gridDim.
|
| 276 |
}
|
| 277 |
|
| 278 |
-
if (
|
| 279 |
-
dst_meta[(ic0 + tid)*gridDim.
|
| 280 |
}
|
| 281 |
#else
|
| 282 |
NO_DEVICE_CODE;
|
| 283 |
#endif // FLASH_ATTN_AVAILABLE
|
| 284 |
}
|
| 285 |
|
| 286 |
-
template <int D, int cols_per_block,
|
| 287 |
void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 288 |
constexpr int nwarps = D/WARP_SIZE;
|
| 289 |
-
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block,
|
| 290 |
constexpr bool need_f16_K = D != 128;
|
| 291 |
constexpr bool need_f16_V = D != 128 && D != 64;
|
| 292 |
constexpr size_t nbytes_shared = 0;
|
| 293 |
-
launch_fattn<D, cols_per_block, 1,
|
| 294 |
}
|
| 295 |
|
| 296 |
template <int D, ggml_type type_K, ggml_type type_V>
|
|
@@ -307,65 +305,48 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
|
|
| 307 |
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
| 308 |
|
| 309 |
if (Q->ne[1] == 1) {
|
| 310 |
-
constexpr int cols_per_block
|
| 311 |
-
constexpr int parallel_blocks = 4;
|
| 312 |
if (logit_softcap == 0.0f) {
|
| 313 |
constexpr bool use_logit_softcap = false;
|
| 314 |
-
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block,
|
| 315 |
} else {
|
| 316 |
constexpr bool use_logit_softcap = true;
|
| 317 |
-
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block,
|
| 318 |
}
|
| 319 |
return;
|
| 320 |
}
|
| 321 |
|
| 322 |
if (Q->ne[1] == 2) {
|
| 323 |
-
constexpr int cols_per_block
|
| 324 |
-
constexpr int parallel_blocks = 4;
|
| 325 |
if (logit_softcap == 0.0f) {
|
| 326 |
constexpr bool use_logit_softcap = false;
|
| 327 |
-
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block,
|
| 328 |
} else {
|
| 329 |
constexpr bool use_logit_softcap = true;
|
| 330 |
-
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block,
|
| 331 |
}
|
| 332 |
return;
|
| 333 |
}
|
| 334 |
|
| 335 |
if (Q->ne[1] <= 4) {
|
| 336 |
-
constexpr int cols_per_block
|
| 337 |
-
constexpr int parallel_blocks = 4;
|
| 338 |
if (logit_softcap == 0.0f) {
|
| 339 |
constexpr bool use_logit_softcap = false;
|
| 340 |
-
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block,
|
| 341 |
} else {
|
| 342 |
constexpr bool use_logit_softcap = true;
|
| 343 |
-
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block,
|
| 344 |
}
|
| 345 |
return;
|
| 346 |
}
|
| 347 |
|
| 348 |
-
|
| 349 |
-
constexpr int cols_per_block = 8;
|
| 350 |
-
constexpr int parallel_blocks = 4;
|
| 351 |
-
if (logit_softcap == 0.0f) {
|
| 352 |
-
constexpr bool use_logit_softcap = false;
|
| 353 |
-
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
| 354 |
-
} else {
|
| 355 |
-
constexpr bool use_logit_softcap = true;
|
| 356 |
-
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
| 357 |
-
}
|
| 358 |
-
return;
|
| 359 |
-
}
|
| 360 |
-
|
| 361 |
-
constexpr int cols_per_block = 8;
|
| 362 |
-
constexpr int parallel_blocks = 1;
|
| 363 |
if (logit_softcap == 0.0f) {
|
| 364 |
constexpr bool use_logit_softcap = false;
|
| 365 |
-
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block,
|
| 366 |
} else {
|
| 367 |
constexpr bool use_logit_softcap = true;
|
| 368 |
-
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block,
|
| 369 |
}
|
| 370 |
}
|
| 371 |
|
|
|
|
| 1 |
#include "common.cuh"
|
| 2 |
#include "fattn-common.cuh"
|
| 3 |
|
| 4 |
+
template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
|
| 5 |
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 6 |
__launch_bounds__(D, 1)
|
| 7 |
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
|
|
|
| 55 |
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
|
| 56 |
constexpr dequantize_1_f32_t dequantize_1_v = get_dequantize_1_f32(type_V);
|
| 57 |
|
| 58 |
+
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
|
|
|
| 59 |
|
| 60 |
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
| 61 |
+
Q += nb02* blockIdx.z + nb01*ic0;
|
| 62 |
+
K += nb12*(blockIdx.z / gqa_ratio);
|
| 63 |
+
V += nb22*(blockIdx.z / gqa_ratio); // K and V have same shape
|
| 64 |
const half * maskh = (const half *) mask + ne11*ic0;
|
| 65 |
|
| 66 |
+
const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
|
| 67 |
|
| 68 |
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
| 69 |
constexpr int nwarps = D / WARP_SIZE;
|
|
|
|
| 166 |
|
| 167 |
float VKQ[ncols] = {0.0f};
|
| 168 |
|
| 169 |
+
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
|
|
|
|
| 170 |
// Calculate KQ tile and keep track of new maximum KQ values:
|
| 171 |
|
| 172 |
float kqmax_new_arr[ncols];
|
|
|
|
| 266 |
kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
|
| 267 |
|
| 268 |
float dst_val = VKQ[j_VKQ];
|
| 269 |
+
if (gridDim.y == 1) {
|
| 270 |
dst_val /= kqsum[j_VKQ];
|
| 271 |
}
|
| 272 |
+
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
|
| 273 |
+
dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val;
|
| 274 |
}
|
| 275 |
|
| 276 |
+
if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
|
| 277 |
+
dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
|
| 278 |
}
|
| 279 |
#else
|
| 280 |
NO_DEVICE_CODE;
|
| 281 |
#endif // FLASH_ATTN_AVAILABLE
|
| 282 |
}
|
| 283 |
|
| 284 |
+
template <int D, int cols_per_block, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
|
| 285 |
void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 286 |
constexpr int nwarps = D/WARP_SIZE;
|
| 287 |
+
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, type_K, type_V, use_logit_softcap>;
|
| 288 |
constexpr bool need_f16_K = D != 128;
|
| 289 |
constexpr bool need_f16_V = D != 128 && D != 64;
|
| 290 |
constexpr size_t nbytes_shared = 0;
|
| 291 |
+
launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
|
| 292 |
}
|
| 293 |
|
| 294 |
template <int D, ggml_type type_K, ggml_type type_V>
|
|
|
|
| 305 |
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
| 306 |
|
| 307 |
if (Q->ne[1] == 1) {
|
| 308 |
+
constexpr int cols_per_block = 1;
|
|
|
|
| 309 |
if (logit_softcap == 0.0f) {
|
| 310 |
constexpr bool use_logit_softcap = false;
|
| 311 |
+
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
| 312 |
} else {
|
| 313 |
constexpr bool use_logit_softcap = true;
|
| 314 |
+
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
| 315 |
}
|
| 316 |
return;
|
| 317 |
}
|
| 318 |
|
| 319 |
if (Q->ne[1] == 2) {
|
| 320 |
+
constexpr int cols_per_block = 2;
|
|
|
|
| 321 |
if (logit_softcap == 0.0f) {
|
| 322 |
constexpr bool use_logit_softcap = false;
|
| 323 |
+
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
| 324 |
} else {
|
| 325 |
constexpr bool use_logit_softcap = true;
|
| 326 |
+
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
| 327 |
}
|
| 328 |
return;
|
| 329 |
}
|
| 330 |
|
| 331 |
if (Q->ne[1] <= 4) {
|
| 332 |
+
constexpr int cols_per_block = 4;
|
|
|
|
| 333 |
if (logit_softcap == 0.0f) {
|
| 334 |
constexpr bool use_logit_softcap = false;
|
| 335 |
+
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
| 336 |
} else {
|
| 337 |
constexpr bool use_logit_softcap = true;
|
| 338 |
+
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
| 339 |
}
|
| 340 |
return;
|
| 341 |
}
|
| 342 |
|
| 343 |
+
constexpr int cols_per_block = 8;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 344 |
if (logit_softcap == 0.0f) {
|
| 345 |
constexpr bool use_logit_softcap = false;
|
| 346 |
+
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
| 347 |
} else {
|
| 348 |
constexpr bool use_logit_softcap = true;
|
| 349 |
+
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
|
| 350 |
}
|
| 351 |
}
|
| 352 |
|
ggml/src/ggml-cuda/fattn-wmma-f16.cu
CHANGED
|
@@ -18,7 +18,7 @@ namespace wmma = rocwmma;
|
|
| 18 |
#endif // FP16_MMA_AVAILABLE
|
| 19 |
|
| 20 |
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
|
| 21 |
-
template<int D, int ncols, int nwarps, int VKQ_stride,
|
| 22 |
__launch_bounds__(nwarps*ggml_cuda_get_physical_warp_size(), 1)
|
| 23 |
static __global__ void flash_attn_ext_f16(
|
| 24 |
const char * __restrict__ Q,
|
|
@@ -67,8 +67,7 @@ static __global__ void flash_attn_ext_f16(
|
|
| 67 |
|
| 68 |
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 69 |
|
| 70 |
-
const int ic0 = ncols*
|
| 71 |
-
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
| 72 |
|
| 73 |
static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
|
| 74 |
static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
|
|
@@ -91,16 +90,16 @@ static __global__ void flash_attn_ext_f16(
|
|
| 91 |
constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
|
| 92 |
|
| 93 |
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
| 94 |
-
const float * Q_f = (const float *) (Q + nb02* blockIdx.
|
| 95 |
-
const half * K_h = (const half *) (K + nb12*(blockIdx.
|
| 96 |
-
const half * V_h = (const half *) (V + nb12*(blockIdx.
|
| 97 |
const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0;
|
| 98 |
const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2);
|
| 99 |
|
| 100 |
const int stride_Q = nb01 / sizeof(float);
|
| 101 |
const int stride_KV = nb11 / sizeof(half);
|
| 102 |
|
| 103 |
-
const float slopef = get_alibi_slope(max_bias, blockIdx.
|
| 104 |
const half slopeh = __float2half(slopef);
|
| 105 |
const half2 slope2 = make_half2(slopef, slopef);
|
| 106 |
|
|
@@ -176,7 +175,7 @@ static __global__ void flash_attn_ext_f16(
|
|
| 176 |
__syncthreads();
|
| 177 |
|
| 178 |
// Iterate over ne11 == previous tokens:
|
| 179 |
-
for (int k_VKQ_0 =
|
| 180 |
// Calculate tile of KQ:
|
| 181 |
#pragma unroll
|
| 182 |
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
|
|
@@ -395,7 +394,7 @@ static __global__ void flash_attn_ext_f16(
|
|
| 395 |
if (ic0 + j_VKQ >= ne01) {
|
| 396 |
return;
|
| 397 |
}
|
| 398 |
-
const int j_dst = (ic0 + j_VKQ)*
|
| 399 |
|
| 400 |
float KQ_rowsum_j;
|
| 401 |
if (std::is_same<KQ_acc_t, float>::value) {
|
|
@@ -411,13 +410,13 @@ static __global__ void flash_attn_ext_f16(
|
|
| 411 |
break;
|
| 412 |
}
|
| 413 |
float dst_val = VKQ[j_VKQ*D_padded + i];
|
| 414 |
-
if (
|
| 415 |
dst_val /= KQ_rowsum_j;
|
| 416 |
}
|
| 417 |
-
dst[j_dst*gridDim.
|
| 418 |
}
|
| 419 |
|
| 420 |
-
if (
|
| 421 |
continue;
|
| 422 |
}
|
| 423 |
|
|
@@ -428,7 +427,7 @@ static __global__ void flash_attn_ext_f16(
|
|
| 428 |
dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
|
| 429 |
}
|
| 430 |
dst_meta_val.y = KQ_rowsum_j;
|
| 431 |
-
dst_meta[(ic0 + j_VKQ)*gridDim.
|
| 432 |
}
|
| 433 |
#else
|
| 434 |
NO_DEVICE_CODE;
|
|
@@ -462,60 +461,26 @@ static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
|
|
| 462 |
template <int D, int cols_per_block, typename KQ_acc_t>
|
| 463 |
void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 464 |
const ggml_tensor * KQV = dst;
|
| 465 |
-
const ggml_tensor * Q = dst->src[0];
|
| 466 |
|
| 467 |
constexpr int nwarps = 4;
|
| 468 |
|
| 469 |
constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
|
| 470 |
-
const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
|
| 471 |
-
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
|
| 472 |
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
|
| 473 |
|
| 474 |
float logit_softcap;
|
| 475 |
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
| 476 |
|
| 477 |
-
if (4*blocks_num_pb1 < 2*nsm) {
|
| 478 |
-
constexpr int parallel_blocks = 4;
|
| 479 |
-
fattn_kernel_t fattn_kernel;
|
| 480 |
-
if (logit_softcap == 0.0f) {
|
| 481 |
-
constexpr bool use_logit_softcap = false;
|
| 482 |
-
fattn_kernel = flash_attn_ext_f16<
|
| 483 |
-
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
| 484 |
-
} else {
|
| 485 |
-
constexpr bool use_logit_softcap = true;
|
| 486 |
-
fattn_kernel = flash_attn_ext_f16<
|
| 487 |
-
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
| 488 |
-
}
|
| 489 |
-
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true, warp_size);
|
| 490 |
-
return;
|
| 491 |
-
}
|
| 492 |
-
if (2*blocks_num_pb1 < 2*nsm) {
|
| 493 |
-
constexpr int parallel_blocks = 2;
|
| 494 |
-
fattn_kernel_t fattn_kernel;
|
| 495 |
-
if (logit_softcap == 0.0f) {
|
| 496 |
-
constexpr bool use_logit_softcap = false;
|
| 497 |
-
fattn_kernel = flash_attn_ext_f16<
|
| 498 |
-
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
| 499 |
-
} else {
|
| 500 |
-
constexpr bool use_logit_softcap = true;
|
| 501 |
-
fattn_kernel = flash_attn_ext_f16<
|
| 502 |
-
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
| 503 |
-
}
|
| 504 |
-
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true, warp_size);
|
| 505 |
-
return;
|
| 506 |
-
}
|
| 507 |
-
constexpr int parallel_blocks = 1;
|
| 508 |
fattn_kernel_t fattn_kernel;
|
| 509 |
if (logit_softcap == 0.0f) {
|
| 510 |
constexpr bool use_logit_softcap = false;
|
| 511 |
fattn_kernel = flash_attn_ext_f16<
|
| 512 |
-
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m),
|
| 513 |
} else {
|
| 514 |
constexpr bool use_logit_softcap = true;
|
| 515 |
fattn_kernel = flash_attn_ext_f16<
|
| 516 |
-
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m),
|
| 517 |
}
|
| 518 |
-
launch_fattn<D, cols_per_block, 1,
|
| 519 |
}
|
| 520 |
|
| 521 |
void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
|
| 18 |
#endif // FP16_MMA_AVAILABLE
|
| 19 |
|
| 20 |
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
|
| 21 |
+
template<int D, int ncols, int nwarps, int VKQ_stride, typename KQ_acc_t, bool use_logit_softcap>
|
| 22 |
__launch_bounds__(nwarps*ggml_cuda_get_physical_warp_size(), 1)
|
| 23 |
static __global__ void flash_attn_ext_f16(
|
| 24 |
const char * __restrict__ Q,
|
|
|
|
| 67 |
|
| 68 |
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 69 |
|
| 70 |
+
const int ic0 = ncols*blockIdx.x; // Index of the first Q/QKV column to work on.
|
|
|
|
| 71 |
|
| 72 |
static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
|
| 73 |
static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
|
|
|
|
| 90 |
constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
|
| 91 |
|
| 92 |
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
| 93 |
+
const float * Q_f = (const float *) (Q + nb02* blockIdx.z + nb01*ic0);
|
| 94 |
+
const half * K_h = (const half *) (K + nb12*(blockIdx.z / gqa_ratio));
|
| 95 |
+
const half * V_h = (const half *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
|
| 96 |
const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0;
|
| 97 |
const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2);
|
| 98 |
|
| 99 |
const int stride_Q = nb01 / sizeof(float);
|
| 100 |
const int stride_KV = nb11 / sizeof(half);
|
| 101 |
|
| 102 |
+
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
|
| 103 |
const half slopeh = __float2half(slopef);
|
| 104 |
const half2 slope2 = make_half2(slopef, slopef);
|
| 105 |
|
|
|
|
| 175 |
__syncthreads();
|
| 176 |
|
| 177 |
// Iterate over ne11 == previous tokens:
|
| 178 |
+
for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE) {
|
| 179 |
// Calculate tile of KQ:
|
| 180 |
#pragma unroll
|
| 181 |
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
|
|
|
|
| 394 |
if (ic0 + j_VKQ >= ne01) {
|
| 395 |
return;
|
| 396 |
}
|
| 397 |
+
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
|
| 398 |
|
| 399 |
float KQ_rowsum_j;
|
| 400 |
if (std::is_same<KQ_acc_t, float>::value) {
|
|
|
|
| 410 |
break;
|
| 411 |
}
|
| 412 |
float dst_val = VKQ[j_VKQ*D_padded + i];
|
| 413 |
+
if (gridDim.y == 1) {
|
| 414 |
dst_val /= KQ_rowsum_j;
|
| 415 |
}
|
| 416 |
+
dst[j_dst*gridDim.z*D + blockIdx.z*D + i] = dst_val;
|
| 417 |
}
|
| 418 |
|
| 419 |
+
if (gridDim.y == 1 || threadIdx.x != 0) {
|
| 420 |
continue;
|
| 421 |
}
|
| 422 |
|
|
|
|
| 427 |
dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
|
| 428 |
}
|
| 429 |
dst_meta_val.y = KQ_rowsum_j;
|
| 430 |
+
dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = dst_meta_val;
|
| 431 |
}
|
| 432 |
#else
|
| 433 |
NO_DEVICE_CODE;
|
|
|
|
| 461 |
template <int D, int cols_per_block, typename KQ_acc_t>
|
| 462 |
void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 463 |
const ggml_tensor * KQV = dst;
|
|
|
|
| 464 |
|
| 465 |
constexpr int nwarps = 4;
|
| 466 |
|
| 467 |
constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
|
|
|
|
|
|
|
| 468 |
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
|
| 469 |
|
| 470 |
float logit_softcap;
|
| 471 |
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
| 472 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 473 |
fattn_kernel_t fattn_kernel;
|
| 474 |
if (logit_softcap == 0.0f) {
|
| 475 |
constexpr bool use_logit_softcap = false;
|
| 476 |
fattn_kernel = flash_attn_ext_f16<
|
| 477 |
+
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>;
|
| 478 |
} else {
|
| 479 |
constexpr bool use_logit_softcap = true;
|
| 480 |
fattn_kernel = flash_attn_ext_f16<
|
| 481 |
+
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>;
|
| 482 |
}
|
| 483 |
+
launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, 0, FATTN_KQ_STRIDE, true, true, false, warp_size);
|
| 484 |
}
|
| 485 |
|
| 486 |
void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
ggml/src/ggml-cuda/fattn.cu
CHANGED
|
@@ -281,13 +281,13 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|
| 281 |
|
| 282 |
if (!fp16_mma_available(cc)) {
|
| 283 |
if (prec == GGML_PREC_DEFAULT) {
|
| 284 |
-
if (Q->ne[1] <= 8) {
|
| 285 |
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
| 286 |
} else {
|
| 287 |
ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
|
| 288 |
}
|
| 289 |
} else {
|
| 290 |
-
if (Q->ne[1] <= 8) {
|
| 291 |
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
| 292 |
} else {
|
| 293 |
ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
|
|
@@ -296,17 +296,17 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|
| 296 |
return;
|
| 297 |
}
|
| 298 |
|
| 299 |
-
const
|
| 300 |
-
const bool
|
| 301 |
-
|
| 302 |
-
|
|
|
|
| 303 |
if (prec == GGML_PREC_DEFAULT) {
|
| 304 |
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
| 305 |
-
|
| 306 |
-
} else if(Q->ne[0] <= 128) {
|
| 307 |
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
| 308 |
-
return;
|
| 309 |
}
|
|
|
|
| 310 |
}
|
| 311 |
|
| 312 |
// The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
|
|
|
|
| 281 |
|
| 282 |
if (!fp16_mma_available(cc)) {
|
| 283 |
if (prec == GGML_PREC_DEFAULT) {
|
| 284 |
+
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
|
| 285 |
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
| 286 |
} else {
|
| 287 |
ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
|
| 288 |
}
|
| 289 |
} else {
|
| 290 |
+
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
|
| 291 |
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
| 292 |
} else {
|
| 293 |
ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
|
|
|
|
| 296 |
return;
|
| 297 |
}
|
| 298 |
|
| 299 |
+
const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
|
| 300 |
+
const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
|
| 301 |
+
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < GGML_CUDA_CC_ADA_LOVELACE && !mma_needs_data_conversion;
|
| 302 |
+
const bool can_use_vector_kernel = (Q->ne[0] % (2*warp_size) == 0) && (prec == GGML_PREC_DEFAULT || Q->ne[0] <= 128);
|
| 303 |
+
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
|
| 304 |
if (prec == GGML_PREC_DEFAULT) {
|
| 305 |
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
| 306 |
+
} else {
|
|
|
|
| 307 |
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
|
|
|
| 308 |
}
|
| 309 |
+
return;
|
| 310 |
}
|
| 311 |
|
| 312 |
// The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
|
ggml/src/ggml-cuda/ggml-cuda.cu
CHANGED
|
@@ -3230,6 +3230,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|
| 3230 |
#ifndef FLASH_ATTN_AVAILABLE
|
| 3231 |
return false;
|
| 3232 |
#endif // FLASH_ATTN_AVAILABLE
|
|
|
|
|
|
|
|
|
|
| 3233 |
if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
|
| 3234 |
return false;
|
| 3235 |
}
|
|
|
|
| 3230 |
#ifndef FLASH_ATTN_AVAILABLE
|
| 3231 |
return false;
|
| 3232 |
#endif // FLASH_ATTN_AVAILABLE
|
| 3233 |
+
if (op->src[0]->ne[3] != 1) {
|
| 3234 |
+
return false;
|
| 3235 |
+
}
|
| 3236 |
if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
|
| 3237 |
return false;
|
| 3238 |
}
|
ggml/src/ggml-cuda/vendors/hip.h
CHANGED
|
@@ -129,6 +129,7 @@
|
|
| 129 |
#define cudaGraph_t hipGraph_t
|
| 130 |
#define cudaStream_t hipStream_t
|
| 131 |
#define cudaSuccess hipSuccess
|
|
|
|
| 132 |
#define __trap() do { abort(); __builtin_unreachable(); } while(0)
|
| 133 |
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
|
| 134 |
#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
|
|
|
|
| 129 |
#define cudaGraph_t hipGraph_t
|
| 130 |
#define cudaStream_t hipStream_t
|
| 131 |
#define cudaSuccess hipSuccess
|
| 132 |
+
#define cudaOccupancyMaxActiveBlocksPerMultiprocessor hipOccupancyMaxActiveBlocksPerMultiprocessor
|
| 133 |
#define __trap() do { abort(); __builtin_unreachable(); } while(0)
|
| 134 |
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
|
| 135 |
#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
|
ggml/src/ggml-cuda/vendors/musa.h
CHANGED
|
@@ -134,5 +134,6 @@
|
|
| 134 |
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
|
| 135 |
#define cudaStreamBeginCapture musaStreamBeginCapture
|
| 136 |
#define cudaStreamEndCapture musaStreamEndCapture
|
|
|
|
| 137 |
|
| 138 |
typedef mt_bfloat16 nv_bfloat16;
|
|
|
|
| 134 |
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
|
| 135 |
#define cudaStreamBeginCapture musaStreamBeginCapture
|
| 136 |
#define cudaStreamEndCapture musaStreamEndCapture
|
| 137 |
+
#define cudaOccupancyMaxActiveBlocksPerMultiprocessor musaOccupancyMaxActiveBlocksPerMultiprocessor
|
| 138 |
|
| 139 |
typedef mt_bfloat16 nv_bfloat16;
|