ggerganov commited on
Commit
58b0822
·
1 Parent(s): f083887

metal : small-batch mat-mul kernels (llama/10581)

Browse files

* metal : small-batch mat-mul kernels

ggml-ci

* metal : add rest of types

ggml-ci

* metal : final adjustments

ggml-ci

* metal : add comments

ggml-ci

ggml/src/ggml-metal/ggml-metal-impl.h CHANGED
@@ -192,6 +192,30 @@ typedef struct {
192
  int16_t r3;
193
  } ggml_metal_kargs_mul_mv;
194
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  typedef struct {
196
  int32_t nei0;
197
  int32_t nei1;
 
192
  int16_t r3;
193
  } ggml_metal_kargs_mul_mv;
194
 
195
+ typedef struct {
196
+ int32_t ne00;
197
+ int32_t ne01;
198
+ int32_t ne02;
199
+ uint64_t nb00;
200
+ uint64_t nb01;
201
+ uint64_t nb02;
202
+ uint64_t nb03;
203
+ int32_t ne10;
204
+ int32_t ne11;
205
+ int32_t ne12;
206
+ uint64_t nb10;
207
+ uint64_t nb11;
208
+ uint64_t nb12;
209
+ uint64_t nb13;
210
+ int32_t ne0;
211
+ int32_t ne1;
212
+ int16_t r2;
213
+ int16_t r3;
214
+ int16_t nsg;
215
+ int16_t nxpsg;
216
+ int16_t r1ptg;
217
+ } ggml_metal_kargs_mul_mv_ext;
218
+
219
  typedef struct {
220
  int32_t nei0;
221
  int32_t nei1;
ggml/src/ggml-metal/ggml-metal.m CHANGED
@@ -175,6 +175,46 @@ enum ggml_metal_kernel_type {
175
  GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
176
  GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
177
  GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,
179
  GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,
180
  GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,
@@ -702,6 +742,46 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
702
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction);
703
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
704
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
705
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction);
706
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction);
707
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction);
@@ -1936,30 +2016,180 @@ static void ggml_metal_encode_node(
1936
 
1937
  // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1938
  // to the matrix-vector kernel
1939
- int ne11_mm_min = 4;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1940
 
1941
- #if 0
1942
- // the numbers below are measured on M2 Ultra for 7B and 13B models
1943
- // these numbers do not translate to other devices or model sizes
1944
- // TODO: need to find a better approach
1945
- if ([device.name isEqualToString:@"Apple M2 Ultra"]) {
1946
- switch (src0t) {
1947
- case GGML_TYPE_F16: ne11_mm_min = 2; break;
1948
- case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
1949
- case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
1950
- case GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
1951
- case GGML_TYPE_Q4_0:
1952
- case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
1953
- case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
1954
- case GGML_TYPE_Q5_0: // not tested yet
1955
- case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
1956
- case GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
1957
- case GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
1958
- default: ne11_mm_min = 1; break;
1959
- }
1960
- }
1961
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1962
 
 
 
 
1963
  // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1964
  // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1965
  if ([device supportsFamily:MTLGPUFamilyApple7] &&
 
175
  GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
176
  GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
177
  GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
178
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,
179
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,
180
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,
181
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5,
182
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2,
183
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3,
184
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4,
185
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5,
186
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2,
187
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3,
188
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4,
189
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5,
190
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2,
191
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3,
192
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4,
193
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5,
194
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2,
195
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3,
196
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4,
197
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5,
198
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2,
199
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3,
200
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4,
201
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5,
202
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2,
203
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3,
204
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4,
205
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5,
206
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2,
207
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3,
208
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4,
209
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5,
210
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2,
211
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3,
212
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4,
213
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5,
214
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2,
215
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3,
216
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4,
217
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5,
218
  GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,
219
  GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,
220
  GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,
 
742
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction);
743
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
744
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
745
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction);
746
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction);
747
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction);
748
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5, mul_mv_ext_f16_f32_r1_5, has_simdgroup_reduction);
749
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2, mul_mv_ext_q4_0_f32_r1_2, has_simdgroup_reduction);
750
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3, mul_mv_ext_q4_0_f32_r1_3, has_simdgroup_reduction);
751
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4, mul_mv_ext_q4_0_f32_r1_4, has_simdgroup_reduction);
752
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5, mul_mv_ext_q4_0_f32_r1_5, has_simdgroup_reduction);
753
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2, mul_mv_ext_q4_1_f32_r1_2, has_simdgroup_reduction);
754
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3, mul_mv_ext_q4_1_f32_r1_3, has_simdgroup_reduction);
755
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4, mul_mv_ext_q4_1_f32_r1_4, has_simdgroup_reduction);
756
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5, mul_mv_ext_q4_1_f32_r1_5, has_simdgroup_reduction);
757
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2, mul_mv_ext_q5_0_f32_r1_2, has_simdgroup_reduction);
758
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3, mul_mv_ext_q5_0_f32_r1_3, has_simdgroup_reduction);
759
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4, mul_mv_ext_q5_0_f32_r1_4, has_simdgroup_reduction);
760
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5, mul_mv_ext_q5_0_f32_r1_5, has_simdgroup_reduction);
761
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2, mul_mv_ext_q5_1_f32_r1_2, has_simdgroup_reduction);
762
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3, mul_mv_ext_q5_1_f32_r1_3, has_simdgroup_reduction);
763
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4, mul_mv_ext_q5_1_f32_r1_4, has_simdgroup_reduction);
764
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5, mul_mv_ext_q5_1_f32_r1_5, has_simdgroup_reduction);
765
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2, mul_mv_ext_q8_0_f32_r1_2, has_simdgroup_reduction);
766
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, mul_mv_ext_q8_0_f32_r1_3, has_simdgroup_reduction);
767
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, mul_mv_ext_q8_0_f32_r1_4, has_simdgroup_reduction);
768
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, mul_mv_ext_q8_0_f32_r1_5, has_simdgroup_reduction);
769
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2, mul_mv_ext_q4_K_f32_r1_2, has_simdgroup_reduction);
770
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3, mul_mv_ext_q4_K_f32_r1_3, has_simdgroup_reduction);
771
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4, mul_mv_ext_q4_K_f32_r1_4, has_simdgroup_reduction);
772
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5, mul_mv_ext_q4_K_f32_r1_5, has_simdgroup_reduction);
773
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2, mul_mv_ext_q5_K_f32_r1_2, has_simdgroup_reduction);
774
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3, mul_mv_ext_q5_K_f32_r1_3, has_simdgroup_reduction);
775
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4, mul_mv_ext_q5_K_f32_r1_4, has_simdgroup_reduction);
776
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5, mul_mv_ext_q5_K_f32_r1_5, has_simdgroup_reduction);
777
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2, mul_mv_ext_q6_K_f32_r1_2, has_simdgroup_reduction);
778
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3, mul_mv_ext_q6_K_f32_r1_3, has_simdgroup_reduction);
779
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4, mul_mv_ext_q6_K_f32_r1_4, has_simdgroup_reduction);
780
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5, mul_mv_ext_q6_K_f32_r1_5, has_simdgroup_reduction);
781
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2, mul_mv_ext_iq4_nl_f32_r1_2, has_simdgroup_reduction);
782
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3, mul_mv_ext_iq4_nl_f32_r1_3, has_simdgroup_reduction);
783
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4, mul_mv_ext_iq4_nl_f32_r1_4, has_simdgroup_reduction);
784
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5, mul_mv_ext_iq4_nl_f32_r1_5, has_simdgroup_reduction);
785
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction);
786
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction);
787
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction);
 
