Skip to content

Commit 569485c

Browse files
committed
some working version
1 parent 521eeab commit 569485c

File tree

3 files changed

+40
-5
lines changed

3 files changed

+40
-5
lines changed

examples/offline_inference/torchrun_dp_example.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
no internal lb supported in external_launcher mode.
77
"""
88

9-
from vllm import LLM, SamplingParams
9+
from vllm.distributed import cleanup_dist_env_and_memory
1010

11+
from vllm import LLM, SamplingParams
1112
# Create prompts, the same across all ranks
1213
prompts = [
1314
"Hello, my name is",
@@ -26,14 +27,17 @@
2627
# deterministic across ranks.
2728
llm = LLM(
2829
model="/data/local/models/oss/qwen1.5_2.7B_moe_chat",
29-
tensor_parallel_size=2,
30-
data_parallel_size=4,
30+
tensor_parallel_size=1,
31+
data_parallel_size=2,
3132
pipeline_parallel_size=1,
3233
enable_expert_parallel=True,
3334
distributed_executor_backend="external_launcher",
3435
max_model_len=32768,
36+
compilation_config={
37+
"cudagraph_mode": "FULL",
38+
},
3539
# FIXME: with torch.compile, the torchrun processes do not exit properly
36-
enforce_eager=True,
40+
# enforce_eager=True,
3741
seed=1,
3842
)
3943

@@ -55,6 +59,7 @@
5559
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}\n")
5660
print("-" * 50)
5761

62+
cleanup_dist_env_and_memory()
5863
"""
5964
Further tips:
6065

vllm/distributed/parallel_state.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -820,14 +820,32 @@ def recv(self,
820820
return self.device_communicator.recv(size, dtype, src)
821821

822822
def destroy(self):
823+
print(f"fanglu: Destroying device group, {self.unique_name=}")
824+
cudagraph_wrapper = getattr(self, "model", None)
825+
if cudagraph_wrapper is not None:
826+
print(f"Clean up cudagraph keys")
827+
for key in cudagraph_wrapper.concrete_cudagraph_entries:
828+
del cudagraph_wrapper.concrete_cudagraph_entries[key].cudagraph
829+
torch.cuda.empty_cache()
830+
gc.collect()
831+
# torch._dynamo.reset_code_caches()
832+
# from torch._inductor.cudagraph_trees import reset_cudagraph_trees
833+
# reset_cudagraph_trees()
834+
# torch._dynamo.reset()
835+
print("fanglu: Reset torch._dynamo done")
823836
if hasattr(self, "device_group"):
824837
torch.distributed.destroy_process_group(self.device_group)
838+
print("fanglu: Destroying device group done")
825839
del self.device_group
826840
if hasattr(self, "cpu_group"):
841+
print("fanglu: Destroying cpu group")
827842
torch.distributed.destroy_process_group(self.cpu_group)
843+
print("fanglu: Destroying cpu group done")
828844
del self.cpu_group
829845
if self.device_communicator is not None:
846+
print("fanglu: Destroying device communicator")
830847
self.device_communicator.destroy()
848+
print("fanglu: Destroying device communicator done")
831849
if self.mq_broadcaster is not None:
832850
self.mq_broadcaster = None
833851

@@ -1317,12 +1335,18 @@ def destroy_distributed_environment():
13171335

13181336

13191337
def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
1338+
print("fanglu: Cleaning up dist env and memory")
1339+
torch._dynamo.reset()
1340+
print("fanglu: Reset torch._dynamo done")
13201341
destroy_model_parallel()
1342+
print("fanglu: Destroy model parallel done")
13211343
destroy_distributed_environment()
1344+
print("fanglu: Destroy dist env done")
13221345
if shutdown_ray:
13231346
import ray # Lazy import Ray
13241347
ray.shutdown()
13251348
gc.collect()
1349+
print("GC done")
13261350
from vllm.platforms import current_platform
13271351
empty_cache = current_platform.empty_cache
13281352
if empty_cache is not None:

vllm/v1/worker/gpu_model_runner.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
has_kv_transfer_group)
3131
from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
3232
from vllm.distributed.parallel_state import (
33-
get_pp_group, get_tp_group, graph_capture, is_global_first_rank,
33+
get_dp_group, get_pp_group, get_tp_group, graph_capture, is_global_first_rank,
3434
prepare_communication_buffer_for_model)
3535
from vllm.forward_context import (BatchDescriptor, DPMetadata,
3636
set_forward_context)
@@ -2442,6 +2442,7 @@ def load_model(self, eep_scale_up: bool = False) -> None:
24422442
self.model = CUDAGraphWrapper(self.model,
24432443
self.vllm_config,
24442444
runtime_mode=CUDAGraphMode.FULL)
2445+
setattr(get_dp_group(), "model", self.model)
24452446

24462447
def reload_weights(self) -> None:
24472448
assert getattr(self, "model", None) is not None, \
@@ -3093,6 +3094,7 @@ def freeze_gc():
30933094
set_cudagraph_capturing_enabled(True)
30943095
with freeze_gc(), graph_capture(device=self.device):
30953096
cudagraph_mode = self.compilation_config.cudagraph_mode
3097+
print(f"{cudagraph_mode}")
30963098
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
30973099
cudagraph_runtime_mode = cudagraph_mode.mixed_mode()
30983100

@@ -3306,6 +3308,7 @@ def initialize_cudagraph_capture(self) -> None:
33063308
self.cudagraph_dispatcher.initialize_cudagraph_keys(
33073309
self.compilation_config.cudagraph_mode,
33083310
self.uniform_decode_query_len)
3311+
setattr(get_dp_group(), "dispatcher", self.cudagraph_dispatcher)
33093312

33103313
def calculate_reorder_batch_threshold(self) -> None:
33113314
"""
@@ -3757,3 +3760,6 @@ def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]:
37573760
self.transfer_event.record()
37583761
self.transfer_event.synchronize()
37593762
return pinned.tolist()
3763+
3764+
def __del__(self):
3765+
print("GPU Model Runner is called.")

0 commit comments

Comments
 (0)