jeffbolznv commited on
Commit
ad8b504
·
1 Parent(s): 97d9aa6

vulkan: move common FA code to flash_attn_base.comp (llama/13556)

Browse files

* vulkan: move common FA code to flash_attn_base.comp

* vulkan: move common FA index/stride setup code to flash_attn_base.comp

* build fix

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp CHANGED
@@ -9,60 +9,13 @@
9
  #extension GL_KHR_shader_subgroup_shuffle : enable
10
 
11
  #include "types.comp"
 
12
 
13
- layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
14
-
15
- layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
16
- layout (constant_id = 1) const uint32_t Br = 1;
17
- layout (constant_id = 2) const uint32_t Bc = 32;
18
- layout (constant_id = 3) const uint32_t D = 32;
19
-
20
- layout (constant_id = 5) const uint32_t D_split = 16;
21
  const uint32_t D_per_thread = D / D_split;
22
 
23
  const uint32_t cols_per_iter = WorkGroupSize / D_split;
24
  const uint32_t cols_per_thread = Bc / cols_per_iter;
25
 
26
- layout (push_constant) uniform parameter {
27
- uint32_t N;
28
- uint32_t KV;
29
-
30
- uint32_t ne1;
31
- uint32_t ne2;
32
- uint32_t ne3;
33
-
34
- uint32_t neq2;
35
- uint32_t neq3;
36
- uint32_t nek2;
37
- uint32_t nek3;
38
- uint32_t nev2;
39
- uint32_t nev3;
40
- uint32_t nem1;
41
-
42
- uint32_t nb01;
43
- uint32_t nb02;
44
- uint32_t nb03;
45
- uint32_t nb11;
46
- uint32_t nb12;
47
- uint32_t nb13;
48
- uint32_t nb21;
49
- uint32_t nb22;
50
- uint32_t nb23;
51
- uint32_t nb31;
52
-
53
- float scale;
54
- float max_bias;
55
- float logit_softcap;
56
-
57
- uint32_t mask;
58
- uint32_t n_head_log2;
59
- float m0;
60
- float m1;
61
-
62
- uint32_t gqa_ratio;
63
- uint32_t split_kv;
64
- uint32_t k_num;
65
- } p;
66
 
67
  layout (binding = 0) readonly buffer Q {float data_q[];};
68
  layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};
@@ -71,39 +24,6 @@ layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];};
71
  layout (binding = 2) readonly buffer V {float16_t data_v[];};
72
  layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
73
  layout (binding = 3) readonly buffer M {float16_t data_m[];};
74
- layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
75
-
76
- #if defined(A_TYPE_PACKED16)
77
- #define BINDING_IDX_K 0
78
- #define BINDING_IDX_V 1
79
- layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2];
80
- #endif
81
-
82
- #if defined(DATA_A_Q4_0)
83
- #define BLOCK_BYTE_SIZE 18
84
-
85
- vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
86
- uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
87
- uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
88
- uint shift = (iqs & 0x10) >> 2;
89
- vui_lo >>= shift;
90
- vui_hi >>= shift;
91
-
92
- return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
93
- }
94
- #endif
95
-
96
- #if defined(DATA_A_Q8_0)
97
- #define BLOCK_BYTE_SIZE 34
98
- vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
99
- const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
100
- const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
101
-
102
- return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
103
- }
104
- #endif
105
-
106
- #define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
107
 
108
  // Store the output when doing grouped query attention.
109
  // Rows index by Q's dimension 2, and the first N rows are valid.
@@ -114,27 +34,6 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
114
  return elem;
115
  }
116
 
