newfrisbie commited on
Commit
b553b89
·
1 Parent(s): f8d4728

metal : add POOL2D and fix IM2COL (llama/9943)

Browse files

* add pool_2d

Signed-off-by: Junhee Yoo <[email protected]>

* fix im2col and add unittest for N>=1024

Signed-off-by: Junhee Yoo <[email protected]>

* add tests for N % 1024 != 0

Signed-off-by: Junhee Yoo <[email protected]>

* remove trailing whitespaces

Signed-off-by: Junhee Yoo <[email protected]>

* apply suggestions

Signed-off-by: Junhee Yoo <[email protected]>

* apply more optimization

- original IM2COL kernel + _ext with MIN()

Signed-off-by: Junhee Yoo <[email protected]>

* apply review: change kernel name of pool_2d

Signed-off-by: Junhee Yoo <[email protected]>

* apply review

Signed-off-by: Junhee Yoo <[email protected]>

* fix more formatting and enhance readability

Signed-off-by: Junhee Yoo <[email protected]>

---------

Signed-off-by: Junhee Yoo <[email protected]>

Files changed (2) hide show
  1. ggml/src/ggml-metal.m +111 -19
  2. ggml/src/ggml-metal.metal +178 -0
ggml/src/ggml-metal.m CHANGED
@@ -241,6 +241,8 @@ enum ggml_metal_kernel_type {
241
  GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
242
  GGML_METAL_KERNEL_TYPE_IM2COL_F16,
243
  GGML_METAL_KERNEL_TYPE_IM2COL_F32,
 
 
244
  GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
245
  GGML_METAL_KERNEL_TYPE_PAD_F32,
246
  GGML_METAL_KERNEL_TYPE_ARANGE_F32,
@@ -272,6 +274,8 @@ enum ggml_metal_kernel_type {
272
  GGML_METAL_KERNEL_TYPE_SIN,
273
  GGML_METAL_KERNEL_TYPE_COS,
274
  GGML_METAL_KERNEL_TYPE_SUM_ROWS,
 
 
275
 
276
  GGML_METAL_KERNEL_TYPE_COUNT
277
  };
@@ -685,6 +689,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
685
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
686
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
687
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
 
 
688
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
689
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
690
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
@@ -716,6 +722,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
716
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
717
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
718
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
 
 
719
  }
720
 
721
  [metal_library release];
@@ -844,8 +852,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
844
  case GGML_OP_IM2COL:
845
  return op->src[0]->type == GGML_TYPE_F16;
846
  case GGML_OP_POOL_1D:
847
- case GGML_OP_POOL_2D:
848
  return false;
 
849
  case GGML_OP_UPSCALE:
850
  case GGML_OP_PAD:
851
  case GGML_OP_ARANGE:
@@ -2545,6 +2553,8 @@ static void ggml_metal_encode_node(
2545
  } break;
2546
  case GGML_OP_IM2COL:
