Eve commited on
Commit
ffdf466
·
1 Parent(s): f6cff0a

vulkan: matmul dequantization improvements (llama/12015)

Browse files

* faster dequant for old quants

* dont use unpack for iq4_nl

* vec2 unpack for q8

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp CHANGED
@@ -82,9 +82,9 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
82
  return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1]));
83
  }
84
  vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
85
- uint32_t v0 = data_a_packed16[a_offset + ib].qs[iqs/2];
86
- uint32_t v1 = data_a_packed16[a_offset + ib].qs[iqs/2 + 1];
87
- return vec4(int8_t(v0 & 0xFF), int8_t(v0 >> 8), int8_t(v1 & 0xFF), int8_t(v1 >> 8));
88
  }
89
  #endif
90
 
 
82
  return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1]));
83
  }
84
  vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
85
+ const i8vec2 v0 = unpack8(data_a_packed16[a_offset + ib].qs[iqs/2]);
86
+ const i8vec2 v1 = unpack8(data_a_packed16[a_offset + ib].qs[iqs/2 + 1]);
87
+ return vec4(v0.x, v0.y, v1.x, v1.y);
88
  }
89
  #endif
90
 
ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp CHANGED
@@ -92,7 +92,7 @@ float16_t dequantFuncQ8_0(const in decodeBufQ8_0 bl, const in uint blockCoords[2
92
  const uint iqs = idx;
93
 
94
  // Load 16b and select the byte for this element
95
- int32_t qs = unpack8(int32_t(bl.block.qs[(iqs & 0x1E) >> 1]))[iqs & 1];
96
  float16_t ret = float16_t(qs) * d;
97
  return ret;
98
  }
 
92
  const uint iqs = idx;
93
 
94
  // Load 16b and select the byte for this element
95
+ int32_t qs = unpack8(bl.block.qs[(iqs & 0x1E) >> 1])[iqs & 1];
96
  float16_t ret = float16_t(qs) * d;
97
  return ret;
98
  }
ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp CHANGED
@@ -32,6 +32,13 @@
32
  layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
33
 
34
  layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
 
 
 
 
 
 
 
35
  layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
36
  layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
37
 
@@ -243,74 +250,100 @@ void main() {
243
  #endif
244
  #elif defined(DATA_A_Q4_0)
245
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
246
- const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
247
-
248
- const uint ib = idx / 16;
249
- const uint iqs = idx & 0xF;
250
-
251
- const float d = float(data_a[ib].d);
252
- const uint vui = uint(data_a[ib].qs[iqs]);
253
- const vec2 v = (vec2(vui & 0xF, vui >> 4) - 8.0f) * d;
254
-
255
- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
256
- buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
 
 
 
 
 
 
 
257
  #elif defined(DATA_A_Q4_1)
258
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
259
- const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
260
-
261
- const uint ib = idx / 16;
262
- const uint iqs = idx & 0xF;
263
-
264
- const float d = float(data_a[ib].d);
265
- const float m = float(data_a[ib].m);
266
- const uint vui = uint(data_a[ib].qs[iqs]);
267
- const vec2 v = vec2(vui & 0xF, vui >> 4) * d + m;
268
-
269
- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
270
- buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
 
 
 
 
 
 
 
271
  #elif defined(DATA_A_Q5_0)
272
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
273
- const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
274
 
275
- const uint ib = idx / 16;
276
- const uint iqs = idx & 0xF;
277
 
278
- const float d = float(data_a[ib].d);
279
- const uint uint_qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0];
280
- const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
281
- const uint vui = uint(data_a[ib].qs[iqs]);
282
- const vec2 v = (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d;
 
 
283
 
284
  buf_a[buf_idx ] = FLOAT_TYPE(v.x);
 
285
  buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
 
286
  #elif defined(DATA_A_Q5_1)
287
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
288
- const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
289
 
290
- const uint ib = idx / 16;
291
- const uint iqs = idx & 0xF;
292
 
293
- const float d = float(data_a[ib].d);
294
- const float m = float(data_a[ib].m);
295
- const uint uint_qh = data_a[ib].qh;
296
- const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
297
- const uint vui = uint(data_a[ib].qs[iqs]);
298
- const vec2 v = vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m;
 
 
299
 
300
  buf_a[buf_idx ] = FLOAT_TYPE(v.x);
 
301
  buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
 
302
  #elif defined(DATA_A_Q8_0)
303
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
304
  const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
305
 
306
- const uint ib = idx / 16;
307
- const uint iqs = (idx & 0xF) * 2;
308
 
309
- const float d = float(data_a[ib].d);
310
- const vec2 v = vec2(int(data_a[ib].qs[iqs]), int(data_a[ib].qs[iqs + 1])) * d;
 
 
311
 
312
  buf_a[buf_idx ] = FLOAT_TYPE(v.x);
313
  buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
 
 
314
  #elif defined(DATA_A_Q2_K)
315
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
316
  const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
@@ -623,17 +656,18 @@ void main() {
623
  buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
624
  #elif defined(DATA_A_IQ4_NL)
625
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
626
- const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
627
 
628
- const uint ib = idx / 16;
629
- const uint iqs = idx & 0xF;
630
 
631
- const float d = float(data_a[ib].d);
632
- const uint vui = uint(data_a[ib].qs[iqs]);
633
- const vec2 v = vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]) * d;
634
 
