JohannesGaessler commited on
Commit
a6d9f2d
·
1 Parent(s): 195fe29

CUDA: faster large batch FA without tensor cores (llama/7314)

Browse files
ggml-cuda/fattn-tile-f16.cu ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "common.cuh"
2
+ #include "fattn-common.cuh"
3
+ #include "fattn-tile-f16.cuh"
4
+
5
+ #define FATTN_KQ_STRIDE_TILE_F16 64
6
+
7
+ template<int D, int ncols, int nwarps, int parallel_blocks> // D == head size
8
+ #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
9
+ __launch_bounds__(nwarps*WARP_SIZE, 1)
10
+ #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
11
+ static __global__ void flash_attn_tile_ext_f16(
12
+ const char * __restrict__ Q,
13
+ const char * __restrict__ K,
14
+ const char * __restrict__ V,
15
+ const char * __restrict__ mask,
16
+ float * __restrict__ dst,
17
+ float2 * __restrict__ dst_meta,
18
+ const float scale,
19
+ const float max_bias,
20
+ const float m0,
21
+ const float m1,
22
+ const uint32_t n_head_log2,
23
+ const int ne00,
24
+ const int ne01,
25
+ const int ne02,
26
+ const int ne03,
27
+ const int ne10,
28
+ const int ne11,
29
+ const int ne12,
30
+ const int ne13,
31
+ const int ne31,
32
+ const int nb31,
33
+ const int nb01,
34
+ const int nb02,
35
+ const int nb03,
36
+ const int nb11,
37
+ const int nb12,
38
+ const int nb13,
39
+ const int ne0,
40
+ const int ne1,
41
+ const int ne2,
42
+ const int ne3) {
43
+ #if FP16_AVAILABLE
44
+ //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
45
+
46
+ const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
47
+ const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
48
+
49
+ const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
50
+ const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0);
51
+ const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
52
+ const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
53
+ const half * maskh = (const half *) mask + ne11*ic0;
54
+
55
+ const int stride_KV2 = nb11 / sizeof(half2);
56
+
57
+ half slopeh = __float2half(1.0f);
58
+
59
+ // ALiBi
60
+ if (max_bias > 0.0f) {
61
+ const uint32_t h = blockIdx.y;
62
+
63
+ const float base = h < n_head_log2 ? m0 : m1;
64
+ const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
65
+
66
+ slopeh = __float2half(powf(base, exph));
67
+ }
68
+
69
+ static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
70
+
71
+ __shared__ half KQ[ncols*FATTN_KQ_STRIDE_TILE_F16];
72
+ half2 * KQ2 = (half2 *) KQ;
73
+
74
+ __shared__ half2 KV_tmp[FATTN_KQ_STRIDE_TILE_F16][D/2 + 1]; // Pad D to avoid memory bank conflicts.
75
+
76
+ half kqmax[ncols/nwarps];
77
+ #pragma unroll
78
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
79
+ kqmax[j0/nwarps] = -HALF_MAX_HALF;
80
+ }
81
+ half2 kqsum[ncols/nwarps] = {{0.0f, 0.0f}};
82
+
83
+ half2 VKQ[ncols/nwarps][(D/2)/WARP_SIZE] = {{{0.0f, 0.0f}}};
84
+
85
+ // Convert Q to half2 and store in registers:
86
+ __shared__ half2 Q_h2[ncols][D/2];
87
+ #pragma unroll
88
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
89
+ const int j = j0 + threadIdx.y;
90
+
91
+ #pragma unroll
92
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
93
+ const int i = i0 + threadIdx.x;
94
+
95
+ const float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i];
96
+ Q_h2[j][i] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
97
+ }
98
+ }
99
+
100
+ __syncthreads();
101
+
102
+ const int k_start = parallel_blocks == 1 ? 0 : ip*FATTN_KQ_STRIDE_TILE_F16;
103
+ for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE_TILE_F16) {
104
+ // Calculate KQ tile and keep track of new maximum KQ values:
105
+
106
+ half kqmax_new[ncols/nwarps];
107
+ #pragma unroll
108
+ for (int j = 0; j < ncols/nwarps; ++j) {
109
+ kqmax_new[j] = kqmax[j];
110
+ }
111
+
112
+ #pragma unroll
113
+ for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += nwarps) {
114
+ const int i_KQ = i_KQ_0 + threadIdx.y;
115
+
116
+ #pragma unroll
117
+ for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
118
+ const int k_KQ = k_KQ_0 + threadIdx.x;
119
+
120
+ KV_tmp[i_KQ][k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
121
+ }
122
+ }
123
+
124
+ __syncthreads();
125
+
126
+ half2 sum2[FATTN_KQ_STRIDE_TILE_F16/WARP_SIZE][ncols/nwarps] = {{{0.0f, 0.0f}}};
127
+
128
+ #pragma unroll
129
+ for (int k_KQ = 0; k_KQ < D/2; ++k_KQ) {
130
+ half2 K_k[FATTN_KQ_STRIDE_TILE_F16/WARP_SIZE];
131
+ half2 Q_k[ncols/nwarps];
132
+
133
+ #pragma unroll
134
+ for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) {
135
+ const int i_KQ = i_KQ_0 + threadIdx.x;
136
+
137
+ K_k[i_KQ_0/WARP_SIZE] = KV_tmp[i_KQ][k_KQ];
138
+ }
139
+ #pragma unroll
140
+ for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
141
+ const int j_KQ = j_KQ_0 + threadIdx.y;
142
+
143
+ Q_k[j_KQ_0/nwarps] = Q_h2[j_KQ][k_KQ];
144
+ }
145
+
146
+ #pragma unroll
147
+ for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) {
148
+ #pragma unroll
149
+ for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
150
+ sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += K_k[i_KQ_0/WARP_SIZE]*Q_k[j_KQ_0/nwarps];
151
+ }
152
+ }
153
+ }
154
+
155
+ #pragma unroll
156
+ for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) {
157
+ const int i_KQ = i_KQ_0 + threadIdx.x;
158
+
159
+ #pragma unroll
160
+ for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
161
+ const int j_KQ = j_KQ_0 + threadIdx.y;
162
+
163
+ half sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
164
+ sum += mask ? slopeh*maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
165
+
166
+ kqmax_new[j_KQ_0/nwarps] = ggml_cuda_hmax(kqmax_new[j_KQ_0/nwarps], sum);
167
+
168
+ KQ[j_KQ*FATTN_KQ_STRIDE_TILE_F16 + i_KQ] = sum;
169
+ }
170
+ }
171
+
172
+ __syncthreads();
173
+
174
+ #pragma unroll
175
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
176
+ const int j = j0 + threadIdx.y;
177
+
178
+ kqmax_new[j0/nwarps] = warp_reduce_max(kqmax_new[j0/nwarps]);
179
+ const half2 KQ_max_scale = __half2half2(hexp(kqmax[j0/nwarps] - kqmax_new[j0/nwarps]));
180
+ kqmax[j0/nwarps] = kqmax_new[j0/nwarps];
181
+
182
+ #pragma unroll
183
+ for (int i0 = 0; i0 < FATTN_KQ_STRIDE_TILE_F16/2; i0 += WARP_SIZE) {
184
+ const int i = i0 + threadIdx.x;
185
+
186
+ const half2 diff = KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + i] - __half2half2(kqmax[j0/nwarps]);
187
+ const half2 val = h2exp(diff);
188
+ kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + val;
189
+ KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + i] = val;
190
+ }
191
+
192
+ #pragma unroll
193
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
194
+ VKQ[j0/nwarps][i0/WARP_SIZE] *= KQ_max_scale;
195
+ }
196
+ }
197
+
198
+ __syncthreads();
199
+
200
+ #pragma unroll
201
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F16; k0 += nwarps) {
202
+ const int k = k0 + threadIdx.y;
203
+
204
+ #pragma unroll
205
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
206
+ const int i = i0 + threadIdx.x;
207
+
208
+ KV_tmp[k][i] = V_h2[(k_VKQ_0 + k)*stride_KV2 + i];
209
+ }
210
+ }
211
+
212
+ __syncthreads();
213
+
214
+ #pragma unroll
215
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F16; k0 += 2) {
216
+ half2 V_k[(D/2)/WARP_SIZE][2];
217
+ half2 KQ_k[ncols/nwarps];
218
+
219
+ #pragma unroll
220
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
221
+ const int i = i0 + threadIdx.x;
222
+
223
+ V_k[i0/WARP_SIZE][0] = KV_tmp[k0 + 0][i];
224
+ V_k[i0/WARP_SIZE][1] = KV_tmp[k0 + 1][i];
225
+ }
226
+ #pragma unroll
227
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
228
+ const int j = j0 + threadIdx.y;
229
+
230
+ KQ_k[j0/nwarps] = KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + k0/2];
231
+ }
232
+
233
+ #pragma unroll
234
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
235
+ #pragma unroll
236
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
237
+ VKQ[j0/nwarps][i0/WARP_SIZE] += V_k[i0/WARP_SIZE][0]* __low2half2(KQ_k[j0/nwarps]);
238
+ VKQ[j0/nwarps][i0/WARP_SIZE] += V_k[i0/WARP_SIZE][1]*__high2half2(KQ_k[j0/nwarps]);
239
+ }
240
+ }
241
+ }
242
+
243
+ __syncthreads();
244
+ }
245
+
246
+ #pragma unroll
247
+ for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
248
+ const int j_VKQ = j_VKQ_0 + threadIdx.y;
249
+
250
+ half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
251
+ kqsum_j = warp_reduce_sum(kqsum_j);
252
+
253
+ #pragma unroll
254
+ for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
255
+ const int i0 = i00 + 2*threadIdx.x;
256
+
257
+ half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
258
+ if (parallel_blocks == 1) {
259
+ dst_val /= __half2half2(kqsum_j);
260
+ }
261
+ const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
262
+ dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = __low2float(dst_val);
263
+ dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = __high2float(dst_val);
264
+ }
265
+
266
+ if (parallel_blocks != 1 && threadIdx.x == 0) {
267
+ dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
268
+ }
269
+ }
270
+ #else
271
+ NO_DEVICE_CODE;
272
+ #endif // FP16_AVAILABLE
273
+ }
274
+
275
+ template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_tile_f16(
276
+ const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
277
+ ggml_cuda_pool & pool, cudaStream_t main_stream
278
+ ) {
279
+ ggml_cuda_pool_alloc<float> dst_tmp(pool);
280
+ ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
281
+
282
+ if (parallel_blocks > 1) {
283
+ dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
284
+ dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
285
+ }
286
+
287
+ constexpr int nwarps = 8;
288
+ const dim3 block_dim(WARP_SIZE, nwarps, 1);
289
+ const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
290
+ const int shmem = 0;
291
+
292
+ float scale = 1.0f;
293
+ float max_bias = 0.0f;
294
+
295
+ memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
296
+ memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
297
+
298
+ const uint32_t n_head = Q->ne[2];
299
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
300
+
301
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
302
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
303
+
304
+ flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>
305
+ <<<blocks_num, block_dim, shmem, main_stream>>> (
306
+ (const char *) Q->data,
307
+ (const char *) K->data,
308
+ (const char *) V->data,
309
+ mask ? ((const char *) mask->data) : nullptr,
310
+ parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
311
+ scale, max_bias, m0, m1, n_head_log2,
312
+ Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
313
+ K->ne[0], K->ne[1], K->ne[2], K->ne[3],
314
+ mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
315
+ Q->nb[1], Q->nb[2], Q->nb[3],
316
+ K->nb[1], K->nb[2], K->nb[3],
317
+ KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
318
+ );
319
+ CUDA_CHECK(cudaGetLastError());
320
+
321
+ if (parallel_blocks == 1) {
322
+ return;
323
+ }
324
+
325
+ const dim3 block_dim_combine(D, 1, 1);
326
+ const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
327
+ const int shmem_combine = 0;
328
+
329
+ flash_attn_combine_results<D, parallel_blocks>
330
+ <<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
331
+ (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
332
+ CUDA_CHECK(cudaGetLastError());
333
+ }
334
+
335
+ void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
336
+ const ggml_tensor * Q = dst->src[0];
337
+ const ggml_tensor * K = dst->src[1];
338
+ const ggml_tensor * V = dst->src[2];
339
+
340
+ const ggml_tensor * mask = dst->src[3];
341
+
342
+ ggml_tensor * KQV = dst;
343
+
344
+ const int32_t precision = KQV->op_params[2];
345
+ GGML_ASSERT(precision == GGML_PREC_DEFAULT);
346
+ GGML_ASSERT(Q->ne[0] == 64 || Q->ne[0] == 128 && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
347
+
348
+ if (Q->ne[1] <= 16) {
349
+ constexpr int cols_per_block = 16;
350
+ constexpr int parallel_blocks = 4;
351
+ switch (Q->ne[0]) {
352
+ case 64:
353
+ launch_fattn_tile_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
354
+ break;
355
+ case 128:
356
+ launch_fattn_tile_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
357
+ break;
358
+ default:
359
+ GGML_ASSERT(false);
360
+ break;
361
+ }
362
+ return;
363
+ }
364
+
365
+ if (Q->ne[1] <= 32) {
366
+ constexpr int cols_per_block = 32;
367
+ constexpr int parallel_blocks = 4;
368
+ switch (Q->ne[0]) {
369
+ case 64:
370
+ launch_fattn_tile_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
371
+ break;
372
+ case 128:
373
+ launch_fattn_tile_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
374
+ break;
375
+ default:
376
+ GGML_ASSERT(false);
377
+ break;
378
+ }
379
+ return;
380
+ }
381
+
382
+ constexpr int cols_per_block = 32;
383
+ constexpr int parallel_blocks = 1;
384
+ switch (Q->ne[0]) {
385
+ case 64:
386
+ launch_fattn_tile_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
387
+ break;
388
+ case 128:
389
+ launch_fattn_tile_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
390
+ break;
391
+ default:
392
+ GGML_ASSERT(false);
393
+ break;
394
+ }
395
+ }
ggml-cuda/fattn-tile-f16.cuh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #include "common.cuh"
2
+
3
+ void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
ggml-cuda/fattn-tile-f32.cu ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "common.cuh"
2
+ #include "fattn-common.cuh"
3
+ #include "fattn-tile-f32.cuh"
4
+
5
+ #define FATTN_KQ_STRIDE_TILE_F32 32
6
+
7
+ template<int D, int ncols, int nwarps, int parallel_blocks> // D == head size
8
+ #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
9
+ __launch_bounds__(nwarps*WARP_SIZE, 1)
10
+ #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
11
+ static __global__ void flash_attn_tile_ext_f32(
12
+ const char * __restrict__ Q,
13
+ const char * __restrict__ K,
14
+ const char * __restrict__ V,
15
+ const char * __restrict__ mask,
16
+ float * __restrict__ dst,
17
+ float2 * __restrict__ dst_meta,
18
+ const float scale,
19
+ const float max_bias,
20
+ const float m0,
21
+ const float m1,
22
+ const uint32_t n_head_log2,
23
+ const int ne00,
24
+ const int ne01,
25
+ const int ne02,
26
+ const int ne03,
27
+ const int ne10,
28
+ const int ne11,
29
+ const int ne12,
30
+ const int ne13,
31
+ const int ne31,
32
+ const int nb31,
33
+ const int nb01,
34
+ const int nb02,
35
+ const int nb03,
36
+ const int nb11,
37
+ const int nb12,
38
+ const int nb13,
39
+ const int ne0,
40
+ const int ne1,
41
+ const int ne2,
42
+ const int ne3) {
43
+ //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
44
+
45
+ const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
46
+ const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
47
+
48
+ const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
49
+ const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0);
50
+ const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
51
+ const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
52
+ const half * maskh = (const half *) mask + ne11*ic0;
53
+
54
+ const int stride_KV2 = nb11 / sizeof(half2);
55
+
56
+ float slope = 1.0f;
57
+
58
+ // ALiBi
59
+ if (max_bias > 0.0f) {
60
+ const uint32_t h = blockIdx.y;
61
+
62
+ const float base = h < n_head_log2 ? m0 : m1;
63
+ const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
64
+
65
+ slope = powf(base, exph);
66
+ }
67
+
68
+ static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
69
+
70
+ __shared__ float KQ[ncols*FATTN_KQ_STRIDE_TILE_F32];
71
+
72
+ __shared__ float KV_tmp[FATTN_KQ_STRIDE_TILE_F32][D + 1]; // Pad D to avoid memory bank conflicts.
73
+ float2 * KV_tmp2 = (float2 *) KV_tmp;
74
+
75
+ float kqmax[ncols/nwarps];
76
+ #pragma unroll
77
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
78
+ kqmax[j0/nwarps] = -FLT_MAX/2.0f;
79
+ }
80
+ float kqsum[ncols/nwarps] = {0.0f};
81
+
82
+ float2 VKQ[ncols/nwarps][(D/2)/WARP_SIZE] = {{{0.0f, 0.0f}}};
83
+
84
+ // Convert Q to half2 and store in registers:
85
+ __shared__ float Q_f[ncols][D];
86
+ #pragma unroll
87
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
88
+ const int j = j0 + threadIdx.y;
89
+
90
+ #pragma unroll
91
+ for (int i0 = 0; i0 < D; i0 += 2*WARP_SIZE) {
92
+ float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i0/2 + threadIdx.x];
93
+ Q_f[j][i0 + 0*WARP_SIZE + threadIdx.x] = tmp.x * scale;
94
+ Q_f[j][i0 + 1*WARP_SIZE + threadIdx.x] = tmp.y * scale;
95
+ }
96
+ }
97
+
98
+ __syncthreads();
99
+
100
+ const int k_start = parallel_blocks == 1 ? 0 : ip*FATTN_KQ_STRIDE_TILE_F32;
101
+ for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE_TILE_F32) {
102
+ // Calculate KQ tile and keep track of new maximum KQ values:
103
+
104
+ float kqmax_new[ncols/nwarps];
105
+ #pragma unroll
106
+ for (int j = 0; j < ncols/nwarps; ++j) {
107
+ kqmax_new[j] = kqmax[j];
108
+ }
109
+
110
+ #pragma unroll
111
+ for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += nwarps) {
112
+ const int i_KQ = i_KQ_0 + threadIdx.y;
113
+
114
+ #pragma unroll
115
+ for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 2*WARP_SIZE) {
116
+ const half2 tmp = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x];
117
+ KV_tmp[i_KQ][k_KQ_0 + 0*WARP_SIZE + threadIdx.x] = __low2float(tmp);
118
+ KV_tmp[i_KQ][k_KQ_0 + 1*WARP_SIZE + threadIdx.x] = __high2float(tmp);
119
+ }
120
+ }
121
+
122
+ __syncthreads();
123
+
124
+ float sum[FATTN_KQ_STRIDE_TILE_F32/WARP_SIZE][ncols/nwarps] = {{0.0f}};
125
+
126
+ #pragma unroll
127
+ for (int k_KQ = 0; k_KQ < D; ++k_KQ) {
128
+ float K_k[FATTN_KQ_STRIDE_TILE_F32/WARP_SIZE];
129
+ float Q_k[ncols/nwarps];
130
+
131
+ #pragma unroll
132
+ for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += WARP_SIZE) {
133
+ const int i_KQ = i_KQ_0 + threadIdx.x;
134
+
135
+ K_k[i_KQ_0/WARP_SIZE] = KV_tmp[i_KQ][k_KQ];
136
+ }
137
+ #pragma unroll
138
+ for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
139
+ const int j_KQ = j_KQ_0 + threadIdx.y;
140
+
141
+ Q_k[j_KQ_0/nwarps] = Q_f[j_KQ][k_KQ];
142
+ }
143
+
144
+ #pragma unroll
145
+ for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += WARP_SIZE) {
146
+ #pragma unroll
147
+ for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
148
+ sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += K_k[i_KQ_0/WARP_SIZE] * Q_k[j_KQ_0/nwarps];
149
+ }
150
+ }
151
+ }
152
+
153
+ #pragma unroll
154
+ for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += WARP_SIZE) {
155
+ const int i_KQ = i_KQ_0 + threadIdx.x;
156
+
157
+ #pragma unroll
158
+ for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
159
+ const int j_KQ = j_KQ_0 + threadIdx.y;
160
+
161
+ sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
162
+
163
+ kqmax_new[j_KQ_0/nwarps] = fmaxf(kqmax_new[j_KQ_0/nwarps], sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
164
+
165
+ KQ[j_KQ*FATTN_KQ_STRIDE_TILE_F32 + i_KQ] = sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps];
166
+ }
167
+ }
168
+
169
+ __syncthreads();
170
+
171
+ #pragma unroll
172
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
173
+ const int j = j0 + threadIdx.y;
174
+
175
+ kqmax_new[j0/nwarps] = warp_reduce_max(kqmax_new[j0/nwarps]);
176
+ const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new[j0/nwarps]);
177
+ kqmax[j0/nwarps] = kqmax_new[j0/nwarps];
178
+
179
+ float kqsum_add = 0.0f;
180
+ #pragma unroll
181
+ for (int i0 = 0; i0 < FATTN_KQ_STRIDE_TILE_F32; i0 += WARP_SIZE) {
182
+ const int i = i0 + threadIdx.x;
183
+
184
+ const float diff = KQ[j*FATTN_KQ_STRIDE_TILE_F32 + i] - kqmax[j0/nwarps];
185
+ const float val = expf(diff);
186
+ kqsum_add += val;
187
+ KQ[j*FATTN_KQ_STRIDE_TILE_F32 + i] = val;
188
+ }
189
+ kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + kqsum_add;
190
+
191
+ #pragma unroll
192
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
193
+ VKQ[j0/nwarps][i0/WARP_SIZE].x *= KQ_max_scale;
194
+ VKQ[j0/nwarps][i0/WARP_SIZE].y *= KQ_max_scale;
195
+ }
196
+ }
197
+
198
+ __syncthreads();
199
+
200
+ #pragma unroll
201
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F32; k0 += nwarps) {
202
+ const int k = k0 + threadIdx.y;
203
+
204
+ #pragma unroll
205
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
206
+ const int i = i0 + threadIdx.x;
207
+
208
+ KV_tmp2[k*(D/2) + i].x = __low2float(V_h2[(k_VKQ_0 + k)*stride_KV2 + i]);
209
+ KV_tmp2[k*(D/2) + i].y = __high2float(V_h2[(k_VKQ_0 + k)*stride_KV2 + i]);
210
+ }
211
+ }
212
+
213
+ __syncthreads();
214
+
215
+ #pragma unroll
216
+ for (int k = 0; k < FATTN_KQ_STRIDE_TILE_F32; ++k) {
217
+ float2 V_k[(D/2)/WARP_SIZE];
218
+ float KQ_k[ncols/nwarps];
219
+
220
+ #pragma unroll
221
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
222
+ const int i = i0 + threadIdx.x;
223
+
224
+ V_k[i0/WARP_SIZE] = KV_tmp2[k*(D/2) + i];
225
+ }
226
+ #pragma unroll
227
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
228
+ const int j = j0 + threadIdx.y;
229
+
230
+ KQ_k[j0/nwarps] = KQ[j*FATTN_KQ_STRIDE_TILE_F32 + k];
231
+ }
232
+
233
+ #pragma unroll
234
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
235
+ #pragma unroll
236
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
237
+ VKQ[j0/nwarps][i0/WARP_SIZE].x += V_k[i0/WARP_SIZE].x*KQ_k[j0/nwarps];
238
+ VKQ[j0/nwarps][i0/WARP_SIZE].y += V_k[i0/WARP_SIZE].y*KQ_k[j0/nwarps];
239
+ }
240
+ }
241
+ }
242
+
243
+ __syncthreads();
244
+ }
245
+
246
+ #pragma unroll
247
+ for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
248
+ const int j_VKQ = j_VKQ_0 + threadIdx.y;
249
+
250
+ float kqsum_j = kqsum[j_VKQ_0/nwarps];
251
+ kqsum_j = warp_reduce_sum(kqsum_j);
252
+
253
+ #pragma unroll
254
+ for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
255
+ const int i0 = i00 + 2*threadIdx.x;
256
+
257
+ float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
258
+ if (parallel_blocks == 1) {
259
+ dst_val.x /= kqsum_j;
260
+ dst_val.y /= kqsum_j;
261
+ }
262
+ const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
263
+ dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = dst_val.x;
264
+ dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = dst_val.y;
265
+ }
266
+
267
+ if (parallel_blocks != 1 && threadIdx.x == 0) {
268
+ dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
269
+ }
270
+ }
271
+ }
272
+
273
+ template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_tile_f32(
274
+ const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
275
+ ggml_cuda_pool & pool, cudaStream_t main_stream
276
+ ) {
277
+ ggml_cuda_pool_alloc<float> dst_tmp(pool);
278
+ ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
279
+
280
+ if (parallel_blocks > 1) {
281
+ dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
282
+ dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
283
+ }
284
+
285
+ constexpr int nwarps = 8;
286
+ const dim3 block_dim(WARP_SIZE, nwarps, 1);
287
+ const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
288
+ const int shmem = 0;
289
+
290
+ float scale = 1.0f;
291
+ float max_bias = 0.0f;
292
+
293
+ memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
294
+ memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
295
+
296
+ const uint32_t n_head = Q->ne[2];
297
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
298
+
299
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
300
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
301
+
302
+ flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>
303
+ <<<blocks_num, block_dim, shmem, main_stream>>> (
304
+ (const char *) Q->data,
305
+ (const char *) K->data,
306
+ (const char *) V->data,
307
+ mask ? ((const char *) mask->data) : nullptr,
308
+ parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
309
+ scale, max_bias, m0, m1, n_head_log2,
310
+ Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
311
+ K->ne[0], K->ne[1], K->ne[2], K->ne[3],
312
+ mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
313
+ Q->nb[1], Q->nb[2], Q->nb[3],
314
+ K->nb[1], K->nb[2], K->nb[3],
315
+ KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
316
+ );
317
+ CUDA_CHECK(cudaGetLastError());
318
+
319
+ if (parallel_blocks == 1) {
320
+ return;
321
+ }
322
+
323
+ const dim3 block_dim_combine(D, 1, 1);
324
+ const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
325
+ const int shmem_combine = 0;
326
+
327
+ flash_attn_combine_results<D, parallel_blocks>
328
+ <<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
329
+ (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
330
+ CUDA_CHECK(cudaGetLastError());
331
+ }
332
+
333
+ void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
334
+ const ggml_tensor * Q = dst->src[0];
335
+ const ggml_tensor * K = dst->src[1];
336
+ const ggml_tensor * V = dst->src[2];
337
+
338
+ const ggml_tensor * mask = dst->src[3];
339
+
340
+ ggml_tensor * KQV = dst;
341
+
342
+ const int32_t precision = KQV->op_params[2];
343
+ GGML_ASSERT(precision == GGML_PREC_DEFAULT);
344
+ GGML_ASSERT(Q->ne[0] == 64 || Q->ne[0] == 128 && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
345
+
346
+ if (Q->ne[1] <= 16) {
347
+ constexpr int cols_per_block = 16;
348
+ constexpr int parallel_blocks = 4;
349
+ switch (Q->ne[0]) {
350
+ case 64:
351
+ launch_fattn_tile_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
352
+ break;
353
+ case 128:
354
+ launch_fattn_tile_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
355
+ break;
356
+ default:
357
+ GGML_ASSERT(false);
358
+ break;
359
+ }
360
+ return;
361
+ }
362
+
363
+ if (Q->ne[1] <= 32) {
364
+ constexpr int cols_per_block = 32;
365
+ constexpr int parallel_blocks = 4;
366
+ switch (Q->ne[0]) {
367
+ case 64:
368
+ launch_fattn_tile_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
369
+ break;
370
+ case 128:
371
+ launch_fattn_tile_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
372
+ break;
373
+ default:
374
+ GGML_ASSERT(false);
375
+ break;
376
+ }
377
+ return;
378
+ }
379
+
380
+ constexpr int cols_per_block = 32;
381
+ constexpr int parallel_blocks = 1;
382
+ switch (Q->ne[0]) {
383
+ case 64:
384
+ launch_fattn_tile_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
385
+ break;
386
+ case 128:
387
+ launch_fattn_tile_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
388
+ break;
389
+ default:
390
+ GGML_ASSERT(false);
391
+ break;
392
+ }
393
+ }
ggml-cuda/fattn-tile-f32.cuh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #include "common.cuh"
2
+
3
+ void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
ggml-cuda/fattn.cu CHANGED
@@ -1,5 +1,7 @@
1
  #include "common.cuh"
2
  #include "fattn-common.cuh"
 
 
3
  #include "fattn-vec-f16.cuh"
4
  #include "fattn-vec-f32.cuh"
5
  #include "fattn.cuh"
@@ -88,7 +90,7 @@ static __global__ void flash_attn_ext_f16(
88
 
89
  // ALiBi
90
  if (max_bias > 0.0f) {
91
- const int h = blockIdx.y;
92
 
93
  const float base = h < n_head_log2 ? m0 : m1;
94
  const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
@@ -541,13 +543,31 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
541
 
542
  const int32_t precision = KQV->op_params[2];
543
 
 
 
 
 
 
 
 
 
 
 
544
  if (!fast_fp16_available(cc)) {
545
- ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
 
 
 
 
546
  return;
547
  }
548
 
549
  if (!fp16_mma_available(cc)) {
550
- ggml_cuda_flash_attn_ext_vec_f16_no_mma(ctx, dst);
 
 
 
 
551
  return;
552
  }
553
 
 
1
  #include "common.cuh"
2
  #include "fattn-common.cuh"
3
+ #include "fattn-tile-f16.cuh"
4
+ #include "fattn-tile-f32.cuh"
5
  #include "fattn-vec-f16.cuh"
6
  #include "fattn-vec-f32.cuh"
7
  #include "fattn.cuh"
 
90
 
91
  // ALiBi
92
  if (max_bias > 0.0f) {
93
+ const uint32_t h = blockIdx.y;
94
 
95
  const float base = h < n_head_log2 ? m0 : m1;
96
  const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
 
543
 
544
  const int32_t precision = KQV->op_params[2];
545
 
546
+ // On AMD the tile kernels perform poorly, use the vec kernel instead:
547
+ if (cc >= CC_OFFSET_AMD) {
548
+ if (precision == GGML_PREC_DEFAULT) {
549
+ ggml_cuda_flash_attn_ext_vec_f16_no_mma(ctx, dst);
550
+ } else {
551
+ ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
552
+ }
553
+ return;
554
+ }
555
+
556
  if (!fast_fp16_available(cc)) {
557
+ if (Q->ne[1] <= 8) {
558
+ ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
559
+ } else {
560
+ ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
561
+ }
562
  return;
563
  }
564
 
565
  if (!fp16_mma_available(cc)) {
566
+ if (Q->ne[1] <= 8) {
567
+ ggml_cuda_flash_attn_ext_vec_f16_no_mma(ctx, dst);
568
+ } else {
569
+ ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
570
+ }
571
  return;
572
  }
573