2016
 
2017
  // find the break-even point where the matrix-matrix kernel becomes more efficient compared
2018
  // to the matrix-vector kernel
2019
+ const int ne11_mm_min = 4;
2020
+
2021
+ // first try to use small-batch mat-mv kernels
2022
+ // these should be efficient for BS [2, ~8]
2023
+ if (src1t == GGML_TYPE_F32 && (ne00%256 == 0) &&
2024
+ (
2025
+ (
2026
+ (
2027
+ src0t == GGML_TYPE_F16 || // TODO: helper function
2028
+ src0t == GGML_TYPE_Q4_0 ||
2029
+ src0t == GGML_TYPE_Q4_1 ||
2030
+ src0t == GGML_TYPE_Q5_0 ||
2031
+ src0t == GGML_TYPE_Q5_1 ||
2032
+ src0t == GGML_TYPE_Q8_0 ||
2033
+ src0t == GGML_TYPE_IQ4_NL ||
2034
+ false) && (ne11 >= 2 && ne11 <= 8)
2035
+ ) ||
2036
+ (
2037
+ (
2038
+ src0t == GGML_TYPE_Q4_K ||
2039
+ src0t == GGML_TYPE_Q5_K ||
2040
+ src0t == GGML_TYPE_Q6_K ||
2041
+ false) && (ne11 >= 4 && ne11 <= 8)
2042
+ )
2043
+ )
2044
+ ) {
2045
+ // TODO: determine the optimal parameters based on grid utilization
2046
+ // I still don't know why we should not always use the maximum available threads:
2047
+ //
2048
+ // nsg = pipeline.maxTotalThreadsPerThreadgroup / 32
2049
+ //
2050
+ // my current hypothesis is that the work grid is not evenly divisible for different nsg
2051
+ // values and there can be some tail effects when nsg is high. need to confirm this
2052
+ //
2053
+ const int nsg = 2; // num simdgroups per threadgroup
2054
+ const int nxpsg = ne11 < 3 ? 16 : 8; // num threads along row per simdgroup
2055
+ const int nypsg = 32/nxpsg; // num threads along col per simdgroup (i.e. a simdgroup processes that many src0 rows at a time)
2056
+ const int r0ptg = nypsg*nsg; // num src0 rows per threadgroup
2057
+ int r1ptg = 4; // num src1 rows per threadgroup
2058
+
2059
+ // note: not sure how optimal are those across all different hardware. there might be someting cleverer
2060
+ switch (ne11) {
2061
+ case 2:
2062
+ r1ptg = 2; break;
2063
+ case 3:
2064
+ case 6:
2065
+ r1ptg = 3; break;
2066
+ case 4:
2067
+ case 7:
2068
+ case 8:
2069
+ r1ptg = 4; break;
2070
+ case 5:
2071
+ r1ptg = 5; break;
2072
+ };
2073
 
2074
+ id<MTLComputePipelineState> pipeline = nil;
2075
+
2076
+ switch (src0->type) {
2077
+ case GGML_TYPE_F16:
2078
+ switch (r1ptg) {
2079
+ case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2].pipeline; break;
2080
+ case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3].pipeline; break;
2081
+ case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4].pipeline; break;
2082
+ case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5].pipeline; break;
2083
+ default: GGML_ABORT("not implemented");
2084
+ } break;
2085
+ case GGML_TYPE_Q4_0:
2086
+ switch (r1ptg) {
2087
+ case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2].pipeline; break;
2088
+ case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3].pipeline; break;
2089
+ case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4].pipeline; break;
2090
+ case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5].pipeline; break;
2091
+ default: GGML_ABORT("not implemented");
2092
+ } break;
2093
+ case GGML_TYPE_Q4_1:
2094
+ switch (r1ptg) {
2095
+ case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2].pipeline; break;
2096
+ case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3].pipeline; break;
2097
+ case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4].pipeline; break;
2098
+ case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5].pipeline; break;
2099
+ default: GGML_ABORT("not implemented");
2100
+ } break;
2101
+ case GGML_TYPE_Q5_0:
2102
+ switch (r1ptg) {
2103
+ case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2].pipeline; break;
2104
+ case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3].pipeline; break;
2105
+ case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4].pipeline; break;
2106
+ case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5].pipeline; break;
2107
+ default: GGML_ABORT("not implemented");
2108
+ } break;
2109
+ case GGML_TYPE_Q5_1:
2110
+ switch (r1ptg) {
2111
+ case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2].pipeline; break;
2112
+ case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3].pipeline; break;
2113
+ case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4].pipeline; break;
2114
+ case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5].pipeline; break;
2115
+ default: GGML_ABORT("not implemented");
2116
+ } break;
2117
+ case GGML_TYPE_Q8_0:
2118
+ switch (r1ptg) {
2119
+ case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2].pipeline; break;
2120
+ case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3].pipeline; break;
2121
+ case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4].pipeline; break;
2122
+ case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5].pipeline; break;
2123
+ default: GGML_ABORT("not implemented");
2124
+ } break;
2125
+ case GGML_TYPE_Q4_K:
2126
+ switch (r1ptg) {
2127
+ case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2].pipeline; break;
2128
+ case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3].pipeline; break;
2129
+ case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4].pipeline; break;
2130
+ case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5].pipeline; break;
2131
+ default: GGML_ABORT("not implemented");
2132
+ } break;
2133
+ case GGML_TYPE_Q5_K:
2134
+ switch (r1ptg) {
2135
+ case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2].pipeline; break;
2136
+ case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3].pipeline; break;
2137
+ case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4].pipeline; break;
2138
+ case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5].pipeline; break;
2139
+ default: GGML_ABORT("not implemented");
2140
+ } break;
2141
+ case GGML_TYPE_Q6_K:
2142
+ switch (r1ptg) {
2143
+ case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2].pipeline; break;
2144
+ case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3].pipeline; break;
2145
+ case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4].pipeline; break;
2146
+ case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5].pipeline; break;
2147
+ default: GGML_ABORT("not implemented");
2148
+ } break;
2149
+ case GGML_TYPE_IQ4_NL:
2150
+ switch (r1ptg) {
2151
+ case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2].pipeline; break;
2152
+ case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3].pipeline; break;
2153
+ case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4].pipeline; break;
2154
+ case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5].pipeline; break;
2155
+ default: GGML_ABORT("not implemented");
2156
+ } break;
2157
+ default: GGML_ABORT("not implemented");
2158
+ }
2159
+
2160
+ ggml_metal_kargs_mul_mv_ext args = {
2161
+ /*.ne00 =*/ ne00,
2162
+ /*.ne01 =*/ ne01,
2163
+ /*.ne02 =*/ ne02,
2164
+ /*.nb00 =*/ nb00,
2165
+ /*.nb01 =*/ nb01,
2166
+ /*.nb02 =*/ nb02,
2167
+ /*.nb03 =*/ nb03,
2168
+ /*.ne10 =*/ ne10,
2169
+ /*.ne11 =*/ ne11,
2170
+ /*.ne12 =*/ ne12,
2171
+ /*.nb10 =*/ nb10,
2172
+ /*.nb11 =*/ nb11,
2173
+ /*.nb12 =*/ nb12,
2174
+ /*.nb13 =*/ nb13,
2175
+ /*.ne0 =*/ ne0,
2176
+ /*.ne1 =*/ ne1,
2177
+ /*.r2 =*/ r2,
2178
+ /*.r3 =*/ r3,
2179
+ /*.nsg =*/ nsg,
2180
+ /*.nxpsg =*/ nxpsg,
2181
+ /*.r1ptg =*/ r1ptg,
2182
+ };
2183
+
2184
+ [encoder setComputePipelineState:pipeline];
2185
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
2186
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2187
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
2188
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
2189
 
