22from typing import Tuple
33
44
5- def get_meshgrid_nd (sizes , dim = 2 ):
5+ def get_meshgrid_nd (sizes ):
66 """
77 Get n-D meshgrid with given sizes.
88
@@ -46,8 +46,8 @@ def apply_rope(
4646 xq_ = torch .view_as_complex (
4747 xq .reshape (* xq .shape [:- 1 ], - 1 , 2 )
4848 ) # [B, S, H, D//2]
49- S , H = xq_ .shape [1 : 3 ]
50- freqs_cis = freqs_cis .view (1 , S , H , - 1 ) # [S, D //2] --> [1, S, H, D//(2H) ]
49+ S = xq_ .shape [1 ]
50+ freqs_cis = freqs_cis .view (1 , S , 1 , - 1 ) # [S, nD //2] --> [1, S, 1, nD//2 ]
5151 # view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real)
5252 xq_out = torch .view_as_real (xq_ * freqs_cis ).flatten (3 ).type_as (xq )
5353 xk_ = torch .view_as_complex (
@@ -73,7 +73,7 @@ def get_nd_rope(
7373 theta: Scaling factor for frequency computation.
7474
7575 Returns:
76- emb: Positional embedding [HW , D/2]
76+ emb: Positional embedding [S , D/2]
7777 """
7878 grid = get_meshgrid_nd (sizes ) # [n, T, H, W]
7979
@@ -84,10 +84,10 @@ def get_nd_rope(
8484 dim_list [i ],
8585 grid [i ].reshape (- 1 ),
8686 theta ,
87- ) # [THW, D /2]
87+ ) # [THW, D_i /2]
8888 embs .append (emb )
8989
90- emb = torch .cat (embs , dim = 1 ) # (THW, nD /2)
90+ emb = torch .cat (embs , dim = 1 ) # (THW, D /2)
9191 return emb
9292
9393
0 commit comments