Zhiyuan Li ggerganov Diego Devesa pacominev Yuri Khrustalev Meng, Hengyu commited on
Commit
f58e658
·
1 Parent(s): 847669b

Optimize RWKV6 Operator Naming and Implement Multi-core CPU/ SYCL Acceleration (llama/10133)

Browse files

* rwkv6: rename to wkv6

* rwkv6: support avx2 avx512 armv8 armv9

* rwkv6: update cuda file name

* rwkv6: rename params

* wkv on sycl

* sycl: add some ops

* sycl: Enhance OP support judgment

* wkv6: drop armv9 and tranfer to GGML style

ggml-ci

* sync : ggml

* update the function to use appropriate types

* fix define error

* Update ggml/src/ggml-cpu.c

* add appropriate asserts

* move element-wise functions outside

* put the declaration outside the loop

* rewrite to be more inline with the common pattern for distributing threads

* use recommended way GGML_TENSOR_LOCALS

---------

Co-authored-by: Georgi Gerganov <[email protected]>
Co-authored-by: Diego Devesa <[email protected]>
Co-authored-by: Plamen Minev <[email protected]>
Co-authored-by: Yuri Khrustalev <[email protected]>
Co-authored-by: Meng, Hengyu <[email protected]>

ggml/include/ggml.h CHANGED
@@ -509,7 +509,7 @@ extern "C" {
509
  GGML_OP_WIN_UNPART,
510
  GGML_OP_GET_REL_POS,
511
  GGML_OP_ADD_REL_POS,
512
- GGML_OP_RWKV_WKV,
513
 
514
  GGML_OP_UNARY,
515
 
@@ -1819,7 +1819,7 @@ extern "C" {
1819
  struct ggml_tensor * pw,
1820
  struct ggml_tensor * ph);
1821
 
1822
- GGML_API struct ggml_tensor * ggml_rwkv_wkv(
1823
  struct ggml_context * ctx,
1824
  struct ggml_tensor * k,
1825
  struct ggml_tensor * v,
 
509
  GGML_OP_WIN_UNPART,
510
  GGML_OP_GET_REL_POS,
511
  GGML_OP_ADD_REL_POS,
512
+ GGML_OP_RWKV_WKV6,
513
 
514
  GGML_OP_UNARY,
515
 
 
1819
  struct ggml_tensor * pw,
1820
  struct ggml_tensor * ph);
1821
 
1822
+ GGML_API struct ggml_tensor * ggml_rwkv_wkv6(
1823
  struct ggml_context * ctx,
1824
  struct ggml_tensor * k,
1825
  struct ggml_tensor * v,
ggml/src/ggml-cpu.c CHANGED
@@ -11642,24 +11642,30 @@ static void ggml_compute_forward_add_rel_pos(
11642
  }
11643
  }
11644
 
11645
- // ggml_compute_forward_rwkv_wkv
11646
 
11647
- static void ggml_compute_forward_rwkv_wkv_f32(
11648
  const struct ggml_compute_params * params,
11649
  struct ggml_tensor * dst) {
11650
- const size_t T = dst->src[1]->ne[3];
11651
- const size_t C = dst->ne[0];
11652
- const size_t H = dst->src[1]->ne[2];
11653
- const size_t n_seqs = dst->src[5]->ne[1];
 
11654
 
11655
  float * dst_data = (float *) dst->data;
11656
  float * state = ((float *) dst->data) + C * T;
11657
 
11658
- if (params->ith != 0) {
 
 
 
11659
  return;
11660
  }
11661
 
11662
- memset(dst_data, 0, T * C * sizeof(float));
 
 
11663
 
11664
  float * k = (float *) dst->src[0]->data;
11665
  float * v = (float *) dst->src[1]->data;
@@ -11667,54 +11673,160 @@ static void ggml_compute_forward_rwkv_wkv_f32(
11667
  float * time_faaaa = (float *) dst->src[3]->data;
11668
  float * time_decay = (float *) dst->src[4]->data;
11669
 
11670
- size_t t_stride = H * (C / H);
11671
 
11672
- size_t h_stride = C / H;
11673
- size_t h_stride_2d = (C / H) * (C / H);
 
11674
 
11675
- // basically fused operations:
11676
- // dst = r @ (time_faaaa * (k @ v) + state),
11677
- // state = time_decay * state + (k @ v),
11678
- // recursive through each token
11679
- for (size_t t = 0; t < T; t++) {
11680
- size_t t_offset = t * t_stride;
11681
- size_t state_offset = (C / H) * C * (t / (T / n_seqs));
11682
- float * state_cur = state + state_offset;
11683
- float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
11684
 
11685
- for (size_t h = 0; h < H; h++) {
11686
- size_t h_offset = h * h_stride;
11687
- size_t t_h_offset = t_offset + h_offset;
11688
- size_t h_2d_offset = h * h_stride_2d;
11689
 
11690
- for (size_t i = 0; i < C / H; i++) {
11691
- size_t t_h_i_offset = t_h_offset + i;
11692
- size_t h_i_offset = h_offset + i;
11693
- size_t h_2d_i_offset = h_2d_offset + i * h_stride;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11694
 
11695
- float k_val = k[t_h_i_offset];
11696
- float r_val = r[t_h_i_offset];
11697
- float time_faaaa_val = time_faaaa[h_i_offset];
11698
- // RWKV v6: different time_decay for each token.
11699
- float time_decay_val = time_decay[t_h_i_offset];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11700
 
11701
- for (size_t j = 0; j < C / H; j ++) {
11702
- size_t t_h_j_offset = t_h_offset + j;
11703
- size_t h_2d_i_j_offset = h_2d_i_offset + j;
 
 
 
 
 
 
 
 
 
 
 
11704
 
11705
- float v_val = v[t_h_j_offset];
11706
- float kv_val = v_val * k_val;
11707
- float prev_state_val = state_prev[h_2d_i_j_offset];
11708
- float temp_val = kv_val * time_faaaa_val + prev_state_val;
11709
- dst_data[t_h_j_offset] += temp_val * r_val;
11710
- state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11711
  }
11712
  }
11713
  }
11714
- }
11715
  }
11716
 
11717
- static void ggml_compute_forward_rwkv_wkv(
 
11718
  const struct ggml_compute_params * params,
11719
  struct ggml_tensor * dst) {
11720
 
@@ -11723,7 +11835,7 @@ static void ggml_compute_forward_rwkv_wkv(
11723
  switch (src0->type) {
11724
  case GGML_TYPE_F32:
11725
  {
11726
- ggml_compute_forward_rwkv_wkv_f32(params, dst);
11727
  } break;
11728
  default:
11729
  {
@@ -12475,9 +12587,9 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
12475
  {
12476
  ggml_compute_forward_add_rel_pos(params, tensor);
12477
  } break;
12478
- case GGML_OP_RWKV_WKV:
12479
  {
12480
- ggml_compute_forward_rwkv_wkv(params, tensor);
12481
  } break;
12482
  case GGML_OP_MAP_UNARY:
12483
  {
@@ -12775,7 +12887,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
12775
  case GGML_OP_WIN_PART:
12776
  case GGML_OP_WIN_UNPART:
12777
  case GGML_OP_GET_REL_POS:
12778
- case GGML_OP_RWKV_WKV:
12779
  case GGML_OP_MAP_UNARY:
12780
  case GGML_OP_MAP_BINARY:
12781
  case GGML_OP_MAP_CUSTOM1_F32:
 
11642
  }
11643
  }
11644
 
11645
+ // ggml_compute_forward_rwkv_wkv6
11646
 
11647
+ static void ggml_compute_forward_rwkv_wkv6_f32(
11648
  const struct ggml_compute_params * params,
11649
  struct ggml_tensor * dst) {
11650
+ const int64_t T = dst->src[1]->ne[3];
11651
+ const int64_t C = dst->ne[0];
11652
+ const int64_t HEADS = dst->src[1]->ne[2];
11653
+ const int64_t n_seqs = dst->src[5]->ne[1];
11654
+ const int64_t head_size = C / HEADS;
11655
 
11656
  float * dst_data = (float *) dst->data;
11657
  float * state = ((float *) dst->data) + C * T;
11658
 
11659
+ const int ith = params->ith;
11660
+ const int nth = params->nth;
11661
+
11662
+ if (ith >= HEADS) {
11663
  return;
11664
  }
11665
 
11666
+ const int h_start = (HEADS * ith) / nth;
11667
+ const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
11668
+ (HEADS * (ith + 1)) / nth : HEADS;
11669
 
11670
  float * k = (float *) dst->src[0]->data;
11671
  float * v = (float *) dst->src[1]->data;
 
11673
  float * time_faaaa = (float *) dst->src[3]->data;
11674
  float * time_decay = (float *) dst->src[4]->data;
11675
 
11676
+ size_t t_stride = HEADS * head_size; // Same to C
11677
 
11678
+ size_t h_stride = C / HEADS;
11679
+ GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
11680
+ size_t h_stride_2d = head_size * head_size;
11681
 
11682
+ if (ith == 0) {
11683
+ memset(dst_data, 0, T * C * sizeof(float));
11684
+ }
11685
+ ggml_barrier(params->threadpool);
 
 
 
 
 
11686
 
 
 
 
 
11687
 
11688
+ #if defined(__AVX__) && !defined(__AVX512F__)
11689
+ #define GGML_F32X GGML_F32x8
11690
+ #define GGML_F32X_SET1 GGML_F32x8_SET1
11691
+ #define GGML_F32X_LOAD GGML_F32x8_LOAD
11692
+ #define GGML_F32X_STORE GGML_F32x8_STORE
11693
+ #define GGML_F32X_MUL GGML_F32x8_MUL
11694
+ #define GGML_F32X_FMA GGML_F32x8_FMA
11695
+ #define WKV_VECTOR_SIZE 8
11696
+ #elif defined(__AVX512F__)
11697
+ #define GGML_F32X GGML_F32x16
11698
+ #define GGML_F32X_SET1 GGML_F32x16_SET1
11699
+ #define GGML_F32X_LOAD GGML_F32x16_LOAD
11700
+ #define GGML_F32X_STORE GGML_F32x16_STORE
11701
+ #define GGML_F32X_MUL GGML_F32x16_MUL
11702
+ #define GGML_F32X_FMA GGML_F32x16_FMA
11703
+ #define WKV_VECTOR_SIZE 16
11704
+ #elif defined(__ARM_NEON) && defined(__aarch64__)
11705
+ #define GGML_F32X GGML_F32x4
11706
+ #define GGML_F32X_SET1 GGML_F32x4_SET1
11707
+ #define GGML_F32X_LOAD GGML_F32x4_LOAD
11708
+ #define GGML_F32X_STORE GGML_F32x4_STORE
11709
+ #define GGML_F32X_MUL GGML_F32x4_MUL
11710
+ #define GGML_F32X_FMA GGML_F32x4_FMA
11711
+ #define WKV_VECTOR_SIZE 4
11712
+ #endif
11713
 
11714
+ #ifdef WKV_VECTOR_SIZE
11715
+ const int64_t vec_count = head_size / WKV_VECTOR_SIZE;
11716
+
11717
+ for (int64_t t = 0; t < T; t++) {
11718
+ size_t t_offset = t * t_stride;
11719
+ size_t state_offset = head_size * C * (t / (T / n_seqs));
11720
+ float * state_cur = state + state_offset;
11721
+ float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
11722
+
11723
+ for (int64_t h = h_start; h < h_end; h++) {
11724
+ size_t h_offset = h * h_stride;
11725
+ size_t t_h_offset = t_offset + h_offset;
11726
+ size_t h_2d_offset = h * h_stride_2d;
11727
+
11728
+ for (int64_t i = 0; i < head_size; i++) {
11729
+ size_t t_h_i_offset = t_h_offset + i;
11730
+ size_t h_i_offset = h_offset + i;
11731
+ size_t h_2d_i_offset = h_2d_offset + i * h_stride;
11732
+
11733
+ float k_val = k[t_h_i_offset];
11734
+ float r_val = r[t_h_i_offset];
11735
+ float time_faaaa_val = time_faaaa[h_i_offset];
11736
+ float time_decay_val = time_decay[t_h_i_offset];
11737
+
11738
+ // Broadcast scalar values to vectors
11739
+ GGML_F32X k_vec = GGML_F32X_SET1(k_val);
11740
+ GGML_F32X r_vec = GGML_F32X_SET1(r_val);
11741
+ GGML_F32X time_faaaa_vec = GGML_F32X_SET1(time_faaaa_val);
11742
+ GGML_F32X time_decay_vec = GGML_F32X_SET1(time_decay_val);
11743
+
11744
+ for (int64_t j = 0; j < vec_count; j++) {
11745
+ size_t base_j = j * WKV_VECTOR_SIZE;
11746
+ size_t t_h_j_offset = t_h_offset + base_j;
11747
+ size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
11748
+
11749
+ // Load x elements at once
11750
+ GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]);
11751
+ GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);
11752
+ GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]);
11753
+
11754
+ // Compute kv = v * k
11755
+ GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec);
11756
+
11757
+ // Compute temp = kv * time_faaaa + prev_state
11758
+ GGML_F32X temp_vec = GGML_F32X_FMA(prev_state_vec, kv_vec, time_faaaa_vec);
11759
+
11760
+ // Update dst: dst += temp * r
11761
+ dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, r_vec);
11762
+ GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);
11763
+
11764
+ // Update state: state = prev_state * time_decay + kv
11765
+ GGML_F32X new_state_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, time_decay_vec);
11766
+ GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], new_state_vec);
11767
+ }
11768
 
11769
+ // Handle remaining elements, this will not be used.
11770
+ for (int64_t j = vec_count * WKV_VECTOR_SIZE; j < head_size; j++) {
11771
+ size_t t_h_j_offset = t_h_offset + j;
11772
+ size_t h_2d_i_j_offset = h_2d_i_offset + j;
11773
+ float v_val = v[t_h_j_offset];
11774
+ float kv_val = v_val * k_val;
11775
+ float prev_state_val = state_prev[h_2d_i_j_offset];
11776
+ float temp_val = kv_val * time_faaaa_val + prev_state_val;
11777
+ dst_data[t_h_j_offset] += temp_val * r_val;
11778
+ state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
11779
+ }
11780
+ }
11781
+ }
11782
+ }
11783
 
11784
+ #else
11785
+ // basically fused operations:
11786
+ // dst = r @ (time_faaaa * (k @ v) + state),
11787
+ // state = time_decay * state + (k @ v),
11788
+ // recursive through each token
11789
+ for (int64_t t = 0; t < T; t++) {
11790
+ size_t t_offset = t * t_stride;
11791
+ size_t state_offset = head_size * C * (t / (T / n_seqs));
11792
+ float * state_cur = state + state_offset;
11793
+ float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
11794
+
11795
+ for (int64_t h = h_start; h < h_end; h++) {
11796
+ size_t h_offset = h * h_stride;
11797
+ size_t t_h_offset = t_offset + h_offset;
11798
+ size_t h_2d_offset = h * h_stride_2d;
11799
+
11800
+ for (int64_t i = 0; i < head_size; i++) {
11801
+ size_t t_h_i_offset = t_h_offset + i;
11802
+ size_t h_i_offset = h_offset + i;
11803
+ size_t h_2d_i_offset = h_2d_offset + i * h_stride;
11804
+
11805
+ float k_val = k[t_h_i_offset];
11806
+ float r_val = r[t_h_i_offset];
11807
+ float time_faaaa_val = time_faaaa[h_i_offset];
11808
+ // RWKV v6: different time_decay for each token.
11809
+ float time_decay_val = time_decay[t_h_i_offset];
11810
+
11811
+ for (int64_t j = 0; j < head_size; j++) {
11812
+ size_t t_h_j_offset = t_h_offset + j;
11813
+ size_t h_2d_i_j_offset = h_2d_i_offset + j;
11814
+
11815
+ float v_val = v[t_h_j_offset];
11816
+ float kv_val = v_val * k_val;
11817
+ float prev_state_val = state_prev[h_2d_i_j_offset];
11818
+ float temp_val = kv_val * time_faaaa_val + prev_state_val;
11819
+ dst_data[t_h_j_offset] += temp_val * r_val;
11820
+ state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
11821
+ }
11822
  }
11823
  }
11824
  }
11825
+ #endif
11826
  }
11827
 
