Sigbjørn Skjæret ggerganov OccamRazor Akarshan jeffbolznv commited on
Commit
add5c0f
·
1 Parent(s): 737f12d

ggml : implement REGLU/GEGLU/SWIGLU ops (llama/14158)

Browse files

* implement unary REGLU/GEGLU/SWIGLU cpu ops

* relax constraints

* duplicate shape of source

* fix ggml_vec_geglu_f16

* special case gated ops

* implement unary REGLU/GEGLU/SWIGLU cuda ops

* tighten constraints again

* refactor into GGML_GLU_OP

* metal : add glu kernels

ggml-ci

* add CUDA_GLU_BLOCK_SIZE [no ci]

* more constraints and use 64bit ints

ggml-ci

* 64bit multiplication [no ci]

* implement swapped variants (cpu/cuda)

* update comment [no ci]

ggml-ci

* Vulkan: Add GLU ops and shaders

* SYCL: Implement fused kernel GEGLU, SWIGLU and REGLU for single up+gate

* ggml : implement GLU for split up/gate (llama/14181)

* implement GLU for split up/gate

* add tests for ggml_glu_split

* Vulkan: Implement glu_split logic and shader support

* add split to logging [no ci]

* SYCL: refactor element_size ops and add split up and gate support to gated kernels

* SYCL: switch GEGLU to use tanh approximation

---------

Co-authored-by: 0cc4m <[email protected]>
Co-authored-by: Akarshan <[email protected]>

* GGML: increase OP count in assertion

* Refactor: Optimize SYCL element-wise operations with unary function inlining

This commit refactors the SYCL element-wise operations to improve performance by:

- Inlining unary operations (sgn, abs, elu, gelu, silu, etc.) to reduce kernel launch overhead.
- Introducing helper functions `op_xxx` for each unary operation to encapsulate the logic.
- Replacing direct kernel calls with calls to these inlined functions.
- Using `__dpct_inline__` to encourage compiler inlining.
- Minor code cleanup and consistency improvements.

The changes aim to reduce kernel launch overhead and improve the overall efficiency of element-wise operations on SYCL devices.

* vulkan: Increase workgroup size for GLU, for performance (llama/14345)

* vulkan: Increase workgroup size for GLU, for performance

* vulkan: change GLU shaders to do one element per invocation rather than one row per workgroup

* merge fix

* metal : add support for split and swap

ggml-ci

---------

Co-authored-by: Georgi Gerganov <[email protected]>
Co-authored-by: 0cc4m <[email protected]>
Co-authored-by: Akarshan <[email protected]>
Co-authored-by: Jeff Bolz <[email protected]>

ggml/include/ggml.h CHANGED
@@ -520,6 +520,8 @@ extern "C" {
520
  GGML_OP_CROSS_ENTROPY_LOSS_BACK,
521
  GGML_OP_OPT_STEP_ADAMW,
522
 
 
 
523
  GGML_OP_COUNT,
524
  };
525
 
@@ -543,6 +545,14 @@ extern "C" {
543
  GGML_UNARY_OP_COUNT,
544
  };
545
 
 
 
 
 
 
 
 
 
546
  enum ggml_object_type {
547
  GGML_OBJECT_TYPE_TENSOR,
548
  GGML_OBJECT_TYPE_GRAPH,
@@ -658,6 +668,7 @@ extern "C" {
658
  GGML_API const char * ggml_op_symbol(enum ggml_op op);
659
 
660
  GGML_API const char * ggml_unary_op_name(enum ggml_unary_op op);
 
661
  GGML_API const char * ggml_op_desc(const struct ggml_tensor * t); // unary or op name
662
 
663
  GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor);
@@ -762,6 +773,7 @@ extern "C" {
762
  GGML_API void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3);
763
 
764
  GGML_API enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor);
 
765
 
766
  GGML_API void * ggml_get_data (const struct ggml_tensor * tensor);
767
  GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
@@ -1090,6 +1102,63 @@ extern "C" {
1090
  struct ggml_context * ctx,
1091
  struct ggml_tensor * a);
1092
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1093
  // normalize along rows
1094
  GGML_API struct ggml_tensor * ggml_norm(
1095
  struct ggml_context * ctx,
 
520
  GGML_OP_CROSS_ENTROPY_LOSS_BACK,
521
  GGML_OP_OPT_STEP_ADAMW,
522
 
523
+ GGML_OP_GLU,
524
+
525
  GGML_OP_COUNT,
526
  };
527
 
 
545
  GGML_UNARY_OP_COUNT,
546
  };
547
 
548
+ enum ggml_glu_op {
549
+ GGML_GLU_OP_REGLU,
550
+ GGML_GLU_OP_GEGLU,
551
+ GGML_GLU_OP_SWIGLU,
552
+
553
+ GGML_GLU_OP_COUNT,
554
+ };
555
+
556
  enum ggml_object_type {
557
  GGML_OBJECT_TYPE_TENSOR,
558
  GGML_OBJECT_TYPE_GRAPH,
 
668
  GGML_API const char * ggml_op_symbol(enum ggml_op op);
669
 
670
  GGML_API const char * ggml_unary_op_name(enum ggml_unary_op op);
671
+ GGML_API const char * ggml_glu_op_name(enum ggml_glu_op op);
672
  GGML_API const char * ggml_op_desc(const struct ggml_tensor * t); // unary or op name
673
 
674
  GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor);
 
773
  GGML_API void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3);
774
 
775
  GGML_API enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor);
776
+ GGML_API enum ggml_glu_op ggml_get_glu_op(const struct ggml_tensor * tensor);
777
 
778
  GGML_API void * ggml_get_data (const struct ggml_tensor * tensor);
779
  GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
 
1102
  struct ggml_context * ctx,
1103
  struct ggml_tensor * a);
1104
 
1105
+ // gated linear unit ops
1106
+ // A: n columns, r rows,
1107
+ // result is n / 2 columns, r rows,
1108
+ // expects gate in second half of row, unless swapped is true
1109
+ GGML_API struct ggml_tensor * ggml_glu(
1110
+ struct ggml_context * ctx,
1111
+ struct ggml_tensor * a,
1112
+ enum ggml_glu_op op,
1113
+ bool swapped);
1114
+
1115
+ GGML_API struct ggml_tensor * ggml_reglu(
1116
+ struct ggml_context * ctx,
1117
+ struct ggml_tensor * a);
1118
+
1119
+ GGML_API struct ggml_tensor * ggml_reglu_swapped(
1120
+ struct ggml_context * ctx,
1121
+ struct ggml_tensor * a);
1122
+
1123
+ GGML_API struct ggml_tensor * ggml_geglu(
1124
+ struct ggml_context * ctx,
1125
+ struct ggml_tensor * a);
1126
+
1127
+ GGML_API struct ggml_tensor * ggml_geglu_swapped(
1128
+ struct ggml_context * ctx,
1129
+ struct ggml_tensor * a);
1130
+
1131
+ GGML_API struct ggml_tensor * ggml_swiglu(
1132
+ struct ggml_context * ctx,
1133
+ struct ggml_tensor * a);
1134
+
1135
+ GGML_API struct ggml_tensor * ggml_swiglu_swapped(
1136
+ struct ggml_context * ctx,
1137
+ struct ggml_tensor * a);
1138
+
1139
+ // A: n columns, r rows,
1140
+ // B: n columns, r rows,
1141
+ GGML_API struct ggml_tensor * ggml_glu_split(
1142
+ struct ggml_context * ctx,
1143
+ struct ggml_tensor * a,
1144
+ struct ggml_tensor * b,
1145
+ enum ggml_glu_op op);
1146
+
1147
+ GGML_API struct ggml_tensor * ggml_reglu_split(
1148
+ struct ggml_context * ctx,
1149
+ struct ggml_tensor * a,
1150
+ struct ggml_tensor * b);
1151
+
1152
+ GGML_API struct ggml_tensor * ggml_geglu_split(
1153
+ struct ggml_context * ctx,
1154
+ struct ggml_tensor * a,
1155
+ struct ggml_tensor * b);
1156
+
1157
+ GGML_API struct ggml_tensor * ggml_swiglu_split(
1158
+ struct ggml_context * ctx,
1159
+ struct ggml_tensor * a,
1160
+ struct ggml_tensor * b);
1161
+
1162
  // normalize along rows
1163
  GGML_API struct ggml_tensor * ggml_norm(
1164
  struct ggml_context * ctx,
ggml/src/ggml-cpu/ggml-cpu.c CHANGED
@@ -1949,6 +1949,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
1949
  {
1950
  ggml_compute_forward_unary(params, tensor);
1951
  } break;
 
 
 
 
1952
  case GGML_OP_GET_REL_POS:
1953
  {
1954
  ggml_compute_forward_get_rel_pos(params, tensor);
@@ -2159,6 +2163,18 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
2159
  GGML_ABORT("fatal error");
2160
  }
2161
  break;
 
 
 
 
 
 
 
 
 
 
 
 
2162
  case GGML_OP_SILU_BACK:
2163
  case GGML_OP_MUL:
2164
  case GGML_OP_DIV:
 
1949
  {
1950
  ggml_compute_forward_unary(params, tensor);
1951
  } break;
1952
+ case GGML_OP_GLU:
1953
+ {
1954
+ ggml_compute_forward_glu(params, tensor);
1955
+ } break;
1956
  case GGML_OP_GET_REL_POS:
1957
  {
1958
  ggml_compute_forward_get_rel_pos(params, tensor);
 
2163
  GGML_ABORT("fatal error");
2164
  }
2165
  break;
2166
+ case GGML_OP_GLU:
2167
+ switch (ggml_get_glu_op(node)) {
2168
+ case GGML_GLU_OP_REGLU:
2169
+ case GGML_GLU_OP_GEGLU:
2170
+ case GGML_GLU_OP_SWIGLU:
2171
+ {
2172
+ n_tasks = n_threads;
2173
+ } break;
2174
+ default:
2175
+ GGML_ABORT("fatal error");
2176
+ }
2177
+ break;
2178
  case GGML_OP_SILU_BACK:
2179
  case GGML_OP_MUL:
2180
  case GGML_OP_DIV:
ggml/src/ggml-cpu/ops.cpp CHANGED
@@ -3184,6 +3184,435 @@ void ggml_compute_forward_silu_back(
3184
  }
3185
  }
3186
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3187
  // ggml_compute_forward_norm
3188
 
3189
  static void ggml_compute_forward_norm_f32(
@@ -8057,6 +8486,34 @@ void ggml_compute_forward_unary(
8057
  }
8058
  }
8059
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8060
  // ggml_compute_forward_get_rel_pos
8061
 
8062
  static void ggml_compute_forward_get_rel_pos_f16(
 
3184
  }
3185
  }
3186
 
3187
+ // ggml_compute_forward_reglu
3188
+
3189
+ static void ggml_compute_forward_reglu_f32(
3190
+ const ggml_compute_params * params,
3191
+ ggml_tensor * dst) {
3192
+
3193
+ const ggml_tensor * src0 = dst->src[0];
3194
+ const ggml_tensor * src1 = dst->src[1];
3195
+ char * src0_d = (char *) src0->data;
3196
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3197
+ const size_t src0_o = src0->nb[1];
3198
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3199
+
3200
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
3201
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
3202
+
3203
+ if (src1) {
3204
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
3205
+ GGML_ASSERT(src0->type == src1->type);
3206
+ }
3207
+
3208
+ const int ith = params->ith;
3209
+ const int nth = params->nth;
3210
+
3211
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3212
+ const int nr = ggml_nrows(src0);
3213
+
3214
+ GGML_ASSERT(dst->ne[0] == nc);
3215
+ GGML_ASSERT(ggml_nrows(dst) == nr);
3216
+
3217
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3218
+
3219
+ // rows per thread
3220
+ const int dr = (nr + nth - 1)/nth;
3221
+
3222
+ // row range for this thread
3223
+ const int ir0 = dr*ith;
3224
+ const int ir1 = MIN(ir0 + dr, nr);
3225
+
3226
+ for (int i1 = ir0; i1 < ir1; i1++) {
3227
+ float * src0_p = (float *) (src0_d + i1*src0_o);
3228
+ float * src1_p = (float *) (src1_d + i1*src1_o);
3229
+
3230
+ if (!src1) {
3231
+ src0_p += swapped ? nc : 0;
3232
+ src1_p += swapped ? 0 : nc;
3233
+ }
3234
+
3235
+ ggml_vec_reglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3236
+
3237
+ #ifndef NDEBUG
3238
+ for (int k = 0; k < nc; k++) {
3239
+ const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3240
+ GGML_UNUSED(x);
3241
+ assert(!isnan(x));
3242
+ assert(!isinf(x));
3243
+ }
3244
+ #endif
3245
+ }
3246
+ }
3247
+
3248
+ static void ggml_compute_forward_reglu_f16(
3249
+ const ggml_compute_params * params,
3250
+ ggml_tensor * dst) {
3251
+
3252
+ const ggml_tensor * src0 = dst->src[0];
3253
+ const ggml_tensor * src1 = dst->src[1];
3254
+ char * src0_d = (char *) src0->data;
3255
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3256
+ const size_t src0_o = src0->nb[1];
3257
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3258
+
3259
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
3260
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
3261
+
3262
+ if (src1) {
3263
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
3264
+ GGML_ASSERT(src0->type == src1->type);
3265
+ }
3266
+
3267
+ const int ith = params->ith;
3268
+ const int nth = params->nth;
3269
+
3270
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3271
+ const int nr = ggml_nrows(src0);
3272
+
3273
+ GGML_ASSERT(dst->ne[0] == nc);
3274
+ GGML_ASSERT(ggml_nrows(dst) == nr);
3275
+
3276
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3277
+
3278
+ // rows per thread
3279
+ const int dr = (nr + nth - 1)/nth;
3280
+
3281
+ // row range for this thread
3282
+ const int ir0 = dr*ith;
3283
+ const int ir1 = MIN(ir0 + dr, nr);
3284
+
3285
+ for (int i1 = ir0; i1 < ir1; i1++) {
3286
+ ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3287
+ ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3288
+
3289
+ if (!src1) {
3290
+ src0_p += swapped ? nc : 0;
3291
+ src1_p += swapped ? 0 : nc;
3292
+ }
3293
+
3294
+ ggml_vec_reglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3295
+
3296
+ #ifndef NDEBUG
3297
+ for (int k = 0; k < nc; k++) {
3298
+ const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3299
+ const float v = GGML_FP16_TO_FP32(x);
3300
+ GGML_UNUSED(v);
3301
+ assert(!isnan(v));
3302
+ assert(!isinf(v));
3303
+ }
3304
+ #endif
3305
+ }
3306
+ }
3307
+
3308
+ static void ggml_compute_forward_reglu(
3309
+ const ggml_compute_params * params,
3310
+ ggml_tensor * dst) {
3311
+
3312
+ const ggml_tensor * src0 = dst->src[0];
3313
+
3314
+ switch (src0->type) {
3315
+ case GGML_TYPE_F32:
3316
+ {
3317
+ ggml_compute_forward_reglu_f32(params, dst);
3318
+ } break;
3319
+ case GGML_TYPE_F16:
3320
+ {
3321
+ ggml_compute_forward_reglu_f16(params, dst);
3322
+ } break;
3323
+ default:
3324
+ {
3325
+ GGML_ABORT("fatal error");
3326
+ }
3327
+ }
3328
+ }
3329
+
3330
+ // ggml_compute_forward_geglu
3331
+
3332
+ static void ggml_compute_forward_geglu_f32(
3333
+ const ggml_compute_params * params,
3334
+ ggml_tensor * dst) {
3335
+
3336
+ const ggml_tensor * src0 = dst->src[0];
3337
+ const ggml_tensor * src1 = dst->src[1];
3338
+ char * src0_d = (char *) src0->data;
3339
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3340
+ const size_t src0_o = src0->nb[1];
3341
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3342
+
3343
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
3344
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
3345
+
3346
+ if (src1) {
3347
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
3348
+ GGML_ASSERT(src0->type == src1->type);
3349
+ }
3350
+
3351
+ const int ith = params->ith;
3352
+ const int nth = params->nth;
3353
+
3354
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3355
+ const int nr = ggml_nrows(src0);
3356
+
3357
+ GGML_ASSERT(dst->ne[0] == nc);
3358
+ GGML_ASSERT(ggml_nrows(dst) == nr);
3359
+
3360
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3361
+
3362
+ // rows per thread
3363
+ const int dr = (nr + nth - 1)/nth;
3364
+
3365
+ // row range for this thread
3366
+ const int ir0 = dr*ith;
3367
+ const int ir1 = MIN(ir0 + dr, nr);
3368
+
3369
+ for (int i1 = ir0; i1 < ir1; i1++) {
3370
+ float * src0_p = (float *) (src0_d + i1*src0_o);
3371
+ float * src1_p = (float *) (src1_d + i1*src1_o);
3372
+
3373
+ if (!src1) {
3374
+ src0_p += swapped ? nc : 0;
3375
+ src1_p += swapped ? 0 : nc;
3376
+ }
3377
+
3378
+ ggml_vec_geglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3379
+
3380
+ #ifndef NDEBUG
3381
+ for (int k = 0; k < nc; k++) {
3382
+ const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3383
+ GGML_UNUSED(x);
3384
+ assert(!isnan(x));
3385
+ assert(!isinf(x));
3386
+ }
3387
+ #endif
3388
+ }
3389
+ }
3390
+
3391
+ static void ggml_compute_forward_geglu_f16(
3392
+ const ggml_compute_params * params,
3393
+ ggml_tensor * dst) {
3394
+
3395
+ const ggml_tensor * src0 = dst->src[0];
3396
+ const ggml_tensor * src1 = dst->src[1];
3397
+ char * src0_d = (char *) src0->data;
3398
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3399
+ const size_t src0_o = src0->nb[1];
3400
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3401
+
3402
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
3403
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
3404
+
3405
+ if (src1) {
3406
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
3407
+ GGML_ASSERT(src0->type == src1->type);
3408
+ }
3409
+
3410
+ const int ith = params->ith;
3411
+ const int nth = params->nth;
3412
+
3413
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3414
+ const int nr = ggml_nrows(src0);
3415
+
3416
+ GGML_ASSERT(dst->ne[0] == nc);
3417
+ GGML_ASSERT(ggml_nrows(dst) == nr);
3418
+
3419
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3420
+
3421
+ // rows per thread
3422
+ const int dr = (nr + nth - 1)/nth;
3423
+
3424
+ // row range for this thread
3425
+ const int ir0 = dr*ith;
3426
+ const int ir1 = MIN(ir0 + dr, nr);
3427
+
3428
+ for (int i1 = ir0; i1 < ir1; i1++) {
3429
+ ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3430
+ ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3431
+
3432
+ if (!src1) {
3433
+ src0_p += swapped ? nc : 0;
3434
+ src1_p += swapped ? 0 : nc;
3435
+ }
3436
+
3437
+ ggml_vec_geglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3438
+
3439
+ #ifndef NDEBUG
3440
+ for (int k = 0; k < nc; k++) {
3441
+ const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3442
+ const float v = GGML_FP16_TO_FP32(x);
3443
+ GGML_UNUSED(v);
3444
+ assert(!isnan(v));
3445
+ assert(!isinf(v));
3446
+ }
3447
+ #endif
3448
+ }
3449
+ }
3450
+
3451
+ static void ggml_compute_forward_geglu(
3452
+ const ggml_compute_params * params,
3453
+ ggml_tensor * dst) {
3454
+
3455
+ const ggml_tensor * src0 = dst->src[0];
3456
+
3457
+ switch (src0->type) {
3458
+ case GGML_TYPE_F32:
3459
+ {
3460
+ ggml_compute_forward_geglu_f32(params, dst);
3461
+ } break;
3462
+ case GGML_TYPE_F16:
3463
+ {
3464
+ ggml_compute_forward_geglu_f16(params, dst);
3465
+ } break;
3466
+ default:
3467
+ {
3468
+ GGML_ABORT("fatal error");
3469
+ }
3470
+ }
3471
+ }
3472
+
3473
+ // ggml_compute_forward_swiglu
3474
+
3475
+ static void ggml_compute_forward_swiglu_f32(
3476
+ const ggml_compute_params * params,
3477
+ ggml_tensor * dst) {
3478
+
3479
+ const ggml_tensor * src0 = dst->src[0];
3480
+ const ggml_tensor * src1 = dst->src[1];
3481
+ char * src0_d = (char *) src0->data;
3482
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3483
+ const size_t src0_o = src0->nb[1];
3484
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3485
+
3486
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
3487
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
3488
+
3489
+ if (src1) {
3490
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
3491
+ GGML_ASSERT(src0->type == src1->type);
3492
+ }
3493
+
3494
+ const int ith = params->ith;
3495
+ const int nth = params->nth;
3496
+
3497
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3498
+ const int nr = ggml_nrows(src0);
3499
+
3500
+ GGML_ASSERT(dst->ne[0] == nc);
3501
+ GGML_ASSERT(ggml_nrows(dst) == nr);
3502
+
3503
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3504
+
3505
+ // rows per thread
3506
+ const int dr = (nr + nth - 1)/nth;
3507
+
3508
+ // row range for this thread
3509
+ const int ir0 = dr*ith;
3510
+ const int ir1 = MIN(ir0 + dr, nr);
3511
+
3512
+ for (int i1 = ir0; i1 < ir1; i1++) {
3513
+ float * src0_p = (float *) (src0_d + i1*src0_o);
3514
+ float * src1_p = (float *) (src1_d + i1*src1_o);
3515
+
3516
+ if (!src1) {
3517
+ src0_p += swapped ? nc : 0;
3518
+ src1_p += swapped ? 0 : nc;
3519
+ }
3520
+
3521
+ ggml_vec_swiglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3522
+
3523
+ #ifndef NDEBUG
3524
+ for (int k = 0; k < nc; k++) {
3525
+ const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3526
+ GGML_UNUSED(x);
3527
+ assert(!isnan(x));
3528
+ assert(!isinf(x));
3529
+ }
3530
+ #endif
3531
+ }
3532
+ }
3533
+
3534
+ static void ggml_compute_forward_swiglu_f16(
3535
+ const ggml_compute_params * params,
3536
+ ggml_tensor * dst) {
3537
+
3538
+ const ggml_tensor * src0 = dst->src[0];
3539
+ const ggml_tensor * src1 = dst->src[1];
3540
+ char * src0_d = (char *) src0->data;
3541
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3542
+ const size_t src0_o = src0->nb[1];
3543
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3544
+
3545
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
3546
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
3547
+
3548
+ if (src1) {
3549
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
3550
+ GGML_ASSERT(src0->type == src1->type);
3551
+ }
3552
+
3553
+ const int ith = params->ith;
3554
+ const int nth = params->nth;
3555
+
3556
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3557
+ const int nr = ggml_nrows(src0);
3558
+
3559
+ GGML_ASSERT(dst->ne[0] == nc);
3560
+ GGML_ASSERT(ggml_nrows(dst) == nr);
3561
+
3562
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3563
+
3564
+ // rows per thread
3565
+ const int dr = (nr + nth - 1)/nth;
3566
+
3567
+ // row range for this thread
3568
+ const int ir0 = dr*ith;
3569
+ const int ir1 = MIN(ir0 + dr, nr);
3570
+
3571
+ for (int i1 = ir0; i1 < ir1; i1++) {
3572
+ ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3573
+ ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3574
+
3575
+ if (!src1) {
3576
+ src0_p += swapped ? nc : 0;
3577
+ src1_p += swapped ? 0 : nc;
3578
+ }
3579
+
3580
+ ggml_vec_swiglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3581
+
3582
+ #ifndef NDEBUG
3583
+ for (int k = 0; k < nc; k++) {
3584
+ const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3585
+ const float v = GGML_FP16_TO_FP32(x);
3586
+ GGML_UNUSED(v);
3587
+ assert(!isnan(v));
3588
+ assert(!isinf(v));
3589
+ }
3590
+ #endif
3591
+ }
3592
+ }
3593
+
3594
+ static void ggml_compute_forward_swiglu(
3595
+ const ggml_compute_params * params,
3596
+ ggml_tensor * dst) {
3597
+
3598
+ const ggml_tensor * src0 = dst->src[0];
3599
+
3600
+ switch (src0->type) {
3601
+ case GGML_TYPE_F32:
3602
+ {
3603
+ ggml_compute_forward_swiglu_f32(params, dst);
3604
+ } break;
3605
+ case GGML_TYPE_F16:
3606
+ {
3607
+ ggml_compute_forward_swiglu_f16(params, dst);
3608
+ } break;
3609
+ default:
3610
+ {
3611
+ GGML_ABORT("fatal error");
3612
+ }
3613
+ }
3614
+ }
3615
+
3616
  // ggml_compute_forward_norm
