Skip to content

Commit d16aa3d

Browse files
authored
[Model] Add option to run Step3VisionEncoder in DP (#22697)
Signed-off-by: zzh142857 <[email protected]>
1 parent 6807af8 commit d16aa3d

File tree

1 file changed

+91
-41
lines changed

1 file changed

+91
-41
lines changed

vllm/model_executor/models/step3_vl.py

Lines changed: 91 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from vllm.model_executor.layers.activation import get_act_fn
2222
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
2323
QKVParallelLinear,
24+
ReplicatedLinear,
2425
RowParallelLinear)
2526
from vllm.model_executor.layers.quantization import QuantizationConfig
2627
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
@@ -33,6 +34,7 @@
3334
BaseProcessingInfo, PromptReplacement,
3435
PromptUpdate, PromptUpdateDetails)
3536
from vllm.multimodal.profiling import BaseDummyInputsBuilder
37+
from vllm.multimodal.utils import run_dp_sharded_vision_model
3638
from vllm.sequence import IntermediateTensors
3739
from vllm.transformers_utils.configs import Step3VisionEncoderConfig
3840
from vllm.transformers_utils.tokenizer import AnyTokenizer
@@ -650,7 +652,8 @@ class Step3VisionAttention(nn.Module):
650652
def __init__(self,
651653
config,
652654
quant_config: Optional[QuantizationConfig] = None,
653-
prefix: str = ""):
655+
prefix: str = "",
656+
use_data_parallel: bool = False):
654657
super().__init__()
655658
self.config = config
656659
self.embed_dim = config.hidden_size
@@ -659,20 +662,42 @@ def __init__(self,
659662

660663
self.scale = self.head_dim**-0.5
661664

662-
tp_size = get_tensor_model_parallel_world_size()
665+
tp_size = (1 if use_data_parallel else
666+
get_tensor_model_parallel_world_size())
663667
assert self.total_num_heads % tp_size == 0
664668
self.num_heads = self.total_num_heads // tp_size
665-
self.qkv_proj = QKVParallelLinear(self.embed_dim,
666-
self.head_dim,
667-
self.total_num_heads,
668-
bias=True,
669-
quant_config=quant_config,
670-
prefix=prefix)
671-
self.out_proj = RowParallelLinear(self.embed_dim,
672-
self.embed_dim,
673-
bias=True,
674-
quant_config=quant_config,
675-
prefix=prefix)
669+
670+
self.q_size = self.num_heads * self.head_dim
671+
672+
if use_data_parallel:
673+
self.qkv_proj = ReplicatedLinear(
674+
self.embed_dim,
675+
3 * self.q_size,
676+
bias=True,
677+
quant_config=quant_config,
678+
prefix=prefix,
679+
)
680+
self.out_proj = ReplicatedLinear(
681+
self.total_num_heads * self.head_dim,
682+
self.embed_dim,
683+
bias=True,
684+
quant_config=quant_config,
685+
prefix=prefix,
686+
)
687+
else:
688+
self.qkv_proj = QKVParallelLinear(
689+
self.embed_dim,
690+
self.head_dim,
691+
self.total_num_heads,
692+
bias=True,
693+
quant_config=quant_config,
694+
prefix=prefix,
695+
)
696+
self.out_proj = RowParallelLinear(self.embed_dim,
697+
self.embed_dim,
698+
bias=True,
699+
quant_config=quant_config,
700+
prefix=prefix)
676701

677702
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
678703
return tensor.view(bsz, seq_len, self.num_heads,
@@ -712,20 +737,25 @@ class Step3VisionMLP(nn.Module):
712737
def __init__(self,
713738
config,
714739
quant_config: Optional[QuantizationConfig] = None,
715-
prefix: str = ""):
740+
prefix: str = "",
741+
use_data_parallel: bool = False):
716742
super().__init__()
717743
self.config = config
718744
self.activation_fn = get_act_fn(config.hidden_act)
719-
self.fc1 = ColumnParallelLinear(config.hidden_size,
720-
config.intermediate_size,
721-
bias=True,
722-
quant_config=quant_config,
723-
prefix=prefix)
724-
self.fc2 = RowParallelLinear(config.intermediate_size,
725-
config.hidden_size,
726-
bias=True,
727-
quant_config=quant_config,
728-
prefix=prefix)
745+
cls_fc1 = (ReplicatedLinear
746+
if use_data_parallel else ColumnParallelLinear)
747+
self.fc1 = cls_fc1(config.hidden_size,
748+
config.intermediate_size,
749+
bias=True,
750+
quant_config=quant_config,
751+
prefix=prefix)
752+
cls_fc2 = (ReplicatedLinear
753+
if use_data_parallel else RowParallelLinear)
754+
self.fc2 = cls_fc2(config.intermediate_size,
755+
config.hidden_size,
756+
bias=True,
757+
quant_config=quant_config,
758+
prefix=prefix)
729759

730760
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
731761
hidden_states, _ = self.fc1(hidden_states)
@@ -739,15 +769,22 @@ class Step3VisionEncoderLayer(nn.Module):
739769
def __init__(self,
740770
config: Step3VisionEncoderConfig,
741771
quant_config: Optional[QuantizationConfig] = None,
742-
prefix: str = ""):
772+
prefix: str = "",
773+
use_data_parallel: bool = False):
743774
super().__init__()
775+
self.use_data_parallel = use_data_parallel
744776
self.embed_dim = config.hidden_size
745-
self.self_attn = Step3VisionAttention(config,
746-
quant_config,
747-
prefix=f"{prefix}.self_attn")
777+
self.self_attn = Step3VisionAttention(
778+
config,
779+
quant_config,
780+
prefix=f"{prefix}.self_attn",
781+
use_data_parallel=self.use_data_parallel)
748782
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
749783
eps=config.layer_norm_eps)
750-
self.mlp = Step3VisionMLP(config, quant_config, prefix=f"{prefix}.mlp")
784+
self.mlp = Step3VisionMLP(config,
785+
quant_config,
786+
prefix=f"{prefix}.mlp",
787+
use_data_parallel=self.use_data_parallel)
751788
self.layer_norm2 = nn.LayerNorm(self.embed_dim,
752789
eps=config.layer_norm_eps)
753790

