lhez commited on
Commit
d70ff9f
·
1 Parent(s): 68eb27a

opencl : add GEGLU, REGLU, SWIGLU (llama/14456)

Browse files
ggml/src/ggml-opencl/CMakeLists.txt CHANGED
@@ -65,6 +65,7 @@ set(GGML_OPENCL_KERNELS
65
  gemv_noshuffle_general
66
  gemv_noshuffle
67
  get_rows
 
68
  group_norm
69
  im2col_f32
70
  im2col_f16
 
65
  gemv_noshuffle_general
66
  gemv_noshuffle
67
  get_rows
68
+ glu
69
  group_norm
70
  im2col_f32
71
  im2col_f16
ggml/src/ggml-opencl/ggml-opencl.cpp CHANGED
@@ -351,6 +351,7 @@ struct ggml_backend_opencl_context {
351
  cl_program program_gemv_noshuffle_general;
352
  cl_program program_gemv_noshuffle;
353
  cl_program program_get_rows;
 
354
  cl_program program_im2col_f16;
355
  cl_program program_im2col_f32;
356
  cl_program program_mul_mat_Ab_Bi_8x4;
@@ -401,6 +402,8 @@ struct ggml_backend_opencl_context {
401
  cl_kernel kernel_relu;
402
  cl_kernel kernel_sigmoid_f32, kernel_sigmoid_f16;
403
  cl_kernel kernel_clamp;
 
 
404
  cl_kernel kernel_norm;
405
  cl_kernel kernel_rms_norm;
406
  cl_kernel kernel_group_norm;
@@ -738,6 +741,27 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
738
  GGML_LOG_CONT(".");
739
  }
740
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
741
  // get_rows
742
  {
743
  #ifdef GGML_OPENCL_EMBED_KERNELS
@@ -2242,6 +2266,15 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
2242
  default:
2243
  return false;
2244
  }
 
 
 
 
 
 
 
 
 
2245
  case GGML_OP_CLAMP:
2246
  return op->src[0]->type == GGML_TYPE_F32;
2247
  case GGML_OP_SOFT_MAX:
@@ -6143,6 +6176,91 @@ static void ggml_cl_sum_rows(ggml_backend_t backend, const ggml_tensor * src0, c
6143
  backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
6144
  }
6145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6146
  //------------------------------------------------------------------------------
6147
  // Op offloading
6148
  //------------------------------------------------------------------------------
@@ -6244,6 +6362,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
6244
  default:
6245
  return false;
6246
  } break;
 
 
 
 
 
 
6247
  case GGML_OP_CLAMP:
6248
  if (!any_on_device) {
6249
  return false;
 
351
  cl_program program_gemv_noshuffle_general;
352
  cl_program program_gemv_noshuffle;
353
  cl_program program_get_rows;
354
+ cl_program program_glu;
355
  cl_program program_im2col_f16;
356
  cl_program program_im2col_f32;
357
  cl_program program_mul_mat_Ab_Bi_8x4;
 
402
  cl_kernel kernel_relu;
403
  cl_kernel kernel_sigmoid_f32, kernel_sigmoid_f16;
404
  cl_kernel kernel_clamp;
405
+ cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu,
406
+ kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16;
407
  cl_kernel kernel_norm;
408
  cl_kernel kernel_rms_norm;
409
  cl_kernel kernel_group_norm;
 
741
  GGML_LOG_CONT(".");
742
  }
743
 
744
+ // glu
745
+ {
746
+ #ifdef GGML_OPENCL_EMBED_KERNELS
747
+ const std::string kernel_src {
748
+ #include "glu.cl.h"
749
+ };
750
+ #else
751
+ const std::string kernel_src = read_file("glu.cl");
752
+ #endif
753
+ backend_ctx->program_glu =
754
+ build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
755
+
756
+ CL_CHECK((backend_ctx->kernel_geglu = clCreateKernel(backend_ctx->program_glu, "kernel_geglu", &err), err));
757
+ CL_CHECK((backend_ctx->kernel_reglu = clCreateKernel(backend_ctx->program_glu, "kernel_reglu", &err), err));
758
+ CL_CHECK((backend_ctx->kernel_swiglu = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu", &err), err));
759
+ CL_CHECK((backend_ctx->kernel_geglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_f16", &err), err));
760
+ CL_CHECK((backend_ctx->kernel_reglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_reglu_f16", &err), err));
761
+ CL_CHECK((backend_ctx->kernel_swiglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu_f16", &err), err));
762
+ GGML_LOG_CONT(".");
763
+ }
764
+
765
  // get_rows
766
  {
767
  #ifdef GGML_OPENCL_EMBED_KERNELS
 
2266
  default:
2267
  return false;
2268
  }
2269
+ case GGML_OP_GLU:
2270
+ switch (ggml_get_glu_op(op)) {
2271
+ case GGML_GLU_OP_GEGLU:
2272
+ case GGML_GLU_OP_REGLU:
2273
+ case GGML_GLU_OP_SWIGLU:
2274
+ return ggml_is_contiguous_1(op->src[0]) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
2275
+ default:
2276
+ return false;
2277
+ }
2278
  case GGML_OP_CLAMP:
2279
  return op->src[0]->type == GGML_TYPE_F32;
2280
  case GGML_OP_SOFT_MAX:
 
6176
  backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
6177
  }
6178
 
6179
+ static void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6180
+ GGML_ASSERT(src0);
6181
+ GGML_ASSERT(src0->extra);
6182
+ GGML_ASSERT(dst);
6183
+ GGML_ASSERT(dst->extra);
6184
+
6185
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
6186
+
6187
+ if (src1) {
6188
+ GGML_ASSERT(src1);
6189
+ GGML_ASSERT(src1->extra);
6190
+ GGML_ASSERT(ggml_are_same_shape(src0, src1));
6191
+ }
6192
+
6193
+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
6194
+
6195
+ cl_kernel kernel;
6196
+ switch (ggml_get_glu_op(dst)) {
6197
+ case GGML_GLU_OP_GEGLU:
6198
+ if (dst->type == GGML_TYPE_F32) {
6199
+ kernel = backend_ctx->kernel_geglu;
6200
+ } else {
6201
+ kernel = backend_ctx->kernel_geglu_f16;
6202
+ }
6203
+ break;
6204
+ case GGML_GLU_OP_REGLU:
6205
+ if (dst->type == GGML_TYPE_F32) {
6206
+ kernel = backend_ctx->kernel_reglu;
6207
+ } else {
6208
+ kernel = backend_ctx->kernel_reglu_f16;
6209
+ }
6210
+ break;
6211
+ case GGML_GLU_OP_SWIGLU:
6212
+ if (dst->type == GGML_TYPE_F32) {
6213
+ kernel = backend_ctx->kernel_swiglu;
6214
+ } else {
6215
+ kernel = backend_ctx->kernel_swiglu_f16;
6216
+ }
6217
+ break;
6218
+ default:
6219
+ GGML_ABORT("Unsupported glu op");
6220
+ }
6221
+
6222
+ ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
6223
+ ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
6224
+
6225
+ ggml_tensor_extra_cl * extra1 = src1 ? (ggml_tensor_extra_cl *)src1->extra : nullptr;
6226
+
6227
+ cl_ulong offset0 = extra0->offset + src0->view_offs;
6228
+ cl_ulong offsetd = extrad->offset + dst->view_offs;
6229
+
6230
+ cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0;
6231
+
6232
+ const int ne0 = dst->ne[0];
6233
+
6234
+ const cl_ulong nb01 = src0->nb[1];
6235
+ const cl_ulong nb11 = src1 ? src1->nb[1] : nb01;
6236
+
6237
+ const cl_ulong nb1 = dst->nb[1];
6238
+
6239
+ const int swp = ((const int32_t *) dst->op_params)[1];
6240
+ const int ne00_off = src1 ? 0 : (swp ? ne0 : 0);
6241
+ const int ne10_off = src1 ? 0 : (swp ? 0 : ne0);
6242
+
6243
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
6244
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
6245
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), src1 ? &extra1->data_device : &extra0->data_device));
6246
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
6247
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
6248
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
6249
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb01));
6250
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb11));
6251
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne0));
6252
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb1));
6253
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne00_off));
6254
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne10_off));
6255
+
6256
+ const size_t nrows = ggml_nrows(src0);
6257
+ size_t nth = 512;
6258
+ size_t global_work_size[] = {nrows*nth, 1, 1};
6259
+ size_t local_work_size[] = {nth, 1, 1};
6260
+
6261
+ backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
6262
+ }
6263
+
6264
  //------------------------------------------------------------------------------
