Spaces:
Running
ggml : implement REGLU/GEGLU/SWIGLU ops (llama/14158)
Browse files* implement unary REGLU/GEGLU/SWIGLU cpu ops
* relax constraints
* duplicate shape of source
* fix ggml_vec_geglu_f16
* special case gated ops
* implement unary REGLU/GEGLU/SWIGLU cuda ops
* tighten constraints again
* refactor into GGML_GLU_OP
* metal : add glu kernels
ggml-ci
* add CUDA_GLU_BLOCK_SIZE [no ci]
* more constraints and use 64bit ints
ggml-ci
* 64bit multiplication [no ci]
* implement swapped variants (cpu/cuda)
* update comment [no ci]
ggml-ci
* Vulkan: Add GLU ops and shaders
* SYCL: Implement fused kernel GEGLU, SWIGLU and REGLU for single up+gate
* ggml : implement GLU for split up/gate (llama/14181)
* implement GLU for split up/gate
* add tests for ggml_glu_split
* Vulkan: Implement glu_split logic and shader support
* add split to logging [no ci]
* SYCL: refactor element_size ops and add split up and gate support to gated kernels
* SYCL: switch GEGLU to use tanh approximation
---------
Co-authored-by: 0cc4m <[email protected]>
Co-authored-by: Akarshan <[email protected]>
* GGML: increase OP count in assertion
* Refactor: Optimize SYCL element-wise operations with unary function inlining
This commit refactors the SYCL element-wise operations to improve performance by:
- Inlining unary operations (sgn, abs, elu, gelu, silu, etc.) to reduce kernel launch overhead.
- Introducing helper functions `op_xxx` for each unary operation to encapsulate the logic.
- Replacing direct kernel calls with calls to these inlined functions.
- Using `__dpct_inline__` to encourage compiler inlining.
- Minor code cleanup and consistency improvements.
The changes aim to reduce kernel launch overhead and improve the overall efficiency of element-wise operations on SYCL devices.
* vulkan: Increase workgroup size for GLU, for performance (llama/14345)
* vulkan: Increase workgroup size for GLU, for performance
* vulkan: change GLU shaders to do one element per invocation rather than one row per workgroup
* merge fix
* metal : add support for split and swap
ggml-ci
---------
Co-authored-by: Georgi Gerganov <[email protected]>
Co-authored-by: 0cc4m <[email protected]>
Co-authored-by: Akarshan <[email protected]>
Co-authored-by: Jeff Bolz <[email protected]>
- ggml/include/ggml.h +69 -0
- ggml/src/ggml-cpu/ggml-cpu.c +16 -0
- ggml/src/ggml-cpu/ops.cpp +457 -0
- ggml/src/ggml-cpu/ops.h +1 -0
- ggml/src/ggml-cpu/vec.cpp +24 -0
- ggml/src/ggml-cpu/vec.h +54 -0
- ggml/src/ggml-cuda/ggml-cuda.cu +25 -0
- ggml/src/ggml-cuda/unary.cu +89 -0
- ggml/src/ggml-cuda/unary.cuh +7 -0
- ggml/src/ggml-metal/ggml-metal-impl.h +11 -0
- ggml/src/ggml-metal/ggml-metal.m +71 -0
- ggml/src/ggml-metal/ggml-metal.metal +64 -0
- ggml/src/ggml-sycl/element_wise.cpp +635 -1026
- ggml/src/ggml-sycl/element_wise.hpp +16 -9
- ggml/src/ggml-sycl/ggml-sycl.cpp +25 -0
- ggml/src/ggml-vulkan/ggml-vulkan.cpp +114 -3
- ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
- ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +7 -0
- ggml/src/ggml.c +136 -2
|
@@ -520,6 +520,8 @@ extern "C" {
|
|
| 520 |
GGML_OP_CROSS_ENTROPY_LOSS_BACK,
|
| 521 |
GGML_OP_OPT_STEP_ADAMW,
|
| 522 |
|
|
|
|
|
|
|
| 523 |
GGML_OP_COUNT,
|
| 524 |
};
|
| 525 |
|
|
@@ -543,6 +545,14 @@ extern "C" {
|
|
| 543 |
GGML_UNARY_OP_COUNT,
|
| 544 |
};
|
| 545 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 546 |
enum ggml_object_type {
|
| 547 |
GGML_OBJECT_TYPE_TENSOR,
|
| 548 |
GGML_OBJECT_TYPE_GRAPH,
|
|
@@ -658,6 +668,7 @@ extern "C" {
|
|
| 658 |
GGML_API const char * ggml_op_symbol(enum ggml_op op);
|
| 659 |
|
| 660 |
GGML_API const char * ggml_unary_op_name(enum ggml_unary_op op);
|
|
|
|
| 661 |
GGML_API const char * ggml_op_desc(const struct ggml_tensor * t); // unary or op name
|
| 662 |
|
| 663 |
GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor);
|
|
@@ -762,6 +773,7 @@ extern "C" {
|
|
| 762 |
GGML_API void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3);
|
| 763 |
|
| 764 |
GGML_API enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor);
|
|
|
|
| 765 |
|
| 766 |
GGML_API void * ggml_get_data (const struct ggml_tensor * tensor);
|
| 767 |
GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
|
|
@@ -1090,6 +1102,63 @@ extern "C" {
|
|
| 1090 |
struct ggml_context * ctx,
|
| 1091 |
struct ggml_tensor * a);
|
| 1092 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1093 |
// normalize along rows
|
| 1094 |
GGML_API struct ggml_tensor * ggml_norm(
|
| 1095 |
struct ggml_context * ctx,
|
|
|
|
| 520 |
GGML_OP_CROSS_ENTROPY_LOSS_BACK,
|
| 521 |
GGML_OP_OPT_STEP_ADAMW,
|
| 522 |
|
| 523 |
+
GGML_OP_GLU,
|
| 524 |
+
|
| 525 |
GGML_OP_COUNT,
|
| 526 |
};
|
| 527 |
|
|
|
|
| 545 |
GGML_UNARY_OP_COUNT,
|
| 546 |
};
|
| 547 |
|
| 548 |
+
enum ggml_glu_op {
|
| 549 |
+
GGML_GLU_OP_REGLU,
|
| 550 |
+
GGML_GLU_OP_GEGLU,
|
| 551 |
+
GGML_GLU_OP_SWIGLU,
|
| 552 |
+
|
| 553 |
+
GGML_GLU_OP_COUNT,
|
| 554 |
+
};
|
| 555 |
+
|
| 556 |
enum ggml_object_type {
|
| 557 |
GGML_OBJECT_TYPE_TENSOR,
|
| 558 |
GGML_OBJECT_TYPE_GRAPH,
|
|
|
|
| 668 |
GGML_API const char * ggml_op_symbol(enum ggml_op op);
|
| 669 |
|
| 670 |
GGML_API const char * ggml_unary_op_name(enum ggml_unary_op op);
|
| 671 |
+
GGML_API const char * ggml_glu_op_name(enum ggml_glu_op op);
|
| 672 |
GGML_API const char * ggml_op_desc(const struct ggml_tensor * t); // unary or op name
|
| 673 |
|
| 674 |
GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor);
|
|
|
|
| 773 |
GGML_API void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3);
|
| 774 |
|
| 775 |
GGML_API enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor);
|
| 776 |
+
GGML_API enum ggml_glu_op ggml_get_glu_op(const struct ggml_tensor * tensor);
|
| 777 |
|
| 778 |
GGML_API void * ggml_get_data (const struct ggml_tensor * tensor);
|
| 779 |
GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
|
|
|
|
| 1102 |
struct ggml_context * ctx,
|
| 1103 |
struct ggml_tensor * a);
|
| 1104 |
|
| 1105 |
+
// gated linear unit ops
|
| 1106 |
+
// A: n columns, r rows,
|
| 1107 |
+
// result is n / 2 columns, r rows,
|
| 1108 |
+
// expects gate in second half of row, unless swapped is true
|
| 1109 |
+
GGML_API struct ggml_tensor * ggml_glu(
|
| 1110 |
+
struct ggml_context * ctx,
|
| 1111 |
+
struct ggml_tensor * a,
|
| 1112 |
+
enum ggml_glu_op op,
|
| 1113 |
+
bool swapped);
|
| 1114 |
+
|
| 1115 |
+
GGML_API struct ggml_tensor * ggml_reglu(
|
| 1116 |
+
struct ggml_context * ctx,
|
| 1117 |
+
struct ggml_tensor * a);
|
| 1118 |
+
|
| 1119 |
+
GGML_API struct ggml_tensor * ggml_reglu_swapped(
|
| 1120 |
+
struct ggml_context * ctx,
|
| 1121 |
+
struct ggml_tensor * a);
|
| 1122 |
+
|
| 1123 |
+
GGML_API struct ggml_tensor * ggml_geglu(
|
| 1124 |
+
struct ggml_context * ctx,
|
| 1125 |
+
struct ggml_tensor * a);
|
| 1126 |
+
|
| 1127 |
+
GGML_API struct ggml_tensor * ggml_geglu_swapped(
|
| 1128 |
+
struct ggml_context * ctx,
|
| 1129 |
+
struct ggml_tensor * a);
|
| 1130 |
+
|
| 1131 |
+
GGML_API struct ggml_tensor * ggml_swiglu(
|
| 1132 |
+
struct ggml_context * ctx,
|
| 1133 |
+
struct ggml_tensor * a);
|
| 1134 |
+
|
| 1135 |
+
GGML_API struct ggml_tensor * ggml_swiglu_swapped(
|
| 1136 |
+
struct ggml_context * ctx,
|
| 1137 |
+
struct ggml_tensor * a);
|
| 1138 |
+
|
| 1139 |
+
// A: n columns, r rows,
|
| 1140 |
+
// B: n columns, r rows,
|
| 1141 |
+
GGML_API struct ggml_tensor * ggml_glu_split(
|
| 1142 |
+
struct ggml_context * ctx,
|
| 1143 |
+
struct ggml_tensor * a,
|
| 1144 |
+
struct ggml_tensor * b,
|
| 1145 |
+
enum ggml_glu_op op);
|
| 1146 |
+
|
| 1147 |
+
GGML_API struct ggml_tensor * ggml_reglu_split(
|
| 1148 |
+
struct ggml_context * ctx,
|
| 1149 |
+
struct ggml_tensor * a,
|
| 1150 |
+
struct ggml_tensor * b);
|
| 1151 |
+
|
| 1152 |
+
GGML_API struct ggml_tensor * ggml_geglu_split(
|
| 1153 |
+
struct ggml_context * ctx,
|
| 1154 |
+
struct ggml_tensor * a,
|
| 1155 |
+
struct ggml_tensor * b);
|
| 1156 |
+
|
| 1157 |
+
GGML_API struct ggml_tensor * ggml_swiglu_split(
|
| 1158 |
+
struct ggml_context * ctx,
|
| 1159 |
+
struct ggml_tensor * a,
|
| 1160 |
+
struct ggml_tensor * b);
|
| 1161 |
+
|
| 1162 |
// normalize along rows
|
| 1163 |
GGML_API struct ggml_tensor * ggml_norm(
|
| 1164 |
struct ggml_context * ctx,
|
|
@@ -1949,6 +1949,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
| 1949 |
{
|
| 1950 |
ggml_compute_forward_unary(params, tensor);
|
| 1951 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1952 |
case GGML_OP_GET_REL_POS:
|
| 1953 |
{
|
| 1954 |
ggml_compute_forward_get_rel_pos(params, tensor);
|
|
@@ -2159,6 +2163,18 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|
| 2159 |
GGML_ABORT("fatal error");
|
| 2160 |
}
|
| 2161 |
break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2162 |
case GGML_OP_SILU_BACK:
|
| 2163 |
case GGML_OP_MUL:
|
| 2164 |
case GGML_OP_DIV:
|
|
|
|
| 1949 |
{
|
| 1950 |
ggml_compute_forward_unary(params, tensor);
|
| 1951 |
} break;
|
| 1952 |
+
case GGML_OP_GLU:
|
| 1953 |
+
{
|
| 1954 |
+
ggml_compute_forward_glu(params, tensor);
|
| 1955 |
+
} break;
|
| 1956 |
case GGML_OP_GET_REL_POS:
|
| 1957 |
{
|
| 1958 |
ggml_compute_forward_get_rel_pos(params, tensor);
|
|
|
|
| 2163 |
GGML_ABORT("fatal error");
|
| 2164 |
}
|
| 2165 |
break;
|
| 2166 |
+
case GGML_OP_GLU:
|
| 2167 |
+
switch (ggml_get_glu_op(node)) {
|
| 2168 |
+
case GGML_GLU_OP_REGLU:
|
| 2169 |
+
case GGML_GLU_OP_GEGLU:
|
| 2170 |
+
case GGML_GLU_OP_SWIGLU:
|
| 2171 |
+
{
|
| 2172 |
+
n_tasks = n_threads;
|
| 2173 |
+
} break;
|
| 2174 |
+
default:
|
| 2175 |
+
GGML_ABORT("fatal error");
|
| 2176 |
+
}
|
| 2177 |
+
break;
|
| 2178 |
case GGML_OP_SILU_BACK:
|
| 2179 |
case GGML_OP_MUL:
|
| 2180 |
case GGML_OP_DIV:
|
|
@@ -3184,6 +3184,435 @@ void ggml_compute_forward_silu_back(
|
|
| 3184 |
}
|
| 3185 |
}
|
| 3186 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3187 |
// ggml_compute_forward_norm
|
| 3188 |
|
| 3189 |
static void ggml_compute_forward_norm_f32(
|
|
@@ -8057,6 +8486,34 @@ void ggml_compute_forward_unary(
|
|
| 8057 |
}
|
| 8058 |
}
|
| 8059 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8060 |
// ggml_compute_forward_get_rel_pos
|
| 8061 |
|
| 8062 |
static void ggml_compute_forward_get_rel_pos_f16(
|
|
|
|
| 3184 |
}
|
| 3185 |
}
|
| 3186 |
|
| 3187 |
+
// ggml_compute_forward_reglu
|
| 3188 |
+
|
| 3189 |
+
static void ggml_compute_forward_reglu_f32(
|
| 3190 |
+
const ggml_compute_params * params,
|
| 3191 |
+
ggml_tensor * dst) {
|
| 3192 |
+
|
| 3193 |
+
const ggml_tensor * src0 = dst->src[0];
|
| 3194 |
+
const ggml_tensor * src1 = dst->src[1];
|
| 3195 |
+
char * src0_d = (char *) src0->data;
|
| 3196 |
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
| 3197 |
+
const size_t src0_o = src0->nb[1];
|
| 3198 |
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
| 3199 |
+
|
| 3200 |
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
| 3201 |
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
| 3202 |
+
|
| 3203 |
+
if (src1) {
|
| 3204 |
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
| 3205 |
+
GGML_ASSERT(src0->type == src1->type);
|
| 3206 |
+
}
|
| 3207 |
+
|
| 3208 |
+
const int ith = params->ith;
|
| 3209 |
+
const int nth = params->nth;
|
| 3210 |
+
|
| 3211 |
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
| 3212 |
+
const int nr = ggml_nrows(src0);
|
| 3213 |
+
|
| 3214 |
+
GGML_ASSERT(dst->ne[0] == nc);
|
| 3215 |
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
| 3216 |
+
|
| 3217 |
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
| 3218 |
+
|
| 3219 |
+
// rows per thread
|
| 3220 |
+
const int dr = (nr + nth - 1)/nth;
|
| 3221 |
+
|
| 3222 |
+
// row range for this thread
|
| 3223 |
+
const int ir0 = dr*ith;
|
| 3224 |
+
const int ir1 = MIN(ir0 + dr, nr);
|
| 3225 |
+
|
| 3226 |
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
| 3227 |
+
float * src0_p = (float *) (src0_d + i1*src0_o);
|
| 3228 |
+
float * src1_p = (float *) (src1_d + i1*src1_o);
|
| 3229 |
+
|
| 3230 |
+
if (!src1) {
|
| 3231 |
+
src0_p += swapped ? nc : 0;
|
| 3232 |
+
src1_p += swapped ? 0 : nc;
|
| 3233 |
+
}
|
| 3234 |
+
|
| 3235 |
+
ggml_vec_reglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
| 3236 |
+
|
| 3237 |
+
#ifndef NDEBUG
|
| 3238 |
+
for (int k = 0; k < nc; k++) {
|
| 3239 |
+
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
| 3240 |
+
GGML_UNUSED(x);
|
| 3241 |
+
assert(!isnan(x));
|
| 3242 |
+
assert(!isinf(x));
|
| 3243 |
+
}
|
| 3244 |
+
#endif
|
| 3245 |
+
}
|
| 3246 |
+
}
|
| 3247 |
+
|
| 3248 |
+
static void ggml_compute_forward_reglu_f16(
|
| 3249 |
+
const ggml_compute_params * params,
|
| 3250 |
+
ggml_tensor * dst) {
|
| 3251 |
+
|
| 3252 |
+
const ggml_tensor * src0 = dst->src[0];
|
| 3253 |
+
const ggml_tensor * src1 = dst->src[1];
|
| 3254 |
+
char * src0_d = (char *) src0->data;
|
| 3255 |
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
| 3256 |
+
const size_t src0_o = src0->nb[1];
|
| 3257 |
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
| 3258 |
+
|
| 3259 |
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
| 3260 |
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
| 3261 |
+
|
| 3262 |
+
if (src1) {
|
| 3263 |
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
| 3264 |
+
GGML_ASSERT(src0->type == src1->type);
|
| 3265 |
+
}
|
| 3266 |
+
|
| 3267 |
+
const int ith = params->ith;
|
| 3268 |
+
const int nth = params->nth;
|
| 3269 |
+
|
| 3270 |
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
| 3271 |
+
const int nr = ggml_nrows(src0);
|
| 3272 |
+
|
| 3273 |
+
GGML_ASSERT(dst->ne[0] == nc);
|
| 3274 |
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
| 3275 |
+
|
| 3276 |
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
| 3277 |
+
|
| 3278 |
+
// rows per thread
|
| 3279 |
+
const int dr = (nr + nth - 1)/nth;
|
| 3280 |
+
|
| 3281 |
+
// row range for this thread
|
| 3282 |
+
const int ir0 = dr*ith;
|
| 3283 |
+
const int ir1 = MIN(ir0 + dr, nr);
|
| 3284 |
+
|
| 3285 |
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
| 3286 |
+
ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
|
| 3287 |
+
ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
|
| 3288 |
+
|
| 3289 |
+
if (!src1) {
|
| 3290 |
+
src0_p += swapped ? nc : 0;
|
| 3291 |
+
src1_p += swapped ? 0 : nc;
|
| 3292 |
+
}
|
| 3293 |
+
|
| 3294 |
+
ggml_vec_reglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
| 3295 |
+
|
| 3296 |
+
#ifndef NDEBUG
|
| 3297 |
+
for (int k = 0; k < nc; k++) {
|
| 3298 |
+
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
| 3299 |
+
const float v = GGML_FP16_TO_FP32(x);
|
| 3300 |
+
GGML_UNUSED(v);
|
| 3301 |
+
assert(!isnan(v));
|
| 3302 |
+
assert(!isinf(v));
|
| 3303 |
+
}
|
| 3304 |
+
#endif
|
| 3305 |
+
}
|
| 3306 |
+
}
|
| 3307 |
+
|
| 3308 |
+
static void ggml_compute_forward_reglu(
|
| 3309 |
+
const ggml_compute_params * params,
|
| 3310 |
+
ggml_tensor * dst) {
|
| 3311 |
+
|
| 3312 |
+
const ggml_tensor * src0 = dst->src[0];
|
| 3313 |
+
|
| 3314 |
+
switch (src0->type) {
|
| 3315 |
+
case GGML_TYPE_F32:
|
| 3316 |
+
{
|
| 3317 |
+
ggml_compute_forward_reglu_f32(params, dst);
|
| 3318 |
+
} break;
|
| 3319 |
+
case GGML_TYPE_F16:
|
| 3320 |
+
{
|
| 3321 |
+
ggml_compute_forward_reglu_f16(params, dst);
|
| 3322 |
+
} break;
|
| 3323 |
+
default:
|
| 3324 |
+
{
|
| 3325 |
+
GGML_ABORT("fatal error");
|
| 3326 |
+
}
|
| 3327 |
+
}
|
| 3328 |
+
}
|
| 3329 |
+
|
| 3330 |
+
// ggml_compute_forward_geglu
|
| 3331 |
+
|
| 3332 |
+
static void ggml_compute_forward_geglu_f32(
|
| 3333 |
+
const ggml_compute_params * params,
|
| 3334 |
+
ggml_tensor * dst) {
|
| 3335 |
+
|
| 3336 |
+
const ggml_tensor * src0 = dst->src[0];
|
| 3337 |
+
const ggml_tensor * src1 = dst->src[1];
|
| 3338 |
+
char * src0_d = (char *) src0->data;
|
| 3339 |
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
| 3340 |
+
const size_t src0_o = src0->nb[1];
|
| 3341 |
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
| 3342 |
+
|
| 3343 |
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
| 3344 |
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
| 3345 |
+
|
| 3346 |
+
if (src1) {
|
| 3347 |
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
| 3348 |
+
GGML_ASSERT(src0->type == src1->type);
|
| 3349 |
+
}
|
| 3350 |
+
|
| 3351 |
+
const int ith = params->ith;
|
| 3352 |
+
const int nth = params->nth;
|
| 3353 |
+
|
| 3354 |
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
| 3355 |
+
const int nr = ggml_nrows(src0);
|
| 3356 |
+
|
| 3357 |
+
GGML_ASSERT(dst->ne[0] == nc);
|
| 3358 |
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
| 3359 |
+
|
| 3360 |
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
| 3361 |
+
|
| 3362 |
+
// rows per thread
|
| 3363 |
+
const int dr = (nr + nth - 1)/nth;
|
| 3364 |
+
|
| 3365 |
+
// row range for this thread
|
| 3366 |
+
const int ir0 = dr*ith;
|
| 3367 |
+
const int ir1 = MIN(ir0 + dr, nr);
|
| 3368 |
+
|
| 3369 |
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
| 3370 |
+
float * src0_p = (float *) (src0_d + i1*src0_o);
|
| 3371 |
+
float * src1_p = (float *) (src1_d + i1*src1_o);
|
| 3372 |
+
|
| 3373 |
+
if (!src1) {
|
| 3374 |
+
src0_p += swapped ? nc : 0;
|
| 3375 |
+
src1_p += swapped ? 0 : nc;
|
| 3376 |
+
}
|
| 3377 |
+
|
| 3378 |
+
ggml_vec_geglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
| 3379 |
+
|
| 3380 |
+
#ifndef NDEBUG
|
| 3381 |
+
for (int k = 0; k < nc; k++) {
|
| 3382 |
+
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
| 3383 |
+
GGML_UNUSED(x);
|
| 3384 |
+
assert(!isnan(x));
|
| 3385 |
+
assert(!isinf(x));
|
| 3386 |
+
}
|
| 3387 |
+
#endif
|
| 3388 |
+
}
|
| 3389 |
+
}
|
| 3390 |
+
|
| 3391 |
+
static void ggml_compute_forward_geglu_f16(
|
| 3392 |
+
const ggml_compute_params * params,
|
| 3393 |
+
ggml_tensor * dst) {
|
| 3394 |
+
|
| 3395 |
+
const ggml_tensor * src0 = dst->src[0];
|
| 3396 |
+
const ggml_tensor * src1 = dst->src[1];
|
| 3397 |
+
char * src0_d = (char *) src0->data;
|
| 3398 |
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
| 3399 |
+
const size_t src0_o = src0->nb[1];
|
| 3400 |
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
| 3401 |
+
|
| 3402 |
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
| 3403 |
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
| 3404 |
+
|
| 3405 |
+
if (src1) {
|
| 3406 |
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
| 3407 |
+
GGML_ASSERT(src0->type == src1->type);
|
| 3408 |
+
}
|
| 3409 |
+
|
| 3410 |
+
const int ith = params->ith;
|
| 3411 |
+
const int nth = params->nth;
|
| 3412 |
+
|
| 3413 |
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
| 3414 |
+
const int nr = ggml_nrows(src0);
|
| 3415 |
+
|
| 3416 |
+
GGML_ASSERT(dst->ne[0] == nc);
|
| 3417 |
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
| 3418 |
+
|
| 3419 |
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
| 3420 |
+
|
| 3421 |
+
// rows per thread
|
| 3422 |
+
const int dr = (nr + nth - 1)/nth;
|
| 3423 |
+
|
| 3424 |
+
// row range for this thread
|
| 3425 |
+
const int ir0 = dr*ith;
|
| 3426 |
+
const int ir1 = MIN(ir0 + dr, nr);
|
| 3427 |
+
|
| 3428 |
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
| 3429 |
+
ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
|
| 3430 |
+
ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
|
| 3431 |
+
|
| 3432 |
+
if (!src1) {
|
| 3433 |
+
src0_p += swapped ? nc : 0;
|
| 3434 |
+
src1_p += swapped ? 0 : nc;
|
| 3435 |
+
}
|
| 3436 |
+
|
| 3437 |
+
ggml_vec_geglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
| 3438 |
+
|
| 3439 |
+
#ifndef NDEBUG
|
| 3440 |
+
for (int k = 0; k < nc; k++) {
|
| 3441 |
+
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
| 3442 |
+
const float v = GGML_FP16_TO_FP32(x);
|
| 3443 |
+
GGML_UNUSED(v);
|
| 3444 |
+
assert(!isnan(v));
|
| 3445 |
+
assert(!isinf(v));
|
| 3446 |
+
}
|
| 3447 |
+
#endif
|
| 3448 |
+
}
|
| 3449 |
+
}
|
| 3450 |
+
|
| 3451 |
+
static void ggml_compute_forward_geglu(
|
| 3452 |
+
const ggml_compute_params * params,
|
| 3453 |
+
ggml_tensor * dst) {
|
| 3454 |
+
|
| 3455 |
+
const ggml_tensor * src0 = dst->src[0];
|
| 3456 |
+
|
| 3457 |
+
switch (src0->type) {
|
| 3458 |
+
case GGML_TYPE_F32:
|
| 3459 |
+
{
|
| 3460 |
+
ggml_compute_forward_geglu_f32(params, dst);
|
| 3461 |
+
} break;
|
| 3462 |
+
case GGML_TYPE_F16:
|
| 3463 |
+
{
|
| 3464 |
+
ggml_compute_forward_geglu_f16(params, dst);
|
| 3465 |
+
} break;
|
| 3466 |
+
default:
|
| 3467 |
+
{
|
| 3468 |
+
GGML_ABORT("fatal error");
|
| 3469 |
+
}
|
| 3470 |
+
}
|
| 3471 |
+
}
|
| 3472 |
+
|
| 3473 |
+
// ggml_compute_forward_swiglu
|
| 3474 |
+
|
| 3475 |
+
static void ggml_compute_forward_swiglu_f32(
|
| 3476 |
+
const ggml_compute_params * params,
|
| 3477 |
+
ggml_tensor * dst) {
|
| 3478 |
+
|
| 3479 |
+
const ggml_tensor * src0 = dst->src[0];
|
| 3480 |
+
const ggml_tensor * src1 = dst->src[1];
|
| 3481 |
+
char * src0_d = (char *) src0->data;
|
| 3482 |
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
| 3483 |
+
const size_t src0_o = src0->nb[1];
|
| 3484 |
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
| 3485 |
+
|
| 3486 |
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
| 3487 |
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
| 3488 |
+
|
| 3489 |
+
if (src1) {
|
| 3490 |
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
| 3491 |
+
GGML_ASSERT(src0->type == src1->type);
|
| 3492 |
+
}
|
| 3493 |
+
|
| 3494 |
+
const int ith = params->ith;
|
| 3495 |
+
const int nth = params->nth;
|
| 3496 |
+
|
| 3497 |
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
| 3498 |
+
const int nr = ggml_nrows(src0);
|
| 3499 |
+
|
| 3500 |
+
GGML_ASSERT(dst->ne[0] == nc);
|
| 3501 |
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
| 3502 |
+
|
| 3503 |
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
| 3504 |
+
|
| 3505 |
+
// rows per thread
|
| 3506 |
+
const int dr = (nr + nth - 1)/nth;
|
| 3507 |
+
|
| 3508 |
+
// row range for this thread
|
| 3509 |
+
const int ir0 = dr*ith;
|
| 3510 |
+
const int ir1 = MIN(ir0 + dr, nr);
|
| 3511 |
+
|
| 3512 |
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
| 3513 |
+
float * src0_p = (float *) (src0_d + i1*src0_o);
|
| 3514 |
+
float * src1_p = (float *) (src1_d + i1*src1_o);
|
| 3515 |
+
|
| 3516 |
+
if (!src1) {
|
| 3517 |
+
src0_p += swapped ? nc : 0;
|
| 3518 |
+
src1_p += swapped ? 0 : nc;
|
| 3519 |
+
}
|
| 3520 |
+
|
| 3521 |
+
ggml_vec_swiglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
| 3522 |
+
|
| 3523 |
+
#ifndef NDEBUG
|
| 3524 |
+
for (int k = 0; k < nc; k++) {
|
| 3525 |
+
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
| 3526 |
+
GGML_UNUSED(x);
|
| 3527 |
+
assert(!isnan(x));
|
| 3528 |
+
assert(!isinf(x));
|
| 3529 |
+
}
|
| 3530 |
+
#endif
|
| 3531 |
+
}
|
| 3532 |
+
}
|
| 3533 |
+
|
| 3534 |
+
static void ggml_compute_forward_swiglu_f16(
|
| 3535 |
+
const ggml_compute_params * params,
|
| 3536 |
+
ggml_tensor * dst) {
|
| 3537 |
+
|
| 3538 |
+
const ggml_tensor * src0 = dst->src[0];
|
| 3539 |
+
const ggml_tensor * src1 = dst->src[1];
|
| 3540 |
+
char * src0_d = (char *) src0->data;
|
| 3541 |
+
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
| 3542 |
+
const size_t src0_o = src0->nb[1];
|
| 3543 |
+
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
| 3544 |
+
|
| 3545 |
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
| 3546 |
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
| 3547 |
+
|
| 3548 |
+
if (src1) {
|
| 3549 |
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
| 3550 |
+
GGML_ASSERT(src0->type == src1->type);
|
| 3551 |
+
}
|
| 3552 |
+
|
| 3553 |
+
const int ith = params->ith;
|
| 3554 |
+
const int nth = params->nth;
|
| 3555 |
+
|
| 3556 |
+
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
| 3557 |
+
const int nr = ggml_nrows(src0);
|
| 3558 |
+
|
| 3559 |
+
GGML_ASSERT(dst->ne[0] == nc);
|
| 3560 |
+
GGML_ASSERT(ggml_nrows(dst) == nr);
|
| 3561 |
+
|
| 3562 |
+
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
| 3563 |
+
|
| 3564 |
+
// rows per thread
|
| 3565 |
+
const int dr = (nr + nth - 1)/nth;
|
| 3566 |
+
|
| 3567 |
+
// row range for this thread
|
| 3568 |
+
const int ir0 = dr*ith;
|
| 3569 |
+
const int ir1 = MIN(ir0 + dr, nr);
|
| 3570 |
+
|
| 3571 |
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
| 3572 |
+
ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
|
| 3573 |
+
ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
|
| 3574 |
+
|
| 3575 |
+
if (!src1) {
|
| 3576 |
+
src0_p += swapped ? nc : 0;
|
| 3577 |
+
src1_p += swapped ? 0 : nc;
|
| 3578 |
+
}
|
| 3579 |
+
|
| 3580 |
+
ggml_vec_swiglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
| 3581 |
+
|
| 3582 |
+
#ifndef NDEBUG
|
| 3583 |
+
for (int k = 0; k < nc; k++) {
|
| 3584 |
+
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
| 3585 |
+
const float v = GGML_FP16_TO_FP32(x);
|
| 3586 |
+
GGML_UNUSED(v);
|
| 3587 |
+
assert(!isnan(v));
|
| 3588 |
+
assert(!isinf(v));
|
| 3589 |
+
}
|
| 3590 |
+
#endif
|
| 3591 |
+
}
|
| 3592 |
+
}
|
| 3593 |
+
|
| 3594 |
+
static void ggml_compute_forward_swiglu(
|
| 3595 |
+
const ggml_compute_params * params,
|
| 3596 |
+
ggml_tensor * dst) {
|
| 3597 |
+
|
| 3598 |
+
const ggml_tensor * src0 = dst->src[0];
|
| 3599 |
+
|
| 3600 |
+
switch (src0->type) {
|
| 3601 |
+
case GGML_TYPE_F32:
|
| 3602 |
+
{
|
| 3603 |
+
ggml_compute_forward_swiglu_f32(params, dst);
|
| 3604 |
+
} break;
|
| 3605 |
+
case GGML_TYPE_F16:
|
| 3606 |
+
{
|
| 3607 |
+
ggml_compute_forward_swiglu_f16(params, dst);
|
| 3608 |
+
} break;
|
| 3609 |
+
default:
|
| 3610 |
+
{
|
| 3611 |
+
GGML_ABORT("fatal error");
|
| 3612 |
+
}
|
| 3613 |
+
}
|
| 3614 |
+
}
|
| 3615 |
+
|
| 3616 |
// ggml_compute_forward_norm
|
| 3617 |
|
| 3618 |
static void ggml_compute_forward_norm_f32(
|
|
|
|
| 8486 |
}
|
| 8487 |
}
|
| 8488 |
|
| 8489 |
+
//ggml_compute_forward_glu
|
| 8490 |
+
|
| 8491 |
+
void ggml_compute_forward_glu(
|
| 8492 |
+
const ggml_compute_params * params,
|
| 8493 |
+
ggml_tensor * dst) {
|
| 8494 |
+
|
| 8495 |
+
const ggml_glu_op op = ggml_get_glu_op(dst);
|
| 8496 |
+
|
| 8497 |
+
switch (op) {
|
| 8498 |
+
case GGML_GLU_OP_REGLU:
|
| 8499 |
+
{
|
| 8500 |
+
ggml_compute_forward_reglu(params, dst);
|
| 8501 |
+
} break;
|
| 8502 |
+
case GGML_GLU_OP_GEGLU:
|
| 8503 |
+
{
|
| 8504 |
+
ggml_compute_forward_geglu(params, dst);
|
| 8505 |
+
} break;
|
| 8506 |
+
case GGML_GLU_OP_SWIGLU:
|
| 8507 |
+
{
|
| 8508 |
+
ggml_compute_forward_swiglu(params, dst);
|
| 8509 |
+
} break;
|
| 8510 |
+
default:
|
| 8511 |
+
{
|
| 8512 |
+
GGML_ABORT("fatal error");
|
| 8513 |
+
}
|
| 8514 |
+
}
|
| 8515 |
+
}
|
| 8516 |
+
|
| 8517 |
// ggml_compute_forward_get_rel_pos
|
| 8518 |
|
| 8519 |
static void ggml_compute_forward_get_rel_pos_f16(
|
|
@@ -94,6 +94,7 @@ void ggml_compute_forward_ssm_scan(const struct ggml_compute_params * params, st
|
|
| 94 |
void ggml_compute_forward_win_part(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
| 95 |
void ggml_compute_forward_win_unpart(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
| 96 |
void ggml_compute_forward_unary(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
|
|
| 97 |
void ggml_compute_forward_get_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
| 98 |
void ggml_compute_forward_add_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
| 99 |
void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
|
|
| 94 |
void ggml_compute_forward_win_part(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
| 95 |
void ggml_compute_forward_win_unpart(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
| 96 |
void ggml_compute_forward_unary(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
| 97 |
+
void ggml_compute_forward_glu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
| 98 |
void ggml_compute_forward_get_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
| 99 |
void ggml_compute_forward_add_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
| 100 |
void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
@@ -254,6 +254,30 @@ void ggml_vec_silu_f32(const int n, float * y, const float * x) {
|
|
| 254 |
}
|
| 255 |
}
|
| 256 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {
|
| 258 |
int i = 0;
|
| 259 |
ggml_float sum = 0;
|
|
|
|
| 254 |
}
|
| 255 |
}
|
| 256 |
|
| 257 |
+
void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float * g) {
|
| 258 |
+
int i = 0;
|
| 259 |
+
#if defined(__AVX512F__) && defined(__AVX512DQ__)
|
| 260 |
+
for (; i + 15 < n; i += 16) {
|
| 261 |
+
_mm512_storeu_ps(y + i, _mm512_mul_ps(ggml_v_silu(_mm512_loadu_ps(x + i)), _mm512_loadu_ps(g + i)));
|
| 262 |
+
}
|
| 263 |
+
#elif defined(__AVX2__) && defined(__FMA__)
|
| 264 |
+
for (; i + 7 < n; i += 8) {
|
| 265 |
+
_mm256_storeu_ps(y + i, _mm256_mul_ps(ggml_v_silu(_mm256_loadu_ps(x + i)), _mm256_loadu_ps(g + i)));
|
| 266 |
+
}
|
| 267 |
+
#elif defined(__SSE2__)
|
| 268 |
+
for (; i + 3 < n; i += 4) {
|
| 269 |
+
_mm_storeu_ps(y + i, _mm_mul_ps(ggml_v_silu(_mm_loadu_ps(x + i)), _mm_loadu_ps(g + i)));
|
| 270 |
+
}
|
| 271 |
+
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
| 272 |
+
for (; i + 3 < n; i += 4) {
|
| 273 |
+
vst1q_f32(y + i, vmulq_f32(ggml_v_silu(vld1q_f32(x + i)), vld1q_f32(g + i)));
|
| 274 |
+
}
|
| 275 |
+
#endif
|
| 276 |
+
for (; i < n; ++i) {
|
| 277 |
+
y[i] = ggml_silu_f32(x[i]) * g[i];
|
| 278 |
+
}
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {
|
| 282 |
int i = 0;
|
| 283 |
ggml_float sum = 0;
|
|
@@ -905,6 +905,60 @@ inline static void ggml_vec_silu_backward_f16(const int n, ggml_fp16_t * dx, con
|
|
| 905 |
}
|
| 906 |
}
|
| 907 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 908 |
inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
|
| 909 |
#ifndef GGML_USE_ACCELERATE
|
| 910 |
ggml_float sum = 0.0;
|
|
|
|
| 905 |
}
|
| 906 |
}
|
| 907 |
|
| 908 |
+
inline static void ggml_vec_reglu_f32 (const int n, float * y, const float * x, const float * g) {
|
| 909 |
+
for (int i = 0; i < n; ++i) {
|
| 910 |
+
y[i] = (x[i] > 0.f) ? x[i] * g[i] : 0.f;
|
| 911 |
+
}
|
| 912 |
+
}
|
| 913 |
+
|
| 914 |
+
inline static void ggml_vec_reglu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
|
| 915 |
+
for (int i = 0; i < n; ++i) {
|
| 916 |
+
float v = GGML_FP16_TO_FP32(x[i]);
|
| 917 |
+
y[i] = GGML_FP32_TO_FP16((v > 0.f) ? v * GGML_FP16_TO_FP32(g[i]) : 0.f);
|
| 918 |
+
}
|
| 919 |
+
}
|
| 920 |
+
|
| 921 |
+
#ifdef GGML_GELU_FP16
|
| 922 |
+
inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x, const float * g) {
|
| 923 |
+
uint16_t t;
|
| 924 |
+
for (int i = 0; i < n; ++i) {
|
| 925 |
+
if (x[i] <= -10.0f) {
|
| 926 |
+
y[i] = 0.0f;
|
| 927 |
+
} else if (x[i] >= 10.0f) {
|
| 928 |
+
y[i] = x[i] * g[i];
|
| 929 |
+
} else {
|
| 930 |
+
ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
|
| 931 |
+
memcpy(&t, &fp16, sizeof(uint16_t));
|
| 932 |
+
y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t]) * g[i];
|
| 933 |
+
}
|
| 934 |
+
}
|
| 935 |
+
}
|
| 936 |
+
#else
|
| 937 |
+
inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x, const float * g) {
|
| 938 |
+
for (int i = 0; i < n; ++i) {
|
| 939 |
+
y[i] = ggml_gelu_f32(x[i]) * g[i];
|
| 940 |
+
}
|
| 941 |
+
}
|
| 942 |
+
#endif
|
| 943 |
+
|
| 944 |
+
inline static void ggml_vec_geglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
|
| 945 |
+
const uint16_t * i16 = (const uint16_t *) x;
|
| 946 |
+
for (int i = 0; i < n; ++i) {
|
| 947 |
+
float v = GGML_FP16_TO_FP32(g[i]);
|
| 948 |
+
y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(ggml_table_gelu_f16[i16[i]]) * v);
|
| 949 |
+
}
|
| 950 |
+
}
|
| 951 |
+
|
| 952 |
+
void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float * g);
|
| 953 |
+
|
| 954 |
+
inline static void ggml_vec_swiglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
|
| 955 |
+
for (int i = 0; i < n; ++i) {
|
| 956 |
+
float v = GGML_FP16_TO_FP32(x[i]);
|
| 957 |
+
float w = GGML_FP16_TO_FP32(g[i]);
|
| 958 |
+
y[i] = GGML_FP32_TO_FP16((v/(1.0f + expf(-v))) * w);
|
| 959 |
+
}
|
| 960 |
+
}
|
| 961 |
+
|
| 962 |
inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
|
| 963 |
#ifndef GGML_USE_ACCELERATE
|
| 964 |
ggml_float sum = 0.0;
|
|
@@ -2303,6 +2303,21 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|
| 2303 |
return false;
|
| 2304 |
}
|
| 2305 |
break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2306 |
case GGML_OP_NORM:
|
| 2307 |
ggml_cuda_op_norm(ctx, dst);
|
| 2308 |
break;
|
|
@@ -3096,6 +3111,16 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|
| 3096 |
return false;
|
| 3097 |
}
|
| 3098 |
break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3099 |
case GGML_OP_MUL_MAT:
|
| 3100 |
case GGML_OP_MUL_MAT_ID:
|
| 3101 |
{
|
|
|
|
| 2303 |
return false;
|
| 2304 |
}
|
| 2305 |
break;
|
| 2306 |
+
case GGML_OP_GLU:
|
| 2307 |
+
switch (ggml_get_glu_op(dst)) {
|
| 2308 |
+
case GGML_GLU_OP_REGLU:
|
| 2309 |
+
ggml_cuda_op_reglu(ctx, dst);
|
| 2310 |
+
break;
|
| 2311 |
+
case GGML_GLU_OP_GEGLU:
|
| 2312 |
+
ggml_cuda_op_geglu(ctx, dst);
|
| 2313 |
+
break;
|
| 2314 |
+
case GGML_GLU_OP_SWIGLU:
|
| 2315 |
+
ggml_cuda_op_swiglu(ctx, dst);
|
| 2316 |
+
break;
|
| 2317 |
+
default:
|
| 2318 |
+
return false;
|
| 2319 |
+
}
|
| 2320 |
+
break;
|
| 2321 |
case GGML_OP_NORM:
|
| 2322 |
ggml_cuda_op_norm(ctx, dst);
|
| 2323 |
break;
|
|
|
|
| 3111 |
return false;
|
| 3112 |
}
|
| 3113 |
break;
|
| 3114 |
+
case GGML_OP_GLU:
|
| 3115 |
+
switch (ggml_get_glu_op(op)) {
|
| 3116 |
+
case GGML_GLU_OP_REGLU:
|
| 3117 |
+
case GGML_GLU_OP_GEGLU:
|
| 3118 |
+
case GGML_GLU_OP_SWIGLU:
|
| 3119 |
+
return ggml_is_contiguous_1(op->src[0]);
|
| 3120 |
+
default:
|
| 3121 |
+
return false;
|
| 3122 |
+
}
|
| 3123 |
+
break;
|
| 3124 |
case GGML_OP_MUL_MAT:
|
| 3125 |
case GGML_OP_MUL_MAT_ID:
|
| 3126 |
{
|
|
@@ -196,6 +196,95 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
| 196 |
ggml_cuda_op_unary<op_log>(ctx, dst);
|
| 197 |
}
|
| 198 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
/* silu_back */
|
| 200 |
|
| 201 |
static __device__ __forceinline__ float op_silu_back(float grad, float x) {
|
|
|
|
| 196 |
ggml_cuda_op_unary<op_log>(ctx, dst);
|
| 197 |
}
|
| 198 |
|
| 199 |
+
/* gated ops */
|
| 200 |
+
|
| 201 |
+
template <float (*op)(float), typename T>
|
| 202 |
+
static __global__ void unary_gated_op_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1) {
|
| 203 |
+
const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x;
|
| 204 |
+
|
| 205 |
+
if (i >= k) {
|
| 206 |
+
return;
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
// perform base op and multiply with gate (either offset in same tensor or a separate one)
|
| 210 |
+
const int64_t j0 = (i / n) * o0 + (i % n);
|
| 211 |
+
const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
|
| 212 |
+
|
| 213 |
+
dst[i] = (T)(op((float)x[j0]) * (float)g[j1]);
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
template <float (*op)(float), typename T>
|
| 217 |
+
static void unary_gated_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, cudaStream_t stream) {
|
| 218 |
+
const int64_t num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE;
|
| 219 |
+
unary_gated_op_kernel<op><<<num_blocks, CUDA_GLU_BLOCK_SIZE, 0, stream>>>(x, g, dst, k, n, o0, o1);
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
template <float (*op)(float)>
|
| 223 |
+
void ggml_cuda_op_unary_gated(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 224 |
+
const ggml_tensor * src0 = dst->src[0];
|
| 225 |
+
const ggml_tensor * src1 = dst->src[1];
|
| 226 |
+
void * src0_d = src0->data;
|
| 227 |
+
void * src1_d = src1 ? src1->data : src0->data;
|
| 228 |
+
const int64_t src0_o = src0->nb[1];
|
| 229 |
+
const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
| 230 |
+
void * dst_d = dst->data;
|
| 231 |
+
const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
| 232 |
+
cudaStream_t stream = ctx.stream();
|
| 233 |
+
|
| 234 |
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
| 235 |
+
GGML_ASSERT(src0->nb[0] == ggml_element_size(src0));
|
| 236 |
+
GGML_ASSERT(ggml_is_contiguous(dst));
|
| 237 |
+
|
| 238 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
| 239 |
+
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
| 240 |
+
GGML_ASSERT(src0->type == dst->type);
|
| 241 |
+
GGML_ASSERT(dst->ne[0] == nc);
|
| 242 |
+
GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0));
|
| 243 |
+
|
| 244 |
+
if (src1) {
|
| 245 |
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
| 246 |
+
GGML_ASSERT(src1->nb[0] == ggml_element_size(src1));
|
| 247 |
+
GGML_ASSERT(src1->ne[0] == nc);
|
| 248 |
+
GGML_ASSERT(src0->type == src1->type);
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
const int32_t swapped = ((const int32_t *) dst->op_params)[1];
|
| 252 |
+
|
| 253 |
+
if (src0->type == GGML_TYPE_F16) {
|
| 254 |
+
half * src0_p = (half *) src0_d;
|
| 255 |
+
half * src1_p = (half *) src1_d;
|
| 256 |
+
|
| 257 |
+
if (!src1) {
|
| 258 |
+
src0_p += swapped ? nc : 0;
|
| 259 |
+
src1_p += swapped ? 0 : nc;
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
unary_gated_cuda<op>(src0_p, src1_p, (half *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(half), src1_o / sizeof(half), stream);
|
| 263 |
+
} else {
|
| 264 |
+
float * src0_p = (float *) src0_d;
|
| 265 |
+
float * src1_p = (float *) src1_d;
|
| 266 |
+
|
| 267 |
+
if (!src1) {
|
| 268 |
+
src0_p += swapped ? nc : 0;
|
| 269 |
+
src1_p += swapped ? 0 : nc;
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
unary_gated_cuda<op>(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), stream);
|
| 273 |
+
}
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 277 |
+
ggml_cuda_op_unary_gated<op_relu>(ctx, dst);
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 281 |
+
ggml_cuda_op_unary_gated<op_gelu>(ctx, dst);
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 285 |
+
ggml_cuda_op_unary_gated<op_silu>(ctx, dst);
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
/* silu_back */
|
| 289 |
|
| 290 |
static __device__ __forceinline__ float op_silu_back(float grad, float x) {
|
|
@@ -15,6 +15,7 @@
|
|
| 15 |
#define CUDA_SQRT_BLOCK_SIZE 256
|
| 16 |
#define CUDA_SIN_BLOCK_SIZE 256
|
| 17 |
#define CUDA_COS_BLOCK_SIZE 256
|
|
|
|
| 18 |
|
| 19 |
void ggml_cuda_op_abs(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
| 20 |
|
|
@@ -57,3 +58,9 @@ void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|
| 57 |
void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
| 58 |
|
| 59 |
void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
#define CUDA_SQRT_BLOCK_SIZE 256
|
| 16 |
#define CUDA_SIN_BLOCK_SIZE 256
|
| 17 |
#define CUDA_COS_BLOCK_SIZE 256
|
| 18 |
+
#define CUDA_GLU_BLOCK_SIZE 256
|
| 19 |
|
| 20 |
void ggml_cuda_op_abs(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
| 21 |
|
|
|
|
| 58 |
void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
| 59 |
|
| 60 |
void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
| 61 |
+
|
| 62 |
+
void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
| 63 |
+
|
| 64 |
+
void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
| 65 |
+
|
| 66 |
+
void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|
@@ -422,6 +422,17 @@ typedef struct {
|
|
| 422 |
int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources
|
| 423 |
} ggml_metal_kargs_im2col;
|
| 424 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 425 |
typedef struct {
|
| 426 |
int64_t ne00;
|
| 427 |
int64_t ne01;
|
|
|
|
| 422 |
int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources
|
| 423 |
} ggml_metal_kargs_im2col;
|
| 424 |
|
| 425 |
+
typedef struct{
|
| 426 |
+
int32_t ne00;
|
| 427 |
+
uint64_t nb01;
|
| 428 |
+
int32_t ne10;
|
| 429 |
+
uint64_t nb11;
|
| 430 |
+
int32_t ne0;
|
| 431 |
+
uint64_t nb1;
|
| 432 |
+
int32_t i00;
|
| 433 |
+
int32_t i10;
|
| 434 |
+
} ggml_metal_kargs_glu;
|
| 435 |
+
|
| 436 |
typedef struct {
|
| 437 |
int64_t ne00;
|
| 438 |
int64_t ne01;
|
|
@@ -526,6 +526,9 @@ enum ggml_metal_kernel_type {
|
|
| 526 |
GGML_METAL_KERNEL_TYPE_SIN,
|
| 527 |
GGML_METAL_KERNEL_TYPE_COS,
|
| 528 |
GGML_METAL_KERNEL_TYPE_NEG,
|
|
|
|
|
|
|
|
|
|
| 529 |
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
| 530 |
GGML_METAL_KERNEL_TYPE_MEAN,
|
| 531 |
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
|
@@ -1502,6 +1505,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
| 1502 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
|
| 1503 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
| 1504 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
|
|
|
|
|
|
|
|
|
|
| 1505 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
| 1506 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
|
| 1507 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
|
@@ -1680,6 +1686,15 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
| 1680 |
default:
|
| 1681 |
return false;
|
| 1682 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1683 |
case GGML_OP_NONE:
|
| 1684 |
case GGML_OP_RESHAPE:
|
| 1685 |
case GGML_OP_VIEW:
|
|
@@ -2419,6 +2434,62 @@ static bool ggml_metal_encode_node(
|
|
| 2419 |
GGML_ABORT("fatal error");
|
| 2420 |
}
|
| 2421 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2422 |
case GGML_OP_SQR:
|
| 2423 |
{
|
| 2424 |
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
|
|
| 526 |
GGML_METAL_KERNEL_TYPE_SIN,
|
| 527 |
GGML_METAL_KERNEL_TYPE_COS,
|
| 528 |
GGML_METAL_KERNEL_TYPE_NEG,
|
| 529 |
+
GGML_METAL_KERNEL_TYPE_REGLU,
|
| 530 |
+
GGML_METAL_KERNEL_TYPE_GEGLU,
|
| 531 |
+
GGML_METAL_KERNEL_TYPE_SWIGLU,
|
| 532 |
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
| 533 |
GGML_METAL_KERNEL_TYPE_MEAN,
|
| 534 |
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
|
|
|
| 1505 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
|
| 1506 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
| 1507 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
|
| 1508 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REGLU, reglu, true);
|
| 1509 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true);
|
| 1510 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true);
|
| 1511 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
| 1512 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
|
| 1513 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
|
|
|
| 1686 |
default:
|
| 1687 |
return false;
|
| 1688 |
}
|
| 1689 |
+
case GGML_OP_GLU:
|
| 1690 |
+
switch (ggml_get_glu_op(op)) {
|
| 1691 |
+
case GGML_GLU_OP_REGLU:
|
| 1692 |
+
case GGML_GLU_OP_GEGLU:
|
| 1693 |
+
case GGML_GLU_OP_SWIGLU:
|
| 1694 |
+
return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
| 1695 |
+
default:
|
| 1696 |
+
return false;
|
| 1697 |
+
}
|
| 1698 |
case GGML_OP_NONE:
|
| 1699 |
case GGML_OP_RESHAPE:
|
| 1700 |
case GGML_OP_VIEW:
|
|
|
|
| 2434 |
GGML_ABORT("fatal error");
|
| 2435 |
}
|
| 2436 |
} break;
|
| 2437 |
+
case GGML_OP_GLU:
|
| 2438 |
+
{
|
| 2439 |
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
| 2440 |
+
|
| 2441 |
+
if (src1) {
|
| 2442 |
+
GGML_ASSERT(ggml_are_same_shape(src0, src1));
|
| 2443 |
+
}
|
| 2444 |
+
|
| 2445 |
+
id<MTLComputePipelineState> pipeline = nil;
|
| 2446 |
+
|
| 2447 |
+
switch (ggml_get_glu_op(node)) {
|
| 2448 |
+
case GGML_GLU_OP_REGLU:
|
| 2449 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REGLU].pipeline;
|
| 2450 |
+
break;
|
| 2451 |
+
case GGML_GLU_OP_GEGLU:
|
| 2452 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU].pipeline;
|
| 2453 |
+
break;
|
| 2454 |
+
case GGML_GLU_OP_SWIGLU:
|
| 2455 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
|
| 2456 |
+
break;
|
| 2457 |
+
default:
|
| 2458 |
+
GGML_ABORT("fatal error");
|
| 2459 |
+
}
|
| 2460 |
+
|
| 2461 |
+
const int32_t swp = ((const int32_t *) dst->op_params)[1];
|
| 2462 |
+
|
| 2463 |
+
const int32_t i00 = swp ? ne0 : 0;
|
| 2464 |
+
const int32_t i10 = swp ? 0 : ne0;
|
| 2465 |
+
|
| 2466 |
+
ggml_metal_kargs_glu args = {
|
| 2467 |
+
/*.ne00 =*/ ne00,
|
| 2468 |
+
/*.nb01 =*/ nb01,
|
| 2469 |
+
/*.ne10 =*/ src1 ? ne10 : ne00,
|
| 2470 |
+
/*.nb11 =*/ src1 ? nb11 : nb01,
|
| 2471 |
+
/*.ne0 =*/ ne0,
|
| 2472 |
+
/*.nb1 =*/ nb1,
|
| 2473 |
+
/*.i00 =*/ src1 ? 0 : i00,
|
| 2474 |
+
/*.i10 =*/ src1 ? 0 : i10,
|
| 2475 |
+
};
|
| 2476 |
+
|
| 2477 |
+
[encoder setComputePipelineState:pipeline];
|
| 2478 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 2479 |
+
if (src1) {
|
| 2480 |
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
| 2481 |
+
} else {
|
| 2482 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
| 2483 |
+
}
|
| 2484 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
| 2485 |
+
[encoder setBytes:&args length:sizeof(args) atIndex:3];
|
| 2486 |
+
|
| 2487 |
+
const int64_t nrows = ggml_nrows(src0);
|
| 2488 |
+
|
| 2489 |
+
const int32_t nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00/2);
|
| 2490 |
+
|
| 2491 |
+
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 2492 |
+
} break;
|
| 2493 |
case GGML_OP_SQR:
|
| 2494 |
{
|
| 2495 |
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
@@ -1191,6 +1191,70 @@ kernel void kernel_neg(
|
|
| 1191 |
dst[tpig] = -src0[tpig];
|
| 1192 |
}
|
| 1193 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1194 |
template <bool norm>
|
| 1195 |
kernel void kernel_sum_rows(
|
| 1196 |
constant ggml_metal_kargs_sum_rows & args,
|
|
|
|
| 1191 |
dst[tpig] = -src0[tpig];
|
| 1192 |
}
|
| 1193 |
|
| 1194 |
+
kernel void kernel_reglu(
|
| 1195 |
+
device const char * src0,
|
| 1196 |
+
device const char * src1,
|
| 1197 |
+
device char * dst,
|
| 1198 |
+
constant ggml_metal_kargs_glu & args,
|
| 1199 |
+
uint tgpig[[threadgroup_position_in_grid]],
|
| 1200 |
+
uint tpitg[[thread_position_in_threadgroup]],
|
| 1201 |
+
uint ntg[[threads_per_threadgroup]]) {
|
| 1202 |
+
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
| 1203 |
+
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
| 1204 |
+
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
|
| 1205 |
+
|
| 1206 |
+
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
| 1207 |
+
const float x0 = src0_row[i0];
|
| 1208 |
+
const float x1 = src1_row[i0];
|
| 1209 |
+
|
| 1210 |
+
dst_row[i0] = x0*x1*(x0 > 0.0f);
|
| 1211 |
+
}
|
| 1212 |
+
}
|
| 1213 |
+
|
| 1214 |
+
kernel void kernel_geglu(
|
| 1215 |
+
device const char * src0,
|
| 1216 |
+
device const char * src1,
|
| 1217 |
+
device char * dst,
|
| 1218 |
+
constant ggml_metal_kargs_glu & args,
|
| 1219 |
+
uint tgpig[[threadgroup_position_in_grid]],
|
| 1220 |
+
uint tpitg[[thread_position_in_threadgroup]],
|
| 1221 |
+
uint ntg[[threads_per_threadgroup]]) {
|
| 1222 |
+
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
| 1223 |
+
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
| 1224 |
+
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
|
| 1225 |
+
|
| 1226 |
+
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
| 1227 |
+
const float x0 = src0_row[i0];
|
| 1228 |
+
const float x1 = src1_row[i0];
|
| 1229 |
+
|
| 1230 |
+
const float gelu = 0.5f*x0*(1.0f + precise::tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
|
| 1231 |
+
|
| 1232 |
+
dst_row[i0] = gelu*x1;
|
| 1233 |
+
}
|
| 1234 |
+
}
|
| 1235 |
+
|
| 1236 |
+
kernel void kernel_swiglu(
|
| 1237 |
+
device const char * src0,
|
| 1238 |
+
device const char * src1,
|
| 1239 |
+
device char * dst,
|
| 1240 |
+
constant ggml_metal_kargs_glu & args,
|
| 1241 |
+
uint tgpig[[threadgroup_position_in_grid]],
|
| 1242 |
+
uint tpitg[[thread_position_in_threadgroup]],
|
| 1243 |
+
uint ntg[[threads_per_threadgroup]]) {
|
| 1244 |
+
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
| 1245 |
+
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
| 1246 |
+
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
|
| 1247 |
+
|
| 1248 |
+
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
| 1249 |
+
const float x0 = src0_row[i0];
|
| 1250 |
+
const float x1 = src1_row[i0];
|
| 1251 |
+
|
| 1252 |
+
const float silu = x0 / (1.0f + exp(-x0));
|
| 1253 |
+
|
| 1254 |
+
dst_row[i0] = silu*x1;
|
| 1255 |
+
}
|
| 1256 |
+
}
|
| 1257 |
+
|
| 1258 |
template <bool norm>
|
| 1259 |
kernel void kernel_sum_rows(
|
| 1260 |
constant ggml_metal_kargs_sum_rows & args,
|
|
@@ -1,12 +1,19 @@
|
|
| 1 |
#include "common.hpp"
|
|
|
|
| 2 |
#include "ggml.h"
|
| 3 |
#include "element_wise.hpp"
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
static void acc_f32(const float * x, const float * y, float * dst, const int ne,
|
| 6 |
const int ne10, const int ne11, const int ne12,
|
| 7 |
-
const int nb1, const int nb2, int offset, const sycl::nd_item<
|
| 8 |
-
const int i = item_ct1
|
| 9 |
-
item_ct1.get_local_id(2);
|
| 10 |
if (i >= ne) {
|
| 11 |
return;
|
| 12 |
}
|
|
@@ -21,535 +28,375 @@ static void acc_f32(const float * x, const float * y, float * dst, const int ne,
|
|
| 21 |
}
|
| 22 |
}
|
| 23 |
|
|
|
|
| 24 |
template<typename T>
|
| 25 |
-
static
|
| 26 |
-
|
| 27 |
-
dst[i] = x[i] > static_cast<T>(0.f) ? static_cast<T>(1.f) : ((x[i] < static_cast<T>(0.f) ? static_cast<T>(-1.f) : static_cast<T>(0.f)));
|
| 28 |
-
}
|
| 29 |
}
|
| 30 |
|
| 31 |
template<typename T>
|
| 32 |
-
static
|
| 33 |
-
|
| 34 |
-
dst[i] = sycl::fabs(x[i]);
|
| 35 |
-
}
|
| 36 |
}
|
| 37 |
|
| 38 |
template<typename T>
|
| 39 |
-
static
|
| 40 |
-
|
| 41 |
-
dst[i] = (x[i] > static_cast<T>(0.f)) ? x[i] : sycl::expm1(x[i]);
|
| 42 |
-
}
|
| 43 |
}
|
| 44 |
|
| 45 |
template<typename T>
|
| 46 |
-
static
|
| 47 |
-
const sycl::nd_item<3> &item_ct1) {
|
| 48 |
const T GELU_COEF_A = static_cast<T>(0.044715f);
|
| 49 |
const T SQRT_2_OVER_PI = static_cast<T>(0.79788456080286535587989211986876f);
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
if (i >= k) {
|
| 54 |
-
return;
|
| 55 |
-
}
|
| 56 |
-
|
| 57 |
-
float xi = x[i];
|
| 58 |
-
dst[i] = static_cast<T>(0.5f) * xi *
|
| 59 |
-
(static_cast<T>(1.0f) +
|
| 60 |
-
sycl::tanh(SQRT_2_OVER_PI * xi * (static_cast<T>(1.0f) + GELU_COEF_A * xi * xi)));
|
| 61 |
}
|
| 62 |
|
| 63 |
template<typename T>
|
| 64 |
-
static
|
| 65 |
-
|
| 66 |
-
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
| 67 |
-
item_ct1.get_local_id(2);
|
| 68 |
-
|
| 69 |
-
if (i >= k) {
|
| 70 |
-
return;
|
| 71 |
-
}
|
| 72 |
-
dst[i] = x[i] / (static_cast<T>(1.0f) + sycl::native::exp(-x[i]));
|
| 73 |
}
|
| 74 |
|
| 75 |
template<typename T>
|
| 76 |
-
static
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
| 80 |
-
item_ct1.get_local_id(2);
|
| 81 |
-
if (i >= k) {
|
| 82 |
-
return;
|
| 83 |
-
}
|
| 84 |
-
dst[i] = x[i] * (static_cast<T>(1.0f) / (static_cast<T>(1.0f) + sycl::native::exp(GELU_QUICK_COEF * x[i])));
|
| 85 |
}
|
| 86 |
|
| 87 |
template<typename T>
|
| 88 |
-
static
|
| 89 |
const T SQRT_2_INV = static_cast<T>(0.70710678118654752440084436210484f);
|
| 90 |
-
|
| 91 |
-
auto x_i = x[i];
|
| 92 |
-
dst[i] = static_cast<T>(0.5f) * x_i * (static_cast<T>(1.0f) + sycl::erf(x_i * SQRT_2_INV));
|
| 93 |
-
}
|
| 94 |
}
|
| 95 |
|
| 96 |
template<typename T>
|
| 97 |
-
static
|
| 98 |
-
|
| 99 |
-
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
| 100 |
-
item_ct1.get_local_id(2);
|
| 101 |
-
if (i >= k) {
|
| 102 |
-
return;
|
| 103 |
-
}
|
| 104 |
-
dst[i] = sycl::tanh((x[i]));
|
| 105 |
}
|
| 106 |
|
| 107 |
template<typename T>
|
| 108 |
-
static
|
| 109 |
-
|
| 110 |
-
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
| 111 |
-
item_ct1.get_local_id(2);
|
| 112 |
-
|
| 113 |
-
if (i >= k) {
|
| 114 |
-
return;
|
| 115 |
-
}
|
| 116 |
-
dst[i] = sycl::fmax((x[i]), static_cast<T>(0));
|
| 117 |
}
|
| 118 |
|
| 119 |
template<typename T>
|
| 120 |
-
static
|
| 121 |
-
|
| 122 |
-
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
| 123 |
-
item_ct1.get_local_id(2);
|
| 124 |
-
|
| 125 |
-
if (i >= k) {
|
| 126 |
-
return;
|
| 127 |
-
}
|
| 128 |
-
dst[i] = 1.0f / (static_cast<T>(1.0f) + sycl::native::exp(-x[i]));
|
| 129 |
}
|
| 130 |
|
| 131 |
template<typename T>
|
| 132 |
-
static
|
| 133 |
-
|
| 134 |
-
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
| 135 |
-
item_ct1.get_local_id(2);
|
| 136 |
-
|
| 137 |
-
if (i >= k) {
|
| 138 |
-
return;
|
| 139 |
-
}
|
| 140 |
-
dst[i] = sycl::sqrt(x[i]);
|
| 141 |
}
|
| 142 |
|
| 143 |
template<typename T>
|
| 144 |
-
static
|
| 145 |
-
|
| 146 |
-
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
| 147 |
-
item_ct1.get_local_id(2);
|
| 148 |
-
|
| 149 |
-
if (i >= k) {
|
| 150 |
-
return;
|
| 151 |
-
}
|
| 152 |
-
dst[i] = sycl::sin(x[i]);
|
| 153 |
}
|
| 154 |
|
| 155 |
template<typename T>
|
| 156 |
-
static
|
| 157 |
-
|
| 158 |
-
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
| 159 |
-
item_ct1.get_local_id(2);
|
| 160 |
-
|
| 161 |
-
if (i >= k) {
|
| 162 |
-
return;
|
| 163 |
-
}
|
| 164 |
-
dst[i] = sycl::cos(x[i]);
|
| 165 |
}
|
| 166 |
|
| 167 |
template<typename T>
|
| 168 |
-
static
|
| 169 |
-
|
| 170 |
-
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
| 171 |
-
item_ct1.get_local_id(2);
|
| 172 |
-
|
| 173 |
-
if (i >= k) {
|
| 174 |
-
return;
|
| 175 |
-
}
|
| 176 |
-
dst[i] = sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x[i] + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
|
| 177 |
}
|
| 178 |
|
| 179 |
template<typename T>
|
| 180 |
-
static
|
| 181 |
-
|
| 182 |
-
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
| 183 |
-
item_ct1.get_local_id(2);
|
| 184 |
-
|
| 185 |
-
if (i >= k) {
|
| 186 |
-
return;
|
| 187 |
-
}
|
| 188 |
-
dst[i] = x[i] * sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x[i] + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
|
| 189 |
}
|
| 190 |
|
| 191 |
template<typename T>
|
| 192 |
-
static
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
item_ct1.get_local_id(2);
|
| 196 |
|
| 197 |
-
|
| 198 |
-
|
|
|
|
|
|
|
| 199 |
}
|
| 200 |
-
|
| 201 |
}
|
| 202 |
|
| 203 |
template<typename T>
|
| 204 |
-
static
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
item_ct1.get_local_id(2);
|
| 208 |
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
T xi = x[i];
|
| 213 |
-
if (xi <= 0) {
|
| 214 |
-
dst[i] = neg_infinity<T>();
|
| 215 |
-
} else {
|
| 216 |
-
dst[i] = sycl::log(xi);
|
| 217 |
-
}
|
| 218 |
}
|
| 219 |
|
| 220 |
template<typename T>
|
| 221 |
-
static
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
|
|
|
| 225 |
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
dst[i] = -x[i];
|
| 230 |
}
|
| 231 |
|
| 232 |
template<typename T>
|
| 233 |
-
static
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
item_ct1.get_local_id(2);
|
| 237 |
|
| 238 |
-
|
| 239 |
-
|
|
|
|
|
|
|
| 240 |
}
|
| 241 |
-
dst[i] = x[i] > static_cast<T>(0.0f);
|
| 242 |
}
|
| 243 |
|
| 244 |
template<typename T>
|
| 245 |
-
static void
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
item_ct1.get_local_id(2);
|
| 249 |
-
if (i >= k) {
|
| 250 |
-
return;
|
| 251 |
}
|
| 252 |
-
dst[i] = sycl::fmax((x[i]), static_cast<T>(0)) +
|
| 253 |
-
sycl::fmin((x[i]), static_cast<T>(0.0f)) * negative_slope;
|
| 254 |
}
|
| 255 |
|
| 256 |
template<typename T>
|
| 257 |
-
static void
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
item_ct1.get_local_id(2);
|
| 261 |
-
|
| 262 |
-
if (i >= k) {
|
| 263 |
-
return;
|
| 264 |
}
|
| 265 |
-
dst[i] = x[i] * x[i];
|
| 266 |
}
|
| 267 |
|
| 268 |
-
template<typename
|
| 269 |
-
static void
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
const float sf2, const float sf3, const sycl::nd_item<1> &item_ct1) {
|
| 273 |
-
int index = item_ct1.get_local_id(0) +
|
| 274 |
-
item_ct1.get_group(0) * item_ct1.get_local_range(0);
|
| 275 |
-
if (index >= ne10 * ne11 * ne12 * ne13) {
|
| 276 |
-
return;
|
| 277 |
}
|
| 278 |
-
// operation
|
| 279 |
-
int i10 = index % ne10;
|
| 280 |
-
int i11 = (index / ne10) % ne11;
|
| 281 |
-
int i12 = (index / (ne10 * ne11)) % ne12;
|
| 282 |
-
int i13 = (index / (ne10 * ne11 * ne12)) % ne13;
|
| 283 |
-
|
| 284 |
-
int i00 = i10 / sf0;
|
| 285 |
-
int i01 = i11 / sf1;
|
| 286 |
-
int i02 = i12 / sf2;
|
| 287 |
-
int i03 = i13 / sf3;
|
| 288 |
-
|
| 289 |
-
dst[index] = *(const T *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
|
| 290 |
}
|
| 291 |
|
| 292 |
-
template
|
| 293 |
-
static void
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
item_ct1.get_group(2) * item_ct1.get_local_range(2);
|
| 297 |
-
if (nidx >= ne0) {
|
| 298 |
-
return;
|
| 299 |
}
|
|
|
|
| 300 |
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
int offset_src = nidx + item_ct1.get_group(1) * ne00 +
|
| 306 |
-
item_ct1.get_group(0) * ne00 * ne01;
|
| 307 |
-
dst[offset_dst] = x[offset_src];
|
| 308 |
-
} else {
|
| 309 |
-
dst[offset_dst] = static_cast<T>(0.0f);
|
| 310 |
}
|
| 311 |
}
|
| 312 |
|
| 313 |
-
|
| 314 |
template<typename T>
|
| 315 |
-
static void
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
item_ct1.get_local_id(2);
|
| 319 |
-
|
| 320 |
-
if (i >= k) {
|
| 321 |
-
return;
|
| 322 |
}
|
| 323 |
-
|
| 324 |
-
dst[i] = x[i] < static_cast<T>(min) ? static_cast<T>(min) : (x[i] > static_cast<T>(max) ? static_cast<T>(max) : x[i]);
|
| 325 |
}
|
| 326 |
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
sycl_parallel_for(stream,
|
| 333 |
-
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE),
|
| 334 |
-
sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)),
|
| 335 |
-
[=](sycl::nd_item<3> item_ct1) {
|
| 336 |
-
acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset, item_ct1);
|
| 337 |
-
});
|
| 338 |
}
|
| 339 |
|
| 340 |
template<typename T>
|
| 341 |
-
static void
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
|
| 346 |
-
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
|
| 347 |
-
[=](sycl::nd_item<3> item_ct1) { gelu(x, dst, k, item_ct1); });
|
| 348 |
}
|
| 349 |
|
| 350 |
template<typename T>
|
| 351 |
-
static void
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE),
|
| 356 |
-
sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE)),
|
| 357 |
-
[=](sycl::nd_item<3> item_ct1) { silu(x, dst, k, item_ct1); });
|
| 358 |
}
|
| 359 |
|
| 360 |
template<typename T>
|
| 361 |
-
static void
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
stream, sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range(1, 1, 256)), sycl::range(1, 1, 256)),
|
| 366 |
-
[=](sycl::nd_item<3> item_ct1) { sgn(x, dst, k, item_ct1); });
|
| 367 |
}
|
| 368 |
|
| 369 |
template<typename T>
|
| 370 |
-
static void
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
stream,
|
| 375 |
-
sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256)), sycl::range<3>(1, 1, 256)),
|
| 376 |
-
[=](sycl::nd_item<3> item_ct1) { abs_op(x, dst, k, item_ct1); });
|
| 377 |
}
|
| 378 |
|
| 379 |
-
|
| 380 |
template<typename T>
|
| 381 |
-
static void
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
stream,
|
| 386 |
-
sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256)), sycl::range<3>(1, 1, 256)),
|
| 387 |
-
[=](sycl::nd_item<3> item_ct1) { elu_op(x, dst, k, item_ct1); });
|
| 388 |
}
|
| 389 |
|
| 390 |
template<typename T>
|
| 391 |
-
static void
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
|
| 396 |
-
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
|
| 397 |
-
[=](sycl::nd_item<3> item_ct1) { gelu_quick(x, dst, k, item_ct1); });
|
| 398 |
}
|
| 399 |
|
| 400 |
-
|
| 401 |
template<typename T>
|
| 402 |
-
static void
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
|
| 407 |
-
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
|
| 408 |
-
[=](sycl::nd_item<3> item_ct1) { gelu_erf(x, dst, k, item_ct1); });
|
| 409 |
}
|
| 410 |
|
| 411 |
template<typename T>
|
| 412 |
-
static void
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE),
|
| 417 |
-
sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE)),
|
| 418 |
-
[=](sycl::nd_item<3> item_ct1) { tanh(x, dst, k, item_ct1); });
|
| 419 |
}
|
| 420 |
|
| 421 |
template<typename T>
|
| 422 |
-
static void
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE),
|
| 427 |
-
sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)),
|
| 428 |
-
[=](sycl::nd_item<3> item_ct1) { relu(x, dst, k, item_ct1); });
|
| 429 |
}
|
| 430 |
|
| 431 |
template<typename T>
|
| 432 |
-
static void
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
stream,
|
| 437 |
-
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE),
|
| 438 |
-
sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE)),
|
| 439 |
-
[=](sycl::nd_item<3> item_ct1) { hardsigmoid(x, dst, k, item_ct1); });
|
| 440 |
}
|
| 441 |
|
| 442 |
template<typename T>
|
| 443 |
-
static void
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
stream,
|
| 448 |
-
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE),
|
| 449 |
-
sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE)),
|
| 450 |
-
[=](sycl::nd_item<3> item_ct1) { hardswish(x, dst, k, item_ct1); });
|
| 451 |
}
|
| 452 |
|
| 453 |
template<typename T>
|
| 454 |
-
static void
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE),
|
| 459 |
-
sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)),
|
| 460 |
-
[=](sycl::nd_item<3> item_ct1) { exp(x, dst, k, item_ct1); });
|
| 461 |
}
|
| 462 |
|
| 463 |
template<typename T>
|
| 464 |
-
static void
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE),
|
| 469 |
-
sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)),
|
| 470 |
-
[=](sycl::nd_item<3> item_ct1) { log(x, dst, k, item_ct1); });
|
| 471 |
}
|
| 472 |
|
| 473 |
template<typename T>
|
| 474 |
-
static void
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE),
|
| 479 |
-
sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)),
|
| 480 |
-
[=](sycl::nd_item<3> item_ct1) { neg(x, dst, k, item_ct1); });
|
| 481 |
}
|
| 482 |
|
| 483 |
-
template<typename
|
| 484 |
-
static void
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 491 |
}
|
| 492 |
|
| 493 |
-
template<typename T>
|
| 494 |
-
static void
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 502 |
}
|
| 503 |
|
| 504 |
template<typename T>
|
| 505 |
-
static void
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
sycl::range<3>(1, 1, SYCL_SQRT_BLOCK_SIZE)),
|
| 511 |
-
[=](sycl::nd_item<3> item_ct1) { sqrt(x, dst, k, item_ct1); });
|
| 512 |
}
|
| 513 |
|
| 514 |
template<typename T>
|
| 515 |
-
static void
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
[=](sycl::nd_item<3> item_ct1) { sin(x, dst, k, item_ct1); });
|
| 522 |
}
|
| 523 |
|
| 524 |
template<typename T>
|
| 525 |
-
static void
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
[=](sycl::nd_item<3> item_ct1) { cos(x, dst, k, item_ct1); });
|
| 532 |
}
|
| 533 |
|
| 534 |
template<typename T>
|
| 535 |
-
static void
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)),
|
| 542 |
-
[=](sycl::nd_item<3> item_ct1) { leaky_relu(x, dst, k, negative_slope, item_ct1); });
|
| 543 |
}
|
| 544 |
|
| 545 |
-
|
| 546 |
-
static void
|
| 547 |
-
|
| 548 |
-
|
|
|
|
|
|
|
| 549 |
sycl_parallel_for(stream,
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 553 |
}
|
| 554 |
|
| 555 |
template<typename T>
|
|
@@ -558,7 +405,7 @@ static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01,
|
|
| 558 |
const int ne12, const int ne13, const float sf0, const float sf1,
|
| 559 |
const float sf2, const float sf3, queue_ptr stream) {
|
| 560 |
int dst_size = ne10 * ne11 * ne12 * ne13;
|
| 561 |
-
int num_blocks = (dst_size
|
| 562 |
sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE);
|
| 563 |
sycl_parallel_for<1>(
|
| 564 |
stream, sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
|
@@ -570,7 +417,7 @@ template<typename T>
|
|
| 570 |
static void pad_sycl(const T *x, T *dst, const int ne00,
|
| 571 |
const int ne01, const int ne02, const int ne0,
|
| 572 |
const int ne1, const int ne2, queue_ptr stream) {
|
| 573 |
-
int num_blocks = (ne0
|
| 574 |
sycl::range<3> gridDim(ne2, ne1, num_blocks);
|
| 575 |
sycl_parallel_for(stream,
|
| 576 |
sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE),
|
|
@@ -578,22 +425,11 @@ static void pad_sycl(const T *x, T *dst, const int ne00,
|
|
| 578 |
[=](sycl::nd_item<3> item_ct1) { pad(x, dst, ne0, ne00, ne01, ne02, item_ct1); });
|
| 579 |
}
|
| 580 |
|
| 581 |
-
template<typename
|
| 582 |
-
static void
|
| 583 |
-
const float max, const int k,
|
| 584 |
-
queue_ptr stream) {
|
| 585 |
-
const int num_blocks = (k + SYCL_CLAMP_BLOCK_SIZE - 1) / SYCL_CLAMP_BLOCK_SIZE;
|
| 586 |
-
sycl_parallel_for(stream,
|
| 587 |
-
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE),
|
| 588 |
-
sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE)),
|
| 589 |
-
[=](sycl::nd_item<3> item_ct1) { clamp(x, dst, min, max, k, item_ct1); });
|
| 590 |
-
}
|
| 591 |
-
|
| 592 |
-
inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
| 593 |
#if defined (GGML_SYCL_F16)
|
| 594 |
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
| 595 |
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
| 596 |
-
|
| 597 |
#else
|
| 598 |
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
| 599 |
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
|
@@ -606,14 +442,14 @@ inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
|
| 606 |
case GGML_TYPE_F16:
|
| 607 |
{
|
| 608 |
auto data_pts = cast_data<sycl::half>(dst);
|
| 609 |
-
|
| 610 |
break;
|
| 611 |
}
|
| 612 |
#endif
|
| 613 |
case GGML_TYPE_F32:
|
| 614 |
{
|
| 615 |
auto data_pts = cast_data<float>(dst);
|
| 616 |
-
|
| 617 |
break;
|
| 618 |
}
|
| 619 |
default:
|
|
@@ -621,11 +457,11 @@ inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
|
| 621 |
}
|
| 622 |
}
|
| 623 |
|
| 624 |
-
|
|
|
|
| 625 |
#if defined (GGML_SYCL_F16)
|
| 626 |
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
| 627 |
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
| 628 |
-
|
| 629 |
#else
|
| 630 |
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
| 631 |
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
|
@@ -633,52 +469,66 @@ inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
|
| 633 |
GGML_ASSERT(dst->src[0]->type == dst->type);
|
| 634 |
dpct::queue_ptr main_stream = ctx.stream();
|
| 635 |
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
| 636 |
-
|
| 637 |
-
|
| 638 |
-
|
| 639 |
-
|
| 640 |
-
|
| 641 |
-
|
| 642 |
-
|
| 643 |
-
|
| 644 |
-
|
| 645 |
-
|
| 646 |
-
|
| 647 |
-
|
| 648 |
-
|
| 649 |
-
|
| 650 |
-
|
| 651 |
-
|
| 652 |
-
|
| 653 |
}
|
| 654 |
-
}
|
| 655 |
-
|
| 656 |
-
|
| 657 |
-
inline void ggml_sycl_op_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
| 658 |
-
#if defined (GGML_SYCL_F16)
|
| 659 |
-
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
| 660 |
-
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
| 661 |
-
|
| 662 |
-
#else
|
| 663 |
-
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
| 664 |
-
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
| 665 |
-
#endif
|
| 666 |
-
GGML_ASSERT(dst->src[0]->type == dst->type);
|
| 667 |
-
dpct::queue_ptr main_stream = ctx.stream();
|
| 668 |
-
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
| 669 |
switch (dst->type) {
|
| 670 |
#if defined (GGML_SYCL_F16)
|
| 671 |
case GGML_TYPE_F16:
|
| 672 |
{
|
| 673 |
-
|
| 674 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 675 |
break;
|
| 676 |
}
|
| 677 |
#endif
|
| 678 |
case GGML_TYPE_F32:
|
| 679 |
{
|
| 680 |
-
|
| 681 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 682 |
break;
|
| 683 |
}
|
| 684 |
default:
|
|
@@ -686,7 +536,8 @@ inline void ggml_sycl_op_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
|
| 686 |
}
|
| 687 |
}
|
| 688 |
|
| 689 |
-
|
|
|
|
| 690 |
#if defined (GGML_SYCL_F16)
|
| 691 |
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
| 692 |
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
|
@@ -695,52 +546,31 @@ inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst
|
|
| 695 |
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
| 696 |
#endif
|
| 697 |
GGML_ASSERT(dst->src[0]->type == dst->type);
|
| 698 |
-
dpct::queue_ptr main_stream = ctx.stream();
|
| 699 |
-
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
| 700 |
-
switch (dst->type) {
|
| 701 |
-
#if defined (GGML_SYCL_F16)
|
| 702 |
-
case GGML_TYPE_F16:
|
| 703 |
-
{
|
| 704 |
-
auto data_pts = cast_data<sycl::half>(dst);
|
| 705 |
-
silu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
| 706 |
-
break;
|
| 707 |
-
}
|
| 708 |
-
#endif
|
| 709 |
-
case GGML_TYPE_F32:
|
| 710 |
-
{
|
| 711 |
-
auto data_pts = cast_data<float>(dst);
|
| 712 |
-
silu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
| 713 |
-
break;
|
| 714 |
-
}
|
| 715 |
-
default:
|
| 716 |
-
GGML_ABORT("GGML tensor type not supported!\n");
|
| 717 |
-
}
|
| 718 |
-
}
|
| 719 |
|
| 720 |
-
inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
| 721 |
-
#if defined (GGML_SYCL_F16)
|
| 722 |
-
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
| 723 |
-
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
| 724 |
-
#else
|
| 725 |
-
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
| 726 |
-
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
| 727 |
-
#endif
|
| 728 |
-
GGML_ASSERT(dst->src[0]->type == dst->type);
|
| 729 |
dpct::queue_ptr main_stream = ctx.stream();
|
| 730 |
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 731 |
switch (dst->type) {
|
| 732 |
#if defined (GGML_SYCL_F16)
|
| 733 |
case GGML_TYPE_F16:
|
| 734 |
{
|
| 735 |
auto data_pts = cast_data<sycl::half>(dst);
|
| 736 |
-
|
|
|
|
|
|
|
| 737 |
break;
|
| 738 |
}
|
| 739 |
#endif
|
| 740 |
case GGML_TYPE_F32:
|
| 741 |
{
|
| 742 |
auto data_pts = cast_data<float>(dst);
|
| 743 |
-
|
|
|
|
|
|
|
| 744 |
break;
|
| 745 |
}
|
| 746 |
default:
|
|
@@ -748,7 +578,8 @@ inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst
|
|
| 748 |
}
|
| 749 |
}
|
| 750 |
|
| 751 |
-
|
|
|
|
| 752 |
#if defined (GGML_SYCL_F16)
|
| 753 |
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
| 754 |
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
|
@@ -757,6 +588,7 @@ inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor
|
|
| 757 |
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
| 758 |
#endif
|
| 759 |
GGML_ASSERT(dst->src[0]->type == dst->type);
|
|
|
|
| 760 |
dpct::queue_ptr main_stream = ctx.stream();
|
| 761 |
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
| 762 |
switch (dst->type) {
|
|
@@ -764,14 +596,16 @@ inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor
|
|
| 764 |
case GGML_TYPE_F16:
|
| 765 |
{
|
| 766 |
auto data_pts = cast_data<sycl::half>(dst);
|
| 767 |
-
|
|
|
|
| 768 |
break;
|
| 769 |
}
|
| 770 |
#endif
|
| 771 |
case GGML_TYPE_F32:
|
| 772 |
{
|
| 773 |
auto data_pts = cast_data<float>(dst);
|
| 774 |
-
|
|
|
|
| 775 |
break;
|
| 776 |
}
|
| 777 |
default:
|
|
@@ -779,593 +613,320 @@ inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor
|
|
| 779 |
}
|
| 780 |
}
|
| 781 |
|
| 782 |
-
|
| 783 |
-
#if defined (GGML_SYCL_F16)
|
| 784 |
-
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
| 785 |
-
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
| 786 |
-
#else
|
| 787 |
-
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
| 788 |
-
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
| 789 |
-
#endif
|
| 790 |
-
GGML_ASSERT(dst->src[0]->type == dst->type);
|
| 791 |
-
dpct::queue_ptr main_stream = ctx.stream();
|
| 792 |
-
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
| 793 |
-
switch (dst->type) {
|
| 794 |
-
#if defined (GGML_SYCL_F16)
|
| 795 |
-
case GGML_TYPE_F16:
|
| 796 |
-
{
|
| 797 |
-
auto data_pts = cast_data<sycl::half>(dst);
|
| 798 |
-
gelu_erf_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
| 799 |
-
break;
|
| 800 |
-
}
|
| 801 |
-
#endif
|
| 802 |
-
case GGML_TYPE_F32:
|
| 803 |
-
{
|
| 804 |
-
auto data_pts = cast_data<float>(dst);
|
| 805 |
-
gelu_erf_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
| 806 |
-
break;
|
| 807 |
-
}
|
| 808 |
-
default:
|
| 809 |
-
GGML_ABORT("GGML tensor type not supported!\n");
|
| 810 |
-
}
|
| 811 |
-
}
|
| 812 |
-
|
| 813 |
-
|
| 814 |
-
inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
| 815 |
-
#if defined (GGML_SYCL_F16)
|
| 816 |
-
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
| 817 |
-
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
| 818 |
-
#else
|
| 819 |
-
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
| 820 |
-
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
| 821 |
-
#endif
|
| 822 |
-
GGML_ASSERT(dst->src[0]->type == dst->type);
|
| 823 |
-
dpct::queue_ptr main_stream = ctx.stream();
|
| 824 |
-
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
| 825 |
-
switch (dst->type) {
|
| 826 |
-
#if defined (GGML_SYCL_F16)
|
| 827 |
-
case GGML_TYPE_F16:
|
| 828 |
-
{
|
| 829 |
-
auto data_pts = cast_data<sycl::half>(dst);
|
| 830 |
-
tanh_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
| 831 |
-
break;
|
| 832 |
-
}
|
| 833 |
-
#endif
|
| 834 |
-
case GGML_TYPE_F32:
|
| 835 |
-
{
|
| 836 |
-
auto data_pts = cast_data<float>(dst);
|
| 837 |
-
tanh_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
| 838 |
-
break;
|
| 839 |
-
}
|
| 840 |
-
default:
|
| 841 |
-
GGML_ABORT("GGML tensor type not supported!\n");
|
| 842 |
-
}
|
| 843 |
-
}
|
| 844 |
-
|
| 845 |
-
inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
| 846 |
-
#if defined (GGML_SYCL_F16)
|
| 847 |
-
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
| 848 |
-
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
| 849 |
-
#else
|
| 850 |
-
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
| 851 |
-
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
| 852 |
-
#endif
|
| 853 |
-
GGML_ASSERT(dst->src[0]->type == dst->type);
|
| 854 |
-
dpct::queue_ptr main_stream = ctx.stream();
|
| 855 |
-
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
| 856 |
-
|
| 857 |
-
switch (dst->type) {
|
| 858 |
-
#if defined (GGML_SYCL_F16)
|
| 859 |
-
case GGML_TYPE_F16:
|
| 860 |
-
{
|
| 861 |
-
auto data_pts = cast_data<sycl::half>(dst);
|
| 862 |
-
relu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
| 863 |
-
break;
|
| 864 |
-
}
|
| 865 |
-
#endif
|
| 866 |
-
case GGML_TYPE_F32:
|
| 867 |
-
{
|
| 868 |
-
auto data_pts = cast_data<float>(dst);
|
| 869 |
-
relu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
| 870 |
-
break;
|
| 871 |
-
}
|
| 872 |
-
default:
|
| 873 |
-
GGML_ABORT("GGML tensor type not supported!\n");
|
| 874 |
-
}
|
| 875 |
-
}
|
| 876 |
|
| 877 |
-
inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
| 878 |
-
#if defined (GGML_SYCL_F16)
|
| 879 |
-
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
| 880 |
-
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
| 881 |
-
#else
|
| 882 |
-
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
| 883 |
-
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
| 884 |
-
#endif
|
| 885 |
-
GGML_ASSERT(dst->src[0]->type == dst->type);
|
| 886 |
|
| 887 |
-
dpct::queue_ptr main_stream = ctx.stream();
|
| 888 |
-
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
| 889 |
|
| 890 |
-
|
| 891 |
-
|
| 892 |
-
|
| 893 |
-
|
| 894 |
-
|
| 895 |
-
|
| 896 |
-
|
| 897 |
-
|
| 898 |
-
|
| 899 |
-
|
| 900 |
-
|
| 901 |
-
auto data_pts = cast_data<float>(dst);
|
| 902 |
-
hardsigmoid_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
| 903 |
-
break;
|
| 904 |
-
}
|
| 905 |
-
default:
|
| 906 |
-
GGML_ABORT("GGML tensor type not supported!\n");
|
| 907 |
-
}
|
| 908 |
}
|
| 909 |
|
| 910 |
-
inline void
|
| 911 |
-
|
| 912 |
-
|
| 913 |
-
|
| 914 |
-
|
| 915 |
-
|
| 916 |
-
|
| 917 |
-
|
| 918 |
-
|
| 919 |
-
|
| 920 |
-
|
| 921 |
-
switch (dst->type) {
|
| 922 |
-
#if defined (GGML_SYCL_F16)
|
| 923 |
-
case GGML_TYPE_F16:
|
| 924 |
-
{
|
| 925 |
-
auto data_pts = cast_data<sycl::half>(dst);
|
| 926 |
-
hardswish_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
| 927 |
-
break;
|
| 928 |
-
}
|
| 929 |
-
#endif
|
| 930 |
-
case GGML_TYPE_F32:
|
| 931 |
-
{
|
| 932 |
-
auto data_pts = cast_data<float>(dst);
|
| 933 |
-
hardswish_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
| 934 |
-
break;
|
| 935 |
-
}
|
| 936 |
-
default:
|
| 937 |
-
GGML_ABORT("GGML tensor type not supported!\n");
|
| 938 |
-
}
|
| 939 |
}
|
| 940 |
|
| 941 |
-
inline void
|
| 942 |
-
|
| 943 |
-
|
| 944 |
-
|
| 945 |
-
|
| 946 |
-
|
| 947 |
-
|
| 948 |
-
|
| 949 |
-
|
| 950 |
-
|
| 951 |
-
|
| 952 |
-
switch (dst->type) {
|
| 953 |
-
#if defined (GGML_SYCL_F16)
|
| 954 |
-
case GGML_TYPE_F16:
|
| 955 |
-
{
|
| 956 |
-
auto data_pts = cast_data<sycl::half>(dst);
|
| 957 |
-
exp_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
| 958 |
-
break;
|
| 959 |
-
}
|
| 960 |
-
#endif
|
| 961 |
-
case GGML_TYPE_F32:
|
| 962 |
-
{
|
| 963 |
-
auto data_pts = cast_data<float>(dst);
|
| 964 |
-
exp_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
| 965 |
-
break;
|
| 966 |
-
}
|
| 967 |
-
default:
|
| 968 |
-
GGML_ABORT("GGML tensor type not supported!\n");
|
| 969 |
-
}
|
| 970 |
}
|
| 971 |
|
| 972 |
-
inline void
|
| 973 |
-
|
| 974 |
-
|
| 975 |
-
|
| 976 |
-
|
| 977 |
-
|
| 978 |
-
|
| 979 |
-
|
| 980 |
-
|
| 981 |
-
|
| 982 |
-
|
| 983 |
-
switch (dst->type) {
|
| 984 |
-
#if defined (GGML_SYCL_F16)
|
| 985 |
-
case GGML_TYPE_F16:
|
| 986 |
-
{
|
| 987 |
-
auto data_pts = cast_data<sycl::half>(dst);
|
| 988 |
-
log_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
| 989 |
-
break;
|
| 990 |
-
}
|
| 991 |
-
#endif
|
| 992 |
-
case GGML_TYPE_F32:
|
| 993 |
-
{
|
| 994 |
-
auto data_pts = cast_data<float>(dst);
|
| 995 |
-
log_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
| 996 |
-
break;
|
| 997 |
-
}
|
| 998 |
-
default:
|
| 999 |
-
GGML_ABORT("GGML tensor type not supported!\n");
|
| 1000 |
-
}
|
| 1001 |
}
|
| 1002 |
|
| 1003 |
-
inline void
|
| 1004 |
-
|
| 1005 |
-
|
| 1006 |
-
|
| 1007 |
-
|
| 1008 |
-
|
| 1009 |
-
|
| 1010 |
-
|
| 1011 |
-
|
| 1012 |
-
|
| 1013 |
-
|
| 1014 |
-
switch (dst->type) {
|
| 1015 |
-
#if defined (GGML_SYCL_F16)
|
| 1016 |
-
case GGML_TYPE_F16:
|
| 1017 |
-
{
|
| 1018 |
-
auto data_pts = cast_data<sycl::half>(dst);
|
| 1019 |
-
sigmoid_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
| 1020 |
-
break;
|
| 1021 |
-
}
|
| 1022 |
-
#endif
|
| 1023 |
-
case GGML_TYPE_F32:
|
| 1024 |
-
{
|
| 1025 |
-
auto data_pts = cast_data<float>(dst);
|
| 1026 |
-
sigmoid_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
| 1027 |
-
break;
|
| 1028 |
-
}
|
| 1029 |
-
default:
|
| 1030 |
-
GGML_ABORT("GGML tensor type not supported!\n");
|
| 1031 |
-
}
|
| 1032 |
}
|
| 1033 |
|
| 1034 |
-
inline void
|
| 1035 |
-
|
| 1036 |
-
|
| 1037 |
-
|
| 1038 |
-
|
| 1039 |
-
|
| 1040 |
-
|
| 1041 |
-
|
| 1042 |
-
|
| 1043 |
-
|
| 1044 |
-
|
| 1045 |
-
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
| 1046 |
-
switch (dst->type) {
|
| 1047 |
-
#if defined (GGML_SYCL_F16)
|
| 1048 |
-
case GGML_TYPE_F16:
|
| 1049 |
-
{
|
| 1050 |
-
auto data_pts = cast_data<sycl::half>(dst);
|
| 1051 |
-
sqrt_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
| 1052 |
-
break;
|
| 1053 |
-
}
|
| 1054 |
-
#endif
|
| 1055 |
-
case GGML_TYPE_F32:
|
| 1056 |
-
{
|
| 1057 |
-
auto data_pts = cast_data<float>(dst);
|
| 1058 |
-
sqrt_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
| 1059 |
-
break;
|
| 1060 |
-
}
|
| 1061 |
-
default:
|
| 1062 |
-
GGML_ABORT("GGML tensor type not supported!\n");
|
| 1063 |
-
}
|
| 1064 |
}
|
| 1065 |
-
|
| 1066 |
-
inline void
|
| 1067 |
-
|
| 1068 |
-
|
| 1069 |
-
|
| 1070 |
-
|
| 1071 |
-
|
| 1072 |
-
|
| 1073 |
-
|
| 1074 |
-
|
| 1075 |
-
|
| 1076 |
-
|
| 1077 |
-
switch (dst->type) {
|
| 1078 |
-
#if defined (GGML_SYCL_F16)
|
| 1079 |
-
case GGML_TYPE_F16:
|
| 1080 |
-
{
|
| 1081 |
-
auto data_pts = cast_data<sycl::half>(dst);
|
| 1082 |
-
sin_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
| 1083 |
-
break;
|
| 1084 |
-
}
|
| 1085 |
-
#endif
|
| 1086 |
-
case GGML_TYPE_F32:
|
| 1087 |
-
{
|
| 1088 |
-
auto data_pts = cast_data<float>(dst);
|
| 1089 |
-
sin_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
| 1090 |
-
break;
|
| 1091 |
-
}
|
| 1092 |
-
default:
|
| 1093 |
-
GGML_ABORT("GGML tensor type not supported!\n");
|
| 1094 |
-
}
|
| 1095 |
}
|
| 1096 |
|
| 1097 |
-
inline void
|
| 1098 |
-
|
| 1099 |
-
|
| 1100 |
-
|
| 1101 |
-
|
| 1102 |
-
|
| 1103 |
-
|
| 1104 |
-
|
| 1105 |
-
|
| 1106 |
-
|
| 1107 |
-
|
| 1108 |
-
switch (dst->type) {
|
| 1109 |
-
#if defined (GGML_SYCL_F16)
|
| 1110 |
-
case GGML_TYPE_F16:
|
| 1111 |
-
{
|
| 1112 |
-
auto data_pts = cast_data<sycl::half>(dst);
|
| 1113 |
-
cos_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
| 1114 |
-
break;
|
| 1115 |
-
}
|
| 1116 |
-
#endif
|
| 1117 |
-
case GGML_TYPE_F32:
|
| 1118 |
-
{
|
| 1119 |
-
auto data_pts = cast_data<float>(dst);
|
| 1120 |
-
cos_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
| 1121 |
-
break;
|
| 1122 |
-
}
|
| 1123 |
-
default:
|
| 1124 |
-
GGML_ABORT("GGML tensor type not supported!\n");
|
| 1125 |
-
}
|
| 1126 |
}
|
| 1127 |
|
| 1128 |
-
inline void
|
| 1129 |
-
|
| 1130 |
-
|
| 1131 |
-
|
| 1132 |
-
|
| 1133 |
-
|
| 1134 |
-
|
| 1135 |
-
|
| 1136 |
-
|
| 1137 |
-
|
| 1138 |
-
|
| 1139 |
-
switch (dst->type) {
|
| 1140 |
-
#if defined (GGML_SYCL_F16)
|
| 1141 |
-
case GGML_TYPE_F16:
|
| 1142 |
-
{
|
| 1143 |
-
auto data_pts = cast_data<sycl::half>(dst);
|
| 1144 |
-
step_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
| 1145 |
-
break;
|
| 1146 |
-
}
|
| 1147 |
-
#endif
|
| 1148 |
-
case GGML_TYPE_F32:
|
| 1149 |
-
{
|
| 1150 |
-
auto data_pts = cast_data<float>(dst);
|
| 1151 |
-
step_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
| 1152 |
-
break;
|
| 1153 |
-
}
|
| 1154 |
-
default:
|
| 1155 |
-
GGML_ABORT("GGML tensor type not supported!\n");
|
| 1156 |
-
}
|
| 1157 |
}
|
| 1158 |
|
| 1159 |
-
inline void
|
| 1160 |
-
|
| 1161 |
-
|
| 1162 |
-
|
| 1163 |
-
|
| 1164 |
-
|
| 1165 |
-
|
| 1166 |
-
|
| 1167 |
-
|
| 1168 |
-
|
| 1169 |
-
|
| 1170 |
-
switch (dst->type) {
|
| 1171 |
-
#if defined (GGML_SYCL_F16)
|
| 1172 |
-
case GGML_TYPE_F16:
|
| 1173 |
-
{
|
| 1174 |
-
auto data_pts = cast_data<sycl::half>(dst);
|
| 1175 |
-
neg_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
| 1176 |
-
break;
|
| 1177 |
-
}
|
| 1178 |
-
#endif
|
| 1179 |
-
case GGML_TYPE_F32:
|
| 1180 |
-
{
|
| 1181 |
-
auto data_pts = cast_data<float>(dst);
|
| 1182 |
-
neg_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
| 1183 |
-
break;
|
| 1184 |
-
}
|
| 1185 |
-
default:
|
| 1186 |
-
GGML_ABORT("GGML tensor type not supported!\n");
|
| 1187 |
-
}
|
| 1188 |
}
|
| 1189 |
|
| 1190 |
-
inline void
|
| 1191 |
-
|
| 1192 |
-
|
| 1193 |
-
|
| 1194 |
-
|
| 1195 |
-
|
| 1196 |
-
|
| 1197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1198 |
|
| 1199 |
-
|
| 1200 |
-
|
| 1201 |
-
|
| 1202 |
-
|
| 1203 |
-
|
| 1204 |
-
|
| 1205 |
-
|
| 1206 |
-
|
| 1207 |
-
|
| 1208 |
-
|
| 1209 |
-
|
| 1210 |
-
break;
|
| 1211 |
-
}
|
| 1212 |
-
#endif
|
| 1213 |
-
case GGML_TYPE_F32:
|
| 1214 |
-
{
|
| 1215 |
-
auto data_pts = cast_data<float>(dst);
|
| 1216 |
-
leaky_relu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), negative_slope, main_stream);
|
| 1217 |
-
break;
|
| 1218 |
-
}
|
| 1219 |
-
default:
|
| 1220 |
-
GGML_ABORT("GGML tensor type not supported!\n");
|
| 1221 |
-
}
|
| 1222 |
}
|
| 1223 |
|
| 1224 |
-
inline void
|
| 1225 |
-
|
| 1226 |
-
|
| 1227 |
-
|
| 1228 |
-
|
| 1229 |
-
|
| 1230 |
-
|
| 1231 |
-
|
| 1232 |
-
|
| 1233 |
-
|
| 1234 |
-
|
| 1235 |
-
switch (dst->type) {
|
| 1236 |
-
#if defined (GGML_SYCL_F16)
|
| 1237 |
-
case GGML_TYPE_F16:
|
| 1238 |
-
{
|
| 1239 |
-
auto data_pts = cast_data<sycl::half>(dst);
|
| 1240 |
-
sqr_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
| 1241 |
-
break;
|
| 1242 |
-
}
|
| 1243 |
-
#endif
|
| 1244 |
-
case GGML_TYPE_F32:
|
| 1245 |
-
{
|
| 1246 |
-
auto data_pts = cast_data<float>(dst);
|
| 1247 |
-
sqr_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
|
| 1248 |
-
break;
|
| 1249 |
-
}
|
| 1250 |
-
default:
|
| 1251 |
-
GGML_ABORT("GGML tensor type not supported!\n");
|
| 1252 |
-
}
|
| 1253 |
}
|
| 1254 |
|
| 1255 |
-
inline void
|
| 1256 |
-
|
| 1257 |
-
|
| 1258 |
-
|
| 1259 |
-
|
| 1260 |
-
|
| 1261 |
-
|
| 1262 |
-
|
| 1263 |
-
|
|
|
|
|
|
|
|
|
|
| 1264 |
|
| 1265 |
-
|
| 1266 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1267 |
|
| 1268 |
-
|
| 1269 |
-
|
| 1270 |
-
|
| 1271 |
-
|
| 1272 |
-
|
| 1273 |
-
|
| 1274 |
-
|
| 1275 |
-
|
| 1276 |
-
|
| 1277 |
-
|
| 1278 |
-
|
| 1279 |
-
main_stream);
|
| 1280 |
-
break;
|
| 1281 |
-
}
|
| 1282 |
-
#endif
|
| 1283 |
-
case GGML_TYPE_F32:
|
| 1284 |
-
{
|
| 1285 |
-
auto data_pts = cast_data<float>(dst);
|
| 1286 |
-
upscale_sycl(data_pts.src, data_pts.dst, dst->src[0]->nb[0], dst->src[0]->nb[1], dst->src[0]->nb[2],
|
| 1287 |
-
dst->src[0]->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3,
|
| 1288 |
-
main_stream);
|
| 1289 |
-
break;
|
| 1290 |
-
}
|
| 1291 |
-
default:
|
| 1292 |
-
GGML_ABORT("GGML tensor type not supported!\n");
|
| 1293 |
-
}
|
| 1294 |
}
|
| 1295 |
|
| 1296 |
-
inline void
|
| 1297 |
-
|
| 1298 |
-
|
| 1299 |
-
|
| 1300 |
-
|
| 1301 |
-
|
| 1302 |
-
|
| 1303 |
-
|
| 1304 |
-
|
| 1305 |
-
|
| 1306 |
-
|
| 1307 |
-
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
| 1308 |
-
switch (dst->type) {
|
| 1309 |
-
#if defined (GGML_SYCL_F16)
|
| 1310 |
-
case GGML_TYPE_F16:
|
| 1311 |
-
{
|
| 1312 |
-
auto data_pts = cast_data<sycl::half>(dst);
|
| 1313 |
-
pad_sycl(data_pts.src, data_pts.dst, dst->src[0]->ne[0], dst->src[0]->ne[1], dst->src[0]->ne[2], dst->ne[0],
|
| 1314 |
-
dst->ne[1], dst->ne[2], main_stream);
|
| 1315 |
-
break;
|
| 1316 |
-
}
|
| 1317 |
-
#endif
|
| 1318 |
-
case GGML_TYPE_F32:
|
| 1319 |
-
{
|
| 1320 |
-
auto data_pts = cast_data<float>(dst);
|
| 1321 |
-
pad_sycl(data_pts.src, data_pts.dst, dst->src[0]->ne[0], dst->src[0]->ne[1], dst->src[0]->ne[2], dst->ne[0],
|
| 1322 |
-
dst->ne[1], dst->ne[2], main_stream);
|
| 1323 |
-
break;
|
| 1324 |
-
}
|
| 1325 |
-
default:
|
| 1326 |
-
GGML_ABORT("GGML tensor type not supported!\n");
|
| 1327 |
-
}
|
| 1328 |
}
|
| 1329 |
|
| 1330 |
-
inline void
|
| 1331 |
-
|
| 1332 |
-
|
| 1333 |
-
|
| 1334 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1335 |
|
| 1336 |
-
|
| 1337 |
-
|
| 1338 |
-
|
| 1339 |
-
|
| 1340 |
-
|
| 1341 |
-
|
| 1342 |
-
|
| 1343 |
-
|
| 1344 |
-
|
| 1345 |
-
|
|
|
|
|
|
|
| 1346 |
|
| 1347 |
-
|
| 1348 |
-
|
| 1349 |
-
|
| 1350 |
-
|
| 1351 |
-
|
| 1352 |
-
|
| 1353 |
-
|
| 1354 |
-
|
| 1355 |
-
|
| 1356 |
-
|
| 1357 |
-
|
| 1358 |
-
|
| 1359 |
-
|
| 1360 |
-
|
| 1361 |
-
|
| 1362 |
-
|
| 1363 |
-
|
| 1364 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1365 |
}
|
| 1366 |
|
| 1367 |
-
inline void
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1368 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1369 |
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
| 1370 |
GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32);
|
| 1371 |
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
@@ -1381,7 +942,40 @@ inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
|
|
| 1381 |
// int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
|
| 1382 |
int offset = dst->op_params[3] / 4; // offset in bytes
|
| 1383 |
|
| 1384 |
-
acc_f32_sycl(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), dst->src[1]->ne[0], dst->src[1]->ne[1], dst->src[1]->ne[2], nb1, nb2, offset, main_stream);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1385 |
}
|
| 1386 |
|
| 1387 |
|
|
@@ -1509,3 +1103,18 @@ void ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
| 1509 |
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
| 1510 |
ggml_sycl_op_elu(ctx, dst);
|
| 1511 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
#include "common.hpp"
|
| 2 |
+
#include "ggml-sycl/presets.hpp"
|
| 3 |
#include "ggml.h"
|
| 4 |
#include "element_wise.hpp"
|
| 5 |
|
| 6 |
+
#define SYCL_GLOBAL_ID_LOOP(K, ITEM) \
|
| 7 |
+
for (auto i = ITEM.get_global_id(0); i < (size_t)K; i += ITEM.get_global_range(0))
|
| 8 |
+
|
| 9 |
+
#define SYCL_LOCAL_ID_CALC(ITEM, IDX) \
|
| 10 |
+
(ITEM.get_local_range(IDX) * ITEM.get_group(IDX) + ITEM.get_local_id(IDX))
|
| 11 |
+
|
| 12 |
+
|
| 13 |
static void acc_f32(const float * x, const float * y, float * dst, const int ne,
|
| 14 |
const int ne10, const int ne11, const int ne12,
|
| 15 |
+
const int nb1, const int nb2, int offset, const sycl::nd_item<1> &item_ct1) {
|
| 16 |
+
const int i = SYCL_LOCAL_ID_CALC(item_ct1, 0);
|
|
|
|
| 17 |
if (i >= ne) {
|
| 18 |
return;
|
| 19 |
}
|
|
|
|
| 28 |
}
|
| 29 |
}
|
| 30 |
|
| 31 |
+
/* Unary OP funcs */
|
| 32 |
template<typename T>
|
| 33 |
+
static __dpct_inline__ T op_sgn(T x) {
|
| 34 |
+
return x > static_cast<T>(0.f) ? static_cast<T>(1.f) : ((x < static_cast<T>(0.f) ? static_cast<T>(-1.f) : static_cast<T>(0.f)));
|
|
|
|
|
|
|
| 35 |
}
|
| 36 |
|
| 37 |
template<typename T>
|
| 38 |
+
static __dpct_inline__ T op_abs(T x) {
|
| 39 |
+
return sycl::fabs(x);
|
|
|
|
|
|
|
| 40 |
}
|
| 41 |
|
| 42 |
template<typename T>
|
| 43 |
+
static __dpct_inline__ T op_elu(T x) {
|
| 44 |
+
return (x > static_cast<T>(0.f)) ? x : sycl::expm1(x);
|
|
|
|
|
|
|
| 45 |
}
|
| 46 |
|
| 47 |
template<typename T>
|
| 48 |
+
static __dpct_inline__ T op_gelu(T x) {
|
|
|
|
| 49 |
const T GELU_COEF_A = static_cast<T>(0.044715f);
|
| 50 |
const T SQRT_2_OVER_PI = static_cast<T>(0.79788456080286535587989211986876f);
|
| 51 |
+
return static_cast<T>(0.5f) * x *
|
| 52 |
+
(static_cast<T>(1.0f) +
|
| 53 |
+
sycl::tanh(SQRT_2_OVER_PI * x * (static_cast<T>(1.0f) + GELU_COEF_A * x * x)));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
}
|
| 55 |
|
| 56 |
template<typename T>
|
| 57 |
+
static __dpct_inline__ T op_silu(T x) {
|
| 58 |
+
return x / (static_cast<T>(1.0f) + sycl::native::exp(-x));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
}
|
| 60 |
|
| 61 |
template<typename T>
|
| 62 |
+
static __dpct_inline__ T op_gelu_quick(T x) {
|
| 63 |
+
const T GELU_QUICK_COEF_LOCAL = static_cast<T>(-1.702f);
|
| 64 |
+
return x * (static_cast<T>(1.0f) / (static_cast<T>(1.0f) + sycl::native::exp(GELU_QUICK_COEF_LOCAL * x)));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
}
|
| 66 |
|
| 67 |
template<typename T>
|
| 68 |
+
static __dpct_inline__ T op_gelu_erf(T x) {
|
| 69 |
const T SQRT_2_INV = static_cast<T>(0.70710678118654752440084436210484f);
|
| 70 |
+
return static_cast<T>(0.5f) * x * (static_cast<T>(1.0f) + sycl::erf(x * SQRT_2_INV));
|
|
|
|
|
|
|
|
|
|
| 71 |
}
|
| 72 |
|
| 73 |
template<typename T>
|
| 74 |
+
static __dpct_inline__ T op_tanh(T x) {
|
| 75 |
+
return sycl::tanh(x);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
}
|
| 77 |
|
| 78 |
template<typename T>
|
| 79 |
+
static __dpct_inline__ T op_relu(T x) {
|
| 80 |
+
return sycl::fmax(x, static_cast<T>(0));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
}
|
| 82 |
|
| 83 |
template<typename T>
|
| 84 |
+
static __dpct_inline__ T op_sigmoid(T x) {
|
| 85 |
+
return static_cast<T>(1.0f) / (static_cast<T>(1.0f) + sycl::native::exp(-x));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
}
|
| 87 |
|
| 88 |
template<typename T>
|
| 89 |
+
static __dpct_inline__ T op_sqrt(T x) {
|
| 90 |
+
return sycl::sqrt(x);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
}
|
| 92 |
|
| 93 |
template<typename T>
|
| 94 |
+
static __dpct_inline__ T op_sin(T x) {
|
| 95 |
+
return sycl::sin(x);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
}
|
| 97 |
|
| 98 |
template<typename T>
|
| 99 |
+
static __dpct_inline__ T op_cos(T x) {
|
| 100 |
+
return sycl::cos(x);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
}
|
| 102 |
|
| 103 |
template<typename T>
|
| 104 |
+
static __dpct_inline__ T op_hardsigmoid(T x) {
|
| 105 |
+
return sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
}
|
| 107 |
|
| 108 |
template<typename T>
|
| 109 |
+
static __dpct_inline__ T op_hardswish(T x) {
|
| 110 |
+
return x * sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
}
|
| 112 |
|
| 113 |
template<typename T>
|
| 114 |
+
static __dpct_inline__ T op_exp(T x) {
|
| 115 |
+
return sycl::exp(x);
|
| 116 |
+
}
|
|
|
|
| 117 |
|
| 118 |
+
template<typename T>
|
| 119 |
+
static __dpct_inline__ T op_log(T x) {
|
| 120 |
+
if (x <= static_cast<T>(0)) {
|
| 121 |
+
return neg_infinity<T>();
|
| 122 |
}
|
| 123 |
+
return sycl::log(x);
|
| 124 |
}
|
| 125 |
|
| 126 |
template<typename T>
|
| 127 |
+
static __dpct_inline__ T op_neg(T x) {
|
| 128 |
+
return -x;
|
| 129 |
+
}
|
|
|
|
| 130 |
|
| 131 |
+
template<typename T>
|
| 132 |
+
static __dpct_inline__ T op_step(T x) {
|
| 133 |
+
return (x > static_cast<T>(0.0f)) ? static_cast<T>(1.0f) : static_cast<T>(0.0f);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
}
|
| 135 |
|
| 136 |
template<typename T>
|
| 137 |
+
static __dpct_inline__ T op_leaky_relu(T x, float negative_slope) {
|
| 138 |
+
T neg_slope_T = static_cast<T>(negative_slope);
|
| 139 |
+
return sycl::fmax(x, static_cast<T>(0)) +
|
| 140 |
+
sycl::fmin(x, static_cast<T>(0.0f)) * neg_slope_T;
|
| 141 |
+
}
|
| 142 |
|
| 143 |
+
template<typename T>
|
| 144 |
+
static __dpct_inline__ T op_sqr(T x) {
|
| 145 |
+
return x * x;
|
|
|
|
| 146 |
}
|
| 147 |
|
| 148 |
template<typename T>
|
| 149 |
+
static __dpct_inline__ T op_clamp(T x, float min_val, float max_val) {
|
| 150 |
+
return x < static_cast<T>(min_val) ? static_cast<T>(min_val) : (x > static_cast<T>(max_val) ? static_cast<T>(max_val) : x);
|
| 151 |
+
}
|
|
|
|
| 152 |
|
| 153 |
+
template<typename T>
|
| 154 |
+
static void unary_op_sgn_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
| 155 |
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
| 156 |
+
dst[i] = op_sgn(x[i]);
|
| 157 |
}
|
|
|
|
| 158 |
}
|
| 159 |
|
| 160 |
template<typename T>
|
| 161 |
+
static void unary_op_abs_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
| 162 |
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
| 163 |
+
dst[i] = op_abs(x[i]);
|
|
|
|
|
|
|
|
|
|
| 164 |
}
|
|
|
|
|
|
|
| 165 |
}
|
| 166 |
|
| 167 |
template<typename T>
|
| 168 |
+
static void unary_op_elu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
| 169 |
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
| 170 |
+
dst[i] = op_elu(x[i]);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
}
|
|
|
|
| 172 |
}
|
| 173 |
|
| 174 |
+
template<typename T>
|
| 175 |
+
static void unary_op_gelu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
| 176 |
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
| 177 |
+
dst[i] = op_gelu(x[i]);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
}
|
| 180 |
|
| 181 |
+
template<typename T>
|
| 182 |
+
static void unary_op_silu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
| 183 |
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
| 184 |
+
dst[i] = op_silu(x[i]);
|
|
|
|
|
|
|
|
|
|
| 185 |
}
|
| 186 |
+
}
|
| 187 |
|
| 188 |
+
template<typename T>
|
| 189 |
+
static void unary_op_gelu_quick_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
| 190 |
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
| 191 |
+
dst[i] = op_gelu_quick(x[i]);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
}
|
| 193 |
}
|
| 194 |
|
|
|
|
| 195 |
template<typename T>
|
| 196 |
+
static void unary_op_gelu_erf_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
| 197 |
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
| 198 |
+
dst[i] = op_gelu_erf(x[i]);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
}
|
|
|
|
|
|
|
| 200 |
}
|
| 201 |
|
| 202 |
+
template<typename T>
|
| 203 |
+
static void unary_op_tanh_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
| 204 |
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
| 205 |
+
dst[i] = op_tanh(x[i]);
|
| 206 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
}
|
| 208 |
|
| 209 |
template<typename T>
|
| 210 |
+
static void unary_op_relu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
| 211 |
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
| 212 |
+
dst[i] = op_relu(x[i]);
|
| 213 |
+
}
|
|
|
|
|
|
|
|
|
|
| 214 |
}
|
| 215 |
|
| 216 |
template<typename T>
|
| 217 |
+
static void unary_op_sigmoid_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
| 218 |
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
| 219 |
+
dst[i] = op_sigmoid(x[i]);
|
| 220 |
+
}
|
|
|
|
|
|
|
|
|
|
| 221 |
}
|
| 222 |
|
| 223 |
template<typename T>
|
| 224 |
+
static void unary_op_sqrt_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
| 225 |
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
| 226 |
+
dst[i] = op_sqrt(x[i]);
|
| 227 |
+
}
|
|
|
|
|
|
|
| 228 |
}
|
| 229 |
|
| 230 |
template<typename T>
|
| 231 |
+
static void unary_op_sin_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
| 232 |
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
| 233 |
+
dst[i] = op_sin(x[i]);
|
| 234 |
+
}
|
|
|
|
|
|
|
|
|
|
| 235 |
}
|
| 236 |
|
|
|
|
| 237 |
template<typename T>
|
| 238 |
+
static void unary_op_cos_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
| 239 |
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
| 240 |
+
dst[i] = op_cos(x[i]);
|
| 241 |
+
}
|
|
|
|
|
|
|
|
|
|
| 242 |
}
|
| 243 |
|
| 244 |
template<typename T>
|
| 245 |
+
static void unary_op_hardsigmoid_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
| 246 |
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
| 247 |
+
dst[i] = op_hardsigmoid(x[i]);
|
| 248 |
+
}
|
|
|
|
|
|
|
|
|
|
| 249 |
}
|
| 250 |
|
|
|
|
| 251 |
template<typename T>
|
| 252 |
+
static void unary_op_hardswish_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
| 253 |
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
| 254 |
+
dst[i] = op_hardswish(x[i]);
|
| 255 |
+
}
|
|
|
|
|
|
|
|
|
|
| 256 |
}
|
| 257 |
|
| 258 |
template<typename T>
|
| 259 |
+
static void unary_op_exp_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
| 260 |
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
| 261 |
+
dst[i] = op_exp(x[i]);
|
| 262 |
+
}
|
|
|
|
|
|
|
|
|
|
| 263 |
}
|
| 264 |
|
| 265 |
template<typename T>
|
| 266 |
+
static void unary_op_log_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
| 267 |
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
| 268 |
+
dst[i] = op_log(x[i]);
|
| 269 |
+
}
|
|
|
|
|
|
|
|
|
|
| 270 |
}
|
| 271 |
|
| 272 |
template<typename T>
|
| 273 |
+
static void unary_op_neg_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
| 274 |
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
| 275 |
+
dst[i] = op_neg(x[i]);
|
| 276 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
}
|
| 278 |
|
| 279 |
template<typename T>
|
| 280 |
+
static void unary_op_step_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
| 281 |
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
| 282 |
+
dst[i] = op_step(x[i]);
|
| 283 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
}
|
| 285 |
|
| 286 |
template<typename T>
|
| 287 |
+
static void unary_op_leaky_relu_kernel(const T * x, T * dst, const int k, float negative_slope, const sycl::nd_item<1> &item_ct1) {
|
| 288 |
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
| 289 |
+
dst[i] = op_leaky_relu(x[i], negative_slope);
|
| 290 |
+
}
|
|
|
|
|
|
|
|
|
|
| 291 |
}
|
| 292 |
|
| 293 |
template<typename T>
|
| 294 |
+
static void unary_op_sqr_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
| 295 |
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
| 296 |
+
dst[i] = op_sqr(x[i]);
|
| 297 |
+
}
|
|
|
|
|
|
|
|
|
|
| 298 |
}
|
| 299 |
|
| 300 |
template<typename T>
|
| 301 |
+
static void unary_op_clamp_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1, float min_val, float max_val) {
|
| 302 |
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
| 303 |
+
dst[i] = op_clamp(x[i], min_val, max_val);
|
| 304 |
+
}
|
|
|
|
|
|
|
|
|
|
| 305 |
}
|
| 306 |
|
| 307 |
+
template<typename T>
|
| 308 |
+
static void upscale(const T *x, T *dst, const int nb00, const int nb01,
|
| 309 |
+
const int nb02, const int nb03, const int ne10, const int ne11,
|
| 310 |
+
const int ne12, const int ne13, const float sf0, const float sf1,
|
| 311 |
+
const float sf2, const float sf3, const sycl::nd_item<1> &item_ct1) {
|
| 312 |
+
int index = item_ct1.get_local_id(0) +
|
| 313 |
+
item_ct1.get_group(0) * item_ct1.get_local_range(0);
|
| 314 |
+
if (index >= ne10 * ne11 * ne12 * ne13) {
|
| 315 |
+
return;
|
| 316 |
+
}
|
| 317 |
+
// operation
|
| 318 |
+
int i10 = index % ne10;
|
| 319 |
+
int i11 = (index / ne10) % ne11;
|
| 320 |
+
int i12 = (index / (ne10 * ne11)) % ne12;
|
| 321 |
+
int i13 = (index / (ne10 * ne11 * ne12)) % ne13;
|
| 322 |
+
|
| 323 |
+
int i00 = static_cast<int>(i10 / sf0);
|
| 324 |
+
int i01 = static_cast<int>(i11 / sf1);
|
| 325 |
+
int i02 = static_cast<int>(i12 / sf2);
|
| 326 |
+
int i03 = static_cast<int>(i13 / sf3);
|
| 327 |
+
|
| 328 |
+
dst[index] = *(const T *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
|
| 329 |
}
|
| 330 |
|
| 331 |
+
template <typename T>
|
| 332 |
+
static void pad(const T *x, T *dst, const int ne0, const int ne00, const int ne01, const int ne02,
|
| 333 |
+
const sycl::nd_item<3> &item_ct1) {
|
| 334 |
+
int nidx = SYCL_LOCAL_ID_CALC(item_ct1, 2);
|
| 335 |
+
if (nidx >= ne0) {
|
| 336 |
+
return;
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
// operation
|
| 340 |
+
int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
|
| 341 |
+
item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
|
| 342 |
+
if (nidx < ne00 && item_ct1.get_group(1) < (size_t) ne01 && item_ct1.get_group(0) < (size_t) ne02) {
|
| 343 |
+
int offset_src = nidx + item_ct1.get_group(1) * ne00 +
|
| 344 |
+
item_ct1.get_group(0) * ne00 * ne01;
|
| 345 |
+
dst[offset_dst] = x[offset_src];
|
| 346 |
+
} else {
|
| 347 |
+
dst[offset_dst] = static_cast<T>(0.0f);
|
| 348 |
+
}
|
| 349 |
}
|
| 350 |
|
| 351 |
template<typename T>
|
| 352 |
+
static void clamp(const T * x, T * dst, const float min, const float max, const int k,
|
| 353 |
+
const sycl::nd_item<1> &item_ct1) {
|
| 354 |
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
| 355 |
+
dst[i] = x[i] < static_cast<T>(min) ? static_cast<T>(min) : (x[i] > static_cast<T>(max) ? static_cast<T>(max) : x[i]);
|
| 356 |
+
}
|
|
|
|
|
|
|
| 357 |
}
|
| 358 |
|
| 359 |
template<typename T>
|
| 360 |
+
static void gated_op_fused_geglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) {
|
| 361 |
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
| 362 |
+
const int64_t j0 = (i / n) * o0 + (i % n);
|
| 363 |
+
const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
|
| 364 |
+
dst[i] = op_gelu(x[j0]) * g[j1];
|
| 365 |
+
}
|
|
|
|
| 366 |
}
|
| 367 |
|
| 368 |
template<typename T>
|
| 369 |
+
static void gated_op_fused_reglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) {
|
| 370 |
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
| 371 |
+
const int64_t j0 = (i / n) * o0 + (i % n);
|
| 372 |
+
const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
|
| 373 |
+
dst[i] = op_relu(x[j0]) * g[j1];
|
| 374 |
+
}
|
|
|
|
| 375 |
}
|
| 376 |
|
| 377 |
template<typename T>
|
| 378 |
+
static void gated_op_fused_swiglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) {
|
| 379 |
+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
| 380 |
+
const int64_t j0 = (i / n) * o0 + (i % n);
|
| 381 |
+
const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
|
| 382 |
+
dst[i] = op_silu(x[j0]) * g[j1];
|
| 383 |
+
}
|
|
|
|
|
|
|
| 384 |
}
|
| 385 |
|
| 386 |
+
namespace ggml_sycl_detail {
|
| 387 |
+
static void acc_f32_sycl(const float *x, const float *y, float *dst,
|
| 388 |
+
const int n_elements, const int ne10, const int ne11,
|
| 389 |
+
const int ne12, const int nb1, const int nb2,
|
| 390 |
+
const int offset, queue_ptr stream) {
|
| 391 |
+
int num_blocks = ceil_div(n_elements, SYCL_ACC_BLOCK_SIZE);
|
| 392 |
sycl_parallel_for(stream,
|
| 393 |
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) *
|
| 394 |
+
sycl::range<1>(SYCL_ACC_BLOCK_SIZE),
|
| 395 |
+
sycl::range<1>(SYCL_ACC_BLOCK_SIZE)),
|
| 396 |
+
[=](sycl::nd_item<1> item_ct1) {
|
| 397 |
+
acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset,
|
| 398 |
+
item_ct1);
|
| 399 |
+
});
|
| 400 |
}
|
| 401 |
|
| 402 |
template<typename T>
|
|
|
|
| 405 |
const int ne12, const int ne13, const float sf0, const float sf1,
|
| 406 |
const float sf2, const float sf3, queue_ptr stream) {
|
| 407 |
int dst_size = ne10 * ne11 * ne12 * ne13;
|
| 408 |
+
int num_blocks = ceil_div(dst_size, SYCL_UPSCALE_BLOCK_SIZE);
|
| 409 |
sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE);
|
| 410 |
sycl_parallel_for<1>(
|
| 411 |
stream, sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
|
|
|
| 417 |
static void pad_sycl(const T *x, T *dst, const int ne00,
|
| 418 |
const int ne01, const int ne02, const int ne0,
|
| 419 |
const int ne1, const int ne2, queue_ptr stream) {
|
| 420 |
+
int num_blocks = ceil_div(ne0, SYCL_PAD_BLOCK_SIZE);
|
| 421 |
sycl::range<3> gridDim(ne2, ne1, num_blocks);
|
| 422 |
sycl_parallel_for(stream,
|
| 423 |
sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE),
|
|
|
|
| 425 |
[=](sycl::nd_item<3> item_ct1) { pad(x, dst, ne0, ne00, ne01, ne02, item_ct1); });
|
| 426 |
}
|
| 427 |
|
| 428 |
+
template<typename KernelInvoker, typename... Args>
|
| 429 |
+
static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 430 |
#if defined (GGML_SYCL_F16)
|
| 431 |
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
| 432 |
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
|
|
|
| 433 |
#else
|
| 434 |
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
| 435 |
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
|
|
|
| 442 |
case GGML_TYPE_F16:
|
| 443 |
{
|
| 444 |
auto data_pts = cast_data<sycl::half>(dst);
|
| 445 |
+
kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward<Args>(args)...);
|
| 446 |
break;
|
| 447 |
}
|
| 448 |
#endif
|
| 449 |
case GGML_TYPE_F32:
|
| 450 |
{
|
| 451 |
auto data_pts = cast_data<float>(dst);
|
| 452 |
+
kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward<Args>(args)...);
|
| 453 |
break;
|
| 454 |
}
|
| 455 |
default:
|
|
|
|
| 457 |
}
|
| 458 |
}
|
| 459 |
|
| 460 |
+
template<typename KernelInvoker, typename... Args>
|
| 461 |
+
static inline void dispatch_ggml_sycl_op_fused_glu(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
|
| 462 |
#if defined (GGML_SYCL_F16)
|
| 463 |
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
| 464 |
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
|
|
|
| 465 |
#else
|
| 466 |
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
| 467 |
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
|
|
|
| 469 |
GGML_ASSERT(dst->src[0]->type == dst->type);
|
| 470 |
dpct::queue_ptr main_stream = ctx.stream();
|
| 471 |
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
| 472 |
+
const ggml_tensor * src0 = dst->src[0];
|
| 473 |
+
const ggml_tensor * src1 = dst->src[1];
|
| 474 |
+
const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2;;
|
| 475 |
+
GGML_ASSERT(dst->ne[0] == nc);
|
| 476 |
+
GGML_ASSERT(ggml_is_contiguous_1(dst->src[0]));
|
| 477 |
+
GGML_ASSERT(ggml_is_contiguous(dst));
|
| 478 |
+
const int32_t swapped = ((const int32_t *) dst->op_params)[1];
|
| 479 |
+
void * src0_d = src0->data;
|
| 480 |
+
void * src1_d = src1 ? src1->data : src0->data;
|
| 481 |
+
const int64_t src0_o = src0->nb[1];
|
| 482 |
+
const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
| 483 |
+
void * dst_d = dst->data;
|
| 484 |
+
if (src1) {
|
| 485 |
+
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
| 486 |
+
GGML_ASSERT(src1->nb[0] == ggml_element_size(src1));
|
| 487 |
+
GGML_ASSERT(src1->ne[0] == nc);
|
| 488 |
+
GGML_ASSERT(src0->type == src1->type);
|
| 489 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 490 |
switch (dst->type) {
|
| 491 |
#if defined (GGML_SYCL_F16)
|
| 492 |
case GGML_TYPE_F16:
|
| 493 |
{
|
| 494 |
+
sycl::half * src0_p = (sycl::half *) src0_d;
|
| 495 |
+
sycl::half * src1_p = (sycl::half *) src1_d;
|
| 496 |
+
|
| 497 |
+
if (!src1) {
|
| 498 |
+
src0_p += swapped ? nc : 0;
|
| 499 |
+
src1_p += swapped ? 0 : nc;
|
| 500 |
+
}
|
| 501 |
+
kernel_invoker(src0_p,
|
| 502 |
+
src1_p,
|
| 503 |
+
(sycl::half *) dst_d,
|
| 504 |
+
ggml_nelements(dst),
|
| 505 |
+
nc,
|
| 506 |
+
src0_o / sizeof(sycl::half),
|
| 507 |
+
src1_o / sizeof(sycl::half),
|
| 508 |
+
main_stream,
|
| 509 |
+
std::forward<Args>(args)...);
|
| 510 |
break;
|
| 511 |
}
|
| 512 |
#endif
|
| 513 |
case GGML_TYPE_F32:
|
| 514 |
{
|
| 515 |
+
float * src0_p = (float *) src0_d;
|
| 516 |
+
float * src1_p = (float *) src1_d;
|
| 517 |
+
|
| 518 |
+
if (!src1) {
|
| 519 |
+
src0_p += swapped ? nc : 0;
|
| 520 |
+
src1_p += swapped ? 0 : nc;
|
| 521 |
+
}
|
| 522 |
+
|
| 523 |
+
kernel_invoker(src0_p,
|
| 524 |
+
src1_p,
|
| 525 |
+
(float *) dst_d,
|
| 526 |
+
ggml_nelements(dst),
|
| 527 |
+
nc,
|
| 528 |
+
src0_o / sizeof(float),
|
| 529 |
+
src1_o / sizeof(float),
|
| 530 |
+
main_stream,
|
| 531 |
+
std::forward<Args>(args)...);
|
| 532 |
break;
|
| 533 |
}
|
| 534 |
default:
|
|
|
|
| 536 |
}
|
| 537 |
}
|
| 538 |
|
| 539 |
+
template<typename KernelInvoker, typename... Args>
|
| 540 |
+
static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
|
| 541 |
#if defined (GGML_SYCL_F16)
|
| 542 |
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
| 543 |
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
|
|
|
| 546 |
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
| 547 |
#endif
|
| 548 |
GGML_ASSERT(dst->src[0]->type == dst->type);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 549 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 550 |
dpct::queue_ptr main_stream = ctx.stream();
|
| 551 |
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
| 552 |
+
|
| 553 |
+
const float sf0 = (float) dst->ne[0] / dst->src[0]->ne[0];
|
| 554 |
+
const float sf1 = (float) dst->ne[1] / dst->src[0]->ne[1];
|
| 555 |
+
const float sf2 = (float) dst->ne[2] / dst->src[0]->ne[2];
|
| 556 |
+
const float sf3 = (float) dst->ne[3] / dst->src[0]->ne[3];
|
| 557 |
switch (dst->type) {
|
| 558 |
#if defined (GGML_SYCL_F16)
|
| 559 |
case GGML_TYPE_F16:
|
| 560 |
{
|
| 561 |
auto data_pts = cast_data<sycl::half>(dst);
|
| 562 |
+
kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->nb[0], (int)dst->src[0]->nb[1], (int)dst->src[0]->nb[2],
|
| 563 |
+
(int)dst->src[0]->nb[3], (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], sf0, sf1, sf2, sf3,
|
| 564 |
+
main_stream, std::forward<Args>(args)...);
|
| 565 |
break;
|
| 566 |
}
|
| 567 |
#endif
|
| 568 |
case GGML_TYPE_F32:
|
| 569 |
{
|
| 570 |
auto data_pts = cast_data<float>(dst);
|
| 571 |
+
kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->nb[0], (int)dst->src[0]->nb[1], (int)dst->src[0]->nb[2],
|
| 572 |
+
(int)dst->src[0]->nb[3], (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], sf0, sf1, sf2, sf3,
|
| 573 |
+
main_stream, std::forward<Args>(args)...);
|
| 574 |
break;
|
| 575 |
}
|
| 576 |
default:
|
|
|
|
| 578 |
}
|
| 579 |
}
|
| 580 |
|
| 581 |
+
template<typename KernelInvoker, typename... Args>
|
| 582 |
+
static inline void dispatch_ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
|
| 583 |
#if defined (GGML_SYCL_F16)
|
| 584 |
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
| 585 |
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
|
|
|
| 588 |
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
| 589 |
#endif
|
| 590 |
GGML_ASSERT(dst->src[0]->type == dst->type);
|
| 591 |
+
GGML_ASSERT(dst->src[0]->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
|
| 592 |
dpct::queue_ptr main_stream = ctx.stream();
|
| 593 |
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
| 594 |
switch (dst->type) {
|
|
|
|
| 596 |
case GGML_TYPE_F16:
|
| 597 |
{
|
| 598 |
auto data_pts = cast_data<sycl::half>(dst);
|
| 599 |
+
kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->ne[0], (int)dst->src[0]->ne[1], (int)dst->src[0]->ne[2], (int)dst->ne[0],
|
| 600 |
+
(int)dst->ne[1], (int)dst->ne[2], main_stream, std::forward<Args>(args)...);
|
| 601 |
break;
|
| 602 |
}
|
| 603 |
#endif
|
| 604 |
case GGML_TYPE_F32:
|
| 605 |
{
|
| 606 |
auto data_pts = cast_data<float>(dst);
|
| 607 |
+
kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->ne[0], (int)dst->src[0]->ne[1], (int)dst->src[0]->ne[2], (int)dst->ne[0],
|
| 608 |
+
(int)dst->ne[1], (int)dst->ne[2], main_stream, std::forward<Args>(args)...);
|
| 609 |
break;
|
| 610 |
}
|
| 611 |
default:
|
|
|
|
| 613 |
}
|
| 614 |
}
|
| 615 |
|
| 616 |
+
} // namespace ggml_sycl_detail
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 617 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 618 |
|
|
|
|
|
|
|
| 619 |
|
| 620 |
+
static inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
| 621 |
+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
| 622 |
+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
| 623 |
+
const int num_blocks = ceil_div(k_elements, 256);
|
| 624 |
+
sycl_parallel_for(stream,
|
| 625 |
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
|
| 626 |
+
sycl::range<1>(256)),
|
| 627 |
+
[=](sycl::nd_item<1> item_ct1) {
|
| 628 |
+
unary_op_sgn_kernel(src, dst_ptr, k_elements, item_ct1);
|
| 629 |
+
});
|
| 630 |
+
});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 631 |
}
|
| 632 |
|
| 633 |
+
static inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
| 634 |
+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
| 635 |
+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
| 636 |
+
const int num_blocks = ceil_div(k_elements, 256);
|
| 637 |
+
sycl_parallel_for(stream,
|
| 638 |
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
|
| 639 |
+
sycl::range<1>(256)),
|
| 640 |
+
[=](sycl::nd_item<1> item_ct1) {
|
| 641 |
+
unary_op_abs_kernel(src, dst_ptr, k_elements, item_ct1);
|
| 642 |
+
});
|
| 643 |
+
});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 644 |
}
|
| 645 |
|
| 646 |
+
static inline void ggml_sycl_op_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
| 647 |
+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
| 648 |
+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
| 649 |
+
const int num_blocks = ceil_div(k_elements, 256);
|
| 650 |
+
sycl_parallel_for(stream,
|
| 651 |
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
|
| 652 |
+
sycl::range<1>(256)),
|
| 653 |
+
[=](sycl::nd_item<1> item_ct1) {
|
| 654 |
+
unary_op_elu_kernel(src, dst_ptr, k_elements, item_ct1);
|
| 655 |
+
});
|
| 656 |
+
});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 657 |
}
|
| 658 |
|
| 659 |
+
static inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
| 660 |
+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
| 661 |
+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
| 662 |
+
const int num_blocks = ceil_div(k_elements, SYCL_SILU_BLOCK_SIZE);
|
| 663 |
+
sycl_parallel_for(stream,
|
| 664 |
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SILU_BLOCK_SIZE),
|
| 665 |
+
sycl::range<1>(SYCL_SILU_BLOCK_SIZE)),
|
| 666 |
+
[=](sycl::nd_item<1> item_ct1) {
|
| 667 |
+
unary_op_silu_kernel(src, dst_ptr, k_elements, item_ct1);
|
| 668 |
+
});
|
| 669 |
+
});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 670 |
}
|
| 671 |
|
| 672 |
+
static inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
| 673 |
+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
| 674 |
+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
| 675 |
+
const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE);
|
| 676 |
+
sycl_parallel_for(stream,
|
| 677 |
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE),
|
| 678 |
+
sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
|
| 679 |
+
[=](sycl::nd_item<1> item_ct1) {
|
| 680 |
+
unary_op_gelu_kernel(src, dst_ptr, k_elements, item_ct1);
|
| 681 |
+
});
|
| 682 |
+
});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 683 |
}
|
| 684 |
|
| 685 |
+
static inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
| 686 |
+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
| 687 |
+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
| 688 |
+
const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE);
|
| 689 |
+
sycl_parallel_for(stream,
|
| 690 |
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE),
|
| 691 |
+
sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
|
| 692 |
+
[=](sycl::nd_item<1> item_ct1) {
|
| 693 |
+
unary_op_gelu_quick_kernel(src, dst_ptr, k_elements, item_ct1);
|
| 694 |
+
});
|
| 695 |
+
});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 696 |
}
|
| 697 |
+
|
| 698 |
+
static inline void ggml_sycl_op_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
| 699 |
+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
| 700 |
+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
| 701 |
+
const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE);
|
| 702 |
+
sycl_parallel_for(stream,
|
| 703 |
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE),
|
| 704 |
+
sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
|
| 705 |
+
[=](sycl::nd_item<1> item_ct1) {
|
| 706 |
+
unary_op_gelu_erf_kernel(src, dst_ptr, k_elements, item_ct1);
|
| 707 |
+
});
|
| 708 |
+
});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 709 |
}
|
| 710 |
|
| 711 |
+
static inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
| 712 |
+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
| 713 |
+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
| 714 |
+
const int num_blocks = ceil_div(k_elements, SYCL_TANH_BLOCK_SIZE);
|
| 715 |
+
sycl_parallel_for(stream,
|
| 716 |
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_TANH_BLOCK_SIZE),
|
| 717 |
+
sycl::range<1>(SYCL_TANH_BLOCK_SIZE)),
|
| 718 |
+
[=](sycl::nd_item<1> item_ct1) {
|
| 719 |
+
unary_op_tanh_kernel(src, dst_ptr, k_elements, item_ct1);
|
| 720 |
+
});
|
| 721 |
+
});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 722 |
}
|
| 723 |
|
| 724 |
+
static inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
| 725 |
+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
| 726 |
+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
| 727 |
+
const int num_blocks = ceil_div(k_elements, SYCL_RELU_BLOCK_SIZE);
|
| 728 |
+
sycl_parallel_for(stream,
|
| 729 |
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_RELU_BLOCK_SIZE),
|
| 730 |
+
sycl::range<1>(SYCL_RELU_BLOCK_SIZE)),
|
| 731 |
+
[=](sycl::nd_item<1> item_ct1) {
|
| 732 |
+
unary_op_relu_kernel(src, dst_ptr, k_elements, item_ct1);
|
| 733 |
+
});
|
| 734 |
+
});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 735 |
}
|
| 736 |
|
| 737 |
+
static inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
| 738 |
+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
| 739 |
+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
| 740 |
+
const int num_blocks = ceil_div(k_elements, SYCL_HARDSIGMOID_BLOCK_SIZE);
|
| 741 |
+
sycl_parallel_for(stream,
|
| 742 |
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_HARDSIGMOID_BLOCK_SIZE),
|
| 743 |
+
sycl::range<1>(SYCL_HARDSIGMOID_BLOCK_SIZE)),
|
| 744 |
+
[=](sycl::nd_item<1> item_ct1) {
|
| 745 |
+
unary_op_hardsigmoid_kernel(src, dst_ptr, k_elements, item_ct1);
|
| 746 |
+
});
|
| 747 |
+
});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 748 |
}
|
| 749 |
|
| 750 |
+
static inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
| 751 |
+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
| 752 |
+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
| 753 |
+
const int num_blocks = ceil_div(k_elements, SYCL_HARDSWISH_BLOCK_SIZE);
|
| 754 |
+
sycl_parallel_for(stream,
|
| 755 |
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_HARDSWISH_BLOCK_SIZE),
|
| 756 |
+
sycl::range<1>(SYCL_HARDSWISH_BLOCK_SIZE)),
|
| 757 |
+
[=](sycl::nd_item<1> item_ct1) {
|
| 758 |
+
unary_op_hardswish_kernel(src, dst_ptr, k_elements, item_ct1);
|
| 759 |
+
});
|
| 760 |
+
});
|
| 761 |
+
}
|
| 762 |
|
| 763 |
+
static inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
| 764 |
+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
| 765 |
+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
| 766 |
+
const int num_blocks = ceil_div(k_elements, SYCL_EXP_BLOCK_SIZE);
|
| 767 |
+
sycl_parallel_for(stream,
|
| 768 |
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_EXP_BLOCK_SIZE),
|
| 769 |
+
sycl::range<1>(SYCL_EXP_BLOCK_SIZE)),
|
| 770 |
+
[=](sycl::nd_item<1> item_ct1) {
|
| 771 |
+
unary_op_exp_kernel(src, dst_ptr, k_elements, item_ct1);
|
| 772 |
+
});
|
| 773 |
+
});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 774 |
}
|
| 775 |
|
| 776 |
+
static inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
| 777 |
+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
| 778 |
+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
| 779 |
+
const int num_blocks = ceil_div(k_elements, SYCL_EXP_BLOCK_SIZE); // Using EXP block size
|
| 780 |
+
sycl_parallel_for(stream,
|
| 781 |
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_EXP_BLOCK_SIZE),
|
| 782 |
+
sycl::range<1>(SYCL_EXP_BLOCK_SIZE)),
|
| 783 |
+
[=](sycl::nd_item<1> item_ct1) {
|
| 784 |
+
unary_op_log_kernel(src, dst_ptr, k_elements, item_ct1);
|
| 785 |
+
});
|
| 786 |
+
});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 787 |
}
|
| 788 |
|
| 789 |
+
static inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
| 790 |
+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
| 791 |
+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
| 792 |
+
const int num_blocks = ceil_div(k_elements, SYCL_NEG_BLOCK_SIZE);
|
| 793 |
+
sycl_parallel_for(stream,
|
| 794 |
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_NEG_BLOCK_SIZE),
|
| 795 |
+
sycl::range<1>(SYCL_NEG_BLOCK_SIZE)),
|
| 796 |
+
[=](sycl::nd_item<1> item_ct1) {
|
| 797 |
+
unary_op_neg_kernel(src, dst_ptr, k_elements, item_ct1);
|
| 798 |
+
});
|
| 799 |
+
});
|
| 800 |
+
}
|
| 801 |
|
| 802 |
+
static inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
| 803 |
+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
| 804 |
+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
| 805 |
+
const int num_blocks = ceil_div(k_elements, SYCL_NEG_BLOCK_SIZE); // Using NEG block size
|
| 806 |
+
sycl_parallel_for(stream,
|
| 807 |
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_NEG_BLOCK_SIZE),
|
| 808 |
+
sycl::range<1>(SYCL_NEG_BLOCK_SIZE)),
|
| 809 |
+
[=](sycl::nd_item<1> item_ct1) {
|
| 810 |
+
unary_op_step_kernel(src, dst_ptr, k_elements, item_ct1);
|
| 811 |
+
});
|
| 812 |
+
});
|
| 813 |
+
}
|
| 814 |
|
| 815 |
+
static inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
| 816 |
+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
| 817 |
+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
| 818 |
+
const int num_blocks = ceil_div(k_elements, SYCL_SIGMOID_BLOCK_SIZE);
|
| 819 |
+
sycl_parallel_for(stream,
|
| 820 |
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIGMOID_BLOCK_SIZE),
|
| 821 |
+
sycl::range<1>(SYCL_SIGMOID_BLOCK_SIZE)),
|
| 822 |
+
[=](sycl::nd_item<1> item_ct1) {
|
| 823 |
+
unary_op_sigmoid_kernel(src, dst_ptr, k_elements, item_ct1);
|
| 824 |
+
});
|
| 825 |
+
});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 826 |
}
|
| 827 |
|
| 828 |
+
static inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
| 829 |
+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
| 830 |
+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
| 831 |
+
const int num_blocks = ceil_div(k_elements, SYCL_SQRT_BLOCK_SIZE);
|
| 832 |
+
sycl_parallel_for(stream,
|
| 833 |
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQRT_BLOCK_SIZE),
|
| 834 |
+
sycl::range<1>(SYCL_SQRT_BLOCK_SIZE)),
|
| 835 |
+
[=](sycl::nd_item<1> item_ct1) {
|
| 836 |
+
unary_op_sqrt_kernel(src, dst_ptr, k_elements, item_ct1);
|
| 837 |
+
});
|
| 838 |
+
});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 839 |
}
|
| 840 |
|
| 841 |
+
static inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
| 842 |
+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
| 843 |
+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
| 844 |
+
const int num_blocks = ceil_div(k_elements, SYCL_SIN_BLOCK_SIZE);
|
| 845 |
+
sycl_parallel_for(stream,
|
| 846 |
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE),
|
| 847 |
+
sycl::range<1>(SYCL_SIN_BLOCK_SIZE)),
|
| 848 |
+
[=](sycl::nd_item<1> item_ct1) {
|
| 849 |
+
unary_op_sin_kernel(src, dst_ptr, k_elements, item_ct1);
|
| 850 |
+
});
|
| 851 |
+
});
|
| 852 |
+
}
|
| 853 |
|
| 854 |
+
static inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
| 855 |
+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
| 856 |
+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
| 857 |
+
const int num_blocks = ceil_div(k_elements, SYCL_SIN_BLOCK_SIZE); // Using SIN block size
|
| 858 |
+
sycl_parallel_for(stream,
|
| 859 |
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE),
|
| 860 |
+
sycl::range<1>(SYCL_SIN_BLOCK_SIZE)),
|
| 861 |
+
[=](sycl::nd_item<1> item_ct1) {
|
| 862 |
+
unary_op_cos_kernel(src, dst_ptr, k_elements, item_ct1);
|
| 863 |
+
});
|
| 864 |
+
});
|
| 865 |
+
}
|
| 866 |
|
| 867 |
+
static inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
| 868 |
+
float negative_slope;
|
| 869 |
+
memcpy(&negative_slope, dst->op_params, sizeof(float));
|
| 870 |
+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
| 871 |
+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream, float slope) {
|
| 872 |
+
const int num_blocks = ceil_div(k_elements, SYCL_RELU_BLOCK_SIZE);
|
| 873 |
+
sycl_parallel_for(stream,
|
| 874 |
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_RELU_BLOCK_SIZE),
|
| 875 |
+
sycl::range<1>(SYCL_RELU_BLOCK_SIZE)),
|
| 876 |
+
[=](sycl::nd_item<1> item_ct1) {
|
| 877 |
+
unary_op_leaky_relu_kernel(src, dst_ptr, k_elements, slope, item_ct1);
|
| 878 |
+
});
|
| 879 |
+
}, negative_slope);
|
| 880 |
+
}
|
| 881 |
+
|
| 882 |
+
static inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
| 883 |
+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
| 884 |
+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
| 885 |
+
const int num_blocks = ceil_div(k_elements, SYCL_SQR_BLOCK_SIZE);
|
| 886 |
+
sycl_parallel_for(stream,
|
| 887 |
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQR_BLOCK_SIZE),
|
| 888 |
+
sycl::range<1>(SYCL_SQR_BLOCK_SIZE)),
|
| 889 |
+
[=](sycl::nd_item<1> item_ct1) {
|
| 890 |
+
unary_op_sqr_kernel(src, dst_ptr, k_elements, item_ct1);
|
| 891 |
+
});
|
| 892 |
+
});
|
| 893 |
+
}
|
| 894 |
+
|
| 895 |
+
static inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
| 896 |
+
ggml_sycl_detail::dispatch_ggml_sycl_op_upscale(ctx, dst,
|
| 897 |
+
[](const auto* src, auto* dst_ptr, int nb00, int nb01, int nb02, int nb03,
|
| 898 |
+
int ne10, int ne11, int ne12, int ne13, float sf0, float sf1, float sf2, float sf3,
|
| 899 |
+
queue_ptr stream) {
|
| 900 |
+
ggml_sycl_detail::upscale_sycl(src, dst_ptr, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, stream);
|
| 901 |
+
});
|
| 902 |
}
|
| 903 |
|
| 904 |
+
static inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
| 905 |
+
ggml_sycl_detail::dispatch_ggml_sycl_op_pad(ctx, dst,
|
| 906 |
+
[](const auto* src, auto* dst_ptr, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2,
|
| 907 |
+
queue_ptr stream) {
|
| 908 |
+
ggml_sycl_detail::pad_sycl(src, dst_ptr, ne00, ne01, ne02, ne0, ne1, ne2, stream);
|
| 909 |
+
});
|
| 910 |
+
}
|
| 911 |
|
| 912 |
+
static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
| 913 |
+
float min_val;
|
| 914 |
+
float max_val;
|
| 915 |
+
memcpy(&min_val, dst->op_params, sizeof(float));
|
| 916 |
+
memcpy(&max_val, (float *) dst->op_params + 1, sizeof(float));
|
| 917 |
+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
| 918 |
+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream, float min_arg, float max_arg) {
|
| 919 |
+
const int num_blocks = ceil_div(k_elements, SYCL_CLAMP_BLOCK_SIZE);
|
| 920 |
+
sycl_parallel_for(stream,
|
| 921 |
+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE),
|
| 922 |
+
sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE)),
|
| 923 |
+
[=](sycl::nd_item<1> item_ct1) {
|
| 924 |
+
clamp(src, dst_ptr, min_arg, max_arg, k_elements, item_ct1);
|
| 925 |
+
});
|
| 926 |
+
}, min_val, max_val);
|
| 927 |
+
}
|
| 928 |
+
|
| 929 |
+
static inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
| 930 |
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
| 931 |
GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32);
|
| 932 |
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
|
|
| 942 |
// int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
|
| 943 |
int offset = dst->op_params[3] / 4; // offset in bytes
|
| 944 |
|
| 945 |
+
ggml_sycl_detail::acc_f32_sycl(src0_dd, src1_dd, dst_dd, (int)ggml_nelements(dst), (int)dst->src[1]->ne[0], (int)dst->src[1]->ne[1], (int)dst->src[1]->ne[2], nb1, nb2, offset, main_stream);
|
| 946 |
+
}
|
| 947 |
+
|
| 948 |
+
static inline void ggml_sycl_op_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
| 949 |
+
ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
|
| 950 |
+
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
|
| 951 |
+
const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
|
| 952 |
+
sycl_parallel_for(main_stream,
|
| 953 |
+
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
| 954 |
+
gated_op_fused_geglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
|
| 955 |
+
});
|
| 956 |
+
});
|
| 957 |
+
}
|
| 958 |
+
|
| 959 |
+
static inline void ggml_sycl_op_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
| 960 |
+
ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
|
| 961 |
+
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
|
| 962 |
+
const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_RELU_BLOCK_SIZE); // Using RELU block size for reglu
|
| 963 |
+
sycl_parallel_for(main_stream,
|
| 964 |
+
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
| 965 |
+
gated_op_fused_reglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
|
| 966 |
+
});
|
| 967 |
+
});
|
| 968 |
+
}
|
| 969 |
+
|
| 970 |
+
static inline void ggml_sycl_op_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
| 971 |
+
ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
|
| 972 |
+
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
|
| 973 |
+
const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_SILU_BLOCK_SIZE); // Using SILU block size for swiglu
|
| 974 |
+
sycl_parallel_for(main_stream,
|
| 975 |
+
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
| 976 |
+
gated_op_fused_swiglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
|
| 977 |
+
});
|
| 978 |
+
});
|
| 979 |
}
|
| 980 |
|
| 981 |
|
|
|
|
| 1103 |
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
| 1104 |
ggml_sycl_op_elu(ctx, dst);
|
| 1105 |
}
|
| 1106 |
+
|
| 1107 |
+
void ggml_sycl_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
| 1108 |
+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
| 1109 |
+
ggml_sycl_op_geglu(ctx, dst);
|
| 1110 |
+
}
|
| 1111 |
+
|
| 1112 |
+
void ggml_sycl_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
| 1113 |
+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
| 1114 |
+
ggml_sycl_op_reglu(ctx, dst);
|
| 1115 |
+
}
|
| 1116 |
+
|
| 1117 |
+
void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
| 1118 |
+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
| 1119 |
+
ggml_sycl_op_swiglu(ctx, dst);
|
| 1120 |
+
}
|
|
@@ -3,27 +3,30 @@
|
|
| 3 |
|
| 4 |
#include "common.hpp"
|
| 5 |
#include "ggml.h"
|
| 6 |
-
#include <limits
|
| 7 |
|
| 8 |
template <typename T>
|
| 9 |
T neg_infinity() {
|
| 10 |
return -std::numeric_limits<T>::infinity();
|
| 11 |
}
|
| 12 |
|
| 13 |
-
template<typename
|
| 14 |
struct typed_data {
|
| 15 |
-
const
|
| 16 |
-
|
| 17 |
};
|
| 18 |
|
| 19 |
-
template<typename
|
| 20 |
-
typed_data<
|
| 21 |
return {
|
| 22 |
-
/* .src = */ static_cast<const
|
| 23 |
-
/* .dst = */ static_cast<
|
| 24 |
};
|
| 25 |
}
|
| 26 |
|
|
|
|
|
|
|
|
|
|
| 27 |
void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
| 28 |
|
| 29 |
void ggml_sycl_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
|
@@ -73,5 +76,9 @@ void ggml_sycl_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
|
| 73 |
void ggml_sycl_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
| 74 |
|
| 75 |
void ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
| 76 |
-
#endif // GGML_SYCL_ELEMENTWISE_HPP
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
#include "common.hpp"
|
| 5 |
#include "ggml.h"
|
| 6 |
+
#include <limits> // For std::numeric_limits
|
| 7 |
|
| 8 |
template <typename T>
|
| 9 |
T neg_infinity() {
|
| 10 |
return -std::numeric_limits<T>::infinity();
|
| 11 |
}
|
| 12 |
|
| 13 |
+
template<typename T_Dst, typename T_Src = T_Dst>
|
| 14 |
struct typed_data {
|
| 15 |
+
const T_Src * src;
|
| 16 |
+
T_Dst * dst;
|
| 17 |
};
|
| 18 |
|
| 19 |
+
template<typename T_Dst, typename T_Src = T_Dst>
|
| 20 |
+
typed_data<T_Dst, T_Src> cast_data(ggml_tensor * dst) {
|
| 21 |
return {
|
| 22 |
+
/* .src = */ static_cast<const T_Src *>(dst->src[0]->data),
|
| 23 |
+
/* .dst = */ static_cast<T_Dst *>(dst->data)
|
| 24 |
};
|
| 25 |
}
|
| 26 |
|
| 27 |
+
const float GELU_QUICK_COEF = -1.702f;
|
| 28 |
+
|
| 29 |
+
|
| 30 |
void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
| 31 |
|
| 32 |
void ggml_sycl_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
|
|
|
| 76 |
void ggml_sycl_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
| 77 |
|
| 78 |
void ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
|
|
|
| 79 |
|
| 80 |
+
void ggml_sycl_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
| 81 |
+
void ggml_sycl_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
| 82 |
+
void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
| 83 |
+
|
| 84 |
+
#endif // GGML_SYCL_ELEMENTWISE_HPP
|
|
@@ -3676,6 +3676,21 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
|
| 3676 |
return false;
|
| 3677 |
}
|
| 3678 |
break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3679 |
case GGML_OP_NORM:
|
| 3680 |
ggml_sycl_norm(ctx, dst);
|
| 3681 |
break;
|
|
@@ -4212,6 +4227,16 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
| 4212 |
default:
|
| 4213 |
return false;
|
| 4214 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4215 |
case GGML_OP_MUL_MAT:
|
| 4216 |
case GGML_OP_MUL_MAT_ID:
|
| 4217 |
{
|
|
|
|
| 3676 |
return false;
|
| 3677 |
}
|
| 3678 |
break;
|
| 3679 |
+
case GGML_OP_GLU:
|
| 3680 |
+
switch (ggml_get_glu_op(dst)) {
|
| 3681 |
+
case GGML_GLU_OP_REGLU:
|
| 3682 |
+
ggml_sycl_reglu(ctx, dst);
|
| 3683 |
+
break;
|
| 3684 |
+
case GGML_GLU_OP_GEGLU:
|
| 3685 |
+
ggml_sycl_geglu(ctx, dst);
|
| 3686 |
+
break;
|
| 3687 |
+
case GGML_GLU_OP_SWIGLU:
|
| 3688 |
+
ggml_sycl_swiglu(ctx, dst);
|
| 3689 |
+
break;
|
| 3690 |
+
default:
|
| 3691 |
+
return false;
|
| 3692 |
+
}
|
| 3693 |
+
break;
|
| 3694 |
case GGML_OP_NORM:
|
| 3695 |
ggml_sycl_norm(ctx, dst);
|
| 3696 |
break;
|
|
|
|
| 4227 |
default:
|
| 4228 |
return false;
|
| 4229 |
}
|
| 4230 |
+
case GGML_OP_GLU:
|
| 4231 |
+
switch (ggml_get_glu_op(op)) {
|
| 4232 |
+
case GGML_GLU_OP_REGLU:
|
| 4233 |
+
case GGML_GLU_OP_GEGLU:
|
| 4234 |
+
case GGML_GLU_OP_SWIGLU:
|
| 4235 |
+
return ggml_is_contiguous_1(op->src[0]);
|
| 4236 |
+
default:
|
| 4237 |
+
return false;
|
| 4238 |
+
}
|
| 4239 |
+
break;
|
| 4240 |
case GGML_OP_MUL_MAT:
|
| 4241 |
case GGML_OP_MUL_MAT_ID:
|
| 4242 |
{
|
|
@@ -437,6 +437,10 @@ struct vk_device_struct {
|
|
| 437 |
vk_pipeline pipeline_tanh[2];
|
| 438 |
vk_pipeline pipeline_sigmoid[2];
|
| 439 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 440 |
vk_pipeline pipeline_leaky_relu_f32;
|
| 441 |
vk_pipeline pipeline_silu_back_f32;
|
| 442 |
vk_pipeline pipeline_diag_mask_inf_f32;
|
|
@@ -661,6 +665,13 @@ struct vk_op_push_constants {
|
|
| 661 |
float param2;
|
| 662 |
};
|
| 663 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 664 |
struct vk_op_unary_push_constants {
|
| 665 |
uint32_t ne;
|
| 666 |
uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
|
|
@@ -2757,6 +2768,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2757 |
CREATE_UNARY(sigmoid)
|
| 2758 |
#undef CREATE_UNARY
|
| 2759 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2760 |
ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
| 2761 |
ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
| 2762 |
|
|
@@ -6473,6 +6493,24 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
| 6473 |
break;
|
| 6474 |
}
|
| 6475 |
return nullptr;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6476 |
case GGML_OP_DIAG_MASK_INF:
|
| 6477 |
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
| 6478 |
return ctx->device->pipeline_diag_mask_inf_f32;
|
|
@@ -6933,6 +6971,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
| 6933 |
case GGML_OP_CONCAT:
|
| 6934 |
case GGML_OP_UPSCALE:
|
| 6935 |
case GGML_OP_UNARY:
|
|
|
|
| 6936 |
case GGML_OP_CONV_2D_DW:
|
| 6937 |
{
|
| 6938 |
uint32_t ne = ggml_nelements(dst);
|
|
@@ -6973,7 +7012,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
| 6973 |
}
|
| 6974 |
}
|
| 6975 |
|
| 6976 |
-
if (op == GGML_OP_SOFT_MAX) {
|
| 6977 |
// Empty src1 is possible in soft_max, but the shader needs a buffer
|
| 6978 |
vk_subbuffer subbuf_y;
|
| 6979 |
if (use_src1) {
|
|
@@ -7566,6 +7605,25 @@ static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, con
|
|
| 7566 |
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
|
| 7567 |
}
|
| 7568 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7569 |
static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
| 7570 |
int32_t * op_params = (int32_t *)dst->op_params;
|
| 7571 |
ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun);
|
|
@@ -8778,6 +8836,16 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
| 8778 |
return false;
|
| 8779 |
}
|
| 8780 |
break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8781 |
case GGML_OP_REPEAT:
|
| 8782 |
case GGML_OP_REPEAT_BACK:
|
| 8783 |
case GGML_OP_GET_ROWS:
|
|
@@ -8870,6 +8938,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
| 8870 |
case GGML_OP_RMS_NORM_BACK:
|
| 8871 |
case GGML_OP_L2_NORM:
|
| 8872 |
case GGML_OP_UNARY:
|
|
|
|
| 8873 |
case GGML_OP_DIAG_MASK_INF:
|
| 8874 |
case GGML_OP_SOFT_MAX:
|
| 8875 |
case GGML_OP_SOFT_MAX_BACK:
|
|
@@ -9013,6 +9082,17 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
| 9013 |
return false;
|
| 9014 |
}
|
| 9015 |
break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9016 |
case GGML_OP_DIAG_MASK_INF:
|
| 9017 |
ggml_vk_diag_mask_inf(ctx, compute_ctx, src0, node, dryrun);
|
| 9018 |
|
|
@@ -9138,8 +9218,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
| 9138 |
if (!ok) {
|
| 9139 |
if (node->op == GGML_OP_UNARY) {
|
| 9140 |
std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast<ggml_unary_op>(node->op_params[0])) << ")" << std::endl;
|
| 9141 |
-
}
|
| 9142 |
-
|
|
|
|
| 9143 |
std::cerr << __func__ << ": error: op not supported " << node->name << " (" << ggml_op_name(node->op) << ")" << std::endl;
|
| 9144 |
}
|
| 9145 |
}
|
|
@@ -9218,6 +9299,17 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
|
| 9218 |
return false;
|
| 9219 |
}
|
| 9220 |
break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9221 |
case GGML_OP_MUL_MAT:
|
| 9222 |
case GGML_OP_MUL_MAT_ID:
|
| 9223 |
case GGML_OP_FLASH_ATTN_EXT:
|
|
@@ -10016,6 +10108,19 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
| 10016 |
return false;
|
| 10017 |
}
|
| 10018 |
break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10019 |
case GGML_OP_MUL_MAT:
|
| 10020 |
case GGML_OP_MUL_MAT_ID:
|
| 10021 |
{
|
|
@@ -10746,6 +10851,12 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
| 10746 |
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
|
| 10747 |
GGML_ABORT("fatal error");
|
| 10748 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10749 |
} else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) {
|
| 10750 |
if (src1 == nullptr) {
|
| 10751 |
tensor_clone = ggml_dup(ggml_ctx, src_clone[0]);
|
|
|
|
| 437 |
vk_pipeline pipeline_tanh[2];
|
| 438 |
vk_pipeline pipeline_sigmoid[2];
|
| 439 |
|
| 440 |
+
vk_pipeline pipeline_geglu[2];
|
| 441 |
+
vk_pipeline pipeline_reglu[2];
|
| 442 |
+
vk_pipeline pipeline_swiglu[2];
|
| 443 |
+
|
| 444 |
vk_pipeline pipeline_leaky_relu_f32;
|
| 445 |
vk_pipeline pipeline_silu_back_f32;
|
| 446 |
vk_pipeline pipeline_diag_mask_inf_f32;
|
|
|
|
| 665 |
float param2;
|
| 666 |
};
|
| 667 |
|
| 668 |
+
struct vk_op_glu_push_constants {
|
| 669 |
+
uint32_t N;
|
| 670 |
+
uint32_t ne00;
|
| 671 |
+
uint32_t ne20;
|
| 672 |
+
uint32_t mode; // 0: default, 1: swapped, 2: split
|
| 673 |
+
};
|
| 674 |
+
|
| 675 |
struct vk_op_unary_push_constants {
|
| 676 |
uint32_t ne;
|
| 677 |
uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
|
|
|
|
| 2768 |
CREATE_UNARY(sigmoid)
|
| 2769 |
#undef CREATE_UNARY
|
| 2770 |
|
| 2771 |
+
#define CREATE_GLU(name) \
|
| 2772 |
+
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
|
| 2773 |
+
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true);
|
| 2774 |
+
|
| 2775 |
+
CREATE_GLU(geglu)
|
| 2776 |
+
CREATE_GLU(reglu)
|
| 2777 |
+
CREATE_GLU(swiglu)
|
| 2778 |
+
#undef CREATE_GLU
|
| 2779 |
+
|
| 2780 |
ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
| 2781 |
ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
| 2782 |
|
|
|
|
| 6493 |
break;
|
| 6494 |
}
|
| 6495 |
return nullptr;
|
| 6496 |
+
case GGML_OP_GLU:
|
| 6497 |
+
if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) ||
|
| 6498 |
+
(dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) ||
|
| 6499 |
+
(src0->type != dst->type)) {
|
| 6500 |
+
return nullptr;
|
| 6501 |
+
}
|
| 6502 |
+
|
| 6503 |
+
switch (ggml_get_glu_op(dst)) {
|
| 6504 |
+
case GGML_GLU_OP_GEGLU:
|
| 6505 |
+
return ctx->device->pipeline_geglu[dst->type == GGML_TYPE_F16];
|
| 6506 |
+
case GGML_GLU_OP_REGLU:
|
| 6507 |
+
return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16];
|
| 6508 |
+
case GGML_GLU_OP_SWIGLU:
|
| 6509 |
+
return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16];
|
| 6510 |
+
default:
|
| 6511 |
+
break;
|
| 6512 |
+
}
|
| 6513 |
+
return nullptr;
|
| 6514 |
case GGML_OP_DIAG_MASK_INF:
|
| 6515 |
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
| 6516 |
return ctx->device->pipeline_diag_mask_inf_f32;
|
|
|
|
| 6971 |
case GGML_OP_CONCAT:
|
| 6972 |
case GGML_OP_UPSCALE:
|
| 6973 |
case GGML_OP_UNARY:
|
| 6974 |
+
case GGML_OP_GLU:
|
| 6975 |
case GGML_OP_CONV_2D_DW:
|
| 6976 |
{
|
| 6977 |
uint32_t ne = ggml_nelements(dst);
|
|
|
|
| 7012 |
}
|
| 7013 |
}
|
| 7014 |
|
| 7015 |
+
if (op == GGML_OP_SOFT_MAX || op == GGML_OP_GLU) {
|
| 7016 |
// Empty src1 is possible in soft_max, but the shader needs a buffer
|
| 7017 |
vk_subbuffer subbuf_y;
|
| 7018 |
if (use_src1) {
|
|
|
|
| 7605 |
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
|
| 7606 |
}
|
| 7607 |
|
| 7608 |
+
static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
| 7609 |
+
const bool swapped = (bool)dst->op_params[1];
|
| 7610 |
+
const bool split = src1 != nullptr;
|
| 7611 |
+
|
| 7612 |
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
| 7613 |
+
|
| 7614 |
+
if (!split) {
|
| 7615 |
+
GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]);
|
| 7616 |
+
} else {
|
| 7617 |
+
GGML_ASSERT(src0->ne[0] == src1->ne[0]);
|
| 7618 |
+
GGML_ASSERT(src0->ne[0] == dst->ne[0]);
|
| 7619 |
+
GGML_ASSERT(src0->type == src1->type);
|
| 7620 |
+
}
|
| 7621 |
+
|
| 7622 |
+
const uint32_t mode = split ? 2 : (swapped ? 1 : 0);
|
| 7623 |
+
|
| 7624 |
+
ggml_vk_op_f32<vk_op_glu_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU, { (uint32_t)ggml_nelements(dst), (uint32_t)src0->ne[0], (uint32_t)dst->ne[0], mode }, dryrun);
|
| 7625 |
+
}
|
| 7626 |
+
|
| 7627 |
static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
| 7628 |
int32_t * op_params = (int32_t *)dst->op_params;
|
| 7629 |
ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun);
|
|
|
|
| 8836 |
return false;
|
| 8837 |
}
|
| 8838 |
break;
|
| 8839 |
+
case GGML_OP_GLU:
|
| 8840 |
+
switch (ggml_get_glu_op(node)) {
|
| 8841 |
+
case GGML_GLU_OP_GEGLU:
|
| 8842 |
+
case GGML_GLU_OP_REGLU:
|
| 8843 |
+
case GGML_GLU_OP_SWIGLU:
|
| 8844 |
+
break;
|
| 8845 |
+
default:
|
| 8846 |
+
return false;
|
| 8847 |
+
}
|
| 8848 |
+
break;
|
| 8849 |
case GGML_OP_REPEAT:
|
| 8850 |
case GGML_OP_REPEAT_BACK:
|
| 8851 |
case GGML_OP_GET_ROWS:
|
|
|
|
| 8938 |
case GGML_OP_RMS_NORM_BACK:
|
| 8939 |
case GGML_OP_L2_NORM:
|
| 8940 |
case GGML_OP_UNARY:
|
| 8941 |
+
case GGML_OP_GLU:
|
| 8942 |
case GGML_OP_DIAG_MASK_INF:
|
| 8943 |
case GGML_OP_SOFT_MAX:
|
| 8944 |
case GGML_OP_SOFT_MAX_BACK:
|
|
|
|
| 9082 |
return false;
|
| 9083 |
}
|
| 9084 |
break;
|
| 9085 |
+
case GGML_OP_GLU:
|
| 9086 |
+
switch (ggml_get_glu_op(node)) {
|
| 9087 |
+
case GGML_GLU_OP_GEGLU:
|
| 9088 |
+
case GGML_GLU_OP_REGLU:
|
| 9089 |
+
case GGML_GLU_OP_SWIGLU:
|
| 9090 |
+
ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun);
|
| 9091 |
+
break;
|
| 9092 |
+
default:
|
| 9093 |
+
return false;
|
| 9094 |
+
}
|
| 9095 |
+
break;
|
| 9096 |
case GGML_OP_DIAG_MASK_INF:
|
| 9097 |
ggml_vk_diag_mask_inf(ctx, compute_ctx, src0, node, dryrun);
|
| 9098 |
|
|
|
|
| 9218 |
if (!ok) {
|
| 9219 |
if (node->op == GGML_OP_UNARY) {
|
| 9220 |
std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast<ggml_unary_op>(node->op_params[0])) << ")" << std::endl;
|
| 9221 |
+
} else if (node->op == GGML_OP_GLU) {
|
| 9222 |
+
std::cerr << __func__ << ": error: op not supported GLU " << node->name << " (" << ggml_glu_op_name(static_cast<ggml_glu_op>(node->op_params[0])) << ")" << std::endl;
|
| 9223 |
+
} else {
|
| 9224 |
std::cerr << __func__ << ": error: op not supported " << node->name << " (" << ggml_op_name(node->op) << ")" << std::endl;
|
| 9225 |
}
|
| 9226 |
}
|
|
|
|
| 9299 |
return false;
|
| 9300 |
}
|
| 9301 |
break;
|
| 9302 |
+
case GGML_OP_GLU:
|
| 9303 |
+
switch (ggml_get_glu_op(tensor)) {
|
| 9304 |
+
case GGML_GLU_OP_GEGLU:
|
| 9305 |
+
case GGML_GLU_OP_REGLU:
|
| 9306 |
+
case GGML_GLU_OP_SWIGLU:
|
| 9307 |
+
buf = tensor->buffer;
|
| 9308 |
+
break;
|
| 9309 |
+
default:
|
| 9310 |
+
return false;
|
| 9311 |
+
}
|
| 9312 |
+
break;
|
| 9313 |
case GGML_OP_MUL_MAT:
|
| 9314 |
case GGML_OP_MUL_MAT_ID:
|
| 9315 |
case GGML_OP_FLASH_ATTN_EXT:
|
|
|
|
| 10108 |
return false;
|
| 10109 |
}
|
| 10110 |
break;
|
| 10111 |
+
case GGML_OP_GLU:
|
| 10112 |
+
switch (ggml_get_glu_op(op)) {
|
| 10113 |
+
case GGML_GLU_OP_GEGLU:
|
| 10114 |
+
case GGML_GLU_OP_REGLU:
|
| 10115 |
+
case GGML_GLU_OP_SWIGLU:
|
| 10116 |
+
return ggml_is_contiguous(op->src[0]) &&
|
| 10117 |
+
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
| 10118 |
+
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
|
| 10119 |
+
(op->src[0]->type == op->type);
|
| 10120 |
+
default:
|
| 10121 |
+
return false;
|
| 10122 |
+
}
|
| 10123 |
+
break;
|
| 10124 |
case GGML_OP_MUL_MAT:
|
| 10125 |
case GGML_OP_MUL_MAT_ID:
|
| 10126 |
{
|
|
|
|
| 10851 |
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
|
| 10852 |
GGML_ABORT("fatal error");
|
| 10853 |
}
|
| 10854 |
+
} else if (tensor->op == GGML_OP_GLU) {
|
| 10855 |
+
if (src_clone[1] == nullptr) {
|
| 10856 |
+
tensor_clone = ggml_glu(ggml_ctx, src_clone[0], (ggml_glu_op) tensor->op_params[0], tensor->op_params[1]);
|
| 10857 |
+
} else {
|
| 10858 |
+
tensor_clone = ggml_glu_split(ggml_ctx, src_clone[0], src_clone[1], (ggml_glu_op) tensor->op_params[0]);
|
| 10859 |
+
}
|
| 10860 |
} else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) {
|
| 10861 |
if (src1 == nullptr) {
|
| 10862 |
tensor_clone = ggml_dup(ggml_ctx, src_clone[0]);
|
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#version 450
|
| 2 |
+
|
| 3 |
+
#include "glu_head.comp"
|
| 4 |
+
|
| 5 |
+
const float GELU_COEF_A = 0.044715f;
|
| 6 |
+
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
| 7 |
+
|
| 8 |
+
float op(float a, float b) {
|
| 9 |
+
const float val = SQRT_2_OVER_PI*a*(1.0f + GELU_COEF_A*a*a);
|
| 10 |
+
return 0.5f*a*(2.0f - 2.0f / (exp(2 * val) + 1)) * b;
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
#include "glu_main.comp"
|
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#extension GL_EXT_shader_16bit_storage : require
|
| 2 |
+
|
| 3 |
+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
| 4 |
+
|
| 5 |
+
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
| 6 |
+
layout (binding = 1) readonly buffer B {A_TYPE data_b[];};
|
| 7 |
+
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
| 8 |
+
|
| 9 |
+
layout (push_constant) uniform parameter
|
| 10 |
+
{
|
| 11 |
+
uint N;
|
| 12 |
+
uint ne00;
|
| 13 |
+
uint ne20;
|
| 14 |
+
uint mode;
|
| 15 |
+
} p;
|
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
void main() {
|
| 2 |
+
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
| 3 |
+
|
| 4 |
+
if (i >= p.N) {
|
| 5 |
+
return;
|
| 6 |
+
}
|
| 7 |
+
|
| 8 |
+
const uint row = i / p.ne20;
|
| 9 |
+
const uint col = i - row * p.ne20;
|
| 10 |
+
|
| 11 |
+
if (p.mode == 0) {
|
| 12 |
+
// Default
|
| 13 |
+
const uint offset = p.ne00 / 2;
|
| 14 |
+
const uint idx = row * p.ne00 + col;
|
| 15 |
+
|
| 16 |
+
data_d[row * offset + col] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset])));
|
| 17 |
+
} else if (p.mode == 1) {
|
| 18 |
+
// Swapped
|
| 19 |
+
const uint offset = p.ne00 / 2;
|
| 20 |
+
const uint idx = row * p.ne00 + col;
|
| 21 |
+
|
| 22 |
+
data_d[row * offset + col] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx])));
|
| 23 |
+
} else {
|
| 24 |
+
// Split
|
| 25 |
+
const uint idx = row * p.ne00 + col;
|
| 26 |
+
|
| 27 |
+
data_d[idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx])));
|
| 28 |
+
}
|
| 29 |
+
}
|
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#version 450
|
| 2 |
+
|
| 3 |
+
#include "glu_head.comp"
|
| 4 |
+
|
| 5 |
+
float op(float a, float b) {
|
| 6 |
+
return max(a, 0.0f) * b;
|
| 7 |
+
}
|
| 8 |
+
|
| 9 |
+
#include "glu_main.comp"
|
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#version 450
|
| 2 |
+
|
| 3 |
+
#include "glu_head.comp"
|
| 4 |
+
|
| 5 |
+
float op(float a, float b) {
|
| 6 |
+
return a / (1.0f + exp(-a)) * b;
|
| 7 |
+
}
|
| 8 |
+
|
| 9 |
+
#include "glu_main.comp"
|
|
@@ -585,6 +585,13 @@ void process_shaders() {
|
|
| 585 |
string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
| 586 |
string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
| 587 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 588 |
string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
| 589 |
string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
| 590 |
|
|
|
|
| 585 |
string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
| 586 |
string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
| 587 |
|
| 588 |
+
string_to_spv("geglu_f16", "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
| 589 |
+
string_to_spv("geglu_f32", "geglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
| 590 |
+
string_to_spv("reglu_f16", "reglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
| 591 |
+
string_to_spv("reglu_f32", "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
| 592 |
+
string_to_spv("swiglu_f16", "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
| 593 |
+
string_to_spv("swiglu_f32", "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
| 594 |
+
|
| 595 |
string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
| 596 |
string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
| 597 |
|
|
@@ -982,9 +982,11 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
| 982 |
"CROSS_ENTROPY_LOSS",
|
| 983 |
"CROSS_ENTROPY_LOSS_BACK",
|
| 984 |
"OPT_STEP_ADAMW",
|
|
|
|
|
|
|
| 985 |
};
|
| 986 |
|
| 987 |
-
static_assert(GGML_OP_COUNT ==
|
| 988 |
|
| 989 |
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
| 990 |
"none",
|
|
@@ -1079,9 +1081,11 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
| 1079 |
"cross_entropy_loss(x,y)",
|
| 1080 |
"cross_entropy_loss_back(x,y)",
|
| 1081 |
"adamw(x)",
|
|
|
|
|
|
|
| 1082 |
};
|
| 1083 |
|
| 1084 |
-
static_assert(GGML_OP_COUNT ==
|
| 1085 |
|
| 1086 |
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
| 1087 |
|
|
@@ -1107,6 +1111,15 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
|
|
| 1107 |
static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15");
|
| 1108 |
|
| 1109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1110 |
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
|
| 1111 |
static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
|
| 1112 |
|
|
@@ -1209,11 +1222,19 @@ const char * ggml_unary_op_name(enum ggml_unary_op op) {
|
|
| 1209 |
return GGML_UNARY_OP_NAME[op];
|
| 1210 |
}
|
| 1211 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1212 |
const char * ggml_op_desc(const struct ggml_tensor * t) {
|
| 1213 |
if (t->op == GGML_OP_UNARY) {
|
| 1214 |
enum ggml_unary_op uop = ggml_get_unary_op(t);
|
| 1215 |
return ggml_unary_op_name(uop);
|
| 1216 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1217 |
return ggml_op_name(t->op);
|
| 1218 |
}
|
| 1219 |
|
|
@@ -1730,6 +1751,11 @@ enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor) {
|
|
| 1730 |
return (enum ggml_unary_op) ggml_get_op_params_i32(tensor, 0);
|
| 1731 |
}
|
| 1732 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1733 |
const char * ggml_get_name(const struct ggml_tensor * tensor) {
|
| 1734 |
return tensor->name;
|
| 1735 |
}
|
|
@@ -2609,6 +2635,114 @@ struct ggml_tensor * ggml_exp_inplace(
|
|
| 2609 |
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_EXP);
|
| 2610 |
}
|
| 2611 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2612 |
// ggml_norm
|
| 2613 |
|
| 2614 |
static struct ggml_tensor * ggml_norm_impl(
|
|
|
|
| 982 |
"CROSS_ENTROPY_LOSS",
|
| 983 |
"CROSS_ENTROPY_LOSS_BACK",
|
| 984 |
"OPT_STEP_ADAMW",
|
| 985 |
+
|
| 986 |
+
"GLU",
|
| 987 |
};
|
| 988 |
|
| 989 |
+
static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85");
|
| 990 |
|
| 991 |
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
| 992 |
"none",
|
|
|
|
| 1081 |
"cross_entropy_loss(x,y)",
|
| 1082 |
"cross_entropy_loss_back(x,y)",
|
| 1083 |
"adamw(x)",
|
| 1084 |
+
|
| 1085 |
+
"glu(x)",
|
| 1086 |
};
|
| 1087 |
|
| 1088 |
+
static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85");
|
| 1089 |
|
| 1090 |
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
| 1091 |
|
|
|
|
| 1111 |
static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15");
|
| 1112 |
|
| 1113 |
|
| 1114 |
+
static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = {
|
| 1115 |
+
"REGLU",
|
| 1116 |
+
"GEGLU",
|
| 1117 |
+
"SWIGLU",
|
| 1118 |
+
};
|
| 1119 |
+
|
| 1120 |
+
static_assert(GGML_GLU_OP_COUNT == 3, "GGML_GLU_OP_COUNT != 3");
|
| 1121 |
+
|
| 1122 |
+
|
| 1123 |
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
|
| 1124 |
static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
|
| 1125 |
|
|
|
|
| 1222 |
return GGML_UNARY_OP_NAME[op];
|
| 1223 |
}
|
| 1224 |
|
| 1225 |
+
const char * ggml_glu_op_name(enum ggml_glu_op op) {
|
| 1226 |
+
return GGML_GLU_OP_NAME[op];
|
| 1227 |
+
}
|
| 1228 |
+
|
| 1229 |
const char * ggml_op_desc(const struct ggml_tensor * t) {
|
| 1230 |
if (t->op == GGML_OP_UNARY) {
|
| 1231 |
enum ggml_unary_op uop = ggml_get_unary_op(t);
|
| 1232 |
return ggml_unary_op_name(uop);
|
| 1233 |
}
|
| 1234 |
+
if (t->op == GGML_OP_GLU) {
|
| 1235 |
+
enum ggml_glu_op gop = ggml_get_glu_op(t);
|
| 1236 |
+
return ggml_glu_op_name(gop);
|
| 1237 |
+
}
|
| 1238 |
return ggml_op_name(t->op);
|
| 1239 |
}
|
| 1240 |
|
|
|
|
| 1751 |
return (enum ggml_unary_op) ggml_get_op_params_i32(tensor, 0);
|
| 1752 |
}
|
| 1753 |
|
| 1754 |
+
enum ggml_glu_op ggml_get_glu_op(const struct ggml_tensor * tensor) {
|
| 1755 |
+
GGML_ASSERT(tensor->op == GGML_OP_GLU);
|
| 1756 |
+
return (enum ggml_glu_op) ggml_get_op_params_i32(tensor, 0);
|
| 1757 |
+
}
|
| 1758 |
+
|
| 1759 |
const char * ggml_get_name(const struct ggml_tensor * tensor) {
|
| 1760 |
return tensor->name;
|
| 1761 |
}
|
|
|
|
| 2635 |
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_EXP);
|
| 2636 |
}
|
| 2637 |
|
| 2638 |
+
// ggml_glu
|
| 2639 |
+
|
| 2640 |
+
static struct ggml_tensor * ggml_glu_impl(
|
| 2641 |
+
struct ggml_context * ctx,
|
| 2642 |
+
struct ggml_tensor * a,
|
| 2643 |
+
struct ggml_tensor * b,
|
| 2644 |
+
enum ggml_glu_op op,
|
| 2645 |
+
bool swapped) {
|
| 2646 |
+
GGML_ASSERT(ggml_is_contiguous_1(a));
|
| 2647 |
+
|
| 2648 |
+
if (b) {
|
| 2649 |
+
GGML_ASSERT(ggml_is_contiguous_1(b));
|
| 2650 |
+
GGML_ASSERT(ggml_are_same_shape(a, b));
|
| 2651 |
+
GGML_ASSERT(a->type == b->type);
|
| 2652 |
+
}
|
| 2653 |
+
|
| 2654 |
+
int64_t ne[GGML_MAX_DIMS] = { a->ne[0] / 2 }; for (int i = 1; i < GGML_MAX_DIMS; i++) ne[i] = a->ne[i];
|
| 2655 |
+
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, b ? a->ne : ne, NULL, 0);
|
| 2656 |
+
|
| 2657 |
+
ggml_set_op_params_i32(result, 0, (int32_t) op);
|
| 2658 |
+
ggml_set_op_params_i32(result, 1, (int32_t) swapped);
|
| 2659 |
+
|
| 2660 |
+
result->op = GGML_OP_GLU;
|
| 2661 |
+
result->src[0] = a;
|
| 2662 |
+
result->src[1] = b;
|
| 2663 |
+
|
| 2664 |
+
return result;
|
| 2665 |
+
}
|
| 2666 |
+
|
| 2667 |
+
struct ggml_tensor * ggml_glu(
|
| 2668 |
+
struct ggml_context * ctx,
|
| 2669 |
+
struct ggml_tensor * a,
|
| 2670 |
+
enum ggml_glu_op op,
|
| 2671 |
+
bool swapped) {
|
| 2672 |
+
return ggml_glu_impl(ctx, a, NULL, op, swapped);
|
| 2673 |
+
}
|
| 2674 |
+
|
| 2675 |
+
struct ggml_tensor * ggml_glu_split(
|
| 2676 |
+
struct ggml_context * ctx,
|
| 2677 |
+
struct ggml_tensor * a,
|
| 2678 |
+
struct ggml_tensor * b,
|
| 2679 |
+
enum ggml_glu_op op) {
|
| 2680 |
+
return ggml_glu_impl(ctx, a, b, op, false);
|
| 2681 |
+
}
|
| 2682 |
+
|
| 2683 |
+
// ggml_reglu
|
| 2684 |
+
|
| 2685 |
+
struct ggml_tensor * ggml_reglu(
|
| 2686 |
+
struct ggml_context * ctx,
|
| 2687 |
+
struct ggml_tensor * a) {
|
| 2688 |
+
return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_REGLU, false);
|
| 2689 |
+
}
|
| 2690 |
+
|
| 2691 |
+
struct ggml_tensor * ggml_reglu_swapped(
|
| 2692 |
+
struct ggml_context * ctx,
|
| 2693 |
+
struct ggml_tensor * a) {
|
| 2694 |
+
return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_REGLU, true);
|
| 2695 |
+
}
|
| 2696 |
+
|
| 2697 |
+
struct ggml_tensor * ggml_reglu_split(
|
| 2698 |
+
struct ggml_context * ctx,
|
| 2699 |
+
struct ggml_tensor * a,
|
| 2700 |
+
struct ggml_tensor * b) {
|
| 2701 |
+
return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_REGLU, false);
|
| 2702 |
+
}
|
| 2703 |
+
|
| 2704 |
+
// ggml_geglu
|
| 2705 |
+
|
| 2706 |
+
struct ggml_tensor * ggml_geglu(
|
| 2707 |
+
struct ggml_context * ctx,
|
| 2708 |
+
struct ggml_tensor * a) {
|
| 2709 |
+
return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU, false);
|
| 2710 |
+
}
|
| 2711 |
+
|
| 2712 |
+
struct ggml_tensor * ggml_geglu_swapped(
|
| 2713 |
+
struct ggml_context * ctx,
|
| 2714 |
+
struct ggml_tensor * a) {
|
| 2715 |
+
return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU, true);
|
| 2716 |
+
}
|
| 2717 |
+
|
| 2718 |
+
struct ggml_tensor * ggml_geglu_split(
|
| 2719 |
+
struct ggml_context * ctx,
|
| 2720 |
+
struct ggml_tensor * a,
|
| 2721 |
+
struct ggml_tensor * b) {
|
| 2722 |
+
return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU, false);
|
| 2723 |
+
}
|
| 2724 |
+
|
| 2725 |
+
// ggml_swiglu
|
| 2726 |
+
|
| 2727 |
+
struct ggml_tensor * ggml_swiglu(
|
| 2728 |
+
struct ggml_context * ctx,
|
| 2729 |
+
struct ggml_tensor * a) {
|
| 2730 |
+
return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_SWIGLU, false);
|
| 2731 |
+
}
|
| 2732 |
+
|
| 2733 |
+
struct ggml_tensor * ggml_swiglu_swapped(
|
| 2734 |
+
struct ggml_context * ctx,
|
| 2735 |
+
struct ggml_tensor * a) {
|
| 2736 |
+
return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_SWIGLU, true);
|
| 2737 |
+
}
|
| 2738 |
+
|
| 2739 |
+
struct ggml_tensor * ggml_swiglu_split(
|
| 2740 |
+
struct ggml_context * ctx,
|
| 2741 |
+
struct ggml_tensor * a,
|
| 2742 |
+
struct ggml_tensor * b) {
|
| 2743 |
+
return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_SWIGLU, false);
|
| 2744 |
+
}
|
| 2745 |
+
|
| 2746 |
// ggml_norm
|
| 2747 |
|
| 2748 |
static struct ggml_tensor * ggml_norm_impl(
|