3617
 
3618
  static void ggml_compute_forward_norm_f32(
 
8486
  }
8487
  }
8488
 
8489
+ //ggml_compute_forward_glu
8490
+
8491
+ void ggml_compute_forward_glu(
8492
+ const ggml_compute_params * params,
8493
+ ggml_tensor * dst) {
8494
+
8495
+ const ggml_glu_op op = ggml_get_glu_op(dst);
8496
+
8497
+ switch (op) {
8498
+ case GGML_GLU_OP_REGLU:
8499
+ {
8500
+ ggml_compute_forward_reglu(params, dst);
8501
+ } break;
8502
+ case GGML_GLU_OP_GEGLU:
8503
+ {
8504
+ ggml_compute_forward_geglu(params, dst);
8505
+ } break;
8506
+ case GGML_GLU_OP_SWIGLU:
8507
+ {
8508
+ ggml_compute_forward_swiglu(params, dst);
8509
+ } break;
8510
+ default:
8511
+ {
8512
+ GGML_ABORT("fatal error");
8513
+ }
8514
+ }
8515
+ }
8516
+
8517
  // ggml_compute_forward_get_rel_pos
8518
 
8519
  static void ggml_compute_forward_get_rel_pos_f16(
ggml/src/ggml-cpu/ops.h CHANGED
@@ -94,6 +94,7 @@ void ggml_compute_forward_ssm_scan(const struct ggml_compute_params * params, st
94
  void ggml_compute_forward_win_part(const struct ggml_compute_params * params, struct ggml_tensor * dst);
95
  void ggml_compute_forward_win_unpart(const struct ggml_compute_params * params, struct ggml_tensor * dst);
96
  void ggml_compute_forward_unary(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 
97
  void ggml_compute_forward_get_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
98
  void ggml_compute_forward_add_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
99
  void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 
94
  void ggml_compute_forward_win_part(const struct ggml_compute_params * params, struct ggml_tensor * dst);
95
  void ggml_compute_forward_win_unpart(const struct ggml_compute_params * params, struct ggml_tensor * dst);
96
  void ggml_compute_forward_unary(const struct ggml_compute_params * params, struct ggml_tensor * dst);
97
+ void ggml_compute_forward_glu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
98
  void ggml_compute_forward_get_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
99
  void ggml_compute_forward_add_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
100
  void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, struct ggml_tensor * dst);
ggml/src/ggml-cpu/vec.cpp CHANGED
@@ -254,6 +254,30 @@ void ggml_vec_silu_f32(const int n, float * y, const float * x) {
254
  }
255
  }
256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {
258
  int i = 0;
259
  ggml_float sum = 0;
 
254
  }
255
  }
256
 
257
+ void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float * g) {
258
+ int i = 0;
259
+ #if defined(__AVX512F__) && defined(__AVX512DQ__)
260
+ for (; i + 15 < n; i += 16) {
261
+ _mm512_storeu_ps(y + i, _mm512_mul_ps(ggml_v_silu(_mm512_loadu_ps(x + i)), _mm512_loadu_ps(g + i)));
262
+ }
263
+ #elif defined(__AVX2__) && defined(__FMA__)
264
+ for (; i + 7 < n; i += 8) {
265
+ _mm256_storeu_ps(y + i, _mm256_mul_ps(ggml_v_silu(_mm256_loadu_ps(x + i)), _mm256_loadu_ps(g + i)));
266
+ }
267
+ #elif defined(__SSE2__)
268
+ for (; i + 3 < n; i += 4) {
269
+ _mm_storeu_ps(y + i, _mm_mul_ps(ggml_v_silu(_mm_loadu_ps(x + i)), _mm_loadu_ps(g + i)));
270
+ }
271
+ #elif defined(__ARM_NEON) && defined(__aarch64__)
272
+ for (; i + 3 < n; i += 4) {
273
+ vst1q_f32(y + i, vmulq_f32(ggml_v_silu(vld1q_f32(x + i)), vld1q_f32(g + i)));
274
+ }
275
+ #endif
276
+ for (; i < n; ++i) {
277
+ y[i] = ggml_silu_f32(x[i]) * g[i];
278
+ }
279
+ }
280
+
281
  ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {
282
  int i = 0;
283
  ggml_float sum = 0;
ggml/src/ggml-cpu/vec.h CHANGED
@@ -905,6 +905,60 @@ inline static void ggml_vec_silu_backward_f16(const int n, ggml_fp16_t * dx, con
905
  }
906
  }
907
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
908
  inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
909
  #ifndef GGML_USE_ACCELERATE
910
  ggml_float sum = 0.0;
 
905
  }
906
  }
907
 
908
+ inline static void ggml_vec_reglu_f32 (const int n, float * y, const float * x, const float * g) {
909
+ for (int i = 0; i < n; ++i) {
910
+ y[i] = (x[i] > 0.f) ? x[i] * g[i] : 0.f;
911
+ }
912
+ }
913
+
914
+ inline static void ggml_vec_reglu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
915
+ for (int i = 0; i < n; ++i) {
916
+ float v = GGML_FP16_TO_FP32(x[i]);
917
+ y[i] = GGML_FP32_TO_FP16((v > 0.f) ? v * GGML_FP16_TO_FP32(g[i]) : 0.f);
918
+ }
919
+ }
920
+
921
+ #ifdef GGML_GELU_FP16
922
+ inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x, const float * g) {
923
+ uint16_t t;
924
+ for (int i = 0; i < n; ++i) {
925
+ if (x[i] <= -10.0f) {
926
+ y[i] = 0.0f;
927
+ } else if (x[i] >= 10.0f) {
928
+ y[i] = x[i] * g[i];
929
+ } else {
930
+ ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
931
+ memcpy(&t, &fp16, sizeof(uint16_t));
932
+ y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t]) * g[i];
933
+ }
934
+ }
935
+ }
936
+ #else
937
+ inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x, const float * g) {
938
+ for (int i = 0; i < n; ++i) {
939
+ y[i] = ggml_gelu_f32(x[i]) * g[i];
940
+ }
941
+ }
942
+ #endif
943
+
944
+ inline static void ggml_vec_geglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
945
+ const uint16_t * i16 = (const uint16_t *) x;
946
+ for (int i = 0; i < n; ++i) {
947
+ float v = GGML_FP16_TO_FP32(g[i]);
948
+ y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(ggml_table_gelu_f16[i16[i]]) * v);
949
+ }
950
+ }
951
+
952
+ void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float * g);
953
+
954
+ inline static void ggml_vec_swiglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
955
+ for (int i = 0; i < n; ++i) {
956
+ float v = GGML_FP16_TO_FP32(x[i]);
957
+ float w = GGML_FP16_TO_FP32(g[i]);
958
+ y[i] = GGML_FP32_TO_FP16((v/(1.0f + expf(-v))) * w);
959
+ }
960
+ }
961
+
962
  inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
963
  #ifndef GGML_USE_ACCELERATE
964
  ggml_float sum = 0.0;
ggml/src/ggml-cuda/ggml-cuda.cu CHANGED
@@ -2303,6 +2303,21 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2303
  return false;
2304
  }
2305
  break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2306
  case GGML_OP_NORM:
2307
  ggml_cuda_op_norm(ctx, dst);
2308
  break;
@@ -3096,6 +3111,16 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3096
  return false;
3097
  }
3098
  break;
 
 
 
 
 
 
 
 
 
 
3099
  case GGML_OP_MUL_MAT:
3100
  case GGML_OP_MUL_MAT_ID:
3101
  {
 
2303
  return false;
2304
  }
2305
  break;
2306
+ case GGML_OP_GLU:
2307
+ switch (ggml_get_glu_op(dst)) {
2308
+ case GGML_GLU_OP_REGLU:
2309
+ ggml_cuda_op_reglu(ctx, dst);
2310
+ break;
2311
+ case GGML_GLU_OP_GEGLU:
2312
+ ggml_cuda_op_geglu(ctx, dst);
2313
+ break;
2314
+ case GGML_GLU_OP_SWIGLU:
2315
+ ggml_cuda_op_swiglu(ctx, dst);
2316
+ break;
2317
+ default:
2318
+ return false;
2319
+ }
2320
+ break;
2321
  case GGML_OP_NORM:
2322
  ggml_cuda_op_norm(ctx, dst);
2323
  break;
 
3111
  return false;
3112
  }
3113
  break;
3114
+ case GGML_OP_GLU:
3115
+ switch (ggml_get_glu_op(op)) {
3116
+ case GGML_GLU_OP_REGLU:
3117
+ case GGML_GLU_OP_GEGLU:
3118
+ case GGML_GLU_OP_SWIGLU:
3119
+ return ggml_is_contiguous_1(op->src[0]);
3120
+ default:
3121
+ return false;
3122
+ }
3123
+ break;
3124
  case GGML_OP_MUL_MAT:
3125
  case GGML_OP_MUL_MAT_ID:
3126
  {
ggml/src/ggml-cuda/unary.cu CHANGED
@@ -196,6 +196,95 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
196
  ggml_cuda_op_unary<op_log>(ctx, dst);
197
  }
198
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  /* silu_back */
200
 
201
  static __device__ __forceinline__ float op_silu_back(float grad, float x) {
 
196
  ggml_cuda_op_unary<op_log>(ctx, dst);
197
  }
198
 
199
+ /* gated ops */
200
+
201
+ template <float (*op)(float), typename T>
202
+ static __global__ void unary_gated_op_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1) {
203
+ const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x;
204
+
205
+ if (i >= k) {
206
+ return;
207
+ }
208
+
209
+ // perform base op and multiply with gate (either offset in same tensor or a separate one)
210
+ const int64_t j0 = (i / n) * o0 + (i % n);
211
+ const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
212
+
213
+ dst[i] = (T)(op((float)x[j0]) * (float)g[j1]);
214
+ }
215
+
216
+ template <float (*op)(float), typename T>
217
+ static void unary_gated_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, cudaStream_t stream) {
218
+ const int64_t num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE;
219
+ unary_gated_op_kernel<op><<<num_blocks, CUDA_GLU_BLOCK_SIZE, 0, stream>>>(x, g, dst, k, n, o0, o1);
220
+ }
221
+
222
+ template <float (*op)(float)>
223
+ void ggml_cuda_op_unary_gated(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
224
+ const ggml_tensor * src0 = dst->src[0];
225
+ const ggml_tensor * src1 = dst->src[1];
226
+ void * src0_d = src0->data;
227
+ void * src1_d = src1 ? src1->data : src0->data;
228
+ const int64_t src0_o = src0->nb[1];
229
+ const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
230
+ void * dst_d = dst->data;
231
+ const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
232
+ cudaStream_t stream = ctx.stream();
233
+
234
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
235
+ GGML_ASSERT(src0->nb[0] == ggml_element_size(src0));
236
+ GGML_ASSERT(ggml_is_contiguous(dst));
237
+
238
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
239
+ GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
240
+ GGML_ASSERT(src0->type == dst->type);
241
+ GGML_ASSERT(dst->ne[0] == nc);
242
+ GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0));
243
+
244
+ if (src1) {
245
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
246
+ GGML_ASSERT(src1->nb[0] == ggml_element_size(src1));
247
+ GGML_ASSERT(src1->ne[0] == nc);
248
+ GGML_ASSERT(src0->type == src1->type);
249
+ }
250
+
251
+ const int32_t swapped = ((const int32_t *) dst->op_params)[1];
252
+
253
+ if (src0->type == GGML_TYPE_F16) {
254
+ half * src0_p = (half *) src0_d;
255
+ half * src1_p = (half *) src1_d;
256
+
257
+ if (!src1) {
258
+ src0_p += swapped ? nc : 0;
259
+ src1_p += swapped ? 0 : nc;
260
+ }
261
+
262
+ unary_gated_cuda<op>(src0_p, src1_p, (half *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(half), src1_o / sizeof(half), stream);
263
+ } else {
264
+ float * src0_p = (float *) src0_d;
265
+ float * src1_p = (float *) src1_d;
266
+
267
+ if (!src1) {
268
+ src0_p += swapped ? nc : 0;
269
+ src1_p += swapped ? 0 : nc;
270
+ }
271
+
272
+ unary_gated_cuda<op>(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), stream);
273
+ }
274
+ }
275
+
276
+ void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
277
+ ggml_cuda_op_unary_gated<op_relu>(ctx, dst);
278
+ }
279
+
280
+ void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
281
+ ggml_cuda_op_unary_gated<op_gelu>(ctx, dst);
282
+ }
283
+
284
+ void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
285
+ ggml_cuda_op_unary_gated<op_silu>(ctx, dst);
286
+ }
287
+
288
  /* silu_back */
289
 
290
  static __device__ __forceinline__ float op_silu_back(float grad, float x) {
ggml/src/ggml-cuda/unary.cuh CHANGED
@@ -15,6 +15,7 @@
15
  #define CUDA_SQRT_BLOCK_SIZE 256
16
  #define CUDA_SIN_BLOCK_SIZE 256
17
  #define CUDA_COS_BLOCK_SIZE 256
 
18
 
19
  void ggml_cuda_op_abs(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
20
 
@@ -57,3 +58,9 @@ void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
57
  void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
58
 
59
  void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 
 
 
 
 
 
15
  #define CUDA_SQRT_BLOCK_SIZE 256
16
  #define CUDA_SIN_BLOCK_SIZE 256
17
  #define CUDA_COS_BLOCK_SIZE 256
18
+ #define CUDA_GLU_BLOCK_SIZE 256
19
 
20
  void ggml_cuda_op_abs(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
21
 
 
58
  void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
59
 
60
  void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
61
+
62
+ void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
63
+
64
+ void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
65
+
66
+ void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
ggml/src/ggml-metal/ggml-metal-impl.h CHANGED
@@ -422,6 +422,17 @@ typedef struct {
422
  int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources
423
  } ggml_metal_kargs_im2col;
424
 
 
 
 
 
 
 
 
 
 
 
 
425
  typedef struct {
426
  int64_t ne00;
427
  int64_t ne01;
 
422
  int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources
423
  } ggml_metal_kargs_im2col;
424
 
425
+ typedef struct{
426
+ int32_t ne00;
427
+ uint64_t nb01;
428
+ int32_t ne10;
429
+ uint64_t nb11;
430
+ int32_t ne0;
431
+ uint64_t nb1;
432
+ int32_t i00;
433
+ int32_t i10;
434
+ } ggml_metal_kargs_glu;
435
+
436
  typedef struct {
437
  int64_t ne00;
438
  int64_t ne01;
ggml/src/ggml-metal/ggml-metal.m CHANGED
@@ -526,6 +526,9 @@ enum ggml_metal_kernel_type {
526
  GGML_METAL_KERNEL_TYPE_SIN,
527
  GGML_METAL_KERNEL_TYPE_COS,
528
  GGML_METAL_KERNEL_TYPE_NEG,
 
 
 
529
  GGML_METAL_KERNEL_TYPE_SUM_ROWS,
530
  GGML_METAL_KERNEL_TYPE_MEAN,
531
  GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
@@ -1502,6 +1505,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
1502
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
1503
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
1504
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
 
 
 
1505
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
1506
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
1507
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
@@ -1680,6 +1686,15 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1680
  default:
1681
  return false;
1682
  }
 
 
 
 
 
 
 
 
 
1683
  case GGML_OP_NONE:
1684
  case GGML_OP_RESHAPE:
1685
  case GGML_OP_VIEW:
@@ -2419,6 +2434,62 @@ static bool ggml_metal_encode_node(
2419
  GGML_ABORT("fatal error");
2420
  }
2421
  } break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2422
  case GGML_OP_SQR:
2423
  {
2424
  GGML_ASSERT(ggml_is_contiguous(src0));
 
526
  GGML_METAL_KERNEL_TYPE_SIN,
527
  GGML_METAL_KERNEL_TYPE_COS,
528
  GGML_METAL_KERNEL_TYPE_NEG,
529
+ GGML_METAL_KERNEL_TYPE_REGLU,
530
+ GGML_METAL_KERNEL_TYPE_GEGLU,
531
+ GGML_METAL_KERNEL_TYPE_SWIGLU,
532
  GGML_METAL_KERNEL_TYPE_SUM_ROWS,
533
  GGML_METAL_KERNEL_TYPE_MEAN,
534
  GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
 
1505
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
1506
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
1507
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
1508
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REGLU, reglu, true);
1509
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true);
1510
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true);
1511
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
1512
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
1513
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
 
1686
  default:
1687
  return false;
1688
  }
