Skip to content

Commit 3de6e6a

Browse files
authored
[core][distributed] support n layers % pp size != 0 (#6115)
1 parent 966fe72 commit 3de6e6a

File tree

7 files changed

+19
-10
lines changed

7 files changed

+19
-10
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ steps:
8080
commands:
8181
- TP_SIZE=2 PP_SIZE=2 EAGER_MODE=1 CHUNKED_PREFILL=1 pytest -v -s distributed/test_pipeline_parallel.py
8282
- TP_SIZE=2 PP_SIZE=2 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py
83+
- TP_SIZE=1 PP_SIZE=3 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py
8384
- PP_SIZE=4 EAGER_MODE=1 CHUNKED_PREFILL=1 pytest -v -s distributed/test_pipeline_parallel.py
8485
- PP_SIZE=4 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py
8586

vllm/config.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -265,8 +265,6 @@ def verify_with_parallel_config(
265265
" must be divisible by tensor parallel size "
266266
f"({tensor_parallel_size}).")
267267

268-
total_num_hidden_layers = getattr(self.hf_text_config,
269-
"num_hidden_layers", 0)
270268
pipeline_parallel_size = parallel_config.pipeline_parallel_size
271269
architectures = getattr(self.hf_config, "architectures", [])
272270
if not all(arch in _PP_SUPPORTED_MODELS
@@ -275,12 +273,6 @@ def verify_with_parallel_config(
275273
"Pipeline parallelism is only supported for the following "
276274
f" architectures: {_PP_SUPPORTED_MODELS}.")
277275

278-
if total_num_hidden_layers % pipeline_parallel_size != 0:
279-
raise ValueError(
280-
f"Total number of hidden layers ({total_num_hidden_layers}) "
281-
"must be divisible by pipeline parallel size "
282-
f"({pipeline_parallel_size}).")
283-
284276
if self.quantization == "bitsandbytes" and (
285277
parallel_config.tensor_parallel_size > 1
286278
or parallel_config.pipeline_parallel_size > 1):
@@ -385,9 +377,13 @@ def get_num_attention_heads(self,
385377
return num_heads // parallel_config.tensor_parallel_size
386378

387379
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
380+
from vllm.distributed.utils import get_pp_indices
388381
total_num_hidden_layers = getattr(self.hf_text_config,
389382
"num_hidden_layers", 0)
390-
return total_num_hidden_layers // parallel_config.pipeline_parallel_size
383+
pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size
384+
pp_size = parallel_config.pipeline_parallel_size
385+
start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size)
386+
return end - start
391387

392388
def contains_seqlen_agnostic_layers(
393389
self, parallel_config: "ParallelConfig") -> bool:
@@ -709,6 +705,7 @@ def __init__(
709705
{"CUDA_VISIBLE_DEVICES": envs.CUDA_VISIBLE_DEVICES})
710706

711707
self._verify_args()
708+
self.rank = 0
712709

713710
def _verify_args(self) -> None:
714711
if (self.pipeline_parallel_size > 1

vllm/distributed/utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,15 @@ def split_tensor_along_last_dim(
5050

5151
def get_pp_indices(num_hidden_layers: int, pp_rank: int,
5252
pp_size: int) -> Tuple[int, int]:
53-
layers_per_partition = divide(num_hidden_layers, pp_size)
53+
"""Try to evenly distribute layers across partitions.
54+
If the number of layers is not divisible by the number of partitions,
55+
the last partition will have the remaining layers.
56+
"""
57+
layers_per_partition = num_hidden_layers // pp_size
5458
start_layer = pp_rank * layers_per_partition
5559
end_layer = start_layer + layers_per_partition
5660

61+
if pp_rank == pp_size - 1:
62+
end_layer = num_hidden_layers
63+
5764
return (start_layer, end_layer)

vllm/worker/openvino_worker.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def __init__(
154154
) -> None:
155155
self.model_config = model_config
156156
self.parallel_config = parallel_config
157+
self.parallel_config.rank = rank
157158
self.scheduler_config = scheduler_config
158159
self.device_config = device_config
159160
self.cache_config = cache_config

vllm/worker/tpu_worker.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def __init__(
3939
) -> None:
4040
self.model_config = model_config
4141
self.parallel_config = parallel_config
42+
self.parallel_config.rank = rank
4243
self.scheduler_config = scheduler_config
4344
self.device_config = device_config
4445
self.cache_config = cache_config

vllm/worker/worker.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def __init__(
5050
) -> None:
5151
self.model_config = model_config
5252
self.parallel_config = parallel_config
53+
self.parallel_config.rank = rank
5354
self.scheduler_config = scheduler_config
5455
self.device_config = device_config
5556
self.cache_config = cache_config

vllm/worker/xpu_worker.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def __init__(
5454

5555
self.model_config = model_config
5656
self.parallel_config = parallel_config
57+
self.parallel_config.rank = rank
5758
self.scheduler_config = scheduler_config
5859
self.device_config = device_config
5960
self.cache_config = cache_config

0 commit comments

Comments
 (0)