Spaces:
Running
Running
Yavor Ivanov
commited on
Commit
·
2ed022e
1
Parent(s):
e51f2d4
metal : Add missing unary ops Metal support (llama/14660)
Browse files
ggml/src/ggml-metal/ggml-metal.m
CHANGED
|
@@ -173,6 +173,12 @@ enum ggml_metal_kernel_type {
|
|
| 173 |
GGML_METAL_KERNEL_TYPE_SILU,
|
| 174 |
GGML_METAL_KERNEL_TYPE_SILU_4,
|
| 175 |
GGML_METAL_KERNEL_TYPE_ELU,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
|
| 177 |
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
|
| 178 |
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
|
|
@@ -1155,6 +1161,12 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
| 1155 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
| 1156 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
|
| 1157 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ELU, elu, true);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1158 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction);
|
| 1159 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction);
|
| 1160 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction);
|
|
@@ -1688,6 +1700,12 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
| 1688 |
case GGML_UNARY_OP_SILU:
|
| 1689 |
case GGML_UNARY_OP_ELU:
|
| 1690 |
case GGML_UNARY_OP_NEG:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1691 |
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
| 1692 |
default:
|
| 1693 |
return false;
|
|
@@ -2439,6 +2457,78 @@ static bool ggml_metal_encode_node(
|
|
| 2439 |
|
| 2440 |
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 2441 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2442 |
default:
|
| 2443 |
{
|
| 2444 |
GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
|
|
|
|
| 173 |
GGML_METAL_KERNEL_TYPE_SILU,
|
| 174 |
GGML_METAL_KERNEL_TYPE_SILU_4,
|
| 175 |
GGML_METAL_KERNEL_TYPE_ELU,
|
| 176 |
+
GGML_METAL_KERNEL_TYPE_ABS,
|
| 177 |
+
GGML_METAL_KERNEL_TYPE_SGN,
|
| 178 |
+
GGML_METAL_KERNEL_TYPE_STEP,
|
| 179 |
+
GGML_METAL_KERNEL_TYPE_HARDSWISH,
|
| 180 |
+
GGML_METAL_KERNEL_TYPE_HARDSIGMOID,
|
| 181 |
+
GGML_METAL_KERNEL_TYPE_EXP,
|
| 182 |
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
|
| 183 |
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
|
| 184 |
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
|
|
|
|
| 1161 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
| 1162 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
|
| 1163 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ELU, elu, true);
|
| 1164 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ABS, abs, true);
|
| 1165 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SGN, sgn, true);
|
| 1166 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_STEP, step, true);
|
| 1167 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_HARDSWISH, hardswish, true);
|
| 1168 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_HARDSIGMOID, hardsigmoid, true);
|
| 1169 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_EXP, exp, true);
|
| 1170 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction);
|
| 1171 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction);
|
| 1172 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction);
|
|
|
|
| 1700 |
case GGML_UNARY_OP_SILU:
|
| 1701 |
case GGML_UNARY_OP_ELU:
|
| 1702 |
case GGML_UNARY_OP_NEG:
|
| 1703 |
+
case GGML_UNARY_OP_ABS:
|
| 1704 |
+
case GGML_UNARY_OP_SGN:
|
| 1705 |
+
case GGML_UNARY_OP_STEP:
|
| 1706 |
+
case GGML_UNARY_OP_HARDSWISH:
|
| 1707 |
+
case GGML_UNARY_OP_HARDSIGMOID:
|
| 1708 |
+
case GGML_UNARY_OP_EXP:
|
| 1709 |
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
| 1710 |
default:
|
| 1711 |
return false;
|
|
|
|
| 2457 |
|
| 2458 |
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 2459 |
} break;
|
| 2460 |
+
case GGML_UNARY_OP_ABS:
|
| 2461 |
+
{
|
| 2462 |
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ABS].pipeline;
|
| 2463 |
+
|
| 2464 |
+
[encoder setComputePipelineState:pipeline];
|
| 2465 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 2466 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 2467 |
+
|
| 2468 |
+
const int64_t n = ggml_nelements(dst);
|
| 2469 |
+
|
| 2470 |
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 2471 |
+
} break;
|
| 2472 |
+
case GGML_UNARY_OP_SGN:
|
| 2473 |
+
{
|
| 2474 |
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SGN].pipeline;
|
| 2475 |
+
|
| 2476 |
+
[encoder setComputePipelineState:pipeline];
|
| 2477 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 2478 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 2479 |
+
|
| 2480 |
+
const int64_t n = ggml_nelements(dst);
|
| 2481 |
+
|
| 2482 |
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 2483 |
+
} break;
|
| 2484 |
+
case GGML_UNARY_OP_STEP:
|
| 2485 |
+
{
|
| 2486 |
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_STEP].pipeline;
|
| 2487 |
+
|
| 2488 |
+
[encoder setComputePipelineState:pipeline];
|
| 2489 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 2490 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 2491 |
+
|
| 2492 |
+
const int64_t n = ggml_nelements(dst);
|
| 2493 |
+
|
| 2494 |
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 2495 |
+
} break;
|
| 2496 |
+
case GGML_UNARY_OP_HARDSWISH:
|
| 2497 |
+
{
|
| 2498 |
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_HARDSWISH].pipeline;
|
| 2499 |
+
|
| 2500 |
+
[encoder setComputePipelineState:pipeline];
|
| 2501 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 2502 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 2503 |
+
|
| 2504 |
+
const int64_t n = ggml_nelements(dst);
|
| 2505 |
+
|
| 2506 |
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 2507 |
+
} break;
|
| 2508 |
+
case GGML_UNARY_OP_HARDSIGMOID:
|
| 2509 |
+
{
|
| 2510 |
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_HARDSIGMOID].pipeline;
|
| 2511 |
+
|
| 2512 |
+
[encoder setComputePipelineState:pipeline];
|
| 2513 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 2514 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 2515 |
+
|
| 2516 |
+
const int64_t n = ggml_nelements(dst);
|
| 2517 |
+
|
| 2518 |
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 2519 |
+
} break;
|
| 2520 |
+
case GGML_UNARY_OP_EXP:
|
| 2521 |
+
{
|
| 2522 |
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_EXP].pipeline;
|
| 2523 |
+
|
| 2524 |
+
[encoder setComputePipelineState:pipeline];
|
| 2525 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 2526 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 2527 |
+
|
| 2528 |
+
const int64_t n = ggml_nelements(dst);
|
| 2529 |
+
|
| 2530 |
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 2531 |
+
} break;
|
| 2532 |
default:
|
| 2533 |
{
|
| 2534 |
GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
|
ggml/src/ggml-metal/ggml-metal.metal
CHANGED
|
@@ -1199,6 +1199,51 @@ kernel void kernel_neg(
|
|
| 1199 |
dst[tpig] = -src0[tpig];
|
| 1200 |
}
|
| 1201 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1202 |
kernel void kernel_reglu(
|
| 1203 |
device const char * src0,
|
| 1204 |
device const char * src1,
|
|
|
|
| 1199 |
dst[tpig] = -src0[tpig];
|
| 1200 |
}
|
| 1201 |
|
| 1202 |
+
kernel void kernel_abs(
|
| 1203 |
+
device const float * src0,
|
| 1204 |
+
device float * dst,
|
| 1205 |
+
uint tpig[[thread_position_in_grid]]) {
|
| 1206 |
+
dst[tpig] = fabs(src0[tpig]);
|
| 1207 |
+
}
|
| 1208 |
+
|
| 1209 |
+
kernel void kernel_sgn(
|
| 1210 |
+
device const float * src0,
|
| 1211 |
+
device float * dst,
|
| 1212 |
+
uint tpig[[thread_position_in_grid]]) {
|
| 1213 |
+
device const float & x = src0[tpig];
|
| 1214 |
+
dst[tpig] = (x > 0.0f) ? 1.0f : ((x < 0.0f) ? -1.0f : 0.0f);
|
| 1215 |
+
}
|
| 1216 |
+
|
| 1217 |
+
kernel void kernel_step(
|
| 1218 |
+
device const float * src0,
|
| 1219 |
+
device float * dst,
|
| 1220 |
+
uint tpig[[thread_position_in_grid]]) {
|
| 1221 |
+
dst[tpig] = src0[tpig] > 0.0f ? 1.0f : 0.0f;
|
| 1222 |
+
}
|
| 1223 |
+
|
| 1224 |
+
kernel void kernel_hardswish(
|
| 1225 |
+
device const float * src0,
|
| 1226 |
+
device float * dst,
|
| 1227 |
+
uint tpig[[thread_position_in_grid]]) {
|
| 1228 |
+
device const float & x = src0[tpig];
|
| 1229 |
+
dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
|
| 1230 |
+
}
|
| 1231 |
+
|
| 1232 |
+
kernel void kernel_hardsigmoid(
|
| 1233 |
+
device const float * src0,
|
| 1234 |
+
device float * dst,
|
| 1235 |
+
uint tpig[[thread_position_in_grid]]) {
|
| 1236 |
+
device const float & x = src0[tpig];
|
| 1237 |
+
dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
|
| 1238 |
+
}
|
| 1239 |
+
|
| 1240 |
+
kernel void kernel_exp(
|
| 1241 |
+
device const float * src0,
|
| 1242 |
+
device float * dst,
|
| 1243 |
+
uint tpig[[thread_position_in_grid]]) {
|
| 1244 |
+
dst[tpig] = exp(src0[tpig]);
|
| 1245 |
+
}
|
| 1246 |
+
|
| 1247 |
kernel void kernel_reglu(
|
| 1248 |
device const char * src0,
|
| 1249 |
device const char * src1,
|