11828
+
11829
+ static void ggml_compute_forward_rwkv_wkv6(
11830
  const struct ggml_compute_params * params,
11831
  struct ggml_tensor * dst) {
11832
 
 
11835
  switch (src0->type) {
11836
  case GGML_TYPE_F32:
11837
  {
11838
+ ggml_compute_forward_rwkv_wkv6_f32(params, dst);
11839
  } break;
11840
  default:
11841
  {
 
12587
  {
12588
  ggml_compute_forward_add_rel_pos(params, tensor);
12589
  } break;
12590
+ case GGML_OP_RWKV_WKV6:
12591
  {
12592
+ ggml_compute_forward_rwkv_wkv6(params, tensor);
12593
  } break;
12594
  case GGML_OP_MAP_UNARY:
12595
  {
 
12887
  case GGML_OP_WIN_PART:
12888
  case GGML_OP_WIN_UNPART:
12889
  case GGML_OP_GET_REL_POS:
12890
+ case GGML_OP_RWKV_WKV6:
12891
  case GGML_OP_MAP_UNARY:
12892
  case GGML_OP_MAP_BINARY:
12893
  case GGML_OP_MAP_CUSTOM1_F32:
ggml/src/ggml-cuda.cu CHANGED
@@ -36,7 +36,7 @@
36
  #include "ggml-cuda/tsembd.cuh"
37
  #include "ggml-cuda/unary.cuh"
38
  #include "ggml-cuda/upscale.cuh"
39
- #include "ggml-cuda/rwkv-wkv.cuh"
40
 
41
  #include <algorithm>
42
  #include <array>
@@ -2319,8 +2319,8 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2319
  case GGML_OP_CROSS_ENTROPY_LOSS:
2320
  ggml_cuda_cross_entropy_loss(ctx, dst);
2321
  break;
2322
- case GGML_OP_RWKV_WKV:
2323
- ggml_cuda_op_rwkv_wkv(ctx, dst);
2324
  break;
2325
  case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
2326
  ggml_cuda_cross_entropy_loss_back(ctx, dst);
@@ -3153,7 +3153,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3153
  case GGML_OP_ARANGE:
3154
  case GGML_OP_TIMESTEP_EMBEDDING:
3155
  case GGML_OP_LEAKY_RELU:
3156
- case GGML_OP_RWKV_WKV:
3157
  return true;
3158
  case GGML_OP_FLASH_ATTN_EXT: {
3159
  #ifndef FLASH_ATTN_AVAILABLE
 
36
  #include "ggml-cuda/tsembd.cuh"
37
  #include "ggml-cuda/unary.cuh"
38
  #include "ggml-cuda/upscale.cuh"
39
+ #include "ggml-cuda/wkv6.cuh"
40
 
41
  #include <algorithm>
42
  #include <array>
 
2319
  case GGML_OP_CROSS_ENTROPY_LOSS:
2320
  ggml_cuda_cross_entropy_loss(ctx, dst);
2321
  break;
2322
+ case GGML_OP_RWKV_WKV6:
2323
+ ggml_cuda_op_rwkv_wkv6(ctx, dst);
2324
  break;
2325
  case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
2326
  ggml_cuda_cross_entropy_loss_back(ctx, dst);
 
3153
  case GGML_OP_ARANGE:
3154
  case GGML_OP_TIMESTEP_EMBEDDING:
3155
  case GGML_OP_LEAKY_RELU:
3156
+ case GGML_OP_RWKV_WKV6:
3157
  return true;
3158
  case GGML_OP_FLASH_ATTN_EXT: {
3159
  #ifndef FLASH_ATTN_AVAILABLE
ggml/src/ggml-cuda/wkv6.cu ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "common.cuh"
2
+ #include "wkv6.cuh"
3
+
4
+ static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {
5
+ const int tid = threadIdx.x;
6
+ const int bid = blockIdx.x;
7
+
8
+ const int head_size = CUDA_WKV_BLOCK_SIZE;
9
+ const int batch_i = bid / H;
10
+ const int head_i = bid % H;
11
+ const int state_size = C * head_size;
12
+ const int n_seq_tokens = T / B;
13
+
14
+ float state[head_size];
15
+ __shared__ float _k[head_size], _r[head_size], _tf[head_size], _td[head_size];
16
+
17
+ #pragma unroll
18
+ for (int i = 0; i < head_size; i++) {
19
+ state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
20
+ }
21
+
22
+ __syncthreads();
23
+ _tf[tid] = tf[head_i * head_size + tid];
24
+ __syncthreads();
25
+
26
+ for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
27
+ __syncthreads();
28
+ _k[tid] = k[t];
29
+ _r[tid] = r[t];
30
+ _td[tid] = td[t];
31
+ __syncthreads();
32
+
33
+ const float _v = v[t];
34
+ float y = 0;
35
+ for (int j = 0; j < head_size; j += 4) {
36
+ const float4& k = (float4&)(_k[j]);
37
+ const float4& r = (float4&)(_r[j]);
38
+ const float4& tf = (float4&)(_tf[j]);
39
+ const float4& td = (float4&)(_td[j]);
40
+ float4& s = (float4&)(state[j]);
41
+ float4 kv;
42
+
43
+ kv.x = k.x * _v;
44
+ kv.y = k.y * _v;
45
+ kv.z = k.z * _v;
46
+ kv.w = k.w * _v;
47
+
48
+ y += r.x * (tf.x * kv.x + s.x);
49
+ y += r.y * (tf.y * kv.y + s.y);
50
+ y += r.z * (tf.z * kv.z + s.z);
51
+ y += r.w * (tf.w * kv.w + s.w);
52
+
53
+ s.x = s.x * td.x + kv.x;
54
+ s.y = s.y * td.y + kv.y;
55
+ s.z = s.z * td.z + kv.z;
56
+ s.w = s.w * td.w + kv.w;
57
+ }
58
+ dst[t] = y;
59
+ }
60
+
61
+ #pragma unroll
62
+ for (int i = 0; i < head_size; i++) {
63
+ dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
64
+ }
65
+ }
66
+
67
+ void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
68
+ const float * k_d = (const float *)dst->src[0]->data;
69
+ const float * v_d = (const float *)dst->src[1]->data;
70
+ const float * r_d = (const float *)dst->src[2]->data;
71
+ const float * tf_d = (const float *)dst->src[3]->data;
72
+ const float * td_d = (const float *)dst->src[4]->data;
73
+ const float * s_d = (const float *)dst->src[5]->data;
74
+
75
+ const int64_t B = dst->src[5]->ne[1];
76
+ const int64_t T = dst->src[0]->ne[3];
77
+ const int64_t C = dst->ne[0];
78
+ const int64_t H = dst->src[0]->ne[2];
79
+
80
+ float * dst_d = (float *)dst->data;
81
+
82
+ cudaStream_t stream = ctx.stream();
83
+
84
+ GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
85
+ GGML_ASSERT(C % H == 0);
86
+ GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE); // The current cuda kernel is designed for RWKV6, HEAD_SIZE == 64
87
+
88
+ rwkv_wkv_f32<<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
89
+ }
ggml/src/ggml-cuda/wkv6.cuh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #include "common.cuh"
2
+
3
+ #define CUDA_WKV_BLOCK_SIZE 64
4
+
5
+ void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
ggml/src/ggml-sycl.cpp CHANGED
@@ -1194,272 +1194,8 @@ typedef void (*ggml_sycl_op_mul_mat_t)(
1194
  float *dst_dd_i, const int64_t row_low, const int64_t row_high,
1195
  const int64_t src1_ncols, const int64_t src1_padded_row_size,
1196
  const queue_ptr &stream);
1197
- typedef void (*ggml_sycl_op_flatten_t)(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
1198
- const ggml_tensor *src1,
1199
- ggml_tensor *dst, const float *src0_dd,
1200
- const float *src1_dd, float *dst_dd,
1201
- const queue_ptr &main_stream);
1202
-
1203
- static __dpct_inline__ float op_repeat(const float a, const float b) {
1204
- return b;
1205
- GGML_UNUSED(a);
1206
- }
1207
-
1208
- static __dpct_inline__ float op_add(const float a, const float b) {
1209
- return a + b;
1210
- }
1211
-
1212
- static __dpct_inline__ float op_mul(const float a, const float b) {
1213
- return a * b;
1214
- }
1215
-
1216
- static __dpct_inline__ float op_div(const float a, const float b) {
1217
- return a / b;
1218
- }
1219
-
1220
- template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
1221
- static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
1222
- int ne0, int ne1, int ne2, int ne3,
1223
- int ne10, int ne11, int ne12, int ne13,
1224
- /*int s0, */ int s1, int s2, int s3,
1225
- /*int s10,*/ int s11, int s12, int s13,
1226
- const sycl::nd_item<3> &item_ct1) {
1227
- const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
1228
- item_ct1.get_local_id(2);
1229
- const int i1 = (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
1230
- item_ct1.get_local_id(1));
1231
- const int i2 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
1232
- item_ct1.get_local_id(0)) /
1233
- ne3;
1234
- const int i3 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
1235
- item_ct1.get_local_id(0)) %
1236
- ne3;
1237
-
1238
- if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
1239
- return;
1240
- }
1241
-
1242
- const int i11 = i1 % ne11;
1243
- const int i12 = i2 % ne12;
1244
- const int i13 = i3 % ne13;
1245
-
1246
- const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
1247
- const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
1248
- const size_t i_dst = i_src0;
1249
-
1250
- const src0_t * src0_row = src0 + i_src0;
1251
- const src1_t * src1_row = src1 + i_src1;
1252
- dst_t * dst_row = dst + i_dst;
1253
-
1254
- for (int i0 = i0s; i0 < ne0;
1255
- i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) {
1256
- const int i10 = i0 % ne10;
1257
- dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
1258
- }
1259
- }
1260
 
1261
- template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
1262
- static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
1263
- int ne0, int ne1, int ne2, int ne3,
1264
- int ne10, int ne11, int ne12, int ne13,
1265
- /*int s0, */ int s1, int s2, int s3,
1266
- /*int s10,*/ int s11, int s12, int s13,
1267
- const sycl::nd_item<3> &item_ct1) {
1268
 
1269
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
1270
- item_ct1.get_local_id(2);
1271
-
1272
- const int i3 = i/(ne2*ne1*ne0);
1273
- const int i2 = (i/(ne1*ne0)) % ne2;
1274
- const int i1 = (i/ne0) % ne1;
1275
- const int i0 = i % ne0;
1276
-
1277
- if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
1278
- return;
1279
- }
1280
-
1281
- const int i11 = i1 % ne11;
1282
- const int i12 = i2 % ne12;
1283
- const int i13 = i3 % ne13;
1284
-
1285
- const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
1286
- const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
1287
- const size_t i_dst = i_src0;
1288
-
1289
- const src0_t * src0_row = src0 + i_src0;
1290
- const src1_t * src1_row = src1 + i_src1;
1291
- dst_t * dst_row = dst + i_dst;
1292
-
1293
- const int i10 = i0 % ne10;
1294
- dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
1295
- }
1296
-
1297
- static void acc_f32(const float * x, const float * y, float * dst, const int ne,
1298
- const int ne10, const int ne11, const int ne12,
1299
- const int nb1, const int nb2, int offset, const sycl::nd_item<3> &item_ct1) {
1300
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
1301
- item_ct1.get_local_id(2);
1302
- if (i >= ne) {
1303
- return;
1304
- }
1305
- int src1_idx = i - offset;
1306
- int oz = src1_idx / nb2;
1307
- int oy = (src1_idx - (oz * nb2)) / nb1;
1308
- int ox = src1_idx % nb1;
1309
- if (src1_idx >= 0 && ox < ne10 && oy < ne11 && oz < ne12) {
1310
- dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11];
1311
- } else {
1312
- dst[i] = x[i];
1313
- }
1314
- }
1315
-
1316
- static void gelu_f32(const float * x, float * dst, const int k,
1317
- const sycl::nd_item<3> &item_ct1) {
1318
- const float GELU_COEF_A = 0.044715f;
1319
- const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
1320
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
1321
- item_ct1.get_local_id(2);
1322
-
1323
- if (i >= k) {
1324
- return;
1325
- }
1326
-
1327
- float xi = x[i];
1328
- dst[i] = 0.5f * xi *
1329
- (1.0f +
1330
- sycl::tanh(SQRT_2_OVER_PI * xi * (1.0f + GELU_COEF_A * xi * xi)));
1331
- }
1332
-
1333
- static void silu_f32(const float * x, float * dst, const int k,
1334
- const sycl::nd_item<3> &item_ct1) {
1335
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
1336
- item_ct1.get_local_id(2);
1337
-
1338
- if (i >= k) {
1339
- return;
1340
- }
1341
- dst[i] = x[i] / (1.0f + sycl::native::exp(-x[i]));
1342
- }
1343
-
1344
- static void gelu_quick_f32(const float *x, float *dst, int k,
1345
- const sycl::nd_item<3> &item_ct1) {
1346
- const float GELU_QUICK_COEF = -1.702f;
1347
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
1348
- item_ct1.get_local_id(2);
1349
- if (i >= k) {
1350
- return;
1351
- }
1352
- dst[i] = x[i] * (1.0f / (1.0f + sycl::native::exp(GELU_QUICK_COEF * x[i])));
1353
- }
1354
-
1355
- static void tanh_f32(const float *x, float *dst, int k,
1356
- const sycl::nd_item<3> &item_ct1) {
1357
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
1358
- item_ct1.get_local_id(2);
1359
- if (i >= k) {
1360
- return;
1361
- }
1362
- dst[i] = sycl::tanh((float)(x[i]));
1363
- }
1364
-
1365
- static void relu_f32(const float * x, float * dst, const int k,
1366
- const sycl::nd_item<3> &item_ct1) {
1367
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
1368
- item_ct1.get_local_id(2);
1369
-
1370
- if (i >= k) {
1371
- return;
1372
- }
1373
- dst[i] = sycl::fmax((float)(x[i]), (float)0);
1374
- }
1375
-
1376
- static void hardsigmoid_f32(const float * x, float * dst, const int k,
1377
- const sycl::nd_item<3> &item_ct1) {
1378
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
1379
- item_ct1.get_local_id(2);
1380
-
1381
- if (i >= k) {
1382
- return;
1383
- }
1384
- dst[i] = sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f));
1385
- }
1386
-
1387
- static void hardswish_f32(const float * x, float * dst, const int k,
1388
- const sycl::nd_item<3> &item_ct1) {
1389
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
1390
- item_ct1.get_local_id(2);
1391
-
1392
- if (i >= k) {
1393
- return;
1394
- }
1395
- dst[i] = x[i] * sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f));
1396
- }
1397
-
1398
- static void leaky_relu_f32(const float *x, float *dst, const int k, const float negative_slope,
1399
- const sycl::nd_item<3> &item_ct1) {
1400
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
1401
- item_ct1.get_local_id(2);
1402
- if (i >= k) {
1403
- return;
1404
- }
1405
- dst[i] = sycl::fmax((float)(x[i]), (float)0) +
1406
- sycl::fmin((float)(x[i]), 0.0f) * negative_slope;
1407
- }
1408
-
1409
- static void sqr_f32(const float * x, float * dst, const int k,
1410
- const sycl::nd_item<3> &item_ct1) {
1411
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
1412
- item_ct1.get_local_id(2);
1413
-
1414
- if (i >= k) {
1415
- return;
1416
- }
1417
- dst[i] = x[i] * x[i];
1418
- }
1419
-
1420
- static void upscale_f32(const float *x, float *dst, const int nb00, const int nb01,
1421
- const int nb02, const int nb03, const int ne10, const int ne11,
1422
- const int ne12, const int ne13, const float sf0, const float sf1,
1423
- const float sf2, const float sf3, const sycl::nd_item<1> &item_ct1) {
1424
- int index = item_ct1.get_local_id(0) +
1425
- item_ct1.get_group(0) * item_ct1.get_local_range(0);
1426
- if (index >= ne10 * ne11 * ne12 * ne13) {
1427
- return;
1428
- }
1429
- // operation
1430
- int i10 = index % ne10;
1431
- int i11 = (index / ne10) % ne11;
1432
- int i12 = (index / (ne10 * ne11)) % ne12;
1433
- int i13 = (index / (ne10 * ne11 * ne12)) % ne13;
1434
-
1435
- int i00 = i10 / sf0;
1436
- int i01 = i11 / sf1;
1437
- int i02 = i12 / sf2;
1438
- int i03 = i13 / sf3;
1439
-
1440
- dst[index] = *(float *)((char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
1441
- }
1442
-
1443
- static void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02,
1444
- const sycl::nd_item<3> &item_ct1) {
1445
- int nidx = item_ct1.get_local_id(2) +
1446
- item_ct1.get_group(2) * item_ct1.get_local_range(2);
1447
- if (nidx >= ne0) {
1448
- return;
1449
- }
1450
-
1451
- // operation
1452
- int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
1453
- item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
1454
- if (nidx < ne00 && item_ct1.get_group(1) < ne01 &&
1455
- item_ct1.get_group(0) < ne02) {
1456
- int offset_src = nidx + item_ct1.get_group(1) * ne00 +
1457
- item_ct1.get_group(0) * ne00 * ne01;
1458
- dst[offset_dst] = x[offset_src];
1459
- } else {
1460
- dst[offset_dst] = 0.0f;
1461
- }
1462
- }
1463
 
1464
  template<int QUANT_BLOCK_TILE>
1465
  static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded,
@@ -2148,297 +1884,6 @@ static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tens
2148
  (void) dst;
2149
  }
2150
 
