ggerganov JohannesGaessler phymbert commited on
Commit
34d3b03
·
1 Parent(s): a83f2ae

ggml : add Flash Attention (llama/5021)

Browse files

* ggml : add ggml_flash_attn_ext API

* ggml : fix GQA support in ggml_flash_attn_ext

* ggml : online attention (CPU)

* metal : initial implementation

* metal : f16 precision

* metal : reduce branches

* metal : specialize for head size

* wip : 8 rows per simd group

* wip : 4 rows per simd group

* wip : template for rows per warp

* metal : parallelize across KV size

* metal : parallel reduce across heads

* metal : efficient flash_attn_f16 implementation

* metal : avoid redundant loads of the attention

* metal : scale and mask in matrix form

* metal : fix comment

* llama : avoid ggml_cast, use F32 query

* metal : add parallel reduce version (disabled)

* metal : move output into local memory + optimize

- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments

* metal : add tests, fix scaling, support C > 32

* metal : improve precision

* ggml : fix f16 mad

* metal : minor

* metal : support Q > 8

* tests : add ATTN tests

* metal : disable buffer allocation logs

* tests : more

* metal : faster inner loop for C == 32

* metal : fix array initialization

* tests : ifdef

* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext

* ggml : fix ggml_soft_max mask requirement

* cuda : fix soft_max to use correct mask size

* cuda : add flash_attn kernel (wip)

* metal : optimize softmax for C > 32

* metal : optimize softmax

* tests : minor fix

* cuda : avoid zeroing fragments

* tests : update dims

* cuda : fix __hisinf() result check

* cuda : avoid warp_reduce for smax

* cuda : use int instead of int64_t

Noticeably improves performance (thanks to Johannes)

* cuda : make loops use the same loop values

Thanks Johannes again for the tip

* cuda : unroll some of the loops

* cuda : avoid __hisinf branches

* cuda : use half2 in softmax

* cuda : switch to 1 warp for bs > 16

* cuda : speed-up reduce part of the kernel

* cuda : unroll Q*K^T loop

* cuda : fix -INF block check

* cuda : simplify softmax

* cuda : fix matrix names

* cuda : minor

* llama : adapt to F16 KQ_pos

* llama : adapt new models to F16 KQ_mask

* ggml : fix F16 store (ARM NEON)

* llama : fix type of KQ_mask and KQ_pos

* ggml : fix CPU soft_max

* tests : add hs=256

* cuda : fix build

* metal : improve perf via smaller int registers

* cuda : adapt soft_max to F16 mask and pos

* CUDA: faster FlashAttention, kernel for bs == 1

* 16 cols for Phi-2

* no vec for hs, no hs==256 ncols==32 for Volta

* adjust kernel selection logic

* 4 warps, 256 stride for all D

* no ncols == 64

* Multiple parallel blocks for batch size 1

* fix compile warnings

* fix excessive KQ_b loads

* fix cmake build

* fix KV cache padding, NaN from INFINITY (llama/6438)

* llama : flash_attn cparam + fix defrag

* server: support flash_attn param

* server: bench: enable flash_attn param

* CUDA: refactor host code, dyn. par. blocks

* fix flash_attn_vec_f16 race condition

* flush softmax exp below threshold to 0

* store temp KQ in registers

* Calculate KQ as FP32 if KQV has GGML_PREC_F32

* Add __hgt2_mask implementation for CUDA 11

* fix KQ FP32 precision fpr parallel_blocks > 1

* llama-bench : add -fa,--flash-attn arg

* metal : add BS=1 kernel for flash attention (llama/6508)

* metal : add BS=1 kernel for flash attention (wip)

* metal : support more than 1 warps

* metal : opts

* metal : opt

* metal : switch to parallel reduce

* metal : reduce registers

* metal : simplify

* metal : initial FA vec kernel

* metal : use F32 attention accumulators

* batched-bench : add fattn arg

* llama : simplify llama_build_kv_store

ggml-ci

* llama : adapt build_olmo to changes

* ggml : fix arm fp16 store on windows

* metal : clean-up

* metal : clean-up kernel code

* metal : minor

* tests : remove benchmarks

ggml-ci

* ggml : fix avx512 const correctness

ggml-ci

* ggml : fix soft_max with bias on CPU

ggml-ci

* common : print --flash-attn in help

* ggml : fix num dimensions in ggml_flash_attn_ext

* llama : force disable flash attention for incompatible models

* ggml : ggml_soft_max support F16/F32 mask/pos

ggml-ci

* cuda : uint -> uint32_t

* cuda : "constexpr dim3" -> "const dim3"

ggml-ci

* cuda : try to fix __hgt2_mask

ggml-ci

* ggml : add TODO's for F16/F32 mask/pos support in other backends

* llama : replace bool need_kq_pos with use_alibi

* llama : prep ALiBi support for BERT models

ggml-ci

* llama : fix n_batch requirements

ggml-ci

* cont

* server : add help for --flash-attn arg

* llama : disable FA for AMD

* tests : remove TMP_ATTN_BENCH

ggml-ci

* llama : support save/load state with FA enabled

ggml-ci

* ci : add CUDA save-load-state tests

ggml-ci

* llama : llama_kv_cache_clear zeroes data + fix save-load seq

ggml-ci

* llama : fix copy-paste errors, add TODO

* llama : disallow incompatible states

* llama : update llama_state_get_size after v_trans field

* metal : remove tmp log

* llama : add static reminder for llama_state_get_size

* metal : fix max nsg

ggml-ci

* ci : fix arg order

ggml-ci

---------

Co-authored-by: Johannes Gäßler <[email protected]>
Co-authored-by: Pierrick HYMBERT <[email protected]>

Files changed (12) hide show
  1. ggml-cuda.cu +6 -0
  2. ggml-cuda/common.cuh +26 -14
  3. ggml-cuda/fattn.cu +944 -0
  4. ggml-cuda/fattn.cuh +3 -0
  5. ggml-cuda/softmax.cu +36 -10
  6. ggml-kompute.cpp +7 -0
  7. ggml-metal.m +372 -177
  8. ggml-metal.metal +654 -18
  9. ggml-sycl.cpp +5 -1
  10. ggml-vulkan.cpp +5 -0
  11. ggml.c +360 -15
  12. ggml.h +20 -0
ggml-cuda.cu CHANGED
@@ -14,6 +14,7 @@
14
  #include "ggml-cuda/cpy.cuh"
15
  #include "ggml-cuda/diagmask.cuh"
16
  #include "ggml-cuda/dmmv.cuh"
 
17
  #include "ggml-cuda/getrows.cuh"
18
  #include "ggml-cuda/im2col.cuh"
19
  #include "ggml-cuda/mmq.cuh"
@@ -140,6 +141,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
140
  info.devices[id].cc = 100*prop.major + 10*prop.minor;
141
  #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
142
  info.devices[id].smpb = prop.sharedMemPerBlock;
 
143
  }
144
 
145
  for (int id = 0; id < info.device_count; ++id) {
@@ -2293,6 +2295,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2293
  case GGML_OP_ARGSORT:
2294
  ggml_cuda_op_argsort(ctx, dst);
2295
  break;
 
 
 
2296
  default:
2297
  return false;
2298
  }
@@ -2568,6 +2573,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
2568
  case GGML_OP_ARANGE:
2569
  case GGML_OP_TIMESTEP_EMBEDDING:
2570
  case GGML_OP_LEAKY_RELU:
 
2571
  return true;
2572
  default:
2573
  return false;
 
14
  #include "ggml-cuda/cpy.cuh"
15
  #include "ggml-cuda/diagmask.cuh"
16
  #include "ggml-cuda/dmmv.cuh"
17
+ #include "ggml-cuda/fattn.cuh"
18
  #include "ggml-cuda/getrows.cuh"
19
  #include "ggml-cuda/im2col.cuh"
20
  #include "ggml-cuda/mmq.cuh"
 
141
  info.devices[id].cc = 100*prop.major + 10*prop.minor;
142
  #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
143
  info.devices[id].smpb = prop.sharedMemPerBlock;
144
+ info.devices[id].nsm = prop.multiProcessorCount;
145
  }
146
 
147
  for (int id = 0; id < info.device_count; ++id) {
 
2295
  case GGML_OP_ARGSORT:
2296
  ggml_cuda_op_argsort(ctx, dst);
2297
  break;
2298
+ case GGML_OP_FLASH_ATTN_EXT:
2299
+ ggml_cuda_flash_attn_ext(ctx, dst);
2300
+ break;
2301
  default:
2302
  return false;
2303
  }
 
2573
  case GGML_OP_ARANGE:
2574
  case GGML_OP_TIMESTEP_EMBEDDING:
2575
  case GGML_OP_LEAKY_RELU:
2576
+ case GGML_OP_FLASH_ATTN_EXT:
2577
  return true;
2578
  default:
2579
  return false;
ggml-cuda/common.cuh CHANGED
@@ -142,6 +142,7 @@
142
  #define CC_PASCAL 600
143
  #define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
144
  #define CC_VOLTA 700
 
145
  #define CC_OFFSET_AMD 1000000
146
  #define CC_RDNA1 (CC_OFFSET_AMD + 1010)
147
  #define CC_RDNA2 (CC_OFFSET_AMD + 1030)
@@ -271,7 +272,6 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
271
  return a;
272
  }
273
 
274
- #ifdef GGML_CUDA_F16
275
  static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
276
  #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
277
  #pragma unroll
@@ -284,7 +284,6 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
284
  NO_DEVICE_CODE;
285
  #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
286
  }
287
- #endif // GGML_CUDA_F16
288
 
289
  static __device__ __forceinline__ float warp_reduce_max(float x) {
290
  #pragma unroll
@@ -294,19 +293,26 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
294
  return x;
295
  }
296
 
297
- //static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
298
- //#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
299
- //#pragma unroll
300
- // for (int mask = 16; mask > 0; mask >>= 1) {
301
- // x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
302
- // }
303
- // return x;
304
- //#else
305
- // GGML_UNUSED(x);
306
- // NO_DEVICE_CODE;
307
- //#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
308
- //}
309
 
 
 
 
 
 
 
 
310
 
311
  #if defined(GGML_USE_HIPBLAS)
312
  #define __CUDA_ARCH__ 1300
@@ -391,6 +397,11 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
391
  }
392
  #endif // defined(GGML_USE_HIPBLAS)
393
 
 
 
 
 
 
394
  // TODO: move to ggml-common.h
395
  static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
396
 
@@ -404,6 +415,7 @@ struct ggml_cuda_device_info {
404
 
405
  struct cuda_device_info {
406
  int cc; // compute capability
 
407
  size_t smpb; // max. shared memory per block
408
  bool vmm; // virtual memory support
409
  size_t vmm_granularity; // granularity of virtual memory
 
142
  #define CC_PASCAL 600
143
  #define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
144
  #define CC_VOLTA 700
145
+ #define CC_AMPERE 800
146
  #define CC_OFFSET_AMD 1000000
147
  #define CC_RDNA1 (CC_OFFSET_AMD + 1010)
148
  #define CC_RDNA2 (CC_OFFSET_AMD + 1030)
 
272
  return a;
273
  }
274
 
 
275
  static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
276
  #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
277
  #pragma unroll
 
284
  NO_DEVICE_CODE;
285
  #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
286
  }
 
287
 
288
  static __device__ __forceinline__ float warp_reduce_max(float x) {
289
  #pragma unroll
 
293
  return x;
294
  }
295
 
296
+ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
297
+ #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
298
+ #pragma unroll
299
+ for (int mask = 16; mask > 0; mask >>= 1) {
300
+ x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
301
+ }
302
+ return x;
303
+ #else
304
+ GGML_UNUSED(x);
305
+ NO_DEVICE_CODE;
306
+ #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
307
+ }
308
 
309
+ #if CUDART_VERSION < 12000
310
+ static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half2 b) {
311
+ const uint32_t mask_low = 0x0000FFFF * (float( __low2half(a)) > float( __low2half(b)));
312
+ const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b)));
313
+ return mask_low | mask_high;
314
+ }
315
+ #endif // CUDART_VERSION < 12000
316
 
317
  #if defined(GGML_USE_HIPBLAS)
318
  #define __CUDA_ARCH__ 1300
 
397
  }
398
  #endif // defined(GGML_USE_HIPBLAS)
399
 
400
+ #define FP16_AVAILABLE defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) ? \
401
+ defined(RDNA1) || defined(RDNA2) || defined(RDNA3) : __CUDA_ARCH__ >= CC_PASCAL
402
+
403
+ #define FP16_MMA_AVAILABLE !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
404
+
405
  // TODO: move to ggml-common.h
406
  static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
407
 
 
415
 