1689
+ case GGML_OP_GLU:
1690
+ switch (ggml_get_glu_op(op)) {
1691
+ case GGML_GLU_OP_REGLU:
1692
+ case GGML_GLU_OP_GEGLU:
1693
+ case GGML_GLU_OP_SWIGLU:
1694
+ return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
1695
+ default:
1696
+ return false;
1697
+ }
1698
  case GGML_OP_NONE:
1699
  case GGML_OP_RESHAPE:
1700
  case GGML_OP_VIEW:
 
2434
  GGML_ABORT("fatal error");
2435
  }
2436
  } break;
2437
+ case GGML_OP_GLU:
2438
+ {
2439
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
2440
+
2441
+ if (src1) {
2442
+ GGML_ASSERT(ggml_are_same_shape(src0, src1));
2443
+ }
2444
+
2445
+ id<MTLComputePipelineState> pipeline = nil;
2446
+
2447
+ switch (ggml_get_glu_op(node)) {
2448
+ case GGML_GLU_OP_REGLU:
2449
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REGLU].pipeline;
2450
+ break;
2451
+ case GGML_GLU_OP_GEGLU:
2452
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU].pipeline;
2453
+ break;
2454
+ case GGML_GLU_OP_SWIGLU:
2455
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
2456
+ break;
2457
+ default:
2458
+ GGML_ABORT("fatal error");
2459
+ }
2460
+
2461
+ const int32_t swp = ((const int32_t *) dst->op_params)[1];
2462
+
2463
+ const int32_t i00 = swp ? ne0 : 0;
2464
+ const int32_t i10 = swp ? 0 : ne0;
2465
+
2466
+ ggml_metal_kargs_glu args = {
2467
+ /*.ne00 =*/ ne00,
2468
+ /*.nb01 =*/ nb01,
2469
+ /*.ne10 =*/ src1 ? ne10 : ne00,
2470
+ /*.nb11 =*/ src1 ? nb11 : nb01,
2471
+ /*.ne0 =*/ ne0,
2472
+ /*.nb1 =*/ nb1,
2473
+ /*.i00 =*/ src1 ? 0 : i00,
2474
+ /*.i10 =*/ src1 ? 0 : i10,
2475
+ };
2476
+
2477
+ [encoder setComputePipelineState:pipeline];
2478
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2479
+ if (src1) {
2480
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2481
+ } else {
2482
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2483
+ }
2484
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2485
+ [encoder setBytes:&args length:sizeof(args) atIndex:3];
2486
+
2487
+ const int64_t nrows = ggml_nrows(src0);
2488
+
2489
+ const int32_t nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00/2);
2490
+
2491
+ [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2492
+ } break;
2493
  case GGML_OP_SQR:
2494
  {
2495
  GGML_ASSERT(ggml_is_contiguous(src0));
ggml/src/ggml-metal/ggml-metal.metal CHANGED
@@ -1191,6 +1191,70 @@ kernel void kernel_neg(
1191
  dst[tpig] = -src0[tpig];
1192
  }
1193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1194
  template <bool norm>
1195
  kernel void kernel_sum_rows(
1196
  constant ggml_metal_kargs_sum_rows & args,
 
1191
  dst[tpig] = -src0[tpig];
1192
  }
1193
 
1194
+ kernel void kernel_reglu(
1195
+ device const char * src0,
1196
+ device const char * src1,
1197
+ device char * dst,
1198
+ constant ggml_metal_kargs_glu & args,
1199
+ uint tgpig[[threadgroup_position_in_grid]],
1200
+ uint tpitg[[thread_position_in_threadgroup]],
1201
+ uint ntg[[threads_per_threadgroup]]) {
1202
+ device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1203
+ device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1204
+ device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1205
+
1206
+ for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1207
+ const float x0 = src0_row[i0];
1208
+ const float x1 = src1_row[i0];
1209
+
1210
+ dst_row[i0] = x0*x1*(x0 > 0.0f);
1211
+ }
1212
+ }
1213
+
1214
+ kernel void kernel_geglu(
1215
+ device const char * src0,
1216
+ device const char * src1,
1217
+ device char * dst,
1218
+ constant ggml_metal_kargs_glu & args,
1219
+ uint tgpig[[threadgroup_position_in_grid]],
1220
+ uint tpitg[[thread_position_in_threadgroup]],
1221
+ uint ntg[[threads_per_threadgroup]]) {
1222
+ device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1223
+ device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1224
+ device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1225
+
1226
+ for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1227
+ const float x0 = src0_row[i0];
1228
+ const float x1 = src1_row[i0];
1229
+
1230
+ const float gelu = 0.5f*x0*(1.0f + precise::tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
1231
+
1232
+ dst_row[i0] = gelu*x1;
1233
+ }
1234
+ }
1235
+
1236
+ kernel void kernel_swiglu(
1237
+ device const char * src0,
1238
+ device const char * src1,
1239
+ device char * dst,
1240
+ constant ggml_metal_kargs_glu & args,
1241
+ uint tgpig[[threadgroup_position_in_grid]],
1242
+ uint tpitg[[thread_position_in_threadgroup]],
1243
+ uint ntg[[threads_per_threadgroup]]) {
1244
+ device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1245
+ device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1246
+ device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1247
+
1248
+ for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1249
+ const float x0 = src0_row[i0];
1250
+ const float x1 = src1_row[i0];
1251
+
1252
+ const float silu = x0 / (1.0f + exp(-x0));
1253
+
1254
+ dst_row[i0] = silu*x1;
1255
+ }
1256
+ }
1257
+
1258
  template <bool norm>
1259
  kernel void kernel_sum_rows(
1260
  constant ggml_metal_kargs_sum_rows & args,
ggml/src/ggml-sycl/element_wise.cpp CHANGED
@@ -1,12 +1,19 @@
1
  #include "common.hpp"
 
2
  #include "ggml.h"
3
  #include "element_wise.hpp"
4
 
 
 
 
 
 
 
 
5
  static void acc_f32(const float * x, const float * y, float * dst, const int ne,
6
  const int ne10, const int ne11, const int ne12,
7
- const int nb1, const int nb2, int offset, const sycl::nd_item<3> &item_ct1) {
8
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
9
- item_ct1.get_local_id(2);
10
  if (i >= ne) {
11
  return;
12
  }
@@ -21,535 +28,375 @@ static void acc_f32(const float * x, const float * y, float * dst, const int ne,
21
  }
22
  }
23
 
 
24
  template<typename T>
25
- static void sgn(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) {
26
- for(auto i = item_ct1.get_global_id(2); i < (const size_t)k; i += item_ct1.get_global_range(2)) {
27
- dst[i] = x[i] > static_cast<T>(0.f) ? static_cast<T>(1.f) : ((x[i] < static_cast<T>(0.f) ? static_cast<T>(-1.f) : static_cast<T>(0.f)));
28
- }
29
  }
30
 
31
  template<typename T>
32
- static void abs_op(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) {
33
- for(auto i = item_ct1.get_global_id(2); i < (const size_t)k; i += item_ct1.get_global_range(2)) {
34
- dst[i] = sycl::fabs(x[i]);
35
- }
36
  }
37
 
38
  template<typename T>
39
- static void elu_op(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) {
40
- for(auto i = item_ct1.get_global_id(2); i < (const size_t)k; i += item_ct1.get_global_range(2)) {
41
- dst[i] = (x[i] > static_cast<T>(0.f)) ? x[i] : sycl::expm1(x[i]);
42
- }
43
  }
44
 
45
  template<typename T>
46
- static void gelu(const T * x, T * dst, const int k,
47
- const sycl::nd_item<3> &item_ct1) {
48
  const T GELU_COEF_A = static_cast<T>(0.044715f);
49
  const T SQRT_2_OVER_PI = static_cast<T>(0.79788456080286535587989211986876f);
50
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
51
- item_ct1.get_local_id(2);
52
-
53
- if (i >= k) {
54
- return;
55
- }
56
-
57
- float xi = x[i];
58
- dst[i] = static_cast<T>(0.5f) * xi *
59
- (static_cast<T>(1.0f) +
60
- sycl::tanh(SQRT_2_OVER_PI * xi * (static_cast<T>(1.0f) + GELU_COEF_A * xi * xi)));
61
  }
62
 
63
  template<typename T>
64
- static void silu(const T * x, T * dst, const int k,
65
- const sycl::nd_item<3> &item_ct1) {
66
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
67
- item_ct1.get_local_id(2);
68
-
69
- if (i >= k) {
70
- return;
71
- }
72
- dst[i] = x[i] / (static_cast<T>(1.0f) + sycl::native::exp(-x[i]));
73
  }
74
 
75
  template<typename T>
76
- static void gelu_quick(const T *x, T *dst, int k,
77
- const sycl::nd_item<3> &item_ct1) {
78
- const float GELU_QUICK_COEF = -1.702f;
79
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
80
- item_ct1.get_local_id(2);
81
- if (i >= k) {
82
- return;
83
- }
84
- dst[i] = x[i] * (static_cast<T>(1.0f) / (static_cast<T>(1.0f) + sycl::native::exp(GELU_QUICK_COEF * x[i])));
85
  }
86
 
87
  template<typename T>
88
- static void gelu_erf(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) {
89
  const T SQRT_2_INV = static_cast<T>(0.70710678118654752440084436210484f);
90
- for(auto i = item_ct1.get_global_id(2); i < (const size_t)k; i += item_ct1.get_global_range(2)) {
91
- auto x_i = x[i];
92
- dst[i] = static_cast<T>(0.5f) * x_i * (static_cast<T>(1.0f) + sycl::erf(x_i * SQRT_2_INV));
93
- }
94
  }
95
 
96
  template<typename T>
97
- static void tanh(const T *x, T *dst, int k,
98
- const sycl::nd_item<3> &item_ct1) {
99
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
100
- item_ct1.get_local_id(2);
101
- if (i >= k) {
102
- return;
103
- }
104
- dst[i] = sycl::tanh((x[i]));
105
  }
106
 
107
  template<typename T>
108
- static void relu(const T * x, T * dst, const int k,
109
- const sycl::nd_item<3> &item_ct1) {
110
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
111
- item_ct1.get_local_id(2);
112
-
113
- if (i >= k) {
114
- return;
115
- }
116
- dst[i] = sycl::fmax((x[i]), static_cast<T>(0));
117
  }
118
 
119
  template<typename T>
120
- static void sigmoid(const T * x, T * dst, const int k,
121
- const sycl::nd_item<3> &item_ct1) {
122
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
123
- item_ct1.get_local_id(2);
124
-
125
- if (i >= k) {
126
- return;
127
- }
128
- dst[i] = 1.0f / (static_cast<T>(1.0f) + sycl::native::exp(-x[i]));
129
  }
130
 
131
  template<typename T>
132
- static void sqrt(const T * x, T * dst, const int k,
133
- const sycl::nd_item<3> &item_ct1) {
134
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
135
- item_ct1.get_local_id(2);
136
-
137
- if (i >= k) {
138
- return;
139
- }
140
- dst[i] = sycl::sqrt(x[i]);
141
  }
142
 
143
  template<typename T>
144
- static void sin(const T * x, T * dst, const int k,
145
- const sycl::nd_item<3> &item_ct1) {
146
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
147
- item_ct1.get_local_id(2);
148
-
149
- if (i >= k) {
150
- return;
151
- }
152
- dst[i] = sycl::sin(x[i]);
153
  }
154
 
155
  template<typename T>
156
- static void cos(const T * x, T * dst, const int k,
157
- const sycl::nd_item<3> &item_ct1) {
158
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
159
- item_ct1.get_local_id(2);
160
-
161
- if (i >= k) {
162
- return;
163
- }
164
- dst[i] = sycl::cos(x[i]);
165
  }
166
 
167
  template<typename T>
168
- static void hardsigmoid(const T * x, T * dst, const int k,
169
- const sycl::nd_item<3> &item_ct1) {
170
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
171
- item_ct1.get_local_id(2);
172
-
173
- if (i >= k) {
174
- return;
175
- }
176
- dst[i] = sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x[i] + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
177
  }
178
 
179
  template<typename T>
180
- static void hardswish(const T * x, T * dst, const int k,
181
- const sycl::nd_item<3> &item_ct1) {
182
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
183
- item_ct1.get_local_id(2);
184
-
185
- if (i >= k) {
186
- return;
187
- }
188
- dst[i] = x[i] * sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x[i] + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
189
  }
190
 
191
  template<typename T>
192
- static void exp(const T * x, T * dst, const int k,
193
- const sycl::nd_item<3> &item_ct1) {
194
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
195
- item_ct1.get_local_id(2);
196
 
197
- if (i >= k) {
198
- return;
 
 
199
  }
200
- dst[i] = sycl::exp(x[i]);
201
  }
202
 
203
  template<typename T>
204
- static void log(const T * x, T * dst, const int k,
205
- const sycl::nd_item<3> &item_ct1) {
206
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
207
- item_ct1.get_local_id(2);
208
 
209
- if (i >= k) {
210
- return;
211
- }
212
- T xi = x[i];
213
- if (xi <= 0) {
214
- dst[i] = neg_infinity<T>();
215
- } else {
216
- dst[i] = sycl::log(xi);
217
- }
218
  }
219
 
220
  template<typename T>
221
- static void neg(const T * x, T * dst, const int k,
222
- const sycl::nd_item<3> &item_ct1) {
223
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
224
- item_ct1.get_local_id(2);
 
225
 
226
- if (i >= k) {
227
- return;
228
- }
229
- dst[i] = -x[i];
230
  }
231
 
232
  template<typename T>
233
- static void step(const T * x, T * dst, const int k,
234
- const sycl::nd_item<3> &item_ct1) {
235
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
236
- item_ct1.get_local_id(2);
237
 
238
- if (i >= k) {
239
- return;
 
 
240
  }
241
- dst[i] = x[i] > static_cast<T>(0.0f);
242
  }
243
 
244
  template<typename T>
245
- static void leaky_relu(const T *x, T *dst, const int k, const float negative_slope,
246
- const sycl::nd_item<3> &item_ct1) {
247
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
248
- item_ct1.get_local_id(2);
249
- if (i >= k) {
250
- return;
251
  }
252
- dst[i] = sycl::fmax((x[i]), static_cast<T>(0)) +
253
- sycl::fmin((x[i]), static_cast<T>(0.0f)) * negative_slope;
254
  }
255
 
256
  template<typename T>
257
- static void sqr(const T * x, T * dst, const int k,
258
- const sycl::nd_item<3> &item_ct1) {
259
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
260
- item_ct1.get_local_id(2);
261
-
262
- if (i >= k) {
263
- return;
264
  }
265
- dst[i] = x[i] * x[i];
266
  }
267
 
268
- template<typename T>
269
- static void upscale(const T *x, T *dst, const int nb00, const int nb01,
270
- const int nb02, const int nb03, const int ne10, const int ne11,
271
- const int ne12, const int ne13, const float sf0, const float sf1,
272
- const float sf2, const float sf3, const sycl::nd_item<1> &item_ct1) {
273
- int index = item_ct1.get_local_id(0) +
274
- item_ct1.get_group(0) * item_ct1.get_local_range(0);
275
- if (index >= ne10 * ne11 * ne12 * ne13) {
276
- return;
277
  }
278
- // operation
279
- int i10 = index % ne10;
280
- int i11 = (index / ne10) % ne11;
281
- int i12 = (index / (ne10 * ne11)) % ne12;
282
- int i13 = (index / (ne10 * ne11 * ne12)) % ne13;
283
-
284
- int i00 = i10 / sf0;
285
- int i01 = i11 / sf1;
286
- int i02 = i12 / sf2;
287
- int i03 = i13 / sf3;
288
-
289
- dst[index] = *(const T *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
290
  }
291
 
292
- template <typename T>
293
- static void pad(const T *x, T *dst, const int ne0, const int ne00, const int ne01, const int ne02,
294
- const sycl::nd_item<3> &item_ct1) {
295
- int nidx = item_ct1.get_local_id(2) +
296
- item_ct1.get_group(2) * item_ct1.get_local_range(2);
297
- if (nidx >= ne0) {
298
- return;
299
  }
 
300
 
301
- // operation
302
- int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
303
- item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
304
- if (nidx < ne00 && item_ct1.get_group(1) < (size_t) ne01 && item_ct1.get_group(0) < (size_t) ne02) {
305
- int offset_src = nidx + item_ct1.get_group(1) * ne00 +
306
- item_ct1.get_group(0) * ne00 * ne01;
307
- dst[offset_dst] = x[offset_src];
308
- } else {
309
- dst[offset_dst] = static_cast<T>(0.0f);
310
  }
311
  }
312
 
313
-
314
  template<typename T>
315
- static void clamp(const T * x, T * dst, const float min, const float max, const int k,
316
- const sycl::nd_item<3> &item_ct1) {
317
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
318
- item_ct1.get_local_id(2);
319
-
320
- if (i >= k) {
321
- return;
322
  }
323
-
324
- dst[i] = x[i] < static_cast<T>(min) ? static_cast<T>(min) : (x[i] > static_cast<T>(max) ? static_cast<T>(max) : x[i]);
325
  }
326
 
327
- static void acc_f32_sycl(const float *x, const float *y, float *dst,
328
- const int n_elements, const int ne10, const int ne11,
329
- const int ne12, const int nb1, const int nb2,
330
- const int offset, queue_ptr stream) {
331
- int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE;
332
- sycl_parallel_for(stream,
333
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE),
334
- sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)),
335
- [=](sycl::nd_item<3> item_ct1) {
336
- acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset, item_ct1);
337
- });
338
  }
339
 
340
  template<typename T>
341
- static void gelu_sycl(const T *x, T *dst, const int k,
342
- queue_ptr stream) {
343
- const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
344
- sycl_parallel_for(stream,
345
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
346
- sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
347
- [=](sycl::nd_item<3> item_ct1) { gelu(x, dst, k, item_ct1); });
348
  }
349
 
350
  template<typename T>
