2929 Mamba2Metadata ,
3030)
3131from sglang .srt .layers .radix_attention import RadixAttention
32+ from sglang .srt .layers .radix_linear_attention import RadixLinearAttention
3233from sglang .srt .mem_cache .memory_pool import HybridReqToTokenPool , MambaPool
3334from sglang .srt .model_executor .forward_batch_info import ForwardBatch , ForwardMode
3435from sglang .srt .model_executor .model_runner import ModelRunner
@@ -833,30 +834,23 @@ def __init__(self, model_runner: ModelRunner):
833834
834835 def forward_decode (
835836 self ,
836- q : torch .Tensor ,
837- k : torch .Tensor ,
838- v : torch .Tensor ,
839- layer : RadixAttention ,
837+ layer : RadixLinearAttention ,
840838 forward_batch : ForwardBatch ,
841- save_kv_cache : bool = True ,
842- ** kwargs ,
839+ mixed_qkv : torch .Tensor ,
840+ a : torch .Tensor ,
841+ b : torch .Tensor ,
842+ ** kwargs , # Unused, for compatibility with HybridLinearAttnBackend
843843 ):
844- mixed_qkv = kwargs ["mixed_qkv" ]
845- conv_weights = kwargs ["conv_weights" ]
846- bias = kwargs ["bias" ]
847- activation = kwargs ["activation" ]
848- key_dim = kwargs ["key_dim" ]
849- value_dim = kwargs ["value_dim" ]
850- attn_tp_size = kwargs ["attention_tp_size" ]
851- head_k_dim = kwargs ["head_k_dim" ]
852- head_v_dim = kwargs ["head_v_dim" ]
853- a = kwargs ["a" ]
854- b = kwargs ["b" ]
855- A_log = kwargs ["A_log" ]
856- dt_bias = kwargs ["dt_bias" ]
857- layer_id = kwargs ["layer_id" ]
858-
859- layer_cache = self .req_to_token_pool .mamba2_layer_cache (layer_id )
844+ conv_weights = layer .conv_weights
845+ bias = layer .bias
846+ activation = layer .activation
847+ key_dim = layer .key_dim
848+ value_dim = layer .value_dim
849+ attn_tp_size = layer .attention_tp_size
850+ head_k_dim = layer .head_k_dim
851+ head_v_dim = layer .head_v_dim
852+
853+ layer_cache = self .req_to_token_pool .mamba2_layer_cache (layer .layer_id )
860854 conv_states = layer_cache .conv [0 ]
861855 ssm_states = layer_cache .temporal
862856 query_start_loc = self .forward_metadata .query_start_loc
@@ -888,8 +882,8 @@ def forward_decode(
888882 value = value .view (1 , seq_len , value .shape [1 ] // head_v_dim , head_v_dim )
889883
890884 core_attn_out = self ._kernel_func (
891- A_log = A_log ,
892- dt_bias = dt_bias ,
885+ A_log = layer . A_log ,
886+ dt_bias = layer . dt_bias ,
893887 q = query ,
894888 k = key ,
895889 v = value ,
@@ -911,29 +905,23 @@ def forward_decode(
911905
912906 def forward_extend (
913907 self ,
914- q : torch .Tensor ,
915- k : torch .Tensor ,
916- v : torch .Tensor ,
917- layer : RadixAttention ,
908+ layer : RadixLinearAttention ,
918909 forward_batch : ForwardBatch ,
919- save_kv_cache : bool = True ,
920- ** kwargs ,
910+ mixed_qkv : torch .Tensor ,
911+ a : torch .Tensor ,
912+ b : torch .Tensor ,
913+ ** kwargs , # Unused, for compatibility with HybridLinearAttnBackend
921914 ):
922- mixed_qkv = kwargs ["mixed_qkv" ]
923- conv_weights = kwargs ["conv_weights" ]
924- bias = kwargs ["bias" ]
925- activation = kwargs ["activation" ]
926- key_dim = kwargs ["key_dim" ]
927- value_dim = kwargs ["value_dim" ]
928- attn_tp_size = kwargs ["attention_tp_size" ]
929- head_k_dim = kwargs ["head_k_dim" ]
930- head_v_dim = kwargs ["head_v_dim" ]
931- a = kwargs ["a" ]
932- b = kwargs ["b" ]
933- A_log = kwargs ["A_log" ]
934- dt_bias = kwargs ["dt_bias" ]
935- layer_id = kwargs ["layer_id" ]
936- seq_len = kwargs ["seq_len" ]
915+ seq_len = mixed_qkv .shape [0 ]
916+
917+ conv_weights = layer .conv_weights
918+ bias = layer .bias
919+ activation = layer .activation
920+ key_dim = layer .key_dim
921+ value_dim = layer .value_dim
922+ attn_tp_size = layer .attention_tp_size
923+ head_k_dim = layer .head_k_dim
924+ head_v_dim = layer .head_v_dim
937925
938926 is_target_verify = forward_batch .forward_mode .is_target_verify ()
939927 forward_metadata = self .forward_metadata
@@ -944,7 +932,7 @@ def forward_extend(
944932 retrieve_next_sibling = forward_metadata .retrieve_next_sibling
945933 retrieve_parent_token = forward_metadata .retrieve_parent_token
946934
947- mamba_cache_params = self .req_to_token_pool .mamba2_layer_cache (layer_id )
935+ mamba_cache_params = self .req_to_token_pool .mamba2_layer_cache (layer . layer_id )
948936 conv_states = mamba_cache_params .conv [0 ]
949937 ssm_states = mamba_cache_params .temporal
950938 if is_target_verify :
@@ -1029,7 +1017,7 @@ def forward_extend(
10291017 key = key .view (1 , actual_seq_len , num_heads , head_k_dim )
10301018 value = value .view (1 , actual_seq_len , num_value_heads , head_v_dim )
10311019
1032- g , beta = fused_gdn_gating (A_log , a , b , dt_bias )
1020+ g , beta = fused_gdn_gating (layer . A_log , a , b , layer . dt_bias )
10331021
10341022 if is_target_verify :
10351023 core_attn_out = fused_recurrent_gated_delta_rule_update (
@@ -1240,75 +1228,114 @@ def get_cuda_graph_seq_len_fill_value(self):
12401228
12411229 def forward_decode (
12421230 self ,
1243- q : torch .Tensor ,
1244- k : torch .Tensor ,
1245- v : torch .Tensor ,
12461231 layer : RadixAttention ,
12471232 forward_batch : ForwardBatch ,
12481233 save_kv_cache : bool = True ,
1234+ q : Optional [torch .Tensor ] = None , # For full attention
1235+ k : Optional [torch .Tensor ] = None , # For full attention
1236+ v : Optional [torch .Tensor ] = None , # For full attention
1237+ mixed_qkv : Optional [torch .Tensor ] = None , # For GDN linear attention
1238+ a : Optional [torch .Tensor ] = None , # For GDN linear attention
1239+ b : Optional [torch .Tensor ] = None , # For GDN linear attention
12491240 ** kwargs ,
12501241 ):
12511242 layer_id = layer .layer_id if layer else kwargs ["layer_id" ]
12521243 if layer_id in self .full_attn_layers :
12531244 return self .full_attn_backend .forward_decode (
12541245 q , k , v , layer , forward_batch , save_kv_cache , ** kwargs
12551246 )
1247+ # Linear attention backend
12561248 return self .linear_attn_backend .forward_decode (
1257- q , k , v , layer , forward_batch , save_kv_cache , ** kwargs
1249+ q = q ,
1250+ k = k ,
1251+ v = v ,
1252+ layer = layer ,
1253+ forward_batch = forward_batch ,
1254+ save_kv_cache = save_kv_cache ,
1255+ mixed_qkv = mixed_qkv ,
1256+ a = a ,
1257+ b = b ,
1258+ ** kwargs ,
12581259 )
12591260
12601261 def forward_extend (
12611262 self ,
1262- q : torch .Tensor ,
1263- k : torch .Tensor ,
1264- v : torch .Tensor ,
12651263 layer : RadixAttention ,
12661264 forward_batch : ForwardBatch ,
12671265 save_kv_cache : bool = True ,
1266+ q : Optional [torch .Tensor ] = None , # For full attention
1267+ k : Optional [torch .Tensor ] = None , # For full attention
1268+ v : Optional [torch .Tensor ] = None , # For full attention
1269+ mixed_qkv : Optional [torch .Tensor ] = None , # For GDN linear attention
1270+ a : Optional [torch .Tensor ] = None , # For GDN linear attention
1271+ b : Optional [torch .Tensor ] = None , # For GDN linear attention
12681272 ** kwargs ,
12691273 ):
12701274 layer_id = layer .layer_id if layer else kwargs ["layer_id" ]
12711275 if layer_id in self .full_attn_layers :
12721276 return self .full_attn_backend .forward_extend (
12731277 q , k , v , layer , forward_batch , save_kv_cache , ** kwargs
12741278 )
1279+ # Linear attention backend
12751280 return self .linear_attn_backend .forward_extend (
1276- q , k , v , layer , forward_batch , save_kv_cache , ** kwargs
1281+ q = q ,
1282+ k = k ,
1283+ v = v ,
1284+ layer = layer ,
1285+ forward_batch = forward_batch ,
1286+ save_kv_cache = save_kv_cache ,
1287+ mixed_qkv = mixed_qkv ,
1288+ a = a ,
1289+ b = b ,
1290+ ** kwargs ,
12771291 )
12781292
12791293 def forward (
12801294 self ,
1281- q : torch .Tensor ,
1282- k : torch .Tensor ,
1283- v : torch .Tensor ,
1284- layer : RadixAttention ,
1285- forward_batch : ForwardBatch ,
1295+ q : Optional [ torch .Tensor ] = None , # For full attention
1296+ k : Optional [ torch .Tensor ] = None , # For full attention
1297+ v : Optional [ torch .Tensor ] = None , # For full attention
1298+ layer : RadixAttention = None ,
1299+ forward_batch : ForwardBatch = None ,
12861300 save_kv_cache : bool = True ,
1301+ mixed_qkv : Optional [torch .Tensor ] = None , # For GDN linear attention
1302+ a : Optional [torch .Tensor ] = None , # For GDN linear attention
1303+ b : Optional [torch .Tensor ] = None , # For GDN linear attention
12871304 ** kwargs ,
12881305 ):
1289- """Run forward on an attention layer."""
1306+ layer_id = layer .layer_id if layer else kwargs ["layer_id" ]
1307+ is_linear_attn = layer_id not in self .full_attn_layers
1308+
12901309 if forward_batch .forward_mode .is_idle ():
1291- if layer is None :
1292- return torch .empty_like (kwargs ["z" ])
1310+ if is_linear_attn :
1311+ return mixed_qkv .new_empty (
1312+ mixed_qkv .shape [0 ], layer .num_v_heads , layer .head_v_dim
1313+ )
12931314 return q .new_empty (q .shape [0 ], layer .tp_q_head_num * layer .v_head_dim )
12941315 elif forward_batch .forward_mode .is_decode ():
12951316 return self .forward_decode (
1317+ layer ,
1318+ forward_batch ,
1319+ save_kv_cache ,
12961320 q ,
12971321 k ,
12981322 v ,
1299- layer ,
1300- forward_batch ,
1301- save_kv_cache = save_kv_cache ,
1323+ mixed_qkv ,
1324+ a ,
1325+ b ,
13021326 ** kwargs ,
13031327 )
13041328 else :
13051329 return self .forward_extend (
1330+ layer ,
1331+ forward_batch ,
1332+ save_kv_cache ,
13061333 q ,
13071334 k ,
13081335 v ,
1309- layer ,
1310- forward_batch ,
1311- save_kv_cache = save_kv_cache ,
1336+ mixed_qkv ,
1337+ a ,
1338+ b ,
13121339 ** kwargs ,
13131340 )
13141341
0 commit comments