2190
+ //printf("ne01 = %lld nr0ptg = %d\n", ne01, nr0ptg);
2191
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + r0ptg - 1)/r0ptg, (ne11 + r1ptg - 1)/r1ptg, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
2192
+ } else
2193
  // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
2194
  // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
2195
  if ([device supportsFamily:MTLGPUFamilyApple7] &&
ggml/src/ggml-metal/ggml-metal.metal CHANGED
@@ -47,6 +47,11 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
47
  reg = (type4x4)(*src);
48
  }
49
 
 
 
 
 
 
50
  #if defined(GGML_METAL_USE_BF16)
51
  template <typename type4x4>
52
  void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) {
@@ -55,7 +60,7 @@ void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & re
55
  #endif
56
 
57
  template <typename type4x4>
58
- void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
59
  device const uint16_t * qs = ((device const uint16_t *)xb + 1);
60
  const float d1 = il ? (xb->d / 16.h) : xb->d;
61
  const float d2 = d1 / 256.f;
@@ -73,8 +78,23 @@ void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg
73
  reg = (type4x4) reg_f;
74
  }
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  template <typename type4x4>
77
- void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
78
  device const uint16_t * qs = ((device const uint16_t *)xb + 2);
79
  const float d1 = il ? (xb->d / 16.h) : xb->d;