416
  struct cuda_device_info {
417
  int cc; // compute capability
418
+ int nsm; // number of streaming multiprocessors
419
  size_t smpb; // max. shared memory per block
420
  bool vmm; // virtual memory support
421
  size_t vmm_granularity; // granularity of virtual memory
ggml-cuda/fattn.cu ADDED
@@ -0,0 +1,944 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "common.cuh"
2
+ #include "fattn.cuh"
3
+
4
+ #include <cstdint>
5
+
6
+ #if FP16_MMA_AVAILABLE
7
+ #include <mma.h>
8
+ #endif
9
+
10
+ #define FATTN_KQ_STRIDE 256
11
+ #define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
12
+ #define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
13
+
14
+ template<int D, int parallel_blocks> // D == head size
15
+ __launch_bounds__(((D + WARP_SIZE - 1) / WARP_SIZE)*WARP_SIZE, 1)
16
+ static __global__ void flash_attn_vec_ext_f16(
17
+ const char * __restrict__ Q,
18
+ const char * __restrict__ K,
19
+ const char * __restrict__ V,
20
+ const char * __restrict__ mask,
21
+ float * __restrict__ dst,
22
+ float2 * __restrict__ dst_meta,
23
+ const float scale,
24
+ const int ne00,
25
+ const int ne01,
26
+ const int ne02,
27
+ const int ne03,
28
+ const int ne10,
29
+ const int ne11,
30
+ const int ne12,
31
+ const int ne13,
32
+ const int ne31,
33
+ const int nb31,
34
+ const int nb01,
35
+ const int nb02,
36
+ const int nb03,
37
+ const int nb11,
38
+ const int nb12,
39
+ const int nb13,
40
+ const int ne0,
41
+ const int ne1,
42
+ const int ne2,
43
+ const int ne3) {
44
+ #if FP16_AVAILABLE
45
+ //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
46
+
47
+ const int ic = blockIdx.x / parallel_blocks; // Index of the Q/QKV column to work on.
48
+ const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
49
+
50
+ const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
51
+ const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic);
52
+ const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
53
+ const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
54
+ const half * maskh = (const half *) mask + ne11*ic;
55
+
56
+ const int stride_KV = nb11 / sizeof(half);
57
+ const int stride_KV2 = nb11 / sizeof(half2);
58
+
59
+ constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE;
60
+ const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
61
+ __builtin_assume(tid < nwarps*WARP_SIZE);
62
+
63
+ __shared__ half KQ[nwarps*WARP_SIZE];
64
+ KQ[tid] = -INFINITY;
65
+ half2 * KQ2 = (half2 *) KQ;
66
+
67
+ half kqmax = -HALF_MAX_HALF;
68
+ half kqsum = 0.0f;
69
+
70
+ __shared__ half kqmax_shared[WARP_SIZE];
71
+ __shared__ half kqsum_shared[WARP_SIZE];
72
+ if (threadIdx.y == 0) {
73
+ kqmax_shared[threadIdx.x] = -HALF_MAX_HALF;
74
+ kqsum_shared[threadIdx.x] = 0.0f;
75
+ }
76
+ __syncthreads();
77
+
78
+ // Convert Q to half2 and store in registers:
79
+ half2 Q_h2[(D/2 + WARP_SIZE - 1) / WARP_SIZE];
80
+ #pragma unroll
81
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
82
+ const int i = i0 + threadIdx.x;
83
+ if (i0 + WARP_SIZE > D/2 && i >= D/2) {
84
+ break;
85
+ }
86
+
87
+ Q_h2[i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(Q_f2[i].x, Q_f2[i].y);
88
+ }
89
+
90
+ half2 VKQ = make_half2(0.0f, 0.0f); // Each thread calculates a single VKQ value.
91
+
92
+ const int k_start = parallel_blocks == 1 ? 0 : ip*D;
93
+ for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
94
+ // Calculate KQ tile and keep track of new maximum KQ values:
95
+ half kqmax_new = kqmax;
96
+ #pragma unroll
97
+ for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) {
98
+ const int i_KQ = i_KQ_0 + threadIdx.y;
99
+
100
+ if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) {
101
+ break;
102
+ }
103
+
104
+ half2 sum2 = make_half2(0.0f, 0.0f);
105
+ #pragma unroll
106
+ for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
107
+ const int k_KQ = k_KQ_0 + threadIdx.x;
108
+ if (k_KQ_0 + WARP_SIZE > D/2 && k_KQ >= D/2) {
109
+ break;
110
+ }
111
+
112
+ const half2 K_ik = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
113
+ sum2 += K_ik * Q_h2[k_KQ_0/WARP_SIZE];
114
+ }
115
+
116
+ sum2 = warp_reduce_sum(sum2);
117
+ half sum = __low2half(sum2) + __high2half(sum2);
118
+ sum += mask ? maskh[k_VKQ_0 + i_KQ] : __float2half(0.0f);
119
+ kqmax_new = __hmax(kqmax_new, sum);
120
+ if (threadIdx.x == 0) {
121
+ KQ[i_KQ] = sum;
122
+ }
123
+ }
124
+
125
+ kqmax_new = warp_reduce_max(kqmax_new);
126
+ if (threadIdx.x == 0) {
127
+ kqmax_shared[threadIdx.y] = kqmax_new;
128
+ }
129
+ __syncthreads();
130
+ kqmax_new = kqmax_shared[threadIdx.x];
131
+ kqmax_new = warp_reduce_max(kqmax_new);
132
+
133
+ const half KQ_max_scale = hexp(kqmax - kqmax_new);
134
+ kqmax = kqmax_new;
135
+
136
+ const half val = hexp(KQ[tid] - kqmax);
137
+ kqsum = kqsum*KQ_max_scale + val;
138
+ KQ[tid] = val;
139
+
140
+ VKQ *= __half2half2(KQ_max_scale);
141
+
142
+ __syncthreads();
143
+
144
+ if (tid < D) {
145
+ #pragma unroll
146
+ for (int k0 = 0; k0 < D; k0 += 2) {
147
+ if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) {
148
+ break;
149
+ }
150
+
151
+ half2 V_k;
152
+ reinterpret_cast<half&>(V_k.x) = V_h[(k_VKQ_0 + k0 + 0)*stride_KV + tid];
153
+ reinterpret_cast<half&>(V_k.y) = V_h[(k_VKQ_0 + k0 + 1)*stride_KV + tid];
154
+ VKQ += V_k*KQ2[k0/2];
155
+ }
156
+ }
157
+
158
+ __syncthreads();
159
+ }
160
+
161
+ if (tid >= D) {
162
+ kqsum = 0.0f;
163
+ }
164
+
165
+ kqsum = warp_reduce_sum(kqsum);
166
+ if (threadIdx.x == 0) {
167
+ kqsum_shared[threadIdx.y] = kqsum;
168
+ }
169
+ __syncthreads();
170
+ kqsum = kqsum_shared[threadIdx.x];
171
+ kqsum = warp_reduce_sum(kqsum);
172
+
173
+ if (tid >= D) {
174
+ return;
175
+ }
176
+
177
+ half dst_val = (__low2half(VKQ) + __high2half(VKQ));
178
+ if (parallel_blocks == 1) {
179
+ dst_val /= kqsum;
180
+ }
181
+ dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = dst_val;
182
+
183
+ if (parallel_blocks == 1 || tid != 0) {
184
+ return;
185
+ }
186
+ dst_meta[ic*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax, kqsum);
187
+ #else
188
+ NO_DEVICE_CODE;
189
+ #endif // FP16_AVAILABLE
190
+ }
191
+
192
+ // D == head size, VKQ_stride == num VKQ rows calculated in parallel:
193
+ template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t>
194
+ __launch_bounds__(nwarps*WARP_SIZE, 1)
195
+ static __global__ void flash_attn_ext_f16(
196
+ const char * __restrict__ Q,
197
+ const char * __restrict__ K,
198
+ const char * __restrict__ V,
199
+ const char * __restrict__ mask,
200
+ float * __restrict__ dst,
201
+ float2 * __restrict__ dst_meta,
202
+ const float scale,
203
+ const int ne00,
204
+ const int ne01,
205
+ const int ne02,
206
+ const int ne03,
207
+ const int ne10,
208
+ const int ne11,
209
+ const int ne12,
210
+ const int ne13,
211
+ const int ne31,
212
+ const int nb31,
213
+ const int nb01,
214
+ const int nb02,
215
+ const int nb03,
216
+ const int nb11,
217
+ const int nb12,
218
+ const int nb13,
219
+ const int ne0,
220
+ const int ne1,
221
+ const int ne2,
222
+ const int ne3) {
223
+ #if FP16_MMA_AVAILABLE
224
+ //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
225
+
226
+ const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
227
+ const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
228
+
229
+ static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
230
+ static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
231
+ constexpr int frag_m = ncols == 8 ? 32 : 16;
232
+ constexpr int frag_n = ncols == 8 ? 8 : 16;
233
+ static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
234
+ typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::row_major> frag_a_K;
235
+ typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_a_V;
236
+ typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b;
237
+ typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
238
+ typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
239
+
240
+ constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
241
+ constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
242
+ static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps.");
243
+
244
+ // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts:
245
+ constexpr int D_padded = D + 8;
246
+ constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
247
+ constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
248
+
249
+ const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
250
+ const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0);
251
+ const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio));
252
+ const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
253
+ const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0;
254
+ const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2);
255
+
256
+ const int stride_Q = nb01 / sizeof(float);
257
+ const int stride_KV = nb11 / sizeof(half);
258
+
259
+ frag_b Q_b[D/16][ncols/frag_n];
260
+
261
+ // A single buffer for temporarily holding tiles of KQ and VKQ parts:
262
+ constexpr int mem_KQ = ncols*kqs_padded*kqar;
263
+ constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded;
264
+ __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts];
265
+ float * KQ_f = (float *) KQ;
266
+ half2 * KQ2 = (half2 *) KQ;
267
+
268
+ float KQ_rowsum_f[ncols/nwarps] = {0.0f};
269
+ float KQ_max_f[ncols/nwarps];
270
+ float KQ_max_scale_f[ncols/nwarps] = {0.0f};
271
+
272
+ #pragma unroll
273
+ for (int j = 0; j < ncols/nwarps; ++j) {
274
+ KQ_max_f[j] = -FLT_MAX/2.0f;
275
+ }
276
+
277
+ half2 KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}};
278
+ half2 KQ_max_h2[ncols/nwarps];
279
+ half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}};
280
+
281
+ #pragma unroll
282
+ for (int j = 0; j < ncols/nwarps; ++j) {
283
+ KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF);
284
+ }
285
+
286
+ __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
287
+ half2 * VKQ2 = (half2 *) VKQ;
288
+ #pragma unroll
289
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
290
+ const int j = j0 + threadIdx.y;
291
+ #pragma unroll
292
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
293
+ const int i = i0 + threadIdx.x;
294
+ if (i0 + WARP_SIZE > D/2 && i >= D/2) {
295
+ break;
296
+ }
297
+ VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f);
298
+ }
299
+ }
300
+
301
+ // Convert Q to half and apply scale, temporarily store in KQ:
302
+ #pragma unroll
303
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
304
+ const int j = j0 + threadIdx.y;
305
+ #pragma unroll
306
+ for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
307
+ const int i = i0 + threadIdx.x;
308
+ if (i0 + WARP_SIZE > D && i >= D) {
309
+ break;
310
+ }
311
+ KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
312
+ }
313
+ }
314
+
315
+ __syncthreads();
316
+
317
+ // Load Q into tensor core fragments/registers since it will be used frequently:
318
+ #pragma unroll
319
+ for (int i0 = 0; i0 < D; i0 += 16) {
320
+ #pragma unroll
321
+ for (int j0 = 0; j0 < ncols; j0 += frag_n) {
322
+ nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
323
+ }
324
+ }
325
+
326
+ __syncthreads();
327
+
328
+ // Iterate over ne11 == previous tokens:
329
+ for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) {
330
+ // Calculate tile of KQ:
331
+ #pragma unroll
332
+ for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
333
+ frag_c_KQ KQ_c[ncols/frag_n];
334
+ #pragma unroll
335
+ for (int j = 0; j < ncols/frag_n; ++j) {
336
+ nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
337
+ }
338
+ #pragma unroll
339
+ for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
340
+ frag_a_K K_a;
341
+ nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
342
+ #pragma unroll
343
+ for (int j = 0; j < ncols/frag_n; ++j) {
344
+ nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
345
+ }
346
+ }
347
+ #pragma unroll
348
+ for (int j0 = 0; j0 < ncols; j0 += frag_n) {
349
+ nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major);
350
+ }
351
+ }
352
+
353
+ __syncthreads();
354
+
355
+ // Calculate softmax for each KQ column using the current max. value.
356
+ // The divisor is stored in KQ_rowsum and will be applied at the end.
357
+ #pragma unroll
358
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
359
+ const int j = j0 + threadIdx.y;
360
+
361
+ if (std::is_same<KQ_acc_t, float>::value) {
362
+ float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE];
363
+ #pragma unroll
364
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
365
+ const int k = k0 + threadIdx.x;
366
+
367
+ KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k];
368
+ }
369
+
370
+ float KQ_max_new = KQ_max_f[j0/nwarps];
371
+ #pragma unroll
372
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
373
+ const int k = k0 + threadIdx.x;
374
+
375
+ KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
376
+ KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]);
377
+ }
378
+ KQ_max_new = warp_reduce_max(KQ_max_new);
379
+
380
+ const float diff = KQ_max_f[j0/nwarps] - KQ_max_new;
381
+ KQ_max_scale_f[j0/nwarps] = expf(diff);
382
+ if (diff <= SOFTMAX_FTZ_THRESHOLD) {
383
+ KQ_max_scale_f[j0/nwarps] = 0.0f;
384
+ }
385
+ KQ_max_f[j0/nwarps] = KQ_max_new;
386
+
387
+ float KQ_rowsum_add = 0.0f;
388
+ #pragma unroll
389
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
390
+ const int k = k0 + threadIdx.x;
391
+
392
+ const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps];
393
+ KQ_f_tmp[k0/WARP_SIZE] = expf(diff);
394
+ if (diff <= SOFTMAX_FTZ_THRESHOLD) {
395
+ KQ_f_tmp[k0/WARP_SIZE] = 0.0f;
396
+ }
397
+ KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE];
398
+ KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE];
399
+ }
400
+ KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
401
+
402
+ // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
403
+ KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;
404
+ } else {
405
+ half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)];
406
+ #pragma unroll
407
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
408
+ const int k = k0 + threadIdx.x;
409
+
410
+ KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
411
+ }
412
+
413
+ half2 KQ_max_new = KQ_max_h2[j0/nwarps];
414
+ #pragma unroll
415
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
416
+ const int k = k0 + threadIdx.x;
417
+
418
+ KQ2_tmp[k0/WARP_SIZE] += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
419
+ KQ_max_new = __hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
420
+ }
421
+ KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
422
+ const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
423
+ KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
424
+ const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
425
+ *((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask;
426
+ KQ_max_h2[j0/nwarps] = KQ_max_new;
427
+
428
+ half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
429
+ #pragma unroll
430
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
431
+ const int k = k0 + threadIdx.x;
432
+
433
+ const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps];
434
+ KQ2_tmp[k0/WARP_SIZE] = h2exp(diff);
435
+ const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
436
+ *((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask;
437
+ KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE];
438
+ KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE];
439
+ }
440
+ KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
441
+
442
+ // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
443
+ KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add;
444
+ }
445
+ }
446
+
447
+ __syncthreads();
448
+
449
+ frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n];
450
+ #pragma unroll
451
+ for (int j0 = 0; j0 < ncols; j0 += frag_n) {
452
+ #pragma unroll
453
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
454
+ const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
455
+ nvcuda::wmma::load_matrix_sync(
456
+ KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
457
+ KQ + j0*(kqar*kqs_padded) + k,
458
+ kqar*kqs_padded);
459
+ }
460
+ }
461
+
462
+ frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n];
463
+ #pragma unroll
464
+ for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
465
+ #pragma unroll
466
+ for (int j = 0; j < ncols/frag_n; ++j) {
467
+ nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
468
+ }
469
+
470
+ #pragma unroll
471
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
472
+ const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
473
+
474
+ frag_a_V v_a;
475
+ nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
476
+ #pragma unroll
477
+ for (int j = 0; j < ncols/frag_n; ++j) {
478
+ nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
479
+ }
480
+ }
481
+ }
482
+
483
+ __syncthreads();
484
+
485
+ const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded);
486
+ #pragma unroll
487
+ for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
488
+ #pragma unroll
489
+ for (int j0 = 0; j0 < ncols; j0 += frag_n) {
490
+ nvcuda::wmma::store_matrix_sync(
491
+ KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
492
+ VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
493
+ D_padded, nvcuda::wmma::mem_col_major);
494
+ }
495
+ }
496
+
497
+ __syncthreads();
498
+
499
+ #pragma unroll
500
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
501
+ const int j = j0 + threadIdx.y;
502
+
503
+ half2 VKQ_scale;
504
+ if (std::is_same<KQ_acc_t, float>::value) {
505
+ VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]);
506
+ } else {
507
+ VKQ_scale = KQ_max_scale_h2[j0/nwarps];
508
+ }
509
+
510
+ #pragma unroll
511
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
512
+ const int i = i0 + threadIdx.x;
513
+ if (i0 + WARP_SIZE > D/2 && i >= D/2) {
514
+ break;
515
+ }
516
+
517
+ half2 VKQ_add = make_half2(0.0f, 0.0f);
518
+ #pragma unroll
519
+ for (int l = 0; l < VKQ_ratio; ++l) {
520
+ VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i];
521
+ }
522
+ VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add;
523
+ }
524
+ }
525
+
526
+ __syncthreads();
527
+ }
528
+
529
+ #pragma unroll
530
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
531
+ const int j_VKQ = j0 + threadIdx.y;
532
+ if (ic0 + j_VKQ >= ne01) {
533
+ return;
534
+ }
535
+ const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
536
+
537
+ float KQ_rowsum_j;
538
+ if (std::is_same<KQ_acc_t, float>::value) {
539
+ KQ_rowsum_j = KQ_rowsum_f[j0/nwarps];
540
+ } else {
541
+ KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
542
+ }
543
+
544
+ #pragma unroll
545
+ for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
546
+ const int i = i0 + threadIdx.x;
547
+ if (i0 + WARP_SIZE > D && i >= D) {
548
+ break;
549
+ }
550
+ float dst_val = VKQ[j_VKQ*D_padded + i];
551
+ if (parallel_blocks == 1) {
552
+ dst_val /= KQ_rowsum_j;
553
+ }
554
+ dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val;
555
+ }
556
+
557
+ if (parallel_blocks == 1 || threadIdx.x != 0) {
558
+ continue;
559
+ }
560
+
561
+ float2 dst_meta_val;
562
+ if (std::is_same<KQ_acc_t, float>::value) {
563
+ dst_meta_val.x = KQ_max_f[j0/nwarps];
564
+ } else {
565
+ dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
566
+ }
567
+ dst_meta_val.y = KQ_rowsum_j;
568
+ dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val;
569
+ }
570
+ #else
571
+ NO_DEVICE_CODE;
572
+ #endif // FP16_MMA_AVAILABLE
573
+ }
574
+
575
+ template<int D, int parallel_blocks> // D == head size
576
+ __launch_bounds__(D, 1)
577
+ static __global__ void flash_attn_combine_results(
578
+ const float * __restrict__ VKQ_parts,
579
+ const float2 * __restrict__ VKQ_meta,
580
+ float * __restrict__ dst) {
581
+ #if FP16_AVAILABLE
582
+ VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x;
583
+ VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x;
584
+ dst += D * gridDim.y*blockIdx.x;
585
+
586
+ const int tid = threadIdx.x;
587
+ __builtin_assume(tid < D);
588
+
589
+ __shared__ float2 meta[parallel_blocks];
590
+ if (tid < 2*parallel_blocks) {
591
+ ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid];
592
+ }
593
+
594
+ __syncthreads();
595
+
596
+ float kqmax = meta[0].x;
597
+ #pragma unroll
598
+ for (int l = 1; l < parallel_blocks; ++l) {
599
+ kqmax = max(kqmax, meta[l].x);
600
+ }
601
+
602
+ float VKQ_numerator = 0.0f;
603
+ float VKQ_denominator = 0.0f;
604
+ #pragma unroll
605
+ for (int l = 0; l < parallel_blocks; ++l) {
606
+ const float diff = meta[l].x - kqmax;
607
+ const float KQ_max_scale = expf(diff);
608
+ const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
609
+ *((uint32_t *) &KQ_max_scale) &= ftz_mask;
610
+
611
+ VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid];
612
+ VKQ_denominator += KQ_max_scale * meta[l].y;
613
+ }
614
+
615
+ dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
616
+ #else
617
+ NO_DEVICE_CODE;
618
+ #endif // FP16_AVAILABLE
619
+ }
620
+
621
+ constexpr int get_max_power_of_2(int x) {
622
+ return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1;
623
+ }
624
+
625
+ static_assert(get_max_power_of_2(1) == 1, "Test failed.");
626
+ static_assert(get_max_power_of_2(2) == 2, "Test failed.");
627
+ static_assert(get_max_power_of_2(4) == 4, "Test failed.");
628
+ static_assert(get_max_power_of_2(6) == 2, "Test failed.");
629
+
630
+ // Number of VKQ rows calculated in parallel:
631
+ constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) {
632
+ return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m;
633
+ }
634
+
635
+ static_assert(get_VKQ_stride(128, 1, 32) == 32, "Test failed.");
636
+ static_assert(get_VKQ_stride(128, 2, 32) == 64, "Test failed.");
637
+ static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed.");
638
+ static_assert(get_VKQ_stride( 64, 1, 32) == 32, "Test failed.");
639
+ static_assert(get_VKQ_stride( 64, 2, 32) == 64, "Test failed.");
640
+ static_assert(get_VKQ_stride( 64, 4, 32) == 64, "Test failed.");
641
+ static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed.");
642
+ static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed.");
643
+ static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
644
+
645
+ template <int D, int parallel_blocks> void launch_fattn_vec_f16(
646
+ const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
647
+ ggml_cuda_pool & pool, cudaStream_t main_stream
648
+ ) {
649
+ ggml_cuda_pool_alloc<float> dst_tmp(pool);
650
+ ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
651
+
652
+ if (parallel_blocks > 1) {
653
+ dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
654
+ dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
655
+ }
656
+
657
+ constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE;
658
+ const dim3 block_dim(WARP_SIZE, nwarps, 1);
659
+ const dim3 blocks_num(parallel_blocks*Q->ne[1], Q->ne[2], Q->ne[3]);
660
+ const int shmem = 0;
661
+
662
+ float scale;
663
+ memcpy(&scale, KQV->op_params, sizeof(float));
664
+
665
+ flash_attn_vec_ext_f16<D, parallel_blocks>
666
+ <<<blocks_num, block_dim, shmem, main_stream>>> (
667
+ (const char *) Q->data,
668
+ (const char *) K->data,
669
+ (const char *) V->data,
670
+ mask ? ((const char *) mask->data) : nullptr,
671
+ parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
672
+ scale,
673
+ Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
674
+ K->ne[0], K->ne[1], K->ne[2], K->ne[3],
675
+ mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
676
+ Q->nb[1], Q->nb[2], Q->nb[3],
677
+ K->nb[1], K->nb[2], K->nb[3],
678
+ KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
679
+ );
680
+ CUDA_CHECK(cudaGetLastError());
681
+
682
+ if (parallel_blocks == 1) {
683
+ return;
684
+ }
685
+
686
+ const dim3 block_dim_combine(D, 1, 1);
687
+ const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
688
+ const int shmem_combine = 0;
689
+
690
+ flash_attn_combine_results<D, parallel_blocks>
691
+ <<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
692
+ (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
693
+ CUDA_CHECK(cudaGetLastError());
694
+ }
695
+
696
+ template <int D, int cols_per_block, int nwarps, int parallel_blocks, typename KQ_acc_t> void launch_fattn_f16_impl(
697
+ const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
698
+ ggml_cuda_pool & pool, cudaStream_t main_stream
699
+ ) {
700
+ ggml_cuda_pool_alloc<float> dst_tmp(pool);
701
+ ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
702
+
703
+ if (parallel_blocks > 1) {
704
+ dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
705
+ dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
706
+ }
707
+
708
+ constexpr int frag_m = (cols_per_block) == 8 && (D) % 32 == 0 ? 32 : 16;
709
+ const dim3 block_dim(WARP_SIZE, nwarps, 1);
710
+ const dim3 blocks_num(parallel_blocks*(Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]);
711
+ const int shmem = 0;
712
+
713
+ float scale;
714
+ memcpy(&scale, KQV->op_params, sizeof(float));
715
+
716
+ flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>
717
+ <<<blocks_num, block_dim, shmem, main_stream>>> (
718
+ (const char *) Q->data,
719
+ (const char *) K->data,
720
+ (const char *) V->data,
721
+ mask ? ((const char *) mask->data) : nullptr,
722
+ (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
723
+ scale,
724
+ Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
725
+ K->ne[0], K->ne[1], K->ne[2], K->ne[3],
726
+ mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
727
+ Q->nb[1], Q->nb[2], Q->nb[3],
728
+ K->nb[1], K->nb[2], K->nb[3],
729
+ KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
730
+ );
731
+ CUDA_CHECK(cudaGetLastError());
732
+
733
+ if ((parallel_blocks) == 1) {
734
+ return;
735
+ }
736
+
737
+ const dim3 block_dim_combine(D, 1, 1);
738
+ const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
739
+ const int shmem_combine = 0;
740
+
741
+ flash_attn_combine_results<D, parallel_blocks>
742
+ <<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
743
+ (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
744
+ CUDA_CHECK(cudaGetLastError());
745
+ }
746
+
747
+ template <int D, int cols_per_block, int nwarps, typename KQ_acc_t> void launch_fattn_f16(
748
+ const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
749
+ const int nsm, ggml_cuda_pool & pool, cudaStream_t main_stream
750
+ ) {
751
+ const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
752
+
753
+ if (4*blocks_num_pb1 < 2*nsm) {
754
+ launch_fattn_f16_impl<D, cols_per_block, nwarps, 4, KQ_acc_t>(Q, K, V, KQV, mask, pool, main_stream);
755
+ return;
756
+ }
757
+ if (2*blocks_num_pb1 < 2*nsm) {
758
+ launch_fattn_f16_impl<D, cols_per_block, nwarps, 2, KQ_acc_t>(Q, K, V, KQV, mask, pool, main_stream);
759
+ return;
760
+ }
761
+ launch_fattn_f16_impl<D, cols_per_block, nwarps, 1, KQ_acc_t>(Q, K, V, KQV, mask, pool, main_stream);
762
+ }
763
+
764
+ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
765
+ const ggml_tensor * Q = dst->src[0];
766
+ const ggml_tensor * K = dst->src[1];
767
+ const ggml_tensor * V = dst->src[2];
768
+
769
+ const ggml_tensor * mask = dst->src[3];
770
+
771
+ ggml_tensor * KQV = dst;
772
+
773
+ GGML_ASSERT(Q->type == GGML_TYPE_F32);
774
+ GGML_ASSERT(K->type == GGML_TYPE_F16);
775
+ GGML_ASSERT(V->type == GGML_TYPE_F16);
776
+ GGML_ASSERT(KQV->type == GGML_TYPE_F32);
777
+
778
+ GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
779
+ GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
780
+ "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
781
+
782
+ GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
783
+
784
+ ggml_cuda_set_device(ctx.device);
785
+
786
+ const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
787
+
788
+ const int32_t precision = KQV->op_params[1];
789
+
790
+ if (precision != GGML_PREC_DEFAULT) {
791
+ if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
792
+ constexpr int cols_per_block = 16;
793
+ constexpr int nwarps = 4;
794
+ switch (Q->ne[0]) {
795
+ case 64:
796
+ launch_fattn_f16< 64, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
797
+ break;
798
+ case 80:
799
+ launch_fattn_f16< 80, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
800
+ break;
801
+ case 96:
802
+ launch_fattn_f16< 96, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
803
+ break;
804
+ case 112:
805
+ launch_fattn_f16<112, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
806
+ break;
807
+ case 128:
808
+ launch_fattn_f16<128, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
809
+ break;
810
+ case 256:
811
+ launch_fattn_f16<256, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
812
+ break;
813
+ default:
814
+ GGML_ASSERT(false);
815
+ break;
816
+ }
817
+ } else {
818
+ constexpr int cols_per_block = 32;
819
+ constexpr int nwarps = 4;
820
+ switch (Q->ne[0]) {
821
+ case 64:
822
+ launch_fattn_f16< 64, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
823
+ break;
824
+ case 80:
825
+ launch_fattn_f16< 80, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
826
+ break;
827
+ case 96:
828
+ launch_fattn_f16< 96, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
829
+ break;
830
+ case 112:
831
+ launch_fattn_f16<112, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
832
+ break;
833
+ case 128:
834
+ launch_fattn_f16<128, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
835
+ break;
836
+ // case 256:
837
+ // launch_fattn_f16<256, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
838
+ // break;
839
+ default:
840
+ GGML_ASSERT(false);
841
+ break;
842
+ }
843
+ }
844
+ return;
845
+ }
846
+
847
+ if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
848
+ constexpr int parallel_blocks = 4;
849
+ switch (Q->ne[0]) {
850
+ case 64:
851
+ launch_fattn_vec_f16< 64, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
852
+ break;
853
+ case 128:
854
+ launch_fattn_vec_f16<128, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
855
+ break;
856
+ case 256:
857
+ launch_fattn_vec_f16<256, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
858
+ break;
859
+ default:
860
+ GGML_ASSERT(false);
861
+ break;
862
+ }
863
+ return;
864
+ }
865
+
866
+ if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
867
+ constexpr int cols_per_block = 8;
868
+ constexpr int nwarps = 4;
869
+ switch (Q->ne[0]) {
870
+ case 64:
871
+ launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
872
+ break;
873
+ case 96:
874
+ launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
875
+ break;
876
+ case 128:
877
+ launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
878
+ break;
879
+ case 256:
880
+ launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
881
+ break;
882
+ default:
883
+ GGML_ASSERT(false);
884
+ break;
885
+ }
886
+ return;
887
+ }
888
+
889
+ if (Q->ne[1] <= 32) {
890
+ constexpr int cols_per_block = 16;
891
+ constexpr int nwarps = 4;
892
+ switch (Q->ne[0]) {
893
+ case 64:
894
+ launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
895
+ break;
896
+ case 80:
897
+ launch_fattn_f16< 80, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
898
+ break;
899
+ case 96:
900
+ launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
901
+ break;
902
+ case 112:
903
+ launch_fattn_f16<112, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
904
+ break;
905
+ case 128:
906
+ launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
907
+ break;
908
+ case 256:
909
+ launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
910
+ break;
911
+ default:
912
+ GGML_ASSERT(false);
913
+ break;
914
+ }
915
+ return;
916
+ }
917
+
918
+ constexpr int cols_per_block = 32;
919
+ constexpr int nwarps = 4;
920
+ switch (Q->ne[0]) {
921
+ case 64:
922
+ launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
923
+ break;
924
+ case 80:
925
+ launch_fattn_f16< 80, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
926
+ break;
927
+ case 96:
928
+ launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
929
+ break;
930
+ case 112:
931
+ launch_fattn_f16<112, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
932
+ break;
933
+ case 128:
934
+ launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
935
+ break;
936
+ case 256:
937
+ launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
938
+ break;
939
+ default:
940
+ GGML_ASSERT(false);
941
+ break;
942
+ }
943
+ return;
944
+ }
ggml-cuda/fattn.cuh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #include "common.cuh"
2
+
3
+ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
ggml-cuda/softmax.cu CHANGED
@@ -1,7 +1,17 @@
1
  #include "softmax.cuh"
