Skip to content

Commit 320a622

Browse files
authored
[BugFix] Implement RoPE for GPT-J (#941)
1 parent c9927c1 commit 320a622

File tree

5 files changed

+122
-72
lines changed

5 files changed

+122
-72
lines changed

csrc/pos_encoding.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
#include <torch/extension.h>
22

3-
void rotary_embedding_neox(
3+
void rotary_embedding(
44
torch::Tensor& positions,
55
torch::Tensor& query,
66
torch::Tensor& key,
77
int head_size,
8-
torch::Tensor& cos_sin_cache);
8+
torch::Tensor& cos_sin_cache,
9+
bool is_neox);
910

1011
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
1112
m.def(
12-
"rotary_embedding_neox",
13-
&rotary_embedding_neox,
14-
"Apply GPT-NeoX style rotary embedding to query and key");
13+
"rotary_embedding",
14+
&rotary_embedding,
15+
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
1516
}

csrc/pos_encoding_kernels.cu

Lines changed: 68 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,38 @@
55

66
namespace vllm {
77

8-
template<typename scalar_t>
9-
__global__ void rotary_embedding_neox_kernel(
8+
template<typename scalar_t, bool IS_NEOX>
9+
inline __device__ void apply_rotary_embedding(
10+
scalar_t* __restrict__ arr,
11+
const scalar_t* __restrict__ cos_ptr,
12+
const scalar_t* __restrict__ sin_ptr,
13+
int rot_offset,
14+
int embed_dim)
15+
{
16+
int x_index, y_index;
17+
scalar_t cos, sin;
18+
if (IS_NEOX) {
19+
// GPT-NeoX style rotary embedding.
20+
x_index = rot_offset;
21+
y_index = embed_dim + rot_offset;
22+
cos = __ldg(cos_ptr + x_index);
23+
sin = __ldg(sin_ptr + x_index);
24+
} else {
25+
// GPT-J style rotary embedding.
26+
x_index = 2 * rot_offset;
27+
y_index = 2 * rot_offset + 1;
28+
cos = __ldg(cos_ptr + x_index / 2);
29+
sin = __ldg(sin_ptr + x_index / 2);
30+
}
31+
32+
const scalar_t x = arr[x_index];
33+
const scalar_t y = arr[y_index];
34+
arr[x_index] = x * cos - y * sin;
35+
arr[y_index] = y * cos + x * sin;
36+
}
37+
38+
template<typename scalar_t, bool IS_NEOX>
39+
__global__ void rotary_embedding_kernel(
1040
const int64_t* __restrict__ positions, // [num_tokens]
1141
scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size]
1242
scalar_t* __restrict__ key, // [num_tokens, num_kv_heads, head_size]
@@ -23,58 +53,37 @@ __global__ void rotary_embedding_neox_kernel(
2353
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
2454

2555
const int embed_dim = rot_dim / 2;
56+
const scalar_t* cos_ptr = cache_ptr;
57+
const scalar_t* sin_ptr = cache_ptr + embed_dim;
58+
2659
const int nq = num_heads * embed_dim;
2760
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
2861
const int head_idx = i / embed_dim;
2962
const int token_head = token_idx * query_stride + head_idx * head_size;
30-
3163
const int rot_offset = i % embed_dim;
32-
const int x_index = rot_offset;
33-
const int y_index = embed_dim + rot_offset;
34-
35-
const int out_x = token_idx * query_stride + head_idx * head_size + x_index;
36-
const int out_y = token_idx * query_stride + head_idx * head_size + y_index;
37-
38-
const scalar_t cos = __ldg(cache_ptr + x_index);
39-
const scalar_t sin = __ldg(cache_ptr + y_index);
40-
41-
const scalar_t q_x = query[token_head + x_index];
42-
const scalar_t q_y = query[token_head + y_index];
43-
query[out_x] = q_x * cos - q_y * sin;
44-
query[out_y] = q_y * cos + q_x * sin;
64+
apply_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
65+
sin_ptr, rot_offset, embed_dim);
4566
}
4667

4768
const int nk = num_kv_heads * embed_dim;
4869
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
4970
const int head_idx = i / embed_dim;
5071
const int token_head = token_idx * key_stride + head_idx * head_size;
51-
5272
const int rot_offset = i % embed_dim;
53-
const int x_index = rot_offset;
54-
const int y_index = embed_dim + rot_offset;
55-
56-
const int out_x = token_idx * key_stride + head_idx * head_size + x_index;
57-
const int out_y = token_idx * key_stride + head_idx * head_size + y_index;
58-
59-
const scalar_t cos = __ldg(cache_ptr + x_index);
60-
const scalar_t sin = __ldg(cache_ptr + y_index);
61-
62-
const scalar_t k_x = key[token_head + x_index];
63-
const scalar_t k_y = key[token_head + y_index];
64-
key[out_x] = k_x * cos - k_y * sin;
65-
key[out_y] = k_y * cos + k_x * sin;
73+
apply_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
74+
sin_ptr, rot_offset, embed_dim);
6675
}
6776
}
6877

