Nicolò Scipione commited on
Commit
aedb0b3
·
1 Parent(s): 3547979

SYCL: Introducing memory host pool (llama/11251)

Browse files

* Implement host pool for matrix_info

Creating a new memory pool on the host to store memory location for
matrix_info needed to launch gemm_batch from oneMKL/oneMath.
Removing complex support in gemm_batch since it is not used in llama.cpp

* Remove unnecessary headers and cast

* Reorder member variable to avoid warning on initialization

* Formatting

* Remove unused variable

* Address PR review feedback - remove warning

---------

Signed-off-by: nscipione <[email protected]>

ggml/src/ggml-sycl/common.hpp CHANGED
@@ -333,8 +333,12 @@ struct ggml_backend_sycl_context {
333
  // pool
334
  std::unique_ptr<ggml_sycl_pool> pools[GGML_SYCL_MAX_DEVICES];
335
 
 
 
336
  static std::unique_ptr<ggml_sycl_pool> new_pool_for_device(queue_ptr qptr, int device);
337
 
 
 
338
  ggml_sycl_pool & pool(int device) {
339
  if (pools[device] == nullptr) {
340
  pools[device] = new_pool_for_device(stream(device,0), device);
@@ -345,6 +349,15 @@ struct ggml_backend_sycl_context {
345
  ggml_sycl_pool & pool() {
346
  return pool(device);
347
  }
 
 
 
 
 
 
 
 
 
348
  };
349
 
350
  // common device functions
 
333
  // pool
334
  std::unique_ptr<ggml_sycl_pool> pools[GGML_SYCL_MAX_DEVICES];
335
 
336
+ std::unique_ptr<ggml_sycl_pool> host_pools[GGML_SYCL_MAX_DEVICES];
337
+
338
  static std::unique_ptr<ggml_sycl_pool> new_pool_for_device(queue_ptr qptr, int device);
339
 
340
+ static std::unique_ptr<ggml_sycl_pool> new_pool_for_host(queue_ptr qptr, int device);
341
+
342
  ggml_sycl_pool & pool(int device) {
343
  if (pools[device] == nullptr) {
344
  pools[device] = new_pool_for_device(stream(device,0), device);
 
349
  ggml_sycl_pool & pool() {
350
  return pool(device);
351
  }
352
+
353
+ ggml_sycl_pool & host_pool(int device) {
354
+ if (host_pools[device] == nullptr) {
355
+ host_pools[device] = new_pool_for_host(stream(device, 0), device);
356
+ }
357
+ return *host_pools[device];
358
+ }
359
+
360
+ ggml_sycl_pool & host_pool() { return host_pool(device); }
361
  };
362
 
363
  // common device functions
ggml/src/ggml-sycl/dpct/helper.hpp CHANGED
@@ -82,6 +82,14 @@ inline std::string get_device_backend_and_type(const sycl::device &device) {
82
  return device_type.str();
83
  }
84
 
 
 
 
 
 
 
 
 
85
  namespace dpct
86
  {
87
  typedef sycl::queue *queue_ptr;
@@ -1727,26 +1735,13 @@ namespace dpct
1727
  };
1728
 
1729
  template <class Ta, class Tb, class Tc, class Ts>
1730
- inline void gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans,
1731
- oneapi::mkl::transpose b_trans, int m, int n, int k,
1732
- const void *alpha, const void **a, int lda,
1733
- const void **b, int ldb, const void *beta, void **c,
1734
- int ldc, int batch_size)
1735
- {
1736
- struct matrix_info_t
1737
- {
1738
- oneapi::mkl::transpose transpose_info[2];
1739
- Ts value_info[2];
1740
- std::int64_t size_info[3];
1741
- std::int64_t ld_info[3];
1742
- std::int64_t groupsize_info;
1743
- };
1744
-
1745
  Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
1746
  Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
1747
 
1748
- matrix_info_t *matrix_info =
1749
- (matrix_info_t *)std::malloc(sizeof(matrix_info_t));
1750
  matrix_info->transpose_info[0] = a_trans;
1751
  matrix_info->transpose_info[1] = b_trans;
1752
  matrix_info->value_info[0] = alpha_value;
@@ -1763,23 +1758,18 @@ namespace dpct
1763
  sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
1764
  oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info,
1765
  matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1,
1766
- matrix_info->size_info + 2, matrix_info->value_info, reinterpret_cast<const Ta **>(a),
1767
- matrix_info->ld_info, reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
1768
- matrix_info->value_info + 1, reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1,
1769
- &(matrix_info->groupsize_info));
1770
  #else
1771
  sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
1772
  q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info,
1773
- matrix_info->size_info + 1, matrix_info->size_info + 2, matrix_info->value_info,
1774
  reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
1775
- matrix_info->ld_info + 1, matrix_info->value_info + 1, reinterpret_cast<Tc **>(c),
1776
- matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
1777
  #endif
1778
-
1779
- q.submit([&](sycl::handler &cgh)
1780
- {
1781
- cgh.depends_on(e);
1782
- cgh.host_task([=] { std::free(matrix_info); }); });
1783
  }
1784
 
1785
  template <class Ta, class Tb, class Tc, class Ts>
@@ -2422,25 +2412,11 @@ namespace dpct
2422
  /// \param [in] ldc Leading dimension of C.
2423
  /// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
2424
  /// \param [in] scaling_type Data type of the scaling factors.
2425
- inline void gemm_batch(sycl::queue &q, oneapi::mkl::transpose a_trans,
2426
- oneapi::mkl::transpose b_trans, int m, int n, int k,
2427
- const void *alpha, const void *a[],
2428
- library_data_t a_type, int lda, const void *b[],
2429
- library_data_t b_type, int ldb, const void *beta,
2430
- void *c[], library_data_t c_type, int ldc,
2431
- int batch_size, library_data_t scaling_type)
2432
- {
2433
- if (scaling_type == library_data_t::real_float &&
2434
- c_type == library_data_t::complex_float)
2435
- {
2436
- scaling_type = library_data_t::complex_float;
2437
- }
2438
- else if (scaling_type == library_data_t::real_double &&
2439
- c_type == library_data_t::complex_double)
2440
- {
2441
- scaling_type = library_data_t::complex_double;
2442
- }
2443
-
2444
  std::uint64_t key =
2445
  detail::get_type_combination_id(a_type, b_type, c_type, scaling_type);
2446
  switch (key)
@@ -2449,48 +2425,24 @@ namespace dpct
2449
  library_data_t::real_float, library_data_t::real_float,
2450
  library_data_t::real_float, library_data_t::real_float):
2451
  {
2452
- detail::gemm_batch_impl<float, float, float, float>(
2453
- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2454
- batch_size);
2455
  break;
2456
  }
2457
  case detail::get_type_combination_id(
2458
  library_data_t::real_double, library_data_t::real_double,
2459
  library_data_t::real_double, library_data_t::real_double):
2460
  {
2461
- detail::gemm_batch_impl<double, double, double, double>(
2462
- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2463
- batch_size);
2464
- break;
2465
- }
2466
- case detail::get_type_combination_id(
2467
- library_data_t::complex_float, library_data_t::complex_float,
2468
- library_data_t::complex_float, library_data_t::complex_float):
2469
- {
2470
- detail::gemm_batch_impl<std::complex<float>, std::complex<float>,
2471
- std::complex<float>, std::complex<float>>(
2472
- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2473
- batch_size);
2474
- break;
2475
- }
2476
- case detail::get_type_combination_id(
2477
- library_data_t::complex_double, library_data_t::complex_double,
2478
- library_data_t::complex_double, library_data_t::complex_double):
2479
- {
2480
- detail::gemm_batch_impl<std::complex<double>, std::complex<double>,
2481
- std::complex<double>, std::complex<double>>(
2482
- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2483
- batch_size);
2484
  break;
2485
  }
2486
  case detail::get_type_combination_id(
2487
  library_data_t::real_half, library_data_t::real_half,
2488
  library_data_t::real_half, library_data_t::real_half):
2489
  {
2490
- detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half,
2491
- sycl::half>(q, a_trans, b_trans, m, n, k, alpha,
2492
- a, lda, b, ldb, beta, c, ldc,
2493
- batch_size);
2494
  break;
2495
  }
2496
  #ifdef __INTEL_MKL__
@@ -2498,19 +2450,16 @@ namespace dpct
2498
  library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2499
  library_data_t::real_bfloat16, library_data_t::real_float):
2500
  {
2501
- detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16,
2502
- oneapi::mkl::bfloat16, float>(
2503
- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2504
- batch_size);
2505
  break;
2506
  }
2507
  case detail::get_type_combination_id(
2508
  library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2509
  library_data_t::real_float, library_data_t::real_float):
2510
  {
2511
- detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float,
2512
- float>(q, a_trans, b_trans, m, n, k, alpha, a, lda,
2513
- b, ldb, beta, c, ldc, batch_size);
2514
  break;
2515
  }
2516
  #endif
@@ -2522,10 +2471,9 @@ namespace dpct
2522
  dpct::get_value(reinterpret_cast<const std::int32_t *>(alpha), q);
2523
  float beta_float =
2524
  dpct::get_value(reinterpret_cast<const std::int32_t *>(beta), q);
2525
- detail::gemm_batch_impl<std::int8_t, std::int8_t, std::int32_t,
2526
- float>(q, a_trans, b_trans, m, n, k, &alpha_float,
2527
- a, lda, b, ldb, &beta_float, c, ldc,
2528
- batch_size);
2529
  break;
2530
  }
