7
7
8
8
from vllm import pos_encoding_ops
9
9
10
+ IS_NEOX_STYLE = [True , False ]
10
11
DTYPES = [torch .half , torch .bfloat16 , torch .float ]
11
12
HEAD_SIZES = [64 , 80 , 96 , 112 , 128 , 256 ]
12
13
ROTARY_DIMS = [None , 32 ] # None means rotary dim == head size
13
14
NUM_HEADS = [7 , 12 , 40 , 52 ] # Arbitrary values for testing
14
- NUM_TOKENS = [7 , 83 , 2048 ] # Arbitrary values for testing
15
+ NUM_TOKENS = [11 , 83 , 2048 ] # Arbitrary values for testing
15
16
SEEDS = [0 ]
16
17
17
18
18
- def rotate_half (x : torch .Tensor ) -> torch .Tensor :
19
+ def rotate_neox (x : torch .Tensor ) -> torch .Tensor :
19
20
x1 = x [..., :x .shape [- 1 ] // 2 ]
20
21
x2 = x [..., x .shape [- 1 ] // 2 :]
21
22
return torch .cat ((- x2 , x1 ), dim = - 1 )
22
23
23
24
24
- def apply_rotary_pos_emb (
25
+ def rotate_gptj (x : torch .Tensor ) -> torch .Tensor :
26
+ x1 = x [..., ::2 ]
27
+ x2 = x [..., 1 ::2 ]
28
+ x = torch .stack ((- x2 , x1 ), dim = - 1 )
29
+ return x .flatten (- 2 )
30
+
31
+
32
+ def apply_rope (
25
33
q : torch .Tensor ,
26
34
k : torch .Tensor ,
27
35
cos : torch .Tensor ,
28
36
sin : torch .Tensor ,
37
+ is_neox_style : bool ,
29
38
) -> Tuple [torch .Tensor , torch .Tensor ]:
30
- q_embed = (q * cos ) + (rotate_half (q ) * sin )
31
- k_embed = (k * cos ) + (rotate_half (k ) * sin )
39
+ rotate_fn = rotate_neox if is_neox_style else rotate_gptj
40
+ q_embed = (q * cos ) + (rotate_fn (q ) * sin )
41
+ k_embed = (k * cos ) + (rotate_fn (k ) * sin )
32
42
return q_embed , k_embed
33
43
34
44
35
- class RefRotaryEmbeddingNeox (nn .Module ):
36
- """Reference implementation of the GPT-NeoX style rotary embedding."""
45
+ class RefRotaryEmbedding (nn .Module ):
46
+ """Reference implementation of rotary embedding."""
37
47
38
48
def __init__ (
39
49
self ,
40
50
dim : int ,
41
- max_position_embeddings : int = 2048 ,
51
+ is_neox_style : bool ,
52
+ max_position_embeddings : int = 8192 ,
42
53
base : int = 10000 ,
43
54
) -> None :
44
55
super ().__init__ ()
45
56
self .rotary_dim = dim
57
+ self .is_neox_style = is_neox_style
46
58
self .max_position_embeddings = max_position_embeddings
47
59
48
60
# Create cos and sin embeddings.
49
61
inv_freq = 1.0 / (base ** (torch .arange (0 , dim , 2 ) / dim ))
50
62
t = torch .arange (max_position_embeddings ).float ()
51
63
freqs = torch .einsum ("i,j->ij" , t , inv_freq .float ())
52
- emb = torch .cat ((freqs , freqs ), dim = - 1 )
64
+ if is_neox_style :
65
+ emb = torch .cat ((freqs , freqs ), dim = - 1 )
66
+ else :
67
+ emb = torch .repeat_interleave (freqs , 2 , - 1 )
53
68
cos = emb .cos ().to (dtype = inv_freq .dtype )
54
69
sin = emb .sin ().to (dtype = inv_freq .dtype )
55
70
self .register_buffer ("cos_cached" , cos , persistent = False )
@@ -61,7 +76,6 @@ def forward(
61
76
query : torch .Tensor , # [num_tokens, num_heads, head_size]
62
77
key : torch .Tensor , # [num_tokens, num_heads, head_size]
63
78
) -> Tuple [torch .Tensor , torch .Tensor ]:
64
-
65
79
query_rot = query [..., :self .rotary_dim ]
66
80
query_pass = query [..., self .rotary_dim :]
67
81
key_rot = key [..., :self .rotary_dim ]
@@ -71,7 +85,9 @@ def forward(
71
85
key_rot = key_rot .transpose (0 , 1 )
72
86
cos = F .embedding (positions , self .cos_cached )
73
87
sin = F .embedding (positions , self .sin_cached )
74
- query_rot , key_rot = apply_rotary_pos_emb (query_rot , key_rot , cos , sin )
88
+
89
+ query_rot , key_rot = apply_rope (query_rot , key_rot , cos , sin ,
90
+ self .is_neox_style )
75
91
query_rot = query_rot .transpose (0 , 1 ).contiguous ()
76
92
key_rot = key_rot .transpose (0 , 1 ).contiguous ()
77
93
@@ -82,14 +98,16 @@ def forward(
82
98
return query , key
83
99
84
100
101
+ @pytest .mark .parametrize ("is_neox_style" , IS_NEOX_STYLE )
85
102
@pytest .mark .parametrize ("num_tokens" , NUM_TOKENS )
86
103
@pytest .mark .parametrize ("num_heads" , NUM_HEADS )
87
104
@pytest .mark .parametrize ("head_size" , HEAD_SIZES )
88
105
@pytest .mark .parametrize ("rotary_dim" , ROTARY_DIMS )
89
106
@pytest .mark .parametrize ("dtype" , DTYPES )
90
107
@pytest .mark .parametrize ("seed" , SEEDS )
91
108
@torch .inference_mode ()
92
- def test_rotary_embedding_neox (
109
+ def test_rotary_embedding (
110
+ is_neox_style : bool ,
93
111
num_tokens : int ,
94
112
num_heads : int ,
95
113
head_size : int ,
@@ -104,15 +122,15 @@ def test_rotary_embedding_neox(
104
122
torch .random .manual_seed (seed )
105
123
torch .cuda .manual_seed (seed )
106
124
107
- positions = torch .randint (0 , max_position , (num_tokens , ), device = ' cuda' )
125
+ positions = torch .randint (0 , max_position , (num_tokens , ), device = " cuda" )
108
126
query = torch .randn (num_tokens ,
109
127
num_heads * head_size ,
110
128
dtype = dtype ,
111
- device = ' cuda' )
129
+ device = " cuda" )
112
130
key = torch .randn (num_tokens ,
113
131
num_heads * head_size ,
114
132
dtype = dtype ,
115
- device = ' cuda' )
133
+ device = " cuda" )
116
134
117
135
# Create the rotary embedding.
118
136
inv_freq = 1.0 / (base ** (torch .arange (0 , rotary_dim , 2 ) / rotary_dim ))
@@ -126,20 +144,22 @@ def test_rotary_embedding_neox(
126
144
# Run the kernel. The kernel is in-place, so we need to clone the inputs.
127
145
out_query = query .clone ()
128
146
out_key = key .clone ()
129
- pos_encoding_ops .rotary_embedding_neox (
147
+ pos_encoding_ops .rotary_embedding (
130
148
positions ,
131
149
out_query ,
132
150
out_key ,
133
151
head_size ,
134
152
cos_sin_cache ,
153
+ is_neox_style ,
135
154
)
136
155
137
156
# Run the reference implementation.
138
- ref_rotary_embedding = RefRotaryEmbeddingNeox (
157
+ ref_rotary_embedding = RefRotaryEmbedding (
139
158
dim = rotary_dim ,
159
+ is_neox_style = is_neox_style ,
140
160
max_position_embeddings = max_position ,
141
161
base = base ,
142
- ).to (dtype = dtype , device = ' cuda' )
162
+ ).to (dtype = dtype , device = " cuda" )
143
163
ref_query , ref_key = ref_rotary_embedding (
144
164
positions ,
145
165
query .view (num_tokens , num_heads , head_size ),
0 commit comments