1+ #
2+ # Licensed under the Apache License, Version 2.0 (the "License");
3+ # you may not use this file except in compliance with the License.
4+ # You may obtain a copy of the License at
5+ #
6+ # http://www.apache.org/licenses/LICENSE-2.0
7+ #
8+ # Unless required by applicable law or agreed to in writing, software
9+ # distributed under the License is distributed on an "AS IS" BASIS,
10+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+ # See the License for the specific language governing permissions and
12+ # limitations under the License.
13+ # This file is a part of the vllm-ascend project.
14+ #
15+
16+
17+ import inspect
18+ import pytest
19+ import torch
20+ from unittest .mock import MagicMock , patch , PropertyMock
21+
22+ from vllm .model_executor .layers .rotary_embedding import (YaRNScalingRotaryEmbedding , RotaryEmbedding )
23+ from vllm_ascend .ops .rotary_embedding import (AscendYaRNRotaryEmbedding , AscendRotaryEmbedding )
24+
25+
26+ HEAD_SIZE = 64
27+ ROTARY_DIM = 64
28+ MAX_POS = 2048
29+ BASE = 10000.0
30+ DTYPE = torch .bfloat16
31+ SEQ_LEN = 4
32+ NUM_HEADS = 2
33+
34+
35+ def _make_tensors (seq_len = SEQ_LEN , num_heads = NUM_HEADS , head_size = HEAD_SIZE ):
36+ positions = torch .arange (seq_len , dtype = torch .long )
37+ query = torch .randn (seq_len , num_heads * head_size )
38+ key = torch .randn (seq_len , num_heads * head_size )
39+ return positions , query , key
40+
41+
42+ def check_parent_init_signature_has_not_changed (parent_func , child_func ):
43+ parent_sig = inspect .signature (parent_func )
44+ parent_params = set (parent_sig .parameters ) - {"self" }
45+
46+ child_sig = inspect .signature (child_func )
47+ child_params = set (child_sig .parameters ) - {"self" }
48+
49+ added = parent_params - child_params
50+ removed = child_params - parent_params
51+
52+ assert not added , (
53+ f"{ parent_func .__name__ } added new parameter(s): { added } . "
54+ f"Check whether { child_func .__name__ } needs to forward them."
55+ )
56+ assert not removed , (
57+ f"{ parent_func .__name__ } removed parameter(s): { removed } . "
58+ f"Check whether { child_func .__name__ } needs to forward them."
59+ )
60+
61+
62+ @pytest .fixture (autouse = True )
63+ def patch_init_side_effects ():
64+ """
65+ Suppress all side-effects that fire during __init__ so every test starts
66+ from a clean, predictable state without needing real NPU ops or vLLM
67+ global config.
68+ """
69+ with (
70+ patch ("vllm_ascend.ops.rotary_embedding._record_cos_sin_cache" ),
71+ patch ("vllm_ascend.ops.rotary_embedding._record_cos_and_sin_cache_interleaved" ),
72+ patch ("vllm_ascend.ops.rotary_embedding.get_current_vllm_config" ) as mock_cfg ,
73+ ):
74+ # Default: speculative_config is None → use_mtp = False
75+ mock_cfg .return_value .speculative_config = None
76+ yield mock_cfg
77+
78+
79+ @pytest .fixture ()
80+ def make_embedding (patch_init_side_effects ):
81+ """Factory that creates an AscendRotaryEmbedding with controllable use_mtp."""
82+
83+ def _factory (use_mtp : bool = False , is_neox_style : bool = True ):
84+ spec_cfg = MagicMock (method = "mtp" ) if use_mtp else None
85+ patch_init_side_effects .return_value .speculative_config = spec_cfg
86+
87+ with patch ("vllm_ascend.ops.rotary_embedding.RotaryEmbedding.__init__" ) as mock_parent_init :
88+ mock_parent_init .return_value = None
89+ from vllm_ascend .ops .rotary_embedding import AscendRotaryEmbedding
90+
91+ emb = AscendRotaryEmbedding .__new__ (AscendRotaryEmbedding )
92+ # Manually set attrs that the real parent would set
93+ emb .head_size = HEAD_SIZE
94+ emb .rotary_dim = ROTARY_DIM
95+ emb .is_neox_style = is_neox_style
96+ emb .cos_sin_cache = torch .zeros (MAX_POS , ROTARY_DIM )
97+ # Call __init__ to exercise our code path
98+ AscendRotaryEmbedding .__init__ (
99+ emb , HEAD_SIZE , ROTARY_DIM , MAX_POS , BASE , is_neox_style , DTYPE
100+ )
101+ return emb
102+
103+ return _factory
104+
105+
106+ @pytest .fixture ()
107+ def make_yarn_embedding (patch_init_side_effects ):
108+ """
109+ Factory for AscendYaRNRotaryEmbedding with parent __init__ suppressed.
110+ patch_init_side_effects is the same autouse fixture as before.
111+ """
112+ def _factory (is_neox_style : bool = True ):
113+ with patch ("vllm_ascend.ops.rotary_embedding.YaRNScalingRotaryEmbedding.__init__" ) as mock_parent_init :
114+ mock_parent_init .return_value = None
115+ from vllm_ascend .ops .rotary_embedding import AscendYaRNRotaryEmbedding
116+
117+ emb = AscendYaRNRotaryEmbedding .__new__ (AscendYaRNRotaryEmbedding )
118+ emb .head_size = HEAD_SIZE
119+ emb .rotary_dim = ROTARY_DIM
120+ emb .is_neox_style = is_neox_style
121+ emb .cos_sin_cache = torch .zeros (MAX_POS , ROTARY_DIM )
122+ AscendYaRNRotaryEmbedding .__init__ (
123+ emb ,
124+ head_size = HEAD_SIZE ,
125+ rotary_dim = ROTARY_DIM ,
126+ max_position_embeddings = MAX_POS ,
127+ base = BASE ,
128+ is_neox_style = is_neox_style ,
129+ scaling_factor = 1.0 ,
130+ dtype = DTYPE ,
131+ )
132+ return emb
133+
134+ return _factory
135+
136+
137+ class TestAscendEmbeddingForwardOOT :
138+
139+ @patch ("torch.ops.vllm.npu_rotary_embedding" )
140+ @patch ("vllm_ascend.ascend_forward_context.get_forward_context" )
141+ def test_basic_call_delegates_to_npu_op (self , mock_get_forward_context , mock_npu_op , make_embedding ):
142+ """forward_oot always calls npu_rotary_embedding and returns its result."""
143+ mock_get_forward_context .return_value = MagicMock ()
144+ mock_get_forward_context .return_value .is_draft_model = False
145+ mock_get_forward_context .return_value .flash_comm_v1_enabled = False
146+ expected_output = (torch .randn (SEQ_LEN , NUM_HEADS * HEAD_SIZE ),) * 2
147+ mock_npu_op .return_value = expected_output
148+
149+ emb = make_embedding ()
150+ positions , query , key = _make_tensors ()
151+
152+ result = emb .forward_oot (positions , query , key )
153+
154+ mock_npu_op .assert_called_once_with (
155+ positions , query , key , emb .cos_sin_cache ,
156+ HEAD_SIZE , ROTARY_DIM , emb .is_neox_style ,
157+ )
158+ assert result is expected_output
159+
160+ @patch ("torch.ops.vllm.npu_rotary_embedding" )
161+ @patch ("vllm_ascend.ascend_forward_context.get_forward_context" )
162+ def test_neox_style_override_true (self , mock_get_forward_context , mock_npu_op , make_embedding ):
163+ """is_neox_style_override=True wins over self.is_neox_style=False."""
164+ mock_get_forward_context .return_value = MagicMock ()
165+ mock_get_forward_context .return_value .is_draft_model = False
166+ mock_get_forward_context .return_value .flash_comm_v1_enabled = False
167+ mock_npu_op .return_value = MagicMock ()
168+
169+ emb = make_embedding (is_neox_style = False )
170+ positions , query , key = _make_tensors ()
171+
172+ emb .forward_oot (positions , query , key , is_neox_style_override = True )
173+
174+ _ , kwargs = mock_npu_op .call_args
175+ # Verify the override was forwarded correctly
176+ assert mock_npu_op .call_args [0 ][- 1 ] is True # last positional arg = is_neox_style
177+
178+ @patch ("torch.ops.vllm.npu_rotary_embedding" )
179+ @patch ("vllm_ascend.ascend_forward_context.get_forward_context" )
180+ def test_neox_style_override_false (self , mock_get_forward_context , mock_npu_op , make_embedding ):
181+ """is_neox_style_override=False wins over self.is_neox_style=True."""
182+ mock_get_forward_context .return_value = MagicMock ()
183+ mock_get_forward_context .return_value .is_draft_model = False
184+ mock_get_forward_context .return_value .flash_comm_v1_enabled = False
185+ mock_npu_op .return_value = MagicMock ()
186+
187+ emb = make_embedding (is_neox_style = True )
188+ positions , query , key = _make_tensors ()
189+
190+ emb .forward_oot (positions , query , key , is_neox_style_override = False )
191+
192+ assert mock_npu_op .call_args [0 ][- 1 ] is False
193+
194+ @patch ("torch.ops.vllm.npu_rotary_embedding" )
195+ @patch ("vllm_ascend.ascend_forward_context.get_forward_context" )
196+ def test_neox_style_override_none_uses_self (self , mock_get_forward_context , mock_npu_op , make_embedding ):
197+ """When override is None, self.is_neox_style is used unchanged."""
198+ mock_get_forward_context .return_value = MagicMock ()
199+ mock_get_forward_context .return_value .is_draft_model = False
200+ mock_get_forward_context .return_value .flash_comm_v1_enabled = False
201+ mock_npu_op .return_value = MagicMock ()
202+
203+ emb = make_embedding (is_neox_style = True )
204+ positions , query , key = _make_tensors ()
205+
206+ emb .forward_oot (positions , query , key , is_neox_style_override = None )
207+
208+ assert mock_npu_op .call_args [0 ][- 1 ] is True
209+
210+ @patch ("torch.ops.vllm.maybe_all_gather_and_maybe_unpad" )
211+ @patch ("torch.ops.vllm.npu_rotary_embedding" )
212+ @patch ("vllm_ascend.ascend_forward_context.get_forward_context" )
213+ def test_gather_unpad_called_when_all_conditions_met (
214+ self , mock_get_forward_context , mock_npu_op , mock_gather , make_embedding
215+ ):
216+ """
217+ maybe_all_gather_and_maybe_unpad is called iff:
218+ is_draft_model=True AND use_mtp=True AND flash_comm_v1_enabled=True
219+ """
220+ mock_get_forward_context .return_value = MagicMock ()
221+ mock_get_forward_context .return_value .is_draft_model = True
222+ mock_get_forward_context .return_value .flash_comm_v1_enabled = True
223+ gathered_positions = torch .arange (SEQ_LEN , dtype = torch .long )
224+ mock_gather .return_value = gathered_positions
225+ mock_npu_op .return_value = MagicMock ()
226+
227+ emb = make_embedding (use_mtp = True )
228+ positions , query , key = _make_tensors ()
229+
230+ emb .forward_oot (positions , query , key )
231+
232+ mock_gather .assert_called_once ()
233+ # npu op should receive the gathered positions, not the originals
234+ assert mock_npu_op .call_args [0 ][0 ] is gathered_positions
235+
236+ @pytest .mark .parametrize ("is_draft_model,flash_comm,use_mtp" , [
237+ (False , True , True ), # not draft
238+ (True , False , True ), # flash_comm disabled
239+ (True , True , False ), # use_mtp disabled
240+ ])
241+ @patch ("torch.ops.vllm.maybe_all_gather_and_maybe_unpad" )
242+ @patch ("torch.ops.vllm.npu_rotary_embedding" )
243+ @patch ("vllm_ascend.ascend_forward_context.get_forward_context" )
244+ def test_gather_unpad_skipped_unless_all_conditions_met (
245+ self , mock_get_forward_context , mock_npu_op , mock_gather ,
246+ is_draft_model , flash_comm , use_mtp , make_embedding ,
247+ ):
248+ """gather/unpad must NOT fire if any one of the three conditions is False."""
249+ mock_get_forward_context .return_value = MagicMock ()
250+ mock_get_forward_context .return_value .is_draft_model = is_draft_model
251+ mock_get_forward_context .return_value .flash_comm_v1_enabled = flash_comm
252+ mock_npu_op .return_value = MagicMock ()
253+
254+ emb = make_embedding (use_mtp = use_mtp )
255+ positions , query , key = _make_tensors ()
256+
257+ emb .forward_oot (positions , query , key )
258+
259+ mock_gather .assert_not_called ()
260+ # Original positions tensor is passed through untouched
261+ assert mock_npu_op .call_args [0 ][0 ] is positions
262+
263+ def test_parent_init_signature_has_not_changed (self ):
264+ """
265+ Fail loudly if RotaryEmbedding.__init__ adds, removes, or
266+ renames parameters, so a developer knows to update AscendRotaryEmbedding
267+ accordingly.
268+ """
269+ check_parent_init_signature_has_not_changed (
270+ RotaryEmbedding .__init__ ,
271+ AscendRotaryEmbedding .__init__
272+ )
273+
274+
275+ class TestAscendYaRNRotaryEmbeddingForwardOOT :
276+
277+ @patch ("vllm_ascend.ops.rotary_embedding.AscendRotaryEmbedding.forward_oot" )
278+ def test_delegates_to_ascend_rotary_forward_oot (self , mock_delegate , make_yarn_embedding ):
279+ """forward_oot must delegate to AscendRotaryEmbedding.forward_oot."""
280+ expected = MagicMock ()
281+ mock_delegate .return_value = expected
282+
283+ emb = make_yarn_embedding ()
284+ positions , query , key = _make_tensors ()
285+
286+ result = emb .forward_oot (positions , query , key )
287+
288+ mock_delegate .assert_called_once_with (emb , positions , query , key , None , None )
289+ assert result is expected
290+
291+ @patch ("vllm_ascend.ops.rotary_embedding.AscendRotaryEmbedding.forward_oot" )
292+ def test_return_value_passed_through (self , mock_delegate , make_yarn_embedding ):
293+ """Return value from the delegate is returned unchanged."""
294+ sentinel = (torch .randn (SEQ_LEN , HEAD_SIZE ), torch .randn (SEQ_LEN , HEAD_SIZE ))
295+ mock_delegate .return_value = sentinel
296+
297+ emb = make_yarn_embedding ()
298+ positions , query , key = _make_tensors ()
299+
300+ result = emb .forward_oot (positions , query , key )
301+
302+ assert result is sentinel
303+
304+ @pytest .mark .parametrize ("override" , [True , False ])
305+ @patch ("vllm_ascend.ops.rotary_embedding.AscendRotaryEmbedding.forward_oot" )
306+ def test_is_neox_style_override_forwarded (self , mock_delegate , override , make_yarn_embedding ):
307+ """is_neox_style_override must be forwarded verbatim, both True and False."""
308+ mock_delegate .return_value = MagicMock ()
309+
310+ emb = make_yarn_embedding ()
311+ positions , query , key = _make_tensors ()
312+
313+ emb .forward_oot (positions , query , key , is_neox_style_override = override )
314+
315+ _ , call_args , _ = mock_delegate .mock_calls [0 ]
316+ assert call_args [5 ] is override # 6th positional arg
317+
318+ @patch ("vllm_ascend.ops.rotary_embedding.AscendRotaryEmbedding.forward_oot" )
319+ def test_all_args_forwarded_together (self , mock_delegate , make_yarn_embedding ):
320+ """Smoke test: all args passed simultaneously are all forwarded correctly."""
321+ mock_delegate .return_value = MagicMock ()
322+
323+ emb = make_yarn_embedding ()
324+ positions , query , key = _make_tensors ()
325+ offsets = torch .ones (SEQ_LEN , dtype = torch .long )
326+
327+ emb .forward_oot (positions , query , key , offsets = offsets , is_neox_style_override = False )
328+
329+ mock_delegate .assert_called_once_with (emb , positions , query , key , offsets , False )
330+
331+ def test_parent_init_signature_has_not_changed (self ):
332+ """
333+ Fail loudly if YaRNScalingRotaryEmbedding.__init__ adds, removes, or
334+ renames parameters, so a developer knows to update AscendYaRNRotaryEmbedding
335+ accordingly.
336+ """
337+ check_parent_init_signature_has_not_changed (
338+ YaRNScalingRotaryEmbedding .__init__ ,
339+ AscendYaRNRotaryEmbedding .__init__
340+ )
0 commit comments