2531
  case detail::get_type_combination_id(
@@ -2533,8 +2481,7 @@ namespace dpct
2533
  library_data_t::real_float, library_data_t::real_float):
2534
  {
2535
  detail::gemm_batch_impl<std::int8_t, std::int8_t, float, float>(
2536
- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2537
- batch_size);
2538
  break;
2539
  }
2540
  case detail::get_type_combination_id(
@@ -2542,8 +2489,7 @@ namespace dpct
2542
  library_data_t::real_float, library_data_t::real_float):
2543
  {
2544
  detail::gemm_batch_impl<sycl::half, sycl::half, float, float>(
2545
- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2546
- batch_size);
2547
  break;
2548
  }
2549
  case detail::get_type_combination_id(
@@ -2557,8 +2503,7 @@ namespace dpct
2557
  sycl::half alpha_half(alpha_value);
2558
  sycl::half beta_half(beta_value);
2559
  detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
2560
- q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc,
2561
- batch_size);
2562
  break;
2563
  }
2564
  default:
 
82
  return device_type.str();
83
  }
84
 
85
+ template <typename Ts> struct matrix_info_t {
86
+ oneapi::mkl::transpose transpose_info[2];
87
+ Ts value_info[2];
88
+ std::int64_t size_info[3];
89
+ std::int64_t ld_info[3];
90
+ std::int64_t groupsize_info;
91
+ };
92
+
93
  namespace dpct
