Spaces:
Running
ggml : add Flash Attention (llama/5021)
Browse files* ggml : add ggml_flash_attn_ext API
* ggml : fix GQA support in ggml_flash_attn_ext
* ggml : online attention (CPU)
* metal : initial implementation
* metal : f16 precision
* metal : reduce branches
* metal : specialize for head size
* wip : 8 rows per simd group
* wip : 4 rows per simd group
* wip : template for rows per warp
* metal : parallelize across KV size
* metal : parallel reduce across heads
* metal : efficient flash_attn_f16 implementation
* metal : avoid redundant loads of the attention
* metal : scale and mask in matrix form
* metal : fix comment
* llama : avoid ggml_cast, use F32 query
* metal : add parallel reduce version (disabled)
* metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
* metal : add tests, fix scaling, support C > 32
* metal : improve precision
* ggml : fix f16 mad
* metal : minor
* metal : support Q > 8
* tests : add ATTN tests
* metal : disable buffer allocation logs
* tests : more
* metal : faster inner loop for C == 32
* metal : fix array initialization
* tests : ifdef
* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext
* ggml : fix ggml_soft_max mask requirement
* cuda : fix soft_max to use correct mask size
* cuda : add flash_attn kernel (wip)
* metal : optimize softmax for C > 32
* metal : optimize softmax
* tests : minor fix
* cuda : avoid zeroing fragments
* tests : update dims
* cuda : fix __hisinf() result check
* cuda : avoid warp_reduce for smax
* cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
* cuda : make loops use the same loop values
Thanks Johannes again for the tip
* cuda : unroll some of the loops
* cuda : avoid __hisinf branches
* cuda : use half2 in softmax
* cuda : switch to 1 warp for bs > 16
* cuda : speed-up reduce part of the kernel
* cuda : unroll Q*K^T loop
* cuda : fix -INF block check
* cuda : simplify softmax
* cuda : fix matrix names
* cuda : minor
* llama : adapt to F16 KQ_pos
* llama : adapt new models to F16 KQ_mask
* ggml : fix F16 store (ARM NEON)
* llama : fix type of KQ_mask and KQ_pos
* ggml : fix CPU soft_max
* tests : add hs=256
* cuda : fix build
* metal : improve perf via smaller int registers
* cuda : adapt soft_max to F16 mask and pos
* CUDA: faster FlashAttention, kernel for bs == 1
* 16 cols for Phi-2
* no vec for hs, no hs==256 ncols==32 for Volta
* adjust kernel selection logic
* 4 warps, 256 stride for all D
* no ncols == 64
* Multiple parallel blocks for batch size 1
* fix compile warnings
* fix excessive KQ_b loads
* fix cmake build
* fix KV cache padding, NaN from INFINITY (llama/6438)
* llama : flash_attn cparam + fix defrag
* server: support flash_attn param
* server: bench: enable flash_attn param
* CUDA: refactor host code, dyn. par. blocks
* fix flash_attn_vec_f16 race condition
* flush softmax exp below threshold to 0
* store temp KQ in registers
* Calculate KQ as FP32 if KQV has GGML_PREC_F32
* Add __hgt2_mask implementation for CUDA 11
* fix KQ FP32 precision fpr parallel_blocks > 1
* llama-bench : add -fa,--flash-attn arg
* metal : add BS=1 kernel for flash attention (llama/6508)
* metal : add BS=1 kernel for flash attention (wip)
* metal : support more than 1 warps
* metal : opts
* metal : opt
* metal : switch to parallel reduce
* metal : reduce registers
* metal : simplify
* metal : initial FA vec kernel
* metal : use F32 attention accumulators
* batched-bench : add fattn arg
* llama : simplify llama_build_kv_store
ggml-ci
* llama : adapt build_olmo to changes
* ggml : fix arm fp16 store on windows
* metal : clean-up
* metal : clean-up kernel code
* metal : minor
* tests : remove benchmarks
ggml-ci
* ggml : fix avx512 const correctness
ggml-ci
* ggml : fix soft_max with bias on CPU
ggml-ci
* common : print --flash-attn in help
* ggml : fix num dimensions in ggml_flash_attn_ext
* llama : force disable flash attention for incompatible models
* ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
* cuda : uint -> uint32_t
* cuda : "constexpr dim3" -> "const dim3"
ggml-ci
* cuda : try to fix __hgt2_mask
ggml-ci
* ggml : add TODO's for F16/F32 mask/pos support in other backends
* llama : replace bool need_kq_pos with use_alibi
* llama : prep ALiBi support for BERT models
ggml-ci
* llama : fix n_batch requirements
ggml-ci
* cont
* server : add help for --flash-attn arg
* llama : disable FA for AMD
* tests : remove TMP_ATTN_BENCH
ggml-ci
* llama : support save/load state with FA enabled
ggml-ci
* ci : add CUDA save-load-state tests
ggml-ci
* llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
* llama : fix copy-paste errors, add TODO
* llama : disallow incompatible states
* llama : update llama_state_get_size after v_trans field
* metal : remove tmp log
* llama : add static reminder for llama_state_get_size
* metal : fix max nsg
ggml-ci
* ci : fix arg order
ggml-ci
---------
Co-authored-by: Johannes Gäßler <[email protected]>
Co-authored-by: Pierrick HYMBERT <[email protected]>
- ggml-cuda.cu +6 -0
- ggml-cuda/common.cuh +26 -14
- ggml-cuda/fattn.cu +944 -0
- ggml-cuda/fattn.cuh +3 -0
- ggml-cuda/softmax.cu +36 -10
- ggml-kompute.cpp +7 -0
- ggml-metal.m +372 -177
- ggml-metal.metal +654 -18
- ggml-sycl.cpp +5 -1
- ggml-vulkan.cpp +5 -0
- ggml.c +360 -15
- ggml.h +20 -0
|
@@ -14,6 +14,7 @@
|
|
| 14 |
#include "ggml-cuda/cpy.cuh"
|
| 15 |
#include "ggml-cuda/diagmask.cuh"
|
| 16 |
#include "ggml-cuda/dmmv.cuh"
|
|
|
|
| 17 |
#include "ggml-cuda/getrows.cuh"
|
| 18 |
#include "ggml-cuda/im2col.cuh"
|
| 19 |
#include "ggml-cuda/mmq.cuh"
|
|
@@ -140,6 +141,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
|
| 140 |
info.devices[id].cc = 100*prop.major + 10*prop.minor;
|
| 141 |
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
| 142 |
info.devices[id].smpb = prop.sharedMemPerBlock;
|
|
|
|
| 143 |
}
|
| 144 |
|
| 145 |
for (int id = 0; id < info.device_count; ++id) {
|
|
@@ -2293,6 +2295,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|
| 2293 |
case GGML_OP_ARGSORT:
|
| 2294 |
ggml_cuda_op_argsort(ctx, dst);
|
| 2295 |
break;
|
|
|
|
|
|
|
|
|
|
| 2296 |
default:
|
| 2297 |
return false;
|
| 2298 |
}
|
|
@@ -2568,6 +2573,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
|
| 2568 |
case GGML_OP_ARANGE:
|
| 2569 |
case GGML_OP_TIMESTEP_EMBEDDING:
|
| 2570 |
case GGML_OP_LEAKY_RELU:
|
|
|
|
| 2571 |
return true;
|
| 2572 |
default:
|
| 2573 |
return false;
|
|
|
|
| 14 |
#include "ggml-cuda/cpy.cuh"
|
| 15 |
#include "ggml-cuda/diagmask.cuh"
|
| 16 |
#include "ggml-cuda/dmmv.cuh"
|
| 17 |
+
#include "ggml-cuda/fattn.cuh"
|
| 18 |
#include "ggml-cuda/getrows.cuh"
|
| 19 |
#include "ggml-cuda/im2col.cuh"
|
| 20 |
#include "ggml-cuda/mmq.cuh"
|
|
|
|
| 141 |
info.devices[id].cc = 100*prop.major + 10*prop.minor;
|
| 142 |
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
| 143 |
info.devices[id].smpb = prop.sharedMemPerBlock;
|
| 144 |
+
info.devices[id].nsm = prop.multiProcessorCount;
|
| 145 |
}
|
| 146 |
|
| 147 |
for (int id = 0; id < info.device_count; ++id) {
|
|
|
|
| 2295 |
case GGML_OP_ARGSORT:
|
| 2296 |
ggml_cuda_op_argsort(ctx, dst);
|
| 2297 |
break;
|
| 2298 |
+
case GGML_OP_FLASH_ATTN_EXT:
|
| 2299 |
+
ggml_cuda_flash_attn_ext(ctx, dst);
|
| 2300 |
+
break;
|
| 2301 |
default:
|
| 2302 |
return false;
|
| 2303 |
}
|
|
|
|
| 2573 |
case GGML_OP_ARANGE:
|
| 2574 |
case GGML_OP_TIMESTEP_EMBEDDING:
|
| 2575 |
case GGML_OP_LEAKY_RELU:
|
| 2576 |
+
case GGML_OP_FLASH_ATTN_EXT:
|
| 2577 |
return true;
|
| 2578 |
default:
|
| 2579 |
return false;
|
|
@@ -142,6 +142,7 @@
|
|
| 142 |
#define CC_PASCAL 600
|
| 143 |
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
|
| 144 |
#define CC_VOLTA 700
|
|
|
|
| 145 |
#define CC_OFFSET_AMD 1000000
|
| 146 |
#define CC_RDNA1 (CC_OFFSET_AMD + 1010)
|
| 147 |
#define CC_RDNA2 (CC_OFFSET_AMD + 1030)
|
|
@@ -271,7 +272,6 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
|
|
| 271 |
return a;
|
| 272 |
}
|
| 273 |
|
| 274 |
-
#ifdef GGML_CUDA_F16
|
| 275 |
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
| 276 |
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
| 277 |
#pragma unroll
|
|
@@ -284,7 +284,6 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
|
| 284 |
NO_DEVICE_CODE;
|
| 285 |
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
| 286 |
}
|
| 287 |
-
#endif // GGML_CUDA_F16
|
| 288 |
|
| 289 |
static __device__ __forceinline__ float warp_reduce_max(float x) {
|
| 290 |
#pragma unroll
|
|
@@ -294,19 +293,26 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
|
|
| 294 |
return x;
|
| 295 |
}
|
| 296 |
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
|
| 311 |
#if defined(GGML_USE_HIPBLAS)
|
| 312 |
#define __CUDA_ARCH__ 1300
|
|
@@ -391,6 +397,11 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
|
|
| 391 |
}
|
| 392 |
#endif // defined(GGML_USE_HIPBLAS)
|
| 393 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 394 |
// TODO: move to ggml-common.h
|
| 395 |
static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
|
| 396 |
|
|
@@ -404,6 +415,7 @@ struct ggml_cuda_device_info {
|
|
| 404 |
|
| 405 |
struct cuda_device_info {
|
| 406 |
int cc; // compute capability
|
|
|
|
| 407 |
size_t smpb; // max. shared memory per block
|
| 408 |
bool vmm; // virtual memory support
|
| 409 |
size_t vmm_granularity; // granularity of virtual memory
|
|
|
|
| 142 |
#define CC_PASCAL 600
|
| 143 |
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
|
| 144 |
#define CC_VOLTA 700
|
| 145 |
+
#define CC_AMPERE 800
|
| 146 |
#define CC_OFFSET_AMD 1000000
|
| 147 |
#define CC_RDNA1 (CC_OFFSET_AMD + 1010)
|
| 148 |
#define CC_RDNA2 (CC_OFFSET_AMD + 1030)
|
|
|
|
| 272 |
return a;
|
| 273 |
}
|
| 274 |
|
|
|
|
| 275 |
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
| 276 |
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
| 277 |
#pragma unroll
|
|
|
|
| 284 |
NO_DEVICE_CODE;
|
| 285 |
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
| 286 |
}
|
|
|
|
| 287 |
|
| 288 |
static __device__ __forceinline__ float warp_reduce_max(float x) {
|
| 289 |
#pragma unroll
|
|
|
|
| 293 |
return x;
|
| 294 |
}
|
| 295 |
|
| 296 |
+
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
|
| 297 |
+
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
| 298 |
+
#pragma unroll
|
| 299 |
+
for (int mask = 16; mask > 0; mask >>= 1) {
|
| 300 |
+
x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
|
| 301 |
+
}
|
| 302 |
+
return x;
|
| 303 |
+
#else
|
| 304 |
+
GGML_UNUSED(x);
|
| 305 |
+
NO_DEVICE_CODE;
|
| 306 |
+
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
| 307 |
+
}
|
| 308 |
|
| 309 |
+
#if CUDART_VERSION < 12000
|
| 310 |
+
static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half2 b) {
|
| 311 |
+
const uint32_t mask_low = 0x0000FFFF * (float( __low2half(a)) > float( __low2half(b)));
|
| 312 |
+
const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b)));
|
| 313 |
+
return mask_low | mask_high;
|
| 314 |
+
}
|
| 315 |
+
#endif // CUDART_VERSION < 12000
|
| 316 |
|
| 317 |
#if defined(GGML_USE_HIPBLAS)
|
| 318 |
#define __CUDA_ARCH__ 1300
|
|
|
|
| 397 |
}
|
| 398 |
#endif // defined(GGML_USE_HIPBLAS)
|
| 399 |
|
| 400 |
+
#define FP16_AVAILABLE defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) ? \
|
| 401 |
+
defined(RDNA1) || defined(RDNA2) || defined(RDNA3) : __CUDA_ARCH__ >= CC_PASCAL
|
| 402 |
+
|
| 403 |
+
#define FP16_MMA_AVAILABLE !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
|
| 404 |
+
|
| 405 |
// TODO: move to ggml-common.h
|
| 406 |
static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
|
| 407 |
|
|
|
|
| 415 |
|
| 416 |
struct cuda_device_info {
|
| 417 |
int cc; // compute capability
|
| 418 |
+
int nsm; // number of streaming multiprocessors
|
| 419 |
size_t smpb; // max. shared memory per block
|
| 420 |
bool vmm; // virtual memory support
|
| 421 |
size_t vmm_granularity; // granularity of virtual memory
|
|
@@ -0,0 +1,944 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "common.cuh"
|
| 2 |
+
#include "fattn.cuh"
|
| 3 |
+
|
| 4 |
+
#include <cstdint>
|
| 5 |
+
|
| 6 |
+
#if FP16_MMA_AVAILABLE
|
| 7 |
+
#include <mma.h>
|
| 8 |
+
#endif
|
| 9 |
+
|
| 10 |
+
#define FATTN_KQ_STRIDE 256
|
| 11 |
+
#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
|
| 12 |
+
#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
|
| 13 |
+
|
| 14 |
+
template<int D, int parallel_blocks> // D == head size
|
| 15 |
+
__launch_bounds__(((D + WARP_SIZE - 1) / WARP_SIZE)*WARP_SIZE, 1)
|
| 16 |
+
static __global__ void flash_attn_vec_ext_f16(
|
| 17 |
+
const char * __restrict__ Q,
|
| 18 |
+
const char * __restrict__ K,
|
| 19 |
+
const char * __restrict__ V,
|
| 20 |
+
const char * __restrict__ mask,
|
| 21 |
+
float * __restrict__ dst,
|
| 22 |
+
float2 * __restrict__ dst_meta,
|
| 23 |
+
const float scale,
|
| 24 |
+
const int ne00,
|
| 25 |
+
const int ne01,
|
| 26 |
+
const int ne02,
|
| 27 |
+
const int ne03,
|
| 28 |
+
const int ne10,
|
| 29 |
+
const int ne11,
|
| 30 |
+
const int ne12,
|
| 31 |
+
const int ne13,
|
| 32 |
+
const int ne31,
|
| 33 |
+
const int nb31,
|
| 34 |
+
const int nb01,
|
| 35 |
+
const int nb02,
|
| 36 |
+
const int nb03,
|
| 37 |
+
const int nb11,
|
| 38 |
+
const int nb12,
|
| 39 |
+
const int nb13,
|
| 40 |
+
const int ne0,
|
| 41 |
+
const int ne1,
|
| 42 |
+
const int ne2,
|
| 43 |
+
const int ne3) {
|
| 44 |
+
#if FP16_AVAILABLE
|
| 45 |
+
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
| 46 |
+
|
| 47 |
+
const int ic = blockIdx.x / parallel_blocks; // Index of the Q/QKV column to work on.
|
| 48 |
+
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
| 49 |
+
|
| 50 |
+
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
| 51 |
+
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic);
|
| 52 |
+
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
|
| 53 |
+
const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
|
| 54 |
+
const half * maskh = (const half *) mask + ne11*ic;
|
| 55 |
+
|
| 56 |
+
const int stride_KV = nb11 / sizeof(half);
|
| 57 |
+
const int stride_KV2 = nb11 / sizeof(half2);
|
| 58 |
+
|
| 59 |
+
constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE;
|
| 60 |
+
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
| 61 |
+
__builtin_assume(tid < nwarps*WARP_SIZE);
|
| 62 |
+
|
| 63 |
+
__shared__ half KQ[nwarps*WARP_SIZE];
|
| 64 |
+
KQ[tid] = -INFINITY;
|
| 65 |
+
half2 * KQ2 = (half2 *) KQ;
|
| 66 |
+
|
| 67 |
+
half kqmax = -HALF_MAX_HALF;
|
| 68 |
+
half kqsum = 0.0f;
|
| 69 |
+
|
| 70 |
+
__shared__ half kqmax_shared[WARP_SIZE];
|
| 71 |
+
__shared__ half kqsum_shared[WARP_SIZE];
|
| 72 |
+
if (threadIdx.y == 0) {
|
| 73 |
+
kqmax_shared[threadIdx.x] = -HALF_MAX_HALF;
|
| 74 |
+
kqsum_shared[threadIdx.x] = 0.0f;
|
| 75 |
+
}
|
| 76 |
+
__syncthreads();
|
| 77 |
+
|
| 78 |
+
// Convert Q to half2 and store in registers:
|
| 79 |
+
half2 Q_h2[(D/2 + WARP_SIZE - 1) / WARP_SIZE];
|
| 80 |
+
#pragma unroll
|
| 81 |
+
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
| 82 |
+
const int i = i0 + threadIdx.x;
|
| 83 |
+
if (i0 + WARP_SIZE > D/2 && i >= D/2) {
|
| 84 |
+
break;
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
Q_h2[i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(Q_f2[i].x, Q_f2[i].y);
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
half2 VKQ = make_half2(0.0f, 0.0f); // Each thread calculates a single VKQ value.
|
| 91 |
+
|
| 92 |
+
const int k_start = parallel_blocks == 1 ? 0 : ip*D;
|
| 93 |
+
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
|
| 94 |
+
// Calculate KQ tile and keep track of new maximum KQ values:
|
| 95 |
+
half kqmax_new = kqmax;
|
| 96 |
+
#pragma unroll
|
| 97 |
+
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) {
|
| 98 |
+
const int i_KQ = i_KQ_0 + threadIdx.y;
|
| 99 |
+
|
| 100 |
+
if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) {
|
| 101 |
+
break;
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
half2 sum2 = make_half2(0.0f, 0.0f);
|
| 105 |
+
#pragma unroll
|
| 106 |
+
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
|
| 107 |
+
const int k_KQ = k_KQ_0 + threadIdx.x;
|
| 108 |
+
if (k_KQ_0 + WARP_SIZE > D/2 && k_KQ >= D/2) {
|
| 109 |
+
break;
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
const half2 K_ik = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
|
| 113 |
+
sum2 += K_ik * Q_h2[k_KQ_0/WARP_SIZE];
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
sum2 = warp_reduce_sum(sum2);
|
| 117 |
+
half sum = __low2half(sum2) + __high2half(sum2);
|
| 118 |
+
sum += mask ? maskh[k_VKQ_0 + i_KQ] : __float2half(0.0f);
|
| 119 |
+
kqmax_new = __hmax(kqmax_new, sum);
|
| 120 |
+
if (threadIdx.x == 0) {
|
| 121 |
+
KQ[i_KQ] = sum;
|
| 122 |
+
}
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
kqmax_new = warp_reduce_max(kqmax_new);
|
| 126 |
+
if (threadIdx.x == 0) {
|
| 127 |
+
kqmax_shared[threadIdx.y] = kqmax_new;
|
| 128 |
+
}
|
| 129 |
+
__syncthreads();
|
| 130 |
+
kqmax_new = kqmax_shared[threadIdx.x];
|
| 131 |
+
kqmax_new = warp_reduce_max(kqmax_new);
|
| 132 |
+
|
| 133 |
+
const half KQ_max_scale = hexp(kqmax - kqmax_new);
|
| 134 |
+
kqmax = kqmax_new;
|
| 135 |
+
|
| 136 |
+
const half val = hexp(KQ[tid] - kqmax);
|
| 137 |
+
kqsum = kqsum*KQ_max_scale + val;
|
| 138 |
+
KQ[tid] = val;
|
| 139 |
+
|
| 140 |
+
VKQ *= __half2half2(KQ_max_scale);
|
| 141 |
+
|
| 142 |
+
__syncthreads();
|
| 143 |
+
|
| 144 |
+
if (tid < D) {
|
| 145 |
+
#pragma unroll
|
| 146 |
+
for (int k0 = 0; k0 < D; k0 += 2) {
|
| 147 |
+
if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) {
|
| 148 |
+
break;
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
half2 V_k;
|
| 152 |
+
reinterpret_cast<half&>(V_k.x) = V_h[(k_VKQ_0 + k0 + 0)*stride_KV + tid];
|
| 153 |
+
reinterpret_cast<half&>(V_k.y) = V_h[(k_VKQ_0 + k0 + 1)*stride_KV + tid];
|
| 154 |
+
VKQ += V_k*KQ2[k0/2];
|
| 155 |
+
}
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
__syncthreads();
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
if (tid >= D) {
|
| 162 |
+
kqsum = 0.0f;
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
kqsum = warp_reduce_sum(kqsum);
|
| 166 |
+
if (threadIdx.x == 0) {
|
| 167 |
+
kqsum_shared[threadIdx.y] = kqsum;
|
| 168 |
+
}
|
| 169 |
+
__syncthreads();
|
| 170 |
+
kqsum = kqsum_shared[threadIdx.x];
|
| 171 |
+
kqsum = warp_reduce_sum(kqsum);
|
| 172 |
+
|
| 173 |
+
if (tid >= D) {
|
| 174 |
+
return;
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
half dst_val = (__low2half(VKQ) + __high2half(VKQ));
|
| 178 |
+
if (parallel_blocks == 1) {
|
| 179 |
+
dst_val /= kqsum;
|
| 180 |
+
}
|
| 181 |
+
dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = dst_val;
|
| 182 |
+
|
| 183 |
+
if (parallel_blocks == 1 || tid != 0) {
|
| 184 |
+
return;
|
| 185 |
+
}
|
| 186 |
+
dst_meta[ic*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax, kqsum);
|
| 187 |
+
#else
|
| 188 |
+
NO_DEVICE_CODE;
|
| 189 |
+
#endif // FP16_AVAILABLE
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
|
| 193 |
+
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t>
|
| 194 |
+
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
| 195 |
+
static __global__ void flash_attn_ext_f16(
|
| 196 |
+
const char * __restrict__ Q,
|
| 197 |
+
const char * __restrict__ K,
|
| 198 |
+
const char * __restrict__ V,
|
| 199 |
+
const char * __restrict__ mask,
|
| 200 |
+
float * __restrict__ dst,
|
| 201 |
+
float2 * __restrict__ dst_meta,
|
| 202 |
+
const float scale,
|
| 203 |
+
const int ne00,
|
| 204 |
+
const int ne01,
|
| 205 |
+
const int ne02,
|
| 206 |
+
const int ne03,
|
| 207 |
+
const int ne10,
|
| 208 |
+
const int ne11,
|
| 209 |
+
const int ne12,
|
| 210 |
+
const int ne13,
|
| 211 |
+
const int ne31,
|
| 212 |
+
const int nb31,
|
| 213 |
+
const int nb01,
|
| 214 |
+
const int nb02,
|
| 215 |
+
const int nb03,
|
| 216 |
+
const int nb11,
|
| 217 |
+
const int nb12,
|
| 218 |
+
const int nb13,
|
| 219 |
+
const int ne0,
|
| 220 |
+
const int ne1,
|
| 221 |
+
const int ne2,
|
| 222 |
+
const int ne3) {
|
| 223 |
+
#if FP16_MMA_AVAILABLE
|
| 224 |
+
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
| 225 |
+
|
| 226 |
+
const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
|
| 227 |
+
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
| 228 |
+
|
| 229 |
+
static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
|
| 230 |
+
static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
|
| 231 |
+
constexpr int frag_m = ncols == 8 ? 32 : 16;
|
| 232 |
+
constexpr int frag_n = ncols == 8 ? 8 : 16;
|
| 233 |
+
static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
|
| 234 |
+
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::row_major> frag_a_K;
|
| 235 |
+
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_a_V;
|
| 236 |
+
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b;
|
| 237 |
+
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
|
| 238 |
+
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
|
| 239 |
+
|
| 240 |
+
constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
|
| 241 |
+
constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
|
| 242 |
+
static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps.");
|
| 243 |
+
|
| 244 |
+
// Pad internal representation of KQ, KQV to reduce shared memory bank conflicts:
|
| 245 |
+
constexpr int D_padded = D + 8;
|
| 246 |
+
constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
|
| 247 |
+
constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
|
| 248 |
+
|
| 249 |
+
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
| 250 |
+
const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0);
|
| 251 |
+
const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio));
|
| 252 |
+
const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
|
| 253 |
+
const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0;
|
| 254 |
+
const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2);
|
| 255 |
+
|
| 256 |
+
const int stride_Q = nb01 / sizeof(float);
|
| 257 |
+
const int stride_KV = nb11 / sizeof(half);
|
| 258 |
+
|
| 259 |
+
frag_b Q_b[D/16][ncols/frag_n];
|
| 260 |
+
|
| 261 |
+
// A single buffer for temporarily holding tiles of KQ and VKQ parts:
|
| 262 |
+
constexpr int mem_KQ = ncols*kqs_padded*kqar;
|
| 263 |
+
constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded;
|
| 264 |
+
__shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts];
|
| 265 |
+
float * KQ_f = (float *) KQ;
|
| 266 |
+
half2 * KQ2 = (half2 *) KQ;
|
| 267 |
+
|
| 268 |
+
float KQ_rowsum_f[ncols/nwarps] = {0.0f};
|
| 269 |
+
float KQ_max_f[ncols/nwarps];
|
| 270 |
+
float KQ_max_scale_f[ncols/nwarps] = {0.0f};
|
| 271 |
+
|
| 272 |
+
#pragma unroll
|
| 273 |
+
for (int j = 0; j < ncols/nwarps; ++j) {
|
| 274 |
+
KQ_max_f[j] = -FLT_MAX/2.0f;
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
half2 KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}};
|
| 278 |
+
half2 KQ_max_h2[ncols/nwarps];
|
| 279 |
+
half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}};
|
| 280 |
+
|
| 281 |
+
#pragma unroll
|
| 282 |
+
for (int j = 0; j < ncols/nwarps; ++j) {
|
| 283 |
+
KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF);
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
__shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
|
| 287 |
+
half2 * VKQ2 = (half2 *) VKQ;
|
| 288 |
+
#pragma unroll
|
| 289 |
+
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
| 290 |
+
const int j = j0 + threadIdx.y;
|
| 291 |
+
#pragma unroll
|
| 292 |
+
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
| 293 |
+
const int i = i0 + threadIdx.x;
|
| 294 |
+
if (i0 + WARP_SIZE > D/2 && i >= D/2) {
|
| 295 |
+
break;
|
| 296 |
+
}
|
| 297 |
+
VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f);
|
| 298 |
+
}
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
// Convert Q to half and apply scale, temporarily store in KQ:
|
| 302 |
+
#pragma unroll
|
| 303 |
+
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
| 304 |
+
const int j = j0 + threadIdx.y;
|
| 305 |
+
#pragma unroll
|
| 306 |
+
for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
|
| 307 |
+
const int i = i0 + threadIdx.x;
|
| 308 |
+
if (i0 + WARP_SIZE > D && i >= D) {
|
| 309 |
+
break;
|
| 310 |
+
}
|
| 311 |
+
KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
|
| 312 |
+
}
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
__syncthreads();
|
| 316 |
+
|
| 317 |
+
// Load Q into tensor core fragments/registers since it will be used frequently:
|
| 318 |
+
#pragma unroll
|
| 319 |
+
for (int i0 = 0; i0 < D; i0 += 16) {
|
| 320 |
+
#pragma unroll
|
| 321 |
+
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
| 322 |
+
nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
|
| 323 |
+
}
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
__syncthreads();
|
| 327 |
+
|
| 328 |
+
// Iterate over ne11 == previous tokens:
|
| 329 |
+
for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) {
|
| 330 |
+
// Calculate tile of KQ:
|
| 331 |
+
#pragma unroll
|
| 332 |
+
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
|
| 333 |
+
frag_c_KQ KQ_c[ncols/frag_n];
|
| 334 |
+
#pragma unroll
|
| 335 |
+
for (int j = 0; j < ncols/frag_n; ++j) {
|
| 336 |
+
nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
|
| 337 |
+
}
|
| 338 |
+
#pragma unroll
|
| 339 |
+
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
|
| 340 |
+
frag_a_K K_a;
|
| 341 |
+
nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
|
| 342 |
+
#pragma unroll
|
| 343 |
+
for (int j = 0; j < ncols/frag_n; ++j) {
|
| 344 |
+
nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
|
| 345 |
+
}
|
| 346 |
+
}
|
| 347 |
+
#pragma unroll
|
| 348 |
+
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
| 349 |
+
nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major);
|
| 350 |
+
}
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
__syncthreads();
|
| 354 |
+
|
| 355 |
+
// Calculate softmax for each KQ column using the current max. value.
|
| 356 |
+
// The divisor is stored in KQ_rowsum and will be applied at the end.
|
| 357 |
+
#pragma unroll
|
| 358 |
+
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
| 359 |
+
const int j = j0 + threadIdx.y;
|
| 360 |
+
|
| 361 |
+
if (std::is_same<KQ_acc_t, float>::value) {
|
| 362 |
+
float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE];
|
| 363 |
+
#pragma unroll
|
| 364 |
+
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
|
| 365 |
+
const int k = k0 + threadIdx.x;
|
| 366 |
+
|
| 367 |
+
KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k];
|
| 368 |
+
}
|
| 369 |
+
|
| 370 |
+
float KQ_max_new = KQ_max_f[j0/nwarps];
|
| 371 |
+
#pragma unroll
|
| 372 |
+
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
|
| 373 |
+
const int k = k0 + threadIdx.x;
|
| 374 |
+
|
| 375 |
+
KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
|
| 376 |
+
KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]);
|
| 377 |
+
}
|
| 378 |
+
KQ_max_new = warp_reduce_max(KQ_max_new);
|
| 379 |
+
|
| 380 |
+
const float diff = KQ_max_f[j0/nwarps] - KQ_max_new;
|
| 381 |
+
KQ_max_scale_f[j0/nwarps] = expf(diff);
|
| 382 |
+
if (diff <= SOFTMAX_FTZ_THRESHOLD) {
|
| 383 |
+
KQ_max_scale_f[j0/nwarps] = 0.0f;
|
| 384 |
+
}
|
| 385 |
+
KQ_max_f[j0/nwarps] = KQ_max_new;
|
| 386 |
+
|
| 387 |
+
float KQ_rowsum_add = 0.0f;
|
| 388 |
+
#pragma unroll
|
| 389 |
+
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
|
| 390 |
+
const int k = k0 + threadIdx.x;
|
| 391 |
+
|
| 392 |
+
const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps];
|
| 393 |
+
KQ_f_tmp[k0/WARP_SIZE] = expf(diff);
|
| 394 |
+
if (diff <= SOFTMAX_FTZ_THRESHOLD) {
|
| 395 |
+
KQ_f_tmp[k0/WARP_SIZE] = 0.0f;
|
| 396 |
+
}
|
| 397 |
+
KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE];
|
| 398 |
+
KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE];
|
| 399 |
+
}
|
| 400 |
+
KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
|
| 401 |
+
|
| 402 |
+
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
|
| 403 |
+
KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;
|
| 404 |
+
} else {
|
| 405 |
+
half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)];
|
| 406 |
+
#pragma unroll
|
| 407 |
+
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
|
| 408 |
+
const int k = k0 + threadIdx.x;
|
| 409 |
+
|
| 410 |
+
KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
|
| 411 |
+
}
|
| 412 |
+
|
| 413 |
+
half2 KQ_max_new = KQ_max_h2[j0/nwarps];
|
| 414 |
+
#pragma unroll
|
| 415 |
+
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
|
| 416 |
+
const int k = k0 + threadIdx.x;
|
| 417 |
+
|
| 418 |
+
KQ2_tmp[k0/WARP_SIZE] += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
|
| 419 |
+
KQ_max_new = __hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
|
| 420 |
+
}
|
| 421 |
+
KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
|
| 422 |
+
const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
|
| 423 |
+
KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
|
| 424 |
+
const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
|
| 425 |
+
*((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask;
|
| 426 |
+
KQ_max_h2[j0/nwarps] = KQ_max_new;
|
| 427 |
+
|
| 428 |
+
half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
|
| 429 |
+
#pragma unroll
|
| 430 |
+
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
|
| 431 |
+
const int k = k0 + threadIdx.x;
|
| 432 |
+
|
| 433 |
+
const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps];
|
| 434 |
+
KQ2_tmp[k0/WARP_SIZE] = h2exp(diff);
|
| 435 |
+
const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
|
| 436 |
+
*((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask;
|
| 437 |
+
KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE];
|
| 438 |
+
KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE];
|
| 439 |
+
}
|
| 440 |
+
KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
|
| 441 |
+
|
| 442 |
+
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
|
| 443 |
+
KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add;
|
| 444 |
+
}
|
| 445 |
+
}
|
| 446 |
+
|
| 447 |
+
__syncthreads();
|
| 448 |
+
|
| 449 |
+
frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n];
|
| 450 |
+
#pragma unroll
|
| 451 |
+
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
| 452 |
+
#pragma unroll
|
| 453 |
+
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
|
| 454 |
+
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
|
| 455 |
+
nvcuda::wmma::load_matrix_sync(
|
| 456 |
+
KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
|
| 457 |
+
KQ + j0*(kqar*kqs_padded) + k,
|
| 458 |
+
kqar*kqs_padded);
|
| 459 |
+
}
|
| 460 |
+
}
|
| 461 |
+
|
| 462 |
+
frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n];
|
| 463 |
+
#pragma unroll
|
| 464 |
+
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
|
| 465 |
+
#pragma unroll
|
| 466 |
+
for (int j = 0; j < ncols/frag_n; ++j) {
|
| 467 |
+
nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
|
| 468 |
+
}
|
| 469 |
+
|
| 470 |
+
#pragma unroll
|
| 471 |
+
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
|
| 472 |
+
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
|
| 473 |
+
|
| 474 |
+
frag_a_V v_a;
|
| 475 |
+
nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
|
| 476 |
+
#pragma unroll
|
| 477 |
+
for (int j = 0; j < ncols/frag_n; ++j) {
|
| 478 |
+
nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
|
| 479 |
+
}
|
| 480 |
+
}
|
| 481 |
+
}
|
| 482 |
+
|
| 483 |
+
__syncthreads();
|
| 484 |
+
|
| 485 |
+
const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded);
|
| 486 |
+
#pragma unroll
|
| 487 |
+
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
|
| 488 |
+
#pragma unroll
|
| 489 |
+
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
| 490 |
+
nvcuda::wmma::store_matrix_sync(
|
| 491 |
+
KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
|
| 492 |
+
VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
|
| 493 |
+
D_padded, nvcuda::wmma::mem_col_major);
|
| 494 |
+
}
|
| 495 |
+
}
|
| 496 |
+
|
| 497 |
+
__syncthreads();
|
| 498 |
+
|
| 499 |
+
#pragma unroll
|
| 500 |
+
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
| 501 |
+
const int j = j0 + threadIdx.y;
|
| 502 |
+
|
| 503 |
+
half2 VKQ_scale;
|
| 504 |
+
if (std::is_same<KQ_acc_t, float>::value) {
|
| 505 |
+
VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]);
|
| 506 |
+
} else {
|
| 507 |
+
VKQ_scale = KQ_max_scale_h2[j0/nwarps];
|
| 508 |
+
}
|
| 509 |
+
|
| 510 |
+
#pragma unroll
|
| 511 |
+
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
| 512 |
+
const int i = i0 + threadIdx.x;
|
| 513 |
+
if (i0 + WARP_SIZE > D/2 && i >= D/2) {
|
| 514 |
+
break;
|
| 515 |
+
}
|
| 516 |
+
|
| 517 |
+
half2 VKQ_add = make_half2(0.0f, 0.0f);
|
| 518 |
+
#pragma unroll
|
| 519 |
+
for (int l = 0; l < VKQ_ratio; ++l) {
|
| 520 |
+
VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i];
|
| 521 |
+
}
|
| 522 |
+
VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add;
|
| 523 |
+
}
|
| 524 |
+
}
|
| 525 |
+
|
| 526 |
+
__syncthreads();
|
| 527 |
+
}
|
| 528 |
+
|
| 529 |
+
#pragma unroll
|
| 530 |
+
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
| 531 |
+
const int j_VKQ = j0 + threadIdx.y;
|
| 532 |
+
if (ic0 + j_VKQ >= ne01) {
|
| 533 |
+
return;
|
| 534 |
+
}
|
| 535 |
+
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
|
| 536 |
+
|
| 537 |
+
float KQ_rowsum_j;
|
| 538 |
+
if (std::is_same<KQ_acc_t, float>::value) {
|
| 539 |
+
KQ_rowsum_j = KQ_rowsum_f[j0/nwarps];
|
| 540 |
+
} else {
|
| 541 |
+
KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
|
| 542 |
+
}
|
| 543 |
+
|
| 544 |
+
#pragma unroll
|
| 545 |
+
for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
|
| 546 |
+
const int i = i0 + threadIdx.x;
|
| 547 |
+
if (i0 + WARP_SIZE > D && i >= D) {
|
| 548 |
+
break;
|
| 549 |
+
}
|
| 550 |
+
float dst_val = VKQ[j_VKQ*D_padded + i];
|
| 551 |
+
if (parallel_blocks == 1) {
|
| 552 |
+
dst_val /= KQ_rowsum_j;
|
| 553 |
+
}
|
| 554 |
+
dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val;
|
| 555 |
+
}
|
| 556 |
+
|
| 557 |
+
if (parallel_blocks == 1 || threadIdx.x != 0) {
|
| 558 |
+
continue;
|
| 559 |
+
}
|
| 560 |
+
|
| 561 |
+
float2 dst_meta_val;
|
| 562 |
+
if (std::is_same<KQ_acc_t, float>::value) {
|
| 563 |
+
dst_meta_val.x = KQ_max_f[j0/nwarps];
|
| 564 |
+
} else {
|
| 565 |
+
dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
|
| 566 |
+
}
|
| 567 |
+
dst_meta_val.y = KQ_rowsum_j;
|
| 568 |
+
dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val;
|
| 569 |
+
}
|
| 570 |
+
#else
|
| 571 |
+
NO_DEVICE_CODE;
|
| 572 |
+
#endif // FP16_MMA_AVAILABLE
|
| 573 |
+
}
|
| 574 |
+
|
| 575 |
+
template<int D, int parallel_blocks> // D == head size
|
| 576 |
+
__launch_bounds__(D, 1)
|
| 577 |
+
static __global__ void flash_attn_combine_results(
|
| 578 |
+
const float * __restrict__ VKQ_parts,
|
| 579 |
+
const float2 * __restrict__ VKQ_meta,
|
| 580 |
+
float * __restrict__ dst) {
|
| 581 |
+
#if FP16_AVAILABLE
|
| 582 |
+
VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x;
|
| 583 |
+
VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x;
|
| 584 |
+
dst += D * gridDim.y*blockIdx.x;
|
| 585 |
+
|
| 586 |
+
const int tid = threadIdx.x;
|
| 587 |
+
__builtin_assume(tid < D);
|
| 588 |
+
|
| 589 |
+
__shared__ float2 meta[parallel_blocks];
|
| 590 |
+
if (tid < 2*parallel_blocks) {
|
| 591 |
+
((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid];
|
| 592 |
+
}
|
| 593 |
+
|
| 594 |
+
__syncthreads();
|
| 595 |
+
|
| 596 |
+
float kqmax = meta[0].x;
|
| 597 |
+
#pragma unroll
|
| 598 |
+
for (int l = 1; l < parallel_blocks; ++l) {
|
| 599 |
+
kqmax = max(kqmax, meta[l].x);
|
| 600 |
+
}
|
| 601 |
+
|
| 602 |
+
float VKQ_numerator = 0.0f;
|
| 603 |
+
float VKQ_denominator = 0.0f;
|
| 604 |
+
#pragma unroll
|
| 605 |
+
for (int l = 0; l < parallel_blocks; ++l) {
|
| 606 |
+
const float diff = meta[l].x - kqmax;
|
| 607 |
+
const float KQ_max_scale = expf(diff);
|
| 608 |
+
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
|
| 609 |
+
*((uint32_t *) &KQ_max_scale) &= ftz_mask;
|
| 610 |
+
|
| 611 |
+
VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid];
|
| 612 |
+
VKQ_denominator += KQ_max_scale * meta[l].y;
|
| 613 |
+
}
|
| 614 |
+
|
| 615 |
+
dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
|
| 616 |
+
#else
|
| 617 |
+
NO_DEVICE_CODE;
|
| 618 |
+
#endif // FP16_AVAILABLE
|
| 619 |
+
}
|
| 620 |
+
|
| 621 |
+
constexpr int get_max_power_of_2(int x) {
|
| 622 |
+
return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1;
|
| 623 |
+
}
|
| 624 |
+
|
| 625 |
+
static_assert(get_max_power_of_2(1) == 1, "Test failed.");
|
| 626 |
+
static_assert(get_max_power_of_2(2) == 2, "Test failed.");
|
| 627 |
+
static_assert(get_max_power_of_2(4) == 4, "Test failed.");
|
| 628 |
+
static_assert(get_max_power_of_2(6) == 2, "Test failed.");
|
| 629 |
+
|
| 630 |
+
// Number of VKQ rows calculated in parallel:
|
| 631 |
+
constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) {
|
| 632 |
+
return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m;
|
| 633 |
+
}
|
| 634 |
+
|
| 635 |
+
static_assert(get_VKQ_stride(128, 1, 32) == 32, "Test failed.");
|
| 636 |
+
static_assert(get_VKQ_stride(128, 2, 32) == 64, "Test failed.");
|
| 637 |
+
static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed.");
|
| 638 |
+
static_assert(get_VKQ_stride( 64, 1, 32) == 32, "Test failed.");
|
| 639 |
+
static_assert(get_VKQ_stride( 64, 2, 32) == 64, "Test failed.");
|
| 640 |
+
static_assert(get_VKQ_stride( 64, 4, 32) == 64, "Test failed.");
|
| 641 |
+
static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed.");
|
| 642 |
+
static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed.");
|
| 643 |
+
static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
|
| 644 |
+
|
| 645 |
+
template <int D, int parallel_blocks> void launch_fattn_vec_f16(
|
| 646 |
+
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
|
| 647 |
+
ggml_cuda_pool & pool, cudaStream_t main_stream
|
| 648 |
+
) {
|
| 649 |
+
ggml_cuda_pool_alloc<float> dst_tmp(pool);
|
| 650 |
+
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
|
| 651 |
+
|
| 652 |
+
if (parallel_blocks > 1) {
|
| 653 |
+
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
|
| 654 |
+
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
|
| 655 |
+
}
|
| 656 |
+
|
| 657 |
+
constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE;
|
| 658 |
+
const dim3 block_dim(WARP_SIZE, nwarps, 1);
|
| 659 |
+
const dim3 blocks_num(parallel_blocks*Q->ne[1], Q->ne[2], Q->ne[3]);
|
| 660 |
+
const int shmem = 0;
|
| 661 |
+
|
| 662 |
+
float scale;
|
| 663 |
+
memcpy(&scale, KQV->op_params, sizeof(float));
|
| 664 |
+
|
| 665 |
+
flash_attn_vec_ext_f16<D, parallel_blocks>
|
| 666 |
+
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
| 667 |
+
(const char *) Q->data,
|
| 668 |
+
(const char *) K->data,
|
| 669 |
+
(const char *) V->data,
|
| 670 |
+
mask ? ((const char *) mask->data) : nullptr,
|
| 671 |
+
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
|
| 672 |
+
scale,
|
| 673 |
+
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
| 674 |
+
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
| 675 |
+
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
| 676 |
+
Q->nb[1], Q->nb[2], Q->nb[3],
|
| 677 |
+
K->nb[1], K->nb[2], K->nb[3],
|
| 678 |
+
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
| 679 |
+
);
|
| 680 |
+
CUDA_CHECK(cudaGetLastError());
|
| 681 |
+
|
| 682 |
+
if (parallel_blocks == 1) {
|
| 683 |
+
return;
|
| 684 |
+
}
|
| 685 |
+
|
| 686 |
+
const dim3 block_dim_combine(D, 1, 1);
|
| 687 |
+
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
|
| 688 |
+
const int shmem_combine = 0;
|
| 689 |
+
|
| 690 |
+
flash_attn_combine_results<D, parallel_blocks>
|
| 691 |
+
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
|
| 692 |
+
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
|
| 693 |
+
CUDA_CHECK(cudaGetLastError());
|
| 694 |
+
}
|
| 695 |
+
|
| 696 |
+
template <int D, int cols_per_block, int nwarps, int parallel_blocks, typename KQ_acc_t> void launch_fattn_f16_impl(
|
| 697 |
+
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
|
| 698 |
+
ggml_cuda_pool & pool, cudaStream_t main_stream
|
| 699 |
+
) {
|
| 700 |
+
ggml_cuda_pool_alloc<float> dst_tmp(pool);
|
| 701 |
+
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
|
| 702 |
+
|
| 703 |
+
if (parallel_blocks > 1) {
|
| 704 |
+
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
|
| 705 |
+
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
|
| 706 |
+
}
|
| 707 |
+
|
| 708 |
+
constexpr int frag_m = (cols_per_block) == 8 && (D) % 32 == 0 ? 32 : 16;
|
| 709 |
+
const dim3 block_dim(WARP_SIZE, nwarps, 1);
|
| 710 |
+
const dim3 blocks_num(parallel_blocks*(Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]);
|
| 711 |
+
const int shmem = 0;
|
| 712 |
+
|
| 713 |
+
float scale;
|
| 714 |
+
memcpy(&scale, KQV->op_params, sizeof(float));
|
| 715 |
+
|
| 716 |
+
flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>
|
| 717 |
+
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
| 718 |
+
(const char *) Q->data,
|
| 719 |
+
(const char *) K->data,
|
| 720 |
+
(const char *) V->data,
|
| 721 |
+
mask ? ((const char *) mask->data) : nullptr,
|
| 722 |
+
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
|
| 723 |
+
scale,
|
| 724 |
+
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
| 725 |
+
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
| 726 |
+
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
| 727 |
+
Q->nb[1], Q->nb[2], Q->nb[3],
|
| 728 |
+
K->nb[1], K->nb[2], K->nb[3],
|
| 729 |
+
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
| 730 |
+
);
|
| 731 |
+
CUDA_CHECK(cudaGetLastError());
|
| 732 |
+
|
| 733 |
+
if ((parallel_blocks) == 1) {
|
| 734 |
+
return;
|
| 735 |
+
}
|
| 736 |
+
|
| 737 |
+
const dim3 block_dim_combine(D, 1, 1);
|
| 738 |
+
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
|
| 739 |
+
const int shmem_combine = 0;
|
| 740 |
+
|
| 741 |
+
flash_attn_combine_results<D, parallel_blocks>
|
| 742 |
+
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
|
| 743 |
+
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
|
| 744 |
+
CUDA_CHECK(cudaGetLastError());
|
| 745 |
+
}
|
| 746 |
+
|
| 747 |
+
template <int D, int cols_per_block, int nwarps, typename KQ_acc_t> void launch_fattn_f16(
|
| 748 |
+
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
|
| 749 |
+
const int nsm, ggml_cuda_pool & pool, cudaStream_t main_stream
|
| 750 |
+
) {
|
| 751 |
+
const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
|
| 752 |
+
|
| 753 |
+
if (4*blocks_num_pb1 < 2*nsm) {
|
| 754 |
+
launch_fattn_f16_impl<D, cols_per_block, nwarps, 4, KQ_acc_t>(Q, K, V, KQV, mask, pool, main_stream);
|
| 755 |
+
return;
|
| 756 |
+
}
|
| 757 |
+
if (2*blocks_num_pb1 < 2*nsm) {
|
| 758 |
+
launch_fattn_f16_impl<D, cols_per_block, nwarps, 2, KQ_acc_t>(Q, K, V, KQV, mask, pool, main_stream);
|
| 759 |
+
return;
|
| 760 |
+
}
|
| 761 |
+
launch_fattn_f16_impl<D, cols_per_block, nwarps, 1, KQ_acc_t>(Q, K, V, KQV, mask, pool, main_stream);
|
| 762 |
+
}
|
| 763 |
+
|
| 764 |
+
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 765 |
+
const ggml_tensor * Q = dst->src[0];
|
| 766 |
+
const ggml_tensor * K = dst->src[1];
|
| 767 |
+
const ggml_tensor * V = dst->src[2];
|
| 768 |
+
|
| 769 |
+
const ggml_tensor * mask = dst->src[3];
|
| 770 |
+
|
| 771 |
+
ggml_tensor * KQV = dst;
|
| 772 |
+
|
| 773 |
+
GGML_ASSERT(Q->type == GGML_TYPE_F32);
|
| 774 |
+
GGML_ASSERT(K->type == GGML_TYPE_F16);
|
| 775 |
+
GGML_ASSERT(V->type == GGML_TYPE_F16);
|
| 776 |
+
GGML_ASSERT(KQV->type == GGML_TYPE_F32);
|
| 777 |
+
|
| 778 |
+
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
|
| 779 |
+
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
|
| 780 |
+
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
|
| 781 |
+
|
| 782 |
+
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
|
| 783 |
+
|
| 784 |
+
ggml_cuda_set_device(ctx.device);
|
| 785 |
+
|
| 786 |
+
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
|
| 787 |
+
|
| 788 |
+
const int32_t precision = KQV->op_params[1];
|
| 789 |
+
|
| 790 |
+
if (precision != GGML_PREC_DEFAULT) {
|
| 791 |
+
if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
|
| 792 |
+
constexpr int cols_per_block = 16;
|
| 793 |
+
constexpr int nwarps = 4;
|
| 794 |
+
switch (Q->ne[0]) {
|
| 795 |
+
case 64:
|
| 796 |
+
launch_fattn_f16< 64, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
| 797 |
+
break;
|
| 798 |
+
case 80:
|
| 799 |
+
launch_fattn_f16< 80, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
| 800 |
+
break;
|
| 801 |
+
case 96:
|
| 802 |
+
launch_fattn_f16< 96, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
| 803 |
+
break;
|
| 804 |
+
case 112:
|
| 805 |
+
launch_fattn_f16<112, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
| 806 |
+
break;
|
| 807 |
+
case 128:
|
| 808 |
+
launch_fattn_f16<128, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
| 809 |
+
break;
|
| 810 |
+
case 256:
|
| 811 |
+
launch_fattn_f16<256, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
| 812 |
+
break;
|
| 813 |
+
default:
|
| 814 |
+
GGML_ASSERT(false);
|
| 815 |
+
break;
|
| 816 |
+
}
|
| 817 |
+
} else {
|
| 818 |
+
constexpr int cols_per_block = 32;
|
| 819 |
+
constexpr int nwarps = 4;
|
| 820 |
+
switch (Q->ne[0]) {
|
| 821 |
+
case 64:
|
| 822 |
+
launch_fattn_f16< 64, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
| 823 |
+
break;
|
| 824 |
+
case 80:
|
| 825 |
+
launch_fattn_f16< 80, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
| 826 |
+
break;
|
| 827 |
+
case 96:
|
| 828 |
+
launch_fattn_f16< 96, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
| 829 |
+
break;
|
| 830 |
+
case 112:
|
| 831 |
+
launch_fattn_f16<112, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
| 832 |
+
break;
|
| 833 |
+
case 128:
|
| 834 |
+
launch_fattn_f16<128, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
| 835 |
+
break;
|
| 836 |
+
// case 256:
|
| 837 |
+
// launch_fattn_f16<256, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
| 838 |
+
// break;
|
| 839 |
+
default:
|
| 840 |
+
GGML_ASSERT(false);
|
| 841 |
+
break;
|
| 842 |
+
}
|
| 843 |
+
}
|
| 844 |
+
return;
|
| 845 |
+
}
|
| 846 |
+
|
| 847 |
+
if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
|
| 848 |
+
constexpr int parallel_blocks = 4;
|
| 849 |
+
switch (Q->ne[0]) {
|
| 850 |
+
case 64:
|
| 851 |
+
launch_fattn_vec_f16< 64, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 852 |
+
break;
|
| 853 |
+
case 128:
|
| 854 |
+
launch_fattn_vec_f16<128, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 855 |
+
break;
|
| 856 |
+
case 256:
|
| 857 |
+
launch_fattn_vec_f16<256, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 858 |
+
break;
|
| 859 |
+
default:
|
| 860 |
+
GGML_ASSERT(false);
|
| 861 |
+
break;
|
| 862 |
+
}
|
| 863 |
+
return;
|
| 864 |
+
}
|
| 865 |
+
|
| 866 |
+
if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
|
| 867 |
+
constexpr int cols_per_block = 8;
|
| 868 |
+
constexpr int nwarps = 4;
|
| 869 |
+
switch (Q->ne[0]) {
|
| 870 |
+
case 64:
|
| 871 |
+
launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
| 872 |
+
break;
|
| 873 |
+
case 96:
|
| 874 |
+
launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
| 875 |
+
break;
|
| 876 |
+
case 128:
|
| 877 |
+
launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
| 878 |
+
break;
|
| 879 |
+
case 256:
|
| 880 |
+
launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
| 881 |
+
break;
|
| 882 |
+
default:
|
| 883 |
+
GGML_ASSERT(false);
|
| 884 |
+
break;
|
| 885 |
+
}
|
| 886 |
+
return;
|
| 887 |
+
}
|
| 888 |
+
|
| 889 |
+
if (Q->ne[1] <= 32) {
|
| 890 |
+
constexpr int cols_per_block = 16;
|
| 891 |
+
constexpr int nwarps = 4;
|
| 892 |
+
switch (Q->ne[0]) {
|
| 893 |
+
case 64:
|
| 894 |
+
launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
| 895 |
+
break;
|
| 896 |
+
case 80:
|
| 897 |
+
launch_fattn_f16< 80, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
| 898 |
+
break;
|
| 899 |
+
case 96:
|
| 900 |
+
launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
| 901 |
+
break;
|
| 902 |
+
case 112:
|
| 903 |
+
launch_fattn_f16<112, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
| 904 |
+
break;
|
| 905 |
+
case 128:
|
| 906 |
+
launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
| 907 |
+
break;
|
| 908 |
+
case 256:
|
| 909 |
+
launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
| 910 |
+
break;
|
| 911 |
+
default:
|
| 912 |
+
GGML_ASSERT(false);
|
| 913 |
+
break;
|
| 914 |
+
}
|
| 915 |
+
return;
|
| 916 |
+
}
|
| 917 |
+
|
| 918 |
+
constexpr int cols_per_block = 32;
|
| 919 |
+
constexpr int nwarps = 4;
|
| 920 |
+
switch (Q->ne[0]) {
|
| 921 |
+
case 64:
|
| 922 |
+
launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
| 923 |
+
break;
|
| 924 |
+
case 80:
|
| 925 |
+
launch_fattn_f16< 80, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
| 926 |
+
break;
|
| 927 |
+
case 96:
|
| 928 |
+
launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
| 929 |
+
break;
|
| 930 |
+
case 112:
|
| 931 |
+
launch_fattn_f16<112, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
| 932 |
+
break;
|
| 933 |
+
case 128:
|
| 934 |
+
launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
| 935 |
+
break;
|
| 936 |
+
case 256:
|
| 937 |
+
launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
| 938 |
+
break;
|
| 939 |
+
default:
|
| 940 |
+
GGML_ASSERT(false);
|
| 941 |
+
break;
|
| 942 |
+
}
|
| 943 |
+
return;
|
| 944 |
+
}
|
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "common.cuh"
|
| 2 |
+
|
| 3 |
+
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|
@@ -1,7 +1,17 @@
|
|
| 1 |
#include "softmax.cuh"
|
| 2 |
|
| 3 |
-
template <
|
| 4 |
-
static
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
|
| 6 |
|
| 7 |
const int tid = threadIdx.x;
|
|
@@ -43,7 +53,7 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f
|
|
| 43 |
const int64_t ix = (int64_t)rowx*ncols + col;
|
| 44 |
const int64_t iy = (int64_t)rowy*ncols + col;
|
| 45 |
|
| 46 |
-
const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + (pos ? slope*pos[col] : 0.0f);
|
| 47 |
|
| 48 |
vals[col] = val;
|
| 49 |
max_val = max(max_val, val);
|
|
@@ -114,7 +124,8 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f
|
|
| 114 |
}
|
| 115 |
}
|
| 116 |
|
| 117 |
-
|
|
|
|
| 118 |
int nth = WARP_SIZE;
|
| 119 |
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
| 120 |
const dim3 block_dims(nth, 1, 1);
|
|
@@ -167,15 +178,19 @@ static void soft_max_f32_cuda(const float * x, const float * mask, const float *
|
|
| 167 |
void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 168 |
const ggml_tensor * src0 = dst->src[0];
|
| 169 |
const ggml_tensor * src1 = dst->src[1];
|
|
|
|
|
|
|
| 170 |
const float * src0_d = (const float *)src0->data;
|
| 171 |
-
const
|
|
|
|
| 172 |
float * dst_d = (float *)dst->data;
|
| 173 |
cudaStream_t stream = ctx.stream();
|
| 174 |
|
| 175 |
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 176 |
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 177 |
|
| 178 |
-
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
|
|
|
|
| 179 |
|
| 180 |
const int64_t ne00 = src0->ne[0];
|
| 181 |
const int64_t nrows_x = ggml_nrows(src0);
|
|
@@ -188,14 +203,25 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
| 188 |
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
| 189 |
|
| 190 |
// positions tensor
|
| 191 |
-
|
| 192 |
|
| 193 |
-
ggml_tensor * src2 = dst->src[2];
|
| 194 |
const bool use_src2 = src2 != nullptr;
|
| 195 |
|
| 196 |
if (use_src2) {
|
| 197 |
-
|
| 198 |
}
|
| 199 |
|
| 200 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
}
|
|
|
|
| 1 |
#include "softmax.cuh"
|
| 2 |
|
| 3 |
+
template <typename T>
|
| 4 |
+
static __device__ __forceinline__ float t2f32(T val) {
|
| 5 |
+
return (float) val;
|
| 6 |
+
}
|
| 7 |
+
|
| 8 |
+
template <>
|
| 9 |
+
__device__ float __forceinline__ t2f32<half>(half val) {
|
| 10 |
+
return __half2float(val);
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
template <bool vals_smem, int ncols_template, int block_size_template, typename T>
|
| 14 |
+
static __global__ void soft_max_f32(const float * x, const T * mask, const T * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
|
| 15 |
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
|
| 16 |
|
| 17 |
const int tid = threadIdx.x;
|
|
|
|
| 53 |
const int64_t ix = (int64_t)rowx*ncols + col;
|
| 54 |
const int64_t iy = (int64_t)rowy*ncols + col;
|
| 55 |
|
| 56 |
+
const float val = x[ix]*scale + (mask ? t2f32(mask[iy]) : 0.0f) + (pos ? slope*t2f32(pos[col]) : 0.0f);
|
| 57 |
|
| 58 |
vals[col] = val;
|
| 59 |
max_val = max(max_val, val);
|
|
|
|
| 124 |
}
|
| 125 |
}
|
| 126 |
|
| 127 |
+
template<typename T>
|
| 128 |
+
static void soft_max_f32_cuda(const float * x, const T * mask, const T * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
|
| 129 |
int nth = WARP_SIZE;
|
| 130 |
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
| 131 |
const dim3 block_dims(nth, 1, 1);
|
|
|
|
| 178 |
void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 179 |
const ggml_tensor * src0 = dst->src[0];
|
| 180 |
const ggml_tensor * src1 = dst->src[1];
|
| 181 |
+
const ggml_tensor * src2 = dst->src[2];
|
| 182 |
+
|
| 183 |
const float * src0_d = (const float *)src0->data;
|
| 184 |
+
const void * src1_d = src1 ? (const void *)src1->data : nullptr;
|
| 185 |
+
|
| 186 |
float * dst_d = (float *)dst->data;
|
| 187 |
cudaStream_t stream = ctx.stream();
|
| 188 |
|
| 189 |
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 190 |
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 191 |
|
| 192 |
+
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
|
| 193 |
+
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional
|
| 194 |
|
| 195 |
const int64_t ne00 = src0->ne[0];
|
| 196 |
const int64_t nrows_x = ggml_nrows(src0);
|
|
|
|
| 203 |
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
| 204 |
|
| 205 |
// positions tensor
|
| 206 |
+
void * src2_d = nullptr;
|
| 207 |
|
|
|
|
| 208 |
const bool use_src2 = src2 != nullptr;
|
| 209 |
|
| 210 |
if (use_src2) {
|
| 211 |
+
src2_d = (void *)src2->data;
|
| 212 |
}
|
| 213 |
|
| 214 |
+
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
|
| 215 |
+
|
| 216 |
+
if (use_f16) {
|
| 217 |
+
const half * src1_dd = (const half *)src1_d;
|
| 218 |
+
const half * src2_dd = (const half *)src2_d;
|
| 219 |
+
|
| 220 |
+
soft_max_f32_cuda(src0_d, src1_dd, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
|
| 221 |
+
} else {
|
| 222 |
+
const float * src1_dd = (const float *)src1_d;
|
| 223 |
+
const float * src2_dd = (const float *)src2_d;
|
| 224 |
+
|
| 225 |
+
soft_max_f32_cuda(src0_d, src1_dd, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
|
| 226 |
+
}
|
| 227 |
}
|
|
@@ -1427,6 +1427,7 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
|
|
| 1427 |
for (int i = node_start; i < node_end; ++i) {
|
| 1428 |
struct ggml_tensor * src0 = gf->nodes[i]->src[0];
|
| 1429 |
struct ggml_tensor * src1 = gf->nodes[i]->src[1];
|
|
|
|
| 1430 |
struct ggml_tensor * dst = gf->nodes[i];
|
| 1431 |
GGML_ASSERT(dst->data != nullptr);
|
| 1432 |
|
|
@@ -1559,6 +1560,12 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
|
|
| 1559 |
{
|
| 1560 |
float scale;
|
| 1561 |
memcpy(&scale, dst->op_params, sizeof(float));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1562 |
ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale);
|
| 1563 |
} break;
|
| 1564 |
case GGML_OP_DIAG_MASK_INF:
|
|
|
|
| 1427 |
for (int i = node_start; i < node_end; ++i) {
|
| 1428 |
struct ggml_tensor * src0 = gf->nodes[i]->src[0];
|
| 1429 |
struct ggml_tensor * src1 = gf->nodes[i]->src[1];
|
| 1430 |
+
struct ggml_tensor * src2 = gf->nodes[i]->src[2]; GGML_UNUSED(src2);
|
| 1431 |
struct ggml_tensor * dst = gf->nodes[i];
|
| 1432 |
GGML_ASSERT(dst->data != nullptr);
|
| 1433 |
|
|
|
|
| 1560 |
{
|
| 1561 |
float scale;
|
| 1562 |
memcpy(&scale, dst->op_params, sizeof(float));
|
| 1563 |
+
|
| 1564 |
+
#pragma message("TODO: add ggml_vk_soft_max() F16/F32 src1 and src2 support")
|
| 1565 |
+
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
|
| 1566 |
+
GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32);
|
| 1567 |
+
GGML_ASSERT(src2 == nullptr);
|
| 1568 |
+
|
| 1569 |
ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale);
|
| 1570 |
} break;
|
| 1571 |
case GGML_OP_DIAG_MASK_INF:
|
|
@@ -47,8 +47,10 @@ enum ggml_metal_kernel_type {
|
|
| 47 |
GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
|
| 48 |
GGML_METAL_KERNEL_TYPE_SILU,
|
| 49 |
GGML_METAL_KERNEL_TYPE_SILU_4,
|
| 50 |
-
|
| 51 |
-
|
|
|
|
|
|
|
| 52 |
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
|
| 53 |
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
|
| 54 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
|
|
@@ -178,6 +180,14 @@ enum ggml_metal_kernel_type {
|
|
| 178 |
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
|
| 179 |
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
|
| 180 |
GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
| 182 |
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
| 183 |
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
|
|
@@ -444,7 +454,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
| 444 |
}
|
| 445 |
|
| 446 |
/*
|
| 447 |
-
GGML_METAL_LOG_INFO("%s: loaded %-
|
| 448 |
(int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \
|
| 449 |
(int) kernel->pipeline.threadExecutionWidth); \
|
| 450 |
*/
|
|
@@ -460,173 +470,183 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
| 460 |
return NULL; \
|
| 461 |
} \
|
| 462 |
} else { \
|
| 463 |
-
GGML_METAL_LOG_WARN("%s: skipping %-
|
| 464 |
}
|
| 465 |
|
| 466 |
// simd_sum and simd_max requires MTLGPUFamilyApple7
|
| 467 |
|
| 468 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD,
|
| 469 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW,
|
| 470 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL,
|
| 471 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW,
|
| 472 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV,
|
| 473 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW,
|
| 474 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE,
|
| 475 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4,
|
| 476 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP,
|
| 477 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH,
|
| 478 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU,
|
| 479 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID,
|
| 480 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU,
|
| 481 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4,
|
| 482 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK,
|
| 483 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
|
| 484 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU,
|
| 485 |
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4,
|
| 486 |
-
GGML_METAL_ADD_KERNEL(
|
| 487 |
-
GGML_METAL_ADD_KERNEL(
|
| 488 |
-
GGML_METAL_ADD_KERNEL(
|
| 489 |
-
GGML_METAL_ADD_KERNEL(
|
| 490 |
-
GGML_METAL_ADD_KERNEL(
|
| 491 |
-
GGML_METAL_ADD_KERNEL(
|
| 492 |
-
GGML_METAL_ADD_KERNEL(
|
| 493 |
-
GGML_METAL_ADD_KERNEL(
|
| 494 |
-
GGML_METAL_ADD_KERNEL(
|
| 495 |
-
GGML_METAL_ADD_KERNEL(
|
| 496 |
-
GGML_METAL_ADD_KERNEL(
|
| 497 |
-
GGML_METAL_ADD_KERNEL(
|
| 498 |
-
GGML_METAL_ADD_KERNEL(
|
| 499 |
-
GGML_METAL_ADD_KERNEL(
|
| 500 |
-
GGML_METAL_ADD_KERNEL(
|
| 501 |
-
GGML_METAL_ADD_KERNEL(
|
| 502 |
-
GGML_METAL_ADD_KERNEL(
|
| 503 |
-
GGML_METAL_ADD_KERNEL(
|
| 504 |
-
GGML_METAL_ADD_KERNEL(
|
| 505 |
-
GGML_METAL_ADD_KERNEL(
|
| 506 |
-
GGML_METAL_ADD_KERNEL(
|
| 507 |
-
GGML_METAL_ADD_KERNEL(
|
| 508 |
-
GGML_METAL_ADD_KERNEL(
|
| 509 |
-
GGML_METAL_ADD_KERNEL(
|
| 510 |
-
GGML_METAL_ADD_KERNEL(
|
| 511 |
-
GGML_METAL_ADD_KERNEL(
|
| 512 |
-
GGML_METAL_ADD_KERNEL(
|
| 513 |
-
GGML_METAL_ADD_KERNEL(
|
| 514 |
-
GGML_METAL_ADD_KERNEL(
|
| 515 |
-
GGML_METAL_ADD_KERNEL(
|
| 516 |
-
GGML_METAL_ADD_KERNEL(
|
| 517 |
-
GGML_METAL_ADD_KERNEL(
|
| 518 |
-
GGML_METAL_ADD_KERNEL(
|
| 519 |
-
GGML_METAL_ADD_KERNEL(
|
| 520 |
-
GGML_METAL_ADD_KERNEL(
|
| 521 |
-
GGML_METAL_ADD_KERNEL(
|
| 522 |
-
GGML_METAL_ADD_KERNEL(
|
| 523 |
-
GGML_METAL_ADD_KERNEL(
|
| 524 |
-
GGML_METAL_ADD_KERNEL(
|
| 525 |
-
GGML_METAL_ADD_KERNEL(
|
| 526 |
-
GGML_METAL_ADD_KERNEL(
|
| 527 |
-
GGML_METAL_ADD_KERNEL(
|
| 528 |
-
GGML_METAL_ADD_KERNEL(
|
| 529 |
-
GGML_METAL_ADD_KERNEL(
|
| 530 |
-
GGML_METAL_ADD_KERNEL(
|
| 531 |
-
GGML_METAL_ADD_KERNEL(
|
| 532 |
-
GGML_METAL_ADD_KERNEL(
|
| 533 |
-
GGML_METAL_ADD_KERNEL(
|
| 534 |
-
GGML_METAL_ADD_KERNEL(
|
| 535 |
-
GGML_METAL_ADD_KERNEL(
|
| 536 |
-
GGML_METAL_ADD_KERNEL(
|
| 537 |
-
GGML_METAL_ADD_KERNEL(
|
| 538 |
-
GGML_METAL_ADD_KERNEL(
|
| 539 |
-
GGML_METAL_ADD_KERNEL(
|
| 540 |
-
|
| 541 |
-
GGML_METAL_ADD_KERNEL(
|
| 542 |
-
//GGML_METAL_ADD_KERNEL(
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
GGML_METAL_ADD_KERNEL(
|
| 547 |
-
GGML_METAL_ADD_KERNEL(
|
| 548 |
-
GGML_METAL_ADD_KERNEL(
|
| 549 |
-
GGML_METAL_ADD_KERNEL(
|
| 550 |
-
GGML_METAL_ADD_KERNEL(
|
| 551 |
-
GGML_METAL_ADD_KERNEL(
|
| 552 |
-
GGML_METAL_ADD_KERNEL(
|
| 553 |
-
GGML_METAL_ADD_KERNEL(
|
| 554 |
-
GGML_METAL_ADD_KERNEL(
|
| 555 |
-
GGML_METAL_ADD_KERNEL(
|
| 556 |
-
GGML_METAL_ADD_KERNEL(
|
| 557 |
-
GGML_METAL_ADD_KERNEL(
|
| 558 |
-
GGML_METAL_ADD_KERNEL(
|
| 559 |
-
GGML_METAL_ADD_KERNEL(
|
| 560 |
-
GGML_METAL_ADD_KERNEL(
|
| 561 |
-
GGML_METAL_ADD_KERNEL(
|
| 562 |
-
GGML_METAL_ADD_KERNEL(
|
| 563 |
-
GGML_METAL_ADD_KERNEL(
|
| 564 |
-
GGML_METAL_ADD_KERNEL(
|
| 565 |
-
GGML_METAL_ADD_KERNEL(
|
| 566 |
-
GGML_METAL_ADD_KERNEL(
|
| 567 |
-
GGML_METAL_ADD_KERNEL(
|
| 568 |
-
GGML_METAL_ADD_KERNEL(
|
| 569 |
-
GGML_METAL_ADD_KERNEL(
|
| 570 |
-
GGML_METAL_ADD_KERNEL(
|
| 571 |
-
GGML_METAL_ADD_KERNEL(
|
| 572 |
-
GGML_METAL_ADD_KERNEL(
|
| 573 |
-
GGML_METAL_ADD_KERNEL(
|
| 574 |
-
GGML_METAL_ADD_KERNEL(
|
| 575 |
-
GGML_METAL_ADD_KERNEL(
|
| 576 |
-
GGML_METAL_ADD_KERNEL(
|
| 577 |
-
GGML_METAL_ADD_KERNEL(
|
| 578 |
-
GGML_METAL_ADD_KERNEL(
|
| 579 |
-
GGML_METAL_ADD_KERNEL(
|
| 580 |
-
GGML_METAL_ADD_KERNEL(
|
| 581 |
-
GGML_METAL_ADD_KERNEL(
|
| 582 |
-
GGML_METAL_ADD_KERNEL(
|
| 583 |
-
GGML_METAL_ADD_KERNEL(
|
| 584 |
-
GGML_METAL_ADD_KERNEL(
|
| 585 |
-
GGML_METAL_ADD_KERNEL(
|
| 586 |
-
GGML_METAL_ADD_KERNEL(
|
| 587 |
-
GGML_METAL_ADD_KERNEL(
|
| 588 |
-
GGML_METAL_ADD_KERNEL(
|
| 589 |
-
GGML_METAL_ADD_KERNEL(
|
| 590 |
-
GGML_METAL_ADD_KERNEL(
|
| 591 |
-
GGML_METAL_ADD_KERNEL(
|
| 592 |
-
GGML_METAL_ADD_KERNEL(
|
| 593 |
-
GGML_METAL_ADD_KERNEL(
|
| 594 |
-
GGML_METAL_ADD_KERNEL(
|
| 595 |
-
GGML_METAL_ADD_KERNEL(
|
| 596 |
-
GGML_METAL_ADD_KERNEL(
|
| 597 |
-
GGML_METAL_ADD_KERNEL(
|
| 598 |
-
GGML_METAL_ADD_KERNEL(
|
| 599 |
-
GGML_METAL_ADD_KERNEL(
|
| 600 |
-
GGML_METAL_ADD_KERNEL(
|
| 601 |
-
GGML_METAL_ADD_KERNEL(
|
| 602 |
-
GGML_METAL_ADD_KERNEL(
|
| 603 |
-
GGML_METAL_ADD_KERNEL(
|
| 604 |
-
GGML_METAL_ADD_KERNEL(
|
| 605 |
-
GGML_METAL_ADD_KERNEL(
|
| 606 |
-
GGML_METAL_ADD_KERNEL(
|
| 607 |
-
GGML_METAL_ADD_KERNEL(
|
| 608 |
-
GGML_METAL_ADD_KERNEL(
|
| 609 |
-
GGML_METAL_ADD_KERNEL(
|
| 610 |
-
GGML_METAL_ADD_KERNEL(
|
| 611 |
-
GGML_METAL_ADD_KERNEL(
|
| 612 |
-
GGML_METAL_ADD_KERNEL(
|
| 613 |
-
GGML_METAL_ADD_KERNEL(
|
| 614 |
-
GGML_METAL_ADD_KERNEL(
|
| 615 |
-
GGML_METAL_ADD_KERNEL(
|
| 616 |
-
GGML_METAL_ADD_KERNEL(
|
| 617 |
-
GGML_METAL_ADD_KERNEL(
|
| 618 |
-
GGML_METAL_ADD_KERNEL(
|
| 619 |
-
GGML_METAL_ADD_KERNEL(
|
| 620 |
-
GGML_METAL_ADD_KERNEL(
|
| 621 |
-
GGML_METAL_ADD_KERNEL(
|
| 622 |
-
GGML_METAL_ADD_KERNEL(
|
| 623 |
-
GGML_METAL_ADD_KERNEL(
|
| 624 |
-
GGML_METAL_ADD_KERNEL(
|
| 625 |
-
GGML_METAL_ADD_KERNEL(
|
| 626 |
-
GGML_METAL_ADD_KERNEL(
|
| 627 |
-
GGML_METAL_ADD_KERNEL(
|
| 628 |
-
GGML_METAL_ADD_KERNEL(
|
| 629 |
-
GGML_METAL_ADD_KERNEL(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 630 |
}
|
| 631 |
|
| 632 |
[metal_library release];
|
|
@@ -746,6 +766,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
| 746 |
case GGML_OP_TIMESTEP_EMBEDDING:
|
| 747 |
case GGML_OP_ARGSORT:
|
| 748 |
case GGML_OP_LEAKY_RELU:
|
|
|
|
| 749 |
return true;
|
| 750 |
case GGML_OP_MUL_MAT:
|
| 751 |
case GGML_OP_MUL_MAT_ID:
|
|
@@ -1341,20 +1362,33 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 1341 |
} break;
|
| 1342 |
case GGML_OP_SOFT_MAX:
|
| 1343 |
{
|
|
|
|
|
|
|
|
|
|
| 1344 |
int nth = 32; // SIMD width
|
| 1345 |
|
| 1346 |
id<MTLComputePipelineState> pipeline = nil;
|
| 1347 |
|
|
|
|
|
|
|
| 1348 |
if (ne00%4 == 0) {
|
| 1349 |
while (nth < ne00/4 && nth < 256) {
|
| 1350 |
nth *= 2;
|
| 1351 |
}
|
| 1352 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1353 |
} else {
|
| 1354 |
while (nth < ne00 && nth < 1024) {
|
| 1355 |
nth *= 2;
|
| 1356 |
}
|
| 1357 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1358 |
}
|
| 1359 |
|
| 1360 |
float scale;
|
|
@@ -2518,6 +2552,161 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 2518 |
|
| 2519 |
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 2520 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2521 |
case GGML_OP_DUP:
|
| 2522 |
case GGML_OP_CPY:
|
| 2523 |
case GGML_OP_CONT:
|
|
@@ -2721,10 +2910,13 @@ GGML_CALL static const char * ggml_backend_metal_buffer_type_get_name(ggml_backe
|
|
| 2721 |
UNUSED(buft);
|
| 2722 |
}
|
| 2723 |
|
| 2724 |
-
static void ggml_backend_metal_log_allocated_size(id<MTLDevice> device) {
|
|
|
|
| 2725 |
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
|
| 2726 |
if (@available(macOS 10.12, iOS 16.0, *)) {
|
| 2727 |
-
GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
|
|
|
|
|
|
|
| 2728 |
device.currentAllocatedSize / 1024.0 / 1024.0,
|
| 2729 |
device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
| 2730 |
|
|
@@ -2734,10 +2926,15 @@ static void ggml_backend_metal_log_allocated_size(id<MTLDevice> device) {
|
|
| 2734 |
GGML_METAL_LOG_INFO("\n");
|
| 2735 |
}
|
| 2736 |
} else {
|
| 2737 |
-
GGML_METAL_LOG_INFO(", (%8.2f)\n",
|
|
|
|
|
|
|
|
|
|
| 2738 |
}
|
|
|
|
| 2739 |
#endif
|
| 2740 |
UNUSED(device);
|
|
|
|
| 2741 |
}
|
| 2742 |
|
| 2743 |
GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
|
@@ -2771,8 +2968,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buff
|
|
| 2771 |
return NULL;
|
| 2772 |
}
|
| 2773 |
|
| 2774 |
-
|
| 2775 |
-
ggml_backend_metal_log_allocated_size(device);
|
| 2776 |
|
| 2777 |
return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size);
|
| 2778 |
}
|
|
@@ -2859,7 +3055,7 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data,
|
|
| 2859 |
return false;
|
| 2860 |
}
|
| 2861 |
|
| 2862 |
-
|
| 2863 |
|
| 2864 |
++ctx->n_buffers;
|
| 2865 |
} else {
|
|
@@ -2882,7 +3078,8 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data,
|
|
| 2882 |
return false;
|
| 2883 |
}
|
| 2884 |
|
| 2885 |
-
|
|
|
|
| 2886 |
if (i + size_step < size) {
|
| 2887 |
GGML_METAL_LOG_INFO("\n");
|
| 2888 |
}
|
|
@@ -2891,8 +3088,6 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data,
|
|
| 2891 |
}
|
| 2892 |
}
|
| 2893 |
|
| 2894 |
-
ggml_backend_metal_log_allocated_size(device);
|
| 2895 |
-
|
| 2896 |
return ggml_backend_buffer_init(ggml_backend_metal_buffer_type(), ggml_backend_metal_buffer_i, ctx, size);
|
| 2897 |
}
|
| 2898 |
|
|
|
|
| 47 |
GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
|
| 48 |
GGML_METAL_KERNEL_TYPE_SILU,
|
| 49 |
GGML_METAL_KERNEL_TYPE_SILU_4,
|
| 50 |
+
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
|
| 51 |
+
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
|
| 52 |
+
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
|
| 53 |
+
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4,
|
| 54 |
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
|
| 55 |
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
|
| 56 |
GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
|
|
|
|
| 180 |
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
|
| 181 |
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
|
| 182 |
GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
|
| 183 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64,
|
| 184 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80,
|
| 185 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
|
| 186 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
|
| 187 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
|
| 188 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
|
| 189 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
|
| 190 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
|
| 191 |
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
| 192 |
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
| 193 |
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
|
|
|
|
| 454 |
}
|
| 455 |
|
| 456 |
/*
|
| 457 |
+
GGML_METAL_LOG_INFO("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
|
| 458 |
(int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \
|
| 459 |
(int) kernel->pipeline.threadExecutionWidth); \
|
| 460 |
*/
|
|
|
|
| 470 |
return NULL; \
|
| 471 |
} \
|
| 472 |
} else { \
|
| 473 |
+
GGML_METAL_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \
|
| 474 |
}
|
| 475 |
|
| 476 |
// simd_sum and simd_max requires MTLGPUFamilyApple7
|
| 477 |
|
| 478 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
|
| 479 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
|
| 480 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
|
| 481 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
|
| 482 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
|
| 483 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
|
| 484 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
|
| 485 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
|
| 486 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
|
| 487 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
|
| 488 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
|
| 489 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
|
| 490 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
|
| 491 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
|
| 492 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
|
| 493 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
|
| 494 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
| 495 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
|
| 496 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, ctx->support_simdgroup_reduction);
|
| 497 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, ctx->support_simdgroup_reduction);
|
| 498 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, ctx->support_simdgroup_reduction);
|
| 499 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, ctx->support_simdgroup_reduction);
|
| 500 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
|
| 501 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
|
| 502 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
|
| 503 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
|
| 504 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true);
|
| 505 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
|
| 506 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
|
| 507 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true);
|
| 508 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true);
|
| 509 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true);
|
| 510 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true);
|
| 511 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true);
|
| 512 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true);
|
| 513 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true);
|
| 514 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
|
| 515 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
|
| 516 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
|
| 517 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true);
|
| 518 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true);
|
| 519 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
|
| 520 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true);
|
| 521 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
|
| 522 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
|
| 523 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
| 524 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
|
| 525 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
|
| 526 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
| 527 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
|
| 528 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction);
|
| 529 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
|
| 530 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction);
|
| 531 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction);
|
| 532 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction);
|
| 533 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction);
|
| 534 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction);
|
| 535 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction);
|
| 536 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction);
|
| 537 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction);
|
| 538 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction);
|
| 539 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction);
|
| 540 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction);
|
| 541 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction);
|
| 542 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
|
| 543 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
|
| 544 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction);
|
| 545 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction);
|
| 546 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction);
|
| 547 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
|
| 548 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, ctx->support_simdgroup_reduction);
|
| 549 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
|
| 550 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction);
|
| 551 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
|
| 552 |
+
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
|
| 553 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
|
| 554 |
+
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction);
|
| 555 |
+
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction);
|
| 556 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction);
|
| 557 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction);
|
| 558 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction);
|
| 559 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction);
|
| 560 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction);
|
| 561 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction);
|
| 562 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction);
|
| 563 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction);
|
| 564 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction);
|
| 565 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction);
|
| 566 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
|
| 567 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
|
| 568 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction);
|
| 569 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction);
|
| 570 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction);
|
| 571 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
|
| 572 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, ctx->support_simdgroup_reduction);
|
| 573 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
|
| 574 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction);
|
| 575 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
|
| 576 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
|
| 577 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
|
| 578 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm);
|
| 579 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm);
|
| 580 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm);
|
| 581 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm);
|
| 582 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm);
|
| 583 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm);
|
| 584 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm);
|
| 585 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm);
|
| 586 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm);
|
| 587 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
|
| 588 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
|
| 589 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
|
| 590 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm);
|
| 591 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm);
|
| 592 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
|
| 593 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, ctx->support_simdgroup_mm);
|
| 594 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
|
| 595 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm);
|
| 596 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
|
| 597 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
|
| 598 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
|
| 599 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm);
|
| 600 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm);
|
| 601 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm);
|
| 602 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm);
|
| 603 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm);
|
| 604 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm);
|
| 605 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm);
|
| 606 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm);
|
| 607 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm);
|
| 608 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
|
| 609 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
|
| 610 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm);
|
| 611 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm);
|
| 612 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm);
|
| 613 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
|
| 614 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm);
|
| 615 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
|
| 616 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
|
| 617 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
|
| 618 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
|
| 619 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
|
| 620 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
| 621 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
| 622 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
| 623 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
|
| 624 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
|
| 625 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true);
|
| 626 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
| 627 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
|
| 628 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
|
| 629 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true);
|
| 630 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true);
|
| 631 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true);
|
| 632 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true);
|
| 633 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true);
|
| 634 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true);
|
| 635 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, true);
|
| 636 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, true);
|
| 637 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
| 638 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
| 639 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
| 640 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
|
| 641 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
|
| 642 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
|
| 643 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
|
| 644 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
|
| 645 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
|
| 646 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
|
| 647 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
|
| 648 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
|
| 649 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
| 650 |
}
|
| 651 |
|
| 652 |
[metal_library release];
|
|
|
|
| 766 |
case GGML_OP_TIMESTEP_EMBEDDING:
|
| 767 |
case GGML_OP_ARGSORT:
|
| 768 |
case GGML_OP_LEAKY_RELU:
|
| 769 |
+
case GGML_OP_FLASH_ATTN_EXT:
|
| 770 |
return true;
|
| 771 |
case GGML_OP_MUL_MAT:
|
| 772 |
case GGML_OP_MUL_MAT_ID:
|
|
|
|
| 1362 |
} break;
|
| 1363 |
case GGML_OP_SOFT_MAX:
|
| 1364 |
{
|
| 1365 |
+
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
|
| 1366 |
+
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32);
|
| 1367 |
+
|
| 1368 |
int nth = 32; // SIMD width
|
| 1369 |
|
| 1370 |
id<MTLComputePipelineState> pipeline = nil;
|
| 1371 |
|
| 1372 |
+
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
|
| 1373 |
+
|
| 1374 |
if (ne00%4 == 0) {
|
| 1375 |
while (nth < ne00/4 && nth < 256) {
|
| 1376 |
nth *= 2;
|
| 1377 |
}
|
| 1378 |
+
if (use_f16) {
|
| 1379 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4].pipeline;
|
| 1380 |
+
} else {
|
| 1381 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline;
|
| 1382 |
+
}
|
| 1383 |
} else {
|
| 1384 |
while (nth < ne00 && nth < 1024) {
|
| 1385 |
nth *= 2;
|
| 1386 |
}
|
| 1387 |
+
if (use_f16) {
|
| 1388 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16].pipeline;
|
| 1389 |
+
} else {
|
| 1390 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32].pipeline;
|
| 1391 |
+
}
|
| 1392 |
}
|
| 1393 |
|
| 1394 |
float scale;
|
|
|
|
| 2552 |
|
| 2553 |
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 2554 |
} break;
|
| 2555 |
+
case GGML_OP_FLASH_ATTN_EXT:
|
| 2556 |
+
{
|
| 2557 |
+
GGML_ASSERT(ne00 % 4 == 0);
|
| 2558 |
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 2559 |
+
|
| 2560 |
+
struct ggml_tensor * src3 = gf->nodes[i]->src[3];
|
| 2561 |
+
|
| 2562 |
+
GGML_ASSERT(ggml_are_same_shape(src1, src2));
|
| 2563 |
+
GGML_ASSERT(src3);
|
| 2564 |
+
|
| 2565 |
+
size_t offs_src3 = 0;
|
| 2566 |
+
|
| 2567 |
+
id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
|
| 2568 |
+
|
| 2569 |
+
GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16);
|
| 2570 |
+
GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
|
| 2571 |
+
"the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
|
| 2572 |
+
|
| 2573 |
+
const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
|
| 2574 |
+
const int64_t ne31 = src3 ? src3->ne[1] : 0;
|
| 2575 |
+
const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
|
| 2576 |
+
const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33);
|
| 2577 |
+
|
| 2578 |
+
const uint64_t nb30 = src3 ? src3->nb[0] : 0; GGML_UNUSED(nb30);
|
| 2579 |
+
const uint64_t nb31 = src3 ? src3->nb[1] : 0;
|
| 2580 |
+
const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32);
|
| 2581 |
+
const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33);
|
| 2582 |
+
|
| 2583 |
+
const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
|
| 2584 |
+
|
| 2585 |
+
float scale;
|
| 2586 |
+
memcpy(&scale, dst->op_params, sizeof(float));
|
| 2587 |
+
|
| 2588 |
+
id<MTLComputePipelineState> pipeline = nil;
|
| 2589 |
+
|
| 2590 |
+
bool use_vec_kernel = false;
|
| 2591 |
+
|
| 2592 |
+
if (ne01 >= 4 || (ne00%128 != 0)) {
|
| 2593 |
+
switch (ne00) {
|
| 2594 |
+
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
|
| 2595 |
+
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
|
| 2596 |
+
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
|
| 2597 |
+
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
|
| 2598 |
+
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
|
| 2599 |
+
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
|
| 2600 |
+
default:
|
| 2601 |
+
{
|
| 2602 |
+
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
|
| 2603 |
+
GGML_METAL_LOG_ERROR("add template specialization for this size\n");
|
| 2604 |
+
GGML_ASSERT(false && "add template specialization for this size");
|
| 2605 |
+
}
|
| 2606 |
+
}
|
| 2607 |
+
} else {
|
| 2608 |
+
use_vec_kernel = true;
|
| 2609 |
+
|
| 2610 |
+
switch (ne00) {
|
| 2611 |
+
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
|
| 2612 |
+
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
|
| 2613 |
+
default:
|
| 2614 |
+
{
|
| 2615 |
+
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
|
| 2616 |
+
GGML_METAL_LOG_ERROR("add template specialization for this size\n");
|
| 2617 |
+
GGML_ASSERT(false && "add template specialization for this size");
|
| 2618 |
+
}
|
| 2619 |
+
}
|
| 2620 |
+
}
|
| 2621 |
+
|
| 2622 |
+
[encoder setComputePipelineState:pipeline];
|
| 2623 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 2624 |
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
| 2625 |
+
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
| 2626 |
+
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
| 2627 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
|
| 2628 |
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5];
|
| 2629 |
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6];
|
| 2630 |
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:7];
|
| 2631 |
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:8];
|
| 2632 |
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:9];
|
| 2633 |
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10];
|
| 2634 |
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11];
|
| 2635 |
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12];
|
| 2636 |
+
[encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:13];
|
| 2637 |
+
[encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:14];
|
| 2638 |
+
[encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:15];
|
| 2639 |
+
[encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:16];
|
| 2640 |
+
[encoder setBytes:&nb10 length:sizeof(uint64_t) atIndex:17];
|
| 2641 |
+
[encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18];
|
| 2642 |
+
[encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19];
|
| 2643 |
+
[encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20];
|
| 2644 |
+
[encoder setBytes:&ne31 length:sizeof( int64_t) atIndex:21];
|
| 2645 |
+
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:22];
|
| 2646 |
+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:23];
|
| 2647 |
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:24];
|
| 2648 |
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:25];
|
| 2649 |
+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
|
| 2650 |
+
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
|
| 2651 |
+
|
| 2652 |
+
if (!use_vec_kernel) {
|
| 2653 |
+
// half8x8 kernel
|
| 2654 |
+
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
|
| 2655 |
+
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
|
| 2656 |
+
|
| 2657 |
+
GGML_ASSERT(nqptg <= 32);
|
| 2658 |
+
GGML_ASSERT(nqptg % 8 == 0);
|
| 2659 |
+
GGML_ASSERT(ncpsg % 32 == 0);
|
| 2660 |
+
|
| 2661 |
+
int64_t nsgmax = 2;
|
| 2662 |
+
|
| 2663 |
+
while (true) {
|
| 2664 |
+
const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2);
|
| 2665 |
+
if (smem > ctx->device.maxThreadgroupMemoryLength) {
|
| 2666 |
+
break;
|
| 2667 |
+
}
|
| 2668 |
+
nsgmax *= 2;
|
| 2669 |
+
}
|
| 2670 |
+
nsgmax /= 2;
|
| 2671 |
+
|
| 2672 |
+
// simdgroups per threadgroup (a.k.a. warps)
|
| 2673 |
+
const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
|
| 2674 |
+
|
| 2675 |
+
const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2);
|
| 2676 |
+
|
| 2677 |
+
//printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
|
| 2678 |
+
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
|
| 2679 |
+
|
| 2680 |
+
[encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
|
| 2681 |
+
|
| 2682 |
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
| 2683 |
+
} else {
|
| 2684 |
+
// half1x4 kernel
|
| 2685 |
+
const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
|
| 2686 |
+
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
|
| 2687 |
+
|
| 2688 |
+
GGML_ASSERT(nqptg <= 32);
|
| 2689 |
+
GGML_ASSERT(nqptg % 1 == 0);
|
| 2690 |
+
GGML_ASSERT(ncpsg % 32 == 0);
|
| 2691 |
+
|
| 2692 |
+
// simdgroups per threadgroup (a.k.a. warps)
|
| 2693 |
+
const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
|
| 2694 |
+
|
| 2695 |
+
int64_t nsg = 1;
|
| 2696 |
+
while (nsg <= nsgt) {
|
| 2697 |
+
nsg *= 2;
|
| 2698 |
+
}
|
| 2699 |
+
nsg /= 2;
|
| 2700 |
+
|
| 2701 |
+
const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
|
| 2702 |
+
|
| 2703 |
+
//printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
|
| 2704 |
+
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
|
| 2705 |
+
[encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
|
| 2706 |
+
|
| 2707 |
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
| 2708 |
+
}
|
| 2709 |
+
} break;
|
| 2710 |
case GGML_OP_DUP:
|
| 2711 |
case GGML_OP_CPY:
|
| 2712 |
case GGML_OP_CONT:
|
|
|
|
| 2910 |
UNUSED(buft);
|
| 2911 |
}
|
| 2912 |
|
| 2913 |
+
static void ggml_backend_metal_log_allocated_size(id<MTLDevice> device, size_t size_aligned) {
|
| 2914 |
+
#ifndef GGML_METAL_NDEBUG
|
| 2915 |
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
|
| 2916 |
if (@available(macOS 10.12, iOS 16.0, *)) {
|
| 2917 |
+
GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f / %8.2f)",
|
| 2918 |
+
__func__,
|
| 2919 |
+
size_aligned / 1024.0 / 1024.0,
|
| 2920 |
device.currentAllocatedSize / 1024.0 / 1024.0,
|
| 2921 |
device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
| 2922 |
|
|
|
|
| 2926 |
GGML_METAL_LOG_INFO("\n");
|
| 2927 |
}
|
| 2928 |
} else {
|
| 2929 |
+
GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f)\n",
|
| 2930 |
+
__func__,
|
| 2931 |
+
size_aligned / 1024.0 / 1024.0,
|
| 2932 |
+
device.currentAllocatedSize / 1024.0 / 1024.0);
|
| 2933 |
}
|
| 2934 |
+
#endif
|
| 2935 |
#endif
|
| 2936 |
UNUSED(device);
|
| 2937 |
+
UNUSED(size_aligned);
|
| 2938 |
}
|
| 2939 |
|
| 2940 |
GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
|
|
|
| 2968 |
return NULL;
|
| 2969 |
}
|
| 2970 |
|
| 2971 |
+
//ggml_backend_metal_log_allocated_size(device, size_aligned);
|
|
|
|
| 2972 |
|
| 2973 |
return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size);
|
| 2974 |
}
|
|
|
|
| 3055 |
return false;
|
| 3056 |
}
|
| 3057 |
|
| 3058 |
+
ggml_backend_metal_log_allocated_size(device, size_aligned);
|
| 3059 |
|
| 3060 |
++ctx->n_buffers;
|
| 3061 |
} else {
|
|
|
|
| 3078 |
return false;
|
| 3079 |
}
|
| 3080 |
|
| 3081 |
+
ggml_backend_metal_log_allocated_size(device, size_step_aligned);
|
| 3082 |
+
|
| 3083 |
if (i + size_step < size) {
|
| 3084 |
GGML_METAL_LOG_INFO("\n");
|
| 3085 |
}
|
|
|
|
| 3088 |
}
|
| 3089 |
}
|
| 3090 |
|
|
|
|
|
|
|
| 3091 |
return ggml_backend_buffer_init(ggml_backend_metal_buffer_type(), ggml_backend_metal_buffer_i, ctx, size);
|
| 3092 |
}
|
| 3093 |
|
|
@@ -359,11 +359,12 @@ kernel void kernel_sum_rows(
|
|
| 359 |
dst_row[0] = row_sum;
|
| 360 |
}
|
| 361 |
|
|
|
|
| 362 |
kernel void kernel_soft_max(
|
| 363 |
-
device const
|
| 364 |
-
device const
|
| 365 |
-
device const
|
| 366 |
-
device
|
| 367 |
constant int64_t & ne00,
|
| 368 |
constant int64_t & ne01,
|
| 369 |
constant int64_t & ne02,
|
|
@@ -382,10 +383,10 @@ kernel void kernel_soft_max(
|
|
| 382 |
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
| 383 |
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
| 384 |
|
| 385 |
-
device const float * psrc0 =
|
| 386 |
-
device const
|
| 387 |
-
device const
|
| 388 |
-
device float * pdst =
|
| 389 |
|
| 390 |
float slope = 0.0f;
|
| 391 |
|
|
@@ -463,11 +464,12 @@ kernel void kernel_soft_max(
|
|
| 463 |
}
|
| 464 |
}
|
| 465 |
|
|
|
|
| 466 |
kernel void kernel_soft_max_4(
|
| 467 |
-
device const
|
| 468 |
-
device const
|
| 469 |
-
device const
|
| 470 |
-
device
|
| 471 |
constant int64_t & ne00,
|
| 472 |
constant int64_t & ne01,
|
| 473 |
constant int64_t & ne02,
|
|
@@ -486,10 +488,10 @@ kernel void kernel_soft_max_4(
|
|
| 486 |
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
| 487 |
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
| 488 |
|
| 489 |
-
device const float4 * psrc4 =
|
| 490 |
-
device const
|
| 491 |
-
device const
|
| 492 |
-
device float4 * pdst4 =
|
| 493 |
|
| 494 |
float slope = 0.0f;
|
| 495 |
|
|
@@ -506,7 +508,7 @@ kernel void kernel_soft_max_4(
|
|
| 506 |
float4 lmax4 = -INFINITY;
|
| 507 |
|
| 508 |
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
| 509 |
-
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f));
|
| 510 |
}
|
| 511 |
|
| 512 |
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
|
@@ -532,7 +534,7 @@ kernel void kernel_soft_max_4(
|
|
| 532 |
// parallel sum
|
| 533 |
float4 lsum4 = 0.0f;
|
| 534 |
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
| 535 |
-
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)) - max_val);
|
| 536 |
lsum4 += exp_psrc4;
|
| 537 |
pdst4[i00] = exp_psrc4;
|
| 538 |
}
|
|
@@ -569,6 +571,14 @@ kernel void kernel_soft_max_4(
|
|
| 569 |
}
|
| 570 |
}
|
| 571 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 572 |
kernel void kernel_diag_mask_inf(
|
| 573 |
device const float * src0,
|
| 574 |
device float * dst,
|
|
@@ -2091,6 +2101,632 @@ kernel void kernel_leaky_relu_f32(
|
|
| 2091 |
dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
|
| 2092 |
}
|
| 2093 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2094 |
kernel void kernel_cpy_f16_f16(
|
| 2095 |
device const half * src0,
|
| 2096 |
device half * dst,
|
|
|
|
| 359 |
dst_row[0] = row_sum;
|
| 360 |
}
|
| 361 |
|
| 362 |
+
template<typename T>
|
| 363 |
kernel void kernel_soft_max(
|
| 364 |
+
device const char * src0,
|
| 365 |
+
device const char * src1,
|
| 366 |
+
device const char * src2,
|
| 367 |
+
device char * dst,
|
| 368 |
constant int64_t & ne00,
|
| 369 |
constant int64_t & ne01,
|
| 370 |
constant int64_t & ne02,
|
|
|
|
| 383 |
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
| 384 |
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
| 385 |
|
| 386 |
+
device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
| 387 |
+
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr;
|
| 388 |
+
device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
|
| 389 |
+
device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
| 390 |
|
| 391 |
float slope = 0.0f;
|
| 392 |
|
|
|
|
| 464 |
}
|
| 465 |
}
|
| 466 |
|
| 467 |
+
template<typename T>
|
| 468 |
kernel void kernel_soft_max_4(
|
| 469 |
+
device const char * src0,
|
| 470 |
+
device const char * src1,
|
| 471 |
+
device const char * src2,
|
| 472 |
+
device char * dst,
|
| 473 |
constant int64_t & ne00,
|
| 474 |
constant int64_t & ne01,
|
| 475 |
constant int64_t & ne02,
|
|
|
|
| 488 |
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
| 489 |
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
| 490 |
|
| 491 |
+
device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
|
| 492 |
+
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr;
|
| 493 |
+
device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
|
| 494 |
+
device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
|
| 495 |
|
| 496 |
float slope = 0.0f;
|
| 497 |
|
|
|
|
| 508 |
float4 lmax4 = -INFINITY;
|
| 509 |
|
| 510 |
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
| 511 |
+
lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)));
|
| 512 |
}
|
| 513 |
|
| 514 |
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
|
|
|
| 534 |
// parallel sum
|
| 535 |
float4 lsum4 = 0.0f;
|
| 536 |
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
| 537 |
+
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f))) - max_val);
|
| 538 |
lsum4 += exp_psrc4;
|
| 539 |
pdst4[i00] = exp_psrc4;
|
| 540 |
}
|
|
|
|
| 571 |
}
|
| 572 |
}
|
| 573 |
|
| 574 |
+
typedef decltype(kernel_soft_max<float>) kernel_soft_max_t;
|
| 575 |
+
typedef decltype(kernel_soft_max_4<float4>) kernel_soft_max_4_t;
|
| 576 |
+
|
| 577 |
+
template [[host_name("kernel_soft_max_f16")]] kernel kernel_soft_max_t kernel_soft_max<half>;
|
| 578 |
+
template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kernel_soft_max<float>;
|
| 579 |
+
template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<half4>;
|
| 580 |
+
template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<float4>;
|
| 581 |
+
|
| 582 |
kernel void kernel_diag_mask_inf(
|
| 583 |
device const float * src0,
|
| 584 |
device float * dst,
|
|
|
|
| 2101 |
dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
|
| 2102 |
}
|
| 2103 |
|
| 2104 |
+
typedef void (flash_attn_ext_f16_t)(
|
| 2105 |
+
device const char * q,
|
| 2106 |
+
device const char * k,
|
| 2107 |
+
device const char * v,
|
| 2108 |
+
device const char * mask,
|
| 2109 |
+
device float * dst,
|
| 2110 |
+
constant int64_t & ne00,
|
| 2111 |
+
constant int64_t & ne01,
|
| 2112 |
+
constant int64_t & ne02,
|
| 2113 |
+
constant int64_t & ne03,
|
| 2114 |
+
constant uint64_t & nb00,
|
| 2115 |
+
constant uint64_t & nb01,
|
| 2116 |
+
constant uint64_t & nb02,
|
| 2117 |
+
constant uint64_t & nb03,
|
| 2118 |
+
constant int64_t & ne10,
|
| 2119 |
+
constant int64_t & ne11,
|
| 2120 |
+
constant int64_t & ne12,
|
| 2121 |
+
constant int64_t & ne13,
|
| 2122 |
+
constant uint64_t & nb10,
|
| 2123 |
+
constant uint64_t & nb11,
|
| 2124 |
+
constant uint64_t & nb12,
|
| 2125 |
+
constant uint64_t & nb13,
|
| 2126 |
+
constant int64_t & ne31,
|
| 2127 |
+
constant uint64_t & nb31,
|
| 2128 |
+
constant int64_t & ne0,
|
| 2129 |
+
constant int64_t & ne1,
|
| 2130 |
+
constant int64_t & ne2,
|
| 2131 |
+
constant int64_t & ne3,
|
| 2132 |
+
constant float & scale,
|
| 2133 |
+
threadgroup half * shared,
|
| 2134 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2135 |
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
| 2136 |
+
uint3 ntg[[threads_per_threadgroup]],
|
| 2137 |
+
ushort tiisg[[thread_index_in_simdgroup]],
|
| 2138 |
+
ushort sgitg[[simdgroup_index_in_threadgroup]]);
|
| 2139 |
+
|
| 2140 |
+
// ref: https://arxiv.org/pdf/2307.08691.pdf
|
| 2141 |
+
template<int64_t D, int64_t Q = 8, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
|
| 2142 |
+
kernel void kernel_flash_attn_ext_f16(
|
| 2143 |
+
device const char * q,
|
| 2144 |
+
device const char * k,
|
| 2145 |
+
device const char * v,
|
| 2146 |
+
device const char * mask,
|
| 2147 |
+
device float * dst,
|
| 2148 |
+
constant int64_t & ne00,
|
| 2149 |
+
constant int64_t & ne01,
|
| 2150 |
+
constant int64_t & ne02,
|
| 2151 |
+
constant int64_t & ne03,
|
| 2152 |
+
constant uint64_t & nb00,
|
| 2153 |
+
constant uint64_t & nb01,
|
| 2154 |
+
constant uint64_t & nb02,
|
| 2155 |
+
constant uint64_t & nb03,
|
| 2156 |
+
constant int64_t & ne10,
|
| 2157 |
+
constant int64_t & ne11,
|
| 2158 |
+
constant int64_t & ne12,
|
| 2159 |
+
constant int64_t & ne13,
|
| 2160 |
+
constant uint64_t & nb10,
|
| 2161 |
+
constant uint64_t & nb11,
|
| 2162 |
+
constant uint64_t & nb12,
|
| 2163 |
+
constant uint64_t & nb13,
|
| 2164 |
+
constant int64_t & ne31,
|
| 2165 |
+
constant uint64_t & nb31,
|
| 2166 |
+
constant int64_t & ne0,
|
| 2167 |
+
constant int64_t & ne1,
|
| 2168 |
+
constant int64_t & ne2,
|
| 2169 |
+
constant int64_t & ne3,
|
| 2170 |
+
constant float & scale,
|
| 2171 |
+
threadgroup half * shared [[threadgroup(0)]],
|
| 2172 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2173 |
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
| 2174 |
+
uint3 ntg[[threads_per_threadgroup]],
|
| 2175 |
+
ushort tiisg[[thread_index_in_simdgroup]],
|
| 2176 |
+
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 2177 |
+
const short nsg = ntg.y; // number of simdgroups
|
| 2178 |
+
|
| 2179 |
+
const short iq3 = tgpig[2];
|
| 2180 |
+
const short iq2 = tgpig[1];
|
| 2181 |
+
const short iq1 = tgpig[0]*Q;
|
| 2182 |
+
|
| 2183 |
+
const short D4 = D/4;
|
| 2184 |
+
const short D8 = D/8;
|
| 2185 |
+
const short Q8 = Q/8;
|
| 2186 |
+
const short NW = N_SIMDWIDTH;
|
| 2187 |
+
const short SH = (C + Q); // shared memory per simdgroup in (half)
|
| 2188 |
+
|
| 2189 |
+
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
|
| 2190 |
+
const short TF = T/2; // shared memory size per query in (float)
|
| 2191 |
+
const short T4 = T/4; // shared memory size per query in (half4)
|
| 2192 |
+
|
| 2193 |
+
threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
|
| 2194 |
+
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
|
| 2195 |
+
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
| 2196 |
+
|
| 2197 |
+
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
| 2198 |
+
simdgroup_half8x8 lo[D8];
|
| 2199 |
+
|
| 2200 |
+
// load heads from Q to shared memory
|
| 2201 |
+
for (short j = sgitg; j < Q; j += nsg) {
|
| 2202 |
+
device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
|
| 2203 |
+
|
| 2204 |
+
for (short i = tiisg; i < D4; i += NW) {
|
| 2205 |
+
if (iq1 + j < ne01) {
|
| 2206 |
+
sq4[j*T4 + i] = (half4) q4[i];
|
| 2207 |
+
} else {
|
| 2208 |
+
sq4[j*T4 + i] = 0.0h;
|
| 2209 |
+
}
|
| 2210 |
+
}
|
| 2211 |
+
}
|
| 2212 |
+
|
| 2213 |
+
// zero out lo
|
| 2214 |
+
for (short i = 0; i < D8; ++i) {
|
| 2215 |
+
lo[i] = make_filled_simdgroup_matrix<half, 8>(0.0h);
|
| 2216 |
+
}
|
| 2217 |
+
|
| 2218 |
+
// zero out shared memory SH
|
| 2219 |
+
for (short j = 0; j < Q; ++j) {
|
| 2220 |
+
for (short i = tiisg; i < SH; i += NW) {
|
| 2221 |
+
ss[j*TF + i] = 0.0f;
|
| 2222 |
+
}
|
| 2223 |
+
}
|
| 2224 |
+
|
| 2225 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 2226 |
+
|
| 2227 |
+
{
|
| 2228 |
+
float S[Q] = { [0 ... Q-1] = 0.0h };
|
| 2229 |
+
float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 };
|
| 2230 |
+
|
| 2231 |
+
// assume K and V are same shape
|
| 2232 |
+
const short ne22 = ne12;
|
| 2233 |
+
const short ne23 = ne13;
|
| 2234 |
+
|
| 2235 |
+
const uint nb21 = nb11;
|
| 2236 |
+
const uint nb22 = nb12;
|
| 2237 |
+
const uint nb23 = nb13;
|
| 2238 |
+
|
| 2239 |
+
// broadcast
|
| 2240 |
+
const short rk2 = ne02/ne12;
|
| 2241 |
+
const short rk3 = ne03/ne13;
|
| 2242 |
+
|
| 2243 |
+
const short rv2 = ne02/ne22;
|
| 2244 |
+
const short rv3 = ne03/ne23;
|
| 2245 |
+
|
| 2246 |
+
// k indices
|
| 2247 |
+
const short ik2 = iq2/rk2;
|
| 2248 |
+
const short ik3 = iq3/rk3;
|
| 2249 |
+
|
| 2250 |
+
// v indices
|
| 2251 |
+
const short iv2 = iq2/rv2;
|
| 2252 |
+
const short iv3 = iq3/rv3;
|
| 2253 |
+
|
| 2254 |
+
// load the queries from shared memory into local memory
|
| 2255 |
+
simdgroup_half8x8 mq[D8];
|
| 2256 |
+
|
| 2257 |
+
for (short i = 0; i < D8; ++i) {
|
| 2258 |
+
simdgroup_load(mq[i], sq + i*8, T);
|
| 2259 |
+
}
|
| 2260 |
+
|
| 2261 |
+
// pointer to the mask
|
| 2262 |
+
device const half * mp = (device const half *) (mask + iq1*nb31);
|
| 2263 |
+
|
| 2264 |
+
// prepare diagonal scale matrix
|
| 2265 |
+
simdgroup_float8x8 mscale(scale);
|
| 2266 |
+
|
| 2267 |
+
// loop over the KV cache
|
| 2268 |
+
// each simdgroup handles blocks of Q rows and C columns
|
| 2269 |
+
for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
|
| 2270 |
+
const int ic = ic0 + C*sgitg;
|
| 2271 |
+
if (ic >= ne11) {
|
| 2272 |
+
break;
|
| 2273 |
+
}
|
| 2274 |
+
|
| 2275 |
+
// Q*K^T
|
| 2276 |
+
{
|
| 2277 |
+
for (short cc = 0; cc < C/8; ++cc) {
|
| 2278 |
+
simdgroup_float8x8 mqk = make_filled_simdgroup_matrix<float, 8>(0.h);
|
| 2279 |
+
|
| 2280 |
+
device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
| 2281 |
+
|
| 2282 |
+
for (short i = 0; i < D8; ++i) {
|
| 2283 |
+
simdgroup_half8x8 mk;
|
| 2284 |
+
simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
|
| 2285 |
+
|
| 2286 |
+
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
|
| 2287 |
+
}
|
| 2288 |
+
|
| 2289 |
+
// mqk = mqk*scale + mask
|
| 2290 |
+
simdgroup_half8x8 mm;
|
| 2291 |
+
simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
|
| 2292 |
+
simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
|
| 2293 |
+
|
| 2294 |
+
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
|
| 2295 |
+
}
|
| 2296 |
+
}
|
| 2297 |
+
|
| 2298 |
+
// used to detect blocks full of -INF
|
| 2299 |
+
float smax = -INFINITY;
|
| 2300 |
+
|
| 2301 |
+
// online softmax
|
| 2302 |
+
{
|
| 2303 |
+
float ms[Q];
|
| 2304 |
+
|
| 2305 |
+
for (short j = 0; j < Q; ++j) {
|
| 2306 |
+
const short p = tiisg;
|
| 2307 |
+
|
| 2308 |
+
const float m = M[j];
|
| 2309 |
+
const float s = ss[j*TF + p];
|
| 2310 |
+
|
| 2311 |
+
smax = simd_max(max(smax, s));
|
| 2312 |
+
M[j] = simd_max(max(M[j], s));
|
| 2313 |
+
|
| 2314 |
+
ms[j] = exp(m - M[j]);
|
| 2315 |
+
const float vs = exp(s - M[j]);
|
| 2316 |
+
|
| 2317 |
+
S[j] = S[j]*ms[j] + simd_sum(vs);
|
| 2318 |
+
|
| 2319 |
+
// the P matrix from the paper (Q rows, C columns)
|
| 2320 |
+
ss[j*TF + p] = vs;
|
| 2321 |
+
}
|
| 2322 |
+
|
| 2323 |
+
// create a QxQ diagonal matrix for rescaling the output
|
| 2324 |
+
if (tiisg < Q) {
|
| 2325 |
+
ss[tiisg*TF + C + tiisg] = ms[tiisg];
|
| 2326 |
+
}
|
| 2327 |
+
}
|
| 2328 |
+
|
| 2329 |
+
// skip -INF blocks
|
| 2330 |
+
if (smax == -INFINITY) {
|
| 2331 |
+
continue;
|
| 2332 |
+
}
|
| 2333 |
+
|
| 2334 |
+
// O = diag(ms)*O
|
| 2335 |
+
{
|
| 2336 |
+
simdgroup_float8x8 mm;
|
| 2337 |
+
simdgroup_load(mm, ss + C, TF, 0, false);
|
| 2338 |
+
|
| 2339 |
+
for (short i = 0; i < D8; ++i) {
|
| 2340 |
+
simdgroup_multiply(lo[i], mm, lo[i]);
|
| 2341 |
+
}
|
| 2342 |
+
}
|
| 2343 |
+
|
| 2344 |
+
// O = O + (Q*K^T)*V
|
| 2345 |
+
{
|
| 2346 |
+
for (short cc = 0; cc < C/8; ++cc) {
|
| 2347 |
+
device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
| 2348 |
+
|
| 2349 |
+
for (short i = 0; i < D8; ++i) {
|
| 2350 |
+
simdgroup_half8x8 mk;
|
| 2351 |
+
simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false);
|
| 2352 |
+
|
| 2353 |
+
simdgroup_float8x8 mv;
|
| 2354 |
+
simdgroup_load(mv, ss + 8*cc, TF, 0, false);
|
| 2355 |
+
|
| 2356 |
+
simdgroup_multiply_accumulate(lo[i], mv, mk, lo[i]);
|
| 2357 |
+
}
|
| 2358 |
+
}
|
| 2359 |
+
}
|
| 2360 |
+
}
|
| 2361 |
+
|
| 2362 |
+
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
| 2363 |
+
for (short j = 0; j < Q; ++j) {
|
| 2364 |
+
if (tiisg == 0) {
|
| 2365 |
+
ss[j*TF + 0] = S[j];
|
| 2366 |
+
ss[j*TF + 1] = M[j];
|
| 2367 |
+
}
|
| 2368 |
+
}
|
| 2369 |
+
}
|
| 2370 |
+
|
| 2371 |
+
// reduce the warps sequentially
|
| 2372 |
+
for (short sg = 1; sg < nsg; ++sg) {
|
| 2373 |
+
float S = { 0.0h };
|
| 2374 |
+
float M = { -FLT_MAX/2 };
|
| 2375 |
+
|
| 2376 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 2377 |
+
|
| 2378 |
+
// each simdgroup stores its output to shared memory, reusing sq
|
| 2379 |
+
if (sgitg == sg) {
|
| 2380 |
+
for (short i = 0; i < D8; ++i) {
|
| 2381 |
+
simdgroup_store(lo[i], sq + i*8, T, 0, false);
|
| 2382 |
+
}
|
| 2383 |
+
}
|
| 2384 |
+
|
| 2385 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 2386 |
+
|
| 2387 |
+
// the first simdgroup accumulates the results from the other simdgroups
|
| 2388 |
+
if (sgitg == 0) {
|
| 2389 |
+
for (short j = 0; j < Q; ++j) {
|
| 2390 |
+
const float S0 = ss[j*TF + 0];
|
| 2391 |
+
const float S1 = ss[j*TF + sg*SH + 0];
|
| 2392 |
+
|
| 2393 |
+
const float M0 = ss[j*TF + 1];
|
| 2394 |
+
const float M1 = ss[j*TF + sg*SH + 1];
|
| 2395 |
+
|
| 2396 |
+
M = max(M0, M1);
|
| 2397 |
+
|
| 2398 |
+
const float ms0 = exp(M0 - M);
|
| 2399 |
+
const float ms1 = exp(M1 - M);
|
| 2400 |
+
|
| 2401 |
+
S = S0*ms0 + S1*ms1;
|
| 2402 |
+
|
| 2403 |
+
if (tiisg == 0) {
|
| 2404 |
+
ss[j*TF + 0] = S;
|
| 2405 |
+
ss[j*TF + 1] = M;
|
| 2406 |
+
|
| 2407 |
+
ss[j*TF + C + j ] = ms0;
|
| 2408 |
+
ss[j*TF + C + j + sg*SH] = ms1;
|
| 2409 |
+
}
|
| 2410 |
+
}
|
| 2411 |
+
|
| 2412 |
+
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
| 2413 |
+
{
|
| 2414 |
+
simdgroup_half8x8 t;
|
| 2415 |
+
simdgroup_float8x8 ms0;
|
| 2416 |
+
simdgroup_float8x8 ms1;
|
| 2417 |
+
|
| 2418 |
+
simdgroup_load(ms0, ss + C, TF, 0, false);
|
| 2419 |
+
simdgroup_load(ms1, ss + C + sg*SH, TF, 0, false);
|
| 2420 |
+
|
| 2421 |
+
for (short i = 0; i < D8; ++i) {
|
| 2422 |
+
simdgroup_load (t, sq + i*8, T, 0, false);
|
| 2423 |
+
simdgroup_multiply(t, ms1, t);
|
| 2424 |
+
|
| 2425 |
+
simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
|
| 2426 |
+
}
|
| 2427 |
+
}
|
| 2428 |
+
}
|
| 2429 |
+
}
|
| 2430 |
+
|
| 2431 |
+
// store result to shared memory (reuse sq)
|
| 2432 |
+
if (sgitg == 0) {
|
| 2433 |
+
for (short i = 0; i < D8; ++i) {
|
| 2434 |
+
simdgroup_store(lo[i], sq + i*8, T, 0, false);
|
| 2435 |
+
}
|
| 2436 |
+
}
|
| 2437 |
+
|
| 2438 |
+
device float4 * dst4 = (device float4 *) dst;
|
| 2439 |
+
|
| 2440 |
+
// final rescale with 1/S and store to global memory
|
| 2441 |
+
if (sgitg == 0) {
|
| 2442 |
+
for (short j = 0; j < Q && iq1 + j < ne01; ++j) {
|
| 2443 |
+
const float S = ss[j*TF + 0];
|
| 2444 |
+
|
| 2445 |
+
for (short i = tiisg; i < D4; i += NW) {
|
| 2446 |
+
dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S;
|
| 2447 |
+
}
|
| 2448 |
+
}
|
| 2449 |
+
}
|
| 2450 |
+
}
|
| 2451 |
+
|
| 2452 |
+
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>;
|
| 2453 |
+
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>;
|
| 2454 |
+
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>;
|
| 2455 |
+
template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>;
|
| 2456 |
+
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>;
|
| 2457 |
+
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
|
| 2458 |
+
|
| 2459 |
+
template<int64_t D, int64_t Q = 1, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
|
| 2460 |
+
kernel void kernel_flash_attn_ext_vec_f16(
|
| 2461 |
+
device const char * q,
|
| 2462 |
+
device const char * k,
|
| 2463 |
+
device const char * v,
|
| 2464 |
+
device const char * mask,
|
| 2465 |
+
device float * dst,
|
| 2466 |
+
constant int64_t & ne00,
|
| 2467 |
+
constant int64_t & ne01,
|
| 2468 |
+
constant int64_t & ne02,
|
| 2469 |
+
constant int64_t & ne03,
|
| 2470 |
+
constant uint64_t & nb00,
|
| 2471 |
+
constant uint64_t & nb01,
|
| 2472 |
+
constant uint64_t & nb02,
|
| 2473 |
+
constant uint64_t & nb03,
|
| 2474 |
+
constant int64_t & ne10,
|
| 2475 |
+
constant int64_t & ne11,
|
| 2476 |
+
constant int64_t & ne12,
|
| 2477 |
+
constant int64_t & ne13,
|
| 2478 |
+
constant uint64_t & nb10,
|
| 2479 |
+
constant uint64_t & nb11,
|
| 2480 |
+
constant uint64_t & nb12,
|
| 2481 |
+
constant uint64_t & nb13,
|
| 2482 |
+
constant int64_t & ne31,
|
| 2483 |
+
constant uint64_t & nb31,
|
| 2484 |
+
constant int64_t & ne0,
|
| 2485 |
+
constant int64_t & ne1,
|
| 2486 |
+
constant int64_t & ne2,
|
| 2487 |
+
constant int64_t & ne3,
|
| 2488 |
+
constant float & scale,
|
| 2489 |
+
threadgroup half * shared [[threadgroup(0)]],
|
| 2490 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 2491 |
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
| 2492 |
+
uint3 ntg[[threads_per_threadgroup]],
|
| 2493 |
+
ushort tiisg[[thread_index_in_simdgroup]],
|
| 2494 |
+
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 2495 |
+
const short nsg = ntg.y; // number of simdgroups
|
| 2496 |
+
|
| 2497 |
+
const short iq3 = tgpig[2];
|
| 2498 |
+
const short iq2 = tgpig[1];
|
| 2499 |
+
const short iq1 = tgpig[0];
|
| 2500 |
+
|
| 2501 |
+
const short D4 = D/4;
|
| 2502 |
+
const short NW = N_SIMDWIDTH;
|
| 2503 |
+
const short SH = (C + Q); // shared memory per simdgroup in (half)
|
| 2504 |
+
|
| 2505 |
+
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
|
| 2506 |
+
|
| 2507 |
+
//threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
|
| 2508 |
+
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
|
| 2509 |
+
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
| 2510 |
+
threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4
|
| 2511 |
+
threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results
|
| 2512 |
+
|
| 2513 |
+
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
| 2514 |
+
half4 lo[D4/NW];
|
| 2515 |
+
|
| 2516 |
+
// load heads from Q to shared memory
|
| 2517 |
+
device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
|
| 2518 |
+
|
| 2519 |
+
for (short i = tiisg; i < D4; i += NW) {
|
| 2520 |
+
if (iq1 < ne01) {
|
| 2521 |
+
sq4[i] = (half4) q4[i];
|
| 2522 |
+
} else {
|
| 2523 |
+
sq4[i] = 0.0h;
|
| 2524 |
+
}
|
| 2525 |
+
}
|
| 2526 |
+
|
| 2527 |
+
// zero out lo
|
| 2528 |
+
for (short i = tiisg; i < D4; i += NW) {
|
| 2529 |
+
lo[i/NW] = 0.0h;
|
| 2530 |
+
}
|
| 2531 |
+
|
| 2532 |
+
// zero out shared memory SH
|
| 2533 |
+
for (short i = tiisg; i < SH/4; i += NW) {
|
| 2534 |
+
ss4[i] = 0.0h;
|
| 2535 |
+
}
|
| 2536 |
+
|
| 2537 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 2538 |
+
|
| 2539 |
+
{
|
| 2540 |
+
float S = { 0.0h };
|
| 2541 |
+
float M = { -FLT_MAX/2 };
|
| 2542 |
+
|
| 2543 |
+
// assume K and V are same shape
|
| 2544 |
+
const short ne22 = ne12;
|
| 2545 |
+
const short ne23 = ne13;
|
| 2546 |
+
|
| 2547 |
+
const uint nb21 = nb11;
|
| 2548 |
+
const uint nb22 = nb12;
|
| 2549 |
+
const uint nb23 = nb13;
|
| 2550 |
+
|
| 2551 |
+
// broadcast
|
| 2552 |
+
const short rk2 = ne02/ne12;
|
| 2553 |
+
const short rk3 = ne03/ne13;
|
| 2554 |
+
|
| 2555 |
+
const short rv2 = ne02/ne22;
|
| 2556 |
+
const short rv3 = ne03/ne23;
|
| 2557 |
+
|
| 2558 |
+
// k indices
|
| 2559 |
+
const short ik2 = iq2 / rk2;
|
| 2560 |
+
const short ik3 = iq3 / rk3;
|
| 2561 |
+
|
| 2562 |
+
// v indices
|
| 2563 |
+
const short iv2 = iq2 / rv2;
|
| 2564 |
+
const short iv3 = iq3 / rv3;
|
| 2565 |
+
|
| 2566 |
+
// load the queries from shared memory into local memory
|
| 2567 |
+
half4 mq[D4];
|
| 2568 |
+
|
| 2569 |
+
for (short ii = 0; ii < D4; ii += NW) {
|
| 2570 |
+
short i = ii + tiisg;
|
| 2571 |
+
mq[i] = sq4[i];
|
| 2572 |
+
}
|
| 2573 |
+
|
| 2574 |
+
// pointer to the mask
|
| 2575 |
+
device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31);
|
| 2576 |
+
|
| 2577 |
+
// loop over the KV cache
|
| 2578 |
+
// each simdgroup handles blocks of Q rows and C columns
|
| 2579 |
+
for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
|
| 2580 |
+
const int ic = ic0 + C*sgitg;
|
| 2581 |
+
if (ic >= ne11) {
|
| 2582 |
+
break;
|
| 2583 |
+
}
|
| 2584 |
+
|
| 2585 |
+
// Q*K^T
|
| 2586 |
+
{
|
| 2587 |
+
#pragma unroll
|
| 2588 |
+
for (short cc = 0; cc < C/4; ++cc) {
|
| 2589 |
+
float4 mqk = { 0.0h };
|
| 2590 |
+
|
| 2591 |
+
device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
| 2592 |
+
|
| 2593 |
+
#pragma unroll
|
| 2594 |
+
for (short ii = 0; ii < D4; ii += NW) {
|
| 2595 |
+
const short i = ii + tiisg;
|
| 2596 |
+
|
| 2597 |
+
half4x4 mk;
|
| 2598 |
+
mk[0] = pk4[i + 0*(nb11/8)];
|
| 2599 |
+
mk[1] = pk4[i + 1*(nb11/8)];
|
| 2600 |
+
mk[2] = pk4[i + 2*(nb11/8)];
|
| 2601 |
+
mk[3] = pk4[i + 3*(nb11/8)];
|
| 2602 |
+
|
| 2603 |
+
mqk += (float4) (mq[i] * mk);
|
| 2604 |
+
}
|
| 2605 |
+
|
| 2606 |
+
// reduce the results from the threads in the simdgroup
|
| 2607 |
+
mqk += simd_shuffle_down(mqk, 16);
|
| 2608 |
+
mqk += simd_shuffle_down(mqk, 8);
|
| 2609 |
+
mqk += simd_shuffle_down(mqk, 4);
|
| 2610 |
+
mqk += simd_shuffle_down(mqk, 2);
|
| 2611 |
+
mqk += simd_shuffle_down(mqk, 1);
|
| 2612 |
+
|
| 2613 |
+
// mqk = mqk*scale + mask
|
| 2614 |
+
if (tiisg == 0) {
|
| 2615 |
+
float4 mm = (float4) mp4[ic/4 + cc];
|
| 2616 |
+
mqk = mqk*scale + mm;
|
| 2617 |
+
|
| 2618 |
+
ss4[cc] = mqk;
|
| 2619 |
+
}
|
| 2620 |
+
}
|
| 2621 |
+
}
|
| 2622 |
+
|
| 2623 |
+
// online softmax
|
| 2624 |
+
{
|
| 2625 |
+
const short p = tiisg;
|
| 2626 |
+
|
| 2627 |
+
const float m = M;
|
| 2628 |
+
const float s = ss[p];
|
| 2629 |
+
|
| 2630 |
+
M = simd_max(max(M, s));
|
| 2631 |
+
|
| 2632 |
+
const float ms = exp(m - M);
|
| 2633 |
+
const float vs = exp(s - M);
|
| 2634 |
+
|
| 2635 |
+
S = S*ms + simd_sum(vs);
|
| 2636 |
+
|
| 2637 |
+
// the P matrix from the paper (Q rows, C columns)
|
| 2638 |
+
ss[p] = vs;
|
| 2639 |
+
|
| 2640 |
+
// O = diag(ms)*O
|
| 2641 |
+
#pragma unroll
|
| 2642 |
+
for (short ii = 0; ii < D4; ii += NW) {
|
| 2643 |
+
const short i = ii + tiisg;
|
| 2644 |
+
lo[i/NW] *= ms;
|
| 2645 |
+
}
|
| 2646 |
+
}
|
| 2647 |
+
|
| 2648 |
+
// O = O + (Q*K^T)*V
|
| 2649 |
+
{
|
| 2650 |
+
#pragma unroll
|
| 2651 |
+
for (short cc = 0; cc < C/4; ++cc) {
|
| 2652 |
+
device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + 4*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
| 2653 |
+
|
| 2654 |
+
#pragma unroll
|
| 2655 |
+
for (short ii = 0; ii < D4; ii += NW) {
|
| 2656 |
+
const short i = ii + tiisg;
|
| 2657 |
+
|
| 2658 |
+
lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0];
|
| 2659 |
+
lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1];
|
| 2660 |
+
lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2];
|
| 2661 |
+
lo[i/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3];
|
| 2662 |
+
}
|
| 2663 |
+
}
|
| 2664 |
+
}
|
| 2665 |
+
|
| 2666 |
+
}
|
| 2667 |
+
|
| 2668 |
+
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
| 2669 |
+
if (tiisg == 0) {
|
| 2670 |
+
ss[0] = S;
|
| 2671 |
+
ss[1] = M;
|
| 2672 |
+
}
|
| 2673 |
+
}
|
| 2674 |
+
|
| 2675 |
+
// store results to shared memory
|
| 2676 |
+
for (short ii = 0; ii < D4; ii += NW) {
|
| 2677 |
+
short i = ii + tiisg;
|
| 2678 |
+
sr4[i] = lo[ii/NW];
|
| 2679 |
+
}
|
| 2680 |
+
|
| 2681 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 2682 |
+
|
| 2683 |
+
// parallel reduce
|
| 2684 |
+
for (short r = nsg/2; r > 0; r >>= 1) {
|
| 2685 |
+
if (sgitg < r) {
|
| 2686 |
+
const float S0 = ss[ 0];
|
| 2687 |
+
const float S1 = ss[r*SH + 0];
|
| 2688 |
+
|
| 2689 |
+
const float M0 = ss[ 1];
|
| 2690 |
+
const float M1 = ss[r*SH + 1];
|
| 2691 |
+
|
| 2692 |
+
const float M = max(M0, M1);
|
| 2693 |
+
|
| 2694 |
+
const float ms0 = exp(M0 - M);
|
| 2695 |
+
const float ms1 = exp(M1 - M);
|
| 2696 |
+
|
| 2697 |
+
const float S = S0*ms0 + S1*ms1;
|
| 2698 |
+
|
| 2699 |
+
if (tiisg == 0) {
|
| 2700 |
+
ss[0] = S;
|
| 2701 |
+
ss[1] = M;
|
| 2702 |
+
}
|
| 2703 |
+
|
| 2704 |
+
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
| 2705 |
+
for (short ii = 0; ii < D4; ii += NW) {
|
| 2706 |
+
short i = ii + tiisg;
|
| 2707 |
+
sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1;
|
| 2708 |
+
}
|
| 2709 |
+
}
|
| 2710 |
+
|
| 2711 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 2712 |
+
}
|
| 2713 |
+
|
| 2714 |
+
device float4 * dst4 = (device float4 *) dst;
|
| 2715 |
+
|
| 2716 |
+
// final rescale with 1/S and store to global memory
|
| 2717 |
+
if (sgitg == 0) {
|
| 2718 |
+
const float S = ss[0];
|
| 2719 |
+
|
| 2720 |
+
for (short ii = 0; ii < D4; ii += NW) {
|
| 2721 |
+
short i = ii + tiisg;
|
| 2722 |
+
dst4[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D4 + i] = (float4) sr4[i]/S;
|
| 2723 |
+
}
|
| 2724 |
+
}
|
| 2725 |
+
}
|
| 2726 |
+
|
| 2727 |
+
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
|
| 2728 |
+
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
|
| 2729 |
+
|
| 2730 |
kernel void kernel_cpy_f16_f16(
|
| 2731 |
device const half * src0,
|
| 2732 |
device half * dst,
|
|
@@ -14744,7 +14744,12 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0,
|
|
| 14744 |
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 14745 |
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 14746 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14747 |
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
|
|
|
|
| 14748 |
|
| 14749 |
const int64_t ne00 = src0->ne[0];
|
| 14750 |
const int64_t nrows_x = ggml_nrows(src0);
|
|
@@ -14760,7 +14765,6 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0,
|
|
| 14760 |
float * src2_dd = nullptr;
|
| 14761 |
sycl_pool_alloc<float> src2_f;
|
| 14762 |
|
| 14763 |
-
ggml_tensor * src2 = dst->src[2];
|
| 14764 |
const bool use_src2 = src2 != nullptr;
|
| 14765 |
|
| 14766 |
if (use_src2) {
|
|
|
|
| 14744 |
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 14745 |
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 14746 |
|
| 14747 |
+
const ggml_tensor * src2 = dst->src[2];
|
| 14748 |
+
|
| 14749 |
+
#pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 and src2 support")
|
| 14750 |
+
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
|
| 14751 |
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
|
| 14752 |
+
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional
|
| 14753 |
|
| 14754 |
const int64_t ne00 = src0->ne[0];
|
| 14755 |
const int64_t nrows_x = ggml_nrows(src0);
|
|
|
|
| 14765 |
float * src2_dd = nullptr;
|
| 14766 |
sycl_pool_alloc<float> src2_f;
|
| 14767 |
|
|
|
|
| 14768 |
const bool use_src2 = src2 != nullptr;
|
| 14769 |
|
| 14770 |
if (use_src2) {
|
|
@@ -3178,6 +3178,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
| 3178 |
}
|
| 3179 |
return nullptr;
|
| 3180 |
case GGML_OP_SOFT_MAX:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3181 |
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && (src2 == nullptr || src2->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
|
| 3182 |
return ctx->device->pipeline_soft_max_f32;
|
| 3183 |
}
|
|
|
|
| 3178 |
}
|
| 3179 |
return nullptr;
|
| 3180 |
case GGML_OP_SOFT_MAX:
|
| 3181 |
+
#pragma message("TODO: add ggml_vk_soft_max() F16 src1 and src2 support")
|
| 3182 |
+
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
|
| 3183 |
+
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32);
|
| 3184 |
+
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32);
|
| 3185 |
+
|
| 3186 |
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && (src2 == nullptr || src2->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
|
| 3187 |
return ctx->device->pipeline_soft_max_f32;
|
| 3188 |
}
|
|
@@ -951,7 +951,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
|
|
| 951 |
#define GGML_F16_VEC_ZERO GGML_F16x8_ZERO
|
| 952 |
#define GGML_F16_VEC_SET1 GGML_F16x8_SET1
|
| 953 |
#define GGML_F16_VEC_LOAD(p, i) GGML_F16x8_LOAD(p)
|
| 954 |
-
#define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE(p, r[i])
|
| 955 |
#define GGML_F16_VEC_FMA GGML_F16x8_FMA
|
| 956 |
#define GGML_F16_VEC_ADD GGML_F16x8_ADD
|
| 957 |
#define GGML_F16_VEC_MUL GGML_F16x8_MUL
|
|
@@ -977,7 +977,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
|
|
| 977 |
#define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO
|
| 978 |
#define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1
|
| 979 |
#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p)
|
| 980 |
-
#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i])
|
| 981 |
#define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
|
| 982 |
#define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
|
| 983 |
#define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
|
|
@@ -1046,7 +1046,7 @@ do { \
|
|
| 1046 |
|
| 1047 |
// unlike _mm256_cvt intrinsics that require F16C, _mm512_cvt is defined in AVX512F
|
| 1048 |
// so F16C guard isn't required
|
| 1049 |
-
#define GGML_F32Cx16_LOAD(x) _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(x)))
|
| 1050 |
#define GGML_F32Cx16_STORE(x, y) _mm256_storeu_si256((__m256i *)(x), _mm512_cvtps_ph(y, 0))
|
| 1051 |
|
| 1052 |
#define GGML_F32Cx16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a)
|
|
@@ -1144,7 +1144,7 @@ do { \
|
|
| 1144 |
|
| 1145 |
#if defined(__F16C__)
|
| 1146 |
// the _mm256_cvt intrinsics require F16C
|
| 1147 |
-
#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x)))
|
| 1148 |
#define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0))
|
| 1149 |
#else
|
| 1150 |
static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
|
|
@@ -1662,6 +1662,37 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float
|
|
| 1662 |
#endif
|
| 1663 |
}
|
| 1664 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1665 |
// xs and vs are byte strides of x and v
|
| 1666 |
inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) {
|
| 1667 |
|
|
@@ -1746,6 +1777,35 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
|
|
| 1746 |
#endif
|
| 1747 |
}
|
| 1748 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1749 |
inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); }
|
| 1750 |
inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
|
| 1751 |
inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
|
|
@@ -2001,6 +2061,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
| 2001 |
"LEAKY_RELU",
|
| 2002 |
|
| 2003 |
"FLASH_ATTN",
|
|
|
|
| 2004 |
"FLASH_FF",
|
| 2005 |
"FLASH_ATTN_BACK",
|
| 2006 |
"SSM_CONV",
|
|
@@ -2027,7 +2088,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
| 2027 |
"CROSS_ENTROPY_LOSS_BACK",
|
| 2028 |
};
|
| 2029 |
|
| 2030 |
-
static_assert(GGML_OP_COUNT ==
|
| 2031 |
|
| 2032 |
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
| 2033 |
"none",
|
|
@@ -2091,6 +2152,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
| 2091 |
"leaky_relu(x)",
|
| 2092 |
|
| 2093 |
"flash_attn(x)",
|
|
|
|
| 2094 |
"flash_ff(x)",
|
| 2095 |
"flash_attn_back(x)",
|
| 2096 |
"ssm_conv(x)",
|
|
@@ -2117,7 +2179,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
| 2117 |
"cross_entropy_loss_back(x,y)",
|
| 2118 |
};
|
| 2119 |
|
| 2120 |
-
static_assert(GGML_OP_COUNT ==
|
| 2121 |
|
| 2122 |
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
| 2123 |
|
|
@@ -4575,6 +4637,8 @@ struct ggml_tensor * ggml_mul_mat(
|
|
| 4575 |
void ggml_mul_mat_set_prec(
|
| 4576 |
struct ggml_tensor * a,
|
| 4577 |
enum ggml_prec prec) {
|
|
|
|
|
|
|
| 4578 |
const int32_t prec_i32 = (int32_t) prec;
|
| 4579 |
|
| 4580 |
ggml_set_op_params_i32(a, 0, prec_i32);
|
|
@@ -5413,17 +5477,23 @@ static struct ggml_tensor * ggml_soft_max_impl(
|
|
| 5413 |
GGML_ASSERT(ggml_is_contiguous(a));
|
| 5414 |
|
| 5415 |
if (mask) {
|
|
|
|
| 5416 |
GGML_ASSERT(ggml_is_contiguous(mask));
|
| 5417 |
GGML_ASSERT(ggml_is_matrix(mask));
|
| 5418 |
-
GGML_ASSERT(
|
|
|
|
| 5419 |
}
|
| 5420 |
|
| 5421 |
if (pos) {
|
| 5422 |
GGML_ASSERT(ggml_is_vector(pos));
|
| 5423 |
-
GGML_ASSERT(pos->type == GGML_TYPE_F32);
|
| 5424 |
GGML_ASSERT(pos->ne[0] == a->ne[0]);
|
| 5425 |
}
|
| 5426 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5427 |
if (max_bias > 0.0f) {
|
| 5428 |
GGML_ASSERT(pos);
|
| 5429 |
}
|
|
@@ -6232,6 +6302,59 @@ struct ggml_tensor * ggml_flash_attn(
|
|
| 6232 |
return result;
|
| 6233 |
}
|
| 6234 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6235 |
// ggml_flash_ff
|
| 6236 |
|
| 6237 |
struct ggml_tensor * ggml_flash_ff(
|
|
@@ -12317,7 +12440,7 @@ static void ggml_compute_forward_soft_max_f32(
|
|
| 12317 |
|
| 12318 |
GGML_TENSOR_UNARY_OP_LOCALS
|
| 12319 |
|
| 12320 |
-
const int64_t ne11 = src1 ? src1->ne[1] : 1;
|
| 12321 |
|
| 12322 |
// TODO: is this supposed to be ceil instead of floor?
|
| 12323 |
// https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
|
|
@@ -12340,19 +12463,31 @@ static void ggml_compute_forward_soft_max_f32(
|
|
| 12340 |
float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
|
| 12341 |
|
| 12342 |
// when max_bias <= 0.0f, src2 is not used and we default it to src0 to avoid branching
|
| 12343 |
-
|
|
|
|
|
|
|
|
|
|
| 12344 |
|
| 12345 |
for (int i1 = ir0; i1 < ir1; i1++) {
|
| 12346 |
float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
|
| 12347 |
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
|
| 12348 |
|
| 12349 |
// broadcast the mask across rows
|
| 12350 |
-
|
|
|
|
| 12351 |
|
| 12352 |
ggml_vec_cpy_f32 (nc, wp, sp);
|
| 12353 |
ggml_vec_scale_f32(nc, wp, scale);
|
| 12354 |
-
if (
|
| 12355 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12356 |
}
|
| 12357 |
|
| 12358 |
// ALiBi bias
|
|
@@ -12360,8 +12495,14 @@ static void ggml_compute_forward_soft_max_f32(
|
|
| 12360 |
const uint32_t h = (i1/ne01)%ne02; // head
|
| 12361 |
const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1);
|
| 12362 |
|
| 12363 |
-
|
| 12364 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12365 |
}
|
| 12366 |
}
|
| 12367 |
|
|
@@ -14631,6 +14772,198 @@ static void ggml_compute_forward_flash_attn(
|
|
| 14631 |
}
|
| 14632 |
}
|
| 14633 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14634 |
// ggml_compute_forward_flash_ff
|
| 14635 |
|
| 14636 |
static void ggml_compute_forward_flash_ff_f16(
|
|
@@ -16442,6 +16775,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
| 16442 |
const bool masked = t != 0;
|
| 16443 |
ggml_compute_forward_flash_attn(params, masked, tensor);
|
| 16444 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16445 |
case GGML_OP_FLASH_FF:
|
| 16446 |
{
|
| 16447 |
ggml_compute_forward_flash_ff(params, tensor);
|
|
@@ -17454,6 +17791,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
| 17454 |
GGML_ASSERT(false); // TODO: not implemented
|
| 17455 |
} break;
|
| 17456 |
case GGML_OP_FLASH_ATTN:
|
|
|
|
| 17457 |
{
|
| 17458 |
struct ggml_tensor * flash_grad = NULL;
|
| 17459 |
if (src0->grad || src1->grad || tensor->src[2]->grad) {
|
|
@@ -18231,6 +18569,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
|
|
| 18231 |
n_tasks = n_threads;
|
| 18232 |
} break;
|
| 18233 |
case GGML_OP_FLASH_ATTN:
|
|
|
|
| 18234 |
{
|
| 18235 |
n_tasks = n_threads;
|
| 18236 |
} break;
|
|
@@ -18634,6 +18973,12 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
|
|
| 18634 |
cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
|
| 18635 |
}
|
| 18636 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18637 |
case GGML_OP_FLASH_FF:
|
| 18638 |
{
|
| 18639 |
if (node->src[1]->type == GGML_TYPE_F32) {
|
|
|
|
| 951 |
#define GGML_F16_VEC_ZERO GGML_F16x8_ZERO
|
| 952 |
#define GGML_F16_VEC_SET1 GGML_F16x8_SET1
|
| 953 |
#define GGML_F16_VEC_LOAD(p, i) GGML_F16x8_LOAD(p)
|
| 954 |
+
#define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE((ggml_fp16_internal_t *)(p), r[i])
|
| 955 |
#define GGML_F16_VEC_FMA GGML_F16x8_FMA
|
| 956 |
#define GGML_F16_VEC_ADD GGML_F16x8_ADD
|
| 957 |
#define GGML_F16_VEC_MUL GGML_F16x8_MUL
|
|
|
|
| 977 |
#define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO
|
| 978 |
#define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1
|
| 979 |
#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p)
|
| 980 |
+
#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE((ggml_fp16_internal_t *)(p), r[i])
|
| 981 |
#define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
|
| 982 |
#define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
|
| 983 |
#define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
|
|
|
|
| 1046 |
|
| 1047 |
// unlike _mm256_cvt intrinsics that require F16C, _mm512_cvt is defined in AVX512F
|
| 1048 |
// so F16C guard isn't required
|
| 1049 |
+
#define GGML_F32Cx16_LOAD(x) _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x)))
|
| 1050 |
#define GGML_F32Cx16_STORE(x, y) _mm256_storeu_si256((__m256i *)(x), _mm512_cvtps_ph(y, 0))
|
| 1051 |
|
| 1052 |
#define GGML_F32Cx16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a)
|
|
|
|
| 1144 |
|
| 1145 |
#if defined(__F16C__)
|
| 1146 |
// the _mm256_cvt intrinsics require F16C
|
| 1147 |
+
#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
|
| 1148 |
#define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0))
|
| 1149 |
#else
|
| 1150 |
static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
|
|
|
|
| 1662 |
#endif
|
| 1663 |
}
|
| 1664 |
|
| 1665 |
+
inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, const ggml_fp16_t * restrict x, const float v) {
|
| 1666 |
+
#if defined(GGML_SIMD)
|
| 1667 |
+
const int np = (n & ~(GGML_F16_STEP - 1));
|
| 1668 |
+
|
| 1669 |
+
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
|
| 1670 |
+
|
| 1671 |
+
GGML_F16_VEC ax[GGML_F16_ARR];
|
| 1672 |
+
GGML_F16_VEC ay[GGML_F16_ARR];
|
| 1673 |
+
|
| 1674 |
+
for (int i = 0; i < np; i += GGML_F16_STEP) {
|
| 1675 |
+
for (int j = 0; j < GGML_F16_ARR; j++) {
|
| 1676 |
+
ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
|
| 1677 |
+
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
|
| 1678 |
+
ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
|
| 1679 |
+
|
| 1680 |
+
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
|
| 1681 |
+
}
|
| 1682 |
+
}
|
| 1683 |
+
|
| 1684 |
+
// leftovers
|
| 1685 |
+
for (int i = np; i < n; ++i) {
|
| 1686 |
+
y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
|
| 1687 |
+
}
|
| 1688 |
+
#else
|
| 1689 |
+
// scalar
|
| 1690 |
+
for (int i = 0; i < n; ++i) {
|
| 1691 |
+
y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
|
| 1692 |
+
}
|
| 1693 |
+
#endif
|
| 1694 |
+
}
|
| 1695 |
+
|
| 1696 |
// xs and vs are byte strides of x and v
|
| 1697 |
inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) {
|
| 1698 |
|
|
|
|
| 1777 |
#endif
|
| 1778 |
}
|
| 1779 |
|
| 1780 |
+
inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) {
|
| 1781 |
+
#if defined(GGML_SIMD)
|
| 1782 |
+
const int np = (n & ~(GGML_F16_STEP - 1));
|
| 1783 |
+
|
| 1784 |
+
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
|
| 1785 |
+
|
| 1786 |
+
GGML_F16_VEC ay[GGML_F16_ARR];
|
| 1787 |
+
|
| 1788 |
+
for (int i = 0; i < np; i += GGML_F16_STEP) {
|
| 1789 |
+
for (int j = 0; j < GGML_F16_ARR; j++) {
|
| 1790 |
+
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
|
| 1791 |
+
ay[j] = GGML_F16_VEC_MUL(ay[j], vx);
|
| 1792 |
+
|
| 1793 |
+
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
|
| 1794 |
+
}
|
| 1795 |
+
}
|
| 1796 |
+
|
| 1797 |
+
// leftovers
|
| 1798 |
+
for (int i = np; i < n; ++i) {
|
| 1799 |
+
y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v);
|
| 1800 |
+
}
|
| 1801 |
+
#else
|
| 1802 |
+
// scalar
|
| 1803 |
+
for (int i = 0; i < n; ++i) {
|
| 1804 |
+
y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v);
|
| 1805 |
+
}
|
| 1806 |
+
#endif
|
| 1807 |
+
}
|
| 1808 |
+
|
| 1809 |
inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); }
|
| 1810 |
inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
|
| 1811 |
inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
|
|
|
|
| 2061 |
"LEAKY_RELU",
|
| 2062 |
|
| 2063 |
"FLASH_ATTN",
|
| 2064 |
+
"FLASH_ATTN_EXT",
|
| 2065 |
"FLASH_FF",
|
| 2066 |
"FLASH_ATTN_BACK",
|
| 2067 |
"SSM_CONV",
|
|
|
|
| 2088 |
"CROSS_ENTROPY_LOSS_BACK",
|
| 2089 |
};
|
| 2090 |
|
| 2091 |
+
static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77");
|
| 2092 |
|
| 2093 |
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
| 2094 |
"none",
|
|
|
|
| 2152 |
"leaky_relu(x)",
|
| 2153 |
|
| 2154 |
"flash_attn(x)",
|
| 2155 |
+
"flash_attn_ext(x)",
|
| 2156 |
"flash_ff(x)",
|
| 2157 |
"flash_attn_back(x)",
|
| 2158 |
"ssm_conv(x)",
|
|
|
|
| 2179 |
"cross_entropy_loss_back(x,y)",
|
| 2180 |
};
|
| 2181 |
|
| 2182 |
+
static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77");
|
| 2183 |
|
| 2184 |
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
| 2185 |
|
|
|
|
| 4637 |
void ggml_mul_mat_set_prec(
|
| 4638 |
struct ggml_tensor * a,
|
| 4639 |
enum ggml_prec prec) {
|
| 4640 |
+
GGML_ASSERT(a->op == GGML_OP_MUL_MAT);
|
| 4641 |
+
|
| 4642 |
const int32_t prec_i32 = (int32_t) prec;
|
| 4643 |
|
| 4644 |
ggml_set_op_params_i32(a, 0, prec_i32);
|
|
|
|
| 5477 |
GGML_ASSERT(ggml_is_contiguous(a));
|
| 5478 |
|
| 5479 |
if (mask) {
|
| 5480 |
+
GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32);
|
| 5481 |
GGML_ASSERT(ggml_is_contiguous(mask));
|
| 5482 |
GGML_ASSERT(ggml_is_matrix(mask));
|
| 5483 |
+
GGML_ASSERT(mask->ne[0] == a->ne[0]);
|
| 5484 |
+
GGML_ASSERT(mask->ne[1] >= a->ne[1]);
|
| 5485 |
}
|
| 5486 |
|
| 5487 |
if (pos) {
|
| 5488 |
GGML_ASSERT(ggml_is_vector(pos));
|
| 5489 |
+
GGML_ASSERT(pos->type == GGML_TYPE_F16 || pos->type == GGML_TYPE_F32);
|
| 5490 |
GGML_ASSERT(pos->ne[0] == a->ne[0]);
|
| 5491 |
}
|
| 5492 |
|
| 5493 |
+
if (pos && mask) {
|
| 5494 |
+
GGML_ASSERT(pos->type == mask->type);
|
| 5495 |
+
}
|
| 5496 |
+
|
| 5497 |
if (max_bias > 0.0f) {
|
| 5498 |
GGML_ASSERT(pos);
|
| 5499 |
}
|
|
|
|
| 6302 |
return result;
|
| 6303 |
}
|
| 6304 |
|
| 6305 |
+
// ggml_flash_attn_ext
|
| 6306 |
+
|
| 6307 |
+
struct ggml_tensor * ggml_flash_attn_ext(
|
| 6308 |
+
struct ggml_context * ctx,
|
| 6309 |
+
struct ggml_tensor * q,
|
| 6310 |
+
struct ggml_tensor * k,
|
| 6311 |
+
struct ggml_tensor * v,
|
| 6312 |
+
struct ggml_tensor * mask,
|
| 6313 |
+
float scale) {
|
| 6314 |
+
GGML_ASSERT(ggml_can_mul_mat(k, q));
|
| 6315 |
+
// TODO: check if vT can be multiplied by (k*qT)
|
| 6316 |
+
if (mask) {
|
| 6317 |
+
GGML_ASSERT(ggml_is_contiguous(mask));
|
| 6318 |
+
GGML_ASSERT(mask->ne[2] == 1);
|
| 6319 |
+
GGML_ASSERT(mask->ne[3] == 1);
|
| 6320 |
+
GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) &&
|
| 6321 |
+
"the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big");
|
| 6322 |
+
//GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
|
| 6323 |
+
}
|
| 6324 |
+
|
| 6325 |
+
bool is_node = false;
|
| 6326 |
+
|
| 6327 |
+
if (q->grad || k->grad || v->grad) {
|
| 6328 |
+
is_node = true;
|
| 6329 |
+
}
|
| 6330 |
+
|
| 6331 |
+
// permute(0, 2, 1, 3)
|
| 6332 |
+
int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
|
| 6333 |
+
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
| 6334 |
+
|
| 6335 |
+
float params[] = { scale };
|
| 6336 |
+
ggml_set_op_params(result, params, sizeof(params));
|
| 6337 |
+
|
| 6338 |
+
result->op = GGML_OP_FLASH_ATTN_EXT;
|
| 6339 |
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
| 6340 |
+
result->src[0] = q;
|
| 6341 |
+
result->src[1] = k;
|
| 6342 |
+
result->src[2] = v;
|
| 6343 |
+
result->src[3] = mask;
|
| 6344 |
+
|
| 6345 |
+
return result;
|
| 6346 |
+
}
|
| 6347 |
+
|
| 6348 |
+
void ggml_flash_attn_ext_set_prec(
|
| 6349 |
+
struct ggml_tensor * a,
|
| 6350 |
+
enum ggml_prec prec) {
|
| 6351 |
+
GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
|
| 6352 |
+
|
| 6353 |
+
const int32_t prec_i32 = (int32_t) prec;
|
| 6354 |
+
|
| 6355 |
+
ggml_set_op_params_i32(a, 1, prec_i32); // scale is on first pos
|
| 6356 |
+
}
|
| 6357 |
+
|
| 6358 |
// ggml_flash_ff
|
| 6359 |
|
| 6360 |
struct ggml_tensor * ggml_flash_ff(
|
|
|
|
| 12440 |
|
| 12441 |
GGML_TENSOR_UNARY_OP_LOCALS
|
| 12442 |
|
| 12443 |
+
//const int64_t ne11 = src1 ? src1->ne[1] : 1;
|
| 12444 |
|
| 12445 |
// TODO: is this supposed to be ceil instead of floor?
|
| 12446 |
// https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
|
|
|
|
| 12463 |
float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
|
| 12464 |
|
| 12465 |
// when max_bias <= 0.0f, src2 is not used and we default it to src0 to avoid branching
|
| 12466 |
+
ggml_fp16_t * pos_f16 = src2 ? (ggml_fp16_t *) src2->data : src0->data;
|
| 12467 |
+
float * pos_f32 = src2 ? (float *) src2->data : src0->data;
|
| 12468 |
+
|
| 12469 |
+
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
|
| 12470 |
|
| 12471 |
for (int i1 = ir0; i1 < ir1; i1++) {
|
| 12472 |
float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
|
| 12473 |
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
|
| 12474 |
|
| 12475 |
// broadcast the mask across rows
|
| 12476 |
+
ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
|
| 12477 |
+
float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
|
| 12478 |
|
| 12479 |
ggml_vec_cpy_f32 (nc, wp, sp);
|
| 12480 |
ggml_vec_scale_f32(nc, wp, scale);
|
| 12481 |
+
if (mp_f32) {
|
| 12482 |
+
if (use_f16) {
|
| 12483 |
+
for (int i = 0; i < nc; ++i) {
|
| 12484 |
+
wp[i] += GGML_FP16_TO_FP32(mp_f16[i]);
|
| 12485 |
+
}
|
| 12486 |
+
} else {
|
| 12487 |
+
for (int i = 0; i < nc; ++i) {
|
| 12488 |
+
wp[i] += mp_f32[i];
|
| 12489 |
+
}
|
| 12490 |
+
}
|
| 12491 |
}
|
| 12492 |
|
| 12493 |
// ALiBi bias
|
|
|
|
| 12495 |
const uint32_t h = (i1/ne01)%ne02; // head
|
| 12496 |
const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1);
|
| 12497 |
|
| 12498 |
+
if (use_f16) {
|
| 12499 |
+
for (int i = 0; i < nc; ++i) {
|
| 12500 |
+
wp[i] += slope*GGML_FP16_TO_FP32(pos_f16[i]);
|
| 12501 |
+
}
|
| 12502 |
+
} else {
|
| 12503 |
+
for (int i = 0; i < nc; ++i) {
|
| 12504 |
+
wp[i] += slope*pos_f32[i];
|
| 12505 |
+
}
|
| 12506 |
}
|
| 12507 |
}
|
| 12508 |
|
|
|
|
| 14772 |
}
|
| 14773 |
}
|
| 14774 |
|
| 14775 |
+
// ggml_compute_forward_flash_attn_ext
|
| 14776 |
+
|
| 14777 |
+
static void ggml_compute_forward_flash_attn_ext_f16(
|
| 14778 |
+
const struct ggml_compute_params * params,
|
| 14779 |
+
const struct ggml_tensor * q,
|
| 14780 |
+
const struct ggml_tensor * k,
|
| 14781 |
+
const struct ggml_tensor * v,
|
| 14782 |
+
const struct ggml_tensor * mask,
|
| 14783 |
+
struct ggml_tensor * dst) {
|
| 14784 |
+
int64_t t0 = ggml_perf_time_us();
|
| 14785 |
+
UNUSED(t0);
|
| 14786 |
+
|
| 14787 |
+
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
| 14788 |
+
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
| 14789 |
+
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
| 14790 |
+
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
|
| 14791 |
+
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
|
| 14792 |
+
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
|
| 14793 |
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
| 14794 |
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
| 14795 |
+
|
| 14796 |
+
const int ith = params->ith;
|
| 14797 |
+
const int nth = params->nth;
|
| 14798 |
+
|
| 14799 |
+
const int64_t D = neq0;
|
| 14800 |
+
const int64_t N = neq1;
|
| 14801 |
+
|
| 14802 |
+
GGML_ASSERT(ne0 == D);
|
| 14803 |
+
GGML_ASSERT(ne2 == N);
|
| 14804 |
+
|
| 14805 |
+
GGML_ASSERT(nbq0 == sizeof(float));
|
| 14806 |
+
GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
|
| 14807 |
+
GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
|
| 14808 |
+
|
| 14809 |
+
GGML_ASSERT(neq0 == D);
|
| 14810 |
+
GGML_ASSERT(nek0 == D);
|
| 14811 |
+
GGML_ASSERT(nev0 == D);
|
| 14812 |
+
|
| 14813 |
+
GGML_ASSERT(neq1 == N);
|
| 14814 |
+
GGML_ASSERT(nev0 == D);
|
| 14815 |
+
|
| 14816 |
+
// dst cannot be transposed or permuted
|
| 14817 |
+
GGML_ASSERT(nb0 == sizeof(float));
|
| 14818 |
+
GGML_ASSERT(nb0 <= nb1);
|
| 14819 |
+
GGML_ASSERT(nb1 <= nb2);
|
| 14820 |
+
GGML_ASSERT(nb2 <= nb3);
|
| 14821 |
+
|
| 14822 |
+
// broadcast factors
|
| 14823 |
+
const int64_t rk2 = neq2/nek2;
|
| 14824 |
+
const int64_t rk3 = neq3/nek3;
|
| 14825 |
+
|
| 14826 |
+
const int64_t rv2 = neq2/nev2;
|
| 14827 |
+
const int64_t rv3 = neq3/nev3;
|
| 14828 |
+
|
| 14829 |
+
if (params->type == GGML_TASK_TYPE_INIT) {
|
| 14830 |
+
return;
|
| 14831 |
+
}
|
| 14832 |
+
|
| 14833 |
+
if (params->type == GGML_TASK_TYPE_FINALIZE) {
|
| 14834 |
+
return;
|
| 14835 |
+
}
|
| 14836 |
+
|
| 14837 |
+
// parallelize by q rows using ggml_vec_dot_f32
|
| 14838 |
+
|
| 14839 |
+
// total rows in q
|
| 14840 |
+
const int nr = neq1*neq2*neq3;
|
| 14841 |
+
|
| 14842 |
+
// rows per thread
|
| 14843 |
+
const int dr = (nr + nth - 1)/nth;
|
| 14844 |
+
|
| 14845 |
+
// row range for this thread
|
| 14846 |
+
const int ir0 = dr*ith;
|
| 14847 |
+
const int ir1 = MIN(ir0 + dr, nr);
|
| 14848 |
+
|
| 14849 |
+
float scale = 1.0f;
|
| 14850 |
+
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
| 14851 |
+
|
| 14852 |
+
// loop over n_batch and n_head
|
| 14853 |
+
for (int ir = ir0; ir < ir1; ++ir) {
|
| 14854 |
+
// q indices
|
| 14855 |
+
const int iq3 = ir/(neq2*neq1);
|
| 14856 |
+
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
|
| 14857 |
+
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
|
| 14858 |
+
|
| 14859 |
+
float S = 0.0f;
|
| 14860 |
+
float M = -INFINITY;
|
| 14861 |
+
|
| 14862 |
+
float * V32 = (float *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32);
|
| 14863 |
+
ggml_fp16_t * Q16 = (ggml_fp16_t *) (V32); // reuse memory
|
| 14864 |
+
ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D);
|
| 14865 |
+
|
| 14866 |
+
memset(V16, 0, D*sizeof(ggml_fp16_t));
|
| 14867 |
+
|
| 14868 |
+
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
|
| 14869 |
+
|
| 14870 |
+
// k indices
|
| 14871 |
+
const int ik3 = iq3 / rk3;
|
| 14872 |
+
const int ik2 = iq2 / rk2;
|
| 14873 |
+
|
| 14874 |
+
// v indices
|
| 14875 |
+
const int iv3 = iq3 / rv3;
|
| 14876 |
+
const int iv2 = iq2 / rv2;
|
| 14877 |
+
|
| 14878 |
+
// online softmax / attention
|
| 14879 |
+
// loop over n_kv and n_head_kv
|
| 14880 |
+
// ref: https://arxiv.org/pdf/2112.05682.pdf
|
| 14881 |
+
for (int64_t ic = 0; ic < nek1; ++ic) {
|
| 14882 |
+
const float mv = mp ? GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
|
| 14883 |
+
if (mv == -INFINITY) {
|
| 14884 |
+
continue;
|
| 14885 |
+
}
|
| 14886 |
+
|
| 14887 |
+
float s;
|
| 14888 |
+
|
| 14889 |
+
// convert Q to F16 in V32
|
| 14890 |
+
{
|
| 14891 |
+
const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
|
| 14892 |
+
|
| 14893 |
+
for (int64_t d = 0; d < D; ++d) {
|
| 14894 |
+
Q16[d] = GGML_FP32_TO_FP16(pq[d]);
|
| 14895 |
+
}
|
| 14896 |
+
}
|
| 14897 |
+
|
| 14898 |
+
ggml_vec_dot_f16(D,
|
| 14899 |
+
&s, 0,
|
| 14900 |
+
(ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
|
| 14901 |
+
Q16, 0, 1);
|
| 14902 |
+
|
| 14903 |
+
s = s*scale + mv;
|
| 14904 |
+
|
| 14905 |
+
const float Mold = M;
|
| 14906 |
+
|
| 14907 |
+
float ms = 1.0f;
|
| 14908 |
+
float vs = 1.0f;
|
| 14909 |
+
|
| 14910 |
+
if (s > M) {
|
| 14911 |
+
M = s;
|
| 14912 |
+
ms = expf(Mold - M);
|
| 14913 |
+
|
| 14914 |
+
// V = V*expf(Mold - M)
|
| 14915 |
+
ggml_vec_scale_f16(D, V16, ms);
|
| 14916 |
+
} else {
|
| 14917 |
+
vs = expf(s - M);
|
| 14918 |
+
}
|
| 14919 |
+
|
| 14920 |
+
const ggml_fp16_t * v16 = (const ggml_fp16_t *) ((char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
|
| 14921 |
+
|
| 14922 |
+
// V += v*expf(s - M)
|
| 14923 |
+
ggml_vec_mad_f16(D, V16, v16, vs);
|
| 14924 |
+
|
| 14925 |
+
S = S*ms + vs;
|
| 14926 |
+
}
|
| 14927 |
+
|
| 14928 |
+
// V /= S
|
| 14929 |
+
for (int64_t d = 0; d < D; ++d) {
|
| 14930 |
+
V32[d] = GGML_FP16_TO_FP32(V16[d])/S;
|
| 14931 |
+
}
|
| 14932 |
+
|
| 14933 |
+
// dst indices
|
| 14934 |
+
const int i1 = iq1;
|
| 14935 |
+
const int i2 = iq2;
|
| 14936 |
+
const int i3 = iq3;
|
| 14937 |
+
|
| 14938 |
+
// original
|
| 14939 |
+
//memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
|
| 14940 |
+
|
| 14941 |
+
// permute(0, 2, 1, 3)
|
| 14942 |
+
memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, V32, nb1);
|
| 14943 |
+
}
|
| 14944 |
+
}
|
| 14945 |
+
|
| 14946 |
+
static void ggml_compute_forward_flash_attn_ext(
|
| 14947 |
+
const struct ggml_compute_params * params,
|
| 14948 |
+
const struct ggml_tensor * q,
|
| 14949 |
+
const struct ggml_tensor * k,
|
| 14950 |
+
const struct ggml_tensor * v,
|
| 14951 |
+
const struct ggml_tensor * mask,
|
| 14952 |
+
struct ggml_tensor * dst) {
|
| 14953 |
+
switch (dst->op_params[1]) {
|
| 14954 |
+
case GGML_PREC_DEFAULT:
|
| 14955 |
+
case GGML_PREC_F32:
|
| 14956 |
+
{
|
| 14957 |
+
// uses F32 accumulators
|
| 14958 |
+
ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst);
|
| 14959 |
+
} break;
|
| 14960 |
+
default:
|
| 14961 |
+
{
|
| 14962 |
+
GGML_ASSERT(false);
|
| 14963 |
+
} break;
|
| 14964 |
+
}
|
| 14965 |
+
}
|
| 14966 |
+
|
| 14967 |
// ggml_compute_forward_flash_ff
|
| 14968 |
|
| 14969 |
static void ggml_compute_forward_flash_ff_f16(
|
|
|
|
| 16775 |
const bool masked = t != 0;
|
| 16776 |
ggml_compute_forward_flash_attn(params, masked, tensor);
|
| 16777 |
} break;
|
| 16778 |
+
case GGML_OP_FLASH_ATTN_EXT:
|
| 16779 |
+
{
|
| 16780 |
+
ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
|
| 16781 |
+
} break;
|
| 16782 |
case GGML_OP_FLASH_FF:
|
| 16783 |
{
|
| 16784 |
ggml_compute_forward_flash_ff(params, tensor);
|
|
|
|
| 17791 |
GGML_ASSERT(false); // TODO: not implemented
|
| 17792 |
} break;
|
| 17793 |
case GGML_OP_FLASH_ATTN:
|
| 17794 |
+
case GGML_OP_FLASH_ATTN_EXT:
|
| 17795 |
{
|
| 17796 |
struct ggml_tensor * flash_grad = NULL;
|
| 17797 |
if (src0->grad || src1->grad || tensor->src[2]->grad) {
|
|
|
|
| 18569 |
n_tasks = n_threads;
|
| 18570 |
} break;
|
| 18571 |
case GGML_OP_FLASH_ATTN:
|
| 18572 |
+
case GGML_OP_FLASH_ATTN_EXT:
|
| 18573 |
{
|
| 18574 |
n_tasks = n_threads;
|
| 18575 |
} break;
|
|
|
|
| 18973 |
cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
|
| 18974 |
}
|
| 18975 |
} break;
|
| 18976 |
+
case GGML_OP_FLASH_ATTN_EXT:
|
| 18977 |
+
{
|
| 18978 |
+
const int64_t ne00 = node->src[0]->ne[0]; // D
|
| 18979 |
+
|
| 18980 |
+
cur = 2*sizeof(float)*ne00*n_tasks; // 2x head size
|
| 18981 |
+
} break;
|
| 18982 |
case GGML_OP_FLASH_FF:
|
| 18983 |
{
|
| 18984 |
if (node->src[1]->type == GGML_TYPE_F32) {
|
|
@@ -475,6 +475,7 @@ extern "C" {
|
|
| 475 |
GGML_OP_LEAKY_RELU,
|
| 476 |
|
| 477 |
GGML_OP_FLASH_ATTN,
|
|
|
|
| 478 |
GGML_OP_FLASH_FF,
|
| 479 |
GGML_OP_FLASH_ATTN_BACK,
|
| 480 |
GGML_OP_SSM_CONV,
|
|
@@ -1731,6 +1732,25 @@ extern "C" {
|
|
| 1731 |
struct ggml_tensor * v,
|
| 1732 |
bool masked);
|
| 1733 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1734 |
GGML_API struct ggml_tensor * ggml_flash_attn_back(
|
| 1735 |
struct ggml_context * ctx,
|
| 1736 |
struct ggml_tensor * q,
|
|
|
|
| 475 |
GGML_OP_LEAKY_RELU,
|
| 476 |
|
| 477 |
GGML_OP_FLASH_ATTN,
|
| 478 |
+
GGML_OP_FLASH_ATTN_EXT,
|
| 479 |
GGML_OP_FLASH_FF,
|
| 480 |
GGML_OP_FLASH_ATTN_BACK,
|
| 481 |
GGML_OP_SSM_CONV,
|
|
|
|
| 1732 |
struct ggml_tensor * v,
|
| 1733 |
bool masked);
|
| 1734 |
|
| 1735 |
+
#define GGML_KQ_MASK_PAD 32
|
| 1736 |
+
|
| 1737 |
+
// q: [n_embd, n_batch, n_head, 1]
|
| 1738 |
+
// k: [n_embd, n_kv, n_head_kv, 1]
|
| 1739 |
+
// v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !!
|
| 1740 |
+
// mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
|
| 1741 |
+
// res: [n_embd, n_head, n_batch, 1] !! permuted !!
|
| 1742 |
+
GGML_API struct ggml_tensor * ggml_flash_attn_ext(
|
| 1743 |
+
struct ggml_context * ctx,
|
| 1744 |
+
struct ggml_tensor * q,
|
| 1745 |
+
struct ggml_tensor * k,
|
| 1746 |
+
struct ggml_tensor * v,
|
| 1747 |
+
struct ggml_tensor * mask,
|
| 1748 |
+
float scale);
|
| 1749 |
+
|
| 1750 |
+
GGML_API void ggml_flash_attn_ext_set_prec(
|
| 1751 |
+
struct ggml_tensor * a,
|
| 1752 |
+
enum ggml_prec prec);
|
| 1753 |
+
|
| 1754 |
GGML_API struct ggml_tensor * ggml_flash_attn_back(
|
| 1755 |
struct ggml_context * ctx,
|
| 1756 |
struct ggml_tensor * q,
|