4
4
from vllm .config import ParallelConfig
5
5
6
6
from vllm_ascend .distributed .parallel_state import (
7
- _LMTP , _MC2 , _OTP , destroy_ascend_model_parallel , get_lmhead_tp_group ,
8
- get_mc2_group , get_otp_group , init_ascend_model_parallel )
7
+ _LMTP , _MC2 , _OTP , _P_TP , destroy_ascend_model_parallel ,
8
+ get_lmhead_tp_group , get_mc2_group , get_otp_group , get_p_tp_group ,
9
+ init_ascend_model_parallel )
9
10
10
11
11
12
@pytest .fixture
@@ -30,6 +31,7 @@ def test_init_ascend_model_parallel(mock_distributed, parallel_config):
30
31
mock_ascend_config = MagicMock ()
31
32
mock_ascend_config .lmhead_tensor_parallel_size = 2
32
33
mock_ascend_config .oproj_tensor_parallel_size = 2
34
+ mock_ascend_config .pd_tp_ratio = 2
33
35
with patch ('vllm_ascend.distributed.parallel_state.model_parallel_initialized' , return_value = False ), \
34
36
patch ('vllm_ascend.distributed.parallel_state.init_model_parallel_group' ), \
35
37
patch ('vllm_ascend.distributed.parallel_state.get_ascend_config' , return_value = mock_ascend_config ):
@@ -38,11 +40,14 @@ def test_init_ascend_model_parallel(mock_distributed, parallel_config):
38
40
mc2_group = get_mc2_group ()
39
41
lmheadtp_group = get_lmhead_tp_group ()
40
42
otp_group = get_otp_group ()
43
+ p_tp_group = get_p_tp_group ()
41
44
assert mc2_group is not None
42
45
assert otp_group is not None
43
46
assert lmheadtp_group is not None
47
+ assert p_tp_group is not None
44
48
45
49
destroy_ascend_model_parallel ()
46
50
assert _MC2 is None
47
51
assert _LMTP is None
48
52
assert _OTP is None
53
+ assert _P_TP is None
0 commit comments