Skip to content

Commit 7a88f24

Browse files
committed
fix test_set_cos_sin_cache
Signed-off-by: taoyuxiang <[email protected]>
1 parent 8a6d5ee commit 7a88f24

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

tests/ut/ops/test_rotary_embedding.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
332332

333333
class TestSetCosSinCache(TestBase):
334334

335-
def test_set_cos_sin_cache_registers_buffers_and_sets_embed(self):
335+
def test_set_cos_sin_cache(self):
336336
# prepare an instance with reasonable values
337337
base = 10000.0
338338
rotary_dim = 4
@@ -341,9 +341,9 @@ def test_set_cos_sin_cache_registers_buffers_and_sets_embed(self):
341341
# mock out register_buffer
342342
model.register_buffer = MagicMock()
343343
# call the private method via name mangling
344-
model._RotaryEmbedding._set_cos_sin_cache(seq_len=8,
345-
device="cpu",
346-
dtype=torch.float32)
344+
model._set_cos_sin_cache(seq_len=8,
345+
device="cpu",
346+
dtype=torch.float32)
347347
# expect three calls: inv_freq, cos, sin
348348
assert model.register_buffer.call_count == 3
349349
names = [call.args[0] for call in model.register_buffer.call_args_list]

0 commit comments

Comments
 (0)