2547
  {
 
 
2548
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
2549
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
2550
  GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
@@ -2574,30 +2584,54 @@ static void ggml_metal_encode_node(
2574
  const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
2575
  const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
2576
 
2577
- id<MTLComputePipelineState> pipeline = nil;
 
 
2578
 
2579
  switch (dst->type) {
2580
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break;
2581
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
 
 
 
 
 
 
 
 
 
 
2582
  default: GGML_ABORT("fatal error");
2583
  };
2584
 
2585
  [encoder setComputePipelineState:pipeline];
2586
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
2587
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2588
- [encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
2589
- [encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
2590
- [encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
2591
- [encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
2592
- [encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
2593
- [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
2594
- [encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
2595
- [encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
2596
- [encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
2597
- [encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
2598
- [encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
2599
-
2600
- [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
 
 
 
 
 
 
 
 
 
 
 
 
2601
  } break;
2602
  case GGML_OP_UPSCALE:
2603
  {
@@ -3001,6 +3035,64 @@ static void ggml_metal_encode_node(
3001
 
3002
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
3003
  } break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3004
  default:
3005
  {
3006
  GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
 
241
  GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
242
  GGML_METAL_KERNEL_TYPE_IM2COL_F16,
243
  GGML_METAL_KERNEL_TYPE_IM2COL_F32,
244
+ GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,
245
+ GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32,
246
  GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
247
  GGML_METAL_KERNEL_TYPE_PAD_F32,
248
  GGML_METAL_KERNEL_TYPE_ARANGE_F32,
 
274
  GGML_METAL_KERNEL_TYPE_SIN,
275
  GGML_METAL_KERNEL_TYPE_COS,
276
  GGML_METAL_KERNEL_TYPE_SUM_ROWS,
277
+ GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
278
+ GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
279
 
280
  GGML_METAL_KERNEL_TYPE_COUNT
281
  };
 
689
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
690
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
691
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
692
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true);
693
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true);
694
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
695
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
696
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
 
722
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
723
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
724
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
725
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
726
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
727
  }
728
 
729
  [metal_library release];
 
852
  case GGML_OP_IM2COL:
853
  return op->src[0]->type == GGML_TYPE_F16;
854
  case GGML_OP_POOL_1D:
 
855
  return false;
856
+ case GGML_OP_POOL_2D:
857
  case GGML_OP_UPSCALE:
858
  case GGML_OP_PAD:
859
  case GGML_OP_ARANGE:
 
2553
  } break;
2554
  case GGML_OP_IM2COL:
2555
  {
2556
+ GGML_ASSERT(ggml_is_contiguous(src0));
2557
+ GGML_ASSERT(ggml_is_contiguous(src1));
2558
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
2559
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
2560
  GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
 
2584
  const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
2585
  const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
2586
 
2587
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline;
2588
+
2589
+ const bool is_gt_mttpt = ((size_t)(N * KH * KW)) > pipeline.maxTotalThreadsPerThreadgroup;
2590
 
2591
  switch (dst->type) {
2592
+ case GGML_TYPE_F32: {
2593
+ pipeline = (is_gt_mttpt ?
2594
+ ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32].pipeline
2595
+ :
2596
+ ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline);
2597
+ } break;
2598
+ case GGML_TYPE_F16: {
2599
+ pipeline = (is_gt_mttpt ?
2600
+ ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16].pipeline
2601
+ :
2602
+ ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline);
2603
+ } break;
2604
  default: GGML_ABORT("fatal error");
2605
  };
2606
 
2607
  [encoder setComputePipelineState:pipeline];
2608
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
2609
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2610
+ [encoder setBytes:&ofs0 length:sizeof(int32_t) atIndex:2];
2611
+ [encoder setBytes:&ofs1 length:sizeof(int32_t) atIndex:3];
2612
+ [encoder setBytes:&IW length:sizeof(int32_t) atIndex:4];
2613
+ [encoder setBytes:&IH length:sizeof(int32_t) atIndex:5];
2614
+ [encoder setBytes:&CHW length:sizeof(int32_t) atIndex:6];
2615
+ [encoder setBytes:&s0 length:sizeof(int32_t) atIndex:7];
2616
+ [encoder setBytes:&s1 length:sizeof(int32_t) atIndex:8];
2617
+ [encoder setBytes:&p0 length:sizeof(int32_t) atIndex:9];
2618
+ [encoder setBytes:&p1 length:sizeof(int32_t) atIndex:10];
2619
+ [encoder setBytes:&d0 length:sizeof(int32_t) atIndex:11];
2620
+ [encoder setBytes:&d1 length:sizeof(int32_t) atIndex:12];
2621
+
2622
+ if (is_gt_mttpt) {
2623
+ [encoder setBytes:&N length:sizeof(int32_t) atIndex:13];
2624
+ [encoder setBytes:&KH length:sizeof(int32_t) atIndex:14];
2625
+ [encoder setBytes:&KW length:sizeof(int32_t) atIndex:15];
2626
+
2627
+ const uint64_t n_threads = MIN(pipeline.maxTotalThreadsPerThreadgroup, (uint64_t)N);
2628
+
2629
+ const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0);
2630
+
2631
+ [encoder dispatchThreadgroups:MTLSizeMake(quotient * CHW, OH, OW) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
2632
+ } else {
2633
+ [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
2634
+ }
2635
  } break;
2636
  case GGML_OP_UPSCALE:
2637
  {
 
3035
 
3036
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
3037
  } break;
3038
+ case GGML_OP_POOL_2D:
3039
+ {
3040
+ GGML_ASSERT(ggml_is_contiguous(src0));
3041
+ GGML_ASSERT(src0t == GGML_TYPE_F32 && src0t == dstt);
3042
+
3043
+ const int32_t * opts = dst->op_params;
3044
+ enum ggml_op_pool op = opts[0];
3045
+
3046
+ id<MTLComputePipelineState> pipeline = nil;
3047
+ switch (src0t) {
3048
+ case GGML_TYPE_F32: {
3049
+ switch(op) {
3050
+ case GGML_OP_POOL_AVG:
3051
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32].pipeline; break;
3052
+ case GGML_OP_POOL_MAX:
3053
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32].pipeline; break;
3054
+ default: GGML_ASSERT(false && "not implemented");
3055
+ }
3056
+ } break;
3057
+ default: GGML_ASSERT(false && "not implemented");
3058
+ }
3059
+
3060
+ const int32_t k0 = opts[1];
3061
+ const int32_t k1 = opts[2];
3062
+ const int32_t s0 = opts[3];
3063
+ const int32_t s1 = opts[4];
3064
+ const int32_t p0 = opts[5];
3065
+ const int32_t p1 = opts[6];
3066
+
3067
+ const int64_t IH = src0->ne[1];
3068
+ const int64_t IW = src0->ne[0];
3069
+
3070
+ const int64_t N = dst->ne[3];
3071
+ const int64_t OC = dst->ne[2];
3072
+ const int64_t OH = dst->ne[1];
3073
+ const int64_t OW = dst->ne[0];
3074
+
3075
+ const int64_t parallel_elements = N * OC * OH * OW;
3076
+ const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements);
3077
+ const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads;
3078
+
3079
+ [encoder setComputePipelineState:pipeline];
3080
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3081
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
3082
+ [encoder setBytes:&k0 length:sizeof(int32_t) atIndex:2];
3083
+ [encoder setBytes:&k1 length:sizeof(int32_t) atIndex:3];
3084
+ [encoder setBytes:&s0 length:sizeof(int32_t) atIndex:4];
3085
+ [encoder setBytes:&s1 length:sizeof(int32_t) atIndex:5];
3086
+ [encoder setBytes:&p0 length:sizeof(int32_t) atIndex:6];
3087
+ [encoder setBytes:&p1 length:sizeof(int32_t) atIndex:7];
3088
+ [encoder setBytes:&IH length:sizeof(int64_t) atIndex:8];
3089
+ [encoder setBytes:&IW length:sizeof(int64_t) atIndex:9];
3090
+ [encoder setBytes:&OH length:sizeof(int64_t) atIndex:10];
3091
+ [encoder setBytes:&OW length:sizeof(int64_t) atIndex:11];
3092
+ [encoder setBytes:&parallel_elements length:sizeof(int64_t) atIndex:12];
3093
+
3094
+ [encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
3095
+ } break;
3096
  default:
3097
  {
3098
  GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
ggml/src/ggml-metal.metal CHANGED
@@ -1933,6 +1933,85 @@ kernel void kernel_im2col(
1933
  template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
1934
  template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
1935
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1936
  kernel void kernel_upscale_f32(
1937
  device const char * src0,
1938
  device char * dst,
@@ -6372,3 +6451,102 @@ template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t
6372
  template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl>>;
6373
  template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>;
6374
  template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1933
  template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
1934
  template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
1935
 
1936
+ typedef void (im2col_ext_t)(
1937
+ device const float * x,
1938
+ device char * dst,
1939
+ constant int32_t & ofs0,
1940
+ constant int32_t & ofs1,
1941
+ constant int32_t & IW,
1942
+ constant int32_t & IH,
1943
+ constant int32_t & CHW,
1944
+ constant int32_t & s0,
1945
+ constant int32_t & s1,
1946
+ constant int32_t & p0,
1947
+ constant int32_t & p1,
1948
+ constant int32_t & d0,
1949
+ constant int32_t & d1,
1950
+ constant int32_t & N,
1951
+ constant int32_t & KH,
1952
+ constant int32_t & KW,
1953
+ uint3 tgpig[[threadgroup_position_in_grid]],
1954
+ uint3 tgpg[[threadgroups_per_grid]],
1955
+ uint3 tpitg[[thread_position_in_threadgroup]],
1956
+ uint3 ntg[[threads_per_threadgroup]]);
1957
+
1958
+ template <typename T>
1959
+ kernel void kernel_im2col_ext(
1960
+ device const float * x,
1961
+ device char * dst,
1962
+ constant int32_t & ofs0,
1963
+ constant int32_t & ofs1,
1964
+ constant int32_t & IW,
1965
+ constant int32_t & IH,
1966
+ constant int32_t & CHW,
1967
+ constant int32_t & s0,
1968
+ constant int32_t & s1,
1969
+ constant int32_t & p0,
1970
+ constant int32_t & p1,
1971
+ constant int32_t & d0,
1972
+ constant int32_t & d1,
1973
+ constant int32_t & N,
1974
+ constant int32_t & KH,
1975
+ constant int32_t & KW,
1976
+ uint3 tgpig[[threadgroup_position_in_grid]],
1977
+ uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
1978
+ uint3 tpitg[[thread_position_in_threadgroup]],
1979
+ uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
1980
+ const int32_t KHW = KH * KW; // KHW == ntg[1] * ntg[2], KW == ntg[2]
1981
+
1982
+ const int32_t d = tgpig[0] / CHW;
1983
+ const int32_t chw = tgpig[0] % CHW;
1984
+ const int32_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
1985
+ const int32_t HW = tgpig[0] % KHW;
1986
+
1987
+ const int32_t tpitg_0 = (d * ntg[0]) + tpitg[0];
1988
+ if (tpitg_0 >= N) {
1989
+ return;
1990
+ }
1991
+
1992
+ const int32_t tpitg_1 = HW / KW;
1993
+ const int32_t tpitg_2 = HW % KW;
1994
+
1995
+ const int32_t iiw = tgpig[2] * s0 + tpitg_2 * d0 - p0;
1996
+ const int32_t iih = tgpig[1] * s1 + tpitg_1 * d1 - p1;
1997
+
1998
+ const int32_t offset_dst =
1999
+ (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
2000
+ (tgpig_0 * KHW + tpitg_1 * KW + tpitg_2);
2001
+
2002
+ device T * pdst = (device T *) (dst);
2003
+
2004
+ if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
2005
+ pdst[offset_dst] = 0.0f;
2006
+ } else {
2007
+ const int32_t offset_src = tpitg_0 * ofs0 + tgpig_0 * ofs1;
2008
+ pdst[offset_dst] = x[offset_src + iih * IW + iiw];
2009
+ }
2010
+ }
2011
+
2012
+ template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
2013
+ template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
2014
+
2015
  kernel void kernel_upscale_f32(
2016
  device const char * src0,
2017
  device char * dst,
 
6451
  template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl>>;
6452
  template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>;
6453
  template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;
6454
+
6455
+ kernel void kernel_pool_2d_max_f32(
6456
+ device const float * src0,
6457
+ device float * dst,
6458
+ constant int32_t & k0,
6459
+ constant int32_t & k1,
6460
+ constant int32_t & s0,
6461
+ constant int32_t & s1,
6462
+ constant int32_t & p0,
6463
+ constant int32_t & p1,
6464
+ constant int64_t & IH,
6465
+ constant int64_t & IW,
6466
+ constant int64_t & OH,
6467
+ constant int64_t & OW,
6468
+ constant int64_t & parallel_elements,
6469
+ uint gid[[thread_position_in_grid]]) {
6470
+
6471
+ if (gid >= parallel_elements) {
6472
+ return;
6473
+ }
6474
+
6475
+ const int idx = gid;
6476
+ const int I_HW = IH * IW;
6477
+ const int O_HW = OH * OW;
6478
+ const int nc = idx / O_HW;
6479
+ const int cur_oh = idx % O_HW / OW;
6480
+ const int cur_ow = idx % O_HW % OW;
6481
+
6482
+ device const float * i_ptr = src0 + nc * I_HW;
6483
+ device float * o_ptr = dst + nc * O_HW;
6484
+
6485
+ const int start_h = cur_oh * s1 - p1;
6486
+ const int bh = MAX(0, start_h);
6487
+ const int eh = MIN(IH, start_h + k1);
6488
+ const int start_w = cur_ow * s0 - p0;
6489
+ const int bw = MAX(0, start_w);
6490
+ const int ew = MIN(IW, start_w + k0);
6491
+
6492
+ float res = -INFINITY;
6493
+
6494
+ for (int i = bh; i < eh; i += 1) {
6495
+ for (int j = bw; j < ew; j += 1) {
6496
+ res = MAX(res, i_ptr[i * IW + j]);
6497
+ }
6498
+ }
6499
+
6500
+ o_ptr[cur_oh * OW + cur_ow] = res;
6501
+ }
6502
+
6503
+ kernel void kernel_pool_2d_avg_f32(
6504
+ device const float * src0,
6505
+ device float * dst,
6506
+ constant int32_t & k0,
6507
+ constant int32_t & k1,
6508
+ constant int32_t & s0,
6509
+ constant int32_t & s1,
6510
+ constant int32_t & p0,
6511
+ constant int32_t & p1,
6512
+ constant int64_t & IH,
6513
+ constant int64_t & IW,
6514
+ constant int64_t & OH,
6515
+ constant int64_t & OW,
6516
+ constant int64_t & parallel_elements,
6517
+ uint gid[[thread_position_in_grid]]) {
6518
+
6519
+ if (gid >= parallel_elements) {
6520
+ return;
6521
+ }
6522
+
6523
+ const int idx = gid;
6524
+ const int I_HW = IH * IW;
6525
+ const int O_HW = OH * OW;
6526
+ const int nc = idx / O_HW;
6527
+ const int cur_oh = idx % O_HW / OW;
6528
+ const int cur_ow = idx % O_HW % OW;
6529
+
6530
+ device const float * i_ptr = src0 + nc * I_HW;
6531
+ device float * o_ptr = dst + nc * O_HW;
6532
+
6533
+ const int start_h = cur_oh * s1 - p1;
6534
+ const int bh = MAX(0, start_h);
6535
+ const int eh = MIN(IH, start_h + k1);
6536
+ const int start_w = cur_ow * s0 - p0;
6537
+ const int bw = MAX(0, start_w);
6538
+ const int ew = MIN(IW, start_w + k0);
6539
+ // const float scale = 1. / ((eh - bh) * (ew - bw));
6540
+ const float scale = 1. / (k0 * k1);
6541
+
6542
+ float res = 0;
6543
+
6544
+ for (int i = bh; i < eh; i += 1) {
6545
+ for (int j = bw; j < ew; j += 1) {
6546
+ float cur = i_ptr[i * IW + j];
6547
+ res += cur * scale;
6548
+ }
6549
+ }
6550
+
6551
+ o_ptr[cur_oh * OW + cur_ow] = res;
6552
+ }