94
  {
95
  typedef sycl::queue *queue_ptr;
 
1735
  };
1736
 
1737
  template <class Ta, class Tb, class Tc, class Ts>
1738
+ inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans,
1739
+ int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b,
1740
+ int ldb, const void * beta, void ** c, int ldc, int batch_size,
1741
+ matrix_info_t<float> * matrix_info) {
 
 
 
 
 
 
 
 
 
 
 
1742
  Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
1743
  Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
1744
 
 
 
1745
  matrix_info->transpose_info[0] = a_trans;
1746
  matrix_info->transpose_info[1] = b_trans;
1747
  matrix_info->value_info[0] = alpha_value;
 
1758
  sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
1759
  oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info,
1760
  matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1,
1761
+ matrix_info->size_info + 2, reinterpret_cast<Ts *>(matrix_info->value_info),
1762
+ reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
1763
+ matrix_info->ld_info + 1, reinterpret_cast<Ts *>(matrix_info->value_info + 1),
1764
+ reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
1765
  #else
1766
  sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
1767
  q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info,
1768
+ matrix_info->size_info + 1, matrix_info->size_info + 2, reinterpret_cast<Ts *>(matrix_info->value_info),
1769
  reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
1770
+ matrix_info->ld_info + 1, reinterpret_cast<Ts *>(matrix_info->value_info + 1),
1771
+ reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
1772
  #endif
 
 
 
 
 
1773
  }
1774
 
