Yavor Ivanov commited on
Commit
2ed022e
·
1 Parent(s): e51f2d4

metal : Add missing unary ops Metal support (llama/14660)

Browse files
ggml/src/ggml-metal/ggml-metal.m CHANGED
@@ -173,6 +173,12 @@ enum ggml_metal_kernel_type {
173
  GGML_METAL_KERNEL_TYPE_SILU,
174
  GGML_METAL_KERNEL_TYPE_SILU_4,
175
  GGML_METAL_KERNEL_TYPE_ELU,
 
 
 
 
 
 
176
  GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
177
  GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
178
  GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
@@ -1155,6 +1161,12 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
1155
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
1156
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
1157
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ELU, elu, true);
 
 
 
 
 
 
1158
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction);
1159
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction);
1160
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction);
@@ -1688,6 +1700,12 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1688
  case GGML_UNARY_OP_SILU:
1689
  case GGML_UNARY_OP_ELU:
1690
  case GGML_UNARY_OP_NEG:
 
 
 
 
 
 
1691
  return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
1692
  default:
1693
  return false;
@@ -2439,6 +2457,78 @@ static bool ggml_metal_encode_node(
2439
 
2440
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2441
  } break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2442
  default:
2443
  {
2444
  GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
 
173
  GGML_METAL_KERNEL_TYPE_SILU,
174
  GGML_METAL_KERNEL_TYPE_SILU_4,
175
  GGML_METAL_KERNEL_TYPE_ELU,
176
+ GGML_METAL_KERNEL_TYPE_ABS,
177
+ GGML_METAL_KERNEL_TYPE_SGN,
178
+ GGML_METAL_KERNEL_TYPE_STEP,
179
+ GGML_METAL_KERNEL_TYPE_HARDSWISH,
180
+ GGML_METAL_KERNEL_TYPE_HARDSIGMOID,
181
+ GGML_METAL_KERNEL_TYPE_EXP,
182
  GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
183
  GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
184
  GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
 
1161
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
1162
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
1163
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ELU, elu, true);
1164
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ABS, abs, true);
1165
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SGN, sgn, true);
1166
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_STEP, step, true);
1167
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_HARDSWISH, hardswish, true);
1168
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_HARDSIGMOID, hardsigmoid, true);
1169
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_EXP, exp, true);
1170
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction);
1171
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction);
1172
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction);
 
1700
  case GGML_UNARY_OP_SILU:
1701
  case GGML_UNARY_OP_ELU:
1702
  case GGML_UNARY_OP_NEG:
1703
+ case GGML_UNARY_OP_ABS:
1704
+ case GGML_UNARY_OP_SGN:
1705
+ case GGML_UNARY_OP_STEP:
1706
+ case GGML_UNARY_OP_HARDSWISH:
1707
+ case GGML_UNARY_OP_HARDSIGMOID:
1708
+ case GGML_UNARY_OP_EXP:
1709
  return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
1710
  default:
1711
  return false;
 
2457
 
2458
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2459
  } break;
2460
+ case GGML_UNARY_OP_ABS:
2461
+ {
2462
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ABS].pipeline;
2463
+
2464
+ [encoder setComputePipelineState:pipeline];
2465
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2466
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2467
+
2468
+ const int64_t n = ggml_nelements(dst);
2469
+
2470
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2471
+ } break;
2472
+ case GGML_UNARY_OP_SGN:
2473
+ {
2474
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SGN].pipeline;
2475
+
2476
+ [encoder setComputePipelineState:pipeline];
2477
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2478
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2479
+
2480
+ const int64_t n = ggml_nelements(dst);
2481
+
2482
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2483
+ } break;
2484
+ case GGML_UNARY_OP_STEP:
2485
+ {
2486
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_STEP].pipeline;
2487
+
2488
+ [encoder setComputePipelineState:pipeline];
2489
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2490
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2491
+
2492
+ const int64_t n = ggml_nelements(dst);
2493
+
2494
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2495
+ } break;
2496
+ case GGML_UNARY_OP_HARDSWISH:
2497
+ {
2498
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_HARDSWISH].pipeline;
2499
+
2500
+ [encoder setComputePipelineState:pipeline];
2501
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2502
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2503
+
2504
+ const int64_t n = ggml_nelements(dst);
2505
+
2506
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2507
+ } break;
2508
+ case GGML_UNARY_OP_HARDSIGMOID:
2509
+ {
2510
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_HARDSIGMOID].pipeline;
2511
+
2512
+ [encoder setComputePipelineState:pipeline];
2513
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2514
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2515
+
2516
+ const int64_t n = ggml_nelements(dst);
2517
+
2518
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2519
+ } break;
2520
+ case GGML_UNARY_OP_EXP:
2521
+ {
2522
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_EXP].pipeline;
2523
+
2524
+ [encoder setComputePipelineState:pipeline];
2525
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2526
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2527
+
2528
+ const int64_t n = ggml_nelements(dst);
2529
+
2530
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2531
+ } break;
2532
  default:
2533
  {
2534
  GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
ggml/src/ggml-metal/ggml-metal.metal CHANGED
@@ -1199,6 +1199,51 @@ kernel void kernel_neg(
1199
  dst[tpig] = -src0[tpig];
1200
  }
1201
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1202
  kernel void kernel_reglu(
1203
  device const char * src0,
1204
  device const char * src1,
 
1199
  dst[tpig] = -src0[tpig];
1200
  }
1201
 
1202
+ kernel void kernel_abs(
1203
+ device const float * src0,
1204
+ device float * dst,
1205
+ uint tpig[[thread_position_in_grid]]) {
1206
+ dst[tpig] = fabs(src0[tpig]);
1207
+ }
1208
+
1209
+ kernel void kernel_sgn(
1210
+ device const float * src0,
1211
+ device float * dst,
1212
+ uint tpig[[thread_position_in_grid]]) {
1213
+ device const float & x = src0[tpig];
1214
+ dst[tpig] = (x > 0.0f) ? 1.0f : ((x < 0.0f) ? -1.0f : 0.0f);
1215
+ }
1216
+
1217
+ kernel void kernel_step(
1218
+ device const float * src0,
1219
+ device float * dst,
1220
+ uint tpig[[thread_position_in_grid]]) {
1221
+ dst[tpig] = src0[tpig] > 0.0f ? 1.0f : 0.0f;
1222
+ }
1223
+
1224
+ kernel void kernel_hardswish(
1225
+ device const float * src0,
1226
+ device float * dst,
1227
+ uint tpig[[thread_position_in_grid]]) {
1228
+ device const float & x = src0[tpig];
1229
+ dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
1230
+ }
1231
+
1232
+ kernel void kernel_hardsigmoid(
1233
+ device const float * src0,
1234
+ device float * dst,
1235
+ uint tpig[[thread_position_in_grid]]) {
1236
+ device const float & x = src0[tpig];
1237
+ dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
1238
+ }
1239
+
1240
+ kernel void kernel_exp(
1241
+ device const float * src0,
1242
+ device float * dst,
1243
+ uint tpig[[thread_position_in_grid]]) {
1244
+ dst[tpig] = exp(src0[tpig]);
1245
+ }
1246
+
1247
  kernel void kernel_reglu(
1248
  device const char * src0,
1249
  device const char * src1,