uvos commited on
Commit
adf6b4b
·
1 Parent(s): c3467c7

CUDA/HIP: fix ssm_scan on devices where warp size is not 32 (llama/14196)

Browse files
Files changed (1) hide show
  1. ggml/src/ggml-cuda/ssm-scan.cu +6 -4
ggml/src/ggml-cuda/ssm-scan.cu CHANGED
@@ -10,6 +10,8 @@ __global__ void __launch_bounds__(splitD, 2)
10
  float * __restrict__ dst, const int64_t L) {
11
  GGML_UNUSED(src1_nb0);
12
  GGML_UNUSED(src2_nb0);
 
 
13
  const int bidx = blockIdx.x; // split along B
14
  const int bidy = blockIdx.y; // split along D
15
  const int tid = threadIdx.x;
@@ -44,16 +46,16 @@ __global__ void __launch_bounds__(splitD, 2)
44
  if (N == 16) {
45
  #pragma unroll
46
  for (size_t i = 0; i < splitD / 4; i += 2) {
47
- float value = A_block[(wid * warpSize + i) * stride_A + wtid];
48
  // todo: bank conflict
49
  // I am always confused with how to use the swizzling method to solve
50
  // bank conflit. Hoping somebody can tell me.
51
- smem_A[(wid * warpSize + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
52
  }
53
  #pragma unroll
54
  for (size_t i = 0; i < splitD / 4; i += 2) {
55
- float value = s0_block[(wid * warpSize + i) * stride_s0 + wtid];
56
- smem_s0[(wid * warpSize + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
57
  }
58
  }
59
 
 
10
  float * __restrict__ dst, const int64_t L) {
11
  GGML_UNUSED(src1_nb0);
12
  GGML_UNUSED(src2_nb0);
13
+
14
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
15
  const int bidx = blockIdx.x; // split along B
16
  const int bidy = blockIdx.y; // split along D
17
  const int tid = threadIdx.x;
 
46
  if (N == 16) {
47
  #pragma unroll
48
  for (size_t i = 0; i < splitD / 4; i += 2) {
49
+ float value = A_block[(wid * warp_size + i) * stride_A + wtid];
50
  // todo: bank conflict
51
  // I am always confused with how to use the swizzling method to solve
52
  // bank conflit. Hoping somebody can tell me.
53
+ smem_A[(wid * warp_size + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
54
  }
55
  #pragma unroll
56
  for (size_t i = 0; i < splitD / 4; i += 2) {
57
+ float value = s0_block[(wid * warp_size + i) * stride_s0 + wtid];
58
+ smem_s0[(wid * warp_size + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
59
  }
60
  }
61