635
- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
636
- buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
 
 
637
  #endif
638
  }
639
  [[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
 
32
  layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
33
 
34
  layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
35
+ #if defined(A_TYPE_PACKED16)
36
+ layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
37
+ #endif
38
+ #if defined(A_TYPE_PACKED32)
39
+ layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
40
+ #endif
41
+
42
  layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
43
  layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
44
 
 
250
  #endif
251
  #elif defined(DATA_A_Q4_0)
252
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
253
+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a;
254
+
255
+ const uint ib = idx / 4;
256
+ const uint iqs = idx & 0x03;
257
+
258
+ const float d = float(data_a_packed16[ib].d);
259
+ const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16);
260
+ const vec4 v0 = (vec4(unpack8(vui & 0x0F0F0F0F)) - 8.0f) * d;
261
+ const vec4 v1 = (vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) - 8.0f) * d;
262
+
263
+ buf_a[buf_idx ] = FLOAT_TYPE(v0.x);
264
+ buf_a[buf_idx + 1 ] = FLOAT_TYPE(v0.y);
265
+ buf_a[buf_idx + 2 ] = FLOAT_TYPE(v0.z);
266
+ buf_a[buf_idx + 3 ] = FLOAT_TYPE(v0.w);
267
+ buf_a[buf_idx + 16] = FLOAT_TYPE(v1.x);
268
+ buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y);
269
+ buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z);
270
+ buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w);
271
  #elif defined(DATA_A_Q4_1)
272
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
273
+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a;
274
+
275
+ const uint ib = idx / 4;
276
+ const uint iqs = idx & 0x03;
277
+
278
+ const float d = float(data_a_packed16[ib].d);
279
+ const float m = float(data_a_packed16[ib].m);
280
+ const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16);
281
+ const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * d + m;
282
+ const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * d + m;
283
+
284
+ buf_a[buf_idx ] = FLOAT_TYPE(v0.x);
285
+ buf_a[buf_idx + 1 ] = FLOAT_TYPE(v0.y);
286
+ buf_a[buf_idx + 2 ] = FLOAT_TYPE(v0.z);
287
+ buf_a[buf_idx + 3 ] = FLOAT_TYPE(v0.w);
288
+ buf_a[buf_idx + 16] = FLOAT_TYPE(v1.x);
289
+ buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y);
290
+ buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z);
291
+ buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w);
292
  #elif defined(DATA_A_Q5_0)
293
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
294
+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a;
295
 
296
+ const uint ib = idx / 8;
297
+ const uint iqs = idx & 0x07;
298
 
299
+ const float d = float(data_a_packed16[ib].d);
300
+ const uint uint_qh = uint(data_a_packed16[ib].qh[1]) << 16 | uint(data_a_packed16[ib].qh[0]);
301
+ const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10);
302
+ const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10);
303
+
304
+ const uint vui = uint(data_a_packed16[ib].qs[iqs]);
305
+ const vec4 v = (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f) * d;
306
 
307
  buf_a[buf_idx ] = FLOAT_TYPE(v.x);
308
+ buf_a[buf_idx + 1 ] = FLOAT_TYPE(v.z);
309
  buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
310
+ buf_a[buf_idx + 17] = FLOAT_TYPE(v.w);
311
  #elif defined(DATA_A_Q5_1)
