Alberto Cabrera Pérez commited on
Commit
918eff5
·
1 Parent(s): dc59673

sycl: addressing non-contiguous src1 mul_mats (nc and batched) (llama/13343)

Browse files

* sycl: fixed non-contiguous src1 mul_mats (nc and batched)

* Fixed wrong static_cast inside kernel

ggml/src/ggml-sycl/common.hpp CHANGED
@@ -114,17 +114,12 @@ static void crash() {
114
  GGML_ABORT("SYCL error");
115
  }
116
 
117
- #define SYCL_CHECK(err) \
118
- do { \
119
- auto err_ = (err); \
120
- if (err_ != 0) \
121
- ggml_sycl_error( \
122
- #err, \
123
- __func__, \
124
- __FILE__, \
125
- __LINE__, \
126
- "Meet error in this line code!"); \
127
- } while (0)
128
 
129
  #if DPCT_COMPAT_RT_VERSION >= 11100
130
  #define GGML_SYCL_ASSUME(x) __builtin_assume(x)
 
114
  GGML_ABORT("SYCL error");
115
  }
116
 
117
+ #define SYCL_CHECK(err) \
118
+ do { \
119
+ auto err_ = (err); \
120
+ if (err_ != 0) \
121
+ ggml_sycl_error(#err, __func__, __FILE__, __LINE__, "Exception caught in this line of code."); \
122
+ } while (0)
 
 
 
 
 
123
 
124
  #if DPCT_COMPAT_RT_VERSION >= 11100
125
  #define GGML_SYCL_ASSUME(x) __builtin_assume(x)
ggml/src/ggml-sycl/convert.cpp CHANGED
@@ -437,41 +437,52 @@ static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int64_t k
437
  }
438
 
439
  template <typename src_t, typename dst_t>
440
- static void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k,
441
- const sycl::nd_item<3> &item_ct1) {
 
 
442
  const int64_t work_group_size = item_ct1.get_local_range(2);
443
- const int64_t global_id = item_ct1.get_local_id(2) + work_group_size * item_ct1.get_group(2);
 
 
 
 
444
 
445
  // make each work-item deal with more elements since sycl global range can not exceed max int
446
- const src_t * x = (const src_t *) vx;
447
- for (int64_t i = global_id; i < k; i += work_group_size * item_ct1.get_group_range(2)) {
448
- y[i] = x[i];
 
 
 
 
449
  }
450
  }
451
 
452
  template <typename src_t, typename dst_t>
453
- static void convert_unary_sycl(const void *__restrict__ vx,
454
- dst_t *__restrict__ y, const int64_t k,
455
- dpct::queue_ptr stream) {
456
- const int64_t num_blocks = (k + SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / SYCL_DEQUANTIZE_BLOCK_SIZE;
 
 
457
 
458
  // decrease global range when it exceeds the max int
459
- int64_t local_size = downsample_sycl_global_range(num_blocks, SYCL_DEQUANTIZE_BLOCK_SIZE);
460
- sycl::range<3> block_nums(1, 1, num_blocks);
461
- sycl::range<3> local_range(1, 1, local_size);
462
- {
463
- dpct::has_capability_or_fail(stream->get_device(),
464
- {sycl::aspect::fp16});
465
 
466
- stream->parallel_for(
467
- sycl::nd_range<3>(block_nums * local_range, local_range),
468
- [=](sycl::nd_item<3> item_ct1) {
469
- convert_unary<src_t>(vx, y, k, item_ct1);
470
- });
471
- }
472
  }
473
 
474
- to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor *dst) {
 
 
 
 
 
475
  switch (type) {
476
  case GGML_TYPE_Q4_0:
477
  if (dst->src[0]->extra &&
@@ -574,3 +585,12 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
574
  return nullptr;
575
  }
576
  }
 
 
 
 
 
 
 
 
 
 
437
  }
438
 
439
  template <typename src_t, typename dst_t>