2151
- template<float (*bin_op)(const float, const float)>
2152
- struct bin_bcast_sycl {
2153
- template <typename src0_t, typename src1_t, typename dst_t>
2154
- void operator()(ggml_backend_sycl_context & ctx,
2155
- const struct ggml_tensor *src0,
2156
- const struct ggml_tensor *src1, struct ggml_tensor *dst,
2157
- const src0_t *src0_dd, const src1_t *src1_dd, dst_t *dst_dd,
2158
- queue_ptr stream) {
2159
-
2160
- GGML_TENSOR_BINARY_OP_LOCALS
2161
-
2162
- int nr0 = ne10/ne0;
2163
- int nr1 = ne11/ne1;
2164
- int nr2 = ne12/ne2;
2165
- int nr3 = ne13/ne3;
2166
-
2167
- int nr[4] = { nr0, nr1, nr2, nr3 };
2168
-
2169
- // collapse dimensions until first broadcast dimension
2170
- int64_t cne0[] = {ne0, ne1, ne2, ne3};
2171
- int64_t cne1[] = {ne10, ne11, ne12, ne13};
2172
- size_t cnb0[] = {nb0, nb1, nb2, nb3};
2173
- size_t cnb1[] = {nb10, nb11, nb12, nb13};
2174
- auto collapse = [](int64_t cne[]) {
2175
- cne[0] *= cne[1];
2176
- cne[1] = cne[2];
2177
- cne[2] = cne[3];
2178
- cne[3] = 1;
2179
- };
2180
-
2181
- auto collapse_nb = [](size_t cnb[], int64_t cne[]) {
2182
- cnb[1] *= cne[1];
2183
- cnb[2] *= cne[2];
2184
- cnb[3] *= cne[3];
2185
- };
2186
-
2187
- for (int i = 0; i < 4; i++) {
2188
- if (nr[i] != 1) {
2189
- break;
2190
- }
2191
- if (i > 0) {
2192
- collapse_nb(cnb0, cne0);
2193
- collapse_nb(cnb1, cne1);
2194
- collapse(cne0);
2195
- collapse(cne1);
2196
- }
2197
- }
2198
- {
2199
- int64_t ne0 = cne0[0];
2200
- int64_t ne1 = cne0[1];
2201
- int64_t ne2 = cne0[2];
2202
- int64_t ne3 = cne0[3];
2203
-
2204
- int64_t ne10 = cne1[0];
2205
- int64_t ne11 = cne1[1];
2206
- int64_t ne12 = cne1[2];
2207
- int64_t ne13 = cne1[3];
2208
-
2209
- size_t nb0 = cnb0[0];
2210
- size_t nb1 = cnb0[1];
2211
- size_t nb2 = cnb0[2];
2212
- size_t nb3 = cnb0[3];
2213
-
2214
- size_t nb10 = cnb1[0];
2215
- size_t nb11 = cnb1[1];
2216
- size_t nb12 = cnb1[2];
2217
- size_t nb13 = cnb1[3];
2218
-
2219
- size_t s0 = nb0 / sizeof(dst_t);
2220
- size_t s1 = nb1 / sizeof(dst_t);
2221
- size_t s2 = nb2 / sizeof(dst_t);
2222
- size_t s3 = nb3 / sizeof(dst_t);
2223
-
2224
- size_t s10 = nb10 / sizeof(src1_t);
2225
- size_t s11 = nb11 / sizeof(src1_t);
2226
- size_t s12 = nb12 / sizeof(src1_t);
2227
- size_t s13 = nb13 / sizeof(src1_t);
2228
-
2229
- GGML_ASSERT(s0 == 1);
2230
- GGML_ASSERT(s10 == 1);
2231
-
2232
- const int block_size = 128;
2233
-
2234
- int64_t hne0 = std::max(ne0/2LL, 1LL);
2235
-
2236
- sycl::range<3> block_dims(1, 1, 1);
2237
- block_dims[2] = std::min<unsigned int>(hne0, block_size);
2238
- block_dims[1] = std::min<unsigned int>(
2239
- ne1, block_size / (unsigned int)block_dims[2]);
2240
- block_dims[0] = std::min(
2241
- std::min<unsigned int>(
2242
- ne2 * ne3, block_size / (unsigned int)block_dims[2] /
2243
- (unsigned int)block_dims[1]),
2244
- 64U);
2245
-
2246
- sycl::range<3> block_nums(
2247
- (ne2 * ne3 + block_dims[0] - 1) / block_dims[0],
2248
- (ne1 + block_dims[1] - 1) / block_dims[1],
2249
- (hne0 + block_dims[2] - 1) / block_dims[2]);
2250
-
2251
- if (block_nums[0] > 65535) {
2252
- // this is the maximum number of blocks in z direction, fallback to 1D grid kernel
2253
- int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
2254
- {
2255
- dpct::has_capability_or_fail(stream->get_device(),
2256
- {sycl::aspect::fp16});
2257
-
2258
- stream->parallel_for(
2259
- sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) *
2260
- sycl::range<3>(1, 1, block_size),
2261
- sycl::range<3>(1, 1, block_size)),
2262
- [=](sycl::nd_item<3> item_ct1) {
2263
- k_bin_bcast_unravel<bin_op>(
2264
- src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3,
2265
- ne10, ne11, ne12, ne13, s1, s2, s3, s11, s12,
2266
- s13, item_ct1);
2267
- });
2268
- }
2269
- } else {
2270
- /*
2271
- DPCT1049:16: The work-group size passed to the SYCL kernel may
2272
- exceed the limit. To get the device limit, query
2273
- info::device::max_work_group_size. Adjust the work-group size if
2274
- needed.
2275
- */
2276
- dpct::has_capability_or_fail(stream->get_device(),
2277
- {sycl::aspect::fp16});
2278
-
2279
- stream->parallel_for(
2280
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
2281
- [=](sycl::nd_item<3> item_ct1) {
2282
- k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
2283
- ne2, ne3, ne10, ne11, ne12, ne13,
2284
- s1, s2, s3, s11, s12, s13,
2285
- item_ct1);
2286
- });
2287
- }
2288
- }
2289
- }
2290
- };
2291
-
2292
- static void acc_f32_sycl(const float *x, const float *y, float *dst,
2293
- const int n_elements, const int ne10, const int ne11,
2294
- const int ne12, const int nb1, const int nb2,
2295
- const int offset, queue_ptr stream) {
2296
- int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE;
2297
- stream->parallel_for(
2298
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
2299
- sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE),
2300
- sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)),
2301
- [=](sycl::nd_item<3> item_ct1) {
2302
- acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset,
2303
- item_ct1);
2304
- });
2305
- }
2306
-
2307
- static void gelu_f32_sycl(const float *x, float *dst, const int k,
2308
- queue_ptr stream) {
2309
- const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
2310
- stream->parallel_for(
2311
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
2312
- sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
2313
- sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
2314
- [=](sycl::nd_item<3> item_ct1) {
2315
- gelu_f32(x, dst, k, item_ct1);
2316
- });
2317
- }
2318
-
2319
- static void silu_f32_sycl(const float *x, float *dst, const int k,
2320
- queue_ptr stream) {
2321
- const int num_blocks = (k + SYCL_SILU_BLOCK_SIZE - 1) / SYCL_SILU_BLOCK_SIZE;
2322
- stream->parallel_for(
2323
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
2324
- sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE),
2325
- sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE)),
2326
- [=](sycl::nd_item<3> item_ct1) {
2327
- silu_f32(x, dst, k, item_ct1);
2328
- });
2329
- }
2330
-
2331
- static void gelu_quick_f32_sycl(const float *x, float *dst, const int k,
2332
- queue_ptr stream) {
2333
- const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
2334
- stream->parallel_for(
2335
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
2336
- sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
2337
- sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
2338
- [=](sycl::nd_item<3> item_ct1) {
2339
- gelu_quick_f32(x, dst, k, item_ct1);
2340
- });
2341
- }
2342
-
2343
- static void tanh_f32_sycl(const float *x, float *dst, const int k,
2344
- queue_ptr stream) {
2345
- const int num_blocks = (k + SYCL_TANH_BLOCK_SIZE - 1) / SYCL_TANH_BLOCK_SIZE;
2346
- stream->parallel_for(
2347
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
2348
- sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE),
2349
- sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE)),
2350
- [=](sycl::nd_item<3> item_ct1) {
2351
- tanh_f32(x, dst, k, item_ct1);
2352
- });
2353
- }
2354
-
2355
- static void relu_f32_sycl(const float *x, float *dst, const int k,
2356
- queue_ptr stream) {
2357
- const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
2358
- stream->parallel_for(
2359
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
2360
- sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE),
2361
- sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)),
2362
- [=](sycl::nd_item<3> item_ct1) {
2363
- relu_f32(x, dst, k, item_ct1);
2364
- });
2365
- }
2366
-
2367
- static void hardsigmoid_f32_sycl(const float *x, float *dst, const int k,
2368
- queue_ptr stream) {
2369
- const int num_blocks = (k + SYCL_HARDSIGMOID_BLOCK_SIZE - 1) / SYCL_HARDSIGMOID_BLOCK_SIZE;
2370
- stream->parallel_for(
2371
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
2372
- sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE),
2373
- sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE)),
2374
- [=](sycl::nd_item<3> item_ct1) {
2375
- hardsigmoid_f32(x, dst, k, item_ct1);
2376
- });
2377
- }
2378
-
2379
- static void hardswish_f32_sycl(const float *x, float *dst, const int k,
2380
- queue_ptr stream) {
2381
- const int num_blocks = (k + SYCL_HARDSWISH_BLOCK_SIZE - 1) / SYCL_HARDSWISH_BLOCK_SIZE;
2382
- stream->parallel_for(
2383
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
2384
- sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE),
2385
- sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE)),
2386
- [=](sycl::nd_item<3> item_ct1) {
2387
- hardswish_f32(x, dst, k, item_ct1);
2388
- });
2389
- }
2390
-
2391
- static void leaky_relu_f32_sycl(const float *x, float *dst, const int k,
2392
- const float negative_slope,
2393
- queue_ptr stream) {
2394
- const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
2395
- stream->parallel_for(
2396
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
2397
- sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE),
2398
- sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)),
2399
- [=](sycl::nd_item<3> item_ct1) {
2400
- leaky_relu_f32(x, dst, k, negative_slope, item_ct1);
2401
- });
2402
- }
2403
-
2404
- static void sqr_f32_sycl(const float *x, float *dst, const int k,
2405
- queue_ptr stream) {
2406
- const int num_blocks = (k + SYCL_SQR_BLOCK_SIZE - 1) / SYCL_SQR_BLOCK_SIZE;
2407
- stream->parallel_for(
2408
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
2409
- sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE),
2410
- sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE)),
2411
- [=](sycl::nd_item<3> item_ct1) {
2412
- sqr_f32(x, dst, k, item_ct1);
2413
- });
2414
- }
2415
-
2416
- static void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01,
2417
- const int nb02, const int nb03, const int ne10, const int ne11,
2418
- const int ne12, const int ne13, const float sf0, const float sf1,
2419
- const float sf2, const float sf3, queue_ptr stream) {
2420
- int dst_size = ne10 * ne11 * ne12 * ne13;
2421
- int num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE;
2422
- sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE);
2423
- stream->parallel_for(
2424
- sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)),
2425
- [=](sycl::nd_item<1> item_ct1) {
2426
- upscale_f32(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, item_ct1);
2427
- });
2428
- }
2429
-
2430
- static void pad_f32_sycl(const float *x, float *dst, const int ne00,
2431
- const int ne01, const int ne02, const int ne0,
2432
- const int ne1, const int ne2, queue_ptr stream) {
2433
- int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE;
2434
- sycl::range<3> gridDim(ne2, ne1, num_blocks);
2435
- stream->parallel_for(
2436
- sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE),
2437
- sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)),
2438
- [=](sycl::nd_item<3> item_ct1) {
2439
- pad_f32(x, dst, ne0, ne00, ne01, ne02, item_ct1);
2440
- });
2441
- }
2442
 
2443
  static void quantize_row_q8_1_sycl(const float *x, void *vy, const int kx,
2444
  const int ky, const int kx_padded,
@@ -2816,6 +2261,58 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
2816
  }
2817
  }
2818
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2819
  static void diag_mask_inf_f32_sycl(const float *x, float *dst,
2820
  const int ncols_x, const int nrows_x,
2821
  const int rows_per_channel, const int n_past,
@@ -2855,362 +2352,111 @@ static dpct::err0 ggml_sycl_cpy_tensor_2d(void *dst,
2855
  } else {
2856
  // GGML_SYCL_DEBUG("GGML_ABORT("fatal error")\n");
2857
  GGML_ABORT("fatal error");
2858
- }
2859
- char * dst_ptr = (char *) dst;
2860
-
2861
- GGML_TENSOR_LOCALS_1(int64_t, ne, src, ne);
2862
- GGML_TENSOR_LOCALS(int64_t, nb, src, nb);
2863
- const enum ggml_type type = src->type;
2864
- const int64_t ts = ggml_type_size(type);
2865
- const int64_t bs = ggml_blck_size(type);
2866
- int64_t i1_diff = i1_high - i1_low;
2867
-
2868
- const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;
2869
- if (nb0 == ts && nb1 == ts*ne0/bs) {
2870
- // GGML_SYCL_DEBUG("stream->memcpy: dst_ptr=%p, x=%p, size=%lu\n", dst_ptr, x, i1_diff * nb1);
2871
- // return CHECK_TRY_ERROR(stream->memcpy(dst_ptr, x, i1_diff * nb1));
2872
- return CHECK_TRY_ERROR(dpct::async_dpct_memcpy(dst_ptr, x, i1_diff * nb1,
2873
- kind, *stream));
2874
-
2875
- } else if (nb0 == ts) {
2876
- return CHECK_TRY_ERROR(
2877
- dpct::async_dpct_memcpy(dst_ptr, ts * ne0 / bs, x, nb1,
2878
- ts * ne0 / bs, i1_diff, kind, *stream));
2879
- } else {
2880
- for (int64_t i1 = 0; i1 < i1_diff; i1++) {
2881
- const void * rx = (const void *) ((const char *) x + i1*nb1);
2882
- void * rd = (void *) (dst_ptr + i1*ts*ne0/bs);
2883
- // pretend the row is a matrix with cols=1
2884
- dpct::err0 r = CHECK_TRY_ERROR(dpct::async_dpct_memcpy(
2885
- rd, ts / bs, rx, nb0, ts / bs, ne0, kind, *stream));
2886
- /*
2887
- DPCT1001:85: The statement could not be removed.
2888
- */
2889
- /*
2890
- DPCT1000:86: Error handling if-stmt was detected but could not be
2891
- rewritten.
2892
- */
2893
- if (r != 0) return r;
2894
- }
2895
- return 0;
2896
- }
2897
- }
2898
- catch (sycl::exception const &exc) {
2899
- std::cerr << exc.what() << "Exception caught at file:" << __FILE__
2900
- << ", line:" << __LINE__ << std::endl;
2901
- std::exit(1);
2902
- }
2903
-
2904
- static void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2905
- const ggml_tensor *src1, ggml_tensor *dst,
2906
- const float *src0_d, const float *src1_d,
2907
- float *dst_d, const queue_ptr &stream) {
2908
-
2909
- GGML_ASSERT(src1->type == GGML_TYPE_I32);
2910
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
2911
-
2912
- GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
2913
- GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
2914
- GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
2915
-
2916
- const int32_t * src1_i32 = (const int32_t *) src1_d;
2917
-
2918
- switch (src0->type) {
2919
- case GGML_TYPE_F16:
2920
- get_rows_sycl_float(ctx, src0, src1, dst, (const sycl::half *)src0_d,
2921
- src1_i32, dst_d, stream);
2922
- break;
2923
- case GGML_TYPE_F32:
2924
- get_rows_sycl_float(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
2925
- break;
2926
- case GGML_TYPE_Q4_0:
2927
- get_rows_sycl<QK4_0, QR4_0, dequantize_q4_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
2928
- break;
2929
- case GGML_TYPE_Q4_1:
2930
- get_rows_sycl<QK4_1, QR4_1, dequantize_q4_1>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
2931
- break;
2932
- case GGML_TYPE_Q5_0:
2933
- get_rows_sycl<QK5_0, QR5_0, dequantize_q5_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
2934
- break;
2935
- case GGML_TYPE_Q5_1:
2936
- get_rows_sycl<QK5_1, QR5_1, dequantize_q5_1>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
2937
- break;
2938
- case GGML_TYPE_Q8_0:
2939
- get_rows_sycl<QK8_0, QR8_0, dequantize_q8_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
2940
- break;
2941
- default:
2942
- // TODO: k-quants
2943
- fprintf(stderr, "%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type));
2944
- GGML_ABORT("fatal error");
2945
- break;
2946
- }
2947
- }
2948
-
2949
- template <class op>
2950
- inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2951
- const ggml_tensor *src1, ggml_tensor *dst,
2952
- const float *src0_dd, const float *src1_dd,
2953
- float *dst_dd,
2954
- const queue_ptr &main_stream) {
2955
-
2956
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
2957
- op()(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
2958
- } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
2959
- op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd,
2960
- (sycl::half *)dst_dd, main_stream);
2961
- } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
2962
- op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd, dst_dd,
2963
- main_stream);
2964
- } else if (src0->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
2965
- op()(ctx, src0, src1, dst, (const int32_t *)src0_dd, (const int32_t *)src1_dd, (int32_t *)dst_dd,
2966
- main_stream);
2967
- } else if (src0->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) {
2968
- op()(ctx, src0, src1, dst, (const int16_t *)src0_dd, (const int16_t *)src1_dd, (int16_t *)dst_dd,
2969
- main_stream);
2970
- } else {
2971
- fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
2972
- ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
2973
- GGML_ABORT("fatal error");
2974
- }
2975
- }
2976
-
2977
- static void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2978
- const ggml_tensor *src1, ggml_tensor *dst,
2979
- const float *src0_d, const float *src1_d,
2980
- float *dst_d,
2981
- const queue_ptr &main_stream) {
2982
-
2983
- ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_repeat>>(ctx, dst, src0, dst, nullptr, src0_d, dst_d, main_stream);
2984
-
2985
- (void) src1;
2986
- (void) src1_d;
2987
- }
2988
-
2989
- inline void ggml_sycl_op_add(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
2990
- ggml_tensor *dst, const float *src0_dd,
2991
- const float *src1_dd, float *dst_dd,
2992
- const queue_ptr &main_stream) {
2993
-
2994
- ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_add>>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
2995
- }
2996
-
2997
- inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
2998
- ggml_tensor *dst, const float *src0_dd,
2999
- const float *src1_dd, float *dst_dd,
3000
- const queue_ptr &main_stream) {
3001
-
3002
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
3003
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
3004
- GGML_ASSERT( dst->type == GGML_TYPE_F32);
3005
- GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported
3006
-
3007
- int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
3008
- int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
3009
- // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
3010
- int offset = dst->op_params[3] / 4; // offset in bytes
3011
-
3012
- acc_f32_sycl(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], nb1, nb2, offset, main_stream);
3013
-
3014
- (void) dst;
3015
- }
3016
-
3017
- inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
3018
- ggml_tensor *dst, const float *src0_dd,
3019
- const float *src1_dd, float *dst_dd,
3020
- const queue_ptr &main_stream) {
3021
-
3022
- ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_mul>>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
3023
- }
3024
-
3025
- inline void ggml_sycl_op_div(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
3026
- ggml_tensor *dst, const float *src0_dd,
3027
- const float *src1_dd, float *dst_dd,
3028
- const queue_ptr &main_stream) {
3029
-
3030
- ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_div>>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
3031
- }
3032
-
3033
- inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
3034
- ggml_tensor *dst, const float *src0_dd,
3035
- const float *src1_dd, float *dst_dd,
3036
- const queue_ptr &main_stream) {
3037
-
3038
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
3039
- GGML_ASSERT( dst->type == GGML_TYPE_F32);
3040
-
3041
- gelu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
3042
-
3043
- (void) src1;
3044
- (void) dst;
3045
- (void) src1_dd;
3046
- }
3047
-
3048
- inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
3049
- ggml_tensor *dst, const float *src0_dd,
3050
- const float *src1_dd, float *dst_dd,
3051
- const queue_ptr &main_stream) {
3052
-
3053
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
3054
- GGML_ASSERT( dst->type == GGML_TYPE_F32);
3055
-
3056
- silu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
3057
-
3058
- (void) src1;
3059
- (void) dst;
3060
- (void) src1_dd;
3061
- }
3062
-
3063
- inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
3064
- const ggml_tensor *src1, ggml_tensor *dst,
3065
- const float *src0_dd, const float *src1_dd,
3066
- float *dst_dd,
3067
- const queue_ptr &main_stream) {
3068
-
3069
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
3070
- GGML_ASSERT( dst->type == GGML_TYPE_F32);
3071
-
3072
- gelu_quick_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
3073
-
3074
- (void) src1;
3075
- (void) dst;
3076
- (void) src1_dd;
3077
- }
3078
-
3079
- inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
3080
- ggml_tensor *dst, const float *src0_dd,
3081
- const float *src1_dd, float *dst_dd,
3082
- const queue_ptr &main_stream) {
3083
-
3084
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
3085
- GGML_ASSERT( dst->type == GGML_TYPE_F32);
3086
- tanh_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
3087
-
3088
- (void) src1;
3089
- (void) dst;
3090
- (void) src1_dd;
3091
- }
3092
-
3093
- inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
3094
- ggml_tensor *dst, const float *src0_dd,
3095
- const float *src1_dd, float *dst_dd,
3096
- const queue_ptr &main_stream) {
3097
-
3098
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
3099
- GGML_ASSERT( dst->type == GGML_TYPE_F32);
3100
-
3101
- relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
3102
-
3103
- (void) src1;
3104
- (void) dst;
3105
- (void) src1_dd;
3106
- }
3107
-
3108
- static void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
3109
- const ggml_tensor *src1, ggml_tensor *dst,
3110
- const float *src0_dd, const float *src1_dd,
3111
- float *dst_dd,
3112
- const queue_ptr &main_stream) {
3113
-
3114
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
3115
- GGML_ASSERT( dst->type == GGML_TYPE_F32);
3116
-
3117
- hardsigmoid_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
3118
-
3119
- (void) src1;
3120
- (void) dst;
3121
- (void) src1_dd;
3122
- }
3123
-
3124
- static void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
3125
- const ggml_tensor *src1, ggml_tensor *dst,
3126
- const float *src0_dd, const float *src1_dd,
3127
- float *dst_dd, const queue_ptr &main_stream) {
3128
-
3129
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
3130
- GGML_ASSERT( dst->type == GGML_TYPE_F32);
3131
-
3132
- hardswish_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
3133
-
3134
- (void) src1;
3135
- (void) dst;
3136
- (void) src1_dd;
3137
- }
3138
-
3139
- inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
3140
- const ggml_tensor *src1, ggml_tensor *dst,
3141
- const float *src0_dd, const float *src1_dd,
3142
- float *dst_dd,
3143
- const queue_ptr &main_stream) {
3144
-
3145
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
3146
- GGML_ASSERT( dst->type == GGML_TYPE_F32);
3147
 
3148
- float negative_slope;
3149
- memcpy(&negative_slope, dst->op_params, sizeof(float));
 
 
 
 
3150
 
3151
- leaky_relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), negative_slope, main_stream);
 
 
 
 
 
3152
 
3153
- (void) src1;
3154
- (void) dst;
3155
- (void) src1_dd;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3156
  }
3157
-
3158
- inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
3159
- ggml_tensor *dst, const float *src0_dd,
3160
- const float *src1_dd, float *dst_dd,
3161
- const queue_ptr &main_stream) {
3162
-
3163
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
3164
- GGML_ASSERT( dst->type == GGML_TYPE_F32);
3165
-
3166
- sqr_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
3167
-
3168
- (void) src1;
3169
- (void) dst;
3170
- (void) src1_dd;
3171
  }
