Spaces:
Running
Running
Commit
·
97d9aa6
1
Parent(s):
36a3b4e
vulkan: use scalar FA rather than coopmat2 when N==1 (llama/13554)
Browse files
ggml/src/ggml-vulkan/ggml-vulkan.cpp
CHANGED
|
@@ -5872,10 +5872,17 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
| 5872 |
vk_pipeline *pipelines;
|
| 5873 |
bool small_rows = N <= get_fa_num_small_rows(path);
|
| 5874 |
|
|
|
|
|
|
|
| 5875 |
if (small_rows && path == FA_COOPMAT1) {
|
| 5876 |
path = FA_SCALAR;
|
| 5877 |
}
|
| 5878 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5879 |
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
|
| 5880 |
|
| 5881 |
switch (path) {
|
|
|
|
| 5872 |
vk_pipeline *pipelines;
|
| 5873 |
bool small_rows = N <= get_fa_num_small_rows(path);
|
| 5874 |
|
| 5875 |
+
// coopmat1 does not actually support "small rows" (it needs 16 rows).
|
| 5876 |
+
// So use scalar instead.
|
| 5877 |
if (small_rows && path == FA_COOPMAT1) {
|
| 5878 |
path = FA_SCALAR;
|
| 5879 |
}
|
| 5880 |
|
| 5881 |
+
// scalar is faster than coopmat2 when N==1
|
| 5882 |
+
if (N == 1 && path == FA_COOPMAT2) {
|
| 5883 |
+
path = FA_SCALAR;
|
| 5884 |
+
}
|
| 5885 |
+
|
| 5886 |
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
|
| 5887 |
|
| 5888 |
switch (path) {
|