2
 
3
- template <bool vals_smem, int ncols_template, int block_size_template>
4
- static __global__ void soft_max_f32(const float * x, const float * mask, const float * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
 
 
 
 
 
 
 
 
 
 
5
  const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
6
 
7
  const int tid = threadIdx.x;
@@ -43,7 +53,7 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f
43
  const int64_t ix = (int64_t)rowx*ncols + col;
44
  const int64_t iy = (int64_t)rowy*ncols + col;
45
 
46
- const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + (pos ? slope*pos[col] : 0.0f);
47
 
48
  vals[col] = val;
49
  max_val = max(max_val, val);
@@ -114,7 +124,8 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f
114
  }
115
  }
116
 
117
- static void soft_max_f32_cuda(const float * x, const float * mask, const float * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
 
118
  int nth = WARP_SIZE;
119
  while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
120
  const dim3 block_dims(nth, 1, 1);
@@ -167,15 +178,19 @@ static void soft_max_f32_cuda(const float * x, const float * mask, const float *
167
  void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
168
  const ggml_tensor * src0 = dst->src[0];
169
  const ggml_tensor * src1 = dst->src[1];
 
 
170
  const float * src0_d = (const float *)src0->data;
171
- const float * src1_d = src1 ? (const float *)src1->data : nullptr;
 
172
  float * dst_d = (float *)dst->data;
173
  cudaStream_t stream = ctx.stream();
174
 
175
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
176
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
177
 
178
- GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
 
179
 
180
  const int64_t ne00 = src0->ne[0];
181
  const int64_t nrows_x = ggml_nrows(src0);
@@ -188,14 +203,25 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
188
  memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
189
 
190
  // positions tensor
191
- float * src2_dd = nullptr;
192
 
193
- ggml_tensor * src2 = dst->src[2];
194
  const bool use_src2 = src2 != nullptr;
195
 
196
  if (use_src2) {
197
- src2_dd = (float *)src2->data;
198
  }
199
 
200
- soft_max_f32_cuda(src0_d, src1_d, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
 
 
 
 
 
 
 
 
 
 
 
 
201
  }
 
1
  #include "softmax.cuh"
2
 
3
+ template <typename T>
4
+ static __device__ __forceinline__ float t2f32(T val) {
5
+ return (float) val;
6
+ }
7
+
8
+ template <>
9
+ __device__ float __forceinline__ t2f32<half>(half val) {
10
+ return __half2float(val);
11
+ }
12
+
13
+ template <bool vals_smem, int ncols_template, int block_size_template, typename T>
14
+ static __global__ void soft_max_f32(const float * x, const T * mask, const T * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
15
  const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
16
 
17
  const int tid = threadIdx.x;
 
53
  const int64_t ix = (int64_t)rowx*ncols + col;
54
  const int64_t iy = (int64_t)rowy*ncols + col;
55
 
56
+ const float val = x[ix]*scale + (mask ? t2f32(mask[iy]) : 0.0f) + (pos ? slope*t2f32(pos[col]) : 0.0f);
57
 
58
  vals[col] = val;
59
  max_val = max(max_val, val);
 
124
  }
125
  }
126
 
127
+ template<typename T>
128
+ static void soft_max_f32_cuda(const float * x, const T * mask, const T * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
129
  int nth = WARP_SIZE;
130
  while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
131
  const dim3 block_dims(nth, 1, 1);
 
178
  void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
179
  const ggml_tensor * src0 = dst->src[0];
180
  const ggml_tensor * src1 = dst->src[1];
181
+ const ggml_tensor * src2 = dst->src[2];
182
+
183
  const float * src0_d = (const float *)src0->data;
184
+ const void * src1_d = src1 ? (const void *)src1->data : nullptr;
185
+
186
  float * dst_d = (float *)dst->data;
187
  cudaStream_t stream = ctx.stream();
188
 
189
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
190
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
191
 
192
+ GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
193
+ GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional
194
 
195
  const int64_t ne00 = src0->ne[0];
196
  const int64_t nrows_x = ggml_nrows(src0);
 
203
  memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
204
 
205
  // positions tensor
206
+ void * src2_d = nullptr;
207
 
 
208
  const bool use_src2 = src2 != nullptr;
209
 
210
  if (use_src2) {
211
+ src2_d = (void *)src2->data;
212
  }
213
 
214
+ const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
215
+
216
+ if (use_f16) {
217
+ const half * src1_dd = (const half *)src1_d;
218
+ const half * src2_dd = (const half *)src2_d;
219
+
220
+ soft_max_f32_cuda(src0_d, src1_dd, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
221
+ } else {
222
+ const float * src1_dd = (const float *)src1_d;
223
+ const float * src2_dd = (const float *)src2_d;
224
+
225
+ soft_max_f32_cuda(src0_d, src1_dd, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
226
+ }
227
  }
ggml-kompute.cpp CHANGED
@@ -1427,6 +1427,7 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
1427
  for (int i = node_start; i < node_end; ++i) {
1428
  struct ggml_tensor * src0 = gf->nodes[i]->src[0];
1429
  struct ggml_tensor * src1 = gf->nodes[i]->src[1];
 
1430
  struct ggml_tensor * dst = gf->nodes[i];
1431
  GGML_ASSERT(dst->data != nullptr);
1432
 
@@ -1559,6 +1560,12 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
1559
  {
1560
  float scale;
1561
  memcpy(&scale, dst->op_params, sizeof(float));
 
 
 
 
 
 
1562
  ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale);
1563
  } break;
1564
  case GGML_OP_DIAG_MASK_INF:
 
1427
  for (int i = node_start; i < node_end; ++i) {
1428
  struct ggml_tensor * src0 = gf->nodes[i]->src[0];
1429
  struct ggml_tensor * src1 = gf->nodes[i]->src[1];
1430
+ struct ggml_tensor * src2 = gf->nodes[i]->src[2]; GGML_UNUSED(src2);
1431
  struct ggml_tensor * dst = gf->nodes[i];
1432
  GGML_ASSERT(dst->data != nullptr);
1433
 
 
1560
  {
1561
  float scale;
1562
  memcpy(&scale, dst->op_params, sizeof(float));
1563
+
1564
+ #pragma message("TODO: add ggml_vk_soft_max() F16/F32 src1 and src2 support")
1565
+ #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
1566
+ GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32);
1567
+ GGML_ASSERT(src2 == nullptr);
1568
+
1569
  ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale);
1570
  } break;
1571
  case GGML_OP_DIAG_MASK_INF:
ggml-metal.m CHANGED
@@ -47,8 +47,10 @@ enum ggml_metal_kernel_type {
47
  GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
48
  GGML_METAL_KERNEL_TYPE_SILU,
49
  GGML_METAL_KERNEL_TYPE_SILU_4,
50
- GGML_METAL_KERNEL_TYPE_SOFT_MAX,
51
- GGML_METAL_KERNEL_TYPE_SOFT_MAX_4,
 
 
52
  GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
53
  GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
54
  GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
@@ -178,6 +180,14 @@ enum ggml_metal_kernel_type {
178
  GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
179
  GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
180
  GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
 
 
 
 
 
 
 
 
181
  GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
182
  GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
183
  GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
@@ -444,7 +454,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
444
  }
445
 
446
  /*
447
- GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
448
  (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \
449
  (int) kernel->pipeline.threadExecutionWidth); \
450
  */
@@ -460,173 +470,183 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
460
  return NULL; \
461
  } \
462
  } else { \
463
- GGML_METAL_LOG_WARN("%s: skipping %-32s (not supported)\n", __func__, "kernel_"#name); \
464
  }
465
 
466
  // simd_sum and simd_max requires MTLGPUFamilyApple7
467
 
468
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
469
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
470
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
471
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
472
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
473
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
474
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
475
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
476
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
477
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
478
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
479
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
480
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
481
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
482
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
483
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
484
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
485
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
486
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction);
487
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction);
488
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
489
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
490
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
491
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
492
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true);
493
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
494
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
495
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true);
496
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true);
497
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true);
498
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true);
499
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true);
500
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true);
501
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true);
502
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
503
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
504
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
505
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true);
506
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true);
507
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
508
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true);
509
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
510
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
511
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
512
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
513
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
514
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
515
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
516
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction);
517
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
518
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction);
519
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction);
520
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction);
521
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction);
522
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction);
523
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction);
524
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction);
525
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction);
526
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction);
527
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction);
528
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction);
529
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction);
530
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
531
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
532
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction);
533
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction);
534
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction);
535
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
536
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, ctx->support_simdgroup_reduction);
537
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
538
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction);
539
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
540
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
541
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
542
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction);
543
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction);
544
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction);
545
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction);
546
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction);
547
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction);
548
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction);
549
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction);
550
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction);
551
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction);
552
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction);
553
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction);
554
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
555
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
556
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction);
557
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction);
558
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction);
559
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
560
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, ctx->support_simdgroup_reduction);
561
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
562
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction);
563
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
564
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
565
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
566
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm);
567
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm);
568
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm);
569
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm);
570
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm);
571
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm);
572
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm);
573
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm);
574
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm);
575
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
576
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
577
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
578
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm);
579
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm);
580
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
581
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, ctx->support_simdgroup_mm);
582
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
583
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm);
584
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
585
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
586
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
587
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm);
588
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm);
589
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm);
590
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm);
591
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm);
592
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm);
593
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm);
594
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm);
595
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm);
596
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
597
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
598
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm);
599
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm);
600
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm);
601
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
602
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm);
603
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
604
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
605
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
606
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
607
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
608
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
609
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
610
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
611
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
612
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
613
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true);
614
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
615
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
616
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
617
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
618
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
619
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
620
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
621
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
622
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
623
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
624
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
625
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
626
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
627
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
628
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
629
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
 
 
 
 
 
 
 
 
 
 
630
  }