3172
 
3173
- inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
3174
- const ggml_tensor *src1, ggml_tensor *dst,
3175
- const float *src0_dd, const float *src1_dd,
3176
- float *dst_dd,
3177
- const queue_ptr &main_stream) {
3178
 
3179
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
3180
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
3181
 
3182
- const float sf0 = (float)dst->ne[0]/src0->ne[0];
3183
- const float sf1 = (float)dst->ne[1]/src0->ne[1];
3184
- const float sf2 = (float)dst->ne[2]/src0->ne[2];
3185
- const float sf3 = (float)dst->ne[3]/src0->ne[3];
3186
 
3187
- upscale_f32_sycl(src0_dd, dst_dd, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
3188
- dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3,
3189
- main_stream);
3190
 
3191
- (void) src1;
3192
- (void) dst;
3193
- (void) src1_dd;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3194
  }
3195
 
3196
- inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
3197
- ggml_tensor *dst, const float *src0_dd,
3198
- const float *src1_dd, float *dst_dd,
3199
- const queue_ptr &main_stream) {
3200
 
3201
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
3202
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
3203
- GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
 
 
3204
 
3205
- pad_f32_sycl(src0_dd, dst_dd,
3206
- src0->ne[0], src0->ne[1], src0->ne[2],
3207
- dst->ne[0], dst->ne[1], dst->ne[2], main_stream);
3208
 
3209
  (void) src1;
3210
- (void) dst;
3211
- (void) src1_dd;
3212
  }
3213
 
 
3214
  inline void ggml_sycl_op_mul_mat_sycl(
3215
  ggml_backend_sycl_context & ctx,
3216
  const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
@@ -3379,6 +2625,23 @@ static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, const ggml_tens
3379
  (void) src1_dd;
3380
  }
3381
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3382
  inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
3383
  const ggml_tensor *src1, ggml_tensor *dst,
3384
  const float *src0_dd, const float *src1_dd,
@@ -3419,6 +2682,25 @@ inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, const ggml_ten
3419
  (void) src1_dd;
3420
  }
3421
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3422
  inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
3423
  const ggml_tensor *src1,
3424
  ggml_tensor *dst, const float *src0_dd,
@@ -3489,46 +2771,6 @@ inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, const ggml_tenso
3489
  (void) src1_dd;
3490
  }
3491
 
3492
- static void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
3493
- const ggml_tensor *src1, ggml_tensor *dst,
3494
- const ggml_sycl_op_flatten_t op) try {
3495
- const int64_t nrows0 = ggml_nrows(src0);
3496
-
3497
- const bool use_src1 = src1 != nullptr;
3498
- const int64_t nrows1 = use_src1 ? ggml_nrows(src1) : 1;
3499
-
3500
- GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
3501
- GGML_ASSERT( dst->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
3502
-
3503
- ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
3504
- ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr;
3505
- ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
3506
-
3507
- // dd = data device
3508
- float * src0_ddf = (float *) src0->data;
3509
- float * src1_ddf = use_src1 ? (float *) src1->data : nullptr;
3510
- float * dst_ddf = (float *) dst->data;
3511
-
3512
- ggml_sycl_pool_alloc<float> src0_f(ctx.pool());
3513
- ggml_sycl_pool_alloc<float> src1_f(ctx.pool());
3514
- ggml_sycl_pool_alloc<float> dst_f(ctx.pool());
3515
-
3516
- ggml_sycl_set_device(ctx.device);
3517
- queue_ptr main_stream = ctx.stream();
3518
- // GGML_SYCL_DEBUG("ctx.device=%d, main_stream=%p src0_on_device=%d, src1_on_device=%d, dst_on_device=%d\n",
3519
- // ctx.device, main_stream, src0_on_device, src1_on_device, dst_on_device);
3520
-
3521
- // do the computation
3522
- op(ctx, src0, src1, dst, src0_ddf, src1_ddf, dst_ddf, main_stream);
3523
- // print_ggml_tensor("tensor", dst);
3524
- }
3525
- catch (sycl::exception const &exc) {
3526
-
3527
- std::cerr << exc.what() << "Exception caught at file:" << __FILE__
3528
- << ", line:" << __LINE__ << std::endl;
3529
- std::exit(1);
3530
- }
3531
-
3532
  static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) {
3533
  static bool peer_access_enabled = false;
3534
 
@@ -3908,112 +3150,21 @@ static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, const ggml_tenso
3908
  GGML_SYCL_DEBUG("call %s done\n", __func__);
3909
  }
3910
 
3911
- static void ggml_sycl_add(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3912
- GGML_SYCL_DEBUG("call %s\n", __func__);
3913
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_add);
3914
- GGML_SYCL_DEBUG("call %s done\n", __func__);
3915
- }
3916
-
3917
- static void ggml_sycl_acc(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3918
- GGML_SYCL_DEBUG("call %s\n", __func__);
3919
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_acc);
3920
- GGML_SYCL_DEBUG("call %s done\n", __func__);
3921
- }
3922
-
3923
- static void ggml_sycl_mul(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3924
- GGML_SYCL_DEBUG("call %s\n", __func__);
3925
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_mul);
3926
- GGML_SYCL_DEBUG("call %s done\n", __func__);
3927
- }
3928
-
3929
- static void ggml_sycl_div(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3930
- GGML_SYCL_DEBUG("call %s\n", __func__);
3931
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_div);
3932
- GGML_SYCL_DEBUG("call %s done\n", __func__);
3933
- }
3934
-
3935
- static void ggml_sycl_gelu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3936
- GGML_SYCL_DEBUG("call %s\n", __func__);
3937
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_gelu);
3938
- GGML_SYCL_DEBUG("call %s done\n", __func__);
3939
- }
3940
-
3941
- static void ggml_sycl_silu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3942
- GGML_SYCL_DEBUG("call %s\n", __func__);
3943
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_silu);
3944
- GGML_SYCL_DEBUG("call %s done\n", __func__);
3945
- }
3946
-
3947
- static void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3948
- GGML_SYCL_DEBUG("call %s\n", __func__);
3949
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_gelu_quick);
3950
- GGML_SYCL_DEBUG("call %s done\n", __func__);
3951
- }
3952
-
3953
- static void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3954
- GGML_SYCL_DEBUG("call %s\n", __func__);
3955
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_tanh);
3956
- GGML_SYCL_DEBUG("call %s done\n", __func__);
3957
- }
3958
-
3959
- static void ggml_sycl_relu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3960
- GGML_SYCL_DEBUG("call %s\n", __func__);
3961
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_relu);
3962
- GGML_SYCL_DEBUG("call %s done\n", __func__);
3963
- }
3964
-
3965
- static void ggml_sycl_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3966
- GGML_SYCL_DEBUG("call %s\n", __func__);
3967
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_hardsigmoid);
3968
- GGML_SYCL_DEBUG("call %s done\n", __func__);
3969
- }
3970
-
3971
- static void ggml_sycl_hardswish(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3972
- GGML_SYCL_DEBUG("call %s\n", __func__);
3973
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_hardswish);
3974
- GGML_SYCL_DEBUG("call %s done\n", __func__);
3975
- }
3976
-
3977
- static void ggml_sycl_leaky_relu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3978
- GGML_SYCL_DEBUG("call %s\n", __func__);
3979
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_leaky_relu);
3980
- GGML_SYCL_DEBUG("call %s done\n", __func__);
3981
- }
3982
-
3983
- static void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3984
- GGML_SYCL_DEBUG("call %s\n", __func__);
3985
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sqr);
3986
- GGML_SYCL_DEBUG("call %s done\n", __func__);
3987
- }
3988
-
3989
  static void ggml_sycl_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3990
  GGML_SYCL_DEBUG("call %s\n", __func__);
3991
  ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_norm);
3992
  GGML_SYCL_DEBUG("call %s done\n", __func__);
3993
  }
3994
 
3995
- static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3996
- GGML_SYCL_DEBUG("call %s\n", __func__);
3997
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_group_norm);
3998
- GGML_SYCL_DEBUG("call %s done\n", __func__);
3999
- }
4000
-
4001
- static void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
4002
- GGML_SYCL_DEBUG("call %s\n", __func__);
4003
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_upscale);
4004
- GGML_SYCL_DEBUG("call %s done\n", __func__);
4005
- }
4006
-
4007
- static void ggml_sycl_pad(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
4008
  GGML_SYCL_DEBUG("call %s\n", __func__);
4009
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_pad);
4010
  GGML_SYCL_DEBUG("call %s done\n", __func__);
4011
  }
4012
 
4013
-
4014
- static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
4015
  GGML_SYCL_DEBUG("call %s\n", __func__);
4016
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_rms_norm);
4017
  GGML_SYCL_DEBUG("call %s done\n", __func__);
4018
  }
4019
 
@@ -4632,6 +3783,11 @@ static void ggml_sycl_im2col(ggml_backend_sycl_context & ctx, const ggml_tensor
4632
  ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_im2col);
4633
  }
4634
 
 
 
 
 
 
4635
  static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
4636
  GGML_ASSERT(ggml_is_contiguous(src0));
4637
  ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sum_rows);
@@ -4642,6 +3798,11 @@ static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, const ggml_tensor
4642
  ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_argsort);
4643
  }
4644
 
 
 
 
 
 
4645
  static void ggml_sycl_nop(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
4646
  (void) src0;
4647
  (void) src1;
@@ -4673,6 +3834,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
4673
  ggml_sycl_func_t func;
4674
 
4675
  switch (tensor->op) {
 
 
 
4676
  case GGML_OP_CONV_TRANSPOSE_1D:
4677
  func = ggml_sycl_op_conv_transpose_1d;
4678
  break;
@@ -4686,19 +3850,32 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
4686
  func = ggml_sycl_dup;
4687
  break;
4688
  case GGML_OP_ADD:
 
4689
  func = ggml_sycl_add;
4690
  break;
 
 
 
4691
  case GGML_OP_ACC:
4692
  func = ggml_sycl_acc;
4693
  break;
4694
  case GGML_OP_MUL:
4695
  func = ggml_sycl_mul;
4696
  break;
 
 
 
4697
  case GGML_OP_DIV:
4698
  func = ggml_sycl_div;
4699
  break;
4700
  case GGML_OP_UNARY:
4701
  switch (ggml_get_unary_op(tensor)) {
 
 
 
 
 
 
4702
  case GGML_UNARY_OP_GELU:
4703
  func = ggml_sycl_gelu;
4704
  break;
@@ -4714,12 +3891,18 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
4714
  case GGML_UNARY_OP_RELU:
4715
  func = ggml_sycl_relu;
4716
  break;
 
 
 
4717
  case GGML_UNARY_OP_HARDSIGMOID:
4718
  func = ggml_sycl_hardsigmoid;
4719
  break;
4720
  case GGML_UNARY_OP_HARDSWISH:
4721
  func = ggml_sycl_hardswish;
4722
  break;
 
 
 
4723
  default:
4724
  return false;
4725
  }
@@ -4757,12 +3940,24 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
4757
  }
4758
  func = ggml_sycl_mul_mat_id;
4759
  break;
 
 
 
4760
  case GGML_OP_SCALE:
4761
  func = ggml_sycl_scale;
4762
  break;
4763
  case GGML_OP_SQR:
4764
  func = ggml_sycl_sqr;
4765
  break;
 
 
 
 
 
 
 
 
 
4766
  case GGML_OP_CLAMP:
4767
  func = ggml_sycl_clamp;
4768
  break;
@@ -4794,6 +3989,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
4794
  case GGML_OP_POOL_2D:
4795
  func = ggml_sycl_pool2d;
4796
  break;
 
 
 
4797
  case GGML_OP_SUM_ROWS:
4798
  func = ggml_sycl_sum_rows;
4799
  break;
@@ -4803,6 +4001,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
4803
  case GGML_OP_TIMESTEP_EMBEDDING:
4804
  func = ggml_sycl_op_timestep_embedding;
4805
  break;
 
 
 
4806
  default:
4807
  return false;
4808
  }
@@ -5125,13 +4326,17 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
5125
  } break;
5126
  case GGML_OP_UNARY:
5127
  switch (ggml_get_unary_op(op)) {
 
 
5128
  case GGML_UNARY_OP_GELU:
5129
  case GGML_UNARY_OP_SILU:
5130
  case GGML_UNARY_OP_RELU:
 
5131
  case GGML_UNARY_OP_HARDSIGMOID:
5132
  case GGML_UNARY_OP_HARDSWISH:
5133
  case GGML_UNARY_OP_GELU_QUICK:
5134
  case GGML_UNARY_OP_TANH:
 
5135
  return ggml_is_contiguous(op->src[0]);
5136
  default:
5137
  return false;
@@ -5168,6 +4373,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
5168
  }
5169
  return true;
5170
  } break;
 
 
5171
  case GGML_OP_GET_ROWS:
5172
  {
5173
  switch (op->src[0]->type) {
@@ -5213,10 +4420,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
5213
  case GGML_OP_CONCAT:
5214
  {
5215
  ggml_type src0_type = op->src[0]->type;
5216
- int dim = op->op_params[0];
5217
- return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16 && dim == 2;
5218
  } break;
5219
  case GGML_OP_DUP:
 
5220
  case GGML_OP_NONE:
5221
  case GGML_OP_RESHAPE:
5222
  case GGML_OP_REPEAT:
@@ -5225,11 +4432,17 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
5225
  case GGML_OP_TRANSPOSE:
5226
  case GGML_OP_NORM:
5227
  case GGML_OP_ADD:
 
 
 
5228
  case GGML_OP_MUL:
5229
  case GGML_OP_DIV:
5230
  case GGML_OP_RMS_NORM:
5231
  case GGML_OP_SCALE:
5232
  case GGML_OP_SQR:
 
 
 
5233
  case GGML_OP_CLAMP:
5234
  return true;
5235
  case GGML_OP_CONT:
@@ -5243,6 +4456,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
5243
  // TODO: add support for the new F32 operations
5244
  return op->src[0]->type == GGML_TYPE_F16;
5245
  case GGML_OP_POOL_2D:
 
5246
  case GGML_OP_SUM_ROWS:
5247
  case GGML_OP_ARGSORT:
5248
  case GGML_OP_ACC:
@@ -5251,6 +4465,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
5251
  case GGML_OP_PAD:
5252
  case GGML_OP_LEAKY_RELU:
5253
  case GGML_OP_TIMESTEP_EMBEDDING:
 
5254
  return true;
5255
  default:
5256
  return false;
@@ -5268,9 +4483,23 @@ static bool ggml_backend_sycl_device_supports_buft(ggml_backend_dev_t dev, ggml_
5268
  return buft_ctx->device == sycl_ctx->device;
5269
  }
5270
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5271
  static bool ggml_backend_sycl_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
5272
  const int min_batch_size = 32;
5273
- return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS && op->op != GGML_OP_MUL_MAT_ID;
5274
  GGML_UNUSED(dev);
5275
  }
5276
 
 
1194
  float *dst_dd_i, const int64_t row_low, const int64_t row_high,
1195
  const int64_t src1_ncols, const int64_t src1_padded_row_size,
1196
  const queue_ptr &stream);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1197
 
 
 
 
 
 
 
 
1198
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1199
 
1200
  template<int QUANT_BLOCK_TILE>
1201
  static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded,
 
1884
  (void) dst;
1885
  }
1886
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1887
 
1888
  static void quantize_row_q8_1_sycl(const float *x, void *vy, const int kx,
1889
  const int ky, const int kx_padded,
 
2261
  }
2262
  }
2263
 
2264
+ static void argmax_f32_i32_sycl(const float *x, int *dst, const int ncols,
2265
+ const int nrows, queue_ptr stream) {
2266
+ const sycl::range<3> block_dims(1, 1, SYCL_ARGMAX_BLOCK_SIZE);
2267
+ const sycl::range<3> block_nums(1, nrows, 1);
2268
+ const size_t shared_mem = 256 * sizeof(float);
2269
+
2270
+ stream->submit([&](sycl::handler &cgh) {
2271
+ sycl::local_accessor<float, 1> shared_data(
2272
+ sycl::range<1>(shared_mem/sizeof(float)), cgh);
2273
+ sycl::local_accessor<int, 1> shared_indices(
2274
+ sycl::range<1>(shared_mem/sizeof(float)), cgh);
2275
+
2276
+ cgh.parallel_for(
2277
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
2278
+ [=](sycl::nd_item<3> item_ct1) {
2279
+ const int tid = item_ct1.get_local_id(2);
2280
+ const int row = item_ct1.get_global_id(1);
2281
+
2282
+ float max_val = -INFINITY;
2283
+ int max_idx = -1;
2284
+
2285
+ for (int col = tid; col < ncols; col += 256) {
2286
+ float val = x[row * ncols + col];
2287
+ if (val > max_val) {
2288
+ max_val = val;
2289
+ max_idx = col;
2290
+ }
2291
+ }
2292
+
2293
+ shared_data[tid] = max_val;
2294
+ shared_indices[tid] = max_idx;
2295
+ item_ct1.barrier(sycl::access::fence_space::local_space);
2296
+
2297
+ for (int stride = 256/2; stride > 0; stride >>= 1) {
2298
+ if (tid < stride) {
2299
+ float val1 = shared_data[tid];
2300
+ float val2 = shared_data[tid + stride];
2301
+ if (val2 > val1) {
2302
+ shared_data[tid] = val2;
2303
+ shared_indices[tid] = shared_indices[tid + stride];
2304
+ }
2305
+ }
2306
+ item_ct1.barrier(sycl::access::fence_space::local_space);
2307
+ }
2308
+
2309
+
2310
+ if (tid == 0) {
2311
+ dst[row] = shared_indices[0];
2312
+ }
2313
+ });
2314
+ });
2315
+ }
2316
  static void diag_mask_inf_f32_sycl(const float *x, float *dst,
2317
  const int ncols_x, const int nrows_x,
2318
  const int rows_per_channel, const int n_past,
 
2352
  } else {
2353
  // GGML_SYCL_DEBUG("GGML_ABORT("fatal error")\n");
2354
  GGML_ABORT("fatal error");
2355
+ }
2356
+ char * dst_ptr = (char *) dst;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2357
 
2358
+ GGML_TENSOR_LOCALS_1(int64_t, ne, src, ne);
2359
+ GGML_TENSOR_LOCALS(int64_t, nb, src, nb);
2360
+ const enum ggml_type type = src->type;
2361
+ const int64_t ts = ggml_type_size(type);
2362
+ const int64_t bs = ggml_blck_size(type);
2363
+ int64_t i1_diff = i1_high - i1_low;
2364
 
