JohannesGaessler commited on
Commit
97cb7ce
·
1 Parent(s): 8d3e707

CUDA: enable Gemma FA for HIP/Pascal (llama/9581)

Browse files
ggml/src/ggml-cuda.cu CHANGED
@@ -2976,19 +2976,19 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
2976
  case GGML_OP_LEAKY_RELU:
2977
  case GGML_OP_RWKV_WKV:
2978
  return true;
2979
- case GGML_OP_FLASH_ATTN_EXT:
2980
- #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
2981
- return (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) || op->src[0]->ne[0] == 128;
2982
- #else
2983
  if (op->src[0]->ne[0] == 128) {
2984
  return true;
2985
  }
2986
- if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
2987
  return true;
2988
  }
2989
- return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA &&
2990
- op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
2991
- #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
2992
  case GGML_OP_CROSS_ENTROPY_LOSS:
2993
  case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
2994
  case GGML_OP_OPT_STEP_ADAMW:
 
2976
  case GGML_OP_LEAKY_RELU:
2977
  case GGML_OP_RWKV_WKV:
2978
  return true;
2979
+ case GGML_OP_FLASH_ATTN_EXT: {
2980
+ if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
2981
+ return true;
2982
+ }
2983
  if (op->src[0]->ne[0] == 128) {
2984
  return true;
2985
  }
2986
+ if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
2987
  return true;
2988
  }
2989
+ const int cc = ggml_cuda_info().devices[cuda_ctx->device].cc;
2990
+ return cc >= CC_VOLTA && cc < CC_OFFSET_AMD && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
2991
+ }
2992
  case GGML_OP_CROSS_ENTROPY_LOSS:
2993
  case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
2994
  case GGML_OP_OPT_STEP_ADAMW:
ggml/src/ggml-cuda/fattn.cu CHANGED
@@ -314,7 +314,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
314
  }
315
 
316
  if (!fast_fp16_available(cc)) {
317
- if (Q->ne[1] <= 8) {
318
  ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
319
  } else {
320
  ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
 
314
  }
315
 
316
  if (!fast_fp16_available(cc)) {
317
+ if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
318
  ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
319
  } else {
320
  ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);