631
 
632
  [metal_library release];
@@ -746,6 +766,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
746
  case GGML_OP_TIMESTEP_EMBEDDING:
747
  case GGML_OP_ARGSORT:
748
  case GGML_OP_LEAKY_RELU:
 
749
  return true;
750
  case GGML_OP_MUL_MAT:
751
  case GGML_OP_MUL_MAT_ID:
@@ -1341,20 +1362,33 @@ static enum ggml_status ggml_metal_graph_compute(
1341
  } break;
1342
  case GGML_OP_SOFT_MAX:
1343
  {
 
 
 
1344
  int nth = 32; // SIMD width
1345
 
1346
  id<MTLComputePipelineState> pipeline = nil;
1347
 
 
 
1348
  if (ne00%4 == 0) {
1349
  while (nth < ne00/4 && nth < 256) {
1350
  nth *= 2;
1351
  }
1352
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_4].pipeline;
 
 
 
 
1353
  } else {
1354
  while (nth < ne00 && nth < 1024) {
1355
  nth *= 2;
1356
  }
1357
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline;
 
 
 
 
1358
  }
1359
 
1360
  float scale;
@@ -2518,6 +2552,161 @@ static enum ggml_status ggml_metal_graph_compute(
2518
 
2519
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2520
  } break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2521
  case GGML_OP_DUP:
2522
  case GGML_OP_CPY:
2523
  case GGML_OP_CONT:
@@ -2721,10 +2910,13 @@ GGML_CALL static const char * ggml_backend_metal_buffer_type_get_name(ggml_backe
2721
  UNUSED(buft);
2722
  }
2723
 
2724
- static void ggml_backend_metal_log_allocated_size(id<MTLDevice> device) {
 
2725
  #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
2726
  if (@available(macOS 10.12, iOS 16.0, *)) {
2727
- GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
 
 
2728
  device.currentAllocatedSize / 1024.0 / 1024.0,
2729
  device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
2730
 
@@ -2734,10 +2926,15 @@ static void ggml_backend_metal_log_allocated_size(id<MTLDevice> device) {
2734
  GGML_METAL_LOG_INFO("\n");
2735
  }
2736
  } else {
2737
- GGML_METAL_LOG_INFO(", (%8.2f)\n", device.currentAllocatedSize / 1024.0 / 1024.0);
 
 
 
2738
  }
 
2739
  #endif
2740
  UNUSED(device);
 
2741
  }
2742
 
2743
  GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
@@ -2771,8 +2968,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buff
2771
  return NULL;
2772
  }
2773
 
2774
- GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0);
2775
- ggml_backend_metal_log_allocated_size(device);
2776
 
2777
  return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size);
2778
  }
@@ -2859,7 +3055,7 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data,
2859
  return false;
2860
  }
2861
 
2862
- GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0);
2863
 
2864
  ++ctx->n_buffers;
2865
  } else {
@@ -2882,7 +3078,8 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data,
2882
  return false;
2883
  }
2884
 
2885
- GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, offs = %12ld", __func__, size_step_aligned / 1024.0 / 1024.0, i);
 
2886
  if (i + size_step < size) {
2887
  GGML_METAL_LOG_INFO("\n");
2888
  }
@@ -2891,8 +3088,6 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data,
2891
  }
2892
  }
2893
 
2894
- ggml_backend_metal_log_allocated_size(device);
2895
-
2896
  return ggml_backend_buffer_init(ggml_backend_metal_buffer_type(), ggml_backend_metal_buffer_i, ctx, size);
2897
  }
2898
 
 
47
  GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
48
  GGML_METAL_KERNEL_TYPE_SILU,
49
  GGML_METAL_KERNEL_TYPE_SILU_4,
50
+ GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
51
+ GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
52
+ GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
53
+ GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4,
54
  GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
55
  GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
56
  GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
 
180
  GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
181
  GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
182
  GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
183
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64,
184
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80,
185
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
186
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
187
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
188
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
189
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
190
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
191
  GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
192
  GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
193
  GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
 
454
  }
455
 
456
  /*
457
+ GGML_METAL_LOG_INFO("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
458
  (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \
459
  (int) kernel->pipeline.threadExecutionWidth); \
460
  */
 
470
  return NULL; \
471
  } \
472
  } else { \
473
+ GGML_METAL_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \
474
  }
475
 
476
  // simd_sum and simd_max requires MTLGPUFamilyApple7
477
 
478
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
479
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
480
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
481
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
482
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
483
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
484
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
485
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
486
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
487
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
488
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
489
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
490
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
491
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
492
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
493
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
494
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
495
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
496
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, ctx->support_simdgroup_reduction);
497
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, ctx->support_simdgroup_reduction);
498
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, ctx->support_simdgroup_reduction);
499
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, ctx->support_simdgroup_reduction);
500
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
501
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
502
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
503
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
504
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true);
505
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
506
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
507
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true);
508
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true);
509
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true);
510
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true);
511
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true);
512
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true);
513
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true);
514
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
515
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
516
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
517
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true);
518
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true);
519
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
520
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true);
521
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
522
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
523
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
524
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
525
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
526
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
527
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
528
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction);
529
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
530
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction);
531
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction);
532
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction);
533
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction);
534
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction);
535
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction);
536
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction);
537
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction);
538
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction);
539
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction);
540
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction);
541
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction);
542
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
543
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
544
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction);
545
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction);
546
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction);
547
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
548
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, ctx->support_simdgroup_reduction);
549
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
550
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction);
551
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
552
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
553
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
554
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction);
555
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction);
556
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction);
557
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction);
558
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction);
559
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction);
560
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction);
561
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction);
562
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction);
563
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction);
564
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction);
565
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction);
566
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
567
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
568
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction);
569
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction);
570
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction);
571
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
572
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, ctx->support_simdgroup_reduction);
573
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
574
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction);
575
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
576
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
577
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
578
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm);
579
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm);
580
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm);
581
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm);
582
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm);
583
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm);
584
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm);
585
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm);
586
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm);
587
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
588
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
589
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
590
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm);
591
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm);
592
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
593
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, ctx->support_simdgroup_mm);
594
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
595
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm);
596
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
597
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
598
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
599
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm);
600
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm);
601
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm);
602
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm);
603
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm);
604
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm);
605
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm);
606
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm);
607
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm);
608
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
609
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
610
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm);
611
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm);
612
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm);
613
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
614
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm);
615
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
616
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
617
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
618
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
619
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
620
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
621
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
622
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
623
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
624
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
625
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true);
626
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
627
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
628
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
629
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true);
630
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true);
631
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true);
632
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true);
633
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true);
634
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true);
635
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, true);
636
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, true);
637
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
638
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
639
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
640
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
641
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
642
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
643
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
644
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
645
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
646
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
647
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
648
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
649
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
650
  }
651
 
652
  [metal_library release];
 
766
  case GGML_OP_TIMESTEP_EMBEDDING:
767
  case GGML_OP_ARGSORT:
768
  case GGML_OP_LEAKY_RELU:
769
+ case GGML_OP_FLASH_ATTN_EXT:
770
  return true;
771
  case GGML_OP_MUL_MAT:
772
  case GGML_OP_MUL_MAT_ID:
 
1362
  } break;
1363
  case GGML_OP_SOFT_MAX:
1364
  {
1365
+ GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
1366
+ GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32);
1367
+
1368
  int nth = 32; // SIMD width
1369
 
1370
  id<MTLComputePipelineState> pipeline = nil;
1371
 
1372
+ const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
1373
+
1374
  if (ne00%4 == 0) {
1375
  while (nth < ne00/4 && nth < 256) {
1376
  nth *= 2;
1377
  }
1378
+ if (use_f16) {
1379
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4].pipeline;
1380
+ } else {
1381
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline;
1382
+ }
1383
  } else {
1384
  while (nth < ne00 && nth < 1024) {
1385
  nth *= 2;
1386
  }
1387
+ if (use_f16) {
1388
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16].pipeline;
1389
+ } else {
1390
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32].pipeline;
1391
+ }
1392
  }
1393
 
1394
  float scale;
 
2552
 
2553
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2554
  } break;
