Skip to content

Commit 3fc31ee

Browse files
authored
[1/N][refactor] torchair deepseek modeling refactor (#2384)
### What this PR does / why we need it? Move torchair related model arch into torchair moduel to make the code clear. Next step we'll remove all torchair related code outside of torchair moduel. ### Does this PR introduce _any_ user-facing change? No. - vLLM version: v0.10.0 - vLLM main: vllm-project/vllm@08d5f71 Signed-off-by: linfeng-yuan <[email protected]>
1 parent 19fdc9a commit 3fc31ee

File tree

9 files changed

+1863
-0
lines changed

9 files changed

+1863
-0
lines changed
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
import pytest
2+
import torch
3+
from pytest_mock import MockerFixture
4+
from transformers import PretrainedConfig
5+
from vllm.config import CacheConfig, ModelConfig, VllmConfig
6+
7+
from tests.ut.base import PytestBase
8+
from vllm_ascend.torchair.models.torchair_deepseek_mtp import (
9+
TorchairDeepSeekMTP, TorchairDeepSeekMultiTokenPredictor,
10+
TorchairDeepSeekMultiTokenPredictorLayer)
11+
12+
13+
class TestTorchairDeepSeekMultiTokenPredictorLayer(PytestBase):
14+
15+
@pytest.fixture
16+
def setup_mtp_layer(self, mocker: MockerFixture):
17+
config = PretrainedConfig(vocab_size=1000,
18+
hidden_size=768,
19+
rms_norm_eps=1e-5)
20+
mocker.patch(
21+
"vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__",
22+
return_value=None)
23+
mocker.patch("vllm.model_executor.layers.layernorm.RMSNorm.__init__",
24+
return_value=None)
25+
mocker.patch(
26+
"vllm.model_executor.models.deepseek_mtp.SharedHead.__init__",
27+
return_value=None)
28+
mocker.patch(
29+
"vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekShareHead.__init__",
30+
return_value=None)
31+
mocker_deepseek_v2_decode_layer = mocker.patch(
32+
"vllm_ascend.torchair.models.torchair_deepseek_v2.TorchairDeepseekV2DecoderLayer.__init__",
33+
return_value=None)
34+
35+
mtp_layer = TorchairDeepSeekMultiTokenPredictorLayer(config, "", None)
36+
mocker_deepseek_v2_decode_layer.assert_called_once()
37+
return mtp_layer
38+
39+
def test_init(self, mocker: MockerFixture, setup_mtp_layer):
40+
mtp_layer = setup_mtp_layer
41+
assert isinstance(mtp_layer, TorchairDeepSeekMultiTokenPredictorLayer)
42+
43+
def test_forward(self, mocker: MockerFixture, setup_mtp_layer):
44+
mtp_layer = setup_mtp_layer
45+
mocker.patch("torch.nn.Module.__setattr__")
46+
mocker.patch("torch.nn.Module.__getattr__")
47+
mocker.patch("torch.nn.Module.__delattr__")
48+
mocker.patch.object(mtp_layer,
49+
'eh_proj',
50+
return_value=torch.randn(2, 3, 768))
51+
mocker.patch("torch.cat", return_value=torch.randn(2, 3, 768))
52+
mtp_layer.mtp_block.return_value = (torch.randn(2, 3, 768),
53+
torch.randn(2, 3, 768))
54+
55+
input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]])
56+
positions = torch.tensor([[0, 1, 2], [0, 1, 2]])
57+
kv_cache = torch.randn(2, 3, 768)
58+
previous_hidden_states = torch.randn(2, 3, 768)
59+
inputs_embeds = torch.tensor([[1.0, 2.0, 3.0]])
60+
61+
output = mtp_layer(input_ids, positions, kv_cache, None,
62+
previous_hidden_states, inputs_embeds, 0)
63+
assert output.shape == (2, 3, 768)
64+
65+
66+
class TestTorchairDeepSeekMultiTokenPredictor(PytestBase):
67+
68+
@pytest.fixture
69+
def setup_predictor(self, mocker: MockerFixture):
70+
mock_vllm_config = mocker.MagicMock(spec=VllmConfig)
71+
mock_model_config = mocker.MagicMock(spec=ModelConfig)
72+
mock_hf_config = mocker.MagicMock()
73+
mock_hf_config.num_hidden_layers = 12
74+
mock_hf_config.num_nextn_predict_layers = 3
75+
mock_hf_config.vocab_size = 30000
76+
mock_model_config.hf_config = mock_hf_config
77+
mock_vllm_config.model_config = mock_model_config
78+
mock_vllm_config.cache_config = CacheConfig()
79+
mock_vllm_config.quant_config = mocker.MagicMock()
80+
mocker.patch(
81+
"vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__",
82+
return_value=None)
83+
mocker.patch(
84+
"vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMultiTokenPredictorLayer.__init__",
85+
return_value=None)
86+
87+
predictor = TorchairDeepSeekMultiTokenPredictor(
88+
vllm_config=mock_vllm_config)
89+
return predictor
90+
91+
def test_init(self, mocker: MockerFixture, setup_predictor):
92+
predictor = setup_predictor
93+
assert predictor.num_mtp_layers == 3
94+
assert isinstance(predictor, TorchairDeepSeekMultiTokenPredictor)
95+
96+
@pytest.mark.parametrize(
97+
'kv_caches, inputs_embeds',
98+
[(torch.tensor([[[0.1, 0.2, 0.3]]]), torch.tensor([[0.1, 0.2, 0.3]]))])
99+
def test_forward(self, mocker: MockerFixture, setup_predictor, kv_caches,
100+
inputs_embeds):
101+
predictor = setup_predictor
102+
mock_layer = mocker.MagicMock()
103+
mock_layer.return_value = torch.tensor([1.0, 2.0, 3.0])
104+
predictor.layers_list = [mock_layer]
105+
106+
# todo: need or not?
107+
# predictor.num_mtp_layers = 1
108+
input_ids = torch.tensor([[1, 2, 3]])
109+
positions = torch.tensor([[0, 1, 2]])
110+
mocker.patch(
111+
"vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMultiTokenPredictorLayer.__call__",
112+
return_value=torch.tensor([[1.0, 2.0, 3.0]]))
113+
output = predictor.forward(input_ids, positions, kv_caches, None, None,
114+
inputs_embeds, 0)
115+
mock_layer.assert_called_once()
116+
assert torch.allclose(output, torch.tensor([1.0, 2.0, 3.0]))
117+
118+
def test_compute_logits(self, mocker: MockerFixture, setup_predictor):
119+
hidden_states = torch.tensor([[1, 2, 3], [4, 5, 6]])
120+
predictor = setup_predictor
121+
122+
mock_layer = mocker.MagicMock()
123+
mock_layer.return_value = torch.tensor([1.0, 2.0, 3.0])
124+
predictor.layers_list = [mock_layer]
125+
mocker.patch("torch.nn.Module.__setattr__")
126+
mocker.patch("torch.nn.Module.__getattr__")
127+
mocker.patch("torch.nn.Module.__delattr__")
128+
mocker.patch(
129+
"vllm.model_executor.layers.logits_processor.LogitsProcessor.__init__",
130+
return_value=None)
131+
predictor.logits_processor.return_value = torch.tensor([1.0, 2.0, 3.0])
132+
133+
result_logits = predictor.compute_logits(hidden_states=hidden_states,
134+
sampling_metadata=None)
135+
predictor.logits_processor.assert_called_once()
136+
assert torch.allclose(result_logits, torch.tensor([1.0, 2.0, 3.0]))
137+
138+
139+
class TestTorchairDeepSeekMTP(PytestBase):
140+
141+
@pytest.fixture
142+
def setup_mtp(self, mocker: MockerFixture):
143+
vllm_config = mocker.MagicMock()
144+
vllm_config.model_config.hf_config.num_hidden_layers = 12
145+
vllm_config.model_config.hf_config.num_nextn_predict_layers = 3
146+
vllm_config.cache_config = mocker.MagicMock()
147+
vllm_config.quant_config = mocker.MagicMock()
148+
149+
mocker.patch("torch.nn.Module.__setattr__")
150+
mocker.patch("torch.nn.Module.__getattr__")
151+
mocker.patch("torch.nn.Module.__delattr__")
152+
mocker.patch(
153+
"vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__",
154+
return_value=None)
155+
mocker.patch(
156+
"vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMultiTokenPredictorLayer.__call__",
157+
return_value=None)
158+
mocker.patch("vllm.model_executor.layers.sampler.get_sampler",
159+
return_value=None)
160+
161+
mtp = TorchairDeepSeekMTP(vllm_config=vllm_config)
162+
return mtp
163+
164+
def test_init(self, mocker: MockerFixture, setup_mtp):
165+
mtp = setup_mtp
166+
assert isinstance(mtp, TorchairDeepSeekMTP)
167+
168+
def test_forward(self, mocker: MockerFixture, setup_mtp):
169+
input_ids = torch.tensor([[1, 2, 3]])
170+
positions = torch.tensor([[0, 1, 2]])
171+
kv_caches = [torch.tensor([[0.1, 0.2, 0.3]])]
172+
previous_hidden_states = torch.tensor([[0.1, 0.2, 0.3]])
173+
inputs_embeds = torch.tensor([[0.1, 0.2, 0.3]])
174+
spec_step_idx = 0
175+
setup_mtp.model.return_value = torch.tensor([[1.0, 2.0, 3.0]])
176+
177+
output = setup_mtp.forward(input_ids, positions, kv_caches, None,
178+
previous_hidden_states, inputs_embeds,
179+
spec_step_idx)
180+
assert torch.allclose(output, torch.tensor([[1.0, 2.0, 3.0]]))

0 commit comments

Comments
 (0)