1775
  template <class Ta, class Tb, class Tc, class Ts>
 
2412
  /// \param [in] ldc Leading dimension of C.
2413
  /// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
2414
  /// \param [in] scaling_type Data type of the scaling factors.
2415
+ inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,
2416
+ int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda,
2417
+ const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[],
2418
+ library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type,
2419
+ matrix_info_t<float> * matrix_info) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2420
  std::uint64_t key =
2421
  detail::get_type_combination_id(a_type, b_type, c_type, scaling_type);
2422
  switch (key)
 
2425
  library_data_t::real_float, library_data_t::real_float,
2426
  library_data_t::real_float, library_data_t::real_float):
2427
  {
2428
+ detail::gemm_batch_impl<float, float, float, float>(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb,
2429
+ beta, c, ldc, batch_size, matrix_info);
 
2430
  break;
2431
  }
2432
  case detail::get_type_combination_id(
2433
  library_data_t::real_double, library_data_t::real_double,
2434
  library_data_t::real_double, library_data_t::real_double):
2435
  {
2436
+ detail::gemm_batch_impl<double, double, double, double>(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb,
2437
+ beta, c, ldc, batch_size, matrix_info);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2438
  break;
2439
  }
2440
  case detail::get_type_combination_id(
2441
  library_data_t::real_half, library_data_t::real_half,
2442
  library_data_t::real_half, library_data_t::real_half):
2443
  {
2444
+ detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
2445
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
 
 
2446
  break;
2447
  }
2448
  #ifdef __INTEL_MKL__
 
2450
  library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2451
  library_data_t::real_bfloat16, library_data_t::real_float):