2555
+ case GGML_OP_FLASH_ATTN_EXT:
2556
+ {
2557
+ GGML_ASSERT(ne00 % 4 == 0);
2558
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
2559
+
2560
+ struct ggml_tensor * src3 = gf->nodes[i]->src[3];
2561
+
2562
+ GGML_ASSERT(ggml_are_same_shape(src1, src2));
2563
+ GGML_ASSERT(src3);
2564
+
2565
+ size_t offs_src3 = 0;
2566
+
2567
+ id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
2568
+
2569
+ GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16);
2570
+ GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
2571
+ "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
2572
+
2573
+ const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
2574
+ const int64_t ne31 = src3 ? src3->ne[1] : 0;
2575
+ const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
2576
+ const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33);
2577
+
2578
+ const uint64_t nb30 = src3 ? src3->nb[0] : 0; GGML_UNUSED(nb30);
2579
+ const uint64_t nb31 = src3 ? src3->nb[1] : 0;
2580
+ const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32);
2581
+ const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33);
2582
+
2583
+ const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
2584
+
2585
+ float scale;
2586
+ memcpy(&scale, dst->op_params, sizeof(float));
2587
+
2588
+ id<MTLComputePipelineState> pipeline = nil;
2589
+
2590
+ bool use_vec_kernel = false;
2591
+
2592
+ if (ne01 >= 4 || (ne00%128 != 0)) {
2593
+ switch (ne00) {
2594
+ case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
2595
+ case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
2596
+ case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
2597
+ case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
2598
+ case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
2599
+ case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
2600
+ default:
2601
+ {
2602
+ GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
2603
+ GGML_METAL_LOG_ERROR("add template specialization for this size\n");
2604
+ GGML_ASSERT(false && "add template specialization for this size");
2605
+ }
2606
+ }
2607
+ } else {
2608
+ use_vec_kernel = true;
2609
+
2610
+ switch (ne00) {
2611
+ case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
2612
+ case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
2613
+ default:
2614
+ {
2615
+ GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
2616
+ GGML_METAL_LOG_ERROR("add template specialization for this size\n");
2617
+ GGML_ASSERT(false && "add template specialization for this size");
2618
+ }
2619
+ }
2620
+ }
2621
+
2622
+ [encoder setComputePipelineState:pipeline];
2623
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2624
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2625
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
2626
+ [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
2627
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:4];
2628
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5];
2629
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6];
2630
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:7];
2631
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:8];
2632
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:9];
2633
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10];
2634
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11];
2635
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12];
2636
+ [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:13];
2637
+ [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:14];
2638
+ [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:15];
2639
+ [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:16];
2640
+ [encoder setBytes:&nb10 length:sizeof(uint64_t) atIndex:17];
2641
+ [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18];
2642
+ [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19];
2643
+ [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20];
2644
+ [encoder setBytes:&ne31 length:sizeof( int64_t) atIndex:21];
2645
+ [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:22];
2646
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:23];
2647
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:24];
2648
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:25];
2649
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
2650
+ [encoder setBytes:&scale length:sizeof( float) atIndex:27];
2651
+
2652
+ if (!use_vec_kernel) {
2653
+ // half8x8 kernel
2654
+ const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
2655
+ const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
2656
+
2657
+ GGML_ASSERT(nqptg <= 32);
2658
+ GGML_ASSERT(nqptg % 8 == 0);
2659
+ GGML_ASSERT(ncpsg % 32 == 0);
2660
+
2661
+ int64_t nsgmax = 2;
2662
+
2663
+ while (true) {
2664
+ const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2);
2665
+ if (smem > ctx->device.maxThreadgroupMemoryLength) {
2666
+ break;
2667
+ }
2668
+ nsgmax *= 2;
2669
+ }
2670
+ nsgmax /= 2;
2671
+
2672
+ // simdgroups per threadgroup (a.k.a. warps)
2673
+ const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
2674
+
2675
+ const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2);
2676
+
2677
+ //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
2678
+ GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
2679
+
2680
+ [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
2681
+
2682
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
2683
+ } else {
2684
+ // half1x4 kernel
2685
+ const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
2686
+ const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
2687
+
2688
+ GGML_ASSERT(nqptg <= 32);
2689
+ GGML_ASSERT(nqptg % 1 == 0);
2690
+ GGML_ASSERT(ncpsg % 32 == 0);
2691
+
2692
+ // simdgroups per threadgroup (a.k.a. warps)
2693
+ const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
2694
+
2695
+ int64_t nsg = 1;
2696
+ while (nsg <= nsgt) {
2697
+ nsg *= 2;
2698
+ }
2699
+ nsg /= 2;
2700
+
2701
+ const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
2702
+
2703
+ //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
2704
+ GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
2705
+ [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
2706
+
2707
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
2708
+ }
2709
+ } break;
2710
  case GGML_OP_DUP:
2711
  case GGML_OP_CPY:
2712
  case GGML_OP_CONT:
 
2910
  UNUSED(buft);
2911
  }
2912
 
2913
+ static void ggml_backend_metal_log_allocated_size(id<MTLDevice> device, size_t size_aligned) {
2914
+ #ifndef GGML_METAL_NDEBUG
2915
  #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
2916
  if (@available(macOS 10.12, iOS 16.0, *)) {
2917
+ GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f / %8.2f)",
2918
+ __func__,
2919
+ size_aligned / 1024.0 / 1024.0,
2920
  device.currentAllocatedSize / 1024.0 / 1024.0,
2921
  device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
2922
 
 
2926
  GGML_METAL_LOG_INFO("\n");
2927
  }
2928
  } else {
2929
+ GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f)\n",
2930
+ __func__,
2931
+ size_aligned / 1024.0 / 1024.0,
2932
+ device.currentAllocatedSize / 1024.0 / 1024.0);
2933
  }
2934
+ #endif
2935
  #endif
2936
  UNUSED(device);
2937
+ UNUSED(size_aligned);
2938
  }
2939
 
2940
  GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
 
2968
  return NULL;
2969
  }
2970
 
2971
+ //ggml_backend_metal_log_allocated_size(device, size_aligned);
 
2972
 
2973
  return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size);
2974
  }
 
3055
  return false;
3056
  }
3057
 
3058
+ ggml_backend_metal_log_allocated_size(device, size_aligned);
3059
 
3060
  ++ctx->n_buffers;
3061
  } else {
 
3078
  return false;
3079
  }
3080
 
3081
+ ggml_backend_metal_log_allocated_size(device, size_step_aligned);
3082
+
3083
  if (i + size_step < size) {
3084
  GGML_METAL_LOG_INFO("\n");
3085
  }
 
3088
  }
3089
  }
3090
 
 
 
3091
  return ggml_backend_buffer_init(ggml_backend_metal_buffer_type(), ggml_backend_metal_buffer_i, ctx, size);
3092
  }
3093
 
ggml-metal.metal CHANGED
@@ -359,11 +359,12 @@ kernel void kernel_sum_rows(
359
  dst_row[0] = row_sum;
360
  }
361
 
 
362
  kernel void kernel_soft_max(
363
- device const float * src0,
364
- device const float * src1,
365
- device const float * src2,
366
- device float * dst,
367
  constant int64_t & ne00,
368
  constant int64_t & ne01,
369
  constant int64_t & ne02,
@@ -382,10 +383,10 @@ kernel void kernel_soft_max(
382
  const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
383
  const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
384
 
385
- device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
386
- device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
387
- device const float * ppos = src2 != src0 ? src2 : nullptr;
388
- device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
389
 
390
  float slope = 0.0f;
391
 
@@ -463,11 +464,12 @@ kernel void kernel_soft_max(
463
  }
464
  }
465
 
 
466
  kernel void kernel_soft_max_4(
467
- device const float * src0,
468
- device const float * src1,
469
- device const float * src2,
470
- device float * dst,
471
  constant int64_t & ne00,
472
  constant int64_t & ne01,
473
  constant int64_t & ne02,
@@ -486,10 +488,10 @@ kernel void kernel_soft_max_4(
486
  const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
487
  const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
488
 
489
- device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
490
- device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
491
- device const float4 * ppos = src2 != src0 ? (device const float4 *)(src2) : nullptr;
492
- device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
493
 
494
  float slope = 0.0f;
495
 
@@ -506,7 +508,7 @@ kernel void kernel_soft_max_4(
506
  float4 lmax4 = -INFINITY;
507
 
508
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
509
- lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f));
510
  }
511
 
512
  const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
@@ -532,7 +534,7 @@ kernel void kernel_soft_max_4(
532
  // parallel sum
533
  float4 lsum4 = 0.0f;
534
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
535
- const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)) - max_val);
536
  lsum4 += exp_psrc4;
537
  pdst4[i00] = exp_psrc4;
538
  }
@@ -569,6 +571,14 @@ kernel void kernel_soft_max_4(
569
  }
570
  }
571
 
 
 
 
 
 
 
 
 
572
  kernel void kernel_diag_mask_inf(
573
  device const float * src0,
574
  device float * dst,
@@ -2091,6 +2101,632 @@ kernel void kernel_leaky_relu_f32(
2091
  dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
2092
  }
2093
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2094
  kernel void kernel_cpy_f16_f16(
2095
  device const half * src0,
2096
  device half * dst,
 
359
  dst_row[0] = row_sum;
360
  }
361
 
362
+ template<typename T>
363
  kernel void kernel_soft_max(
364
+ device const char * src0,
365
+ device const char * src1,
366
+ device const char * src2,
367
+ device char * dst,
368
  constant int64_t & ne00,
369
  constant int64_t & ne01,
370
  constant int64_t & ne02,
 
383
  const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
384
  const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
385
 
386
+ device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
387
+ device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr;
388
+ device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
389
+ device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
390
 
391
  float slope = 0.0f;
392
 
 
464
  }
465
  }
466
 
467
+ template<typename T>
468
  kernel void kernel_soft_max_4(
469
+ device const char * src0,
470
+ device const char * src1,
471
+ device const char * src2,
472
+ device char * dst,
473
  constant int64_t & ne00,
474
  constant int64_t & ne01,
475
  constant int64_t & ne02,
 
488
  const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
489
  const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
490
 
491
+ device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
492
+ device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr;
493
+ device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
494
+ device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
495
 
496
  float slope = 0.0f;
497
 
 
508
  float4 lmax4 = -INFINITY;
509
 
510
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
511
+ lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)));
512
  }
513
 
514
  const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
 
534
  // parallel sum
535
  float4 lsum4 = 0.0f;
536
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
537
+ const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f))) - max_val);
538
  lsum4 += exp_psrc4;
539
  pdst4[i00] = exp_psrc4;
540
  }
 
571
  }
572
  }
573
 
574
+ typedef decltype(kernel_soft_max<float>) kernel_soft_max_t;
575
+ typedef decltype(kernel_soft_max_4<float4>) kernel_soft_max_4_t;
576
+
577
+ template [[host_name("kernel_soft_max_f16")]] kernel kernel_soft_max_t kernel_soft_max<half>;
578
+ template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kernel_soft_max<float>;
579
+ template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<half4>;
580
+ template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<float4>;
581
+
582
  kernel void kernel_diag_mask_inf(
583
  device const float * src0,
584
  device float * dst,
 
2101
  dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
2102
  }
2103
 
