Skip to content

Commit 24d1dff

Browse files
[executor] feat: add supports_pp attr to executors (#21786)
Signed-off-by: Haibin Lin <[email protected]>
1 parent 7de45db commit 24d1dff

File tree

4 files changed

+17
-8
lines changed

4 files changed

+17
-8
lines changed

vllm/engine/arg_utils.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1490,14 +1490,18 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
14901490
and _warn_or_fallback("Engine in background thread")):
14911491
return False
14921492

1493-
if (self.pipeline_parallel_size > 1
1494-
and self.distributed_executor_backend
1495-
not in (ParallelConfig.distributed_executor_backend, "ray",
1496-
"mp", "external_launcher")):
1497-
name = "Pipeline Parallelism without Ray distributed executor " \
1498-
"or multiprocessing executor or external launcher"
1499-
_raise_or_fallback(feature_name=name, recommend_to_remove=False)
1500-
return False
1493+
if self.pipeline_parallel_size > 1:
1494+
supports_pp = getattr(self.distributed_executor_backend,
1495+
'supports_pp', False)
1496+
if not supports_pp and self.distributed_executor_backend not in (
1497+
ParallelConfig.distributed_executor_backend, "ray", "mp",
1498+
"external_launcher"):
1499+
name = "Pipeline Parallelism without Ray distributed " \
1500+
"executor or multiprocessing executor or external " \
1501+
"launcher"
1502+
_raise_or_fallback(feature_name=name,
1503+
recommend_to_remove=False)
1504+
return False
15011505

15021506
# The platform may be supported on V1, but off by default for now.
15031507
if not current_platform.default_v1( # noqa: SIM103

vllm/executor/executor_base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class ExecutorBase(ABC):
3535
"""
3636

3737
uses_ray: bool # whether the executor uses Ray for orchestration.
38+
supports_pp: bool = False # whether the executor supports PP
3839

3940
def __init__(
4041
self,

vllm/v1/executor/multiproc_executor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141

4242
class MultiprocExecutor(Executor):
4343

44+
supports_pp: bool = True
45+
4446
def _init_executor(self) -> None:
4547
# Call self.shutdown at exit to clean up
4648
# and ensure workers will be terminated.

vllm/v1/executor/ray_distributed_executor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ def result(self, timeout=None):
4343
class RayDistributedExecutor(RayDistributedExecutorV0, Executor):
4444
"""Ray distributed executor using Ray Compiled Graphs."""
4545

46+
supports_pp: bool = True
47+
4648
def _init_executor(self) -> None:
4749
super()._init_executor()
4850

0 commit comments

Comments
 (0)