440
+ static void convert_unary_nc(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01,
441
+ const int64_t ne02, const int64_t s01, const int64_t s02, const int64_t s03,
442
+ const sycl::nd_item<3> & item_ct1) {
443
+
444
  const int64_t work_group_size = item_ct1.get_local_range(2);
445
+ const int64_t global_id = item_ct1.get_local_id(2) + work_group_size * item_ct1.get_group(2);
446
+
447
+ const int64_t i01 = item_ct1.get_group(1);
448
+ const int64_t i02 = item_ct1.get_group(0) % ne02;
449
+ const int64_t i03 = item_ct1.get_group(0) / ne02;
450
 
451
  // make each work-item deal with more elements since sycl global range can not exceed max int
452
+ const src_t * x = static_cast<const src_t *>(vx);
453
+ const int64_t ix = i03 * s03 + i02 * s02 + i01 * s01;
454
+ const int64_t iy = ((i03 * ne02 + i02) * ne01 + i01) * ne00;
455
+
456
+ #pragma unroll
457
+ for (int64_t i00 = global_id; i00 < ne00; i00 += work_group_size * item_ct1.get_group_range(2)) {
458
+ y[iy + i00] = static_cast<dst_t>(x[ix + i00]);
459
  }
460
  }
461
 
462
  template <typename src_t, typename dst_t>
463
+ static void convert_unary_nc_sycl(const void * __restrict__ vx, dst_t * __restrict__ y,
464
+ const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
465
+ const int64_t s01, const int64_t s02, const int64_t s03, dpct::queue_ptr queue) {
466
+ dpct::has_capability_or_fail(queue->get_device(), { sycl::aspect::fp16 });
467
+
468
+ sycl::range<3> global_size(ne02 * ne03, ne01, ceil_div(ne00, SYCL_DEQUANTIZE_BLOCK_SIZE));
469
 
470
  // decrease global range when it exceeds the max int
471
+ // TODO: Downsample logic is separated from the kernel, a rewrite is desirable
472
+ int64_t downsized_workgroup = downsample_sycl_global_range(global_size[0], SYCL_DEQUANTIZE_BLOCK_SIZE);
473
+ sycl::range<3> workgroup_size(1, 1, downsized_workgroup);
 
 
 
474
 
475
+ queue->parallel_for(sycl::nd_range<3>(global_size * workgroup_size, workgroup_size), [=](sycl::nd_item<3> item_ct1) {
476
+ convert_unary_nc<src_t>(vx, y, ne00, ne01, ne02, s01, s02, s03, item_ct1);
477
+ });
 
 
 
478
  }
479
 
480
+ template <typename src_t, typename dst_t>
481
+ static void convert_unary_sycl(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr queue) {
482
+ convert_unary_nc_sycl<src_t>(vx, y, k, 1, 1, 1, k, k, k, queue);
483
+ }
484
+
485
+ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
486
  switch (type) {
487
  case GGML_TYPE_Q4_0:
488
  if (dst->src[0]->extra &&
 
585
  return nullptr;
586
  }
587
  }
588
+
589
+ to_fp16_nc_sycl_t get_to_fp16_nc_sycl(ggml_type type) {
590
+ switch (type) {
591
+ case GGML_TYPE_F32:
592
+ return convert_unary_nc_sycl<float>;
593
+ default:
594
+ return nullptr;
595
+ }
596
+ }
ggml/src/ggml-sycl/convert.hpp CHANGED
@@ -1,6 +1,6 @@
1
  //
2
  // MIT license
3
- // Copyright (C) 2024 Intel Corporation
4
  // SPDX-License-Identifier: MIT
5
  //
6
 
@@ -16,12 +16,19 @@
16
  #include "common.hpp"
17
 
18
  template <typename T>
19
- using to_t_sycl_t = void (*)(const void *__restrict__ x, T *__restrict__ y,
20
- int64_t k, dpct::queue_ptr stream);
21
- typedef to_t_sycl_t<float> to_fp32_sycl_t;
22
  typedef to_t_sycl_t<sycl::half> to_fp16_sycl_t;
