Charles Xu Diego Devesa commited on
Commit
3541ee8
·
1 Parent(s): 3dc93f3

backend cpu: add online flow for aarch64 Q4_0 GEMV/GEMM kernels (llama/9921)

Browse files

* backend-cpu: add online flow for aarch64 Q4_0 GEMV/GEMM kernels

---------

Co-authored-by: Diego Devesa <[email protected]>

ggml/CMakeLists.txt CHANGED
@@ -92,6 +92,7 @@ else()
92
  endif()
93
 
94
  option(GGML_CPU_HBM "ggml: use memkind for CPU HBM" OFF)
 
95
 
96
  option(GGML_AVX "ggml: enable AVX" ${INS_ENB})
97
  option(GGML_AVX2 "ggml: enable AVX2" ${INS_ENB})
 
92
  endif()
93
 
94
  option(GGML_CPU_HBM "ggml: use memkind for CPU HBM" OFF)
95
+ option(GGML_CPU_AARCH64 "ggml: use runtime weight conversion of Q4_0 to Q4_X_X" ON)
96
 
97
  option(GGML_AVX "ggml: enable AVX" ${INS_ENB})
98
  option(GGML_AVX2 "ggml: enable AVX2" ${INS_ENB})
ggml/include/ggml-cpu.h CHANGED
@@ -169,6 +169,9 @@ extern "C" {
169
  GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void);
170
  #endif
171
 
 
 
 
172
  #ifdef __cplusplus
173
  }
174
  #endif
 
169
  GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void);
170
  #endif
171
 
172
+ GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cpu_aarch64_buffer_type(void);
173
+ GGML_BACKEND_API bool ggml_backend_cpu_buft_is_aarch64(ggml_backend_buffer_type_t buft);
174
+
175
  #ifdef __cplusplus
176
  }
177
  #endif
ggml/src/ggml-cpu/CMakeLists.txt CHANGED
@@ -236,6 +236,11 @@ else()
236
  message(STATUS "Unknown architecture")
237
  endif()
238
 
 
 
 
 
 
239
  target_compile_options(ggml-cpu PRIVATE "$<$<COMPILE_LANGUAGE:CXX>:${ARCH_FLAGS}>")
240
  target_compile_options(ggml-cpu PRIVATE "$<$<COMPILE_LANGUAGE:C>:${ARCH_FLAGS}>")
241
 
 
236
  message(STATUS "Unknown architecture")
237
  endif()
238
 
239
+ if (GGML_CPU_AARCH64)
240
+ message(STATUS "Using runtime weight conversion of Q4_0 to Q4_0_x_x to enable optimized GEMM/GEMV kernels")
241
+ add_compile_definitions(GGML_USE_CPU_AARCH64)
242
+ endif()
243
+
244
  target_compile_options(ggml-cpu PRIVATE "$<$<COMPILE_LANGUAGE:CXX>:${ARCH_FLAGS}>")
245
  target_compile_options(ggml-cpu PRIVATE "$<$<COMPILE_LANGUAGE:C>:${ARCH_FLAGS}>")
246
 
ggml/src/ggml-cpu/ggml-cpu-aarch64.c CHANGED
@@ -3385,3 +3385,147 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
3385
  }
3386
  }
3387
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3385
  }
3386
  }
3387
  }