2365
+ const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;
2366
+ if (nb0 == ts && nb1 == ts*ne0/bs) {
2367
+ // GGML_SYCL_DEBUG("stream->memcpy: dst_ptr=%p, x=%p, size=%lu\n", dst_ptr, x, i1_diff * nb1);
2368
+ // return CHECK_TRY_ERROR(stream->memcpy(dst_ptr, x, i1_diff * nb1));
2369
+ return CHECK_TRY_ERROR(dpct::async_dpct_memcpy(dst_ptr, x, i1_diff * nb1,
2370
+ kind, *stream));
2371
 
2372
+ } else if (nb0 == ts) {
2373
+ return CHECK_TRY_ERROR(
2374
+ dpct::async_dpct_memcpy(dst_ptr, ts * ne0 / bs, x, nb1,
2375
+ ts * ne0 / bs, i1_diff, kind, *stream));
2376
+ } else {
2377
+ for (int64_t i1 = 0; i1 < i1_diff; i1++) {
2378
+ const void * rx = (const void *) ((const char *) x + i1*nb1);
2379
+ void * rd = (void *) (dst_ptr + i1*ts*ne0/bs);
2380
+ // pretend the row is a matrix with cols=1
2381
+ dpct::err0 r = CHECK_TRY_ERROR(dpct::async_dpct_memcpy(
2382
+ rd, ts / bs, rx, nb0, ts / bs, ne0, kind, *stream));
2383
+ /*
2384
+ DPCT1001:85: The statement could not be removed.
2385
+ */
2386
+ /*
2387
+ DPCT1000:86: Error handling if-stmt was detected but could not be
2388
+ rewritten.
2389
+ */
2390
+ if (r != 0) return r;
2391
+ }
2392
+ return 0;
2393
+ }
2394
  }
2395
+ catch (sycl::exception const &exc) {
2396
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
2397
+ << ", line:" << __LINE__ << std::endl;
2398
+ std::exit(1);
 
 
 
 
 
 
 
 
 
 
2399
  }
2400
 
2401
+ static void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2402
+ const ggml_tensor *src1, ggml_tensor *dst,
2403
+ const float *src0_d, const float *src1_d,
2404
+ float *dst_d, const queue_ptr &stream) {
 
2405
 
2406
+ GGML_ASSERT(src1->type == GGML_TYPE_I32);
2407
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
2408
 
2409
+ GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
2410
+ GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
2411
+ GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
 
2412
 
2413
+ const int32_t * src1_i32 = (const int32_t *) src1_d;
 
 
2414
 
2415
+ switch (src0->type) {
2416
+ case GGML_TYPE_F16:
2417
+ get_rows_sycl_float(ctx, src0, src1, dst, (const sycl::half *)src0_d,
2418
+ src1_i32, dst_d, stream);
2419
+ break;
2420
+ case GGML_TYPE_F32:
2421
+ get_rows_sycl_float(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
2422
+ break;
2423
+ case GGML_TYPE_Q4_0:
2424
+ get_rows_sycl<QK4_0, QR4_0, dequantize_q4_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
2425
+ break;
2426
+ case GGML_TYPE_Q4_1:
2427
+ get_rows_sycl<QK4_1, QR4_1, dequantize_q4_1>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
2428
+ break;
2429
+ case GGML_TYPE_Q5_0:
2430
+ get_rows_sycl<QK5_0, QR5_0, dequantize_q5_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
2431
+ break;
2432
+ case GGML_TYPE_Q5_1:
2433
+ get_rows_sycl<QK5_1, QR5_1, dequantize_q5_1>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
2434
+ break;
2435
+ case GGML_TYPE_Q8_0:
2436
+ get_rows_sycl<QK8_0, QR8_0, dequantize_q8_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
2437
+ break;
2438
+ default:
2439
+ // TODO: k-quants
2440
+ fprintf(stderr, "%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type));
2441
+ GGML_ABORT("fatal error");
2442
+ break;
2443
+ }
2444
  }
2445
 
 
 
 
 
2446
 
2447
+ static void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2448
+ const ggml_tensor *src1, ggml_tensor *dst,
2449
+ const float *src0_d, const float *src1_d,
2450
+ float *dst_d,
2451
+ const queue_ptr &main_stream) {
2452
 
2453
+ ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_repeat>>(ctx, dst, src0, dst, nullptr, src0_d, dst_d, main_stream);
 
 
2454
 
2455
  (void) src1;
2456
+ (void) src1_d;
 
2457
  }
2458
 
2459
+
2460
  inline void ggml_sycl_op_mul_mat_sycl(
2461
  ggml_backend_sycl_context & ctx,
2462
  const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
 
2625
  (void) src1_dd;
2626
  }
2627
 
2628
+ inline void ggml_sycl_op_sum(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2629
+ const ggml_tensor *src1, ggml_tensor *dst,
2630
+ const float *src0_dd, const float *src1_dd,
2631
+ float *dst_dd,
2632
+ const queue_ptr &main_stream) {
2633
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
2634
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
2635
+
2636
+ const int64_t ne = ggml_nelements(src0);
2637
+
2638
+ sum_rows_f32_sycl(src0_dd, dst_dd, ne, 1, main_stream);
2639
+
2640
+ (void) src1;
2641
+ (void) dst;
2642
+ (void) src1_dd;
2643
+ }
2644
+
2645
  inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2646
  const ggml_tensor *src1, ggml_tensor *dst,
2647
  const float *src0_dd, const float *src1_dd,
 
2682
  (void) src1_dd;
2683
  }
2684
 
2685
+ inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2686
+ const ggml_tensor *src1, ggml_tensor *dst,
2687
+ const float *src0_dd, const float *src1_dd,
2688
+ float *dst_dd,
2689
+ const queue_ptr &main_stream) {
2690
+
2691
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
2692
+ GGML_ASSERT( dst->type == GGML_TYPE_I32);
2693
+
2694
+ const int64_t ncols = src0->ne[0];
2695
+ const int64_t nrows = ggml_nrows(src0);
2696
+
2697
+ argmax_f32_i32_sycl(src0_dd, (int *)dst_dd, ncols, nrows, main_stream);
2698
+
2699
+ (void) src1;
2700
+ (void) dst;
2701
+ (void) src1_dd;
2702
+ }
2703
+
2704
  inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2705
  const ggml_tensor *src1,
2706
  ggml_tensor *dst, const float *src0_dd,
 
2771
  (void) src1_dd;
2772
  }
2773
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2774
  static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) {
2775
  static bool peer_access_enabled = false;
2776
 
 
3150
  GGML_SYCL_DEBUG("call %s done\n", __func__);
3151
  }
3152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3153
  static void ggml_sycl_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3154
  GGML_SYCL_DEBUG("call %s\n", __func__);
3155
  ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_norm);
3156
  GGML_SYCL_DEBUG("call %s done\n", __func__);
3157
  }
3158
 
3159
+ static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
 
 
 
 
 
 
 
 
 
 
 
 
3160
  GGML_SYCL_DEBUG("call %s\n", __func__);
3161
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_rms_norm);
3162
  GGML_SYCL_DEBUG("call %s done\n", __func__);
3163
  }
3164
 
3165
+ static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
 
3166
  GGML_SYCL_DEBUG("call %s\n", __func__);
3167
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_group_norm);
3168
  GGML_SYCL_DEBUG("call %s done\n", __func__);
3169
  }
3170
 
 
3783
  ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_im2col);
3784
  }
3785
 
3786
+ static void ggml_sycl_sum(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3787
+ GGML_ASSERT(ggml_is_contiguous(src0));
3788
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sum);
3789
+ }
3790
+
3791
  static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3792
  GGML_ASSERT(ggml_is_contiguous(src0));
3793
  ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sum_rows);
 
3798
  ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_argsort);
3799
  }
3800
 
3801
+ static void ggml_sycl_argmax(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3802
+ GGML_ASSERT(ggml_is_contiguous(src0));
3803
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_argmax);
3804
+ }
3805
+
3806
  static void ggml_sycl_nop(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3807
  (void) src0;
3808
  (void) src1;
 
3834
  ggml_sycl_func_t func;
3835
 
3836
  switch (tensor->op) {
3837
+ case GGML_OP_ARGMAX:
3838
+ func = ggml_sycl_argmax;
3839
+ break;
3840
  case GGML_OP_CONV_TRANSPOSE_1D:
3841
  func = ggml_sycl_op_conv_transpose_1d;
3842
  break;
 
3850
  func = ggml_sycl_dup;
3851
  break;
3852
  case GGML_OP_ADD:
3853
+ case GGML_OP_ADD1: // TODO: more efficient implementation
3854
  func = ggml_sycl_add;
3855
  break;
3856
+ case GGML_OP_SUB:
3857
+ func = ggml_sycl_sub;
3858
+ break;
3859
  case GGML_OP_ACC:
3860
  func = ggml_sycl_acc;
3861
  break;
3862
  case GGML_OP_MUL:
3863
  func = ggml_sycl_mul;
3864
  break;
3865
+ case GGML_OP_LOG:
3866
+ func = ggml_sycl_log;
3867
+ break;
3868
  case GGML_OP_DIV:
3869
  func = ggml_sycl_div;
3870
  break;
3871
  case GGML_OP_UNARY:
3872
  switch (ggml_get_unary_op(tensor)) {
3873
+ case GGML_UNARY_OP_NEG:
3874
+ func = ggml_sycl_neg;
3875
+ break;
3876
+ case GGML_UNARY_OP_STEP:
3877
+ func = ggml_sycl_step;
3878
+ break;
3879
  case GGML_UNARY_OP_GELU:
3880
  func = ggml_sycl_gelu;
3881
  break;
 
3891
  case GGML_UNARY_OP_RELU:
3892
  func = ggml_sycl_relu;
3893
  break;
3894
+ case GGML_UNARY_OP_SIGMOID:
3895
+ func = ggml_sycl_sigmoid;
3896
+ break;
3897
  case GGML_UNARY_OP_HARDSIGMOID:
3898
  func = ggml_sycl_hardsigmoid;
3899
  break;
3900
  case GGML_UNARY_OP_HARDSWISH:
3901
  func = ggml_sycl_hardswish;
3902
  break;
3903
+ case GGML_UNARY_OP_EXP:
3904
+ func = ggml_sycl_exp;
3905
+ break;
3906
  default:
3907
  return false;
3908
  }
 
3940
  }
3941
  func = ggml_sycl_mul_mat_id;
3942
  break;
3943
+ case GGML_OP_OUT_PROD:
3944
+ func = ggml_sycl_op_out_prod;
3945
+ break;
3946
  case GGML_OP_SCALE:
3947
  func = ggml_sycl_scale;
3948
  break;
3949
  case GGML_OP_SQR:
3950
  func = ggml_sycl_sqr;
3951
  break;
3952
+ case GGML_OP_SQRT:
3953
+ func = ggml_sycl_sqrt;
3954
+ break;
3955
+ case GGML_OP_SIN:
3956
+ func = ggml_sycl_sin;
3957
+ break;
3958
+ case GGML_OP_COS:
3959
+ func = ggml_sycl_cos;
3960
+ break;
3961
  case GGML_OP_CLAMP:
3962
  func = ggml_sycl_clamp;
3963
  break;
 
3989
  case GGML_OP_POOL_2D:
3990
  func = ggml_sycl_pool2d;
3991
  break;
3992
+ case GGML_OP_SUM:
3993
+ func = ggml_sycl_sum;
3994
+ break;
3995
  case GGML_OP_SUM_ROWS:
3996
  func = ggml_sycl_sum_rows;
3997
  break;
 
4001
  case GGML_OP_TIMESTEP_EMBEDDING:
4002
  func = ggml_sycl_op_timestep_embedding;
4003
  break;
4004
+ case GGML_OP_RWKV_WKV6:
4005
+ func = ggml_sycl_op_rwkv_wkv6;
4006
+ break;
4007
  default:
4008
  return false;
4009
  }
 
4326
  } break;
4327
  case GGML_OP_UNARY:
4328
  switch (ggml_get_unary_op(op)) {
4329
+ case GGML_UNARY_OP_NEG:
4330
+ case GGML_UNARY_OP_STEP:
4331
  case GGML_UNARY_OP_GELU:
4332
  case GGML_UNARY_OP_SILU:
4333
  case GGML_UNARY_OP_RELU:
4334
+ case GGML_UNARY_OP_SIGMOID:
4335
  case GGML_UNARY_OP_HARDSIGMOID:
4336
  case GGML_UNARY_OP_HARDSWISH:
4337
  case GGML_UNARY_OP_GELU_QUICK:
4338
  case GGML_UNARY_OP_TANH:
4339
+ case GGML_UNARY_OP_EXP:
4340
  return ggml_is_contiguous(op->src[0]);
4341
  default:
4342
  return false;
 
4373
  }
4374
  return true;
4375
  } break;
4376
+ case GGML_OP_OUT_PROD:
4377
+ return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
4378
  case GGML_OP_GET_ROWS:
4379
  {
4380
  switch (op->src[0]->type) {
 
4420
  case GGML_OP_CONCAT:
4421
  {
4422
  ggml_type src0_type = op->src[0]->type;
4423
+ return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
 
4424
  } break;
4425
  case GGML_OP_DUP:
4426
+ case GGML_OP_ARGMAX:
4427
  case GGML_OP_NONE:
4428
  case GGML_OP_RESHAPE:
4429
  case GGML_OP_REPEAT:
 
4432
  case GGML_OP_TRANSPOSE:
4433
  case GGML_OP_NORM:
4434
  case GGML_OP_ADD:
4435
+ case GGML_OP_ADD1:
4436
+ case GGML_OP_LOG:
4437
+ case GGML_OP_SUB:
4438
  case GGML_OP_MUL:
4439
  case GGML_OP_DIV:
4440
  case GGML_OP_RMS_NORM:
4441
  case GGML_OP_SCALE:
4442
  case GGML_OP_SQR:
4443
+ case GGML_OP_SQRT:
4444
+ case GGML_OP_SIN:
4445
+ case GGML_OP_COS:
4446
  case GGML_OP_CLAMP:
4447
  return true;
4448
  case GGML_OP_CONT:
 
4456
  // TODO: add support for the new F32 operations
4457
  return op->src[0]->type == GGML_TYPE_F16;
4458
  case GGML_OP_POOL_2D:
4459
+ case GGML_OP_SUM:
4460
  case GGML_OP_SUM_ROWS:
4461
  case GGML_OP_ARGSORT:
4462
  case GGML_OP_ACC:
 
4465
  case GGML_OP_PAD:
4466
  case GGML_OP_LEAKY_RELU:
4467
  case GGML_OP_TIMESTEP_EMBEDDING:
4468
+ case GGML_OP_RWKV_WKV6:
4469
  return true;
4470
  default:
4471
  return false;
 
4483
  return buft_ctx->device == sycl_ctx->device;
4484
  }
4485
 
4486
+ static int64_t get_op_batch_size(const ggml_tensor * op) {
4487
+ switch (op->op) {
4488
+ case GGML_OP_GET_ROWS:
4489
+ return op->ne[1]; // this will increse the speed of prefill in test
4490
+ case GGML_OP_MUL_MAT:
4491
+ return op->ne[1];
4492
+ case GGML_OP_MUL_MAT_ID:
4493
+ case GGML_OP_ROPE:
4494
+ return op->ne[2];
4495
+ default:
4496
+ return ggml_nrows(op);
4497
+ }
4498
+ }
4499
+
4500
  static bool ggml_backend_sycl_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
4501
  const int min_batch_size = 32;
4502
+ return get_op_batch_size(op) >= min_batch_size;
4503
  GGML_UNUSED(dev);
4504
  }
4505
 
ggml/src/ggml-sycl/backend.hpp CHANGED
@@ -26,5 +26,8 @@
26
  #include "softmax.hpp"
27
  #include "tsembd.hpp"
28
  #include "im2col.hpp"
 
 
 
29
 
30
  #endif // GGML_SYCL_BACKEND_HPP
 
26
  #include "softmax.hpp"
27
  #include "tsembd.hpp"
28
  #include "im2col.hpp"
29
+ #include "wkv6.hpp"
30
+ #include "outprod.hpp"
31
+ #include "element_wise.hpp"
32
 
33
  #endif // GGML_SYCL_BACKEND_HPP
ggml/src/ggml-sycl/common.cpp CHANGED
@@ -62,3 +62,43 @@ int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block
62
  }
63
  return sycl_down_blk_size;
64
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  }
63
  return sycl_down_blk_size;
64
  }
65
+
66
+ void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
67
+ const ggml_tensor *src1, ggml_tensor *dst,
68
+ const ggml_sycl_op_flatten_t op) try {
69
+ const int64_t nrows0 = ggml_nrows(src0);
70
+
71
+ const bool use_src1 = src1 != nullptr;
72
+ const int64_t nrows1 = use_src1 ? ggml_nrows(src1) : 1;
73
+
74
+ GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
75
+ GGML_ASSERT( dst->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
76
+
77
+ ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
78
+ ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr;
79
+ ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
80
+
81
+ // dd = data device
82
+ float * src0_ddf = (float *) src0->data;
83
+ float * src1_ddf = use_src1 ? (float *) src1->data : nullptr;
84
+ float * dst_ddf = (float *) dst->data;
85
+
86
+ ggml_sycl_pool_alloc<float> src0_f(ctx.pool());
87
+ ggml_sycl_pool_alloc<float> src1_f(ctx.pool());
88
+ ggml_sycl_pool_alloc<float> dst_f(ctx.pool());
89
+
90
+ ggml_sycl_set_device(ctx.device);
91
+ queue_ptr main_stream = ctx.stream();
92
+ // GGML_SYCL_DEBUG("ctx.device=%d, main_stream=%p src0_on_device=%d, src1_on_device=%d, dst_on_device=%d\n",
93
+ // ctx.device, main_stream, src0_on_device, src1_on_device, dst_on_device);
94
+
95
+ // do the computation
96
+ op(ctx, src0, src1, dst, src0_ddf, src1_ddf, dst_ddf, main_stream);
97
+ // print_ggml_tensor("tensor", dst);
98
+ }
99
+ catch (sycl::exception const &exc) {
100
+
101
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
102
+ << ", line:" << __LINE__ << std::endl;
103
+ std::exit(1);
104
+ }
ggml/src/ggml-sycl/common.hpp CHANGED
@@ -404,4 +404,262 @@ static __dpct_inline__ Tp* get_pointer(sycl::local_accessor<Tp, dim> acc) {
404
 
405
  int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size);
406
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
407
  #endif // GGML_SYCL_COMMON_HPP
 
404
 
405
  int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size);
406
 