23
 
24
- to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor *dst);
25
- to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst);
26
 
27
- #endif // GGML_SYCL_CONVERT_HPP
 
 
 
 
 
 
 
 
 
1
  //
2
  // MIT license
3
+ // Copyright (C) 2025 Intel Corporation
4
  // SPDX-License-Identifier: MIT
5
  //
6
 
 
16
  #include "common.hpp"
17
 
18
  template <typename T>
19
+ using to_t_sycl_t = void (*)(const void * __restrict__ x, T * __restrict__ y, int64_t k, dpct::queue_ptr stream);
20
+ typedef to_t_sycl_t<float> to_fp32_sycl_t;
 
21
  typedef to_t_sycl_t<sycl::half> to_fp16_sycl_t;
22
 
23
+ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst);
24
+ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor * dst);
25
 
26
+ // Nc = Non-contiguous
27
+ template <typename T>
28
+ using to_t_nc_sycl_t = void (*)(const void * x, T * y, int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03,
29
+ int64_t s01, int64_t s02, int64_t s03, dpct::queue_ptr queue);
30
+
31
+ typedef to_t_nc_sycl_t<sycl::half> to_fp16_nc_sycl_t;
32
+ to_fp16_nc_sycl_t get_to_fp16_nc_sycl(ggml_type type);
33
+
34
+ #endif // GGML_SYCL_CONVERT_HPP
ggml/src/ggml-sycl/ggml-sycl.cpp CHANGED
@@ -2694,35 +2694,31 @@ catch (sycl::exception const &exc) {
2694
  std::exit(1);
2695
  }
2696
 
2697
- static void k_compute_batched_ptrs(const sycl::half *src0_as_f16,
2698
- const sycl::half *src1_as_f16, char *dst,
2699
- const void **ptrs_src, void **ptrs_dst,
2700
- int64_t ne12, int64_t ne13, int64_t ne23,
2701
- size_t nb02, size_t nb03, size_t nb12,
2702
- size_t nb13, size_t nbd2, size_t nbd3,
2703
- int64_t r2, int64_t r3,
2704
- const sycl::nd_item<3> &item_ct1) {
2705
- int64_t i13 = item_ct1.get_group(2) * item_ct1.get_local_range(2) +
2706
- item_ct1.get_local_id(2);
2707
- int64_t i12 = item_ct1.get_group(1) * item_ct1.get_local_range(1) +
2708
- item_ct1.get_local_id(1);
2709
 
2710
  if (i13 >= ne13 || i12 >= ne12) {
2711
  return;
2712
  }
2713
 
2714
- int64_t i03 = i13 / r3;
2715
- int64_t i02 = i12 / r2;
 
 
 
 
2716
 
2717
- ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
2718
- ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12 + i13*nb13;
2719
- ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
2720
  }
2721
 
2722
- static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
2723
- const ggml_tensor *src0,
2724
- const ggml_tensor *src1,
2725
- ggml_tensor *dst) try {
2726
  GGML_ASSERT(!ggml_is_transposed(src0));
2727
  GGML_ASSERT(!ggml_is_transposed(src1));
2728
  GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
@@ -2730,102 +2726,100 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
2730
 
2731
  GGML_TENSOR_BINARY_OP_LOCALS
2732
 
 
 
 
2733
 
2734
  SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2735
- queue_ptr main_stream = ctx.stream();;
2736
 
2737
- void * src0_ddq = src0->data;
2738
- sycl::half *src0_as_f16 = (sycl::half *)src0_ddq;
2739
- float * src1_ddf = (float *) src1->data;
2740
- float * dst_ddf = (float *) dst->data;
2741
 
2742
- // convert src1 to fp16
 
 
 
 
 
 
 
 
 
 
2743
  ggml_sycl_pool_alloc<sycl::half> src1_f16_alloc(ctx.pool());
 
 
2744
  if (src1->type != GGML_TYPE_F16) {
2745
- const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);
 
2746
  const int64_t ne_src1 = ggml_nelements(src1);
2747
  src1_f16_alloc.alloc(ne_src1);
2748
- GGML_ASSERT(to_fp16_sycl != nullptr);
2749
- to_fp16_sycl(src1_ddf, src1_f16_alloc.get(), ne_src1, main_stream);
 
 
 
 
2750
  }
