77
88from vllm import pos_encoding_ops
99
10+ IS_NEOX_STYLE = [True , False ]
1011DTYPES = [torch .half , torch .bfloat16 , torch .float ]
1112HEAD_SIZES = [64 , 80 , 96 , 112 , 128 , 256 ]
1213ROTARY_DIMS = [None , 32 ] # None means rotary dim == head size
1314NUM_HEADS = [7 , 12 , 40 , 52 ] # Arbitrary values for testing
14- NUM_TOKENS = [7 , 83 , 2048 ] # Arbitrary values for testing
15+ NUM_TOKENS = [11 , 83 , 2048 ] # Arbitrary values for testing
1516SEEDS = [0 ]
1617
1718
18- def rotate_half (x : torch .Tensor ) -> torch .Tensor :
19+ def rotate_neox (x : torch .Tensor ) -> torch .Tensor :
1920 x1 = x [..., :x .shape [- 1 ] // 2 ]
2021 x2 = x [..., x .shape [- 1 ] // 2 :]
2122 return torch .cat ((- x2 , x1 ), dim = - 1 )
2223
2324
24- def apply_rotary_pos_emb (
25+ def rotate_gptj (x : torch .Tensor ) -> torch .Tensor :
26+ x1 = x [..., ::2 ]
27+ x2 = x [..., 1 ::2 ]
28+ x = torch .stack ((- x2 , x1 ), dim = - 1 )
29+ return x .flatten (- 2 )
30+
31+
32+ def apply_rope (
2533 q : torch .Tensor ,
2634 k : torch .Tensor ,
2735 cos : torch .Tensor ,
2836 sin : torch .Tensor ,
37+ is_neox_style : bool ,
2938) -> Tuple [torch .Tensor , torch .Tensor ]:
30- q_embed = (q * cos ) + (rotate_half (q ) * sin )
31- k_embed = (k * cos ) + (rotate_half (k ) * sin )
39+ rotate_fn = rotate_neox if is_neox_style else rotate_gptj
40+ q_embed = (q * cos ) + (rotate_fn (q ) * sin )
41+ k_embed = (k * cos ) + (rotate_fn (k ) * sin )
3242 return q_embed , k_embed
3343
3444
35- class RefRotaryEmbeddingNeox (nn .Module ):
36- """Reference implementation of the GPT-NeoX style rotary embedding."""
45+ class RefRotaryEmbedding (nn .Module ):
46+ """Reference implementation of rotary embedding."""
3747
3848 def __init__ (
3949 self ,
4050 dim : int ,
41- max_position_embeddings : int = 2048 ,
51+ is_neox_style : bool ,
52+ max_position_embeddings : int = 8192 ,
4253 base : int = 10000 ,
4354 ) -> None :
4455 super ().__init__ ()
4556 self .rotary_dim = dim
57+ self .is_neox_style = is_neox_style
4658 self .max_position_embeddings = max_position_embeddings
4759
4860 # Create cos and sin embeddings.
4961 inv_freq = 1.0 / (base ** (torch .arange (0 , dim , 2 ) / dim ))
5062 t = torch .arange (max_position_embeddings ).float ()
5163 freqs = torch .einsum ("i,j->ij" , t , inv_freq .float ())
52- emb = torch .cat ((freqs , freqs ), dim = - 1 )
64+ if is_neox_style :
65+ emb = torch .cat ((freqs , freqs ), dim = - 1 )
66+ else :
67+ emb = torch .repeat_interleave (freqs , 2 , - 1 )
5368 cos = emb .cos ().to (dtype = inv_freq .dtype )
5469 sin = emb .sin ().to (dtype = inv_freq .dtype )
5570 self .register_buffer ("cos_cached" , cos , persistent = False )
@@ -61,7 +76,6 @@ def forward(
6176 query : torch .Tensor , # [num_tokens, num_heads, head_size]
6277 key : torch .Tensor , # [num_tokens, num_heads, head_size]
6378 ) -> Tuple [torch .Tensor , torch .Tensor ]:
64-
6579 query_rot = query [..., :self .rotary_dim ]
6680 query_pass = query [..., self .rotary_dim :]
6781 key_rot = key [..., :self .rotary_dim ]
@@ -71,7 +85,9 @@ def forward(
7185 key_rot = key_rot .transpose (0 , 1 )
7286 cos = F .embedding (positions , self .cos_cached )
7387 sin = F .embedding (positions , self .sin_cached )
74- query_rot , key_rot = apply_rotary_pos_emb (query_rot , key_rot , cos , sin )
88+
89+ query_rot , key_rot = apply_rope (query_rot , key_rot , cos , sin ,
90+ self .is_neox_style )
7591 query_rot = query_rot .transpose (0 , 1 ).contiguous ()
7692 key_rot = key_rot .transpose (0 , 1 ).contiguous ()
7793
@@ -82,14 +98,16 @@ def forward(
8298 return query , key
8399
84100
101+ @pytest .mark .parametrize ("is_neox_style" , IS_NEOX_STYLE )
85102@pytest .mark .parametrize ("num_tokens" , NUM_TOKENS )
86103@pytest .mark .parametrize ("num_heads" , NUM_HEADS )
87104@pytest .mark .parametrize ("head_size" , HEAD_SIZES )
88105@pytest .mark .parametrize ("rotary_dim" , ROTARY_DIMS )
89106@pytest .mark .parametrize ("dtype" , DTYPES )
90107@pytest .mark .parametrize ("seed" , SEEDS )
91108@torch .inference_mode ()
92- def test_rotary_embedding_neox (
109+ def test_rotary_embedding (
110+ is_neox_style : bool ,
93111 num_tokens : int ,
94112 num_heads : int ,
95113 head_size : int ,
@@ -104,15 +122,15 @@ def test_rotary_embedding_neox(
104122 torch .random .manual_seed (seed )
105123 torch .cuda .manual_seed (seed )
106124
107- positions = torch .randint (0 , max_position , (num_tokens , ), device = ' cuda' )
125+ positions = torch .randint (0 , max_position , (num_tokens , ), device = " cuda" )
108126 query = torch .randn (num_tokens ,
109127 num_heads * head_size ,
110128 dtype = dtype ,
111- device = ' cuda' )
129+ device = " cuda" )
112130 key = torch .randn (num_tokens ,
113131 num_heads * head_size ,
114132 dtype = dtype ,
115- device = ' cuda' )
133+ device = " cuda" )
116134
117135 # Create the rotary embedding.
118136 inv_freq = 1.0 / (base ** (torch .arange (0 , rotary_dim , 2 ) / rotary_dim ))
@@ -126,20 +144,22 @@ def test_rotary_embedding_neox(
126144 # Run the kernel. The kernel is in-place, so we need to clone the inputs.
127145 out_query = query .clone ()
128146 out_key = key .clone ()
129- pos_encoding_ops .rotary_embedding_neox (
147+ pos_encoding_ops .rotary_embedding (
130148 positions ,
131149 out_query ,
132150 out_key ,
133151 head_size ,
134152 cos_sin_cache ,
153+ is_neox_style ,
135154 )
136155
137156 # Run the reference implementation.
138- ref_rotary_embedding = RefRotaryEmbeddingNeox (
157+ ref_rotary_embedding = RefRotaryEmbedding (
139158 dim = rotary_dim ,
159+ is_neox_style = is_neox_style ,
140160 max_position_embeddings = max_position ,
141161 base = base ,
142- ).to (dtype = dtype , device = ' cuda' )
162+ ).to (dtype = dtype , device = " cuda" )
143163 ref_query , ref_key = ref_rotary_embedding (
144164 positions ,
145165 query .view (num_tokens , num_heads , head_size ),
0 commit comments