@@ -482,41 +482,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
482
482
" Tensor page_table, float scale) -> ()" );
483
483
ops.impl (" cutlass_mla_decode" , torch::kCUDA , &cutlass_mla_decode);
484
484
485
- // Mamba selective scan kernel
486
- ops.def (
487
- " selective_scan_fwd(Tensor! u, Tensor! delta,"
488
- " Tensor! A, Tensor! B, Tensor! C,"
489
- " Tensor? D_, Tensor!? z_, Tensor? delta_bias_,"
490
- " bool delta_softplus,"
491
- " Tensor? query_start_loc,"
492
- " Tensor? cache_indices,"
493
- " Tensor? has_initial_state,"
494
- " Tensor! ssm_states,"
495
- " int pad_slot_id) -> ()" );
496
- ops.impl (" selective_scan_fwd" , torch::kCUDA , &selective_scan_fwd);
497
-
498
- ops.def (
499
- " causal_conv1d_update(Tensor! x,"
500
- " Tensor! conv_state,"
501
- " Tensor! weight,"
502
- " Tensor? bias_,"
503
- " bool silu_activation,"
504
- " Tensor? cache_seqlens_,"
505
- " Tensor? conv_state_indices,"
506
- " int pad_slot_id) -> ()" );
507
- ops.impl (" causal_conv1d_update" , torch::kCUDA , &causal_conv1d_update);
508
-
509
- ops.def (
510
- " causal_conv1d_fwd(Tensor! x, Tensor! weight,"
511
- " Tensor? bias_,"
512
- " Tensor!? conv_states,"
513
- " Tensor? query_start_loc,"
514
- " Tensor? cache_indices,"
515
- " Tensor? has_initial_state,"
516
- " bool silu_activation,"
517
- " int pad_slot_id) -> ()" );
518
- ops.impl (" causal_conv1d_fwd" , torch::kCUDA , &causal_conv1d_fwd);
519
-
520
485
// Compute NVFP4 block quantized tensor.
521
486
ops.def (
522
487
" scaled_fp4_quant(Tensor! output, Tensor input,"
@@ -584,6 +549,41 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
584
549
ops.impl (" dynamic_scaled_int8_quant" , torch::kCUDA ,
585
550
&dynamic_scaled_int8_quant);
586
551
552
+ // Mamba selective scan kernel
553
+ ops.def (
554
+ " selective_scan_fwd(Tensor! u, Tensor! delta,"
555
+ " Tensor! A, Tensor! B, Tensor! C,"
556
+ " Tensor? D_, Tensor!? z_, Tensor? delta_bias_,"
557
+ " bool delta_softplus,"
558
+ " Tensor? query_start_loc,"
559
+ " Tensor? cache_indices,"
560
+ " Tensor? has_initial_state,"
561
+ " Tensor! ssm_states,"
562
+ " int pad_slot_id) -> ()" );
563
+ ops.impl (" selective_scan_fwd" , torch::kCUDA , &selective_scan_fwd);
564
+
565
+ ops.def (
566
+ " causal_conv1d_update(Tensor! x,"
567
+ " Tensor! conv_state,"
568
+ " Tensor! weight,"
569
+ " Tensor? bias_,"
570
+ " bool silu_activation,"
571
+ " Tensor? cache_seqlens_,"
572
+ " Tensor? conv_state_indices,"
573
+ " int pad_slot_id) -> ()" );
574
+ ops.impl (" causal_conv1d_update" , torch::kCUDA , &causal_conv1d_update);
575
+
576
+ ops.def (
577
+ " causal_conv1d_fwd(Tensor! x, Tensor! weight,"
578
+ " Tensor? bias_,"
579
+ " Tensor!? conv_states,"
580
+ " Tensor? query_start_loc,"
581
+ " Tensor? cache_indices,"
582
+ " Tensor? has_initial_state,"
583
+ " bool silu_activation,"
584
+ " int pad_slot_id) -> ()" );
585
+ ops.impl (" causal_conv1d_fwd" , torch::kCUDA , &causal_conv1d_fwd);
586
+
587
587
#ifndef USE_ROCM
588
588
// reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel
589
589
ops.def (
0 commit comments