351
- static void silu_sycl(const T *x, T *dst, const int k,
352
- queue_ptr stream) {
353
- const int num_blocks = (k + SYCL_SILU_BLOCK_SIZE - 1) / SYCL_SILU_BLOCK_SIZE;
354
- sycl_parallel_for(stream,
355
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE),
356
- sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE)),
357
- [=](sycl::nd_item<3> item_ct1) { silu(x, dst, k, item_ct1); });
358
  }
359
 
360
  template<typename T>
361
- static void sgn_sycl(const T * x, T * dst, const int k, queue_ptr stream) {
362
- // hard code for now
363
- const int num_blocks = ceil_div(k, 256);
364
- sycl_parallel_for(
365
- stream, sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range(1, 1, 256)), sycl::range(1, 1, 256)),
366
- [=](sycl::nd_item<3> item_ct1) { sgn(x, dst, k, item_ct1); });
367
  }
368
 
369
  template<typename T>
370
- static void abs_sycl(const T * x, T * dst, const int k, queue_ptr stream) {
371
- // hard code for now
372
- const int num_blocks = ceil_div(k, 256);
373
- sycl_parallel_for(
374
- stream,
375
- sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256)), sycl::range<3>(1, 1, 256)),
376
- [=](sycl::nd_item<3> item_ct1) { abs_op(x, dst, k, item_ct1); });
377
  }
378
 
379
-
380
  template<typename T>
381
- static void elu_sycl(const T * x, T * dst, const int k, queue_ptr stream) {
382
- // hard code for now
383
- const int num_blocks = ceil_div(k, 256);
384
- sycl_parallel_for(
385
- stream,
386
- sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256)), sycl::range<3>(1, 1, 256)),
387
- [=](sycl::nd_item<3> item_ct1) { elu_op(x, dst, k, item_ct1); });
388
  }
389
 
390
  template<typename T>
391
- static void gelu_quick_sycl(const T *x, T *dst, const int k,
392
- queue_ptr stream) {
393
- const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
394
- sycl_parallel_for(stream,
395
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
396
- sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
397
- [=](sycl::nd_item<3> item_ct1) { gelu_quick(x, dst, k, item_ct1); });
398
  }
399
 
400
-
401
  template<typename T>
402
- static void gelu_erf_sycl(const T *x, T *dst, const int k,
403
- queue_ptr stream) {
404
- const int num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
405
- sycl_parallel_for(stream,
406
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
407
- sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
408
- [=](sycl::nd_item<3> item_ct1) { gelu_erf(x, dst, k, item_ct1); });
409
  }
410
 
411
  template<typename T>
412
- static void tanh_sycl(const T *x, T *dst, const int k,
413
- queue_ptr stream) {
414
- const int num_blocks = (k + SYCL_TANH_BLOCK_SIZE - 1) / SYCL_TANH_BLOCK_SIZE;
415
- sycl_parallel_for(stream,
416
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE),
417
- sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE)),
418
- [=](sycl::nd_item<3> item_ct1) { tanh(x, dst, k, item_ct1); });
419
  }
420
 
421
  template<typename T>
422
- static void relu_sycl(const T *x, T *dst, const int k,
423
- queue_ptr stream) {
424
- const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
425
- sycl_parallel_for(stream,
426
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE),
427
- sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)),
428
- [=](sycl::nd_item<3> item_ct1) { relu(x, dst, k, item_ct1); });
429
  }
430
 
431
  template<typename T>
432
- static void hardsigmoid_sycl(const T *x, T *dst, const int k,
433
- queue_ptr stream) {
434
- const int num_blocks = (k + SYCL_HARDSIGMOID_BLOCK_SIZE - 1) / SYCL_HARDSIGMOID_BLOCK_SIZE;
435
- sycl_parallel_for(
436
- stream,
437
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE),
438
- sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE)),
439
- [=](sycl::nd_item<3> item_ct1) { hardsigmoid(x, dst, k, item_ct1); });
440
  }
441
 
442
  template<typename T>
443
- static void hardswish_sycl(const T *x, T *dst, const int k,
444
- queue_ptr stream) {
445
- const int num_blocks = (k + SYCL_HARDSWISH_BLOCK_SIZE - 1) / SYCL_HARDSWISH_BLOCK_SIZE;
446
- sycl_parallel_for(
447
- stream,
448
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE),
449
- sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE)),
450
- [=](sycl::nd_item<3> item_ct1) { hardswish(x, dst, k, item_ct1); });
451
  }
452
 
453
  template<typename T>
454
- static void exp_sycl(const T *x, T *dst, const int k,
455
- queue_ptr stream) {
456
- const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE;
457
- sycl_parallel_for(stream,
458
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE),
459
- sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)),
460
- [=](sycl::nd_item<3> item_ct1) { exp(x, dst, k, item_ct1); });
461
  }
462
 
463
  template<typename T>
464
- static void log_sycl(const T *x, T *dst, const int k,
465
- queue_ptr stream) {
466
- const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE;
467
- sycl_parallel_for(stream,
468
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE),
469
- sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)),
470
- [=](sycl::nd_item<3> item_ct1) { log(x, dst, k, item_ct1); });
471
  }
472
 
473
  template<typename T>
474
- static void neg_sycl(const T *x, T *dst, const int k,
475
- queue_ptr stream) {
476
- const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE;
477
- sycl_parallel_for(stream,
478
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE),
479
- sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)),
480
- [=](sycl::nd_item<3> item_ct1) { neg(x, dst, k, item_ct1); });
481
  }
482
 
483
- template<typename T>
484
- static void step_sycl(const T *x, T *dst, const int k,
485
- queue_ptr stream) {
486
- const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE;
487
- sycl_parallel_for(stream,
488
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE),
489
- sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)),
490
- [=](sycl::nd_item<3> item_ct1) { step(x, dst, k, item_ct1); });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
491
  }
492
 
493
- template<typename T>
494
- static void sigmoid_sycl(const T *x, T *dst, const int k,
495
- queue_ptr stream) {
496
- const int num_blocks = (k + SYCL_SIGMOID_BLOCK_SIZE - 1) / SYCL_SIGMOID_BLOCK_SIZE;
497
- sycl_parallel_for(
498
- stream,
499
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_SIGMOID_BLOCK_SIZE),
500
- sycl::range<3>(1, 1, SYCL_SIGMOID_BLOCK_SIZE)),
501
- [=](sycl::nd_item<3> item_ct1) { sigmoid(x, dst, k, item_ct1); });
 
 
 
 
 
 
 
 
 
502
  }
503
 
504
  template<typename T>
505
- static void sqrt_sycl(const T *x, T *dst, const int k,
506
- queue_ptr stream) {
507
- const int num_blocks = (k + SYCL_SQRT_BLOCK_SIZE - 1) / SYCL_SQRT_BLOCK_SIZE;
508
- sycl_parallel_for(stream,
509
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_SQRT_BLOCK_SIZE),
510
- sycl::range<3>(1, 1, SYCL_SQRT_BLOCK_SIZE)),
511
- [=](sycl::nd_item<3> item_ct1) { sqrt(x, dst, k, item_ct1); });
512
  }
513
 
514
  template<typename T>
515
- static void sin_sycl(const T *x, T *dst, const int k,
516
- queue_ptr stream) {
517
- const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE;
518
- sycl_parallel_for(stream,
519
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE),
520
- sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE)),
521
- [=](sycl::nd_item<3> item_ct1) { sin(x, dst, k, item_ct1); });
522
  }
523
 
524
  template<typename T>
525
- static void cos_sycl(const T *x, T *dst, const int k,
526
- queue_ptr stream) {
527
- const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE;
528
- sycl_parallel_for(stream,
529
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE),
530
- sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE)),
531
- [=](sycl::nd_item<3> item_ct1) { cos(x, dst, k, item_ct1); });
532
  }
533
 
534
  template<typename T>
535
- static void leaky_relu_sycl(const T *x, T *dst, const int k,
536
- const float negative_slope,
537
- queue_ptr stream) {
538
- const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
539
- sycl_parallel_for(stream,
540
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE),
541
- sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)),
542
- [=](sycl::nd_item<3> item_ct1) { leaky_relu(x, dst, k, negative_slope, item_ct1); });
543
  }
544
 
545
- template<typename T>
546
- static void sqr_sycl(const T *x, T *dst, const int k,
547
- queue_ptr stream) {
548
- const int num_blocks = (k + SYCL_SQR_BLOCK_SIZE - 1) / SYCL_SQR_BLOCK_SIZE;
 
 
549
  sycl_parallel_for(stream,
550
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE),
551
- sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE)),
552
- [=](sycl::nd_item<3> item_ct1) { sqr(x, dst, k, item_ct1); });
 
 
 
 
553
  }
554
 
555
  template<typename T>
@@ -558,7 +405,7 @@ static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01,
558
  const int ne12, const int ne13, const float sf0, const float sf1,
559
  const float sf2, const float sf3, queue_ptr stream) {
560
  int dst_size = ne10 * ne11 * ne12 * ne13;
561
- int num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE;
562
  sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE);
563
  sycl_parallel_for<1>(
564
  stream, sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
@@ -570,7 +417,7 @@ template<typename T>
570
  static void pad_sycl(const T *x, T *dst, const int ne00,
571
  const int ne01, const int ne02, const int ne0,
572
  const int ne1, const int ne2, queue_ptr stream) {
573
- int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE;
574
  sycl::range<3> gridDim(ne2, ne1, num_blocks);
575
  sycl_parallel_for(stream,
576
  sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE),
@@ -578,22 +425,11 @@ static void pad_sycl(const T *x, T *dst, const int ne00,
578
  [=](sycl::nd_item<3> item_ct1) { pad(x, dst, ne0, ne00, ne01, ne02, item_ct1); });
579
  }
580
 
581
- template<typename T>
582
- static void clamp_sycl(const T *x, T *dst, const float min,
583
- const float max, const int k,
584
- queue_ptr stream) {
585
- const int num_blocks = (k + SYCL_CLAMP_BLOCK_SIZE - 1) / SYCL_CLAMP_BLOCK_SIZE;
586
- sycl_parallel_for(stream,
587
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE),
588
- sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE)),
589
- [=](sycl::nd_item<3> item_ct1) { clamp(x, dst, min, max, k, item_ct1); });
590
- }
591
-
592
- inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
593
  #if defined (GGML_SYCL_F16)
594
  GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
595
  GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
596
-
597
  #else
598
  GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
599
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
@@ -606,14 +442,14 @@ inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
606
  case GGML_TYPE_F16:
607
  {
608
  auto data_pts = cast_data<sycl::half>(dst);
609
- sgn_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
610
  break;
611
  }
612
  #endif
613
  case GGML_TYPE_F32:
614
  {
615
  auto data_pts = cast_data<float>(dst);
616
- sgn_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
617
  break;
618
  }
619
  default:
@@ -621,11 +457,11 @@ inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
621
  }
622
  }
623
 
624
- inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
 
625
  #if defined (GGML_SYCL_F16)
626
  GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
627
  GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
628
-
629
  #else
630
  GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
631
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
@@ -633,52 +469,66 @@ inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
633
  GGML_ASSERT(dst->src[0]->type == dst->type);
634
  dpct::queue_ptr main_stream = ctx.stream();
635
  SYCL_CHECK(ggml_sycl_set_device(ctx.device));
636
- switch (dst->type) {
637
- #if defined (GGML_SYCL_F16)
638
- case GGML_TYPE_F16:
639
- {
640
- auto data_pts = cast_data<sycl::half>(dst);
641
- abs_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
642
- break;
643
- }
644
- #endif
645
- case GGML_TYPE_F32:
646
- {
647
- auto data_pts = cast_data<float>(dst);
648
- abs_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
649
- break;
650
- }
651
- default:
652
- GGML_ABORT("GGML tensor type not supported!\n");
653
  }
654
- }
655
-
656
-
657
- inline void ggml_sycl_op_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
658
- #if defined (GGML_SYCL_F16)
659
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
660
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
661
-
662
- #else
663
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
664
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
665
- #endif
666
- GGML_ASSERT(dst->src[0]->type == dst->type);
667
- dpct::queue_ptr main_stream = ctx.stream();
668
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
669
  switch (dst->type) {
670
  #if defined (GGML_SYCL_F16)
671
  case GGML_TYPE_F16:
672
  {
673
- auto data_pts = cast_data<sycl::half>(dst);
674
- elu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
675
  break;
676
  }
677
  #endif
678
  case GGML_TYPE_F32:
679
  {
680
- auto data_pts = cast_data<float>(dst);
681
- elu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
682
  break;
683
  }
684
  default:
@@ -686,7 +536,8 @@ inline void ggml_sycl_op_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
686
  }
687
  }
688
 
689
- inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
 
690
  #if defined (GGML_SYCL_F16)
691
  GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
692
  GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
@@ -695,52 +546,31 @@ inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst
695
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
696
  #endif
697
  GGML_ASSERT(dst->src[0]->type == dst->type);
698
- dpct::queue_ptr main_stream = ctx.stream();
699
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
700
- switch (dst->type) {
701
- #if defined (GGML_SYCL_F16)
702
- case GGML_TYPE_F16:
703
- {
704
- auto data_pts = cast_data<sycl::half>(dst);
705
- silu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
706
- break;
707
- }
708
- #endif
709
- case GGML_TYPE_F32:
710
- {
711
- auto data_pts = cast_data<float>(dst);
712
- silu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
713
- break;
714
- }
715
- default:
716
- GGML_ABORT("GGML tensor type not supported!\n");
717
- }
718
- }
719
 
720
- inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
721
- #if defined (GGML_SYCL_F16)
722
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
723
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
724
- #else
725
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
726
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
727
- #endif
728
- GGML_ASSERT(dst->src[0]->type == dst->type);
729
  dpct::queue_ptr main_stream = ctx.stream();
730
  SYCL_CHECK(ggml_sycl_set_device(ctx.device));
 
 
 
 
 
731
  switch (dst->type) {
732
  #if defined (GGML_SYCL_F16)
733
  case GGML_TYPE_F16:
734
  {
735
  auto data_pts = cast_data<sycl::half>(dst);
736
- gelu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
 
 
737
  break;
738
  }
739
  #endif
740
  case GGML_TYPE_F32:
741
  {
742
  auto data_pts = cast_data<float>(dst);
743
- gelu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
 
 
744
  break;
745
  }
746
  default:
@@ -748,7 +578,8 @@ inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst
748
  }
749
  }
750
 
751
- inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
 
752
  #if defined (GGML_SYCL_F16)
753
  GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
754
  GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
@@ -757,6 +588,7 @@ inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor
757
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
758
  #endif
759
  GGML_ASSERT(dst->src[0]->type == dst->type);
 
760
  dpct::queue_ptr main_stream = ctx.stream();
761
  SYCL_CHECK(ggml_sycl_set_device(ctx.device));
762
  switch (dst->type) {
@@ -764,14 +596,16 @@ inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor
764
  case GGML_TYPE_F16:
765
  {
766
  auto data_pts = cast_data<sycl::half>(dst);
767
- gelu_quick_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
 
768
  break;
769
  }
770
  #endif
771
  case GGML_TYPE_F32:
772
  {
773
  auto data_pts = cast_data<float>(dst);
774
- gelu_quick_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
 
775
  break;
776
  }
777
  default:
@@ -779,593 +613,320 @@ inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor
779
  }
780
  }
781
 
782
- inline void ggml_sycl_op_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
783
- #if defined (GGML_SYCL_F16)
784
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
785
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
786
- #else
787
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
788
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
789
- #endif
790
- GGML_ASSERT(dst->src[0]->type == dst->type);
791
- dpct::queue_ptr main_stream = ctx.stream();
792
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
793
- switch (dst->type) {
794
- #if defined (GGML_SYCL_F16)
795
- case GGML_TYPE_F16:
796
- {
797
- auto data_pts = cast_data<sycl::half>(dst);
798
- gelu_erf_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
799
- break;
800
- }
801
- #endif
802
- case GGML_TYPE_F32:
803
- {
804
- auto data_pts = cast_data<float>(dst);
805
- gelu_erf_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
806
- break;
807
- }
808
- default:
809
- GGML_ABORT("GGML tensor type not supported!\n");
810
- }
811
- }
812
-
813
-
814
- inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
815
- #if defined (GGML_SYCL_F16)
816
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
817
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
818
- #else
819
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
820
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
821
- #endif
822
- GGML_ASSERT(dst->src[0]->type == dst->type);
823
- dpct::queue_ptr main_stream = ctx.stream();
824
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
825
- switch (dst->type) {
826
- #if defined (GGML_SYCL_F16)
827
- case GGML_TYPE_F16:
828
- {
829
- auto data_pts = cast_data<sycl::half>(dst);
830
- tanh_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
831
- break;
832
- }
833
- #endif
834
- case GGML_TYPE_F32:
835
- {
836
- auto data_pts = cast_data<float>(dst);
837
- tanh_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
838
- break;
839
- }
840
- default:
841
- GGML_ABORT("GGML tensor type not supported!\n");
842
- }
843
- }
844
-
845
- inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
846
- #if defined (GGML_SYCL_F16)
847
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
848
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
849
- #else
850
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
851
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
852
- #endif
853
- GGML_ASSERT(dst->src[0]->type == dst->type);
854
- dpct::queue_ptr main_stream = ctx.stream();
855
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
856
-
857
- switch (dst->type) {
858
- #if defined (GGML_SYCL_F16)
859
- case GGML_TYPE_F16:
860
- {
861
- auto data_pts = cast_data<sycl::half>(dst);
862
- relu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
863
- break;
864
- }
865
- #endif
866
- case GGML_TYPE_F32:
867
- {
868
- auto data_pts = cast_data<float>(dst);
869
- relu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
870
- break;
871
- }
872
- default:
873
- GGML_ABORT("GGML tensor type not supported!\n");
874
- }
875
- }
876
 
877
- inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
878
- #if defined (GGML_SYCL_F16)
879
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
880
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
881
- #else
882
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
883
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
884
- #endif
885
- GGML_ASSERT(dst->src[0]->type == dst->type);
886
 
887
- dpct::queue_ptr main_stream = ctx.stream();
888
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
889
 
890
- switch (dst->type) {
891
- #if defined (GGML_SYCL_F16)
892
- case GGML_TYPE_F16:
893
- {
894
- auto data_pts = cast_data<sycl::half>(dst);
895
- hardsigmoid_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
896
- break;
897
- }
898
- #endif
899
- case GGML_TYPE_F32:
900
- {
901
- auto data_pts = cast_data<float>(dst);
902
- hardsigmoid_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
903
- break;
904
- }
905
- default:
906
- GGML_ABORT("GGML tensor type not supported!\n");
907
- }
908
  }
909
 
910
- inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
911
- #if defined (GGML_SYCL_F16)
912
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
913
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
914
- #else
915
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
916
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
917
- #endif
918
- GGML_ASSERT(dst->src[0]->type == dst->type);
919
- dpct::queue_ptr main_stream = ctx.stream();
920
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
921
- switch (dst->type) {
922
- #if defined (GGML_SYCL_F16)
923
- case GGML_TYPE_F16:
924
- {
925
- auto data_pts = cast_data<sycl::half>(dst);
926
- hardswish_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
927
- break;
928
- }
929
- #endif
930
- case GGML_TYPE_F32:
931
- {
932
- auto data_pts = cast_data<float>(dst);
933
- hardswish_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
934
- break;
935
- }
936
- default:
937
- GGML_ABORT("GGML tensor type not supported!\n");
938
- }
939
  }
940
 