2104
+ typedef void (flash_attn_ext_f16_t)(
2105
+ device const char * q,
2106
+ device const char * k,
2107
+ device const char * v,
2108
+ device const char * mask,
2109
+ device float * dst,
2110
+ constant int64_t & ne00,
2111
+ constant int64_t & ne01,
2112
+ constant int64_t & ne02,
2113
+ constant int64_t & ne03,
2114
+ constant uint64_t & nb00,
2115
+ constant uint64_t & nb01,
2116
+ constant uint64_t & nb02,
2117
+ constant uint64_t & nb03,
2118
+ constant int64_t & ne10,
2119
+ constant int64_t & ne11,
2120
+ constant int64_t & ne12,
2121
+ constant int64_t & ne13,
2122
+ constant uint64_t & nb10,
2123
+ constant uint64_t & nb11,
2124
+ constant uint64_t & nb12,
2125
+ constant uint64_t & nb13,
2126
+ constant int64_t & ne31,
2127
+ constant uint64_t & nb31,
2128
+ constant int64_t & ne0,
2129
+ constant int64_t & ne1,
2130
+ constant int64_t & ne2,
2131
+ constant int64_t & ne3,
2132
+ constant float & scale,
2133
+ threadgroup half * shared,
2134
+ uint3 tgpig[[threadgroup_position_in_grid]],
2135
+ uint3 tpitg[[thread_position_in_threadgroup]],
2136
+ uint3 ntg[[threads_per_threadgroup]],
2137
+ ushort tiisg[[thread_index_in_simdgroup]],
2138
+ ushort sgitg[[simdgroup_index_in_threadgroup]]);
2139
+
2140
+ // ref: https://arxiv.org/pdf/2307.08691.pdf
2141
+ template<int64_t D, int64_t Q = 8, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
2142
+ kernel void kernel_flash_attn_ext_f16(
2143
+ device const char * q,
2144
+ device const char * k,
2145
+ device const char * v,
2146
+ device const char * mask,
2147
+ device float * dst,
2148
+ constant int64_t & ne00,
2149
+ constant int64_t & ne01,
2150
+ constant int64_t & ne02,
2151
+ constant int64_t & ne03,
2152
+ constant uint64_t & nb00,
2153
+ constant uint64_t & nb01,
2154
+ constant uint64_t & nb02,
2155
+ constant uint64_t & nb03,
2156
+ constant int64_t & ne10,
2157
+ constant int64_t & ne11,
2158
+ constant int64_t & ne12,
2159
+ constant int64_t & ne13,
2160
+ constant uint64_t & nb10,
2161
+ constant uint64_t & nb11,
2162
+ constant uint64_t & nb12,
2163
+ constant uint64_t & nb13,
2164
+ constant int64_t & ne31,
2165
+ constant uint64_t & nb31,
2166
+ constant int64_t & ne0,
2167
+ constant int64_t & ne1,
2168
+ constant int64_t & ne2,
2169
+ constant int64_t & ne3,
2170
+ constant float & scale,
2171
+ threadgroup half * shared [[threadgroup(0)]],
2172
+ uint3 tgpig[[threadgroup_position_in_grid]],
2173
+ uint3 tpitg[[thread_position_in_threadgroup]],
2174
+ uint3 ntg[[threads_per_threadgroup]],
2175
+ ushort tiisg[[thread_index_in_simdgroup]],
2176
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
2177
+ const short nsg = ntg.y; // number of simdgroups
2178
+
2179
+ const short iq3 = tgpig[2];
2180
+ const short iq2 = tgpig[1];
2181
+ const short iq1 = tgpig[0]*Q;
2182
+
2183
+ const short D4 = D/4;
2184
+ const short D8 = D/8;
2185
+ const short Q8 = Q/8;
2186
+ const short NW = N_SIMDWIDTH;
2187
+ const short SH = (C + Q); // shared memory per simdgroup in (half)
2188
+
2189
+ const short T = D + 2*nsg*SH; // shared memory size per query in (half)
2190
+ const short TF = T/2; // shared memory size per query in (float)
2191
+ const short T4 = T/4; // shared memory size per query in (half4)
2192
+
2193
+ threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
2194
+ threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
2195
+ threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
2196
+
2197
+ // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
2198
+ simdgroup_half8x8 lo[D8];
2199
+
2200
+ // load heads from Q to shared memory
2201
+ for (short j = sgitg; j < Q; j += nsg) {
2202
+ device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
2203
+
2204
+ for (short i = tiisg; i < D4; i += NW) {
2205
+ if (iq1 + j < ne01) {
2206
+ sq4[j*T4 + i] = (half4) q4[i];
2207
+ } else {
2208
+ sq4[j*T4 + i] = 0.0h;
2209
+ }
2210
+ }
2211
+ }
2212
+
2213
+ // zero out lo
2214
+ for (short i = 0; i < D8; ++i) {
2215
+ lo[i] = make_filled_simdgroup_matrix<half, 8>(0.0h);
2216
+ }
2217
+
2218
+ // zero out shared memory SH
2219
+ for (short j = 0; j < Q; ++j) {
2220
+ for (short i = tiisg; i < SH; i += NW) {
2221
+ ss[j*TF + i] = 0.0f;
2222
+ }
2223
+ }
2224
+
2225
+ threadgroup_barrier(mem_flags::mem_threadgroup);
2226
+
2227
+ {
2228
+ float S[Q] = { [0 ... Q-1] = 0.0h };
2229
+ float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 };
2230
+
2231
+ // assume K and V are same shape
2232
+ const short ne22 = ne12;
2233
+ const short ne23 = ne13;
2234
+
2235
+ const uint nb21 = nb11;
2236
+ const uint nb22 = nb12;
2237
+ const uint nb23 = nb13;
2238
+
2239
+ // broadcast
2240
+ const short rk2 = ne02/ne12;
2241
+ const short rk3 = ne03/ne13;
2242
+
2243
+ const short rv2 = ne02/ne22;
2244
+ const short rv3 = ne03/ne23;
2245
+
2246
+ // k indices
2247
+ const short ik2 = iq2/rk2;
2248
+ const short ik3 = iq3/rk3;
2249
+
2250
+ // v indices
2251
+ const short iv2 = iq2/rv2;
2252
+ const short iv3 = iq3/rv3;
2253
+
2254
+ // load the queries from shared memory into local memory
2255
+ simdgroup_half8x8 mq[D8];
2256
+
2257
+ for (short i = 0; i < D8; ++i) {
2258
+ simdgroup_load(mq[i], sq + i*8, T);
2259
+ }
2260
+
2261
+ // pointer to the mask
2262
+ device const half * mp = (device const half *) (mask + iq1*nb31);
2263
+
2264
+ // prepare diagonal scale matrix
2265
+ simdgroup_float8x8 mscale(scale);
2266
+
2267
+ // loop over the KV cache
2268
+ // each simdgroup handles blocks of Q rows and C columns
2269
+ for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
2270
+ const int ic = ic0 + C*sgitg;
2271
+ if (ic >= ne11) {
2272
+ break;
2273
+ }
2274
+
2275
+ // Q*K^T
2276
+ {
2277
+ for (short cc = 0; cc < C/8; ++cc) {
2278
+ simdgroup_float8x8 mqk = make_filled_simdgroup_matrix<float, 8>(0.h);
2279
+
2280
+ device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
2281
+
2282
+ for (short i = 0; i < D8; ++i) {
2283
+ simdgroup_half8x8 mk;
2284
+ simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
2285
+
2286
+ simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
2287
+ }
2288
+
2289
+ // mqk = mqk*scale + mask
2290
+ simdgroup_half8x8 mm;
2291
+ simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
2292
+ simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
2293
+
2294
+ simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
2295
+ }
2296
+ }
2297
+
2298
+ // used to detect blocks full of -INF
2299
+ float smax = -INFINITY;
2300
+
2301
+ // online softmax
2302
+ {
2303
+ float ms[Q];
2304
+
2305
+ for (short j = 0; j < Q; ++j) {
2306
+ const short p = tiisg;
2307
+
2308
+ const float m = M[j];
2309
+ const float s = ss[j*TF + p];
2310
+
2311
+ smax = simd_max(max(smax, s));
2312
+ M[j] = simd_max(max(M[j], s));
2313
+
2314
+ ms[j] = exp(m - M[j]);
2315
+ const float vs = exp(s - M[j]);
2316
+
2317
+ S[j] = S[j]*ms[j] + simd_sum(vs);
2318
+
2319
+ // the P matrix from the paper (Q rows, C columns)
2320
+ ss[j*TF + p] = vs;
2321
+ }
2322
+
2323
+ // create a QxQ diagonal matrix for rescaling the output
2324
+ if (tiisg < Q) {
2325
+ ss[tiisg*TF + C + tiisg] = ms[tiisg];
2326
+ }
2327
+ }
2328
+
2329
+ // skip -INF blocks
2330
+ if (smax == -INFINITY) {
2331
+ continue;
2332
+ }
2333
+
2334
+ // O = diag(ms)*O
2335
+ {
2336
+ simdgroup_float8x8 mm;
2337
+ simdgroup_load(mm, ss + C, TF, 0, false);
2338
+
2339
+ for (short i = 0; i < D8; ++i) {
2340
+ simdgroup_multiply(lo[i], mm, lo[i]);
2341
+ }
2342
+ }
2343
+
2344
+ // O = O + (Q*K^T)*V
2345
+ {
2346
+ for (short cc = 0; cc < C/8; ++cc) {
2347
+ device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
2348
+
2349
+ for (short i = 0; i < D8; ++i) {
2350
+ simdgroup_half8x8 mk;
2351
+ simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false);
2352
+
2353
+ simdgroup_float8x8 mv;
2354
+ simdgroup_load(mv, ss + 8*cc, TF, 0, false);
2355
+
2356
+ simdgroup_multiply_accumulate(lo[i], mv, mk, lo[i]);
2357
+ }
2358
+ }
2359
+ }
2360
+ }
2361
+
2362
+ // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
2363
+ for (short j = 0; j < Q; ++j) {
2364
+ if (tiisg == 0) {
2365
+ ss[j*TF + 0] = S[j];
2366
+ ss[j*TF + 1] = M[j];
2367
+ }
2368
+ }
2369
+ }
2370
+
2371
+ // reduce the warps sequentially
2372
+ for (short sg = 1; sg < nsg; ++sg) {
2373
+ float S = { 0.0h };
2374
+ float M = { -FLT_MAX/2 };
2375
+
2376
+ threadgroup_barrier(mem_flags::mem_threadgroup);
2377
+
2378
+ // each simdgroup stores its output to shared memory, reusing sq
2379
+ if (sgitg == sg) {
2380
+ for (short i = 0; i < D8; ++i) {
2381
+ simdgroup_store(lo[i], sq + i*8, T, 0, false);
2382
+ }
2383
+ }
2384
+
2385
+ threadgroup_barrier(mem_flags::mem_threadgroup);
2386
+
2387
+ // the first simdgroup accumulates the results from the other simdgroups
2388
+ if (sgitg == 0) {
2389
+ for (short j = 0; j < Q; ++j) {
2390
+ const float S0 = ss[j*TF + 0];
2391
+ const float S1 = ss[j*TF + sg*SH + 0];
2392
+
2393
+ const float M0 = ss[j*TF + 1];
2394
+ const float M1 = ss[j*TF + sg*SH + 1];
2395
+
2396
+ M = max(M0, M1);
2397
+
2398
+ const float ms0 = exp(M0 - M);
2399
+ const float ms1 = exp(M1 - M);
2400
+
2401
+ S = S0*ms0 + S1*ms1;
2402
+
2403
+ if (tiisg == 0) {
2404
+ ss[j*TF + 0] = S;
2405
+ ss[j*TF + 1] = M;
2406
+
2407
+ ss[j*TF + C + j ] = ms0;
2408
+ ss[j*TF + C + j + sg*SH] = ms1;
2409
+ }
2410
+ }
2411
+
2412
+ // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
2413
+ {
2414
+ simdgroup_half8x8 t;
2415
+ simdgroup_float8x8 ms0;
2416
+ simdgroup_float8x8 ms1;
2417
+
2418
+ simdgroup_load(ms0, ss + C, TF, 0, false);
2419
+ simdgroup_load(ms1, ss + C + sg*SH, TF, 0, false);
2420
+
2421
+ for (short i = 0; i < D8; ++i) {
2422
+ simdgroup_load (t, sq + i*8, T, 0, false);
2423
+ simdgroup_multiply(t, ms1, t);
2424
+
2425
+ simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
2426
+ }
2427
+ }
2428
+ }
2429
+ }
2430
+
2431
+ // store result to shared memory (reuse sq)
2432
+ if (sgitg == 0) {
2433
+ for (short i = 0; i < D8; ++i) {
2434
+ simdgroup_store(lo[i], sq + i*8, T, 0, false);
2435
+ }
2436
+ }
2437
+
2438
+ device float4 * dst4 = (device float4 *) dst;
2439
+
2440
+ // final rescale with 1/S and store to global memory
2441
+ if (sgitg == 0) {
2442
+ for (short j = 0; j < Q && iq1 + j < ne01; ++j) {
2443
+ const float S = ss[j*TF + 0];
2444
+
2445
+ for (short i = tiisg; i < D4; i += NW) {
2446
+ dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S;
2447
+ }
2448
+ }
2449
+ }
2450
+ }
2451
+
2452
+ template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>;
2453
+ template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>;
2454
+ template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>;
2455
+ template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>;
2456
+ template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>;
2457
+ template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
2458
+
2459
+ template<int64_t D, int64_t Q = 1, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
2460
+ kernel void kernel_flash_attn_ext_vec_f16(
2461
+ device const char * q,
2462
+ device const char * k,
2463
+ device const char * v,
2464
+ device const char * mask,
2465
+ device float * dst,
2466
+ constant int64_t & ne00,
2467
+ constant int64_t & ne01,
2468
+ constant int64_t & ne02,
2469
+ constant int64_t & ne03,
2470
+ constant uint64_t & nb00,
2471
+ constant uint64_t & nb01,
2472
+ constant uint64_t & nb02,
2473
+ constant uint64_t & nb03,
2474
+ constant int64_t & ne10,
2475
+ constant int64_t & ne11,
2476
+ constant int64_t & ne12,
2477
+ constant int64_t & ne13,
2478
+ constant uint64_t & nb10,
2479
+ constant uint64_t & nb11,
2480
+ constant uint64_t & nb12,
2481
+ constant uint64_t & nb13,
2482
+ constant int64_t & ne31,
2483
+ constant uint64_t & nb31,
2484
+ constant int64_t & ne0,
2485
+ constant int64_t & ne1,
2486
+ constant int64_t & ne2,
2487
+ constant int64_t & ne3,
2488
+ constant float & scale,
2489
+ threadgroup half * shared [[threadgroup(0)]],
2490
+ uint3 tgpig[[threadgroup_position_in_grid]],
2491
+ uint3 tpitg[[thread_position_in_threadgroup]],
2492
+ uint3 ntg[[threads_per_threadgroup]],
2493
+ ushort tiisg[[thread_index_in_simdgroup]],
2494
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
2495
+ const short nsg = ntg.y; // number of simdgroups
2496
+
2497
+ const short iq3 = tgpig[2];
2498
+ const short iq2 = tgpig[1];
2499
+ const short iq1 = tgpig[0];
2500
+
2501
+ const short D4 = D/4;
2502
+ const short NW = N_SIMDWIDTH;
2503
+ const short SH = (C + Q); // shared memory per simdgroup in (half)
2504
+
2505
+ const short T = D + 2*nsg*SH; // shared memory size per query in (half)
2506
+
2507
+ //threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
2508
+ threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
2509
+ threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
2510
+ threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4
2511
+ threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results
2512
+
2513
+ // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
2514
+ half4 lo[D4/NW];
2515
+
2516
+ // load heads from Q to shared memory
2517
+ device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
2518
+
2519
+ for (short i = tiisg; i < D4; i += NW) {
2520
+ if (iq1 < ne01) {
2521
+ sq4[i] = (half4) q4[i];
2522
+ } else {
2523
+ sq4[i] = 0.0h;
2524
+ }
2525
+ }
2526
+
2527
+ // zero out lo
2528
+ for (short i = tiisg; i < D4; i += NW) {
2529
+ lo[i/NW] = 0.0h;
2530
+ }
2531
+
2532
+ // zero out shared memory SH
2533
+ for (short i = tiisg; i < SH/4; i += NW) {
2534
+ ss4[i] = 0.0h;
2535
+ }
2536
+
2537
+ threadgroup_barrier(mem_flags::mem_threadgroup);
2538
+
2539
+ {
2540
+ float S = { 0.0h };
2541
+ float M = { -FLT_MAX/2 };
2542
+
2543
+ // assume K and V are same shape
2544
+ const short ne22 = ne12;
2545
+ const short ne23 = ne13;
2546
+
2547
+ const uint nb21 = nb11;
2548
+ const uint nb22 = nb12;
2549
+ const uint nb23 = nb13;
2550
+
2551
+ // broadcast
2552
+ const short rk2 = ne02/ne12;
2553
+ const short rk3 = ne03/ne13;
2554
+
2555
+ const short rv2 = ne02/ne22;
2556
+ const short rv3 = ne03/ne23;
2557
+
2558
+ // k indices
2559
+ const short ik2 = iq2 / rk2;
2560
+ const short ik3 = iq3 / rk3;
2561
+
2562
+ // v indices
2563
+ const short iv2 = iq2 / rv2;
2564
+ const short iv3 = iq3 / rv3;
2565
+
2566
+ // load the queries from shared memory into local memory
2567
+ half4 mq[D4];
2568
+
2569
+ for (short ii = 0; ii < D4; ii += NW) {
2570
+ short i = ii + tiisg;
2571
+ mq[i] = sq4[i];
2572
+ }
2573
+
2574
+ // pointer to the mask
2575
+ device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31);
2576
+
2577
+ // loop over the KV cache
2578
+ // each simdgroup handles blocks of Q rows and C columns
2579
+ for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
2580
+ const int ic = ic0 + C*sgitg;
2581
+ if (ic >= ne11) {
2582
+ break;
2583
+ }
2584
+
2585
+ // Q*K^T
2586
+ {
2587
+ #pragma unroll
2588
+ for (short cc = 0; cc < C/4; ++cc) {
2589
+ float4 mqk = { 0.0h };
2590
+
2591
+ device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13));
2592
+
2593
+ #pragma unroll
2594
+ for (short ii = 0; ii < D4; ii += NW) {
2595
+ const short i = ii + tiisg;
2596
+
2597
+ half4x4 mk;
2598
+ mk[0] = pk4[i + 0*(nb11/8)];
2599
+ mk[1] = pk4[i + 1*(nb11/8)];
2600
+ mk[2] = pk4[i + 2*(nb11/8)];
2601
+ mk[3] = pk4[i + 3*(nb11/8)];
2602
+
2603
+ mqk += (float4) (mq[i] * mk);
2604
+ }
2605
+
2606
+ // reduce the results from the threads in the simdgroup
2607
+ mqk += simd_shuffle_down(mqk, 16);
2608
+ mqk += simd_shuffle_down(mqk, 8);
2609
+ mqk += simd_shuffle_down(mqk, 4);
2610
+ mqk += simd_shuffle_down(mqk, 2);
2611
+ mqk += simd_shuffle_down(mqk, 1);
2612
+
2613
+ // mqk = mqk*scale + mask
2614
+ if (tiisg == 0) {
2615
+ float4 mm = (float4) mp4[ic/4 + cc];
2616
+ mqk = mqk*scale + mm;
2617
+
2618
+ ss4[cc] = mqk;
2619
+ }
2620
+ }
2621
+ }
2622
+
2623
+ // online softmax
2624
+ {
2625
+ const short p = tiisg;
2626
+
2627
+ const float m = M;
2628
+ const float s = ss[p];
2629
+
2630
+ M = simd_max(max(M, s));
2631
+
2632
+ const float ms = exp(m - M);
2633
+ const float vs = exp(s - M);
2634
+
2635
+ S = S*ms + simd_sum(vs);
2636
+
2637
+ // the P matrix from the paper (Q rows, C columns)
2638
+ ss[p] = vs;
2639
+
2640
+ // O = diag(ms)*O
2641
+ #pragma unroll
2642
+ for (short ii = 0; ii < D4; ii += NW) {
2643
+ const short i = ii + tiisg;
2644
+ lo[i/NW] *= ms;
2645
+ }
2646
+ }
2647
+
2648
+ // O = O + (Q*K^T)*V
2649
+ {
2650
+ #pragma unroll
2651
+ for (short cc = 0; cc < C/4; ++cc) {
2652
+ device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + 4*cc)*nb21 + iv2*nb22 + iv3*nb23));
2653
+
2654
+ #pragma unroll
2655
+ for (short ii = 0; ii < D4; ii += NW) {
2656
+ const short i = ii + tiisg;
2657
+
2658
+ lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0];
2659
+ lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1];
2660
+ lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2];
2661
+ lo[i/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3];
2662
+ }
2663
+ }
2664
+ }
2665
+
2666
+ }
2667
+
2668
+ // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
2669
+ if (tiisg == 0) {
2670
+ ss[0] = S;
2671
+ ss[1] = M;
2672
+ }
2673
+ }
2674
+
2675
+ // store results to shared memory
2676
+ for (short ii = 0; ii < D4; ii += NW) {
2677
+ short i = ii + tiisg;
2678
+ sr4[i] = lo[ii/NW];
2679
+ }
2680
+
2681
+ threadgroup_barrier(mem_flags::mem_threadgroup);
2682
+
2683
+ // parallel reduce
2684
+ for (short r = nsg/2; r > 0; r >>= 1) {
2685
+ if (sgitg < r) {
2686
+ const float S0 = ss[ 0];
2687
+ const float S1 = ss[r*SH + 0];
2688
+
2689
+ const float M0 = ss[ 1];
2690
+ const float M1 = ss[r*SH + 1];
2691
+
2692
+ const float M = max(M0, M1);
2693
+
2694
+ const float ms0 = exp(M0 - M);
2695
+ const float ms1 = exp(M1 - M);
2696
+
2697
+ const float S = S0*ms0 + S1*ms1;
2698
+
2699
+ if (tiisg == 0) {
2700
+ ss[0] = S;
2701
+ ss[1] = M;
2702
+ }
2703
+
2704
+ // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
2705
+ for (short ii = 0; ii < D4; ii += NW) {
2706
+ short i = ii + tiisg;
2707
+ sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1;
2708
+ }
2709
+ }
2710
+
2711
+ threadgroup_barrier(mem_flags::mem_threadgroup);
2712
+ }
2713
+
2714
+ device float4 * dst4 = (device float4 *) dst;
2715
+
2716
+ // final rescale with 1/S and store to global memory
2717
+ if (sgitg == 0) {
2718
+ const float S = ss[0];
2719
+
2720
+ for (short ii = 0; ii < D4; ii += NW) {
2721
+ short i = ii + tiisg;
2722
+ dst4[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D4 + i] = (float4) sr4[i]/S;
2723
+ }
2724
+ }
2725
+ }
2726
+
2727
+ template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
2728
+ template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
2729
+
2730
  kernel void kernel_cpy_f16_f16(
2731
  device const half * src0,
2732
  device half * dst,
ggml-sycl.cpp CHANGED
@@ -14744,7 +14744,12 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0,
14744
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
14745
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
14746
 
 
 
 
 
14747
  GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
 
14748
 
14749
  const int64_t ne00 = src0->ne[0];
14750
  const int64_t nrows_x = ggml_nrows(src0);
@@ -14760,7 +14765,6 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0,
14760
  float * src2_dd = nullptr;
14761
  sycl_pool_alloc<float> src2_f;
14762
 
14763
- ggml_tensor * src2 = dst->src[2];
14764
  const bool use_src2 = src2 != nullptr;
14765
 
14766
  if (use_src2) {
 
14744
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
14745
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
14746
 
14747
+ const ggml_tensor * src2 = dst->src[2];
14748
+
14749
+ #pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 and src2 support")
14750
+ #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
14751
  GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
14752
+ GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional
14753
 
14754
  const int64_t ne00 = src0->ne[0];
14755
  const int64_t nrows_x = ggml_nrows(src0);
 
14765
  float * src2_dd = nullptr;
14766
  sycl_pool_alloc<float> src2_f;
14767
 
 
14768
  const bool use_src2 = src2 != nullptr;
14769
 
14770
  if (use_src2) {
ggml-vulkan.cpp CHANGED
@@ -3178,6 +3178,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
3178
  }
3179
  return nullptr;
3180
  case GGML_OP_SOFT_MAX:
 
 
 
 
 
3181
  if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && (src2 == nullptr || src2->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
3182
  return ctx->device->pipeline_soft_max_f32;
3183
  }
 
3178
  }
3179
  return nullptr;
3180
  case GGML_OP_SOFT_MAX:
3181
+ #pragma message("TODO: add ggml_vk_soft_max() F16 src1 and src2 support")
3182
+ #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
3183
+ GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32);
3184
+ GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32);
3185
+
3186
  if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && (src2 == nullptr || src2->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
3187
  return ctx->device->pipeline_soft_max_f32;
3188
  }