2751
- sycl::half *src1_f16 = src1->type == GGML_TYPE_F16 ? (sycl::half *)src1_ddf
2752
- : src1_f16_alloc.get();
2753
 
2754
- char * dst_t;
 
2755
 
2756
- dpct::library_data_t cu_compute_type = dpct::library_data_t::real_float;
2757
- dpct::library_data_t cu_data_type = dpct::library_data_t::real_float;
2758
 
2759
  // dst strides
2760
  size_t nbd2 = dst->nb[2];
2761
  size_t nbd3 = dst->nb[3];
2762
 
2763
  const float alpha_f32 = 1.0f;
2764
- const float beta_f32 = 0.0f;
2765
 
2766
  const void * alpha = &alpha_f32;
2767
  const void * beta = &beta_f32;
2768
 
2769
- dst_t = (char *) dst_ddf;
2770
-
2771
  GGML_ASSERT(ne12 % ne02 == 0);
2772
  GGML_ASSERT(ne13 % ne03 == 0);
2773
 
2774
  // broadcast factors
2775
- const int64_t r2 = ne12/ne02;
2776
- const int64_t r3 = ne13/ne03;
2777
 
2778
  if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
2779
  // there is no broadcast and src0, src1 are contiguous across dims 2, 3
2780
- SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
2781
- *main_stream, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
2782
- (const char *) src0_as_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
2783
- (const char *) src1_f16, dpct::library_data_t::real_half, nb11 / nb10, nb12 / nb10, beta, (char *) dst_t,
2784
- cu_data_type, ne01, nb2 / nb0, ne12 * ne13, cu_compute_type)));
2785
  } else {
2786
- const int ne23 = ne12*ne13;
2787
 
2788
- ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
2789
- ggml_sycl_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
2790
  ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1);
2791
 
2792
  sycl::range<3> block_dims(1, ne12, ne13);
2793
- /*
2794
- DPCT1049:47: The work-group size passed to the SYCL kernel may exceed
2795
- the limit. To get the device limit, query
2796
- info::device::max_work_group_size. Adjust the work-group size if needed.
2797
- */
2798
- {
2799
- dpct::has_capability_or_fail(main_stream->get_device(),
2800
- {sycl::aspect::fp16});
2801
-
2802
- main_stream->submit([&](sycl::handler &cgh) {
2803
- const void **ptrs_src_get = ptrs_src.get();
2804
- void **ptrs_dst_get = ptrs_dst.get();
2805
- size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : nb12 / 2;
2806
- size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : nb13 / 2;
2807
- cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims),
2808
- [=](sycl::nd_item<3> item_ct1) {
2809
- k_compute_batched_ptrs(
2810
- src0_as_f16, src1_f16,
2811
- dst_t, ptrs_src_get,
2812
- ptrs_dst_get, ne12, ne13, ne23,
2813
- nb02, nb03, nb12_scaled, nb13_scaled,
2814
- nbd2, nbd3, r2, r3, item_ct1);
2815
- });
2816
  });
2817
- }
 
2818
  SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
2819
- *main_stream, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
2820
  (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
2821
- (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, nb11 / nb10, beta,
2822
- (void **) (ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, cu_compute_type, matrix_info.get())));
2823
  }
2824
- }
2825
- catch (sycl::exception const &exc) {
2826
- std::cerr << exc.what() << "Exception caught at file:" << __FILE__
2827
- << ", line:" << __LINE__ << std::endl;
2828
- std::exit(1);
2829
  }
