Spaces:
Running
Running
Masaya, Kato
commited on
Commit
·
51f504f
1
Parent(s):
9f41704
ggml: aarch64: SVE kernels for q8_0_q8_0, q4_0_q8_0 vector dot (llama/7433)
Browse files- ggml-impl.h +4 -0
- ggml-quants.c +64 -2
- ggml.c +10 -0
- ggml.h +1 -0
ggml-impl.h
CHANGED
|
@@ -144,6 +144,10 @@ extern "C" {
|
|
| 144 |
#endif
|
| 145 |
#endif
|
| 146 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
// 16-bit float
|
| 148 |
// on Arm, we use __fp16
|
| 149 |
// on x86, we use uint16_t
|
|
|
|
| 144 |
#endif
|
| 145 |
#endif
|
| 146 |
|
| 147 |
+
#if defined(__ARM_FEATURE_SVE)
|
| 148 |
+
#include <arm_sve.h>
|
| 149 |
+
#endif
|
| 150 |
+
|
| 151 |
// 16-bit float
|
| 152 |
// on Arm, we use __fp16
|
| 153 |
// on x86, we use uint16_t
|
ggml-quants.c
CHANGED
|
@@ -3813,7 +3813,44 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
|
| 3813 |
return;
|
| 3814 |
}
|
| 3815 |
#endif
|
| 3816 |
-
#if defined(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3817 |
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
| 3818 |
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
| 3819 |
|
|
@@ -5384,7 +5421,32 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
|
| 5384 |
return;
|
| 5385 |
}
|
| 5386 |
#endif
|
| 5387 |
-
#if defined(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5388 |
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
| 5389 |
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
| 5390 |
|
|
|
|
| 3813 |
return;
|
| 3814 |
}
|
| 3815 |
#endif
|
| 3816 |
+
#if defined(__ARM_FEATURE_SVE)
|
| 3817 |
+
const svbool_t ptrueh = svptrue_pat_b8(SV_VL16);
|
| 3818 |
+
const svbool_t ptruel = svnot_b_z(svptrue_b8(), ptrueh);
|
| 3819 |
+
|
| 3820 |
+
svfloat32_t sumv0 = svdup_n_f32(0.0f);
|
| 3821 |
+
svfloat32_t sumv1 = svdup_n_f32(0.0f);
|
| 3822 |
+
|
| 3823 |
+
assert(nb % 2 == 0); // TODO: handle odd nb
|
| 3824 |
+
|
| 3825 |
+
for (int i = 0; i < nb; i += 2) {
|
| 3826 |
+
const block_q4_0 * restrict x0 = &x[i + 0];
|
| 3827 |
+
const block_q4_0 * restrict x1 = &x[i + 1];
|
| 3828 |
+
const block_q8_0 * restrict y0 = &y[i + 0];
|
| 3829 |
+
const block_q8_0 * restrict y1 = &y[i + 1];
|
| 3830 |
+
|
| 3831 |
+
// load x
|
| 3832 |
+
const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
|
| 3833 |
+
const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
|
| 3834 |
+
|
| 3835 |
+
// 4-bit -> 8-bit
|
| 3836 |
+
const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx0r, 0x0F), 0x04));
|
| 3837 |
+
const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx1r, 0x0F), 0x04));
|
| 3838 |
+
|
| 3839 |
+
// sub 8
|
| 3840 |
+
const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8);
|
| 3841 |
+
const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8);
|
| 3842 |
+
|
| 3843 |
+
// load y
|
| 3844 |
+
const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
|
| 3845 |
+
const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
|
| 3846 |
+
|
| 3847 |
+
// dot product
|
| 3848 |
+
sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
| 3849 |
+
sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
| 3850 |
+
}
|
| 3851 |
+
|
| 3852 |
+
*s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
|
| 3853 |
+
#elif defined(__ARM_NEON)
|
| 3854 |
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
| 3855 |
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
| 3856 |
|
|
|
|
| 5421 |
return;
|
| 5422 |
}
|
| 5423 |
#endif
|
| 5424 |
+
#if defined(__ARM_FEATURE_SVE)
|
| 5425 |
+
svfloat32_t sumv0 = svdup_n_f32(0.0f);
|
| 5426 |
+
svfloat32_t sumv1 = svdup_n_f32(0.0f);
|
| 5427 |
+
|
| 5428 |
+
assert(nb % 2 == 0); // TODO: handle odd nb
|
| 5429 |
+
|
| 5430 |
+
for (int i = 0; i < nb; i += 2) {
|
| 5431 |
+
const block_q8_0 * restrict x0 = &x[i + 0];
|
| 5432 |
+
const block_q8_0 * restrict x1 = &x[i + 1];
|
| 5433 |
+
const block_q8_0 * restrict y0 = &y[i + 0];
|
| 5434 |
+
const block_q8_0 * restrict y1 = &y[i + 1];
|
| 5435 |
+
|
| 5436 |
+
// load x
|
| 5437 |
+
const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs);
|
| 5438 |
+
const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs);
|
| 5439 |
+
|
| 5440 |
+
// load y
|
| 5441 |
+
const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
|
| 5442 |
+
const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
|
| 5443 |
+
|
| 5444 |
+
sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
| 5445 |
+
sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
| 5446 |
+
}
|
| 5447 |
+
|
| 5448 |
+
*s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
|
| 5449 |
+
#elif defined(__ARM_NEON)
|
| 5450 |
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
| 5451 |
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
| 5452 |
|
ggml.c
CHANGED
|
@@ -22742,6 +22742,16 @@ int ggml_cpu_has_neon(void) {
|
|
| 22742 |
#endif
|
| 22743 |
}
|
| 22744 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22745 |
int ggml_cpu_has_arm_fma(void) {
|
| 22746 |
#if defined(__ARM_FEATURE_FMA)
|
| 22747 |
return 1;
|
|
|
|
| 22742 |
#endif
|
| 22743 |
}
|
| 22744 |
|
| 22745 |
+
int ggml_cpu_has_sve(void) {
|
| 22746 |
+
#if defined(__ARM_FEATURE_SVE)
|
| 22747 |
+
// TODO: Currently, SVE 256 bit is only supported.
|
| 22748 |
+
GGML_ASSERT(svcntb() == QK8_0);
|
| 22749 |
+
return 1;
|
| 22750 |
+
#else
|
| 22751 |
+
return 0;
|
| 22752 |
+
#endif
|
| 22753 |
+
}
|
| 22754 |
+
|
| 22755 |
int ggml_cpu_has_arm_fma(void) {
|
| 22756 |
#if defined(__ARM_FEATURE_FMA)
|
| 22757 |
return 1;
|
ggml.h
CHANGED
|
@@ -2404,6 +2404,7 @@ extern "C" {
|
|
| 2404 |
GGML_API int ggml_cpu_has_avx512_bf16(void);
|
| 2405 |
GGML_API int ggml_cpu_has_fma (void);
|
| 2406 |
GGML_API int ggml_cpu_has_neon (void);
|
|
|
|
| 2407 |
GGML_API int ggml_cpu_has_arm_fma (void);
|
| 2408 |
GGML_API int ggml_cpu_has_metal (void);
|
| 2409 |
GGML_API int ggml_cpu_has_f16c (void);
|
|
|
|
| 2404 |
GGML_API int ggml_cpu_has_avx512_bf16(void);
|
| 2405 |
GGML_API int ggml_cpu_has_fma (void);
|
| 2406 |
GGML_API int ggml_cpu_has_neon (void);
|
| 2407 |
+
GGML_API int ggml_cpu_has_sve (void);
|
| 2408 |
GGML_API int ggml_cpu_has_arm_fma (void);
|
| 2409 |
GGML_API int ggml_cpu_has_metal (void);
|
| 2410 |
GGML_API int ggml_cpu_has_f16c (void);
|