|
1 | 1 | import math
|
2 | 2 | from unittest.mock import MagicMock, patch
|
3 | 3 |
|
| 4 | +import pytest |
4 | 5 | import torch
|
5 | 6 |
|
6 | 7 | from tests.ut.base import TestBase
|
@@ -316,48 +317,49 @@ def test_scale_greater_than_1(self):
|
316 | 317 | msg=f"Failed for scale={scale}, mscale={mscale}")
|
317 | 318 |
|
318 | 319 |
|
319 |
| -class MockRotaryEmbedding(torch.nn.Module): |
| 320 | +class MockRotaryEmbedding: |
320 | 321 |
|
321 | 322 | def __init__(self, base, rotary_dim, max_position_embeddings):
|
322 |
| - super().__init__() |
323 |
| - |
324 | 323 | self.base = base
|
325 |
| - |
326 | 324 | self.rotary_dim = rotary_dim
|
327 |
| - |
328 | 325 | self.max_position_embeddings = max_position_embeddings
|
329 | 326 |
|
330 |
| - def _set_cos_sin_cache(self, seq_len, device, dtype): |
331 |
| - return raw__set_cos_sin_cache(self, seq_len, device, dtype) |
332 |
| - |
333 |
| - |
334 |
| -class TestSetCosSinCache(TestBase): |
335 |
| - |
336 |
| - def test_set_cos_sin_cache(self): |
337 |
| - # prepare an instance with reasonable values |
338 |
| - base = 10000.0 |
339 |
| - rotary_dim = 4 |
340 |
| - max_pos = 10 |
341 |
| - model = MockRotaryEmbedding(base, rotary_dim, max_pos) |
342 |
| - # mock out register_buffer |
343 |
| - model.register_buffer = MagicMock() |
344 |
| - # call the private method via name mangling |
345 |
| - model._set_cos_sin_cache(seq_len=8, device="cpu", dtype=torch.float32) |
346 |
| - # expect three calls: inv_freq, cos, sin |
347 |
| - assert model.register_buffer.call_count == 3 |
348 |
| - names = [call.args[0] for call in model.register_buffer.call_args_list] |
349 |
| - assert set(names) == {"inv_freq", "cos", "sin"} |
350 |
| - # verify inv_freq shape |
351 |
| - inv_freq = model.register_buffer.call_args_list[0].args[1] |
352 |
| - assert isinstance(inv_freq, torch.Tensor) |
353 |
| - assert inv_freq.shape == (rotary_dim // 2, ) |
354 |
| - # verify cos buffer |
355 |
| - cos = model.register_buffer.call_args_list[1].args[1] |
356 |
| - assert isinstance(cos, torch.Tensor) |
357 |
| - assert cos.shape == (max_pos, rotary_dim) |
358 |
| - assert cos.dtype == torch.float32 |
359 |
| - # verify sin buffer |
360 |
| - sin = model.register_buffer.call_args_list[2].args[1] |
361 |
| - assert isinstance(sin, torch.Tensor) |
362 |
| - assert sin.shape == (max_pos, rotary_dim) |
363 |
| - assert sin.dtype == torch.float32 |
| 327 | + |
| 328 | +@pytest.fixture |
| 329 | +def dummy_module(): |
| 330 | + return MockRotaryEmbedding(base=10000.0, |
| 331 | + rotary_dim=64, |
| 332 | + max_position_embeddings=512) |
| 333 | + |
| 334 | + |
| 335 | +class TestSetCosSinCache: |
| 336 | + |
| 337 | + def test_set_cos_sin_cache_generates_real_tensors(self, dummy_module): |
| 338 | + calls = [] |
| 339 | + |
| 340 | + def fake_register_buffer(name, tensor, persistent=True): |
| 341 | + setattr(dummy_module, name, tensor) |
| 342 | + calls.append(name) |
| 343 | + |
| 344 | + dummy_module.register_buffer = fake_register_buffer |
| 345 | + seq_len = 128 |
| 346 | + device = torch.device("cpu") |
| 347 | + dtype = torch.float32 |
| 348 | + |
| 349 | + raw__set_cos_sin_cache(dummy_module, seq_len, device, dtype) |
| 350 | + |
| 351 | + assert calls == ['inv_freq', 'cos', 'sin'] |
| 352 | + |
| 353 | + assert isinstance(dummy_module.inv_freq, torch.Tensor) |
| 354 | + assert dummy_module.inv_freq.shape == (dummy_module.rotary_dim // 2, ) |
| 355 | + assert dummy_module.inv_freq.device == device |
| 356 | + assert dummy_module.inv_freq.dtype == torch.float32 |
| 357 | + |
| 358 | + expected_shape = (dummy_module.max_position_embeddings, |
| 359 | + dummy_module.rotary_dim) |
| 360 | + for name in ('cos', 'sin'): |
| 361 | + buf = getattr(dummy_module, name) |
| 362 | + assert isinstance(buf, torch.Tensor) |
| 363 | + assert buf.shape == expected_shape |
| 364 | + assert buf.device == device |
| 365 | + assert buf.dtype == torch.float32 |
0 commit comments