Spaces:
Running
Running
woachk
commited on
Commit
·
fa0872f
1
Parent(s):
f22c7e4
kompute : implement op_getrows_f32 (llama/6403)
Browse filesop_getrows_f32 is required since https://github.com/ggerganov/llama.cpp/pull/6122
for the Vulkan w/ Kompute backend to be functional.
As such, implement this op to make this backend functional again.
- ggml-kompute.cpp +13 -1
ggml-kompute.cpp
CHANGED
|
@@ -22,6 +22,7 @@
|
|
| 22 |
#include "shaderop_mul_mat_q4_1.h"
|
| 23 |
#include "shaderop_mul_mat_q6_k.h"
|
| 24 |
#include "shaderop_mul_mat_mat_f32.h"
|
|
|
|
| 25 |
#include "shaderop_getrows_f16.h"
|
| 26 |
#include "shaderop_getrows_q4_0.h"
|
| 27 |
#include "shaderop_getrows_q4_1.h"
|
|
@@ -1146,6 +1147,14 @@ static void ggml_vk_get_rows(
|
|
| 1146 |
seq.record<kp::OpAlgoDispatch>(s_algo);
|
| 1147 |
}
|
| 1148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1149 |
template <typename... Args>
|
| 1150 |
static void ggml_vk_get_rows_f16(Args&&... args) {
|
| 1151 |
const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f16_comp_spv,
|
|
@@ -1371,6 +1380,7 @@ static bool ggml_vk_supports_op(const struct ggml_tensor * op) {
|
|
| 1371 |
return op->ne[3] == 1;
|
| 1372 |
case GGML_OP_GET_ROWS:
|
| 1373 |
switch (op->src[0]->type) {
|
|
|
|
| 1374 |
case GGML_TYPE_F16:
|
| 1375 |
case GGML_TYPE_Q4_0:
|
| 1376 |
case GGML_TYPE_Q4_1:
|
|
@@ -1661,7 +1671,9 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
|
|
| 1661 |
} break;
|
| 1662 |
case GGML_OP_GET_ROWS:
|
| 1663 |
{
|
| 1664 |
-
if (src0t ==
|
|
|
|
|
|
|
| 1665 |
ggml_vk_get_rows_f16(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
|
| 1666 |
} else if (src0t == GGML_TYPE_Q4_0) {
|
| 1667 |
ggml_vk_get_rows_q4_0(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
|
|
|
|
| 22 |
#include "shaderop_mul_mat_q4_1.h"
|
| 23 |
#include "shaderop_mul_mat_q6_k.h"
|
| 24 |
#include "shaderop_mul_mat_mat_f32.h"
|
| 25 |
+
#include "shaderop_getrows_f32.h"
|
| 26 |
#include "shaderop_getrows_f16.h"
|
| 27 |
#include "shaderop_getrows_q4_0.h"
|
| 28 |
#include "shaderop_getrows_q4_1.h"
|
|
|
|
| 1147 |
seq.record<kp::OpAlgoDispatch>(s_algo);
|
| 1148 |
}
|
| 1149 |
|
| 1150 |
+
template <typename... Args>
|
| 1151 |
+
static void ggml_vk_get_rows_f32(Args&&... args) {
|
| 1152 |
+
const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f32_comp_spv,
|
| 1153 |
+
kp::shader_data::op_getrows_f32_comp_spv_len);
|
| 1154 |
+
|
| 1155 |
+
ggml_vk_get_rows(spirv, "f32", sizeof(float), 0, std::forward<Args>(args)...);
|
| 1156 |
+
}
|
| 1157 |
+
|
| 1158 |
template <typename... Args>
|
| 1159 |
static void ggml_vk_get_rows_f16(Args&&... args) {
|
| 1160 |
const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f16_comp_spv,
|
|
|
|
| 1380 |
return op->ne[3] == 1;
|
| 1381 |
case GGML_OP_GET_ROWS:
|
| 1382 |
switch (op->src[0]->type) {
|
| 1383 |
+
case GGML_TYPE_F32:
|
| 1384 |
case GGML_TYPE_F16:
|
| 1385 |
case GGML_TYPE_Q4_0:
|
| 1386 |
case GGML_TYPE_Q4_1:
|
|
|
|
| 1671 |
} break;
|
| 1672 |
case GGML_OP_GET_ROWS:
|
| 1673 |
{
|
| 1674 |
+
if (src0t == GGML_TYPE_F32) {
|
| 1675 |
+
ggml_vk_get_rows_f32(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
|
| 1676 |
+
} else if (src0t == GGML_TYPE_F16) {
|
| 1677 |
ggml_vk_get_rows_f16(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
|
| 1678 |
} else if (src0t == GGML_TYPE_Q4_0) {
|
| 1679 |
ggml_vk_get_rows_q4_0(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
|