-
-
Notifications
You must be signed in to change notification settings - Fork 11.9k
[DP] support torchrun external launcher with Data Parallelism #24899
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
1b3e596
support torchrun dp
luccafong 5ddc0f4
resolve torch compile piecewise + dp hangining
luccafong 9977530
ensure all dp rank step when there is remaining requests on other dp…
luccafong cedff2f
add ci tests and safe cleanup
luccafong 4fe81d7
address commmen and safer cleanup
luccafong bf70583
fix lint
luccafong 39f66e5
minor fix
zhuohan123 04a487d
cleanup
luccafong f710310
Merge branch 'main' into torchrun_dp
luccafong 5106a65
deprecate v0 test
luccafong 2d88de3
Merge branch 'main' into torchrun_dp
luccafong 84de3f1
Merge branch 'main' into torchrun_dp
zhuohan123 b4ab007
Merge branch 'main' into torchrun_dp
luccafong File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,81 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| """ | ||
| experimental support for data-parallel inference with torchrun | ||
| Note the data load balancing and distribution is done out of the vllm engine, | ||
| no internal lb supported in external_launcher mode. | ||
| """ | ||
|
|
||
| from vllm import LLM, SamplingParams | ||
|
|
||
| # Create prompts, the same across all ranks | ||
| prompts = [ | ||
| "Hello, my name is", | ||
| "The president of the United States is", | ||
| "The capital of France is", | ||
| "The future of AI is", | ||
| ] * 50 | ||
|
|
||
| # Create sampling parameters, the same across all ranks | ||
| sampling_params = SamplingParams(temperature=0.8, top_p=0.95) | ||
|
|
||
| # Use `distributed_executor_backend="external_launcher"` so that | ||
| # this llm engine/instance only creates one worker. | ||
| # it is important to set an explicit seed to make sure that | ||
| # all ranks have the same random seed, so that sampling can be | ||
| # deterministic across ranks. | ||
| llm = LLM( | ||
| model="microsoft/Phi-mini-MoE-instruct", | ||
| tensor_parallel_size=1, | ||
| data_parallel_size=2, | ||
| pipeline_parallel_size=1, | ||
| enable_expert_parallel=False, | ||
| distributed_executor_backend="external_launcher", | ||
| max_model_len=4096, | ||
| gpu_memory_utilization=0.6, | ||
| seed=1, | ||
| ) | ||
|
|
||
| dp_rank = llm.llm_engine.vllm_config.parallel_config.data_parallel_rank | ||
| dp_size = llm.llm_engine.vllm_config.parallel_config.data_parallel_size | ||
|
|
||
| prompts = [ | ||
| f"{idx}.{prompt}" for idx, prompt in enumerate(prompts) if idx % dp_size == dp_rank | ||
| ] | ||
|
|
||
| outputs = llm.generate(prompts, sampling_params) | ||
|
|
||
|
|
||
| # all ranks will have the same outputs | ||
| print("-" * 50) | ||
| for output in outputs: | ||
| prompt = output.prompt | ||
| generated_text = output.outputs[0].text | ||
| print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}\n") | ||
| print("-" * 50) | ||
| """ | ||
| Further tips: | ||
|
|
||
| 1. to communicate control messages across all ranks, use the cpu group, | ||
| a PyTorch ProcessGroup with GLOO backend. | ||
|
|
||
| ```python | ||
| from vllm.distributed.parallel_state import get_world_group | ||
| cpu_group = get_world_group().cpu_group | ||
| torch_rank = dist.get_rank(group=cpu_group) | ||
| if torch_rank == 0: | ||
| # do something for rank 0, e.g. saving the results to disk. | ||
| ``` | ||
|
|
||
| 2. to communicate data across all ranks, use the model's device group, | ||
| a PyTorch ProcessGroup with NCCL backend. | ||
| ```python | ||
| from vllm.distributed.parallel_state import get_world_group | ||
| device_group = get_world_group().device_group | ||
| ``` | ||
|
|
||
| 3. to access the model directly in every rank, use the following code: | ||
| ```python | ||
| llm.llm_engine.model_executor.driver_worker.worker.model_runner.model | ||
| ``` | ||
| """ | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,81 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| # unit test for `examples/offline_inference/torchrun_example.py` | ||
| import os | ||
| import random | ||
|
|
||
| import torch.distributed as dist | ||
|
|
||
| from vllm import LLM, SamplingParams | ||
| from vllm.distributed.parallel_state import get_tp_group, get_world_group | ||
|
|
||
| dist.init_process_group(backend="gloo") | ||
|
|
||
| # Create prompts | ||
| prompts = [ | ||
| "Hello, my name is", | ||
| "The president of the United States is", | ||
| "The capital of France is", | ||
| "The future of AI is", | ||
| ] * 10 | ||
| dp_size = int(os.getenv("DP_SIZE", "1")) | ||
| dp_rank = int(os.getenv("DP_RANK", "0")) | ||
|
|
||
| if dp_size > 1: | ||
| # distribute the prompts across the data parallel ranks | ||
| prompts = [ | ||
| prompt for idx, prompt in enumerate(prompts) | ||
| if idx % dp_size == dp_rank | ||
| ] | ||
|
|
||
| sampling_params = SamplingParams(temperature=0.8, top_p=0.95) | ||
|
|
||
| # set different `gpu_memory_utilization` and `swap_space` for different ranks, | ||
| # to test if all ranks agree on the same kv cache configuration. | ||
| llm = LLM(model="microsoft/Phi-mini-MoE-instruct", | ||
| tensor_parallel_size=int(os.getenv("TP_SIZE", "1")), | ||
| pipeline_parallel_size=int(os.getenv("PP_SIZE", "1")), | ||
| enable_expert_parallel=int(os.getenv("ENABLE_EP", "0")) == 1, | ||
| distributed_executor_backend="external_launcher", | ||
| gpu_memory_utilization=random.uniform(0.7, 0.9), | ||
zhuohan123 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| swap_space=random.randint(1, 4), | ||
| seed=0) | ||
|
|
||
| outputs = llm.generate(prompts, sampling_params) | ||
|
|
||
| group = get_world_group() if dp_size == 1 else get_tp_group() | ||
| cpu_group = group.cpu_group | ||
| group_rank = dist.get_rank(group=cpu_group) | ||
|
|
||
|
|
||
| def test_consistent_across_ranks(obj): | ||
| if group_rank == 0: | ||
| dist.broadcast_object_list([obj], src=group.ranks[0], group=cpu_group) | ||
| else: | ||
| container = [None] | ||
| dist.broadcast_object_list(container, | ||
| src=group.ranks[0], | ||
| group=cpu_group) | ||
| assert container[0] == obj | ||
|
|
||
|
|
||
| test_consistent_across_ranks( | ||
| llm.llm_engine.vllm_config.cache_config.num_cpu_blocks) | ||
| test_consistent_across_ranks( | ||
| llm.llm_engine.vllm_config.cache_config.num_gpu_blocks) | ||
|
|
||
| # make sure we can access the model parameters from the calling process | ||
| # of the `LLM` instance. | ||
| params = list(llm.llm_engine.model_executor.driver_worker.worker.model_runner. | ||
| model.parameters()) | ||
| test_consistent_across_ranks(len(params)) | ||
|
|
||
| # all ranks should have the same outputs | ||
| for output in outputs: | ||
| prompt = output.prompt | ||
| generated_text = output.outputs[0].text | ||
| test_consistent_across_ranks(prompt) | ||
| test_consistent_across_ranks(generated_text) | ||
| print(f"Rank {group_rank}, Prompt: {prompt!r}, " | ||
| f"Generated text: {generated_text!r}") | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.