|
4 | 4 | import torch
|
5 | 5 |
|
6 | 6 | from tests.ut.base import TestBase
|
7 |
| -from vllm_ascend.ops.rotary_embedding import (custom_rotary_embedding_enabled, |
| 7 | +from vllm_ascend.ops.rotary_embedding import (__set_cos_sin_cache, |
| 8 | + custom_rotary_embedding_enabled, |
8 | 9 | native_rope_deepseek_forward,
|
9 | 10 | rope_forward_oot, rotate_half,
|
10 | 11 | yarn_find_correction_dim,
|
@@ -312,3 +313,52 @@ def test_scale_greater_than_1(self):
|
312 | 313 | expected,
|
313 | 314 | places=6,
|
314 | 315 | msg=f"Failed for scale={scale}, mscale={mscale}")
|
| 316 | + |
| 317 | + |
| 318 | +class MockRotaryEmbedding(torch.nn.Module): |
| 319 | + |
| 320 | + def __init__(self, base, rotary_dim, max_position_embeddings): |
| 321 | + super().__init__() |
| 322 | + |
| 323 | + self.base = base |
| 324 | + |
| 325 | + self.rotary_dim = rotary_dim |
| 326 | + |
| 327 | + self.max_position_embeddings = max_position_embeddings |
| 328 | + |
| 329 | + def _set_cos_sin_cache(self, seq_len, device, dtype): |
| 330 | + return __set_cos_sin_cache(self, seq_len, device, dtype) |
| 331 | + |
| 332 | + |
| 333 | +class TestSetCosSinCache(TestBase): |
| 334 | + |
| 335 | + def test_set_cos_sin_cache_registers_buffers_and_sets_embed(self): |
| 336 | + # prepare an instance with reasonable values |
| 337 | + base = 10000.0 |
| 338 | + rotary_dim = 4 |
| 339 | + max_pos = 10 |
| 340 | + model = MockRotaryEmbedding(base, rotary_dim, max_pos) |
| 341 | + # mock out register_buffer |
| 342 | + model.register_buffer = MagicMock() |
| 343 | + # call the private method via name mangling |
| 344 | + model._RotaryEmbedding._set_cos_sin_cache(seq_len=8, |
| 345 | + device="cpu", |
| 346 | + dtype=torch.float32) |
| 347 | + # expect three calls: inv_freq, cos, sin |
| 348 | + assert model.register_buffer.call_count == 3 |
| 349 | + names = [call.args[0] for call in model.register_buffer.call_args_list] |
| 350 | + assert set(names) == {"inv_freq", "cos", "sin"} |
| 351 | + # verify inv_freq shape |
| 352 | + inv_freq = model.register_buffer.call_args_list[0].args[1] |
| 353 | + assert isinstance(inv_freq, torch.Tensor) |
| 354 | + assert inv_freq.shape == (rotary_dim // 2, ) |
| 355 | + # verify cos buffer |
| 356 | + cos = model.register_buffer.call_args_list[1].args[1] |
| 357 | + assert isinstance(cos, torch.Tensor) |
| 358 | + assert cos.shape == (max_pos, rotary_dim) |
| 359 | + assert cos.dtype == torch.float32 |
| 360 | + # verify sin buffer |
| 361 | + sin = model.register_buffer.call_args_list[2].args[1] |
| 362 | + assert isinstance(sin, torch.Tensor) |
| 363 | + assert sin.shape == (max_pos, rotary_dim) |
| 364 | + assert sin.dtype == torch.float32 |
0 commit comments