21
21
from vllm .model_executor .layers .activation import get_act_fn
22
22
from vllm .model_executor .layers .linear import (ColumnParallelLinear ,
23
23
QKVParallelLinear ,
24
+ ReplicatedLinear ,
24
25
RowParallelLinear )
25
26
from vllm .model_executor .layers .quantization import QuantizationConfig
26
27
from vllm .model_executor .layers .sampler import SamplerOutput , get_sampler
33
34
BaseProcessingInfo , PromptReplacement ,
34
35
PromptUpdate , PromptUpdateDetails )
35
36
from vllm .multimodal .profiling import BaseDummyInputsBuilder
37
+ from vllm .multimodal .utils import run_dp_sharded_vision_model
36
38
from vllm .sequence import IntermediateTensors
37
39
from vllm .transformers_utils .configs import Step3VisionEncoderConfig
38
40
from vllm .transformers_utils .tokenizer import AnyTokenizer
@@ -650,7 +652,8 @@ class Step3VisionAttention(nn.Module):
650
652
def __init__ (self ,
651
653
config ,
652
654
quant_config : Optional [QuantizationConfig ] = None ,
653
- prefix : str = "" ):
655
+ prefix : str = "" ,
656
+ use_data_parallel : bool = False ):
654
657
super ().__init__ ()
655
658
self .config = config
656
659
self .embed_dim = config .hidden_size
@@ -659,20 +662,42 @@ def __init__(self,
659
662
660
663
self .scale = self .head_dim ** - 0.5
661
664
662
- tp_size = get_tensor_model_parallel_world_size ()
665
+ tp_size = (1 if use_data_parallel else
666
+ get_tensor_model_parallel_world_size ())
663
667
assert self .total_num_heads % tp_size == 0
664
668
self .num_heads = self .total_num_heads // tp_size
665
- self .qkv_proj = QKVParallelLinear (self .embed_dim ,
666
- self .head_dim ,
667
- self .total_num_heads ,
668
- bias = True ,
669
- quant_config = quant_config ,
670
- prefix = prefix )
671
- self .out_proj = RowParallelLinear (self .embed_dim ,
672
- self .embed_dim ,
673
- bias = True ,
674
- quant_config = quant_config ,
675
- prefix = prefix )
669
+
670
+ self .q_size = self .num_heads * self .head_dim
671
+
672
+ if use_data_parallel :
673
+ self .qkv_proj = ReplicatedLinear (
674
+ self .embed_dim ,
675
+ 3 * self .q_size ,
676
+ bias = True ,
677
+ quant_config = quant_config ,
678
+ prefix = prefix ,
679
+ )
680
+ self .out_proj = ReplicatedLinear (
681
+ self .total_num_heads * self .head_dim ,
682
+ self .embed_dim ,
683
+ bias = True ,
684
+ quant_config = quant_config ,
685
+ prefix = prefix ,
686
+ )
687
+ else :
688
+ self .qkv_proj = QKVParallelLinear (
689
+ self .embed_dim ,
690
+ self .head_dim ,
691
+ self .total_num_heads ,
692
+ bias = True ,
693
+ quant_config = quant_config ,
694
+ prefix = prefix ,
695
+ )
696
+ self .out_proj = RowParallelLinear (self .embed_dim ,
697
+ self .embed_dim ,
698
+ bias = True ,
699
+ quant_config = quant_config ,
700
+ prefix = prefix )
676
701
677
702
def _shape (self , tensor : torch .Tensor , seq_len : int , bsz : int ):
678
703
return tensor .view (bsz , seq_len , self .num_heads ,
@@ -712,20 +737,25 @@ class Step3VisionMLP(nn.Module):
712
737
def __init__ (self ,
713
738
config ,
714
739
quant_config : Optional [QuantizationConfig ] = None ,
715
- prefix : str = "" ):
740
+ prefix : str = "" ,
741
+ use_data_parallel : bool = False ):
716
742
super ().__init__ ()
717
743
self .config = config
718
744
self .activation_fn = get_act_fn (config .hidden_act )
719
- self .fc1 = ColumnParallelLinear (config .hidden_size ,
720
- config .intermediate_size ,
721
- bias = True ,
722
- quant_config = quant_config ,
723
- prefix = prefix )
724
- self .fc2 = RowParallelLinear (config .intermediate_size ,
725
- config .hidden_size ,
726
- bias = True ,
727
- quant_config = quant_config ,
728
- prefix = prefix )
745
+ cls_fc1 = (ReplicatedLinear
746
+ if use_data_parallel else ColumnParallelLinear )
747
+ self .fc1 = cls_fc1 (config .hidden_size ,
748
+ config .intermediate_size ,
749
+ bias = True ,
750
+ quant_config = quant_config ,
751
+ prefix = prefix )
752
+ cls_fc2 = (ReplicatedLinear
753
+ if use_data_parallel else RowParallelLinear )
754
+ self .fc2 = cls_fc2 (config .intermediate_size ,
755
+ config .hidden_size ,
756
+ bias = True ,
757
+ quant_config = quant_config ,
758
+ prefix = prefix )
729
759
730
760
def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
731
761
hidden_states , _ = self .fc1 (hidden_states )
@@ -739,15 +769,22 @@ class Step3VisionEncoderLayer(nn.Module):
739
769
def __init__ (self ,
740
770
config : Step3VisionEncoderConfig ,
741
771
quant_config : Optional [QuantizationConfig ] = None ,
742
- prefix : str = "" ):
772
+ prefix : str = "" ,
773
+ use_data_parallel : bool = False ):
743
774
super ().__init__ ()
775
+ self .use_data_parallel = use_data_parallel
744
776
self .embed_dim = config .hidden_size
745
- self .self_attn = Step3VisionAttention (config ,
746
- quant_config ,
747
- prefix = f"{ prefix } .self_attn" )
777
+ self .self_attn = Step3VisionAttention (
778
+ config ,
779
+ quant_config ,
780
+ prefix = f"{ prefix } .self_attn" ,
781
+ use_data_parallel = self .use_data_parallel )
748
782
self .layer_norm1 = nn .LayerNorm (self .embed_dim ,
749
783
eps = config .layer_norm_eps )
750
- self .mlp = Step3VisionMLP (config , quant_config , prefix = f"{ prefix } .mlp" )
784
+ self .mlp = Step3VisionMLP (config ,
785
+ quant_config ,
786
+ prefix = f"{ prefix } .mlp" ,
787
+ use_data_parallel = self .use_data_parallel )
751
788
self .layer_norm2 = nn .LayerNorm (self .embed_dim ,
752
789
eps = config .layer_norm_eps )
753
790
@@ -767,13 +804,16 @@ class Step3VisionEncoder(nn.Module):
767
804
def __init__ (self ,
768
805
config : Step3VisionEncoderConfig ,
769
806
quant_config : Optional [QuantizationConfig ] = None ,
770
- prefix : str = "" ):
807
+ prefix : str = "" ,
808
+ use_data_parallel : bool = False ):
771
809
super ().__init__ ()
772
810
self .config = config
811
+ self .use_data_parallel = use_data_parallel
773
812
self .layers = nn .ModuleList ([
774
813
Step3VisionEncoderLayer (config ,
775
814
quant_config ,
776
- prefix = f"{ prefix } .layers.{ i } " )
815
+ prefix = f"{ prefix } .layers.{ i } " ,
816
+ use_data_parallel = self .use_data_parallel )
777
817
for i in range (config .num_hidden_layers )
778
818
])
779
819
@@ -792,21 +832,29 @@ class Step3VisionTransformer(nn.Module):
792
832
def __init__ (self ,
793
833
config : Step3VisionEncoderConfig ,
794
834
quant_config : Optional [QuantizationConfig ] = None ,
795
- prefix : str = "" ):
835
+ prefix : str = "" ,
836
+ use_data_parallel : bool = False ):
796
837
super ().__init__ ()
797
838
self .config = config
839
+ self .use_data_parallel = use_data_parallel
798
840
self .image_size = config .image_size
799
841
self .embeddings = Step3VisionEmbeddings (config )
800
- self .transformer = Step3VisionEncoder (config ,
801
- quant_config ,
802
- prefix = f"{ prefix } .transformer" )
842
+ self .transformer = Step3VisionEncoder (
843
+ config ,
844
+ quant_config ,
845
+ prefix = f"{ prefix } .transformer" ,
846
+ use_data_parallel = self .use_data_parallel )
803
847
804
848
def forward (
805
849
self ,
806
850
pixel_values : torch .Tensor ,
807
851
):
808
852
hidden_states = self .embeddings (pixel_values )
809
- hidden_states = self .transformer (inputs_embeds = hidden_states )
853
+ if self .use_data_parallel :
854
+ hidden_states = run_dp_sharded_vision_model (
855
+ hidden_states , self .transformer )
856
+ else :
857
+ hidden_states = self .transformer (inputs_embeds = hidden_states )
810
858
return hidden_states
811
859
812
860
@@ -836,13 +884,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
836
884
837
885
self .config = config
838
886
self .multimodal_config = multimodal_config
887
+ self .use_data_parallel = (vllm_config .parallel_config .
888
+ enable_multimodal_encoder_data_parallel )
839
889
840
890
if multimodal_config .get_limit_per_prompt ("image" ):
841
- self .vision_model = Step3VisionTransformer (config . vision_config ,
842
- None ,
843
- prefix = maybe_prefix (
844
- prefix ,
845
- "vision_model" ) )
891
+ self .vision_model = Step3VisionTransformer (
892
+ config . vision_config ,
893
+ None ,
894
+ prefix = maybe_prefix ( prefix , "vision_model" ) ,
895
+ use_data_parallel = self . use_data_parallel )
846
896
self .vit_downsampler = nn .Conv2d (
847
897
config .vision_config .hidden_size ,
848
898
config .vision_config .output_hidden_size ,
0 commit comments