Skip to content

Commit 9747c93

Browse files
committed
fix codecov
Signed-off-by: taoyuxiang <[email protected]>
1 parent d4a3fbe commit 9747c93

File tree

2 files changed

+94
-2
lines changed

2 files changed

+94
-2
lines changed

tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,4 +108,46 @@ def test_eagle_correctness(
108108
model_name: str,
109109
use_eagle3: bool,
110110
):
111-
pass
111+
'''
112+
Compare the outputs of a original LLM and a speculative LLM
113+
should be the same when using eagle speculative decoding.
114+
'''
115+
if not use_eagle3:
116+
pytest.skip("Not current support for the test.")
117+
118+
ref_llm = LLM(model=model_name, max_model_len=2048, enforce_eager=True)
119+
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
120+
del ref_llm
121+
122+
spec_model_name = eagle3_model_name() if use_eagle3 else eagle_model_name()
123+
spec_llm = LLM(
124+
model=model_name,
125+
trust_remote_code=True,
126+
enable_chunked_prefill=True,
127+
max_num_seqs=1,
128+
max_num_batched_tokens=2048,
129+
gpu_memory_utilization=0.6,
130+
speculative_config={
131+
"method": "eagle3" if use_eagle3 else "eagle",
132+
"model": spec_model_name,
133+
"num_speculative_tokens": 2,
134+
"max_model_len": 128,
135+
},
136+
max_model_len=128,
137+
enforce_eager=True,
138+
)
139+
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
140+
matches = 0
141+
misses = 0
142+
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
143+
if ref_output.outputs[0].text == spec_output.outputs[0].text:
144+
matches += 1
145+
else:
146+
misses += 1
147+
print(f"ref_output: {ref_output.outputs[0].text}")
148+
print(f"spec_output: {spec_output.outputs[0].text}")
149+
150+
# Heuristic: expect at least 66% of the prompts to match exactly
151+
# Upon failure, inspect the outputs to check for inaccuracy.
152+
assert matches > int(0.66 * len(ref_outputs))
153+
del spec_llm

tests/ut/ops/test_rotary_embedding.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
import torch
55

66
from tests.ut.base import TestBase
7-
from vllm_ascend.ops.rotary_embedding import (custom_rotary_embedding_enabled,
7+
from vllm_ascend.ops.rotary_embedding import (__set_cos_sin_cache,
8+
custom_rotary_embedding_enabled,
89
native_rope_deepseek_forward,
910
rope_forward_oot, rotate_half,
1011
yarn_find_correction_dim,
@@ -312,3 +313,52 @@ def test_scale_greater_than_1(self):
312313
expected,
313314
places=6,
314315
msg=f"Failed for scale={scale}, mscale={mscale}")
316+
317+
318+
class MockRotaryEmbedding(torch.nn.Module):
319+
320+
def __init__(self, base, rotary_dim, max_position_embeddings):
321+
super().__init__()
322+
323+
self.base = base
324+
325+
self.rotary_dim = rotary_dim
326+
327+
self.max_position_embeddings = max_position_embeddings
328+
329+
def _set_cos_sin_cache(self, seq_len, device, dtype):
330+
return __set_cos_sin_cache(self, seq_len, device, dtype)
331+
332+
333+
class TestSetCosSinCache(TestBase):
334+
335+
def test_set_cos_sin_cache_registers_buffers_and_sets_embed(self):
336+
# prepare an instance with reasonable values
337+
base = 10000.0
338+
rotary_dim = 4
339+
max_pos = 10
340+
model = MockRotaryEmbedding(base, rotary_dim, max_pos)
341+
# mock out register_buffer
342+
model.register_buffer = MagicMock()
343+
# call the private method via name mangling
344+
model._RotaryEmbedding._set_cos_sin_cache(seq_len=8,
345+
device="cpu",
346+
dtype=torch.float32)
347+
# expect three calls: inv_freq, cos, sin
348+
assert model.register_buffer.call_count == 3
349+
names = [call.args[0] for call in model.register_buffer.call_args_list]
350+
assert set(names) == {"inv_freq", "cos", "sin"}
351+
# verify inv_freq shape
352+
inv_freq = model.register_buffer.call_args_list[0].args[1]
353+
assert isinstance(inv_freq, torch.Tensor)
354+
assert inv_freq.shape == (rotary_dim // 2, )
355+
# verify cos buffer
356+
cos = model.register_buffer.call_args_list[1].args[1]
357+
assert isinstance(cos, torch.Tensor)
358+
assert cos.shape == (max_pos, rotary_dim)
359+
assert cos.dtype == torch.float32
360+
# verify sin buffer
361+
sin = model.register_buffer.call_args_list[2].args[1]
362+
assert isinstance(sin, torch.Tensor)
363+
assert sin.shape == (max_pos, rotary_dim)
364+
assert sin.dtype == torch.float32

0 commit comments

Comments
 (0)