2830
 
2831
  inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
@@ -2966,7 +2960,7 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
2966
  // The kernel from the if path is faster for that specific case, but does not support all mul mats.
2967
  ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
2968
  }
2969
- } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
2970
  // KQV single-batch
2971
  ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst);
2972
  } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
@@ -3873,9 +3867,6 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
3873
  if (a->ne[3] != b->ne[3]) {
3874
  return false;
3875
  }
3876
- if (!ggml_is_contiguous(b)) {
3877
- return false;
3878
- }
3879
  ggml_type a_type = a->type;
3880
  if (a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ4_XS ||
3881
  a_type == GGML_TYPE_IQ3_XXS || a_type == GGML_TYPE_IQ3_S ||
 
2694
  std::exit(1);
2695
  }
2696
 
2697
+ static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::half * src1_as_f16, char * dst,
2698
+ const void ** ptrs_src, void ** ptrs_dst, int64_t ne12, int64_t ne13, int64_t ne23,
2699
+ size_t nb02, size_t nb03, size_t nb12, size_t nb13, size_t nbd2, size_t nbd3,
2700
+ int64_t r2, int64_t r3, const sycl::nd_item<3> & item_ct1) {
2701
+ const int64_t i13 = item_ct1.get_group(2) * item_ct1.get_local_range(2) + item_ct1.get_local_id(2);
2702
+ const int64_t i12 = item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1);
 
 
 
 
 
 
2703
 
2704
  if (i13 >= ne13 || i12 >= ne12) {
2705
  return;
2706
  }
2707
 
2708
+ const int64_t i03 = i13 / r3;
2709
+ const int64_t i02 = i12 / r2;
2710
+
2711
+ const uint8_t * src0_bytes = reinterpret_cast<const uint8_t *>(src0_as_f16);
2712
+ const uint8_t * src1_bytes = reinterpret_cast<const uint8_t *>(src1_as_f16);
2713
+ uint8_t * dst_bytes = reinterpret_cast<uint8_t *>(dst);
2714
 
2715
+ ptrs_src[0 * ne23 + i12 + i13 * ne12] = src0_bytes + i02 * nb02 + i03 * nb03;
2716
+ ptrs_src[1 * ne23 + i12 + i13 * ne12] = src1_bytes + i12 * nb12 + i13 * nb13;
2717
+ ptrs_dst[0 * ne23 + i12 + i13 * ne12] = dst_bytes + i12 * nbd2 + i13 * nbd3;
2718
  }
2719
 
2720
+ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor * src0,
2721
+ const ggml_tensor * src1, ggml_tensor * dst) try {
 
 
2722
  GGML_ASSERT(!ggml_is_transposed(src0));
2723
  GGML_ASSERT(!ggml_is_transposed(src1));
2724
  GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
 
2726
 
2727
  GGML_TENSOR_BINARY_OP_LOCALS
2728
 
2729
+ // TODO: see https://github.com/ggml-org/llama.cpp/pull/13155
2730
+ // Batched mul_mat requires a rewrite to support both oneDNN and non-contiguous dst
2731
+ GGML_ASSERT(ggml_is_contiguous(dst));
2732
 
2733
  SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2734
+ queue_ptr queue = ctx.stream();
2735
 
2736
+ dpct::has_capability_or_fail(queue->get_device(), { sycl::aspect::fp16 });
 
 
 
2737
 
2738
+ const sycl::half * src0_f16 = static_cast<const sycl::half *>(src0->data);
2739
+ float * dst_ddf = static_cast<float *>(dst->data);
2740
+
2741
+ const sycl::half * src1_f16 = static_cast<const sycl::half *>(src1->data);
2742
+ const size_t type_size_src1 = ggml_type_size(src1->type);
2743
+ GGML_ASSERT(nb10 == type_size_src1);
2744
+
2745
+ // SRC1 strides
2746
+ int64_t s11 = nb11 / type_size_src1;
2747
+ int64_t s12 = nb12 / type_size_src1;
2748
+ int64_t s13 = nb13 / type_size_src1;
2749
  ggml_sycl_pool_alloc<sycl::half> src1_f16_alloc(ctx.pool());
2750
+
2751
+ // convert src1 to fp16
2752
  if (src1->type != GGML_TYPE_F16) {
2753
+ const to_fp16_nc_sycl_t to_fp16_nc_sycl = get_to_fp16_nc_sycl(src1->type);
2754
+ GGML_ASSERT(to_fp16_nc_sycl != nullptr);
2755
  const int64_t ne_src1 = ggml_nelements(src1);
2756
  src1_f16_alloc.alloc(ne_src1);
2757
+ to_fp16_nc_sycl(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, queue);
2758
+
2759
+ src1_f16 = src1_f16_alloc.get();
2760
+ s11 = ne10;
2761
+ s12 = ne11 * s11;
2762
+ s13 = ne12 * s12;
2763
  }
 
 
2764
 
2765
+ ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool());
2766
+ char * dst_t = reinterpret_cast<char *>(dst_ddf);
2767
 
