1- from typing import Optional , Tuple
1+ from typing import Optional
22
33import pytest
44import torch
5- import torch .nn as nn
6- import torch .nn .functional as F
75
8- from vllm ._C import ops
6+ from vllm .model_executor . layers . rotary_embedding import get_rope
97
108IS_NEOX_STYLE = [True , False ]
119DTYPES = [torch .half , torch .bfloat16 , torch .float ]
1210HEAD_SIZES = [64 , 80 , 96 , 112 , 128 , 256 ]
1311ROTARY_DIMS = [None , 32 ] # None means rotary dim == head size
14- NUM_HEADS = [7 , 12 , 40 , 52 ] # Arbitrary values for testing
15- NUM_TOKENS = [11 , 83 , 2048 ] # Arbitrary values for testing
12+ NUM_HEADS = [7 , 17 ] # Arbitrary values for testing
13+ BATCH_SIZES = [1 , 5 ] # Arbitrary values for testing
14+ SEQ_LENS = [11 , 8192 ] # Arbitrary values for testing
1615SEEDS = [0 ]
1716
1817
19- def rotate_neox (x : torch .Tensor ) -> torch .Tensor :
20- x1 = x [..., :x .shape [- 1 ] // 2 ]
21- x2 = x [..., x .shape [- 1 ] // 2 :]
22- return torch .cat ((- x2 , x1 ), dim = - 1 )
23-
24-
25- def rotate_gptj (x : torch .Tensor ) -> torch .Tensor :
26- x1 = x [..., ::2 ]
27- x2 = x [..., 1 ::2 ]
28- x = torch .stack ((- x2 , x1 ), dim = - 1 )
29- return x .flatten (- 2 )
30-
31-
32- def apply_rope (
33- q : torch .Tensor ,
34- k : torch .Tensor ,
35- cos : torch .Tensor ,
36- sin : torch .Tensor ,
37- is_neox_style : bool ,
38- ) -> Tuple [torch .Tensor , torch .Tensor ]:
39- rotate_fn = rotate_neox if is_neox_style else rotate_gptj
40- q_embed = (q * cos ) + (rotate_fn (q ) * sin )
41- k_embed = (k * cos ) + (rotate_fn (k ) * sin )
42- return q_embed , k_embed
43-
44-
45- class RefRotaryEmbedding (nn .Module ):
46- """Reference implementation of rotary embedding."""
47-
48- def __init__ (
49- self ,
50- dim : int ,
51- is_neox_style : bool ,
52- max_position_embeddings : int = 8192 ,
53- base : int = 10000 ,
54- ) -> None :
55- super ().__init__ ()
56- self .rotary_dim = dim
57- self .is_neox_style = is_neox_style
58- self .max_position_embeddings = max_position_embeddings
59-
60- # Create cos and sin embeddings.
61- inv_freq = 1.0 / (base ** (torch .arange (0 , dim , 2 ) / dim ))
62- t = torch .arange (max_position_embeddings ).float ()
63- freqs = torch .einsum ("i,j->ij" , t , inv_freq .float ())
64- if is_neox_style :
65- emb = torch .cat ((freqs , freqs ), dim = - 1 )
66- else :
67- emb = torch .repeat_interleave (freqs , 2 , - 1 )
68- cos = emb .cos ().to (dtype = inv_freq .dtype )
69- sin = emb .sin ().to (dtype = inv_freq .dtype )
70- self .register_buffer ("cos_cached" , cos , persistent = False )
71- self .register_buffer ("sin_cached" , sin , persistent = False )
72-
73- def forward (
74- self ,
75- positions : torch .Tensor , # [num_tokens]
76- query : torch .Tensor , # [num_tokens, num_heads, head_size]
77- key : torch .Tensor , # [num_tokens, num_heads, head_size]
78- ) -> Tuple [torch .Tensor , torch .Tensor ]:
79- query_rot = query [..., :self .rotary_dim ]
80- query_pass = query [..., self .rotary_dim :]
81- key_rot = key [..., :self .rotary_dim ]
82- key_pass = key [..., self .rotary_dim :]
83-
84- query_rot = query_rot .transpose (0 , 1 )
85- key_rot = key_rot .transpose (0 , 1 )
86- cos = F .embedding (positions , self .cos_cached )
87- sin = F .embedding (positions , self .sin_cached )
88-
89- query_rot , key_rot = apply_rope (query_rot , key_rot , cos , sin ,
90- self .is_neox_style )
91- query_rot = query_rot .transpose (0 , 1 ).contiguous ()
92- key_rot = key_rot .transpose (0 , 1 ).contiguous ()
93-
94- query = torch .cat ((query_rot , query_pass ), dim = - 1 )
95- key = torch .cat ((key_rot , key_pass ), dim = - 1 )
96-
97- # Output query/key shape: [num_tokens, num_tokens, head_size]
98- return query , key
99-
100-
10118@pytest .mark .parametrize ("is_neox_style" , IS_NEOX_STYLE )
102- @pytest .mark .parametrize ("num_tokens" , NUM_TOKENS )
19+ @pytest .mark .parametrize ("batch_size" , BATCH_SIZES )
20+ @pytest .mark .parametrize ("seq_len" , SEQ_LENS )
10321@pytest .mark .parametrize ("num_heads" , NUM_HEADS )
10422@pytest .mark .parametrize ("head_size" , HEAD_SIZES )
10523@pytest .mark .parametrize ("rotary_dim" , ROTARY_DIMS )
@@ -108,7 +26,8 @@ def forward(
10826@torch .inference_mode ()
10927def test_rotary_embedding (
11028 is_neox_style : bool ,
111- num_tokens : int ,
29+ batch_size : int ,
30+ seq_len : int ,
11231 num_heads : int ,
11332 head_size : int ,
11433 rotary_dim : Optional [int ],
@@ -122,53 +41,25 @@ def test_rotary_embedding(
12241 torch .random .manual_seed (seed )
12342 torch .cuda .manual_seed (seed )
12443
125- positions = torch .randint (0 , max_position , (num_tokens , ), device = "cuda" )
126- query = torch .randn (num_tokens ,
44+ if rotary_dim is None :
45+ rotary_dim = head_size
46+ rope = get_rope (head_size , rotary_dim , max_position , base , is_neox_style )
47+ rope = rope .to (dtype ).cuda ()
48+
49+ positions = torch .randint (0 ,
50+ max_position , (batch_size , seq_len ),
51+ device = "cuda" )
52+ query = torch .randn (batch_size ,
53+ seq_len ,
12754 num_heads * head_size ,
12855 dtype = dtype ,
12956 device = "cuda" )
130- key = torch .randn (num_tokens ,
131- num_heads * head_size ,
132- dtype = dtype ,
133- device = "cuda" )
134-
135- # Create the rotary embedding.
136- inv_freq = 1.0 / (base ** (
137- torch .arange (0 , rotary_dim , 2 , dtype = torch .float ) / rotary_dim ))
138- t = torch .arange (max_position ).float ()
139- freqs = torch .einsum ("i,j -> ij" , t , inv_freq )
140- cos = freqs .cos ()
141- sin = freqs .sin ()
142- cos_sin_cache = torch .cat ((cos , sin ), dim = - 1 )
143- cos_sin_cache = cos_sin_cache .to (dtype = dtype , device = "cuda" )
144-
145- # Run the kernel. The kernel is in-place, so we need to clone the inputs.
146- out_query = query .clone ()
147- out_key = key .clone ()
148- ops .rotary_embedding (
149- positions ,
150- out_query ,
151- out_key ,
152- head_size ,
153- cos_sin_cache ,
154- is_neox_style ,
155- )
156-
157- # Run the reference implementation.
158- ref_rotary_embedding = RefRotaryEmbedding (
159- dim = rotary_dim ,
160- is_neox_style = is_neox_style ,
161- max_position_embeddings = max_position ,
162- base = base ,
163- ).to (dtype = dtype , device = "cuda" )
164- ref_query , ref_key = ref_rotary_embedding (
165- positions ,
166- query .view (num_tokens , num_heads , head_size ),
167- key .view (num_tokens , num_heads , head_size ),
168- )
169- ref_query = ref_query .view (num_tokens , num_heads * head_size )
170- ref_key = ref_key .view (num_tokens , num_heads * head_size )
57+ key = torch .randn_like (query )
17158
59+ # NOTE(woosuk): The reference implementation should be executed first
60+ # because the custom kernel is in-place.
61+ ref_query , ref_key = rope ._forward (positions , query , key )
62+ out_query , out_key = rope .forward (positions , query , key )
17263 # Compare the results.
17364 assert torch .allclose (out_query , ref_query , atol = 1e-5 , rtol = 1e-5 )
17465 assert torch .allclose (out_key , ref_key , atol = 1e-5 , rtol = 1e-5 )
0 commit comments