ggerganov commited on
Commit
9160e8f
·
1 Parent(s): 76b8073

metal : more precise Q*K in FA vec kernel (llama/10247)

Browse files
Files changed (1) hide show
  1. ggml/src/ggml-metal.metal +22 -13
ggml/src/ggml-metal.metal CHANGED
@@ -2942,6 +2942,7 @@ kernel void kernel_flash_attn_ext(
2942
  half smax = -INFINITY;
2943
 
2944
  // load the mask in shared memory
 
2945
  for (short j = 0; j < Q; ++j) {
2946
  device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*nb31);
2947
 
@@ -2968,7 +2969,7 @@ kernel void kernel_flash_attn_ext(
2968
  // we can read directly from global memory
2969
  device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
2970
 
2971
- #pragma unroll
2972
  for (short i = 0; i < D8; ++i) {
2973
  k8x8_t mk;
2974
  simdgroup_load(mk, pk + i*8, nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10
@@ -2989,7 +2990,7 @@ kernel void kernel_flash_attn_ext(
2989
 
2990
  simdgroup_barrier(mem_flags::mem_threadgroup);
2991
 
2992
- #pragma unroll
2993
  for (short k = 0; k < 4; ++k) {
2994
  k8x8_t mk;
2995
 
@@ -3067,7 +3068,7 @@ kernel void kernel_flash_attn_ext(
3067
  s8x8_t mm;
3068
  simdgroup_load(mm, ss + 2*C, TS, 0, false);
3069
 
3070
- #pragma unroll
3071
  for (short i = 0; i < D8; ++i) {
3072
  simdgroup_multiply(lo[i], mm, lo[i]);
3073
  }
@@ -3082,7 +3083,8 @@ kernel void kernel_flash_attn_ext(
3082
  if (is_same<vd4x4_t, v4x4_t>::value) {
3083
  // we can read directly from global memory
3084
  device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
3085
- #pragma unroll
 
3086
  for (short i = 0; i < D8; ++i) {
3087
  v8x8_t mv;
3088
  simdgroup_load(mv, pv + i*8, nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20
@@ -3103,7 +3105,7 @@ kernel void kernel_flash_attn_ext(
3103
 
3104
  simdgroup_barrier(mem_flags::mem_threadgroup);
3105
 
3106
- #pragma unroll
3107
  for (short k = 0; k < 4; ++k) {
3108
  v8x8_t mv;
3109
 
@@ -3196,6 +3198,7 @@ kernel void kernel_flash_attn_ext(
3196
  simdgroup_load(ms0, ss + 2*C, TS, 0, false);
3197
  simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false);
3198
 
 
3199
  for (short i = 0; i < D8; ++i) {
3200
  o8x8_t t;
3201
 
@@ -3413,6 +3416,7 @@ kernel void kernel_flash_attn_ext_vec(
3413
  // load the queries from shared memory into local memory
3414
  q4x4_t mq[D16/NL];
3415
 
 
3416
  for (short ii = 0; ii < D16; ii += NL) {
3417
  mq[ii/NL] = sq4x4[ii + tx];
3418
  }
@@ -3454,17 +3458,23 @@ kernel void kernel_flash_attn_ext_vec(
3454
 
3455
  device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
3456
 
3457
- #pragma unroll
3458
  for (short ii = 0; ii < D16; ii += NL) {
3459
  const short i = ii + tx;
3460
 
3461
  k4x4_t mk;
3462
  deq_k(pk + i/nl_k, i%nl_k, mk);
3463
 
3464
- mqka[0] += dot(mq[ii/NL][0], mk[0]);
3465
- mqka[1] += dot(mq[ii/NL][1], mk[1]);
3466
- mqka[2] += dot(mq[ii/NL][2], mk[2]);
3467
- mqka[3] += dot(mq[ii/NL][3], mk[3]);
 
 
 
 
 
 
3468
  }
3469
 
3470
  qk_t mqk = mqka[0] + mqka[1] + mqka[2] + mqka[3];
@@ -3513,7 +3523,7 @@ kernel void kernel_flash_attn_ext_vec(
3513
  ss[tiisg] = vs;
3514
 
3515
  // O = diag(ms)*O
3516
- #pragma unroll
3517
  for (short ii = 0; ii < D16; ii += NL) {
3518
  lo[ii/NL] *= ms;
3519
  }
@@ -3523,13 +3533,12 @@ kernel void kernel_flash_attn_ext_vec(
3523
 
3524
  // O = O + (Q*K^T)*V
3525
  {
3526
- #pragma unroll
3527
  for (short cc = 0; cc < C/4; ++cc) {
3528
  device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
3529
 
3530
  const s4x4_t ms(ss[4*cc + ty]);
3531
 
3532
- #pragma unroll
3533
  for (short ii = 0; ii < D16; ii += NL) {
3534
  const short i = ii + tx;
3535
 
 
2942
  half smax = -INFINITY;
2943
 
2944
  // load the mask in shared memory
2945
+ #pragma unroll(Q)
2946
  for (short j = 0; j < Q; ++j) {
2947
  device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*nb31);
2948
 
 
2969
  // we can read directly from global memory
2970
  device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
2971
 
2972
+ #pragma unroll(D8)
2973
  for (short i = 0; i < D8; ++i) {
2974
  k8x8_t mk;
2975
  simdgroup_load(mk, pk + i*8, nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10
 
2990
 
2991
  simdgroup_barrier(mem_flags::mem_threadgroup);
2992
 
2993
+ #pragma unroll(4)
2994
  for (short k = 0; k < 4; ++k) {
2995
  k8x8_t mk;
2996
 
 
3068
  s8x8_t mm;
3069
  simdgroup_load(mm, ss + 2*C, TS, 0, false);
3070
 
3071
+ #pragma unroll(D8)
3072
  for (short i = 0; i < D8; ++i) {
3073
  simdgroup_multiply(lo[i], mm, lo[i]);
3074
  }
 
3083
  if (is_same<vd4x4_t, v4x4_t>::value) {
3084
  // we can read directly from global memory
3085
  device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
3086
+
3087
+ #pragma unroll(D8)
3088
  for (short i = 0; i < D8; ++i) {
3089
  v8x8_t mv;
3090
  simdgroup_load(mv, pv + i*8, nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20
 
3105
 
3106
  simdgroup_barrier(mem_flags::mem_threadgroup);
3107
 
3108
+ #pragma unroll(4)
3109
  for (short k = 0; k < 4; ++k) {
3110
  v8x8_t mv;
3111
 
 
3198
  simdgroup_load(ms0, ss + 2*C, TS, 0, false);
3199
  simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false);
3200
 
3201
+ #pragma unroll(D8)
3202
  for (short i = 0; i < D8; ++i) {
3203
  o8x8_t t;
3204
 
 
3416
  // load the queries from shared memory into local memory
3417
  q4x4_t mq[D16/NL];
3418
 
3419
+ #pragma unroll(D16/NL)
3420
  for (short ii = 0; ii < D16; ii += NL) {
3421
  mq[ii/NL] = sq4x4[ii + tx];
3422
  }
 
3458
 
3459
  device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
3460
 
3461
+ #pragma unroll(D16/NL)
3462
  for (short ii = 0; ii < D16; ii += NL) {
3463
  const short i = ii + tx;
3464
 
3465
  k4x4_t mk;
3466
  deq_k(pk + i/nl_k, i%nl_k, mk);
3467
 
3468
+ // note: this is less precise than the version below
3469
+ //mqka[0] += dot(mq[ii/NL][0], mk[0]);
3470
+ //mqka[1] += dot(mq[ii/NL][1], mk[1]);
3471
+ //mqka[2] += dot(mq[ii/NL][2], mk[2]);
3472
+ //mqka[3] += dot(mq[ii/NL][3], mk[3]);
3473
+
3474
+ mqka[0] += dot((float4) mq[ii/NL][0], (float4) mk[0]);
3475
+ mqka[1] += dot((float4) mq[ii/NL][1], (float4) mk[1]);
3476
+ mqka[2] += dot((float4) mq[ii/NL][2], (float4) mk[2]);
3477
+ mqka[3] += dot((float4) mq[ii/NL][3], (float4) mk[3]);
3478
  }
3479
 
3480
  qk_t mqk = mqka[0] + mqka[1] + mqka[2] + mqka[3];
 
3523
  ss[tiisg] = vs;
3524
 
3525
  // O = diag(ms)*O
3526
+ #pragma unroll(D16/NL)
3527
  for (short ii = 0; ii < D16; ii += NL) {
3528
  lo[ii/NL] *= ms;
3529
  }
 
3533
 
3534
  // O = O + (Q*K^T)*V
3535
  {
 
3536
  for (short cc = 0; cc < C/4; ++cc) {
3537
  device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
3538
 
3539
  const s4x4_t ms(ss[4*cc + ty]);
3540
 
3541
+ #pragma unroll(D16/NL)
3542
  for (short ii = 0; ii < D16; ii += NL) {
3543
  const short i = ii + tx;
3544