407
+ typedef void (*ggml_sycl_op_flatten_t)(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
408
+ const ggml_tensor *src1,
409
+ ggml_tensor *dst, const float *src0_dd,
410
+ const float *src1_dd, float *dst_dd,
411
+ const queue_ptr &main_stream);
412
+
413
+ template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
414
+ static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
415
+ int ne0, int ne1, int ne2, int ne3,
416
+ int ne10, int ne11, int ne12, int ne13,
417
+ /*int s0, */ int s1, int s2, int s3,
418
+ /*int s10,*/ int s11, int s12, int s13,
419
+ const sycl::nd_item<3> &item_ct1) {
420
+ const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
421
+ item_ct1.get_local_id(2);
422
+ const int i1 = (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
423
+ item_ct1.get_local_id(1));
424
+ const int i2 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
425
+ item_ct1.get_local_id(0)) /
426
+ ne3;
427
+ const int i3 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
428
+ item_ct1.get_local_id(0)) %
429
+ ne3;
430
+
431
+ if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
432
+ return;
433
+ }
434
+
435
+ const int i11 = i1 % ne11;
436
+ const int i12 = i2 % ne12;
437
+ const int i13 = i3 % ne13;
438
+
439
+ const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
440
+ const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
441
+ const size_t i_dst = i_src0;
442
+
443
+ const src0_t * src0_row = src0 + i_src0;
444
+ const src1_t * src1_row = src1 + i_src1;
445
+ dst_t * dst_row = dst + i_dst;
446
+
447
+ for (int i0 = i0s; i0 < ne0;
448
+ i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) {
449
+ const int i10 = i0 % ne10;
450
+ dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
451
+ }
452
+ }
453
+
454
+ template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
455
+ static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
456
+ int ne0, int ne1, int ne2, int ne3,
457
+ int ne10, int ne11, int ne12, int ne13,
458
+ /*int s0, */ int s1, int s2, int s3,
459
+ /*int s10,*/ int s11, int s12, int s13,
460
+ const sycl::nd_item<3> &item_ct1) {
461
+
462
+ const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
463
+ item_ct1.get_local_id(2);
464
+
465
+ const int i3 = i/(ne2*ne1*ne0);
466
+ const int i2 = (i/(ne1*ne0)) % ne2;
467
+ const int i1 = (i/ne0) % ne1;
468
+ const int i0 = i % ne0;
469
+
470
+ if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
471
+ return;
472
+ }
473
+
474
+ const int i11 = i1 % ne11;
475
+ const int i12 = i2 % ne12;
476
+ const int i13 = i3 % ne13;
477
+
478
+ const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
479
+ const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
480
+ const size_t i_dst = i_src0;
481
+
482
+ const src0_t * src0_row = src0 + i_src0;
483
+ const src1_t * src1_row = src1 + i_src1;
484
+ dst_t * dst_row = dst + i_dst;
485
+
486
+ const int i10 = i0 % ne10;
487
+ dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
488
+ }
489
+
490
+
491
+ template<float (*bin_op)(const float, const float)>
492
+ struct bin_bcast_sycl {
493
+ template <typename src0_t, typename src1_t, typename dst_t>
494
+ void operator()(ggml_backend_sycl_context & ctx,
495
+ const struct ggml_tensor *src0,
496
+ const struct ggml_tensor *src1, struct ggml_tensor *dst,
497
+ const src0_t *src0_dd, const src1_t *src1_dd, dst_t *dst_dd,
498
+ queue_ptr stream) {
499
+
500
+ GGML_TENSOR_BINARY_OP_LOCALS
501
+
502
+ int nr0 = ne10/ne0;
503
+ int nr1 = ne11/ne1;
504
+ int nr2 = ne12/ne2;
505
+ int nr3 = ne13/ne3;
506
+
507
+ int nr[4] = { nr0, nr1, nr2, nr3 };
508
+
509
+ // collapse dimensions until first broadcast dimension
510
+ int64_t cne0[] = {ne0, ne1, ne2, ne3};
511
+ int64_t cne1[] = {ne10, ne11, ne12, ne13};
512
+ size_t cnb0[] = {nb0, nb1, nb2, nb3};
513
+ size_t cnb1[] = {nb10, nb11, nb12, nb13};
514
+ auto collapse = [](int64_t cne[]) {
515
+ cne[0] *= cne[1];
516
+ cne[1] = cne[2];
517
+ cne[2] = cne[3];
518
+ cne[3] = 1;
519
+ };
520
+
521
+ auto collapse_nb = [](size_t cnb[], int64_t cne[]) {
522
+ cnb[1] *= cne[1];
523
+ cnb[2] *= cne[2];
524
+ cnb[3] *= cne[3];
525
+ };
526
+
527
+ for (int i = 0; i < 4; i++) {
528
+ if (nr[i] != 1) {
529
+ break;
530
+ }
531
+ if (i > 0) {
532
+ collapse_nb(cnb0, cne0);
533
+ collapse_nb(cnb1, cne1);
534
+ collapse(cne0);
535
+ collapse(cne1);
536
+ }
537
+ }
538
+ {
539
+ int64_t ne0 = cne0[0];
540
+ int64_t ne1 = cne0[1];
541
+ int64_t ne2 = cne0[2];
542
+ int64_t ne3 = cne0[3];
543
+
544
+ int64_t ne10 = cne1[0];
545
+ int64_t ne11 = cne1[1];
546
+ int64_t ne12 = cne1[2];
547
+ int64_t ne13 = cne1[3];
548
+
549
+ size_t nb0 = cnb0[0];
550
+ size_t nb1 = cnb0[1];
551
+ size_t nb2 = cnb0[2];
552
+ size_t nb3 = cnb0[3];
553
+
554
+ size_t nb10 = cnb1[0];
555
+ size_t nb11 = cnb1[1];
556
+ size_t nb12 = cnb1[2];
557
+ size_t nb13 = cnb1[3];
558
+
559
+ size_t s0 = nb0 / sizeof(dst_t);
560
+ size_t s1 = nb1 / sizeof(dst_t);
561
+ size_t s2 = nb2 / sizeof(dst_t);
562
+ size_t s3 = nb3 / sizeof(dst_t);
563
+
564
+ size_t s10 = nb10 / sizeof(src1_t);
565
+ size_t s11 = nb11 / sizeof(src1_t);
566
+ size_t s12 = nb12 / sizeof(src1_t);
567
+ size_t s13 = nb13 / sizeof(src1_t);
568
+
569
+ GGML_ASSERT(s0 == 1);
570
+ GGML_ASSERT(s10 == 1);
571
+
572
+ const int block_size = 128;
573
+
574
+ int64_t hne0 = std::max(ne0/2LL, 1LL);
575
+
576
+ sycl::range<3> block_dims(1, 1, 1);
577
+ block_dims[2] = std::min<unsigned int>(hne0, block_size);
578
+ block_dims[1] = std::min<unsigned int>(
579
+ ne1, block_size / (unsigned int)block_dims[2]);
580
+ block_dims[0] = std::min(
581
+ std::min<unsigned int>(
582
+ ne2 * ne3, block_size / (unsigned int)block_dims[2] /
583
+ (unsigned int)block_dims[1]),
584
+ 64U);
585
+
586
+ sycl::range<3> block_nums(
587
+ (ne2 * ne3 + block_dims[0] - 1) / block_dims[0],
588
+ (ne1 + block_dims[1] - 1) / block_dims[1],
589
+ (hne0 + block_dims[2] - 1) / block_dims[2]);
590
+
591
+ if (block_nums[0] > 65535) {
592
+ // this is the maximum number of blocks in z direction, fallback to 1D grid kernel
593
+ int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
594
+ {
595
+ dpct::has_capability_or_fail(stream->get_device(),
596
+ {sycl::aspect::fp16});
597
+
598
+ stream->parallel_for(
599
+ sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) *
600
+ sycl::range<3>(1, 1, block_size),
601
+ sycl::range<3>(1, 1, block_size)),
602
+ [=](sycl::nd_item<3> item_ct1) {
603
+ k_bin_bcast_unravel<bin_op>(
604
+ src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3,
605
+ ne10, ne11, ne12, ne13, s1, s2, s3, s11, s12,
606
+ s13, item_ct1);
607
+ });
608
+ }
609
+ } else {
610
+ /*
611
+ DPCT1049:16: The work-group size passed to the SYCL kernel may
612
+ exceed the limit. To get the device limit, query
613
+ info::device::max_work_group_size. Adjust the work-group size if
614
+ needed.
615
+ */
616
+ dpct::has_capability_or_fail(stream->get_device(),
617
+ {sycl::aspect::fp16});
618
+
619
+ stream->parallel_for(
620
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
621
+ [=](sycl::nd_item<3> item_ct1) {
622
+ k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
623
+ ne2, ne3, ne10, ne11, ne12, ne13,
624
+ s1, s2, s3, s11, s12, s13,
625
+ item_ct1);
626
+ });
627
+ }
628
+ }
629
+ }
630
+ };
631
+
632
+ template <class op>
633
+ inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
634
+ const ggml_tensor *src1, ggml_tensor *dst,
635
+ const float *src0_dd, const float *src1_dd,
636
+ float *dst_dd,
637
+ const queue_ptr &main_stream) {
638
+
639
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
640
+ op()(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
641
+ } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
642
+ op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd,
643
+ (sycl::half *)dst_dd, main_stream);
644
+ } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
645
+ op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd, dst_dd,
646
+ main_stream);
647
+ } else if (src0->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
648
+ op()(ctx, src0, src1, dst, (const int32_t *)src0_dd, (const int32_t *)src1_dd, (int32_t *)dst_dd,
649
+ main_stream);
650
+ } else if (src0->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) {
651
+ op()(ctx, src0, src1, dst, (const int16_t *)src0_dd, (const int16_t *)src1_dd, (int16_t *)dst_dd,
652
+ main_stream);
653
+ } else {
654
+ fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
655
+ ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
656
+ GGML_ABORT("fatal error");
657
+ }
658
+ }
659
+
660
+
661
+ void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
662
+ const ggml_tensor *src1, ggml_tensor *dst,
663
+ const ggml_sycl_op_flatten_t op);
664
+
665
  #endif // GGML_SYCL_COMMON_HPP
ggml/src/ggml-sycl/concat.cpp CHANGED
@@ -106,6 +106,7 @@ static void concat_f32_sycl(const float *x, const float *y, float *dst,
106
  concat_f32_dim1(x, y, dst, ne0, ne01, item_ct1);
107
  });
108
  break;
 
109
  default:
110
  stream->parallel_for(
111
  sycl::nd_range<3>(gridDim *
 
106
  concat_f32_dim1(x, y, dst, ne0, ne01, item_ct1);
107
  });
108
  break;
109
+ // dim >=2 will be dispatched to the default path
110
  default:
