jeffbolznv commited on
Commit
97d9aa6
·
1 Parent(s): 36a3b4e

vulkan: use scalar FA rather than coopmat2 when N==1 (llama/13554)

Browse files
ggml/src/ggml-vulkan/ggml-vulkan.cpp CHANGED
@@ -5872,10 +5872,17 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
5872
  vk_pipeline *pipelines;
5873
  bool small_rows = N <= get_fa_num_small_rows(path);
5874
 
 
 
5875
  if (small_rows && path == FA_COOPMAT1) {
5876
  path = FA_SCALAR;
5877
  }
5878
 
 
 
 
 
 
5879
  bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
5880
 
5881
  switch (path) {
 
5872
  vk_pipeline *pipelines;
5873
  bool small_rows = N <= get_fa_num_small_rows(path);
5874
 
5875
+ // coopmat1 does not actually support "small rows" (it needs 16 rows).
5876
+ // So use scalar instead.
5877
  if (small_rows && path == FA_COOPMAT1) {
5878
  path = FA_SCALAR;
5879
  }
5880
 
5881
+ // scalar is faster than coopmat2 when N==1
5882
+ if (N == 1 && path == FA_COOPMAT2) {
5883
+ path = FA_SCALAR;
5884
+ }
5885
+
5886
  bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
5887
 
5888
  switch (path) {