941
- inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
942
- #if defined (GGML_SYCL_F16)
943
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
944
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
945
- #else
946
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
947
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
948
- #endif
949
- GGML_ASSERT(dst->src[0]->type == dst->type);
950
- dpct::queue_ptr main_stream = ctx.stream();
951
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
952
- switch (dst->type) {
953
- #if defined (GGML_SYCL_F16)
954
- case GGML_TYPE_F16:
955
- {
956
- auto data_pts = cast_data<sycl::half>(dst);
957
- exp_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
958
- break;
959
- }
960
- #endif
961
- case GGML_TYPE_F32:
962
- {
963
- auto data_pts = cast_data<float>(dst);
964
- exp_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
965
- break;
966
- }
967
- default:
968
- GGML_ABORT("GGML tensor type not supported!\n");
969
- }
970
  }
971
 
972
- inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
973
- #if defined (GGML_SYCL_F16)
974
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
975
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
976
- #else
977
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
978
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
979
- #endif
980
- GGML_ASSERT(dst->src[0]->type == dst->type);
981
- dpct::queue_ptr main_stream = ctx.stream();
982
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
983
- switch (dst->type) {
984
- #if defined (GGML_SYCL_F16)
985
- case GGML_TYPE_F16:
986
- {
987
- auto data_pts = cast_data<sycl::half>(dst);
988
- log_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
989
- break;
990
- }
991
- #endif
992
- case GGML_TYPE_F32:
993
- {
994
- auto data_pts = cast_data<float>(dst);
995
- log_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
996
- break;
997
- }
998
- default:
999
- GGML_ABORT("GGML tensor type not supported!\n");
1000
- }
1001
  }
1002
 
1003
- inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1004
- #if defined (GGML_SYCL_F16)
1005
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
1006
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
1007
- #else
1008
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
1009
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
1010
- #endif
1011
- GGML_ASSERT(dst->src[0]->type == dst->type);
1012
- dpct::queue_ptr main_stream = ctx.stream();
1013
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
1014
- switch (dst->type) {
1015
- #if defined (GGML_SYCL_F16)
1016
- case GGML_TYPE_F16:
1017
- {
1018
- auto data_pts = cast_data<sycl::half>(dst);
1019
- sigmoid_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
1020
- break;
1021
- }
1022
- #endif
1023
- case GGML_TYPE_F32:
1024
- {
1025
- auto data_pts = cast_data<float>(dst);
1026
- sigmoid_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
1027
- break;
1028
- }
1029
- default:
1030
- GGML_ABORT("GGML tensor type not supported!\n");
1031
- }
1032
  }
1033
 
1034
- inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1035
- #if defined (GGML_SYCL_F16)
1036
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
1037
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
1038
- #else
1039
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
1040
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
1041
- #endif
1042
- GGML_ASSERT(dst->src[0]->type == dst->type);
1043
-
1044
- dpct::queue_ptr main_stream = ctx.stream();
1045
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
1046
- switch (dst->type) {
1047
- #if defined (GGML_SYCL_F16)
1048
- case GGML_TYPE_F16:
1049
- {
1050
- auto data_pts = cast_data<sycl::half>(dst);
1051
- sqrt_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
1052
- break;
1053
- }
1054
- #endif
1055
- case GGML_TYPE_F32:
1056
- {
1057
- auto data_pts = cast_data<float>(dst);
1058
- sqrt_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
1059
- break;
1060
- }
1061
- default:
1062
- GGML_ABORT("GGML tensor type not supported!\n");
1063
- }
1064
  }
1065
-
1066
- inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1067
- #if defined (GGML_SYCL_F16)
1068
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
1069
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
1070
- #else
1071
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
1072
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
1073
- #endif
1074
- GGML_ASSERT(dst->src[0]->type == dst->type);
1075
- dpct::queue_ptr main_stream = ctx.stream();
1076
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
1077
- switch (dst->type) {
1078
- #if defined (GGML_SYCL_F16)
1079
- case GGML_TYPE_F16:
1080
- {
1081
- auto data_pts = cast_data<sycl::half>(dst);
1082
- sin_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
1083
- break;
1084
- }
1085
- #endif
1086
- case GGML_TYPE_F32:
1087
- {
1088
- auto data_pts = cast_data<float>(dst);
1089
- sin_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
1090
- break;
1091
- }
1092
- default:
1093
- GGML_ABORT("GGML tensor type not supported!\n");
1094
- }
1095
  }
1096
 
1097
- inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1098
- #if defined (GGML_SYCL_F16)
1099
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
1100
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
1101
- #else
1102
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
1103
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
1104
- #endif
1105
- GGML_ASSERT(dst->src[0]->type == dst->type);
1106
- dpct::queue_ptr main_stream = ctx.stream();
1107
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
1108
- switch (dst->type) {
1109
- #if defined (GGML_SYCL_F16)
1110
- case GGML_TYPE_F16:
1111
- {
1112
- auto data_pts = cast_data<sycl::half>(dst);
1113
- cos_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
1114
- break;
1115
- }
1116
- #endif
1117
- case GGML_TYPE_F32:
1118
- {
1119
- auto data_pts = cast_data<float>(dst);
1120
- cos_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
1121
- break;
1122
- }
1123
- default:
1124
- GGML_ABORT("GGML tensor type not supported!\n");
1125
- }
1126
  }
1127
 
1128
- inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1129
- #if defined (GGML_SYCL_F16)
1130
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
1131
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
1132
- #else
1133
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
1134
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
1135
- #endif
1136
- GGML_ASSERT(dst->src[0]->type == dst->type);
1137
- dpct::queue_ptr main_stream = ctx.stream();
1138
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
1139
- switch (dst->type) {
1140
- #if defined (GGML_SYCL_F16)
1141
- case GGML_TYPE_F16:
1142
- {
1143
- auto data_pts = cast_data<sycl::half>(dst);
1144
- step_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
1145
- break;
1146
- }
1147
- #endif
1148
- case GGML_TYPE_F32:
1149
- {
1150
- auto data_pts = cast_data<float>(dst);
1151
- step_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
1152
- break;
1153
- }
1154
- default:
1155
- GGML_ABORT("GGML tensor type not supported!\n");
1156
- }
1157
  }
1158
 
1159
- inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1160
- #if defined (GGML_SYCL_F16)
1161
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
1162
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
1163
- #else
1164
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
1165
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
1166
- #endif
1167
- GGML_ASSERT(dst->src[0]->type == dst->type);
1168
- dpct::queue_ptr main_stream = ctx.stream();
1169
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
1170
- switch (dst->type) {
1171
- #if defined (GGML_SYCL_F16)
1172
- case GGML_TYPE_F16:
1173
- {
1174
- auto data_pts = cast_data<sycl::half>(dst);
1175
- neg_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
1176
- break;
1177
- }
1178
- #endif
1179
- case GGML_TYPE_F32:
1180
- {
1181
- auto data_pts = cast_data<float>(dst);
1182
- neg_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
1183
- break;
1184
- }
1185
- default:
1186
- GGML_ABORT("GGML tensor type not supported!\n");
1187
- }
1188
  }
1189
 
1190
- inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1191
- #if defined (GGML_SYCL_F16)
1192
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
1193
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
1194
- #else
1195
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
1196
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
1197
- #endif
 
 
 
 
1198
 
1199
- GGML_ASSERT(dst->src[0]->type == dst->type);
1200
- float negative_slope;
1201
- memcpy(&negative_slope, dst->op_params, sizeof(float));
1202
- dpct::queue_ptr main_stream = ctx.stream();
1203
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
1204
- switch (dst->type) {
1205
- #if defined (GGML_SYCL_F16)
1206
- case GGML_TYPE_F16:
1207
- {
1208
- auto data_pts = cast_data<sycl::half>(dst);
1209
- leaky_relu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), negative_slope, main_stream);
1210
- break;
1211
- }
1212
- #endif
1213
- case GGML_TYPE_F32:
1214
- {
1215
- auto data_pts = cast_data<float>(dst);
1216
- leaky_relu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), negative_slope, main_stream);
1217
- break;
1218
- }
1219
- default:
1220
- GGML_ABORT("GGML tensor type not supported!\n");
1221
- }
1222
  }
1223
 
1224
- inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1225
- #if defined (GGML_SYCL_F16)
1226
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
1227
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
1228
- #else
1229
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
1230
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
1231
- #endif
1232
- GGML_ASSERT(dst->src[0]->type == dst->type);
1233
- dpct::queue_ptr main_stream = ctx.stream();
1234
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
1235
- switch (dst->type) {
1236
- #if defined (GGML_SYCL_F16)
1237
- case GGML_TYPE_F16:
1238
- {
1239
- auto data_pts = cast_data<sycl::half>(dst);
1240
- sqr_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
1241
- break;
1242
- }
1243
- #endif
1244
- case GGML_TYPE_F32:
1245
- {
1246
- auto data_pts = cast_data<float>(dst);
1247
- sqr_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
1248
- break;
1249
- }
1250
- default:
1251
- GGML_ABORT("GGML tensor type not supported!\n");
1252
- }
1253
  }
1254
 
1255
- inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1256
- #if defined (GGML_SYCL_F16)
1257
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
1258
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
1259
- #else
1260
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
1261
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
1262
- #endif
1263
- GGML_ASSERT(dst->src[0]->type == dst->type);
 
 
 
1264
 
1265
- dpct::queue_ptr main_stream = ctx.stream();
1266
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
 
 
 
 
 
 
 
 
 
 
1267
 
1268
- const float sf0 = (float) dst->ne[0] / dst->src[0]->ne[0];
1269
- const float sf1 = (float) dst->ne[1] / dst->src[0]->ne[1];
1270
- const float sf2 = (float) dst->ne[2] / dst->src[0]->ne[2];
1271
- const float sf3 = (float) dst->ne[3] / dst->src[0]->ne[3];
1272
- switch (dst->type) {
1273
- #if defined (GGML_SYCL_F16)
1274
- case GGML_TYPE_F16:
1275
- {
1276
- auto data_pts = cast_data<sycl::half>(dst);
1277
- upscale_sycl(data_pts.src, data_pts.dst, dst->src[0]->nb[0], dst->src[0]->nb[1], dst->src[0]->nb[2],
1278
- dst->src[0]->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3,
1279
- main_stream);
1280
- break;
1281
- }
1282
- #endif
1283
- case GGML_TYPE_F32:
1284
- {
1285
- auto data_pts = cast_data<float>(dst);
1286
- upscale_sycl(data_pts.src, data_pts.dst, dst->src[0]->nb[0], dst->src[0]->nb[1], dst->src[0]->nb[2],
1287
- dst->src[0]->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3,
1288
- main_stream);
1289
- break;
1290
- }
1291
- default:
1292
- GGML_ABORT("GGML tensor type not supported!\n");
1293
- }
1294
  }
1295
 
1296
- inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1297
- #if defined (GGML_SYCL_F16)
1298
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
1299
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
1300
- #else
1301
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
1302
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
1303
- #endif
1304
- GGML_ASSERT(dst->src[0]->type == dst->type);
1305
- GGML_ASSERT(dst->src[0]->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
1306
- dpct::queue_ptr main_stream = ctx.stream();
1307
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
1308
- switch (dst->type) {
1309
- #if defined (GGML_SYCL_F16)
1310
- case GGML_TYPE_F16:
1311
- {
1312
- auto data_pts = cast_data<sycl::half>(dst);
1313
- pad_sycl(data_pts.src, data_pts.dst, dst->src[0]->ne[0], dst->src[0]->ne[1], dst->src[0]->ne[2], dst->ne[0],
1314
- dst->ne[1], dst->ne[2], main_stream);
1315
- break;
1316
- }
1317
- #endif
1318
- case GGML_TYPE_F32:
1319
- {
1320
- auto data_pts = cast_data<float>(dst);
1321
- pad_sycl(data_pts.src, data_pts.dst, dst->src[0]->ne[0], dst->src[0]->ne[1], dst->src[0]->ne[2], dst->ne[0],
1322
- dst->ne[1], dst->ne[2], main_stream);
1323
- break;
1324
- }
1325
- default:
1326
- GGML_ABORT("GGML tensor type not supported!\n");
1327
- }
1328
  }
1329
 
1330
- inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1331
- #if defined(GGML_SYCL_F16)
1332
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
1333
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
1334
- #else
 
 
 
 
 
 
 
1335
 
1336
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
1337
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
1338
- #endif
1339
- GGML_ASSERT(dst->src[0]->type == dst->type);
1340
- dpct::queue_ptr main_stream = ctx.stream();
1341
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
1342
- float min;
1343
- float max;
1344
- memcpy(&min, dst->op_params, sizeof(float));
1345
- memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
 
 
1346
 
1347
- switch (dst->type) {
1348
- #if defined(GGML_SYCL_F16)
1349
- case GGML_TYPE_F16:
1350
- {
1351
- auto data_pts = cast_data<sycl::half>(dst);
1352
- clamp_sycl(data_pts.src, data_pts.dst, min, max, ggml_nelements(dst->src[0]), main_stream);
1353
- break;
1354
- }
1355
- #endif
1356
- case GGML_TYPE_F32:
1357
- {
1358
- auto data_pts = cast_data<float>(dst);
1359
- clamp_sycl(data_pts.src, data_pts.dst, min, max, ggml_nelements(dst->src[0]), main_stream);
1360
- break;
1361
- }
1362
- default:
1363
- GGML_ABORT("GGML tensor type not supported!\n");
1364
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1365
  }
1366
 
1367
- inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
 
 
 
 
 
 
1368
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1369
  GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
1370
  GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32);
1371
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
@@ -1381,7 +942,40 @@ inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
1381
  // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
1382
  int offset = dst->op_params[3] / 4; // offset in bytes
1383
 
1384
- acc_f32_sycl(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), dst->src[1]->ne[0], dst->src[1]->ne[1], dst->src[1]->ne[2], nb1, nb2, offset, main_stream);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1385
  }
1386
 
1387
 
@@ -1509,3 +1103,18 @@ void ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1509
  scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1510
  ggml_sycl_op_elu(ctx, dst);
1511
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  #include "common.hpp"
2
+ #include "ggml-sycl/presets.hpp"
3
  #include "ggml.h"
4
  #include "element_wise.hpp"
5
 
6
+ #define SYCL_GLOBAL_ID_LOOP(K, ITEM) \
7
+ for (auto i = ITEM.get_global_id(0); i < (size_t)K; i += ITEM.get_global_range(0))
8
+
9
+ #define SYCL_LOCAL_ID_CALC(ITEM, IDX) \
10
+ (ITEM.get_local_range(IDX) * ITEM.get_group(IDX) + ITEM.get_local_id(IDX))
11
+
12
+
13
  static void acc_f32(const float * x, const float * y, float * dst, const int ne,
14
  const int ne10, const int ne11, const int ne12,
15
+ const int nb1, const int nb2, int offset, const sycl::nd_item<1> &item_ct1) {
16
+ const int i = SYCL_LOCAL_ID_CALC(item_ct1, 0);
 
17
  if (i >= ne) {
18
  return;
19
  }
 
28
  }
29
  }
30
 
31
+ /* Unary OP funcs */
32
  template<typename T>
33
+ static __dpct_inline__ T op_sgn(T x) {
34
+ return x > static_cast<T>(0.f) ? static_cast<T>(1.f) : ((x < static_cast<T>(0.f) ? static_cast<T>(-1.f) : static_cast<T>(0.f)));
 
 
35
  }
36
 
37
  template<typename T>
38
+ static __dpct_inline__ T op_abs(T x) {
39
+ return sycl::fabs(x);
 
 
40
  }
41
 
42
  template<typename T>
43
+ static __dpct_inline__ T op_elu(T x) {
44
+ return (x > static_cast<T>(0.f)) ? x : sycl::expm1(x);
 
 
45
  }
46
 
47
  template<typename T>
48
+ static __dpct_inline__ T op_gelu(T x) {
 
49
  const T GELU_COEF_A = static_cast<T>(0.044715f);
50
  const T SQRT_2_OVER_PI = static_cast<T>(0.79788456080286535587989211986876f);
51
+ return static_cast<T>(0.5f) * x *
52
+ (static_cast<T>(1.0f) +
53
+ sycl::tanh(SQRT_2_OVER_PI * x * (static_cast<T>(1.0f) + GELU_COEF_A * x * x)));
 
 
 
 
 
 
 
 
54
  }
55
 
56
  template<typename T>
57
+ static __dpct_inline__ T op_silu(T x) {
58
+ return x / (static_cast<T>(1.0f) + sycl::native::exp(-x));
 
 
 
 
 
 
 
59
  }
60
 
61
  template<typename T>
62
+ static __dpct_inline__ T op_gelu_quick(T x) {
63
+ const T GELU_QUICK_COEF_LOCAL = static_cast<T>(-1.702f);
64
+ return x * (static_cast<T>(1.0f) / (static_cast<T>(1.0f) + sycl::native::exp(GELU_QUICK_COEF_LOCAL * x)));
 
 
 
 
 
 
65
  }
66
 
67
  template<typename T>
68
+ static __dpct_inline__ T op_gelu_erf(T x) {
69
  const T SQRT_2_INV = static_cast<T>(0.70710678118654752440084436210484f);
70
+ return static_cast<T>(0.5f) * x * (static_cast<T>(1.0f) + sycl::erf(x * SQRT_2_INV));
 
 
 
71
  }
72
 
73
  template<typename T>
74
+ static __dpct_inline__ T op_tanh(T x) {
75
+ return sycl::tanh(x);
 
 
 
 
 
 
76
  }
77
 
78
  template<typename T>
79
+ static __dpct_inline__ T op_relu(T x) {
80
+ return sycl::fmax(x, static_cast<T>(0));
 
 
 
 
 
 
 
81
  }
82
 
83
  template<typename T>
84
+ static __dpct_inline__ T op_sigmoid(T x) {
85
+ return static_cast<T>(1.0f) / (static_cast<T>(1.0f) + sycl::native::exp(-x));
 
 
 
 
 
 
 
86
  }
87
 
88
  template<typename T>
89
+ static __dpct_inline__ T op_sqrt(T x) {
90
+ return sycl::sqrt(x);
 
 
 
 
 
 
 
91
  }
92
 
93
  template<typename T>
94
+ static __dpct_inline__ T op_sin(T x) {
95
+ return sycl::sin(x);
 
 
 
 
 
 
 
96
  }
97
 
98
  template<typename T>
99
+ static __dpct_inline__ T op_cos(T x) {
100
+ return sycl::cos(x);
 
 
 
 
 
 
 
101
  }
102
 
103
  template<typename T>
104
+ static __dpct_inline__ T op_hardsigmoid(T x) {
105
+ return sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
 
 
 
 
 
 
 
106
  }
107
 
108
  template<typename T>
109
+ static __dpct_inline__ T op_hardswish(T x) {
110
+ return x * sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
 
 
 
 
 
 
 
111
  }
112
 
113
  template<typename T>
114
+ static __dpct_inline__ T op_exp(T x) {
115
+ return sycl::exp(x);
116
+ }
 
117
 
118
+ template<typename T>
119
+ static __dpct_inline__ T op_log(T x) {
120
+ if (x <= static_cast<T>(0)) {
121
+ return neg_infinity<T>();
122
  }
123
+ return sycl::log(x);
124
  }
125
 
126
  template<typename T>
127
+ static __dpct_inline__ T op_neg(T x) {
128
+ return -x;
129
+ }
 
130
 
131
+ template<typename T>
132
+ static __dpct_inline__ T op_step(T x) {
133
+ return (x > static_cast<T>(0.0f)) ? static_cast<T>(1.0f) : static_cast<T>(0.0f);
 
 
 
 
 
 
134
  }
135
 
136
  template<typename T>
137
+ static __dpct_inline__ T op_leaky_relu(T x, float negative_slope) {
138
+ T neg_slope_T = static_cast<T>(negative_slope);
139
+ return sycl::fmax(x, static_cast<T>(0)) +
140
+ sycl::fmin(x, static_cast<T>(0.0f)) * neg_slope_T;
141
+ }
142
 