80
  const float d2 = d1 / 256.f;
@@ -92,8 +112,23 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
92
  reg = (type4x4) reg_f;
93
  }
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  template <typename type4x4>
96
- void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) {
97
  device const uint16_t * qs = ((device const uint16_t *)xb + 3);
98
  const float d = xb->d;
99
  const float md = -16.h * xb->d;
@@ -124,8 +159,38 @@ void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg
124
  reg = (type4x4) reg_f;
125
  }
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  template <typename type4x4>
128
- void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) {
129
  device const uint16_t * qs = ((device const uint16_t *)xb + 4);
130
  const float d = xb->d;
131
  const float m = xb->m;
@@ -156,10 +221,40 @@ void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg
156
  reg = (type4x4) reg_f;
157
  }
158
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  template <typename type4x4>
160
  void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
161
  device const int8_t * qs = ((device const int8_t *)xb->qs);
162
- const half d = xb->d;
163
 
164
  float4x4 reg_f;
165
 
@@ -170,6 +265,16 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
170
  reg = (type4x4) reg_f;
171
  }
172
 
 
 
 
 
 
 
 
 
 
 
173
  template <typename type4x4>
174
  void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
175
  const float d = xb->d;
@@ -224,7 +329,7 @@ static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q
224
  }
225
 
226
  template <typename type4x4>
227
- void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
228
  device const uchar * q = xb->qs;
229
 
230
  short is = (il/4) * 2;
@@ -236,7 +341,7 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
236
  const float dl = d * sc[0];
237
  const float ml = min * sc[1];
238
 
239
- const ushort mask = il<2 ? 0x0F : 0xF0;
240
  for (int i = 0; i < 16; ++i) {
241
  reg[i/4][i%4] = dl * (q[i] & mask) - ml;
242
  }
@@ -469,6 +574,19 @@ void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4
469
  }
470
  }
471
 
 
 
 
 
 
 
 
 
 
 
 
 
 
472
  template <typename type4x4>
473
  void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
474
  // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
@@ -1809,6 +1927,301 @@ kernel void kernel_mul_mv_q8_0_f32(
1809
  kernel_mul_mv_q8_0_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
1810
  }
1811
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1812
  #define N_MV_T_T 4
1813
 
1814
  template<typename T0, typename T04, typename T1, typename T14, typename args_t>
 
47
  reg = (type4x4)(*src);
48
  }
49
 
50
+ template <typename type4>
51
+ void dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) {
52
+ reg = (type4)(*(src + il));
53
+ }
54
+
55
  #if defined(GGML_METAL_USE_BF16)
56
  template <typename type4x4>
57
  void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) {
 
60
  #endif
61
 
62
  template <typename type4x4>
63
+ void dequantize_q4_0(device const block_q4_0 * xb, short il, thread type4x4 & reg) {
64
  device const uint16_t * qs = ((device const uint16_t *)xb + 1);
65
  const float d1 = il ? (xb->d / 16.h) : xb->d;
66
  const float d2 = d1 / 256.f;
 
78
  reg = (type4x4) reg_f;
79
  }
80
 
81
+ template <typename type4>
82
+ void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & reg) {
83
+ device const uint16_t * qs = ((device const uint16_t *)xb + 1);
84
+ const float d1 = (il/4) ? (xb->d / 16.h) : xb->d;
85
+ const float d2 = d1 / 256.f;
86
+ const float md = -8.h * xb->d;
87
+ const ushort mask0 = (il/4) ? 0x00F0 : 0x000F;
88
+ const ushort mask1 = mask0 << 8;
89
+
90
+ for (int i = 0; i < 2; i++) {
91
+ reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + md;
92
+ reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + md;
93
+ }
94
+ }
95
+
96
  template <typename type4x4>
97
+ void dequantize_q4_1(device const block_q4_1 * xb, short il, thread type4x4 & reg) {
98
  device const uint16_t * qs = ((device const uint16_t *)xb + 2);
99
  const float d1 = il ? (xb->d / 16.h) : xb->d;
100
  const float d2 = d1 / 256.f;
 
112
  reg = (type4x4) reg_f;
113
  }
