1
- from typing import Optional , Tuple
1
+ from typing import Optional
2
2
3
3
import pytest
4
4
import torch
5
- import torch .nn as nn
6
- import torch .nn .functional as F
7
5
8
- from vllm ._C import ops
6
+ from vllm .model_executor . layers . rotary_embedding import get_rope
9
7
10
8
IS_NEOX_STYLE = [True , False ]
11
9
DTYPES = [torch .half , torch .bfloat16 , torch .float ]
12
10
HEAD_SIZES = [64 , 80 , 96 , 112 , 128 , 256 ]
13
11
ROTARY_DIMS = [None , 32 ] # None means rotary dim == head size
14
- NUM_HEADS = [7 , 12 , 40 , 52 ] # Arbitrary values for testing
15
- NUM_TOKENS = [11 , 83 , 2048 ] # Arbitrary values for testing
12
+ NUM_HEADS = [7 , 17 ] # Arbitrary values for testing
13
+ BATCH_SIZES = [1 , 5 ] # Arbitrary values for testing
14
+ SEQ_LENS = [11 , 8192 ] # Arbitrary values for testing
16
15
SEEDS = [0 ]
17
16
18
17
19
- def rotate_neox (x : torch .Tensor ) -> torch .Tensor :
20
- x1 = x [..., :x .shape [- 1 ] // 2 ]
21
- x2 = x [..., x .shape [- 1 ] // 2 :]
22
- return torch .cat ((- x2 , x1 ), dim = - 1 )
23
-
24
-
25
- def rotate_gptj (x : torch .Tensor ) -> torch .Tensor :
26
- x1 = x [..., ::2 ]
27
- x2 = x [..., 1 ::2 ]
28
- x = torch .stack ((- x2 , x1 ), dim = - 1 )
29
- return x .flatten (- 2 )
30
-
31
-
32
- def apply_rope (
33
- q : torch .Tensor ,
34
- k : torch .Tensor ,
35
- cos : torch .Tensor ,
36
- sin : torch .Tensor ,
37
- is_neox_style : bool ,
38
- ) -> Tuple [torch .Tensor , torch .Tensor ]:
39
- rotate_fn = rotate_neox if is_neox_style else rotate_gptj
40
- q_embed = (q * cos ) + (rotate_fn (q ) * sin )
41
- k_embed = (k * cos ) + (rotate_fn (k ) * sin )
42
- return q_embed , k_embed
43
-
44
-
45
- class RefRotaryEmbedding (nn .Module ):
46
- """Reference implementation of rotary embedding."""
47
-
48
- def __init__ (
49
- self ,
50
- dim : int ,
51
- is_neox_style : bool ,
52
- max_position_embeddings : int = 8192 ,
53
- base : int = 10000 ,
54
- ) -> None :
55
- super ().__init__ ()
56
- self .rotary_dim = dim
57
- self .is_neox_style = is_neox_style
58
- self .max_position_embeddings = max_position_embeddings
59
-
60
- # Create cos and sin embeddings.
61
- inv_freq = 1.0 / (base ** (torch .arange (0 , dim , 2 ) / dim ))
62
- t = torch .arange (max_position_embeddings ).float ()
63
- freqs = torch .einsum ("i,j->ij" , t , inv_freq .float ())
64
- if is_neox_style :
65
- emb = torch .cat ((freqs , freqs ), dim = - 1 )
66
- else :
67
- emb = torch .repeat_interleave (freqs , 2 , - 1 )
68
- cos = emb .cos ().to (dtype = inv_freq .dtype )
69
- sin = emb .sin ().to (dtype = inv_freq .dtype )
70
- self .register_buffer ("cos_cached" , cos , persistent = False )
71
- self .register_buffer ("sin_cached" , sin , persistent = False )
72
-
73
- def forward (
74
- self ,
75
- positions : torch .Tensor , # [num_tokens]
76
- query : torch .Tensor , # [num_tokens, num_heads, head_size]
77
- key : torch .Tensor , # [num_tokens, num_heads, head_size]
78
- ) -> Tuple [torch .Tensor , torch .Tensor ]:
79
- query_rot = query [..., :self .rotary_dim ]
80
- query_pass = query [..., self .rotary_dim :]
81
- key_rot = key [..., :self .rotary_dim ]
82
- key_pass = key [..., self .rotary_dim :]
83
-
84
- query_rot = query_rot .transpose (0 , 1 )
85
- key_rot = key_rot .transpose (0 , 1 )
86
- cos = F .embedding (positions , self .cos_cached )
87
- sin = F .embedding (positions , self .sin_cached )
88
-
89
- query_rot , key_rot = apply_rope (query_rot , key_rot , cos , sin ,
90
- self .is_neox_style )
91
- query_rot = query_rot .transpose (0 , 1 ).contiguous ()
92
- key_rot = key_rot .transpose (0 , 1 ).contiguous ()
93
-
94
- query = torch .cat ((query_rot , query_pass ), dim = - 1 )
95
- key = torch .cat ((key_rot , key_pass ), dim = - 1 )
96
-
97
- # Output query/key shape: [num_tokens, num_tokens, head_size]
98
- return query , key
99
-
100
-
101
18
@pytest .mark .parametrize ("is_neox_style" , IS_NEOX_STYLE )
102
- @pytest .mark .parametrize ("num_tokens" , NUM_TOKENS )
19
+ @pytest .mark .parametrize ("batch_size" , BATCH_SIZES )
20
+ @pytest .mark .parametrize ("seq_len" , SEQ_LENS )
103
21
@pytest .mark .parametrize ("num_heads" , NUM_HEADS )
104
22
@pytest .mark .parametrize ("head_size" , HEAD_SIZES )
105
23
@pytest .mark .parametrize ("rotary_dim" , ROTARY_DIMS )
@@ -108,7 +26,8 @@ def forward(
108
26
@torch .inference_mode ()
109
27
def test_rotary_embedding (
110
28
is_neox_style : bool ,
111
- num_tokens : int ,
29
+ batch_size : int ,
30
+ seq_len : int ,
112
31
num_heads : int ,
113
32
head_size : int ,
114
33
rotary_dim : Optional [int ],
@@ -122,53 +41,25 @@ def test_rotary_embedding(
122
41
torch .random .manual_seed (seed )
123
42
torch .cuda .manual_seed (seed )
124
43
125
- positions = torch .randint (0 , max_position , (num_tokens , ), device = "cuda" )
126
- query = torch .randn (num_tokens ,
44
+ if rotary_dim is None :
45
+ rotary_dim = head_size
46
+ rope = get_rope (head_size , rotary_dim , max_position , base , is_neox_style )
47
+ rope = rope .to (dtype ).cuda ()
48
+
49
+ positions = torch .randint (0 ,
50
+ max_position , (batch_size , seq_len ),
51
+ device = "cuda" )
52
+ query = torch .randn (batch_size ,
53
+ seq_len ,
127
54
num_heads * head_size ,
128
55
dtype = dtype ,
129
56
device = "cuda" )
130
- key = torch .randn (num_tokens ,
131
- num_heads * head_size ,
132
- dtype = dtype ,
133
- device = "cuda" )
134
-
135
- # Create the rotary embedding.
136
- inv_freq = 1.0 / (base ** (
137
- torch .arange (0 , rotary_dim , 2 , dtype = torch .float ) / rotary_dim ))
138
- t = torch .arange (max_position ).float ()
139
- freqs = torch .einsum ("i,j -> ij" , t , inv_freq )
140
- cos = freqs .cos ()
141
- sin = freqs .sin ()
142
- cos_sin_cache = torch .cat ((cos , sin ), dim = - 1 )
143
- cos_sin_cache = cos_sin_cache .to (dtype = dtype , device = "cuda" )
144
-
145
- # Run the kernel. The kernel is in-place, so we need to clone the inputs.
146
- out_query = query .clone ()
147
- out_key = key .clone ()
148
- ops .rotary_embedding (
149
- positions ,
150
- out_query ,
151
- out_key ,
152
- head_size ,
153
- cos_sin_cache ,
154
- is_neox_style ,
155
- )
156
-
157
- # Run the reference implementation.
158
- ref_rotary_embedding = RefRotaryEmbedding (
159
- dim = rotary_dim ,
160
- is_neox_style = is_neox_style ,
161
- max_position_embeddings = max_position ,
162
- base = base ,
163
- ).to (dtype = dtype , device = "cuda" )
164
- ref_query , ref_key = ref_rotary_embedding (
165
- positions ,
166
- query .view (num_tokens , num_heads , head_size ),
167
- key .view (num_tokens , num_heads , head_size ),
168
- )
169
- ref_query = ref_query .view (num_tokens , num_heads * head_size )
170
- ref_key = ref_key .view (num_tokens , num_heads * head_size )
57
+ key = torch .randn_like (query )
171
58
59
+ # NOTE(woosuk): The reference implementation should be executed first
60
+ # because the custom kernel is in-place.
61
+ ref_query , ref_key = rope ._forward (positions , query , key )
62
+ out_query , out_key = rope .forward (positions , query , key )
172
63
# Compare the results.
173
64
assert torch .allclose (out_query , ref_query , atol = 1e-5 , rtol = 1e-5 )
174
65
assert torch .allclose (out_key , ref_key , atol = 1e-5 , rtol = 1e-5 )
0 commit comments