2768
+ dpct::library_data_t mkl_compute_type = dpct::library_data_t::real_float;
2769
+ dpct::library_data_t mkl_data_type = dpct::library_data_t::real_float;
2770
 
2771
  // dst strides
2772
  size_t nbd2 = dst->nb[2];
2773
  size_t nbd3 = dst->nb[3];
2774
 
2775
  const float alpha_f32 = 1.0f;
2776
+ const float beta_f32 = 0.0f;
2777
 
2778
  const void * alpha = &alpha_f32;
2779
  const void * beta = &beta_f32;
2780
 
 
 
2781
  GGML_ASSERT(ne12 % ne02 == 0);
2782
  GGML_ASSERT(ne13 % ne03 == 0);
2783
 
2784
  // broadcast factors
2785
+ const int64_t r2 = ne12 / ne02;
2786
+ const int64_t r3 = ne13 / ne03;
2787
 
2788
  if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
2789
  // there is no broadcast and src0, src1 are contiguous across dims 2, 3
2790
+ SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans,
2791
+ oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
2792
+ src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
2793
+ src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_t,
2794
+ mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));
2795
  } else {
2796
+ const int ne23 = ne12 * ne13;
2797
 
2798
+ ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2 * ne23);
2799
+ ggml_sycl_pool_alloc<void *> ptrs_dst(ctx.pool(), 1 * ne23);
2800
  ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1);
2801
 
2802
  sycl::range<3> block_dims(1, ne12, ne13);
2803
+ queue->submit([&](sycl::handler & cgh) {
2804
+ const void ** ptrs_src_get = ptrs_src.get();
2805
+ void ** ptrs_dst_get = ptrs_dst.get();
2806
+ size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half);
2807
+ size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half);
2808
+ cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
2809
+ k_compute_batched_ptrs(src0_f16, src1_f16, dst_t, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02,
2810
+ nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2811
  });
2812
+ });
2813
+
2814
  SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
2815
+ *queue, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
2816
  (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
2817
+ (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta,
2818
+ (void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get())));
2819
  }
2820
+ } catch (const sycl::exception & exc) {
2821
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
2822
+ std::exit(1);
 
 
2823
  }
2824
 
2825
  inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
 
2960
  // The kernel from the if path is faster for that specific case, but does not support all mul mats.
2961
  ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
2962
  }
2963
+ } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
2964
  // KQV single-batch
2965
  ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst);
2966
  } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
 
3867
  if (a->ne[3] != b->ne[3]) {
3868
  return false;
3869
  }
 
 
 
3870
  ggml_type a_type = a->type;
3871
  if (a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ4_XS ||
3872
  a_type == GGML_TYPE_IQ3_XXS || a_type == GGML_TYPE_IQ3_S ||