117
- // Store column zero. This is used to save per-row m and L values for split_k.
118
- ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
119
- {
120
- if (r < N && c == 0) {
121
- uint32_t offset = iq2 + r;
122
- data_o[o_offset + offset] = D_TYPE(elem);
123
- }
124
- return elem;
125
- }
126
-
127
- // Load the slope matrix, indexed by Q's dimension 2.
128
- ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
129
- {
130
- const uint32_t h = iq2 + (r % p.gqa_ratio);
131
-
132
- const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
133
- const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
134
-
135
- return ACC_TYPE(pow(base, ACC_TYPE(exph)));
136
- }
137
-
138
  shared FLOAT_TYPE tmpsh[WorkGroupSize];
139
  shared vec4 tmpshv4[WorkGroupSize];
140
 
@@ -146,58 +45,12 @@ void main() {
146
  init_iq_shmem(gl_WorkGroupSize);
147
  #endif
148
 
149
- const uint32_t tid = gl_LocalInvocationIndex;
150
- const uint32_t N = p.N;
151
- const uint32_t KV = p.KV;
152
 
 
153
  const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
154
  const uint32_t col_tid = gl_LocalInvocationIndex / D_split;
155
 
156
- uint32_t i = gl_WorkGroupID.x;
157
- uint32_t split_k_index = 0;
158
-
159
- if (p.k_num > 1) {
160
- i = 0;
161
- split_k_index = gl_WorkGroupID.x;
162
- }
163
-
164
- const uint32_t Tr = CEIL_DIV(N, Br);
165
-
166
- const uint32_t start_j = split_k_index * p.split_kv / Bc;
167
- const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
168
-
169
- // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
170
- // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
171
- const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio;
172
- const uint32_t iq3 = gl_WorkGroupID.z;
173
-
174
- // broadcast factors
175
- const uint32_t rk2 = p.neq2/p.nek2;
176
- const uint32_t rk3 = p.neq3/p.nek3;
177
-
178
- const uint32_t rv2 = p.neq2/p.nev2;
179
- const uint32_t rv3 = p.neq3/p.nev3;
180
-
181
- // k indices
182
- const uint32_t ik3 = iq3 / rk3;
183
- const uint32_t ik2 = iq2 / rk2;
184
-
185
- // v indices
186
- const uint32_t iv3 = iq3 / rv3;
187
- const uint32_t iv2 = iq2 / rv2;
188
-
189
- // nb?1 are already divided by the type size and are in units of elements.
190
- // When using grouped query attention, Q is indexed by iq2, so the stride
191
- // should be nb02 (which is in bytes).
192
- uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
193
- uint32_t k_stride = p.nb11;
194
- uint32_t v_stride = p.nb21;
195
- // When using grouped query attention, all rows use the same mask (stride 0).
196
- // "p.gqa_ratio >> 16" is just a roundabout way of writing zero
197
- // that prevents the compiler from folding the "&" through the select
198
- // and breaking the alignment detection.
199
- uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
200
-
201
  uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
202
 
203
  [[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
 
9
  #extension GL_KHR_shader_subgroup_shuffle : enable
10
 
11
  #include "types.comp"
12
+ #include "flash_attn_base.comp"
13
 
 
 
 
 
 
 
 
 
14
  const uint32_t D_per_thread = D / D_split;
15
 
16
  const uint32_t cols_per_iter = WorkGroupSize / D_split;
17
  const uint32_t cols_per_thread = Bc / cols_per_iter;
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  layout (binding = 0) readonly buffer Q {float data_q[];};
21
  layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};
 
24
  layout (binding = 2) readonly buffer V {float16_t data_v[];};
25
  layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
26
  layout (binding = 3) readonly buffer M {float16_t data_m[];};
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  // Store the output when doing grouped query attention.
29
  // Rows index by Q's dimension 2, and the first N rows are valid.
 
34
  return elem;
35
  }
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  shared FLOAT_TYPE tmpsh[WorkGroupSize];
38
  shared vec4 tmpshv4[WorkGroupSize];
39
 
 
45
  init_iq_shmem(gl_WorkGroupSize);
46
  #endif
47
 
48
+ init_indices();
 
 
49
 
50
+ const uint32_t tid = gl_LocalInvocationIndex;
51
  const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
52
  const uint32_t col_tid = gl_LocalInvocationIndex / D_split;
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
55
 
56
  [[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
3
+
4
+ layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
5
+ layout (constant_id = 1) const uint32_t Br = 1;
6
+ layout (constant_id = 2) const uint32_t Bc = 32;
7
+ layout (constant_id = 3) const uint32_t D = 32;
8
+ layout (constant_id = 4) const uint32_t Clamp = 0;
9
+ layout (constant_id = 5) const uint32_t D_split = 16;
10
+
11
+
12
+ layout (push_constant) uniform parameter {
13
+ uint32_t N;
14
+ uint32_t KV;
15
+
16
+ uint32_t ne1;
17
+ uint32_t ne2;
18
+ uint32_t ne3;
19
+
20
+ uint32_t neq2;
21
+ uint32_t neq3;
22
+ uint32_t nek2;
23
+ uint32_t nek3;
24
+ uint32_t nev2;
25
+ uint32_t nev3;
26
+ uint32_t nem1;
27
+
28
+ uint32_t nb01;
29
+ uint32_t nb02;
30
+ uint32_t nb03;
31
+ uint32_t nb11;
32
+ uint32_t nb12;
33
+ uint32_t nb13;
34
+ uint32_t nb21;
35
+ uint32_t nb22;
36
+ uint32_t nb23;
37
+ uint32_t nb31;
38
+
39
+ float scale;
40
+ float max_bias;
41
+ float logit_softcap;
42
+
43
+ uint32_t mask;
44
+ uint32_t n_head_log2;
45
+ float m0;
46
+ float m1;
47
+
48
+ uint32_t gqa_ratio;
49
+ uint32_t split_kv;
50
+ uint32_t k_num;
51
+ } p;
52
+
53
+ layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
54
+
55
+ #if defined(A_TYPE_PACKED16)
56
+ #define BINDING_IDX_K 0
57
+ #define BINDING_IDX_V 1
58
+ layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2];
59
+ #endif
60
+
61
+ #if defined(DATA_A_Q4_0)
62
+ #define BLOCK_BYTE_SIZE 18
63
+
64
+ vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
65
+ uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
66
+ uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
67
+ uint shift = (iqs & 0x10) >> 2;
68
+ vui_lo >>= shift;
69
+ vui_hi >>= shift;
70
+
71
+ return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
72
+ }
73
+ #endif
74
+
75
+ #if defined(DATA_A_Q8_0)
76
+ #define BLOCK_BYTE_SIZE 34
77
+ vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
78
+ const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
79
+ const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
80
+
81
+ return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
82
+ }
83
+ #endif
84
+
85
+ #define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
86
+
87
+
88
+ // Store column zero. This is used to save per-row m and L values for split_k.
89
+ ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
90
+ {
91
+ if (r < N && c == 0) {
92
+ uint32_t offset = iq2 + r;
93
+ data_o[o_offset + offset] = D_TYPE(elem);
94
+ }
95
+ return elem;
96
+ }
97
+
98
+ // Load the slope matrix, indexed by Q's dimension 2.
99
+ ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
100
+ {
101
+ const uint32_t h = iq2 + (r % p.gqa_ratio);
102
+
103
+ const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
104
+ const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
105
+
106
+ return ACC_TYPE(pow(base, ACC_TYPE(exph)));
107
+ }
108
+
109
+ uint32_t i, N, KV, split_k_index, Tr, start_j, end_j,
110
+ iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3,
111
+ q_stride, k_stride, v_stride, m_stride;
112
+
113
+ void init_indices()
114
+ {
115
+ N = p.N;
116
+ KV = p.KV;
117
+
118
+ i = gl_WorkGroupID.x;
119
+ split_k_index = 0;
120
+
121
+ if (p.k_num > 1) {
122
+ i = 0;
123
+ split_k_index = gl_WorkGroupID.x;
124
+ }
125
+
126
+ Tr = CEIL_DIV(N, Br);
127
+
128
+ start_j = split_k_index * p.split_kv / Bc;
129
+ end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
130
+
131
+ // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
132
+ // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
133
+ iq2 = gl_WorkGroupID.y * p.gqa_ratio;
134
+ iq3 = gl_WorkGroupID.z;
135
+
136
+ // broadcast factors
137
+ rk2 = p.neq2/p.nek2;
138
+ rk3 = p.neq3/p.nek3;
139
+
140
+ rv2 = p.neq2/p.nev2;
141
+ rv3 = p.neq3/p.nev3;
142
+
143
+ // k indices
144
+ ik3 = iq3 / rk3;
145
+ ik2 = iq2 / rk2;
146
+
147
+ // v indices
148
+ iv3 = iq3 / rv3;
149
+ iv2 = iq2 / rv2;
150
+
151
+ // nb?1 are already divided by the type size and are in units of elements.
152
+ // When using grouped query attention, Q is indexed by iq2, so the stride
153
+ // should be nb02 (which is in bytes).
154
+ q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
155
+ k_stride = p.nb11;
156
+ v_stride = p.nb21;
157
+ // When using grouped query attention, all rows use the same mask (stride 0).
158
+ // "p.gqa_ratio >> 16" is just a roundabout way of writing zero
159
+ // that prevents the compiler from folding the "&" through the select
160
+ // and breaking the alignment detection.
161
+ m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
162
+ }
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp CHANGED
@@ -11,14 +11,7 @@
11
  #extension GL_KHR_cooperative_matrix : enable
