Skip to content

Commit 5efd690

Browse files
[CLI][Doc] Formalize --mm-encoder-tp-mode (#23190)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent b17109b commit 5efd690

File tree

7 files changed

+104
-24
lines changed

7 files changed

+104
-24
lines changed

docs/configuration/optimization.md

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,51 @@ Data parallelism replicates the entire model across multiple GPU sets and proces
129129
Data parallelism can be combined with the other parallelism strategies and is set by `data_parallel_size=N`.
130130
Note that MoE layers will be sharded according to the product of the tensor parallel size and data parallel size.
131131

132+
### Batch-level DP for Multi-Modal Encoders
133+
134+
By default, TP is used to shard the weights of multi-modal encoders just like for language decoders,
135+
in order to reduce the memory and compute load on each GPU.
136+
137+
However, since the size of multi-modal encoders is very small compared to language decoders,
138+
there is relatively little gain from TP. On the other hand, TP incurs significant communication
139+
overhead because of all-reduce being performed after every layer.
140+
141+
Given this, it may be advantageous to instead shard the batched input data using TP, essentially
142+
performing batch-level DP. This has been shown to improve the throughput by around 10% for
143+
`tensor_parallel_size=8`. For vision encoders that use hardware-unoptimized Conv3D operations,
144+
batch-level DP can provide another 40% increase to throughput compared to regular TP.
145+
146+
Nevertheless, since the weights of the multi-modal encoder are replicated across each TP rank,
147+
there will be a minor increase in memory consumption and may cause OOM if you can barely fit the model already.
148+
149+
You can enable batch-level DP by setting `mm_encoder_tp_mode="data"`, for example:
150+
151+
```python
152+
from vllm import LLM
153+
154+
llm = LLM(
155+
model="Qwen/Qwen2.5-VL-72B-Instruct",
156+
# Create two EngineCore instances, one per DP rank
157+
data_parallel_size=2,
158+
# Within each EngineCore instance:
159+
# The vision encoder uses TP=4 (not DP=2) to shard the input data
160+
# The language decoder uses TP=4 to shard the weights as usual
161+
tensor_parallel_size=4,
162+
mm_encoder_tp_mode="data",
163+
)
164+
```
165+
166+
!! important
167+
Batch-level DP is not to be confused with API request-level DP
168+
(which is instead controlled by `data_parallel_size`).
169+
170+
The availablilty of batch-level DP is based on model implementation.
171+
Currently, the following models support `mm_encoder_tp_mode="data"`:
172+
173+
- Llama4 (<gh-pr:18368>)
174+
- Qwen2.5-VL (<gh-pr:22742>)
175+
- Step3 (<gh-pr:22697>)
176+
132177
## Input Processing
133178

134179
### Parallel Processing

vllm/config/__init__.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,7 @@ def is_init_field(cls: ConfigType, name: str) -> bool:
258258
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
259259
LogprobsMode = Literal["raw_logprobs", "raw_logits", "processed_logprobs",
260260
"processed_logits"]
261+
MMEncoderTPMode = Literal["weights", "data"]
261262

262263

263264
@config
@@ -438,6 +439,19 @@ class ModelConfig:
438439
`mm_processor_cache_gb * (api_server_count + data_parallel_size)`.
439440
440441
Set to `0` to disable this cache completely (not recommended)."""
442+
mm_encoder_tp_mode: MMEncoderTPMode = "weights"
443+
"""Indicates how to optimize multi-modal encoder inference using
444+
tensor parallelism (TP).
445+
446+
- `"weights"`: Within the same vLLM engine, split the weights of
447+
each layer across TP ranks. (default TP behavior)
448+
- `"data"`: Within the same vLLM engine, split the batched input data
449+
across TP ranks to process the data in parallel, while hosting
450+
the full weights on each TP rank.
451+
This batch-level DP is not to be confused with API request-level
452+
DP (which is controlled by `--data-parallel-size`).
453+
This is only supported on a per-model basis and falls back to
454+
`"weights"` if the encoder does not support DP."""
441455
override_neuron_config: dict[str, Any] = field(default_factory=dict)
442456
"""Initialize non-default neuron config or override default neuron config
443457
that are specific to Neuron devices, this argument will be used to
@@ -856,8 +870,10 @@ def _init_multimodal_config(self) -> Optional["MultiModalConfig"]:
856870
media_io_kwargs=self.media_io_kwargs,
857871
mm_processor_kwargs=self.mm_processor_kwargs,
858872
mm_processor_cache_gb=self.mm_processor_cache_gb,
873+
mm_encoder_tp_mode=self.mm_encoder_tp_mode,
859874
interleave_mm_strings=self.interleave_mm_strings,
860-
skip_mm_profiling=self.skip_mm_profiling)
875+
skip_mm_profiling=self.skip_mm_profiling,
876+
)
861877