ggml.c CHANGED
@@ -951,7 +951,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
951
  #define GGML_F16_VEC_ZERO GGML_F16x8_ZERO
952
  #define GGML_F16_VEC_SET1 GGML_F16x8_SET1
953
  #define GGML_F16_VEC_LOAD(p, i) GGML_F16x8_LOAD(p)
954
- #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE(p, r[i])
955
  #define GGML_F16_VEC_FMA GGML_F16x8_FMA
956
  #define GGML_F16_VEC_ADD GGML_F16x8_ADD
957
  #define GGML_F16_VEC_MUL GGML_F16x8_MUL
@@ -977,7 +977,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
977
  #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO
978
  #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1
979
  #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p)
980
- #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i])
981
  #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
982
  #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
983
  #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
@@ -1046,7 +1046,7 @@ do { \
1046
 
1047
  // unlike _mm256_cvt intrinsics that require F16C, _mm512_cvt is defined in AVX512F
1048
  // so F16C guard isn't required
1049
- #define GGML_F32Cx16_LOAD(x) _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(x)))
1050
  #define GGML_F32Cx16_STORE(x, y) _mm256_storeu_si256((__m256i *)(x), _mm512_cvtps_ph(y, 0))
1051
 
1052
  #define GGML_F32Cx16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a)
@@ -1144,7 +1144,7 @@ do { \
1144
 
1145
  #if defined(__F16C__)
1146
  // the _mm256_cvt intrinsics require F16C
1147
- #define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x)))
1148
  #define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0))
1149
  #else
1150
  static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
@@ -1662,6 +1662,37 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float
1662
  #endif
1663
  }
1664
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1665
  // xs and vs are byte strides of x and v
1666
  inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) {
1667
 
@@ -1746,6 +1777,35 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
1746
  #endif
1747
  }
1748
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1749
  inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); }
1750
  inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
1751
  inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
@@ -2001,6 +2061,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
2001
  "LEAKY_RELU",
2002
 
2003
  "FLASH_ATTN",
 
2004
  "FLASH_FF",
2005
  "FLASH_ATTN_BACK",
2006
  "SSM_CONV",
@@ -2027,7 +2088,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
2027
  "CROSS_ENTROPY_LOSS_BACK",
2028
  };
2029
 
2030
- static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
2031
 
2032
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2033
  "none",
@@ -2091,6 +2152,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2091
  "leaky_relu(x)",
2092
 
2093
  "flash_attn(x)",
 
2094
  "flash_ff(x)",
2095
  "flash_attn_back(x)",
2096
  "ssm_conv(x)",
@@ -2117,7 +2179,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2117
  "cross_entropy_loss_back(x,y)",
2118
  };
2119
 
2120
- static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
2121
 
2122
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
2123
 
@@ -4575,6 +4637,8 @@ struct ggml_tensor * ggml_mul_mat(
4575
  void ggml_mul_mat_set_prec(
4576
  struct ggml_tensor * a,
4577
  enum ggml_prec prec) {
 
 
4578
  const int32_t prec_i32 = (int32_t) prec;
4579
 
4580
  ggml_set_op_params_i32(a, 0, prec_i32);
@@ -5413,17 +5477,23 @@ static struct ggml_tensor * ggml_soft_max_impl(
5413
  GGML_ASSERT(ggml_is_contiguous(a));
5414
 
5415
  if (mask) {
 
5416
  GGML_ASSERT(ggml_is_contiguous(mask));
5417
  GGML_ASSERT(ggml_is_matrix(mask));
5418
- GGML_ASSERT(ggml_can_repeat_rows(mask, a));
 
5419
  }
5420
 
5421
  if (pos) {
5422
  GGML_ASSERT(ggml_is_vector(pos));
5423
- GGML_ASSERT(pos->type == GGML_TYPE_F32);
5424
  GGML_ASSERT(pos->ne[0] == a->ne[0]);
5425
  }
5426
 
 
 
 
 
5427
  if (max_bias > 0.0f) {
5428
  GGML_ASSERT(pos);
5429
  }
@@ -6232,6 +6302,59 @@ struct ggml_tensor * ggml_flash_attn(
6232
  return result;
6233
  }
6234
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6235
  // ggml_flash_ff
6236
 
6237
  struct ggml_tensor * ggml_flash_ff(
@@ -12317,7 +12440,7 @@ static void ggml_compute_forward_soft_max_f32(
12317
 
12318
  GGML_TENSOR_UNARY_OP_LOCALS
12319
 
12320
- const int64_t ne11 = src1 ? src1->ne[1] : 1;
12321
 
12322
  // TODO: is this supposed to be ceil instead of floor?
12323
  // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
@@ -12340,19 +12463,31 @@ static void ggml_compute_forward_soft_max_f32(
12340
  float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
12341
 
12342
  // when max_bias <= 0.0f, src2 is not used and we default it to src0 to avoid branching
12343
- float * pos = src2 ? (float *) src2->data : src0->data;
 
 
 
12344
 
12345
  for (int i1 = ir0; i1 < ir1; i1++) {
12346
  float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
12347
  float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
12348
 
12349
  // broadcast the mask across rows
12350
- float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL;
 
12351
 
12352
  ggml_vec_cpy_f32 (nc, wp, sp);
12353
  ggml_vec_scale_f32(nc, wp, scale);
12354
- if (mp) {
12355
- ggml_vec_acc_f32(nc, wp, mp);
 
 
 
 
 
 
 
 
12356
  }
12357
 
12358
  // ALiBi bias
@@ -12360,8 +12495,14 @@ static void ggml_compute_forward_soft_max_f32(
12360
  const uint32_t h = (i1/ne01)%ne02; // head
12361
  const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1);
12362
 
12363
- for (int i = 0; i < nc; i++) {
12364
- wp[i] = wp[i] + slope*pos[i];
 
 
 
 
 
 
12365
  }
12366
  }
12367
 
@@ -14631,6 +14772,198 @@ static void ggml_compute_forward_flash_attn(
14631
  }
14632
  }
14633
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14634
  // ggml_compute_forward_flash_ff
14635
 
14636
  static void ggml_compute_forward_flash_ff_f16(
@@ -16442,6 +16775,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
16442
  const bool masked = t != 0;
16443
  ggml_compute_forward_flash_attn(params, masked, tensor);
16444
  } break;
 
 
 
 
16445
  case GGML_OP_FLASH_FF:
16446
  {
16447
  ggml_compute_forward_flash_ff(params, tensor);
@@ -17454,6 +17791,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
17454
  GGML_ASSERT(false); // TODO: not implemented
17455
  } break;
17456
  case GGML_OP_FLASH_ATTN:
 
17457
  {
17458
  struct ggml_tensor * flash_grad = NULL;
17459
  if (src0->grad || src1->grad || tensor->src[2]->grad) {
@@ -18231,6 +18569,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
18231
  n_tasks = n_threads;
18232
  } break;
18233
  case GGML_OP_FLASH_ATTN:
 
18234
  {
18235
  n_tasks = n_threads;
18236
  } break;
@@ -18634,6 +18973,12 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
18634
  cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
18635
  }
18636
  } break;
 
 
 
 
 
 
18637
  case GGML_OP_FLASH_FF:
18638
  {
18639
  if (node->src[1]->type == GGML_TYPE_F32) {
 
951
  #define GGML_F16_VEC_ZERO GGML_F16x8_ZERO
952
  #define GGML_F16_VEC_SET1 GGML_F16x8_SET1
953
  #define GGML_F16_VEC_LOAD(p, i) GGML_F16x8_LOAD(p)
954
+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE((ggml_fp16_internal_t *)(p), r[i])
955
  #define GGML_F16_VEC_FMA GGML_F16x8_FMA
956
  #define GGML_F16_VEC_ADD GGML_F16x8_ADD
957
  #define GGML_F16_VEC_MUL GGML_F16x8_MUL
 
977
  #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO
978
  #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1
979
  #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p)
980
+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE((ggml_fp16_internal_t *)(p), r[i])
981
  #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
982
  #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
983
  #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
 
1046
 
1047
  // unlike _mm256_cvt intrinsics that require F16C, _mm512_cvt is defined in AVX512F
1048
  // so F16C guard isn't required
1049
+ #define GGML_F32Cx16_LOAD(x) _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x)))
1050
  #define GGML_F32Cx16_STORE(x, y) _mm256_storeu_si256((__m256i *)(x), _mm512_cvtps_ph(y, 0))
1051
 
1052
  #define GGML_F32Cx16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a)
 
1144
 
1145
  #if defined(__F16C__)
1146
  // the _mm256_cvt intrinsics require F16C
1147
+ #define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
1148
  #define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0))
1149
  #else
1150
  static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
 
1662
  #endif
1663
  }
1664
 
1665
+ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, const ggml_fp16_t * restrict x, const float v) {
1666
+ #if defined(GGML_SIMD)
1667
+ const int np = (n & ~(GGML_F16_STEP - 1));
1668
+
1669
+ GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
1670
+
1671
+ GGML_F16_VEC ax[GGML_F16_ARR];
1672
+ GGML_F16_VEC ay[GGML_F16_ARR];
1673
+
1674
+ for (int i = 0; i < np; i += GGML_F16_STEP) {
1675
+ for (int j = 0; j < GGML_F16_ARR; j++) {
1676
+ ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
1677
+ ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
1678
+ ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
1679
+
1680
+ GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
1681
+ }
1682
+ }
1683
+
1684
+ // leftovers
1685
+ for (int i = np; i < n; ++i) {
1686
+ y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
1687
+ }
1688
+ #else
1689
+ // scalar
1690
+ for (int i = 0; i < n; ++i) {
1691
+ y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
1692
+ }
1693
+ #endif
1694
+ }
1695
+
1696
  // xs and vs are byte strides of x and v
1697
  inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) {
1698
 
 
1777
  #endif
1778
  }
1779
 
1780
+ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) {
1781
+ #if defined(GGML_SIMD)
1782
+ const int np = (n & ~(GGML_F16_STEP - 1));
1783
+
1784
+ GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
1785
+
1786
+ GGML_F16_VEC ay[GGML_F16_ARR];
1787
+
1788
+ for (int i = 0; i < np; i += GGML_F16_STEP) {
1789
+ for (int j = 0; j < GGML_F16_ARR; j++) {
1790
+ ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
1791
+ ay[j] = GGML_F16_VEC_MUL(ay[j], vx);
1792
+
1793
+ GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
1794
+ }
1795
+ }
1796
+
1797
+ // leftovers
1798
+ for (int i = np; i < n; ++i) {
1799
+ y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v);
1800
+ }
1801
+ #else
1802
+ // scalar
1803
+ for (int i = 0; i < n; ++i) {
1804
+ y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v);
1805
+ }
1806
+ #endif
1807
+ }
1808
+
1809
  inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); }
1810
  inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
1811
  inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
 
2061
  "LEAKY_RELU",
2062
 
2063
  "FLASH_ATTN",
2064
+ "FLASH_ATTN_EXT",
2065
  "FLASH_FF",
2066
  "FLASH_ATTN_BACK",
2067
  "SSM_CONV",
 
2088
  "CROSS_ENTROPY_LOSS_BACK",
2089
  };
2090
 
2091
+ static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77");
2092
 
2093
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2094
  "none",
 
2152
  "leaky_relu(x)",
2153
 
2154
  "flash_attn(x)",
2155
+ "flash_attn_ext(x)",
2156
  "flash_ff(x)",
2157
  "flash_attn_back(x)",
2158
  "ssm_conv(x)",
 
2179
  "cross_entropy_loss_back(x,y)",
2180
  };
2181
 
2182
+ static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77");
2183
 
