Skip to content

Commit 0ea12cb

Browse files
NicholasTaonicholastao
authored andcommitted
fix TestSetCosSinCache
Signed-off-by: taoyuxiang <[email protected]>
1 parent 9faa4bf commit 0ea12cb

File tree

1 file changed

+41
-39
lines changed

1 file changed

+41
-39
lines changed

tests/ut/ops/test_rotary_embedding.py

Lines changed: 41 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import math
22
from unittest.mock import MagicMock, patch
33

4+
import pytest
45
import torch
56

67
from tests.ut.base import TestBase
@@ -316,48 +317,49 @@ def test_scale_greater_than_1(self):
316317
msg=f"Failed for scale={scale}, mscale={mscale}")
317318

318319

319-
class MockRotaryEmbedding(torch.nn.Module):
320+
class MockRotaryEmbedding:
320321

321322
def __init__(self, base, rotary_dim, max_position_embeddings):
322-
super().__init__()
323-
324323
self.base = base
325-
326324
self.rotary_dim = rotary_dim
327-
328325
self.max_position_embeddings = max_position_embeddings
329326

330-
def _set_cos_sin_cache(self, seq_len, device, dtype):
331-
return raw__set_cos_sin_cache(self, seq_len, device, dtype)
332-
333-
334-
class TestSetCosSinCache(TestBase):
335-
336-
def test_set_cos_sin_cache(self):
337-
# prepare an instance with reasonable values
338-
base = 10000.0
339-
rotary_dim = 4
340-
max_pos = 10
341-
model = MockRotaryEmbedding(base, rotary_dim, max_pos)
342-
# mock out register_buffer
343-
model.register_buffer = MagicMock()
344-
# call the private method via name mangling
345-
model._set_cos_sin_cache(seq_len=8, device="cpu", dtype=torch.float32)
346-
# expect three calls: inv_freq, cos, sin
347-
assert model.register_buffer.call_count == 3
348-
names = [call.args[0] for call in model.register_buffer.call_args_list]
349-
assert set(names) == {"inv_freq", "cos", "sin"}
350-
# verify inv_freq shape
351-
inv_freq = model.register_buffer.call_args_list[0].args[1]
352-
assert isinstance(inv_freq, torch.Tensor)
353-
assert inv_freq.shape == (rotary_dim // 2, )
354-
# verify cos buffer
355-
cos = model.register_buffer.call_args_list[1].args[1]
356-
assert isinstance(cos, torch.Tensor)
357-
assert cos.shape == (max_pos, rotary_dim)
358-
assert cos.dtype == torch.float32
359-
# verify sin buffer
360-
sin = model.register_buffer.call_args_list[2].args[1]
361-
assert isinstance(sin, torch.Tensor)
362-
assert sin.shape == (max_pos, rotary_dim)
363-
assert sin.dtype == torch.float32
327+
328+
@pytest.fixture
329+
def dummy_module():
330+
return MockRotaryEmbedding(base=10000.0,
331+
rotary_dim=64,
332+
max_position_embeddings=512)
333+
334+
335+
class TestSetCosSinCache:
336+
337+
def test_set_cos_sin_cache_generates_real_tensors(self, dummy_module):
338+
calls = []
339+
340+
def fake_register_buffer(name, tensor, persistent=True):
341+
setattr(dummy_module, name, tensor)
342+
calls.append(name)
343+
344+
dummy_module.register_buffer = fake_register_buffer
345+
seq_len = 128
346+
device = torch.device("cpu")
347+
dtype = torch.float32
348+
349+
raw__set_cos_sin_cache(dummy_module, seq_len, device, dtype)
350+
351+
assert calls == ['inv_freq', 'cos', 'sin']
352+
353+
assert isinstance(dummy_module.inv_freq, torch.Tensor)
354+
assert dummy_module.inv_freq.shape == (dummy_module.rotary_dim // 2, )
355+
assert dummy_module.inv_freq.device == device
356+
assert dummy_module.inv_freq.dtype == torch.float32
357+
358+
expected_shape = (dummy_module.max_position_embeddings,
359+
dummy_module.rotary_dim)
360+
for name in ('cos', 'sin'):
361+
buf = getattr(dummy_module, name)
362+
assert isinstance(buf, torch.Tensor)
363+
assert buf.shape == expected_shape
364+
assert buf.device == device
365+
assert buf.dtype == torch.float32

0 commit comments

Comments
 (0)