Spaces:
Running
Running
Commit
·
76b8073
1
Parent(s):
661360d
vulkan: Fix newly added tests for permuted mul_mat and 1D im2col (llama/10226)
Browse files- ggml/src/ggml-vulkan.cpp +21 -6
ggml/src/ggml-vulkan.cpp
CHANGED
|
@@ -3147,7 +3147,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|
| 3147 |
const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
|
| 3148 |
const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
|
| 3149 |
|
| 3150 |
-
if (
|
| 3151 |
// Fall back to dequant + f16 mulmat
|
| 3152 |
mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16);
|
| 3153 |
}
|
|
@@ -3630,9 +3630,19 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
|
|
| 3630 |
|
| 3631 |
static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
| 3632 |
VK_LOG_DEBUG("ggml_vk_mul_mat(" << src0 << ", " << src1 << ", " << dst << ")");
|
| 3633 |
-
if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3634 |
ggml_vk_mul_mat_vec_p021_f16_f32(ctx, subctx, src0, src1, dst, dryrun);
|
| 3635 |
-
} else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1
|
|
|
|
| 3636 |
ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, src0, src1, dst, dryrun);
|
| 3637 |
} else if (dst->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
|
| 3638 |
ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun);
|
|
@@ -3708,7 +3718,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
| 3708 |
const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
|
| 3709 |
const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
|
| 3710 |
|
| 3711 |
-
if (
|
| 3712 |
GGML_ABORT("fatal error");
|
| 3713 |
}
|
| 3714 |
|
|
@@ -4470,7 +4480,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
| 4470 |
const uint32_t OH = is_2D ? dst->ne[2] : 1;
|
| 4471 |
const uint32_t OW = dst->ne[1];
|
| 4472 |
|
| 4473 |
-
const uint32_t batch = src1->ne[3];
|
| 4474 |
|
| 4475 |
elements = { OW * KW * KH, OH, batch * IC };
|
| 4476 |
} break;
|
|
@@ -4915,7 +4925,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
| 4915 |
const uint32_t OW = dst->ne[1];
|
| 4916 |
|
| 4917 |
const uint32_t offset_delta = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
|
| 4918 |
-
const uint32_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32
|
| 4919 |
|
| 4920 |
const uint32_t pelements = OW * KW * KH;
|
| 4921 |
|
|
@@ -6804,6 +6814,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
| 6804 |
if (a->ne[3] != b->ne[3]) {
|
| 6805 |
return false;
|
| 6806 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6807 |
return true;
|
| 6808 |
} break;
|
| 6809 |
case GGML_OP_GET_ROWS:
|
|
|
|
| 3147 |
const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
|
| 3148 |
const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
|
| 3149 |
|
| 3150 |
+
if (qx_needs_dequant) {
|
| 3151 |
// Fall back to dequant + f16 mulmat
|
| 3152 |
mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16);
|
| 3153 |
}
|
|
|
|
| 3630 |
|
| 3631 |
static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
| 3632 |
VK_LOG_DEBUG("ggml_vk_mul_mat(" << src0 << ", " << src1 << ", " << dst << ")");
|
| 3633 |
+
if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1 &&
|
| 3634 |
+
// detect 0213 permutation, and batch size of 1
|
| 3635 |
+
src0->nb[0] <= src0->nb[2] &&
|
| 3636 |
+
src0->nb[2] <= src0->nb[1] &&
|
| 3637 |
+
src0->nb[1] <= src0->nb[3] &&
|
| 3638 |
+
src1->nb[0] <= src1->nb[2] &&
|
| 3639 |
+
src1->nb[2] <= src1->nb[1] &&
|
| 3640 |
+
src1->nb[1] <= src1->nb[3] &&
|
| 3641 |
+
src0->ne[3] == 1 &&
|
| 3642 |
+
src1->ne[3] == 1) {
|
| 3643 |
ggml_vk_mul_mat_vec_p021_f16_f32(ctx, subctx, src0, src1, dst, dryrun);
|
| 3644 |
+
} else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1 &&
|
| 3645 |
+
!ggml_is_permuted(src0) && !ggml_is_permuted(src1)) {
|
| 3646 |
ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, src0, src1, dst, dryrun);
|
| 3647 |
} else if (dst->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
|
| 3648 |
ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun);
|
|
|
|
| 3718 |
const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
|
| 3719 |
const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
|
| 3720 |
|
| 3721 |
+
if (qx_needs_dequant) {
|
| 3722 |
GGML_ABORT("fatal error");
|
| 3723 |
}
|
| 3724 |
|
|
|
|
| 4480 |
const uint32_t OH = is_2D ? dst->ne[2] : 1;
|
| 4481 |
const uint32_t OW = dst->ne[1];
|
| 4482 |
|
| 4483 |
+
const uint32_t batch = src1->ne[is_2D ? 3 : 2];
|
| 4484 |
|
| 4485 |
elements = { OW * KW * KH, OH, batch * IC };
|
| 4486 |
} break;
|
|
|
|
| 4925 |
const uint32_t OW = dst->ne[1];
|
| 4926 |
|
| 4927 |
const uint32_t offset_delta = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
|
| 4928 |
+
const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
|
| 4929 |
|
| 4930 |
const uint32_t pelements = OW * KW * KH;
|
| 4931 |
|
|
|
|
| 6814 |
if (a->ne[3] != b->ne[3]) {
|
| 6815 |
return false;
|
| 6816 |
}
|
| 6817 |
+
if (!(ggml_vk_dim01_contiguous(op->src[0]) || op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) ||
|
| 6818 |
+
!(ggml_vk_dim01_contiguous(op->src[1]) || op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16)) {
|
| 6819 |
+
return false;
|
| 6820 |
+
}
|
| 6821 |
+
|
| 6822 |
return true;
|
| 6823 |
} break;
|
| 6824 |
case GGML_OP_GET_ROWS:
|