Skip to content

Commit dd25081

Browse files
author
Tingbo
committed
fix test_rope
1 parent be4eccb commit dd25081

File tree

2 files changed

+11
-10
lines changed

2 files changed

+11
-10
lines changed

model/rope.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import Tuple
33

44

5-
def get_meshgrid_nd(sizes, dim=2):
5+
def get_meshgrid_nd(sizes):
66
"""
77
Get n-D meshgrid with given sizes.
88
@@ -46,8 +46,8 @@ def apply_rope(
4646
xq_ = torch.view_as_complex(
4747
xq.reshape(*xq.shape[:-1], -1, 2)
4848
) # [B, S, H, D//2]
49-
S, H = xq_.shape[1:3]
50-
freqs_cis = freqs_cis.view(1, S, H, -1) # [S, D//2] --> [1, S, H, D//(2H)]
49+
S = xq_.shape[1]
50+
freqs_cis = freqs_cis.view(1, S, 1, -1) # [S, nD//2] --> [1, S, 1, nD//2]
5151
# view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real)
5252
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
5353
xk_ = torch.view_as_complex(
@@ -73,7 +73,7 @@ def get_nd_rope(
7373
theta: Scaling factor for frequency computation.
7474
7575
Returns:
76-
emb: Positional embedding [HW, D/2]
76+
emb: Positional embedding [S, D/2]
7777
"""
7878
grid = get_meshgrid_nd(sizes) # [n, T, H, W]
7979

@@ -84,10 +84,10 @@ def get_nd_rope(
8484
dim_list[i],
8585
grid[i].reshape(-1),
8686
theta,
87-
) # [THW, D/2]
87+
) # [THW, D_i/2]
8888
embs.append(emb)
8989

90-
emb = torch.cat(embs, dim=1) # (THW, nD/2)
90+
emb = torch.cat(embs, dim=1) # (THW, D/2)
9191
return emb
9292

9393

tests/test_rope.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import unittest
44

5-
from model.rope import apply_rope, get_1d_rope, get_nd_rope
5+
from model.rope import apply_rope, get_1d_rope, get_nd_rope, get_meshgrid_nd
66

77

88
class TestRope(unittest.TestCase):
@@ -14,11 +14,12 @@ def test_get_1d_rope(self):
1414
self.assertEqual(emb.shape, torch.Size([l, dim // 2]))
1515

1616
def test_get_nd_rope(self):
17-
T, H, W, dim = 8, 16, 16, 4
18-
dim_list = [dim, dim, dim]
17+
dim_list = [2, 2, 2]
18+
dim = sum(dim_list)
19+
T, H, W = 8, 16, 16
1920
sizes = [T, H, W]
2021
emb = get_nd_rope(dim_list, sizes)
21-
self.assertEqual(emb.shape, torch.Size([T * H * W, 3 * dim // 2]))
22+
self.assertEqual(emb.shape, torch.Size([T * H * W, dim // 2]))
2223

2324
def test_apply_rope(self):
2425
B, S, H, D = 1, 16, 2, 4

0 commit comments

Comments
 (0)