smeso commited on
Commit
e2fe267
·
1 Parent(s): 098f7fa

vulkan: add dryrun support to sin and cos ops (ggml/947)

Browse files

sin and cos failed test-backend-ops because they
tried to dereference a context pointer that is null
on dry runs.

This commit prevents that segfault.

Signed-off-by: Salvatore Mesoraca <[email protected]>

Files changed (1) hide show
  1. ggml/src/ggml-vulkan.cpp +6 -6
ggml/src/ggml-vulkan.cpp CHANGED
@@ -4616,7 +4616,7 @@ static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const
4616
  }, dryrun);
4617
  }
4618
 
4619
- static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
4620
  const uint32_t src0_type_size = ggml_type_size(src0->type);
4621
  const uint32_t dst_type_size = ggml_type_size(dst->type);
4622
 
@@ -4626,10 +4626,10 @@ static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const
4626
  (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
4627
  0,
4628
  0.0f, 0.0f,
4629
- });
4630
  }
4631
 
4632
- static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
4633
  const uint32_t src0_type_size = ggml_type_size(src0->type);
4634
  const uint32_t dst_type_size = ggml_type_size(dst->type);
4635
 
@@ -4639,7 +4639,7 @@ static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const
4639
  (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
4640
  0,
4641
  0.0f, 0.0f,
4642
- });
4643
  }
4644
 
4645
  static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
@@ -5783,11 +5783,11 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
5783
 
5784
  break;
5785
  case GGML_OP_SIN:
5786
- ggml_vk_sin(ctx, compute_ctx, src0, node);
5787
 
5788
  break;
5789
  case GGML_OP_COS:
5790
- ggml_vk_cos(ctx, compute_ctx, src0, node);
5791
 
5792
  break;
5793
  case GGML_OP_CLAMP:
 
4616
  }, dryrun);
4617
  }
4618
 
4619
+ static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
4620
  const uint32_t src0_type_size = ggml_type_size(src0->type);
4621
  const uint32_t dst_type_size = ggml_type_size(dst->type);
4622
 
 
4626
  (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
4627
  0,
4628
  0.0f, 0.0f,
4629
+ }, dryrun);
4630
  }
4631
 
4632
+ static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
4633
  const uint32_t src0_type_size = ggml_type_size(src0->type);
4634
  const uint32_t dst_type_size = ggml_type_size(dst->type);
4635
 
 
4639
  (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
4640
  0,
4641
  0.0f, 0.0f,
4642
+ }, dryrun);
4643
  }
4644
 
4645
  static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
 
5783
 
5784
  break;
5785
  case GGML_OP_SIN:
5786
+ ggml_vk_sin(ctx, compute_ctx, src0, node, dryrun);
5787
 
5788
  break;
5789
  case GGML_OP_COS:
5790
+ ggml_vk_cos(ctx, compute_ctx, src0, node, dryrun);
5791
 
5792
  break;
5793
  case GGML_OP_CLAMP: