Spaces:
Running
Running
cuda : support Falcon-H1 state size for SSM_SCAN (llama/14602)
Browse files
ggml/src/ggml-cuda/ggml-cuda.cu
CHANGED
|
@@ -3335,8 +3335,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|
| 3335 |
case GGML_OP_SSM_SCAN: {
|
| 3336 |
if (op->src[3]->ne[0] == 1) {
|
| 3337 |
// Mamba2
|
| 3338 |
-
// (kernel only supports d_state == 128 && d_head % 16 == 0)
|
| 3339 |
-
return op->src[0]->ne[0] == 128 && op->src[0]->ne[1] % 16 == 0;
|
| 3340 |
} else {
|
| 3341 |
// Mamba
|
| 3342 |
// (kernel only supports d_state == 16, d_head == 1, n_head % 128 == 0, n_group == 1)
|
|
|
|
| 3335 |
case GGML_OP_SSM_SCAN: {
|
| 3336 |
if (op->src[3]->ne[0] == 1) {
|
| 3337 |
// Mamba2
|
| 3338 |
+
// (kernel only supports (d_state == 128 || d_state == 256) && d_head % 16 == 0)
|
| 3339 |
+
return (op->src[0]->ne[0] == 128 || op->src[0]->ne[0] == 256) && op->src[0]->ne[1] % 16 == 0;
|
| 3340 |
} else {
|
| 3341 |
// Mamba
|
| 3342 |
// (kernel only supports d_state == 16, d_head == 1, n_head % 128 == 0, n_group == 1)
|
ggml/src/ggml-cuda/ssm-scan.cu
CHANGED
|
@@ -201,11 +201,11 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
|
|
| 201 |
const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
|
| 202 |
const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
|
| 203 |
cudaStream_t stream) {
|
| 204 |
-
const int threads = 128;
|
| 205 |
// NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
|
| 206 |
if (src3_nb1 == sizeof(float)) {
|
| 207 |
// Mamba-2
|
| 208 |
if (d_state == 128) {
|
|
|
|
| 209 |
GGML_ASSERT(d_state % threads == 0);
|
| 210 |
// NOTE: can be any power of two between 4 and 64
|
| 211 |
const int splitH = 16;
|
|
@@ -215,10 +215,21 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
|
|
| 215 |
src0, src1, src2, src3, src4, src5, src6, dst,
|
| 216 |
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
|
| 217 |
src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
} else {
|
| 219 |
-
GGML_ABORT("doesn't support d_state!=128.");
|
| 220 |
}
|
| 221 |
} else {
|
|
|
|
| 222 |
// Mamba-1
|
| 223 |
GGML_ASSERT(n_head % threads == 0);
|
| 224 |
GGML_ASSERT(head_dim == 1);
|
|
|
|
| 201 |
const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
|
| 202 |
const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
|
| 203 |
cudaStream_t stream) {
|
|
|
|
| 204 |
// NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
|
| 205 |
if (src3_nb1 == sizeof(float)) {
|
| 206 |
// Mamba-2
|
| 207 |
if (d_state == 128) {
|
| 208 |
+
const int threads = 128;
|
| 209 |
GGML_ASSERT(d_state % threads == 0);
|
| 210 |
// NOTE: can be any power of two between 4 and 64
|
| 211 |
const int splitH = 16;
|
|
|
|
| 215 |
src0, src1, src2, src3, src4, src5, src6, dst,
|
| 216 |
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
|
| 217 |
src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
|
| 218 |
+
} else if (d_state == 256) { // Falcon-H1
|
| 219 |
+
const int threads = 256;
|
| 220 |
+
// NOTE: can be any power of two between 8 and 64
|
| 221 |
+
const int splitH = 16;
|
| 222 |
+
GGML_ASSERT(head_dim % splitH == 0);
|
| 223 |
+
const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1);
|
| 224 |
+
ssm_scan_f32_group<16, 256><<<blocks, threads, 0, stream>>>(
|
| 225 |
+
src0, src1, src2, src3, src4, src5, src6, dst,
|
| 226 |
+
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
|
| 227 |
+
src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
|
| 228 |
} else {
|
| 229 |
+
GGML_ABORT("doesn't support d_state!=(128 or 256).");
|
| 230 |
}
|
| 231 |
} else {
|
| 232 |
+
const int threads = 128;
|
| 233 |
// Mamba-1
|
| 234 |
GGML_ASSERT(n_head % threads == 0);
|
| 235 |
GGML_ASSERT(head_dim == 1);
|