114
 
115
+ template <typename type4>
116
+ void dequantize_q4_1_t4(device const block_q4_1 * xb, short il, thread type4 & reg) {
117
+ device const uint16_t * qs = ((device const uint16_t *)xb + 2);
118
+ const float d1 = (il/4) ? (xb->d / 16.h) : xb->d;
119
+ const float d2 = d1 / 256.f;
120
+ const float m = xb->m;
121
+ const ushort mask0 = (il/4) ? 0x00F0 : 0x000F;
122
+ const ushort mask1 = mask0 << 8;
123
+
124
+ for (int i = 0; i < 2; i++) {
125
+ reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + m;
126
+ reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + m;
127
+ }
128
+ }
129
+
130
  template <typename type4x4>
131
+ void dequantize_q5_0(device const block_q5_0 * xb, short il, thread type4x4 & reg) {
132
  device const uint16_t * qs = ((device const uint16_t *)xb + 3);
133
  const float d = xb->d;
134
  const float md = -16.h * xb->d;
 
159
  reg = (type4x4) reg_f;
160
  }
161
 
162
+ template <typename type4>
163
+ void dequantize_q5_0_t4(device const block_q5_0 * xb, short il, thread type4 & reg) {
164
+ device const uint16_t * qs = ((device const uint16_t *)xb + 3);
165
+ const float d = xb->d;
166
+ const float md = -16.h * xb->d;
167
+ const ushort mask = (il/4) ? 0x00F0 : 0x000F;
168
+
169
+ const uint32_t qh = *((device const uint32_t *)xb->qh);
170
+
171
+ const int x_mv = (il/4) ? 4 : 0;
172
+
173
+ const int gh_mv = (il/4) ? 12 : 0;
174
+ const int gh_bk = (il/4) ? 0 : 4;
175
+
176
+ for (int ii = 0; ii < 2; ii++) {
177
+ int i = 2*(il%4) + ii;
178
+
179
+ // extract the 5-th bits for x0 and x1
180
+ const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
181
+ const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
182
+
183
+ // combine the 4-bits from qs with the 5th bit
184
+ const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
185
+ const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
186
+
187
+ reg[2*ii + 0] = d * x0 + md;
188
+ reg[2*ii + 1] = d * x1 + md;
189
+ }
190
+ }
191
+
192
  template <typename type4x4>
193
+ void dequantize_q5_1(device const block_q5_1 * xb, short il, thread type4x4 & reg) {
194
  device const uint16_t * qs = ((device const uint16_t *)xb + 4);
195
  const float d = xb->d;
196
  const float m = xb->m;
 
221
  reg = (type4x4) reg_f;
222
  }
223
 
224
+ template <typename type4>
225
+ void dequantize_q5_1_t4(device const block_q5_1 * xb, short il, thread type4 & reg) {
226
+ device const uint16_t * qs = ((device const uint16_t *)xb + 4);
227
+ const float d = xb->d;
228
+ const float m = xb->m;
229
+ const ushort mask = (il/4) ? 0x00F0 : 0x000F;
230
+
231
+ const uint32_t qh = *((device const uint32_t *)xb->qh);
232
+
233
+ const int x_mv = (il/4) ? 4 : 0;
234
+
235
+ const int gh_mv = (il/4) ? 12 : 0;
236
+ const int gh_bk = (il/4) ? 0 : 4;
237
+
238
+ for (int ii = 0; ii < 2; ii++) {
239
+ int i = 2*(il%4) + ii;
240
+
241
+ // extract the 5-th bits for x0 and x1
242
+ const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
243
+ const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
244
+
245
+ // combine the 4-bits from qs with the 5th bit
246
+ const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
247
+ const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
248
+
249
+ reg[2*ii + 0] = d * x0 + m;
250
+ reg[2*ii + 1] = d * x1 + m;
251
+ }
252
+ }
253
+
254
  template <typename type4x4>
255
  void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
256
  device const int8_t * qs = ((device const int8_t *)xb->qs);
257
+ const float d = xb->d;
258
 
259
  float4x4 reg_f;
260
 
 
265
  reg = (type4x4) reg_f;
266
  }
267
 
268
+ template <typename type4>
269
+ void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & reg) {
270
+ device const int8_t * qs = ((device const int8_t *)xb->qs);
271
+ const float d = xb->d;
272
+
273
+ for (int i = 0; i < 4; i++) {
274
+ reg[i] = (qs[4*(il%4) + i + 16*(il/4)] * d);
275
+ }
276
+ }
277
+
278
  template <typename type4x4>
279
  void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
280
  const float d = xb->d;
 
329
  }
330
 
331
  template <typename type4x4>
332
+ void dequantize_q4_K(device const block_q4_K * xb, short il, thread type4x4 & reg) {
333
  device const uchar * q = xb->qs;
334
 
335
  short is = (il/4) * 2;
 
341
  const float dl = d * sc[0];
342
  const float ml = min * sc[1];
343
 
344
+ const ushort mask = il < 2 ? 0x0F : 0xF0;
345
  for (int i = 0; i < 16; ++i) {
346
  reg[i/4][i%4] = dl * (q[i] & mask) - ml;
347
  }
 
574
  }