3388
+
3389
+ // FIXME: this code is duplicated from ggml-aarch64.c
3390
+ static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave, unsigned int xor_mask) {
3391
+ block_q4_0x4 out;
3392
+
3393
+ for (int i = 0; i < 4; i++) {
3394
+ out.d[i] = in[i].d;
3395
+ }
3396
+
3397
+ for (int i = 0; i < QK4_0 * 2; i++) {
3398
+ int src_offset = (i / (4 * blck_size_interleave)) * blck_size_interleave;
3399
+ int src_id = (i % (4 * blck_size_interleave)) / blck_size_interleave;
3400
+ src_offset += (i % blck_size_interleave);
3401
+
3402
+ out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask;
3403
+ }
3404
+
3405
+ return out;
3406
+ }
3407
+
3408
+ // interleave 8 block_q4_0s in blocks of blck_size_interleave
3409
+ // returns an interleaved block_q4_0x8
3410
+ // in the interleaved block_q4_0x8, place deltas for 8 block_q4_0 blocks
3411
+ // first, then interleave quants from 8 block_q4_0s in blocks of blck_size_interleave
3412
+ static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_interleave, unsigned int xor_mask) {
3413
+ block_q4_0x8 out;
3414
+
3415
+ for (int i = 0; i < 8; i++) {
3416
+ out.d[i] = in[i].d;
3417
+ }
3418
+
3419
+ for (int i = 0; i < QK4_0 * 4; i++) {
3420
+ int src_offset = (i / (8 * blck_size_interleave)) * blck_size_interleave;
3421
+ int src_id = (i % (8 * blck_size_interleave)) / blck_size_interleave;
3422
+ src_offset += (i % blck_size_interleave);
3423
+
3424
+ out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask;
3425
+ }
3426
+
3427
+ return out;
3428
+ }
3429
+
3430
+ static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * restrict data, size_t data_size) {
3431
+ GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
3432
+ GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
3433
+
3434
+ block_q4_0x4 * dst = (block_q4_0x4 *)t->data;
3435
+ const block_q4_0 * src = (const block_q4_0 *)data;
3436
+ block_q4_0 dst_tmp[4];
3437
+ int nrow = t->ne[1]; // Number of rows
3438
+ int nrows_interleaved = 4;
3439
+ int nblocks = t->ne[0] / QK4_0;
3440
+
3441
+ GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));
3442
+
3443
+ if (nrow % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
3444
+ return -1;
3445
+ }
3446
+
3447
+ for (int b = 0; b < nrow; b += nrows_interleaved) {
3448
+ for (int64_t x = 0; x < nblocks; x++) {
3449
+ for (int i = 0; i < nrows_interleaved; i++) {
3450
+ dst_tmp[i] = src[x + i * nblocks];
3451
+ }
3452
+ *dst++ = make_block_q4_0x4(dst_tmp, interleave_block, 0x88);
3453
+ }
3454
+ src += nrows_interleaved * nblocks;
3455
+ }
3456
+ return 0;
3457
+
3458
+ GGML_UNUSED(data_size);
3459
+ }
3460
+
3461
+ static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor *t, int interleave_block, const void * restrict data, size_t data_size) {
3462
+ GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
3463
+ GGML_ASSERT(interleave_block == 8);
3464
+
3465
+ block_q4_0x8 * dst = (block_q4_0x8*)t->data;
3466
+ const block_q4_0 * src = (const block_q4_0*) data;
3467
+ block_q4_0 dst_tmp[8];
3468
+ int nrow = t->ne[1]; // Number of rows
3469
+ int nrows_interleaved = 8;
3470
+ int nblocks = t->ne[0] / QK4_0;
3471
+
3472
+ GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));
3473
+
3474
+ if (nrow % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
3475
+ return -1;
3476
+ }
3477
+
3478
+ for (int b = 0; b < nrow; b += nrows_interleaved) {
3479
+ for (int64_t x = 0; x < nblocks; x++) {
3480
+ for (int i = 0; i < nrows_interleaved; i++ ) {
3481
+ dst_tmp[i] = src[x + i * nblocks];
3482
+ }
3483
+ *dst++ = make_block_q4_0x8(dst_tmp, interleave_block, 0x88);
3484
+ }
3485
+ src += nrows_interleaved * nblocks;
3486
+ }
3487
+ return 0;
3488
+
3489
+ GGML_UNUSED(data_size);
3490
+ }
3491
+
3492
+ // Prepare for optimized kernels if applicable
3493
+ void ggml_aarch64_repack_tensor(struct ggml_tensor * cur, enum ggml_type repack_type, const void * restrict data, size_t data_size) {
3494
+ if (cur->type == repack_type) {
3495
+ memcpy(cur->data, data, data_size);
3496
+ return;
3497
+ }
3498
+
3499
+ GGML_ASSERT(cur->type == GGML_TYPE_Q4_0);
3500
+
3501
+ switch (repack_type) {
3502
+ case GGML_TYPE_Q4_0_8_8:
3503
+ repack_q4_0_to_q4_0_8_bl(cur, 8, data, data_size);
3504
+ break;
3505
+ case GGML_TYPE_Q4_0_4_8:
3506
+ repack_q4_0_to_q4_0_4_bl(cur, 8, data, data_size);
3507
+ break;
3508
+ case GGML_TYPE_Q4_0_4_4:
3509
+ repack_q4_0_to_q4_0_4_bl(cur, 4, data, data_size);
3510
+ break;
3511
+ default:
3512
+ GGML_ABORT("Unsupported type");
3513
+ }
3514
+ }
3515
+
3516
+ enum ggml_type ggml_aarch64_get_optimal_repack_type(const struct ggml_tensor * cur) {
3517
+ if (cur->type == GGML_TYPE_Q4_0) {
3518
+ // TODO: enable for AVX2 - currently disabled due to bad gemv performance
3519
+ if (/* ggml_cpu_has_avx2() || */ (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0)) {
3520
+ return GGML_TYPE_Q4_0_8_8;
3521
+ }
3522
+ if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
3523
+ return GGML_TYPE_Q4_0_4_8;
3524
+ }
3525
+ if (ggml_cpu_has_neon()) {
3526
+ return GGML_TYPE_Q4_0_4_4;
3527
+ }
3528
+ }
3529
+
3530
+ return cur->type;
3531
+ }
ggml/src/ggml-cpu/ggml-cpu-aarch64.h CHANGED
@@ -21,6 +21,9 @@ void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
21
  void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
22
  void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
23
 
 
 
 
24
  #ifdef __cplusplus
25
  }
26
  #endif
 
21
  void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
22
  void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
23
 
24
+ void ggml_aarch64_repack_tensor(struct ggml_tensor * cur, enum ggml_type repack_type, const void * data, size_t data_size);
25
+ enum ggml_type ggml_aarch64_get_optimal_repack_type(const struct ggml_tensor * cur);
26
+
27
  #ifdef __cplusplus
28
  }
29
  #endif
ggml/src/ggml-cpu/ggml-cpu.c CHANGED
@@ -7330,6 +7330,7 @@ static void ggml_compute_forward_group_norm(
7330
  static void ggml_compute_forward_mul_mat_one_chunk(
7331
  const struct ggml_compute_params * params,
7332
  struct ggml_tensor * dst,
 
7333
  const int64_t num_rows_per_vec_dot,
7334
  const int64_t ir0_start,
7335
  const int64_t ir0_end,
@@ -7341,8 +7342,6 @@ static void ggml_compute_forward_mul_mat_one_chunk(
7341
 
7342
  GGML_TENSOR_BINARY_OP_LOCALS
7343
 
7344
- const enum ggml_type type = src0->type;
7345
-
7346
  const bool src1_cont = ggml_is_contiguous(src1);
7347
 
7348
  ggml_vec_dot_t const vec_dot = type_traits_cpu[type].vec_dot;
@@ -7430,7 +7429,11 @@ static void ggml_compute_forward_mul_mat(
7430
  const int ith = params->ith;
7431
  const int nth = params->nth;
7432
 
7433
- const enum ggml_type type = src0->type;
 
 
 
 
7434
 
7435
  enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type;
7436
  ggml_from_float_t const from_float = type_traits_cpu[vec_dot_type].from_float;
@@ -7469,15 +7472,15 @@ static void ggml_compute_forward_mul_mat(
7469
  if (src1_cont) {
7470
  for (int64_t i13 = 0; i13 < ne13; i13++)
7471
  for (int64_t i12 = 0; i12 < ne12; i12++)
7472
- if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
7473
  (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
7474
- nb01/ggml_type_size(src0->type),
7475
  (const char *)src1->data + i12*nb12 + i13*nb13,
7476
  nb11/ggml_type_size(src1->type),
7477
  (char *)dst->data + i12*nb2 + i13*nb3,
7478
  nb1/ggml_type_size(dst->type),
7479
  ith, nth,
7480
- src0->type,
7481
  src1->type,
7482
  dst->type))
7483
  goto UseGgmlGemm1;
@@ -7530,15 +7533,15 @@ UseGgmlGemm1:;
7530
 
7531
  for (int64_t i13 = 0; i13 < ne13; i13++)
7532
  for (int64_t i12 = 0; i12 < ne12; i12++)
7533
- if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
7534
  (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
7535
- nb01/ggml_type_size(src0->type),
7536
  (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size,
7537
  row_size/ggml_type_size(vec_dot_type),
7538
  (char *)dst->data + i12*nb2 + i13*nb3,
7539
  nb1/ggml_type_size(dst->type),
7540
  ith, nth,
7541
- src0->type,
7542
  vec_dot_type,
7543
  dst->type))
7544
  goto UseGgmlGemm2;
@@ -7623,7 +7626,7 @@ UseGgmlGemm2:;
7623
  const int64_t ir1_start = dr1 * ith1;
7624
  const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
7625
 
7626
- ggml_compute_forward_mul_mat_one_chunk(params, dst, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
7627
 
7628
  if (nth >= nchunk0 * nchunk1) {
7629
  break;
 
7330
  static void ggml_compute_forward_mul_mat_one_chunk(
7331
  const struct ggml_compute_params * params,
7332
  struct ggml_tensor * dst,
7333
+ const enum ggml_type type,
7334
  const int64_t num_rows_per_vec_dot,
7335
  const int64_t ir0_start,
7336
  const int64_t ir0_end,
 
7342
 
7343
  GGML_TENSOR_BINARY_OP_LOCALS
7344
 
 
 
7345
  const bool src1_cont = ggml_is_contiguous(src1);
7346
 
7347
  ggml_vec_dot_t const vec_dot = type_traits_cpu[type].vec_dot;
 
7429
  const int ith = params->ith;
7430
  const int nth = params->nth;
7431
 
7432
+ enum ggml_type type = src0->type;
7433
+
7434
+ if (src0->buffer && ggml_backend_cpu_buft_is_aarch64(src0->buffer->buft)) {
7435
+ type = (enum ggml_type)(intptr_t)src0->extra;
7436
+ }
7437
 
7438
  enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type;
7439
  ggml_from_float_t const from_float = type_traits_cpu[vec_dot_type].from_float;
 
7472
  if (src1_cont) {
7473
  for (int64_t i13 = 0; i13 < ne13; i13++)
7474
  for (int64_t i12 = 0; i12 < ne12; i12++)
7475
+ if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(type),
7476
  (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
7477
+ nb01/ggml_type_size(type),
7478
  (const char *)src1->data + i12*nb12 + i13*nb13,
7479
  nb11/ggml_type_size(src1->type),
7480
  (char *)dst->data + i12*nb2 + i13*nb3,
7481
  nb1/ggml_type_size(dst->type),
7482
  ith, nth,
7483
+ type,
7484
  src1->type,
7485
  dst->type))
7486
  goto UseGgmlGemm1;
 
7533
 
7534
  for (int64_t i13 = 0; i13 < ne13; i13++)
7535
  for (int64_t i12 = 0; i12 < ne12; i12++)
7536
+ if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(type),
7537
  (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
7538
+ nb01/ggml_type_size(type),
7539
  (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size,
7540
  row_size/ggml_type_size(vec_dot_type),
7541
  (char *)dst->data + i12*nb2 + i13*nb3,
7542
  nb1/ggml_type_size(dst->type),
7543
  ith, nth,
7544
+ type,
7545
  vec_dot_type,
7546
  dst->type))
7547
  goto UseGgmlGemm2;
 
7626
  const int64_t ir1_start = dr1 * ith1;
7627
  const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
7628
 
7629
+ ggml_compute_forward_mul_mat_one_chunk(params, dst, type, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
7630
 
7631
  if (nth >= nchunk0 * nchunk1) {
7632
  break;
ggml/src/ggml-cpu/ggml-cpu.cpp CHANGED
@@ -1,6 +1,7 @@
1
  #include "ggml-backend.h"
2
  #include "ggml-backend-impl.h"
3
  #include "ggml-cpu.h"
 
4
  #include "ggml-impl.h"
5
  #include <cctype>
6
  #include <string>
@@ -69,15 +70,84 @@ ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void) {
69
  }
70
  #endif
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  static ggml_backend_buffer_type_t * ggml_backend_cpu_get_extra_bufts(ggml_backend_dev_t device) {
73
- static ggml_backend_buffer_type_t bufts[] = {
 
 
74
  #ifdef GGML_USE_CPU_HBM
75
- ggml_backend_cpu_hbm_buffer_type(),
 
 
 
 
76
  #endif
77
- NULL
78
- };
79
 
80
- return bufts;
 
 
 
 
 
81
 
82
  GGML_UNUSED(device);
83
  }
@@ -383,6 +453,21 @@ static ggml_backend_buffer_t ggml_backend_cpu_device_buffer_from_host_ptr(ggml_b
383
  }
384
 
385
  static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
  switch (op->op) {
387
  case GGML_OP_CPY:
388
  return
@@ -391,13 +476,13 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
391
  op->type != GGML_TYPE_IQ1_S &&
392
  op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float
393
  case GGML_OP_MUL_MAT:
394
- return op->src[1]->type == GGML_TYPE_F32;// FIXME || op->src[1]->type == ggml_get_type_traits(op->src[0]->type)->vec_dot_type;
395
  case GGML_OP_ROPE_BACK:
396
  return op->src[2] == NULL && (op->op_params[2] & 4) == 0;
397
  case GGML_OP_IM2COL_BACK:
398
- return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
399
  case GGML_OP_OUT_PROD:
400
- return (op->src[0]->type == GGML_TYPE_F32 || ggml_is_quantized(op->src[0]->type)) && op->src[1]->type == GGML_TYPE_F32;
401
  default:
402
  return true;
403
  }
@@ -406,7 +491,7 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
406
  }
407
 
408
  static bool ggml_backend_cpu_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
409
- return ggml_backend_buft_is_host(buft);
410
 
411
  GGML_UNUSED(dev);
412
  }
@@ -566,6 +651,9 @@ static const struct ggml_backend_reg_i ggml_backend_cpu_reg_i = {
566
  };
567
 
568
  ggml_backend_reg_t ggml_backend_cpu_reg(void) {
 
 
 
569
  static struct ggml_backend_reg ggml_backend_cpu_reg = {
570
  /* .iface = */ ggml_backend_cpu_reg_i,
571
  /* .context = */ NULL,
 
1
  #include "ggml-backend.h"
2
  #include "ggml-backend-impl.h"
3
  #include "ggml-cpu.h"
4
+ #include "ggml-cpu-aarch64.h"
5
  #include "ggml-impl.h"
6
  #include <cctype>
7
  #include <string>
 
70
  }
71
  #endif
72
 
73
+ // buffer type AARCH64
74
+
75
+ static void ggml_backend_cpu_aarch64_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
76
+ tensor->extra = (void *)ggml_aarch64_get_optimal_repack_type(tensor); // NOLINT
77
+
78
+ GGML_UNUSED(buffer);
79
+ }
80
+
81
+ static void ggml_backend_cpu_aarch64_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
82
+ GGML_ASSERT(offset == 0);
83
+ GGML_ASSERT(size == ggml_nbytes(tensor));
84
+
85
+ enum ggml_type repack_type = (enum ggml_type)(intptr_t)tensor->extra;
86
+
87
+ ggml_aarch64_repack_tensor(tensor, repack_type, data, size);
88
+
89
+ GGML_UNUSED(buffer);
90
+ }
91
+
92
+ static const char * ggml_backend_cpu_aarch64_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
93
+ return "CPU_AARCH64";
94
+
95
+ GGML_UNUSED(buft);
96
+ }
97
+
98
+ static ggml_backend_buffer_t ggml_backend_cpu_aarch64_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
99
+ auto * buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
100
+
101
+ if (buffer == NULL) {
102
+ return NULL;
103
+ }
104
+
105
+ buffer->buft = buft;
106
+ buffer->iface.init_tensor = ggml_backend_cpu_aarch64_buffer_init_tensor;
107
+ buffer->iface.set_tensor = ggml_backend_cpu_aarch64_buffer_set_tensor;
108
+
109
+ return buffer;
110
+ }
111
+
112
+ ggml_backend_buffer_type_t ggml_backend_cpu_aarch64_buffer_type(void) {
113
+ static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_aarch64 = {
114
+ /* .iface = */ {
115
+ /* .get_name = */ ggml_backend_cpu_aarch64_buffer_type_get_name,
116
+ /* .alloc_buffer = */ ggml_backend_cpu_aarch64_buffer_type_alloc_buffer,
117
+ /* .get_alignment = */ ggml_backend_cpu_buffer_type()->iface.get_alignment,
118
+ /* .get_max_size = */ NULL, // defaults to SIZE_MAX
119
+ /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
120
+ /* .is_host = */ NULL,
121
+ },
122
+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
123
+ /* .context = */ NULL,
124
+ };
125
+
126
+ return &ggml_backend_cpu_buffer_type_aarch64;
127
+ }
128
+
129
+ bool ggml_backend_cpu_buft_is_aarch64(ggml_backend_buffer_type_t buft) {
130
+ return buft == ggml_backend_cpu_aarch64_buffer_type();
131
+ }
132
+
133
  static ggml_backend_buffer_type_t * ggml_backend_cpu_get_extra_bufts(ggml_backend_dev_t device) {
134
+ static std::vector<ggml_backend_buffer_type_t> bufts = []() {
135
+ std::vector<ggml_backend_buffer_type_t> bufts;
136
+
137
  #ifdef GGML_USE_CPU_HBM
138
+ bufts.push_back(ggml_backend_cpu_hbm_buffer_type());
139
+ #endif
140
+
141
+ #ifdef GGML_USE_CPU_AARCH64
142
+ bufts.push_back(ggml_backend_cpu_aarch64_buffer_type());
143
  #endif
 
 
144
 
145
+ bufts.push_back(NULL);
146
+
147
+ return bufts;
148
+ }();
149
+
150
+ return bufts.data();
151
 
152
  GGML_UNUSED(device);
153
  }
 
453
  }
454
 
455
  static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
456
+ const struct ggml_tensor * src0 = op->src[0];
457
+ const struct ggml_tensor * src1 = op->src[1];
458
+
459
+ if (src0 && src0->buffer && ggml_backend_cpu_buft_is_aarch64(src0->buffer->buft)) {
460
+ if (op->op != GGML_OP_MUL_MAT || src0->type != GGML_TYPE_Q4_0 || ggml_aarch64_get_optimal_repack_type(src0) == GGML_TYPE_Q4_0) {
461
+ return false;
462
+ }
463
+ }
464
+
465
+ for (int i = 1; i < GGML_MAX_SRC; i++) {
466
+ if (op->src[i] && op->src[i]->buffer && ggml_backend_cpu_buft_is_aarch64(op->src[i]->buffer->buft)) {
467
+ return false;
468
+ }
469
+ }
470
+
471
  switch (op->op) {
472
  case GGML_OP_CPY:
473
  return
 
476
  op->type != GGML_TYPE_IQ1_S &&
477
  op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float
478
  case GGML_OP_MUL_MAT:
479
+ return src1->type == GGML_TYPE_F32 || src1->type == ggml_get_type_traits_cpu(src0->type)->vec_dot_type;
480
  case GGML_OP_ROPE_BACK:
481
  return op->src[2] == NULL && (op->op_params[2] & 4) == 0;
482
  case GGML_OP_IM2COL_BACK:
483
+ return src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32;
484
  case GGML_OP_OUT_PROD:
485
+ return (src0->type == GGML_TYPE_F32 || ggml_is_quantized(src0->type)) && src1->type == GGML_TYPE_F32;
486
  default:
487
  return true;
488
  }
 
491
  }
492
 
493
  static bool ggml_backend_cpu_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
494
+ return ggml_backend_buft_is_host(buft) || ggml_backend_cpu_buft_is_aarch64(buft);
495
 
496
  GGML_UNUSED(dev);
497
  }
 
651
  };
652
 
653
  ggml_backend_reg_t ggml_backend_cpu_reg(void) {
654
+ // init CPU feature detection
655
+ ggml_cpu_init();
656
+
657
  static struct ggml_backend_reg ggml_backend_cpu_reg = {
658
  /* .iface = */ ggml_backend_cpu_reg_i,
659
  /* .context = */ NULL,