ggerganov commited on
Commit
0534b5d
·
1 Parent(s): 0c32e28

metal : add GGML_OP_REPEAT kernels (llama/7557)

Browse files
Files changed (2) hide show
  1. ggml-metal.m +49 -4
  2. ggml-metal.metal +47 -0
ggml-metal.m CHANGED
@@ -35,6 +35,10 @@ enum ggml_metal_kernel_type {
35
  GGML_METAL_KERNEL_TYPE_MUL_ROW,
36
  GGML_METAL_KERNEL_TYPE_DIV,
37
  GGML_METAL_KERNEL_TYPE_DIV_ROW,
 
 
 
 
38
  GGML_METAL_KERNEL_TYPE_SCALE,
39
  GGML_METAL_KERNEL_TYPE_SCALE_4,
40
  GGML_METAL_KERNEL_TYPE_CLAMP,
@@ -485,6 +489,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
485
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
486
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
487
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
 
 
 
 
488
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
489
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
490
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
@@ -746,6 +754,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
746
  case GGML_OP_ACC:
747
  case GGML_OP_MUL:
748
  case GGML_OP_DIV:
 
749
  case GGML_OP_SCALE:
750
  case GGML_OP_CLAMP:
751
  case GGML_OP_SQR:
@@ -979,8 +988,6 @@ static enum ggml_status ggml_metal_graph_compute(
979
  switch (dst->op) {
980
  case GGML_OP_CONCAT:
981
  {
982
- const int64_t nb = ne00;
983
-
984
  id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
985
 
986
  [encoder setComputePipelineState:pipeline];
@@ -1011,7 +1018,6 @@ static enum ggml_status ggml_metal_graph_compute(
1011
  [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
1012
  [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
1013
  [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
1014
- [encoder setBytes:&nb length:sizeof(nb) atIndex:27];
1015
 
1016
  const int nth = MIN(1024, ne0);
1017
 
@@ -1021,11 +1027,14 @@ static enum ggml_status ggml_metal_graph_compute(
1021
  case GGML_OP_MUL:
1022
  case GGML_OP_DIV:
1023
  {
 
 
 
1024
  const size_t offs = 0;
1025
 
1026
  bool bcast_row = false;
1027
 
1028
- int64_t nb = ne00;
1029
 
1030
  id<MTLComputePipelineState> pipeline = nil;
1031
 
@@ -1094,6 +1103,42 @@ static enum ggml_status ggml_metal_graph_compute(
1094
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1095
  }
1096
  } break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1097
  case GGML_OP_ACC:
1098
  {
1099
  GGML_ASSERT(src0t == GGML_TYPE_F32);
 
35
  GGML_METAL_KERNEL_TYPE_MUL_ROW,
36
  GGML_METAL_KERNEL_TYPE_DIV,
37
  GGML_METAL_KERNEL_TYPE_DIV_ROW,
38
+ GGML_METAL_KERNEL_TYPE_REPEAT_F32,
39
+ GGML_METAL_KERNEL_TYPE_REPEAT_F16,
40
+ GGML_METAL_KERNEL_TYPE_REPEAT_I32,
41
+ GGML_METAL_KERNEL_TYPE_REPEAT_I16,
42
  GGML_METAL_KERNEL_TYPE_SCALE,
43
  GGML_METAL_KERNEL_TYPE_SCALE_4,
44
  GGML_METAL_KERNEL_TYPE_CLAMP,
 
489
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
490
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
491
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
492
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
493
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
494
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
495
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true);
496
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
497
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
498
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
 
754
  case GGML_OP_ACC:
755
  case GGML_OP_MUL:
756
  case GGML_OP_DIV:
757
+ case GGML_OP_REPEAT:
758
  case GGML_OP_SCALE:
759
  case GGML_OP_CLAMP:
760
  case GGML_OP_SQR:
 
988
  switch (dst->op) {
989
  case GGML_OP_CONCAT:
990
  {
 
 
991
  id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
992
 
993
  [encoder setComputePipelineState:pipeline];
 
1018
  [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
1019
  [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
1020
  [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
 
1021
 
1022
  const int nth = MIN(1024, ne0);
1023
 
 
1027
  case GGML_OP_MUL:
1028
  case GGML_OP_DIV:
1029
  {
1030
+ GGML_ASSERT(src0t == GGML_TYPE_F32);
1031
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
1032
+
1033
  const size_t offs = 0;
1034
 
1035
  bool bcast_row = false;
1036
 
1037
+ int64_t nb = ne00; // used by the "row" kernels
1038
 
1039
  id<MTLComputePipelineState> pipeline = nil;
1040
 
 
1103
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1104
  }
1105
  } break;
1106
+ case GGML_OP_REPEAT:
1107
+ {
1108
+ id<MTLComputePipelineState> pipeline;
1109
+
1110
+ switch (src0t) {
1111
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break;
1112
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break;
1113
+ case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break;
1114
+ case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break;
1115
+ default: GGML_ASSERT(false);
1116
+ }
1117
+
1118
+ [encoder setComputePipelineState:pipeline];
1119
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1120
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1121
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1122
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1123
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1124
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
1125
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1126
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1127
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1128
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
1129
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
1130
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
1131
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
1132
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
1133
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
1134
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
1135
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
1136
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
1137
+
1138
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
1139
+
1140
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1141
+ } break;
1142
  case GGML_OP_ACC:
1143
  {
1144
  GGML_ASSERT(src0t == GGML_TYPE_F32);
ggml-metal.metal CHANGED
@@ -168,6 +168,53 @@ kernel void kernel_div(
168
  }
169
  }
170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  // assumption: src1 is a row
172
  // broadcast src1 into src0
173
  kernel void kernel_add_row(
 
168
  }
169
  }
170
 
171
+ template<typename T>
172
+ kernel void kernel_repeat(
173
+ device const char * src0,
174
+ device char * dst,
175
+ constant int64_t & ne00,
176
+ constant int64_t & ne01,
177
+ constant int64_t & ne02,
178
+ constant int64_t & ne03,
179
+ constant uint64_t & nb00,
180
+ constant uint64_t & nb01,
181
+ constant uint64_t & nb02,
182
+ constant uint64_t & nb03,
183
+ constant int64_t & ne0,
184
+ constant int64_t & ne1,
185
+ constant int64_t & ne2,
186
+ constant int64_t & ne3,
187
+ constant uint64_t & nb0,
188
+ constant uint64_t & nb1,
189
+ constant uint64_t & nb2,
190
+ constant uint64_t & nb3,
191
+ uint3 tgpig[[threadgroup_position_in_grid]],
192
+ uint3 tpitg[[thread_position_in_threadgroup]],
193
+ uint3 ntg[[threads_per_threadgroup]]) {
194
+ const int64_t i3 = tgpig.z;
195
+ const int64_t i2 = tgpig.y;
196
+ const int64_t i1 = tgpig.x;
197
+
198
+ const int64_t i03 = i3 % ne03;
199
+ const int64_t i02 = i2 % ne02;
200
+ const int64_t i01 = i1 % ne01;
201
+
202
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
203
+ device char * dst_ptr = dst + i3*nb3 + i2*nb2 + i1*nb1 ;
204
+
205
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
206
+ const int i00 = i0 % ne00;
207
+ *((device T *)(dst_ptr + i0*nb0)) = *((device T *)(src0_ptr + i00*nb00));
208
+ }
209
+ }
210
+
211
+ typedef decltype(kernel_repeat<float>) kernel_repeat_t;
212
+
213
+ template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>;
214
+ template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>;
215
+ template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
216
+ template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
217
+
218
  // assumption: src1 is a row
219
  // broadcast src1 into src0
220
  kernel void kernel_add_row(