12
 
13
  #include "types.comp"
14
-
15
- layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
16
-
17
- layout (constant_id = 1) const uint32_t Br = 1;
18
- layout (constant_id = 2) const uint32_t Bc = 32;
19
- layout (constant_id = 3) const uint32_t D = 32;
20
-
21
- layout (constant_id = 5) const uint32_t D_split = 16;
22
 
23
  const uint32_t D_per_thread = D / D_split;
24
  const uint32_t row_split = 4;
@@ -26,46 +19,6 @@ const uint32_t rows_per_thread = Br / row_split;
26
  const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split;
27
  const uint32_t cols_per_thread = Bc / cols_per_iter;
28
 
29
- layout (push_constant) uniform parameter {
30
- uint32_t N;
31
- uint32_t KV;
32
-
33
- uint32_t ne1;
34
- uint32_t ne2;
35
- uint32_t ne3;
36
-
37
- uint32_t neq2;
38
- uint32_t neq3;
39
- uint32_t nek2;
40
- uint32_t nek3;
41
- uint32_t nev2;
42
- uint32_t nev3;
43
- uint32_t nem1;
44
-
45
- uint32_t nb01;
46
- uint32_t nb02;
47
- uint32_t nb03;
48
- uint32_t nb11;
49
- uint32_t nb12;
50
- uint32_t nb13;
51
- uint32_t nb21;
52
- uint32_t nb22;
53
- uint32_t nb23;
54
- uint32_t nb31;
55
-
56
- float scale;
57
- float max_bias;
58
- float logit_softcap;
59
-
60
- uint32_t mask;
61
- uint32_t n_head_log2;
62
- float m0;
63
- float m1;
64
-
65
- uint32_t gqa_ratio;
66
- uint32_t split_kv;
67
- uint32_t k_num;
68
- } p;
69
 