@@ -767,13 +804,16 @@ class Step3VisionEncoder(nn.Module):
767804
def __init__(self,
768805
config: Step3VisionEncoderConfig,
769806
quant_config: Optional[QuantizationConfig] = None,
770-
prefix: str = ""):
807+
prefix: str = "",
808+
use_data_parallel: bool = False):
771809
super().__init__()
772810
self.config = config
811+
self.use_data_parallel = use_data_parallel
773812
self.layers = nn.ModuleList([
774813
Step3VisionEncoderLayer(config,
775814
quant_config,
776-
prefix=f"{prefix}.layers.{i}")
815+
prefix=f"{prefix}.layers.{i}",
816+
use_data_parallel=self.use_data_parallel)
777817
for i in range(config.num_hidden_layers)
778818
])
779819

@@ -792,21 +832,29 @@ class Step3VisionTransformer(nn.Module):
792832
def __init__(self,
793833
config: Step3VisionEncoderConfig,
794834
quant_config: Optional[QuantizationConfig] = None,
795-
prefix: str = ""):
835+
prefix: str = "",
836+
use_data_parallel: bool = False):
796837
super().__init__()
797838
self.config = config
839+
self.use_data_parallel = use_data_parallel
798840
self.image_size = config.image_size
799841
self.embeddings = Step3VisionEmbeddings(config)
800-
self.transformer = Step3VisionEncoder(config,
801-
quant_config,
802-
prefix=f"{prefix}.transformer")
842+
self.transformer = Step3VisionEncoder(
843+
config,
844+
quant_config,
845+
prefix=f"{prefix}.transformer",
846+
use_data_parallel=self.use_data_parallel)
803847

804848
def forward(
805849
self,
806850
pixel_values: torch.Tensor,
807851
):
808852
hidden_states = self.embeddings(pixel_values)
809-
hidden_states = self.transformer(inputs_embeds=hidden_states)
853+
if self.use_data_parallel:
854+
hidden_states = run_dp_sharded_vision_model(
855+
hidden_states, self.transformer)
856+
else:
857+
hidden_states = self.transformer(inputs_embeds=hidden_states)
810858
return hidden_states
811859

812860

@@ -836,13 +884,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
836884

837885
self.config = config
838886
self.multimodal_config = multimodal_config
887+
self.use_data_parallel = (vllm_config.parallel_config.
888+
enable_multimodal_encoder_data_parallel)
839889

840890
if multimodal_config.get_limit_per_prompt("image"):
841-
self.vision_model = Step3VisionTransformer(config.vision_config,
842-
None,
843-
prefix=maybe_prefix(
844-
prefix,
845-
"vision_model"))
891+
self.vision_model = Step3VisionTransformer(
892+
config.vision_config,
893+
None,
894+
prefix=maybe_prefix(prefix, "vision_model"),
895+
use_data_parallel=self.use_data_parallel)
846896
self.vit_downsampler = nn.Conv2d(
847897
config.vision_config.hidden_size,
848898
config.vision_config.output_hidden_size,

0 commit comments

Comments
 (0)