Spaces:
Running
Running
Commit
·
97cb7ce
1
Parent(s):
8d3e707
CUDA: enable Gemma FA for HIP/Pascal (llama/9581)
Browse files- ggml/src/ggml-cuda.cu +8 -8
- ggml/src/ggml-cuda/fattn.cu +1 -1
ggml/src/ggml-cuda.cu
CHANGED
|
@@ -2976,19 +2976,19 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
|
| 2976 |
case GGML_OP_LEAKY_RELU:
|
| 2977 |
case GGML_OP_RWKV_WKV:
|
| 2978 |
return true;
|
| 2979 |
-
case GGML_OP_FLASH_ATTN_EXT:
|
| 2980 |
-
|
| 2981 |
-
|
| 2982 |
-
|
| 2983 |
if (op->src[0]->ne[0] == 128) {
|
| 2984 |
return true;
|
| 2985 |
}
|
| 2986 |
-
if (op->src[0]->ne[0] ==
|
| 2987 |
return true;
|
| 2988 |
}
|
| 2989 |
-
|
| 2990 |
-
|
| 2991 |
-
|
| 2992 |
case GGML_OP_CROSS_ENTROPY_LOSS:
|
| 2993 |
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
| 2994 |
case GGML_OP_OPT_STEP_ADAMW:
|
|
|
|
| 2976 |
case GGML_OP_LEAKY_RELU:
|
| 2977 |
case GGML_OP_RWKV_WKV:
|
| 2978 |
return true;
|
| 2979 |
+
case GGML_OP_FLASH_ATTN_EXT: {
|
| 2980 |
+
if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
|
| 2981 |
+
return true;
|
| 2982 |
+
}
|
| 2983 |
if (op->src[0]->ne[0] == 128) {
|
| 2984 |
return true;
|
| 2985 |
}
|
| 2986 |
+
if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
|
| 2987 |
return true;
|
| 2988 |
}
|
| 2989 |
+
const int cc = ggml_cuda_info().devices[cuda_ctx->device].cc;
|
| 2990 |
+
return cc >= CC_VOLTA && cc < CC_OFFSET_AMD && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
|
| 2991 |
+
}
|
| 2992 |
case GGML_OP_CROSS_ENTROPY_LOSS:
|
| 2993 |
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
| 2994 |
case GGML_OP_OPT_STEP_ADAMW:
|
ggml/src/ggml-cuda/fattn.cu
CHANGED
|
@@ -314,7 +314,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|
| 314 |
}
|
| 315 |
|
| 316 |
if (!fast_fp16_available(cc)) {
|
| 317 |
-
if (Q->ne[1] <= 8) {
|
| 318 |
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
| 319 |
} else {
|
| 320 |
ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
|
|
|
|
| 314 |
}
|
| 315 |
|
| 316 |
if (!fast_fp16_available(cc)) {
|
| 317 |
+
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
|
| 318 |
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
| 319 |
} else {
|
| 320 |
ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
|