Spaces:
Running
Running
| typedef void (* fattn_kernel_t)( | |
| const char * __restrict__ Q, | |
| const char * __restrict__ K, | |
| const char * __restrict__ V, | |
| const char * __restrict__ mask, | |
| float * __restrict__ dst, | |
| float2 * __restrict__ dst_meta, | |
| const float scale, | |
| const float max_bias, | |
| const float m0, | |
| const float m1, | |
| const uint32_t n_head_log2, | |
| const int ne00, | |
| const int ne01, | |
| const int ne02, | |
| const int ne03, | |
| const int ne10, | |
| const int ne11, | |
| const int ne12, | |
| const int ne13, | |
| const int ne31, | |
| const int nb31, | |
| const int nb01, | |
| const int nb02, | |
| const int nb03, | |
| const int nb11, | |
| const int nb12, | |
| const int nb13, | |
| const int ne0, | |
| const int ne1, | |
| const int ne2, | |
| const int ne3); | |
| template<int D, int parallel_blocks> // D == head size | |
| __launch_bounds__(D, 1) | |
| static __global__ void flash_attn_combine_results( | |
| const float * __restrict__ VKQ_parts, | |
| const float2 * __restrict__ VKQ_meta, | |
| float * __restrict__ dst) { | |
| VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x; | |
| VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x; | |
| dst += D * gridDim.y*blockIdx.x; | |
| const int tid = threadIdx.x; | |
| __builtin_assume(tid < D); | |
| __shared__ float2 meta[parallel_blocks]; | |
| if (tid < 2*parallel_blocks) { | |
| ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid]; | |
| } | |
| __syncthreads(); | |
| float kqmax = meta[0].x; | |
| for (int l = 1; l < parallel_blocks; ++l) { | |
| kqmax = max(kqmax, meta[l].x); | |
| } | |
| float VKQ_numerator = 0.0f; | |
| float VKQ_denominator = 0.0f; | |
| for (int l = 0; l < parallel_blocks; ++l) { | |
| const float diff = meta[l].x - kqmax; | |
| const float KQ_max_scale = expf(diff); | |
| const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD); | |
| *((uint32_t *) &KQ_max_scale) &= ftz_mask; | |
| VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid]; | |
| VKQ_denominator += KQ_max_scale * meta[l].y; | |
| } | |
| dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator; | |
| } | |
| template <int D, int parallel_blocks> | |
| void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, int nwarps, int cols_per_block) { | |
| const ggml_tensor * Q = dst->src[0]; | |
| const ggml_tensor * K = dst->src[1]; | |
| const ggml_tensor * V = dst->src[2]; | |
| const ggml_tensor * mask = dst->src[3]; | |
| ggml_tensor * KQV = dst; | |
| GGML_ASSERT(Q->type == GGML_TYPE_F32); | |
| GGML_ASSERT(K->type == GGML_TYPE_F16); | |
| GGML_ASSERT(V->type == GGML_TYPE_F16); | |
| GGML_ASSERT(KQV->type == GGML_TYPE_F32); | |
| GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); | |
| GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && | |
| "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); | |
| GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding."); | |
| ggml_cuda_pool & pool = ctx.pool(); | |
| cudaStream_t main_stream = ctx.stream(); | |
| ggml_cuda_pool_alloc<float> dst_tmp(pool); | |
| ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool); | |
| if (parallel_blocks > 1) { | |
| dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); | |
| dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); | |
| } | |
| const dim3 block_dim(WARP_SIZE, nwarps, 1); | |
| const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]); | |
| const int shmem = 0; | |
| float scale = 1.0f; | |
| float max_bias = 0.0f; | |
| memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float)); | |
| memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float)); | |
| const uint32_t n_head = Q->ne[2]; | |
| const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); | |
| const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); | |
| const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); | |
| fattn_kernel<<<blocks_num, block_dim, shmem, main_stream>>>( | |
| (const char *) Q->data, | |
| (const char *) K->data, | |
| (const char *) V->data, | |
| mask ? ((const char *) mask->data) : nullptr, | |
| (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, | |
| scale, max_bias, m0, m1, n_head_log2, | |
| Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], | |
| K->ne[0], K->ne[1], K->ne[2], K->ne[3], | |
| mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, | |
| Q->nb[1], Q->nb[2], Q->nb[3], | |
| K->nb[1], K->nb[2], K->nb[3], | |
| KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] | |
| ); | |
| CUDA_CHECK(cudaGetLastError()); | |
| if ((parallel_blocks) == 1) { | |
| return; | |
| } | |
| const dim3 block_dim_combine(D, 1, 1); | |
| const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z); | |
| const int shmem_combine = 0; | |
| flash_attn_combine_results<D, parallel_blocks> | |
| <<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>> | |
| (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); | |
| CUDA_CHECK(cudaGetLastError()); | |
| } | |