woachk commited on
Commit
fa0872f
·
1 Parent(s): f22c7e4

kompute : implement op_getrows_f32 (llama/6403)

Browse files

op_getrows_f32 is required since https://github.com/ggerganov/llama.cpp/pull/6122
for the Vulkan w/ Kompute backend to be functional.

As such, implement this op to make this backend functional again.

Files changed (1) hide show
  1. ggml-kompute.cpp +13 -1
ggml-kompute.cpp CHANGED
@@ -22,6 +22,7 @@
22
  #include "shaderop_mul_mat_q4_1.h"
23
  #include "shaderop_mul_mat_q6_k.h"
24
  #include "shaderop_mul_mat_mat_f32.h"
 
25
  #include "shaderop_getrows_f16.h"
26
  #include "shaderop_getrows_q4_0.h"
27
  #include "shaderop_getrows_q4_1.h"
@@ -1146,6 +1147,14 @@ static void ggml_vk_get_rows(
1146
  seq.record<kp::OpAlgoDispatch>(s_algo);
1147
  }
1148
 
 
 
 
 
 
 
 
 
1149
  template <typename... Args>
1150
  static void ggml_vk_get_rows_f16(Args&&... args) {
1151
  const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f16_comp_spv,
@@ -1371,6 +1380,7 @@ static bool ggml_vk_supports_op(const struct ggml_tensor * op) {
1371
  return op->ne[3] == 1;
1372
  case GGML_OP_GET_ROWS:
1373
  switch (op->src[0]->type) {
 
1374
  case GGML_TYPE_F16:
1375
  case GGML_TYPE_Q4_0:
1376
  case GGML_TYPE_Q4_1:
@@ -1661,7 +1671,9 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
1661
  } break;
1662
  case GGML_OP_GET_ROWS:
1663
  {
1664
- if (src0t == GGML_TYPE_F16) {
 
 
1665
  ggml_vk_get_rows_f16(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
1666
  } else if (src0t == GGML_TYPE_Q4_0) {
1667
  ggml_vk_get_rows_q4_0(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
 
22
  #include "shaderop_mul_mat_q4_1.h"
23
  #include "shaderop_mul_mat_q6_k.h"
24
  #include "shaderop_mul_mat_mat_f32.h"
25
+ #include "shaderop_getrows_f32.h"
26
  #include "shaderop_getrows_f16.h"
27
  #include "shaderop_getrows_q4_0.h"
28
  #include "shaderop_getrows_q4_1.h"
 
1147
  seq.record<kp::OpAlgoDispatch>(s_algo);
1148
  }
1149
 
1150
+ template <typename... Args>
1151
+ static void ggml_vk_get_rows_f32(Args&&... args) {
1152
+ const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f32_comp_spv,
1153
+ kp::shader_data::op_getrows_f32_comp_spv_len);
1154
+
1155
+ ggml_vk_get_rows(spirv, "f32", sizeof(float), 0, std::forward<Args>(args)...);
1156
+ }
1157
+
1158
  template <typename... Args>
1159
  static void ggml_vk_get_rows_f16(Args&&... args) {
1160
  const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f16_comp_spv,
 
1380
  return op->ne[3] == 1;
1381
  case GGML_OP_GET_ROWS:
1382
  switch (op->src[0]->type) {
1383
+ case GGML_TYPE_F32:
1384
  case GGML_TYPE_F16:
1385
  case GGML_TYPE_Q4_0:
1386
  case GGML_TYPE_Q4_1:
 
1671
  } break;
1672
  case GGML_OP_GET_ROWS:
1673
  {
1674
+ if (src0t == GGML_TYPE_F32) {
1675
+ ggml_vk_get_rows_f32(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
1676
+ } else if (src0t == GGML_TYPE_F16) {
1677
  ggml_vk_get_rows_f16(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
1678
  } else if (src0t == GGML_TYPE_Q4_0) {
1679
  ggml_vk_get_rows_q4_0(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));