111
  stream->parallel_for(
112
  sycl::nd_range<3>(gridDim *
ggml/src/ggml-sycl/element_wise.cpp ADDED
@@ -0,0 +1,1011 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "common.hpp"
2
+ #include "element_wise.hpp"
3
+
4
+ void acc_f32(const float * x, const float * y, float * dst, const int ne,
5
+ const int ne10, const int ne11, const int ne12,
6
+ const int nb1, const int nb2, int offset, const sycl::nd_item<3> &item_ct1) {
7
+ const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
8
+ item_ct1.get_local_id(2);
9
+ if (i >= ne) {
10
+ return;
11
+ }
12
+ int src1_idx = i - offset;
13
+ int oz = src1_idx / nb2;
14
+ int oy = (src1_idx - (oz * nb2)) / nb1;
15
+ int ox = src1_idx % nb1;
16
+ if (src1_idx >= 0 && ox < ne10 && oy < ne11 && oz < ne12) {
17
+ dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11];
18
+ } else {
19
+ dst[i] = x[i];
20
+ }
21
+ }
22
+
23
+ void gelu_f32(const float * x, float * dst, const int k,
24
+ const sycl::nd_item<3> &item_ct1) {
25
+ const float GELU_COEF_A = 0.044715f;
26
+ const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
27
+ const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
28
+ item_ct1.get_local_id(2);
29
+
30
+ if (i >= k) {
31
+ return;
32
+ }
33
+
34
+ float xi = x[i];
35
+ dst[i] = 0.5f * xi *
36
+ (1.0f +
37
+ sycl::tanh(SQRT_2_OVER_PI * xi * (1.0f + GELU_COEF_A * xi * xi)));
38
+ }
39
+
40
+ void silu_f32(const float * x, float * dst, const int k,
41
+ const sycl::nd_item<3> &item_ct1) {
42
+ const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
43
+ item_ct1.get_local_id(2);
44
+
45
+ if (i >= k) {
46
+ return;
47
+ }
48
+ dst[i] = x[i] / (1.0f + sycl::native::exp(-x[i]));
49
+ }
50
+
51
+ void gelu_quick_f32(const float *x, float *dst, int k,
52
+ const sycl::nd_item<3> &item_ct1) {
53
+ const float GELU_QUICK_COEF = -1.702f;
54
+ const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
55
+ item_ct1.get_local_id(2);
56
+ if (i >= k) {
57
+ return;
58
+ }
59
+ dst[i] = x[i] * (1.0f / (1.0f + sycl::native::exp(GELU_QUICK_COEF * x[i])));
60
+ }
61
+
62
+ void tanh_f32(const float *x, float *dst, int k,
63
+ const sycl::nd_item<3> &item_ct1) {
64
+ const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
65
+ item_ct1.get_local_id(2);
66
+ if (i >= k) {
67
+ return;
68
+ }
69
+ dst[i] = sycl::tanh((float)(x[i]));
70
+ }
71
+
72
+ void relu_f32(const float * x, float * dst, const int k,
73
+ const sycl::nd_item<3> &item_ct1) {
74
+ const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
75
+ item_ct1.get_local_id(2);
76
+
77
+ if (i >= k) {
78
+ return;
79
+ }
80
+ dst[i] = sycl::fmax((float)(x[i]), (float)0);
81
+ }
82
+
83
+ void sigmoid_f32(const float * x, float * dst, const int k,
84
+ const sycl::nd_item<3> &item_ct1) {
85
+ const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
86
+ item_ct1.get_local_id(2);
87
+
88
+ if (i >= k) {
89
+ return;
90
+ }
91
+ dst[i] = 1.0f / (1.0f + sycl::native::exp(-x[i]));
92
+ }
93
+
94
+ void sqrt_f32(const float * x, float * dst, const int k,
95
+ const sycl::nd_item<3> &item_ct1) {
96
+ const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
97
+ item_ct1.get_local_id(2);
98
+
99
+ if (i >= k) {
100
+ return;
101
+ }
102
+ dst[i] = sycl::sqrt(x[i]);
103
+ }
104
+
105
+ void sin_f32(const float * x, float * dst, const int k,
106
+ const sycl::nd_item<3> &item_ct1) {
107
+ const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
108
+ item_ct1.get_local_id(2);
109
+
110
+ if (i >= k) {
111
+ return;
112
+ }
113
+ dst[i] = sycl::sin(x[i]);
114
+ }
115
+
116
+ void cos_f32(const float * x, float * dst, const int k,
117
+ const sycl::nd_item<3> &item_ct1) {
118
+ const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
119
+ item_ct1.get_local_id(2);
120
+
121
+ if (i >= k) {
122
+ return;
123
+ }
124
+ dst[i] = sycl::cos(x[i]);
125
+ }
126
+
127
+ void hardsigmoid_f32(const float * x, float * dst, const int k,
128
+ const sycl::nd_item<3> &item_ct1) {
129
+ const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
130
+ item_ct1.get_local_id(2);
131
+
132
+ if (i >= k) {
133
+ return;
134
+ }
135
+ dst[i] = sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f));
136
+ }
137
+
138
+ void hardswish_f32(const float * x, float * dst, const int k,
139
+ const sycl::nd_item<3> &item_ct1) {
140
+ const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
141
+ item_ct1.get_local_id(2);
142
+
143
+ if (i >= k) {
144
+ return;
145
+ }
146
+ dst[i] = x[i] * sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f));
147
+ }
148
+
149
+ void exp_f32(const float * x, float * dst, const int k,
150
+ const sycl::nd_item<3> &item_ct1) {
151
+ const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
152
+ item_ct1.get_local_id(2);
153
+
154
+ if (i >= k) {
155
+ return;
156
+ }
157
+ dst[i] = sycl::exp(x[i]);
158
+ }
159
+
160
+ void log_f32(const float * x, float * dst, const int k,
161
+ const sycl::nd_item<3> &item_ct1) {
162
+ const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
163
+ item_ct1.get_local_id(2);
164
+
165
+ if (i >= k) {
166
+ return;
167
+ }
168
+ float xi = x[i];
169
+ if (xi <= 0) {
170
+ dst[i] = -INFINITY;
171
+ } else {
172
+ dst[i] = sycl::log(xi);
173
+ }
174
+ }
175
+
176
+ void neg_f32(const float * x, float * dst, const int k,
177
+ const sycl::nd_item<3> &item_ct1) {
178
+ const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
179
+ item_ct1.get_local_id(2);
180
+
181
+ if (i >= k) {
182
+ return;
183
+ }
184
+ dst[i] = -x[i];
185
+ }
186
+
187
+ void step_f32(const float * x, float * dst, const int k,
188
+ const sycl::nd_item<3> &item_ct1) {
189
+ const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
190
+ item_ct1.get_local_id(2);
191
+
192
+ if (i >= k) {
193
+ return;
194
+ }
195
+ dst[i] = x[i] > 0.0f;
196
+ }
197
+
198
+ void leaky_relu_f32(const float *x, float *dst, const int k, const float negative_slope,
199
+ const sycl::nd_item<3> &item_ct1) {
200
+ const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
201
+ item_ct1.get_local_id(2);
202
+ if (i >= k) {
203
+ return;
204
+ }
205
+ dst[i] = sycl::fmax((float)(x[i]), (float)0) +
206
+ sycl::fmin((float)(x[i]), 0.0f) * negative_slope;
207
+ }
208
+
209
+ void sqr_f32(const float * x, float * dst, const int k,
210
+ const sycl::nd_item<3> &item_ct1) {
211
+ const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
212
+ item_ct1.get_local_id(2);
213
+
214
+ if (i >= k) {
215
+ return;
216
+ }
217
+ dst[i] = x[i] * x[i];
218
+ }
219
+
220
+ void upscale_f32(const float *x, float *dst, const int nb00, const int nb01,
221
+ const int nb02, const int nb03, const int ne10, const int ne11,
222
+ const int ne12, const int ne13, const float sf0, const float sf1,
223
+ const float sf2, const float sf3, const sycl::nd_item<1> &item_ct1) {
224
+ int index = item_ct1.get_local_id(0) +
225
+ item_ct1.get_group(0) * item_ct1.get_local_range(0);
226
+ if (index >= ne10 * ne11 * ne12 * ne13) {
227
+ return;
228
+ }
229
+ // operation
230
+ int i10 = index % ne10;
231
+ int i11 = (index / ne10) % ne11;
232
+ int i12 = (index / (ne10 * ne11)) % ne12;
233
+ int i13 = (index / (ne10 * ne11 * ne12)) % ne13;
234
+
235
+ int i00 = i10 / sf0;
236
+ int i01 = i11 / sf1;
237
+ int i02 = i12 / sf2;
238
+ int i03 = i13 / sf3;
239
+
240
+ dst[index] = *(float *)((char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
241
+ }
242
+
243
+ void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02,
244
+ const sycl::nd_item<3> &item_ct1) {
245
+ int nidx = item_ct1.get_local_id(2) +
246
+ item_ct1.get_group(2) * item_ct1.get_local_range(2);
247
+ if (nidx >= ne0) {
248
+ return;
249
+ }
250
+
251
+ // operation
252
+ int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
253
+ item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
254
+ if (nidx < ne00 && item_ct1.get_group(1) < ne01 &&
255
+ item_ct1.get_group(0) < ne02) {
256
+ int offset_src = nidx + item_ct1.get_group(1) * ne00 +
257
+ item_ct1.get_group(0) * ne00 * ne01;
258
+ dst[offset_dst] = x[offset_src];
259
+ } else {
260
+ dst[offset_dst] = 0.0f;
261
+ }
262
+ }
263
+
264
+
265
+
266
+ void acc_f32_sycl(const float *x, const float *y, float *dst,
267
+ const int n_elements, const int ne10, const int ne11,
268
+ const int ne12, const int nb1, const int nb2,
269
+ const int offset, queue_ptr stream) {
270
+ int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE;
271
+ stream->parallel_for(
272
+ sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
273
+ sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE),
274
+ sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)),
275
+ [=](sycl::nd_item<3> item_ct1) {
276
+ acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset,
277
+ item_ct1);
278
+ });
279
+ }
280
+
281
+ void gelu_f32_sycl(const float *x, float *dst, const int k,
282
+ queue_ptr stream) {
283
+ const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
284
+ stream->parallel_for(
285
+ sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
286
+ sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
287
+ sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
288
+ [=](sycl::nd_item<3> item_ct1) {
289
+ gelu_f32(x, dst, k, item_ct1);
290
+ });
291
+ }
292
+
293
+ void silu_f32_sycl(const float *x, float *dst, const int k,
294
+ queue_ptr stream) {
295
+ const int num_blocks = (k + SYCL_SILU_BLOCK_SIZE - 1) / SYCL_SILU_BLOCK_SIZE;
296
+ stream->parallel_for(
297
+ sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
298
+ sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE),
299
+ sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE)),
300
+ [=](sycl::nd_item<3> item_ct1) {
301
+ silu_f32(x, dst, k, item_ct1);
302
+ });
303
+ }
304
+
305
+ void gelu_quick_f32_sycl(const float *x, float *dst, const int k,
306
+ queue_ptr stream) {
307
+ const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
308
+ stream->parallel_for(
309
+ sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
310
+ sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
311
+ sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
312
+ [=](sycl::nd_item<3> item_ct1) {
313
+ gelu_quick_f32(x, dst, k, item_ct1);
314
+ });
315
+ }
316
+
317
+ void tanh_f32_sycl(const float *x, float *dst, const int k,
318
+ queue_ptr stream) {
319
+ const int num_blocks = (k + SYCL_TANH_BLOCK_SIZE - 1) / SYCL_TANH_BLOCK_SIZE;
320
+ stream->parallel_for(
321
+ sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
322
+ sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE),
323
+ sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE)),
324
+ [=](sycl::nd_item<3> item_ct1) {
325
+ tanh_f32(x, dst, k, item_ct1);
326
+ });
327
+ }
328
+
329
+ void relu_f32_sycl(const float *x, float *dst, const int k,
330
+ queue_ptr stream) {
331
+ const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
332
+ stream->parallel_for(
333
+ sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
334
+ sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE),
335
+ sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)),
336
+ [=](sycl::nd_item<3> item_ct1) {
337
+ relu_f32(x, dst, k, item_ct1);
338
+ });
339
+ }
340
+
341
+ void hardsigmoid_f32_sycl(const float *x, float *dst, const int k,
342
+ queue_ptr stream) {
343
+ const int num_blocks = (k + SYCL_HARDSIGMOID_BLOCK_SIZE - 1) / SYCL_HARDSIGMOID_BLOCK_SIZE;
344
+ stream->parallel_for(
345
+ sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
346
+ sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE),
347
+ sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE)),
348
+ [=](sycl::nd_item<3> item_ct1) {
349
+ hardsigmoid_f32(x, dst, k, item_ct1);
350
+ });
351
+ }
352
+
353
+ void hardswish_f32_sycl(const float *x, float *dst, const int k,
354
+ queue_ptr stream) {
355
+ const int num_blocks = (k + SYCL_HARDSWISH_BLOCK_SIZE - 1) / SYCL_HARDSWISH_BLOCK_SIZE;
356
+ stream->parallel_for(
357
+ sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
358
+ sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE),
359
+ sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE)),
360
+ [=](sycl::nd_item<3> item_ct1) {
361
+ hardswish_f32(x, dst, k, item_ct1);
362
+ });
363
+ }
364
+
365
+ void exp_f32_sycl(const float *x, float *dst, const int k,
366
+ queue_ptr stream) {
367
+ const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE;
368
+ stream->parallel_for(
369
+ sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
370
+ sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE),
371
+ sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)),
372
+ [=](sycl::nd_item<3> item_ct1) {
373
+ exp_f32(x, dst, k, item_ct1);
374
+ });
375
+ }
376
+
377
+ void log_f32_sycl(const float *x, float *dst, const int k,
378
+ queue_ptr stream) {
379
+ const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE;
380
+ stream->parallel_for(
381
+ sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
382
+ sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE),
383
+ sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)),
384
+ [=](sycl::nd_item<3> item_ct1) {
385
+ log_f32(x, dst, k, item_ct1);
386
+ });
387
+ }
388
+
389
+ void neg_f32_sycl(const float *x, float *dst, const int k,
390
+ queue_ptr stream) {
391
+ const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE;
392
+ stream->parallel_for(
393
+ sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
394
+ sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE),
395
+ sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)),
396
+ [=](sycl::nd_item<3> item_ct1) {
397
+ neg_f32(x, dst, k, item_ct1);
398
+ });
399
+ }
400
+
401
+ void step_f32_sycl(const float *x, float *dst, const int k,
402
+ queue_ptr stream) {
403
+ const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE;
404
+ stream->parallel_for(
405
+ sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
406
+ sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE),
407
+ sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)),
408
+ [=](sycl::nd_item<3> item_ct1) {
409
+ step_f32(x, dst, k, item_ct1);
410
+ });
411
+ }
412
+
413
+ void sigmoid_f32_sycl(const float *x, float *dst, const int k,
414
+ queue_ptr stream) {
415
+ const int num_blocks = (k + SYCL_SIGMOID_BLOCK_SIZE - 1) / SYCL_SIGMOID_BLOCK_SIZE;
416
+ stream->parallel_for(
417
+ sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
418
+ sycl::range<3>(1, 1, SYCL_SIGMOID_BLOCK_SIZE),
419
+ sycl::range<3>(1, 1, SYCL_SIGMOID_BLOCK_SIZE)),
420
+ [=](sycl::nd_item<3> item_ct1) {
421
+ sigmoid_f32(x, dst, k, item_ct1);
422
+ });
423
+ }
424
+
425
+ void sqrt_f32_sycl(const float *x, float *dst, const int k,
426
+ queue_ptr stream) {
427
+ const int num_blocks = (k + SYCL_SQRT_BLOCK_SIZE - 1) / SYCL_SQRT_BLOCK_SIZE;
428
+ stream->parallel_for(
429
+ sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
430
+ sycl::range<3>(1, 1, SYCL_SQRT_BLOCK_SIZE),
431
+ sycl::range<3>(1, 1, SYCL_SQRT_BLOCK_SIZE)),
432
+ [=](sycl::nd_item<3> item_ct1) {
433
+ sqrt_f32(x, dst, k, item_ct1);
434
+ });
435
+ }
436
+
437
+ void sin_f32_sycl(const float *x, float *dst, const int k,
438
+ queue_ptr stream) {
439
+ const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE;
440
+ stream->parallel_for(
441
+ sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
442
+ sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE),
443
+ sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE)),
444
+ [=](sycl::nd_item<3> item_ct1) {
445
+ sin_f32(x, dst, k, item_ct1);
446
+ });
447
+ }
448
+
449
+ void cos_f32_sycl(const float *x, float *dst, const int k,
450
+ queue_ptr stream) {
451
+ const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE;
452
+ stream->parallel_for(
453
+ sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
454
+ sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE),
455
+ sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE)),
456
+ [=](sycl::nd_item<3> item_ct1) {
457
+ cos_f32(x, dst, k, item_ct1);
458
+ });
459
+ }
460
+
461
+ void leaky_relu_f32_sycl(const float *x, float *dst, const int k,
462
+ const float negative_slope,
463
+ queue_ptr stream) {
464
+ const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
465
+ stream->parallel_for(
466
+ sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
467
+ sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE),
468
+ sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)),
469
+ [=](sycl::nd_item<3> item_ct1) {
470
+ leaky_relu_f32(x, dst, k, negative_slope, item_ct1);
471
+ });
472
+ }
473
+
474
+ void sqr_f32_sycl(const float *x, float *dst, const int k,
475
+ queue_ptr stream) {
476
+ const int num_blocks = (k + SYCL_SQR_BLOCK_SIZE - 1) / SYCL_SQR_BLOCK_SIZE;
477
+ stream->parallel_for(
478
+ sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
479
+ sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE),
480
+ sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE)),
481
+ [=](sycl::nd_item<3> item_ct1) {
482
+ sqr_f32(x, dst, k, item_ct1);
483
+ });
484
+ }
485
+
486
+ void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01,
487
+ const int nb02, const int nb03, const int ne10, const int ne11,
488
+ const int ne12, const int ne13, const float sf0, const float sf1,
489
+ const float sf2, const float sf3, queue_ptr stream) {
490
+ int dst_size = ne10 * ne11 * ne12 * ne13;
491
+ int num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE;
492
+ sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE);
493
+ stream->parallel_for(
494
+ sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)),
495
+ [=](sycl::nd_item<1> item_ct1) {
496
+ upscale_f32(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, item_ct1);
497
+ });
498
+ }
499
+
500
+ void pad_f32_sycl(const float *x, float *dst, const int ne00,
501
+ const int ne01, const int ne02, const int ne0,
502
+ const int ne1, const int ne2, queue_ptr stream) {
503
+ int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE;
504
+ sycl::range<3> gridDim(ne2, ne1, num_blocks);
505
+ stream->parallel_for(
506
+ sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE),
507
+ sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)),
508
+ [=](sycl::nd_item<3> item_ct1) {
509
+ pad_f32(x, dst, ne0, ne00, ne01, ne02, item_ct1);
510
+ });
511
+ }
512
+
513
+ inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
514
+ ggml_tensor *dst, const float *src0_dd,
515
+ const float *src1_dd, float *dst_dd,
516
+ const queue_ptr &main_stream) {
517
+
518
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
519
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
520
+
521
+ silu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
522
+
523
+ (void) src1;
524
+ (void) dst;
525
+ (void) src1_dd;
526
+ }
527
+
528
+ inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
529
+ ggml_tensor *dst, const float *src0_dd,
530
+ const float *src1_dd, float *dst_dd,
531
+ const queue_ptr &main_stream) {
532
+
533
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
534
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
535
+
536
+ gelu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
537
+
538
+ (void) src1;
539
+ (void) dst;
540
+ (void) src1_dd;
541
+ }
542
+ inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
543
+ const ggml_tensor *src1, ggml_tensor *dst,
544
+ const float *src0_dd, const float *src1_dd,
545
+ float *dst_dd,
546
+ const queue_ptr &main_stream) {
547
+
548
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
549
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
550
+
551
+ gelu_quick_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
552
+
553
+ (void) src1;
554
+ (void) dst;
555
+ (void) src1_dd;
556
+ }
557
+
558
+ inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
559
+ ggml_tensor *dst, const float *src0_dd,
560
+ const float *src1_dd, float *dst_dd,
561
+ const queue_ptr &main_stream) {
562
+
563
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
564
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
565
+ tanh_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
566
+
567
+ (void) src1;
568
+ (void) dst;
569
+ (void) src1_dd;
570
+ }
571
+
572
+ inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
573
+ ggml_tensor *dst, const float *src0_dd,
574
+ const float *src1_dd, float *dst_dd,
575
+ const queue_ptr &main_stream) {
576
+
577
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
578
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
579
+
580
+ relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
581
+
582
+ (void) src1;
583
+ (void) dst;
584
+ (void) src1_dd;
585
+ }
586
+
587
+ inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
588
+ const ggml_tensor *src1, ggml_tensor *dst,
589
+ const float *src0_dd, const float *src1_dd,
590
+ float *dst_dd,
591
+ const queue_ptr &main_stream) {
592
+
593
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
594
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
595
+
596
+ hardsigmoid_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
597
+
598
+ (void) src1;
599
+ (void) dst;
600
+ (void) src1_dd;
601
+ }
602
+
603
+ inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
604
+ const ggml_tensor *src1, ggml_tensor *dst,
605
+ const float *src0_dd, const float *src1_dd,
606
+ float *dst_dd, const queue_ptr &main_stream) {
607
+
608
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
609
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
610
+
611
+ hardswish_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
612
+
613
+ (void) src1;
614
+ (void) dst;
615
+ (void) src1_dd;
616
+ }
617
+
618
+ inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
619
+ const ggml_tensor *src1, ggml_tensor *dst,
620
+ const float *src0_dd, const float *src1_dd,
621
+ float *dst_dd, const queue_ptr &main_stream) {
622
+
623
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
624
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
625
+
626
+ exp_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
627
+
628
+ (void) src1;
629
+ (void) dst;
630
+ (void) src1_dd;
631
+ }
632
+
633
+ inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
634
+ const ggml_tensor *src1, ggml_tensor *dst,
635
+ const float *src0_dd, const float *src1_dd,
636
+ float *dst_dd, const queue_ptr &main_stream) {
637
+
638
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
639
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
640
+
641
+ log_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
642
+
643
+ (void) src1;
644
+ (void) dst;
645
+ (void) src1_dd;
646
+ }
647
+
648
+ inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
649
+ const ggml_tensor *src1, ggml_tensor *dst,
650
+ const float *src0_dd, const float *src1_dd,
651
+ float *dst_dd, const queue_ptr &main_stream) {
652
+
653
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
654
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
655
+
656
+ sigmoid_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
657
+
658
+ (void) src1;
659
+ (void) dst;
660
+ (void) src1_dd;
661
+ }
662
+
663
+ inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
664
+ const ggml_tensor *src1, ggml_tensor *dst,
665
+ const float *src0_dd, const float *src1_dd,
666
+ float *dst_dd, const queue_ptr &main_stream) {
667
+
668
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
669
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
670
+
671
+ sqrt_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
672
+
673
+ (void) src1;
674
+ (void) dst;
675
+ (void) src1_dd;
676
+ }
677
+
678
+ inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
679
+ const ggml_tensor *src1, ggml_tensor *dst,
680
+ const float *src0_dd, const float *src1_dd,
681
+ float *dst_dd, const queue_ptr &main_stream) {
682
+
683
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
684
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
685
+
686
+ sin_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
687
+
688
+ (void) src1;
689
+ (void) dst;
690
+ (void) src1_dd;
691
+ }
692
+
693
+ inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
694
+ const ggml_tensor *src1, ggml_tensor *dst,
695
+ const float *src0_dd, const float *src1_dd,
696
+ float *dst_dd, const queue_ptr &main_stream) {
697
+
698
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
699
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
700
+
701
+ cos_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
702
+
703
+ (void) src1;
704
+ (void) dst;
705
+ (void) src1_dd;
706
+ }
707
+
708
+ inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
709
+ const ggml_tensor *src1, ggml_tensor *dst,
710
+ const float *src0_dd, const float *src1_dd,
711
+ float *dst_dd, const queue_ptr &main_stream) {
712
+
713
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
714
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
715
+
716
+ step_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
717
+
718
+ (void) src1;
719
+ (void) dst;
720
+ (void) src1_dd;
721
+ }
722
+
723
+ inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
724
+ const ggml_tensor *src1, ggml_tensor *dst,
725
+ const float *src0_dd, const float *src1_dd,
726
+ float *dst_dd, const queue_ptr &main_stream) {
727
+
728
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
729
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
730
+
731
+ neg_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
732
+
733
+ (void) src1;
734
+ (void) dst;
735
+ (void) src1_dd;
736
+ }
737
+
738
+ inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
739
+ const ggml_tensor *src1, ggml_tensor *dst,
740
+ const float *src0_dd, const float *src1_dd,
741
+ float *dst_dd,
742
+ const queue_ptr &main_stream) {
743
+
744
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
745
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
746
+
747
+ float negative_slope;
748
+ memcpy(&negative_slope, dst->op_params, sizeof(float));
749
+
750
+ leaky_relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), negative_slope, main_stream);
751
+
752
+ (void) src1;
753
+ (void) dst;
754
+ (void) src1_dd;
755
+ }
756
+
757
+ inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
758
+ ggml_tensor *dst, const float *src0_dd,
759
+ const float *src1_dd, float *dst_dd,
760
+ const queue_ptr &main_stream) {
761
+
762
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
763
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
764
+
765
+ sqr_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
766
+
767
+ (void) src1;
768
+ (void) dst;
769
+ (void) src1_dd;
770
+ }
771
+
772
+ inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
773
+ const ggml_tensor *src1, ggml_tensor *dst,
774
+ const float *src0_dd, const float *src1_dd,
775
+ float *dst_dd,
776
+ const queue_ptr &main_stream) {
777
+
778
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
779
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
780
+
781
+ const float sf0 = (float)dst->ne[0]/src0->ne[0];
782
+ const float sf1 = (float)dst->ne[1]/src0->ne[1];
783
+ const float sf2 = (float)dst->ne[2]/src0->ne[2];
784
+ const float sf3 = (float)dst->ne[3]/src0->ne[3];
785
+
786
+ upscale_f32_sycl(src0_dd, dst_dd, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
787
+ dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3,
788
+ main_stream);
789
+
790
+ (void) src1;
791
+ (void) dst;
792
+ (void) src1_dd;
793
+ }
794
+
795
+ inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
796
+ ggml_tensor *dst, const float *src0_dd,
797
+ const float *src1_dd, float *dst_dd,
798
+ const queue_ptr &main_stream) {
799
+
800
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
801
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
802
+ GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
803
+
804
+ pad_f32_sycl(src0_dd, dst_dd,
805
+ src0->ne[0], src0->ne[1], src0->ne[2],
806
+ dst->ne[0], dst->ne[1], dst->ne[2], main_stream);
807
+
808
+ (void) src1;
809
+ (void) dst;
810
+ (void) src1_dd;
811
+ }
812
+
813
+ inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
814
+ ggml_tensor *dst, const float *src0_dd,
815
+ const float *src1_dd, float *dst_dd,
816
+ const queue_ptr &main_stream) {
817
+
818
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
819
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
820
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
821
+ GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported
822
+
823
+ int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
824
+ int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
825
+ // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
826
+ int offset = dst->op_params[3] / 4; // offset in bytes
827
+
828
+ acc_f32_sycl(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], nb1, nb2, offset, main_stream);
829
+
830
+ (void) dst;
831
+ }
832
+
833
+ inline void ggml_sycl_op_add(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
834
+ ggml_tensor *dst, const float *src0_dd,
835
+ const float *src1_dd, float *dst_dd,
836
+ const queue_ptr &main_stream) {
837
+
838
+ ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_add>>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
839
+ }
840
+
841
+ inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
842
+ ggml_tensor *dst, const float *src0_dd,
843
+ const float *src1_dd, float *dst_dd,
844
+ const queue_ptr &main_stream) {
845
+
846
+ ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_sub>>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
847
+ }
848
+
849
+ inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
850
+ ggml_tensor *dst, const float *src0_dd,
851
+ const float *src1_dd, float *dst_dd,
852
+ const queue_ptr &main_stream) {
853
+
854
+ ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_mul>>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
855
+ }
856
+
857
+ inline void ggml_sycl_op_div(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
858
+ ggml_tensor *dst, const float *src0_dd,
859
+ const float *src1_dd, float *dst_dd,
860
+ const queue_ptr &main_stream) {
861
+
862
+ ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_div>>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
863
+ }
864
+
865
+
866
+ void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
867
+ GGML_SYCL_DEBUG("call %s\n", __func__);
868
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sqrt);
869
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
870
+ }
871
+
872
+ void ggml_sycl_sin(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
873
+ GGML_SYCL_DEBUG("call %s\n", __func__);
874
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sin);
875
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
876
+ }
877
+
878
+ void ggml_sycl_cos(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
879
+ GGML_SYCL_DEBUG("call %s\n", __func__);
880
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_cos);
881
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
882
+ }
883
+
884
+ void ggml_sycl_acc(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
885
+ GGML_SYCL_DEBUG("call %s\n", __func__);
886
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_acc);
887
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
888
+ }
889
+
890
+ void ggml_sycl_gelu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
891
+ GGML_SYCL_DEBUG("call %s\n", __func__);
892
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_gelu);
893
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
894
+ }
895
+
896
+ void ggml_sycl_silu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
897
+ GGML_SYCL_DEBUG("call %s\n", __func__);
898
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_silu);
899
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
900
+ }
901
+
902
+ void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
903
+ GGML_SYCL_DEBUG("call %s\n", __func__);
904
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_gelu_quick);
905
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
906
+ }
907
+
908
+ void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
909
+ GGML_SYCL_DEBUG("call %s\n", __func__);
910
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_tanh);
911
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
912
+ }
913
+
914
+ void ggml_sycl_relu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
915
+ GGML_SYCL_DEBUG("call %s\n", __func__);
916
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_relu);
917
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
918
+ }
919
+
920
+ void ggml_sycl_sigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
921
+ GGML_SYCL_DEBUG("call %s\n", __func__);
922
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sigmoid);
923
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
924
+ }
925
+
926
+ void ggml_sycl_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
927
+ GGML_SYCL_DEBUG("call %s\n", __func__);
928
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_hardsigmoid);
929
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
930
+ }
931
+
932
+ void ggml_sycl_hardswish(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
933
+ GGML_SYCL_DEBUG("call %s\n", __func__);
934
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_hardswish);
935
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
936
+ }
937
+
938
+
939
+ void ggml_sycl_exp(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
940
+ GGML_SYCL_DEBUG("call %s\n", __func__);
941
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_exp);
942
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
943
+ }
944
+
945
+ void ggml_sycl_log(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
946
+ GGML_SYCL_DEBUG("call %s\n", __func__);
947
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_log);
948
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
949
+ }
950
+
951
+ void ggml_sycl_neg(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
952
+ GGML_SYCL_DEBUG("call %s\n", __func__);
953
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_neg);
954
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
955
+ }
956
+
957
+ void ggml_sycl_step(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
958
+ GGML_SYCL_DEBUG("call %s\n", __func__);
959
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_step);
960
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
961
+ }
962
+
963
+ void ggml_sycl_leaky_relu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
964
+ GGML_SYCL_DEBUG("call %s\n", __func__);
965
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_leaky_relu);
966
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
967
+ }
968
+
969
+ void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
970
+ GGML_SYCL_DEBUG("call %s\n", __func__);
971
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sqr);
972
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
973
+ }
974
+
975
+ void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
976
+ GGML_SYCL_DEBUG("call %s\n", __func__);
977
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_upscale);
978
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
979
+ }
980
+
981
+ void ggml_sycl_pad(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
982
+ GGML_SYCL_DEBUG("call %s\n", __func__);
983
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_pad);
984
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
985
+ }
986
+
987
+
988
+
989
+ void ggml_sycl_add(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
990
+ GGML_SYCL_DEBUG("call %s\n", __func__);
991
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_add);
992
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
993
+ }
994
+
995
+ void ggml_sycl_sub(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
996
+ GGML_SYCL_DEBUG("call %s\n", __func__);
997
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sub);
998
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
999
+ }
1000
+
1001
+ void ggml_sycl_mul(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1002
+ GGML_SYCL_DEBUG("call %s\n", __func__);
1003
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_mul);
1004
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
1005
+ }
1006
+
1007
+ void ggml_sycl_div(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1008
+ GGML_SYCL_DEBUG("call %s\n", __func__);
1009
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_div);
1010
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
1011
+ }
ggml/src/ggml-sycl/element_wise.hpp ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef GGML_SYCL_ELEMENTWISE_HPP
2
+ #define GGML_SYCL_ELEMENTWISE_HPP
3
+
4
+ #include "common.hpp"
5
+
6
+ static __dpct_inline__ float op_repeat(const float a, const float b) {
7
+ return b;
8
+ GGML_UNUSED(a);
9
+ }
10
+
11
+ static __dpct_inline__ float op_add(const float a, const float b) {
12
+ return a + b;
13
+ }
14
+
15
+ static __dpct_inline__ float op_sub(const float a, const float b) {
16
+ return a - b;
17
+ }
18
+
19
+ static __dpct_inline__ float op_mul(const float a, const float b) {
20
+ return a * b;
21
+ }
22
+
23
+ static __dpct_inline__ float op_div(const float a, const float b) {
24
+ return a / b;
25
+ }
26
+
27
+
28
+ void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
29
+
30
+ void ggml_sycl_sin(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
31
+
32
+ void ggml_sycl_cos(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
33
+
34
+ void ggml_sycl_acc(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
35
+
36
+ void ggml_sycl_gelu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
37
+
38
+ void ggml_sycl_silu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
39
+
40
+ void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
41
+
42
+ void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
43
+
44
+ void ggml_sycl_relu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
45
+
46
+ void ggml_sycl_sigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
47
+
48
+ void ggml_sycl_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
49
+
50
+ void ggml_sycl_hardswish(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
51
+
52
+ void ggml_sycl_exp(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
53
+
54
+ void ggml_sycl_log(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
55
+
56
+ void ggml_sycl_neg(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
57
+
58
+ void ggml_sycl_step(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
59
+
60
+ void ggml_sycl_leaky_relu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
61
+
62
+ void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
63
+
64
+ void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
65
+
66
+ void ggml_sycl_pad(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
67
+
68
+ void ggml_sycl_add(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
69
+
70
+ void ggml_sycl_sub(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
71
+
72
+ void ggml_sycl_mul(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
73
+
74
+ void ggml_sycl_div(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
75
+
76
+ #endif // GGML_SYCL_ELEMENTWISE_HPP
ggml/src/ggml-sycl/outprod.cpp ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <sycl/sycl.hpp>
2
+ #include "outprod.hpp"
3
+
4
+
5
+ void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
6
+ const ggml_tensor* src1, ggml_tensor* dst) {
7
+
8
+
9
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
10
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
11
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
12
+ GGML_ASSERT(ggml_is_contiguous(src0));
13
+ GGML_ASSERT(ggml_is_contiguous(dst));
14
+
15
+ GGML_TENSOR_BINARY_OP_LOCALS
16
+
17
+ // Get SYCL queue
18
+ dpct::queue_ptr stream = ctx.stream();
19
+
20
+ // Dimension checks
21
+ GGML_ASSERT(ne01 == ne11); // Inner dimensions must match
22
+ GGML_ASSERT(ne0 == ne00); // Output rows match src0 rows
23
+ GGML_ASSERT(ne1 == ne10); // Output cols match src1 cols
24
+
25
+ // Get data pointers
26
+ const float* src0_d = (const float*)src0->data;
27
+ const float* src1_d = (const float*)src1->data;
28
+ float* dst_d = (float*)dst->data;
29
+
30
+ // GEMM parameters
31
+ const float alpha = 1.0f;
32
+ const float beta = 0.0f;
33
+
34
+ // Handle transposition of src1
35
+ const bool src1_T = ggml_is_transposed(src1);
36
+ const oneapi::mkl::transpose src1_op =
37
+ src1_T ? oneapi::mkl::transpose::nontrans : oneapi::mkl::transpose::trans;
38
+ const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float);
39
+
40
+ try {
41
+ // Perform matrix multiplication using oneMKL GEMM
42
+ oneapi::mkl::blas::gemm(*stream,
43
+ oneapi::mkl::transpose::nontrans, src1_op,
44
+ ne0, ne1, ne01,
45
+ alpha,
46
+ src0_d, ne00,
47
+ src1_d, ldb,
48
+ beta,
49
+ dst_d, ne0);
50
+ }
51
+ catch (sycl::exception const& exc) {
52
+ std::cerr << exc.what() << std::endl;
53
+ GGML_ASSERT(false);
54
+ }
55
+ }
ggml/src/ggml-sycl/outprod.hpp ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef GGML_SYCL_OUTPROD_HPP
2
+ #define GGML_SYCL_OUTPROD_HPP
3
+
4
+ #include "common.hpp"
5
+
6
+ void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
7
+ const ggml_tensor* src1, ggml_tensor* dst);
8
+
9
+
10
+ #endif // GGML_SYCL_OUTPROD_HPP
11
+
ggml/src/ggml-sycl/presets.hpp CHANGED
@@ -25,6 +25,11 @@
25
  #define SYCL_RELU_BLOCK_SIZE 256
26
  #define SYCL_HARDSIGMOID_BLOCK_SIZE 256
27
  #define SYCL_HARDSWISH_BLOCK_SIZE 256
 
 
 
 
 
28
  #define SYCL_SQR_BLOCK_SIZE 256
29
  #define SYCL_CPY_BLOCK_SIZE 32
30
  #define SYCL_SCALE_BLOCK_SIZE 256
@@ -41,6 +46,7 @@
41
  #define SYCL_ACC_BLOCK_SIZE 256
42
  #define SYCL_IM2COL_BLOCK_SIZE 256
43
  #define SYCL_POOL2D_BLOCK_SIZE 256
 
44
  #define SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE 256
45
  #define SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE 256
46
 
 
25
  #define SYCL_RELU_BLOCK_SIZE 256
26
  #define SYCL_HARDSIGMOID_BLOCK_SIZE 256
27
  #define SYCL_HARDSWISH_BLOCK_SIZE 256
28
+ #define SYCL_EXP_BLOCK_SIZE 256
29
+ #define SYCL_NEG_BLOCK_SIZE 256
30
+ #define SYCL_SIGMOID_BLOCK_SIZE 256
31
+ #define SYCL_SQRT_BLOCK_SIZE 256
32
+ #define SYCL_SIN_BLOCK_SIZE 256
33
  #define SYCL_SQR_BLOCK_SIZE 256
34
  #define SYCL_CPY_BLOCK_SIZE 32
35
  #define SYCL_SCALE_BLOCK_SIZE 256
 
46
  #define SYCL_ACC_BLOCK_SIZE 256
47
  #define SYCL_IM2COL_BLOCK_SIZE 256
48
  #define SYCL_POOL2D_BLOCK_SIZE 256
49
+ #define SYCL_ARGMAX_BLOCK_SIZE 256
50
  #define SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE 256
51
  #define SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE 256
52
 
ggml/src/ggml-sycl/wkv6.cpp ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <sycl/sycl.hpp>
2
+ #include "wkv6.hpp"
3
+
4
+ constexpr int WKV_BLOCK_SIZE = 64; // Matching CUDA_WKV_BLOCK_SIZE
5
+
6
+ // Helper function for the main kernel
7
+ static void rwkv_wkv_f32_kernel(
8
+ const int B, const int T, const int C, const int H,
9
+ const float* k, const float* v, const float* r,
10
+ const float* tf, const float* td, const float* s,
11
+ float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) {
12
+
13
+ const int tid = item_ct1.get_local_id(2);
14
+ const int bid = item_ct1.get_group(2);
15
+
16
+ const int head_size = WKV_BLOCK_SIZE;
17
+ const int batch_i = bid / H;
18
+ const int head_i = bid % H;
19
+ const int state_size = C * head_size;
20
+ const int n_seq_tokens = T / B;
21
+
22
+ // Set up shared memory pointers
23
+ float* _k = shared_mem;
24
+ float* _r = _k + head_size;
25
+ float* _tf = _r + head_size;
26
+ float* _td = _tf + head_size;
27
+
28
+ // Local state array
29
+ float state[WKV_BLOCK_SIZE];
30
+
31
+ // Load initial state
32
+ #pragma unroll
33
+ for (int i = 0; i < head_size; i++) {
34
+ state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
35
+ }
36
+
37
+ // Sync threads before shared memory operations
38
+ item_ct1.barrier(sycl::access::fence_space::local_space);
39
+
40
+ // Load time-mixing parameters
41
+ _tf[tid] = tf[head_i * head_size + tid];
42
+ item_ct1.barrier(sycl::access::fence_space::local_space);
43
+
44
+ // Main sequence processing loop
45
+ for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
46
+ t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid;
47
+ t += C) {
48
+
49
+ item_ct1.barrier(sycl::access::fence_space::local_space);
50
+
51
+ // Load current timestep data to shared memory
52
+ _k[tid] = k[t];
53
+ _r[tid] = r[t];
54
+ _td[tid] = td[t];
55
+
56
+ item_ct1.barrier(sycl::access::fence_space::local_space);
57
+
58
+ const float _v = v[t];
59
+ float y = 0;
60
+
61
+ // Process in chunks of 4 for better vectorization
62
+ sycl::float4 k4, r4, tf4, td4, s4, kv4;
63
+ #pragma unroll
64
+ for (int j = 0; j < head_size; j += 4) {
65
+ // Load data in vec4 chunks
66
+ k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
67
+ r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
68
+ tf4 = sycl::float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
69
+ td4 = sycl::float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
70
+ s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
71
+
72
+ // Compute key-value product
73
+ sycl::float4 kv4 = k4 * _v;
74
+
75
+ // Accumulate weighted sum
76
+ y += sycl::dot(r4, tf4 * kv4 + s4);
77
+
78
+ // Update state
79
+ s4 = s4 * td4 + kv4;
80
+
81
+ // Store updated state
82
+ state[j] = s4.x();
83
+ state[j+1] = s4.y();
84
+ state[j+2] = s4.z();
85
+ state[j+3] = s4.w();
86
+ }
87
+
88
+ dst[t] = y;
89
+ }
90
+
91
+ // Save final state
92
+ #pragma unroll
93
+ for (int i = 0; i < head_size; i++) {
94
+ dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
95
+ }
96
+ }
97
+
98
+ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
99
+ const ggml_tensor* src1, ggml_tensor* dst) {
100
+
101
+ const float* k_d = (const float*)dst->src[0]->data;
102
+ const float* v_d = (const float*)dst->src[1]->data;
103
+ const float* r_d = (const float*)dst->src[2]->data;
104
+ const float* tf_d = (const float*)dst->src[3]->data;
105
+ const float* td_d = (const float*)dst->src[4]->data;
106
+ const float* s_d = (const float*)dst->src[5]->data;
107
+ float* dst_d = (float*)dst->data;
108
+
109
+ const int64_t B = dst->src[5]->ne[1];
110
+ const int64_t T = dst->src[0]->ne[3];
111
+ const int64_t C = dst->ne[0];
112
+ const int64_t H = dst->src[0]->ne[2];
113
+
114
+ GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
115
+ GGML_ASSERT(C % H == 0);
116
+ GGML_ASSERT(C / H == WKV_BLOCK_SIZE); // The current sycl kernel is designed for RWKV6, HEAD_SIZE == 64
117
+
118
+ dpct::queue_ptr stream = ctx.stream();
119
+
120
+ // Calculate execution configuration
121
+ const size_t shared_mem_size = WKV_BLOCK_SIZE * 4 * sizeof(float); // For k, r, tf, td
122
+ sycl::range<3> block_dims(1, 1, C / H);
123
+ sycl::range<3> grid_dims(1, 1, B * H);
124
+
125
+ // Submit kernel
126
+ stream->submit([&](sycl::handler& cgh) {
127
+ sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
128
+
129
+ cgh.parallel_for(
130
+ sycl::nd_range<3>(grid_dims * block_dims, block_dims),
131
+ [=](sycl::nd_item<3> item_ct1) {
132
+ rwkv_wkv_f32_kernel(
133
+ B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
134
+ item_ct1, shared_mem_acc.get_pointer()
135
+ );
136
+ });
137
+ });
138
+ }
ggml/src/ggml-sycl/wkv6.hpp ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef GGML_SYCL_WKV6_HPP
2
+ #define GGML_SYCL_WKV6_HPP
3
+
4
+ #include "common.hpp"
5
+
6
+ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
7
+ const ggml_tensor *src1, ggml_tensor * dst);
8
+
9
+
10
+ #endif // GGML_SYCL_WKV6_HPP
ggml/src/ggml.c CHANGED
@@ -975,7 +975,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
975
  "WIN_UNPART",
976
  "GET_REL_POS",
977
  "ADD_REL_POS",
978
- "RWKV_WKV",
979
 
980
  "UNARY",
981
 
@@ -1070,7 +1070,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
1070
  "win_unpart(x)",
1071
  "get_rel_pos(x)",
1072
  "add_rel_pos(x)",
1073
- "rwkv_wkv(k, v, r, tf, td, s)",
1074
 
1075
  "unary(x)",
1076
 
@@ -4503,9 +4503,9 @@ struct ggml_tensor * ggml_add_rel_pos_inplace(
4503
  return ggml_add_rel_pos_impl(ctx, a, pw, ph, true);
4504
  }
4505
 
4506
- // ggml_rwkv_wkv
4507
 
4508
- struct ggml_tensor * ggml_rwkv_wkv(
4509
  struct ggml_context * ctx,
4510
  struct ggml_tensor * k,
4511
  struct ggml_tensor * v,
@@ -4537,7 +4537,7 @@ struct ggml_tensor * ggml_rwkv_wkv(
4537
  const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
4538
  struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
4539
 
4540
- result->op = GGML_OP_RWKV_WKV;
4541
  result->src[0] = k;
4542
  result->src[1] = v;
4543
  result->src[2] = r;
@@ -6084,7 +6084,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
6084
  } break;
6085
  case GGML_OP_GET_REL_POS:
6086
  case GGML_OP_ADD_REL_POS:
6087
- case GGML_OP_RWKV_WKV:
6088
  case GGML_OP_MAP_UNARY:
6089
  case GGML_OP_MAP_BINARY:
6090
  case GGML_OP_MAP_CUSTOM1_F32:
 
975
  "WIN_UNPART",
976
  "GET_REL_POS",
977
  "ADD_REL_POS",
978
+ "RWKV_WKV6",
979
 
980
  "UNARY",
981
 
 
1070
  "win_unpart(x)",
1071
  "get_rel_pos(x)",
1072
  "add_rel_pos(x)",
1073
+ "rwkv_wkv6(k, v, r, tf, td, s)",
1074
 
1075
  "unary(x)",
1076
 
 
4503
  return ggml_add_rel_pos_impl(ctx, a, pw, ph, true);
4504
  }
4505
 
4506
+ // ggml_rwkv_wkv6
4507
 
4508
+ struct ggml_tensor * ggml_rwkv_wkv6(
4509
  struct ggml_context * ctx,
4510
  struct ggml_tensor * k,
4511
  struct ggml_tensor * v,
 
4537
  const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
4538
  struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
4539
 
4540
+ result->op = GGML_OP_RWKV_WKV6;
4541
  result->src[0] = k;
4542
  result->src[1] = v;
4543
  result->src[2] = r;
 
6084
  } break;
6085
  case GGML_OP_GET_REL_POS:
6086
  case GGML_OP_ADD_REL_POS:
6087
+ case GGML_OP_RWKV_WKV6:
6088
  case GGML_OP_MAP_UNARY:
6089
  case GGML_OP_MAP_BINARY:
6090
  case GGML_OP_MAP_CUSTOM1_F32: