Masaya, Kato commited on
Commit
51f504f
·
1 Parent(s): 9f41704

ggml: aarch64: SVE kernels for q8_0_q8_0, q4_0_q8_0 vector dot (llama/7433)

Browse files
Files changed (4) hide show
  1. ggml-impl.h +4 -0
  2. ggml-quants.c +64 -2
  3. ggml.c +10 -0
  4. ggml.h +1 -0
ggml-impl.h CHANGED
@@ -144,6 +144,10 @@ extern "C" {
144
  #endif
145
  #endif
146
 
 
 
 
 
147
  // 16-bit float
148
  // on Arm, we use __fp16
149
  // on x86, we use uint16_t
 
144
  #endif
145
  #endif
146
 
147
+ #if defined(__ARM_FEATURE_SVE)
148
+ #include <arm_sve.h>
149
+ #endif
150
+
151
  // 16-bit float
152
  // on Arm, we use __fp16
153
  // on x86, we use uint16_t
ggml-quants.c CHANGED
@@ -3813,7 +3813,44 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
3813
  return;
3814
  }
3815
  #endif
3816
- #if defined(__ARM_NEON)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3817
  float32x4_t sumv0 = vdupq_n_f32(0.0f);
3818
  float32x4_t sumv1 = vdupq_n_f32(0.0f);
3819
 
@@ -5384,7 +5421,32 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
5384
  return;
5385
  }
5386
  #endif
5387
- #if defined(__ARM_NEON)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5388
  float32x4_t sumv0 = vdupq_n_f32(0.0f);
5389
  float32x4_t sumv1 = vdupq_n_f32(0.0f);
5390
 
 
3813
  return;
3814
  }
3815
  #endif
3816
+ #if defined(__ARM_FEATURE_SVE)
3817
+ const svbool_t ptrueh = svptrue_pat_b8(SV_VL16);
3818
+ const svbool_t ptruel = svnot_b_z(svptrue_b8(), ptrueh);
3819
+
3820
+ svfloat32_t sumv0 = svdup_n_f32(0.0f);
3821
+ svfloat32_t sumv1 = svdup_n_f32(0.0f);
3822
+
3823
+ assert(nb % 2 == 0); // TODO: handle odd nb
3824
+
3825
+ for (int i = 0; i < nb; i += 2) {
3826
+ const block_q4_0 * restrict x0 = &x[i + 0];
3827
+ const block_q4_0 * restrict x1 = &x[i + 1];
3828
+ const block_q8_0 * restrict y0 = &y[i + 0];
3829
+ const block_q8_0 * restrict y1 = &y[i + 1];
3830
+
3831
+ // load x
3832
+ const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
3833
+ const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
3834
+
3835
+ // 4-bit -> 8-bit
3836
+ const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx0r, 0x0F), 0x04));
3837
+ const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx1r, 0x0F), 0x04));
3838
+
3839
+ // sub 8
3840
+ const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8);
3841
+ const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8);
3842
+
3843
+ // load y
3844
+ const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
3845
+ const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
3846
+
3847
+ // dot product
3848
+ sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
3849
+ sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
3850
+ }
3851
+
3852
+ *s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
3853
+ #elif defined(__ARM_NEON)
3854
  float32x4_t sumv0 = vdupq_n_f32(0.0f);
3855
  float32x4_t sumv1 = vdupq_n_f32(0.0f);
3856
 
 
5421
  return;
5422
  }
5423
  #endif
5424
+ #if defined(__ARM_FEATURE_SVE)
5425
+ svfloat32_t sumv0 = svdup_n_f32(0.0f);
5426
+ svfloat32_t sumv1 = svdup_n_f32(0.0f);
5427
+
5428
+ assert(nb % 2 == 0); // TODO: handle odd nb
5429
+
5430
+ for (int i = 0; i < nb; i += 2) {
5431
+ const block_q8_0 * restrict x0 = &x[i + 0];
5432
+ const block_q8_0 * restrict x1 = &x[i + 1];
5433
+ const block_q8_0 * restrict y0 = &y[i + 0];
5434
+ const block_q8_0 * restrict y1 = &y[i + 1];
5435
+
5436
+ // load x
5437
+ const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs);
5438
+ const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs);
5439
+
5440
+ // load y
5441
+ const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
5442
+ const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
5443
+
5444
+ sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
5445
+ sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
5446
+ }
5447
+
5448
+ *s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
5449
+ #elif defined(__ARM_NEON)
5450
  float32x4_t sumv0 = vdupq_n_f32(0.0f);
5451
  float32x4_t sumv1 = vdupq_n_f32(0.0f);
5452
 
ggml.c CHANGED
@@ -22742,6 +22742,16 @@ int ggml_cpu_has_neon(void) {
22742
  #endif
22743
  }
22744
 
 
 
 
 
 
 
 
 
 
 
22745
  int ggml_cpu_has_arm_fma(void) {
22746
  #if defined(__ARM_FEATURE_FMA)
22747
  return 1;
 
22742
  #endif
22743
  }
22744
 
22745
+ int ggml_cpu_has_sve(void) {
22746
+ #if defined(__ARM_FEATURE_SVE)
22747
+ // TODO: Currently, SVE 256 bit is only supported.
22748
+ GGML_ASSERT(svcntb() == QK8_0);
22749
+ return 1;
22750
+ #else
22751
+ return 0;
22752
+ #endif
22753
+ }
22754
+
22755
  int ggml_cpu_has_arm_fma(void) {
22756
  #if defined(__ARM_FEATURE_FMA)
22757
  return 1;
ggml.h CHANGED
@@ -2404,6 +2404,7 @@ extern "C" {
2404
  GGML_API int ggml_cpu_has_avx512_bf16(void);
2405
  GGML_API int ggml_cpu_has_fma (void);
2406
  GGML_API int ggml_cpu_has_neon (void);
 
2407
  GGML_API int ggml_cpu_has_arm_fma (void);
2408
  GGML_API int ggml_cpu_has_metal (void);
2409
  GGML_API int ggml_cpu_has_f16c (void);
 
2404
  GGML_API int ggml_cpu_has_avx512_bf16(void);
2405
  GGML_API int ggml_cpu_has_fma (void);
2406
  GGML_API int ggml_cpu_has_neon (void);
2407
+ GGML_API int ggml_cpu_has_sve (void);
2408
  GGML_API int ggml_cpu_has_arm_fma (void);
2409
  GGML_API int ggml_cpu_has_metal (void);
2410
  GGML_API int ggml_cpu_has_f16c (void);