Spaces:
Running
Running
Commit
·
ad8b504
1
Parent(s):
97d9aa6
vulkan: move common FA code to flash_attn_base.comp (llama/13556)
Browse files* vulkan: move common FA code to flash_attn_base.comp
* vulkan: move common FA index/stride setup code to flash_attn_base.comp
* build fix
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp
CHANGED
|
@@ -9,60 +9,13 @@
|
|
| 9 |
#extension GL_KHR_shader_subgroup_shuffle : enable
|
| 10 |
|
| 11 |
#include "types.comp"
|
|
|
|
| 12 |
|
| 13 |
-
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
| 14 |
-
|
| 15 |
-
layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
|
| 16 |
-
layout (constant_id = 1) const uint32_t Br = 1;
|
| 17 |
-
layout (constant_id = 2) const uint32_t Bc = 32;
|
| 18 |
-
layout (constant_id = 3) const uint32_t D = 32;
|
| 19 |
-
|
| 20 |
-
layout (constant_id = 5) const uint32_t D_split = 16;
|
| 21 |
const uint32_t D_per_thread = D / D_split;
|
| 22 |
|
| 23 |
const uint32_t cols_per_iter = WorkGroupSize / D_split;
|
| 24 |
const uint32_t cols_per_thread = Bc / cols_per_iter;
|
| 25 |
|
| 26 |
-
layout (push_constant) uniform parameter {
|
| 27 |
-
uint32_t N;
|
| 28 |
-
uint32_t KV;
|
| 29 |
-
|
| 30 |
-
uint32_t ne1;
|
| 31 |
-
uint32_t ne2;
|
| 32 |
-
uint32_t ne3;
|
| 33 |
-
|
| 34 |
-
uint32_t neq2;
|
| 35 |
-
uint32_t neq3;
|
| 36 |
-
uint32_t nek2;
|
| 37 |
-
uint32_t nek3;
|
| 38 |
-
uint32_t nev2;
|
| 39 |
-
uint32_t nev3;
|
| 40 |
-
uint32_t nem1;
|
| 41 |
-
|
| 42 |
-
uint32_t nb01;
|
| 43 |
-
uint32_t nb02;
|
| 44 |
-
uint32_t nb03;
|
| 45 |
-
uint32_t nb11;
|
| 46 |
-
uint32_t nb12;
|
| 47 |
-
uint32_t nb13;
|
| 48 |
-
uint32_t nb21;
|
| 49 |
-
uint32_t nb22;
|
| 50 |
-
uint32_t nb23;
|
| 51 |
-
uint32_t nb31;
|
| 52 |
-
|
| 53 |
-
float scale;
|
| 54 |
-
float max_bias;
|
| 55 |
-
float logit_softcap;
|
| 56 |
-
|
| 57 |
-
uint32_t mask;
|
| 58 |
-
uint32_t n_head_log2;
|
| 59 |
-
float m0;
|
| 60 |
-
float m1;
|
| 61 |
-
|
| 62 |
-
uint32_t gqa_ratio;
|
| 63 |
-
uint32_t split_kv;
|
| 64 |
-
uint32_t k_num;
|
| 65 |
-
} p;
|
| 66 |
|
| 67 |
layout (binding = 0) readonly buffer Q {float data_q[];};
|
| 68 |
layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};
|
|
@@ -71,39 +24,6 @@ layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];};
|
|
| 71 |
layout (binding = 2) readonly buffer V {float16_t data_v[];};
|
| 72 |
layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
|
| 73 |
layout (binding = 3) readonly buffer M {float16_t data_m[];};
|
| 74 |
-
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
|
| 75 |
-
|
| 76 |
-
#if defined(A_TYPE_PACKED16)
|
| 77 |
-
#define BINDING_IDX_K 0
|
| 78 |
-
#define BINDING_IDX_V 1
|
| 79 |
-
layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2];
|
| 80 |
-
#endif
|
| 81 |
-
|
| 82 |
-
#if defined(DATA_A_Q4_0)
|
| 83 |
-
#define BLOCK_BYTE_SIZE 18
|
| 84 |
-
|
| 85 |
-
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
| 86 |
-
uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
|
| 87 |
-
uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
|
| 88 |
-
uint shift = (iqs & 0x10) >> 2;
|
| 89 |
-
vui_lo >>= shift;
|
| 90 |
-
vui_hi >>= shift;
|
| 91 |
-
|
| 92 |
-
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
|
| 93 |
-
}
|
| 94 |
-
#endif
|
| 95 |
-
|
| 96 |
-
#if defined(DATA_A_Q8_0)
|
| 97 |
-
#define BLOCK_BYTE_SIZE 34
|
| 98 |
-
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
| 99 |
-
const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
|
| 100 |
-
const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
|
| 101 |
-
|
| 102 |
-
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
|
| 103 |
-
}
|
| 104 |
-
#endif
|
| 105 |
-
|
| 106 |
-
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
|
| 107 |
|
| 108 |
// Store the output when doing grouped query attention.
|
| 109 |
// Rows index by Q's dimension 2, and the first N rows are valid.
|
|
@@ -114,27 +34,6 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
|
|
| 114 |
return elem;
|
| 115 |
}
|
| 116 |
|
| 117 |
-
// Store column zero. This is used to save per-row m and L values for split_k.
|
| 118 |
-
ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
| 119 |
-
{
|
| 120 |
-
if (r < N && c == 0) {
|
| 121 |
-
uint32_t offset = iq2 + r;
|
| 122 |
-
data_o[o_offset + offset] = D_TYPE(elem);
|
| 123 |
-
}
|
| 124 |
-
return elem;
|
| 125 |
-
}
|
| 126 |
-
|
| 127 |
-
// Load the slope matrix, indexed by Q's dimension 2.
|
| 128 |
-
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
|
| 129 |
-
{
|
| 130 |
-
const uint32_t h = iq2 + (r % p.gqa_ratio);
|
| 131 |
-
|
| 132 |
-
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
|
| 133 |
-
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
|
| 134 |
-
|
| 135 |
-
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
|
| 136 |
-
}
|
| 137 |
-
|
| 138 |
shared FLOAT_TYPE tmpsh[WorkGroupSize];
|
| 139 |
shared vec4 tmpshv4[WorkGroupSize];
|
| 140 |
|
|
@@ -146,58 +45,12 @@ void main() {
|
|
| 146 |
init_iq_shmem(gl_WorkGroupSize);
|
| 147 |
#endif
|
| 148 |
|
| 149 |
-
|
| 150 |
-
const uint32_t N = p.N;
|
| 151 |
-
const uint32_t KV = p.KV;
|
| 152 |
|
|
|
|
| 153 |
const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
|
| 154 |
const uint32_t col_tid = gl_LocalInvocationIndex / D_split;
|
| 155 |
|
| 156 |
-
uint32_t i = gl_WorkGroupID.x;
|
| 157 |
-
uint32_t split_k_index = 0;
|
| 158 |
-
|
| 159 |
-
if (p.k_num > 1) {
|
| 160 |
-
i = 0;
|
| 161 |
-
split_k_index = gl_WorkGroupID.x;
|
| 162 |
-
}
|
| 163 |
-
|
| 164 |
-
const uint32_t Tr = CEIL_DIV(N, Br);
|
| 165 |
-
|
| 166 |
-
const uint32_t start_j = split_k_index * p.split_kv / Bc;
|
| 167 |
-
const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
|
| 168 |
-
|
| 169 |
-
// When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
|
| 170 |
-
// When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
|
| 171 |
-
const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio;
|
| 172 |
-
const uint32_t iq3 = gl_WorkGroupID.z;
|
| 173 |
-
|
| 174 |
-
// broadcast factors
|
| 175 |
-
const uint32_t rk2 = p.neq2/p.nek2;
|
| 176 |
-
const uint32_t rk3 = p.neq3/p.nek3;
|
| 177 |
-
|
| 178 |
-
const uint32_t rv2 = p.neq2/p.nev2;
|
| 179 |
-
const uint32_t rv3 = p.neq3/p.nev3;
|
| 180 |
-
|
| 181 |
-
// k indices
|
| 182 |
-
const uint32_t ik3 = iq3 / rk3;
|
| 183 |
-
const uint32_t ik2 = iq2 / rk2;
|
| 184 |
-
|
| 185 |
-
// v indices
|
| 186 |
-
const uint32_t iv3 = iq3 / rv3;
|
| 187 |
-
const uint32_t iv2 = iq2 / rv2;
|
| 188 |
-
|
| 189 |
-
// nb?1 are already divided by the type size and are in units of elements.
|
| 190 |
-
// When using grouped query attention, Q is indexed by iq2, so the stride
|
| 191 |
-
// should be nb02 (which is in bytes).
|
| 192 |
-
uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
|
| 193 |
-
uint32_t k_stride = p.nb11;
|
| 194 |
-
uint32_t v_stride = p.nb21;
|
| 195 |
-
// When using grouped query attention, all rows use the same mask (stride 0).
|
| 196 |
-
// "p.gqa_ratio >> 16" is just a roundabout way of writing zero
|
| 197 |
-
// that prevents the compiler from folding the "&" through the select
|
| 198 |
-
// and breaking the alignment detection.
|
| 199 |
-
uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
|
| 200 |
-
|
| 201 |
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
|
| 202 |
|
| 203 |
[[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
|
|
|
|
| 9 |
#extension GL_KHR_shader_subgroup_shuffle : enable
|
| 10 |
|
| 11 |
#include "types.comp"
|
| 12 |
+
#include "flash_attn_base.comp"
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
const uint32_t D_per_thread = D / D_split;
|
| 15 |
|
| 16 |
const uint32_t cols_per_iter = WorkGroupSize / D_split;
|
| 17 |
const uint32_t cols_per_thread = Bc / cols_per_iter;
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
layout (binding = 0) readonly buffer Q {float data_q[];};
|
| 21 |
layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};
|
|
|
|
| 24 |
layout (binding = 2) readonly buffer V {float16_t data_v[];};
|
| 25 |
layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
|
| 26 |
layout (binding = 3) readonly buffer M {float16_t data_m[];};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
// Store the output when doing grouped query attention.
|
| 29 |
// Rows index by Q's dimension 2, and the first N rows are valid.
|
|
|
|
| 34 |
return elem;
|
| 35 |
}
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
shared FLOAT_TYPE tmpsh[WorkGroupSize];
|
| 38 |
shared vec4 tmpshv4[WorkGroupSize];
|
| 39 |
|
|
|
|
| 45 |
init_iq_shmem(gl_WorkGroupSize);
|
| 46 |
#endif
|
| 47 |
|
| 48 |
+
init_indices();
|
|
|
|
|
|
|
| 49 |
|
| 50 |
+
const uint32_t tid = gl_LocalInvocationIndex;
|
| 51 |
const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
|
| 52 |
const uint32_t col_tid = gl_LocalInvocationIndex / D_split;
|
| 53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
|
| 55 |
|
| 56 |
[[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
|
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
| 3 |
+
|
| 4 |
+
layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
|
| 5 |
+
layout (constant_id = 1) const uint32_t Br = 1;
|
| 6 |
+
layout (constant_id = 2) const uint32_t Bc = 32;
|
| 7 |
+
layout (constant_id = 3) const uint32_t D = 32;
|
| 8 |
+
layout (constant_id = 4) const uint32_t Clamp = 0;
|
| 9 |
+
layout (constant_id = 5) const uint32_t D_split = 16;
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
layout (push_constant) uniform parameter {
|
| 13 |
+
uint32_t N;
|
| 14 |
+
uint32_t KV;
|
| 15 |
+
|
| 16 |
+
uint32_t ne1;
|
| 17 |
+
uint32_t ne2;
|
| 18 |
+
uint32_t ne3;
|
| 19 |
+
|
| 20 |
+
uint32_t neq2;
|
| 21 |
+
uint32_t neq3;
|
| 22 |
+
uint32_t nek2;
|
| 23 |
+
uint32_t nek3;
|
| 24 |
+
uint32_t nev2;
|
| 25 |
+
uint32_t nev3;
|
| 26 |
+
uint32_t nem1;
|
| 27 |
+
|
| 28 |
+
uint32_t nb01;
|
| 29 |
+
uint32_t nb02;
|
| 30 |
+
uint32_t nb03;
|
| 31 |
+
uint32_t nb11;
|
| 32 |
+
uint32_t nb12;
|
| 33 |
+
uint32_t nb13;
|
| 34 |
+
uint32_t nb21;
|
| 35 |
+
uint32_t nb22;
|
| 36 |
+
uint32_t nb23;
|
| 37 |
+
uint32_t nb31;
|
| 38 |
+
|
| 39 |
+
float scale;
|
| 40 |
+
float max_bias;
|
| 41 |
+
float logit_softcap;
|
| 42 |
+
|
| 43 |
+
uint32_t mask;
|
| 44 |
+
uint32_t n_head_log2;
|
| 45 |
+
float m0;
|
| 46 |
+
float m1;
|
| 47 |
+
|
| 48 |
+
uint32_t gqa_ratio;
|
| 49 |
+
uint32_t split_kv;
|
| 50 |
+
uint32_t k_num;
|
| 51 |
+
} p;
|
| 52 |
+
|
| 53 |
+
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
|
| 54 |
+
|
| 55 |
+
#if defined(A_TYPE_PACKED16)
|
| 56 |
+
#define BINDING_IDX_K 0
|
| 57 |
+
#define BINDING_IDX_V 1
|
| 58 |
+
layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2];
|
| 59 |
+
#endif
|
| 60 |
+
|
| 61 |
+
#if defined(DATA_A_Q4_0)
|
| 62 |
+
#define BLOCK_BYTE_SIZE 18
|
| 63 |
+
|
| 64 |
+
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
| 65 |
+
uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
|
| 66 |
+
uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
|
| 67 |
+
uint shift = (iqs & 0x10) >> 2;
|
| 68 |
+
vui_lo >>= shift;
|
| 69 |
+
vui_hi >>= shift;
|
| 70 |
+
|
| 71 |
+
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
|
| 72 |
+
}
|
| 73 |
+
#endif
|
| 74 |
+
|
| 75 |
+
#if defined(DATA_A_Q8_0)
|
| 76 |
+
#define BLOCK_BYTE_SIZE 34
|
| 77 |
+
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
| 78 |
+
const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
|
| 79 |
+
const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
|
| 80 |
+
|
| 81 |
+
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
|
| 82 |
+
}
|
| 83 |
+
#endif
|
| 84 |
+
|
| 85 |
+
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
// Store column zero. This is used to save per-row m and L values for split_k.
|
| 89 |
+
ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
| 90 |
+
{
|
| 91 |
+
if (r < N && c == 0) {
|
| 92 |
+
uint32_t offset = iq2 + r;
|
| 93 |
+
data_o[o_offset + offset] = D_TYPE(elem);
|
| 94 |
+
}
|
| 95 |
+
return elem;
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
// Load the slope matrix, indexed by Q's dimension 2.
|
| 99 |
+
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
|
| 100 |
+
{
|
| 101 |
+
const uint32_t h = iq2 + (r % p.gqa_ratio);
|
| 102 |
+
|
| 103 |
+
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
|
| 104 |
+
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
|
| 105 |
+
|
| 106 |
+
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
uint32_t i, N, KV, split_k_index, Tr, start_j, end_j,
|
| 110 |
+
iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3,
|
| 111 |
+
q_stride, k_stride, v_stride, m_stride;
|
| 112 |
+
|
| 113 |
+
void init_indices()
|
| 114 |
+
{
|
| 115 |
+
N = p.N;
|
| 116 |
+
KV = p.KV;
|
| 117 |
+
|
| 118 |
+
i = gl_WorkGroupID.x;
|
| 119 |
+
split_k_index = 0;
|
| 120 |
+
|
| 121 |
+
if (p.k_num > 1) {
|
| 122 |
+
i = 0;
|
| 123 |
+
split_k_index = gl_WorkGroupID.x;
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
Tr = CEIL_DIV(N, Br);
|
| 127 |
+
|
| 128 |
+
start_j = split_k_index * p.split_kv / Bc;
|
| 129 |
+
end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
|
| 130 |
+
|
| 131 |
+
// When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
|
| 132 |
+
// When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
|
| 133 |
+
iq2 = gl_WorkGroupID.y * p.gqa_ratio;
|
| 134 |
+
iq3 = gl_WorkGroupID.z;
|
| 135 |
+
|
| 136 |
+
// broadcast factors
|
| 137 |
+
rk2 = p.neq2/p.nek2;
|
| 138 |
+
rk3 = p.neq3/p.nek3;
|
| 139 |
+
|
| 140 |
+
rv2 = p.neq2/p.nev2;
|
| 141 |
+
rv3 = p.neq3/p.nev3;
|
| 142 |
+
|
| 143 |
+
// k indices
|
| 144 |
+
ik3 = iq3 / rk3;
|
| 145 |
+
ik2 = iq2 / rk2;
|
| 146 |
+
|
| 147 |
+
// v indices
|
| 148 |
+
iv3 = iq3 / rv3;
|
| 149 |
+
iv2 = iq2 / rv2;
|
| 150 |
+
|
| 151 |
+
// nb?1 are already divided by the type size and are in units of elements.
|
| 152 |
+
// When using grouped query attention, Q is indexed by iq2, so the stride
|
| 153 |
+
// should be nb02 (which is in bytes).
|
| 154 |
+
q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
|
| 155 |
+
k_stride = p.nb11;
|
| 156 |
+
v_stride = p.nb21;
|
| 157 |
+
// When using grouped query attention, all rows use the same mask (stride 0).
|
| 158 |
+
// "p.gqa_ratio >> 16" is just a roundabout way of writing zero
|
| 159 |
+
// that prevents the compiler from folding the "&" through the select
|
| 160 |
+
// and breaking the alignment detection.
|
| 161 |
+
m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
|
| 162 |
+
}
|
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
CHANGED
|
@@ -11,14 +11,7 @@
|
|
| 11 |
#extension GL_KHR_cooperative_matrix : enable
|
| 12 |
|
| 13 |
#include "types.comp"
|
| 14 |
-
|
| 15 |
-
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
| 16 |
-
|
| 17 |
-
layout (constant_id = 1) const uint32_t Br = 1;
|
| 18 |
-
layout (constant_id = 2) const uint32_t Bc = 32;
|
| 19 |
-
layout (constant_id = 3) const uint32_t D = 32;
|
| 20 |
-
|
| 21 |
-
layout (constant_id = 5) const uint32_t D_split = 16;
|
| 22 |
|
| 23 |
const uint32_t D_per_thread = D / D_split;
|
| 24 |
const uint32_t row_split = 4;
|
|
@@ -26,46 +19,6 @@ const uint32_t rows_per_thread = Br / row_split;
|
|
| 26 |
const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split;
|
| 27 |
const uint32_t cols_per_thread = Bc / cols_per_iter;
|
| 28 |
|
| 29 |
-
layout (push_constant) uniform parameter {
|
| 30 |
-
uint32_t N;
|
| 31 |
-
uint32_t KV;
|
| 32 |
-
|
| 33 |
-
uint32_t ne1;
|
| 34 |
-
uint32_t ne2;
|
| 35 |
-
uint32_t ne3;
|
| 36 |
-
|
| 37 |
-
uint32_t neq2;
|
| 38 |
-
uint32_t neq3;
|
| 39 |
-
uint32_t nek2;
|
| 40 |
-
uint32_t nek3;
|
| 41 |
-
uint32_t nev2;
|
| 42 |
-
uint32_t nev3;
|
| 43 |
-
uint32_t nem1;
|
| 44 |
-
|
| 45 |
-
uint32_t nb01;
|
| 46 |
-
uint32_t nb02;
|
| 47 |
-
uint32_t nb03;
|
| 48 |
-
uint32_t nb11;
|
| 49 |
-
uint32_t nb12;
|
| 50 |
-
uint32_t nb13;
|
| 51 |
-
uint32_t nb21;
|
| 52 |
-
uint32_t nb22;
|
| 53 |
-
uint32_t nb23;
|
| 54 |
-
uint32_t nb31;
|
| 55 |
-
|
| 56 |
-
float scale;
|
| 57 |
-
float max_bias;
|
| 58 |
-
float logit_softcap;
|
| 59 |
-
|
| 60 |
-
uint32_t mask;
|
| 61 |
-
uint32_t n_head_log2;
|
| 62 |
-
float m0;
|
| 63 |
-
float m1;
|
| 64 |
-
|
| 65 |
-
uint32_t gqa_ratio;
|
| 66 |
-
uint32_t split_kv;
|
| 67 |
-
uint32_t k_num;
|
| 68 |
-
} p;
|
| 69 |
|
| 70 |
layout (binding = 0) readonly buffer Q {float data_q[];};
|
| 71 |
layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};
|
|
@@ -74,39 +27,6 @@ layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];};
|
|
| 74 |
layout (binding = 2) readonly buffer V {float16_t data_v[];};
|
| 75 |
layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
|
| 76 |
layout (binding = 3) readonly buffer M {float16_t data_m[];};
|
| 77 |
-
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
|
| 78 |
-
|
| 79 |
-
#if defined(A_TYPE_PACKED16)
|
| 80 |
-
#define BINDING_IDX_K 0
|
| 81 |
-
#define BINDING_IDX_V 1
|
| 82 |
-
layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2];
|
| 83 |
-
#endif
|
| 84 |
-
|
| 85 |
-
#if defined(DATA_A_Q4_0)
|
| 86 |
-
#define BLOCK_BYTE_SIZE 18
|
| 87 |
-
|
| 88 |
-
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
| 89 |
-
uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
|
| 90 |
-
uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
|
| 91 |
-
uint shift = (iqs & 0x10) >> 2;
|
| 92 |
-
vui_lo >>= shift;
|
| 93 |
-
vui_hi >>= shift;
|
| 94 |
-
|
| 95 |
-
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
|
| 96 |
-
}
|
| 97 |
-
#endif
|
| 98 |
-
|
| 99 |
-
#if defined(DATA_A_Q8_0)
|
| 100 |
-
#define BLOCK_BYTE_SIZE 34
|
| 101 |
-
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
| 102 |
-
const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
|
| 103 |
-
const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
|
| 104 |
-
|
| 105 |
-
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
|
| 106 |
-
}
|
| 107 |
-
#endif
|
| 108 |
-
|
| 109 |
-
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
|
| 110 |
|
| 111 |
// Store the output when doing grouped query attention.
|
| 112 |
// Rows index by Q's dimension 2, and the first N rows are valid.
|
|
@@ -117,27 +37,6 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
|
|
| 117 |
return elem;
|
| 118 |
}
|
| 119 |
|
| 120 |
-
// Store column zero. This is used to save per-row m and L values for split_k.
|
| 121 |
-
ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
| 122 |
-
{
|
| 123 |
-
if (r < N && c == 0) {
|
| 124 |
-
uint32_t offset = iq2 + r;
|
| 125 |
-
data_o[o_offset + offset] = D_TYPE(elem);
|
| 126 |
-
}
|
| 127 |
-
return elem;
|
| 128 |
-
}
|
| 129 |
-
|
| 130 |
-
// Load the slope matrix, indexed by Q's dimension 2.
|
| 131 |
-
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
|
| 132 |
-
{
|
| 133 |
-
const uint32_t h = iq2 + (r % p.gqa_ratio);
|
| 134 |
-
|
| 135 |
-
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
|
| 136 |
-
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
|
| 137 |
-
|
| 138 |
-
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
|
| 139 |
-
}
|
| 140 |
-
|
| 141 |
// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd
|
| 142 |
const uint32_t MatBr = 16;
|
| 143 |
const uint32_t MatBc = 16;
|
|
@@ -162,9 +61,9 @@ void main() {
|
|
| 162 |
init_iq_shmem(gl_WorkGroupSize);
|
| 163 |
#endif
|
| 164 |
|
|
|
|
|
|
|
| 165 |
const uint32_t tid = gl_LocalInvocationIndex;
|
| 166 |
-
const uint32_t N = p.N;
|
| 167 |
-
const uint32_t KV = p.KV;
|
| 168 |
|
| 169 |
const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split;
|
| 170 |
const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup;
|
|
@@ -173,51 +72,6 @@ void main() {
|
|
| 173 |
|
| 174 |
#define tile_row(r) (row_tid * rows_per_thread + (r))
|
| 175 |
|
| 176 |
-
uint32_t i = gl_WorkGroupID.x;
|
| 177 |
-
uint32_t split_k_index = 0;
|
| 178 |
-
|
| 179 |
-
if (p.k_num > 1) {
|
| 180 |
-
i = 0;
|
| 181 |
-
split_k_index = gl_WorkGroupID.x;
|
| 182 |
-
}
|
| 183 |
-
|
| 184 |
-
const uint32_t Tr = CEIL_DIV(N, Br);
|
| 185 |
-
|
| 186 |
-
const uint32_t start_j = split_k_index * p.split_kv / Bc;
|
| 187 |
-
const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
|
| 188 |
-
|
| 189 |
-
// When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
|
| 190 |
-
// When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
|
| 191 |
-
const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio;
|
| 192 |
-
const uint32_t iq3 = gl_WorkGroupID.z;
|
| 193 |
-
|
| 194 |
-
// broadcast factors
|
| 195 |
-
const uint32_t rk2 = p.neq2/p.nek2;
|
| 196 |
-
const uint32_t rk3 = p.neq3/p.nek3;
|
| 197 |
-
|
| 198 |
-
const uint32_t rv2 = p.neq2/p.nev2;
|
| 199 |
-
const uint32_t rv3 = p.neq3/p.nev3;
|
| 200 |
-
|
| 201 |
-
// k indices
|
| 202 |
-
const uint32_t ik3 = iq3 / rk3;
|
| 203 |
-
const uint32_t ik2 = iq2 / rk2;
|
| 204 |
-
|
| 205 |
-
// v indices
|
| 206 |
-
const uint32_t iv3 = iq3 / rv3;
|
| 207 |
-
const uint32_t iv2 = iq2 / rv2;
|
| 208 |
-
|
| 209 |
-
// nb?1 are already divided by the type size and are in units of elements.
|
| 210 |
-
// When using grouped query attention, Q is indexed by iq2, so the stride
|
| 211 |
-
// should be nb02 (which is in bytes).
|
| 212 |
-
uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
|
| 213 |
-
uint32_t k_stride = p.nb11;
|
| 214 |
-
uint32_t v_stride = p.nb21;
|
| 215 |
-
// When using grouped query attention, all rows use the same mask (stride 0).
|
| 216 |
-
// "p.gqa_ratio >> 16" is just a roundabout way of writing zero
|
| 217 |
-
// that prevents the compiler from folding the "&" through the select
|
| 218 |
-
// and breaking the alignment detection.
|
| 219 |
-
uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
|
| 220 |
-
|
| 221 |
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
|
| 222 |
|
| 223 |
[[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
|
|
|
|
| 11 |
#extension GL_KHR_cooperative_matrix : enable
|
| 12 |
|
| 13 |
#include "types.comp"
|
| 14 |
+
#include "flash_attn_base.comp"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
const uint32_t D_per_thread = D / D_split;
|
| 17 |
const uint32_t row_split = 4;
|
|
|
|
| 19 |
const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split;
|
| 20 |
const uint32_t cols_per_thread = Bc / cols_per_iter;
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
layout (binding = 0) readonly buffer Q {float data_q[];};
|
| 24 |
layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};
|
|
|
|
| 27 |
layout (binding = 2) readonly buffer V {float16_t data_v[];};
|
| 28 |
layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
|
| 29 |
layout (binding = 3) readonly buffer M {float16_t data_m[];};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
// Store the output when doing grouped query attention.
|
| 32 |
// Rows index by Q's dimension 2, and the first N rows are valid.
|
|
|
|
| 37 |
return elem;
|
| 38 |
}
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd
|
| 41 |
const uint32_t MatBr = 16;
|
| 42 |
const uint32_t MatBc = 16;
|
|
|
|
| 61 |
init_iq_shmem(gl_WorkGroupSize);
|
| 62 |
#endif
|
| 63 |
|
| 64 |
+
init_indices();
|
| 65 |
+
|
| 66 |
const uint32_t tid = gl_LocalInvocationIndex;
|
|
|
|
|
|
|
| 67 |
|
| 68 |
const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split;
|
| 69 |
const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup;
|
|
|
|
| 72 |
|
| 73 |
#define tile_row(r) (row_tid * rows_per_thread + (r))
|
| 74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
|
| 76 |
|
| 77 |
[[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
|
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
CHANGED
|
@@ -18,62 +18,12 @@
|
|
| 18 |
|
| 19 |
#include "types.comp"
|
| 20 |
#include "dequant_funcs_cm2.comp"
|
| 21 |
-
|
| 22 |
-
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
| 23 |
-
|
| 24 |
-
layout (constant_id = 1) const uint32_t Br = 32;
|
| 25 |
-
layout (constant_id = 2) const uint32_t Bc = 32;
|
| 26 |
-
layout (constant_id = 3) const uint32_t D = 32;
|
| 27 |
-
layout (constant_id = 4) const uint32_t Clamp = gl_CooperativeMatrixClampModeConstantNV;
|
| 28 |
-
|
| 29 |
-
layout (push_constant) uniform parameter {
|
| 30 |
-
uint32_t N;
|
| 31 |
-
uint32_t KV;
|
| 32 |
-
|
| 33 |
-
uint32_t ne1;
|
| 34 |
-
uint32_t ne2;
|
| 35 |
-
uint32_t ne3;
|
| 36 |
-
|
| 37 |
-
uint32_t neq2;
|
| 38 |
-
uint32_t neq3;
|
| 39 |
-
uint32_t nek2;
|
| 40 |
-
uint32_t nek3;
|
| 41 |
-
uint32_t nev2;
|
| 42 |
-
uint32_t nev3;
|
| 43 |
-
uint32_t nem1;
|
| 44 |
-
|
| 45 |
-
uint32_t nb01;
|
| 46 |
-
uint32_t nb02;
|
| 47 |
-
uint32_t nb03;
|
| 48 |
-
uint32_t nb11;
|
| 49 |
-
uint32_t nb12;
|
| 50 |
-
uint32_t nb13;
|
| 51 |
-
uint32_t nb21;
|
| 52 |
-
uint32_t nb22;
|
| 53 |
-
uint32_t nb23;
|
| 54 |
-
uint32_t nb31;
|
| 55 |
-
|
| 56 |
-
float scale;
|
| 57 |
-
float max_bias;
|
| 58 |
-
float logit_softcap;
|
| 59 |
-
|
| 60 |
-
uint32_t mask;
|
| 61 |
-
uint32_t n_head_log2;
|
| 62 |
-
float m0;
|
| 63 |
-
float m1;
|
| 64 |
-
|
| 65 |
-
uint32_t gqa_ratio;
|
| 66 |
-
uint32_t split_kv;
|
| 67 |
-
uint32_t k_num;
|
| 68 |
-
} p;
|
| 69 |
|
| 70 |
layout (binding = 0) readonly buffer Q {uint8_t data_q[];};
|
| 71 |
layout (binding = 1) readonly buffer K {uint8_t data_k[];};
|
| 72 |
layout (binding = 2) readonly buffer V {uint8_t data_v[];};
|
| 73 |
layout (binding = 3) readonly buffer M {uint8_t data_m[];};
|
| 74 |
-
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
|
| 75 |
-
|
| 76 |
-
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
|
| 77 |
|
| 78 |
ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
|
| 79 |
return max(x, y);
|
|
@@ -118,67 +68,12 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
|
|
| 118 |
return elem;
|
| 119 |
}
|
| 120 |
|
| 121 |
-
// Store column zero. This is used to save per-row m and L values for split_k.
|
| 122 |
-
ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
| 123 |
-
{
|
| 124 |
-
if (r < N && c == 0) {
|
| 125 |
-
uint32_t offset = iq2 + r;
|
| 126 |
-
data_o[o_offset + offset] = D_TYPE(elem);
|
| 127 |
-
}
|
| 128 |
-
return elem;
|
| 129 |
-
}
|
| 130 |
-
|
| 131 |
-
// Load the slope matrix, indexed by Q's dimension 2.
|
| 132 |
-
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
|
| 133 |
-
{
|
| 134 |
-
const uint32_t h = iq2 + (r % p.gqa_ratio);
|
| 135 |
-
|
| 136 |
-
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
|
| 137 |
-
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
|
| 138 |
-
|
| 139 |
-
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
|
| 140 |
-
}
|
| 141 |
-
|
| 142 |
void main() {
|
| 143 |
#ifdef NEEDS_INIT_IQ_SHMEM
|
| 144 |
init_iq_shmem(gl_WorkGroupSize);
|
| 145 |
#endif
|
| 146 |
|
| 147 |
-
|
| 148 |
-
const uint32_t KV = p.KV;
|
| 149 |
-
|
| 150 |
-
uint32_t i = gl_WorkGroupID.x;
|
| 151 |
-
uint32_t split_k_index = 0;
|
| 152 |
-
|
| 153 |
-
if (p.k_num > 1) {
|
| 154 |
-
i = 0;
|
| 155 |
-
split_k_index = gl_WorkGroupID.x;
|
| 156 |
-
}
|
| 157 |
-
|
| 158 |
-
const uint32_t Tr = CEIL_DIV(N, Br);
|
| 159 |
-
|
| 160 |
-
const uint32_t start_j = split_k_index * p.split_kv / Bc;
|
| 161 |
-
const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
|
| 162 |
-
|
| 163 |
-
// When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
|
| 164 |
-
// When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
|
| 165 |
-
const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio;
|
| 166 |
-
const uint32_t iq3 = gl_WorkGroupID.z;
|
| 167 |
-
|
| 168 |
-
// broadcast factors
|
| 169 |
-
const uint32_t rk2 = p.neq2/p.nek2;
|
| 170 |
-
const uint32_t rk3 = p.neq3/p.nek3;
|
| 171 |
-
|
| 172 |
-
const uint32_t rv2 = p.neq2/p.nev2;
|
| 173 |
-
const uint32_t rv3 = p.neq3/p.nev3;
|
| 174 |
-
|
| 175 |
-
// k indices
|
| 176 |
-
const uint32_t ik3 = iq3 / rk3;
|
| 177 |
-
const uint32_t ik2 = iq2 / rk2;
|
| 178 |
-
|
| 179 |
-
// v indices
|
| 180 |
-
const uint32_t iv3 = iq3 / rv3;
|
| 181 |
-
const uint32_t iv2 = iq2 / rv2;
|
| 182 |
|
| 183 |
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
|
| 184 |
tensorLayoutNV<2, Clamp> tensorLayoutK = createTensorLayoutNV(2, Clamp);
|
|
@@ -195,17 +90,6 @@ void main() {
|
|
| 195 |
tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D);
|
| 196 |
tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D);
|
| 197 |
|
| 198 |
-
// nb?1 are already divided by the type size and are in units of elements.
|
| 199 |
-
// When using grouped query attention, Q is indexed by iq2, so the stride
|
| 200 |
-
// should be nb02 (which is in bytes).
|
| 201 |
-
uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
|
| 202 |
-
uint32_t k_stride = p.nb11;
|
| 203 |
-
uint32_t v_stride = p.nb21;
|
| 204 |
-
// When using grouped query attention, all rows use the same mask (stride 0).
|
| 205 |
-
// "p.gqa_ratio >> 16" is just a roundabout way of writing zero
|
| 206 |
-
// that prevents the compiler from folding the "&" through the select
|
| 207 |
-
// and breaking the alignment detection.
|
| 208 |
-
uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
|
| 209 |
// hint to the compiler that strides are aligned for the aligned variant of the shader
|
| 210 |
if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
|
| 211 |
{
|
|
|
|
| 18 |
|
| 19 |
#include "types.comp"
|
| 20 |
#include "dequant_funcs_cm2.comp"
|
| 21 |
+
#include "flash_attn_base.comp"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
layout (binding = 0) readonly buffer Q {uint8_t data_q[];};
|
| 24 |
layout (binding = 1) readonly buffer K {uint8_t data_k[];};
|
| 25 |
layout (binding = 2) readonly buffer V {uint8_t data_v[];};
|
| 26 |
layout (binding = 3) readonly buffer M {uint8_t data_m[];};
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
|
| 29 |
return max(x, y);
|
|
|
|
| 68 |
return elem;
|
| 69 |
}
|
| 70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
void main() {
|
| 72 |
#ifdef NEEDS_INIT_IQ_SHMEM
|
| 73 |
init_iq_shmem(gl_WorkGroupSize);
|
| 74 |
#endif
|
| 75 |
|
| 76 |
+
init_indices();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
|
| 79 |
tensorLayoutNV<2, Clamp> tensorLayoutK = createTensorLayoutNV(2, Clamp);
|
|
|
|
| 90 |
tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D);
|
| 91 |
tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D);
|
| 92 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
// hint to the compiler that strides are aligned for the aligned variant of the shader
|
| 94 |
if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
|
| 95 |
{
|