Skip to content

Commit bdb6531

Browse files
authored
[UT] Align input arguments with Ascend(Yarn)RotaryEmbedding with vLLM and add ut (#7358)
### What this PR does / why we need it? This PR adds missing arguments in `AscendRotaryEmbedding`, `AscendYarnRotaryEmbedding` to conform with vLLM. Besides, corresponding ut is introduced. - vLLM version: v0.17.0 - vLLM main: vllm-project/vllm@4034c3d --------- Signed-off-by: Angazenn <supperccell@163.com>
1 parent 568b6d0 commit bdb6531

File tree

2 files changed

+344
-1
lines changed

2 files changed

+344
-1
lines changed
Lines changed: 340 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,340 @@
1+
#
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
# This file is a part of the vllm-ascend project.
14+
#
15+
16+
17+
import inspect
18+
import pytest
19+
import torch
20+
from unittest.mock import MagicMock, patch, PropertyMock
21+
22+
from vllm.model_executor.layers.rotary_embedding import (YaRNScalingRotaryEmbedding, RotaryEmbedding)
23+
from vllm_ascend.ops.rotary_embedding import (AscendYaRNRotaryEmbedding, AscendRotaryEmbedding)
24+
25+
26+
HEAD_SIZE = 64
27+
ROTARY_DIM = 64
28+
MAX_POS = 2048
29+
BASE = 10000.0
30+
DTYPE = torch.bfloat16
31+
SEQ_LEN = 4
32+
NUM_HEADS = 2
33+
34+
35+
def _make_tensors(seq_len=SEQ_LEN, num_heads=NUM_HEADS, head_size=HEAD_SIZE):
36+
positions = torch.arange(seq_len, dtype=torch.long)
37+
query = torch.randn(seq_len, num_heads * head_size)
38+
key = torch.randn(seq_len, num_heads * head_size)
39+
return positions, query, key
40+
41+
42+
def check_parent_init_signature_has_not_changed(parent_func, child_func):
43+
parent_sig = inspect.signature(parent_func)
44+
parent_params = set(parent_sig.parameters) - {"self"}
45+
46+
child_sig = inspect.signature(child_func)
47+
child_params = set(child_sig.parameters) - {"self"}
48+
49+
added = parent_params - child_params
50+
removed = child_params - parent_params
51+
52+
assert not added, (
53+
f"{parent_func.__name__} added new parameter(s): {added}. "
54+
f"Check whether {child_func.__name__} needs to forward them."
55+
)
56+
assert not removed, (
57+
f"{parent_func.__name__} removed parameter(s): {removed}. "
58+
f"Check whether {child_func.__name__} needs to forward them."
59+
)
60+
61+
62+
@pytest.fixture(autouse=True)
63+
def patch_init_side_effects():
64+
"""
65+
Suppress all side-effects that fire during __init__ so every test starts
66+
from a clean, predictable state without needing real NPU ops or vLLM
67+
global config.
68+
"""
69+
with (
70+
patch("vllm_ascend.ops.rotary_embedding._record_cos_sin_cache"),
71+
patch("vllm_ascend.ops.rotary_embedding._record_cos_and_sin_cache_interleaved"),
72+
patch("vllm_ascend.ops.rotary_embedding.get_current_vllm_config") as mock_cfg,
73+
):
74+
# Default: speculative_config is None → use_mtp = False
75+
mock_cfg.return_value.speculative_config = None
76+
yield mock_cfg
77+
78+
79+
@pytest.fixture()
80+
def make_embedding(patch_init_side_effects):
81+
"""Factory that creates an AscendRotaryEmbedding with controllable use_mtp."""
82+
83+
def _factory(use_mtp: bool = False, is_neox_style: bool = True):
84+
spec_cfg = MagicMock(method="mtp") if use_mtp else None
85+
patch_init_side_effects.return_value.speculative_config = spec_cfg
86+
87+
with patch("vllm_ascend.ops.rotary_embedding.RotaryEmbedding.__init__") as mock_parent_init:
88+
mock_parent_init.return_value = None
89+
from vllm_ascend.ops.rotary_embedding import AscendRotaryEmbedding
90+
91+
emb = AscendRotaryEmbedding.__new__(AscendRotaryEmbedding)
92+
# Manually set attrs that the real parent would set
93+
emb.head_size = HEAD_SIZE
94+
emb.rotary_dim = ROTARY_DIM
95+
emb.is_neox_style = is_neox_style
96+
emb.cos_sin_cache = torch.zeros(MAX_POS, ROTARY_DIM)
97+
# Call __init__ to exercise our code path
98+
AscendRotaryEmbedding.__init__(
99+
emb, HEAD_SIZE, ROTARY_DIM, MAX_POS, BASE, is_neox_style, DTYPE
100+
)
101+
return emb
102+
103+
return _factory
104+
105+
106+
@pytest.fixture()
107+
def make_yarn_embedding(patch_init_side_effects):
108+
"""
109+
Factory for AscendYaRNRotaryEmbedding with parent __init__ suppressed.
110+
patch_init_side_effects is the same autouse fixture as before.
111+
"""
112+
def _factory(is_neox_style: bool = True):
113+
with patch("vllm_ascend.ops.rotary_embedding.YaRNScalingRotaryEmbedding.__init__") as mock_parent_init:
114+
mock_parent_init.return_value = None
115+
from vllm_ascend.ops.rotary_embedding import AscendYaRNRotaryEmbedding
116+
117+
emb = AscendYaRNRotaryEmbedding.__new__(AscendYaRNRotaryEmbedding)
118+
emb.head_size = HEAD_SIZE
119+
emb.rotary_dim = ROTARY_DIM
120+
emb.is_neox_style = is_neox_style
121+
emb.cos_sin_cache = torch.zeros(MAX_POS, ROTARY_DIM)
122+
AscendYaRNRotaryEmbedding.__init__(
123+
emb,
124+
head_size=HEAD_SIZE,
125+
rotary_dim=ROTARY_DIM,
126+
max_position_embeddings=MAX_POS,
127+
base=BASE,
128+
is_neox_style=is_neox_style,
129+
scaling_factor=1.0,
130+
dtype=DTYPE,
131+
)
132+
return emb
133+
134+
return _factory
135+
136+
137+
class TestAscendEmbeddingForwardOOT:
138+
139+
@patch("torch.ops.vllm.npu_rotary_embedding")
140+
@patch("vllm_ascend.ascend_forward_context.get_forward_context")
141+
def test_basic_call_delegates_to_npu_op(self, mock_get_forward_context, mock_npu_op, make_embedding):
142+
"""forward_oot always calls npu_rotary_embedding and returns its result."""
143+
mock_get_forward_context.return_value = MagicMock()
144+
mock_get_forward_context.return_value.is_draft_model = False
145+
mock_get_forward_context.return_value.flash_comm_v1_enabled = False
146+
expected_output = (torch.randn(SEQ_LEN, NUM_HEADS * HEAD_SIZE),) * 2
147+
mock_npu_op.return_value = expected_output
148+
149+
emb = make_embedding()
150+
positions, query, key = _make_tensors()
151+
152+
result = emb.forward_oot(positions, query, key)
153+
154+
mock_npu_op.assert_called_once_with(
155+
positions, query, key, emb.cos_sin_cache,
156+
HEAD_SIZE, ROTARY_DIM, emb.is_neox_style,
157+
)
158+
assert result is expected_output
159+
160+
@patch("torch.ops.vllm.npu_rotary_embedding")
161+
@patch("vllm_ascend.ascend_forward_context.get_forward_context")
162+
def test_neox_style_override_true(self, mock_get_forward_context, mock_npu_op, make_embedding):
163+
"""is_neox_style_override=True wins over self.is_neox_style=False."""
164+
mock_get_forward_context.return_value = MagicMock()
165+
mock_get_forward_context.return_value.is_draft_model = False
166+
mock_get_forward_context.return_value.flash_comm_v1_enabled = False
167+
mock_npu_op.return_value = MagicMock()
168+
169+
emb = make_embedding(is_neox_style=False)
170+
positions, query, key = _make_tensors()
171+
172+
emb.forward_oot(positions, query, key, is_neox_style_override=True)
173+
174+
_, kwargs = mock_npu_op.call_args
175+
# Verify the override was forwarded correctly
176+
assert mock_npu_op.call_args[0][-1] is True # last positional arg = is_neox_style
177+
178+
@patch("torch.ops.vllm.npu_rotary_embedding")
179+
@patch("vllm_ascend.ascend_forward_context.get_forward_context")
180+
def test_neox_style_override_false(self, mock_get_forward_context, mock_npu_op, make_embedding):
181+
"""is_neox_style_override=False wins over self.is_neox_style=True."""
182+
mock_get_forward_context.return_value = MagicMock()
183+
mock_get_forward_context.return_value.is_draft_model = False
184+
mock_get_forward_context.return_value.flash_comm_v1_enabled = False
185+
mock_npu_op.return_value = MagicMock()
186+
187+
emb = make_embedding(is_neox_style=True)
188+
positions, query, key = _make_tensors()
189+
190+
emb.forward_oot(positions, query, key, is_neox_style_override=False)
191+
192+
assert mock_npu_op.call_args[0][-1] is False
193+
194+
@patch("torch.ops.vllm.npu_rotary_embedding")
195+
@patch("vllm_ascend.ascend_forward_context.get_forward_context")
196+
def test_neox_style_override_none_uses_self(self, mock_get_forward_context, mock_npu_op, make_embedding):
197+
"""When override is None, self.is_neox_style is used unchanged."""
198+
mock_get_forward_context.return_value = MagicMock()
199+
mock_get_forward_context.return_value.is_draft_model = False
200+
mock_get_forward_context.return_value.flash_comm_v1_enabled = False
201+
mock_npu_op.return_value = MagicMock()
202+
203+
emb = make_embedding(is_neox_style=True)
204+
positions, query, key = _make_tensors()
205+
206+
emb.forward_oot(positions, query, key, is_neox_style_override=None)
207+
208+
assert mock_npu_op.call_args[0][-1] is True
209+
210+
@patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad")
211+
@patch("torch.ops.vllm.npu_rotary_embedding")
212+
@patch("vllm_ascend.ascend_forward_context.get_forward_context")
213+
def test_gather_unpad_called_when_all_conditions_met(
214+
self, mock_get_forward_context, mock_npu_op, mock_gather, make_embedding
215+
):
216+
"""
217+
maybe_all_gather_and_maybe_unpad is called iff:
218+
is_draft_model=True AND use_mtp=True AND flash_comm_v1_enabled=True
219+
"""
220+
mock_get_forward_context.return_value = MagicMock()
221+
mock_get_forward_context.return_value.is_draft_model = True
222+
mock_get_forward_context.return_value.flash_comm_v1_enabled = True
223+
gathered_positions = torch.arange(SEQ_LEN, dtype=torch.long)
224+
mock_gather.return_value = gathered_positions
225+
mock_npu_op.return_value = MagicMock()
226+
227+
emb = make_embedding(use_mtp=True)
228+
positions, query, key = _make_tensors()
229+
230+
emb.forward_oot(positions, query, key)
231+
232+
mock_gather.assert_called_once()
233+
# npu op should receive the gathered positions, not the originals
234+
assert mock_npu_op.call_args[0][0] is gathered_positions
235+
236+
@pytest.mark.parametrize("is_draft_model,flash_comm,use_mtp", [
237+
(False, True, True), # not draft
238+
(True, False, True), # flash_comm disabled
239+
(True, True, False), # use_mtp disabled
240+
])
241+
@patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad")
242+
@patch("torch.ops.vllm.npu_rotary_embedding")
243+
@patch("vllm_ascend.ascend_forward_context.get_forward_context")
244+
def test_gather_unpad_skipped_unless_all_conditions_met(
245+
self, mock_get_forward_context, mock_npu_op, mock_gather,
246+
is_draft_model, flash_comm, use_mtp, make_embedding,
247+
):
248+
"""gather/unpad must NOT fire if any one of the three conditions is False."""
249+
mock_get_forward_context.return_value = MagicMock()
250+
mock_get_forward_context.return_value.is_draft_model = is_draft_model
251+
mock_get_forward_context.return_value.flash_comm_v1_enabled = flash_comm
252+
mock_npu_op.return_value = MagicMock()
253+
254+
emb = make_embedding(use_mtp=use_mtp)
255+
positions, query, key = _make_tensors()
256+
257+
emb.forward_oot(positions, query, key)
258+
259+
mock_gather.assert_not_called()
260+
# Original positions tensor is passed through untouched
261+
assert mock_npu_op.call_args[0][0] is positions
262+
263+
def test_parent_init_signature_has_not_changed(self):
264+
"""
265+
Fail loudly if RotaryEmbedding.__init__ adds, removes, or
266+
renames parameters, so a developer knows to update AscendRotaryEmbedding
267+
accordingly.
268+
"""
269+
check_parent_init_signature_has_not_changed(
270+
RotaryEmbedding.__init__,
271+
AscendRotaryEmbedding.__init__
272+
)
273+
274+
275+
class TestAscendYaRNRotaryEmbeddingForwardOOT:
276+
277+
@patch("vllm_ascend.ops.rotary_embedding.AscendRotaryEmbedding.forward_oot")
278+
def test_delegates_to_ascend_rotary_forward_oot(self, mock_delegate, make_yarn_embedding):
279+
"""forward_oot must delegate to AscendRotaryEmbedding.forward_oot."""
280+
expected = MagicMock()
281+
mock_delegate.return_value = expected
282+
283+
emb = make_yarn_embedding()
284+
positions, query, key = _make_tensors()
285+
286+
result = emb.forward_oot(positions, query, key)
287+
288+
mock_delegate.assert_called_once_with(emb, positions, query, key, None, None)
289+
assert result is expected
290+
291+
@patch("vllm_ascend.ops.rotary_embedding.AscendRotaryEmbedding.forward_oot")
292+
def test_return_value_passed_through(self, mock_delegate, make_yarn_embedding):
293+
"""Return value from the delegate is returned unchanged."""
294+
sentinel = (torch.randn(SEQ_LEN, HEAD_SIZE), torch.randn(SEQ_LEN, HEAD_SIZE))
295+
mock_delegate.return_value = sentinel
296+
297+
emb = make_yarn_embedding()
298+
positions, query, key = _make_tensors()
299+
300+
result = emb.forward_oot(positions, query, key)
301+
302+
assert result is sentinel
303+
304+
@pytest.mark.parametrize("override", [True, False])
305+
@patch("vllm_ascend.ops.rotary_embedding.AscendRotaryEmbedding.forward_oot")
306+
def test_is_neox_style_override_forwarded(self, mock_delegate, override, make_yarn_embedding):
307+
"""is_neox_style_override must be forwarded verbatim, both True and False."""
308+
mock_delegate.return_value = MagicMock()
309+
310+
emb = make_yarn_embedding()
311+
positions, query, key = _make_tensors()
312+
313+
emb.forward_oot(positions, query, key, is_neox_style_override=override)
314+
315+
_, call_args, _ = mock_delegate.mock_calls[0]
316+
assert call_args[5] is override # 6th positional arg
317+
318+
@patch("vllm_ascend.ops.rotary_embedding.AscendRotaryEmbedding.forward_oot")
319+
def test_all_args_forwarded_together(self, mock_delegate, make_yarn_embedding):
320+
"""Smoke test: all args passed simultaneously are all forwarded correctly."""
321+
mock_delegate.return_value = MagicMock()
322+
323+
emb = make_yarn_embedding()
324+
positions, query, key = _make_tensors()
325+
offsets = torch.ones(SEQ_LEN, dtype=torch.long)
326+
327+
emb.forward_oot(positions, query, key, offsets=offsets, is_neox_style_override=False)
328+
329+
mock_delegate.assert_called_once_with(emb, positions, query, key, offsets, False)
330+
331+
def test_parent_init_signature_has_not_changed(self):
332+
"""
333+
Fail loudly if YaRNScalingRotaryEmbedding.__init__ adds, removes, or
334+
renames parameters, so a developer knows to update AscendYaRNRotaryEmbedding
335+
accordingly.
336+
"""
337+
check_parent_init_signature_has_not_changed(
338+
YaRNScalingRotaryEmbedding.__init__,
339+
AscendYaRNRotaryEmbedding.__init__
340+
)

vllm_ascend/ops/rotary_embedding.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,8 +222,9 @@ def __init__(
222222
base: float,
223223
is_neox_style: bool,
224224
dtype: torch.dtype,
225+
init_cache: bool = True,
225226
) -> None:
226-
super().__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype)
227+
super().__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype, init_cache)
227228
vllm_config = get_current_vllm_config()
228229
self.use_mtp = vllm_config.speculative_config and vllm_config.speculative_config.method == "mtp"
229230
_record_cos_sin_cache(self.cos_sin_cache)
@@ -264,13 +265,15 @@ def __init__(
264265
attn_factor: float = 1,
265266
beta_fast: int = 32,
266267
beta_slow: int = 1,
268+
apply_yarn_scaling: bool = True,
267269
truncate: bool = False,
268270
) -> None:
269271
extra_kwargs = {
270272
"extrapolation_factor": extrapolation_factor,
271273
"attn_factor": attn_factor,
272274
"beta_fast": beta_fast,
273275
"beta_slow": beta_slow,
276+
"apply_yarn_scaling": apply_yarn_scaling,
274277
# TODO: current not support actual truncate,adaptation for extra parameters to be compatible with vllm
275278
"truncate": truncate,
276279
}

0 commit comments

Comments
 (0)