6978
} // namespace vllm
7079

71-
void rotary_embedding_neox(
80+
void rotary_embedding(
7281
torch::Tensor& positions, // [num_tokens]
7382
torch::Tensor& query, // [num_tokens, num_heads * head_size]
7483
torch::Tensor& key, // [num_tokens, num_kv_heads * head_size]
7584
int head_size,
76-
torch::Tensor& cos_sin_cache) // [max_position, rot_dim]
77-
{
85+
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
86+
bool is_neox) {
7887
int num_tokens = query.size(0);
7988
int rot_dim = cos_sin_cache.size(1);
8089
int num_heads = query.size(1) / head_size;
@@ -87,18 +96,32 @@ void rotary_embedding_neox(
8796
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
8897
VLLM_DISPATCH_FLOATING_TYPES(
8998
query.scalar_type(),
90-
"rotary_embedding_neox",
99+
"rotary_embedding",
91100
[&] {
92-
vllm::rotary_embedding_neox_kernel<scalar_t><<<grid, block, 0, stream>>>(
93-
positions.data_ptr<int64_t>(),
94-
query.data_ptr<scalar_t>(),
95-
key.data_ptr<scalar_t>(),
96-
cos_sin_cache.data_ptr<scalar_t>(),
97-
rot_dim,
98-
query_stride,
99-
key_stride,
100-
num_heads,
101-
num_kv_heads,
102-
head_size);
101+
if (is_neox) {
102+
vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
103+
positions.data_ptr<int64_t>(),
104+
query.data_ptr<scalar_t>(),
105+
key.data_ptr<scalar_t>(),
106+
cos_sin_cache.data_ptr<scalar_t>(),
107+
rot_dim,
108+
query_stride,
109+
key_stride,
110+
num_heads,
111+
num_kv_heads,
112+
head_size);
113+
} else {
114+
vllm::rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
115+
positions.data_ptr<int64_t>(),
116+
query.data_ptr<scalar_t>(),
117+
key.data_ptr<scalar_t>(),
118+
cos_sin_cache.data_ptr<scalar_t>(),
119+
rot_dim,
120+
query_stride,
121+
key_stride,
122+
num_heads,
123+
num_kv_heads,
124+
head_size);
125+
}
103126
});
104127
}

tests/kernels/test_pos_encoding.py

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,49 +7,64 @@
77

88
from vllm import pos_encoding_ops
99

10+
IS_NEOX_STYLE = [True, False]
1011
DTYPES = [torch.half, torch.bfloat16, torch.float]
1112
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
1213
ROTARY_DIMS = [None, 32] # None means rotary dim == head size
1314
NUM_HEADS = [7, 12, 40, 52] # Arbitrary values for testing
14-
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
15+
NUM_TOKENS = [11, 83, 2048] # Arbitrary values for testing
1516
SEEDS = [0]
1617

1718