70
  layout (binding = 0) readonly buffer Q {float data_q[];};
71
  layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};
@@ -74,39 +27,6 @@ layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];};
74
  layout (binding = 2) readonly buffer V {float16_t data_v[];};
75
  layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
76
  layout (binding = 3) readonly buffer M {float16_t data_m[];};
77
- layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
78
-
79
- #if defined(A_TYPE_PACKED16)
80
- #define BINDING_IDX_K 0
81
- #define BINDING_IDX_V 1
82
- layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2];
83
- #endif
84
-
85
- #if defined(DATA_A_Q4_0)
86
- #define BLOCK_BYTE_SIZE 18
87
-
88
- vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
89
- uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
90
- uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
91
- uint shift = (iqs & 0x10) >> 2;
92
- vui_lo >>= shift;
93
- vui_hi >>= shift;
94
-
95
- return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
96
- }
97
- #endif
98
-
99
- #if defined(DATA_A_Q8_0)
100
- #define BLOCK_BYTE_SIZE 34
101
- vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
102
- const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
103
- const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
104
-
105
- return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
106
- }
107
- #endif
108
-
109
- #define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
110
 
111
  // Store the output when doing grouped query attention.
112
  // Rows index by Q's dimension 2, and the first N rows are valid.
@@ -117,27 +37,6 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
117
  return elem;
