Spaces:
Running
Running
Nicolò Scipione
commited on
Commit
·
aedb0b3
1
Parent(s):
3547979
SYCL: Introducing memory host pool (llama/11251)
Browse files* Implement host pool for matrix_info
Creating a new memory pool on the host to store memory location for
matrix_info needed to launch gemm_batch from oneMKL/oneMath.
Removing complex support in gemm_batch since it is not used in llama.cpp
* Remove unnecessary headers and cast
* Reorder member variable to avoid warning on initialization
* Formatting
* Remove unused variable
* Address PR review feedback - remove warning
---------
Signed-off-by: nscipione <[email protected]>
- ggml/src/ggml-sycl/common.hpp +13 -0
- ggml/src/ggml-sycl/dpct/helper.hpp +40 -95
- ggml/src/ggml-sycl/ggml-sycl.cpp +84 -8
ggml/src/ggml-sycl/common.hpp
CHANGED
|
@@ -333,8 +333,12 @@ struct ggml_backend_sycl_context {
|
|
| 333 |
// pool
|
| 334 |
std::unique_ptr<ggml_sycl_pool> pools[GGML_SYCL_MAX_DEVICES];
|
| 335 |
|
|
|
|
|
|
|
| 336 |
static std::unique_ptr<ggml_sycl_pool> new_pool_for_device(queue_ptr qptr, int device);
|
| 337 |
|
|
|
|
|
|
|
| 338 |
ggml_sycl_pool & pool(int device) {
|
| 339 |
if (pools[device] == nullptr) {
|
| 340 |
pools[device] = new_pool_for_device(stream(device,0), device);
|
|
@@ -345,6 +349,15 @@ struct ggml_backend_sycl_context {
|
|
| 345 |
ggml_sycl_pool & pool() {
|
| 346 |
return pool(device);
|
| 347 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 348 |
};
|
| 349 |
|
| 350 |
// common device functions
|
|
|
|
| 333 |
// pool
|
| 334 |
std::unique_ptr<ggml_sycl_pool> pools[GGML_SYCL_MAX_DEVICES];
|
| 335 |
|
| 336 |
+
std::unique_ptr<ggml_sycl_pool> host_pools[GGML_SYCL_MAX_DEVICES];
|
| 337 |
+
|
| 338 |
static std::unique_ptr<ggml_sycl_pool> new_pool_for_device(queue_ptr qptr, int device);
|
| 339 |
|
| 340 |
+
static std::unique_ptr<ggml_sycl_pool> new_pool_for_host(queue_ptr qptr, int device);
|
| 341 |
+
|
| 342 |
ggml_sycl_pool & pool(int device) {
|
| 343 |
if (pools[device] == nullptr) {
|
| 344 |
pools[device] = new_pool_for_device(stream(device,0), device);
|
|
|
|
| 349 |
ggml_sycl_pool & pool() {
|
| 350 |
return pool(device);
|
| 351 |
}
|
| 352 |
+
|
| 353 |
+
ggml_sycl_pool & host_pool(int device) {
|
| 354 |
+
if (host_pools[device] == nullptr) {
|
| 355 |
+
host_pools[device] = new_pool_for_host(stream(device, 0), device);
|
| 356 |
+
}
|
| 357 |
+
return *host_pools[device];
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
ggml_sycl_pool & host_pool() { return host_pool(device); }
|
| 361 |
};
|
| 362 |
|
| 363 |
// common device functions
|
ggml/src/ggml-sycl/dpct/helper.hpp
CHANGED
|
@@ -82,6 +82,14 @@ inline std::string get_device_backend_and_type(const sycl::device &device) {
|
|
| 82 |
return device_type.str();
|
| 83 |
}
|
| 84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
namespace dpct
|
| 86 |
{
|
| 87 |
typedef sycl::queue *queue_ptr;
|
|
@@ -1727,26 +1735,13 @@ namespace dpct
|
|
| 1727 |
};
|
| 1728 |
|
| 1729 |
template <class Ta, class Tb, class Tc, class Ts>
|
| 1730 |
-
inline void gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans,
|
| 1731 |
-
|
| 1732 |
-
const void *
|
| 1733 |
-
|
| 1734 |
-
int ldc, int batch_size)
|
| 1735 |
-
{
|
| 1736 |
-
struct matrix_info_t
|
| 1737 |
-
{
|
| 1738 |
-
oneapi::mkl::transpose transpose_info[2];
|
| 1739 |
-
Ts value_info[2];
|
| 1740 |
-
std::int64_t size_info[3];
|
| 1741 |
-
std::int64_t ld_info[3];
|
| 1742 |
-
std::int64_t groupsize_info;
|
| 1743 |
-
};
|
| 1744 |
-
|
| 1745 |
Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
|
| 1746 |
Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
|
| 1747 |
|
| 1748 |
-
matrix_info_t *matrix_info =
|
| 1749 |
-
(matrix_info_t *)std::malloc(sizeof(matrix_info_t));
|
| 1750 |
matrix_info->transpose_info[0] = a_trans;
|
| 1751 |
matrix_info->transpose_info[1] = b_trans;
|
| 1752 |
matrix_info->value_info[0] = alpha_value;
|
|
@@ -1763,23 +1758,18 @@ namespace dpct
|
|
| 1763 |
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
|
| 1764 |
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info,
|
| 1765 |
matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1,
|
| 1766 |
-
matrix_info->size_info + 2,
|
| 1767 |
-
matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
|
| 1768 |
-
matrix_info->
|
| 1769 |
-
&(matrix_info->groupsize_info));
|
| 1770 |
#else
|
| 1771 |
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
|
| 1772 |
q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info,
|
| 1773 |
-
matrix_info->size_info + 1, matrix_info->size_info + 2, matrix_info->value_info,
|
| 1774 |
reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
|
| 1775 |
-
matrix_info->ld_info + 1, matrix_info->value_info + 1
|
| 1776 |
-
matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
|
| 1777 |
#endif
|
| 1778 |
-
|
| 1779 |
-
q.submit([&](sycl::handler &cgh)
|
| 1780 |
-
{
|
| 1781 |
-
cgh.depends_on(e);
|
| 1782 |
-
cgh.host_task([=] { std::free(matrix_info); }); });
|
| 1783 |
}
|
| 1784 |
|
| 1785 |
template <class Ta, class Tb, class Tc, class Ts>
|
|
@@ -2422,25 +2412,11 @@ namespace dpct
|
|
| 2422 |
/// \param [in] ldc Leading dimension of C.
|
| 2423 |
/// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
|
| 2424 |
/// \param [in] scaling_type Data type of the scaling factors.
|
| 2425 |
-
inline void gemm_batch(sycl::queue &q, oneapi::mkl::transpose a_trans,
|
| 2426 |
-
|
| 2427 |
-
const void *
|
| 2428 |
-
library_data_t
|
| 2429 |
-
|
| 2430 |
-
void *c[], library_data_t c_type, int ldc,
|
| 2431 |
-
int batch_size, library_data_t scaling_type)
|
| 2432 |
-
{
|
| 2433 |
-
if (scaling_type == library_data_t::real_float &&
|
| 2434 |
-
c_type == library_data_t::complex_float)
|
| 2435 |
-
{
|
| 2436 |
-
scaling_type = library_data_t::complex_float;
|
| 2437 |
-
}
|
| 2438 |
-
else if (scaling_type == library_data_t::real_double &&
|
| 2439 |
-
c_type == library_data_t::complex_double)
|
| 2440 |
-
{
|
| 2441 |
-
scaling_type = library_data_t::complex_double;
|
| 2442 |
-
}
|
| 2443 |
-
|
| 2444 |
std::uint64_t key =
|
| 2445 |
detail::get_type_combination_id(a_type, b_type, c_type, scaling_type);
|
| 2446 |
switch (key)
|
|
@@ -2449,48 +2425,24 @@ namespace dpct
|
|
| 2449 |
library_data_t::real_float, library_data_t::real_float,
|
| 2450 |
library_data_t::real_float, library_data_t::real_float):
|
| 2451 |
{
|
| 2452 |
-
detail::gemm_batch_impl<float, float, float, float>(
|
| 2453 |
-
|
| 2454 |
-
batch_size);
|
| 2455 |
break;
|
| 2456 |
}
|
| 2457 |
case detail::get_type_combination_id(
|
| 2458 |
library_data_t::real_double, library_data_t::real_double,
|
| 2459 |
library_data_t::real_double, library_data_t::real_double):
|
| 2460 |
{
|
| 2461 |
-
detail::gemm_batch_impl<double, double, double, double>(
|
| 2462 |
-
|
| 2463 |
-
batch_size);
|
| 2464 |
-
break;
|
| 2465 |
-
}
|
| 2466 |
-
case detail::get_type_combination_id(
|
| 2467 |
-
library_data_t::complex_float, library_data_t::complex_float,
|
| 2468 |
-
library_data_t::complex_float, library_data_t::complex_float):
|
| 2469 |
-
{
|
| 2470 |
-
detail::gemm_batch_impl<std::complex<float>, std::complex<float>,
|
| 2471 |
-
std::complex<float>, std::complex<float>>(
|
| 2472 |
-
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
| 2473 |
-
batch_size);
|
| 2474 |
-
break;
|
| 2475 |
-
}
|
| 2476 |
-
case detail::get_type_combination_id(
|
| 2477 |
-
library_data_t::complex_double, library_data_t::complex_double,
|
| 2478 |
-
library_data_t::complex_double, library_data_t::complex_double):
|
| 2479 |
-
{
|
| 2480 |
-
detail::gemm_batch_impl<std::complex<double>, std::complex<double>,
|
| 2481 |
-
std::complex<double>, std::complex<double>>(
|
| 2482 |
-
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
| 2483 |
-
batch_size);
|
| 2484 |
break;
|
| 2485 |
}
|
| 2486 |
case detail::get_type_combination_id(
|
| 2487 |
library_data_t::real_half, library_data_t::real_half,
|
| 2488 |
library_data_t::real_half, library_data_t::real_half):
|
| 2489 |
{
|
| 2490 |
-
detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half,
|
| 2491 |
-
|
| 2492 |
-
a, lda, b, ldb, beta, c, ldc,
|
| 2493 |
-
batch_size);
|
| 2494 |
break;
|
| 2495 |
}
|
| 2496 |
#ifdef __INTEL_MKL__
|
|
@@ -2498,19 +2450,16 @@ namespace dpct
|
|
| 2498 |
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
| 2499 |
library_data_t::real_bfloat16, library_data_t::real_float):
|
| 2500 |
{
|
| 2501 |
-
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16,
|
| 2502 |
-
|
| 2503 |
-
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
| 2504 |
-
batch_size);
|
| 2505 |
break;
|
| 2506 |
}
|
| 2507 |
case detail::get_type_combination_id(
|
| 2508 |
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
| 2509 |
library_data_t::real_float, library_data_t::real_float):
|
| 2510 |
{
|
| 2511 |
-
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float,
|
| 2512 |
-
|
| 2513 |
-
b, ldb, beta, c, ldc, batch_size);
|
| 2514 |
break;
|
| 2515 |
}
|
| 2516 |
#endif
|
|
@@ -2522,10 +2471,9 @@ namespace dpct
|
|
| 2522 |
dpct::get_value(reinterpret_cast<const std::int32_t *>(alpha), q);
|
| 2523 |
float beta_float =
|
| 2524 |
dpct::get_value(reinterpret_cast<const std::int32_t *>(beta), q);
|
| 2525 |
-
detail::gemm_batch_impl<std::int8_t, std::int8_t, std::int32_t,
|
| 2526 |
-
|
| 2527 |
-
|
| 2528 |
-
batch_size);
|
| 2529 |
break;
|
| 2530 |
}
|
| 2531 |
case detail::get_type_combination_id(
|
|
@@ -2533,8 +2481,7 @@ namespace dpct
|
|
| 2533 |
library_data_t::real_float, library_data_t::real_float):
|
| 2534 |
{
|
| 2535 |
detail::gemm_batch_impl<std::int8_t, std::int8_t, float, float>(
|
| 2536 |
-
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
| 2537 |
-
batch_size);
|
| 2538 |
break;
|
| 2539 |
}
|
| 2540 |
case detail::get_type_combination_id(
|
|
@@ -2542,8 +2489,7 @@ namespace dpct
|
|
| 2542 |
library_data_t::real_float, library_data_t::real_float):
|
| 2543 |
{
|
| 2544 |
detail::gemm_batch_impl<sycl::half, sycl::half, float, float>(
|
| 2545 |
-
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
| 2546 |
-
batch_size);
|
| 2547 |
break;
|
| 2548 |
}
|
| 2549 |
case detail::get_type_combination_id(
|
|
@@ -2557,8 +2503,7 @@ namespace dpct
|
|
| 2557 |
sycl::half alpha_half(alpha_value);
|
| 2558 |
sycl::half beta_half(beta_value);
|
| 2559 |
detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
|
| 2560 |
-
q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc,
|
| 2561 |
-
batch_size);
|
| 2562 |
break;
|
| 2563 |
}
|
| 2564 |
default:
|
|
|
|
| 82 |
return device_type.str();
|
| 83 |
}
|
| 84 |
|
| 85 |
+
template <typename Ts> struct matrix_info_t {
|
| 86 |
+
oneapi::mkl::transpose transpose_info[2];
|
| 87 |
+
Ts value_info[2];
|
| 88 |
+
std::int64_t size_info[3];
|
| 89 |
+
std::int64_t ld_info[3];
|
| 90 |
+
std::int64_t groupsize_info;
|
| 91 |
+
};
|
| 92 |
+
|
| 93 |
namespace dpct
|
| 94 |
{
|
| 95 |
typedef sycl::queue *queue_ptr;
|
|
|
|
| 1735 |
};
|
| 1736 |
|
| 1737 |
template <class Ta, class Tb, class Tc, class Ts>
|
| 1738 |
+
inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans,
|
| 1739 |
+
int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b,
|
| 1740 |
+
int ldb, const void * beta, void ** c, int ldc, int batch_size,
|
| 1741 |
+
matrix_info_t<float> * matrix_info) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1742 |
Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
|
| 1743 |
Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
|
| 1744 |
|
|
|
|
|
|
|
| 1745 |
matrix_info->transpose_info[0] = a_trans;
|
| 1746 |
matrix_info->transpose_info[1] = b_trans;
|
| 1747 |
matrix_info->value_info[0] = alpha_value;
|
|
|
|
| 1758 |
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
|
| 1759 |
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info,
|
| 1760 |
matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1,
|
| 1761 |
+
matrix_info->size_info + 2, reinterpret_cast<Ts *>(matrix_info->value_info),
|
| 1762 |
+
reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
|
| 1763 |
+
matrix_info->ld_info + 1, reinterpret_cast<Ts *>(matrix_info->value_info + 1),
|
| 1764 |
+
reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
|
| 1765 |
#else
|
| 1766 |
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
|
| 1767 |
q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info,
|
| 1768 |
+
matrix_info->size_info + 1, matrix_info->size_info + 2, reinterpret_cast<Ts *>(matrix_info->value_info),
|
| 1769 |
reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
|
| 1770 |
+
matrix_info->ld_info + 1, reinterpret_cast<Ts *>(matrix_info->value_info + 1),
|
| 1771 |
+
reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
|
| 1772 |
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1773 |
}
|
| 1774 |
|
| 1775 |
template <class Ta, class Tb, class Tc, class Ts>
|
|
|
|
| 2412 |
/// \param [in] ldc Leading dimension of C.
|
| 2413 |
/// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
|
| 2414 |
/// \param [in] scaling_type Data type of the scaling factors.
|
| 2415 |
+
inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,
|
| 2416 |
+
int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda,
|
| 2417 |
+
const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[],
|
| 2418 |
+
library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type,
|
| 2419 |
+
matrix_info_t<float> * matrix_info) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2420 |
std::uint64_t key =
|
| 2421 |
detail::get_type_combination_id(a_type, b_type, c_type, scaling_type);
|
| 2422 |
switch (key)
|
|
|
|
| 2425 |
library_data_t::real_float, library_data_t::real_float,
|
| 2426 |
library_data_t::real_float, library_data_t::real_float):
|
| 2427 |
{
|
| 2428 |
+
detail::gemm_batch_impl<float, float, float, float>(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb,
|
| 2429 |
+
beta, c, ldc, batch_size, matrix_info);
|
|
|
|
| 2430 |
break;
|
| 2431 |
}
|
| 2432 |
case detail::get_type_combination_id(
|
| 2433 |
library_data_t::real_double, library_data_t::real_double,
|
| 2434 |
library_data_t::real_double, library_data_t::real_double):
|
| 2435 |
{
|
| 2436 |
+
detail::gemm_batch_impl<double, double, double, double>(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb,
|
| 2437 |
+
beta, c, ldc, batch_size, matrix_info);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2438 |
break;
|
| 2439 |
}
|
| 2440 |
case detail::get_type_combination_id(
|
| 2441 |
library_data_t::real_half, library_data_t::real_half,
|
| 2442 |
library_data_t::real_half, library_data_t::real_half):
|
| 2443 |
{
|
| 2444 |
+
detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
|
| 2445 |
+
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
|
|
|
|
|
|
|
| 2446 |
break;
|
| 2447 |
}
|
| 2448 |
#ifdef __INTEL_MKL__
|
|
|
|
| 2450 |
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
| 2451 |
library_data_t::real_bfloat16, library_data_t::real_float):
|
| 2452 |
{
|
| 2453 |
+
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float>(
|
| 2454 |
+
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
|
|
|
|
|
|
|
| 2455 |
break;
|
| 2456 |
}
|
| 2457 |
case detail::get_type_combination_id(
|
| 2458 |
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
| 2459 |
library_data_t::real_float, library_data_t::real_float):
|
| 2460 |
{
|
| 2461 |
+
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, float>(
|
| 2462 |
+
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
|
|
|
|
| 2463 |
break;
|
| 2464 |
}
|
| 2465 |
#endif
|
|
|
|
| 2471 |
dpct::get_value(reinterpret_cast<const std::int32_t *>(alpha), q);
|
| 2472 |
float beta_float =
|
| 2473 |
dpct::get_value(reinterpret_cast<const std::int32_t *>(beta), q);
|
| 2474 |
+
detail::gemm_batch_impl<std::int8_t, std::int8_t, std::int32_t, float>(
|
| 2475 |
+
q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc, batch_size,
|
| 2476 |
+
matrix_info);
|
|
|
|
| 2477 |
break;
|
| 2478 |
}
|
| 2479 |
case detail::get_type_combination_id(
|
|
|
|
| 2481 |
library_data_t::real_float, library_data_t::real_float):
|
| 2482 |
{
|
| 2483 |
detail::gemm_batch_impl<std::int8_t, std::int8_t, float, float>(
|
| 2484 |
+
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
|
|
|
|
| 2485 |
break;
|
| 2486 |
}
|
| 2487 |
case detail::get_type_combination_id(
|
|
|
|
| 2489 |
library_data_t::real_float, library_data_t::real_float):
|
| 2490 |
{
|
| 2491 |
detail::gemm_batch_impl<sycl::half, sycl::half, float, float>(
|
| 2492 |
+
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
|
|
|
|
| 2493 |
break;
|
| 2494 |
}
|
| 2495 |
case detail::get_type_combination_id(
|
|
|
|
| 2503 |
sycl::half alpha_half(alpha_value);
|
| 2504 |
sycl::half beta_half(beta_value);
|
| 2505 |
detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
|
| 2506 |
+
q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, batch_size, matrix_info);
|
|
|
|
| 2507 |
break;
|
| 2508 |
}
|
| 2509 |
default:
|
ggml/src/ggml-sycl/ggml-sycl.cpp
CHANGED
|
@@ -1173,6 +1173,85 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool {
|
|
| 1173 |
}
|
| 1174 |
};
|
| 1175 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1176 |
std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_device(queue_ptr qptr, int device) {
|
| 1177 |
// TBD: NO VMM support
|
| 1178 |
// if (ggml_sycl_info().devices[device].vmm) {
|
|
@@ -3363,6 +3442,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
|
|
| 3363 |
|
| 3364 |
ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
|
| 3365 |
ggml_sycl_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
|
|
|
|
| 3366 |
|
| 3367 |
sycl::range<3> block_dims(1, ne12, ne13);
|
| 3368 |
/*
|
|
@@ -3391,14 +3471,10 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
|
|
| 3391 |
});
|
| 3392 |
}
|
| 3393 |
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
|
| 3394 |
-
*main_stream, oneapi::mkl::transpose::trans,
|
| 3395 |
-
|
| 3396 |
-
(const void **)(ptrs_src.get() +
|
| 3397 |
-
|
| 3398 |
-
(const void **)(ptrs_src.get() + 1 * ne23),
|
| 3399 |
-
dpct::library_data_t::real_half, nb11 / nb10, beta,
|
| 3400 |
-
(void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23,
|
| 3401 |
-
cu_compute_type)));
|
| 3402 |
}
|
| 3403 |
}
|
| 3404 |
catch (sycl::exception const &exc) {
|
|
|
|
| 1173 |
}
|
| 1174 |
};
|
| 1175 |
|
| 1176 |
+
struct ggml_sycl_pool_host : public ggml_sycl_pool {
|
| 1177 |
+
queue_ptr qptr;
|
| 1178 |
+
int device;
|
| 1179 |
+
|
| 1180 |
+
inline static int counter{ 0 };
|
| 1181 |
+
|
| 1182 |
+
struct ggml_sycl_buffer {
|
| 1183 |
+
void * ptr = nullptr;
|
| 1184 |
+
size_t size = 0;
|
| 1185 |
+
};
|
| 1186 |
+
|
| 1187 |
+
// Set arbitrarly to 64
|
| 1188 |
+
static constexpr int MAX_POOL_SIZE{ 64 };
|
| 1189 |
+
std::vector<ggml_sycl_buffer> buffer_pool = std::vector<ggml_sycl_buffer>(MAX_POOL_SIZE);
|
| 1190 |
+
size_t pool_size = 0;
|
| 1191 |
+
|
| 1192 |
+
explicit ggml_sycl_pool_host(queue_ptr qptr_, int device_) : qptr(qptr_), device(device_) {}
|
| 1193 |
+
|
| 1194 |
+
~ggml_sycl_pool_host() {
|
| 1195 |
+
for (int i = 0; i < MAX_POOL_SIZE; ++i) {
|
| 1196 |
+
ggml_sycl_buffer & b = buffer_pool[i];
|
| 1197 |
+
if (b.ptr != nullptr) {
|
| 1198 |
+
SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(b.ptr, *qptr)));
|
| 1199 |
+
b.ptr = nullptr;
|
| 1200 |
+
pool_size -= b.size;
|
| 1201 |
+
b.size = 0;
|
| 1202 |
+
}
|
| 1203 |
+
}
|
| 1204 |
+
counter = 0;
|
| 1205 |
+
}
|
| 1206 |
+
|
| 1207 |
+
void * alloc(size_t size, size_t * actual_size) override {
|
| 1208 |
+
if (counter == MAX_POOL_SIZE) {
|
| 1209 |
+
ggml_sycl_buffer b = buffer_pool[0];
|
| 1210 |
+
void * ptr = b.ptr;
|
| 1211 |
+
*actual_size = b.size;
|
| 1212 |
+
counter = 1;
|
| 1213 |
+
return ptr;
|
| 1214 |
+
}
|
| 1215 |
+
ggml_sycl_buffer & b = buffer_pool[counter];
|
| 1216 |
+
|
| 1217 |
+
if (b.ptr == nullptr) {
|
| 1218 |
+
void * ptr;
|
| 1219 |
+
|
| 1220 |
+
SYCL_CHECK(CHECK_TRY_ERROR(ptr = (void *) sycl::malloc_host(size, *qptr)));
|
| 1221 |
+
if (!ptr) {
|
| 1222 |
+
GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on host\n", __func__, size);
|
| 1223 |
+
return nullptr;
|
| 1224 |
+
}
|
| 1225 |
+
pool_size += size;
|
| 1226 |
+
*actual_size = size;
|
| 1227 |
+
counter = counter + 1;
|
| 1228 |
+
return ptr;
|
| 1229 |
+
} else {
|
| 1230 |
+
++counter;
|
| 1231 |
+
b.size = size;
|
| 1232 |
+
return b.ptr;
|
| 1233 |
+
}
|
| 1234 |
+
}
|
| 1235 |
+
|
| 1236 |
+
void free(void * ptr, size_t size) override {
|
| 1237 |
+
// if the pool is not completed add the pointer to it in place of the first nullptr found.
|
| 1238 |
+
// Otherwise do nothing, pointers will be freed once the pool is deallocated.
|
| 1239 |
+
for (int i = 0; i < MAX_POOL_SIZE; ++i) {
|
| 1240 |
+
ggml_sycl_buffer & b = buffer_pool[i];
|
| 1241 |
+
if (b.ptr == nullptr) {
|
| 1242 |
+
b.ptr = ptr;
|
| 1243 |
+
b.size = size;
|
| 1244 |
+
return;
|
| 1245 |
+
}
|
| 1246 |
+
}
|
| 1247 |
+
}
|
| 1248 |
+
};
|
| 1249 |
+
|
| 1250 |
+
std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_host(queue_ptr qptr, int device) {
|
| 1251 |
+
// return pool for the host to speed up memory management
|
| 1252 |
+
return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_host(qptr, device));
|
| 1253 |
+
}
|
| 1254 |
+
|
| 1255 |
std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_device(queue_ptr qptr, int device) {
|
| 1256 |
// TBD: NO VMM support
|
| 1257 |
// if (ggml_sycl_info().devices[device].vmm) {
|
|
|
|
| 3442 |
|
| 3443 |
ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
|
| 3444 |
ggml_sycl_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
|
| 3445 |
+
ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1);
|
| 3446 |
|
| 3447 |
sycl::range<3> block_dims(1, ne12, ne13);
|
| 3448 |
/*
|
|
|
|
| 3471 |
});
|
| 3472 |
}
|
| 3473 |
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
|
| 3474 |
+
*main_stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
|
| 3475 |
+
(const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
|
| 3476 |
+
(const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, nb11 / nb10, beta,
|
| 3477 |
+
(void **) (ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, cu_compute_type, matrix_info.get())));
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3478 |
}
|
| 3479 |
}
|
| 3480 |
catch (sycl::exception const &exc) {
|