2184
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
2185
 
 
4637
  void ggml_mul_mat_set_prec(
4638
  struct ggml_tensor * a,
4639
  enum ggml_prec prec) {
4640
+ GGML_ASSERT(a->op == GGML_OP_MUL_MAT);
4641
+
4642
  const int32_t prec_i32 = (int32_t) prec;
4643
 
4644
  ggml_set_op_params_i32(a, 0, prec_i32);
 
5477
  GGML_ASSERT(ggml_is_contiguous(a));
5478
 
5479
  if (mask) {
5480
+ GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32);
5481
  GGML_ASSERT(ggml_is_contiguous(mask));
5482
  GGML_ASSERT(ggml_is_matrix(mask));
5483
+ GGML_ASSERT(mask->ne[0] == a->ne[0]);
5484
+ GGML_ASSERT(mask->ne[1] >= a->ne[1]);
5485
  }
5486
 
5487
  if (pos) {
5488
  GGML_ASSERT(ggml_is_vector(pos));
5489
+ GGML_ASSERT(pos->type == GGML_TYPE_F16 || pos->type == GGML_TYPE_F32);
5490
  GGML_ASSERT(pos->ne[0] == a->ne[0]);
5491
  }
5492
 
5493
+ if (pos && mask) {
5494
+ GGML_ASSERT(pos->type == mask->type);
5495
+ }
5496
+
5497
  if (max_bias > 0.0f) {
5498
  GGML_ASSERT(pos);
5499
  }
 
6302
  return result;
6303
  }
6304
 
6305
+ // ggml_flash_attn_ext
6306
+
6307
+ struct ggml_tensor * ggml_flash_attn_ext(
6308
+ struct ggml_context * ctx,
6309
+ struct ggml_tensor * q,
6310
+ struct ggml_tensor * k,
6311
+ struct ggml_tensor * v,
6312
+ struct ggml_tensor * mask,
6313
+ float scale) {
6314
+ GGML_ASSERT(ggml_can_mul_mat(k, q));
6315
+ // TODO: check if vT can be multiplied by (k*qT)
6316
+ if (mask) {
6317
+ GGML_ASSERT(ggml_is_contiguous(mask));
6318
+ GGML_ASSERT(mask->ne[2] == 1);
6319
+ GGML_ASSERT(mask->ne[3] == 1);
6320
+ GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) &&
6321
+ "the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big");
6322
+ //GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
6323
+ }
6324
+
6325
+ bool is_node = false;
6326
+
6327
+ if (q->grad || k->grad || v->grad) {
6328
+ is_node = true;
6329
+ }
6330
+
6331
+ // permute(0, 2, 1, 3)
6332
+ int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
6333
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
6334
+
6335
+ float params[] = { scale };
6336
+ ggml_set_op_params(result, params, sizeof(params));
6337
+
6338
+ result->op = GGML_OP_FLASH_ATTN_EXT;
6339
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6340
+ result->src[0] = q;
6341
+ result->src[1] = k;
6342
+ result->src[2] = v;
6343
+ result->src[3] = mask;
6344
+
6345
+ return result;
6346
+ }
6347
+
6348
+ void ggml_flash_attn_ext_set_prec(
6349
+ struct ggml_tensor * a,
6350
+ enum ggml_prec prec) {
6351
+ GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
6352
+
6353
+ const int32_t prec_i32 = (int32_t) prec;
6354
+
6355
+ ggml_set_op_params_i32(a, 1, prec_i32); // scale is on first pos
6356
+ }
6357
+
6358
  // ggml_flash_ff
6359
 
6360
  struct ggml_tensor * ggml_flash_ff(
 
12440
 
12441
  GGML_TENSOR_UNARY_OP_LOCALS
12442
 
12443
+ //const int64_t ne11 = src1 ? src1->ne[1] : 1;
12444
 
12445
  // TODO: is this supposed to be ceil instead of floor?
12446
  // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
 
12463
  float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
12464
 
12465
  // when max_bias <= 0.0f, src2 is not used and we default it to src0 to avoid branching
12466
+ ggml_fp16_t * pos_f16 = src2 ? (ggml_fp16_t *) src2->data : src0->data;
12467
+ float * pos_f32 = src2 ? (float *) src2->data : src0->data;
12468
+
12469
+ const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
12470
 
12471
  for (int i1 = ir0; i1 < ir1; i1++) {
12472
  float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
12473
  float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
12474
 
12475
  // broadcast the mask across rows
12476
+ ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
12477
+ float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
12478
 
12479
  ggml_vec_cpy_f32 (nc, wp, sp);
12480
  ggml_vec_scale_f32(nc, wp, scale);
12481
+ if (mp_f32) {
12482
+ if (use_f16) {
12483
+ for (int i = 0; i < nc; ++i) {
12484
+ wp[i] += GGML_FP16_TO_FP32(mp_f16[i]);
12485
+ }
12486
+ } else {
12487
+ for (int i = 0; i < nc; ++i) {
12488
+ wp[i] += mp_f32[i];
12489
+ }
12490
+ }
12491
  }
12492
 
12493
  // ALiBi bias
 
12495
  const uint32_t h = (i1/ne01)%ne02; // head
12496
  const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1);
12497
 
12498
+ if (use_f16) {
12499
+ for (int i = 0; i < nc; ++i) {
12500
+ wp[i] += slope*GGML_FP16_TO_FP32(pos_f16[i]);
12501
+ }
12502
+ } else {
12503
+ for (int i = 0; i < nc; ++i) {
12504
+ wp[i] += slope*pos_f32[i];
12505
+ }
12506
  }
12507
  }
12508
 
 
14772
  }
14773
  }
14774
 
14775
+ // ggml_compute_forward_flash_attn_ext
14776
+
14777
+ static void ggml_compute_forward_flash_attn_ext_f16(
14778
+ const struct ggml_compute_params * params,
14779
+ const struct ggml_tensor * q,
14780
+ const struct ggml_tensor * k,
14781
+ const struct ggml_tensor * v,
14782
+ const struct ggml_tensor * mask,
14783
+ struct ggml_tensor * dst) {
14784
+ int64_t t0 = ggml_perf_time_us();
14785
+ UNUSED(t0);
14786
+
14787
+ GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
14788
+ GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
14789
+ GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
14790
+ GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
14791
+ GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
14792
+ GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
14793
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
14794
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
14795
+
14796
+ const int ith = params->ith;
14797
+ const int nth = params->nth;
14798
+
14799
+ const int64_t D = neq0;
14800
+ const int64_t N = neq1;
14801
+
14802
+ GGML_ASSERT(ne0 == D);
14803
+ GGML_ASSERT(ne2 == N);
14804
+
14805
+ GGML_ASSERT(nbq0 == sizeof(float));
14806
+ GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
14807
+ GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
14808
+
14809
+ GGML_ASSERT(neq0 == D);
14810
+ GGML_ASSERT(nek0 == D);
14811
+ GGML_ASSERT(nev0 == D);
14812
+
14813
+ GGML_ASSERT(neq1 == N);
14814
+ GGML_ASSERT(nev0 == D);
14815
+
14816
+ // dst cannot be transposed or permuted
14817
+ GGML_ASSERT(nb0 == sizeof(float));
14818
+ GGML_ASSERT(nb0 <= nb1);
14819
+ GGML_ASSERT(nb1 <= nb2);
14820
+ GGML_ASSERT(nb2 <= nb3);
14821
+
14822
+ // broadcast factors
14823
+ const int64_t rk2 = neq2/nek2;
14824
+ const int64_t rk3 = neq3/nek3;
14825
+
14826
+ const int64_t rv2 = neq2/nev2;
14827
+ const int64_t rv3 = neq3/nev3;
14828
+
14829
+ if (params->type == GGML_TASK_TYPE_INIT) {
14830
+ return;
14831
+ }
14832
+
14833
+ if (params->type == GGML_TASK_TYPE_FINALIZE) {
14834
+ return;
14835
+ }
14836
+
14837
+ // parallelize by q rows using ggml_vec_dot_f32
14838
+
14839
+ // total rows in q
14840
+ const int nr = neq1*neq2*neq3;
14841
+
14842
+ // rows per thread
14843
+ const int dr = (nr + nth - 1)/nth;
14844
+
14845
+ // row range for this thread
14846
+ const int ir0 = dr*ith;
14847
+ const int ir1 = MIN(ir0 + dr, nr);
14848
+
14849
+ float scale = 1.0f;
14850
+ memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
14851
+
14852
+ // loop over n_batch and n_head
14853
+ for (int ir = ir0; ir < ir1; ++ir) {
14854
+ // q indices
14855
+ const int iq3 = ir/(neq2*neq1);
14856
+ const int iq2 = (ir - iq3*neq2*neq1)/neq1;
14857
+ const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
14858
+
14859
+ float S = 0.0f;
14860
+ float M = -INFINITY;
14861
+
14862
+ float * V32 = (float *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32);
14863
+ ggml_fp16_t * Q16 = (ggml_fp16_t *) (V32); // reuse memory
14864
+ ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D);
14865
+
14866
+ memset(V16, 0, D*sizeof(ggml_fp16_t));
14867
+
14868
+ const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
14869
+
14870
+ // k indices
14871
+ const int ik3 = iq3 / rk3;
14872
+ const int ik2 = iq2 / rk2;
14873
+
14874
+ // v indices
14875
+ const int iv3 = iq3 / rv3;
14876
+ const int iv2 = iq2 / rv2;
14877
+
14878
+ // online softmax / attention
14879
+ // loop over n_kv and n_head_kv
14880
+ // ref: https://arxiv.org/pdf/2112.05682.pdf
14881
+ for (int64_t ic = 0; ic < nek1; ++ic) {
14882
+ const float mv = mp ? GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
14883
+ if (mv == -INFINITY) {
14884
+ continue;
14885
+ }
14886
+
14887
+ float s;
14888
+
14889
+ // convert Q to F16 in V32
14890
+ {
14891
+ const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
14892
+
14893
+ for (int64_t d = 0; d < D; ++d) {
14894
+ Q16[d] = GGML_FP32_TO_FP16(pq[d]);
14895
+ }
14896
+ }
14897
+
14898
+ ggml_vec_dot_f16(D,
14899
+ &s, 0,
14900
+ (ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
14901
+ Q16, 0, 1);
14902
+
14903
+ s = s*scale + mv;
14904
+
14905
+ const float Mold = M;
14906
+
14907
+ float ms = 1.0f;
14908
+ float vs = 1.0f;
14909
+
14910
+ if (s > M) {
14911
+ M = s;
14912
+ ms = expf(Mold - M);
14913
+
14914
+ // V = V*expf(Mold - M)
14915
+ ggml_vec_scale_f16(D, V16, ms);
14916
+ } else {
14917
+ vs = expf(s - M);
14918
+ }
14919
+
14920
+ const ggml_fp16_t * v16 = (const ggml_fp16_t *) ((char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
14921
+
14922
+ // V += v*expf(s - M)
14923
+ ggml_vec_mad_f16(D, V16, v16, vs);
14924
+
14925
+ S = S*ms + vs;
14926
+ }
14927
+
14928
+ // V /= S
14929
+ for (int64_t d = 0; d < D; ++d) {
14930
+ V32[d] = GGML_FP16_TO_FP32(V16[d])/S;
14931
+ }
14932
+
14933
+ // dst indices
14934
+ const int i1 = iq1;
14935
+ const int i2 = iq2;
14936
+ const int i3 = iq3;
14937
+
14938
+ // original
14939
+ //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
14940
+
14941
+ // permute(0, 2, 1, 3)
14942
+ memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, V32, nb1);
14943
+ }
14944
+ }
14945
+
14946
+ static void ggml_compute_forward_flash_attn_ext(
14947
+ const struct ggml_compute_params * params,
14948
+ const struct ggml_tensor * q,
14949
+ const struct ggml_tensor * k,
14950
+ const struct ggml_tensor * v,
14951
+ const struct ggml_tensor * mask,
14952
+ struct ggml_tensor * dst) {
14953
+ switch (dst->op_params[1]) {
14954
+ case GGML_PREC_DEFAULT:
14955
+ case GGML_PREC_F32:
14956
+ {
14957
+ // uses F32 accumulators
14958
+ ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst);
14959
+ } break;
14960
+ default:
14961
+ {
14962
+ GGML_ASSERT(false);
14963
+ } break;
14964
+ }
14965
+ }
14966
+
14967
  // ggml_compute_forward_flash_ff
14968
 
14969
  static void ggml_compute_forward_flash_ff_f16(
 
16775
  const bool masked = t != 0;
16776
  ggml_compute_forward_flash_attn(params, masked, tensor);
16777
  } break;
16778
+ case GGML_OP_FLASH_ATTN_EXT:
16779
+ {
16780
+ ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
16781
+ } break;
16782
  case GGML_OP_FLASH_FF:
16783
  {
16784
  ggml_compute_forward_flash_ff(params, tensor);
 
17791
  GGML_ASSERT(false); // TODO: not implemented
17792
  } break;
17793
  case GGML_OP_FLASH_ATTN:
17794
+ case GGML_OP_FLASH_ATTN_EXT:
17795
  {
17796
  struct ggml_tensor * flash_grad = NULL;
17797
  if (src0->grad || src1->grad || tensor->src[2]->grad) {
 
18569
  n_tasks = n_threads;
18570
  } break;
18571
  case GGML_OP_FLASH_ATTN:
18572
+ case GGML_OP_FLASH_ATTN_EXT:
18573
  {
18574
  n_tasks = n_threads;
18575
  } break;
 
18973
  cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
18974
  }
18975
  } break;
18976
+ case GGML_OP_FLASH_ATTN_EXT:
18977
+ {
18978
+ const int64_t ne00 = node->src[0]->ne[0]; // D
18979
+
18980
+ cur = 2*sizeof(float)*ne00*n_tasks; // 2x head size
18981
+ } break;
18982
  case GGML_OP_FLASH_FF:
18983
  {
18984
  if (node->src[1]->type == GGML_TYPE_F32) {
ggml.h CHANGED
@@ -475,6 +475,7 @@ extern "C" {
475
  GGML_OP_LEAKY_RELU,
476
 
477
  GGML_OP_FLASH_ATTN,
 
478
  GGML_OP_FLASH_FF,
479
  GGML_OP_FLASH_ATTN_BACK,
480
  GGML_OP_SSM_CONV,
@@ -1731,6 +1732,25 @@ extern "C" {
1731
  struct ggml_tensor * v,
1732
  bool masked);
1733
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1734
  GGML_API struct ggml_tensor * ggml_flash_attn_back(
1735
  struct ggml_context * ctx,
1736
  struct ggml_tensor * q,
 
475
  GGML_OP_LEAKY_RELU,
476
 
477
  GGML_OP_FLASH_ATTN,
478
+ GGML_OP_FLASH_ATTN_EXT,
479
  GGML_OP_FLASH_FF,
480
  GGML_OP_FLASH_ATTN_BACK,
481
  GGML_OP_SSM_CONV,
 
1732
  struct ggml_tensor * v,
1733
  bool masked);
1734
 
1735
+ #define GGML_KQ_MASK_PAD 32
1736
+
1737
+ // q: [n_embd, n_batch, n_head, 1]
1738
+ // k: [n_embd, n_kv, n_head_kv, 1]
1739
+ // v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !!
1740
+ // mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
1741
+ // res: [n_embd, n_head, n_batch, 1] !! permuted !!
1742
+ GGML_API struct ggml_tensor * ggml_flash_attn_ext(
1743
+ struct ggml_context * ctx,
1744
+ struct ggml_tensor * q,
1745
+ struct ggml_tensor * k,
1746
+ struct ggml_tensor * v,
1747
+ struct ggml_tensor * mask,
1748
+ float scale);
1749
+
1750
+ GGML_API void ggml_flash_attn_ext_set_prec(
1751
+ struct ggml_tensor * a,
1752
+ enum ggml_prec prec);
1753
+
1754
  GGML_API struct ggml_tensor * ggml_flash_attn_back(
1755
  struct ggml_context * ctx,
1756
  struct ggml_tensor * q,