Skip to content

Commit 145ae51

Browse files
authored
[Diffusion] Revert 18619 (#19510)
1 parent 6822941 commit 145ae51

File tree

1 file changed

+20
-182
lines changed

1 file changed

+20
-182
lines changed

python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py

Lines changed: 20 additions & 182 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,8 @@
3030
apply_qk_norm,
3131
)
3232
from sglang.multimodal_gen.runtime.layers.linear import (
33-
ColumnParallelLinear,
3433
MergedColumnParallelLinear,
35-
RowParallelLinear,
34+
ReplicatedLinear,
3635
)
3736
from sglang.multimodal_gen.runtime.layers.quantization.configs.base_config import (
3837
QuantizationConfig,
@@ -89,109 +88,6 @@ def _get_qkv_projections(
8988
return img_query, img_key, img_value, txt_query, txt_key, txt_value
9089

9190

92-
class GELU(nn.Module):
93-
r"""
94-
GELU activation function with tanh approximation support with `approximate="tanh"`.
95-
96-
Parameters:
97-
dim_in (`int`): The number of channels in the input.
98-
dim_out (`int`): The number of channels in the output.
99-
approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
100-
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
101-
quant_config: Quantization configure.
102-
prefix: The name of the layer in the state dict.
103-
"""
104-
105-
def __init__(
106-
self,
107-
dim_in: int,
108-
dim_out: int,
109-
approximate: str = "none",
110-
bias: bool = True,
111-
quant_config=None,
112-
prefix: str = "",
113-
):
114-
super().__init__()
115-
self.proj = ColumnParallelLinear(
116-
dim_in,
117-
dim_out,
118-
bias=bias,
119-
gather_output=False,
120-
quant_config=quant_config,
121-
prefix=f"{prefix}.proj" if prefix else "",
122-
)
123-
self.approximate = approximate
124-
125-
def forward(self, hidden_states):
126-
hidden_states = self.proj(hidden_states)
127-
return F.gelu(hidden_states[0], approximate=self.approximate)
128-
129-
130-
class FeedForward(nn.Module):
131-
r"""
132-
A feed-forward layer.
133-
134-
Parameters:
135-
dim (`int`): The number of channels in the input.
136-
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
137-
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
138-
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
139-
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
140-
quant_config: Quantization configure.
141-
prefix: The name of the layer in the state dict.
142-
"""
143-
144-
def __init__(
145-
self,
146-
dim: int,
147-
dim_out: Optional[int] = None,
148-
mult: int = 4,
149-
activation_fn: str = "geglu",
150-
inner_dim=None,
151-
bias: bool = True,
152-
quant_config=None,
153-
prefix: str = "",
154-
):
155-
super().__init__()
156-
if inner_dim is None:
157-
inner_dim = int(dim * mult)
158-
dim_out = dim_out if dim_out is not None else dim
159-
160-
if activation_fn == "gelu":
161-
act_fn = GELU(dim, inner_dim, bias=bias, quant_config=None, prefix=prefix)
162-
if activation_fn == "gelu-approximate":
163-
act_fn = GELU(
164-
dim,
165-
inner_dim,
166-
approximate="tanh",
167-
bias=bias,
168-
quant_config=None,
169-
prefix=prefix,
170-
)
171-
else:
172-
raise NotImplementedError(
173-
f"activation_fn '{activation_fn}' is not supported."
174-
)
175-
176-
self.net = nn.ModuleList([])
177-
self.net.append(act_fn)
178-
self.net.append(nn.Identity())
179-
self.net.append(
180-
RowParallelLinear(
181-
inner_dim,
182-
dim_out,
183-
bias=True,
184-
input_is_parallel=True,
185-
quant_config=None,
186-
)
187-
)
188-
189-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
190-
for module in self.net:
191-
hidden_states = module(hidden_states)
192-
return hidden_states
193-
194-
19591
class QwenTimestepProjEmbeddings(nn.Module):
19692
def __init__(self, embedding_dim, use_additional_t_cond=False):
19793
super().__init__()
@@ -624,27 +520,9 @@ def __init__(
624520
)
625521
else:
626522
# Use separate Q/K/V projections for non-quantized models
627-
self.to_q = ColumnParallelLinear(
628-
dim,
629-
self.inner_dim,
630-
bias=True,
631-
quant_config=quant_config,
632-
prefix=f"{prefix}.to_q",
633-
)
634-
self.to_k = ColumnParallelLinear(
635-
dim,
636-
self.inner_dim,
637-
bias=True,
638-
quant_config=quant_config,
639-
prefix=f"{prefix}.to_k",
640-
)
641-
self.to_v = ColumnParallelLinear(
642-
dim,
643-
self.inner_dim,
644-
bias=True,
645-
quant_config=quant_config,
646-
prefix=f"{prefix}.to_v",
647-
)
523+
self.to_q = ReplicatedLinear(dim, self.inner_dim, bias=True)
524+
self.to_k = ReplicatedLinear(dim, self.inner_dim, bias=True)
525+
self.to_v = ReplicatedLinear(dim, self.inner_dim, bias=True)
648526

649527
if self.qk_norm:
650528
self.norm_q = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity()
@@ -662,51 +540,25 @@ def __init__(
662540
)
663541
else:
664542
# Use separate Q/K/V projections for non-quantized models
665-
self.add_q_proj = ColumnParallelLinear(
666-
added_kv_proj_dim,
667-
self.inner_dim,
668-
bias=True,
669-
quant_config=quant_config,
670-
prefix=f"{prefix}.add_q_proj",
543+
self.add_q_proj = ReplicatedLinear(
544+
added_kv_proj_dim, self.inner_dim, bias=True
671545
)
672-
self.add_k_proj = ColumnParallelLinear(
673-
added_kv_proj_dim,
674-
self.inner_dim,
675-
bias=True,
676-
quant_config=quant_config,
677-
prefix=f"{prefix}.add_k_proj",
546+
self.add_k_proj = ReplicatedLinear(
547+
added_kv_proj_dim, self.inner_dim, bias=True
678548
)
679-
self.add_v_proj = ColumnParallelLinear(
680-
added_kv_proj_dim,
681-
self.inner_dim,
682-
bias=True,
683-
quant_config=quant_config,
684-
prefix=f"{prefix}.add_v_proj",
549+
self.add_v_proj = ReplicatedLinear(
550+
added_kv_proj_dim, self.inner_dim, bias=True
685551
)
686552

687553
if context_pre_only is not None and not context_pre_only:
688-
self.to_add_out = ColumnParallelLinear(
689-
self.inner_dim,
690-
self.dim,
691-
bias=out_bias,
692-
gather_output=True,
693-
quant_config=quant_config,
694-
prefix=f"{prefix}.to_add_out",
695-
)
554+
self.to_add_out = ReplicatedLinear(self.inner_dim, self.dim, bias=out_bias)
696555
else:
697556
self.to_add_out = None
698557

699558
if not pre_only:
700559
self.to_out = nn.ModuleList([])
701560
self.to_out.append(
702-
ColumnParallelLinear(
703-
self.inner_dim,
704-
self.dim,
705-
bias=out_bias,
706-
gather_output=True,
707-
quant_config=quant_config,
708-
prefix=f"{prefix}.to_out.0",
709-
)
561+
ReplicatedLinear(self.inner_dim, self.dim, bias=out_bias)
710562
)
711563
else:
712564
self.to_out = None
@@ -848,13 +700,8 @@ def __init__(
848700
# Image processing modules
849701
self.img_mod = nn.Sequential(
850702
nn.SiLU(),
851-
ColumnParallelLinear(
852-
dim,
853-
6 * dim,
854-
bias=True,
855-
gather_output=True,
856-
quant_config=mod_quant_config,
857-
prefix=f"{prefix}.img_mod",
703+
nn.Linear(
704+
dim, 6 * dim, bias=True
858705
), # For scale, shift, gate for norm1 and norm2
859706
)
860707
self.img_norm1 = LayerNormScaleShift(
@@ -877,13 +724,8 @@ def __init__(
877724
# Text processing modules
878725
self.txt_mod = nn.Sequential(
879726
nn.SiLU(),
880-
ColumnParallelLinear(
881-
dim,
882-
6 * dim,
883-
bias=True,
884-
gather_output=True,
885-
quant_config=mod_quant_config,
886-
prefix=f"{prefix}.txt_mod",
727+
nn.Linear(
728+
dim, 6 * dim, bias=True
887729
), # For scale, shift, gate for norm1 and norm2
888730
)
889731
self.txt_norm1 = LayerNormScaleShift(
@@ -919,15 +761,11 @@ def __init__(
919761
dim=dim,
920762
dim_out=dim,
921763
activation_fn="gelu-approximate",
922-
quant_config=quant_config,
923-
prefix=f"{prefix}.img_mlp",
924764
)
925765
self.txt_mlp = FeedForward(
926766
dim=dim,
927767
dim_out=dim,
928768
activation_fn="gelu-approximate",
929-
quant_config=quant_config,
930-
prefix=f"{prefix}.txt_mlp",
931769
)
932770

933771
if nunchaku_enabled:
@@ -1043,8 +881,8 @@ def forward(
1043881
modulate_index: Optional[List[int]] = None,
1044882
) -> Tuple[torch.Tensor, torch.Tensor]:
1045883
# Get modulation parameters for both streams
1046-
img_mod_params = self.img_mod[1](temb_img_silu)[0] # [B, 6*dim]
1047-
txt_mod_params = self.txt_mod[1](temb_txt_silu)[0] # [B, 6*dim]
884+
img_mod_params = self.img_mod[1](temb_img_silu) # [B, 6*dim]
885+
txt_mod_params = self.txt_mod[1](temb_txt_silu) # [B, 6*dim]
1048886

1049887
if (
1050888
self.quant_config is not None
@@ -1107,7 +945,7 @@ def forward(
1107945
gate_x=img_gate1,
1108946
residual_x=hidden_states,
1109947
)
1110-
img_mlp_output = self.img_mlp(img_modulated2)[0]
948+
img_mlp_output = self.img_mlp(img_modulated2)
1111949

1112950
if img_mlp_output.dim() == 2:
1113951
img_mlp_output = img_mlp_output.unsqueeze(0)
@@ -1123,7 +961,7 @@ def forward(
1123961
scale=txt_scale2,
1124962
)
1125963
txt_gate2 = txt_gate2_raw.unsqueeze(1)
1126-
txt_mlp_output = self.txt_mlp(txt_modulated2)[0]
964+
txt_mlp_output = self.txt_mlp(txt_modulated2)
1127965

1128966
if txt_mlp_output.dim() == 2:
1129967
txt_mlp_output = txt_mlp_output.unsqueeze(0)

0 commit comments

Comments
 (0)