Spaces:
Running
Running
metal : disable FA kernel for HS=256 (llama/7556)
Browse files- ggml-metal.m +9 -6
- ggml-metal.metal +2 -2
ggml-metal.m
CHANGED
|
@@ -184,9 +184,9 @@ enum ggml_metal_kernel_type {
|
|
| 184 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
|
| 185 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
|
| 186 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
|
| 187 |
-
|
| 188 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
|
| 189 |
-
|
| 190 |
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
| 191 |
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
| 192 |
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
|
|
@@ -634,9 +634,9 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
| 634 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, ctx->support_simdgroup_mm);
|
| 635 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, ctx->support_simdgroup_mm);
|
| 636 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, ctx->support_simdgroup_mm);
|
| 637 |
-
|
| 638 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction);
|
| 639 |
-
|
| 640 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
| 641 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
| 642 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
|
@@ -770,6 +770,9 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
| 770 |
case GGML_OP_LEAKY_RELU:
|
| 771 |
return true;
|
| 772 |
case GGML_OP_FLASH_ATTN_EXT:
|
|
|
|
|
|
|
|
|
|
| 773 |
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
| 774 |
case GGML_OP_MUL_MAT:
|
| 775 |
case GGML_OP_MUL_MAT_ID:
|
|
@@ -2573,7 +2576,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 2573 |
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
|
| 2574 |
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
|
| 2575 |
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
|
| 2576 |
-
|
| 2577 |
default:
|
| 2578 |
{
|
| 2579 |
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
|
|
@@ -2586,7 +2589,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 2586 |
|
| 2587 |
switch (ne00) {
|
| 2588 |
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
|
| 2589 |
-
|
| 2590 |
default:
|
| 2591 |
{
|
| 2592 |
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
|
|
|
|
| 184 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
|
| 185 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
|
| 186 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
|
| 187 |
+
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
|
| 188 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
|
| 189 |
+
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
|
| 190 |
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
| 191 |
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
| 192 |
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
|
|
|
|
| 634 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, ctx->support_simdgroup_mm);
|
| 635 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, ctx->support_simdgroup_mm);
|
| 636 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, ctx->support_simdgroup_mm);
|
| 637 |
+
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
|
| 638 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction);
|
| 639 |
+
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
|
| 640 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
| 641 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
| 642 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
|
|
|
| 770 |
case GGML_OP_LEAKY_RELU:
|
| 771 |
return true;
|
| 772 |
case GGML_OP_FLASH_ATTN_EXT:
|
| 773 |
+
if (op->src[0]->ne[0] == 256) {
|
| 774 |
+
return false;
|
| 775 |
+
}
|
| 776 |
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
| 777 |
case GGML_OP_MUL_MAT:
|
| 778 |
case GGML_OP_MUL_MAT_ID:
|
|
|
|
| 2576 |
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
|
| 2577 |
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
|
| 2578 |
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
|
| 2579 |
+
//case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
|
| 2580 |
default:
|
| 2581 |
{
|
| 2582 |
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
|
|
|
|
| 2589 |
|
| 2590 |
switch (ne00) {
|
| 2591 |
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
|
| 2592 |
+
//case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
|
| 2593 |
default:
|
| 2594 |
{
|
| 2595 |
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
|
ggml-metal.metal
CHANGED
|
@@ -2418,7 +2418,7 @@ template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f
|
|
| 2418 |
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>;
|
| 2419 |
template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>;
|
| 2420 |
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>;
|
| 2421 |
-
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
|
| 2422 |
|
| 2423 |
template<int64_t D, int64_t Q = 1, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
|
| 2424 |
kernel void kernel_flash_attn_ext_vec_f16(
|
|
@@ -2696,7 +2696,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|
| 2696 |
}
|
| 2697 |
|
| 2698 |
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
|
| 2699 |
-
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
|
| 2700 |
|
| 2701 |
kernel void kernel_cpy_f16_f16(
|
| 2702 |
device const half * src0,
|
|
|
|
| 2418 |
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>;
|
| 2419 |
template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>;
|
| 2420 |
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>;
|
| 2421 |
+
//template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
|
| 2422 |
|
| 2423 |
template<int64_t D, int64_t Q = 1, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
|
| 2424 |
kernel void kernel_flash_attn_ext_vec_f16(
|
|
|
|
| 2696 |
}
|
| 2697 |
|
| 2698 |
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
|
| 2699 |
+
//template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
|
| 2700 |
|
| 2701 |
kernel void kernel_cpy_f16_f16(
|
| 2702 |
device const half * src0,
|