18-
def rotate_half(x: torch.Tensor) -> torch.Tensor:
19+
def rotate_neox(x: torch.Tensor) -> torch.Tensor:
1920
x1 = x[..., :x.shape[-1] // 2]
2021
x2 = x[..., x.shape[-1] // 2:]
2122
return torch.cat((-x2, x1), dim=-1)
2223

2324

24-
def apply_rotary_pos_emb(
25+
def rotate_gptj(x: torch.Tensor) -> torch.Tensor:
26+
x1 = x[..., ::2]
27+
x2 = x[..., 1::2]
28+
x = torch.stack((-x2, x1), dim=-1)
29+
return x.flatten(-2)
30+
31+
32+
def apply_rope(
2533
q: torch.Tensor,
2634
k: torch.Tensor,
2735
cos: torch.Tensor,
2836
sin: torch.Tensor,
37+
is_neox_style: bool,
2938
) -> Tuple[torch.Tensor, torch.Tensor]:
30-
q_embed = (q * cos) + (rotate_half(q) * sin)
31-
k_embed = (k * cos) + (rotate_half(k) * sin)
39+
rotate_fn = rotate_neox if is_neox_style else rotate_gptj
40+
q_embed = (q * cos) + (rotate_fn(q) * sin)
41+
k_embed = (k * cos) + (rotate_fn(k) * sin)
3242
return q_embed, k_embed
3343

3444

35-
class RefRotaryEmbeddingNeox(nn.Module):
36-
"""Reference implementation of the GPT-NeoX style rotary embedding."""
45+
class RefRotaryEmbedding(nn.Module):
46+
"""Reference implementation of rotary embedding."""
3747

3848
def __init__(
3949
self,
4050
dim: int,
41-
max_position_embeddings: int = 2048,
51+
is_neox_style: bool,
52+
max_position_embeddings: int = 8192,
4253
base: int = 10000,
4354
) -> None:
4455
super().__init__()
4556
self.rotary_dim = dim
57+
self.is_neox_style = is_neox_style
4658
self.max_position_embeddings = max_position_embeddings
4759

4860
# Create cos and sin embeddings.
4961
inv_freq = 1.0 / (base**(torch.arange(0, dim, 2) / dim))
5062
t = torch.arange(max_position_embeddings).float()
5163
freqs = torch.einsum("i,j->ij", t, inv_freq.float())
52-
emb = torch.cat((freqs, freqs), dim=-1)
64+
if is_neox_style:
65+
emb = torch.cat((freqs, freqs), dim=-1)
66+
else:
67+
emb = torch.repeat_interleave(freqs, 2, -1)
5368
cos = emb.cos().to(dtype=inv_freq.dtype)
5469
sin = emb.sin().to(dtype=inv_freq.dtype)
5570
self.register_buffer("cos_cached", cos, persistent=False)
@@ -61,7 +76,6 @@ def forward(
6176
query: torch.Tensor, # [num_tokens, num_heads, head_size]
6277
key: torch.Tensor, # [num_tokens, num_heads, head_size]
6378
) -> Tuple[torch.Tensor, torch.Tensor]:
64-
6579
query_rot = query[..., :self.rotary_dim]
6680
query_pass = query[..., self.rotary_dim:]
6781
key_rot = key[..., :self.rotary_dim]
@@ -71,7 +85,9 @@ def forward(
7185
key_rot = key_rot.transpose(0, 1)
7286
cos = F.embedding(positions, self.cos_cached)
7387
sin = F.embedding(positions, self.sin_cached)
74-
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
88+
89+
query_rot, key_rot = apply_rope(query_rot, key_rot, cos, sin,
90+
self.is_neox_style)
7591
query_rot = query_rot.transpose(0, 1).contiguous()
7692
key_rot = key_rot.transpose(0, 1).contiguous()
7793

@@ -82,14 +98,16 @@ def forward(
8298
return query, key
8399

84100

101+
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
85102
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
86103
@pytest.mark.parametrize("num_heads", NUM_HEADS)
87104
@pytest.mark.parametrize("head_size", HEAD_SIZES)
88105
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
89106
@pytest.mark.parametrize("dtype", DTYPES)
90107
@pytest.mark.parametrize("seed", SEEDS)
91108
@torch.inference_mode()
92-
def test_rotary_embedding_neox(
109+
def test_rotary_embedding(
110+
is_neox_style: bool,
93111
num_tokens: int,
94112
num_heads: int,
95113
head_size: int,
@@ -104,15 +122,15 @@ def test_rotary_embedding_neox(
104122
torch.random.manual_seed(seed)
105123
torch.cuda.manual_seed(seed)
106124

107-
positions = torch.randint(0, max_position, (num_tokens, ), device='cuda')
125+
positions = torch.randint(0, max_position, (num_tokens, ), device="cuda")
108126
query = torch.randn(num_tokens,
109127
num_heads * head_size,
110128
dtype=dtype,
111-
device='cuda')
129+
device="cuda")
112130
key = torch.randn(num_tokens,
113131
num_heads * head_size,
114132
dtype=dtype,
115-
device='cuda')
133+
device="cuda")
116134

117135
# Create the rotary embedding.
118136
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
@@ -126,20 +144,22 @@ def test_rotary_embedding_neox(
126144
# Run the kernel. The kernel is in-place, so we need to clone the inputs.
127145
out_query = query.clone()
128146
out_key = key.clone()
129-
pos_encoding_ops.rotary_embedding_neox(
147+
pos_encoding_ops.rotary_embedding(
130148
positions,
131149
out_query,
132150
out_key,
133151
head_size,
134152
cos_sin_cache,
153+
is_neox_style,
135154
)
136155

137156
# Run the reference implementation.
138-
ref_rotary_embedding = RefRotaryEmbeddingNeox(
157+
ref_rotary_embedding = RefRotaryEmbedding(
139158
dim=rotary_dim,
159+
is_neox_style=is_neox_style,
140160
max_position_embeddings=max_position,
141161
base=base,
142-
).to(dtype=dtype, device='cuda')
162+
).to(dtype=dtype, device="cuda")
143163
ref_query, ref_key = ref_rotary_embedding(
144164
positions,
145165
query.view(num_tokens, num_heads, head_size),

vllm/model_executor/layers/attention.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def forward(
242242

243243

244244
class PagedAttentionWithRoPE(PagedAttention):
245-
"""PagedAttention with GPT-NeoX style rotary embedding."""
245+
"""PagedAttention with rotary embedding."""
246246

247247
def __init__(
248248
self,
@@ -253,8 +253,10 @@ def __init__(
253253
max_position: int = 8192,
254254
base: int = 10000,
255255
num_kv_heads: Optional[int] = None,
256+
is_neox_style: bool = True,
256257
) -> None:
257258
super().__init__(num_heads, head_size, scale, num_kv_heads)
259+
self.is_neox_style = is_neox_style
258260

259261
# Create the cos and sin cache.
260262
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
@@ -303,12 +305,13 @@ def forward(
303305

304306
# Apply rotary embedding to the query and key before passing them
305307
# to the attention op.
306-
pos_encoding_ops.rotary_embedding_neox(
308+
pos_encoding_ops.rotary_embedding(
307309
positions,
308310
query,
309311
key,
310312
self.head_size,
311313
self.cos_sin_cache,
314+
self.is_neox_style,
312315
)
313316
return super().forward(
314317
query,

vllm/model_executor/models/gpt_j.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,11 @@ def __init__(self, config: GPTJConfig):
6767
scaling = self.head_size**-0.5
6868
assert getattr(config, "rotary", True)
6969
assert config.rotary_dim % 2 == 0
70-
self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_size,
71-
scaling, config.rotary_dim)
70+
self.attn = PagedAttentionWithRoPE(self.num_heads,
71+
self.head_size,
72+
scaling,
73+
config.rotary_dim,
74+
is_neox_style=False)
7275
self.warmup = False
7376

7477
def forward(

0 commit comments

Comments
 (0)