compilade commited on
Commit
92b2d32
·
1 Parent(s): 573d50a

cuda : support Falcon-H1 state size for SSM_SCAN (llama/14602)

Browse files
ggml/src/ggml-cuda/ggml-cuda.cu CHANGED
@@ -3335,8 +3335,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3335
  case GGML_OP_SSM_SCAN: {
3336
  if (op->src[3]->ne[0] == 1) {
3337
  // Mamba2
3338
- // (kernel only supports d_state == 128 && d_head % 16 == 0)
3339
- return op->src[0]->ne[0] == 128 && op->src[0]->ne[1] % 16 == 0;
3340
  } else {
3341
  // Mamba
3342
  // (kernel only supports d_state == 16, d_head == 1, n_head % 128 == 0, n_group == 1)
 
3335
  case GGML_OP_SSM_SCAN: {
3336
  if (op->src[3]->ne[0] == 1) {
3337
  // Mamba2
3338
+ // (kernel only supports (d_state == 128 || d_state == 256) && d_head % 16 == 0)
3339
+ return (op->src[0]->ne[0] == 128 || op->src[0]->ne[0] == 256) && op->src[0]->ne[1] % 16 == 0;
3340
  } else {
3341
  // Mamba
3342
  // (kernel only supports d_state == 16, d_head == 1, n_head % 128 == 0, n_group == 1)
ggml/src/ggml-cuda/ssm-scan.cu CHANGED
@@ -201,11 +201,11 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
201
  const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
202
  const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
203
  cudaStream_t stream) {
204
- const int threads = 128;
205
  // NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
206
  if (src3_nb1 == sizeof(float)) {
207
  // Mamba-2
208
  if (d_state == 128) {
 
209
  GGML_ASSERT(d_state % threads == 0);
210
  // NOTE: can be any power of two between 4 and 64
211
  const int splitH = 16;
@@ -215,10 +215,21 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
215
  src0, src1, src2, src3, src4, src5, src6, dst,
216
  src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
217
  src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
 
 
 
 
 
 
 
 
 
 
218
  } else {
219
- GGML_ABORT("doesn't support d_state!=128.");
220
  }
221
  } else {
 
222
  // Mamba-1
223
  GGML_ASSERT(n_head % threads == 0);
224
  GGML_ASSERT(head_dim == 1);
 
201
  const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
202
  const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
203
  cudaStream_t stream) {
 
204
  // NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
205
  if (src3_nb1 == sizeof(float)) {
206
  // Mamba-2
207
  if (d_state == 128) {
208
+ const int threads = 128;
209
  GGML_ASSERT(d_state % threads == 0);
210
  // NOTE: can be any power of two between 4 and 64
211
  const int splitH = 16;
 
215
  src0, src1, src2, src3, src4, src5, src6, dst,
216
  src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
217
  src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
218
+ } else if (d_state == 256) { // Falcon-H1
219
+ const int threads = 256;
220
+ // NOTE: can be any power of two between 8 and 64
221
+ const int splitH = 16;
222
+ GGML_ASSERT(head_dim % splitH == 0);
223
+ const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1);
224
+ ssm_scan_f32_group<16, 256><<<blocks, threads, 0, stream>>>(
225
+ src0, src1, src2, src3, src4, src5, src6, dst,
226
+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
227
+ src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
228
  } else {
229
+ GGML_ABORT("doesn't support d_state!=(128 or 256).");
230
  }
231
  } else {
232
+ const int threads = 128;
233
  // Mamba-1
234
  GGML_ASSERT(n_head % threads == 0);
235
  GGML_ASSERT(head_dim == 1);