6265
  // Op offloading
6266
  //------------------------------------------------------------------------------
 
6362
  default:
6363
  return false;
6364
  } break;
6365
+ case GGML_OP_GLU:
6366
+ if (!any_on_device) {
6367
+ return false;
6368
+ }
6369
+ func = ggml_cl_glu;
6370
+ break;
6371
  case GGML_OP_CLAMP:
6372
  if (!any_on_device) {
6373
  return false;
ggml/src/ggml-opencl/kernels/glu.cl ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma OPENCL EXTENSION cl_khr_fp16 : enable
2
+
3
+ #define GELU_COEF_A 0.044715f
4
+ #define SQRT_2_OVER_PI 0.79788456080286535587989211986876f
5
+
6
+ //------------------------------------------------------------------------------
7
+ // geglu
8
+ //------------------------------------------------------------------------------
9
+ kernel void kernel_geglu(
10
+ global char * src0,
11
+ ulong offset0,
12
+ global char * src1,
13
+ ulong offset1,
14
+ global char * dst,
15
+ ulong offsetd,
16
+ ulong nb01,
17
+ ulong nb11,
18
+ int ne0,
19
+ ulong nb1,
20
+ int ne00_off,
21
+ int ne10_off
22
+ ) {
23
+ src0 = (global char*)((global char*)src0 + offset0);
24
+ src1 = (global char*)((global char*)src1 + offset1);
25
+ dst = (global char*)((global char*)dst + offsetd);
26
+
27
+ global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
28
+ global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
29
+ global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
30
+
31
+ for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
32
+ const float x0 = src0_row[i0];
33
+ const float x1 = src1_row[i0];
34
+
35
+ const float gelu = 0.5f*x0*(1.0f + tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
36
+
37
+ dst_row[i0] = gelu*x1;
38
+ }
39
+ }
40
+
41
+ kernel void kernel_geglu_f16(
42
+ global char * src0,
43
+ ulong offset0,
44
+ global char * src1,
45
+ ulong offset1,
46
+ global char * dst,
47
+ ulong offsetd,
48
+ ulong nb01,
49
+ ulong nb11,
50
+ int ne0,
51
+ ulong nb1,
52
+ int ne00_off,
53
+ int ne10_off
54
+ ) {
55
+ src0 = (global char*)((global char*)src0 + offset0);
56
+ src1 = (global char*)((global char*)src1 + offset1);
57
+ dst = (global char*)((global char*)dst + offsetd);
58
+
59
+ global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
60
+ global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
61
+ global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1);
62
+
63
+ for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
64
+ const half x0 = src0_row[i0];
65
+ const half x1 = src1_row[i0];
66
+
67
+ const half gelu = 0.5f*x0*(1.0f + tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
68
+
69
+ dst_row[i0] = gelu*x1;
70
+ }
71
+ }
72
+
73
+ //------------------------------------------------------------------------------
74
+ // reglu
75
+ //------------------------------------------------------------------------------
76
+ kernel void kernel_reglu(
77
+ global char * src0,
78
+ ulong offset0,
79
+ global char * src1,
80
+ ulong offset1,
81
+ global char * dst,
82
+ ulong offsetd,
83
+ ulong nb01,
84
+ ulong nb11,
85
+ int ne0,
86
+ ulong nb1,
87
+ int ne00_off,
88
+ int ne10_off
89
+ ) {
90
+ src0 = (global char*)((global char*)src0 + offset0);
91
+ src1 = (global char*)((global char*)src1 + offset1);
92
+ dst = (global char*)((global char*)dst + offsetd);
93
+
94
+ global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
95
+ global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
96
+ global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
97
+
98
+ for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
99
+ const float x0 = src0_row[i0];
100
+ const float x1 = src1_row[i0];
101
+
102
+ dst_row[i0] = x0*x1*(x0 > 0.0f);
103
+ }
104
+ }
105
+
106
+ kernel void kernel_reglu_f16(
107
+ global char * src0,
108
+ ulong offset0,
109
+ global char * src1,
110
+ ulong offset1,
111
+ global char * dst,
112
+ ulong offsetd,
113
+ ulong nb01,
114
+ ulong nb11,
115
+ int ne0,
116
+ ulong nb1,
117
+ int ne00_off,
118
+ int ne10_off
119
+ ) {
120
+ src0 = (global char*)((global char*)src0 + offset0);
121
+ src1 = (global char*)((global char*)src1 + offset1);
122
+ dst = (global char*)((global char*)dst + offsetd);
123
+
124
+ global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
125
+ global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
126
+ global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1);
127
+
128
+ for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
129
+ const half x0 = src0_row[i0];
130
+ const half x1 = src1_row[i0];
131
+
132
+ dst_row[i0] = x0*x1*(x0 > 0.0f);
133
+ }
134
+ }
135
+
136
+ //------------------------------------------------------------------------------
137
+ // swiglu
138
+ //------------------------------------------------------------------------------
139
+ kernel void kernel_swiglu(
140
+ global char * src0,
141
+ ulong offset0,
142
+ global char * src1,
143
+ ulong offset1,
144
+ global char * dst,
145
+ ulong offsetd,
146
+ ulong nb01,
147
+ ulong nb11,
148
+ int ne0,
149
+ ulong nb1,
150
+ int ne00_off,
151
+ int ne10_off
152
+ ) {
153
+ src0 = (global char*)((global char*)src0 + offset0);
154
+ src1 = (global char*)((global char*)src1 + offset1);
155
+ dst = (global char*)((global char*)dst + offsetd);
156
+
157
+ global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
158
+ global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
159
+ global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
160
+
161
+ for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
162
+ const float x0 = src0_row[i0];
163
+ const float x1 = src1_row[i0];
164
+
165
+ const float silu = x0 / (1.0f + exp(-x0));
166
+
167
+ dst_row[i0] = silu*x1;
168
+ }
169
+ }
170
+
171
+ kernel void kernel_swiglu_f16(
172
+ global char * src0,
173
+ ulong offset0,
174
+ global char * src1,
175
+ ulong offset1,
176
+ global char * dst,
177
+ ulong offsetd,
178
+ ulong nb01,
179
+ ulong nb11,
180
+ int ne0,
181
+ ulong nb1,
182
+ int ne00_off,
183
+ int ne10_off
184
+ ) {
185
+ src0 = (global char*)((global char*)src0 + offset0);
186
+ src1 = (global char*)((global char*)src1 + offset1);
187
+ dst = (global char*)((global char*)dst + offsetd);
188
+
189
+ global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
190
+ global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
191
+ global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1);
192
+
193
+ for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
194
+ const half x0 = src0_row[i0];
195
+ const half x1 = src1_row[i0];
196
+
197
+ const half silu = x0 / (1.0f + exp(-x0));
198
+
199
+ dst_row[i0] = silu*x1;
200
+ }
201
+ }