Skip to content

Commit 0c54fc7

Browse files
authored
Improve configs - ParallelConfig (#16332)
Signed-off-by: Harry Mellor <[email protected]>
1 parent c1b5785 commit 0c54fc7

File tree

2 files changed

+182
-85
lines changed

2 files changed

+182
-85
lines changed

vllm/config.py

Lines changed: 118 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,16 @@
44
import copy
55
import enum
66
import hashlib
7+
import inspect
78
import json
89
import sys
10+
import textwrap
911
import warnings
1012
from collections import Counter
1113
from collections.abc import Mapping
1214
from contextlib import contextmanager
13-
from dataclasses import dataclass, field, replace
15+
from dataclasses import (MISSING, dataclass, field, fields, is_dataclass,
16+
replace)
1417
from importlib.util import find_spec
1518
from pathlib import Path
1619
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
@@ -104,6 +107,77 @@ class ModelImpl(str, enum.Enum):
104107
TRANSFORMERS = "transformers"
105108

106109

110+
def get_attr_docs(cls: type[Any]) -> dict[str, str]:
111+
"""
112+
Get any docstrings placed after attribute assignments in a class body.
113+
114+
https://davidism.com/mit-license/
115+
"""
116+
117+
def pairwise(iterable):
118+
"""
119+
Manually implement https://docs.python.org/3/library/itertools.html#itertools.pairwise
120+
121+
Can be removed when Python 3.9 support is dropped.
122+
"""
123+
iterator = iter(iterable)
124+
a = next(iterator, None)
125+
126+
for b in iterator:
127+
yield a, b
128+
a = b
129+
130+
cls_node = ast.parse(textwrap.dedent(inspect.getsource(cls))).body[0]
131+
132+
if not isinstance(cls_node, ast.ClassDef):
133+
raise TypeError("Given object was not a class.")
134+
135+
out = {}
136+
137+
# Consider each pair of nodes.
138+
for a, b in pairwise(cls_node.body):
139+
# Must be an assignment then a constant string.
140+
if (not isinstance(a, (ast.Assign, ast.AnnAssign))
141+
or not isinstance(b, ast.Expr)
142+
or not isinstance(b.value, ast.Constant)
143+
or not isinstance(b.value.value, str)):
144+
continue
145+
146+
doc = inspect.cleandoc(b.value.value)
147+
148+
# An assignment can have multiple targets (a = b = v), but an
149+
# annotated assignment only has one target.
150+
targets = a.targets if isinstance(a, ast.Assign) else [a.target]
151+
152+
for target in targets:
153+
# Must be assigning to a plain name.
154+
if not isinstance(target, ast.Name):
155+
continue
156+
157+
out[target.id] = doc
158+
159+
return out
160+
161+
162+
def config(cls: type[Any]) -> type[Any]:
163+
"""
164+
A decorator that ensures all fields in a dataclass have default values
165+
and that each field has a docstring.
166+
"""
167+
if not is_dataclass(cls):
168+
raise TypeError("The decorated class must be a dataclass.")
169+
attr_docs = get_attr_docs(cls)
170+
for f in fields(cls):
171+
if f.init and f.default is MISSING and f.default_factory is MISSING:
172+
raise ValueError(
173+
f"Field '{f.name}' in {cls.__name__} must have a default value."
174+
)
175+
if f.name not in attr_docs:
176+
raise ValueError(
177+
f"Field '{f.name}' in {cls.__name__} must have a docstring.")
178+
return cls
179+
180+
107181
class ModelConfig:
108182
"""Configuration for the model.
109183
@@ -1432,61 +1506,77 @@ def __post_init__(self):
14321506
self.ignore_patterns = ["original/**/*"]
14331507

14341508

1509+
@config
14351510
@dataclass
14361511
class ParallelConfig:
14371512
"""Configuration for the distributed execution."""
14381513

1439-
pipeline_parallel_size: int = 1 # Number of pipeline parallel groups.
1440-
tensor_parallel_size: int = 1 # Number of tensor parallel groups.
1441-
data_parallel_size: int = 1 # Number of data parallel groups.
1442-
data_parallel_rank: int = 0 # Rank of the data parallel group.
1443-
# Local rank of the data parallel group, defaults to global rank.
1514+
pipeline_parallel_size: int = 1
1515+
"""Number of pipeline parallel groups."""
1516+
tensor_parallel_size: int = 1
1517+
"""Number of tensor parallel groups."""
1518+
data_parallel_size: int = 1
1519+
"""Number of data parallel groups. MoE layers will be sharded according to
1520+
the product of the tensor parallel size and data parallel size."""
1521+
data_parallel_rank: int = 0
1522+
"""Rank of the data parallel group."""
14441523
data_parallel_rank_local: Optional[int] = None
1445-
# IP of the data parallel master.
1524+
"""Local rank of the data parallel group, defaults to global rank."""
14461525
data_parallel_master_ip: str = "127.0.0.1"
1447-
data_parallel_master_port: int = 29500 # Port of the data parallel master.
1448-
enable_expert_parallel: bool = False # Use EP instead of TP for MoE layers.
1526+
"""IP of the data parallel master."""
1527+
data_parallel_master_port: int = 29500
1528+
"""Port of the data parallel master."""
1529+
enable_expert_parallel: bool = False
1530+
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
14491531

1450-
# Maximum number of multiple batches
1451-
# when load model sequentially. To avoid RAM OOM when using tensor
1452-
# parallel and large models.
14531532
max_parallel_loading_workers: Optional[int] = None
1533+
"""Maximum number of parallal loading workers when loading model
1534+
sequentially in multiple batches. To avoid RAM OOM when using tensor
1535+
parallel and large models."""
14541536

1455-
# Disable the custom all-reduce kernel and fall back to NCCL.
14561537
disable_custom_all_reduce: bool = False
1538+
"""Disable the custom all-reduce kernel and fall back to NCCL."""
14571539

1458-
# Config for the tokenizer pool. If None, will use synchronous tokenization.
14591540
tokenizer_pool_config: Optional[TokenizerPoolConfig] = None
1541+
"""Config for the tokenizer pool. If None, will use synchronous
1542+
tokenization."""
14601543

1461-
# Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.
14621544
ray_workers_use_nsight: bool = False
1545+
"""Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""
14631546

1464-
# ray distributed model workers placement group.
14651547
placement_group: Optional["PlacementGroup"] = None
1548+
"""ray distributed model workers placement group."""
14661549

1467-
# Backend to use for distributed model
1468-
# workers, either "ray" or "mp" (multiprocessing). If the product
1469-
# of pipeline_parallel_size and tensor_parallel_size is less than
1470-
# or equal to the number of GPUs available, "mp" will be used to
1471-
# keep processing on a single host. Otherwise, this will default
1472-
# to "ray" if Ray is installed and fail otherwise. Note that tpu
1473-
# and hpu only support Ray for distributed inference.
14741550
distributed_executor_backend: Optional[Union[str,
14751551
type["ExecutorBase"]]] = None
1552+
"""Backend to use for distributed model
1553+
workers, either "ray" or "mp" (multiprocessing). If the product
1554+
of pipeline_parallel_size and tensor_parallel_size is less than
1555+
or equal to the number of GPUs available, "mp" will be used to
1556+
keep processing on a single host. Otherwise, this will default
1557+
to "ray" if Ray is installed and fail otherwise. Note that tpu
1558+
and hpu only support Ray for distributed inference."""
14761559

1477-
# the full name of the worker class to use. If "auto", the worker class
1478-
# will be determined based on the platform.
14791560
worker_cls: str = "auto"
1561+
"""The full name of the worker class to use. If "auto", the worker class
1562+
will be determined based on the platform."""
14801563
sd_worker_cls: str = "auto"
1564+
"""The full name of the worker class to use for speculative decofing.
1565+
If "auto", the worker class will be determined based on the platform."""
14811566
worker_extension_cls: str = ""
1567+
"""The full name of the worker extension class to use. The worker extension
1568+
class is dynamically inherited by the worker class. This is used to inject
1569+
new attributes and methods to the worker class for use in collective_rpc
1570+
calls."""
14821571

1483-
# world_size is TPxPP, it affects the number of workers we create.
14841572
world_size: int = field(init=False)
1485-
# world_size_across_dp is TPxPPxDP, it is the size of the world
1486-
# including data parallelism.
1573+
"""world_size is TPxPP, it affects the number of workers we create."""
14871574
world_size_across_dp: int = field(init=False)
1575+
"""world_size_across_dp is TPxPPxDP, it is the size of the world
1576+
including data parallelism."""
14881577

14891578
rank: int = 0
1579+
"""Global rank in distributed setup."""
14901580

14911581
def get_next_dp_init_port(self) -> int:
14921582
"""

vllm/engine/arg_utils.py

Lines changed: 64 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
import json
66
import re
77
import threading
8-
from dataclasses import dataclass
8+
from dataclasses import MISSING, dataclass, fields
99
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
10-
Tuple, Type, Union, cast, get_args)
10+
Tuple, Type, Union, cast, get_args, get_origin)
1111

1212
import torch
1313

@@ -19,7 +19,7 @@
1919
ModelConfig, ModelImpl, ObservabilityConfig,
2020
ParallelConfig, PoolerConfig, PromptAdapterConfig,
2121
SchedulerConfig, SpeculativeConfig, TaskOption,
22-
TokenizerPoolConfig, VllmConfig)
22+
TokenizerPoolConfig, VllmConfig, get_attr_docs)
2323
from vllm.executor.executor_base import ExecutorBase
2424
from vllm.logger import init_logger
2525
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
@@ -111,14 +111,15 @@ class EngineArgs:
111111
# Note: Specifying a custom executor backend by passing a class
112112
# is intended for expert use only. The API may change without
113113
# notice.
114-
distributed_executor_backend: Optional[Union[str,
115-
Type[ExecutorBase]]] = None
114+
distributed_executor_backend: Optional[Union[
115+
str, Type[ExecutorBase]]] = ParallelConfig.distributed_executor_backend
116116
# number of P/D disaggregation (or other disaggregation) workers
117-
pipeline_parallel_size: int = 1
118-
tensor_parallel_size: int = 1
119-
data_parallel_size: int = 1
120-
enable_expert_parallel: bool = False
121-
max_parallel_loading_workers: Optional[int] = None
117+
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
118+
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
119+
data_parallel_size: int = ParallelConfig.data_parallel_size
120+
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
121+
max_parallel_loading_workers: Optional[
122+
int] = ParallelConfig.max_parallel_loading_workers
122123
block_size: Optional[int] = None
123124
enable_prefix_caching: Optional[bool] = None
124125
prefix_caching_hash_algo: str = "builtin"
@@ -145,7 +146,7 @@ class EngineArgs:
145146
quantization: Optional[str] = None
146147
enforce_eager: Optional[bool] = None
147148
max_seq_len_to_capture: int = 8192
148-
disable_custom_all_reduce: bool = False
149+
disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
149150
tokenizer_pool_size: int = 0
150151
# Note: Specifying a tokenizer pool by passing a class
151152
# is intended for expert use only. The API may change without
@@ -170,7 +171,7 @@ class EngineArgs:
170171
device: str = 'auto'
171172
num_scheduler_steps: int = 1
172173
multi_step_stream_outputs: bool = True
173-
ray_workers_use_nsight: bool = False
174+
ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
174175
num_gpu_blocks_override: Optional[int] = None
175176
num_lookahead_slots: int = 0
176177
model_loader_extra_config: Optional[dict] = None
@@ -197,8 +198,8 @@ class EngineArgs:
197198
override_neuron_config: Optional[Dict[str, Any]] = None
198199
override_pooler_config: Optional[PoolerConfig] = None
199200
compilation_config: Optional[CompilationConfig] = None
200-
worker_cls: str = "auto"
201-
worker_extension_cls: str = ""
201+
worker_cls: str = ParallelConfig.worker_cls
202+
worker_extension_cls: str = ParallelConfig.worker_extension_cls
202203

203204
kv_transfer_config: Optional[KVTransferConfig] = None
204205

@@ -232,6 +233,31 @@ def __post_init__(self):
232233
@staticmethod
233234
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
234235
"""Shared CLI arguments for vLLM engine."""
236+
237+
def is_optional(cls: type[Any]) -> bool:
238+
"""Check if the class is an optional type."""
239+
return get_origin(cls) is Union and type(None) in get_args(cls)
240+
241+
def get_kwargs(cls: type[Any]) -> Dict[str, Any]:
242+
cls_docs = get_attr_docs(cls)
243+
kwargs = {}
244+
for field in fields(cls):
245+
name = field.name
246+
# One of these will always be present
247+
default = (field.default_factory
248+
if field.default is MISSING else field.default)
249+
kwargs[name] = {"default": default, "help": cls_docs[name]}
250+
# When using action="store_true"
251+
# add_argument doesn't accept type
252+
if field.type is bool:
253+
continue
254+
# Handle optional fields
255+
if is_optional(field.type):
256+
kwargs[name]["type"] = nullable_str
257+
continue
258+
kwargs[name]["type"] = field.type
259+
return kwargs
260+
235261
# Model arguments
236262
parser.add_argument(
237263
'--model',
@@ -411,52 +437,37 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
411437
'* "transformers" will use the Transformers model '
412438
'implementation.\n')
413439
# Parallel arguments
414-
parser.add_argument(
440+
parallel_kwargs = get_kwargs(ParallelConfig)
441+
parallel_group = parser.add_argument_group(
442+
title="ParallelConfig",
443+
description=ParallelConfig.__doc__,
444+
)
445+
parallel_group.add_argument(
415446
'--distributed-executor-backend',
416447
choices=['ray', 'mp', 'uni', 'external_launcher'],
417-
default=EngineArgs.distributed_executor_backend,
418-
help='Backend to use for distributed model '
419-
'workers, either "ray" or "mp" (multiprocessing). If the product '
420-
'of pipeline_parallel_size and tensor_parallel_size is less than '
421-
'or equal to the number of GPUs available, "mp" will be used to '
422-
'keep processing on a single host. Otherwise, this will default '
423-
'to "ray" if Ray is installed and fail otherwise. Note that tpu '
424-
'only supports Ray for distributed inference.')
425-
426-
parser.add_argument('--pipeline-parallel-size',
427-
'-pp',
428-
type=int,
429-
default=EngineArgs.pipeline_parallel_size,
430-
help='Number of pipeline stages.')
431-
parser.add_argument('--tensor-parallel-size',
432-
'-tp',
433-
type=int,
434-
default=EngineArgs.tensor_parallel_size,
435-
help='Number of tensor parallel replicas.')
436-
parser.add_argument('--data-parallel-size',
437-
'-dp',
438-
type=int,
439-
default=EngineArgs.data_parallel_size,
440-
help='Number of data parallel replicas. '
441-
'MoE layers will be sharded according to the '
442-
'product of the tensor-parallel-size and '
443-
'data-parallel-size.')
444-
parser.add_argument(
448+
**parallel_kwargs["distributed_executor_backend"])
449+
parallel_group.add_argument(
450+
'--pipeline-parallel-size', '-pp',
451+
**parallel_kwargs["pipeline_parallel_size"])
452+
parallel_group.add_argument('--tensor-parallel-size', '-tp',
453+
**parallel_kwargs["tensor_parallel_size"])
454+
parallel_group.add_argument('--data-parallel-size', '-dp',
455+
**parallel_kwargs["data_parallel_size"])
456+
parallel_group.add_argument(
445457
'--enable-expert-parallel',
446458
action='store_true',
447-
help='Use expert parallelism instead of tensor parallelism '
448-
'for MoE layers.')
449-
parser.add_argument(
459+
**parallel_kwargs["enable_expert_parallel"])
460+
parallel_group.add_argument(
450461
'--max-parallel-loading-workers',
451-
type=int,
452-
default=EngineArgs.max_parallel_loading_workers,
453-
help='Load model sequentially in multiple batches, '
454-
'to avoid RAM OOM when using tensor '
455-
'parallel and large models.')
456-
parser.add_argument(
462+
**parallel_kwargs["max_parallel_loading_workers"])
463+
parallel_group.add_argument(
457464
'--ray-workers-use-nsight',
458465
action='store_true',
459-
help='If specified, use nsight to profile Ray workers.')
466+
**parallel_kwargs["ray_workers_use_nsight"])
467+
parallel_group.add_argument(
468+
'--disable-custom-all-reduce',
469+
action='store_true',
470+
**parallel_kwargs["disable_custom_all_reduce"])
460471
# KV cache arguments
461472
parser.add_argument('--block-size',
462473
type=int,
@@ -639,10 +650,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
639650
'Additionally for encoder-decoder models, if the '
640651
'sequence length of the encoder input is larger '
641652
'than this, we fall back to the eager mode.')
642-
parser.add_argument('--disable-custom-all-reduce',
643-
action='store_true',
644-
default=EngineArgs.disable_custom_all_reduce,
645-
help='See ParallelConfig.')
646653
parser.add_argument('--tokenizer-pool-size',
647654
type=int,
648655
default=EngineArgs.tokenizer_pool_size,

0 commit comments

Comments
 (0)