3030 apply_qk_norm ,
3131)
3232from sglang .multimodal_gen .runtime .layers .linear import (
33- ColumnParallelLinear ,
3433 MergedColumnParallelLinear ,
35- RowParallelLinear ,
34+ ReplicatedLinear ,
3635)
3736from sglang .multimodal_gen .runtime .layers .quantization .configs .base_config import (
3837 QuantizationConfig ,
@@ -89,109 +88,6 @@ def _get_qkv_projections(
8988 return img_query , img_key , img_value , txt_query , txt_key , txt_value
9089
9190
92- class GELU (nn .Module ):
93- r"""
94- GELU activation function with tanh approximation support with `approximate="tanh"`.
95-
96- Parameters:
97- dim_in (`int`): The number of channels in the input.
98- dim_out (`int`): The number of channels in the output.
99- approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
100- bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
101- quant_config: Quantization configure.
102- prefix: The name of the layer in the state dict.
103- """
104-
105- def __init__ (
106- self ,
107- dim_in : int ,
108- dim_out : int ,
109- approximate : str = "none" ,
110- bias : bool = True ,
111- quant_config = None ,
112- prefix : str = "" ,
113- ):
114- super ().__init__ ()
115- self .proj = ColumnParallelLinear (
116- dim_in ,
117- dim_out ,
118- bias = bias ,
119- gather_output = False ,
120- quant_config = quant_config ,
121- prefix = f"{ prefix } .proj" if prefix else "" ,
122- )
123- self .approximate = approximate
124-
125- def forward (self , hidden_states ):
126- hidden_states = self .proj (hidden_states )
127- return F .gelu (hidden_states [0 ], approximate = self .approximate )
128-
129-
130- class FeedForward (nn .Module ):
131- r"""
132- A feed-forward layer.
133-
134- Parameters:
135- dim (`int`): The number of channels in the input.
136- dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
137- mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
138- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
139- bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
140- quant_config: Quantization configure.
141- prefix: The name of the layer in the state dict.
142- """
143-
144- def __init__ (
145- self ,
146- dim : int ,
147- dim_out : Optional [int ] = None ,
148- mult : int = 4 ,
149- activation_fn : str = "geglu" ,
150- inner_dim = None ,
151- bias : bool = True ,
152- quant_config = None ,
153- prefix : str = "" ,
154- ):
155- super ().__init__ ()
156- if inner_dim is None :
157- inner_dim = int (dim * mult )
158- dim_out = dim_out if dim_out is not None else dim
159-
160- if activation_fn == "gelu" :
161- act_fn = GELU (dim , inner_dim , bias = bias , quant_config = None , prefix = prefix )
162- if activation_fn == "gelu-approximate" :
163- act_fn = GELU (
164- dim ,
165- inner_dim ,
166- approximate = "tanh" ,
167- bias = bias ,
168- quant_config = None ,
169- prefix = prefix ,
170- )
171- else :
172- raise NotImplementedError (
173- f"activation_fn '{ activation_fn } ' is not supported."
174- )
175-
176- self .net = nn .ModuleList ([])
177- self .net .append (act_fn )
178- self .net .append (nn .Identity ())
179- self .net .append (
180- RowParallelLinear (
181- inner_dim ,
182- dim_out ,
183- bias = True ,
184- input_is_parallel = True ,
185- quant_config = None ,
186- )
187- )
188-
189- def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
190- for module in self .net :
191- hidden_states = module (hidden_states )
192- return hidden_states
193-
194-
19591class QwenTimestepProjEmbeddings (nn .Module ):
19692 def __init__ (self , embedding_dim , use_additional_t_cond = False ):
19793 super ().__init__ ()
@@ -624,27 +520,9 @@ def __init__(
624520 )
625521 else :
626522 # Use separate Q/K/V projections for non-quantized models
627- self .to_q = ColumnParallelLinear (
628- dim ,
629- self .inner_dim ,
630- bias = True ,
631- quant_config = quant_config ,
632- prefix = f"{ prefix } .to_q" ,
633- )
634- self .to_k = ColumnParallelLinear (
635- dim ,
636- self .inner_dim ,
637- bias = True ,
638- quant_config = quant_config ,
639- prefix = f"{ prefix } .to_k" ,
640- )
641- self .to_v = ColumnParallelLinear (
642- dim ,
643- self .inner_dim ,
644- bias = True ,
645- quant_config = quant_config ,
646- prefix = f"{ prefix } .to_v" ,
647- )
523+ self .to_q = ReplicatedLinear (dim , self .inner_dim , bias = True )
524+ self .to_k = ReplicatedLinear (dim , self .inner_dim , bias = True )
525+ self .to_v = ReplicatedLinear (dim , self .inner_dim , bias = True )
648526
649527 if self .qk_norm :
650528 self .norm_q = RMSNorm (head_dim , eps = eps ) if qk_norm else nn .Identity ()
@@ -662,51 +540,25 @@ def __init__(
662540 )
663541 else :
664542 # Use separate Q/K/V projections for non-quantized models
665- self .add_q_proj = ColumnParallelLinear (
666- added_kv_proj_dim ,
667- self .inner_dim ,
668- bias = True ,
669- quant_config = quant_config ,
670- prefix = f"{ prefix } .add_q_proj" ,
543+ self .add_q_proj = ReplicatedLinear (
544+ added_kv_proj_dim , self .inner_dim , bias = True
671545 )
672- self .add_k_proj = ColumnParallelLinear (
673- added_kv_proj_dim ,
674- self .inner_dim ,
675- bias = True ,
676- quant_config = quant_config ,
677- prefix = f"{ prefix } .add_k_proj" ,
546+ self .add_k_proj = ReplicatedLinear (
547+ added_kv_proj_dim , self .inner_dim , bias = True
678548 )
679- self .add_v_proj = ColumnParallelLinear (
680- added_kv_proj_dim ,
681- self .inner_dim ,
682- bias = True ,
683- quant_config = quant_config ,
684- prefix = f"{ prefix } .add_v_proj" ,
549+ self .add_v_proj = ReplicatedLinear (
550+ added_kv_proj_dim , self .inner_dim , bias = True
685551 )
686552
687553 if context_pre_only is not None and not context_pre_only :
688- self .to_add_out = ColumnParallelLinear (
689- self .inner_dim ,
690- self .dim ,
691- bias = out_bias ,
692- gather_output = True ,
693- quant_config = quant_config ,
694- prefix = f"{ prefix } .to_add_out" ,
695- )
554+ self .to_add_out = ReplicatedLinear (self .inner_dim , self .dim , bias = out_bias )
696555 else :
697556 self .to_add_out = None
698557
699558 if not pre_only :
700559 self .to_out = nn .ModuleList ([])
701560 self .to_out .append (
702- ColumnParallelLinear (
703- self .inner_dim ,
704- self .dim ,
705- bias = out_bias ,
706- gather_output = True ,
707- quant_config = quant_config ,
708- prefix = f"{ prefix } .to_out.0" ,
709- )
561+ ReplicatedLinear (self .inner_dim , self .dim , bias = out_bias )
710562 )
711563 else :
712564 self .to_out = None
@@ -848,13 +700,8 @@ def __init__(
848700 # Image processing modules
849701 self .img_mod = nn .Sequential (
850702 nn .SiLU (),
851- ColumnParallelLinear (
852- dim ,
853- 6 * dim ,
854- bias = True ,
855- gather_output = True ,
856- quant_config = mod_quant_config ,
857- prefix = f"{ prefix } .img_mod" ,
703+ nn .Linear (
704+ dim , 6 * dim , bias = True
858705 ), # For scale, shift, gate for norm1 and norm2
859706 )
860707 self .img_norm1 = LayerNormScaleShift (
@@ -877,13 +724,8 @@ def __init__(
877724 # Text processing modules
878725 self .txt_mod = nn .Sequential (
879726 nn .SiLU (),
880- ColumnParallelLinear (
881- dim ,
882- 6 * dim ,
883- bias = True ,
884- gather_output = True ,
885- quant_config = mod_quant_config ,
886- prefix = f"{ prefix } .txt_mod" ,
727+ nn .Linear (
728+ dim , 6 * dim , bias = True
887729 ), # For scale, shift, gate for norm1 and norm2
888730 )
889731 self .txt_norm1 = LayerNormScaleShift (
@@ -919,15 +761,11 @@ def __init__(
919761 dim = dim ,
920762 dim_out = dim ,
921763 activation_fn = "gelu-approximate" ,
922- quant_config = quant_config ,
923- prefix = f"{ prefix } .img_mlp" ,
924764 )
925765 self .txt_mlp = FeedForward (
926766 dim = dim ,
927767 dim_out = dim ,
928768 activation_fn = "gelu-approximate" ,
929- quant_config = quant_config ,
930- prefix = f"{ prefix } .txt_mlp" ,
931769 )
932770
933771 if nunchaku_enabled :
@@ -1043,8 +881,8 @@ def forward(
1043881 modulate_index : Optional [List [int ]] = None ,
1044882 ) -> Tuple [torch .Tensor , torch .Tensor ]:
1045883 # Get modulation parameters for both streams
1046- img_mod_params = self .img_mod [1 ](temb_img_silu )[ 0 ] # [B, 6*dim]
1047- txt_mod_params = self .txt_mod [1 ](temb_txt_silu )[ 0 ] # [B, 6*dim]
884+ img_mod_params = self .img_mod [1 ](temb_img_silu ) # [B, 6*dim]
885+ txt_mod_params = self .txt_mod [1 ](temb_txt_silu ) # [B, 6*dim]
1048886
1049887 if (
1050888 self .quant_config is not None
@@ -1107,7 +945,7 @@ def forward(
1107945 gate_x = img_gate1 ,
1108946 residual_x = hidden_states ,
1109947 )
1110- img_mlp_output = self .img_mlp (img_modulated2 )[ 0 ]
948+ img_mlp_output = self .img_mlp (img_modulated2 )
1111949
1112950 if img_mlp_output .dim () == 2 :
1113951 img_mlp_output = img_mlp_output .unsqueeze (0 )
@@ -1123,7 +961,7 @@ def forward(
1123961 scale = txt_scale2 ,
1124962 )
1125963 txt_gate2 = txt_gate2_raw .unsqueeze (1 )
1126- txt_mlp_output = self .txt_mlp (txt_modulated2 )[ 0 ]
964+ txt_mlp_output = self .txt_mlp (txt_modulated2 )
1127965
1128966 if txt_mlp_output .dim () == 2 :
1129967 txt_mlp_output = txt_mlp_output .unsqueeze (0 )
0 commit comments