Skip to content

Commit 5ad7ff8

Browse files
committed
Support Data Parallel
- enable profile run Signed-off-by: Wuxun Zhang <[email protected]>
1 parent 0492c55 commit 5ad7ff8

File tree

5 files changed

+441
-38
lines changed

5 files changed

+441
-38
lines changed

examples/data_parallel.py

Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
Usage:
5+
Single node:
6+
python examples/offline_inference/data_parallel.py \
7+
--model="ibm-research/PowerMoE-3b" \
8+
--dp-size=2 \
9+
--tp-size=2
10+
11+
Multi-node:
12+
Node 0 (assume the node has ip of 10.99.48.128):
13+
python examples/offline_inference/data_parallel.py \
14+
--model="ibm-research/PowerMoE-3b" \
15+
--dp-size=2 \
16+
--tp-size=2 \
17+
--node-size=2 \
18+
--node-rank=0 \
19+
--master-addr=10.99.48.128 \
20+
--master-port=13345
21+
Node 1:
22+
python examples/offline_inference/data_parallel.py \
23+
--model="ibm-research/PowerMoE-3b" \
24+
--dp-size=2 \
25+
--tp-size=2 \
26+
--node-size=2 \
27+
--node-rank=1 \
28+
--master-addr=10.99.48.128 \
29+
--master-port=13345
30+
"""
31+
32+
import os
33+
from time import sleep
34+
import torch
35+
36+
from vllm import LLM, SamplingParams
37+
from vllm.utils import get_open_port
38+
39+
40+
def parse_args():
41+
import argparse
42+
43+
parser = argparse.ArgumentParser(description="Data Parallel Inference")
44+
parser.add_argument(
45+
"--model",
46+
type=str,
47+
default="ibm-research/PowerMoE-3b",
48+
help="Model name or path",
49+
)
50+
parser.add_argument(
51+
"--dp-size", type=int, default=2, help="Data parallel size"
52+
)
53+
parser.add_argument(
54+
"--tp-size", type=int, default=2, help="Tensor parallel size"
55+
)
56+
parser.add_argument(
57+
"--node-size", type=int, default=1, help="Total number of nodes"
58+
)
59+
parser.add_argument(
60+
"--node-rank", type=int, default=0, help="Rank of the current node"
61+
)
62+
parser.add_argument(
63+
"--master-addr", type=str, default="", help="Master node IP address"
64+
)
65+
parser.add_argument(
66+
"--master-port", type=int, default=0, help="Master node port"
67+
)
68+
parser.add_argument(
69+
"--enforce-eager",
70+
action="store_true",
71+
help="Enforce eager mode execution.",
72+
)
73+
parser.add_argument(
74+
"--trust-remote-code", action="store_true", help="Trust remote code."
75+
)
76+
parser.add_argument(
77+
"--max-num-seqs",
78+
type=int,
79+
default=64,
80+
help=(
81+
"Maximum number of sequences to be processed in a single iteration."
82+
),
83+
)
84+
parser.add_argument(
85+
"--gpu-memory-utilization",
86+
type=float,
87+
default=0.8,
88+
help=("Fraction of GPU memory vLLM is allowed to allocate (0.0, 1.0]."),
89+
)
90+
parser.add_argument(
91+
"--random-input",
92+
action="store_true",
93+
help="Use random generated input tokens.",
94+
)
95+
return parser.parse_args()
96+
97+
98+
def generate_random_token_ids(repeat=1) -> list[int]:
99+
"""
100+
For testing different seuquence length in data parallel scenario
101+
"""
102+
candidate_lens = [130, 560]
103+
prompts = []
104+
for num_tokens in candidate_lens:
105+
tokens = torch.randint(
106+
low=0, high=10000, size=(num_tokens,), dtype=torch.int32
107+
)
108+
[prompts.append(tokens.tolist()) for _ in range(repeat)]
109+
return prompts
110+
111+
112+
def main(
113+
model,
114+
dp_size,
115+
local_dp_rank,
116+
global_dp_rank,
117+
dp_master_ip,
118+
dp_master_port,
119+
GPUs_per_dp_rank,
120+
enforce_eager,
121+
trust_remote_code,
122+
max_num_seqs,
123+
gpu_memory_utilization,
124+
):
125+
os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
126+
os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
127+
os.environ["VLLM_DP_SIZE"] = str(dp_size)
128+
os.environ["VLLM_DP_MASTER_IP"] = dp_master_ip
129+
os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port)
130+
131+
# CUDA_VISIBLE_DEVICES for each DP rank is set automatically inside the
132+
# engine processes.
133+
134+
# Sample prompts.
135+
prompts = [
136+
"Hello, my name is",
137+
"The president of the United States is",
138+
"The capital of France is",
139+
"The future of AI is",
140+
] * 40
141+
142+
# generate prompts with different length to demonstrate DP aware padding.
143+
if args.random_input:
144+
prompts = generate_random_token_ids(40)
145+
146+
# with DP, each rank should process different prompts.
147+
# usually all the DP ranks process a full dataset,
148+
# and each rank processes a different part of the dataset.
149+
floor = len(prompts) // dp_size
150+
remainder = len(prompts) % dp_size
151+
152+
# Distribute prompts into even groups.
153+
def start(rank):
154+
return rank * floor + min(rank, remainder)
155+
156+
prompts = prompts[start(global_dp_rank) : start(global_dp_rank + 1)]
157+
if len(prompts) == 0:
158+
# if any rank has no prompts to process,
159+
# we need to set a placeholder prompt
160+
prompts = ["Placeholder"]
161+
print(f"DP rank {global_dp_rank} needs to process {len(prompts)} prompts")
162+
# Create a sampling params object.
163+
# since we are doing data parallel, every rank can have different
164+
# sampling params. here we set different max_tokens for different
165+
# ranks for demonstration.
166+
sampling_params = SamplingParams(
167+
temperature=0.8, top_p=0.95, max_tokens=[16, 20][global_dp_rank % 2]
168+
)
169+
170+
# Create an LLM.
171+
llm = LLM(
172+
model=model,
173+
tensor_parallel_size=GPUs_per_dp_rank,
174+
enforce_eager=enforce_eager,
175+
enable_expert_parallel=True,
176+
trust_remote_code=trust_remote_code,
177+
max_num_seqs=max_num_seqs,
178+
gpu_memory_utilization=gpu_memory_utilization,
179+
)
180+
if not args.random_input:
181+
outputs = llm.generate(prompts, sampling_params)
182+
else:
183+
outputs = llm.generate(None, sampling_params, prompts)
184+
# Print the outputs.
185+
for i, output in enumerate(outputs):
186+
if i >= 5:
187+
# print only 5 outputs
188+
break
189+
prompt = output.prompt
190+
generated_text = output.outputs[0].text
191+
print(
192+
f"DP rank {global_dp_rank}, Prompt: {prompt!r}, "
193+
f"Generated text: {generated_text!r}"
194+
)
195+
196+
# Give engines time to pause their processing loops before exiting.
197+
sleep(1)
198+
199+
200+
if __name__ == "__main__":
201+
args = parse_args()
202+
203+
dp_size = args.dp_size
204+
tp_size = args.tp_size
205+
node_size = args.node_size
206+
node_rank = args.node_rank
207+
208+
if node_size == 1:
209+
dp_master_ip = "127.0.0.1"
210+
dp_master_port = get_open_port()
211+
else:
212+
dp_master_ip = args.master_addr
213+
dp_master_port = args.master_port
214+
215+
assert dp_size % node_size == 0, "dp_size should be divisible by node_size"
216+
dp_per_node = dp_size // node_size
217+
218+
from multiprocessing import Process
219+
220+
procs = []
221+
for local_dp_rank, global_dp_rank in enumerate(
222+
range(node_rank * dp_per_node, (node_rank + 1) * dp_per_node)
223+
):
224+
proc = Process(
225+
target=main,
226+
args=(
227+
args.model,
228+
dp_size,
229+
local_dp_rank,
230+
global_dp_rank,
231+
dp_master_ip,
232+
dp_master_port,
233+
tp_size,
234+
args.enforce_eager,
235+
args.trust_remote_code,
236+
args.max_num_seqs,
237+
args.gpu_memory_utilization,
238+
),
239+
)
240+
proc.start()
241+
procs.append(proc)
242+
exit_code = 0
243+
for proc in procs:
244+
proc.join(timeout=300)
245+
if proc.exitcode is None:
246+
print(
247+
f"Killing process {proc.pid} that didn't stop within 5 minutes."
248+
)
249+
proc.kill()
250+
exit_code = 1
251+
elif proc.exitcode:
252+
exit_code = proc.exitcode
253+
254+
exit(exit_code)

vllm_gaudi/distributed/device_communicators/hpu_communicator.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,30 @@
55

66
from vllm.distributed.device_communicators.base_device_communicator \
77
import DeviceCommunicatorBase
8+
from vllm.distributed.parallel_state import get_dp_group
9+
from vllm.forward_context import get_forward_context
810

911
import habana_frameworks.torch as htorch # noqa: F401
1012

1113

14+
def naive_multicast(x: torch.Tensor,
15+
cu_tokens_across_dp_cpu: torch.Tensor) -> torch.Tensor:
16+
assert x.dim() == 2, "Input tensor must be 2D"
17+
dp_rank = get_dp_group().rank_in_group
18+
dp_world_size = get_dp_group().world_size
19+
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
20+
device=x.device,
21+
dtype=x.dtype)
22+
start = 0 if dp_rank == 0 else cu_tokens_across_dp_cpu[dp_rank - 1]
23+
end = cu_tokens_across_dp_cpu[dp_rank]
24+
buffer[start:end, :].copy_(x)
25+
for idx in range(dp_world_size):
26+
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
27+
end = cu_tokens_across_dp_cpu[idx]
28+
get_dp_group().broadcast(buffer[start:end, :], idx)
29+
return buffer
30+
31+
1232
class HpuCommunicator(DeviceCommunicatorBase):
1333

1434
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
@@ -41,3 +61,28 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
4161
input_size[dim], ) +
4262
input_size[dim + 1:])
4363
return output_tensor
64+
65+
def dispatch(
66+
self, hidden_states: torch.Tensor,
67+
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
68+
"""
69+
all-gather based dispatch for HPUCommunicator.
70+
"""
71+
cu_tokens_across_dp_cpu = get_forward_context(
72+
).dp_metadata.cu_tokens_across_dp_cpu
73+
hidden_states_across_dp = naive_multicast(hidden_states,
74+
cu_tokens_across_dp_cpu)
75+
router_logits_across_dp = naive_multicast(router_logits,
76+
cu_tokens_across_dp_cpu)
77+
return hidden_states_across_dp, router_logits_across_dp
78+
79+
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
80+
dp_rank = get_dp_group().rank_in_group
81+
cu_tokens_across_dp_cpu = get_forward_context(
82+
).dp_metadata.cu_tokens_across_dp_cpu
83+
start = 0 if dp_rank == 0 else cu_tokens_across_dp_cpu[dp_rank - 1]
84+
end = cu_tokens_across_dp_cpu[dp_rank]
85+
86+
all_hidden_states = get_dp_group().all_reduce(hidden_states)
87+
hidden_states = all_hidden_states[start:end, :]
88+
return hidden_states

vllm_gaudi/platform.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class HpuPlatform(Platform):
3131
supported_quantization: list[str] = [
3232
"compressed-tensors", "fp8", "inc", "awq_hpu", "gptq_hpu"
3333
]
34+
simple_compile_backend = "hpu_backend"
3435

3536
@classmethod
3637
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,

0 commit comments

Comments
 (0)