312
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
313
+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a;
314
 
315
+ const uint ib = idx / 8;
316
+ const uint iqs = idx & 0x07;
317
 
318
+ const float d = float(data_a_packed16[ib].d);
319
+ const float m = float(data_a_packed16[ib].m);
320
+ const uint uint_qh = data_a_packed16[ib].qh;
321
+ const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10);
322
+ const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10);
323
+
324
+ const uint vui = uint(data_a_packed16[ib].qs[iqs]);
325
+ const vec4 v = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) * d + m;
326
 
327
  buf_a[buf_idx ] = FLOAT_TYPE(v.x);
328
+ buf_a[buf_idx + 1 ] = FLOAT_TYPE(v.z);
329
  buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
330
+ buf_a[buf_idx + 17] = FLOAT_TYPE(v.w);
331
  #elif defined(DATA_A_Q8_0)
332
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
333
  const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
334
 
335
+ const uint ib = idx / 8;
336
+ const uint iqs = idx & 0x07;
337
 
338
+ const float d = float(data_a_packed16[ib].d);
339
+ const i8vec2 v0 = unpack8(data_a_packed16[ib].qs[2*iqs]);
340
+ const i8vec2 v1 = unpack8(data_a_packed16[ib].qs[2*iqs + 1]);
341
+ const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d;
342
 
343
  buf_a[buf_idx ] = FLOAT_TYPE(v.x);
344
  buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
345
+ buf_a[buf_idx + 2] = FLOAT_TYPE(v.z);
346
+ buf_a[buf_idx + 3] = FLOAT_TYPE(v.w);
347
  #elif defined(DATA_A_Q2_K)
348
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
349
  const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
 
656
  buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
657
  #elif defined(DATA_A_IQ4_NL)
658
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
659
+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a;
660
 
661
+ const uint ib = idx / 8;
662
+ const uint iqs = idx & 0x07;
663
 
664
+ const FLOAT_TYPE d = FLOAT_TYPE(data_a_packed16[ib].d);
665
+ const uint vui = uint(data_a_packed16[ib].qs[iqs]);
 
666
 
667
+ buf_a[buf_idx ] = FLOAT_TYPE(kvalues_iq4nl[vui & 0xF]) * d;
668
+ buf_a[buf_idx + 1 ] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]) * d;
669
+ buf_a[buf_idx + 16] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)]) * d;
670
+ buf_a[buf_idx + 17] = FLOAT_TYPE(kvalues_iq4nl[vui >> 12]) * d;
671
  #endif
672
  }
673
  [[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
ggml/src/ggml-vulkan/vulkan-shaders/types.comp CHANGED
@@ -139,7 +139,7 @@ struct block_q8_0
139
  struct block_q8_0_packed16
140
  {
141
  float16_t d;
142
- uint16_t qs[32/2];
143
  };
144
 
145
  #if defined(DATA_A_Q8_0)
 
139
  struct block_q8_0_packed16
140
  {
141
  float16_t d;
142
+ int16_t qs[32/2];
143
  };
144
 
145
  #if defined(DATA_A_Q8_0)
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp CHANGED
@@ -325,11 +325,17 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
325
  string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
326
 
327
  for (const auto& tname : type_names) {
 
 
 
 
 
 
328
  std::string data_a_key = "DATA_A_" + to_uppercase(tname);
329
  // For unaligned, load one at a time for f32/f16, or two at a time for quants
330
- std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16") ? "1" : "2";
331
  // For aligned matmul loads
332
- std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : "2";
333
 
334
  // don't generate f32 variants for coopmat2
335
  if (!coopmat2) {
 
325
  string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
326
 
327
  for (const auto& tname : type_names) {
328
+ std::string load_vec_quant = "2";
329
+ if ((tname == "q4_0") || (tname == "q4_1"))
330
+ load_vec_quant = "8";
331
+ else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq4_nl"))
332
+ load_vec_quant = "4";
333
+
334
  std::string data_a_key = "DATA_A_" + to_uppercase(tname);
335
  // For unaligned, load one at a time for f32/f16, or two at a time for quants
336
+ std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16") ? "1" : load_vec_quant;
337
  // For aligned matmul loads
338
+ std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : load_vec_quant;
339
 
340
  // don't generate f32 variants for coopmat2
341
  if (!coopmat2) {