jeffbolznv commited on
Commit
76b8073
·
1 Parent(s): 661360d

vulkan: Fix newly added tests for permuted mul_mat and 1D im2col (llama/10226)

Browse files
Files changed (1) hide show
  1. ggml/src/ggml-vulkan.cpp +21 -6
ggml/src/ggml-vulkan.cpp CHANGED
@@ -3147,7 +3147,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
3147
  const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
3148
  const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
3149
 
3150
- if (mmp == nullptr) {
3151
  // Fall back to dequant + f16 mulmat
3152
  mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16);
3153
  }
@@ -3630,9 +3630,19 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
3630
 
3631
  static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
3632
  VK_LOG_DEBUG("ggml_vk_mul_mat(" << src0 << ", " << src1 << ", " << dst << ")");
3633
- if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1) {
 
 
 
 
 
 
 
 
 
3634
  ggml_vk_mul_mat_vec_p021_f16_f32(ctx, subctx, src0, src1, dst, dryrun);
3635
- } else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1) {
 
3636
  ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, src0, src1, dst, dryrun);
3637
  } else if (dst->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
3638
  ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun);
@@ -3708,7 +3718,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
3708
  const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
3709
  const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
3710
 
3711
- if (mmp == nullptr) {
3712
  GGML_ABORT("fatal error");
3713
  }
3714
 
@@ -4470,7 +4480,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
4470
  const uint32_t OH = is_2D ? dst->ne[2] : 1;
4471
  const uint32_t OW = dst->ne[1];
4472
 
4473
- const uint32_t batch = src1->ne[3];
4474
 
4475
  elements = { OW * KW * KH, OH, batch * IC };
4476
  } break;
@@ -4915,7 +4925,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
4915
  const uint32_t OW = dst->ne[1];
4916
 
4917
  const uint32_t offset_delta = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
4918
- const uint32_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32
4919
 
4920
  const uint32_t pelements = OW * KW * KH;
4921
 
@@ -6804,6 +6814,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
6804
  if (a->ne[3] != b->ne[3]) {
6805
  return false;
6806
  }
 
 
 
 
 
6807
  return true;
6808
  } break;
6809
  case GGML_OP_GET_ROWS:
 
3147
  const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
3148
  const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
3149
 
3150
+ if (qx_needs_dequant) {
3151
  // Fall back to dequant + f16 mulmat
3152
  mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16);
3153
  }
 
3630
 
3631
  static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
3632
  VK_LOG_DEBUG("ggml_vk_mul_mat(" << src0 << ", " << src1 << ", " << dst << ")");
3633
+ if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1 &&
3634
+ // detect 0213 permutation, and batch size of 1
3635
+ src0->nb[0] <= src0->nb[2] &&
3636
+ src0->nb[2] <= src0->nb[1] &&
3637
+ src0->nb[1] <= src0->nb[3] &&
3638
+ src1->nb[0] <= src1->nb[2] &&
3639
+ src1->nb[2] <= src1->nb[1] &&
3640
+ src1->nb[1] <= src1->nb[3] &&
3641
+ src0->ne[3] == 1 &&
3642
+ src1->ne[3] == 1) {
3643
  ggml_vk_mul_mat_vec_p021_f16_f32(ctx, subctx, src0, src1, dst, dryrun);
3644
+ } else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1 &&
3645
+ !ggml_is_permuted(src0) && !ggml_is_permuted(src1)) {
3646
  ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, src0, src1, dst, dryrun);
3647
  } else if (dst->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
3648
  ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun);
 
3718
  const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
3719
  const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
3720
 
3721
+ if (qx_needs_dequant) {
3722
  GGML_ABORT("fatal error");
3723
  }
3724
 
 
4480
  const uint32_t OH = is_2D ? dst->ne[2] : 1;
4481
  const uint32_t OW = dst->ne[1];
4482
 
4483
+ const uint32_t batch = src1->ne[is_2D ? 3 : 2];
4484
 
4485
  elements = { OW * KW * KH, OH, batch * IC };
4486
  } break;
 
4925
  const uint32_t OW = dst->ne[1];
4926
 
4927
  const uint32_t offset_delta = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
4928
+ const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
4929
 
4930
  const uint32_t pelements = OW * KW * KH;
4931
 
 
6814
  if (a->ne[3] != b->ne[3]) {
6815
  return false;
6816
  }
6817
+ if (!(ggml_vk_dim01_contiguous(op->src[0]) || op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) ||
6818
+ !(ggml_vk_dim01_contiguous(op->src[1]) || op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16)) {
6819
+ return false;
6820
+ }
6821
+
6822
  return true;
6823
  } break;
6824
  case GGML_OP_GET_ROWS: