Spaces:
Running
Running
metal : more precise Q*K in FA vec kernel (llama/10247)
Browse files- ggml/src/ggml-metal.metal +22 -13
ggml/src/ggml-metal.metal
CHANGED
|
@@ -2942,6 +2942,7 @@ kernel void kernel_flash_attn_ext(
|
|
| 2942 |
half smax = -INFINITY;
|
| 2943 |
|
| 2944 |
// load the mask in shared memory
|
|
|
|
| 2945 |
for (short j = 0; j < Q; ++j) {
|
| 2946 |
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*nb31);
|
| 2947 |
|
|
@@ -2968,7 +2969,7 @@ kernel void kernel_flash_attn_ext(
|
|
| 2968 |
// we can read directly from global memory
|
| 2969 |
device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
|
| 2970 |
|
| 2971 |
-
#pragma unroll
|
| 2972 |
for (short i = 0; i < D8; ++i) {
|
| 2973 |
k8x8_t mk;
|
| 2974 |
simdgroup_load(mk, pk + i*8, nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10
|
|
@@ -2989,7 +2990,7 @@ kernel void kernel_flash_attn_ext(
|
|
| 2989 |
|
| 2990 |
simdgroup_barrier(mem_flags::mem_threadgroup);
|
| 2991 |
|
| 2992 |
-
#pragma unroll
|
| 2993 |
for (short k = 0; k < 4; ++k) {
|
| 2994 |
k8x8_t mk;
|
| 2995 |
|
|
@@ -3067,7 +3068,7 @@ kernel void kernel_flash_attn_ext(
|
|
| 3067 |
s8x8_t mm;
|
| 3068 |
simdgroup_load(mm, ss + 2*C, TS, 0, false);
|
| 3069 |
|
| 3070 |
-
#pragma unroll
|
| 3071 |
for (short i = 0; i < D8; ++i) {
|
| 3072 |
simdgroup_multiply(lo[i], mm, lo[i]);
|
| 3073 |
}
|
|
@@ -3082,7 +3083,8 @@ kernel void kernel_flash_attn_ext(
|
|
| 3082 |
if (is_same<vd4x4_t, v4x4_t>::value) {
|
| 3083 |
// we can read directly from global memory
|
| 3084 |
device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
|
| 3085 |
-
|
|
|
|
| 3086 |
for (short i = 0; i < D8; ++i) {
|
| 3087 |
v8x8_t mv;
|
| 3088 |
simdgroup_load(mv, pv + i*8, nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20
|
|
@@ -3103,7 +3105,7 @@ kernel void kernel_flash_attn_ext(
|
|
| 3103 |
|
| 3104 |
simdgroup_barrier(mem_flags::mem_threadgroup);
|
| 3105 |
|
| 3106 |
-
#pragma unroll
|
| 3107 |
for (short k = 0; k < 4; ++k) {
|
| 3108 |
v8x8_t mv;
|
| 3109 |
|
|
@@ -3196,6 +3198,7 @@ kernel void kernel_flash_attn_ext(
|
|
| 3196 |
simdgroup_load(ms0, ss + 2*C, TS, 0, false);
|
| 3197 |
simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false);
|
| 3198 |
|
|
|
|
| 3199 |
for (short i = 0; i < D8; ++i) {
|
| 3200 |
o8x8_t t;
|
| 3201 |
|
|
@@ -3413,6 +3416,7 @@ kernel void kernel_flash_attn_ext_vec(
|
|
| 3413 |
// load the queries from shared memory into local memory
|
| 3414 |
q4x4_t mq[D16/NL];
|
| 3415 |
|
|
|
|
| 3416 |
for (short ii = 0; ii < D16; ii += NL) {
|
| 3417 |
mq[ii/NL] = sq4x4[ii + tx];
|
| 3418 |
}
|
|
@@ -3454,17 +3458,23 @@ kernel void kernel_flash_attn_ext_vec(
|
|
| 3454 |
|
| 3455 |
device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
|
| 3456 |
|
| 3457 |
-
#pragma unroll
|
| 3458 |
for (short ii = 0; ii < D16; ii += NL) {
|
| 3459 |
const short i = ii + tx;
|
| 3460 |
|
| 3461 |
k4x4_t mk;
|
| 3462 |
deq_k(pk + i/nl_k, i%nl_k, mk);
|
| 3463 |
|
| 3464 |
-
|
| 3465 |
-
mqka[
|
| 3466 |
-
mqka[
|
| 3467 |
-
mqka[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3468 |
}
|
| 3469 |
|
| 3470 |
qk_t mqk = mqka[0] + mqka[1] + mqka[2] + mqka[3];
|
|
@@ -3513,7 +3523,7 @@ kernel void kernel_flash_attn_ext_vec(
|
|
| 3513 |
ss[tiisg] = vs;
|
| 3514 |
|
| 3515 |
// O = diag(ms)*O
|
| 3516 |
-
#pragma unroll
|
| 3517 |
for (short ii = 0; ii < D16; ii += NL) {
|
| 3518 |
lo[ii/NL] *= ms;
|
| 3519 |
}
|
|
@@ -3523,13 +3533,12 @@ kernel void kernel_flash_attn_ext_vec(
|
|
| 3523 |
|
| 3524 |
// O = O + (Q*K^T)*V
|
| 3525 |
{
|
| 3526 |
-
#pragma unroll
|
| 3527 |
for (short cc = 0; cc < C/4; ++cc) {
|
| 3528 |
device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
|
| 3529 |
|
| 3530 |
const s4x4_t ms(ss[4*cc + ty]);
|
| 3531 |
|
| 3532 |
-
#pragma unroll
|
| 3533 |
for (short ii = 0; ii < D16; ii += NL) {
|
| 3534 |
const short i = ii + tx;
|
| 3535 |
|
|
|
|
| 2942 |
half smax = -INFINITY;
|
| 2943 |
|
| 2944 |
// load the mask in shared memory
|
| 2945 |
+
#pragma unroll(Q)
|
| 2946 |
for (short j = 0; j < Q; ++j) {
|
| 2947 |
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*nb31);
|
| 2948 |
|
|
|
|
| 2969 |
// we can read directly from global memory
|
| 2970 |
device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
|
| 2971 |
|
| 2972 |
+
#pragma unroll(D8)
|
| 2973 |
for (short i = 0; i < D8; ++i) {
|
| 2974 |
k8x8_t mk;
|
| 2975 |
simdgroup_load(mk, pk + i*8, nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10
|
|
|
|
| 2990 |
|
| 2991 |
simdgroup_barrier(mem_flags::mem_threadgroup);
|
| 2992 |
|
| 2993 |
+
#pragma unroll(4)
|
| 2994 |
for (short k = 0; k < 4; ++k) {
|
| 2995 |
k8x8_t mk;
|
| 2996 |
|
|
|
|
| 3068 |
s8x8_t mm;
|
| 3069 |
simdgroup_load(mm, ss + 2*C, TS, 0, false);
|
| 3070 |
|
| 3071 |
+
#pragma unroll(D8)
|
| 3072 |
for (short i = 0; i < D8; ++i) {
|
| 3073 |
simdgroup_multiply(lo[i], mm, lo[i]);
|
| 3074 |
}
|
|
|
|
| 3083 |
if (is_same<vd4x4_t, v4x4_t>::value) {
|
| 3084 |
// we can read directly from global memory
|
| 3085 |
device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
|
| 3086 |
+
|
| 3087 |
+
#pragma unroll(D8)
|
| 3088 |
for (short i = 0; i < D8; ++i) {
|
| 3089 |
v8x8_t mv;
|
| 3090 |
simdgroup_load(mv, pv + i*8, nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20
|
|
|
|
| 3105 |
|
| 3106 |
simdgroup_barrier(mem_flags::mem_threadgroup);
|
| 3107 |
|
| 3108 |
+
#pragma unroll(4)
|
| 3109 |
for (short k = 0; k < 4; ++k) {
|
| 3110 |
v8x8_t mv;
|
| 3111 |
|
|
|
|
| 3198 |
simdgroup_load(ms0, ss + 2*C, TS, 0, false);
|
| 3199 |
simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false);
|
| 3200 |
|
| 3201 |
+
#pragma unroll(D8)
|
| 3202 |
for (short i = 0; i < D8; ++i) {
|
| 3203 |
o8x8_t t;
|
| 3204 |
|
|
|
|
| 3416 |
// load the queries from shared memory into local memory
|
| 3417 |
q4x4_t mq[D16/NL];
|
| 3418 |
|
| 3419 |
+
#pragma unroll(D16/NL)
|
| 3420 |
for (short ii = 0; ii < D16; ii += NL) {
|
| 3421 |
mq[ii/NL] = sq4x4[ii + tx];
|
| 3422 |
}
|
|
|
|
| 3458 |
|
| 3459 |
device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
|
| 3460 |
|
| 3461 |
+
#pragma unroll(D16/NL)
|
| 3462 |
for (short ii = 0; ii < D16; ii += NL) {
|
| 3463 |
const short i = ii + tx;
|
| 3464 |
|
| 3465 |
k4x4_t mk;
|
| 3466 |
deq_k(pk + i/nl_k, i%nl_k, mk);
|
| 3467 |
|
| 3468 |
+
// note: this is less precise than the version below
|
| 3469 |
+
//mqka[0] += dot(mq[ii/NL][0], mk[0]);
|
| 3470 |
+
//mqka[1] += dot(mq[ii/NL][1], mk[1]);
|
| 3471 |
+
//mqka[2] += dot(mq[ii/NL][2], mk[2]);
|
| 3472 |
+
//mqka[3] += dot(mq[ii/NL][3], mk[3]);
|
| 3473 |
+
|
| 3474 |
+
mqka[0] += dot((float4) mq[ii/NL][0], (float4) mk[0]);
|
| 3475 |
+
mqka[1] += dot((float4) mq[ii/NL][1], (float4) mk[1]);
|
| 3476 |
+
mqka[2] += dot((float4) mq[ii/NL][2], (float4) mk[2]);
|
| 3477 |
+
mqka[3] += dot((float4) mq[ii/NL][3], (float4) mk[3]);
|
| 3478 |
}
|
| 3479 |
|
| 3480 |
qk_t mqk = mqka[0] + mqka[1] + mqka[2] + mqka[3];
|
|
|
|
| 3523 |
ss[tiisg] = vs;
|
| 3524 |
|
| 3525 |
// O = diag(ms)*O
|
| 3526 |
+
#pragma unroll(D16/NL)
|
| 3527 |
for (short ii = 0; ii < D16; ii += NL) {
|
| 3528 |
lo[ii/NL] *= ms;
|
| 3529 |
}
|
|
|
|
| 3533 |
|
| 3534 |
// O = O + (Q*K^T)*V
|
| 3535 |
{
|
|
|
|
| 3536 |
for (short cc = 0; cc < C/4; ++cc) {
|
| 3537 |
device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
|
| 3538 |
|
| 3539 |
const s4x4_t ms(ss[4*cc + ty]);
|
| 3540 |
|
| 3541 |
+
#pragma unroll(D16/NL)
|
| 3542 |
for (short ii = 0; ii < D16; ii += NL) {
|
| 3543 |
const short i = ii + tx;
|
| 3544 |
|