143
+ template<typename T>
144
+ static __dpct_inline__ T op_sqr(T x) {
145
+ return x * x;
 
146
  }
147
 
148
  template<typename T>
149
+ static __dpct_inline__ T op_clamp(T x, float min_val, float max_val) {
150
+ return x < static_cast<T>(min_val) ? static_cast<T>(min_val) : (x > static_cast<T>(max_val) ? static_cast<T>(max_val) : x);
151
+ }
 
152
 
153
+ template<typename T>
154
+ static void unary_op_sgn_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
155
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
156
+ dst[i] = op_sgn(x[i]);
157
  }
 
158
  }
159
 
160
  template<typename T>
161
+ static void unary_op_abs_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
162
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
163
+ dst[i] = op_abs(x[i]);
 
 
 
164
  }
 
 
165
  }
166
 
167
  template<typename T>
168
+ static void unary_op_elu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
169
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
170
+ dst[i] = op_elu(x[i]);
 
 
 
 
171
  }
 
172
  }
173
 
174
+ template<typename T>
175
+ static void unary_op_gelu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
176
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
177
+ dst[i] = op_gelu(x[i]);
 
 
 
 
 
178
  }
 
 
 
 
 
 
 
 
 
 
 
 
179
  }
180
 
181
+ template<typename T>
182
+ static void unary_op_silu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
183
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
184
+ dst[i] = op_silu(x[i]);
 
 
 
185
  }
186
+ }
187
 
188
+ template<typename T>
189
+ static void unary_op_gelu_quick_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
190
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
191
+ dst[i] = op_gelu_quick(x[i]);
 
 
 
 
 
192
  }
193
  }
194
 
 
195
  template<typename T>
196
+ static void unary_op_gelu_erf_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
197
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
198
+ dst[i] = op_gelu_erf(x[i]);
 
 
 
 
199
  }
 
 
200
  }
201
 
202
+ template<typename T>
203
+ static void unary_op_tanh_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
204
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
205
+ dst[i] = op_tanh(x[i]);
206
+ }
 
 
 
 
 
 
207
  }
208
 
209
  template<typename T>
210
+ static void unary_op_relu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
211
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
212
+ dst[i] = op_relu(x[i]);
213
+ }
 
 
 
214
  }
215
 
216
  template<typename T>
217
+ static void unary_op_sigmoid_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
218
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
219
+ dst[i] = op_sigmoid(x[i]);
220
+ }
 
 
 
221
  }
222
 
223
  template<typename T>
224
+ static void unary_op_sqrt_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
225
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
226
+ dst[i] = op_sqrt(x[i]);
227
+ }
 
 
228
  }
229
 
230
  template<typename T>
231
+ static void unary_op_sin_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
232
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
233
+ dst[i] = op_sin(x[i]);
234
+ }
 
 
 
235
  }
236
 
 
237
  template<typename T>
238
+ static void unary_op_cos_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
239
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
240
+ dst[i] = op_cos(x[i]);
241
+ }
 
 
 
242
  }
243
 
244
  template<typename T>
245
+ static void unary_op_hardsigmoid_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
246
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
247
+ dst[i] = op_hardsigmoid(x[i]);
248
+ }
 
 
 
249
  }
250
 
 
251
  template<typename T>
252
+ static void unary_op_hardswish_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
253
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
254
+ dst[i] = op_hardswish(x[i]);
255
+ }
 
 
 
256
  }
257
 
258
  template<typename T>
259
+ static void unary_op_exp_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
260
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
261
+ dst[i] = op_exp(x[i]);
262
+ }
 
 
 
263
  }
264
 
265
  template<typename T>
266
+ static void unary_op_log_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
267
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
268
+ dst[i] = op_log(x[i]);
269
+ }
 
 
 
270
  }
271
 
272
  template<typename T>
273
+ static void unary_op_neg_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
274
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
275
+ dst[i] = op_neg(x[i]);
276
+ }
 
 
 
 
277
  }
278
 
279
  template<typename T>
280
+ static void unary_op_step_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
281
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
282
+ dst[i] = op_step(x[i]);
283
+ }
 
 
 
 
284
  }
285
 
286
  template<typename T>
287
+ static void unary_op_leaky_relu_kernel(const T * x, T * dst, const int k, float negative_slope, const sycl::nd_item<1> &item_ct1) {
288
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
289
+ dst[i] = op_leaky_relu(x[i], negative_slope);
290
+ }
 
 
 
291
  }
292
 
293
  template<typename T>
294
+ static void unary_op_sqr_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
295
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
296
+ dst[i] = op_sqr(x[i]);
297
+ }
 
 
 
298
  }
299
 
300
  template<typename T>
301
+ static void unary_op_clamp_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1, float min_val, float max_val) {
302
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
303
+ dst[i] = op_clamp(x[i], min_val, max_val);
304
+ }
 
 
 
305
  }
306
 
307
+ template<typename T>
308
+ static void upscale(const T *x, T *dst, const int nb00, const int nb01,
309
+ const int nb02, const int nb03, const int ne10, const int ne11,
310
+ const int ne12, const int ne13, const float sf0, const float sf1,
311
+ const float sf2, const float sf3, const sycl::nd_item<1> &item_ct1) {
312
+ int index = item_ct1.get_local_id(0) +
313
+ item_ct1.get_group(0) * item_ct1.get_local_range(0);
314
+ if (index >= ne10 * ne11 * ne12 * ne13) {
315
+ return;
316
+ }
317
+ // operation
318
+ int i10 = index % ne10;
319
+ int i11 = (index / ne10) % ne11;
320
+ int i12 = (index / (ne10 * ne11)) % ne12;
321
+ int i13 = (index / (ne10 * ne11 * ne12)) % ne13;
322
+
323
+ int i00 = static_cast<int>(i10 / sf0);
324
+ int i01 = static_cast<int>(i11 / sf1);
325
+ int i02 = static_cast<int>(i12 / sf2);
326
+ int i03 = static_cast<int>(i13 / sf3);
327
+
328
+ dst[index] = *(const T *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
329
  }
330
 
331
+ template <typename T>
332
+ static void pad(const T *x, T *dst, const int ne0, const int ne00, const int ne01, const int ne02,
333
+ const sycl::nd_item<3> &item_ct1) {
334
+ int nidx = SYCL_LOCAL_ID_CALC(item_ct1, 2);
335
+ if (nidx >= ne0) {
336
+ return;
337
+ }
338
+
339
+ // operation
340
+ int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
341
+ item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
342
+ if (nidx < ne00 && item_ct1.get_group(1) < (size_t) ne01 && item_ct1.get_group(0) < (size_t) ne02) {
343
+ int offset_src = nidx + item_ct1.get_group(1) * ne00 +
344
+ item_ct1.get_group(0) * ne00 * ne01;
345
+ dst[offset_dst] = x[offset_src];
346
+ } else {
347
+ dst[offset_dst] = static_cast<T>(0.0f);
348
+ }
349
  }
350
 
351
  template<typename T>
352
+ static void clamp(const T * x, T * dst, const float min, const float max, const int k,
353
+ const sycl::nd_item<1> &item_ct1) {
354
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
355
+ dst[i] = x[i] < static_cast<T>(min) ? static_cast<T>(min) : (x[i] > static_cast<T>(max) ? static_cast<T>(max) : x[i]);
356
+ }
 
 
357
  }
358
 
359
  template<typename T>
360
+ static void gated_op_fused_geglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) {
361
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
362
+ const int64_t j0 = (i / n) * o0 + (i % n);
363
+ const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
364
+ dst[i] = op_gelu(x[j0]) * g[j1];
365
+ }
 
366
  }
367
 
368
  template<typename T>
369
+ static void gated_op_fused_reglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) {
370
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
371
+ const int64_t j0 = (i / n) * o0 + (i % n);
372
+ const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
373
+ dst[i] = op_relu(x[j0]) * g[j1];
374
+ }
 
375
  }
376
 
377
  template<typename T>
378
+ static void gated_op_fused_swiglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) {
379
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
380
+ const int64_t j0 = (i / n) * o0 + (i % n);
381
+ const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
382
+ dst[i] = op_silu(x[j0]) * g[j1];
383
+ }
 
 
384
  }
385
 
386
+ namespace ggml_sycl_detail {
387
+ static void acc_f32_sycl(const float *x, const float *y, float *dst,
388
+ const int n_elements, const int ne10, const int ne11,
389
+ const int ne12, const int nb1, const int nb2,
390
+ const int offset, queue_ptr stream) {
391
+ int num_blocks = ceil_div(n_elements, SYCL_ACC_BLOCK_SIZE);
392
  sycl_parallel_for(stream,
393
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) *
394
+ sycl::range<1>(SYCL_ACC_BLOCK_SIZE),
395
+ sycl::range<1>(SYCL_ACC_BLOCK_SIZE)),
396
+ [=](sycl::nd_item<1> item_ct1) {
397
+ acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset,
398
+ item_ct1);
399
+ });
400
  }
401
 
402
  template<typename T>
 
405
  const int ne12, const int ne13, const float sf0, const float sf1,
406
  const float sf2, const float sf3, queue_ptr stream) {
407
  int dst_size = ne10 * ne11 * ne12 * ne13;
408
+ int num_blocks = ceil_div(dst_size, SYCL_UPSCALE_BLOCK_SIZE);
409
  sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE);
410
  sycl_parallel_for<1>(
411
  stream, sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
 
417
  static void pad_sycl(const T *x, T *dst, const int ne00,
418
  const int ne01, const int ne02, const int ne0,
419
  const int ne1, const int ne2, queue_ptr stream) {
420
+ int num_blocks = ceil_div(ne0, SYCL_PAD_BLOCK_SIZE);
421
  sycl::range<3> gridDim(ne2, ne1, num_blocks);
422
  sycl_parallel_for(stream,
423
  sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE),
 
425
  [=](sycl::nd_item<3> item_ct1) { pad(x, dst, ne0, ne00, ne01, ne02, item_ct1); });
426
  }
427
 
428
+ template<typename KernelInvoker, typename... Args>
429
+ static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
 
 
 
 
 
 
 
 
 
 
430
  #if defined (GGML_SYCL_F16)
431
  GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
432
  GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
 
433
  #else
434
  GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
435
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
 
442
  case GGML_TYPE_F16:
443
  {
444
  auto data_pts = cast_data<sycl::half>(dst);
445
+ kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward<Args>(args)...);
446
  break;
447
  }
448
  #endif
449
  case GGML_TYPE_F32:
450
  {
451
  auto data_pts = cast_data<float>(dst);
452
+ kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward<Args>(args)...);
453
  break;
454
  }
455
  default:
 
457
  }
458
  }
459
 
460
+ template<typename KernelInvoker, typename... Args>
461
+ static inline void dispatch_ggml_sycl_op_fused_glu(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
462
  #if defined (GGML_SYCL_F16)
463
  GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
464
  GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
 
465
  #else
466
  GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
467
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
 
469
  GGML_ASSERT(dst->src[0]->type == dst->type);
470
  dpct::queue_ptr main_stream = ctx.stream();
471
  SYCL_CHECK(ggml_sycl_set_device(ctx.device));
472
+ const ggml_tensor * src0 = dst->src[0];
473
+ const ggml_tensor * src1 = dst->src[1];
474
+ const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2;;
475
+ GGML_ASSERT(dst->ne[0] == nc);
476
+ GGML_ASSERT(ggml_is_contiguous_1(dst->src[0]));
477
+ GGML_ASSERT(ggml_is_contiguous(dst));
478
+ const int32_t swapped = ((const int32_t *) dst->op_params)[1];
479
+ void * src0_d = src0->data;
480
+ void * src1_d = src1 ? src1->data : src0->data;
481
+ const int64_t src0_o = src0->nb[1];
482
+ const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
483
+ void * dst_d = dst->data;
484
+ if (src1) {
485
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
486
+ GGML_ASSERT(src1->nb[0] == ggml_element_size(src1));
487
+ GGML_ASSERT(src1->ne[0] == nc);
488
+ GGML_ASSERT(src0->type == src1->type);
489
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
490
  switch (dst->type) {
491
  #if defined (GGML_SYCL_F16)
492
  case GGML_TYPE_F16:
493
  {
494
+ sycl::half * src0_p = (sycl::half *) src0_d;
495
+ sycl::half * src1_p = (sycl::half *) src1_d;
496
+
497
+ if (!src1) {
498
+ src0_p += swapped ? nc : 0;
499
+ src1_p += swapped ? 0 : nc;
500
+ }
501
+ kernel_invoker(src0_p,
502
+ src1_p,
503
+ (sycl::half *) dst_d,
504
+ ggml_nelements(dst),
505
+ nc,
506
+ src0_o / sizeof(sycl::half),
507
+ src1_o / sizeof(sycl::half),
508
+ main_stream,
509
+ std::forward<Args>(args)...);
510
  break;
511
  }
512
  #endif
513
  case GGML_TYPE_F32:
514
  {
515
+ float * src0_p = (float *) src0_d;
516
+ float * src1_p = (float *) src1_d;
517
+
518
+ if (!src1) {
519
+ src0_p += swapped ? nc : 0;
520
+ src1_p += swapped ? 0 : nc;
521
+ }
522
+
523
+ kernel_invoker(src0_p,
524
+ src1_p,
525
+ (float *) dst_d,
526
+ ggml_nelements(dst),
527
+ nc,
528
+ src0_o / sizeof(float),
529
+ src1_o / sizeof(float),
530
+ main_stream,
531
+ std::forward<Args>(args)...);
532
  break;
533
  }
534
  default:
 
536
  }
537
  }
538
 
539
+ template<typename KernelInvoker, typename... Args>
540
+ static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
541
  #if defined (GGML_SYCL_F16)
542
  GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
543
  GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
 
546
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
547
  #endif
548
  GGML_ASSERT(dst->src[0]->type == dst->type);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
549
 
 
 
 
 
 
 
 
 
 
550
  dpct::queue_ptr main_stream = ctx.stream();
551
  SYCL_CHECK(ggml_sycl_set_device(ctx.device));
552
+
553
+ const float sf0 = (float) dst->ne[0] / dst->src[0]->ne[0];
554
+ const float sf1 = (float) dst->ne[1] / dst->src[0]->ne[1];
555
+ const float sf2 = (float) dst->ne[2] / dst->src[0]->ne[2];
556
+ const float sf3 = (float) dst->ne[3] / dst->src[0]->ne[3];
557
  switch (dst->type) {
558
  #if defined (GGML_SYCL_F16)
559
  case GGML_TYPE_F16:
560
  {
561
  auto data_pts = cast_data<sycl::half>(dst);
562
+ kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->nb[0], (int)dst->src[0]->nb[1], (int)dst->src[0]->nb[2],
563
+ (int)dst->src[0]->nb[3], (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], sf0, sf1, sf2, sf3,
564
+ main_stream, std::forward<Args>(args)...);
565
  break;
566
  }
567
  #endif
568
  case GGML_TYPE_F32:
569
  {
570
  auto data_pts = cast_data<float>(dst);
571
+ kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->nb[0], (int)dst->src[0]->nb[1], (int)dst->src[0]->nb[2],
572
+ (int)dst->src[0]->nb[3], (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], sf0, sf1, sf2, sf3,
573
+ main_stream, std::forward<Args>(args)...);
574
  break;
575
  }
576
  default:
 
578
  }
579
  }
580
 
581
+ template<typename KernelInvoker, typename... Args>
582
+ static inline void dispatch_ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
583
  #if defined (GGML_SYCL_F16)
584
  GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
585
  GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
 
588
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
589
  #endif
590
  GGML_ASSERT(dst->src[0]->type == dst->type);
591
+ GGML_ASSERT(dst->src[0]->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
592
  dpct::queue_ptr main_stream = ctx.stream();
593
  SYCL_CHECK(ggml_sycl_set_device(ctx.device));
594
  switch (dst->type) {
 
596
  case GGML_TYPE_F16:
597
  {
598
  auto data_pts = cast_data<sycl::half>(dst);
599
+ kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->ne[0], (int)dst->src[0]->ne[1], (int)dst->src[0]->ne[2], (int)dst->ne[0],
600
+ (int)dst->ne[1], (int)dst->ne[2], main_stream, std::forward<Args>(args)...);
601
  break;
602
  }
603
  #endif
604
  case GGML_TYPE_F32:
605
  {
606
  auto data_pts = cast_data<float>(dst);
607
+ kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->ne[0], (int)dst->src[0]->ne[1], (int)dst->src[0]->ne[2], (int)dst->ne[0],
608
+ (int)dst->ne[1], (int)dst->ne[2], main_stream, std::forward<Args>(args)...);
609
  break;
610
  }
611
  default:
 
613
  }
614
  }
615
 
616
+ } // namespace ggml_sycl_detail
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
617
 
 
 
 
 
 
 
 
 
 
618
 
 
 
619
 
620
+ static inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
621
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
622
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
623
+ const int num_blocks = ceil_div(k_elements, 256);
624
+ sycl_parallel_for(stream,
625
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
626
+ sycl::range<1>(256)),
627
+ [=](sycl::nd_item<1> item_ct1) {
628
+ unary_op_sgn_kernel(src, dst_ptr, k_elements, item_ct1);
629
+ });
630
+ });
 
 
 
 
 
 
 
631
  }
632
 
633
+ static inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
634
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
635
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
636
+ const int num_blocks = ceil_div(k_elements, 256);
637
+ sycl_parallel_for(stream,
638
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
639
+ sycl::range<1>(256)),
640
+ [=](sycl::nd_item<1> item_ct1) {
641
+ unary_op_abs_kernel(src, dst_ptr, k_elements, item_ct1);
642
+ });
643
+ });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
644
  }
645
 
646
+ static inline void ggml_sycl_op_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
647
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
648
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
649
+ const int num_blocks = ceil_div(k_elements, 256);
650
+ sycl_parallel_for(stream,
651
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
652
+ sycl::range<1>(256)),
653
+ [=](sycl::nd_item<1> item_ct1) {
654
+ unary_op_elu_kernel(src, dst_ptr, k_elements, item_ct1);
655
+ });
656
+ });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
657
  }
658
 
659
+ static inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
660
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
661
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
662
+ const int num_blocks = ceil_div(k_elements, SYCL_SILU_BLOCK_SIZE);
663
+ sycl_parallel_for(stream,
664
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SILU_BLOCK_SIZE),
665
+ sycl::range<1>(SYCL_SILU_BLOCK_SIZE)),
666
+ [=](sycl::nd_item<1> item_ct1) {
667
+ unary_op_silu_kernel(src, dst_ptr, k_elements, item_ct1);
668
+ });
669
+ });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
670
  }
671
 
672
+ static inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
673
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
674
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
675
+ const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE);
676
+ sycl_parallel_for(stream,
677
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE),
678
+ sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
679
+ [=](sycl::nd_item<1> item_ct1) {
680
+ unary_op_gelu_kernel(src, dst_ptr, k_elements, item_ct1);
681
+ });
682
+ });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
683
  }
684
 
685
+ static inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
686
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
687
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
688
+ const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE);
689
+ sycl_parallel_for(stream,
690
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE),
691
+ sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
692
+ [=](sycl::nd_item<1> item_ct1) {
693
+ unary_op_gelu_quick_kernel(src, dst_ptr, k_elements, item_ct1);
694
+ });
695
+ });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
696
  }
697
+
698
+ static inline void ggml_sycl_op_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
699
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
700
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
701
+ const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE);
702
+ sycl_parallel_for(stream,
703
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE),
704
+ sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
705
+ [=](sycl::nd_item<1> item_ct1) {
706
+ unary_op_gelu_erf_kernel(src, dst_ptr, k_elements, item_ct1);
707
+ });
708
+ });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
709
  }
