Skip to content

Commit 419bbce

Browse files
authored
refactor Qwen3-Next with a new RadixLinearAttention (sgl-project#17373)
1 parent f33022d commit 419bbce

File tree

3 files changed

+200
-106
lines changed

3 files changed

+200
-106
lines changed

python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py

Lines changed: 96 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
Mamba2Metadata,
3030
)
3131
from sglang.srt.layers.radix_attention import RadixAttention
32+
from sglang.srt.layers.radix_linear_attention import RadixLinearAttention
3233
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, MambaPool
3334
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
3435
from sglang.srt.model_executor.model_runner import ModelRunner
@@ -833,30 +834,23 @@ def __init__(self, model_runner: ModelRunner):
833834

834835
def forward_decode(
835836
self,
836-
q: torch.Tensor,
837-
k: torch.Tensor,
838-
v: torch.Tensor,
839-
layer: RadixAttention,
837+
layer: RadixLinearAttention,
840838
forward_batch: ForwardBatch,
841-
save_kv_cache: bool = True,
842-
**kwargs,
839+
mixed_qkv: torch.Tensor,
840+
a: torch.Tensor,
841+
b: torch.Tensor,
842+
**kwargs, # Unused, for compatibility with HybridLinearAttnBackend
843843
):
844-
mixed_qkv = kwargs["mixed_qkv"]
845-
conv_weights = kwargs["conv_weights"]
846-
bias = kwargs["bias"]
847-
activation = kwargs["activation"]
848-
key_dim = kwargs["key_dim"]
849-
value_dim = kwargs["value_dim"]
850-
attn_tp_size = kwargs["attention_tp_size"]
851-
head_k_dim = kwargs["head_k_dim"]
852-
head_v_dim = kwargs["head_v_dim"]
853-
a = kwargs["a"]
854-
b = kwargs["b"]
855-
A_log = kwargs["A_log"]
856-
dt_bias = kwargs["dt_bias"]
857-
layer_id = kwargs["layer_id"]
858-
859-
layer_cache = self.req_to_token_pool.mamba2_layer_cache(layer_id)
844+
conv_weights = layer.conv_weights
845+
bias = layer.bias
846+
activation = layer.activation
847+
key_dim = layer.key_dim
848+
value_dim = layer.value_dim
849+
attn_tp_size = layer.attention_tp_size
850+
head_k_dim = layer.head_k_dim
851+
head_v_dim = layer.head_v_dim
852+
853+
layer_cache = self.req_to_token_pool.mamba2_layer_cache(layer.layer_id)
860854
conv_states = layer_cache.conv[0]
861855
ssm_states = layer_cache.temporal
862856
query_start_loc = self.forward_metadata.query_start_loc
@@ -888,8 +882,8 @@ def forward_decode(
888882
value = value.view(1, seq_len, value.shape[1] // head_v_dim, head_v_dim)
889883

890884
core_attn_out = self._kernel_func(
891-
A_log=A_log,
892-
dt_bias=dt_bias,
885+
A_log=layer.A_log,
886+
dt_bias=layer.dt_bias,
893887
q=query,
894888
k=key,
895889
v=value,
@@ -911,29 +905,23 @@ def forward_decode(
911905

912906
def forward_extend(
913907
self,
914-
q: torch.Tensor,
915-
k: torch.Tensor,
916-
v: torch.Tensor,
917-
layer: RadixAttention,
908+
layer: RadixLinearAttention,
918909
forward_batch: ForwardBatch,
919-
save_kv_cache: bool = True,
920-
**kwargs,
910+
mixed_qkv: torch.Tensor,
911+
a: torch.Tensor,
912+
b: torch.Tensor,
913+
**kwargs, # Unused, for compatibility with HybridLinearAttnBackend
921914
):
922-
mixed_qkv = kwargs["mixed_qkv"]
923-
conv_weights = kwargs["conv_weights"]
924-
bias = kwargs["bias"]
925-
activation = kwargs["activation"]
926-
key_dim = kwargs["key_dim"]
927-
value_dim = kwargs["value_dim"]
928-
attn_tp_size = kwargs["attention_tp_size"]
929-
head_k_dim = kwargs["head_k_dim"]
930-
head_v_dim = kwargs["head_v_dim"]
931-
a = kwargs["a"]
932-
b = kwargs["b"]
933-
A_log = kwargs["A_log"]
934-
dt_bias = kwargs["dt_bias"]
935-
layer_id = kwargs["layer_id"]
936-
seq_len = kwargs["seq_len"]
915+
seq_len = mixed_qkv.shape[0]
916+
917+
conv_weights = layer.conv_weights
918+
bias = layer.bias
919+
activation = layer.activation
920+
key_dim = layer.key_dim
921+
value_dim = layer.value_dim
922+
attn_tp_size = layer.attention_tp_size
923+
head_k_dim = layer.head_k_dim
924+
head_v_dim = layer.head_v_dim
937925

938926
is_target_verify = forward_batch.forward_mode.is_target_verify()
939927
forward_metadata = self.forward_metadata
@@ -944,7 +932,7 @@ def forward_extend(
944932
retrieve_next_sibling = forward_metadata.retrieve_next_sibling
945933
retrieve_parent_token = forward_metadata.retrieve_parent_token
946934

947-
mamba_cache_params = self.req_to_token_pool.mamba2_layer_cache(layer_id)
935+
mamba_cache_params = self.req_to_token_pool.mamba2_layer_cache(layer.layer_id)
948936
conv_states = mamba_cache_params.conv[0]
949937
ssm_states = mamba_cache_params.temporal
950938
if is_target_verify:
@@ -1029,7 +1017,7 @@ def forward_extend(
10291017
key = key.view(1, actual_seq_len, num_heads, head_k_dim)
10301018
value = value.view(1, actual_seq_len, num_value_heads, head_v_dim)
10311019

1032-
g, beta = fused_gdn_gating(A_log, a, b, dt_bias)
1020+
g, beta = fused_gdn_gating(layer.A_log, a, b, layer.dt_bias)
10331021

10341022
if is_target_verify:
10351023
core_attn_out = fused_recurrent_gated_delta_rule_update(
@@ -1240,75 +1228,114 @@ def get_cuda_graph_seq_len_fill_value(self):
12401228

12411229
def forward_decode(
12421230
self,
1243-
q: torch.Tensor,
1244-
k: torch.Tensor,
1245-
v: torch.Tensor,
12461231
layer: RadixAttention,
12471232
forward_batch: ForwardBatch,
12481233
save_kv_cache: bool = True,
1234+
q: Optional[torch.Tensor] = None, # For full attention
1235+
k: Optional[torch.Tensor] = None, # For full attention
1236+
v: Optional[torch.Tensor] = None, # For full attention
1237+
mixed_qkv: Optional[torch.Tensor] = None, # For GDN linear attention
1238+
a: Optional[torch.Tensor] = None, # For GDN linear attention
1239+
b: Optional[torch.Tensor] = None, # For GDN linear attention
12491240
**kwargs,
12501241
):
12511242
layer_id = layer.layer_id if layer else kwargs["layer_id"]
12521243
if layer_id in self.full_attn_layers:
12531244
return self.full_attn_backend.forward_decode(
12541245
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
12551246
)
1247+
# Linear attention backend
12561248
return self.linear_attn_backend.forward_decode(
1257-
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
1249+
q=q,
1250+
k=k,
1251+
v=v,
1252+
layer=layer,
1253+
forward_batch=forward_batch,
1254+
save_kv_cache=save_kv_cache,
1255+
mixed_qkv=mixed_qkv,
1256+
a=a,
1257+
b=b,
1258+
**kwargs,
12581259
)
12591260

12601261
def forward_extend(
12611262
self,
1262-
q: torch.Tensor,
1263-
k: torch.Tensor,
1264-
v: torch.Tensor,
12651263
layer: RadixAttention,
12661264
forward_batch: ForwardBatch,
12671265
save_kv_cache: bool = True,
1266+
q: Optional[torch.Tensor] = None, # For full attention
1267+
k: Optional[torch.Tensor] = None, # For full attention
1268+
v: Optional[torch.Tensor] = None, # For full attention
1269+
mixed_qkv: Optional[torch.Tensor] = None, # For GDN linear attention
1270+
a: Optional[torch.Tensor] = None, # For GDN linear attention
1271+
b: Optional[torch.Tensor] = None, # For GDN linear attention
12681272
**kwargs,
12691273
):
12701274
layer_id = layer.layer_id if layer else kwargs["layer_id"]
12711275
if layer_id in self.full_attn_layers:
12721276
return self.full_attn_backend.forward_extend(
12731277
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
12741278
)
1279+
# Linear attention backend
12751280
return self.linear_attn_backend.forward_extend(
1276-
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
1281+
q=q,
1282+
k=k,
1283+
v=v,
1284+
layer=layer,
1285+
forward_batch=forward_batch,
1286+
save_kv_cache=save_kv_cache,
1287+
mixed_qkv=mixed_qkv,
1288+
a=a,
1289+
b=b,
1290+
**kwargs,
12771291
)
12781292

12791293
def forward(
12801294
self,
1281-
q: torch.Tensor,
1282-
k: torch.Tensor,
1283-
v: torch.Tensor,
1284-
layer: RadixAttention,
1285-
forward_batch: ForwardBatch,
1295+
q: Optional[torch.Tensor] = None, # For full attention
1296+
k: Optional[torch.Tensor] = None, # For full attention
1297+
v: Optional[torch.Tensor] = None, # For full attention
1298+
layer: RadixAttention = None,
1299+
forward_batch: ForwardBatch = None,
12861300
save_kv_cache: bool = True,
1301+
mixed_qkv: Optional[torch.Tensor] = None, # For GDN linear attention
1302+
a: Optional[torch.Tensor] = None, # For GDN linear attention
1303+
b: Optional[torch.Tensor] = None, # For GDN linear attention
12871304
**kwargs,
12881305
):
1289-
"""Run forward on an attention layer."""
1306+
layer_id = layer.layer_id if layer else kwargs["layer_id"]
1307+
is_linear_attn = layer_id not in self.full_attn_layers
1308+
12901309
if forward_batch.forward_mode.is_idle():
1291-
if layer is None:
1292-
return torch.empty_like(kwargs["z"])
1310+
if is_linear_attn:
1311+
return mixed_qkv.new_empty(
1312+
mixed_qkv.shape[0], layer.num_v_heads, layer.head_v_dim
1313+
)
12931314
return q.new_empty(q.shape[0], layer.tp_q_head_num * layer.v_head_dim)
12941315
elif forward_batch.forward_mode.is_decode():
12951316
return self.forward_decode(
1317+
layer,
1318+
forward_batch,
1319+
save_kv_cache,
12961320
q,
12971321
k,
12981322
v,
1299-
layer,
1300-
forward_batch,
1301-
save_kv_cache=save_kv_cache,
1323+
mixed_qkv,
1324+
a,
1325+
b,
13021326
**kwargs,
13031327
)
13041328
else:
13051329
return self.forward_extend(
1330+
layer,
1331+
forward_batch,
1332+
save_kv_cache,
13061333
q,
13071334
k,
13081335
v,
1309-
layer,
1310-
forward_batch,
1311-
save_kv_cache=save_kv_cache,
1336+
mixed_qkv,
1337+
a,
1338+
b,
13121339
**kwargs,
13131340
)
13141341

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright 2025-2026 SGLang Team
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
# ==============================================================================
14+
"""Radix linear attention."""
15+
from __future__ import annotations
16+
17+
from typing import TYPE_CHECKING, Optional
18+
19+
import torch
20+
from torch import nn
21+
22+
if TYPE_CHECKING:
23+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
24+
25+
26+
class RadixLinearAttention(nn.Module):
27+
"""
28+
The Linear Attention Layer Implementation.
29+
"""
30+
31+
def __init__(
32+
self,
33+
layer_id: int,
34+
num_qk_heads: int,
35+
num_v_heads: int,
36+
head_qk_dim: int,
37+
head_v_dim: int,
38+
attention_tp_size: int = 1,
39+
conv_weights: Optional[torch.Tensor] = None,
40+
bias: Optional[torch.Tensor] = None,
41+
activation: str = "silu",
42+
A_log: Optional[torch.Tensor] = None,
43+
dt_bias: Optional[torch.Tensor] = None,
44+
):
45+
super().__init__()
46+
self.layer_id = layer_id
47+
# Q and K share the same head count and dimension (per-TP values)
48+
self.num_qk_heads = num_qk_heads
49+
self.num_v_heads = num_v_heads
50+
self.head_qk_dim = head_qk_dim
51+
self.head_v_dim = head_v_dim
52+
self.attention_tp_size = attention_tp_size
53+
54+
self.qk_dim_per_tp = num_qk_heads * head_qk_dim
55+
self.value_dim_per_tp = num_v_heads * head_v_dim
56+
57+
self.key_dim = self.qk_dim_per_tp * attention_tp_size
58+
self.value_dim = self.value_dim_per_tp * attention_tp_size
59+
60+
self.num_k_heads = num_qk_heads
61+
self.num_q_heads = num_qk_heads
62+
self.head_k_dim = head_qk_dim
63+
64+
self.conv_weights = conv_weights
65+
self.bias = bias
66+
self.activation = activation
67+
self.A_log = A_log
68+
self.dt_bias = dt_bias
69+
70+
def forward(
71+
self,
72+
forward_batch: ForwardBatch,
73+
mixed_qkv: torch.Tensor,
74+
a: torch.Tensor,
75+
b: torch.Tensor,
76+
) -> torch.Tensor:
77+
return forward_batch.attn_backend.forward(
78+
layer=self,
79+
forward_batch=forward_batch,
80+
mixed_qkv=mixed_qkv,
81+
a=a,
82+
b=b,
83+
)

0 commit comments

Comments
 (0)