ggerganov commited on
Commit
0c32e28
·
1 Parent(s): 0641dee

metal : disable FA kernel for HS=256 (llama/7556)

Browse files
Files changed (2) hide show
  1. ggml-metal.m +9 -6
  2. ggml-metal.metal +2 -2
ggml-metal.m CHANGED
@@ -184,9 +184,9 @@ enum ggml_metal_kernel_type {
184
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
185
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
186
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
187
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
188
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
189
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
190
  GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
191
  GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
192
  GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
@@ -634,9 +634,9 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
634
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, ctx->support_simdgroup_mm);
635
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, ctx->support_simdgroup_mm);
636
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, ctx->support_simdgroup_mm);
637
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
638
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction);
639
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
640
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
641
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
642
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
@@ -770,6 +770,9 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
770
  case GGML_OP_LEAKY_RELU:
771
  return true;
772
  case GGML_OP_FLASH_ATTN_EXT:
 
 
 
773
  return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
774
  case GGML_OP_MUL_MAT:
775
  case GGML_OP_MUL_MAT_ID:
@@ -2573,7 +2576,7 @@ static enum ggml_status ggml_metal_graph_compute(
2573
  case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
2574
  case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
2575
  case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
2576
- case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
2577
  default:
2578
  {
2579
  GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
@@ -2586,7 +2589,7 @@ static enum ggml_status ggml_metal_graph_compute(
2586
 
2587
  switch (ne00) {
2588
  case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
2589
- case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
2590
  default:
2591
  {
2592
  GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
 
184
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
185
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
186
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
187
+ //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
188
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
189
+ //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
190
  GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
191
  GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
192
  GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
 
634
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, ctx->support_simdgroup_mm);
635
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, ctx->support_simdgroup_mm);
636
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, ctx->support_simdgroup_mm);
637
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
638
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction);
639
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
640
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
641
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
642
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
 
770
  case GGML_OP_LEAKY_RELU:
771
  return true;
772
  case GGML_OP_FLASH_ATTN_EXT:
773
+ if (op->src[0]->ne[0] == 256) {
774
+ return false;
775
+ }
776
  return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
777
  case GGML_OP_MUL_MAT:
778
  case GGML_OP_MUL_MAT_ID:
 
2576
  case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
2577
  case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
2578
  case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
2579
+ //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
2580
  default:
2581
  {
2582
  GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
 
2589
 
2590
  switch (ne00) {
2591
  case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
2592
+ //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
2593
  default:
2594
  {
2595
  GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
ggml-metal.metal CHANGED
@@ -2418,7 +2418,7 @@ template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f
2418
  template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>;
2419
  template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>;
2420
  template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>;
2421
- template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
2422
 
2423
  template<int64_t D, int64_t Q = 1, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
2424
  kernel void kernel_flash_attn_ext_vec_f16(
@@ -2696,7 +2696,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
2696
  }
2697
 
2698
  template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
2699
- template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
2700
 
2701
  kernel void kernel_cpy_f16_f16(
2702
  device const half * src0,
 
2418
  template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>;
2419
  template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>;
2420
  template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>;
2421
+ //template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
2422
 
2423
  template<int64_t D, int64_t Q = 1, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
2424
  kernel void kernel_flash_attn_ext_vec_f16(
 
2696
  }
2697
 
2698
  template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
2699
+ //template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
2700
 
2701
  kernel void kernel_cpy_f16_f16(
2702
  device const half * src0,