118
  }
119
 
120
- // Store column zero. This is used to save per-row m and L values for split_k.
121
- ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
122
- {
123
- if (r < N && c == 0) {
124
- uint32_t offset = iq2 + r;
125
- data_o[o_offset + offset] = D_TYPE(elem);
126
- }
127
- return elem;
128
- }
129
-
130
- // Load the slope matrix, indexed by Q's dimension 2.
131
- ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
132
- {
133
- const uint32_t h = iq2 + (r % p.gqa_ratio);
134
-
135
- const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
136
- const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
137
-
138
- return ACC_TYPE(pow(base, ACC_TYPE(exph)));
139
- }
140
-
141
  // These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd
142
  const uint32_t MatBr = 16;
143
  const uint32_t MatBc = 16;
@@ -162,9 +61,9 @@ void main() {
162
  init_iq_shmem(gl_WorkGroupSize);
163
  #endif
164
 
 
 
165
  const uint32_t tid = gl_LocalInvocationIndex;
166
- const uint32_t N = p.N;
167
- const uint32_t KV = p.KV;
168
 
169
  const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split;
170
  const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup;
@@ -173,51 +72,6 @@ void main() {
173
 
174
  #define tile_row(r) (row_tid * rows_per_thread + (r))
175
 
176
- uint32_t i = gl_WorkGroupID.x;
177
- uint32_t split_k_index = 0;
178
-
179
- if (p.k_num > 1) {
180
- i = 0;
181
- split_k_index = gl_WorkGroupID.x;
182
- }
183
-
184
- const uint32_t Tr = CEIL_DIV(N, Br);
185
-
186
- const uint32_t start_j = split_k_index * p.split_kv / Bc;
187
- const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
188
-
189
- // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
190
- // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
191
- const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio;
192
- const uint32_t iq3 = gl_WorkGroupID.z;
193
-
194
- // broadcast factors
195
- const uint32_t rk2 = p.neq2/p.nek2;
196
- const uint32_t rk3 = p.neq3/p.nek3;
197
-
198
- const uint32_t rv2 = p.neq2/p.nev2;
199
- const uint32_t rv3 = p.neq3/p.nev3;
200
-
201
- // k indices
202
- const uint32_t ik3 = iq3 / rk3;
203
- const uint32_t ik2 = iq2 / rk2;
204
-
205
- // v indices
206
- const uint32_t iv3 = iq3 / rv3;
207
- const uint32_t iv2 = iq2 / rv2;
208
-
209
- // nb?1 are already divided by the type size and are in units of elements.
210
- // When using grouped query attention, Q is indexed by iq2, so the stride
211
- // should be nb02 (which is in bytes).
212
- uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
213
- uint32_t k_stride = p.nb11;
214
- uint32_t v_stride = p.nb21;
215
- // When using grouped query attention, all rows use the same mask (stride 0).
216
- // "p.gqa_ratio >> 16" is just a roundabout way of writing zero
217
- // that prevents the compiler from folding the "&" through the select
218
- // and breaking the alignment detection.
219
- uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
220
-
221
  uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
222
 
223
  [[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
 
11
  #extension GL_KHR_cooperative_matrix : enable
12
 
13
  #include "types.comp"
14
+ #include "flash_attn_base.comp"
 
 
 
 
 
 
 
15
 
16
  const uint32_t D_per_thread = D / D_split;
17
  const uint32_t row_split = 4;
 
19
  const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split;
20
  const uint32_t cols_per_thread = Bc / cols_per_iter;
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  layout (binding = 0) readonly buffer Q {float data_q[];};
24
  layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};
 
27
  layout (binding = 2) readonly buffer V {float16_t data_v[];};
28
  layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
29
  layout (binding = 3) readonly buffer M {float16_t data_m[];};
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  // Store the output when doing grouped query attention.
32
  // Rows index by Q's dimension 2, and the first N rows are valid.
 
37
  return elem;
38
  }
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  // These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd
41
  const uint32_t MatBr = 16;
42
  const uint32_t MatBc = 16;
 
61
  init_iq_shmem(gl_WorkGroupSize);
62
  #endif
63
 
64
+ init_indices();
65
+
66
  const uint32_t tid = gl_LocalInvocationIndex;
 
 
67
 
68
  const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split;
69
  const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup;
 
72
 
73
  #define tile_row(r) (row_tid * rows_per_thread + (r))
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
76
 
77
  [[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp CHANGED
@@ -18,62 +18,12 @@
18
 
19
  #include "types.comp"
20
  #include "dequant_funcs_cm2.comp"
21
-
22
- layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
23
-
24
- layout (constant_id = 1) const uint32_t Br = 32;
25
- layout (constant_id = 2) const uint32_t Bc = 32;
26
- layout (constant_id = 3) const uint32_t D = 32;
27
- layout (constant_id = 4) const uint32_t Clamp = gl_CooperativeMatrixClampModeConstantNV;
28
-
29
- layout (push_constant) uniform parameter {
30
- uint32_t N;
31
- uint32_t KV;
32
-
33
- uint32_t ne1;
34
- uint32_t ne2;
35
- uint32_t ne3;
36
-
37
- uint32_t neq2;
38
- uint32_t neq3;
39
- uint32_t nek2;
40
- uint32_t nek3;
41
- uint32_t nev2;
42
- uint32_t nev3;
43
- uint32_t nem1;
44
-
45
- uint32_t nb01;
46
- uint32_t nb02;
47
- uint32_t nb03;
48
- uint32_t nb11;
49
- uint32_t nb12;
50
- uint32_t nb13;
51
- uint32_t nb21;
52
- uint32_t nb22;
53
- uint32_t nb23;
54
- uint32_t nb31;
55
-
56
- float scale;
57
- float max_bias;
58
- float logit_softcap;
59
-
60
- uint32_t mask;
61
- uint32_t n_head_log2;
62
- float m0;
63
- float m1;
64
-
65
- uint32_t gqa_ratio;
66
- uint32_t split_kv;
67
- uint32_t k_num;
68
- } p;
69
 
70
  layout (binding = 0) readonly buffer Q {uint8_t data_q[];};
71
  layout (binding = 1) readonly buffer K {uint8_t data_k[];};
72
  layout (binding = 2) readonly buffer V {uint8_t data_v[];};
73
  layout (binding = 3) readonly buffer M {uint8_t data_m[];};
74
- layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
75
-
76
- #define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
77
 
78
  ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
79
  return max(x, y);
@@ -118,67 +68,12 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
118
  return elem;
119
  }
120
 
121
- // Store column zero. This is used to save per-row m and L values for split_k.
122
- ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
123
- {
124
- if (r < N && c == 0) {
125
- uint32_t offset = iq2 + r;
126
- data_o[o_offset + offset] = D_TYPE(elem);
127
- }
128
- return elem;
129
- }
130
-
131
- // Load the slope matrix, indexed by Q's dimension 2.
132
- ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
133
- {
134
- const uint32_t h = iq2 + (r % p.gqa_ratio);
135
-
136
- const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
137
- const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
138
-
139
- return ACC_TYPE(pow(base, ACC_TYPE(exph)));
140
- }
141
-
142
  void main() {
143
  #ifdef NEEDS_INIT_IQ_SHMEM
144
  init_iq_shmem(gl_WorkGroupSize);
145
  #endif
146
 
147
- const uint32_t N = p.N;
148
- const uint32_t KV = p.KV;
149
-
150
- uint32_t i = gl_WorkGroupID.x;
151
- uint32_t split_k_index = 0;
152
-
153
- if (p.k_num > 1) {
154
- i = 0;
155
- split_k_index = gl_WorkGroupID.x;
156
- }
157
-
158
- const uint32_t Tr = CEIL_DIV(N, Br);
159
-
160
- const uint32_t start_j = split_k_index * p.split_kv / Bc;
161
- const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
162
-
163
- // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
164
- // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
165
- const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio;
166
- const uint32_t iq3 = gl_WorkGroupID.z;
167
-
168
- // broadcast factors
169
- const uint32_t rk2 = p.neq2/p.nek2;
170
- const uint32_t rk3 = p.neq3/p.nek3;
171
-
172
- const uint32_t rv2 = p.neq2/p.nev2;
173
- const uint32_t rv3 = p.neq3/p.nev3;
174
-
175
- // k indices
176
- const uint32_t ik3 = iq3 / rk3;
177
- const uint32_t ik2 = iq2 / rk2;
178
-
179
- // v indices
180
- const uint32_t iv3 = iq3 / rv3;
181
- const uint32_t iv2 = iq2 / rv2;
182
 
183
  tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
184
  tensorLayoutNV<2, Clamp> tensorLayoutK = createTensorLayoutNV(2, Clamp);
@@ -195,17 +90,6 @@ void main() {
195
  tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D);
196
  tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D);
197
 
198
- // nb?1 are already divided by the type size and are in units of elements.
199
- // When using grouped query attention, Q is indexed by iq2, so the stride
200
- // should be nb02 (which is in bytes).
201
- uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
202
- uint32_t k_stride = p.nb11;
203
- uint32_t v_stride = p.nb21;
204
- // When using grouped query attention, all rows use the same mask (stride 0).
205
- // "p.gqa_ratio >> 16" is just a roundabout way of writing zero
206
- // that prevents the compiler from folding the "&" through the select
207
- // and breaking the alignment detection.
208
- uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
209
  // hint to the compiler that strides are aligned for the aligned variant of the shader
210
  if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
211
  {
 
18
 
19
  #include "types.comp"
20
  #include "dequant_funcs_cm2.comp"
21
+ #include "flash_attn_base.comp"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  layout (binding = 0) readonly buffer Q {uint8_t data_q[];};
24
  layout (binding = 1) readonly buffer K {uint8_t data_k[];};
25
  layout (binding = 2) readonly buffer V {uint8_t data_v[];};
26
  layout (binding = 3) readonly buffer M {uint8_t data_m[];};
 
 
 
27
 
28
  ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
29
  return max(x, y);
 
68
  return elem;
69
  }
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  void main() {
72
  #ifdef NEEDS_INIT_IQ_SHMEM
73
  init_iq_shmem(gl_WorkGroupSize);
74
  #endif
75
 
76
+ init_indices();
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
79
  tensorLayoutNV<2, Clamp> tensorLayoutK = createTensorLayoutNV(2, Clamp);
 
90
  tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D);
91
  tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D);
92
 
 
 
 
 
 
 
 
 
 
 
 
93
  // hint to the compiler that strides are aligned for the aligned variant of the shader
94
  if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
95
  {