jeffbolznv commited on
Commit
4d0d8b8
·
1 Parent(s): bac21a7

vulkan: fix noncontig check for mat_mul_id splitting (llama/14683)

Browse files

* vulkan: fix noncontig check for mat_mul_id splitting

Remove supports_op check for > 4096 (splitting fixes this)

* vulkan: fix batched matmul dequant for Q*_K

ggml/src/ggml-vulkan/ggml-vulkan.cpp CHANGED
@@ -4922,7 +4922,7 @@ static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) {
4922
  return
4923
  tensor->nb[0] == ggml_type_size(tensor->type) &&
4924
  tensor->nb[1] == (tensor->nb[0]*tensor->ne[0])/ggml_blck_size(tensor->type) &&
4925
- tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
4926
  }
4927
 
4928
  static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src, const ggml_tensor * dst, ggml_type to) {
@@ -10356,10 +10356,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10356
  // If there's not enough shared memory for row_ids and the result tile, fallback to CPU
10357
  return false;
10358
  }
10359
- // Check against size of shared memory variable
10360
- if (op->src[2]->ne[0] > 4096) {
10361
- return false;
10362
- }
10363
  }
10364
  switch (src0_type) {
10365
  case GGML_TYPE_F32:
 
4922
  return
4923
  tensor->nb[0] == ggml_type_size(tensor->type) &&
4924
  tensor->nb[1] == (tensor->nb[0]*tensor->ne[0])/ggml_blck_size(tensor->type) &&
4925
+ (tensor->ne[3] == 1 || tensor->nb[3] == tensor->nb[2]*tensor->ne[2]);
4926
  }
4927
 
4928
  static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src, const ggml_tensor * dst, ggml_type to) {
 
10356
  // If there's not enough shared memory for row_ids and the result tile, fallback to CPU
10357
  return false;
10358
  }
 
 
 
 
10359
  }
10360
  switch (src0_type) {
10361
  case GGML_TYPE_F32:
ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp CHANGED
@@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
10
  void main() {
11
  [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
12
  const uint i = gl_WorkGroupID.x * 256 + wgy;
13
- if (i >= p.M * p.K / QUANT_K) {
14
  return;
15
  }
16
 
 
10
  void main() {
11
  [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
12
  const uint i = gl_WorkGroupID.x * 256 + wgy;
13
+ if (i >= p.nel / QUANT_K) {
14
  return;
15
  }
16
 
ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp CHANGED
@@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
10
  void main() {
11
  [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
12
  const uint i = uint(gl_WorkGroupID.x * 256 + wgy);
13
- if (i >= p.M * p.K / QUANT_K) {
14
  return;
15
  }
16
 
 
10
  void main() {
11
  [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
12
  const uint i = uint(gl_WorkGroupID.x * 256 + wgy);
13
+ if (i >= p.nel / QUANT_K) {
14
  return;
15
  }
16
 
ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp CHANGED
@@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
10
  void main() {
11
  [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
12
  const uint ib = gl_WorkGroupID.x * 256 + wgy;
13
- if (ib >= p.M * p.K / QUANT_K) {
14
  return;
15
  }
16
 
 
10
  void main() {
11
  [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
12
  const uint ib = gl_WorkGroupID.x * 256 + wgy;
13
+ if (ib >= p.nel / QUANT_K) {
14
  return;
15
  }
16
 
ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp CHANGED
@@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
10
  void main() {
11
  [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
12
  const uint ib = gl_WorkGroupID.x * 256 + wgy;
13
- if (ib >= p.M * p.K / QUANT_K) {
14
  return;
15
  }
16
 
 
10
  void main() {
11
  [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
12
  const uint ib = gl_WorkGroupID.x * 256 + wgy;
13
+ if (ib >= p.nel / QUANT_K) {
14
  return;
15
  }
16
 
ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp CHANGED
@@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
10
  void main() {
11
  [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
12
  const uint i = gl_WorkGroupID.x * 256 + wgy;
13
- if (i >= p.M * p.K / QUANT_K) {
14
  return;
15
  }
16
  const uint tid = gl_LocalInvocationID.x;
 
10
  void main() {
11
  [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
12
  const uint i = gl_WorkGroupID.x * 256 + wgy;
13
+ if (i >= p.nel / QUANT_K) {
14
  return;
15
  }
16
  const uint tid = gl_LocalInvocationID.x;