Spaces:
Running
Running
metal : add GGML_OP_REPEAT kernels (llama/7557)
Browse files- ggml-metal.m +49 -4
- ggml-metal.metal +47 -0
ggml-metal.m
CHANGED
|
@@ -35,6 +35,10 @@ enum ggml_metal_kernel_type {
|
|
| 35 |
GGML_METAL_KERNEL_TYPE_MUL_ROW,
|
| 36 |
GGML_METAL_KERNEL_TYPE_DIV,
|
| 37 |
GGML_METAL_KERNEL_TYPE_DIV_ROW,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
GGML_METAL_KERNEL_TYPE_SCALE,
|
| 39 |
GGML_METAL_KERNEL_TYPE_SCALE_4,
|
| 40 |
GGML_METAL_KERNEL_TYPE_CLAMP,
|
|
@@ -485,6 +489,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
| 485 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
|
| 486 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
|
| 487 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 488 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
|
| 489 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
|
| 490 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
|
|
@@ -746,6 +754,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
| 746 |
case GGML_OP_ACC:
|
| 747 |
case GGML_OP_MUL:
|
| 748 |
case GGML_OP_DIV:
|
|
|
|
| 749 |
case GGML_OP_SCALE:
|
| 750 |
case GGML_OP_CLAMP:
|
| 751 |
case GGML_OP_SQR:
|
|
@@ -979,8 +988,6 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 979 |
switch (dst->op) {
|
| 980 |
case GGML_OP_CONCAT:
|
| 981 |
{
|
| 982 |
-
const int64_t nb = ne00;
|
| 983 |
-
|
| 984 |
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
|
| 985 |
|
| 986 |
[encoder setComputePipelineState:pipeline];
|
|
@@ -1011,7 +1018,6 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 1011 |
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
| 1012 |
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
| 1013 |
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
| 1014 |
-
[encoder setBytes:&nb length:sizeof(nb) atIndex:27];
|
| 1015 |
|
| 1016 |
const int nth = MIN(1024, ne0);
|
| 1017 |
|
|
@@ -1021,11 +1027,14 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 1021 |
case GGML_OP_MUL:
|
| 1022 |
case GGML_OP_DIV:
|
| 1023 |
{
|
|
|
|
|
|
|
|
|
|
| 1024 |
const size_t offs = 0;
|
| 1025 |
|
| 1026 |
bool bcast_row = false;
|
| 1027 |
|
| 1028 |
-
int64_t nb = ne00;
|
| 1029 |
|
| 1030 |
id<MTLComputePipelineState> pipeline = nil;
|
| 1031 |
|
|
@@ -1094,6 +1103,42 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 1094 |
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 1095 |
}
|
| 1096 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1097 |
case GGML_OP_ACC:
|
| 1098 |
{
|
| 1099 |
GGML_ASSERT(src0t == GGML_TYPE_F32);
|
|
|
|
| 35 |
GGML_METAL_KERNEL_TYPE_MUL_ROW,
|
| 36 |
GGML_METAL_KERNEL_TYPE_DIV,
|
| 37 |
GGML_METAL_KERNEL_TYPE_DIV_ROW,
|
| 38 |
+
GGML_METAL_KERNEL_TYPE_REPEAT_F32,
|
| 39 |
+
GGML_METAL_KERNEL_TYPE_REPEAT_F16,
|
| 40 |
+
GGML_METAL_KERNEL_TYPE_REPEAT_I32,
|
| 41 |
+
GGML_METAL_KERNEL_TYPE_REPEAT_I16,
|
| 42 |
GGML_METAL_KERNEL_TYPE_SCALE,
|
| 43 |
GGML_METAL_KERNEL_TYPE_SCALE_4,
|
| 44 |
GGML_METAL_KERNEL_TYPE_CLAMP,
|
|
|
|
| 489 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
|
| 490 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
|
| 491 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
|
| 492 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
|
| 493 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
|
| 494 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
|
| 495 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true);
|
| 496 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
|
| 497 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
|
| 498 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
|
|
|
|
| 754 |
case GGML_OP_ACC:
|
| 755 |
case GGML_OP_MUL:
|
| 756 |
case GGML_OP_DIV:
|
| 757 |
+
case GGML_OP_REPEAT:
|
| 758 |
case GGML_OP_SCALE:
|
| 759 |
case GGML_OP_CLAMP:
|
| 760 |
case GGML_OP_SQR:
|
|
|
|
| 988 |
switch (dst->op) {
|
| 989 |
case GGML_OP_CONCAT:
|
| 990 |
{
|
|
|
|
|
|
|
| 991 |
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
|
| 992 |
|
| 993 |
[encoder setComputePipelineState:pipeline];
|
|
|
|
| 1018 |
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
| 1019 |
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
| 1020 |
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
|
|
|
| 1021 |
|
| 1022 |
const int nth = MIN(1024, ne0);
|
| 1023 |
|
|
|
|
| 1027 |
case GGML_OP_MUL:
|
| 1028 |
case GGML_OP_DIV:
|
| 1029 |
{
|
| 1030 |
+
GGML_ASSERT(src0t == GGML_TYPE_F32);
|
| 1031 |
+
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
| 1032 |
+
|
| 1033 |
const size_t offs = 0;
|
| 1034 |
|
| 1035 |
bool bcast_row = false;
|
| 1036 |
|
| 1037 |
+
int64_t nb = ne00; // used by the "row" kernels
|
| 1038 |
|
| 1039 |
id<MTLComputePipelineState> pipeline = nil;
|
| 1040 |
|
|
|
|
| 1103 |
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 1104 |
}
|
| 1105 |
} break;
|
| 1106 |
+
case GGML_OP_REPEAT:
|
| 1107 |
+
{
|
| 1108 |
+
id<MTLComputePipelineState> pipeline;
|
| 1109 |
+
|
| 1110 |
+
switch (src0t) {
|
| 1111 |
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break;
|
| 1112 |
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break;
|
| 1113 |
+
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break;
|
| 1114 |
+
case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break;
|
| 1115 |
+
default: GGML_ASSERT(false);
|
| 1116 |
+
}
|
| 1117 |
+
|
| 1118 |
+
[encoder setComputePipelineState:pipeline];
|
| 1119 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 1120 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 1121 |
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
| 1122 |
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
| 1123 |
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
| 1124 |
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
| 1125 |
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
| 1126 |
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
| 1127 |
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
| 1128 |
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
| 1129 |
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
|
| 1130 |
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
|
| 1131 |
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
|
| 1132 |
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
|
| 1133 |
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
|
| 1134 |
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
| 1135 |
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
| 1136 |
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
| 1137 |
+
|
| 1138 |
+
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
|
| 1139 |
+
|
| 1140 |
+
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 1141 |
+
} break;
|
| 1142 |
case GGML_OP_ACC:
|
| 1143 |
{
|
| 1144 |
GGML_ASSERT(src0t == GGML_TYPE_F32);
|
ggml-metal.metal
CHANGED
|
@@ -168,6 +168,53 @@ kernel void kernel_div(
|
|
| 168 |
}
|
| 169 |
}
|
| 170 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
// assumption: src1 is a row
|
| 172 |
// broadcast src1 into src0
|
| 173 |
kernel void kernel_add_row(
|
|
|
|
| 168 |
}
|
| 169 |
}
|
| 170 |
|
| 171 |
+
template<typename T>
|
| 172 |
+
kernel void kernel_repeat(
|
| 173 |
+
device const char * src0,
|
| 174 |
+
device char * dst,
|
| 175 |
+
constant int64_t & ne00,
|
| 176 |
+
constant int64_t & ne01,
|
| 177 |
+
constant int64_t & ne02,
|
| 178 |
+
constant int64_t & ne03,
|
| 179 |
+
constant uint64_t & nb00,
|
| 180 |
+
constant uint64_t & nb01,
|
| 181 |
+
constant uint64_t & nb02,
|
| 182 |
+
constant uint64_t & nb03,
|
| 183 |
+
constant int64_t & ne0,
|
| 184 |
+
constant int64_t & ne1,
|
| 185 |
+
constant int64_t & ne2,
|
| 186 |
+
constant int64_t & ne3,
|
| 187 |
+
constant uint64_t & nb0,
|
| 188 |
+
constant uint64_t & nb1,
|
| 189 |
+
constant uint64_t & nb2,
|
| 190 |
+
constant uint64_t & nb3,
|
| 191 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 192 |
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
| 193 |
+
uint3 ntg[[threads_per_threadgroup]]) {
|
| 194 |
+
const int64_t i3 = tgpig.z;
|
| 195 |
+
const int64_t i2 = tgpig.y;
|
| 196 |
+
const int64_t i1 = tgpig.x;
|
| 197 |
+
|
| 198 |
+
const int64_t i03 = i3 % ne03;
|
| 199 |
+
const int64_t i02 = i2 % ne02;
|
| 200 |
+
const int64_t i01 = i1 % ne01;
|
| 201 |
+
|
| 202 |
+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
|
| 203 |
+
device char * dst_ptr = dst + i3*nb3 + i2*nb2 + i1*nb1 ;
|
| 204 |
+
|
| 205 |
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
| 206 |
+
const int i00 = i0 % ne00;
|
| 207 |
+
*((device T *)(dst_ptr + i0*nb0)) = *((device T *)(src0_ptr + i00*nb00));
|
| 208 |
+
}
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
typedef decltype(kernel_repeat<float>) kernel_repeat_t;
|
| 212 |
+
|
| 213 |
+
template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>;
|
| 214 |
+
template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>;
|
| 215 |
+
template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
|
| 216 |
+
template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
|
| 217 |
+
|
| 218 |
// assumption: src1 is a row
|
| 219 |
// broadcast src1 into src0
|
| 220 |
kernel void kernel_add_row(
|