Skip to content

Commit 69672f1

Browse files
authored
[core][distributed] simplify code to support pipeline parallel (#6406)
1 parent 44874a0 commit 69672f1

File tree

5 files changed

+107
-61
lines changed

5 files changed

+107
-61
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,7 @@ steps:
4646
fast_check: true
4747
commands:
4848
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl
49-
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py
50-
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py
51-
- VLLM_ATTENTION_BACKEND=FLASHINFER pytest -v -s basic_correctness/test_basic_correctness.py
49+
- pytest -v -s basic_correctness/test_basic_correctness.py
5250
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
5351
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
5452
- VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py

tests/basic_correctness/test_basic_correctness.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,8 @@ def test_vllm_gc_ed():
2828
assert weak_llm() is None
2929

3030

31-
@pytest.mark.skipif(is_hip()
32-
and os.getenv("VLLM_ATTENTION_BACKEND") == "FLASHINFER",
33-
reason="Flashinfer does not support ROCm/HIP.")
3431
@pytest.mark.parametrize("model", MODELS)
32+
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS", "FLASHINFER"])
3533
@pytest.mark.parametrize("dtype", ["half"])
3634
@pytest.mark.parametrize("max_tokens", [5])
3735
@pytest.mark.parametrize("enforce_eager", [False, True])
@@ -40,10 +38,17 @@ def test_models(
4038
vllm_runner,
4139
example_prompts,
4240
model: str,
41+
backend: str,
4342
dtype: str,
4443
max_tokens: int,
4544
enforce_eager: bool,
4645
) -> None:
46+
47+
if backend == "FLASHINFER" and is_hip():
48+
pytest.skip("Flashinfer does not support ROCm/HIP.")
49+
50+
os.environ["VLLM_ATTENTION_BACKEND"] = backend
51+
4752
with hf_runner(model, dtype=dtype) as hf_model:
4853
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
4954

vllm/model_executor/models/gpt2.py

Lines changed: 20 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from vllm.config import CacheConfig
2828
from vllm.distributed.parallel_state import (
2929
get_pp_group, get_tensor_model_parallel_world_size)
30-
from vllm.distributed.utils import get_pp_indices
3130
from vllm.model_executor.layers.activation import get_act_fn
3231
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
3332
QKVParallelLinear,
@@ -42,6 +41,8 @@
4241
from vllm.model_executor.sampling_metadata import SamplingMetadata
4342
from vllm.sequence import IntermediateTensors, SamplerOutput
4443

44+
from .utils import is_pp_missing_parameter, make_layers
45+
4546

4647
class GPT2Attention(nn.Module):
4748

@@ -183,18 +184,9 @@ def __init__(
183184
self.embed_dim = config.hidden_size
184185
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
185186
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
186-
self.start_layer, self.end_layer = get_pp_indices(
187+
self.start_layer, self.end_layer, self.h = make_layers(
187188
config.num_hidden_layers,
188-
get_pp_group().rank_in_group,
189-
get_pp_group().world_size)
190-
self.h = nn.ModuleList(
191-
[nn.Identity() for _ in range(self.start_layer)] + [
192-
GPT2Block(config, cache_config, quant_config)
193-
for _ in range(self.start_layer, self.end_layer)
194-
] + [
195-
nn.Identity()
196-
for _ in range(self.end_layer, config.num_hidden_layers)
197-
])
189+
lambda: GPT2Block(config, cache_config, quant_config))
198190
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
199191

200192
def forward(
@@ -291,19 +283,20 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
291283
continue
292284
if not name.startswith("transformer."):
293285
name = "transformer." + name
294-
try:
295-
param = params_dict[name]
296-
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
297-
# Because of this, we need to transpose the weights.
298-
# Note(zhuohan): the logic below might break quantized models.
299-
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
300-
if conv1d_weight_name not in name:
301-
continue
302-
if not name.endswith(".weight"):
303-
continue
304-
loaded_weight = loaded_weight.t()
305-
weight_loader = getattr(param, "weight_loader",
306-
default_weight_loader)
307-
weight_loader(param, loaded_weight)
308-
except KeyError:
286+
287+
if is_pp_missing_parameter(name, self):
309288
continue
289+
290+
param = params_dict[name]
291+
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
292+
# Because of this, we need to transpose the weights.
293+
# Note(zhuohan): the logic below might break quantized models.
294+
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
295+
if conv1d_weight_name not in name:
296+
continue
297+
if not name.endswith(".weight"):
298+
continue
299+
loaded_weight = loaded_weight.t()
300+
weight_loader = getattr(param, "weight_loader",
301+
default_weight_loader)
302+
weight_loader(param, loaded_weight)

vllm/model_executor/models/llama.py

Lines changed: 22 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@
2929

3030
from vllm.attention import Attention, AttentionMetadata
3131
from vllm.config import CacheConfig, LoRAConfig
32-
from vllm.distributed import (get_pp_group, get_pp_indices,
33-
get_tensor_model_parallel_rank,
32+
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
3433
get_tensor_model_parallel_world_size)
3534
from vllm.model_executor.layers.activation import SiluAndMul
3635
from vllm.model_executor.layers.layernorm import RMSNorm
@@ -51,6 +50,7 @@
5150
from vllm.utils import is_hip, print_warning_once
5251

5352
from .interfaces import SupportsLoRA
53+
from .utils import is_pp_missing_parameter, make_layers
5454

5555

5656
class LlamaMLP(nn.Module):
@@ -262,20 +262,11 @@ def __init__(
262262
config.hidden_size,
263263
org_num_embeddings=config.vocab_size,
264264
)
265-
self.start_layer, self.end_layer = get_pp_indices(
265+
self.start_layer, self.end_layer, self.layers = make_layers(
266266
config.num_hidden_layers,
267-
get_pp_group().rank_in_group,
268-
get_pp_group().world_size)
269-
self.layers = nn.ModuleList(
270-
[nn.Identity() for _ in range(self.start_layer)] + [
271-
LlamaDecoderLayer(config=config,
272-
cache_config=cache_config,
273-
quant_config=quant_config)
274-
for _ in range(self.start_layer, self.end_layer)
275-
] + [
276-
nn.Identity()
277-
for _ in range(self.end_layer, config.num_hidden_layers)
278-
])
267+
lambda: LlamaDecoderLayer(config=config,
268+
cache_config=cache_config,
269+
quant_config=quant_config))
279270
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
280271

281272
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
@@ -455,12 +446,14 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
455446
# Skip loading extra bias for GPTQ models.
456447
if name.endswith(".bias") and name not in params_dict:
457448
continue
458-
try:
459-
param = params_dict[name]
460-
weight_loader = param.weight_loader
461-
weight_loader(param, loaded_weight, shard_id)
462-
except KeyError:
463-
pass
449+
450+
if is_pp_missing_parameter(name, self):
451+
continue
452+
453+
param = params_dict[name]
454+
weight_loader = param.weight_loader
455+
weight_loader(param, loaded_weight, shard_id)
456+
464457
break
465458
else:
466459
# Skip loading extra bias for GPTQ models.
@@ -479,13 +472,14 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
479472
continue
480473
else:
481474
name = remapped_kv_scale_name
482-
try:
483-
param = params_dict[name]
484-
weight_loader = getattr(param, "weight_loader",
485-
default_weight_loader)
486-
weight_loader(param, loaded_weight)
487-
except KeyError:
488-
pass
475+
476+
if is_pp_missing_parameter(name, self):
477+
continue
478+
479+
param = params_dict[name]
480+
weight_loader = getattr(param, "weight_loader",
481+
default_weight_loader)
482+
weight_loader(param, loaded_weight)
489483

490484
# If this function is called, it should always initialize KV cache scale
491485
# factors (or else raise an exception). Thus, handled exceptions should

vllm/model_executor/models/utils.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Callable, Dict, List, Tuple
2+
13
import torch
24

35
from vllm.multimodal import BatchedTensors
@@ -39,3 +41,57 @@ def merge_vision_embeddings(input_ids: torch.Tensor,
3941
inputs_embeds[mask] = torch.cat(vision_embeddings)
4042

4143
return inputs_embeds
44+
45+
46+
class PPMissingLayer(torch.nn.Identity):
47+
"""
48+
A placeholder layer for missing layers in a pipeline parallel model.
49+
"""
50+
51+
def __init__(self, *args, **kwargs):
52+
super().__init__()
53+
54+
55+
def make_layers(
56+
num_hidden_layers: int, layer_fn: Callable[[], torch.nn.Module]
57+
) -> Tuple[int, int, torch.nn.ModuleList]:
58+
"""Make a list of layers with the given layer function, taking
59+
pipeline parallelism into account.
60+
"""
61+
from vllm.distributed.parallel_state import get_pp_group
62+
from vllm.distributed.utils import get_pp_indices
63+
start_layer, end_layer = get_pp_indices(num_hidden_layers,
64+
get_pp_group().rank_in_group,
65+
get_pp_group().world_size)
66+
modules = torch.nn.ModuleList(
67+
[PPMissingLayer() for _ in range(start_layer)] +
68+
[layer_fn() for _ in range(start_layer, end_layer)] +
69+
[PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
70+
return start_layer, end_layer, modules
71+
72+
73+
# NOTE: don't use lru_cache here because it can prevent garbage collection
74+
_model_to_pp_missing_layer_names: Dict[int, List[str]] = {}
75+
76+
77+
def get_pp_missing_layer_names(model: torch.nn.Module) -> List[str]:
78+
"""Get the names of the missing layers in a pipeline parallel model."""
79+
model_id = id(model)
80+
if model_id in _model_to_pp_missing_layer_names:
81+
return _model_to_pp_missing_layer_names[model_id]
82+
83+
missing_layer_names = []
84+
for name, module in model.named_modules():
85+
if isinstance(module, PPMissingLayer):
86+
missing_layer_names.append(name)
87+
_model_to_pp_missing_layer_names[model_id] = missing_layer_names
88+
89+
return missing_layer_names
90+
91+
92+
def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
93+
"""Check if a parameter is missing in a pipeline parallel model."""
94+
for missing_layer_name in get_pp_missing_layer_names(model):
95+
if name.startswith(missing_layer_name):
96+
return True
97+
return False

0 commit comments

Comments
 (0)