710
 
711
+ static inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
712
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
713
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
714
+ const int num_blocks = ceil_div(k_elements, SYCL_TANH_BLOCK_SIZE);
715
+ sycl_parallel_for(stream,
716
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_TANH_BLOCK_SIZE),
717
+ sycl::range<1>(SYCL_TANH_BLOCK_SIZE)),
718
+ [=](sycl::nd_item<1> item_ct1) {
719
+ unary_op_tanh_kernel(src, dst_ptr, k_elements, item_ct1);
720
+ });
721
+ });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
722
  }
723
 
724
+ static inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
725
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
726
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
727
+ const int num_blocks = ceil_div(k_elements, SYCL_RELU_BLOCK_SIZE);
728
+ sycl_parallel_for(stream,
729
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_RELU_BLOCK_SIZE),
730
+ sycl::range<1>(SYCL_RELU_BLOCK_SIZE)),
731
+ [=](sycl::nd_item<1> item_ct1) {
732
+ unary_op_relu_kernel(src, dst_ptr, k_elements, item_ct1);
733
+ });
734
+ });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
735
  }
736
 
737
+ static inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
738
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
739
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
740
+ const int num_blocks = ceil_div(k_elements, SYCL_HARDSIGMOID_BLOCK_SIZE);
741
+ sycl_parallel_for(stream,
742
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_HARDSIGMOID_BLOCK_SIZE),
743
+ sycl::range<1>(SYCL_HARDSIGMOID_BLOCK_SIZE)),
744
+ [=](sycl::nd_item<1> item_ct1) {
745
+ unary_op_hardsigmoid_kernel(src, dst_ptr, k_elements, item_ct1);
746
+ });
747
+ });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
748
  }
749
 
750
+ static inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
751
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
752
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
753
+ const int num_blocks = ceil_div(k_elements, SYCL_HARDSWISH_BLOCK_SIZE);
754
+ sycl_parallel_for(stream,
755
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_HARDSWISH_BLOCK_SIZE),
756
+ sycl::range<1>(SYCL_HARDSWISH_BLOCK_SIZE)),
757
+ [=](sycl::nd_item<1> item_ct1) {
758
+ unary_op_hardswish_kernel(src, dst_ptr, k_elements, item_ct1);
759
+ });
760
+ });
761
+ }
762
 
763
+ static inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
764
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
765
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
766
+ const int num_blocks = ceil_div(k_elements, SYCL_EXP_BLOCK_SIZE);
767
+ sycl_parallel_for(stream,
768
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_EXP_BLOCK_SIZE),
769
+ sycl::range<1>(SYCL_EXP_BLOCK_SIZE)),
770
+ [=](sycl::nd_item<1> item_ct1) {
771
+ unary_op_exp_kernel(src, dst_ptr, k_elements, item_ct1);
772
+ });
773
+ });
 
 
 
 
 
 
 
 
 
 
 
 
774
  }
775
 
776
+ static inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
777
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
778
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
779
+ const int num_blocks = ceil_div(k_elements, SYCL_EXP_BLOCK_SIZE); // Using EXP block size
780
+ sycl_parallel_for(stream,
781
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_EXP_BLOCK_SIZE),
782
+ sycl::range<1>(SYCL_EXP_BLOCK_SIZE)),
783
+ [=](sycl::nd_item<1> item_ct1) {
784
+ unary_op_log_kernel(src, dst_ptr, k_elements, item_ct1);
785
+ });
786
+ });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
787
  }
788
 
789
+ static inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
790
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
791
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
792
+ const int num_blocks = ceil_div(k_elements, SYCL_NEG_BLOCK_SIZE);
793
+ sycl_parallel_for(stream,
794
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_NEG_BLOCK_SIZE),
795
+ sycl::range<1>(SYCL_NEG_BLOCK_SIZE)),
796
+ [=](sycl::nd_item<1> item_ct1) {
797
+ unary_op_neg_kernel(src, dst_ptr, k_elements, item_ct1);
798
+ });
799
+ });
800
+ }
801
 
802
+ static inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
803
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
804
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
805
+ const int num_blocks = ceil_div(k_elements, SYCL_NEG_BLOCK_SIZE); // Using NEG block size
806
+ sycl_parallel_for(stream,
807
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_NEG_BLOCK_SIZE),
808
+ sycl::range<1>(SYCL_NEG_BLOCK_SIZE)),
809
+ [=](sycl::nd_item<1> item_ct1) {
810
+ unary_op_step_kernel(src, dst_ptr, k_elements, item_ct1);
811
+ });
812
+ });
813
+ }
814
 
815
+ static inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
816
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
817
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
818
+ const int num_blocks = ceil_div(k_elements, SYCL_SIGMOID_BLOCK_SIZE);
819
+ sycl_parallel_for(stream,
820
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIGMOID_BLOCK_SIZE),
821
+ sycl::range<1>(SYCL_SIGMOID_BLOCK_SIZE)),
822
+ [=](sycl::nd_item<1> item_ct1) {
823
+ unary_op_sigmoid_kernel(src, dst_ptr, k_elements, item_ct1);
824
+ });
825
+ });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
826
  }
827
 
828
+ static inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
829
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
830
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
831
+ const int num_blocks = ceil_div(k_elements, SYCL_SQRT_BLOCK_SIZE);
832
+ sycl_parallel_for(stream,
833
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQRT_BLOCK_SIZE),
834
+ sycl::range<1>(SYCL_SQRT_BLOCK_SIZE)),
835
+ [=](sycl::nd_item<1> item_ct1) {
836
+ unary_op_sqrt_kernel(src, dst_ptr, k_elements, item_ct1);
837
+ });
838
+ });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
839
  }
840
 
841
+ static inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
842
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
843
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
844
+ const int num_blocks = ceil_div(k_elements, SYCL_SIN_BLOCK_SIZE);
845
+ sycl_parallel_for(stream,
846
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE),
847
+ sycl::range<1>(SYCL_SIN_BLOCK_SIZE)),
848
+ [=](sycl::nd_item<1> item_ct1) {
849
+ unary_op_sin_kernel(src, dst_ptr, k_elements, item_ct1);
850
+ });
851
+ });
852
+ }
853
 
854
+ static inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
855
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
856
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
857
+ const int num_blocks = ceil_div(k_elements, SYCL_SIN_BLOCK_SIZE); // Using SIN block size
858
+ sycl_parallel_for(stream,
859
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE),
860
+ sycl::range<1>(SYCL_SIN_BLOCK_SIZE)),
861
+ [=](sycl::nd_item<1> item_ct1) {
862
+ unary_op_cos_kernel(src, dst_ptr, k_elements, item_ct1);
863
+ });
864
+ });
865
+ }
866
 
867
+ static inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
868
+ float negative_slope;
869
+ memcpy(&negative_slope, dst->op_params, sizeof(float));
870
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
871
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream, float slope) {
872
+ const int num_blocks = ceil_div(k_elements, SYCL_RELU_BLOCK_SIZE);
873
+ sycl_parallel_for(stream,
874
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_RELU_BLOCK_SIZE),
875
+ sycl::range<1>(SYCL_RELU_BLOCK_SIZE)),
876
+ [=](sycl::nd_item<1> item_ct1) {
877
+ unary_op_leaky_relu_kernel(src, dst_ptr, k_elements, slope, item_ct1);
878
+ });
879
+ }, negative_slope);
880
+ }
881
+
882
+ static inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
883
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
884
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
885
+ const int num_blocks = ceil_div(k_elements, SYCL_SQR_BLOCK_SIZE);
886
+ sycl_parallel_for(stream,
887
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQR_BLOCK_SIZE),
888
+ sycl::range<1>(SYCL_SQR_BLOCK_SIZE)),
889
+ [=](sycl::nd_item<1> item_ct1) {
890
+ unary_op_sqr_kernel(src, dst_ptr, k_elements, item_ct1);
891
+ });
892
+ });
893
+ }
894
+
895
+ static inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
896
+ ggml_sycl_detail::dispatch_ggml_sycl_op_upscale(ctx, dst,
897
+ [](const auto* src, auto* dst_ptr, int nb00, int nb01, int nb02, int nb03,
898
+ int ne10, int ne11, int ne12, int ne13, float sf0, float sf1, float sf2, float sf3,
899
+ queue_ptr stream) {
900
+ ggml_sycl_detail::upscale_sycl(src, dst_ptr, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, stream);
901
+ });
902
  }
903
 
904
+ static inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
905
+ ggml_sycl_detail::dispatch_ggml_sycl_op_pad(ctx, dst,
906
+ [](const auto* src, auto* dst_ptr, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2,
907
+ queue_ptr stream) {
908
+ ggml_sycl_detail::pad_sycl(src, dst_ptr, ne00, ne01, ne02, ne0, ne1, ne2, stream);
909
+ });
910
+ }
911
 
912
+ static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
913
+ float min_val;
914
+ float max_val;
915
+ memcpy(&min_val, dst->op_params, sizeof(float));
916
+ memcpy(&max_val, (float *) dst->op_params + 1, sizeof(float));
917
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
918
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream, float min_arg, float max_arg) {
919
+ const int num_blocks = ceil_div(k_elements, SYCL_CLAMP_BLOCK_SIZE);
920
+ sycl_parallel_for(stream,
921
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE),
922
+ sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE)),
923
+ [=](sycl::nd_item<1> item_ct1) {
924
+ clamp(src, dst_ptr, min_arg, max_arg, k_elements, item_ct1);
925
+ });
926
+ }, min_val, max_val);
927
+ }
928
+
929
+ static inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
930
  GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
931
  GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32);
932
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
942
  // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
943
  int offset = dst->op_params[3] / 4; // offset in bytes
944
 
945
+ ggml_sycl_detail::acc_f32_sycl(src0_dd, src1_dd, dst_dd, (int)ggml_nelements(dst), (int)dst->src[1]->ne[0], (int)dst->src[1]->ne[1], (int)dst->src[1]->ne[2], nb1, nb2, offset, main_stream);
946
+ }
947
+
948
+ static inline void ggml_sycl_op_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
949
+ ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
950
+ [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
951
+ const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
952
+ sycl_parallel_for(main_stream,
953
+ sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
954
+ gated_op_fused_geglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
955
+ });
956
+ });
957
+ }
958
+
959
+ static inline void ggml_sycl_op_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
960
+ ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
961
+ [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
962
+ const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_RELU_BLOCK_SIZE); // Using RELU block size for reglu
963
+ sycl_parallel_for(main_stream,
964
+ sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
965
+ gated_op_fused_reglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
966
+ });
967
+ });
968
+ }
969
+
970
+ static inline void ggml_sycl_op_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
971
+ ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
972
+ [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
973
+ const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_SILU_BLOCK_SIZE); // Using SILU block size for swiglu
974
+ sycl_parallel_for(main_stream,
975
+ sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
976
+ gated_op_fused_swiglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
977
+ });
978
+ });
979
  }
980
 
981
 
 
1103
  scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1104
  ggml_sycl_op_elu(ctx, dst);
1105
  }
1106
+
1107
+ void ggml_sycl_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1108
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1109
+ ggml_sycl_op_geglu(ctx, dst);
1110
+ }
1111
+
1112
+ void ggml_sycl_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1113
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1114
+ ggml_sycl_op_reglu(ctx, dst);
1115
+ }
1116
+
1117
+ void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1118
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1119
+ ggml_sycl_op_swiglu(ctx, dst);
1120
+ }
ggml/src/ggml-sycl/element_wise.hpp CHANGED
@@ -3,27 +3,30 @@
3
 
4
  #include "common.hpp"
5
  #include "ggml.h"
6
- #include <limits.h>
7
 
8
  template <typename T>
9
  T neg_infinity() {
10
  return -std::numeric_limits<T>::infinity();
11
  }
12
 
13
- template<typename T>
14
  struct typed_data {
15
- const T * src;
16
- T * dst;
17
  };
18
 
19
- template<typename T>
20
- typed_data<T> cast_data(ggml_tensor * dst) {
21
  return {
22
- /* .src = */ static_cast<const T *>(dst->src[0]->data),
23
- /* .dst = */ static_cast<T *>(dst->data)
24
  };
25
  }
26
 
 
 
 
27
  void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
28
 
29
  void ggml_sycl_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
@@ -73,5 +76,9 @@ void ggml_sycl_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
73
  void ggml_sycl_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
74
 
75
  void ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
76
- #endif // GGML_SYCL_ELEMENTWISE_HPP
77
 
 
 
 
 
 
 
3
 
4
  #include "common.hpp"
5
  #include "ggml.h"
6
+ #include <limits> // For std::numeric_limits
7
 
8
  template <typename T>
9
  T neg_infinity() {
10
  return -std::numeric_limits<T>::infinity();
11
  }
12
 
13
+ template<typename T_Dst, typename T_Src = T_Dst>
14
  struct typed_data {
15
+ const T_Src * src;
16
+ T_Dst * dst;
17
  };
18
 
19
+ template<typename T_Dst, typename T_Src = T_Dst>
20
+ typed_data<T_Dst, T_Src> cast_data(ggml_tensor * dst) {
21
  return {
22
+ /* .src = */ static_cast<const T_Src *>(dst->src[0]->data),
23
+ /* .dst = */ static_cast<T_Dst *>(dst->data)
24
  };
25
  }
26
 
27
+ const float GELU_QUICK_COEF = -1.702f;
28
+
29
+
30
  void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
31
 
32
  void ggml_sycl_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
76
  void ggml_sycl_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
77
 
78
  void ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
79
 
80
+ void ggml_sycl_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
81
+ void ggml_sycl_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
82
+ void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
83
+
84
+ #endif // GGML_SYCL_ELEMENTWISE_HPP
ggml/src/ggml-sycl/ggml-sycl.cpp CHANGED
@@ -3676,6 +3676,21 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
3676
  return false;
3677
  }
3678
  break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3679
  case GGML_OP_NORM:
3680
  ggml_sycl_norm(ctx, dst);
3681
  break;
@@ -4212,6 +4227,16 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4212
  default:
4213
  return false;
4214
  }
 
 
 
 
 
 
 
 
 
 
4215
  case GGML_OP_MUL_MAT:
4216
  case GGML_OP_MUL_MAT_ID:
4217
  {
 
3676
  return false;
3677
  }
3678
  break;
3679
+ case GGML_OP_GLU:
3680
+ switch (ggml_get_glu_op(dst)) {
3681
+ case GGML_GLU_OP_REGLU:
3682
+ ggml_sycl_reglu(ctx, dst);
3683
+ break;
3684
+ case GGML_GLU_OP_GEGLU:
3685
+ ggml_sycl_geglu(ctx, dst);
3686
+ break;
3687
+ case GGML_GLU_OP_SWIGLU:
3688
+ ggml_sycl_swiglu(ctx, dst);
3689
+ break;
3690
+ default:
3691
+ return false;
3692
+ }
3693
+ break;
3694
  case GGML_OP_NORM:
3695
  ggml_sycl_norm(ctx, dst);
3696
  break;
 
4227
  default:
4228
  return false;
4229
  }
4230
+ case GGML_OP_GLU:
4231
+ switch (ggml_get_glu_op(op)) {
4232
+ case GGML_GLU_OP_REGLU:
4233
+ case GGML_GLU_OP_GEGLU:
4234
+ case GGML_GLU_OP_SWIGLU:
4235
+ return ggml_is_contiguous_1(op->src[0]);
4236
+ default:
4237
+ return false;
4238
+ }
4239
+ break;
4240
  case GGML_OP_MUL_MAT:
4241
  case GGML_OP_MUL_MAT_ID:
4242
  {
ggml/src/ggml-vulkan/ggml-vulkan.cpp CHANGED
@@ -437,6 +437,10 @@ struct vk_device_struct {
437
  vk_pipeline pipeline_tanh[2];
438
  vk_pipeline pipeline_sigmoid[2];
439
 
 
 
 
 
440
  vk_pipeline pipeline_leaky_relu_f32;
441
  vk_pipeline pipeline_silu_back_f32;
442
  vk_pipeline pipeline_diag_mask_inf_f32;
@@ -661,6 +665,13 @@ struct vk_op_push_constants {
661
  float param2;
662
  };
663
 
 
 
 
 
 
 
 
664
  struct vk_op_unary_push_constants {
665
  uint32_t ne;
666
  uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
@@ -2757,6 +2768,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
2757
  CREATE_UNARY(sigmoid)
2758
  #undef CREATE_UNARY
2759
 
 
 
 
 
 
 
 
 
 
2760
  ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2761
  ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2762
 
@@ -6473,6 +6493,24 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6473
  break;
6474
  }
6475
  return nullptr;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6476
  case GGML_OP_DIAG_MASK_INF:
6477
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6478
  return ctx->device->pipeline_diag_mask_inf_f32;
@@ -6933,6 +6971,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
6933
  case GGML_OP_CONCAT:
6934
  case GGML_OP_UPSCALE:
6935
  case GGML_OP_UNARY:
 
6936
  case GGML_OP_CONV_2D_DW:
6937
  {
6938
  uint32_t ne = ggml_nelements(dst);
@@ -6973,7 +7012,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
6973
  }
6974
  }
6975
 
6976
- if (op == GGML_OP_SOFT_MAX) {
6977
  // Empty src1 is possible in soft_max, but the shader needs a buffer
6978
  vk_subbuffer subbuf_y;
6979
  if (use_src1) {
@@ -7566,6 +7605,25 @@ static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, con
7566
  ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
7567
  }
7568
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7569
  static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7570
  int32_t * op_params = (int32_t *)dst->op_params;
7571
  ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun);
@@ -8778,6 +8836,16 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
8778
  return false;
8779
  }
8780
  break;
 
 
 
 
 
 
 
 
 
 
8781
  case GGML_OP_REPEAT:
8782
  case GGML_OP_REPEAT_BACK:
8783
  case GGML_OP_GET_ROWS:
@@ -8870,6 +8938,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
8870
  case GGML_OP_RMS_NORM_BACK:
8871
  case GGML_OP_L2_NORM:
8872
  case GGML_OP_UNARY:
 
8873
  case GGML_OP_DIAG_MASK_INF:
8874
  case GGML_OP_SOFT_MAX:
8875
  case GGML_OP_SOFT_MAX_BACK:
@@ -9013,6 +9082,17 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9013
  return false;
9014
  }
9015
  break;
 
 
 
 
 
 
 
 
 
 
 
9016
  case GGML_OP_DIAG_MASK_INF:
9017
  ggml_vk_diag_mask_inf(ctx, compute_ctx, src0, node, dryrun);
9018
 
@@ -9138,8 +9218,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9138
  if (!ok) {
9139
  if (node->op == GGML_OP_UNARY) {
9140
  std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast<ggml_unary_op>(node->op_params[0])) << ")" << std::endl;
9141
- }
9142
- else {
 
9143
  std::cerr << __func__ << ": error: op not supported " << node->name << " (" << ggml_op_name(node->op) << ")" << std::endl;
9144
  }
9145
  }
@@ -9218,6 +9299,17 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
9218
  return false;
9219
  }
9220
  break;
 
 
 
 
 
 
 
 
 
 
 
9221
  case GGML_OP_MUL_MAT:
9222
  case GGML_OP_MUL_MAT_ID:
9223
  case GGML_OP_FLASH_ATTN_EXT:
@@ -10016,6 +10108,19 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10016
  return false;
10017
  }
10018
  break;
 
 
 
 
 
 
 
 
 
 
 
 
 
10019
  case GGML_OP_MUL_MAT:
10020
  case GGML_OP_MUL_MAT_ID:
10021
  {
@@ -10746,6 +10851,12 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
10746
  std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
10747
  GGML_ABORT("fatal error");
10748
  }
 
 
 
 
 
 