2452
  {
2453
+ detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float>(
2454
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
 
 
2455
  break;
2456
  }
2457
  case detail::get_type_combination_id(
2458
  library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2459
  library_data_t::real_float, library_data_t::real_float):
2460
  {
2461
+ detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, float>(
2462
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
 
2463
  break;
2464
  }
2465
  #endif
 
2471
  dpct::get_value(reinterpret_cast<const std::int32_t *>(alpha), q);
2472
  float beta_float =
2473
  dpct::get_value(reinterpret_cast<const std::int32_t *>(beta), q);
2474
+ detail::gemm_batch_impl<std::int8_t, std::int8_t, std::int32_t, float>(
2475
+ q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc, batch_size,
2476
+ matrix_info);
 
2477
  break;
2478
  }
2479
  case detail::get_type_combination_id(
 
2481
  library_data_t::real_float, library_data_t::real_float):
2482
  {
2483
  detail::gemm_batch_impl<std::int8_t, std::int8_t, float, float>(
2484
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
 
2485
  break;
2486
  }
2487
  case detail::get_type_combination_id(
 
2489
  library_data_t::real_float, library_data_t::real_float):
2490
  {
2491
  detail::gemm_batch_impl<sycl::half, sycl::half, float, float>(
2492
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
 
2493
  break;
2494
  }
2495
  case detail::get_type_combination_id(
 
2503
  sycl::half alpha_half(alpha_value);
2504
  sycl::half beta_half(beta_value);
2505
  detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
2506
+ q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, batch_size, matrix_info);
 
2507
  break;
2508
  }
2509
  default:
ggml/src/ggml-sycl/ggml-sycl.cpp CHANGED
@@ -1173,6 +1173,85 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool {
1173
  }
1174
  };
1175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1176
  std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_device(queue_ptr qptr, int device) {
1177
  // TBD: NO VMM support
1178
  // if (ggml_sycl_info().devices[device].vmm) {
@@ -3363,6 +3442,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
3363
 
3364
  ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
3365
  ggml_sycl_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
 
3366
 
3367
  sycl::range<3> block_dims(1, ne12, ne13);
3368
  /*
@@ -3391,14 +3471,10 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
3391
  });
3392
  }
3393
  SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
3394
- *main_stream, oneapi::mkl::transpose::trans,
3395
- oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
3396
- (const void **)(ptrs_src.get() + 0 * ne23),
3397
- dpct::library_data_t::real_half, nb01 / nb00,
3398
- (const void **)(ptrs_src.get() + 1 * ne23),
3399
- dpct::library_data_t::real_half, nb11 / nb10, beta,
3400
- (void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23,
3401
- cu_compute_type)));
3402
  }
3403
  }
3404
  catch (sycl::exception const &exc) {
 
1173
  }
1174
  };
1175
 
1176
+ struct ggml_sycl_pool_host : public ggml_sycl_pool {
1177
+ queue_ptr qptr;
1178
+ int device;
1179
+
1180
+ inline static int counter{ 0 };
1181
+
1182
+ struct ggml_sycl_buffer {
1183
+ void * ptr = nullptr;
1184
+ size_t size = 0;
1185
+ };
1186
+
1187
+ // Set arbitrarly to 64
1188
+ static constexpr int MAX_POOL_SIZE{ 64 };
1189
+ std::vector<ggml_sycl_buffer> buffer_pool = std::vector<ggml_sycl_buffer>(MAX_POOL_SIZE);
1190
+ size_t pool_size = 0;
1191
+
1192
+ explicit ggml_sycl_pool_host(queue_ptr qptr_, int device_) : qptr(qptr_), device(device_) {}
1193
+
1194
+ ~ggml_sycl_pool_host() {
1195
+ for (int i = 0; i < MAX_POOL_SIZE; ++i) {
1196
+ ggml_sycl_buffer & b = buffer_pool[i];
1197
+ if (b.ptr != nullptr) {
1198
+ SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(b.ptr, *qptr)));
1199
+ b.ptr = nullptr;
1200
+ pool_size -= b.size;
1201
+ b.size = 0;
1202
+ }
1203
+ }
1204
+ counter = 0;
1205
+ }
1206
+
1207
+ void * alloc(size_t size, size_t * actual_size) override {
1208
+ if (counter == MAX_POOL_SIZE) {
1209
+ ggml_sycl_buffer b = buffer_pool[0];
1210
+ void * ptr = b.ptr;
1211
+ *actual_size = b.size;
1212
+ counter = 1;
1213
+ return ptr;
1214
+ }
1215
+ ggml_sycl_buffer & b = buffer_pool[counter];
1216
+
1217
+ if (b.ptr == nullptr) {
1218
+ void * ptr;
1219
+
1220
+ SYCL_CHECK(CHECK_TRY_ERROR(ptr = (void *) sycl::malloc_host(size, *qptr)));
1221
+ if (!ptr) {
1222
+ GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on host\n", __func__, size);
1223
+ return nullptr;
1224
+ }
1225
+ pool_size += size;
1226
+ *actual_size = size;
1227
+ counter = counter + 1;
1228
+ return ptr;
1229
+ } else {
1230
+ ++counter;
1231
+ b.size = size;
1232
+ return b.ptr;
1233
+ }
1234
+ }
1235
+
1236
+ void free(void * ptr, size_t size) override {
1237
+ // if the pool is not completed add the pointer to it in place of the first nullptr found.
1238
+ // Otherwise do nothing, pointers will be freed once the pool is deallocated.
1239
+ for (int i = 0; i < MAX_POOL_SIZE; ++i) {
1240
+ ggml_sycl_buffer & b = buffer_pool[i];
1241
+ if (b.ptr == nullptr) {
1242
+ b.ptr = ptr;
1243
+ b.size = size;
1244
+ return;
1245
+ }
1246
+ }
1247
+ }
1248
+ };
1249
+
1250
+ std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_host(queue_ptr qptr, int device) {
1251
+ // return pool for the host to speed up memory management
1252
+ return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_host(qptr, device));
1253
+ }
1254
+
1255
  std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_device(queue_ptr qptr, int device) {
1256
  // TBD: NO VMM support
1257
  // if (ggml_sycl_info().devices[device].vmm) {
 
3442
 
3443
  ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
3444
  ggml_sycl_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
3445
+ ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1);
3446
 
3447
  sycl::range<3> block_dims(1, ne12, ne13);
3448
  /*
 
3471
  });
3472
  }
3473
  SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
3474
+ *main_stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
3475
+ (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
3476
+ (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, nb11 / nb10, beta,
3477
+ (void **) (ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, cu_compute_type, matrix_info.get())));
 
 
 
 
3478
  }
3479
  }
3480
  catch (sycl::exception const &exc) {