6
6
import torch .distributed
7
7
8
8
from vllm .distributed .communication_op import ( # noqa
9
- graph_capture , tensor_model_parallel_all_reduce )
9
+ tensor_model_parallel_all_reduce )
10
10
from vllm .distributed .device_communicators .pynccl import PyNcclCommunicator
11
11
from vllm .distributed .device_communicators .pynccl_wrapper import NCCLLibrary
12
12
from vllm .distributed .parallel_state import (ensure_model_parallel_initialized ,
13
+ get_world_group , graph_capture ,
13
14
init_distributed_environment )
14
15
from vllm .utils import update_environment_variables
15
16
@@ -53,7 +54,8 @@ def wrapped_fn(env):
53
54
54
55
@worker_fn_wrapper
55
56
def worker_fn ():
56
- pynccl_comm = PyNcclCommunicator ()
57
+ pynccl_comm = PyNcclCommunicator (get_world_group ().cpu_group ,
58
+ device = get_world_group ().device )
57
59
tensor = torch .ones (16 , 1024 , 1024 ,
58
60
dtype = torch .float32 ).cuda (pynccl_comm .rank )
59
61
with pynccl_comm .change_state (enable = True ):
@@ -129,7 +131,8 @@ def test_pynccl_multiple_allreduce_with_vllm():
129
131
def worker_fn_with_cudagraph ():
130
132
with torch .no_grad ():
131
133
graph = torch .cuda .CUDAGraph ()
132
- pynccl_comm = PyNcclCommunicator ()
134
+ pynccl_comm = PyNcclCommunicator (get_world_group ().cpu_group ,
135
+ device = get_world_group ().device )
133
136
# run something in the default stream to initialize torch engine
134
137
a = torch .ones ((4 , 4 ), device = f'cuda:{ pynccl_comm .rank } ' )
135
138
torch .cuda .synchronize ()
@@ -154,7 +157,8 @@ def test_pynccl_with_cudagraph():
154
157
155
158
@worker_fn_wrapper
156
159
def send_recv_worker_fn ():
157
- pynccl_comm = PyNcclCommunicator ()
160
+ pynccl_comm = PyNcclCommunicator (get_world_group ().cpu_group ,
161
+ device = get_world_group ().device )
158
162
if pynccl_comm .rank == 0 :
159
163
tensor = torch .ones (16 , 1024 , 1024 ,
160
164
dtype = torch .float32 ).cuda (pynccl_comm .rank )
0 commit comments