1
+ import pytest
2
+ import torch
3
+ from pytest_mock import MockerFixture
4
+ from transformers import PretrainedConfig
5
+ from vllm .config import CacheConfig , ModelConfig , VllmConfig
6
+
7
+ from tests .ut .base import PytestBase
8
+ from vllm_ascend .torchair .models .torchair_deepseek_mtp import (
9
+ TorchairDeepSeekMTP , TorchairDeepSeekMultiTokenPredictor ,
10
+ TorchairDeepSeekMultiTokenPredictorLayer )
11
+
12
+
13
+ class TestTorchairDeepSeekMultiTokenPredictorLayer (PytestBase ):
14
+
15
+ @pytest .fixture
16
+ def setup_mtp_layer (self , mocker : MockerFixture ):
17
+ config = PretrainedConfig (vocab_size = 1000 ,
18
+ hidden_size = 768 ,
19
+ rms_norm_eps = 1e-5 )
20
+ mocker .patch (
21
+ "vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__" ,
22
+ return_value = None )
23
+ mocker .patch ("vllm.model_executor.layers.layernorm.RMSNorm.__init__" ,
24
+ return_value = None )
25
+ mocker .patch (
26
+ "vllm.model_executor.models.deepseek_mtp.SharedHead.__init__" ,
27
+ return_value = None )
28
+ mocker .patch (
29
+ "vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekShareHead.__init__" ,
30
+ return_value = None )
31
+ mocker_deepseek_v2_decode_layer = mocker .patch (
32
+ "vllm_ascend.torchair.models.torchair_deepseek_v2.TorchairDeepseekV2DecoderLayer.__init__" ,
33
+ return_value = None )
34
+
35
+ mtp_layer = TorchairDeepSeekMultiTokenPredictorLayer (config , "" , None )
36
+ mocker_deepseek_v2_decode_layer .assert_called_once ()
37
+ return mtp_layer
38
+
39
+ def test_init (self , mocker : MockerFixture , setup_mtp_layer ):
40
+ mtp_layer = setup_mtp_layer
41
+ assert isinstance (mtp_layer , TorchairDeepSeekMultiTokenPredictorLayer )
42
+
43
+ def test_forward (self , mocker : MockerFixture , setup_mtp_layer ):
44
+ mtp_layer = setup_mtp_layer
45
+ mocker .patch ("torch.nn.Module.__setattr__" )
46
+ mocker .patch ("torch.nn.Module.__getattr__" )
47
+ mocker .patch ("torch.nn.Module.__delattr__" )
48
+ mocker .patch .object (mtp_layer ,
49
+ 'eh_proj' ,
50
+ return_value = torch .randn (2 , 3 , 768 ))
51
+ mocker .patch ("torch.cat" , return_value = torch .randn (2 , 3 , 768 ))
52
+ mtp_layer .mtp_block .return_value = (torch .randn (2 , 3 , 768 ),
53
+ torch .randn (2 , 3 , 768 ))
54
+
55
+ input_ids = torch .tensor ([[1 , 2 , 3 ], [4 , 5 , 6 ]])
56
+ positions = torch .tensor ([[0 , 1 , 2 ], [0 , 1 , 2 ]])
57
+ kv_cache = torch .randn (2 , 3 , 768 )
58
+ previous_hidden_states = torch .randn (2 , 3 , 768 )
59
+ inputs_embeds = torch .tensor ([[1.0 , 2.0 , 3.0 ]])
60
+
61
+ output = mtp_layer (input_ids , positions , kv_cache , None ,
62
+ previous_hidden_states , inputs_embeds , 0 )
63
+ assert output .shape == (2 , 3 , 768 )
64
+
65
+
66
+ class TestTorchairDeepSeekMultiTokenPredictor (PytestBase ):
67
+
68
+ @pytest .fixture
69
+ def setup_predictor (self , mocker : MockerFixture ):
70
+ mock_vllm_config = mocker .MagicMock (spec = VllmConfig )
71
+ mock_model_config = mocker .MagicMock (spec = ModelConfig )
72
+ mock_hf_config = mocker .MagicMock ()
73
+ mock_hf_config .num_hidden_layers = 12
74
+ mock_hf_config .num_nextn_predict_layers = 3
75
+ mock_hf_config .vocab_size = 30000
76
+ mock_model_config .hf_config = mock_hf_config
77
+ mock_vllm_config .model_config = mock_model_config
78
+ mock_vllm_config .cache_config = CacheConfig ()
79
+ mock_vllm_config .quant_config = mocker .MagicMock ()
80
+ mocker .patch (
81
+ "vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__" ,
82
+ return_value = None )
83
+ mocker .patch (
84
+ "vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMultiTokenPredictorLayer.__init__" ,
85
+ return_value = None )
86
+
87
+ predictor = TorchairDeepSeekMultiTokenPredictor (
88
+ vllm_config = mock_vllm_config )
89
+ return predictor
90
+
91
+ def test_init (self , mocker : MockerFixture , setup_predictor ):
92
+ predictor = setup_predictor
93
+ assert predictor .num_mtp_layers == 3
94
+ assert isinstance (predictor , TorchairDeepSeekMultiTokenPredictor )
95
+
96
+ @pytest .mark .parametrize (
97
+ 'kv_caches, inputs_embeds' ,
98
+ [(torch .tensor ([[[0.1 , 0.2 , 0.3 ]]]), torch .tensor ([[0.1 , 0.2 , 0.3 ]]))])
99
+ def test_forward (self , mocker : MockerFixture , setup_predictor , kv_caches ,
100
+ inputs_embeds ):
101
+ predictor = setup_predictor
102
+ mock_layer = mocker .MagicMock ()
103
+ mock_layer .return_value = torch .tensor ([1.0 , 2.0 , 3.0 ])
104
+ predictor .layers_list = [mock_layer ]
105
+
106
+ # todo: need or not?
107
+ # predictor.num_mtp_layers = 1
108
+ input_ids = torch .tensor ([[1 , 2 , 3 ]])
109
+ positions = torch .tensor ([[0 , 1 , 2 ]])
110
+ mocker .patch (
111
+ "vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMultiTokenPredictorLayer.__call__" ,
112
+ return_value = torch .tensor ([[1.0 , 2.0 , 3.0 ]]))
113
+ output = predictor .forward (input_ids , positions , kv_caches , None , None ,
114
+ inputs_embeds , 0 )
115
+ mock_layer .assert_called_once ()
116
+ assert torch .allclose (output , torch .tensor ([1.0 , 2.0 , 3.0 ]))
117
+
118
+ def test_compute_logits (self , mocker : MockerFixture , setup_predictor ):
119
+ hidden_states = torch .tensor ([[1 , 2 , 3 ], [4 , 5 , 6 ]])
120
+ predictor = setup_predictor
121
+
122
+ mock_layer = mocker .MagicMock ()
123
+ mock_layer .return_value = torch .tensor ([1.0 , 2.0 , 3.0 ])
124
+ predictor .layers_list = [mock_layer ]
125
+ mocker .patch ("torch.nn.Module.__setattr__" )
126
+ mocker .patch ("torch.nn.Module.__getattr__" )
127
+ mocker .patch ("torch.nn.Module.__delattr__" )
128
+ mocker .patch (
129
+ "vllm.model_executor.layers.logits_processor.LogitsProcessor.__init__" ,
130
+ return_value = None )
131
+ predictor .logits_processor .return_value = torch .tensor ([1.0 , 2.0 , 3.0 ])
132
+
133
+ result_logits = predictor .compute_logits (hidden_states = hidden_states ,
134
+ sampling_metadata = None )
135
+ predictor .logits_processor .assert_called_once ()
136
+ assert torch .allclose (result_logits , torch .tensor ([1.0 , 2.0 , 3.0 ]))
137
+
138
+
139
+ class TestTorchairDeepSeekMTP (PytestBase ):
140
+
141
+ @pytest .fixture
142
+ def setup_mtp (self , mocker : MockerFixture ):
143
+ vllm_config = mocker .MagicMock ()
144
+ vllm_config .model_config .hf_config .num_hidden_layers = 12
145
+ vllm_config .model_config .hf_config .num_nextn_predict_layers = 3
146
+ vllm_config .cache_config = mocker .MagicMock ()
147
+ vllm_config .quant_config = mocker .MagicMock ()
148
+
149
+ mocker .patch ("torch.nn.Module.__setattr__" )
150
+ mocker .patch ("torch.nn.Module.__getattr__" )
151
+ mocker .patch ("torch.nn.Module.__delattr__" )
152
+ mocker .patch (
153
+ "vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__" ,
154
+ return_value = None )
155
+ mocker .patch (
156
+ "vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMultiTokenPredictorLayer.__call__" ,
157
+ return_value = None )
158
+ mocker .patch ("vllm.model_executor.layers.sampler.get_sampler" ,
159
+ return_value = None )
160
+
161
+ mtp = TorchairDeepSeekMTP (vllm_config = vllm_config )
162
+ return mtp
163
+
164
+ def test_init (self , mocker : MockerFixture , setup_mtp ):
165
+ mtp = setup_mtp
166
+ assert isinstance (mtp , TorchairDeepSeekMTP )
167
+
168
+ def test_forward (self , mocker : MockerFixture , setup_mtp ):
169
+ input_ids = torch .tensor ([[1 , 2 , 3 ]])
170
+ positions = torch .tensor ([[0 , 1 , 2 ]])
171
+ kv_caches = [torch .tensor ([[0.1 , 0.2 , 0.3 ]])]
172
+ previous_hidden_states = torch .tensor ([[0.1 , 0.2 , 0.3 ]])
173
+ inputs_embeds = torch .tensor ([[0.1 , 0.2 , 0.3 ]])
174
+ spec_step_idx = 0
175
+ setup_mtp .model .return_value = torch .tensor ([[1.0 , 2.0 , 3.0 ]])
176
+
177
+ output = setup_mtp .forward (input_ids , positions , kv_caches , None ,
178
+ previous_hidden_states , inputs_embeds ,
179
+ spec_step_idx )
180
+ assert torch .allclose (output , torch .tensor ([[1.0 , 2.0 , 3.0 ]]))
0 commit comments