575
  }
576
 
577
+ template <typename type4>
578
+ void dequantize_iq4_nl_t4(device const block_iq4_nl * xb, short il, thread type4 & reg) {
579
+ device const uint16_t * q4 = (device const uint16_t *)xb->qs;
580
+ const float d = xb->d;
581
+ uint32_t aux32;
582
+ thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
583
+ aux32 = ((q4[2*(il%4)] | (q4[2*(il%4)+1] << 16)) >> 4*(il/4)) & 0x0f0f0f0f;
584
+ reg[0] = d * kvalues_iq4nl_f[q8[0]];
585
+ reg[1] = d * kvalues_iq4nl_f[q8[1]];
586
+ reg[2] = d * kvalues_iq4nl_f[q8[2]];
587
+ reg[3] = d * kvalues_iq4nl_f[q8[3]];
588
+ }
589
+
590
  template <typename type4x4>
591
  void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
592
  // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
 
1927
  kernel_mul_mv_q8_0_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
1928
  }
1929
 
1930
+ // mat-vec kernel processing in chunks of float4
1931
+ // chpb - chunks per quantization block
1932
+ template<short nxpsg, short r1ptg, typename q_t, short chpb, void (*deq_t4)(device const q_t *, short, thread float4 &) >
1933
+ void kernel_mul_mv_ext_q4_f32_impl(
1934
+ constant ggml_metal_kargs_mul_mv_ext & args,
1935
+ device const char * src0,
1936
+ device const char * src1,
1937
+ device char * dst,
1938
+ uint3 tgpig[[threadgroup_position_in_grid]],
1939
+ ushort tiisg[[thread_index_in_simdgroup]],
1940
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
1941
+ const short chpt = 4; // chunks per thread
1942
+
1943
+ //const short nxpsg = (32);
1944
+ const short nypsg = (32/nxpsg);
1945
+
1946
+ const short tx = tiisg%nxpsg;
1947
+ const short ty = tiisg/nxpsg;
1948
+
1949
+ const int i01 = tgpig.x*(nypsg*args.nsg) + nypsg*sgitg + ty;
1950
+ const int i11 = tgpig.y*r1ptg;
1951
+ const int i1m = tgpig.z;
1952
+
1953
+ const int i12 = i1m%args.ne12;
1954
+ const int i13 = i1m/args.ne12;
1955
+
1956
+ const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
1957
+ const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
1958
+
1959
+ device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0;
1960
+
1961
+ device const float4 * y4[r1ptg];
1962
+
1963
+ for (int ir1 = 0; ir1 < r1ptg; ++ir1) {
1964
+ y4[ir1] = (i11 + ir1 < args.ne11) ? (device const float4 *) (src1 + offset1 + ir1*args.nb11) + tx : (device const float4 *) src1;
1965
+ }
1966
+
1967
+ float sumf[r1ptg] = { [ 0 ... r1ptg - 1 ] = 0.0f };
1968
+
1969
+ short cch = tx%chpb; // current chunk index
1970
+
1971
+ for (int ich = tx; 4*ich < args.ne00; ich += chpt*nxpsg) {
1972
+ float4 lx[chpt];
1973
+
1974
+ #pragma unroll(chpt)
1975
+ for (short ch = 0; ch < chpt; ++ch) {
1976
+ deq_t4(xq, cch, lx[ch]);
1977
+
1978
+ cch += nxpsg;
1979
+ if (cch >= chpb) {
1980
+ xq += cch/chpb;
1981
+ cch %= chpb;
1982
+ }
1983
+ }
1984
+
1985
+ #pragma unroll(chpt)
1986
+ for (short ch = 0; ch < chpt; ++ch) {
1987
+ #pragma unroll(r1ptg)
1988
+ for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
1989
+ sumf[ir1] += dot(lx[ch], y4[ir1][ch*nxpsg]);
1990
+
1991
+ }
1992
+ }
1993
+
1994
+ #pragma unroll(r1ptg)
1995
+ for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
1996
+ y4[ir1] += chpt*nxpsg;
1997
+ }
1998
+ }
1999
+
2000
+ // reduce only the threads in each row
2001
+ for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
2002
+ if (nxpsg >= 32) {
2003
+ sumf[ir1] += simd_shuffle_down(sumf[ir1], 16);
2004
+ }
2005
+ if (nxpsg >= 16) {
2006
+ sumf[ir1] += simd_shuffle_down(sumf[ir1], 8);
2007
+ }
2008
+ if (nxpsg >= 8) {
2009
+ sumf[ir1] += simd_shuffle_down(sumf[ir1], 4);
2010
+ }
2011
+ if (nxpsg >= 4) {
2012
+ sumf[ir1] += simd_shuffle_down(sumf[ir1], 2);
2013
+ }
2014
+ if (nxpsg >= 2) {
2015
+ sumf[ir1] += simd_shuffle_down(sumf[ir1], 1);
2016
+ }
2017
+
2018
+ //sumf[ir1] = simd_sum(sumf[ir1]);
2019
+ }
2020
+
2021
+ if (tx == 0) {
2022
+ for (short ir1 = 0; ir1 < r1ptg && i11 + ir1 < args.ne11; ++ir1) {
2023
+ device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)(i11 + ir1)*args.ne0;
2024
+
2025
+ if (i01 < args.ne01) {
2026
+ dst_f32[i01] = sumf[ir1];
2027
+ }
2028
+ }
2029
+ }
2030
+ }
2031
+
2032
+ // mat-vec kernel processing in chunks of float4x4
2033
+ template<short nxpsg, short r1ptg, typename q_t, short chpb, void (*deq_t4x4)(device const q_t *, short, thread float4x4 &) >
2034
+ void kernel_mul_mv_ext_q4x4_f32_impl(
2035
+ constant ggml_metal_kargs_mul_mv_ext & args,
2036
+ device const char * src0,
2037
+ device const char * src1,
2038
+ device char * dst,
2039
+ uint3 tgpig[[threadgroup_position_in_grid]],
2040
+ ushort tiisg[[thread_index_in_simdgroup]],
2041
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
2042
+ const short chpt = 1;
2043
+
2044
+ //const short nxpsg = (32);
2045
+ const short nypsg = (32/nxpsg);
2046
+
2047
+ const short tx = tiisg%nxpsg;
2048
+ const short ty = tiisg/nxpsg;
2049
+
2050
+ const int i01 = tgpig.x*(nypsg*args.nsg) + nypsg*sgitg + ty;
2051
+ const int i11 = tgpig.y*r1ptg;
2052
+ const int i1m = tgpig.z;
2053
+
2054
+ const int i12 = i1m%args.ne12;
2055
+ const int i13 = i1m/args.ne12;
2056
+
2057
+ const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
2058
+ const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
2059
+
2060
+ device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0;
2061
+
2062
+ device const float4x4 * y4x4[r1ptg];
2063
+
2064
+ for (int ir1 = 0; ir1 < r1ptg; ++ir1) {
2065
+ y4x4[ir1] = (i11 + ir1 < args.ne11) ? (device const float4x4 *) (src1 + offset1 + ir1*args.nb11) + tx : (device const float4x4 *) src1;
2066
+ }
2067
+
2068
+ float sumf[r1ptg] = { [ 0 ... r1ptg - 1 ] = 0.0f };
2069
+
2070
+ short cch = tx%chpb;
2071
+
2072
+ for (int ich = tx; 16*ich < args.ne00; ich += chpt*nxpsg) {
2073
+ float4x4 lx[chpt];
2074
+
2075
+ #pragma unroll(chpt)
2076
+ for (short ch = 0; ch < chpt; ++ch) {
2077
+ deq_t4x4(xq, cch, lx[ch]);
2078
+
2079
+ cch += nxpsg;
2080
+ if (cch >= chpb) {
2081
+ xq += cch/chpb;
2082
+ cch %= chpb;
2083
+ }
2084
+ }
2085
+
2086
+ #pragma unroll(chpt)
2087
+ for (short ch = 0; ch < chpt; ++ch) {
2088
+ #pragma unroll(r1ptg)
2089
+ for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
2090
+ sumf[ir1] +=
2091
+ dot(lx[ch][0], y4x4[ir1][ch*nxpsg][0]) +
2092
+ dot(lx[ch][1], y4x4[ir1][ch*nxpsg][1]) +
2093
+ dot(lx[ch][2], y4x4[ir1][ch*nxpsg][2]) +
2094
+ dot(lx[ch][3], y4x4[ir1][ch*nxpsg][3]);
2095
+
2096
+ }
2097
+ }
2098
+
2099
+ #pragma unroll(r1ptg)
2100
+ for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
2101
+ y4x4[ir1] += chpt*nxpsg;
2102
+ }
2103
+ }
2104
+
2105
+ for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
2106
+ if (nxpsg >= 32) {
2107
+ sumf[ir1] += simd_shuffle_down(sumf[ir1], 16);
2108
+ }
2109
+ if (nxpsg >= 16) {
2110
+ sumf[ir1] += simd_shuffle_down(sumf[ir1], 8);
2111
+ }
2112
+ if (nxpsg >= 8) {
2113
+ sumf[ir1] += simd_shuffle_down(sumf[ir1], 4);
2114
+ }
2115
+ if (nxpsg >= 4) {
2116
+ sumf[ir1] += simd_shuffle_down(sumf[ir1], 2);
2117
+ }
2118
+ if (nxpsg >= 2) {
2119
+ sumf[ir1] += simd_shuffle_down(sumf[ir1], 1);
2120
+ }
2121
+
2122
+ //sumf[ir1] = simd_sum(sumf[ir1]);
2123
+ }
2124
+
2125
+ if (tx == 0) {
2126
+ for (short ir1 = 0; ir1 < r1ptg && i11 + ir1 < args.ne11; ++ir1) {
2127
+ device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)(i11 + ir1)*args.ne0;
2128
+
2129
+ if (i01 < args.ne01) {
2130
+ dst_f32[i01] = sumf[ir1];
2131
+ }
2132
+ }
2133
+ }
2134
+ }
2135
+
2136
+ // dispatchers needed for compile-time nxpsg
2137
+ // epb - elements per quantization block
2138
+ template<short r1ptg, typename q_t, short epb, void (*deq_t4)(device const q_t *, short, thread float4 &)>
2139
+ kernel void kernel_mul_mv_ext_q4_f32_disp(
2140
+ constant ggml_metal_kargs_mul_mv_ext & args,
2141
+ device const char * src0,
2142
+ device const char * src1,
2143
+ device char * dst,
2144
+ uint3 tgpig[[threadgroup_position_in_grid]],
2145
+ ushort tiisg[[thread_index_in_simdgroup]],
2146
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
2147
+ switch (args.nxpsg) {
2148
+ case 4: kernel_mul_mv_ext_q4_f32_impl<4, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
2149
+ case 8: kernel_mul_mv_ext_q4_f32_impl<8, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
2150
+ case 16: kernel_mul_mv_ext_q4_f32_impl<16, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
2151
+ case 32: kernel_mul_mv_ext_q4_f32_impl<32, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
2152
+ }
2153
+ }
2154
+
2155
+ template<short r1ptg, typename q_t, short epb, void (*deq_t4x4)(device const q_t *, short, thread float4x4 &)>
2156
+ kernel void kernel_mul_mv_ext_q4x4_f32_disp(
2157
+ constant ggml_metal_kargs_mul_mv_ext & args,
2158
+ device const char * src0,
2159
+ device const char * src1,
2160
+ device char * dst,
2161
+ uint3 tgpig[[threadgroup_position_in_grid]],
2162
+ ushort tiisg[[thread_index_in_simdgroup]],
2163
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
2164
+ switch (args.nxpsg) {
2165
+ case 4: kernel_mul_mv_ext_q4x4_f32_impl<4, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
2166
+ case 8: kernel_mul_mv_ext_q4x4_f32_impl<8, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
2167
+ case 16: kernel_mul_mv_ext_q4x4_f32_impl<16, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
2168
+ case 32: kernel_mul_mv_ext_q4x4_f32_impl<32, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
2169
+ }
2170
+ }
2171
+
2172
+ typedef decltype(kernel_mul_mv_ext_q4_f32_disp <2, block_q8_0, 32, dequantize_q8_0_t4>) mul_mv_ext_q4_f32_t;
2173
+ typedef decltype(kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>) mul_mv_ext_q4x4_f32_t;
2174
+
2175
+ template [[host_name("kernel_mul_mv_ext_f16_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, half4, 4, dequantize_f16_t4>;
2176
+ template [[host_name("kernel_mul_mv_ext_f16_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, half4, 4, dequantize_f16_t4>;
2177
+ template [[host_name("kernel_mul_mv_ext_f16_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, half4, 4, dequantize_f16_t4>;
2178
+ template [[host_name("kernel_mul_mv_ext_f16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, half4, 4, dequantize_f16_t4>;
2179
+
2180
+ template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_0, 32, dequantize_q4_0_t4>;
2181
+ template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_0, 32, dequantize_q4_0_t4>;
2182
+ template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_0, 32, dequantize_q4_0_t4>;
2183
+ template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q4_0, 32, dequantize_q4_0_t4>;
2184
+
2185
+ template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_1, 32, dequantize_q4_1_t4>;
2186
+ template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_1, 32, dequantize_q4_1_t4>;
2187
+ template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_1, 32, dequantize_q4_1_t4>;
2188
+ template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q4_1, 32, dequantize_q4_1_t4>;
2189
+
2190
+ template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q5_0, 32, dequantize_q5_0_t4>;
2191
+ template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q5_0, 32, dequantize_q5_0_t4>;
2192
+ template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q5_0, 32, dequantize_q5_0_t4>;
2193
+ template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q5_0, 32, dequantize_q5_0_t4>;
2194
+
2195
+ template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q5_1, 32, dequantize_q5_1_t4>;
2196
+ template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q5_1, 32, dequantize_q5_1_t4>;
2197
+ template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q5_1, 32, dequantize_q5_1_t4>;
2198
+ template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q5_1, 32, dequantize_q5_1_t4>;
2199
+
2200
+ template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q8_0, 32, dequantize_q8_0_t4>;
2201
+ template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q8_0, 32, dequantize_q8_0_t4>;
2202
+ template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q8_0, 32, dequantize_q8_0_t4>;
2203
+ template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q8_0, 32, dequantize_q8_0_t4>;
2204
+
2205
+ template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
2206
+ template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
2207
+ template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
2208
+ template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
2209
+
2210
+ template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>;
2211
+ template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q4_K, 256, dequantize_q4_K>;
2212
+ template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q4_K, 256, dequantize_q4_K>;
2213
+ template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q4_K, 256, dequantize_q4_K>;
2214
+
2215
+ template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q5_K, 256, dequantize_q5_K>;
2216
+ template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q5_K, 256, dequantize_q5_K>;
2217
+ template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q5_K, 256, dequantize_q5_K>;
2218
+ template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q5_K, 256, dequantize_q5_K>;
2219
+
2220
+ template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q6_K, 256, dequantize_q6_K>;
2221
+ template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q6_K, 256, dequantize_q6_K>;
2222
+ template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q6_K, 256, dequantize_q6_K>;
2223
+ template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q6_K, 256, dequantize_q6_K>;
2224
+
2225
  #define N_MV_T_T 4
2226
 
2227
  template<typename T0, typename T04, typename T1, typename T14, typename args_t>