Spaces:
Running
Running
Alberto Cabrera Pérez
commited on
Commit
·
918eff5
1
Parent(s):
dc59673
sycl: addressing non-contiguous src1 mul_mats (nc and batched) (llama/13343)
Browse files* sycl: fixed non-contiguous src1 mul_mats (nc and batched)
* Fixed wrong static_cast inside kernel
- ggml/src/ggml-sycl/common.hpp +6 -11
- ggml/src/ggml-sycl/convert.cpp +43 -23
- ggml/src/ggml-sycl/convert.hpp +14 -7
- ggml/src/ggml-sycl/ggml-sycl.cpp +75 -84
ggml/src/ggml-sycl/common.hpp
CHANGED
|
@@ -114,17 +114,12 @@ static void crash() {
|
|
| 114 |
GGML_ABORT("SYCL error");
|
| 115 |
}
|
| 116 |
|
| 117 |
-
#define SYCL_CHECK(err)
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
__func__, \
|
| 124 |
-
__FILE__, \
|
| 125 |
-
__LINE__, \
|
| 126 |
-
"Meet error in this line code!"); \
|
| 127 |
-
} while (0)
|
| 128 |
|
| 129 |
#if DPCT_COMPAT_RT_VERSION >= 11100
|
| 130 |
#define GGML_SYCL_ASSUME(x) __builtin_assume(x)
|
|
|
|
| 114 |
GGML_ABORT("SYCL error");
|
| 115 |
}
|
| 116 |
|
| 117 |
+
#define SYCL_CHECK(err) \
|
| 118 |
+
do { \
|
| 119 |
+
auto err_ = (err); \
|
| 120 |
+
if (err_ != 0) \
|
| 121 |
+
ggml_sycl_error(#err, __func__, __FILE__, __LINE__, "Exception caught in this line of code."); \
|
| 122 |
+
} while (0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
#if DPCT_COMPAT_RT_VERSION >= 11100
|
| 125 |
#define GGML_SYCL_ASSUME(x) __builtin_assume(x)
|
ggml/src/ggml-sycl/convert.cpp
CHANGED
|
@@ -437,41 +437,52 @@ static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int64_t k
|
|
| 437 |
}
|
| 438 |
|
| 439 |
template <typename src_t, typename dst_t>
|
| 440 |
-
static void
|
| 441 |
-
const
|
|
|
|
|
|
|
| 442 |
const int64_t work_group_size = item_ct1.get_local_range(2);
|
| 443 |
-
const int64_t global_id
|
|
|
|
|
|
|
|
|
|
|
|
|
| 444 |
|
| 445 |
// make each work-item deal with more elements since sycl global range can not exceed max int
|
| 446 |
-
const src_t * x =
|
| 447 |
-
|
| 448 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 449 |
}
|
| 450 |
}
|
| 451 |
|
| 452 |
template <typename src_t, typename dst_t>
|
| 453 |
-
static void
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
|
|
|
|
|
|
| 457 |
|
| 458 |
// decrease global range when it exceeds the max int
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
sycl::range<3>
|
| 462 |
-
{
|
| 463 |
-
dpct::has_capability_or_fail(stream->get_device(),
|
| 464 |
-
{sycl::aspect::fp16});
|
| 465 |
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
convert_unary<src_t>(vx, y, k, item_ct1);
|
| 470 |
-
});
|
| 471 |
-
}
|
| 472 |
}
|
| 473 |
|
| 474 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 475 |
switch (type) {
|
| 476 |
case GGML_TYPE_Q4_0:
|
| 477 |
if (dst->src[0]->extra &&
|
|
@@ -574,3 +585,12 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
|
|
| 574 |
return nullptr;
|
| 575 |
}
|
| 576 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 437 |
}
|
| 438 |
|
| 439 |
template <typename src_t, typename dst_t>
|
| 440 |
+
static void convert_unary_nc(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01,
|
| 441 |
+
const int64_t ne02, const int64_t s01, const int64_t s02, const int64_t s03,
|
| 442 |
+
const sycl::nd_item<3> & item_ct1) {
|
| 443 |
+
|
| 444 |
const int64_t work_group_size = item_ct1.get_local_range(2);
|
| 445 |
+
const int64_t global_id = item_ct1.get_local_id(2) + work_group_size * item_ct1.get_group(2);
|
| 446 |
+
|
| 447 |
+
const int64_t i01 = item_ct1.get_group(1);
|
| 448 |
+
const int64_t i02 = item_ct1.get_group(0) % ne02;
|
| 449 |
+
const int64_t i03 = item_ct1.get_group(0) / ne02;
|
| 450 |
|
| 451 |
// make each work-item deal with more elements since sycl global range can not exceed max int
|
| 452 |
+
const src_t * x = static_cast<const src_t *>(vx);
|
| 453 |
+
const int64_t ix = i03 * s03 + i02 * s02 + i01 * s01;
|
| 454 |
+
const int64_t iy = ((i03 * ne02 + i02) * ne01 + i01) * ne00;
|
| 455 |
+
|
| 456 |
+
#pragma unroll
|
| 457 |
+
for (int64_t i00 = global_id; i00 < ne00; i00 += work_group_size * item_ct1.get_group_range(2)) {
|
| 458 |
+
y[iy + i00] = static_cast<dst_t>(x[ix + i00]);
|
| 459 |
}
|
| 460 |
}
|
| 461 |
|
| 462 |
template <typename src_t, typename dst_t>
|
| 463 |
+
static void convert_unary_nc_sycl(const void * __restrict__ vx, dst_t * __restrict__ y,
|
| 464 |
+
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
| 465 |
+
const int64_t s01, const int64_t s02, const int64_t s03, dpct::queue_ptr queue) {
|
| 466 |
+
dpct::has_capability_or_fail(queue->get_device(), { sycl::aspect::fp16 });
|
| 467 |
+
|
| 468 |
+
sycl::range<3> global_size(ne02 * ne03, ne01, ceil_div(ne00, SYCL_DEQUANTIZE_BLOCK_SIZE));
|
| 469 |
|
| 470 |
// decrease global range when it exceeds the max int
|
| 471 |
+
// TODO: Downsample logic is separated from the kernel, a rewrite is desirable
|
| 472 |
+
int64_t downsized_workgroup = downsample_sycl_global_range(global_size[0], SYCL_DEQUANTIZE_BLOCK_SIZE);
|
| 473 |
+
sycl::range<3> workgroup_size(1, 1, downsized_workgroup);
|
|
|
|
|
|
|
|
|
|
| 474 |
|
| 475 |
+
queue->parallel_for(sycl::nd_range<3>(global_size * workgroup_size, workgroup_size), [=](sycl::nd_item<3> item_ct1) {
|
| 476 |
+
convert_unary_nc<src_t>(vx, y, ne00, ne01, ne02, s01, s02, s03, item_ct1);
|
| 477 |
+
});
|
|
|
|
|
|
|
|
|
|
| 478 |
}
|
| 479 |
|
| 480 |
+
template <typename src_t, typename dst_t>
|
| 481 |
+
static void convert_unary_sycl(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr queue) {
|
| 482 |
+
convert_unary_nc_sycl<src_t>(vx, y, k, 1, 1, 1, k, k, k, queue);
|
| 483 |
+
}
|
| 484 |
+
|
| 485 |
+
to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
|
| 486 |
switch (type) {
|
| 487 |
case GGML_TYPE_Q4_0:
|
| 488 |
if (dst->src[0]->extra &&
|
|
|
|
| 585 |
return nullptr;
|
| 586 |
}
|
| 587 |
}
|
| 588 |
+
|
| 589 |
+
to_fp16_nc_sycl_t get_to_fp16_nc_sycl(ggml_type type) {
|
| 590 |
+
switch (type) {
|
| 591 |
+
case GGML_TYPE_F32:
|
| 592 |
+
return convert_unary_nc_sycl<float>;
|
| 593 |
+
default:
|
| 594 |
+
return nullptr;
|
| 595 |
+
}
|
| 596 |
+
}
|
ggml/src/ggml-sycl/convert.hpp
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
//
|
| 2 |
// MIT license
|
| 3 |
-
// Copyright (C)
|
| 4 |
// SPDX-License-Identifier: MIT
|
| 5 |
//
|
| 6 |
|
|
@@ -16,12 +16,19 @@
|
|
| 16 |
#include "common.hpp"
|
| 17 |
|
| 18 |
template <typename T>
|
| 19 |
-
using to_t_sycl_t = void (*)(const void *__restrict__ x, T *__restrict__ y,
|
| 20 |
-
|
| 21 |
-
typedef to_t_sycl_t<float> to_fp32_sycl_t;
|
| 22 |
typedef to_t_sycl_t<sycl::half> to_fp16_sycl_t;
|
| 23 |
|
| 24 |
-
to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor *dst);
|
| 25 |
-
to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst);
|
| 26 |
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
//
|
| 2 |
// MIT license
|
| 3 |
+
// Copyright (C) 2025 Intel Corporation
|
| 4 |
// SPDX-License-Identifier: MIT
|
| 5 |
//
|
| 6 |
|
|
|
|
| 16 |
#include "common.hpp"
|
| 17 |
|
| 18 |
template <typename T>
|
| 19 |
+
using to_t_sycl_t = void (*)(const void * __restrict__ x, T * __restrict__ y, int64_t k, dpct::queue_ptr stream);
|
| 20 |
+
typedef to_t_sycl_t<float> to_fp32_sycl_t;
|
|
|
|
| 21 |
typedef to_t_sycl_t<sycl::half> to_fp16_sycl_t;
|
| 22 |
|
| 23 |
+
to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst);
|
| 24 |
+
to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor * dst);
|
| 25 |
|
| 26 |
+
// Nc = Non-contiguous
|
| 27 |
+
template <typename T>
|
| 28 |
+
using to_t_nc_sycl_t = void (*)(const void * x, T * y, int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03,
|
| 29 |
+
int64_t s01, int64_t s02, int64_t s03, dpct::queue_ptr queue);
|
| 30 |
+
|
| 31 |
+
typedef to_t_nc_sycl_t<sycl::half> to_fp16_nc_sycl_t;
|
| 32 |
+
to_fp16_nc_sycl_t get_to_fp16_nc_sycl(ggml_type type);
|
| 33 |
+
|
| 34 |
+
#endif // GGML_SYCL_CONVERT_HPP
|
ggml/src/ggml-sycl/ggml-sycl.cpp
CHANGED
|
@@ -2694,35 +2694,31 @@ catch (sycl::exception const &exc) {
|
|
| 2694 |
std::exit(1);
|
| 2695 |
}
|
| 2696 |
|
| 2697 |
-
static void k_compute_batched_ptrs(const sycl::half *src0_as_f16,
|
| 2698 |
-
const
|
| 2699 |
-
|
| 2700 |
-
int64_t
|
| 2701 |
-
|
| 2702 |
-
|
| 2703 |
-
int64_t r2, int64_t r3,
|
| 2704 |
-
const sycl::nd_item<3> &item_ct1) {
|
| 2705 |
-
int64_t i13 = item_ct1.get_group(2) * item_ct1.get_local_range(2) +
|
| 2706 |
-
item_ct1.get_local_id(2);
|
| 2707 |
-
int64_t i12 = item_ct1.get_group(1) * item_ct1.get_local_range(1) +
|
| 2708 |
-
item_ct1.get_local_id(1);
|
| 2709 |
|
| 2710 |
if (i13 >= ne13 || i12 >= ne12) {
|
| 2711 |
return;
|
| 2712 |
}
|
| 2713 |
|
| 2714 |
-
int64_t i03 = i13 / r3;
|
| 2715 |
-
int64_t i02 = i12 / r2;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2716 |
|
| 2717 |
-
ptrs_src[0*ne23 + i12 + i13*ne12] =
|
| 2718 |
-
ptrs_src[1*ne23 + i12 + i13*ne12] =
|
| 2719 |
-
ptrs_dst[0*ne23 + i12 + i13*ne12] =
|
| 2720 |
}
|
| 2721 |
|
| 2722 |
-
static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
|
| 2723 |
-
|
| 2724 |
-
const ggml_tensor *src1,
|
| 2725 |
-
ggml_tensor *dst) try {
|
| 2726 |
GGML_ASSERT(!ggml_is_transposed(src0));
|
| 2727 |
GGML_ASSERT(!ggml_is_transposed(src1));
|
| 2728 |
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
|
|
@@ -2730,102 +2726,100 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
|
|
| 2730 |
|
| 2731 |
GGML_TENSOR_BINARY_OP_LOCALS
|
| 2732 |
|
|
|
|
|
|
|
|
|
|
| 2733 |
|
| 2734 |
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
| 2735 |
-
queue_ptr
|
| 2736 |
|
| 2737 |
-
|
| 2738 |
-
sycl::half *src0_as_f16 = (sycl::half *)src0_ddq;
|
| 2739 |
-
float * src1_ddf = (float *) src1->data;
|
| 2740 |
-
float * dst_ddf = (float *) dst->data;
|
| 2741 |
|
| 2742 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2743 |
ggml_sycl_pool_alloc<sycl::half> src1_f16_alloc(ctx.pool());
|
|
|
|
|
|
|
| 2744 |
if (src1->type != GGML_TYPE_F16) {
|
| 2745 |
-
const
|
|
|
|
| 2746 |
const int64_t ne_src1 = ggml_nelements(src1);
|
| 2747 |
src1_f16_alloc.alloc(ne_src1);
|
| 2748 |
-
|
| 2749 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2750 |
}
|
| 2751 |
-
sycl::half *src1_f16 = src1->type == GGML_TYPE_F16 ? (sycl::half *)src1_ddf
|
| 2752 |
-
: src1_f16_alloc.get();
|
| 2753 |
|
| 2754 |
-
|
|
|
|
| 2755 |
|
| 2756 |
-
dpct::library_data_t
|
| 2757 |
-
dpct::library_data_t
|
| 2758 |
|
| 2759 |
// dst strides
|
| 2760 |
size_t nbd2 = dst->nb[2];
|
| 2761 |
size_t nbd3 = dst->nb[3];
|
| 2762 |
|
| 2763 |
const float alpha_f32 = 1.0f;
|
| 2764 |
-
const float beta_f32
|
| 2765 |
|
| 2766 |
const void * alpha = &alpha_f32;
|
| 2767 |
const void * beta = &beta_f32;
|
| 2768 |
|
| 2769 |
-
dst_t = (char *) dst_ddf;
|
| 2770 |
-
|
| 2771 |
GGML_ASSERT(ne12 % ne02 == 0);
|
| 2772 |
GGML_ASSERT(ne13 % ne03 == 0);
|
| 2773 |
|
| 2774 |
// broadcast factors
|
| 2775 |
-
const int64_t r2 = ne12/ne02;
|
| 2776 |
-
const int64_t r3 = ne13/ne03;
|
| 2777 |
|
| 2778 |
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
|
| 2779 |
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
|
| 2780 |
-
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
|
| 2781 |
-
|
| 2782 |
-
|
| 2783 |
-
|
| 2784 |
-
|
| 2785 |
} else {
|
| 2786 |
-
const int ne23 = ne12*ne13;
|
| 2787 |
|
| 2788 |
-
ggml_sycl_pool_alloc<const void *>
|
| 2789 |
-
ggml_sycl_pool_alloc<
|
| 2790 |
ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1);
|
| 2791 |
|
| 2792 |
sycl::range<3> block_dims(1, ne12, ne13);
|
| 2793 |
-
|
| 2794 |
-
|
| 2795 |
-
|
| 2796 |
-
|
| 2797 |
-
|
| 2798 |
-
|
| 2799 |
-
|
| 2800 |
-
|
| 2801 |
-
|
| 2802 |
-
main_stream->submit([&](sycl::handler &cgh) {
|
| 2803 |
-
const void **ptrs_src_get = ptrs_src.get();
|
| 2804 |
-
void **ptrs_dst_get = ptrs_dst.get();
|
| 2805 |
-
size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : nb12 / 2;
|
| 2806 |
-
size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : nb13 / 2;
|
| 2807 |
-
cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims),
|
| 2808 |
-
[=](sycl::nd_item<3> item_ct1) {
|
| 2809 |
-
k_compute_batched_ptrs(
|
| 2810 |
-
src0_as_f16, src1_f16,
|
| 2811 |
-
dst_t, ptrs_src_get,
|
| 2812 |
-
ptrs_dst_get, ne12, ne13, ne23,
|
| 2813 |
-
nb02, nb03, nb12_scaled, nb13_scaled,
|
| 2814 |
-
nbd2, nbd3, r2, r3, item_ct1);
|
| 2815 |
-
});
|
| 2816 |
});
|
| 2817 |
-
}
|
|
|
|
| 2818 |
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
|
| 2819 |
-
*
|
| 2820 |
(const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
|
| 2821 |
-
(const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half,
|
| 2822 |
-
(void **) (ptrs_dst.get() + 0 * ne23),
|
| 2823 |
}
|
| 2824 |
-
}
|
| 2825 |
-
|
| 2826 |
-
|
| 2827 |
-
<< ", line:" << __LINE__ << std::endl;
|
| 2828 |
-
std::exit(1);
|
| 2829 |
}
|
| 2830 |
|
| 2831 |
inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
|
|
@@ -2966,7 +2960,7 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
|
|
| 2966 |
// The kernel from the if path is faster for that specific case, but does not support all mul mats.
|
| 2967 |
ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
|
| 2968 |
}
|
| 2969 |
-
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
|
| 2970 |
// KQV single-batch
|
| 2971 |
ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst);
|
| 2972 |
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
|
|
@@ -3873,9 +3867,6 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|
| 3873 |
if (a->ne[3] != b->ne[3]) {
|
| 3874 |
return false;
|
| 3875 |
}
|
| 3876 |
-
if (!ggml_is_contiguous(b)) {
|
| 3877 |
-
return false;
|
| 3878 |
-
}
|
| 3879 |
ggml_type a_type = a->type;
|
| 3880 |
if (a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ4_XS ||
|
| 3881 |
a_type == GGML_TYPE_IQ3_XXS || a_type == GGML_TYPE_IQ3_S ||
|
|
|
|
| 2694 |
std::exit(1);
|
| 2695 |
}
|
| 2696 |
|
| 2697 |
+
static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::half * src1_as_f16, char * dst,
|
| 2698 |
+
const void ** ptrs_src, void ** ptrs_dst, int64_t ne12, int64_t ne13, int64_t ne23,
|
| 2699 |
+
size_t nb02, size_t nb03, size_t nb12, size_t nb13, size_t nbd2, size_t nbd3,
|
| 2700 |
+
int64_t r2, int64_t r3, const sycl::nd_item<3> & item_ct1) {
|
| 2701 |
+
const int64_t i13 = item_ct1.get_group(2) * item_ct1.get_local_range(2) + item_ct1.get_local_id(2);
|
| 2702 |
+
const int64_t i12 = item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2703 |
|
| 2704 |
if (i13 >= ne13 || i12 >= ne12) {
|
| 2705 |
return;
|
| 2706 |
}
|
| 2707 |
|
| 2708 |
+
const int64_t i03 = i13 / r3;
|
| 2709 |
+
const int64_t i02 = i12 / r2;
|
| 2710 |
+
|
| 2711 |
+
const uint8_t * src0_bytes = reinterpret_cast<const uint8_t *>(src0_as_f16);
|
| 2712 |
+
const uint8_t * src1_bytes = reinterpret_cast<const uint8_t *>(src1_as_f16);
|
| 2713 |
+
uint8_t * dst_bytes = reinterpret_cast<uint8_t *>(dst);
|
| 2714 |
|
| 2715 |
+
ptrs_src[0 * ne23 + i12 + i13 * ne12] = src0_bytes + i02 * nb02 + i03 * nb03;
|
| 2716 |
+
ptrs_src[1 * ne23 + i12 + i13 * ne12] = src1_bytes + i12 * nb12 + i13 * nb13;
|
| 2717 |
+
ptrs_dst[0 * ne23 + i12 + i13 * ne12] = dst_bytes + i12 * nbd2 + i13 * nbd3;
|
| 2718 |
}
|
| 2719 |
|
| 2720 |
+
static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor * src0,
|
| 2721 |
+
const ggml_tensor * src1, ggml_tensor * dst) try {
|
|
|
|
|
|
|
| 2722 |
GGML_ASSERT(!ggml_is_transposed(src0));
|
| 2723 |
GGML_ASSERT(!ggml_is_transposed(src1));
|
| 2724 |
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
|
|
|
|
| 2726 |
|
| 2727 |
GGML_TENSOR_BINARY_OP_LOCALS
|
| 2728 |
|
| 2729 |
+
// TODO: see https://github.com/ggml-org/llama.cpp/pull/13155
|
| 2730 |
+
// Batched mul_mat requires a rewrite to support both oneDNN and non-contiguous dst
|
| 2731 |
+
GGML_ASSERT(ggml_is_contiguous(dst));
|
| 2732 |
|
| 2733 |
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
| 2734 |
+
queue_ptr queue = ctx.stream();
|
| 2735 |
|
| 2736 |
+
dpct::has_capability_or_fail(queue->get_device(), { sycl::aspect::fp16 });
|
|
|
|
|
|
|
|
|
|
| 2737 |
|
| 2738 |
+
const sycl::half * src0_f16 = static_cast<const sycl::half *>(src0->data);
|
| 2739 |
+
float * dst_ddf = static_cast<float *>(dst->data);
|
| 2740 |
+
|
| 2741 |
+
const sycl::half * src1_f16 = static_cast<const sycl::half *>(src1->data);
|
| 2742 |
+
const size_t type_size_src1 = ggml_type_size(src1->type);
|
| 2743 |
+
GGML_ASSERT(nb10 == type_size_src1);
|
| 2744 |
+
|
| 2745 |
+
// SRC1 strides
|
| 2746 |
+
int64_t s11 = nb11 / type_size_src1;
|
| 2747 |
+
int64_t s12 = nb12 / type_size_src1;
|
| 2748 |
+
int64_t s13 = nb13 / type_size_src1;
|
| 2749 |
ggml_sycl_pool_alloc<sycl::half> src1_f16_alloc(ctx.pool());
|
| 2750 |
+
|
| 2751 |
+
// convert src1 to fp16
|
| 2752 |
if (src1->type != GGML_TYPE_F16) {
|
| 2753 |
+
const to_fp16_nc_sycl_t to_fp16_nc_sycl = get_to_fp16_nc_sycl(src1->type);
|
| 2754 |
+
GGML_ASSERT(to_fp16_nc_sycl != nullptr);
|
| 2755 |
const int64_t ne_src1 = ggml_nelements(src1);
|
| 2756 |
src1_f16_alloc.alloc(ne_src1);
|
| 2757 |
+
to_fp16_nc_sycl(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, queue);
|
| 2758 |
+
|
| 2759 |
+
src1_f16 = src1_f16_alloc.get();
|
| 2760 |
+
s11 = ne10;
|
| 2761 |
+
s12 = ne11 * s11;
|
| 2762 |
+
s13 = ne12 * s12;
|
| 2763 |
}
|
|
|
|
|
|
|
| 2764 |
|
| 2765 |
+
ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool());
|
| 2766 |
+
char * dst_t = reinterpret_cast<char *>(dst_ddf);
|
| 2767 |
|
| 2768 |
+
dpct::library_data_t mkl_compute_type = dpct::library_data_t::real_float;
|
| 2769 |
+
dpct::library_data_t mkl_data_type = dpct::library_data_t::real_float;
|
| 2770 |
|
| 2771 |
// dst strides
|
| 2772 |
size_t nbd2 = dst->nb[2];
|
| 2773 |
size_t nbd3 = dst->nb[3];
|
| 2774 |
|
| 2775 |
const float alpha_f32 = 1.0f;
|
| 2776 |
+
const float beta_f32 = 0.0f;
|
| 2777 |
|
| 2778 |
const void * alpha = &alpha_f32;
|
| 2779 |
const void * beta = &beta_f32;
|
| 2780 |
|
|
|
|
|
|
|
| 2781 |
GGML_ASSERT(ne12 % ne02 == 0);
|
| 2782 |
GGML_ASSERT(ne13 % ne03 == 0);
|
| 2783 |
|
| 2784 |
// broadcast factors
|
| 2785 |
+
const int64_t r2 = ne12 / ne02;
|
| 2786 |
+
const int64_t r3 = ne13 / ne03;
|
| 2787 |
|
| 2788 |
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
|
| 2789 |
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
|
| 2790 |
+
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans,
|
| 2791 |
+
oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
|
| 2792 |
+
src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
|
| 2793 |
+
src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_t,
|
| 2794 |
+
mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));
|
| 2795 |
} else {
|
| 2796 |
+
const int ne23 = ne12 * ne13;
|
| 2797 |
|
| 2798 |
+
ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2 * ne23);
|
| 2799 |
+
ggml_sycl_pool_alloc<void *> ptrs_dst(ctx.pool(), 1 * ne23);
|
| 2800 |
ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1);
|
| 2801 |
|
| 2802 |
sycl::range<3> block_dims(1, ne12, ne13);
|
| 2803 |
+
queue->submit([&](sycl::handler & cgh) {
|
| 2804 |
+
const void ** ptrs_src_get = ptrs_src.get();
|
| 2805 |
+
void ** ptrs_dst_get = ptrs_dst.get();
|
| 2806 |
+
size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half);
|
| 2807 |
+
size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half);
|
| 2808 |
+
cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
| 2809 |
+
k_compute_batched_ptrs(src0_f16, src1_f16, dst_t, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02,
|
| 2810 |
+
nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2811 |
});
|
| 2812 |
+
});
|
| 2813 |
+
|
| 2814 |
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
|
| 2815 |
+
*queue, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
|
| 2816 |
(const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
|
| 2817 |
+
(const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta,
|
| 2818 |
+
(void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get())));
|
| 2819 |
}
|
| 2820 |
+
} catch (const sycl::exception & exc) {
|
| 2821 |
+
std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
|
| 2822 |
+
std::exit(1);
|
|
|
|
|
|
|
| 2823 |
}
|
| 2824 |
|
| 2825 |
inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
|
|
|
|
| 2960 |
// The kernel from the if path is faster for that specific case, but does not support all mul mats.
|
| 2961 |
ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
|
| 2962 |
}
|
| 2963 |
+
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
|
| 2964 |
// KQV single-batch
|
| 2965 |
ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst);
|
| 2966 |
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
|
|
|
|
| 3867 |
if (a->ne[3] != b->ne[3]) {
|
| 3868 |
return false;
|
| 3869 |
}
|
|
|
|
|
|
|
|
|
|
| 3870 |
ggml_type a_type = a->type;
|
| 3871 |
if (a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ4_XS ||
|
| 3872 |
a_type == GGML_TYPE_IQ3_XXS || a_type == GGML_TYPE_IQ3_S ||
|