Skip to content

Commit ee6f79c

Browse files
Add ut for test_communicator.py (#2293)
### What this PR does / why we need it? Add ut for test_communicator.py - vLLM version: v0.10.0 - vLLM main: vllm-project/vllm@e5ebeeb Signed-off-by: yangqinghao-cmss <[email protected]>
1 parent 3e65c40 commit ee6f79c

File tree

1 file changed

+155
-0
lines changed

1 file changed

+155
-0
lines changed
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
import unittest
2+
from unittest.mock import MagicMock, Mock, patch
3+
4+
import torch
5+
import torch.distributed as dist
6+
7+
from vllm_ascend.distributed.communicator import NPUCommunicator
8+
9+
10+
class TestNPUCommunicator(unittest.TestCase):
11+
12+
@patch("vllm.config.get_current_vllm_config", return_value=None)
13+
@patch("torch.npu.current_device", return_value=MagicMock())
14+
@patch("torch.npu.set_device", return_value=MagicMock())
15+
@patch("torch.distributed.get_process_group_ranks",
16+
return_value={
17+
0: 0,
18+
1: 1
19+
})
20+
@patch("torch.distributed.get_group_rank", return_value={0: 0, 1: 1})
21+
@patch("torch.distributed.is_initialized", return_value=True)
22+
@patch("torch.distributed.get_rank", return_value=1)
23+
@patch("torch.distributed.is_initialized", return_value=True)
24+
@patch("torch.distributed.get_backend", return_value="hccl")
25+
@patch("torch.distributed.get_rank", return_value=1)
26+
@patch("torch.distributed.get_world_size", return_value=2)
27+
@patch("torch.distributed.get_process_group_ranks", return_value=[0, 1])
28+
@patch("torch.npu.device")
29+
def test_all_to_all_with_sizes(self, *_):
30+
31+
def patched_all_to_all(output_tensor_list,
32+
input_tensor_list,
33+
group=None,
34+
async_op=False):
35+
output_tensor_list[:] = ([
36+
torch.tensor([10, 20]),
37+
torch.tensor([50, 60])
38+
])
39+
40+
torch.distributed.all_to_all = patched_all_to_all
41+
42+
scatter_sizes = [2, 2]
43+
gather_sizes = [2, 2]
44+
input_ = torch.tensor([10, 20, 30, 40])
45+
46+
comm = NPUCommunicator(cpu_group=dist.group.WORLD)
47+
48+
output = comm.all_to_all(input_,
49+
scatter_sizes=scatter_sizes,
50+
gather_sizes=gather_sizes)
51+
52+
assert output.tolist() == [10, 20, 50, 60]
53+
54+
@patch("vllm.config.get_current_vllm_config", return_value=None)
55+
@patch("torch.npu.current_device", return_value=MagicMock())
56+
@patch("torch.npu.set_device", return_value=MagicMock())
57+
@patch("torch.distributed.get_process_group_ranks",
58+
return_value={
59+
0: 0,
60+
1: 1
61+
})
62+
@patch("torch.distributed.get_group_rank", return_value={0: 0, 1: 1})
63+
@patch("torch.distributed.is_initialized", return_value=True)
64+
@patch("torch.distributed.get_rank", return_value=1)
65+
@patch("torch.distributed.is_initialized", return_value=True)
66+
@patch("torch.distributed.get_backend", return_value="hccl")
67+
@patch("torch.distributed.get_rank", return_value=1)
68+
@patch("torch.distributed.get_world_size", return_value=2)
69+
@patch("torch.distributed.get_process_group_ranks", return_value=[0, 1])
70+
@patch("torch.npu.device")
71+
def test_all_to_all_without_sizes(self, *_):
72+
73+
def patched_all_to_all(output_tensor_list,
74+
input_tensor_list,
75+
group=None,
76+
async_op=False):
77+
output_tensor_list[:] = ([
78+
torch.tensor([[10, 20]]),
79+
torch.tensor([[50, 60]])
80+
])
81+
82+
torch.distributed.all_to_all = patched_all_to_all
83+
84+
input_ = torch.tensor([[10, 20], [30, 40]])
85+
86+
comm = NPUCommunicator(cpu_group=dist.group.WORLD)
87+
output = comm.all_to_all(input_, scatter_dim=0, gather_dim=0)
88+
89+
assert output.tolist() == [[10, 20], [50, 60]]
90+
91+
@patch("vllm.config.get_current_vllm_config", return_value=None)
92+
@patch("torch.npu.current_device", return_value=MagicMock())
93+
@patch("torch.npu.set_device", return_value=MagicMock())
94+
@patch("torch.distributed.get_process_group_ranks",
95+
return_value={
96+
0: 0,
97+
1: 1
98+
})
99+
@patch("torch.distributed.get_group_rank", return_value={0: 0, 1: 1})
100+
@patch("torch.distributed.is_initialized", return_value=True)
101+
@patch("torch.distributed.get_rank", return_value=1)
102+
@patch("torch.distributed.is_initialized", return_value=True)
103+
@patch("torch.distributed.get_backend", return_value="hccl")
104+
@patch("torch.distributed.get_rank", return_value=1)
105+
@patch("torch.distributed.get_world_size", return_value=2)
106+
@patch("torch.distributed.get_process_group_ranks", return_value=[0, 1])
107+
@patch("torch.npu.device")
108+
def test_dispatch(self, *_):
109+
comm = NPUCommunicator(cpu_group=dist.group.WORLD)
110+
comm.all2all_manager = Mock()
111+
hidden_states = torch.randn(2, 4, 8)
112+
router_logits = torch.randn(2, 4, 2)
113+
114+
mock_dispatch_result = (torch.randn(2, 4, 8), torch.randn(2, 4, 2))
115+
comm.all2all_manager.dispatch.return_value = mock_dispatch_result
116+
117+
result_hidden, result_logits = comm.dispatch(hidden_states,
118+
router_logits)
119+
120+
assert torch.allclose(result_hidden, mock_dispatch_result[0])
121+
assert torch.allclose(result_logits, mock_dispatch_result[1])
122+
123+
comm.all2all_manager.dispatch.assert_called_once_with(
124+
hidden_states, router_logits)
125+
126+
@patch("vllm.config.get_current_vllm_config", return_value=None)
127+
@patch("torch.npu.current_device", return_value=MagicMock())
128+
@patch("torch.npu.set_device", return_value=MagicMock())
129+
@patch("torch.distributed.get_process_group_ranks",
130+
return_value={
131+
0: 0,
132+
1: 1
133+
})
134+
@patch("torch.distributed.get_group_rank", return_value={0: 0, 1: 1})
135+
@patch("torch.distributed.is_initialized", return_value=True)
136+
@patch("torch.distributed.get_rank", return_value=1)
137+
@patch("torch.distributed.is_initialized", return_value=True)
138+
@patch("torch.distributed.get_backend", return_value="hccl")
139+
@patch("torch.distributed.get_rank", return_value=1)
140+
@patch("torch.distributed.get_world_size", return_value=2)
141+
@patch("torch.distributed.get_process_group_ranks", return_value=[0, 1])
142+
@patch("torch.npu.device")
143+
def test_combine(self, *_):
144+
comm = NPUCommunicator(cpu_group=dist.group.WORLD)
145+
comm.all2all_manager = Mock()
146+
hidden_states = torch.randn(2, 4, 8)
147+
148+
mock_combine_result = torch.randn(2, 4, 8)
149+
comm.all2all_manager.combine.return_value = mock_combine_result
150+
151+
result = comm.combine(hidden_states)
152+
153+
assert torch.allclose(result, mock_combine_result)
154+
155+
comm.all2all_manager.combine.assert_called_once_with(hidden_states)

0 commit comments

Comments
 (0)