Spaces:
Running
Running
uvos
commited on
Commit
·
adf6b4b
1
Parent(s):
c3467c7
CUDA/HIP: fix ssm_scan on devices where warp size is not 32 (llama/14196)
Browse files
ggml/src/ggml-cuda/ssm-scan.cu
CHANGED
|
@@ -10,6 +10,8 @@ __global__ void __launch_bounds__(splitD, 2)
|
|
| 10 |
float * __restrict__ dst, const int64_t L) {
|
| 11 |
GGML_UNUSED(src1_nb0);
|
| 12 |
GGML_UNUSED(src2_nb0);
|
|
|
|
|
|
|
| 13 |
const int bidx = blockIdx.x; // split along B
|
| 14 |
const int bidy = blockIdx.y; // split along D
|
| 15 |
const int tid = threadIdx.x;
|
|
@@ -44,16 +46,16 @@ __global__ void __launch_bounds__(splitD, 2)
|
|
| 44 |
if (N == 16) {
|
| 45 |
#pragma unroll
|
| 46 |
for (size_t i = 0; i < splitD / 4; i += 2) {
|
| 47 |
-
float value = A_block[(wid *
|
| 48 |
// todo: bank conflict
|
| 49 |
// I am always confused with how to use the swizzling method to solve
|
| 50 |
// bank conflit. Hoping somebody can tell me.
|
| 51 |
-
smem_A[(wid *
|
| 52 |
}
|
| 53 |
#pragma unroll
|
| 54 |
for (size_t i = 0; i < splitD / 4; i += 2) {
|
| 55 |
-
float value = s0_block[(wid *
|
| 56 |
-
smem_s0[(wid *
|
| 57 |
}
|
| 58 |
}
|
| 59 |
|
|
|
|
| 10 |
float * __restrict__ dst, const int64_t L) {
|
| 11 |
GGML_UNUSED(src1_nb0);
|
| 12 |
GGML_UNUSED(src2_nb0);
|
| 13 |
+
|
| 14 |
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 15 |
const int bidx = blockIdx.x; // split along B
|
| 16 |
const int bidy = blockIdx.y; // split along D
|
| 17 |
const int tid = threadIdx.x;
|
|
|
|
| 46 |
if (N == 16) {
|
| 47 |
#pragma unroll
|
| 48 |
for (size_t i = 0; i < splitD / 4; i += 2) {
|
| 49 |
+
float value = A_block[(wid * warp_size + i) * stride_A + wtid];
|
| 50 |
// todo: bank conflict
|
| 51 |
// I am always confused with how to use the swizzling method to solve
|
| 52 |
// bank conflit. Hoping somebody can tell me.
|
| 53 |
+
smem_A[(wid * warp_size + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
|
| 54 |
}
|
| 55 |
#pragma unroll
|
| 56 |
for (size_t i = 0; i < splitD / 4; i += 2) {
|
| 57 |
+
float value = s0_block[(wid * warp_size + i) * stride_s0 + wtid];
|
| 58 |
+
smem_s0[(wid * warp_size + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
|
| 59 |
}
|
| 60 |
}
|
| 61 |
|