Spaces:
Running
Running
lhez
commited on
Commit
·
d70ff9f
1
Parent(s):
68eb27a
opencl : add GEGLU, REGLU, SWIGLU (llama/14456)
Browse files
ggml/src/ggml-opencl/CMakeLists.txt
CHANGED
|
@@ -65,6 +65,7 @@ set(GGML_OPENCL_KERNELS
|
|
| 65 |
gemv_noshuffle_general
|
| 66 |
gemv_noshuffle
|
| 67 |
get_rows
|
|
|
|
| 68 |
group_norm
|
| 69 |
im2col_f32
|
| 70 |
im2col_f16
|
|
|
|
| 65 |
gemv_noshuffle_general
|
| 66 |
gemv_noshuffle
|
| 67 |
get_rows
|
| 68 |
+
glu
|
| 69 |
group_norm
|
| 70 |
im2col_f32
|
| 71 |
im2col_f16
|
ggml/src/ggml-opencl/ggml-opencl.cpp
CHANGED
|
@@ -351,6 +351,7 @@ struct ggml_backend_opencl_context {
|
|
| 351 |
cl_program program_gemv_noshuffle_general;
|
| 352 |
cl_program program_gemv_noshuffle;
|
| 353 |
cl_program program_get_rows;
|
|
|
|
| 354 |
cl_program program_im2col_f16;
|
| 355 |
cl_program program_im2col_f32;
|
| 356 |
cl_program program_mul_mat_Ab_Bi_8x4;
|
|
@@ -401,6 +402,8 @@ struct ggml_backend_opencl_context {
|
|
| 401 |
cl_kernel kernel_relu;
|
| 402 |
cl_kernel kernel_sigmoid_f32, kernel_sigmoid_f16;
|
| 403 |
cl_kernel kernel_clamp;
|
|
|
|
|
|
|
| 404 |
cl_kernel kernel_norm;
|
| 405 |
cl_kernel kernel_rms_norm;
|
| 406 |
cl_kernel kernel_group_norm;
|
|
@@ -738,6 +741,27 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
|
| 738 |
GGML_LOG_CONT(".");
|
| 739 |
}
|
| 740 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 741 |
// get_rows
|
| 742 |
{
|
| 743 |
#ifdef GGML_OPENCL_EMBED_KERNELS
|
|
@@ -2242,6 +2266,15 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
|
|
| 2242 |
default:
|
| 2243 |
return false;
|
| 2244 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2245 |
case GGML_OP_CLAMP:
|
| 2246 |
return op->src[0]->type == GGML_TYPE_F32;
|
| 2247 |
case GGML_OP_SOFT_MAX:
|
|
@@ -6143,6 +6176,91 @@ static void ggml_cl_sum_rows(ggml_backend_t backend, const ggml_tensor * src0, c
|
|
| 6143 |
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
| 6144 |
}
|
| 6145 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6146 |
//------------------------------------------------------------------------------
|
| 6147 |
// Op offloading
|
| 6148 |
//------------------------------------------------------------------------------
|
|
@@ -6244,6 +6362,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
|
|
| 6244 |
default:
|
| 6245 |
return false;
|
| 6246 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6247 |
case GGML_OP_CLAMP:
|
| 6248 |
if (!any_on_device) {
|
| 6249 |
return false;
|
|
|
|
| 351 |
cl_program program_gemv_noshuffle_general;
|
| 352 |
cl_program program_gemv_noshuffle;
|
| 353 |
cl_program program_get_rows;
|
| 354 |
+
cl_program program_glu;
|
| 355 |
cl_program program_im2col_f16;
|
| 356 |
cl_program program_im2col_f32;
|
| 357 |
cl_program program_mul_mat_Ab_Bi_8x4;
|
|
|
|
| 402 |
cl_kernel kernel_relu;
|
| 403 |
cl_kernel kernel_sigmoid_f32, kernel_sigmoid_f16;
|
| 404 |
cl_kernel kernel_clamp;
|
| 405 |
+
cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu,
|
| 406 |
+
kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16;
|
| 407 |
cl_kernel kernel_norm;
|
| 408 |
cl_kernel kernel_rms_norm;
|
| 409 |
cl_kernel kernel_group_norm;
|
|
|
|
| 741 |
GGML_LOG_CONT(".");
|
| 742 |
}
|
| 743 |
|
| 744 |
+
// glu
|
| 745 |
+
{
|
| 746 |
+
#ifdef GGML_OPENCL_EMBED_KERNELS
|
| 747 |
+
const std::string kernel_src {
|
| 748 |
+
#include "glu.cl.h"
|
| 749 |
+
};
|
| 750 |
+
#else
|
| 751 |
+
const std::string kernel_src = read_file("glu.cl");
|
| 752 |
+
#endif
|
| 753 |
+
backend_ctx->program_glu =
|
| 754 |
+
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
| 755 |
+
|
| 756 |
+
CL_CHECK((backend_ctx->kernel_geglu = clCreateKernel(backend_ctx->program_glu, "kernel_geglu", &err), err));
|
| 757 |
+
CL_CHECK((backend_ctx->kernel_reglu = clCreateKernel(backend_ctx->program_glu, "kernel_reglu", &err), err));
|
| 758 |
+
CL_CHECK((backend_ctx->kernel_swiglu = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu", &err), err));
|
| 759 |
+
CL_CHECK((backend_ctx->kernel_geglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_f16", &err), err));
|
| 760 |
+
CL_CHECK((backend_ctx->kernel_reglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_reglu_f16", &err), err));
|
| 761 |
+
CL_CHECK((backend_ctx->kernel_swiglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu_f16", &err), err));
|
| 762 |
+
GGML_LOG_CONT(".");
|
| 763 |
+
}
|
| 764 |
+
|
| 765 |
// get_rows
|
| 766 |
{
|
| 767 |
#ifdef GGML_OPENCL_EMBED_KERNELS
|
|
|
|
| 2266 |
default:
|
| 2267 |
return false;
|
| 2268 |
}
|
| 2269 |
+
case GGML_OP_GLU:
|
| 2270 |
+
switch (ggml_get_glu_op(op)) {
|
| 2271 |
+
case GGML_GLU_OP_GEGLU:
|
| 2272 |
+
case GGML_GLU_OP_REGLU:
|
| 2273 |
+
case GGML_GLU_OP_SWIGLU:
|
| 2274 |
+
return ggml_is_contiguous_1(op->src[0]) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
|
| 2275 |
+
default:
|
| 2276 |
+
return false;
|
| 2277 |
+
}
|
| 2278 |
case GGML_OP_CLAMP:
|
| 2279 |
return op->src[0]->type == GGML_TYPE_F32;
|
| 2280 |
case GGML_OP_SOFT_MAX:
|
|
|
|
| 6176 |
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
| 6177 |
}
|
| 6178 |
|
| 6179 |
+
static void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 6180 |
+
GGML_ASSERT(src0);
|
| 6181 |
+
GGML_ASSERT(src0->extra);
|
| 6182 |
+
GGML_ASSERT(dst);
|
| 6183 |
+
GGML_ASSERT(dst->extra);
|
| 6184 |
+
|
| 6185 |
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
| 6186 |
+
|
| 6187 |
+
if (src1) {
|
| 6188 |
+
GGML_ASSERT(src1);
|
| 6189 |
+
GGML_ASSERT(src1->extra);
|
| 6190 |
+
GGML_ASSERT(ggml_are_same_shape(src0, src1));
|
| 6191 |
+
}
|
| 6192 |
+
|
| 6193 |
+
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
| 6194 |
+
|
| 6195 |
+
cl_kernel kernel;
|
| 6196 |
+
switch (ggml_get_glu_op(dst)) {
|
| 6197 |
+
case GGML_GLU_OP_GEGLU:
|
| 6198 |
+
if (dst->type == GGML_TYPE_F32) {
|
| 6199 |
+
kernel = backend_ctx->kernel_geglu;
|
| 6200 |
+
} else {
|
| 6201 |
+
kernel = backend_ctx->kernel_geglu_f16;
|
| 6202 |
+
}
|
| 6203 |
+
break;
|
| 6204 |
+
case GGML_GLU_OP_REGLU:
|
| 6205 |
+
if (dst->type == GGML_TYPE_F32) {
|
| 6206 |
+
kernel = backend_ctx->kernel_reglu;
|
| 6207 |
+
} else {
|
| 6208 |
+
kernel = backend_ctx->kernel_reglu_f16;
|
| 6209 |
+
}
|
| 6210 |
+
break;
|
| 6211 |
+
case GGML_GLU_OP_SWIGLU:
|
| 6212 |
+
if (dst->type == GGML_TYPE_F32) {
|
| 6213 |
+
kernel = backend_ctx->kernel_swiglu;
|
| 6214 |
+
} else {
|
| 6215 |
+
kernel = backend_ctx->kernel_swiglu_f16;
|
| 6216 |
+
}
|
| 6217 |
+
break;
|
| 6218 |
+
default:
|
| 6219 |
+
GGML_ABORT("Unsupported glu op");
|
| 6220 |
+
}
|
| 6221 |
+
|
| 6222 |
+
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
|
| 6223 |
+
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
|
| 6224 |
+
|
| 6225 |
+
ggml_tensor_extra_cl * extra1 = src1 ? (ggml_tensor_extra_cl *)src1->extra : nullptr;
|
| 6226 |
+
|
| 6227 |
+
cl_ulong offset0 = extra0->offset + src0->view_offs;
|
| 6228 |
+
cl_ulong offsetd = extrad->offset + dst->view_offs;
|
| 6229 |
+
|
| 6230 |
+
cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0;
|
| 6231 |
+
|
| 6232 |
+
const int ne0 = dst->ne[0];
|
| 6233 |
+
|
| 6234 |
+
const cl_ulong nb01 = src0->nb[1];
|
| 6235 |
+
const cl_ulong nb11 = src1 ? src1->nb[1] : nb01;
|
| 6236 |
+
|
| 6237 |
+
const cl_ulong nb1 = dst->nb[1];
|
| 6238 |
+
|
| 6239 |
+
const int swp = ((const int32_t *) dst->op_params)[1];
|
| 6240 |
+
const int ne00_off = src1 ? 0 : (swp ? ne0 : 0);
|
| 6241 |
+
const int ne10_off = src1 ? 0 : (swp ? 0 : ne0);
|
| 6242 |
+
|
| 6243 |
+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
|
| 6244 |
+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
|
| 6245 |
+
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), src1 ? &extra1->data_device : &extra0->data_device));
|
| 6246 |
+
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
|
| 6247 |
+
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
|
| 6248 |
+
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
|
| 6249 |
+
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb01));
|
| 6250 |
+
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb11));
|
| 6251 |
+
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne0));
|
| 6252 |
+
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb1));
|
| 6253 |
+
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne00_off));
|
| 6254 |
+
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne10_off));
|
| 6255 |
+
|
| 6256 |
+
const size_t nrows = ggml_nrows(src0);
|
| 6257 |
+
size_t nth = 512;
|
| 6258 |
+
size_t global_work_size[] = {nrows*nth, 1, 1};
|
| 6259 |
+
size_t local_work_size[] = {nth, 1, 1};
|
| 6260 |
+
|
| 6261 |
+
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
| 6262 |
+
}
|
| 6263 |
+
|
| 6264 |
//------------------------------------------------------------------------------
|
| 6265 |
// Op offloading
|
| 6266 |
//------------------------------------------------------------------------------
|
|
|
|
| 6362 |
default:
|
| 6363 |
return false;
|
| 6364 |
} break;
|
| 6365 |
+
case GGML_OP_GLU:
|
| 6366 |
+
if (!any_on_device) {
|
| 6367 |
+
return false;
|
| 6368 |
+
}
|
| 6369 |
+
func = ggml_cl_glu;
|
| 6370 |
+
break;
|
| 6371 |
case GGML_OP_CLAMP:
|
| 6372 |
if (!any_on_device) {
|
| 6373 |
return false;
|
ggml/src/ggml-opencl/kernels/glu.cl
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
| 2 |
+
|
| 3 |
+
#define GELU_COEF_A 0.044715f
|
| 4 |
+
#define SQRT_2_OVER_PI 0.79788456080286535587989211986876f
|
| 5 |
+
|
| 6 |
+
//------------------------------------------------------------------------------
|
| 7 |
+
// geglu
|
| 8 |
+
//------------------------------------------------------------------------------
|
| 9 |
+
kernel void kernel_geglu(
|
| 10 |
+
global char * src0,
|
| 11 |
+
ulong offset0,
|
| 12 |
+
global char * src1,
|
| 13 |
+
ulong offset1,
|
| 14 |
+
global char * dst,
|
| 15 |
+
ulong offsetd,
|
| 16 |
+
ulong nb01,
|
| 17 |
+
ulong nb11,
|
| 18 |
+
int ne0,
|
| 19 |
+
ulong nb1,
|
| 20 |
+
int ne00_off,
|
| 21 |
+
int ne10_off
|
| 22 |
+
) {
|
| 23 |
+
src0 = (global char*)((global char*)src0 + offset0);
|
| 24 |
+
src1 = (global char*)((global char*)src1 + offset1);
|
| 25 |
+
dst = (global char*)((global char*)dst + offsetd);
|
| 26 |
+
|
| 27 |
+
global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
|
| 28 |
+
global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
|
| 29 |
+
global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
|
| 30 |
+
|
| 31 |
+
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
|
| 32 |
+
const float x0 = src0_row[i0];
|
| 33 |
+
const float x1 = src1_row[i0];
|
| 34 |
+
|
| 35 |
+
const float gelu = 0.5f*x0*(1.0f + tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
|
| 36 |
+
|
| 37 |
+
dst_row[i0] = gelu*x1;
|
| 38 |
+
}
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
kernel void kernel_geglu_f16(
|
| 42 |
+
global char * src0,
|
| 43 |
+
ulong offset0,
|
| 44 |
+
global char * src1,
|
| 45 |
+
ulong offset1,
|
| 46 |
+
global char * dst,
|
| 47 |
+
ulong offsetd,
|
| 48 |
+
ulong nb01,
|
| 49 |
+
ulong nb11,
|
| 50 |
+
int ne0,
|
| 51 |
+
ulong nb1,
|
| 52 |
+
int ne00_off,
|
| 53 |
+
int ne10_off
|
| 54 |
+
) {
|
| 55 |
+
src0 = (global char*)((global char*)src0 + offset0);
|
| 56 |
+
src1 = (global char*)((global char*)src1 + offset1);
|
| 57 |
+
dst = (global char*)((global char*)dst + offsetd);
|
| 58 |
+
|
| 59 |
+
global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
|
| 60 |
+
global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
|
| 61 |
+
global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1);
|
| 62 |
+
|
| 63 |
+
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
|
| 64 |
+
const half x0 = src0_row[i0];
|
| 65 |
+
const half x1 = src1_row[i0];
|
| 66 |
+
|
| 67 |
+
const half gelu = 0.5f*x0*(1.0f + tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
|
| 68 |
+
|
| 69 |
+
dst_row[i0] = gelu*x1;
|
| 70 |
+
}
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
//------------------------------------------------------------------------------
|
| 74 |
+
// reglu
|
| 75 |
+
//------------------------------------------------------------------------------
|
| 76 |
+
kernel void kernel_reglu(
|
| 77 |
+
global char * src0,
|
| 78 |
+
ulong offset0,
|
| 79 |
+
global char * src1,
|
| 80 |
+
ulong offset1,
|
| 81 |
+
global char * dst,
|
| 82 |
+
ulong offsetd,
|
| 83 |
+
ulong nb01,
|
| 84 |
+
ulong nb11,
|
| 85 |
+
int ne0,
|
| 86 |
+
ulong nb1,
|
| 87 |
+
int ne00_off,
|
| 88 |
+
int ne10_off
|
| 89 |
+
) {
|
| 90 |
+
src0 = (global char*)((global char*)src0 + offset0);
|
| 91 |
+
src1 = (global char*)((global char*)src1 + offset1);
|
| 92 |
+
dst = (global char*)((global char*)dst + offsetd);
|
| 93 |
+
|
| 94 |
+
global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
|
| 95 |
+
global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
|
| 96 |
+
global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
|
| 97 |
+
|
| 98 |
+
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
|
| 99 |
+
const float x0 = src0_row[i0];
|
| 100 |
+
const float x1 = src1_row[i0];
|
| 101 |
+
|
| 102 |
+
dst_row[i0] = x0*x1*(x0 > 0.0f);
|
| 103 |
+
}
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
kernel void kernel_reglu_f16(
|
| 107 |
+
global char * src0,
|
| 108 |
+
ulong offset0,
|
| 109 |
+
global char * src1,
|
| 110 |
+
ulong offset1,
|
| 111 |
+
global char * dst,
|
| 112 |
+
ulong offsetd,
|
| 113 |
+
ulong nb01,
|
| 114 |
+
ulong nb11,
|
| 115 |
+
int ne0,
|
| 116 |
+
ulong nb1,
|
| 117 |
+
int ne00_off,
|
| 118 |
+
int ne10_off
|
| 119 |
+
) {
|
| 120 |
+
src0 = (global char*)((global char*)src0 + offset0);
|
| 121 |
+
src1 = (global char*)((global char*)src1 + offset1);
|
| 122 |
+
dst = (global char*)((global char*)dst + offsetd);
|
| 123 |
+
|
| 124 |
+
global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
|
| 125 |
+
global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
|
| 126 |
+
global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1);
|
| 127 |
+
|
| 128 |
+
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
|
| 129 |
+
const half x0 = src0_row[i0];
|
| 130 |
+
const half x1 = src1_row[i0];
|
| 131 |
+
|
| 132 |
+
dst_row[i0] = x0*x1*(x0 > 0.0f);
|
| 133 |
+
}
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
//------------------------------------------------------------------------------
|
| 137 |
+
// swiglu
|
| 138 |
+
//------------------------------------------------------------------------------
|
| 139 |
+
kernel void kernel_swiglu(
|
| 140 |
+
global char * src0,
|
| 141 |
+
ulong offset0,
|
| 142 |
+
global char * src1,
|
| 143 |
+
ulong offset1,
|
| 144 |
+
global char * dst,
|
| 145 |
+
ulong offsetd,
|
| 146 |
+
ulong nb01,
|
| 147 |
+
ulong nb11,
|
| 148 |
+
int ne0,
|
| 149 |
+
ulong nb1,
|
| 150 |
+
int ne00_off,
|
| 151 |
+
int ne10_off
|
| 152 |
+
) {
|
| 153 |
+
src0 = (global char*)((global char*)src0 + offset0);
|
| 154 |
+
src1 = (global char*)((global char*)src1 + offset1);
|
| 155 |
+
dst = (global char*)((global char*)dst + offsetd);
|
| 156 |
+
|
| 157 |
+
global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
|
| 158 |
+
global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
|
| 159 |
+
global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
|
| 160 |
+
|
| 161 |
+
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
|
| 162 |
+
const float x0 = src0_row[i0];
|
| 163 |
+
const float x1 = src1_row[i0];
|
| 164 |
+
|
| 165 |
+
const float silu = x0 / (1.0f + exp(-x0));
|
| 166 |
+
|
| 167 |
+
dst_row[i0] = silu*x1;
|
| 168 |
+
}
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
kernel void kernel_swiglu_f16(
|
| 172 |
+
global char * src0,
|
| 173 |
+
ulong offset0,
|
| 174 |
+
global char * src1,
|
| 175 |
+
ulong offset1,
|
| 176 |
+
global char * dst,
|
| 177 |
+
ulong offsetd,
|
| 178 |
+
ulong nb01,
|
| 179 |
+
ulong nb11,
|
| 180 |
+
int ne0,
|
| 181 |
+
ulong nb1,
|
| 182 |
+
int ne00_off,
|
| 183 |
+
int ne10_off
|
| 184 |
+
) {
|
| 185 |
+
src0 = (global char*)((global char*)src0 + offset0);
|
| 186 |
+
src1 = (global char*)((global char*)src1 + offset1);
|
| 187 |
+
dst = (global char*)((global char*)dst + offsetd);
|
| 188 |
+
|
| 189 |
+
global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
|
| 190 |
+
global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
|
| 191 |
+
global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1);
|
| 192 |
+
|
| 193 |
+
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
|
| 194 |
+
const half x0 = src0_row[i0];
|
| 195 |
+
const half x1 = src1_row[i0];
|
| 196 |
+
|
| 197 |
+
const half silu = x0 / (1.0f + exp(-x0));
|
| 198 |
+
|
| 199 |
+
dst_row[i0] = silu*x1;
|
| 200 |
+
}
|
| 201 |
+
}
|