@@ -201,11 +201,11 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
201201 const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
202202 const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
203203 cudaStream_t stream) {
204- const int threads = 128 ;
205204 // NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
206205 if (src3_nb1 == sizeof (float )) {
207206 // Mamba-2
208207 if (d_state == 128 ) {
208+ const int threads = 128 ;
209209 GGML_ASSERT (d_state % threads == 0 );
210210 // NOTE: can be any power of two between 4 and 64
211211 const int splitH = 16 ;
@@ -215,10 +215,21 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
215215 src0, src1, src2, src3, src4, src5, src6, dst,
216216 src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
217217 src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
218+ } else if (d_state == 256 ) { // Falcon-H1
219+ const int threads = 256 ;
220+ // NOTE: can be any power of two between 8 and 64
221+ const int splitH = 16 ;
222+ GGML_ASSERT (head_dim % splitH == 0 );
223+ const dim3 blocks ((n_head * head_dim + (splitH - 1 )) / splitH, n_seq, 1 );
224+ ssm_scan_f32_group<16 , 256 ><<<blocks, threads, 0 , stream>>> (
225+ src0, src1, src2, src3, src4, src5, src6, dst,
226+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
227+ src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
218228 } else {
219- GGML_ABORT (" doesn't support d_state!=128." );
229+ GGML_ABORT (" doesn't support d_state!=( 128 or 256) ." );
220230 }
221231 } else {
232+ const int threads = 128 ;
222233 // Mamba-1
223234 GGML_ASSERT (n_head % threads == 0 );
224235 GGML_ASSERT (head_dim == 1 );
0 commit comments