Spaces:
Running
Optimize RWKV6 Operator Naming and Implement Multi-core CPU/ SYCL Acceleration (llama/10133)
Browse files* rwkv6: rename to wkv6
* rwkv6: support avx2 avx512 armv8 armv9
* rwkv6: update cuda file name
* rwkv6: rename params
* wkv on sycl
* sycl: add some ops
* sycl: Enhance OP support judgment
* wkv6: drop armv9 and tranfer to GGML style
ggml-ci
* sync : ggml
* update the function to use appropriate types
* fix define error
* Update ggml/src/ggml-cpu.c
* add appropriate asserts
* move element-wise functions outside
* put the declaration outside the loop
* rewrite to be more inline with the common pattern for distributing threads
* use recommended way GGML_TENSOR_LOCALS
---------
Co-authored-by: Georgi Gerganov <[email protected]>
Co-authored-by: Diego Devesa <[email protected]>
Co-authored-by: Plamen Minev <[email protected]>
Co-authored-by: Yuri Khrustalev <[email protected]>
Co-authored-by: Meng, Hengyu <[email protected]>
- ggml/include/ggml.h +2 -2
- ggml/src/ggml-cpu.c +160 -48
- ggml/src/ggml-cuda.cu +4 -4
- ggml/src/ggml-cuda/wkv6.cu +89 -0
- ggml/src/ggml-cuda/wkv6.cuh +5 -0
- ggml/src/ggml-sycl.cpp +259 -1030
- ggml/src/ggml-sycl/backend.hpp +3 -0
- ggml/src/ggml-sycl/common.cpp +40 -0
- ggml/src/ggml-sycl/common.hpp +258 -0
- ggml/src/ggml-sycl/concat.cpp +1 -0
- ggml/src/ggml-sycl/element_wise.cpp +1011 -0
- ggml/src/ggml-sycl/element_wise.hpp +76 -0
- ggml/src/ggml-sycl/outprod.cpp +55 -0
- ggml/src/ggml-sycl/outprod.hpp +11 -0
- ggml/src/ggml-sycl/presets.hpp +6 -0
- ggml/src/ggml-sycl/wkv6.cpp +138 -0
- ggml/src/ggml-sycl/wkv6.hpp +10 -0
- ggml/src/ggml.c +6 -6
|
@@ -509,7 +509,7 @@ extern "C" {
|
|
| 509 |
GGML_OP_WIN_UNPART,
|
| 510 |
GGML_OP_GET_REL_POS,
|
| 511 |
GGML_OP_ADD_REL_POS,
|
| 512 |
-
|
| 513 |
|
| 514 |
GGML_OP_UNARY,
|
| 515 |
|
|
@@ -1819,7 +1819,7 @@ extern "C" {
|
|
| 1819 |
struct ggml_tensor * pw,
|
| 1820 |
struct ggml_tensor * ph);
|
| 1821 |
|
| 1822 |
-
GGML_API struct ggml_tensor *
|
| 1823 |
struct ggml_context * ctx,
|
| 1824 |
struct ggml_tensor * k,
|
| 1825 |
struct ggml_tensor * v,
|
|
|
|
| 509 |
GGML_OP_WIN_UNPART,
|
| 510 |
GGML_OP_GET_REL_POS,
|
| 511 |
GGML_OP_ADD_REL_POS,
|
| 512 |
+
GGML_OP_RWKV_WKV6,
|
| 513 |
|
| 514 |
GGML_OP_UNARY,
|
| 515 |
|
|
|
|
| 1819 |
struct ggml_tensor * pw,
|
| 1820 |
struct ggml_tensor * ph);
|
| 1821 |
|
| 1822 |
+
GGML_API struct ggml_tensor * ggml_rwkv_wkv6(
|
| 1823 |
struct ggml_context * ctx,
|
| 1824 |
struct ggml_tensor * k,
|
| 1825 |
struct ggml_tensor * v,
|
|
@@ -11642,24 +11642,30 @@ static void ggml_compute_forward_add_rel_pos(
|
|
| 11642 |
}
|
| 11643 |
}
|
| 11644 |
|
| 11645 |
-
//
|
| 11646 |
|
| 11647 |
-
static void
|
| 11648 |
const struct ggml_compute_params * params,
|
| 11649 |
struct ggml_tensor * dst) {
|
| 11650 |
-
const
|
| 11651 |
-
const
|
| 11652 |
-
const
|
| 11653 |
-
const
|
|
|
|
| 11654 |
|
| 11655 |
float * dst_data = (float *) dst->data;
|
| 11656 |
float * state = ((float *) dst->data) + C * T;
|
| 11657 |
|
| 11658 |
-
|
|
|
|
|
|
|
|
|
|
| 11659 |
return;
|
| 11660 |
}
|
| 11661 |
|
| 11662 |
-
|
|
|
|
|
|
|
| 11663 |
|
| 11664 |
float * k = (float *) dst->src[0]->data;
|
| 11665 |
float * v = (float *) dst->src[1]->data;
|
|
@@ -11667,54 +11673,160 @@ static void ggml_compute_forward_rwkv_wkv_f32(
|
|
| 11667 |
float * time_faaaa = (float *) dst->src[3]->data;
|
| 11668 |
float * time_decay = (float *) dst->src[4]->data;
|
| 11669 |
|
| 11670 |
-
size_t t_stride =
|
| 11671 |
|
| 11672 |
-
size_t h_stride = C /
|
| 11673 |
-
|
|
|
|
| 11674 |
|
| 11675 |
-
|
| 11676 |
-
|
| 11677 |
-
|
| 11678 |
-
|
| 11679 |
-
for (size_t t = 0; t < T; t++) {
|
| 11680 |
-
size_t t_offset = t * t_stride;
|
| 11681 |
-
size_t state_offset = (C / H) * C * (t / (T / n_seqs));
|
| 11682 |
-
float * state_cur = state + state_offset;
|
| 11683 |
-
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
|
| 11684 |
|
| 11685 |
-
for (size_t h = 0; h < H; h++) {
|
| 11686 |
-
size_t h_offset = h * h_stride;
|
| 11687 |
-
size_t t_h_offset = t_offset + h_offset;
|
| 11688 |
-
size_t h_2d_offset = h * h_stride_2d;
|
| 11689 |
|
| 11690 |
-
|
| 11691 |
-
|
| 11692 |
-
|
| 11693 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11694 |
|
| 11695 |
-
|
| 11696 |
-
|
| 11697 |
-
|
| 11698 |
-
|
| 11699 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11700 |
|
| 11701 |
-
|
| 11702 |
-
|
| 11703 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11704 |
|
| 11705 |
-
|
| 11706 |
-
|
| 11707 |
-
|
| 11708 |
-
|
| 11709 |
-
|
| 11710 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11711 |
}
|
| 11712 |
}
|
| 11713 |
}
|
| 11714 |
-
|
| 11715 |
}
|
| 11716 |
|
| 11717 |
-
|
|
|
|
| 11718 |
const struct ggml_compute_params * params,
|
| 11719 |
struct ggml_tensor * dst) {
|
| 11720 |
|
|
@@ -11723,7 +11835,7 @@ static void ggml_compute_forward_rwkv_wkv(
|
|
| 11723 |
switch (src0->type) {
|
| 11724 |
case GGML_TYPE_F32:
|
| 11725 |
{
|
| 11726 |
-
|
| 11727 |
} break;
|
| 11728 |
default:
|
| 11729 |
{
|
|
@@ -12475,9 +12587,9 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
| 12475 |
{
|
| 12476 |
ggml_compute_forward_add_rel_pos(params, tensor);
|
| 12477 |
} break;
|
| 12478 |
-
case
|
| 12479 |
{
|
| 12480 |
-
|
| 12481 |
} break;
|
| 12482 |
case GGML_OP_MAP_UNARY:
|
| 12483 |
{
|
|
@@ -12775,7 +12887,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|
| 12775 |
case GGML_OP_WIN_PART:
|
| 12776 |
case GGML_OP_WIN_UNPART:
|
| 12777 |
case GGML_OP_GET_REL_POS:
|
| 12778 |
-
case
|
| 12779 |
case GGML_OP_MAP_UNARY:
|
| 12780 |
case GGML_OP_MAP_BINARY:
|
| 12781 |
case GGML_OP_MAP_CUSTOM1_F32:
|
|
|
|
| 11642 |
}
|
| 11643 |
}
|
| 11644 |
|
| 11645 |
+
// ggml_compute_forward_rwkv_wkv6
|
| 11646 |
|
| 11647 |
+
static void ggml_compute_forward_rwkv_wkv6_f32(
|
| 11648 |
const struct ggml_compute_params * params,
|
| 11649 |
struct ggml_tensor * dst) {
|
| 11650 |
+
const int64_t T = dst->src[1]->ne[3];
|
| 11651 |
+
const int64_t C = dst->ne[0];
|
| 11652 |
+
const int64_t HEADS = dst->src[1]->ne[2];
|
| 11653 |
+
const int64_t n_seqs = dst->src[5]->ne[1];
|
| 11654 |
+
const int64_t head_size = C / HEADS;
|
| 11655 |
|
| 11656 |
float * dst_data = (float *) dst->data;
|
| 11657 |
float * state = ((float *) dst->data) + C * T;
|
| 11658 |
|
| 11659 |
+
const int ith = params->ith;
|
| 11660 |
+
const int nth = params->nth;
|
| 11661 |
+
|
| 11662 |
+
if (ith >= HEADS) {
|
| 11663 |
return;
|
| 11664 |
}
|
| 11665 |
|
| 11666 |
+
const int h_start = (HEADS * ith) / nth;
|
| 11667 |
+
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
|
| 11668 |
+
(HEADS * (ith + 1)) / nth : HEADS;
|
| 11669 |
|
| 11670 |
float * k = (float *) dst->src[0]->data;
|
| 11671 |
float * v = (float *) dst->src[1]->data;
|
|
|
|
| 11673 |
float * time_faaaa = (float *) dst->src[3]->data;
|
| 11674 |
float * time_decay = (float *) dst->src[4]->data;
|
| 11675 |
|
| 11676 |
+
size_t t_stride = HEADS * head_size; // Same to C
|
| 11677 |
|
| 11678 |
+
size_t h_stride = C / HEADS;
|
| 11679 |
+
GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
|
| 11680 |
+
size_t h_stride_2d = head_size * head_size;
|
| 11681 |
|
| 11682 |
+
if (ith == 0) {
|
| 11683 |
+
memset(dst_data, 0, T * C * sizeof(float));
|
| 11684 |
+
}
|
| 11685 |
+
ggml_barrier(params->threadpool);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11686 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11687 |
|
| 11688 |
+
#if defined(__AVX__) && !defined(__AVX512F__)
|
| 11689 |
+
#define GGML_F32X GGML_F32x8
|
| 11690 |
+
#define GGML_F32X_SET1 GGML_F32x8_SET1
|
| 11691 |
+
#define GGML_F32X_LOAD GGML_F32x8_LOAD
|
| 11692 |
+
#define GGML_F32X_STORE GGML_F32x8_STORE
|
| 11693 |
+
#define GGML_F32X_MUL GGML_F32x8_MUL
|
| 11694 |
+
#define GGML_F32X_FMA GGML_F32x8_FMA
|
| 11695 |
+
#define WKV_VECTOR_SIZE 8
|
| 11696 |
+
#elif defined(__AVX512F__)
|
| 11697 |
+
#define GGML_F32X GGML_F32x16
|
| 11698 |
+
#define GGML_F32X_SET1 GGML_F32x16_SET1
|
| 11699 |
+
#define GGML_F32X_LOAD GGML_F32x16_LOAD
|
| 11700 |
+
#define GGML_F32X_STORE GGML_F32x16_STORE
|
| 11701 |
+
#define GGML_F32X_MUL GGML_F32x16_MUL
|
| 11702 |
+
#define GGML_F32X_FMA GGML_F32x16_FMA
|
| 11703 |
+
#define WKV_VECTOR_SIZE 16
|
| 11704 |
+
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
| 11705 |
+
#define GGML_F32X GGML_F32x4
|
| 11706 |
+
#define GGML_F32X_SET1 GGML_F32x4_SET1
|
| 11707 |
+
#define GGML_F32X_LOAD GGML_F32x4_LOAD
|
| 11708 |
+
#define GGML_F32X_STORE GGML_F32x4_STORE
|
| 11709 |
+
#define GGML_F32X_MUL GGML_F32x4_MUL
|
| 11710 |
+
#define GGML_F32X_FMA GGML_F32x4_FMA
|
| 11711 |
+
#define WKV_VECTOR_SIZE 4
|
| 11712 |
+
#endif
|
| 11713 |
|
| 11714 |
+
#ifdef WKV_VECTOR_SIZE
|
| 11715 |
+
const int64_t vec_count = head_size / WKV_VECTOR_SIZE;
|
| 11716 |
+
|
| 11717 |
+
for (int64_t t = 0; t < T; t++) {
|
| 11718 |
+
size_t t_offset = t * t_stride;
|
| 11719 |
+
size_t state_offset = head_size * C * (t / (T / n_seqs));
|
| 11720 |
+
float * state_cur = state + state_offset;
|
| 11721 |
+
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
|
| 11722 |
+
|
| 11723 |
+
for (int64_t h = h_start; h < h_end; h++) {
|
| 11724 |
+
size_t h_offset = h * h_stride;
|
| 11725 |
+
size_t t_h_offset = t_offset + h_offset;
|
| 11726 |
+
size_t h_2d_offset = h * h_stride_2d;
|
| 11727 |
+
|
| 11728 |
+
for (int64_t i = 0; i < head_size; i++) {
|
| 11729 |
+
size_t t_h_i_offset = t_h_offset + i;
|
| 11730 |
+
size_t h_i_offset = h_offset + i;
|
| 11731 |
+
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
|
| 11732 |
+
|
| 11733 |
+
float k_val = k[t_h_i_offset];
|
| 11734 |
+
float r_val = r[t_h_i_offset];
|
| 11735 |
+
float time_faaaa_val = time_faaaa[h_i_offset];
|
| 11736 |
+
float time_decay_val = time_decay[t_h_i_offset];
|
| 11737 |
+
|
| 11738 |
+
// Broadcast scalar values to vectors
|
| 11739 |
+
GGML_F32X k_vec = GGML_F32X_SET1(k_val);
|
| 11740 |
+
GGML_F32X r_vec = GGML_F32X_SET1(r_val);
|
| 11741 |
+
GGML_F32X time_faaaa_vec = GGML_F32X_SET1(time_faaaa_val);
|
| 11742 |
+
GGML_F32X time_decay_vec = GGML_F32X_SET1(time_decay_val);
|
| 11743 |
+
|
| 11744 |
+
for (int64_t j = 0; j < vec_count; j++) {
|
| 11745 |
+
size_t base_j = j * WKV_VECTOR_SIZE;
|
| 11746 |
+
size_t t_h_j_offset = t_h_offset + base_j;
|
| 11747 |
+
size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
|
| 11748 |
+
|
| 11749 |
+
// Load x elements at once
|
| 11750 |
+
GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]);
|
| 11751 |
+
GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);
|
| 11752 |
+
GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]);
|
| 11753 |
+
|
| 11754 |
+
// Compute kv = v * k
|
| 11755 |
+
GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec);
|
| 11756 |
+
|
| 11757 |
+
// Compute temp = kv * time_faaaa + prev_state
|
| 11758 |
+
GGML_F32X temp_vec = GGML_F32X_FMA(prev_state_vec, kv_vec, time_faaaa_vec);
|
| 11759 |
+
|
| 11760 |
+
// Update dst: dst += temp * r
|
| 11761 |
+
dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, r_vec);
|
| 11762 |
+
GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);
|
| 11763 |
+
|
| 11764 |
+
// Update state: state = prev_state * time_decay + kv
|
| 11765 |
+
GGML_F32X new_state_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, time_decay_vec);
|
| 11766 |
+
GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], new_state_vec);
|
| 11767 |
+
}
|
| 11768 |
|
| 11769 |
+
// Handle remaining elements, this will not be used.
|
| 11770 |
+
for (int64_t j = vec_count * WKV_VECTOR_SIZE; j < head_size; j++) {
|
| 11771 |
+
size_t t_h_j_offset = t_h_offset + j;
|
| 11772 |
+
size_t h_2d_i_j_offset = h_2d_i_offset + j;
|
| 11773 |
+
float v_val = v[t_h_j_offset];
|
| 11774 |
+
float kv_val = v_val * k_val;
|
| 11775 |
+
float prev_state_val = state_prev[h_2d_i_j_offset];
|
| 11776 |
+
float temp_val = kv_val * time_faaaa_val + prev_state_val;
|
| 11777 |
+
dst_data[t_h_j_offset] += temp_val * r_val;
|
| 11778 |
+
state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
|
| 11779 |
+
}
|
| 11780 |
+
}
|
| 11781 |
+
}
|
| 11782 |
+
}
|
| 11783 |
|
| 11784 |
+
#else
|
| 11785 |
+
// basically fused operations:
|
| 11786 |
+
// dst = r @ (time_faaaa * (k @ v) + state),
|
| 11787 |
+
// state = time_decay * state + (k @ v),
|
| 11788 |
+
// recursive through each token
|
| 11789 |
+
for (int64_t t = 0; t < T; t++) {
|
| 11790 |
+
size_t t_offset = t * t_stride;
|
| 11791 |
+
size_t state_offset = head_size * C * (t / (T / n_seqs));
|
| 11792 |
+
float * state_cur = state + state_offset;
|
| 11793 |
+
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
|
| 11794 |
+
|
| 11795 |
+
for (int64_t h = h_start; h < h_end; h++) {
|
| 11796 |
+
size_t h_offset = h * h_stride;
|
| 11797 |
+
size_t t_h_offset = t_offset + h_offset;
|
| 11798 |
+
size_t h_2d_offset = h * h_stride_2d;
|
| 11799 |
+
|
| 11800 |
+
for (int64_t i = 0; i < head_size; i++) {
|
| 11801 |
+
size_t t_h_i_offset = t_h_offset + i;
|
| 11802 |
+
size_t h_i_offset = h_offset + i;
|
| 11803 |
+
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
|
| 11804 |
+
|
| 11805 |
+
float k_val = k[t_h_i_offset];
|
| 11806 |
+
float r_val = r[t_h_i_offset];
|
| 11807 |
+
float time_faaaa_val = time_faaaa[h_i_offset];
|
| 11808 |
+
// RWKV v6: different time_decay for each token.
|
| 11809 |
+
float time_decay_val = time_decay[t_h_i_offset];
|
| 11810 |
+
|
| 11811 |
+
for (int64_t j = 0; j < head_size; j++) {
|
| 11812 |
+
size_t t_h_j_offset = t_h_offset + j;
|
| 11813 |
+
size_t h_2d_i_j_offset = h_2d_i_offset + j;
|
| 11814 |
+
|
| 11815 |
+
float v_val = v[t_h_j_offset];
|
| 11816 |
+
float kv_val = v_val * k_val;
|
| 11817 |
+
float prev_state_val = state_prev[h_2d_i_j_offset];
|
| 11818 |
+
float temp_val = kv_val * time_faaaa_val + prev_state_val;
|
| 11819 |
+
dst_data[t_h_j_offset] += temp_val * r_val;
|
| 11820 |
+
state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
|
| 11821 |
+
}
|
| 11822 |
}
|
| 11823 |
}
|
| 11824 |
}
|
| 11825 |
+
#endif
|
| 11826 |
}
|
| 11827 |
|
| 11828 |
+
|
| 11829 |
+
static void ggml_compute_forward_rwkv_wkv6(
|
| 11830 |
const struct ggml_compute_params * params,
|
| 11831 |
struct ggml_tensor * dst) {
|
| 11832 |
|
|
|
|
| 11835 |
switch (src0->type) {
|
| 11836 |
case GGML_TYPE_F32:
|
| 11837 |
{
|
| 11838 |
+
ggml_compute_forward_rwkv_wkv6_f32(params, dst);
|
| 11839 |
} break;
|
| 11840 |
default:
|
| 11841 |
{
|
|
|
|
| 12587 |
{
|
| 12588 |
ggml_compute_forward_add_rel_pos(params, tensor);
|
| 12589 |
} break;
|
| 12590 |
+
case GGML_OP_RWKV_WKV6:
|
| 12591 |
{
|
| 12592 |
+
ggml_compute_forward_rwkv_wkv6(params, tensor);
|
| 12593 |
} break;
|
| 12594 |
case GGML_OP_MAP_UNARY:
|
| 12595 |
{
|
|
|
|
| 12887 |
case GGML_OP_WIN_PART:
|
| 12888 |
case GGML_OP_WIN_UNPART:
|
| 12889 |
case GGML_OP_GET_REL_POS:
|
| 12890 |
+
case GGML_OP_RWKV_WKV6:
|
| 12891 |
case GGML_OP_MAP_UNARY:
|
| 12892 |
case GGML_OP_MAP_BINARY:
|
| 12893 |
case GGML_OP_MAP_CUSTOM1_F32:
|
|
@@ -36,7 +36,7 @@
|
|
| 36 |
#include "ggml-cuda/tsembd.cuh"
|
| 37 |
#include "ggml-cuda/unary.cuh"
|
| 38 |
#include "ggml-cuda/upscale.cuh"
|
| 39 |
-
#include "ggml-cuda/
|
| 40 |
|
| 41 |
#include <algorithm>
|
| 42 |
#include <array>
|
|
@@ -2319,8 +2319,8 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|
| 2319 |
case GGML_OP_CROSS_ENTROPY_LOSS:
|
| 2320 |
ggml_cuda_cross_entropy_loss(ctx, dst);
|
| 2321 |
break;
|
| 2322 |
-
case
|
| 2323 |
-
|
| 2324 |
break;
|
| 2325 |
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
| 2326 |
ggml_cuda_cross_entropy_loss_back(ctx, dst);
|
|
@@ -3153,7 +3153,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|
| 3153 |
case GGML_OP_ARANGE:
|
| 3154 |
case GGML_OP_TIMESTEP_EMBEDDING:
|
| 3155 |
case GGML_OP_LEAKY_RELU:
|
| 3156 |
-
case
|
| 3157 |
return true;
|
| 3158 |
case GGML_OP_FLASH_ATTN_EXT: {
|
| 3159 |
#ifndef FLASH_ATTN_AVAILABLE
|
|
|
|
| 36 |
#include "ggml-cuda/tsembd.cuh"
|
| 37 |
#include "ggml-cuda/unary.cuh"
|
| 38 |
#include "ggml-cuda/upscale.cuh"
|
| 39 |
+
#include "ggml-cuda/wkv6.cuh"
|
| 40 |
|
| 41 |
#include <algorithm>
|
| 42 |
#include <array>
|
|
|
|
| 2319 |
case GGML_OP_CROSS_ENTROPY_LOSS:
|
| 2320 |
ggml_cuda_cross_entropy_loss(ctx, dst);
|
| 2321 |
break;
|
| 2322 |
+
case GGML_OP_RWKV_WKV6:
|
| 2323 |
+
ggml_cuda_op_rwkv_wkv6(ctx, dst);
|
| 2324 |
break;
|
| 2325 |
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
| 2326 |
ggml_cuda_cross_entropy_loss_back(ctx, dst);
|
|
|
|
| 3153 |
case GGML_OP_ARANGE:
|
| 3154 |
case GGML_OP_TIMESTEP_EMBEDDING:
|
| 3155 |
case GGML_OP_LEAKY_RELU:
|
| 3156 |
+
case GGML_OP_RWKV_WKV6:
|
| 3157 |
return true;
|
| 3158 |
case GGML_OP_FLASH_ATTN_EXT: {
|
| 3159 |
#ifndef FLASH_ATTN_AVAILABLE
|
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "common.cuh"
|
| 2 |
+
#include "wkv6.cuh"
|
| 3 |
+
|
| 4 |
+
static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {
|
| 5 |
+
const int tid = threadIdx.x;
|
| 6 |
+
const int bid = blockIdx.x;
|
| 7 |
+
|
| 8 |
+
const int head_size = CUDA_WKV_BLOCK_SIZE;
|
| 9 |
+
const int batch_i = bid / H;
|
| 10 |
+
const int head_i = bid % H;
|
| 11 |
+
const int state_size = C * head_size;
|
| 12 |
+
const int n_seq_tokens = T / B;
|
| 13 |
+
|
| 14 |
+
float state[head_size];
|
| 15 |
+
__shared__ float _k[head_size], _r[head_size], _tf[head_size], _td[head_size];
|
| 16 |
+
|
| 17 |
+
#pragma unroll
|
| 18 |
+
for (int i = 0; i < head_size; i++) {
|
| 19 |
+
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
__syncthreads();
|
| 23 |
+
_tf[tid] = tf[head_i * head_size + tid];
|
| 24 |
+
__syncthreads();
|
| 25 |
+
|
| 26 |
+
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
|
| 27 |
+
__syncthreads();
|
| 28 |
+
_k[tid] = k[t];
|
| 29 |
+
_r[tid] = r[t];
|
| 30 |
+
_td[tid] = td[t];
|
| 31 |
+
__syncthreads();
|
| 32 |
+
|
| 33 |
+
const float _v = v[t];
|
| 34 |
+
float y = 0;
|
| 35 |
+
for (int j = 0; j < head_size; j += 4) {
|
| 36 |
+
const float4& k = (float4&)(_k[j]);
|
| 37 |
+
const float4& r = (float4&)(_r[j]);
|
| 38 |
+
const float4& tf = (float4&)(_tf[j]);
|
| 39 |
+
const float4& td = (float4&)(_td[j]);
|
| 40 |
+
float4& s = (float4&)(state[j]);
|
| 41 |
+
float4 kv;
|
| 42 |
+
|
| 43 |
+
kv.x = k.x * _v;
|
| 44 |
+
kv.y = k.y * _v;
|
| 45 |
+
kv.z = k.z * _v;
|
| 46 |
+
kv.w = k.w * _v;
|
| 47 |
+
|
| 48 |
+
y += r.x * (tf.x * kv.x + s.x);
|
| 49 |
+
y += r.y * (tf.y * kv.y + s.y);
|
| 50 |
+
y += r.z * (tf.z * kv.z + s.z);
|
| 51 |
+
y += r.w * (tf.w * kv.w + s.w);
|
| 52 |
+
|
| 53 |
+
s.x = s.x * td.x + kv.x;
|
| 54 |
+
s.y = s.y * td.y + kv.y;
|
| 55 |
+
s.z = s.z * td.z + kv.z;
|
| 56 |
+
s.w = s.w * td.w + kv.w;
|
| 57 |
+
}
|
| 58 |
+
dst[t] = y;
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
#pragma unroll
|
| 62 |
+
for (int i = 0; i < head_size; i++) {
|
| 63 |
+
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
|
| 64 |
+
}
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 68 |
+
const float * k_d = (const float *)dst->src[0]->data;
|
| 69 |
+
const float * v_d = (const float *)dst->src[1]->data;
|
| 70 |
+
const float * r_d = (const float *)dst->src[2]->data;
|
| 71 |
+
const float * tf_d = (const float *)dst->src[3]->data;
|
| 72 |
+
const float * td_d = (const float *)dst->src[4]->data;
|
| 73 |
+
const float * s_d = (const float *)dst->src[5]->data;
|
| 74 |
+
|
| 75 |
+
const int64_t B = dst->src[5]->ne[1];
|
| 76 |
+
const int64_t T = dst->src[0]->ne[3];
|
| 77 |
+
const int64_t C = dst->ne[0];
|
| 78 |
+
const int64_t H = dst->src[0]->ne[2];
|
| 79 |
+
|
| 80 |
+
float * dst_d = (float *)dst->data;
|
| 81 |
+
|
| 82 |
+
cudaStream_t stream = ctx.stream();
|
| 83 |
+
|
| 84 |
+
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
|
| 85 |
+
GGML_ASSERT(C % H == 0);
|
| 86 |
+
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE); // The current cuda kernel is designed for RWKV6, HEAD_SIZE == 64
|
| 87 |
+
|
| 88 |
+
rwkv_wkv_f32<<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
|
| 89 |
+
}
|
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "common.cuh"
|
| 2 |
+
|
| 3 |
+
#define CUDA_WKV_BLOCK_SIZE 64
|
| 4 |
+
|
| 5 |
+
void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|
@@ -1194,272 +1194,8 @@ typedef void (*ggml_sycl_op_mul_mat_t)(
|
|
| 1194 |
float *dst_dd_i, const int64_t row_low, const int64_t row_high,
|
| 1195 |
const int64_t src1_ncols, const int64_t src1_padded_row_size,
|
| 1196 |
const queue_ptr &stream);
|
| 1197 |
-
typedef void (*ggml_sycl_op_flatten_t)(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
| 1198 |
-
const ggml_tensor *src1,
|
| 1199 |
-
ggml_tensor *dst, const float *src0_dd,
|
| 1200 |
-
const float *src1_dd, float *dst_dd,
|
| 1201 |
-
const queue_ptr &main_stream);
|
| 1202 |
-
|
| 1203 |
-
static __dpct_inline__ float op_repeat(const float a, const float b) {
|
| 1204 |
-
return b;
|
| 1205 |
-
GGML_UNUSED(a);
|
| 1206 |
-
}
|
| 1207 |
-
|
| 1208 |
-
static __dpct_inline__ float op_add(const float a, const float b) {
|
| 1209 |
-
return a + b;
|
| 1210 |
-
}
|
| 1211 |
-
|
| 1212 |
-
static __dpct_inline__ float op_mul(const float a, const float b) {
|
| 1213 |
-
return a * b;
|
| 1214 |
-
}
|
| 1215 |
-
|
| 1216 |
-
static __dpct_inline__ float op_div(const float a, const float b) {
|
| 1217 |
-
return a / b;
|
| 1218 |
-
}
|
| 1219 |
-
|
| 1220 |
-
template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
|
| 1221 |
-
static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
|
| 1222 |
-
int ne0, int ne1, int ne2, int ne3,
|
| 1223 |
-
int ne10, int ne11, int ne12, int ne13,
|
| 1224 |
-
/*int s0, */ int s1, int s2, int s3,
|
| 1225 |
-
/*int s10,*/ int s11, int s12, int s13,
|
| 1226 |
-
const sycl::nd_item<3> &item_ct1) {
|
| 1227 |
-
const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
| 1228 |
-
item_ct1.get_local_id(2);
|
| 1229 |
-
const int i1 = (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
| 1230 |
-
item_ct1.get_local_id(1));
|
| 1231 |
-
const int i2 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
|
| 1232 |
-
item_ct1.get_local_id(0)) /
|
| 1233 |
-
ne3;
|
| 1234 |
-
const int i3 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
|
| 1235 |
-
item_ct1.get_local_id(0)) %
|
| 1236 |
-
ne3;
|
| 1237 |
-
|
| 1238 |
-
if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
|
| 1239 |
-
return;
|
| 1240 |
-
}
|
| 1241 |
-
|
| 1242 |
-
const int i11 = i1 % ne11;
|
| 1243 |
-
const int i12 = i2 % ne12;
|
| 1244 |
-
const int i13 = i3 % ne13;
|
| 1245 |
-
|
| 1246 |
-
const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
|
| 1247 |
-
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
|
| 1248 |
-
const size_t i_dst = i_src0;
|
| 1249 |
-
|
| 1250 |
-
const src0_t * src0_row = src0 + i_src0;
|
| 1251 |
-
const src1_t * src1_row = src1 + i_src1;
|
| 1252 |
-
dst_t * dst_row = dst + i_dst;
|
| 1253 |
-
|
| 1254 |
-
for (int i0 = i0s; i0 < ne0;
|
| 1255 |
-
i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) {
|
| 1256 |
-
const int i10 = i0 % ne10;
|
| 1257 |
-
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
|
| 1258 |
-
}
|
| 1259 |
-
}
|
| 1260 |
|
| 1261 |
-
template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
|
| 1262 |
-
static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
|
| 1263 |
-
int ne0, int ne1, int ne2, int ne3,
|
| 1264 |
-
int ne10, int ne11, int ne12, int ne13,
|
| 1265 |
-
/*int s0, */ int s1, int s2, int s3,
|
| 1266 |
-
/*int s10,*/ int s11, int s12, int s13,
|
| 1267 |
-
const sycl::nd_item<3> &item_ct1) {
|
| 1268 |
|
| 1269 |
-
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
| 1270 |
-
item_ct1.get_local_id(2);
|
| 1271 |
-
|
| 1272 |
-
const int i3 = i/(ne2*ne1*ne0);
|
| 1273 |
-
const int i2 = (i/(ne1*ne0)) % ne2;
|
| 1274 |
-
const int i1 = (i/ne0) % ne1;
|
| 1275 |
-
const int i0 = i % ne0;
|
| 1276 |
-
|
| 1277 |
-
if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
|
| 1278 |
-
return;
|
| 1279 |
-
}
|
| 1280 |
-
|
| 1281 |
-
const int i11 = i1 % ne11;
|
| 1282 |
-
const int i12 = i2 % ne12;
|
| 1283 |
-
const int i13 = i3 % ne13;
|
| 1284 |
-
|
| 1285 |
-
const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
|
| 1286 |
-
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
|
| 1287 |
-
const size_t i_dst = i_src0;
|
| 1288 |
-
|
| 1289 |
-
const src0_t * src0_row = src0 + i_src0;
|
| 1290 |
-
const src1_t * src1_row = src1 + i_src1;
|
| 1291 |
-
dst_t * dst_row = dst + i_dst;
|
| 1292 |
-
|
| 1293 |
-
const int i10 = i0 % ne10;
|
| 1294 |
-
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
|
| 1295 |
-
}
|
| 1296 |
-
|
| 1297 |
-
static void acc_f32(const float * x, const float * y, float * dst, const int ne,
|
| 1298 |
-
const int ne10, const int ne11, const int ne12,
|
| 1299 |
-
const int nb1, const int nb2, int offset, const sycl::nd_item<3> &item_ct1) {
|
| 1300 |
-
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
| 1301 |
-
item_ct1.get_local_id(2);
|
| 1302 |
-
if (i >= ne) {
|
| 1303 |
-
return;
|
| 1304 |
-
}
|
| 1305 |
-
int src1_idx = i - offset;
|
| 1306 |
-
int oz = src1_idx / nb2;
|
| 1307 |
-
int oy = (src1_idx - (oz * nb2)) / nb1;
|
| 1308 |
-
int ox = src1_idx % nb1;
|
| 1309 |
-
if (src1_idx >= 0 && ox < ne10 && oy < ne11 && oz < ne12) {
|
| 1310 |
-
dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11];
|
| 1311 |
-
} else {
|
| 1312 |
-
dst[i] = x[i];
|
| 1313 |
-
}
|
| 1314 |
-
}
|
| 1315 |
-
|
| 1316 |
-
static void gelu_f32(const float * x, float * dst, const int k,
|
| 1317 |
-
const sycl::nd_item<3> &item_ct1) {
|
| 1318 |
-
const float GELU_COEF_A = 0.044715f;
|
| 1319 |
-
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
| 1320 |
-
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
| 1321 |
-
item_ct1.get_local_id(2);
|
| 1322 |
-
|
| 1323 |
-
if (i >= k) {
|
| 1324 |
-
return;
|
| 1325 |
-
}
|
| 1326 |
-
|
| 1327 |
-
float xi = x[i];
|
| 1328 |
-
dst[i] = 0.5f * xi *
|
| 1329 |
-
(1.0f +
|
| 1330 |
-
sycl::tanh(SQRT_2_OVER_PI * xi * (1.0f + GELU_COEF_A * xi * xi)));
|
| 1331 |
-
}
|
| 1332 |
-
|
| 1333 |
-
static void silu_f32(const float * x, float * dst, const int k,
|
| 1334 |
-
const sycl::nd_item<3> &item_ct1) {
|
| 1335 |
-
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
| 1336 |
-
item_ct1.get_local_id(2);
|
| 1337 |
-
|
| 1338 |
-
if (i >= k) {
|
| 1339 |
-
return;
|
| 1340 |
-
}
|
| 1341 |
-
dst[i] = x[i] / (1.0f + sycl::native::exp(-x[i]));
|
| 1342 |
-
}
|
| 1343 |
-
|
| 1344 |
-
static void gelu_quick_f32(const float *x, float *dst, int k,
|
| 1345 |
-
const sycl::nd_item<3> &item_ct1) {
|
| 1346 |
-
const float GELU_QUICK_COEF = -1.702f;
|
| 1347 |
-
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
| 1348 |
-
item_ct1.get_local_id(2);
|
| 1349 |
-
if (i >= k) {
|
| 1350 |
-
return;
|
| 1351 |
-
}
|
| 1352 |
-
dst[i] = x[i] * (1.0f / (1.0f + sycl::native::exp(GELU_QUICK_COEF * x[i])));
|
| 1353 |
-
}
|
| 1354 |
-
|
| 1355 |
-
static void tanh_f32(const float *x, float *dst, int k,
|
| 1356 |
-
const sycl::nd_item<3> &item_ct1) {
|
| 1357 |
-
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
| 1358 |
-
item_ct1.get_local_id(2);
|
| 1359 |
-
if (i >= k) {
|
| 1360 |
-
return;
|
| 1361 |
-
}
|
| 1362 |
-
dst[i] = sycl::tanh((float)(x[i]));
|
| 1363 |
-
}
|
| 1364 |
-
|
| 1365 |
-
static void relu_f32(const float * x, float * dst, const int k,
|
| 1366 |
-
const sycl::nd_item<3> &item_ct1) {
|
| 1367 |
-
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
| 1368 |
-
item_ct1.get_local_id(2);
|
| 1369 |
-
|
| 1370 |
-
if (i >= k) {
|
| 1371 |
-
return;
|
| 1372 |
-
}
|
| 1373 |
-
dst[i] = sycl::fmax((float)(x[i]), (float)0);
|
| 1374 |
-
}
|
| 1375 |
-
|
| 1376 |
-
static void hardsigmoid_f32(const float * x, float * dst, const int k,
|
| 1377 |
-
const sycl::nd_item<3> &item_ct1) {
|
| 1378 |
-
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
| 1379 |
-
item_ct1.get_local_id(2);
|
| 1380 |
-
|
| 1381 |
-
if (i >= k) {
|
| 1382 |
-
return;
|
| 1383 |
-
}
|
| 1384 |
-
dst[i] = sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f));
|
| 1385 |
-
}
|
| 1386 |
-
|
| 1387 |
-
static void hardswish_f32(const float * x, float * dst, const int k,
|
| 1388 |
-
const sycl::nd_item<3> &item_ct1) {
|
| 1389 |
-
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
| 1390 |
-
item_ct1.get_local_id(2);
|
| 1391 |
-
|
| 1392 |
-
if (i >= k) {
|
| 1393 |
-
return;
|
| 1394 |
-
}
|
| 1395 |
-
dst[i] = x[i] * sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f));
|
| 1396 |
-
}
|
| 1397 |
-
|
| 1398 |
-
static void leaky_relu_f32(const float *x, float *dst, const int k, const float negative_slope,
|
| 1399 |
-
const sycl::nd_item<3> &item_ct1) {
|
| 1400 |
-
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
| 1401 |
-
item_ct1.get_local_id(2);
|
| 1402 |
-
if (i >= k) {
|
| 1403 |
-
return;
|
| 1404 |
-
}
|
| 1405 |
-
dst[i] = sycl::fmax((float)(x[i]), (float)0) +
|
| 1406 |
-
sycl::fmin((float)(x[i]), 0.0f) * negative_slope;
|
| 1407 |
-
}
|
| 1408 |
-
|
| 1409 |
-
static void sqr_f32(const float * x, float * dst, const int k,
|
| 1410 |
-
const sycl::nd_item<3> &item_ct1) {
|
| 1411 |
-
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
| 1412 |
-
item_ct1.get_local_id(2);
|
| 1413 |
-
|
| 1414 |
-
if (i >= k) {
|
| 1415 |
-
return;
|
| 1416 |
-
}
|
| 1417 |
-
dst[i] = x[i] * x[i];
|
| 1418 |
-
}
|
| 1419 |
-
|
| 1420 |
-
static void upscale_f32(const float *x, float *dst, const int nb00, const int nb01,
|
| 1421 |
-
const int nb02, const int nb03, const int ne10, const int ne11,
|
| 1422 |
-
const int ne12, const int ne13, const float sf0, const float sf1,
|
| 1423 |
-
const float sf2, const float sf3, const sycl::nd_item<1> &item_ct1) {
|
| 1424 |
-
int index = item_ct1.get_local_id(0) +
|
| 1425 |
-
item_ct1.get_group(0) * item_ct1.get_local_range(0);
|
| 1426 |
-
if (index >= ne10 * ne11 * ne12 * ne13) {
|
| 1427 |
-
return;
|
| 1428 |
-
}
|
| 1429 |
-
// operation
|
| 1430 |
-
int i10 = index % ne10;
|
| 1431 |
-
int i11 = (index / ne10) % ne11;
|
| 1432 |
-
int i12 = (index / (ne10 * ne11)) % ne12;
|
| 1433 |
-
int i13 = (index / (ne10 * ne11 * ne12)) % ne13;
|
| 1434 |
-
|
| 1435 |
-
int i00 = i10 / sf0;
|
| 1436 |
-
int i01 = i11 / sf1;
|
| 1437 |
-
int i02 = i12 / sf2;
|
| 1438 |
-
int i03 = i13 / sf3;
|
| 1439 |
-
|
| 1440 |
-
dst[index] = *(float *)((char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
|
| 1441 |
-
}
|
| 1442 |
-
|
| 1443 |
-
static void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02,
|
| 1444 |
-
const sycl::nd_item<3> &item_ct1) {
|
| 1445 |
-
int nidx = item_ct1.get_local_id(2) +
|
| 1446 |
-
item_ct1.get_group(2) * item_ct1.get_local_range(2);
|
| 1447 |
-
if (nidx >= ne0) {
|
| 1448 |
-
return;
|
| 1449 |
-
}
|
| 1450 |
-
|
| 1451 |
-
// operation
|
| 1452 |
-
int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
|
| 1453 |
-
item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
|
| 1454 |
-
if (nidx < ne00 && item_ct1.get_group(1) < ne01 &&
|
| 1455 |
-
item_ct1.get_group(0) < ne02) {
|
| 1456 |
-
int offset_src = nidx + item_ct1.get_group(1) * ne00 +
|
| 1457 |
-
item_ct1.get_group(0) * ne00 * ne01;
|
| 1458 |
-
dst[offset_dst] = x[offset_src];
|
| 1459 |
-
} else {
|
| 1460 |
-
dst[offset_dst] = 0.0f;
|
| 1461 |
-
}
|
| 1462 |
-
}
|
| 1463 |
|
| 1464 |
template<int QUANT_BLOCK_TILE>
|
| 1465 |
static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded,
|
|
@@ -2148,297 +1884,6 @@ static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tens
|
|
| 2148 |
(void) dst;
|
| 2149 |
}
|
| 2150 |
|
| 2151 |
-
template<float (*bin_op)(const float, const float)>
|
| 2152 |
-
struct bin_bcast_sycl {
|
| 2153 |
-
template <typename src0_t, typename src1_t, typename dst_t>
|
| 2154 |
-
void operator()(ggml_backend_sycl_context & ctx,
|
| 2155 |
-
const struct ggml_tensor *src0,
|
| 2156 |
-
const struct ggml_tensor *src1, struct ggml_tensor *dst,
|
| 2157 |
-
const src0_t *src0_dd, const src1_t *src1_dd, dst_t *dst_dd,
|
| 2158 |
-
queue_ptr stream) {
|
| 2159 |
-
|
| 2160 |
-
GGML_TENSOR_BINARY_OP_LOCALS
|
| 2161 |
-
|
| 2162 |
-
int nr0 = ne10/ne0;
|
| 2163 |
-
int nr1 = ne11/ne1;
|
| 2164 |
-
int nr2 = ne12/ne2;
|
| 2165 |
-
int nr3 = ne13/ne3;
|
| 2166 |
-
|
| 2167 |
-
int nr[4] = { nr0, nr1, nr2, nr3 };
|
| 2168 |
-
|
| 2169 |
-
// collapse dimensions until first broadcast dimension
|
| 2170 |
-
int64_t cne0[] = {ne0, ne1, ne2, ne3};
|
| 2171 |
-
int64_t cne1[] = {ne10, ne11, ne12, ne13};
|
| 2172 |
-
size_t cnb0[] = {nb0, nb1, nb2, nb3};
|
| 2173 |
-
size_t cnb1[] = {nb10, nb11, nb12, nb13};
|
| 2174 |
-
auto collapse = [](int64_t cne[]) {
|
| 2175 |
-
cne[0] *= cne[1];
|
| 2176 |
-
cne[1] = cne[2];
|
| 2177 |
-
cne[2] = cne[3];
|
| 2178 |
-
cne[3] = 1;
|
| 2179 |
-
};
|
| 2180 |
-
|
| 2181 |
-
auto collapse_nb = [](size_t cnb[], int64_t cne[]) {
|
| 2182 |
-
cnb[1] *= cne[1];
|
| 2183 |
-
cnb[2] *= cne[2];
|
| 2184 |
-
cnb[3] *= cne[3];
|
| 2185 |
-
};
|
| 2186 |
-
|
| 2187 |
-
for (int i = 0; i < 4; i++) {
|
| 2188 |
-
if (nr[i] != 1) {
|
| 2189 |
-
break;
|
| 2190 |
-
}
|
| 2191 |
-
if (i > 0) {
|
| 2192 |
-
collapse_nb(cnb0, cne0);
|
| 2193 |
-
collapse_nb(cnb1, cne1);
|
| 2194 |
-
collapse(cne0);
|
| 2195 |
-
collapse(cne1);
|
| 2196 |
-
}
|
| 2197 |
-
}
|
| 2198 |
-
{
|
| 2199 |
-
int64_t ne0 = cne0[0];
|
| 2200 |
-
int64_t ne1 = cne0[1];
|
| 2201 |
-
int64_t ne2 = cne0[2];
|
| 2202 |
-
int64_t ne3 = cne0[3];
|
| 2203 |
-
|
| 2204 |
-
int64_t ne10 = cne1[0];
|
| 2205 |
-
int64_t ne11 = cne1[1];
|
| 2206 |
-
int64_t ne12 = cne1[2];
|
| 2207 |
-
int64_t ne13 = cne1[3];
|
| 2208 |
-
|
| 2209 |
-
size_t nb0 = cnb0[0];
|
| 2210 |
-
size_t nb1 = cnb0[1];
|
| 2211 |
-
size_t nb2 = cnb0[2];
|
| 2212 |
-
size_t nb3 = cnb0[3];
|
| 2213 |
-
|
| 2214 |
-
size_t nb10 = cnb1[0];
|
| 2215 |
-
size_t nb11 = cnb1[1];
|
| 2216 |
-
size_t nb12 = cnb1[2];
|
| 2217 |
-
size_t nb13 = cnb1[3];
|
| 2218 |
-
|
| 2219 |
-
size_t s0 = nb0 / sizeof(dst_t);
|
| 2220 |
-
size_t s1 = nb1 / sizeof(dst_t);
|
| 2221 |
-
size_t s2 = nb2 / sizeof(dst_t);
|
| 2222 |
-
size_t s3 = nb3 / sizeof(dst_t);
|
| 2223 |
-
|
| 2224 |
-
size_t s10 = nb10 / sizeof(src1_t);
|
| 2225 |
-
size_t s11 = nb11 / sizeof(src1_t);
|
| 2226 |
-
size_t s12 = nb12 / sizeof(src1_t);
|
| 2227 |
-
size_t s13 = nb13 / sizeof(src1_t);
|
| 2228 |
-
|
| 2229 |
-
GGML_ASSERT(s0 == 1);
|
| 2230 |
-
GGML_ASSERT(s10 == 1);
|
| 2231 |
-
|
| 2232 |
-
const int block_size = 128;
|
| 2233 |
-
|
| 2234 |
-
int64_t hne0 = std::max(ne0/2LL, 1LL);
|
| 2235 |
-
|
| 2236 |
-
sycl::range<3> block_dims(1, 1, 1);
|
| 2237 |
-
block_dims[2] = std::min<unsigned int>(hne0, block_size);
|
| 2238 |
-
block_dims[1] = std::min<unsigned int>(
|
| 2239 |
-
ne1, block_size / (unsigned int)block_dims[2]);
|
| 2240 |
-
block_dims[0] = std::min(
|
| 2241 |
-
std::min<unsigned int>(
|
| 2242 |
-
ne2 * ne3, block_size / (unsigned int)block_dims[2] /
|
| 2243 |
-
(unsigned int)block_dims[1]),
|
| 2244 |
-
64U);
|
| 2245 |
-
|
| 2246 |
-
sycl::range<3> block_nums(
|
| 2247 |
-
(ne2 * ne3 + block_dims[0] - 1) / block_dims[0],
|
| 2248 |
-
(ne1 + block_dims[1] - 1) / block_dims[1],
|
| 2249 |
-
(hne0 + block_dims[2] - 1) / block_dims[2]);
|
| 2250 |
-
|
| 2251 |
-
if (block_nums[0] > 65535) {
|
| 2252 |
-
// this is the maximum number of blocks in z direction, fallback to 1D grid kernel
|
| 2253 |
-
int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
|
| 2254 |
-
{
|
| 2255 |
-
dpct::has_capability_or_fail(stream->get_device(),
|
| 2256 |
-
{sycl::aspect::fp16});
|
| 2257 |
-
|
| 2258 |
-
stream->parallel_for(
|
| 2259 |
-
sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) *
|
| 2260 |
-
sycl::range<3>(1, 1, block_size),
|
| 2261 |
-
sycl::range<3>(1, 1, block_size)),
|
| 2262 |
-
[=](sycl::nd_item<3> item_ct1) {
|
| 2263 |
-
k_bin_bcast_unravel<bin_op>(
|
| 2264 |
-
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3,
|
| 2265 |
-
ne10, ne11, ne12, ne13, s1, s2, s3, s11, s12,
|
| 2266 |
-
s13, item_ct1);
|
| 2267 |
-
});
|
| 2268 |
-
}
|
| 2269 |
-
} else {
|
| 2270 |
-
/*
|
| 2271 |
-
DPCT1049:16: The work-group size passed to the SYCL kernel may
|
| 2272 |
-
exceed the limit. To get the device limit, query
|
| 2273 |
-
info::device::max_work_group_size. Adjust the work-group size if
|
| 2274 |
-
needed.
|
| 2275 |
-
*/
|
| 2276 |
-
dpct::has_capability_or_fail(stream->get_device(),
|
| 2277 |
-
{sycl::aspect::fp16});
|
| 2278 |
-
|
| 2279 |
-
stream->parallel_for(
|
| 2280 |
-
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
| 2281 |
-
[=](sycl::nd_item<3> item_ct1) {
|
| 2282 |
-
k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
|
| 2283 |
-
ne2, ne3, ne10, ne11, ne12, ne13,
|
| 2284 |
-
s1, s2, s3, s11, s12, s13,
|
| 2285 |
-
item_ct1);
|
| 2286 |
-
});
|
| 2287 |
-
}
|
| 2288 |
-
}
|
| 2289 |
-
}
|
| 2290 |
-
};
|
| 2291 |
-
|
| 2292 |
-
static void acc_f32_sycl(const float *x, const float *y, float *dst,
|
| 2293 |
-
const int n_elements, const int ne10, const int ne11,
|
| 2294 |
-
const int ne12, const int nb1, const int nb2,
|
| 2295 |
-
const int offset, queue_ptr stream) {
|
| 2296 |
-
int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE;
|
| 2297 |
-
stream->parallel_for(
|
| 2298 |
-
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
| 2299 |
-
sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE),
|
| 2300 |
-
sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)),
|
| 2301 |
-
[=](sycl::nd_item<3> item_ct1) {
|
| 2302 |
-
acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset,
|
| 2303 |
-
item_ct1);
|
| 2304 |
-
});
|
| 2305 |
-
}
|
| 2306 |
-
|
| 2307 |
-
static void gelu_f32_sycl(const float *x, float *dst, const int k,
|
| 2308 |
-
queue_ptr stream) {
|
| 2309 |
-
const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
|
| 2310 |
-
stream->parallel_for(
|
| 2311 |
-
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
| 2312 |
-
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
|
| 2313 |
-
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
|
| 2314 |
-
[=](sycl::nd_item<3> item_ct1) {
|
| 2315 |
-
gelu_f32(x, dst, k, item_ct1);
|
| 2316 |
-
});
|
| 2317 |
-
}
|
| 2318 |
-
|
| 2319 |
-
static void silu_f32_sycl(const float *x, float *dst, const int k,
|
| 2320 |
-
queue_ptr stream) {
|
| 2321 |
-
const int num_blocks = (k + SYCL_SILU_BLOCK_SIZE - 1) / SYCL_SILU_BLOCK_SIZE;
|
| 2322 |
-
stream->parallel_for(
|
| 2323 |
-
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
| 2324 |
-
sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE),
|
| 2325 |
-
sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE)),
|
| 2326 |
-
[=](sycl::nd_item<3> item_ct1) {
|
| 2327 |
-
silu_f32(x, dst, k, item_ct1);
|
| 2328 |
-
});
|
| 2329 |
-
}
|
| 2330 |
-
|
| 2331 |
-
static void gelu_quick_f32_sycl(const float *x, float *dst, const int k,
|
| 2332 |
-
queue_ptr stream) {
|
| 2333 |
-
const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
|
| 2334 |
-
stream->parallel_for(
|
| 2335 |
-
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
| 2336 |
-
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
|
| 2337 |
-
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
|
| 2338 |
-
[=](sycl::nd_item<3> item_ct1) {
|
| 2339 |
-
gelu_quick_f32(x, dst, k, item_ct1);
|
| 2340 |
-
});
|
| 2341 |
-
}
|
| 2342 |
-
|
| 2343 |
-
static void tanh_f32_sycl(const float *x, float *dst, const int k,
|
| 2344 |
-
queue_ptr stream) {
|
| 2345 |
-
const int num_blocks = (k + SYCL_TANH_BLOCK_SIZE - 1) / SYCL_TANH_BLOCK_SIZE;
|
| 2346 |
-
stream->parallel_for(
|
| 2347 |
-
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
| 2348 |
-
sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE),
|
| 2349 |
-
sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE)),
|
| 2350 |
-
[=](sycl::nd_item<3> item_ct1) {
|
| 2351 |
-
tanh_f32(x, dst, k, item_ct1);
|
| 2352 |
-
});
|
| 2353 |
-
}
|
| 2354 |
-
|
| 2355 |
-
static void relu_f32_sycl(const float *x, float *dst, const int k,
|
| 2356 |
-
queue_ptr stream) {
|
| 2357 |
-
const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
|
| 2358 |
-
stream->parallel_for(
|
| 2359 |
-
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
| 2360 |
-
sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE),
|
| 2361 |
-
sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)),
|
| 2362 |
-
[=](sycl::nd_item<3> item_ct1) {
|
| 2363 |
-
relu_f32(x, dst, k, item_ct1);
|
| 2364 |
-
});
|
| 2365 |
-
}
|
| 2366 |
-
|
| 2367 |
-
static void hardsigmoid_f32_sycl(const float *x, float *dst, const int k,
|
| 2368 |
-
queue_ptr stream) {
|
| 2369 |
-
const int num_blocks = (k + SYCL_HARDSIGMOID_BLOCK_SIZE - 1) / SYCL_HARDSIGMOID_BLOCK_SIZE;
|
| 2370 |
-
stream->parallel_for(
|
| 2371 |
-
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
| 2372 |
-
sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE),
|
| 2373 |
-
sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE)),
|
| 2374 |
-
[=](sycl::nd_item<3> item_ct1) {
|
| 2375 |
-
hardsigmoid_f32(x, dst, k, item_ct1);
|
| 2376 |
-
});
|
| 2377 |
-
}
|
| 2378 |
-
|
| 2379 |
-
static void hardswish_f32_sycl(const float *x, float *dst, const int k,
|
| 2380 |
-
queue_ptr stream) {
|
| 2381 |
-
const int num_blocks = (k + SYCL_HARDSWISH_BLOCK_SIZE - 1) / SYCL_HARDSWISH_BLOCK_SIZE;
|
| 2382 |
-
stream->parallel_for(
|
| 2383 |
-
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
| 2384 |
-
sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE),
|
| 2385 |
-
sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE)),
|
| 2386 |
-
[=](sycl::nd_item<3> item_ct1) {
|
| 2387 |
-
hardswish_f32(x, dst, k, item_ct1);
|
| 2388 |
-
});
|
| 2389 |
-
}
|
| 2390 |
-
|
| 2391 |
-
static void leaky_relu_f32_sycl(const float *x, float *dst, const int k,
|
| 2392 |
-
const float negative_slope,
|
| 2393 |
-
queue_ptr stream) {
|
| 2394 |
-
const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
|
| 2395 |
-
stream->parallel_for(
|
| 2396 |
-
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
| 2397 |
-
sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE),
|
| 2398 |
-
sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)),
|
| 2399 |
-
[=](sycl::nd_item<3> item_ct1) {
|
| 2400 |
-
leaky_relu_f32(x, dst, k, negative_slope, item_ct1);
|
| 2401 |
-
});
|
| 2402 |
-
}
|
| 2403 |
-
|
| 2404 |
-
static void sqr_f32_sycl(const float *x, float *dst, const int k,
|
| 2405 |
-
queue_ptr stream) {
|
| 2406 |
-
const int num_blocks = (k + SYCL_SQR_BLOCK_SIZE - 1) / SYCL_SQR_BLOCK_SIZE;
|
| 2407 |
-
stream->parallel_for(
|
| 2408 |
-
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
| 2409 |
-
sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE),
|
| 2410 |
-
sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE)),
|
| 2411 |
-
[=](sycl::nd_item<3> item_ct1) {
|
| 2412 |
-
sqr_f32(x, dst, k, item_ct1);
|
| 2413 |
-
});
|
| 2414 |
-
}
|
| 2415 |
-
|
| 2416 |
-
static void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01,
|
| 2417 |
-
const int nb02, const int nb03, const int ne10, const int ne11,
|
| 2418 |
-
const int ne12, const int ne13, const float sf0, const float sf1,
|
| 2419 |
-
const float sf2, const float sf3, queue_ptr stream) {
|
| 2420 |
-
int dst_size = ne10 * ne11 * ne12 * ne13;
|
| 2421 |
-
int num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE;
|
| 2422 |
-
sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE);
|
| 2423 |
-
stream->parallel_for(
|
| 2424 |
-
sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)),
|
| 2425 |
-
[=](sycl::nd_item<1> item_ct1) {
|
| 2426 |
-
upscale_f32(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, item_ct1);
|
| 2427 |
-
});
|
| 2428 |
-
}
|
| 2429 |
-
|
| 2430 |
-
static void pad_f32_sycl(const float *x, float *dst, const int ne00,
|
| 2431 |
-
const int ne01, const int ne02, const int ne0,
|
| 2432 |
-
const int ne1, const int ne2, queue_ptr stream) {
|
| 2433 |
-
int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE;
|
| 2434 |
-
sycl::range<3> gridDim(ne2, ne1, num_blocks);
|
| 2435 |
-
stream->parallel_for(
|
| 2436 |
-
sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE),
|
| 2437 |
-
sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)),
|
| 2438 |
-
[=](sycl::nd_item<3> item_ct1) {
|
| 2439 |
-
pad_f32(x, dst, ne0, ne00, ne01, ne02, item_ct1);
|
| 2440 |
-
});
|
| 2441 |
-
}
|
| 2442 |
|
| 2443 |
static void quantize_row_q8_1_sycl(const float *x, void *vy, const int kx,
|
| 2444 |
const int ky, const int kx_padded,
|
|
@@ -2816,6 +2261,58 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
|
|
| 2816 |
}
|
| 2817 |
}
|
| 2818 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2819 |
static void diag_mask_inf_f32_sycl(const float *x, float *dst,
|
| 2820 |
const int ncols_x, const int nrows_x,
|
| 2821 |
const int rows_per_channel, const int n_past,
|
|
@@ -2855,362 +2352,111 @@ static dpct::err0 ggml_sycl_cpy_tensor_2d(void *dst,
|
|
| 2855 |
} else {
|
| 2856 |
// GGML_SYCL_DEBUG("GGML_ABORT("fatal error")\n");
|
| 2857 |
GGML_ABORT("fatal error");
|
| 2858 |
-
}
|
| 2859 |
-
char * dst_ptr = (char *) dst;
|
| 2860 |
-
|
| 2861 |
-
GGML_TENSOR_LOCALS_1(int64_t, ne, src, ne);
|
| 2862 |
-
GGML_TENSOR_LOCALS(int64_t, nb, src, nb);
|
| 2863 |
-
const enum ggml_type type = src->type;
|
| 2864 |
-
const int64_t ts = ggml_type_size(type);
|
| 2865 |
-
const int64_t bs = ggml_blck_size(type);
|
| 2866 |
-
int64_t i1_diff = i1_high - i1_low;
|
| 2867 |
-
|
| 2868 |
-
const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;
|
| 2869 |
-
if (nb0 == ts && nb1 == ts*ne0/bs) {
|
| 2870 |
-
// GGML_SYCL_DEBUG("stream->memcpy: dst_ptr=%p, x=%p, size=%lu\n", dst_ptr, x, i1_diff * nb1);
|
| 2871 |
-
// return CHECK_TRY_ERROR(stream->memcpy(dst_ptr, x, i1_diff * nb1));
|
| 2872 |
-
return CHECK_TRY_ERROR(dpct::async_dpct_memcpy(dst_ptr, x, i1_diff * nb1,
|
| 2873 |
-
kind, *stream));
|
| 2874 |
-
|
| 2875 |
-
} else if (nb0 == ts) {
|
| 2876 |
-
return CHECK_TRY_ERROR(
|
| 2877 |
-
dpct::async_dpct_memcpy(dst_ptr, ts * ne0 / bs, x, nb1,
|
| 2878 |
-
ts * ne0 / bs, i1_diff, kind, *stream));
|
| 2879 |
-
} else {
|
| 2880 |
-
for (int64_t i1 = 0; i1 < i1_diff; i1++) {
|
| 2881 |
-
const void * rx = (const void *) ((const char *) x + i1*nb1);
|
| 2882 |
-
void * rd = (void *) (dst_ptr + i1*ts*ne0/bs);
|
| 2883 |
-
// pretend the row is a matrix with cols=1
|
| 2884 |
-
dpct::err0 r = CHECK_TRY_ERROR(dpct::async_dpct_memcpy(
|
| 2885 |
-
rd, ts / bs, rx, nb0, ts / bs, ne0, kind, *stream));
|
| 2886 |
-
/*
|
| 2887 |
-
DPCT1001:85: The statement could not be removed.
|
| 2888 |
-
*/
|
| 2889 |
-
/*
|
| 2890 |
-
DPCT1000:86: Error handling if-stmt was detected but could not be
|
| 2891 |
-
rewritten.
|
| 2892 |
-
*/
|
| 2893 |
-
if (r != 0) return r;
|
| 2894 |
-
}
|
| 2895 |
-
return 0;
|
| 2896 |
-
}
|
| 2897 |
-
}
|
| 2898 |
-
catch (sycl::exception const &exc) {
|
| 2899 |
-
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
| 2900 |
-
<< ", line:" << __LINE__ << std::endl;
|
| 2901 |
-
std::exit(1);
|
| 2902 |
-
}
|
| 2903 |
-
|
| 2904 |
-
static void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
| 2905 |
-
const ggml_tensor *src1, ggml_tensor *dst,
|
| 2906 |
-
const float *src0_d, const float *src1_d,
|
| 2907 |
-
float *dst_d, const queue_ptr &stream) {
|
| 2908 |
-
|
| 2909 |
-
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
| 2910 |
-
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
| 2911 |
-
|
| 2912 |
-
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
|
| 2913 |
-
GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
|
| 2914 |
-
GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
|
| 2915 |
-
|
| 2916 |
-
const int32_t * src1_i32 = (const int32_t *) src1_d;
|
| 2917 |
-
|
| 2918 |
-
switch (src0->type) {
|
| 2919 |
-
case GGML_TYPE_F16:
|
| 2920 |
-
get_rows_sycl_float(ctx, src0, src1, dst, (const sycl::half *)src0_d,
|
| 2921 |
-
src1_i32, dst_d, stream);
|
| 2922 |
-
break;
|
| 2923 |
-
case GGML_TYPE_F32:
|
| 2924 |
-
get_rows_sycl_float(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
| 2925 |
-
break;
|
| 2926 |
-
case GGML_TYPE_Q4_0:
|
| 2927 |
-
get_rows_sycl<QK4_0, QR4_0, dequantize_q4_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
| 2928 |
-
break;
|
| 2929 |
-
case GGML_TYPE_Q4_1:
|
| 2930 |
-
get_rows_sycl<QK4_1, QR4_1, dequantize_q4_1>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
| 2931 |
-
break;
|
| 2932 |
-
case GGML_TYPE_Q5_0:
|
| 2933 |
-
get_rows_sycl<QK5_0, QR5_0, dequantize_q5_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
| 2934 |
-
break;
|
| 2935 |
-
case GGML_TYPE_Q5_1:
|
| 2936 |
-
get_rows_sycl<QK5_1, QR5_1, dequantize_q5_1>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
| 2937 |
-
break;
|
| 2938 |
-
case GGML_TYPE_Q8_0:
|
| 2939 |
-
get_rows_sycl<QK8_0, QR8_0, dequantize_q8_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
| 2940 |
-
break;
|
| 2941 |
-
default:
|
| 2942 |
-
// TODO: k-quants
|
| 2943 |
-
fprintf(stderr, "%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type));
|
| 2944 |
-
GGML_ABORT("fatal error");
|
| 2945 |
-
break;
|
| 2946 |
-
}
|
| 2947 |
-
}
|
| 2948 |
-
|
| 2949 |
-
template <class op>
|
| 2950 |
-
inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
| 2951 |
-
const ggml_tensor *src1, ggml_tensor *dst,
|
| 2952 |
-
const float *src0_dd, const float *src1_dd,
|
| 2953 |
-
float *dst_dd,
|
| 2954 |
-
const queue_ptr &main_stream) {
|
| 2955 |
-
|
| 2956 |
-
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
| 2957 |
-
op()(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
|
| 2958 |
-
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
|
| 2959 |
-
op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd,
|
| 2960 |
-
(sycl::half *)dst_dd, main_stream);
|
| 2961 |
-
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
|
| 2962 |
-
op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd, dst_dd,
|
| 2963 |
-
main_stream);
|
| 2964 |
-
} else if (src0->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
|
| 2965 |
-
op()(ctx, src0, src1, dst, (const int32_t *)src0_dd, (const int32_t *)src1_dd, (int32_t *)dst_dd,
|
| 2966 |
-
main_stream);
|
| 2967 |
-
} else if (src0->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) {
|
| 2968 |
-
op()(ctx, src0, src1, dst, (const int16_t *)src0_dd, (const int16_t *)src1_dd, (int16_t *)dst_dd,
|
| 2969 |
-
main_stream);
|
| 2970 |
-
} else {
|
| 2971 |
-
fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
|
| 2972 |
-
ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
|
| 2973 |
-
GGML_ABORT("fatal error");
|
| 2974 |
-
}
|
| 2975 |
-
}
|
| 2976 |
-
|
| 2977 |
-
static void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
| 2978 |
-
const ggml_tensor *src1, ggml_tensor *dst,
|
| 2979 |
-
const float *src0_d, const float *src1_d,
|
| 2980 |
-
float *dst_d,
|
| 2981 |
-
const queue_ptr &main_stream) {
|
| 2982 |
-
|
| 2983 |
-
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_repeat>>(ctx, dst, src0, dst, nullptr, src0_d, dst_d, main_stream);
|
| 2984 |
-
|
| 2985 |
-
(void) src1;
|
| 2986 |
-
(void) src1_d;
|
| 2987 |
-
}
|
| 2988 |
-
|
| 2989 |
-
inline void ggml_sycl_op_add(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
| 2990 |
-
ggml_tensor *dst, const float *src0_dd,
|
| 2991 |
-
const float *src1_dd, float *dst_dd,
|
| 2992 |
-
const queue_ptr &main_stream) {
|
| 2993 |
-
|
| 2994 |
-
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_add>>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
|
| 2995 |
-
}
|
| 2996 |
-
|
| 2997 |
-
inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
| 2998 |
-
ggml_tensor *dst, const float *src0_dd,
|
| 2999 |
-
const float *src1_dd, float *dst_dd,
|
| 3000 |
-
const queue_ptr &main_stream) {
|
| 3001 |
-
|
| 3002 |
-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 3003 |
-
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
| 3004 |
-
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 3005 |
-
GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported
|
| 3006 |
-
|
| 3007 |
-
int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
|
| 3008 |
-
int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
|
| 3009 |
-
// int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
|
| 3010 |
-
int offset = dst->op_params[3] / 4; // offset in bytes
|
| 3011 |
-
|
| 3012 |
-
acc_f32_sycl(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], nb1, nb2, offset, main_stream);
|
| 3013 |
-
|
| 3014 |
-
(void) dst;
|
| 3015 |
-
}
|
| 3016 |
-
|
| 3017 |
-
inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
| 3018 |
-
ggml_tensor *dst, const float *src0_dd,
|
| 3019 |
-
const float *src1_dd, float *dst_dd,
|
| 3020 |
-
const queue_ptr &main_stream) {
|
| 3021 |
-
|
| 3022 |
-
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_mul>>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
|
| 3023 |
-
}
|
| 3024 |
-
|
| 3025 |
-
inline void ggml_sycl_op_div(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
| 3026 |
-
ggml_tensor *dst, const float *src0_dd,
|
| 3027 |
-
const float *src1_dd, float *dst_dd,
|
| 3028 |
-
const queue_ptr &main_stream) {
|
| 3029 |
-
|
| 3030 |
-
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_div>>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
|
| 3031 |
-
}
|
| 3032 |
-
|
| 3033 |
-
inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
| 3034 |
-
ggml_tensor *dst, const float *src0_dd,
|
| 3035 |
-
const float *src1_dd, float *dst_dd,
|
| 3036 |
-
const queue_ptr &main_stream) {
|
| 3037 |
-
|
| 3038 |
-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 3039 |
-
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 3040 |
-
|
| 3041 |
-
gelu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
| 3042 |
-
|
| 3043 |
-
(void) src1;
|
| 3044 |
-
(void) dst;
|
| 3045 |
-
(void) src1_dd;
|
| 3046 |
-
}
|
| 3047 |
-
|
| 3048 |
-
inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
| 3049 |
-
ggml_tensor *dst, const float *src0_dd,
|
| 3050 |
-
const float *src1_dd, float *dst_dd,
|
| 3051 |
-
const queue_ptr &main_stream) {
|
| 3052 |
-
|
| 3053 |
-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 3054 |
-
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 3055 |
-
|
| 3056 |
-
silu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
| 3057 |
-
|
| 3058 |
-
(void) src1;
|
| 3059 |
-
(void) dst;
|
| 3060 |
-
(void) src1_dd;
|
| 3061 |
-
}
|
| 3062 |
-
|
| 3063 |
-
inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
| 3064 |
-
const ggml_tensor *src1, ggml_tensor *dst,
|
| 3065 |
-
const float *src0_dd, const float *src1_dd,
|
| 3066 |
-
float *dst_dd,
|
| 3067 |
-
const queue_ptr &main_stream) {
|
| 3068 |
-
|
| 3069 |
-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 3070 |
-
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 3071 |
-
|
| 3072 |
-
gelu_quick_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
| 3073 |
-
|
| 3074 |
-
(void) src1;
|
| 3075 |
-
(void) dst;
|
| 3076 |
-
(void) src1_dd;
|
| 3077 |
-
}
|
| 3078 |
-
|
| 3079 |
-
inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
| 3080 |
-
ggml_tensor *dst, const float *src0_dd,
|
| 3081 |
-
const float *src1_dd, float *dst_dd,
|
| 3082 |
-
const queue_ptr &main_stream) {
|
| 3083 |
-
|
| 3084 |
-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 3085 |
-
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 3086 |
-
tanh_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
| 3087 |
-
|
| 3088 |
-
(void) src1;
|
| 3089 |
-
(void) dst;
|
| 3090 |
-
(void) src1_dd;
|
| 3091 |
-
}
|
| 3092 |
-
|
| 3093 |
-
inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
| 3094 |
-
ggml_tensor *dst, const float *src0_dd,
|
| 3095 |
-
const float *src1_dd, float *dst_dd,
|
| 3096 |
-
const queue_ptr &main_stream) {
|
| 3097 |
-
|
| 3098 |
-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 3099 |
-
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 3100 |
-
|
| 3101 |
-
relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
| 3102 |
-
|
| 3103 |
-
(void) src1;
|
| 3104 |
-
(void) dst;
|
| 3105 |
-
(void) src1_dd;
|
| 3106 |
-
}
|
| 3107 |
-
|
| 3108 |
-
static void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
| 3109 |
-
const ggml_tensor *src1, ggml_tensor *dst,
|
| 3110 |
-
const float *src0_dd, const float *src1_dd,
|
| 3111 |
-
float *dst_dd,
|
| 3112 |
-
const queue_ptr &main_stream) {
|
| 3113 |
-
|
| 3114 |
-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 3115 |
-
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 3116 |
-
|
| 3117 |
-
hardsigmoid_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
| 3118 |
-
|
| 3119 |
-
(void) src1;
|
| 3120 |
-
(void) dst;
|
| 3121 |
-
(void) src1_dd;
|
| 3122 |
-
}
|
| 3123 |
-
|
| 3124 |
-
static void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
| 3125 |
-
const ggml_tensor *src1, ggml_tensor *dst,
|
| 3126 |
-
const float *src0_dd, const float *src1_dd,
|
| 3127 |
-
float *dst_dd, const queue_ptr &main_stream) {
|
| 3128 |
-
|
| 3129 |
-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 3130 |
-
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 3131 |
-
|
| 3132 |
-
hardswish_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
| 3133 |
-
|
| 3134 |
-
(void) src1;
|
| 3135 |
-
(void) dst;
|
| 3136 |
-
(void) src1_dd;
|
| 3137 |
-
}
|
| 3138 |
-
|
| 3139 |
-
inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
| 3140 |
-
const ggml_tensor *src1, ggml_tensor *dst,
|
| 3141 |
-
const float *src0_dd, const float *src1_dd,
|
| 3142 |
-
float *dst_dd,
|
| 3143 |
-
const queue_ptr &main_stream) {
|
| 3144 |
-
|
| 3145 |
-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 3146 |
-
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 3147 |
|
| 3148 |
-
|
| 3149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3150 |
|
| 3151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3152 |
|
| 3153 |
-
(
|
| 3154 |
-
|
| 3155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3156 |
}
|
| 3157 |
-
|
| 3158 |
-
|
| 3159 |
-
|
| 3160 |
-
|
| 3161 |
-
const queue_ptr &main_stream) {
|
| 3162 |
-
|
| 3163 |
-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 3164 |
-
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 3165 |
-
|
| 3166 |
-
sqr_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
| 3167 |
-
|
| 3168 |
-
(void) src1;
|
| 3169 |
-
(void) dst;
|
| 3170 |
-
(void) src1_dd;
|
| 3171 |
}
|
| 3172 |
|
| 3173 |
-
|
| 3174 |
-
|
| 3175 |
-
|
| 3176 |
-
|
| 3177 |
-
const queue_ptr &main_stream) {
|
| 3178 |
|
| 3179 |
-
GGML_ASSERT(
|
| 3180 |
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
| 3181 |
|
| 3182 |
-
|
| 3183 |
-
|
| 3184 |
-
|
| 3185 |
-
const float sf3 = (float)dst->ne[3]/src0->ne[3];
|
| 3186 |
|
| 3187 |
-
|
| 3188 |
-
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3,
|
| 3189 |
-
main_stream);
|
| 3190 |
|
| 3191 |
-
(
|
| 3192 |
-
|
| 3193 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3194 |
}
|
| 3195 |
|
| 3196 |
-
inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
| 3197 |
-
ggml_tensor *dst, const float *src0_dd,
|
| 3198 |
-
const float *src1_dd, float *dst_dd,
|
| 3199 |
-
const queue_ptr &main_stream) {
|
| 3200 |
|
| 3201 |
-
|
| 3202 |
-
|
| 3203 |
-
|
|
|
|
|
|
|
| 3204 |
|
| 3205 |
-
|
| 3206 |
-
src0->ne[0], src0->ne[1], src0->ne[2],
|
| 3207 |
-
dst->ne[0], dst->ne[1], dst->ne[2], main_stream);
|
| 3208 |
|
| 3209 |
(void) src1;
|
| 3210 |
-
(void)
|
| 3211 |
-
(void) src1_dd;
|
| 3212 |
}
|
| 3213 |
|
|
|
|
| 3214 |
inline void ggml_sycl_op_mul_mat_sycl(
|
| 3215 |
ggml_backend_sycl_context & ctx,
|
| 3216 |
const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
|
|
@@ -3379,6 +2625,23 @@ static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, const ggml_tens
|
|
| 3379 |
(void) src1_dd;
|
| 3380 |
}
|
| 3381 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3382 |
inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
| 3383 |
const ggml_tensor *src1, ggml_tensor *dst,
|
| 3384 |
const float *src0_dd, const float *src1_dd,
|
|
@@ -3419,6 +2682,25 @@ inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, const ggml_ten
|
|
| 3419 |
(void) src1_dd;
|
| 3420 |
}
|
| 3421 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3422 |
inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
| 3423 |
const ggml_tensor *src1,
|
| 3424 |
ggml_tensor *dst, const float *src0_dd,
|
|
@@ -3489,46 +2771,6 @@ inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, const ggml_tenso
|
|
| 3489 |
(void) src1_dd;
|
| 3490 |
}
|
| 3491 |
|
| 3492 |
-
static void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
| 3493 |
-
const ggml_tensor *src1, ggml_tensor *dst,
|
| 3494 |
-
const ggml_sycl_op_flatten_t op) try {
|
| 3495 |
-
const int64_t nrows0 = ggml_nrows(src0);
|
| 3496 |
-
|
| 3497 |
-
const bool use_src1 = src1 != nullptr;
|
| 3498 |
-
const int64_t nrows1 = use_src1 ? ggml_nrows(src1) : 1;
|
| 3499 |
-
|
| 3500 |
-
GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
|
| 3501 |
-
GGML_ASSERT( dst->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
|
| 3502 |
-
|
| 3503 |
-
ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
|
| 3504 |
-
ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr;
|
| 3505 |
-
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
|
| 3506 |
-
|
| 3507 |
-
// dd = data device
|
| 3508 |
-
float * src0_ddf = (float *) src0->data;
|
| 3509 |
-
float * src1_ddf = use_src1 ? (float *) src1->data : nullptr;
|
| 3510 |
-
float * dst_ddf = (float *) dst->data;
|
| 3511 |
-
|
| 3512 |
-
ggml_sycl_pool_alloc<float> src0_f(ctx.pool());
|
| 3513 |
-
ggml_sycl_pool_alloc<float> src1_f(ctx.pool());
|
| 3514 |
-
ggml_sycl_pool_alloc<float> dst_f(ctx.pool());
|
| 3515 |
-
|
| 3516 |
-
ggml_sycl_set_device(ctx.device);
|
| 3517 |
-
queue_ptr main_stream = ctx.stream();
|
| 3518 |
-
// GGML_SYCL_DEBUG("ctx.device=%d, main_stream=%p src0_on_device=%d, src1_on_device=%d, dst_on_device=%d\n",
|
| 3519 |
-
// ctx.device, main_stream, src0_on_device, src1_on_device, dst_on_device);
|
| 3520 |
-
|
| 3521 |
-
// do the computation
|
| 3522 |
-
op(ctx, src0, src1, dst, src0_ddf, src1_ddf, dst_ddf, main_stream);
|
| 3523 |
-
// print_ggml_tensor("tensor", dst);
|
| 3524 |
-
}
|
| 3525 |
-
catch (sycl::exception const &exc) {
|
| 3526 |
-
|
| 3527 |
-
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
| 3528 |
-
<< ", line:" << __LINE__ << std::endl;
|
| 3529 |
-
std::exit(1);
|
| 3530 |
-
}
|
| 3531 |
-
|
| 3532 |
static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) {
|
| 3533 |
static bool peer_access_enabled = false;
|
| 3534 |
|
|
@@ -3908,112 +3150,21 @@ static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, const ggml_tenso
|
|
| 3908 |
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
| 3909 |
}
|
| 3910 |
|
| 3911 |
-
static void ggml_sycl_add(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 3912 |
-
GGML_SYCL_DEBUG("call %s\n", __func__);
|
| 3913 |
-
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_add);
|
| 3914 |
-
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
| 3915 |
-
}
|
| 3916 |
-
|
| 3917 |
-
static void ggml_sycl_acc(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 3918 |
-
GGML_SYCL_DEBUG("call %s\n", __func__);
|
| 3919 |
-
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_acc);
|
| 3920 |
-
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
| 3921 |
-
}
|
| 3922 |
-
|
| 3923 |
-
static void ggml_sycl_mul(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 3924 |
-
GGML_SYCL_DEBUG("call %s\n", __func__);
|
| 3925 |
-
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_mul);
|
| 3926 |
-
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
| 3927 |
-
}
|
| 3928 |
-
|
| 3929 |
-
static void ggml_sycl_div(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 3930 |
-
GGML_SYCL_DEBUG("call %s\n", __func__);
|
| 3931 |
-
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_div);
|
| 3932 |
-
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
| 3933 |
-
}
|
| 3934 |
-
|
| 3935 |
-
static void ggml_sycl_gelu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 3936 |
-
GGML_SYCL_DEBUG("call %s\n", __func__);
|
| 3937 |
-
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_gelu);
|
| 3938 |
-
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
| 3939 |
-
}
|
| 3940 |
-
|
| 3941 |
-
static void ggml_sycl_silu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 3942 |
-
GGML_SYCL_DEBUG("call %s\n", __func__);
|
| 3943 |
-
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_silu);
|
| 3944 |
-
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
| 3945 |
-
}
|
| 3946 |
-
|
| 3947 |
-
static void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 3948 |
-
GGML_SYCL_DEBUG("call %s\n", __func__);
|
| 3949 |
-
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_gelu_quick);
|
| 3950 |
-
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
| 3951 |
-
}
|
| 3952 |
-
|
| 3953 |
-
static void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 3954 |
-
GGML_SYCL_DEBUG("call %s\n", __func__);
|
| 3955 |
-
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_tanh);
|
| 3956 |
-
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
| 3957 |
-
}
|
| 3958 |
-
|
| 3959 |
-
static void ggml_sycl_relu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 3960 |
-
GGML_SYCL_DEBUG("call %s\n", __func__);
|
| 3961 |
-
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_relu);
|
| 3962 |
-
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
| 3963 |
-
}
|
| 3964 |
-
|
| 3965 |
-
static void ggml_sycl_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 3966 |
-
GGML_SYCL_DEBUG("call %s\n", __func__);
|
| 3967 |
-
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_hardsigmoid);
|
| 3968 |
-
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
| 3969 |
-
}
|
| 3970 |
-
|
| 3971 |
-
static void ggml_sycl_hardswish(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 3972 |
-
GGML_SYCL_DEBUG("call %s\n", __func__);
|
| 3973 |
-
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_hardswish);
|
| 3974 |
-
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
| 3975 |
-
}
|
| 3976 |
-
|
| 3977 |
-
static void ggml_sycl_leaky_relu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 3978 |
-
GGML_SYCL_DEBUG("call %s\n", __func__);
|
| 3979 |
-
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_leaky_relu);
|
| 3980 |
-
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
| 3981 |
-
}
|
| 3982 |
-
|
| 3983 |
-
static void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 3984 |
-
GGML_SYCL_DEBUG("call %s\n", __func__);
|
| 3985 |
-
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sqr);
|
| 3986 |
-
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
| 3987 |
-
}
|
| 3988 |
-
|
| 3989 |
static void ggml_sycl_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 3990 |
GGML_SYCL_DEBUG("call %s\n", __func__);
|
| 3991 |
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_norm);
|
| 3992 |
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
| 3993 |
}
|
| 3994 |
|
| 3995 |
-
static void
|
| 3996 |
-
GGML_SYCL_DEBUG("call %s\n", __func__);
|
| 3997 |
-
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_group_norm);
|
| 3998 |
-
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
| 3999 |
-
}
|
| 4000 |
-
|
| 4001 |
-
static void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 4002 |
-
GGML_SYCL_DEBUG("call %s\n", __func__);
|
| 4003 |
-
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_upscale);
|
| 4004 |
-
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
| 4005 |
-
}
|
| 4006 |
-
|
| 4007 |
-
static void ggml_sycl_pad(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 4008 |
GGML_SYCL_DEBUG("call %s\n", __func__);
|
| 4009 |
-
ggml_sycl_op_flatten(ctx, src0, src1, dst,
|
| 4010 |
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
| 4011 |
}
|
| 4012 |
|
| 4013 |
-
|
| 4014 |
-
static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 4015 |
GGML_SYCL_DEBUG("call %s\n", __func__);
|
| 4016 |
-
ggml_sycl_op_flatten(ctx, src0, src1, dst,
|
| 4017 |
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
| 4018 |
}
|
| 4019 |
|
|
@@ -4632,6 +3783,11 @@ static void ggml_sycl_im2col(ggml_backend_sycl_context & ctx, const ggml_tensor
|
|
| 4632 |
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_im2col);
|
| 4633 |
}
|
| 4634 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4635 |
static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 4636 |
GGML_ASSERT(ggml_is_contiguous(src0));
|
| 4637 |
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sum_rows);
|
|
@@ -4642,6 +3798,11 @@ static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, const ggml_tensor
|
|
| 4642 |
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_argsort);
|
| 4643 |
}
|
| 4644 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4645 |
static void ggml_sycl_nop(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 4646 |
(void) src0;
|
| 4647 |
(void) src1;
|
|
@@ -4673,6 +3834,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
|
|
| 4673 |
ggml_sycl_func_t func;
|
| 4674 |
|
| 4675 |
switch (tensor->op) {
|
|
|
|
|
|
|
|
|
|
| 4676 |
case GGML_OP_CONV_TRANSPOSE_1D:
|
| 4677 |
func = ggml_sycl_op_conv_transpose_1d;
|
| 4678 |
break;
|
|
@@ -4686,19 +3850,32 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
|
|
| 4686 |
func = ggml_sycl_dup;
|
| 4687 |
break;
|
| 4688 |
case GGML_OP_ADD:
|
|
|
|
| 4689 |
func = ggml_sycl_add;
|
| 4690 |
break;
|
|
|
|
|
|
|
|
|
|
| 4691 |
case GGML_OP_ACC:
|
| 4692 |
func = ggml_sycl_acc;
|
| 4693 |
break;
|
| 4694 |
case GGML_OP_MUL:
|
| 4695 |
func = ggml_sycl_mul;
|
| 4696 |
break;
|
|
|
|
|
|
|
|
|
|
| 4697 |
case GGML_OP_DIV:
|
| 4698 |
func = ggml_sycl_div;
|
| 4699 |
break;
|
| 4700 |
case GGML_OP_UNARY:
|
| 4701 |
switch (ggml_get_unary_op(tensor)) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4702 |
case GGML_UNARY_OP_GELU:
|
| 4703 |
func = ggml_sycl_gelu;
|
| 4704 |
break;
|
|
@@ -4714,12 +3891,18 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
|
|
| 4714 |
case GGML_UNARY_OP_RELU:
|
| 4715 |
func = ggml_sycl_relu;
|
| 4716 |
break;
|
|
|
|
|
|
|
|
|
|
| 4717 |
case GGML_UNARY_OP_HARDSIGMOID:
|
| 4718 |
func = ggml_sycl_hardsigmoid;
|
| 4719 |
break;
|
| 4720 |
case GGML_UNARY_OP_HARDSWISH:
|
| 4721 |
func = ggml_sycl_hardswish;
|
| 4722 |
break;
|
|
|
|
|
|
|
|
|
|
| 4723 |
default:
|
| 4724 |
return false;
|
| 4725 |
}
|
|
@@ -4757,12 +3940,24 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
|
|
| 4757 |
}
|
| 4758 |
func = ggml_sycl_mul_mat_id;
|
| 4759 |
break;
|
|
|
|
|
|
|
|
|
|
| 4760 |
case GGML_OP_SCALE:
|
| 4761 |
func = ggml_sycl_scale;
|
| 4762 |
break;
|
| 4763 |
case GGML_OP_SQR:
|
| 4764 |
func = ggml_sycl_sqr;
|
| 4765 |
break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4766 |
case GGML_OP_CLAMP:
|
| 4767 |
func = ggml_sycl_clamp;
|
| 4768 |
break;
|
|
@@ -4794,6 +3989,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
|
|
| 4794 |
case GGML_OP_POOL_2D:
|
| 4795 |
func = ggml_sycl_pool2d;
|
| 4796 |
break;
|
|
|
|
|
|
|
|
|
|
| 4797 |
case GGML_OP_SUM_ROWS:
|
| 4798 |
func = ggml_sycl_sum_rows;
|
| 4799 |
break;
|
|
@@ -4803,6 +4001,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
|
|
| 4803 |
case GGML_OP_TIMESTEP_EMBEDDING:
|
| 4804 |
func = ggml_sycl_op_timestep_embedding;
|
| 4805 |
break;
|
|
|
|
|
|
|
|
|
|
| 4806 |
default:
|
| 4807 |
return false;
|
| 4808 |
}
|
|
@@ -5125,13 +4326,17 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
| 5125 |
} break;
|
| 5126 |
case GGML_OP_UNARY:
|
| 5127 |
switch (ggml_get_unary_op(op)) {
|
|
|
|
|
|
|
| 5128 |
case GGML_UNARY_OP_GELU:
|
| 5129 |
case GGML_UNARY_OP_SILU:
|
| 5130 |
case GGML_UNARY_OP_RELU:
|
|
|
|
| 5131 |
case GGML_UNARY_OP_HARDSIGMOID:
|
| 5132 |
case GGML_UNARY_OP_HARDSWISH:
|
| 5133 |
case GGML_UNARY_OP_GELU_QUICK:
|
| 5134 |
case GGML_UNARY_OP_TANH:
|
|
|
|
| 5135 |
return ggml_is_contiguous(op->src[0]);
|
| 5136 |
default:
|
| 5137 |
return false;
|
|
@@ -5168,6 +4373,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
| 5168 |
}
|
| 5169 |
return true;
|
| 5170 |
} break;
|
|
|
|
|
|
|
| 5171 |
case GGML_OP_GET_ROWS:
|
| 5172 |
{
|
| 5173 |
switch (op->src[0]->type) {
|
|
@@ -5213,10 +4420,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
| 5213 |
case GGML_OP_CONCAT:
|
| 5214 |
{
|
| 5215 |
ggml_type src0_type = op->src[0]->type;
|
| 5216 |
-
|
| 5217 |
-
return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16 && dim == 2;
|
| 5218 |
} break;
|
| 5219 |
case GGML_OP_DUP:
|
|
|
|
| 5220 |
case GGML_OP_NONE:
|
| 5221 |
case GGML_OP_RESHAPE:
|
| 5222 |
case GGML_OP_REPEAT:
|
|
@@ -5225,11 +4432,17 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
| 5225 |
case GGML_OP_TRANSPOSE:
|
| 5226 |
case GGML_OP_NORM:
|
| 5227 |
case GGML_OP_ADD:
|
|
|
|
|
|
|
|
|
|
| 5228 |
case GGML_OP_MUL:
|
| 5229 |
case GGML_OP_DIV:
|
| 5230 |
case GGML_OP_RMS_NORM:
|
| 5231 |
case GGML_OP_SCALE:
|
| 5232 |
case GGML_OP_SQR:
|
|
|
|
|
|
|
|
|
|
| 5233 |
case GGML_OP_CLAMP:
|
| 5234 |
return true;
|
| 5235 |
case GGML_OP_CONT:
|
|
@@ -5243,6 +4456,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
| 5243 |
// TODO: add support for the new F32 operations
|
| 5244 |
return op->src[0]->type == GGML_TYPE_F16;
|
| 5245 |
case GGML_OP_POOL_2D:
|
|
|
|
| 5246 |
case GGML_OP_SUM_ROWS:
|
| 5247 |
case GGML_OP_ARGSORT:
|
| 5248 |
case GGML_OP_ACC:
|
|
@@ -5251,6 +4465,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
| 5251 |
case GGML_OP_PAD:
|
| 5252 |
case GGML_OP_LEAKY_RELU:
|
| 5253 |
case GGML_OP_TIMESTEP_EMBEDDING:
|
|
|
|
| 5254 |
return true;
|
| 5255 |
default:
|
| 5256 |
return false;
|
|
@@ -5268,9 +4483,23 @@ static bool ggml_backend_sycl_device_supports_buft(ggml_backend_dev_t dev, ggml_
|
|
| 5268 |
return buft_ctx->device == sycl_ctx->device;
|
| 5269 |
}
|
| 5270 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5271 |
static bool ggml_backend_sycl_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
|
| 5272 |
const int min_batch_size = 32;
|
| 5273 |
-
return op
|
| 5274 |
GGML_UNUSED(dev);
|
| 5275 |
}
|
| 5276 |
|
|
|
|
| 1194 |
float *dst_dd_i, const int64_t row_low, const int64_t row_high,
|
| 1195 |
const int64_t src1_ncols, const int64_t src1_padded_row_size,
|
| 1196 |
const queue_ptr &stream);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1197 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1198 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1199 |
|
| 1200 |
template<int QUANT_BLOCK_TILE>
|
| 1201 |
static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded,
|
|
|
|
| 1884 |
(void) dst;
|
| 1885 |
}
|
| 1886 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1887 |
|
| 1888 |
static void quantize_row_q8_1_sycl(const float *x, void *vy, const int kx,
|
| 1889 |
const int ky, const int kx_padded,
|
|
|
|
| 2261 |
}
|
| 2262 |
}
|
| 2263 |
|
| 2264 |
+
static void argmax_f32_i32_sycl(const float *x, int *dst, const int ncols,
|
| 2265 |
+
const int nrows, queue_ptr stream) {
|
| 2266 |
+
const sycl::range<3> block_dims(1, 1, SYCL_ARGMAX_BLOCK_SIZE);
|
| 2267 |
+
const sycl::range<3> block_nums(1, nrows, 1);
|
| 2268 |
+
const size_t shared_mem = 256 * sizeof(float);
|
| 2269 |
+
|
| 2270 |
+
stream->submit([&](sycl::handler &cgh) {
|
| 2271 |
+
sycl::local_accessor<float, 1> shared_data(
|
| 2272 |
+
sycl::range<1>(shared_mem/sizeof(float)), cgh);
|
| 2273 |
+
sycl::local_accessor<int, 1> shared_indices(
|
| 2274 |
+
sycl::range<1>(shared_mem/sizeof(float)), cgh);
|
| 2275 |
+
|
| 2276 |
+
cgh.parallel_for(
|
| 2277 |
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
| 2278 |
+
[=](sycl::nd_item<3> item_ct1) {
|
| 2279 |
+
const int tid = item_ct1.get_local_id(2);
|
| 2280 |
+
const int row = item_ct1.get_global_id(1);
|
| 2281 |
+
|
| 2282 |
+
float max_val = -INFINITY;
|
| 2283 |
+
int max_idx = -1;
|
| 2284 |
+
|
| 2285 |
+
for (int col = tid; col < ncols; col += 256) {
|
| 2286 |
+
float val = x[row * ncols + col];
|
| 2287 |
+
if (val > max_val) {
|
| 2288 |
+
max_val = val;
|
| 2289 |
+
max_idx = col;
|
| 2290 |
+
}
|
| 2291 |
+
}
|
| 2292 |
+
|
| 2293 |
+
shared_data[tid] = max_val;
|
| 2294 |
+
shared_indices[tid] = max_idx;
|
| 2295 |
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
| 2296 |
+
|
| 2297 |
+
for (int stride = 256/2; stride > 0; stride >>= 1) {
|
| 2298 |
+
if (tid < stride) {
|
| 2299 |
+
float val1 = shared_data[tid];
|
| 2300 |
+
float val2 = shared_data[tid + stride];
|
| 2301 |
+
if (val2 > val1) {
|
| 2302 |
+
shared_data[tid] = val2;
|
| 2303 |
+
shared_indices[tid] = shared_indices[tid + stride];
|
| 2304 |
+
}
|
| 2305 |
+
}
|
| 2306 |
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
| 2307 |
+
}
|
| 2308 |
+
|
| 2309 |
+
|
| 2310 |
+
if (tid == 0) {
|
| 2311 |
+
dst[row] = shared_indices[0];
|
| 2312 |
+
}
|
| 2313 |
+
});
|
| 2314 |
+
});
|
| 2315 |
+
}
|
| 2316 |
static void diag_mask_inf_f32_sycl(const float *x, float *dst,
|
| 2317 |
const int ncols_x, const int nrows_x,
|
| 2318 |
const int rows_per_channel, const int n_past,
|
|
|
|
| 2352 |
} else {
|
| 2353 |
// GGML_SYCL_DEBUG("GGML_ABORT("fatal error")\n");
|
| 2354 |
GGML_ABORT("fatal error");
|
| 2355 |
+
}
|
| 2356 |
+
char * dst_ptr = (char *) dst;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2357 |
|
| 2358 |
+
GGML_TENSOR_LOCALS_1(int64_t, ne, src, ne);
|
| 2359 |
+
GGML_TENSOR_LOCALS(int64_t, nb, src, nb);
|
| 2360 |
+
const enum ggml_type type = src->type;
|
| 2361 |
+
const int64_t ts = ggml_type_size(type);
|
| 2362 |
+
const int64_t bs = ggml_blck_size(type);
|
| 2363 |
+
int64_t i1_diff = i1_high - i1_low;
|
| 2364 |
|
| 2365 |
+
const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;
|
| 2366 |
+
if (nb0 == ts && nb1 == ts*ne0/bs) {
|
| 2367 |
+
// GGML_SYCL_DEBUG("stream->memcpy: dst_ptr=%p, x=%p, size=%lu\n", dst_ptr, x, i1_diff * nb1);
|
| 2368 |
+
// return CHECK_TRY_ERROR(stream->memcpy(dst_ptr, x, i1_diff * nb1));
|
| 2369 |
+
return CHECK_TRY_ERROR(dpct::async_dpct_memcpy(dst_ptr, x, i1_diff * nb1,
|
| 2370 |
+
kind, *stream));
|
| 2371 |
|
| 2372 |
+
} else if (nb0 == ts) {
|
| 2373 |
+
return CHECK_TRY_ERROR(
|
| 2374 |
+
dpct::async_dpct_memcpy(dst_ptr, ts * ne0 / bs, x, nb1,
|
| 2375 |
+
ts * ne0 / bs, i1_diff, kind, *stream));
|
| 2376 |
+
} else {
|
| 2377 |
+
for (int64_t i1 = 0; i1 < i1_diff; i1++) {
|
| 2378 |
+
const void * rx = (const void *) ((const char *) x + i1*nb1);
|
| 2379 |
+
void * rd = (void *) (dst_ptr + i1*ts*ne0/bs);
|
| 2380 |
+
// pretend the row is a matrix with cols=1
|
| 2381 |
+
dpct::err0 r = CHECK_TRY_ERROR(dpct::async_dpct_memcpy(
|
| 2382 |
+
rd, ts / bs, rx, nb0, ts / bs, ne0, kind, *stream));
|
| 2383 |
+
/*
|
| 2384 |
+
DPCT1001:85: The statement could not be removed.
|
| 2385 |
+
*/
|
| 2386 |
+
/*
|
| 2387 |
+
DPCT1000:86: Error handling if-stmt was detected but could not be
|
| 2388 |
+
rewritten.
|
| 2389 |
+
*/
|
| 2390 |
+
if (r != 0) return r;
|
| 2391 |
+
}
|
| 2392 |
+
return 0;
|
| 2393 |
+
}
|
| 2394 |
}
|
| 2395 |
+
catch (sycl::exception const &exc) {
|
| 2396 |
+
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
| 2397 |
+
<< ", line:" << __LINE__ << std::endl;
|
| 2398 |
+
std::exit(1);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2399 |
}
|
| 2400 |
|
| 2401 |
+
static void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
| 2402 |
+
const ggml_tensor *src1, ggml_tensor *dst,
|
| 2403 |
+
const float *src0_d, const float *src1_d,
|
| 2404 |
+
float *dst_d, const queue_ptr &stream) {
|
|
|
|
| 2405 |
|
| 2406 |
+
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
| 2407 |
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
| 2408 |
|
| 2409 |
+
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
|
| 2410 |
+
GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
|
| 2411 |
+
GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
|
|
|
|
| 2412 |
|
| 2413 |
+
const int32_t * src1_i32 = (const int32_t *) src1_d;
|
|
|
|
|
|
|
| 2414 |
|
| 2415 |
+
switch (src0->type) {
|
| 2416 |
+
case GGML_TYPE_F16:
|
| 2417 |
+
get_rows_sycl_float(ctx, src0, src1, dst, (const sycl::half *)src0_d,
|
| 2418 |
+
src1_i32, dst_d, stream);
|
| 2419 |
+
break;
|
| 2420 |
+
case GGML_TYPE_F32:
|
| 2421 |
+
get_rows_sycl_float(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
| 2422 |
+
break;
|
| 2423 |
+
case GGML_TYPE_Q4_0:
|
| 2424 |
+
get_rows_sycl<QK4_0, QR4_0, dequantize_q4_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
| 2425 |
+
break;
|
| 2426 |
+
case GGML_TYPE_Q4_1:
|
| 2427 |
+
get_rows_sycl<QK4_1, QR4_1, dequantize_q4_1>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
| 2428 |
+
break;
|
| 2429 |
+
case GGML_TYPE_Q5_0:
|
| 2430 |
+
get_rows_sycl<QK5_0, QR5_0, dequantize_q5_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
| 2431 |
+
break;
|
| 2432 |
+
case GGML_TYPE_Q5_1:
|
| 2433 |
+
get_rows_sycl<QK5_1, QR5_1, dequantize_q5_1>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
| 2434 |
+
break;
|
| 2435 |
+
case GGML_TYPE_Q8_0:
|
| 2436 |
+
get_rows_sycl<QK8_0, QR8_0, dequantize_q8_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
| 2437 |
+
break;
|
| 2438 |
+
default:
|
| 2439 |
+
// TODO: k-quants
|
| 2440 |
+
fprintf(stderr, "%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type));
|
| 2441 |
+
GGML_ABORT("fatal error");
|
| 2442 |
+
break;
|
| 2443 |
+
}
|
| 2444 |
}
|
| 2445 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2446 |
|
| 2447 |
+
static void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
| 2448 |
+
const ggml_tensor *src1, ggml_tensor *dst,
|
| 2449 |
+
const float *src0_d, const float *src1_d,
|
| 2450 |
+
float *dst_d,
|
| 2451 |
+
const queue_ptr &main_stream) {
|
| 2452 |
|
| 2453 |
+
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_repeat>>(ctx, dst, src0, dst, nullptr, src0_d, dst_d, main_stream);
|
|
|
|
|
|
|
| 2454 |
|
| 2455 |
(void) src1;
|
| 2456 |
+
(void) src1_d;
|
|
|
|
| 2457 |
}
|
| 2458 |
|
| 2459 |
+
|
| 2460 |
inline void ggml_sycl_op_mul_mat_sycl(
|
| 2461 |
ggml_backend_sycl_context & ctx,
|
| 2462 |
const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
|
|
|
|
| 2625 |
(void) src1_dd;
|
| 2626 |
}
|
| 2627 |
|
| 2628 |
+
inline void ggml_sycl_op_sum(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
| 2629 |
+
const ggml_tensor *src1, ggml_tensor *dst,
|
| 2630 |
+
const float *src0_dd, const float *src1_dd,
|
| 2631 |
+
float *dst_dd,
|
| 2632 |
+
const queue_ptr &main_stream) {
|
| 2633 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 2634 |
+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 2635 |
+
|
| 2636 |
+
const int64_t ne = ggml_nelements(src0);
|
| 2637 |
+
|
| 2638 |
+
sum_rows_f32_sycl(src0_dd, dst_dd, ne, 1, main_stream);
|
| 2639 |
+
|
| 2640 |
+
(void) src1;
|
| 2641 |
+
(void) dst;
|
| 2642 |
+
(void) src1_dd;
|
| 2643 |
+
}
|
| 2644 |
+
|
| 2645 |
inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
| 2646 |
const ggml_tensor *src1, ggml_tensor *dst,
|
| 2647 |
const float *src0_dd, const float *src1_dd,
|
|
|
|
| 2682 |
(void) src1_dd;
|
| 2683 |
}
|
| 2684 |
|
| 2685 |
+
inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
| 2686 |
+
const ggml_tensor *src1, ggml_tensor *dst,
|
| 2687 |
+
const float *src0_dd, const float *src1_dd,
|
| 2688 |
+
float *dst_dd,
|
| 2689 |
+
const queue_ptr &main_stream) {
|
| 2690 |
+
|
| 2691 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 2692 |
+
GGML_ASSERT( dst->type == GGML_TYPE_I32);
|
| 2693 |
+
|
| 2694 |
+
const int64_t ncols = src0->ne[0];
|
| 2695 |
+
const int64_t nrows = ggml_nrows(src0);
|
| 2696 |
+
|
| 2697 |
+
argmax_f32_i32_sycl(src0_dd, (int *)dst_dd, ncols, nrows, main_stream);
|
| 2698 |
+
|
| 2699 |
+
(void) src1;
|
| 2700 |
+
(void) dst;
|
| 2701 |
+
(void) src1_dd;
|
| 2702 |
+
}
|
| 2703 |
+
|
| 2704 |
inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
| 2705 |
const ggml_tensor *src1,
|
| 2706 |
ggml_tensor *dst, const float *src0_dd,
|
|
|
|
| 2771 |
(void) src1_dd;
|
| 2772 |
}
|
| 2773 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2774 |
static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) {
|
| 2775 |
static bool peer_access_enabled = false;
|
| 2776 |
|
|
|
|
| 3150 |
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
| 3151 |
}
|
| 3152 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3153 |
static void ggml_sycl_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 3154 |
GGML_SYCL_DEBUG("call %s\n", __func__);
|
| 3155 |
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_norm);
|
| 3156 |
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
| 3157 |
}
|
| 3158 |
|
| 3159 |
+
static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3160 |
GGML_SYCL_DEBUG("call %s\n", __func__);
|
| 3161 |
+
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_rms_norm);
|
| 3162 |
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
| 3163 |
}
|
| 3164 |
|
| 3165 |
+
static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
|
|
|
| 3166 |
GGML_SYCL_DEBUG("call %s\n", __func__);
|
| 3167 |
+
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_group_norm);
|
| 3168 |
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
| 3169 |
}
|
| 3170 |
|
|
|
|
| 3783 |
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_im2col);
|
| 3784 |
}
|
| 3785 |
|
| 3786 |
+
static void ggml_sycl_sum(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 3787 |
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
| 3788 |
+
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sum);
|
| 3789 |
+
}
|
| 3790 |
+
|
| 3791 |
static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 3792 |
GGML_ASSERT(ggml_is_contiguous(src0));
|
| 3793 |
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sum_rows);
|
|
|
|
| 3798 |
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_argsort);
|
| 3799 |
}
|
| 3800 |
|
| 3801 |
+
static void ggml_sycl_argmax(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 3802 |
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
| 3803 |
+
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_argmax);
|
| 3804 |
+
}
|
| 3805 |
+
|
| 3806 |
static void ggml_sycl_nop(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 3807 |
(void) src0;
|
| 3808 |
(void) src1;
|
|
|
|
| 3834 |
ggml_sycl_func_t func;
|
| 3835 |
|
| 3836 |
switch (tensor->op) {
|
| 3837 |
+
case GGML_OP_ARGMAX:
|
| 3838 |
+
func = ggml_sycl_argmax;
|
| 3839 |
+
break;
|
| 3840 |
case GGML_OP_CONV_TRANSPOSE_1D:
|
| 3841 |
func = ggml_sycl_op_conv_transpose_1d;
|
| 3842 |
break;
|
|
|
|
| 3850 |
func = ggml_sycl_dup;
|
| 3851 |
break;
|
| 3852 |
case GGML_OP_ADD:
|
| 3853 |
+
case GGML_OP_ADD1: // TODO: more efficient implementation
|
| 3854 |
func = ggml_sycl_add;
|
| 3855 |
break;
|
| 3856 |
+
case GGML_OP_SUB:
|
| 3857 |
+
func = ggml_sycl_sub;
|
| 3858 |
+
break;
|
| 3859 |
case GGML_OP_ACC:
|
| 3860 |
func = ggml_sycl_acc;
|
| 3861 |
break;
|
| 3862 |
case GGML_OP_MUL:
|
| 3863 |
func = ggml_sycl_mul;
|
| 3864 |
break;
|
| 3865 |
+
case GGML_OP_LOG:
|
| 3866 |
+
func = ggml_sycl_log;
|
| 3867 |
+
break;
|
| 3868 |
case GGML_OP_DIV:
|
| 3869 |
func = ggml_sycl_div;
|
| 3870 |
break;
|
| 3871 |
case GGML_OP_UNARY:
|
| 3872 |
switch (ggml_get_unary_op(tensor)) {
|
| 3873 |
+
case GGML_UNARY_OP_NEG:
|
| 3874 |
+
func = ggml_sycl_neg;
|
| 3875 |
+
break;
|
| 3876 |
+
case GGML_UNARY_OP_STEP:
|
| 3877 |
+
func = ggml_sycl_step;
|
| 3878 |
+
break;
|
| 3879 |
case GGML_UNARY_OP_GELU:
|
| 3880 |
func = ggml_sycl_gelu;
|
| 3881 |
break;
|
|
|
|
| 3891 |
case GGML_UNARY_OP_RELU:
|
| 3892 |
func = ggml_sycl_relu;
|
| 3893 |
break;
|
| 3894 |
+
case GGML_UNARY_OP_SIGMOID:
|
| 3895 |
+
func = ggml_sycl_sigmoid;
|
| 3896 |
+
break;
|
| 3897 |
case GGML_UNARY_OP_HARDSIGMOID:
|
| 3898 |
func = ggml_sycl_hardsigmoid;
|
| 3899 |
break;
|
| 3900 |
case GGML_UNARY_OP_HARDSWISH:
|
| 3901 |
func = ggml_sycl_hardswish;
|
| 3902 |
break;
|
| 3903 |
+
case GGML_UNARY_OP_EXP:
|
| 3904 |
+
func = ggml_sycl_exp;
|
| 3905 |
+
break;
|
| 3906 |
default:
|
| 3907 |
return false;
|
| 3908 |
}
|
|
|
|
| 3940 |
}
|
| 3941 |
func = ggml_sycl_mul_mat_id;
|
| 3942 |
break;
|
| 3943 |
+
case GGML_OP_OUT_PROD:
|
| 3944 |
+
func = ggml_sycl_op_out_prod;
|
| 3945 |
+
break;
|
| 3946 |
case GGML_OP_SCALE:
|
| 3947 |
func = ggml_sycl_scale;
|
| 3948 |
break;
|
| 3949 |
case GGML_OP_SQR:
|
| 3950 |
func = ggml_sycl_sqr;
|
| 3951 |
break;
|
| 3952 |
+
case GGML_OP_SQRT:
|
| 3953 |
+
func = ggml_sycl_sqrt;
|
| 3954 |
+
break;
|
| 3955 |
+
case GGML_OP_SIN:
|
| 3956 |
+
func = ggml_sycl_sin;
|
| 3957 |
+
break;
|
| 3958 |
+
case GGML_OP_COS:
|
| 3959 |
+
func = ggml_sycl_cos;
|
| 3960 |
+
break;
|
| 3961 |
case GGML_OP_CLAMP:
|
| 3962 |
func = ggml_sycl_clamp;
|
| 3963 |
break;
|
|
|
|
| 3989 |
case GGML_OP_POOL_2D:
|
| 3990 |
func = ggml_sycl_pool2d;
|
| 3991 |
break;
|
| 3992 |
+
case GGML_OP_SUM:
|
| 3993 |
+
func = ggml_sycl_sum;
|
| 3994 |
+
break;
|
| 3995 |
case GGML_OP_SUM_ROWS:
|
| 3996 |
func = ggml_sycl_sum_rows;
|
| 3997 |
break;
|
|
|
|
| 4001 |
case GGML_OP_TIMESTEP_EMBEDDING:
|
| 4002 |
func = ggml_sycl_op_timestep_embedding;
|
| 4003 |
break;
|
| 4004 |
+
case GGML_OP_RWKV_WKV6:
|
| 4005 |
+
func = ggml_sycl_op_rwkv_wkv6;
|
| 4006 |
+
break;
|
| 4007 |
default:
|
| 4008 |
return false;
|
| 4009 |
}
|
|
|
|
| 4326 |
} break;
|
| 4327 |
case GGML_OP_UNARY:
|
| 4328 |
switch (ggml_get_unary_op(op)) {
|
| 4329 |
+
case GGML_UNARY_OP_NEG:
|
| 4330 |
+
case GGML_UNARY_OP_STEP:
|
| 4331 |
case GGML_UNARY_OP_GELU:
|
| 4332 |
case GGML_UNARY_OP_SILU:
|
| 4333 |
case GGML_UNARY_OP_RELU:
|
| 4334 |
+
case GGML_UNARY_OP_SIGMOID:
|
| 4335 |
case GGML_UNARY_OP_HARDSIGMOID:
|
| 4336 |
case GGML_UNARY_OP_HARDSWISH:
|
| 4337 |
case GGML_UNARY_OP_GELU_QUICK:
|
| 4338 |
case GGML_UNARY_OP_TANH:
|
| 4339 |
+
case GGML_UNARY_OP_EXP:
|
| 4340 |
return ggml_is_contiguous(op->src[0]);
|
| 4341 |
default:
|
| 4342 |
return false;
|
|
|
|
| 4373 |
}
|
| 4374 |
return true;
|
| 4375 |
} break;
|
| 4376 |
+
case GGML_OP_OUT_PROD:
|
| 4377 |
+
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
|
| 4378 |
case GGML_OP_GET_ROWS:
|
| 4379 |
{
|
| 4380 |
switch (op->src[0]->type) {
|
|
|
|
| 4420 |
case GGML_OP_CONCAT:
|
| 4421 |
{
|
| 4422 |
ggml_type src0_type = op->src[0]->type;
|
| 4423 |
+
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
|
|
|
|
| 4424 |
} break;
|
| 4425 |
case GGML_OP_DUP:
|
| 4426 |
+
case GGML_OP_ARGMAX:
|
| 4427 |
case GGML_OP_NONE:
|
| 4428 |
case GGML_OP_RESHAPE:
|
| 4429 |
case GGML_OP_REPEAT:
|
|
|
|
| 4432 |
case GGML_OP_TRANSPOSE:
|
| 4433 |
case GGML_OP_NORM:
|
| 4434 |
case GGML_OP_ADD:
|
| 4435 |
+
case GGML_OP_ADD1:
|
| 4436 |
+
case GGML_OP_LOG:
|
| 4437 |
+
case GGML_OP_SUB:
|
| 4438 |
case GGML_OP_MUL:
|
| 4439 |
case GGML_OP_DIV:
|
| 4440 |
case GGML_OP_RMS_NORM:
|
| 4441 |
case GGML_OP_SCALE:
|
| 4442 |
case GGML_OP_SQR:
|
| 4443 |
+
case GGML_OP_SQRT:
|
| 4444 |
+
case GGML_OP_SIN:
|
| 4445 |
+
case GGML_OP_COS:
|
| 4446 |
case GGML_OP_CLAMP:
|
| 4447 |
return true;
|
| 4448 |
case GGML_OP_CONT:
|
|
|
|
| 4456 |
// TODO: add support for the new F32 operations
|
| 4457 |
return op->src[0]->type == GGML_TYPE_F16;
|
| 4458 |
case GGML_OP_POOL_2D:
|
| 4459 |
+
case GGML_OP_SUM:
|
| 4460 |
case GGML_OP_SUM_ROWS:
|
| 4461 |
case GGML_OP_ARGSORT:
|
| 4462 |
case GGML_OP_ACC:
|
|
|
|
| 4465 |
case GGML_OP_PAD:
|
| 4466 |
case GGML_OP_LEAKY_RELU:
|
| 4467 |
case GGML_OP_TIMESTEP_EMBEDDING:
|
| 4468 |
+
case GGML_OP_RWKV_WKV6:
|
| 4469 |
return true;
|
| 4470 |
default:
|
| 4471 |
return false;
|
|
|
|
| 4483 |
return buft_ctx->device == sycl_ctx->device;
|
| 4484 |
}
|
| 4485 |
|
| 4486 |
+
static int64_t get_op_batch_size(const ggml_tensor * op) {
|
| 4487 |
+
switch (op->op) {
|
| 4488 |
+
case GGML_OP_GET_ROWS:
|
| 4489 |
+
return op->ne[1]; // this will increse the speed of prefill in test
|
| 4490 |
+
case GGML_OP_MUL_MAT:
|
| 4491 |
+
return op->ne[1];
|
| 4492 |
+
case GGML_OP_MUL_MAT_ID:
|
| 4493 |
+
case GGML_OP_ROPE:
|
| 4494 |
+
return op->ne[2];
|
| 4495 |
+
default:
|
| 4496 |
+
return ggml_nrows(op);
|
| 4497 |
+
}
|
| 4498 |
+
}
|
| 4499 |
+
|
| 4500 |
static bool ggml_backend_sycl_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
|
| 4501 |
const int min_batch_size = 32;
|
| 4502 |
+
return get_op_batch_size(op) >= min_batch_size;
|
| 4503 |
GGML_UNUSED(dev);
|
| 4504 |
}
|
| 4505 |
|
|
@@ -26,5 +26,8 @@
|
|
| 26 |
#include "softmax.hpp"
|
| 27 |
#include "tsembd.hpp"
|
| 28 |
#include "im2col.hpp"
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
#endif // GGML_SYCL_BACKEND_HPP
|
|
|
|
| 26 |
#include "softmax.hpp"
|
| 27 |
#include "tsembd.hpp"
|
| 28 |
#include "im2col.hpp"
|
| 29 |
+
#include "wkv6.hpp"
|
| 30 |
+
#include "outprod.hpp"
|
| 31 |
+
#include "element_wise.hpp"
|
| 32 |
|
| 33 |
#endif // GGML_SYCL_BACKEND_HPP
|
|
@@ -62,3 +62,43 @@ int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block
|
|
| 62 |
}
|
| 63 |
return sycl_down_blk_size;
|
| 64 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
}
|
| 63 |
return sycl_down_blk_size;
|
| 64 |
}
|
| 65 |
+
|
| 66 |
+
void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
| 67 |
+
const ggml_tensor *src1, ggml_tensor *dst,
|
| 68 |
+
const ggml_sycl_op_flatten_t op) try {
|
| 69 |
+
const int64_t nrows0 = ggml_nrows(src0);
|
| 70 |
+
|
| 71 |
+
const bool use_src1 = src1 != nullptr;
|
| 72 |
+
const int64_t nrows1 = use_src1 ? ggml_nrows(src1) : 1;
|
| 73 |
+
|
| 74 |
+
GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
|
| 75 |
+
GGML_ASSERT( dst->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
|
| 76 |
+
|
| 77 |
+
ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
|
| 78 |
+
ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr;
|
| 79 |
+
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
|
| 80 |
+
|
| 81 |
+
// dd = data device
|
| 82 |
+
float * src0_ddf = (float *) src0->data;
|
| 83 |
+
float * src1_ddf = use_src1 ? (float *) src1->data : nullptr;
|
| 84 |
+
float * dst_ddf = (float *) dst->data;
|
| 85 |
+
|
| 86 |
+
ggml_sycl_pool_alloc<float> src0_f(ctx.pool());
|
| 87 |
+
ggml_sycl_pool_alloc<float> src1_f(ctx.pool());
|
| 88 |
+
ggml_sycl_pool_alloc<float> dst_f(ctx.pool());
|
| 89 |
+
|
| 90 |
+
ggml_sycl_set_device(ctx.device);
|
| 91 |
+
queue_ptr main_stream = ctx.stream();
|
| 92 |
+
// GGML_SYCL_DEBUG("ctx.device=%d, main_stream=%p src0_on_device=%d, src1_on_device=%d, dst_on_device=%d\n",
|
| 93 |
+
// ctx.device, main_stream, src0_on_device, src1_on_device, dst_on_device);
|
| 94 |
+
|
| 95 |
+
// do the computation
|
| 96 |
+
op(ctx, src0, src1, dst, src0_ddf, src1_ddf, dst_ddf, main_stream);
|
| 97 |
+
// print_ggml_tensor("tensor", dst);
|
| 98 |
+
}
|
| 99 |
+
catch (sycl::exception const &exc) {
|
| 100 |
+
|
| 101 |
+
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
| 102 |
+
<< ", line:" << __LINE__ << std::endl;
|
| 103 |
+
std::exit(1);
|
| 104 |
+
}
|
|
@@ -404,4 +404,262 @@ static __dpct_inline__ Tp* get_pointer(sycl::local_accessor<Tp, dim> acc) {
|
|
| 404 |
|
| 405 |
int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size);
|
| 406 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 407 |
#endif // GGML_SYCL_COMMON_HPP
|
|
|
|
| 404 |
|
| 405 |
int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size);
|
| 406 |
|
| 407 |
+
typedef void (*ggml_sycl_op_flatten_t)(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
| 408 |
+
const ggml_tensor *src1,
|
| 409 |
+
ggml_tensor *dst, const float *src0_dd,
|
| 410 |
+
const float *src1_dd, float *dst_dd,
|
| 411 |
+
const queue_ptr &main_stream);
|
| 412 |
+
|
| 413 |
+
template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
|
| 414 |
+
static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
|
| 415 |
+
int ne0, int ne1, int ne2, int ne3,
|
| 416 |
+
int ne10, int ne11, int ne12, int ne13,
|
| 417 |
+
/*int s0, */ int s1, int s2, int s3,
|
| 418 |
+
/*int s10,*/ int s11, int s12, int s13,
|
| 419 |
+
const sycl::nd_item<3> &item_ct1) {
|
| 420 |
+
const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
| 421 |
+
item_ct1.get_local_id(2);
|
| 422 |
+
const int i1 = (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
| 423 |
+
item_ct1.get_local_id(1));
|
| 424 |
+
const int i2 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
|
| 425 |
+
item_ct1.get_local_id(0)) /
|
| 426 |
+
ne3;
|
| 427 |
+
const int i3 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
|
| 428 |
+
item_ct1.get_local_id(0)) %
|
| 429 |
+
ne3;
|
| 430 |
+
|
| 431 |
+
if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
|
| 432 |
+
return;
|
| 433 |
+
}
|
| 434 |
+
|
| 435 |
+
const int i11 = i1 % ne11;
|
| 436 |
+
const int i12 = i2 % ne12;
|
| 437 |
+
const int i13 = i3 % ne13;
|
| 438 |
+
|
| 439 |
+
const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
|
| 440 |
+
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
|
| 441 |
+
const size_t i_dst = i_src0;
|
| 442 |
+
|
| 443 |
+
const src0_t * src0_row = src0 + i_src0;
|
| 444 |
+
const src1_t * src1_row = src1 + i_src1;
|
| 445 |
+
dst_t * dst_row = dst + i_dst;
|
| 446 |
+
|
| 447 |
+
for (int i0 = i0s; i0 < ne0;
|
| 448 |
+
i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) {
|
| 449 |
+
const int i10 = i0 % ne10;
|
| 450 |
+
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
|
| 451 |
+
}
|
| 452 |
+
}
|
| 453 |
+
|
| 454 |
+
template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
|
| 455 |
+
static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
|
| 456 |
+
int ne0, int ne1, int ne2, int ne3,
|
| 457 |
+
int ne10, int ne11, int ne12, int ne13,
|
| 458 |
+
/*int s0, */ int s1, int s2, int s3,
|
| 459 |
+
/*int s10,*/ int s11, int s12, int s13,
|
| 460 |
+
const sycl::nd_item<3> &item_ct1) {
|
| 461 |
+
|
| 462 |
+
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
| 463 |
+
item_ct1.get_local_id(2);
|
| 464 |
+
|
| 465 |
+
const int i3 = i/(ne2*ne1*ne0);
|
| 466 |
+
const int i2 = (i/(ne1*ne0)) % ne2;
|
| 467 |
+
const int i1 = (i/ne0) % ne1;
|
| 468 |
+
const int i0 = i % ne0;
|
| 469 |
+
|
| 470 |
+
if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
|
| 471 |
+
return;
|
| 472 |
+
}
|
| 473 |
+
|
| 474 |
+
const int i11 = i1 % ne11;
|
| 475 |
+
const int i12 = i2 % ne12;
|
| 476 |
+
const int i13 = i3 % ne13;
|
| 477 |
+
|
| 478 |
+
const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
|
| 479 |
+
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
|
| 480 |
+
const size_t i_dst = i_src0;
|
| 481 |
+
|
| 482 |
+
const src0_t * src0_row = src0 + i_src0;
|
| 483 |
+
const src1_t * src1_row = src1 + i_src1;
|
| 484 |
+
dst_t * dst_row = dst + i_dst;
|
| 485 |
+
|
| 486 |
+
const int i10 = i0 % ne10;
|
| 487 |
+
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
|
| 488 |
+
}
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
template<float (*bin_op)(const float, const float)>
|
| 492 |
+
struct bin_bcast_sycl {
|
| 493 |
+
template <typename src0_t, typename src1_t, typename dst_t>
|
| 494 |
+
void operator()(ggml_backend_sycl_context & ctx,
|
| 495 |
+
const struct ggml_tensor *src0,
|
| 496 |
+
const struct ggml_tensor *src1, struct ggml_tensor *dst,
|
| 497 |
+
const src0_t *src0_dd, const src1_t *src1_dd, dst_t *dst_dd,
|
| 498 |
+
queue_ptr stream) {
|
| 499 |
+
|
| 500 |
+
GGML_TENSOR_BINARY_OP_LOCALS
|
| 501 |
+
|
| 502 |
+
int nr0 = ne10/ne0;
|
| 503 |
+
int nr1 = ne11/ne1;
|
| 504 |
+
int nr2 = ne12/ne2;
|
| 505 |
+
int nr3 = ne13/ne3;
|
| 506 |
+
|
| 507 |
+
int nr[4] = { nr0, nr1, nr2, nr3 };
|
| 508 |
+
|
| 509 |
+
// collapse dimensions until first broadcast dimension
|
| 510 |
+
int64_t cne0[] = {ne0, ne1, ne2, ne3};
|
| 511 |
+
int64_t cne1[] = {ne10, ne11, ne12, ne13};
|
| 512 |
+
size_t cnb0[] = {nb0, nb1, nb2, nb3};
|
| 513 |
+
size_t cnb1[] = {nb10, nb11, nb12, nb13};
|
| 514 |
+
auto collapse = [](int64_t cne[]) {
|
| 515 |
+
cne[0] *= cne[1];
|
| 516 |
+
cne[1] = cne[2];
|
| 517 |
+
cne[2] = cne[3];
|
| 518 |
+
cne[3] = 1;
|
| 519 |
+
};
|
| 520 |
+
|
| 521 |
+
auto collapse_nb = [](size_t cnb[], int64_t cne[]) {
|
| 522 |
+
cnb[1] *= cne[1];
|
| 523 |
+
cnb[2] *= cne[2];
|
| 524 |
+
cnb[3] *= cne[3];
|
| 525 |
+
};
|
| 526 |
+
|
| 527 |
+
for (int i = 0; i < 4; i++) {
|
| 528 |
+
if (nr[i] != 1) {
|
| 529 |
+
break;
|
| 530 |
+
}
|
| 531 |
+
if (i > 0) {
|
| 532 |
+
collapse_nb(cnb0, cne0);
|
| 533 |
+
collapse_nb(cnb1, cne1);
|
| 534 |
+
collapse(cne0);
|
| 535 |
+
collapse(cne1);
|
| 536 |
+
}
|
| 537 |
+
}
|
| 538 |
+
{
|
| 539 |
+
int64_t ne0 = cne0[0];
|
| 540 |
+
int64_t ne1 = cne0[1];
|
| 541 |
+
int64_t ne2 = cne0[2];
|
| 542 |
+
int64_t ne3 = cne0[3];
|
| 543 |
+
|
| 544 |
+
int64_t ne10 = cne1[0];
|
| 545 |
+
int64_t ne11 = cne1[1];
|
| 546 |
+
int64_t ne12 = cne1[2];
|
| 547 |
+
int64_t ne13 = cne1[3];
|
| 548 |
+
|
| 549 |
+
size_t nb0 = cnb0[0];
|
| 550 |
+
size_t nb1 = cnb0[1];
|
| 551 |
+
size_t nb2 = cnb0[2];
|
| 552 |
+
size_t nb3 = cnb0[3];
|
| 553 |
+
|
| 554 |
+
size_t nb10 = cnb1[0];
|
| 555 |
+
size_t nb11 = cnb1[1];
|
| 556 |
+
size_t nb12 = cnb1[2];
|
| 557 |
+
size_t nb13 = cnb1[3];
|
| 558 |
+
|
| 559 |
+
size_t s0 = nb0 / sizeof(dst_t);
|
| 560 |
+
size_t s1 = nb1 / sizeof(dst_t);
|
| 561 |
+
size_t s2 = nb2 / sizeof(dst_t);
|
| 562 |
+
size_t s3 = nb3 / sizeof(dst_t);
|
| 563 |
+
|
| 564 |
+
size_t s10 = nb10 / sizeof(src1_t);
|
| 565 |
+
size_t s11 = nb11 / sizeof(src1_t);
|
| 566 |
+
size_t s12 = nb12 / sizeof(src1_t);
|
| 567 |
+
size_t s13 = nb13 / sizeof(src1_t);
|
| 568 |
+
|
| 569 |
+
GGML_ASSERT(s0 == 1);
|
| 570 |
+
GGML_ASSERT(s10 == 1);
|
| 571 |
+
|
| 572 |
+
const int block_size = 128;
|
| 573 |
+
|
| 574 |
+
int64_t hne0 = std::max(ne0/2LL, 1LL);
|
| 575 |
+
|
| 576 |
+
sycl::range<3> block_dims(1, 1, 1);
|
| 577 |
+
block_dims[2] = std::min<unsigned int>(hne0, block_size);
|
| 578 |
+
block_dims[1] = std::min<unsigned int>(
|
| 579 |
+
ne1, block_size / (unsigned int)block_dims[2]);
|
| 580 |
+
block_dims[0] = std::min(
|
| 581 |
+
std::min<unsigned int>(
|
| 582 |
+
ne2 * ne3, block_size / (unsigned int)block_dims[2] /
|
| 583 |
+
(unsigned int)block_dims[1]),
|
| 584 |
+
64U);
|
| 585 |
+
|
| 586 |
+
sycl::range<3> block_nums(
|
| 587 |
+
(ne2 * ne3 + block_dims[0] - 1) / block_dims[0],
|
| 588 |
+
(ne1 + block_dims[1] - 1) / block_dims[1],
|
| 589 |
+
(hne0 + block_dims[2] - 1) / block_dims[2]);
|
| 590 |
+
|
| 591 |
+
if (block_nums[0] > 65535) {
|
| 592 |
+
// this is the maximum number of blocks in z direction, fallback to 1D grid kernel
|
| 593 |
+
int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
|
| 594 |
+
{
|
| 595 |
+
dpct::has_capability_or_fail(stream->get_device(),
|
| 596 |
+
{sycl::aspect::fp16});
|
| 597 |
+
|
| 598 |
+
stream->parallel_for(
|
| 599 |
+
sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) *
|
| 600 |
+
sycl::range<3>(1, 1, block_size),
|
| 601 |
+
sycl::range<3>(1, 1, block_size)),
|
| 602 |
+
[=](sycl::nd_item<3> item_ct1) {
|
| 603 |
+
k_bin_bcast_unravel<bin_op>(
|
| 604 |
+
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3,
|
| 605 |
+
ne10, ne11, ne12, ne13, s1, s2, s3, s11, s12,
|
| 606 |
+
s13, item_ct1);
|
| 607 |
+
});
|
| 608 |
+
}
|
| 609 |
+
} else {
|
| 610 |
+
/*
|
| 611 |
+
DPCT1049:16: The work-group size passed to the SYCL kernel may
|
| 612 |
+
exceed the limit. To get the device limit, query
|
| 613 |
+
info::device::max_work_group_size. Adjust the work-group size if
|
| 614 |
+
needed.
|
| 615 |
+
*/
|
| 616 |
+
dpct::has_capability_or_fail(stream->get_device(),
|
| 617 |
+
{sycl::aspect::fp16});
|
| 618 |
+
|
| 619 |
+
stream->parallel_for(
|
| 620 |
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
| 621 |
+
[=](sycl::nd_item<3> item_ct1) {
|
| 622 |
+
k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
|
| 623 |
+
ne2, ne3, ne10, ne11, ne12, ne13,
|
| 624 |
+
s1, s2, s3, s11, s12, s13,
|
| 625 |
+
item_ct1);
|
| 626 |
+
});
|
| 627 |
+
}
|
| 628 |
+
}
|
| 629 |
+
}
|
| 630 |
+
};
|
| 631 |
+
|
| 632 |
+
template <class op>
|
| 633 |
+
inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
| 634 |
+
const ggml_tensor *src1, ggml_tensor *dst,
|
| 635 |
+
const float *src0_dd, const float *src1_dd,
|
| 636 |
+
float *dst_dd,
|
| 637 |
+
const queue_ptr &main_stream) {
|
| 638 |
+
|
| 639 |
+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
| 640 |
+
op()(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
|
| 641 |
+
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
|
| 642 |
+
op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd,
|
| 643 |
+
(sycl::half *)dst_dd, main_stream);
|
| 644 |
+
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
|
| 645 |
+
op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd, dst_dd,
|
| 646 |
+
main_stream);
|
| 647 |
+
} else if (src0->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
|
| 648 |
+
op()(ctx, src0, src1, dst, (const int32_t *)src0_dd, (const int32_t *)src1_dd, (int32_t *)dst_dd,
|
| 649 |
+
main_stream);
|
| 650 |
+
} else if (src0->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) {
|
| 651 |
+
op()(ctx, src0, src1, dst, (const int16_t *)src0_dd, (const int16_t *)src1_dd, (int16_t *)dst_dd,
|
| 652 |
+
main_stream);
|
| 653 |
+
} else {
|
| 654 |
+
fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
|
| 655 |
+
ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
|
| 656 |
+
GGML_ABORT("fatal error");
|
| 657 |
+
}
|
| 658 |
+
}
|
| 659 |
+
|
| 660 |
+
|
| 661 |
+
void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
| 662 |
+
const ggml_tensor *src1, ggml_tensor *dst,
|
| 663 |
+
const ggml_sycl_op_flatten_t op);
|
| 664 |
+
|
| 665 |
#endif // GGML_SYCL_COMMON_HPP
|
|
@@ -106,6 +106,7 @@ static void concat_f32_sycl(const float *x, const float *y, float *dst,
|
|
| 106 |
concat_f32_dim1(x, y, dst, ne0, ne01, item_ct1);
|
| 107 |
});
|
| 108 |
break;
|
|
|
|
| 109 |
default:
|
| 110 |
stream->parallel_for(
|
| 111 |
sycl::nd_range<3>(gridDim *
|
|
|
|
| 106 |
concat_f32_dim1(x, y, dst, ne0, ne01, item_ct1);
|
| 107 |
});
|
| 108 |
break;
|
| 109 |
+
// dim >=2 will be dispatched to the default path
|
| 110 |
default:
|
| 111 |
stream->parallel_for(
|
| 112 |
sycl::nd_range<3>(gridDim *
|
|
@@ -0,0 +1,1011 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "common.hpp"
|
| 2 |
+
#include "element_wise.hpp"
|
| 3 |
+
|
| 4 |
+
void acc_f32(const float * x, const float * y, float * dst, const int ne,
|
| 5 |
+
const int ne10, const int ne11, const int ne12,
|
| 6 |
+
const int nb1, const int nb2, int offset, const sycl::nd_item<3> &item_ct1) {
|
| 7 |
+
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
| 8 |
+
item_ct1.get_local_id(2);
|
| 9 |
+
if (i >= ne) {
|
| 10 |
+
return;
|
| 11 |
+
}
|
| 12 |
+
int src1_idx = i - offset;
|
| 13 |
+
int oz = src1_idx / nb2;
|
| 14 |
+
int oy = (src1_idx - (oz * nb2)) / nb1;
|
| 15 |
+
int ox = src1_idx % nb1;
|
| 16 |
+
if (src1_idx >= 0 && ox < ne10 && oy < ne11 && oz < ne12) {
|
| 17 |
+
dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11];
|
| 18 |
+
} else {
|
| 19 |
+
dst[i] = x[i];
|
| 20 |
+
}
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
void gelu_f32(const float * x, float * dst, const int k,
|
| 24 |
+
const sycl::nd_item<3> &item_ct1) {
|
| 25 |
+
const float GELU_COEF_A = 0.044715f;
|
| 26 |
+
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
| 27 |
+
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
| 28 |
+
item_ct1.get_local_id(2);
|
| 29 |
+
|
| 30 |
+
if (i >= k) {
|
| 31 |
+
return;
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
float xi = x[i];
|
| 35 |
+
dst[i] = 0.5f * xi *
|
| 36 |
+
(1.0f +
|
| 37 |
+
sycl::tanh(SQRT_2_OVER_PI * xi * (1.0f + GELU_COEF_A * xi * xi)));
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
void silu_f32(const float * x, float * dst, const int k,
|
| 41 |
+
const sycl::nd_item<3> &item_ct1) {
|
| 42 |
+
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
| 43 |
+
item_ct1.get_local_id(2);
|
| 44 |
+
|
| 45 |
+
if (i >= k) {
|
| 46 |
+
return;
|
| 47 |
+
}
|
| 48 |
+
dst[i] = x[i] / (1.0f + sycl::native::exp(-x[i]));
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
void gelu_quick_f32(const float *x, float *dst, int k,
|
| 52 |
+
const sycl::nd_item<3> &item_ct1) {
|
| 53 |
+
const float GELU_QUICK_COEF = -1.702f;
|
| 54 |
+
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
| 55 |
+
item_ct1.get_local_id(2);
|
| 56 |
+
if (i >= k) {
|
| 57 |
+
return;
|
| 58 |
+
}
|
| 59 |
+
dst[i] = x[i] * (1.0f / (1.0f + sycl::native::exp(GELU_QUICK_COEF * x[i])));
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
void tanh_f32(const float *x, float *dst, int k,
|
| 63 |
+
const sycl::nd_item<3> &item_ct1) {
|
| 64 |
+
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
| 65 |
+
item_ct1.get_local_id(2);
|
| 66 |
+
if (i >= k) {
|
| 67 |
+
return;
|
| 68 |
+
}
|
| 69 |
+
dst[i] = sycl::tanh((float)(x[i]));
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
void relu_f32(const float * x, float * dst, const int k,
|
| 73 |
+
const sycl::nd_item<3> &item_ct1) {
|
| 74 |
+
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
| 75 |
+
item_ct1.get_local_id(2);
|
| 76 |
+
|
| 77 |
+
if (i >= k) {
|
| 78 |
+
return;
|
| 79 |
+
}
|
| 80 |
+
dst[i] = sycl::fmax((float)(x[i]), (float)0);
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
void sigmoid_f32(const float * x, float * dst, const int k,
|
| 84 |
+
const sycl::nd_item<3> &item_ct1) {
|
| 85 |
+
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
| 86 |
+
item_ct1.get_local_id(2);
|
| 87 |
+
|
| 88 |
+
if (i >= k) {
|
| 89 |
+
return;
|
| 90 |
+
}
|
| 91 |
+
dst[i] = 1.0f / (1.0f + sycl::native::exp(-x[i]));
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
void sqrt_f32(const float * x, float * dst, const int k,
|
| 95 |
+
const sycl::nd_item<3> &item_ct1) {
|
| 96 |
+
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
| 97 |
+
item_ct1.get_local_id(2);
|
| 98 |
+
|
| 99 |
+
if (i >= k) {
|
| 100 |
+
return;
|
| 101 |
+
}
|
| 102 |
+
dst[i] = sycl::sqrt(x[i]);
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
void sin_f32(const float * x, float * dst, const int k,
|
| 106 |
+
const sycl::nd_item<3> &item_ct1) {
|
| 107 |
+
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
| 108 |
+
item_ct1.get_local_id(2);
|
| 109 |
+
|
| 110 |
+
if (i >= k) {
|
| 111 |
+
return;
|
| 112 |
+
}
|
| 113 |
+
dst[i] = sycl::sin(x[i]);
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
void cos_f32(const float * x, float * dst, const int k,
|
| 117 |
+
const sycl::nd_item<3> &item_ct1) {
|
| 118 |
+
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
| 119 |
+
item_ct1.get_local_id(2);
|
| 120 |
+
|
| 121 |
+
if (i >= k) {
|
| 122 |
+
return;
|
| 123 |
+
}
|
| 124 |
+
dst[i] = sycl::cos(x[i]);
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
void hardsigmoid_f32(const float * x, float * dst, const int k,
|
| 128 |
+
const sycl::nd_item<3> &item_ct1) {
|
| 129 |
+
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
| 130 |
+
item_ct1.get_local_id(2);
|
| 131 |
+
|
| 132 |
+
if (i >= k) {
|
| 133 |
+
return;
|
| 134 |
+
}
|
| 135 |
+
dst[i] = sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f));
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
void hardswish_f32(const float * x, float * dst, const int k,
|
| 139 |
+
const sycl::nd_item<3> &item_ct1) {
|
| 140 |
+
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
| 141 |
+
item_ct1.get_local_id(2);
|
| 142 |
+
|
| 143 |
+
if (i >= k) {
|
| 144 |
+
return;
|
| 145 |
+
}
|
| 146 |
+
dst[i] = x[i] * sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f));
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
void exp_f32(const float * x, float * dst, const int k,
|
| 150 |
+
const sycl::nd_item<3> &item_ct1) {
|
| 151 |
+
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
| 152 |
+
item_ct1.get_local_id(2);
|
| 153 |
+
|
| 154 |
+
if (i >= k) {
|
| 155 |
+
return;
|
| 156 |
+
}
|
| 157 |
+
dst[i] = sycl::exp(x[i]);
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
void log_f32(const float * x, float * dst, const int k,
|
| 161 |
+
const sycl::nd_item<3> &item_ct1) {
|
| 162 |
+
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
| 163 |
+
item_ct1.get_local_id(2);
|
| 164 |
+
|
| 165 |
+
if (i >= k) {
|
| 166 |
+
return;
|
| 167 |
+
}
|
| 168 |
+
float xi = x[i];
|
| 169 |
+
if (xi <= 0) {
|
| 170 |
+
dst[i] = -INFINITY;
|
| 171 |
+
} else {
|
| 172 |
+
dst[i] = sycl::log(xi);
|
| 173 |
+
}
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
void neg_f32(const float * x, float * dst, const int k,
|
| 177 |
+
const sycl::nd_item<3> &item_ct1) {
|
| 178 |
+
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
| 179 |
+
item_ct1.get_local_id(2);
|
| 180 |
+
|
| 181 |
+
if (i >= k) {
|
| 182 |
+
return;
|
| 183 |
+
}
|
| 184 |
+
dst[i] = -x[i];
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
void step_f32(const float * x, float * dst, const int k,
|
| 188 |
+
const sycl::nd_item<3> &item_ct1) {
|
| 189 |
+
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
| 190 |
+
item_ct1.get_local_id(2);
|
| 191 |
+
|
| 192 |
+
if (i >= k) {
|
| 193 |
+
return;
|
| 194 |
+
}
|
| 195 |
+
dst[i] = x[i] > 0.0f;
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
void leaky_relu_f32(const float *x, float *dst, const int k, const float negative_slope,
|
| 199 |
+
const sycl::nd_item<3> &item_ct1) {
|
| 200 |
+
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
| 201 |
+
item_ct1.get_local_id(2);
|
| 202 |
+
if (i >= k) {
|
| 203 |
+
return;
|
| 204 |
+
}
|
| 205 |
+
dst[i] = sycl::fmax((float)(x[i]), (float)0) +
|
| 206 |
+
sycl::fmin((float)(x[i]), 0.0f) * negative_slope;
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
void sqr_f32(const float * x, float * dst, const int k,
|
| 210 |
+
const sycl::nd_item<3> &item_ct1) {
|
| 211 |
+
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
| 212 |
+
item_ct1.get_local_id(2);
|
| 213 |
+
|
| 214 |
+
if (i >= k) {
|
| 215 |
+
return;
|
| 216 |
+
}
|
| 217 |
+
dst[i] = x[i] * x[i];
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
void upscale_f32(const float *x, float *dst, const int nb00, const int nb01,
|
| 221 |
+
const int nb02, const int nb03, const int ne10, const int ne11,
|
| 222 |
+
const int ne12, const int ne13, const float sf0, const float sf1,
|
| 223 |
+
const float sf2, const float sf3, const sycl::nd_item<1> &item_ct1) {
|
| 224 |
+
int index = item_ct1.get_local_id(0) +
|
| 225 |
+
item_ct1.get_group(0) * item_ct1.get_local_range(0);
|
| 226 |
+
if (index >= ne10 * ne11 * ne12 * ne13) {
|
| 227 |
+
return;
|
| 228 |
+
}
|
| 229 |
+
// operation
|
| 230 |
+
int i10 = index % ne10;
|
| 231 |
+
int i11 = (index / ne10) % ne11;
|
| 232 |
+
int i12 = (index / (ne10 * ne11)) % ne12;
|
| 233 |
+
int i13 = (index / (ne10 * ne11 * ne12)) % ne13;
|
| 234 |
+
|
| 235 |
+
int i00 = i10 / sf0;
|
| 236 |
+
int i01 = i11 / sf1;
|
| 237 |
+
int i02 = i12 / sf2;
|
| 238 |
+
int i03 = i13 / sf3;
|
| 239 |
+
|
| 240 |
+
dst[index] = *(float *)((char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02,
|
| 244 |
+
const sycl::nd_item<3> &item_ct1) {
|
| 245 |
+
int nidx = item_ct1.get_local_id(2) +
|
| 246 |
+
item_ct1.get_group(2) * item_ct1.get_local_range(2);
|
| 247 |
+
if (nidx >= ne0) {
|
| 248 |
+
return;
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
// operation
|
| 252 |
+
int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
|
| 253 |
+
item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
|
| 254 |
+
if (nidx < ne00 && item_ct1.get_group(1) < ne01 &&
|
| 255 |
+
item_ct1.get_group(0) < ne02) {
|
| 256 |
+
int offset_src = nidx + item_ct1.get_group(1) * ne00 +
|
| 257 |
+
item_ct1.get_group(0) * ne00 * ne01;
|
| 258 |
+
dst[offset_dst] = x[offset_src];
|
| 259 |
+
} else {
|
| 260 |
+
dst[offset_dst] = 0.0f;
|
| 261 |
+
}
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
void acc_f32_sycl(const float *x, const float *y, float *dst,
|
| 267 |
+
const int n_elements, const int ne10, const int ne11,
|
| 268 |
+
const int ne12, const int nb1, const int nb2,
|
| 269 |
+
const int offset, queue_ptr stream) {
|
| 270 |
+
int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE;
|
| 271 |
+
stream->parallel_for(
|
| 272 |
+
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
| 273 |
+
sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE),
|
| 274 |
+
sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)),
|
| 275 |
+
[=](sycl::nd_item<3> item_ct1) {
|
| 276 |
+
acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset,
|
| 277 |
+
item_ct1);
|
| 278 |
+
});
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
void gelu_f32_sycl(const float *x, float *dst, const int k,
|
| 282 |
+
queue_ptr stream) {
|
| 283 |
+
const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
|
| 284 |
+
stream->parallel_for(
|
| 285 |
+
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
| 286 |
+
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
|
| 287 |
+
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
|
| 288 |
+
[=](sycl::nd_item<3> item_ct1) {
|
| 289 |
+
gelu_f32(x, dst, k, item_ct1);
|
| 290 |
+
});
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
void silu_f32_sycl(const float *x, float *dst, const int k,
|
| 294 |
+
queue_ptr stream) {
|
| 295 |
+
const int num_blocks = (k + SYCL_SILU_BLOCK_SIZE - 1) / SYCL_SILU_BLOCK_SIZE;
|
| 296 |
+
stream->parallel_for(
|
| 297 |
+
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
| 298 |
+
sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE),
|
| 299 |
+
sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE)),
|
| 300 |
+
[=](sycl::nd_item<3> item_ct1) {
|
| 301 |
+
silu_f32(x, dst, k, item_ct1);
|
| 302 |
+
});
|
| 303 |
+
}
|
| 304 |
+
|
| 305 |
+
void gelu_quick_f32_sycl(const float *x, float *dst, const int k,
|
| 306 |
+
queue_ptr stream) {
|
| 307 |
+
const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
|
| 308 |
+
stream->parallel_for(
|
| 309 |
+
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
| 310 |
+
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
|
| 311 |
+
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
|
| 312 |
+
[=](sycl::nd_item<3> item_ct1) {
|
| 313 |
+
gelu_quick_f32(x, dst, k, item_ct1);
|
| 314 |
+
});
|
| 315 |
+
}
|
| 316 |
+
|
| 317 |
+
void tanh_f32_sycl(const float *x, float *dst, const int k,
|
| 318 |
+
queue_ptr stream) {
|
| 319 |
+
const int num_blocks = (k + SYCL_TANH_BLOCK_SIZE - 1) / SYCL_TANH_BLOCK_SIZE;
|
| 320 |
+
stream->parallel_for(
|
| 321 |
+
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
| 322 |
+
sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE),
|
| 323 |
+
sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE)),
|
| 324 |
+
[=](sycl::nd_item<3> item_ct1) {
|
| 325 |
+
tanh_f32(x, dst, k, item_ct1);
|
| 326 |
+
});
|
| 327 |
+
}
|
| 328 |
+
|
| 329 |
+
void relu_f32_sycl(const float *x, float *dst, const int k,
|
| 330 |
+
queue_ptr stream) {
|
| 331 |
+
const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
|
| 332 |
+
stream->parallel_for(
|
| 333 |
+
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
| 334 |
+
sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE),
|
| 335 |
+
sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)),
|
| 336 |
+
[=](sycl::nd_item<3> item_ct1) {
|
| 337 |
+
relu_f32(x, dst, k, item_ct1);
|
| 338 |
+
});
|
| 339 |
+
}
|
| 340 |
+
|
| 341 |
+
void hardsigmoid_f32_sycl(const float *x, float *dst, const int k,
|
| 342 |
+
queue_ptr stream) {
|
| 343 |
+
const int num_blocks = (k + SYCL_HARDSIGMOID_BLOCK_SIZE - 1) / SYCL_HARDSIGMOID_BLOCK_SIZE;
|
| 344 |
+
stream->parallel_for(
|
| 345 |
+
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
| 346 |
+
sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE),
|
| 347 |
+
sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE)),
|
| 348 |
+
[=](sycl::nd_item<3> item_ct1) {
|
| 349 |
+
hardsigmoid_f32(x, dst, k, item_ct1);
|
| 350 |
+
});
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
void hardswish_f32_sycl(const float *x, float *dst, const int k,
|
| 354 |
+
queue_ptr stream) {
|
| 355 |
+
const int num_blocks = (k + SYCL_HARDSWISH_BLOCK_SIZE - 1) / SYCL_HARDSWISH_BLOCK_SIZE;
|
| 356 |
+
stream->parallel_for(
|
| 357 |
+
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
| 358 |
+
sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE),
|
| 359 |
+
sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE)),
|
| 360 |
+
[=](sycl::nd_item<3> item_ct1) {
|
| 361 |
+
hardswish_f32(x, dst, k, item_ct1);
|
| 362 |
+
});
|
| 363 |
+
}
|
| 364 |
+
|
| 365 |
+
void exp_f32_sycl(const float *x, float *dst, const int k,
|
| 366 |
+
queue_ptr stream) {
|
| 367 |
+
const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE;
|
| 368 |
+
stream->parallel_for(
|
| 369 |
+
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
| 370 |
+
sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE),
|
| 371 |
+
sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)),
|
| 372 |
+
[=](sycl::nd_item<3> item_ct1) {
|
| 373 |
+
exp_f32(x, dst, k, item_ct1);
|
| 374 |
+
});
|
| 375 |
+
}
|
| 376 |
+
|
| 377 |
+
void log_f32_sycl(const float *x, float *dst, const int k,
|
| 378 |
+
queue_ptr stream) {
|
| 379 |
+
const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE;
|
| 380 |
+
stream->parallel_for(
|
| 381 |
+
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
| 382 |
+
sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE),
|
| 383 |
+
sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)),
|
| 384 |
+
[=](sycl::nd_item<3> item_ct1) {
|
| 385 |
+
log_f32(x, dst, k, item_ct1);
|
| 386 |
+
});
|
| 387 |
+
}
|
| 388 |
+
|
| 389 |
+
void neg_f32_sycl(const float *x, float *dst, const int k,
|
| 390 |
+
queue_ptr stream) {
|
| 391 |
+
const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE;
|
| 392 |
+
stream->parallel_for(
|
| 393 |
+
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
| 394 |
+
sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE),
|
| 395 |
+
sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)),
|
| 396 |
+
[=](sycl::nd_item<3> item_ct1) {
|
| 397 |
+
neg_f32(x, dst, k, item_ct1);
|
| 398 |
+
});
|
| 399 |
+
}
|
| 400 |
+
|
| 401 |
+
void step_f32_sycl(const float *x, float *dst, const int k,
|
| 402 |
+
queue_ptr stream) {
|
| 403 |
+
const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE;
|
| 404 |
+
stream->parallel_for(
|
| 405 |
+
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
| 406 |
+
sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE),
|
| 407 |
+
sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)),
|
| 408 |
+
[=](sycl::nd_item<3> item_ct1) {
|
| 409 |
+
step_f32(x, dst, k, item_ct1);
|
| 410 |
+
});
|
| 411 |
+
}
|
| 412 |
+
|
| 413 |
+
void sigmoid_f32_sycl(const float *x, float *dst, const int k,
|
| 414 |
+
queue_ptr stream) {
|
| 415 |
+
const int num_blocks = (k + SYCL_SIGMOID_BLOCK_SIZE - 1) / SYCL_SIGMOID_BLOCK_SIZE;
|
| 416 |
+
stream->parallel_for(
|
| 417 |
+
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
| 418 |
+
sycl::range<3>(1, 1, SYCL_SIGMOID_BLOCK_SIZE),
|
| 419 |
+
sycl::range<3>(1, 1, SYCL_SIGMOID_BLOCK_SIZE)),
|
| 420 |
+
[=](sycl::nd_item<3> item_ct1) {
|
| 421 |
+
sigmoid_f32(x, dst, k, item_ct1);
|
| 422 |
+
});
|
| 423 |
+
}
|
| 424 |
+
|
| 425 |
+
void sqrt_f32_sycl(const float *x, float *dst, const int k,
|
| 426 |
+
queue_ptr stream) {
|
| 427 |
+
const int num_blocks = (k + SYCL_SQRT_BLOCK_SIZE - 1) / SYCL_SQRT_BLOCK_SIZE;
|
| 428 |
+
stream->parallel_for(
|
| 429 |
+
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
| 430 |
+
sycl::range<3>(1, 1, SYCL_SQRT_BLOCK_SIZE),
|
| 431 |
+
sycl::range<3>(1, 1, SYCL_SQRT_BLOCK_SIZE)),
|
| 432 |
+
[=](sycl::nd_item<3> item_ct1) {
|
| 433 |
+
sqrt_f32(x, dst, k, item_ct1);
|
| 434 |
+
});
|
| 435 |
+
}
|
| 436 |
+
|
| 437 |
+
void sin_f32_sycl(const float *x, float *dst, const int k,
|
| 438 |
+
queue_ptr stream) {
|
| 439 |
+
const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE;
|
| 440 |
+
stream->parallel_for(
|
| 441 |
+
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
| 442 |
+
sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE),
|
| 443 |
+
sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE)),
|
| 444 |
+
[=](sycl::nd_item<3> item_ct1) {
|
| 445 |
+
sin_f32(x, dst, k, item_ct1);
|
| 446 |
+
});
|
| 447 |
+
}
|
| 448 |
+
|
| 449 |
+
void cos_f32_sycl(const float *x, float *dst, const int k,
|
| 450 |
+
queue_ptr stream) {
|
| 451 |
+
const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE;
|
| 452 |
+
stream->parallel_for(
|
| 453 |
+
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
| 454 |
+
sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE),
|
| 455 |
+
sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE)),
|
| 456 |
+
[=](sycl::nd_item<3> item_ct1) {
|
| 457 |
+
cos_f32(x, dst, k, item_ct1);
|
| 458 |
+
});
|
| 459 |
+
}
|
| 460 |
+
|
| 461 |
+
void leaky_relu_f32_sycl(const float *x, float *dst, const int k,
|
| 462 |
+
const float negative_slope,
|
| 463 |
+
queue_ptr stream) {
|
| 464 |
+
const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
|
| 465 |
+
stream->parallel_for(
|
| 466 |
+
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
| 467 |
+
sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE),
|
| 468 |
+
sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)),
|
| 469 |
+
[=](sycl::nd_item<3> item_ct1) {
|
| 470 |
+
leaky_relu_f32(x, dst, k, negative_slope, item_ct1);
|
| 471 |
+
});
|
| 472 |
+
}
|
| 473 |
+
|
| 474 |
+
void sqr_f32_sycl(const float *x, float *dst, const int k,
|
| 475 |
+
queue_ptr stream) {
|
| 476 |
+
const int num_blocks = (k + SYCL_SQR_BLOCK_SIZE - 1) / SYCL_SQR_BLOCK_SIZE;
|
| 477 |
+
stream->parallel_for(
|
| 478 |
+
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
| 479 |
+
sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE),
|
| 480 |
+
sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE)),
|
| 481 |
+
[=](sycl::nd_item<3> item_ct1) {
|
| 482 |
+
sqr_f32(x, dst, k, item_ct1);
|
| 483 |
+
});
|
| 484 |
+
}
|
| 485 |
+
|
| 486 |
+
void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01,
|
| 487 |
+
const int nb02, const int nb03, const int ne10, const int ne11,
|
| 488 |
+
const int ne12, const int ne13, const float sf0, const float sf1,
|
| 489 |
+
const float sf2, const float sf3, queue_ptr stream) {
|
| 490 |
+
int dst_size = ne10 * ne11 * ne12 * ne13;
|
| 491 |
+
int num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE;
|
| 492 |
+
sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE);
|
| 493 |
+
stream->parallel_for(
|
| 494 |
+
sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)),
|
| 495 |
+
[=](sycl::nd_item<1> item_ct1) {
|
| 496 |
+
upscale_f32(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, item_ct1);
|
| 497 |
+
});
|
| 498 |
+
}
|
| 499 |
+
|
| 500 |
+
void pad_f32_sycl(const float *x, float *dst, const int ne00,
|
| 501 |
+
const int ne01, const int ne02, const int ne0,
|
| 502 |
+
const int ne1, const int ne2, queue_ptr stream) {
|
| 503 |
+
int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE;
|
| 504 |
+
sycl::range<3> gridDim(ne2, ne1, num_blocks);
|
| 505 |
+
stream->parallel_for(
|
| 506 |
+
sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE),
|
| 507 |
+
sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)),
|
| 508 |
+
[=](sycl::nd_item<3> item_ct1) {
|
| 509 |
+
pad_f32(x, dst, ne0, ne00, ne01, ne02, item_ct1);
|
| 510 |
+
});
|
| 511 |
+
}
|
| 512 |
+
|
| 513 |
+
inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
| 514 |
+
ggml_tensor *dst, const float *src0_dd,
|
| 515 |
+
const float *src1_dd, float *dst_dd,
|
| 516 |
+
const queue_ptr &main_stream) {
|
| 517 |
+
|
| 518 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 519 |
+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 520 |
+
|
| 521 |
+
silu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
| 522 |
+
|
| 523 |
+
(void) src1;
|
| 524 |
+
(void) dst;
|
| 525 |
+
(void) src1_dd;
|
| 526 |
+
}
|
| 527 |
+
|
| 528 |
+
inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
| 529 |
+
ggml_tensor *dst, const float *src0_dd,
|
| 530 |
+
const float *src1_dd, float *dst_dd,
|
| 531 |
+
const queue_ptr &main_stream) {
|
| 532 |
+
|
| 533 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 534 |
+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 535 |
+
|
| 536 |
+
gelu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
| 537 |
+
|
| 538 |
+
(void) src1;
|
| 539 |
+
(void) dst;
|
| 540 |
+
(void) src1_dd;
|
| 541 |
+
}
|
| 542 |
+
inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
| 543 |
+
const ggml_tensor *src1, ggml_tensor *dst,
|
| 544 |
+
const float *src0_dd, const float *src1_dd,
|
| 545 |
+
float *dst_dd,
|
| 546 |
+
const queue_ptr &main_stream) {
|
| 547 |
+
|
| 548 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 549 |
+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 550 |
+
|
| 551 |
+
gelu_quick_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
| 552 |
+
|
| 553 |
+
(void) src1;
|
| 554 |
+
(void) dst;
|
| 555 |
+
(void) src1_dd;
|
| 556 |
+
}
|
| 557 |
+
|
| 558 |
+
inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
| 559 |
+
ggml_tensor *dst, const float *src0_dd,
|
| 560 |
+
const float *src1_dd, float *dst_dd,
|
| 561 |
+
const queue_ptr &main_stream) {
|
| 562 |
+
|
| 563 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 564 |
+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 565 |
+
tanh_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
| 566 |
+
|
| 567 |
+
(void) src1;
|
| 568 |
+
(void) dst;
|
| 569 |
+
(void) src1_dd;
|
| 570 |
+
}
|
| 571 |
+
|
| 572 |
+
inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
| 573 |
+
ggml_tensor *dst, const float *src0_dd,
|
| 574 |
+
const float *src1_dd, float *dst_dd,
|
| 575 |
+
const queue_ptr &main_stream) {
|
| 576 |
+
|
| 577 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 578 |
+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 579 |
+
|
| 580 |
+
relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
| 581 |
+
|
| 582 |
+
(void) src1;
|
| 583 |
+
(void) dst;
|
| 584 |
+
(void) src1_dd;
|
| 585 |
+
}
|
| 586 |
+
|
| 587 |
+
inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
| 588 |
+
const ggml_tensor *src1, ggml_tensor *dst,
|
| 589 |
+
const float *src0_dd, const float *src1_dd,
|
| 590 |
+
float *dst_dd,
|
| 591 |
+
const queue_ptr &main_stream) {
|
| 592 |
+
|
| 593 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 594 |
+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 595 |
+
|
| 596 |
+
hardsigmoid_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
| 597 |
+
|
| 598 |
+
(void) src1;
|
| 599 |
+
(void) dst;
|
| 600 |
+
(void) src1_dd;
|
| 601 |
+
}
|
| 602 |
+
|
| 603 |
+
inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
| 604 |
+
const ggml_tensor *src1, ggml_tensor *dst,
|
| 605 |
+
const float *src0_dd, const float *src1_dd,
|
| 606 |
+
float *dst_dd, const queue_ptr &main_stream) {
|
| 607 |
+
|
| 608 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 609 |
+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 610 |
+
|
| 611 |
+
hardswish_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
| 612 |
+
|
| 613 |
+
(void) src1;
|
| 614 |
+
(void) dst;
|
| 615 |
+
(void) src1_dd;
|
| 616 |
+
}
|
| 617 |
+
|
| 618 |
+
inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
| 619 |
+
const ggml_tensor *src1, ggml_tensor *dst,
|
| 620 |
+
const float *src0_dd, const float *src1_dd,
|
| 621 |
+
float *dst_dd, const queue_ptr &main_stream) {
|
| 622 |
+
|
| 623 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 624 |
+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 625 |
+
|
| 626 |
+
exp_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
| 627 |
+
|
| 628 |
+
(void) src1;
|
| 629 |
+
(void) dst;
|
| 630 |
+
(void) src1_dd;
|
| 631 |
+
}
|
| 632 |
+
|
| 633 |
+
inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
| 634 |
+
const ggml_tensor *src1, ggml_tensor *dst,
|
| 635 |
+
const float *src0_dd, const float *src1_dd,
|
| 636 |
+
float *dst_dd, const queue_ptr &main_stream) {
|
| 637 |
+
|
| 638 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 639 |
+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 640 |
+
|
| 641 |
+
log_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
| 642 |
+
|
| 643 |
+
(void) src1;
|
| 644 |
+
(void) dst;
|
| 645 |
+
(void) src1_dd;
|
| 646 |
+
}
|
| 647 |
+
|
| 648 |
+
inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
| 649 |
+
const ggml_tensor *src1, ggml_tensor *dst,
|
| 650 |
+
const float *src0_dd, const float *src1_dd,
|
| 651 |
+
float *dst_dd, const queue_ptr &main_stream) {
|
| 652 |
+
|
| 653 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 654 |
+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 655 |
+
|
| 656 |
+
sigmoid_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
| 657 |
+
|
| 658 |
+
(void) src1;
|
| 659 |
+
(void) dst;
|
| 660 |
+
(void) src1_dd;
|
| 661 |
+
}
|
| 662 |
+
|
| 663 |
+
inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
| 664 |
+
const ggml_tensor *src1, ggml_tensor *dst,
|
| 665 |
+
const float *src0_dd, const float *src1_dd,
|
| 666 |
+
float *dst_dd, const queue_ptr &main_stream) {
|
| 667 |
+
|
| 668 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 669 |
+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 670 |
+
|
| 671 |
+
sqrt_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
| 672 |
+
|
| 673 |
+
(void) src1;
|
| 674 |
+
(void) dst;
|
| 675 |
+
(void) src1_dd;
|
| 676 |
+
}
|
| 677 |
+
|
| 678 |
+
inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
| 679 |
+
const ggml_tensor *src1, ggml_tensor *dst,
|
| 680 |
+
const float *src0_dd, const float *src1_dd,
|
| 681 |
+
float *dst_dd, const queue_ptr &main_stream) {
|
| 682 |
+
|
| 683 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 684 |
+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 685 |
+
|
| 686 |
+
sin_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
| 687 |
+
|
| 688 |
+
(void) src1;
|
| 689 |
+
(void) dst;
|
| 690 |
+
(void) src1_dd;
|
| 691 |
+
}
|
| 692 |
+
|
| 693 |
+
inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
| 694 |
+
const ggml_tensor *src1, ggml_tensor *dst,
|
| 695 |
+
const float *src0_dd, const float *src1_dd,
|
| 696 |
+
float *dst_dd, const queue_ptr &main_stream) {
|
| 697 |
+
|
| 698 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 699 |
+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 700 |
+
|
| 701 |
+
cos_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
| 702 |
+
|
| 703 |
+
(void) src1;
|
| 704 |
+
(void) dst;
|
| 705 |
+
(void) src1_dd;
|
| 706 |
+
}
|
| 707 |
+
|
| 708 |
+
inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
| 709 |
+
const ggml_tensor *src1, ggml_tensor *dst,
|
| 710 |
+
const float *src0_dd, const float *src1_dd,
|
| 711 |
+
float *dst_dd, const queue_ptr &main_stream) {
|
| 712 |
+
|
| 713 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 714 |
+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 715 |
+
|
| 716 |
+
step_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
| 717 |
+
|
| 718 |
+
(void) src1;
|
| 719 |
+
(void) dst;
|
| 720 |
+
(void) src1_dd;
|
| 721 |
+
}
|
| 722 |
+
|
| 723 |
+
inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
| 724 |
+
const ggml_tensor *src1, ggml_tensor *dst,
|
| 725 |
+
const float *src0_dd, const float *src1_dd,
|
| 726 |
+
float *dst_dd, const queue_ptr &main_stream) {
|
| 727 |
+
|
| 728 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 729 |
+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 730 |
+
|
| 731 |
+
neg_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
| 732 |
+
|
| 733 |
+
(void) src1;
|
| 734 |
+
(void) dst;
|
| 735 |
+
(void) src1_dd;
|
| 736 |
+
}
|
| 737 |
+
|
| 738 |
+
inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
| 739 |
+
const ggml_tensor *src1, ggml_tensor *dst,
|
| 740 |
+
const float *src0_dd, const float *src1_dd,
|
| 741 |
+
float *dst_dd,
|
| 742 |
+
const queue_ptr &main_stream) {
|
| 743 |
+
|
| 744 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 745 |
+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 746 |
+
|
| 747 |
+
float negative_slope;
|
| 748 |
+
memcpy(&negative_slope, dst->op_params, sizeof(float));
|
| 749 |
+
|
| 750 |
+
leaky_relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), negative_slope, main_stream);
|
| 751 |
+
|
| 752 |
+
(void) src1;
|
| 753 |
+
(void) dst;
|
| 754 |
+
(void) src1_dd;
|
| 755 |
+
}
|
| 756 |
+
|
| 757 |
+
inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
| 758 |
+
ggml_tensor *dst, const float *src0_dd,
|
| 759 |
+
const float *src1_dd, float *dst_dd,
|
| 760 |
+
const queue_ptr &main_stream) {
|
| 761 |
+
|
| 762 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 763 |
+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 764 |
+
|
| 765 |
+
sqr_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
| 766 |
+
|
| 767 |
+
(void) src1;
|
| 768 |
+
(void) dst;
|
| 769 |
+
(void) src1_dd;
|
| 770 |
+
}
|
| 771 |
+
|
| 772 |
+
inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
| 773 |
+
const ggml_tensor *src1, ggml_tensor *dst,
|
| 774 |
+
const float *src0_dd, const float *src1_dd,
|
| 775 |
+
float *dst_dd,
|
| 776 |
+
const queue_ptr &main_stream) {
|
| 777 |
+
|
| 778 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 779 |
+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
| 780 |
+
|
| 781 |
+
const float sf0 = (float)dst->ne[0]/src0->ne[0];
|
| 782 |
+
const float sf1 = (float)dst->ne[1]/src0->ne[1];
|
| 783 |
+
const float sf2 = (float)dst->ne[2]/src0->ne[2];
|
| 784 |
+
const float sf3 = (float)dst->ne[3]/src0->ne[3];
|
| 785 |
+
|
| 786 |
+
upscale_f32_sycl(src0_dd, dst_dd, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
|
| 787 |
+
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3,
|
| 788 |
+
main_stream);
|
| 789 |
+
|
| 790 |
+
(void) src1;
|
| 791 |
+
(void) dst;
|
| 792 |
+
(void) src1_dd;
|
| 793 |
+
}
|
| 794 |
+
|
| 795 |
+
inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
| 796 |
+
ggml_tensor *dst, const float *src0_dd,
|
| 797 |
+
const float *src1_dd, float *dst_dd,
|
| 798 |
+
const queue_ptr &main_stream) {
|
| 799 |
+
|
| 800 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 801 |
+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
| 802 |
+
GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
|
| 803 |
+
|
| 804 |
+
pad_f32_sycl(src0_dd, dst_dd,
|
| 805 |
+
src0->ne[0], src0->ne[1], src0->ne[2],
|
| 806 |
+
dst->ne[0], dst->ne[1], dst->ne[2], main_stream);
|
| 807 |
+
|
| 808 |
+
(void) src1;
|
| 809 |
+
(void) dst;
|
| 810 |
+
(void) src1_dd;
|
| 811 |
+
}
|
| 812 |
+
|
| 813 |
+
inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
| 814 |
+
ggml_tensor *dst, const float *src0_dd,
|
| 815 |
+
const float *src1_dd, float *dst_dd,
|
| 816 |
+
const queue_ptr &main_stream) {
|
| 817 |
+
|
| 818 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 819 |
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
| 820 |
+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 821 |
+
GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported
|
| 822 |
+
|
| 823 |
+
int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
|
| 824 |
+
int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
|
| 825 |
+
// int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
|
| 826 |
+
int offset = dst->op_params[3] / 4; // offset in bytes
|
| 827 |
+
|
| 828 |
+
acc_f32_sycl(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], nb1, nb2, offset, main_stream);
|
| 829 |
+
|
| 830 |
+
(void) dst;
|
| 831 |
+
}
|
| 832 |
+
|
| 833 |
+
inline void ggml_sycl_op_add(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
| 834 |
+
ggml_tensor *dst, const float *src0_dd,
|
| 835 |
+
const float *src1_dd, float *dst_dd,
|
| 836 |
+
const queue_ptr &main_stream) {
|
| 837 |
+
|
| 838 |
+
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_add>>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
|
| 839 |
+
}
|
| 840 |
+
|
| 841 |
+
inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
| 842 |
+
ggml_tensor *dst, const float *src0_dd,
|
| 843 |
+
const float *src1_dd, float *dst_dd,
|
| 844 |
+
const queue_ptr &main_stream) {
|
| 845 |
+
|
| 846 |
+
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_sub>>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
|
| 847 |
+
}
|
| 848 |
+
|
| 849 |
+
inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
| 850 |
+
ggml_tensor *dst, const float *src0_dd,
|
| 851 |
+
const float *src1_dd, float *dst_dd,
|
| 852 |
+
const queue_ptr &main_stream) {
|
| 853 |
+
|
| 854 |
+
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_mul>>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
|
| 855 |
+
}
|
| 856 |
+
|
| 857 |
+
inline void ggml_sycl_op_div(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
| 858 |
+
ggml_tensor *dst, const float *src0_dd,
|
| 859 |
+
const float *src1_dd, float *dst_dd,
|
| 860 |
+
const queue_ptr &main_stream) {
|
| 861 |
+
|
| 862 |
+
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_div>>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
|
| 863 |
+
}
|
| 864 |
+
|
| 865 |
+
|
| 866 |
+
void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 867 |
+
GGML_SYCL_DEBUG("call %s\n", __func__);
|
| 868 |
+
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sqrt);
|
| 869 |
+
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
| 870 |
+
}
|
| 871 |
+
|
| 872 |
+
void ggml_sycl_sin(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 873 |
+
GGML_SYCL_DEBUG("call %s\n", __func__);
|
| 874 |
+
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sin);
|
| 875 |
+
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
| 876 |
+
}
|
| 877 |
+
|
| 878 |
+
void ggml_sycl_cos(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 879 |
+
GGML_SYCL_DEBUG("call %s\n", __func__);
|
| 880 |
+
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_cos);
|
| 881 |
+
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
| 882 |
+
}
|
| 883 |
+
|
| 884 |
+
void ggml_sycl_acc(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 885 |
+
GGML_SYCL_DEBUG("call %s\n", __func__);
|
| 886 |
+
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_acc);
|
| 887 |
+
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
| 888 |
+
}
|
| 889 |
+
|
| 890 |
+
void ggml_sycl_gelu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 891 |
+
GGML_SYCL_DEBUG("call %s\n", __func__);
|
| 892 |
+
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_gelu);
|
| 893 |
+
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
| 894 |
+
}
|
| 895 |
+
|
| 896 |
+
void ggml_sycl_silu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 897 |
+
GGML_SYCL_DEBUG("call %s\n", __func__);
|
| 898 |
+
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_silu);
|
| 899 |
+
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
| 900 |
+
}
|
| 901 |
+
|
| 902 |
+
void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 903 |
+
GGML_SYCL_DEBUG("call %s\n", __func__);
|
| 904 |
+
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_gelu_quick);
|
| 905 |
+
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
| 906 |
+
}
|
| 907 |
+
|
| 908 |
+
void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 909 |
+
GGML_SYCL_DEBUG("call %s\n", __func__);
|
| 910 |
+
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_tanh);
|
| 911 |
+
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
| 912 |
+
}
|
| 913 |
+
|
| 914 |
+
void ggml_sycl_relu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 915 |
+
GGML_SYCL_DEBUG("call %s\n", __func__);
|
| 916 |
+
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_relu);
|
| 917 |
+
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
| 918 |
+
}
|
| 919 |
+
|
| 920 |
+
void ggml_sycl_sigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 921 |
+
GGML_SYCL_DEBUG("call %s\n", __func__);
|
| 922 |
+
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sigmoid);
|
| 923 |
+
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
| 924 |
+
}
|
| 925 |
+
|
| 926 |
+
void ggml_sycl_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 927 |
+
GGML_SYCL_DEBUG("call %s\n", __func__);
|
| 928 |
+
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_hardsigmoid);
|
| 929 |
+
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
| 930 |
+
}
|
| 931 |
+
|
| 932 |
+
void ggml_sycl_hardswish(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 933 |
+
GGML_SYCL_DEBUG("call %s\n", __func__);
|
| 934 |
+
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_hardswish);
|
| 935 |
+
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
| 936 |
+
}
|
| 937 |
+
|
| 938 |
+
|
| 939 |
+
void ggml_sycl_exp(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 940 |
+
GGML_SYCL_DEBUG("call %s\n", __func__);
|
| 941 |
+
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_exp);
|
| 942 |
+
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
| 943 |
+
}
|
| 944 |
+
|
| 945 |
+
void ggml_sycl_log(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 946 |
+
GGML_SYCL_DEBUG("call %s\n", __func__);
|
| 947 |
+
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_log);
|
| 948 |
+
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
| 949 |
+
}
|
| 950 |
+
|
| 951 |
+
void ggml_sycl_neg(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 952 |
+
GGML_SYCL_DEBUG("call %s\n", __func__);
|
| 953 |
+
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_neg);
|
| 954 |
+
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
| 955 |
+
}
|
| 956 |
+
|
| 957 |
+
void ggml_sycl_step(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 958 |
+
GGML_SYCL_DEBUG("call %s\n", __func__);
|
| 959 |
+
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_step);
|
| 960 |
+
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
| 961 |
+
}
|
| 962 |
+
|
| 963 |
+
void ggml_sycl_leaky_relu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 964 |
+
GGML_SYCL_DEBUG("call %s\n", __func__);
|
| 965 |
+
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_leaky_relu);
|
| 966 |
+
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
| 967 |
+
}
|
| 968 |
+
|
| 969 |
+
void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 970 |
+
GGML_SYCL_DEBUG("call %s\n", __func__);
|
| 971 |
+
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sqr);
|
| 972 |
+
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
| 973 |
+
}
|
| 974 |
+
|
| 975 |
+
void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 976 |
+
GGML_SYCL_DEBUG("call %s\n", __func__);
|
| 977 |
+
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_upscale);
|
| 978 |
+
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
| 979 |
+
}
|
| 980 |
+
|
| 981 |
+
void ggml_sycl_pad(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 982 |
+
GGML_SYCL_DEBUG("call %s\n", __func__);
|
| 983 |
+
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_pad);
|
| 984 |
+
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
| 985 |
+
}
|
| 986 |
+
|
| 987 |
+
|
| 988 |
+
|
| 989 |
+
void ggml_sycl_add(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 990 |
+
GGML_SYCL_DEBUG("call %s\n", __func__);
|
| 991 |
+
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_add);
|
| 992 |
+
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
| 993 |
+
}
|
| 994 |
+
|
| 995 |
+
void ggml_sycl_sub(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 996 |
+
GGML_SYCL_DEBUG("call %s\n", __func__);
|
| 997 |
+
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sub);
|
| 998 |
+
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
| 999 |
+
}
|
| 1000 |
+
|
| 1001 |
+
void ggml_sycl_mul(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 1002 |
+
GGML_SYCL_DEBUG("call %s\n", __func__);
|
| 1003 |
+
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_mul);
|
| 1004 |
+
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
| 1005 |
+
}
|
| 1006 |
+
|
| 1007 |
+
void ggml_sycl_div(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 1008 |
+
GGML_SYCL_DEBUG("call %s\n", __func__);
|
| 1009 |
+
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_div);
|
| 1010 |
+
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
| 1011 |
+
}
|
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#ifndef GGML_SYCL_ELEMENTWISE_HPP
|
| 2 |
+
#define GGML_SYCL_ELEMENTWISE_HPP
|
| 3 |
+
|
| 4 |
+
#include "common.hpp"
|
| 5 |
+
|
| 6 |
+
static __dpct_inline__ float op_repeat(const float a, const float b) {
|
| 7 |
+
return b;
|
| 8 |
+
GGML_UNUSED(a);
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
static __dpct_inline__ float op_add(const float a, const float b) {
|
| 12 |
+
return a + b;
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
static __dpct_inline__ float op_sub(const float a, const float b) {
|
| 16 |
+
return a - b;
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
static __dpct_inline__ float op_mul(const float a, const float b) {
|
| 20 |
+
return a * b;
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
static __dpct_inline__ float op_div(const float a, const float b) {
|
| 24 |
+
return a / b;
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
| 29 |
+
|
| 30 |
+
void ggml_sycl_sin(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
| 31 |
+
|
| 32 |
+
void ggml_sycl_cos(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
| 33 |
+
|
| 34 |
+
void ggml_sycl_acc(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
| 35 |
+
|
| 36 |
+
void ggml_sycl_gelu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
| 37 |
+
|
| 38 |
+
void ggml_sycl_silu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
| 39 |
+
|
| 40 |
+
void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
| 41 |
+
|
| 42 |
+
void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
| 43 |
+
|
| 44 |
+
void ggml_sycl_relu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
| 45 |
+
|
| 46 |
+
void ggml_sycl_sigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
| 47 |
+
|
| 48 |
+
void ggml_sycl_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
| 49 |
+
|
| 50 |
+
void ggml_sycl_hardswish(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
| 51 |
+
|
| 52 |
+
void ggml_sycl_exp(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
| 53 |
+
|
| 54 |
+
void ggml_sycl_log(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
| 55 |
+
|
| 56 |
+
void ggml_sycl_neg(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
| 57 |
+
|
| 58 |
+
void ggml_sycl_step(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
| 59 |
+
|
| 60 |
+
void ggml_sycl_leaky_relu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
| 61 |
+
|
| 62 |
+
void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
| 63 |
+
|
| 64 |
+
void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
| 65 |
+
|
| 66 |
+
void ggml_sycl_pad(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
| 67 |
+
|
| 68 |
+
void ggml_sycl_add(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
| 69 |
+
|
| 70 |
+
void ggml_sycl_sub(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
| 71 |
+
|
| 72 |
+
void ggml_sycl_mul(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
| 73 |
+
|
| 74 |
+
void ggml_sycl_div(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
| 75 |
+
|
| 76 |
+
#endif // GGML_SYCL_ELEMENTWISE_HPP
|
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <sycl/sycl.hpp>
|
| 2 |
+
#include "outprod.hpp"
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
|
| 6 |
+
const ggml_tensor* src1, ggml_tensor* dst) {
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 10 |
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
| 11 |
+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
| 12 |
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
| 13 |
+
GGML_ASSERT(ggml_is_contiguous(dst));
|
| 14 |
+
|
| 15 |
+
GGML_TENSOR_BINARY_OP_LOCALS
|
| 16 |
+
|
| 17 |
+
// Get SYCL queue
|
| 18 |
+
dpct::queue_ptr stream = ctx.stream();
|
| 19 |
+
|
| 20 |
+
// Dimension checks
|
| 21 |
+
GGML_ASSERT(ne01 == ne11); // Inner dimensions must match
|
| 22 |
+
GGML_ASSERT(ne0 == ne00); // Output rows match src0 rows
|
| 23 |
+
GGML_ASSERT(ne1 == ne10); // Output cols match src1 cols
|
| 24 |
+
|
| 25 |
+
// Get data pointers
|
| 26 |
+
const float* src0_d = (const float*)src0->data;
|
| 27 |
+
const float* src1_d = (const float*)src1->data;
|
| 28 |
+
float* dst_d = (float*)dst->data;
|
| 29 |
+
|
| 30 |
+
// GEMM parameters
|
| 31 |
+
const float alpha = 1.0f;
|
| 32 |
+
const float beta = 0.0f;
|
| 33 |
+
|
| 34 |
+
// Handle transposition of src1
|
| 35 |
+
const bool src1_T = ggml_is_transposed(src1);
|
| 36 |
+
const oneapi::mkl::transpose src1_op =
|
| 37 |
+
src1_T ? oneapi::mkl::transpose::nontrans : oneapi::mkl::transpose::trans;
|
| 38 |
+
const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float);
|
| 39 |
+
|
| 40 |
+
try {
|
| 41 |
+
// Perform matrix multiplication using oneMKL GEMM
|
| 42 |
+
oneapi::mkl::blas::gemm(*stream,
|
| 43 |
+
oneapi::mkl::transpose::nontrans, src1_op,
|
| 44 |
+
ne0, ne1, ne01,
|
| 45 |
+
alpha,
|
| 46 |
+
src0_d, ne00,
|
| 47 |
+
src1_d, ldb,
|
| 48 |
+
beta,
|
| 49 |
+
dst_d, ne0);
|
| 50 |
+
}
|
| 51 |
+
catch (sycl::exception const& exc) {
|
| 52 |
+
std::cerr << exc.what() << std::endl;
|
| 53 |
+
GGML_ASSERT(false);
|
| 54 |
+
}
|
| 55 |
+
}
|
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#ifndef GGML_SYCL_OUTPROD_HPP
|
| 2 |
+
#define GGML_SYCL_OUTPROD_HPP
|
| 3 |
+
|
| 4 |
+
#include "common.hpp"
|
| 5 |
+
|
| 6 |
+
void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
|
| 7 |
+
const ggml_tensor* src1, ggml_tensor* dst);
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
#endif // GGML_SYCL_OUTPROD_HPP
|
| 11 |
+
|
|
@@ -25,6 +25,11 @@
|
|
| 25 |
#define SYCL_RELU_BLOCK_SIZE 256
|
| 26 |
#define SYCL_HARDSIGMOID_BLOCK_SIZE 256
|
| 27 |
#define SYCL_HARDSWISH_BLOCK_SIZE 256
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
#define SYCL_SQR_BLOCK_SIZE 256
|
| 29 |
#define SYCL_CPY_BLOCK_SIZE 32
|
| 30 |
#define SYCL_SCALE_BLOCK_SIZE 256
|
|
@@ -41,6 +46,7 @@
|
|
| 41 |
#define SYCL_ACC_BLOCK_SIZE 256
|
| 42 |
#define SYCL_IM2COL_BLOCK_SIZE 256
|
| 43 |
#define SYCL_POOL2D_BLOCK_SIZE 256
|
|
|
|
| 44 |
#define SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE 256
|
| 45 |
#define SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE 256
|
| 46 |
|
|
|
|
| 25 |
#define SYCL_RELU_BLOCK_SIZE 256
|
| 26 |
#define SYCL_HARDSIGMOID_BLOCK_SIZE 256
|
| 27 |
#define SYCL_HARDSWISH_BLOCK_SIZE 256
|
| 28 |
+
#define SYCL_EXP_BLOCK_SIZE 256
|
| 29 |
+
#define SYCL_NEG_BLOCK_SIZE 256
|
| 30 |
+
#define SYCL_SIGMOID_BLOCK_SIZE 256
|
| 31 |
+
#define SYCL_SQRT_BLOCK_SIZE 256
|
| 32 |
+
#define SYCL_SIN_BLOCK_SIZE 256
|
| 33 |
#define SYCL_SQR_BLOCK_SIZE 256
|
| 34 |
#define SYCL_CPY_BLOCK_SIZE 32
|
| 35 |
#define SYCL_SCALE_BLOCK_SIZE 256
|
|
|
|
| 46 |
#define SYCL_ACC_BLOCK_SIZE 256
|
| 47 |
#define SYCL_IM2COL_BLOCK_SIZE 256
|
| 48 |
#define SYCL_POOL2D_BLOCK_SIZE 256
|
| 49 |
+
#define SYCL_ARGMAX_BLOCK_SIZE 256
|
| 50 |
#define SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE 256
|
| 51 |
#define SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE 256
|
| 52 |
|
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <sycl/sycl.hpp>
|
| 2 |
+
#include "wkv6.hpp"
|
| 3 |
+
|
| 4 |
+
constexpr int WKV_BLOCK_SIZE = 64; // Matching CUDA_WKV_BLOCK_SIZE
|
| 5 |
+
|
| 6 |
+
// Helper function for the main kernel
|
| 7 |
+
static void rwkv_wkv_f32_kernel(
|
| 8 |
+
const int B, const int T, const int C, const int H,
|
| 9 |
+
const float* k, const float* v, const float* r,
|
| 10 |
+
const float* tf, const float* td, const float* s,
|
| 11 |
+
float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) {
|
| 12 |
+
|
| 13 |
+
const int tid = item_ct1.get_local_id(2);
|
| 14 |
+
const int bid = item_ct1.get_group(2);
|
| 15 |
+
|
| 16 |
+
const int head_size = WKV_BLOCK_SIZE;
|
| 17 |
+
const int batch_i = bid / H;
|
| 18 |
+
const int head_i = bid % H;
|
| 19 |
+
const int state_size = C * head_size;
|
| 20 |
+
const int n_seq_tokens = T / B;
|
| 21 |
+
|
| 22 |
+
// Set up shared memory pointers
|
| 23 |
+
float* _k = shared_mem;
|
| 24 |
+
float* _r = _k + head_size;
|
| 25 |
+
float* _tf = _r + head_size;
|
| 26 |
+
float* _td = _tf + head_size;
|
| 27 |
+
|
| 28 |
+
// Local state array
|
| 29 |
+
float state[WKV_BLOCK_SIZE];
|
| 30 |
+
|
| 31 |
+
// Load initial state
|
| 32 |
+
#pragma unroll
|
| 33 |
+
for (int i = 0; i < head_size; i++) {
|
| 34 |
+
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
// Sync threads before shared memory operations
|
| 38 |
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
| 39 |
+
|
| 40 |
+
// Load time-mixing parameters
|
| 41 |
+
_tf[tid] = tf[head_i * head_size + tid];
|
| 42 |
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
| 43 |
+
|
| 44 |
+
// Main sequence processing loop
|
| 45 |
+
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
|
| 46 |
+
t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid;
|
| 47 |
+
t += C) {
|
| 48 |
+
|
| 49 |
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
| 50 |
+
|
| 51 |
+
// Load current timestep data to shared memory
|
| 52 |
+
_k[tid] = k[t];
|
| 53 |
+
_r[tid] = r[t];
|
| 54 |
+
_td[tid] = td[t];
|
| 55 |
+
|
| 56 |
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
| 57 |
+
|
| 58 |
+
const float _v = v[t];
|
| 59 |
+
float y = 0;
|
| 60 |
+
|
| 61 |
+
// Process in chunks of 4 for better vectorization
|
| 62 |
+
sycl::float4 k4, r4, tf4, td4, s4, kv4;
|
| 63 |
+
#pragma unroll
|
| 64 |
+
for (int j = 0; j < head_size; j += 4) {
|
| 65 |
+
// Load data in vec4 chunks
|
| 66 |
+
k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
| 67 |
+
r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
| 68 |
+
tf4 = sycl::float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
|
| 69 |
+
td4 = sycl::float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
|
| 70 |
+
s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
|
| 71 |
+
|
| 72 |
+
// Compute key-value product
|
| 73 |
+
sycl::float4 kv4 = k4 * _v;
|
| 74 |
+
|
| 75 |
+
// Accumulate weighted sum
|
| 76 |
+
y += sycl::dot(r4, tf4 * kv4 + s4);
|
| 77 |
+
|
| 78 |
+
// Update state
|
| 79 |
+
s4 = s4 * td4 + kv4;
|
| 80 |
+
|
| 81 |
+
// Store updated state
|
| 82 |
+
state[j] = s4.x();
|
| 83 |
+
state[j+1] = s4.y();
|
| 84 |
+
state[j+2] = s4.z();
|
| 85 |
+
state[j+3] = s4.w();
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
dst[t] = y;
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
// Save final state
|
| 92 |
+
#pragma unroll
|
| 93 |
+
for (int i = 0; i < head_size; i++) {
|
| 94 |
+
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
|
| 95 |
+
}
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
|
| 99 |
+
const ggml_tensor* src1, ggml_tensor* dst) {
|
| 100 |
+
|
| 101 |
+
const float* k_d = (const float*)dst->src[0]->data;
|
| 102 |
+
const float* v_d = (const float*)dst->src[1]->data;
|
| 103 |
+
const float* r_d = (const float*)dst->src[2]->data;
|
| 104 |
+
const float* tf_d = (const float*)dst->src[3]->data;
|
| 105 |
+
const float* td_d = (const float*)dst->src[4]->data;
|
| 106 |
+
const float* s_d = (const float*)dst->src[5]->data;
|
| 107 |
+
float* dst_d = (float*)dst->data;
|
| 108 |
+
|
| 109 |
+
const int64_t B = dst->src[5]->ne[1];
|
| 110 |
+
const int64_t T = dst->src[0]->ne[3];
|
| 111 |
+
const int64_t C = dst->ne[0];
|
| 112 |
+
const int64_t H = dst->src[0]->ne[2];
|
| 113 |
+
|
| 114 |
+
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
|
| 115 |
+
GGML_ASSERT(C % H == 0);
|
| 116 |
+
GGML_ASSERT(C / H == WKV_BLOCK_SIZE); // The current sycl kernel is designed for RWKV6, HEAD_SIZE == 64
|
| 117 |
+
|
| 118 |
+
dpct::queue_ptr stream = ctx.stream();
|
| 119 |
+
|
| 120 |
+
// Calculate execution configuration
|
| 121 |
+
const size_t shared_mem_size = WKV_BLOCK_SIZE * 4 * sizeof(float); // For k, r, tf, td
|
| 122 |
+
sycl::range<3> block_dims(1, 1, C / H);
|
| 123 |
+
sycl::range<3> grid_dims(1, 1, B * H);
|
| 124 |
+
|
| 125 |
+
// Submit kernel
|
| 126 |
+
stream->submit([&](sycl::handler& cgh) {
|
| 127 |
+
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
| 128 |
+
|
| 129 |
+
cgh.parallel_for(
|
| 130 |
+
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
| 131 |
+
[=](sycl::nd_item<3> item_ct1) {
|
| 132 |
+
rwkv_wkv_f32_kernel(
|
| 133 |
+
B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
|
| 134 |
+
item_ct1, shared_mem_acc.get_pointer()
|
| 135 |
+
);
|
| 136 |
+
});
|
| 137 |
+
});
|
| 138 |
+
}
|
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#ifndef GGML_SYCL_WKV6_HPP
|
| 2 |
+
#define GGML_SYCL_WKV6_HPP
|
| 3 |
+
|
| 4 |
+
#include "common.hpp"
|
| 5 |
+
|
| 6 |
+
void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
| 7 |
+
const ggml_tensor *src1, ggml_tensor * dst);
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
#endif // GGML_SYCL_WKV6_HPP
|
|
@@ -975,7 +975,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
| 975 |
"WIN_UNPART",
|
| 976 |
"GET_REL_POS",
|
| 977 |
"ADD_REL_POS",
|
| 978 |
-
"
|
| 979 |
|
| 980 |
"UNARY",
|
| 981 |
|
|
@@ -1070,7 +1070,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
| 1070 |
"win_unpart(x)",
|
| 1071 |
"get_rel_pos(x)",
|
| 1072 |
"add_rel_pos(x)",
|
| 1073 |
-
"
|
| 1074 |
|
| 1075 |
"unary(x)",
|
| 1076 |
|
|
@@ -4503,9 +4503,9 @@ struct ggml_tensor * ggml_add_rel_pos_inplace(
|
|
| 4503 |
return ggml_add_rel_pos_impl(ctx, a, pw, ph, true);
|
| 4504 |
}
|
| 4505 |
|
| 4506 |
-
//
|
| 4507 |
|
| 4508 |
-
struct ggml_tensor *
|
| 4509 |
struct ggml_context * ctx,
|
| 4510 |
struct ggml_tensor * k,
|
| 4511 |
struct ggml_tensor * v,
|
|
@@ -4537,7 +4537,7 @@ struct ggml_tensor * ggml_rwkv_wkv(
|
|
| 4537 |
const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
|
| 4538 |
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
| 4539 |
|
| 4540 |
-
result->op =
|
| 4541 |
result->src[0] = k;
|
| 4542 |
result->src[1] = v;
|
| 4543 |
result->src[2] = r;
|
|
@@ -6084,7 +6084,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 6084 |
} break;
|
| 6085 |
case GGML_OP_GET_REL_POS:
|
| 6086 |
case GGML_OP_ADD_REL_POS:
|
| 6087 |
-
case
|
| 6088 |
case GGML_OP_MAP_UNARY:
|
| 6089 |
case GGML_OP_MAP_BINARY:
|
| 6090 |
case GGML_OP_MAP_CUSTOM1_F32:
|
|
|
|
| 975 |
"WIN_UNPART",
|
| 976 |
"GET_REL_POS",
|
| 977 |
"ADD_REL_POS",
|
| 978 |
+
"RWKV_WKV6",
|
| 979 |
|
| 980 |
"UNARY",
|
| 981 |
|
|
|
|
| 1070 |
"win_unpart(x)",
|
| 1071 |
"get_rel_pos(x)",
|
| 1072 |
"add_rel_pos(x)",
|
| 1073 |
+
"rwkv_wkv6(k, v, r, tf, td, s)",
|
| 1074 |
|
| 1075 |
"unary(x)",
|
| 1076 |
|
|
|
|
| 4503 |
return ggml_add_rel_pos_impl(ctx, a, pw, ph, true);
|
| 4504 |
}
|
| 4505 |
|
| 4506 |
+
// ggml_rwkv_wkv6
|
| 4507 |
|
| 4508 |
+
struct ggml_tensor * ggml_rwkv_wkv6(
|
| 4509 |
struct ggml_context * ctx,
|
| 4510 |
struct ggml_tensor * k,
|
| 4511 |
struct ggml_tensor * v,
|
|
|
|
| 4537 |
const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
|
| 4538 |
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
| 4539 |
|
| 4540 |
+
result->op = GGML_OP_RWKV_WKV6;
|
| 4541 |
result->src[0] = k;
|
| 4542 |
result->src[1] = v;
|
| 4543 |
result->src[2] = r;
|
|
|
|
| 6084 |
} break;
|
| 6085 |
case GGML_OP_GET_REL_POS:
|
| 6086 |
case GGML_OP_ADD_REL_POS:
|
| 6087 |
+
case GGML_OP_RWKV_WKV6:
|
| 6088 |
case GGML_OP_MAP_UNARY:
|
| 6089 |
case GGML_OP_MAP_BINARY:
|
| 6090 |
case GGML_OP_MAP_CUSTOM1_F32:
|