862878
return None
863879

@@ -2547,6 +2563,22 @@ class MultiModalConfig:
25472563
Set to `0` to disable this cache completely (not recommended).
25482564
"""
25492565

2566+
mm_encoder_tp_mode: MMEncoderTPMode = "weights"
2567+
"""
2568+
Indicates how to optimize multi-modal encoder inference using
2569+
tensor parallelism (TP).
2570+
2571+
- `"weights"`: Within the same vLLM engine, split the weights of
2572+
each layer across TP ranks. (default TP behavior)
2573+
- `"data"`: Within the same vLLM engine, split the batched input data
2574+
across TP ranks to process the data in parallel, while hosting
2575+
the full weights on each TP rank.
2576+
This batch-level DP is not to be confused with API request-level
2577+
DP (which is controlled by `--data-parallel-size`).
2578+
This is only supported on a per-model basis and falls back to
2579+
`"weights"` if the encoder does not support DP.
2580+
"""
2581+
25502582
interleave_mm_strings: bool = False
25512583
"""
25522584
Enable fully interleaved support for multimodal prompts.

vllm/config/parallel.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,6 @@ class is dynamically inherited by the worker class. This is used to inject
137137
rank: int = 0
138138
"""Global rank in distributed setup."""
139139

140-
enable_multimodal_encoder_data_parallel: bool = False
141-
""" Use data parallelism instead of tensor parallelism for vision encoder.
142-
Only support LLama4 for now"""
143-
144140
@property
145141
def world_size_across_dp(self) -> int:
146142
"""world_size_across_dp is TPxPPxDP, it is the size of the world

vllm/engine/arg_utils.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,12 @@
2828
DeviceConfig, DistributedExecutorBackend,
2929
GuidedDecodingBackend, HfOverrides, KVEventsConfig,
3030
KVTransferConfig, LoadConfig, LogprobsMode,
31-
LoRAConfig, MambaDType, ModelConfig, ModelDType,
32-
ModelImpl, MultiModalConfig, ObservabilityConfig,
33-
ParallelConfig, PoolerConfig, PrefixCachingHashAlgo,
34-
RunnerOption, SchedulerConfig, SchedulerPolicy,
35-
SpeculativeConfig, TaskOption, TokenizerMode,
36-
VllmConfig, get_attr_docs, get_field)
31+
LoRAConfig, MambaDType, MMEncoderTPMode, ModelConfig,
32+
ModelDType, ModelImpl, MultiModalConfig,
33+
ObservabilityConfig, ParallelConfig, PoolerConfig,
34+
PrefixCachingHashAlgo, RunnerOption, SchedulerConfig,
35+
SchedulerPolicy, SpeculativeConfig, TaskOption,
36+
TokenizerMode, VllmConfig, get_attr_docs, get_field)
3737
from vllm.logger import init_logger
3838
from vllm.platforms import CpuArchEnum, current_platform
3939
from vllm.plugins import load_general_plugins
@@ -352,6 +352,7 @@ class EngineArgs:
352352
MultiModalConfig.mm_processor_kwargs
353353
disable_mm_preprocessor_cache: bool = False # DEPRECATED
354354
mm_processor_cache_gb: int = MultiModalConfig.mm_processor_cache_gb
355+
mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode
355356
skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
356357
# LoRA fields
357358
enable_lora: bool = False
@@ -434,16 +435,14 @@ class EngineArgs:
434435
use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
435436
pt_load_map_location: str = LoadConfig.pt_load_map_location
436437

437-
enable_multimodal_encoder_data_parallel: bool = \
438-
ParallelConfig.enable_multimodal_encoder_data_parallel
438+
# DEPRECATED
439+
enable_multimodal_encoder_data_parallel: bool = False
439440

440441
logits_processors: Optional[list[Union[
441442
str, type[LogitsProcessor]]]] = ModelConfig.logits_processors
442443
"""Custom logitproc types"""
443444

444445
async_scheduling: bool = SchedulerConfig.async_scheduling
445-
# DEPRECATED
446-
enable_prompt_adapter: bool = False
447446

448447
kv_sharing_fast_prefill: bool = \
449448
CacheConfig.kv_sharing_fast_prefill
@@ -685,7 +684,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
685684
**parallel_kwargs["worker_extension_cls"])
686685
parallel_group.add_argument(
687686
"--enable-multimodal-encoder-data-parallel",
688-
**parallel_kwargs["enable_multimodal_encoder_data_parallel"])
687+
action="store_true",
688+
deprecated=True)
689689

690690
# KV cache arguments
691691
cache_kwargs = get_kwargs(CacheConfig)
@@ -735,6 +735,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
735735
multimodal_group.add_argument("--disable-mm-preprocessor-cache",
736736
action="store_true",
737737
deprecated=True)
738+
multimodal_group.add_argument(
739+
"--mm-encoder-tp-mode", **multimodal_kwargs["mm_encoder_tp_mode"])
738740
multimodal_group.add_argument(
739741
"--interleave-mm-strings",
740742
**multimodal_kwargs["interleave_mm_strings"])
@@ -909,6 +911,14 @@ def create_model_config(self) -> ModelConfig:
909911

910912
self.mm_processor_cache_gb = envs.VLLM_MM_INPUT_CACHE_GIB
911913

914+
if self.enable_multimodal_encoder_data_parallel:
915+
logger.warning(
916+
"--enable-multimodal-encoder-data-parallel` is deprecated "
917+
"and will be removed in v0.13. "
918+
"Please use `--mm-encoder-tp-mode data` instead.")
919+
920+
self.mm_encoder_tp_mode = "data"
921+
912922
return ModelConfig(
913923
model=self.model,
914924
hf_config_path=self.hf_config_path,
@@ -947,6 +957,7 @@ def create_model_config(self) -> ModelConfig:
947957
config_format=self.config_format,
948958
mm_processor_kwargs=self.mm_processor_kwargs,
949959
mm_processor_cache_gb=self.mm_processor_cache_gb,
960+
mm_encoder_tp_mode=self.mm_encoder_tp_mode,
950961
override_neuron_config=self.override_neuron_config,
951962
override_pooler_config=self.override_pooler_config,
952963
logits_processor_pattern=self.logits_processor_pattern,
@@ -1258,8 +1269,6 @@ def create_engine_config(
12581269
distributed_executor_backend=self.distributed_executor_backend,
12591270
worker_cls=self.worker_cls,
12601271
worker_extension_cls=self.worker_extension_cls,
1261-
enable_multimodal_encoder_data_parallel=self.
1262-
enable_multimodal_encoder_data_parallel,
12631272
)
12641273

12651274
if model_config.is_multimodal_model:

vllm/model_executor/models/mllama4.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -728,8 +728,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
728728
config = vllm_config.model_config.hf_config
729729
quant_config = vllm_config.quant_config
730730
multimodal_config = vllm_config.model_config.multimodal_config
731-
self.use_data_parallel = (vllm_config.parallel_config.
732-
enable_multimodal_encoder_data_parallel)
731+
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
732+
733733
self.config = config
734734
self.quant_config = quant_config
735735
self.multimodal_config = multimodal_config

vllm/model_executor/models/qwen2_5_vl.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -877,8 +877,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
877877
config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config
878878
multimodal_config = vllm_config.model_config.multimodal_config
879879

880-
self.use_data_parallel = (vllm_config.parallel_config.
881-
enable_multimodal_encoder_data_parallel)
880+
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
882881
self.config = config
883882
self.multimodal_config = multimodal_config
884883

vllm/model_executor/models/step3_vl.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -882,8 +882,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
882882

883883
self.config = config
884884
self.multimodal_config = multimodal_config
885-
self.use_data_parallel = (vllm_config.parallel_config.
886-
enable_multimodal_encoder_data_parallel)
885+
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
887886

888887
if multimodal_config.get_limit_per_prompt("image"):
889888
self.vision_model = Step3VisionTransformer(

0 commit comments

Comments
 (0)