Spaces:
Running
metal : add POOL2D and fix IM2COL (llama/9943)
Browse files* add pool_2d
Signed-off-by: Junhee Yoo <[email protected]>
* fix im2col and add unittest for N>=1024
Signed-off-by: Junhee Yoo <[email protected]>
* add tests for N % 1024 != 0
Signed-off-by: Junhee Yoo <[email protected]>
* remove trailing whitespaces
Signed-off-by: Junhee Yoo <[email protected]>
* apply suggestions
Signed-off-by: Junhee Yoo <[email protected]>
* apply more optimization
- original IM2COL kernel + _ext with MIN()
Signed-off-by: Junhee Yoo <[email protected]>
* apply review: change kernel name of pool_2d
Signed-off-by: Junhee Yoo <[email protected]>
* apply review
Signed-off-by: Junhee Yoo <[email protected]>
* fix more formatting and enhance readability
Signed-off-by: Junhee Yoo <[email protected]>
---------
Signed-off-by: Junhee Yoo <[email protected]>
- ggml/src/ggml-metal.m +111 -19
- ggml/src/ggml-metal.metal +178 -0
|
@@ -241,6 +241,8 @@ enum ggml_metal_kernel_type {
|
|
| 241 |
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
|
| 242 |
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
|
| 243 |
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
|
|
|
|
|
|
|
| 244 |
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
| 245 |
GGML_METAL_KERNEL_TYPE_PAD_F32,
|
| 246 |
GGML_METAL_KERNEL_TYPE_ARANGE_F32,
|
|
@@ -272,6 +274,8 @@ enum ggml_metal_kernel_type {
|
|
| 272 |
GGML_METAL_KERNEL_TYPE_SIN,
|
| 273 |
GGML_METAL_KERNEL_TYPE_COS,
|
| 274 |
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
|
|
|
|
|
|
| 275 |
|
| 276 |
GGML_METAL_KERNEL_TYPE_COUNT
|
| 277 |
};
|
|
@@ -685,6 +689,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
| 685 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
|
| 686 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
| 687 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
|
|
|
|
|
|
| 688 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
| 689 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
|
| 690 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
|
|
@@ -716,6 +722,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
| 716 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
|
| 717 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
| 718 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
|
|
|
|
|
|
| 719 |
}
|
| 720 |
|
| 721 |
[metal_library release];
|
|
@@ -844,8 +852,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
| 844 |
case GGML_OP_IM2COL:
|
| 845 |
return op->src[0]->type == GGML_TYPE_F16;
|
| 846 |
case GGML_OP_POOL_1D:
|
| 847 |
-
case GGML_OP_POOL_2D:
|
| 848 |
return false;
|
|
|
|
| 849 |
case GGML_OP_UPSCALE:
|
| 850 |
case GGML_OP_PAD:
|
| 851 |
case GGML_OP_ARANGE:
|
|
@@ -2545,6 +2553,8 @@ static void ggml_metal_encode_node(
|
|
| 2545 |
} break;
|
| 2546 |
case GGML_OP_IM2COL:
|
| 2547 |
{
|
|
|
|
|
|
|
| 2548 |
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
| 2549 |
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
| 2550 |
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
|
|
@@ -2574,30 +2584,54 @@ static void ggml_metal_encode_node(
|
|
| 2574 |
const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
|
| 2575 |
const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
|
| 2576 |
|
| 2577 |
-
id<MTLComputePipelineState> pipeline =
|
|
|
|
|
|
|
| 2578 |
|
| 2579 |
switch (dst->type) {
|
| 2580 |
-
case GGML_TYPE_F32:
|
| 2581 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2582 |
default: GGML_ABORT("fatal error");
|
| 2583 |
};
|
| 2584 |
|
| 2585 |
[encoder setComputePipelineState:pipeline];
|
| 2586 |
-
[encoder setBuffer:id_src1 offset:offs_src1
|
| 2587 |
-
[encoder setBuffer:id_dst offset:offs_dst
|
| 2588 |
-
[encoder setBytes:&ofs0 length:sizeof(
|
| 2589 |
-
[encoder setBytes:&ofs1 length:sizeof(
|
| 2590 |
-
[encoder setBytes:&IW length:sizeof(
|
| 2591 |
-
[encoder setBytes:&IH length:sizeof(
|
| 2592 |
-
[encoder setBytes:&CHW length:sizeof(
|
| 2593 |
-
[encoder setBytes:&s0 length:sizeof(
|
| 2594 |
-
[encoder setBytes:&s1 length:sizeof(
|
| 2595 |
-
[encoder setBytes:&p0 length:sizeof(
|
| 2596 |
-
[encoder setBytes:&p1 length:sizeof(
|
| 2597 |
-
[encoder setBytes:&d0 length:sizeof(
|
| 2598 |
-
[encoder setBytes:&d1 length:sizeof(
|
| 2599 |
-
|
| 2600 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2601 |
} break;
|
| 2602 |
case GGML_OP_UPSCALE:
|
| 2603 |
{
|
|
@@ -3001,6 +3035,64 @@ static void ggml_metal_encode_node(
|
|
| 3001 |
|
| 3002 |
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 3003 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3004 |
default:
|
| 3005 |
{
|
| 3006 |
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
|
|
|
|
| 241 |
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
|
| 242 |
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
|
| 243 |
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
|
| 244 |
+
GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,
|
| 245 |
+
GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32,
|
| 246 |
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
| 247 |
GGML_METAL_KERNEL_TYPE_PAD_F32,
|
| 248 |
GGML_METAL_KERNEL_TYPE_ARANGE_F32,
|
|
|
|
| 274 |
GGML_METAL_KERNEL_TYPE_SIN,
|
| 275 |
GGML_METAL_KERNEL_TYPE_COS,
|
| 276 |
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
| 277 |
+
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
| 278 |
+
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
|
| 279 |
|
| 280 |
GGML_METAL_KERNEL_TYPE_COUNT
|
| 281 |
};
|
|
|
|
| 689 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
|
| 690 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
| 691 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
| 692 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true);
|
| 693 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true);
|
| 694 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
| 695 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
|
| 696 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
|
|
|
|
| 722 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
|
| 723 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
| 724 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
| 725 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
|
| 726 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
|
| 727 |
}
|
| 728 |
|
| 729 |
[metal_library release];
|
|
|
|
| 852 |
case GGML_OP_IM2COL:
|
| 853 |
return op->src[0]->type == GGML_TYPE_F16;
|
| 854 |
case GGML_OP_POOL_1D:
|
|
|
|
| 855 |
return false;
|
| 856 |
+
case GGML_OP_POOL_2D:
|
| 857 |
case GGML_OP_UPSCALE:
|
| 858 |
case GGML_OP_PAD:
|
| 859 |
case GGML_OP_ARANGE:
|
|
|
|
| 2553 |
} break;
|
| 2554 |
case GGML_OP_IM2COL:
|
| 2555 |
{
|
| 2556 |
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
| 2557 |
+
GGML_ASSERT(ggml_is_contiguous(src1));
|
| 2558 |
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
| 2559 |
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
| 2560 |
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
|
|
|
|
| 2584 |
const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
|
| 2585 |
const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
|
| 2586 |
|
| 2587 |
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline;
|
| 2588 |
+
|
| 2589 |
+
const bool is_gt_mttpt = ((size_t)(N * KH * KW)) > pipeline.maxTotalThreadsPerThreadgroup;
|
| 2590 |
|
| 2591 |
switch (dst->type) {
|
| 2592 |
+
case GGML_TYPE_F32: {
|
| 2593 |
+
pipeline = (is_gt_mttpt ?
|
| 2594 |
+
ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32].pipeline
|
| 2595 |
+
:
|
| 2596 |
+
ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline);
|
| 2597 |
+
} break;
|
| 2598 |
+
case GGML_TYPE_F16: {
|
| 2599 |
+
pipeline = (is_gt_mttpt ?
|
| 2600 |
+
ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16].pipeline
|
| 2601 |
+
:
|
| 2602 |
+
ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline);
|
| 2603 |
+
} break;
|
| 2604 |
default: GGML_ABORT("fatal error");
|
| 2605 |
};
|
| 2606 |
|
| 2607 |
[encoder setComputePipelineState:pipeline];
|
| 2608 |
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
|
| 2609 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 2610 |
+
[encoder setBytes:&ofs0 length:sizeof(int32_t) atIndex:2];
|
| 2611 |
+
[encoder setBytes:&ofs1 length:sizeof(int32_t) atIndex:3];
|
| 2612 |
+
[encoder setBytes:&IW length:sizeof(int32_t) atIndex:4];
|
| 2613 |
+
[encoder setBytes:&IH length:sizeof(int32_t) atIndex:5];
|
| 2614 |
+
[encoder setBytes:&CHW length:sizeof(int32_t) atIndex:6];
|
| 2615 |
+
[encoder setBytes:&s0 length:sizeof(int32_t) atIndex:7];
|
| 2616 |
+
[encoder setBytes:&s1 length:sizeof(int32_t) atIndex:8];
|
| 2617 |
+
[encoder setBytes:&p0 length:sizeof(int32_t) atIndex:9];
|
| 2618 |
+
[encoder setBytes:&p1 length:sizeof(int32_t) atIndex:10];
|
| 2619 |
+
[encoder setBytes:&d0 length:sizeof(int32_t) atIndex:11];
|
| 2620 |
+
[encoder setBytes:&d1 length:sizeof(int32_t) atIndex:12];
|
| 2621 |
+
|
| 2622 |
+
if (is_gt_mttpt) {
|
| 2623 |
+
[encoder setBytes:&N length:sizeof(int32_t) atIndex:13];
|
| 2624 |
+
[encoder setBytes:&KH length:sizeof(int32_t) atIndex:14];
|
| 2625 |
+
[encoder setBytes:&KW length:sizeof(int32_t) atIndex:15];
|
| 2626 |
+
|
| 2627 |
+
const uint64_t n_threads = MIN(pipeline.maxTotalThreadsPerThreadgroup, (uint64_t)N);
|
| 2628 |
+
|
| 2629 |
+
const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0);
|
| 2630 |
+
|
| 2631 |
+
[encoder dispatchThreadgroups:MTLSizeMake(quotient * CHW, OH, OW) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
|
| 2632 |
+
} else {
|
| 2633 |
+
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
|
| 2634 |
+
}
|
| 2635 |
} break;
|
| 2636 |
case GGML_OP_UPSCALE:
|
| 2637 |
{
|
|
|
|
| 3035 |
|
| 3036 |
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 3037 |
} break;
|
| 3038 |
+
case GGML_OP_POOL_2D:
|
| 3039 |
+
{
|
| 3040 |
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
| 3041 |
+
GGML_ASSERT(src0t == GGML_TYPE_F32 && src0t == dstt);
|
| 3042 |
+
|
| 3043 |
+
const int32_t * opts = dst->op_params;
|
| 3044 |
+
enum ggml_op_pool op = opts[0];
|
| 3045 |
+
|
| 3046 |
+
id<MTLComputePipelineState> pipeline = nil;
|
| 3047 |
+
switch (src0t) {
|
| 3048 |
+
case GGML_TYPE_F32: {
|
| 3049 |
+
switch(op) {
|
| 3050 |
+
case GGML_OP_POOL_AVG:
|
| 3051 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32].pipeline; break;
|
| 3052 |
+
case GGML_OP_POOL_MAX:
|
| 3053 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32].pipeline; break;
|
| 3054 |
+
default: GGML_ASSERT(false && "not implemented");
|
| 3055 |
+
}
|
| 3056 |
+
} break;
|
| 3057 |
+
default: GGML_ASSERT(false && "not implemented");
|
| 3058 |
+
}
|
| 3059 |
+
|
| 3060 |
+
const int32_t k0 = opts[1];
|
| 3061 |
+
const int32_t k1 = opts[2];
|
| 3062 |
+
const int32_t s0 = opts[3];
|
| 3063 |
+
const int32_t s1 = opts[4];
|
| 3064 |
+
const int32_t p0 = opts[5];
|
| 3065 |
+
const int32_t p1 = opts[6];
|
| 3066 |
+
|
| 3067 |
+
const int64_t IH = src0->ne[1];
|
| 3068 |
+
const int64_t IW = src0->ne[0];
|
| 3069 |
+
|
| 3070 |
+
const int64_t N = dst->ne[3];
|
| 3071 |
+
const int64_t OC = dst->ne[2];
|
| 3072 |
+
const int64_t OH = dst->ne[1];
|
| 3073 |
+
const int64_t OW = dst->ne[0];
|
| 3074 |
+
|
| 3075 |
+
const int64_t parallel_elements = N * OC * OH * OW;
|
| 3076 |
+
const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements);
|
| 3077 |
+
const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads;
|
| 3078 |
+
|
| 3079 |
+
[encoder setComputePipelineState:pipeline];
|
| 3080 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 3081 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 3082 |
+
[encoder setBytes:&k0 length:sizeof(int32_t) atIndex:2];
|
| 3083 |
+
[encoder setBytes:&k1 length:sizeof(int32_t) atIndex:3];
|
| 3084 |
+
[encoder setBytes:&s0 length:sizeof(int32_t) atIndex:4];
|
| 3085 |
+
[encoder setBytes:&s1 length:sizeof(int32_t) atIndex:5];
|
| 3086 |
+
[encoder setBytes:&p0 length:sizeof(int32_t) atIndex:6];
|
| 3087 |
+
[encoder setBytes:&p1 length:sizeof(int32_t) atIndex:7];
|
| 3088 |
+
[encoder setBytes:&IH length:sizeof(int64_t) atIndex:8];
|
| 3089 |
+
[encoder setBytes:&IW length:sizeof(int64_t) atIndex:9];
|
| 3090 |
+
[encoder setBytes:&OH length:sizeof(int64_t) atIndex:10];
|
| 3091 |
+
[encoder setBytes:&OW length:sizeof(int64_t) atIndex:11];
|
| 3092 |
+
[encoder setBytes:¶llel_elements length:sizeof(int64_t) atIndex:12];
|
| 3093 |
+
|
| 3094 |
+
[encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
|
| 3095 |
+
} break;
|
| 3096 |
default:
|
| 3097 |
{
|
| 3098 |
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
|
|
@@ -1933,6 +1933,85 @@ kernel void kernel_im2col(
|
|
| 1933 |
template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
|
| 1934 |
template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
|
| 1935 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1936 |
kernel void kernel_upscale_f32(
|
| 1937 |
device const char * src0,
|
| 1938 |
device char * dst,
|
|
@@ -6372,3 +6451,102 @@ template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t
|
|
| 6372 |
template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl>>;
|
| 6373 |
template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>;
|
| 6374 |
template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1933 |
template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
|
| 1934 |
template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
|
| 1935 |
|
| 1936 |
+
typedef void (im2col_ext_t)(
|
| 1937 |
+
device const float * x,
|
| 1938 |
+
device char * dst,
|
| 1939 |
+
constant int32_t & ofs0,
|
| 1940 |
+
constant int32_t & ofs1,
|
| 1941 |
+
constant int32_t & IW,
|
| 1942 |
+
constant int32_t & IH,
|
| 1943 |
+
constant int32_t & CHW,
|
| 1944 |
+
constant int32_t & s0,
|
| 1945 |
+
constant int32_t & s1,
|
| 1946 |
+
constant int32_t & p0,
|
| 1947 |
+
constant int32_t & p1,
|
| 1948 |
+
constant int32_t & d0,
|
| 1949 |
+
constant int32_t & d1,
|
| 1950 |
+
constant int32_t & N,
|
| 1951 |
+
constant int32_t & KH,
|
| 1952 |
+
constant int32_t & KW,
|
| 1953 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1954 |
+
uint3 tgpg[[threadgroups_per_grid]],
|
| 1955 |
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
| 1956 |
+
uint3 ntg[[threads_per_threadgroup]]);
|
| 1957 |
+
|
| 1958 |
+
template <typename T>
|
| 1959 |
+
kernel void kernel_im2col_ext(
|
| 1960 |
+
device const float * x,
|
| 1961 |
+
device char * dst,
|
| 1962 |
+
constant int32_t & ofs0,
|
| 1963 |
+
constant int32_t & ofs1,
|
| 1964 |
+
constant int32_t & IW,
|
| 1965 |
+
constant int32_t & IH,
|
| 1966 |
+
constant int32_t & CHW,
|
| 1967 |
+
constant int32_t & s0,
|
| 1968 |
+
constant int32_t & s1,
|
| 1969 |
+
constant int32_t & p0,
|
| 1970 |
+
constant int32_t & p1,
|
| 1971 |
+
constant int32_t & d0,
|
| 1972 |
+
constant int32_t & d1,
|
| 1973 |
+
constant int32_t & N,
|
| 1974 |
+
constant int32_t & KH,
|
| 1975 |
+
constant int32_t & KW,
|
| 1976 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1977 |
+
uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
|
| 1978 |
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
| 1979 |
+
uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
|
| 1980 |
+
const int32_t KHW = KH * KW; // KHW == ntg[1] * ntg[2], KW == ntg[2]
|
| 1981 |
+
|
| 1982 |
+
const int32_t d = tgpig[0] / CHW;
|
| 1983 |
+
const int32_t chw = tgpig[0] % CHW;
|
| 1984 |
+
const int32_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
|
| 1985 |
+
const int32_t HW = tgpig[0] % KHW;
|
| 1986 |
+
|
| 1987 |
+
const int32_t tpitg_0 = (d * ntg[0]) + tpitg[0];
|
| 1988 |
+
if (tpitg_0 >= N) {
|
| 1989 |
+
return;
|
| 1990 |
+
}
|
| 1991 |
+
|
| 1992 |
+
const int32_t tpitg_1 = HW / KW;
|
| 1993 |
+
const int32_t tpitg_2 = HW % KW;
|
| 1994 |
+
|
| 1995 |
+
const int32_t iiw = tgpig[2] * s0 + tpitg_2 * d0 - p0;
|
| 1996 |
+
const int32_t iih = tgpig[1] * s1 + tpitg_1 * d1 - p1;
|
| 1997 |
+
|
| 1998 |
+
const int32_t offset_dst =
|
| 1999 |
+
(tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
|
| 2000 |
+
(tgpig_0 * KHW + tpitg_1 * KW + tpitg_2);
|
| 2001 |
+
|
| 2002 |
+
device T * pdst = (device T *) (dst);
|
| 2003 |
+
|
| 2004 |
+
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
| 2005 |
+
pdst[offset_dst] = 0.0f;
|
| 2006 |
+
} else {
|
| 2007 |
+
const int32_t offset_src = tpitg_0 * ofs0 + tgpig_0 * ofs1;
|
| 2008 |
+
pdst[offset_dst] = x[offset_src + iih * IW + iiw];
|
| 2009 |
+
}
|
| 2010 |
+
}
|
| 2011 |
+
|
| 2012 |
+
template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
|
| 2013 |
+
template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
|
| 2014 |
+
|
| 2015 |
kernel void kernel_upscale_f32(
|
| 2016 |
device const char * src0,
|
| 2017 |
device char * dst,
|
|
|
|
| 6451 |
template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl>>;
|
| 6452 |
template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>;
|
| 6453 |
template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;
|
| 6454 |
+
|
| 6455 |
+
kernel void kernel_pool_2d_max_f32(
|
| 6456 |
+
device const float * src0,
|
| 6457 |
+
device float * dst,
|
| 6458 |
+
constant int32_t & k0,
|
| 6459 |
+
constant int32_t & k1,
|
| 6460 |
+
constant int32_t & s0,
|
| 6461 |
+
constant int32_t & s1,
|
| 6462 |
+
constant int32_t & p0,
|
| 6463 |
+
constant int32_t & p1,
|
| 6464 |
+
constant int64_t & IH,
|
| 6465 |
+
constant int64_t & IW,
|
| 6466 |
+
constant int64_t & OH,
|
| 6467 |
+
constant int64_t & OW,
|
| 6468 |
+
constant int64_t & parallel_elements,
|
| 6469 |
+
uint gid[[thread_position_in_grid]]) {
|
| 6470 |
+
|
| 6471 |
+
if (gid >= parallel_elements) {
|
| 6472 |
+
return;
|
| 6473 |
+
}
|
| 6474 |
+
|
| 6475 |
+
const int idx = gid;
|
| 6476 |
+
const int I_HW = IH * IW;
|
| 6477 |
+
const int O_HW = OH * OW;
|
| 6478 |
+
const int nc = idx / O_HW;
|
| 6479 |
+
const int cur_oh = idx % O_HW / OW;
|
| 6480 |
+
const int cur_ow = idx % O_HW % OW;
|
| 6481 |
+
|
| 6482 |
+
device const float * i_ptr = src0 + nc * I_HW;
|
| 6483 |
+
device float * o_ptr = dst + nc * O_HW;
|
| 6484 |
+
|
| 6485 |
+
const int start_h = cur_oh * s1 - p1;
|
| 6486 |
+
const int bh = MAX(0, start_h);
|
| 6487 |
+
const int eh = MIN(IH, start_h + k1);
|
| 6488 |
+
const int start_w = cur_ow * s0 - p0;
|
| 6489 |
+
const int bw = MAX(0, start_w);
|
| 6490 |
+
const int ew = MIN(IW, start_w + k0);
|
| 6491 |
+
|
| 6492 |
+
float res = -INFINITY;
|
| 6493 |
+
|
| 6494 |
+
for (int i = bh; i < eh; i += 1) {
|
| 6495 |
+
for (int j = bw; j < ew; j += 1) {
|
| 6496 |
+
res = MAX(res, i_ptr[i * IW + j]);
|
| 6497 |
+
}
|
| 6498 |
+
}
|
| 6499 |
+
|
| 6500 |
+
o_ptr[cur_oh * OW + cur_ow] = res;
|
| 6501 |
+
}
|
| 6502 |
+
|
| 6503 |
+
kernel void kernel_pool_2d_avg_f32(
|
| 6504 |
+
device const float * src0,
|
| 6505 |
+
device float * dst,
|
| 6506 |
+
constant int32_t & k0,
|
| 6507 |
+
constant int32_t & k1,
|
| 6508 |
+
constant int32_t & s0,
|
| 6509 |
+
constant int32_t & s1,
|
| 6510 |
+
constant int32_t & p0,
|
| 6511 |
+
constant int32_t & p1,
|
| 6512 |
+
constant int64_t & IH,
|
| 6513 |
+
constant int64_t & IW,
|
| 6514 |
+
constant int64_t & OH,
|
| 6515 |
+
constant int64_t & OW,
|
| 6516 |
+
constant int64_t & parallel_elements,
|
| 6517 |
+
uint gid[[thread_position_in_grid]]) {
|
| 6518 |
+
|
| 6519 |
+
if (gid >= parallel_elements) {
|
| 6520 |
+
return;
|
| 6521 |
+
}
|
| 6522 |
+
|
| 6523 |
+
const int idx = gid;
|
| 6524 |
+
const int I_HW = IH * IW;
|
| 6525 |
+
const int O_HW = OH * OW;
|
| 6526 |
+
const int nc = idx / O_HW;
|
| 6527 |
+
const int cur_oh = idx % O_HW / OW;
|
| 6528 |
+
const int cur_ow = idx % O_HW % OW;
|
| 6529 |
+
|
| 6530 |
+
device const float * i_ptr = src0 + nc * I_HW;
|
| 6531 |
+
device float * o_ptr = dst + nc * O_HW;
|
| 6532 |
+
|
| 6533 |
+
const int start_h = cur_oh * s1 - p1;
|
| 6534 |
+
const int bh = MAX(0, start_h);
|
| 6535 |
+
const int eh = MIN(IH, start_h + k1);
|
| 6536 |
+
const int start_w = cur_ow * s0 - p0;
|
| 6537 |
+
const int bw = MAX(0, start_w);
|
| 6538 |
+
const int ew = MIN(IW, start_w + k0);
|
| 6539 |
+
// const float scale = 1. / ((eh - bh) * (ew - bw));
|
| 6540 |
+
const float scale = 1. / (k0 * k1);
|
| 6541 |
+
|
| 6542 |
+
float res = 0;
|
| 6543 |
+
|
| 6544 |
+
for (int i = bh; i < eh; i += 1) {
|
| 6545 |
+
for (int j = bw; j < ew; j += 1) {
|
| 6546 |
+
float cur = i_ptr[i * IW + j];
|
| 6547 |
+
res += cur * scale;
|
| 6548 |
+
}
|
| 6549 |
+
}
|
| 6550 |
+
|
| 6551 |
+
o_ptr[cur_oh * OW + cur_ow] = res;
|
| 6552 |
+
}
|