10749
  } else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) {
10750
  if (src1 == nullptr) {
10751
  tensor_clone = ggml_dup(ggml_ctx, src_clone[0]);
 
437
  vk_pipeline pipeline_tanh[2];
438
  vk_pipeline pipeline_sigmoid[2];
439
 
440
+ vk_pipeline pipeline_geglu[2];
441
+ vk_pipeline pipeline_reglu[2];
442
+ vk_pipeline pipeline_swiglu[2];
443
+
444
  vk_pipeline pipeline_leaky_relu_f32;
445
  vk_pipeline pipeline_silu_back_f32;
446
  vk_pipeline pipeline_diag_mask_inf_f32;
 
665
  float param2;
666
  };
667
 
668
+ struct vk_op_glu_push_constants {
669
+ uint32_t N;
670
+ uint32_t ne00;
671
+ uint32_t ne20;
672
+ uint32_t mode; // 0: default, 1: swapped, 2: split
673
+ };
674
+
675
  struct vk_op_unary_push_constants {
676
  uint32_t ne;
677
  uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
 
2768
  CREATE_UNARY(sigmoid)
2769
  #undef CREATE_UNARY
2770
 
2771
+ #define CREATE_GLU(name) \
2772
+ ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
2773
+ ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true);
2774
+
2775
+ CREATE_GLU(geglu)
2776
+ CREATE_GLU(reglu)
2777
+ CREATE_GLU(swiglu)
2778
+ #undef CREATE_GLU
2779
+
2780
  ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2781
  ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2782
 
 
6493
  break;
6494
  }
6495
  return nullptr;
6496
+ case GGML_OP_GLU:
6497
+ if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) ||
6498
+ (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) ||
6499
+ (src0->type != dst->type)) {
6500
+ return nullptr;
6501
+ }
6502
+
6503
+ switch (ggml_get_glu_op(dst)) {
6504
+ case GGML_GLU_OP_GEGLU:
6505
+ return ctx->device->pipeline_geglu[dst->type == GGML_TYPE_F16];
6506
+ case GGML_GLU_OP_REGLU:
6507
+ return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16];
6508
+ case GGML_GLU_OP_SWIGLU:
6509
+ return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16];
6510
+ default:
6511
+ break;
6512
+ }
6513
+ return nullptr;
6514
  case GGML_OP_DIAG_MASK_INF:
6515
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6516
  return ctx->device->pipeline_diag_mask_inf_f32;
 
6971
  case GGML_OP_CONCAT:
6972
  case GGML_OP_UPSCALE:
6973
  case GGML_OP_UNARY:
6974
+ case GGML_OP_GLU:
6975
  case GGML_OP_CONV_2D_DW:
6976
  {
6977
  uint32_t ne = ggml_nelements(dst);
 
7012
  }
7013
  }
7014
 
7015
+ if (op == GGML_OP_SOFT_MAX || op == GGML_OP_GLU) {
7016
  // Empty src1 is possible in soft_max, but the shader needs a buffer
7017
  vk_subbuffer subbuf_y;
7018
  if (use_src1) {
 
7605
  ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
7606
  }
7607
 
7608
+ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
7609
+ const bool swapped = (bool)dst->op_params[1];
7610
+ const bool split = src1 != nullptr;
7611
+
7612
+ GGML_ASSERT(ggml_is_contiguous(src0));
7613
+
7614
+ if (!split) {
7615
+ GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]);
7616
+ } else {
7617
+ GGML_ASSERT(src0->ne[0] == src1->ne[0]);
7618
+ GGML_ASSERT(src0->ne[0] == dst->ne[0]);
7619
+ GGML_ASSERT(src0->type == src1->type);
7620
+ }
7621
+
7622
+ const uint32_t mode = split ? 2 : (swapped ? 1 : 0);
7623
+
7624
+ ggml_vk_op_f32<vk_op_glu_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU, { (uint32_t)ggml_nelements(dst), (uint32_t)src0->ne[0], (uint32_t)dst->ne[0], mode }, dryrun);
7625
+ }
7626
+
7627
  static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7628
  int32_t * op_params = (int32_t *)dst->op_params;
7629
  ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun);
 
8836
  return false;
8837
  }
8838
  break;
8839
+ case GGML_OP_GLU:
8840
+ switch (ggml_get_glu_op(node)) {
8841
+ case GGML_GLU_OP_GEGLU:
8842
+ case GGML_GLU_OP_REGLU:
8843
+ case GGML_GLU_OP_SWIGLU:
8844
+ break;
8845
+ default:
8846
+ return false;
8847
+ }
8848
+ break;
8849
  case GGML_OP_REPEAT:
8850
  case GGML_OP_REPEAT_BACK:
8851
  case GGML_OP_GET_ROWS:
 
8938
  case GGML_OP_RMS_NORM_BACK:
8939
  case GGML_OP_L2_NORM:
8940
  case GGML_OP_UNARY:
8941
+ case GGML_OP_GLU:
8942
  case GGML_OP_DIAG_MASK_INF:
8943
  case GGML_OP_SOFT_MAX:
8944
  case GGML_OP_SOFT_MAX_BACK:
 
9082
  return false;
9083
  }
9084
  break;
9085
+ case GGML_OP_GLU:
9086
+ switch (ggml_get_glu_op(node)) {
9087
+ case GGML_GLU_OP_GEGLU:
9088
+ case GGML_GLU_OP_REGLU:
9089
+ case GGML_GLU_OP_SWIGLU:
9090
+ ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun);
9091
+ break;
9092
+ default:
9093
+ return false;
9094
+ }
9095
+ break;
9096
  case GGML_OP_DIAG_MASK_INF:
9097
  ggml_vk_diag_mask_inf(ctx, compute_ctx, src0, node, dryrun);
9098
 
 
9218
  if (!ok) {
9219
  if (node->op == GGML_OP_UNARY) {
9220
  std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast<ggml_unary_op>(node->op_params[0])) << ")" << std::endl;
9221
+ } else if (node->op == GGML_OP_GLU) {
9222
+ std::cerr << __func__ << ": error: op not supported GLU " << node->name << " (" << ggml_glu_op_name(static_cast<ggml_glu_op>(node->op_params[0])) << ")" << std::endl;
9223
+ } else {
9224
  std::cerr << __func__ << ": error: op not supported " << node->name << " (" << ggml_op_name(node->op) << ")" << std::endl;
9225
  }
9226
  }
 
9299
  return false;
9300
  }
9301
  break;
9302
+ case GGML_OP_GLU:
9303
+ switch (ggml_get_glu_op(tensor)) {
9304
+ case GGML_GLU_OP_GEGLU:
9305
+ case GGML_GLU_OP_REGLU:
9306
+ case GGML_GLU_OP_SWIGLU:
9307
+ buf = tensor->buffer;
9308
+ break;
9309
+ default:
9310
+ return false;
9311
+ }
9312
+ break;
9313
  case GGML_OP_MUL_MAT:
9314
  case GGML_OP_MUL_MAT_ID:
9315
  case GGML_OP_FLASH_ATTN_EXT:
 
10108
  return false;
10109
  }
10110
  break;
10111
+ case GGML_OP_GLU:
10112
+ switch (ggml_get_glu_op(op)) {
10113
+ case GGML_GLU_OP_GEGLU:
10114
+ case GGML_GLU_OP_REGLU:
10115
+ case GGML_GLU_OP_SWIGLU:
10116
+ return ggml_is_contiguous(op->src[0]) &&
10117
+ (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
10118
+ (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
10119
+ (op->src[0]->type == op->type);
10120
+ default:
10121
+ return false;
10122
+ }
10123
+ break;
10124
  case GGML_OP_MUL_MAT:
10125
  case GGML_OP_MUL_MAT_ID:
10126
  {
 
10851
  std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
10852
  GGML_ABORT("fatal error");
10853
  }
10854
+ } else if (tensor->op == GGML_OP_GLU) {
10855
+ if (src_clone[1] == nullptr) {
10856
+ tensor_clone = ggml_glu(ggml_ctx, src_clone[0], (ggml_glu_op) tensor->op_params[0], tensor->op_params[1]);
10857
+ } else {
10858
+ tensor_clone = ggml_glu_split(ggml_ctx, src_clone[0], src_clone[1], (ggml_glu_op) tensor->op_params[0]);
10859
+ }
10860
  } else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) {
10861
  if (src1 == nullptr) {
10862
  tensor_clone = ggml_dup(ggml_ctx, src_clone[0]);
ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #version 450
2
+
3
+ #include "glu_head.comp"
4
+
5
+ const float GELU_COEF_A = 0.044715f;
6
+ const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
7
+
8
+ float op(float a, float b) {
9
+ const float val = SQRT_2_OVER_PI*a*(1.0f + GELU_COEF_A*a*a);
10
+ return 0.5f*a*(2.0f - 2.0f / (exp(2 * val) + 1)) * b;
11
+ }
12
+
13
+ #include "glu_main.comp"
ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #extension GL_EXT_shader_16bit_storage : require
2
+
3
+ layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
4
+
5
+ layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
6
+ layout (binding = 1) readonly buffer B {A_TYPE data_b[];};
7
+ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
8
+
9
+ layout (push_constant) uniform parameter
10
+ {
11
+ uint N;
12
+ uint ne00;
13
+ uint ne20;
14
+ uint mode;
15
+ } p;
ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ void main() {
2
+ const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
3
+
4
+ if (i >= p.N) {
5
+ return;
6
+ }
7
+
8
+ const uint row = i / p.ne20;
9
+ const uint col = i - row * p.ne20;
10
+
11
+ if (p.mode == 0) {
12
+ // Default
13
+ const uint offset = p.ne00 / 2;
14
+ const uint idx = row * p.ne00 + col;
15
+
16
+ data_d[row * offset + col] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset])));
17
+ } else if (p.mode == 1) {
18
+ // Swapped
19
+ const uint offset = p.ne00 / 2;
20
+ const uint idx = row * p.ne00 + col;
21
+
22
+ data_d[row * offset + col] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx])));
23
+ } else {
24
+ // Split
25
+ const uint idx = row * p.ne00 + col;
26
+
27
+ data_d[idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx])));
28
+ }
29
+ }
ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #version 450
2
+
3
+ #include "glu_head.comp"
4
+
5
+ float op(float a, float b) {
6
+ return max(a, 0.0f) * b;
7
+ }
8
+
9
+ #include "glu_main.comp"
ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #version 450
2
+
3
+ #include "glu_head.comp"
4
+
5
+ float op(float a, float b) {
6
+ return a / (1.0f + exp(-a)) * b;
7
+ }
8
+
9
+ #include "glu_main.comp"
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp CHANGED
@@ -585,6 +585,13 @@ void process_shaders() {
585
  string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
586
  string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
587
 
 
 
 
 
 
 
 
588
  string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
589
  string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
590
 
 
585
  string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
586
  string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
587
 
588
+ string_to_spv("geglu_f16", "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
589
+ string_to_spv("geglu_f32", "geglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
590
+ string_to_spv("reglu_f16", "reglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
591
+ string_to_spv("reglu_f32", "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
592
+ string_to_spv("swiglu_f16", "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
593
+ string_to_spv("swiglu_f32", "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
594
+
595
  string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
596
  string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
597
 
ggml/src/ggml.c CHANGED
@@ -982,9 +982,11 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
982
  "CROSS_ENTROPY_LOSS",
983
  "CROSS_ENTROPY_LOSS_BACK",
984
  "OPT_STEP_ADAMW",
 
 
985
  };
986
 
987
- static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84");
988
 
989
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
990
  "none",
@@ -1079,9 +1081,11 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1079
  "cross_entropy_loss(x,y)",
1080
  "cross_entropy_loss_back(x,y)",
1081
  "adamw(x)",
 
 
1082
  };
1083
 
1084
- static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84");
1085
 
1086
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
1087
 
@@ -1107,6 +1111,15 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
1107
  static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15");
1108
 
1109
 
 
 
 
 
 
 
 
 
 
1110
  static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
1111
  static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
1112
 
@@ -1209,11 +1222,19 @@ const char * ggml_unary_op_name(enum ggml_unary_op op) {
1209
  return GGML_UNARY_OP_NAME[op];
1210
  }
1211
 
 
 
 
 
1212
  const char * ggml_op_desc(const struct ggml_tensor * t) {
1213
  if (t->op == GGML_OP_UNARY) {
1214
  enum ggml_unary_op uop = ggml_get_unary_op(t);
1215
  return ggml_unary_op_name(uop);
1216
  }
 
 
 
 
1217
  return ggml_op_name(t->op);
1218
  }
1219
 
@@ -1730,6 +1751,11 @@ enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor) {
1730
  return (enum ggml_unary_op) ggml_get_op_params_i32(tensor, 0);
1731
  }
1732
 
 
 
 
 
 
1733
  const char * ggml_get_name(const struct ggml_tensor * tensor) {
1734
  return tensor->name;
1735
  }
@@ -2609,6 +2635,114 @@ struct ggml_tensor * ggml_exp_inplace(
2609
  return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_EXP);
2610
  }
2611
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2612
  // ggml_norm
2613
 
2614
  static struct ggml_tensor * ggml_norm_impl(
 
982
  "CROSS_ENTROPY_LOSS",
983
  "CROSS_ENTROPY_LOSS_BACK",
984
  "OPT_STEP_ADAMW",
985
+
986
+ "GLU",
987
  };
988
 
989
+ static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85");
990
 
991
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
992
  "none",
 
1081
  "cross_entropy_loss(x,y)",
1082
  "cross_entropy_loss_back(x,y)",
1083
  "adamw(x)",
1084
+
1085
+ "glu(x)",
1086
  };
1087
 
1088
+ static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85");
1089
 
1090
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
1091
 
 
1111
  static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15");
1112
 
1113
 
1114
+ static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = {
1115
+ "REGLU",
1116
+ "GEGLU",
1117
+ "SWIGLU",
1118
+ };
1119
+
1120
+ static_assert(GGML_GLU_OP_COUNT == 3, "GGML_GLU_OP_COUNT != 3");
1121
+
1122
+
1123
  static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
1124
  static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
1125
 
 
1222
  return GGML_UNARY_OP_NAME[op];
1223
  }
1224
 
1225
+ const char * ggml_glu_op_name(enum ggml_glu_op op) {
1226
+ return GGML_GLU_OP_NAME[op];
1227
+ }
1228
+
1229
  const char * ggml_op_desc(const struct ggml_tensor * t) {
1230
  if (t->op == GGML_OP_UNARY) {
1231
  enum ggml_unary_op uop = ggml_get_unary_op(t);
1232
  return ggml_unary_op_name(uop);
1233
  }
1234
+ if (t->op == GGML_OP_GLU) {
1235
+ enum ggml_glu_op gop = ggml_get_glu_op(t);
1236
+ return ggml_glu_op_name(gop);
1237
+ }
1238
  return ggml_op_name(t->op);
1239
  }
1240
 
 
1751
  return (enum ggml_unary_op) ggml_get_op_params_i32(tensor, 0);
1752
  }
1753
 
1754
+ enum ggml_glu_op ggml_get_glu_op(const struct ggml_tensor * tensor) {
1755
+ GGML_ASSERT(tensor->op == GGML_OP_GLU);
1756
+ return (enum ggml_glu_op) ggml_get_op_params_i32(tensor, 0);
1757
+ }
1758
+
1759
  const char * ggml_get_name(const struct ggml_tensor * tensor) {
1760
  return tensor->name;
1761
  }
 
2635
  return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_EXP);
2636
  }
2637
 
2638
+ // ggml_glu
2639
+
2640
+ static struct ggml_tensor * ggml_glu_impl(
2641
+ struct ggml_context * ctx,
2642
+ struct ggml_tensor * a,
2643
+ struct ggml_tensor * b,
2644
+ enum ggml_glu_op op,
2645
+ bool swapped) {
2646
+ GGML_ASSERT(ggml_is_contiguous_1(a));
2647
+
2648
+ if (b) {
2649
+ GGML_ASSERT(ggml_is_contiguous_1(b));
2650
+ GGML_ASSERT(ggml_are_same_shape(a, b));
2651
+ GGML_ASSERT(a->type == b->type);
2652
+ }
2653
+
2654
+ int64_t ne[GGML_MAX_DIMS] = { a->ne[0] / 2 }; for (int i = 1; i < GGML_MAX_DIMS; i++) ne[i] = a->ne[i];
2655
+ struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, b ? a->ne : ne, NULL, 0);
2656
+
2657
+ ggml_set_op_params_i32(result, 0, (int32_t) op);
2658
+ ggml_set_op_params_i32(result, 1, (int32_t) swapped);
2659
+
2660
+ result->op = GGML_OP_GLU;
2661
+ result->src[0] = a;
2662
+ result->src[1] = b;
2663
+
2664
+ return result;
2665
+ }
2666
+
2667
+ struct ggml_tensor * ggml_glu(
2668
+ struct ggml_context * ctx,
2669
+ struct ggml_tensor * a,
2670
+ enum ggml_glu_op op,
2671
+ bool swapped) {
2672
+ return ggml_glu_impl(ctx, a, NULL, op, swapped);
2673
+ }
2674
+
2675
+ struct ggml_tensor * ggml_glu_split(
2676
+ struct ggml_context * ctx,
2677
+ struct ggml_tensor * a,
2678
+ struct ggml_tensor * b,
2679
+ enum ggml_glu_op op) {
2680
+ return ggml_glu_impl(ctx, a, b, op, false);
2681
+ }
2682
+
2683
+ // ggml_reglu
2684
+
2685
+ struct ggml_tensor * ggml_reglu(
2686
+ struct ggml_context * ctx,
2687
+ struct ggml_tensor * a) {
2688
+ return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_REGLU, false);
2689
+ }
2690
+
2691
+ struct ggml_tensor * ggml_reglu_swapped(
2692
+ struct ggml_context * ctx,
2693
+ struct ggml_tensor * a) {
2694
+ return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_REGLU, true);
2695
+ }
2696
+
2697
+ struct ggml_tensor * ggml_reglu_split(
2698
+ struct ggml_context * ctx,
2699
+ struct ggml_tensor * a,
2700
+ struct ggml_tensor * b) {
2701
+ return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_REGLU, false);
2702
+ }
2703
+
2704
+ // ggml_geglu
2705
+
2706
+ struct ggml_tensor * ggml_geglu(
2707
+ struct ggml_context * ctx,
2708
+ struct ggml_tensor * a) {
2709
+ return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU, false);
2710
+ }
2711
+
2712
+ struct ggml_tensor * ggml_geglu_swapped(
2713
+ struct ggml_context * ctx,
2714
+ struct ggml_tensor * a) {
2715
+ return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU, true);
2716
+ }
2717
+
2718
+ struct ggml_tensor * ggml_geglu_split(
2719
+ struct ggml_context * ctx,
2720
+ struct ggml_tensor * a,
2721
+ struct ggml_tensor * b) {
2722
+ return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU, false);
2723
+ }
2724
+
2725
+ // ggml_swiglu
2726
+
2727
+ struct ggml_tensor * ggml_swiglu(
2728
+ struct ggml_context * ctx,
2729
+ struct ggml_tensor * a) {
2730
+ return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_SWIGLU, false);
2731
+ }
2732
+
2733
+ struct ggml_tensor * ggml_swiglu_swapped(
2734
+ struct ggml_context * ctx,
2735
+ struct ggml_tensor * a) {
2736
+ return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_SWIGLU, true);
2737
+ }
2738
+
2739
+ struct ggml_tensor * ggml_swiglu_split(
2740
+ struct ggml_context * ctx,
2741
+ struct ggml_tensor * a,
2742
+ struct ggml_tensor * b) {
2743
+ return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_SWIGLU, false);
2744
+ }
2745
+
2746
  // ggml_norm
2747
 
2748
  static struct ggml_tensor * ggml_norm_impl(