-
Notifications
You must be signed in to change notification settings - Fork 30
Add data parallel support #80
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
wuxun-zhang
wants to merge
12
commits into
vllm-project:main
Choose a base branch
from
wuxun-zhang:wuxun/v1-dp-attention
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+544
−33
Open
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
b4b908a
Support Data Parallel
wuxun-zhang 90532c8
fix
wuxun-zhang cdcc3cc
fix dummy run
wuxun-zhang 5620188
fix lazy hang
wuxun-zhang 7a029da
add dp padding for prefill bs/seqlen/blocks
wuxun-zhang 5d44aee
add dp into ci test
wuxun-zhang 37c4485
use reduce_scatter instead of all_reduce
wuxun-zhang 97a84b0
fix dummy prefill batch for eager
wuxun-zhang 213f54b
fix rebase error
wuxun-zhang ea44413
fix ci error
wuxun-zhang 2a85ac5
Merge branch 'main' into wuxun/v1-dp-attention
kzawora-intel ce53143
Merge branch 'main' into wuxun/v1-dp-attention
adobrzyn File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,254 @@ | ||||||
# SPDX-License-Identifier: Apache-2.0 | ||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||||||
""" | ||||||
Usage: | ||||||
Single node: | ||||||
python examples/offline_inference/data_parallel.py \ | ||||||
--model="ibm-research/PowerMoE-3b" \ | ||||||
--dp-size=2 \ | ||||||
--tp-size=2 | ||||||
|
||||||
Multi-node: | ||||||
Node 0 (assume the node has ip of 10.99.48.128): | ||||||
python examples/offline_inference/data_parallel.py \ | ||||||
--model="ibm-research/PowerMoE-3b" \ | ||||||
--dp-size=2 \ | ||||||
--tp-size=2 \ | ||||||
--node-size=2 \ | ||||||
--node-rank=0 \ | ||||||
--master-addr=10.99.48.128 \ | ||||||
--master-port=13345 | ||||||
Node 1: | ||||||
python examples/offline_inference/data_parallel.py \ | ||||||
--model="ibm-research/PowerMoE-3b" \ | ||||||
--dp-size=2 \ | ||||||
--tp-size=2 \ | ||||||
--node-size=2 \ | ||||||
--node-rank=1 \ | ||||||
--master-addr=10.99.48.128 \ | ||||||
--master-port=13345 | ||||||
""" | ||||||
|
||||||
import os | ||||||
from time import sleep | ||||||
import torch | ||||||
|
||||||
from vllm import LLM, SamplingParams | ||||||
from vllm.utils import get_open_port | ||||||
|
||||||
|
||||||
def parse_args(): | ||||||
import argparse | ||||||
|
||||||
parser = argparse.ArgumentParser(description="Data Parallel Inference") | ||||||
parser.add_argument( | ||||||
"--model", | ||||||
type=str, | ||||||
default="ibm-research/PowerMoE-3b", | ||||||
help="Model name or path", | ||||||
) | ||||||
parser.add_argument( | ||||||
"--dp-size", type=int, default=2, help="Data parallel size" | ||||||
) | ||||||
parser.add_argument( | ||||||
"--tp-size", type=int, default=2, help="Tensor parallel size" | ||||||
) | ||||||
parser.add_argument( | ||||||
"--node-size", type=int, default=1, help="Total number of nodes" | ||||||
) | ||||||
parser.add_argument( | ||||||
"--node-rank", type=int, default=0, help="Rank of the current node" | ||||||
) | ||||||
parser.add_argument( | ||||||
"--master-addr", type=str, default="", help="Master node IP address" | ||||||
) | ||||||
parser.add_argument( | ||||||
"--master-port", type=int, default=0, help="Master node port" | ||||||
) | ||||||
parser.add_argument( | ||||||
"--enforce-eager", | ||||||
action="store_true", | ||||||
help="Enforce eager mode execution.", | ||||||
) | ||||||
parser.add_argument( | ||||||
"--trust-remote-code", action="store_true", help="Trust remote code." | ||||||
) | ||||||
parser.add_argument( | ||||||
"--max-num-seqs", | ||||||
type=int, | ||||||
default=64, | ||||||
help=( | ||||||
"Maximum number of sequences to be processed in a single iteration." | ||||||
), | ||||||
) | ||||||
parser.add_argument( | ||||||
"--gpu-memory-utilization", | ||||||
type=float, | ||||||
default=0.8, | ||||||
help=("Fraction of GPU memory vLLM is allowed to allocate (0.0, 1.0]."), | ||||||
) | ||||||
parser.add_argument( | ||||||
"--random-input", | ||||||
action="store_true", | ||||||
help="Use random generated input tokens.", | ||||||
) | ||||||
return parser.parse_args() | ||||||
|
||||||
|
||||||
def generate_random_token_ids(repeat=1) -> list[int]: | ||||||
""" | ||||||
For testing different seuquence length in data parallel scenario | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix typo: 'seuquence' should be 'sequence'.
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||
""" | ||||||
candidate_lens = [130, 560] | ||||||
prompts = [] | ||||||
for num_tokens in candidate_lens: | ||||||
tokens = torch.randint( | ||||||
low=0, high=10000, size=(num_tokens,), dtype=torch.int32 | ||||||
) | ||||||
[prompts.append(tokens.tolist()) for _ in range(repeat)] | ||||||
return prompts | ||||||
|
||||||
|
||||||
def main( | ||||||
model, | ||||||
dp_size, | ||||||
local_dp_rank, | ||||||
global_dp_rank, | ||||||
dp_master_ip, | ||||||
dp_master_port, | ||||||
GPUs_per_dp_rank, | ||||||
enforce_eager, | ||||||
trust_remote_code, | ||||||
max_num_seqs, | ||||||
gpu_memory_utilization, | ||||||
): | ||||||
os.environ["VLLM_DP_RANK"] = str(global_dp_rank) | ||||||
os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank) | ||||||
os.environ["VLLM_DP_SIZE"] = str(dp_size) | ||||||
os.environ["VLLM_DP_MASTER_IP"] = dp_master_ip | ||||||
os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port) | ||||||
|
||||||
# CUDA_VISIBLE_DEVICES for each DP rank is set automatically inside the | ||||||
# engine processes. | ||||||
|
||||||
# Sample prompts. | ||||||
prompts = [ | ||||||
"Hello, my name is", | ||||||
"The president of the United States is", | ||||||
"The capital of France is", | ||||||
"The future of AI is", | ||||||
] * 40 | ||||||
|
||||||
# generate prompts with different length to demonstrate DP aware padding. | ||||||
if args.random_input: | ||||||
prompts = generate_random_token_ids(40) | ||||||
|
||||||
# with DP, each rank should process different prompts. | ||||||
# usually all the DP ranks process a full dataset, | ||||||
# and each rank processes a different part of the dataset. | ||||||
floor = len(prompts) // dp_size | ||||||
remainder = len(prompts) % dp_size | ||||||
|
||||||
# Distribute prompts into even groups. | ||||||
def start(rank): | ||||||
return rank * floor + min(rank, remainder) | ||||||
|
||||||
prompts = prompts[start(global_dp_rank) : start(global_dp_rank + 1)] | ||||||
if len(prompts) == 0: | ||||||
# if any rank has no prompts to process, | ||||||
# we need to set a placeholder prompt | ||||||
prompts = ["Placeholder"] | ||||||
print(f"DP rank {global_dp_rank} needs to process {len(prompts)} prompts") | ||||||
# Create a sampling params object. | ||||||
# since we are doing data parallel, every rank can have different | ||||||
# sampling params. here we set different max_tokens for different | ||||||
# ranks for demonstration. | ||||||
sampling_params = SamplingParams( | ||||||
temperature=0.8, top_p=0.95, max_tokens=[16, 20][global_dp_rank % 2] | ||||||
) | ||||||
|
||||||
# Create an LLM. | ||||||
llm = LLM( | ||||||
model=model, | ||||||
tensor_parallel_size=GPUs_per_dp_rank, | ||||||
enforce_eager=enforce_eager, | ||||||
enable_expert_parallel=True, | ||||||
trust_remote_code=trust_remote_code, | ||||||
max_num_seqs=max_num_seqs, | ||||||
gpu_memory_utilization=gpu_memory_utilization, | ||||||
) | ||||||
if not args.random_input: | ||||||
outputs = llm.generate(prompts, sampling_params) | ||||||
else: | ||||||
outputs = llm.generate(None, sampling_params, prompts) | ||||||
# Print the outputs. | ||||||
for i, output in enumerate(outputs): | ||||||
if i >= 5: | ||||||
# print only 5 outputs | ||||||
break | ||||||
prompt = output.prompt | ||||||
generated_text = output.outputs[0].text | ||||||
print( | ||||||
f"DP rank {global_dp_rank}, Prompt: {prompt!r}, " | ||||||
f"Generated text: {generated_text!r}" | ||||||
) | ||||||
|
||||||
# Give engines time to pause their processing loops before exiting. | ||||||
sleep(1) | ||||||
|
||||||
|
||||||
if __name__ == "__main__": | ||||||
args = parse_args() | ||||||
|
||||||
dp_size = args.dp_size | ||||||
tp_size = args.tp_size | ||||||
node_size = args.node_size | ||||||
node_rank = args.node_rank | ||||||
|
||||||
if node_size == 1: | ||||||
dp_master_ip = "127.0.0.1" | ||||||
dp_master_port = get_open_port() | ||||||
else: | ||||||
dp_master_ip = args.master_addr | ||||||
dp_master_port = args.master_port | ||||||
|
||||||
assert dp_size % node_size == 0, "dp_size should be divisible by node_size" | ||||||
dp_per_node = dp_size // node_size | ||||||
|
||||||
from multiprocessing import Process | ||||||
|
||||||
procs = [] | ||||||
for local_dp_rank, global_dp_rank in enumerate( | ||||||
range(node_rank * dp_per_node, (node_rank + 1) * dp_per_node) | ||||||
): | ||||||
proc = Process( | ||||||
target=main, | ||||||
args=( | ||||||
args.model, | ||||||
dp_size, | ||||||
local_dp_rank, | ||||||
global_dp_rank, | ||||||
dp_master_ip, | ||||||
dp_master_port, | ||||||
tp_size, | ||||||
args.enforce_eager, | ||||||
args.trust_remote_code, | ||||||
args.max_num_seqs, | ||||||
args.gpu_memory_utilization, | ||||||
), | ||||||
) | ||||||
proc.start() | ||||||
procs.append(proc) | ||||||
exit_code = 0 | ||||||
for proc in procs: | ||||||
proc.join(timeout=300) | ||||||
if proc.exitcode is None: | ||||||
print( | ||||||
f"Killing process {proc.pid} that didn't stop within 5 minutes." | ||||||
) | ||||||
proc.kill() | ||||||
exit_code = 1 | ||||||
elif proc.exitcode: | ||||||
exit_code = proc.exitcode | ||||||
|
||||||
exit(exit_code) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.