Skip to content

Commit 5d63fd4

Browse files
Add rotary_embedding kernel (#7)
* add rope Signed-off-by: Ma, Liangliang <[email protected]> * fix potential acc issue Signed-off-by: Ma, Liangliang <[email protected]> * fix format Signed-off-by: Ma, Liangliang <[email protected]> * fix format of rebase Signed-off-by: Ma, Liangliang <[email protected]> --------- Signed-off-by: Ma, Liangliang <[email protected]>
1 parent ccc900b commit 5d63fd4

File tree

7 files changed

+452
-1
lines changed

7 files changed

+452
-1
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ if(VLLM_GPU_LANG STREQUAL "SYCL")
148148
set(VLLM_EXT_SRC
149149
"csrc/xpu/cache.cpp"
150150
"csrc/xpu/layernorm.cpp"
151+
"csrc/xpu/pos_encoding_kernels.cpp"
151152
"csrc/xpu/torch_bindings.cpp"
152153
)
153154
include_directories("/usr/include")

csrc/xpu/ops.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
88
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
99
torch::Tensor& weight, double epsilon);
1010

11+
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
12+
std::optional<torch::Tensor> key, int64_t head_size,
13+
torch::Tensor& cos_sin_cache, bool is_neox);
14+
1115
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
1216
torch::Tensor& key_cache, torch::Tensor& value_cache,
1317
torch::Tensor& slot_mapping,

csrc/xpu/pos_encoding_kernels.cpp

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
#include <sycl/sycl.hpp>
2+
#include "utils.h"
3+
#include "dispatch_utils.h"
4+
#include <cmath>
5+
#include <c10/macros/Macros.h>
6+
7+
namespace vllm {
8+
9+
template <typename scalar_t, bool IS_NEOX>
10+
inline void apply_token_rotary_embedding(scalar_t* __restrict__ arr,
11+
const scalar_t* __restrict__ cos_ptr,
12+
const scalar_t* __restrict__ sin_ptr,
13+
int rot_offset, int embed_dim) {
14+
int x_index, y_index;
15+
scalar_t cos, sin;
16+
if (IS_NEOX) {
17+
// GPT-NeoX style rotary embedding.
18+
x_index = rot_offset;
19+
y_index = embed_dim + rot_offset;
20+
cos = cos_ptr[x_index];
21+
sin = sin_ptr[x_index];
22+
} else {
23+
// GPT-J style rotary embedding.
24+
x_index = 2 * rot_offset;
25+
y_index = 2 * rot_offset + 1;
26+
cos = cos_ptr[x_index / 2];
27+
sin = sin_ptr[x_index / 2];
28+
}
29+
30+
const scalar_t x = arr[x_index];
31+
const scalar_t y = arr[y_index];
32+
arr[x_index] = x * cos - y * sin;
33+
arr[y_index] = y * cos + x * sin;
34+
}
35+
36+
template <typename scalar_t, bool IS_NEOX>
37+
inline void apply_rotary_embedding(
38+
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
39+
// head_size] or [num_tokens, num_heads,
40+
// head_size]
41+
scalar_t* __restrict__ key, // nullptr or
42+
// [batch_size, seq_len, num_kv_heads,
43+
// head_size] or [num_tokens, num_kv_heads,
44+
// head_size]
45+
const scalar_t* cache_ptr, const int head_size, const int num_heads,
46+
const int num_kv_heads, const int rot_dim, const int token_idx,
47+
const int64_t query_stride, const int64_t key_stride,
48+
const int64_t head_stride, const sycl::nd_item<3>& item_ct1) {
49+
const int embed_dim = rot_dim / 2;
50+
const scalar_t* cos_ptr = cache_ptr;
51+
const scalar_t* sin_ptr = cache_ptr + embed_dim;
52+
53+
const int nq = num_heads * embed_dim;
54+
for (int i = item_ct1.get_local_id(2); i < nq;
55+
i += item_ct1.get_local_range(2)) {
56+
const int head_idx = i / embed_dim;
57+
const int64_t token_head =
58+
token_idx * query_stride + head_idx * head_stride;
59+
const int rot_offset = i % embed_dim;
60+
apply_token_rotary_embedding<scalar_t, IS_NEOX>(
61+
query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
62+
}
63+
64+
if (key != nullptr) {
65+
const int nk = num_kv_heads * embed_dim;
66+
for (int i = item_ct1.get_local_id(2); i < nk;
67+
i += item_ct1.get_local_range(2)) {
68+
const int head_idx = i / embed_dim;
69+
const int64_t token_head =
70+
token_idx * key_stride + head_idx * head_stride;
71+
const int rot_offset = i % embed_dim;
72+
apply_token_rotary_embedding<scalar_t, IS_NEOX>(
73+
key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
74+
}
75+
}
76+
}
77+
78+
template <typename scalar_t, bool IS_NEOX>
79+
void rotary_embedding_kernel(
80+
const int64_t* __restrict__ positions, // [batch_size, seq_len] or
81+
// [num_tokens]
82+
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
83+
// head_size] or [num_tokens, num_heads,
84+
// head_size]
85+
scalar_t* __restrict__ key, // nullptr or
86+
// [batch_size, seq_len, num_kv_heads,
87+
// head_size] or [num_tokens, num_kv_heads,
88+
// head_size]
89+
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
90+
// 2]
91+
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
92+
const int64_t head_stride, const int num_heads, const int num_kv_heads,
93+
const int head_size, const sycl::nd_item<3>& item_ct1) {
94+
// Each thread block is responsible for one token.
95+
const int token_idx = item_ct1.get_group(2);
96+
int64_t pos = positions[token_idx];
97+
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
98+
99+
apply_rotary_embedding<scalar_t, IS_NEOX>(
100+
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
101+
token_idx, query_stride, key_stride, head_stride, item_ct1);
102+
}
103+
104+
} // namespace vllm
105+
106+
template <typename scalar_t>
107+
void call_rotary_embedding_kernel(
108+
torch::Tensor& positions, torch::Tensor& query,
109+
std::optional<torch::Tensor> key, int64_t head_size,
110+
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
111+
bool is_neox) {
112+
using sycl_t = vllm::xpu::SyclTypeTrait<scalar_t>::Type;
113+
// num_tokens = batch_size * seq_len
114+
int64_t num_tokens = positions.numel();
115+
int positions_ndim = positions.dim();
116+
117+
// Make sure num_tokens dim is consistent across positions, query, and key
118+
TORCH_CHECK(
119+
positions_ndim == 1 || positions_ndim == 2,
120+
"positions must have shape [num_tokens] or [batch_size, seq_len]");
121+
if (positions_ndim == 1) {
122+
TORCH_CHECK(query.size(0) == positions.size(0) &&
123+
(!key.has_value() || key->size(0) == positions.size(0)),
124+
"query, key and positions must have the same number of tokens");
125+
}
126+
if (positions_ndim == 2) {
127+
TORCH_CHECK(
128+
query.size(0) == positions.size(0) &&
129+
(!key.has_value() || key->size(0) == positions.size(0)) &&
130+
query.size(1) == positions.size(1) &&
131+
(!key.has_value() || key->size(1) == positions.size(1)),
132+
"query, key and positions must have the same batch_size and seq_len");
133+
}
134+
135+
// Make sure head_size is valid for query and key
136+
// hidden_size = num_heads * head_size
137+
int query_hidden_size = query.numel() / num_tokens;
138+
int key_hidden_size = key.has_value() ? key->numel() / num_tokens : 0;
139+
TORCH_CHECK(query_hidden_size % head_size == 0);
140+
TORCH_CHECK(key_hidden_size % head_size == 0);
141+
142+
// Make sure query and key have consistent number of heads
143+
int num_heads = query_hidden_size / head_size;
144+
int num_kv_heads = key.has_value() ? key_hidden_size / head_size : num_heads;
145+
TORCH_CHECK(num_heads % num_kv_heads == 0);
146+
147+
int rot_dim = cos_sin_cache.size(1);
148+
int seq_dim_idx = positions_ndim - 1;
149+
int64_t query_stride = query.stride(seq_dim_idx);
150+
int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0;
151+
// Determine head stride: for [*, heads, head_size] use stride of last dim;
152+
// for flat [*, heads*head_size], heads blocks are contiguous of size
153+
// head_size
154+
int query_ndim = query.dim();
155+
int64_t head_stride =
156+
(query_ndim == positions_ndim + 2) ? query.stride(-2) : head_size;
157+
158+
auto positions_ptr = positions.data_ptr<int64_t>();
159+
auto query_ptr = query.data_ptr<scalar_t>();
160+
auto key_ptr = key.has_value() ? key->data_ptr<scalar_t>() : nullptr;
161+
auto cos_sin_cache_ptr = cos_sin_cache.data_ptr<scalar_t>();
162+
163+
sycl::range<3> grid(1, 1, num_tokens);
164+
sycl::range<3> block(1, 1, std::min<int64_t>(num_heads * rot_dim / 2, 512));
165+
166+
at::DeviceGuard device_guard(query.device());
167+
auto& queue = vllm::xpu::vllmGetQueue();
168+
if (is_neox) {
169+
queue.submit([&](sycl::handler& cgh) {
170+
cgh.parallel_for(
171+
sycl::nd_range<3>(grid * block, block),
172+
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
173+
vllm::rotary_embedding_kernel<sycl_t, true>(
174+
positions_ptr, (sycl_t*)query_ptr, (sycl_t*)key_ptr,
175+
(sycl_t*)cos_sin_cache_ptr, rot_dim, query_stride, key_stride,
176+
head_stride, num_heads, num_kv_heads, head_size, item_ct1);
177+
});
178+
});
179+
} else {
180+
queue.submit([&](sycl::handler& cgh) {
181+
cgh.parallel_for(
182+
sycl::nd_range<3>(grid * block, block),
183+
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
184+
vllm::rotary_embedding_kernel<sycl_t, false>(
185+
positions_ptr, (sycl_t*)query_ptr, (sycl_t*)key_ptr,
186+
(sycl_t*)cos_sin_cache_ptr, rot_dim, query_stride, key_stride,
187+
head_stride, num_heads, num_kv_heads, head_size, item_ct1);
188+
});
189+
});
190+
}
191+
}
192+
193+
void rotary_embedding(
194+
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
195+
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
196+
// [num_tokens, num_heads * head_size] or
197+
// [batch_size, seq_len, num_heads, head_size] or
198+
// [num_tokens, num_heads, head_size]
199+
std::optional<torch::Tensor> key,
200+
// null or
201+
// [batch_size, seq_len, num_kv_heads * head_size] or
202+
// [num_tokens, num_kv_heads * head_size] or
203+
// [batch_size, seq_len, num_heads, head_size] or
204+
// [num_tokens, num_heads, head_size]
205+
int64_t head_size,
206+
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
207+
bool is_neox) {
208+
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
209+
call_rotary_embedding_kernel<scalar_t>(positions, query, key, head_size,
210+
cos_sin_cache, is_neox);
211+
});
212+
}

csrc/xpu/torch_bindings.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
3131
"fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, "
3232
"float epsilon) -> ()");
3333
ops.impl("fused_add_rms_norm", torch::kXPU, &fused_add_rms_norm);
34+
35+
// pos_embedding
36+
ops.def(
37+
"rotary_embedding(Tensor positions, Tensor! query,"
38+
" Tensor!? key, int head_size,"
39+
" Tensor cos_sin_cache, bool is_neox) -> ()");
40+
ops.impl("rotary_embedding", torch::kXPU, &rotary_embedding);
3441
}
3542

3643
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {

tests/ops/rotary_embedding_op.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""Rotary Positional Embeddings Base Class."""
4+
from typing import Optional
5+
6+
import torch
7+
8+
import tests.register_ops as ops
9+
from tests.ops.custom_ops import CustomOp
10+
11+
12+
def apply_rotary_emb_torch(
13+
x: torch.Tensor,
14+
cos: torch.Tensor,
15+
sin: torch.Tensor,
16+
is_neox_style: bool,
17+
) -> torch.Tensor:
18+
cos = cos.unsqueeze(-2).to(x.dtype)
19+
sin = sin.unsqueeze(-2).to(x.dtype)
20+
if is_neox_style:
21+
x1, x2 = torch.chunk(x, 2, dim=-1)
22+
else:
23+
x1 = x[..., ::2]
24+
x2 = x[..., 1::2]
25+
o1 = x1 * cos - x2 * sin
26+
o2 = x2 * cos + x1 * sin
27+
if is_neox_style:
28+
return torch.cat((o1, o2), dim=-1)
29+
else:
30+
return torch.stack((o1, o2), dim=-1).flatten(-2)
31+
32+
33+
class RotaryEmbedding(CustomOp):
34+
"""Original rotary positional embedding."""
35+
36+
def __init__(
37+
self,
38+
head_size: int,
39+
rotary_dim: int,
40+
max_position_embeddings: int,
41+
base: float,
42+
is_neox_style: bool,
43+
dtype: torch.dtype,
44+
) -> None:
45+
super().__init__()
46+
self.head_size = head_size
47+
self.rotary_dim = rotary_dim
48+
self.max_position_embeddings = max_position_embeddings
49+
self.base = base
50+
self.is_neox_style = is_neox_style
51+
self.dtype = dtype
52+
53+
cache = self._compute_cos_sin_cache()
54+
cache = cache.to(dtype)
55+
self.cos_sin_cache: torch.Tensor
56+
self.register_buffer("cos_sin_cache", cache, persistent=False)
57+
58+
def _compute_inv_freq(self, base: float) -> torch.Tensor:
59+
"""Compute the inverse frequency."""
60+
# NOTE(woosuk): To exactly match the HF implementation, we need to
61+
# use CPU to compute the cache and then move it to GPU. However, we
62+
# create the cache on GPU for faster initialization. This may cause
63+
# a slight numerical difference between the HF implementation and ours.
64+
inv_freq = 1.0 / (base**(torch.arange(
65+
0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
66+
return inv_freq
67+
68+
def _compute_cos_sin_cache(self) -> torch.Tensor:
69+
"""Compute the cos and sin cache."""
70+
inv_freq = self._compute_inv_freq(self.base)
71+
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
72+
73+
freqs = torch.einsum("i,j -> ij", t, inv_freq)
74+
cos = freqs.cos()
75+
sin = freqs.sin()
76+
cache = torch.cat((cos, sin), dim=-1)
77+
return cache
78+
79+
def forward_native(
80+
self,
81+
positions: torch.Tensor,
82+
query: torch.Tensor,
83+
key: Optional[torch.Tensor] = None,
84+
offsets: Optional[torch.Tensor] = None,
85+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
86+
"""A PyTorch-native implementation of forward()."""
87+
if offsets is not None:
88+
positions = positions + offsets
89+
positions = positions.flatten()
90+
num_tokens = positions.shape[0]
91+
cos_sin = self.cos_sin_cache.index_select(0, positions)
92+
cos, sin = cos_sin.chunk(2, dim=-1)
93+
94+
query_shape = query.shape
95+
query = query.view(num_tokens, -1, self.head_size)
96+
query_rot = query[..., :self.rotary_dim]
97+
query_pass = query[..., self.rotary_dim:]
98+
query_rot = apply_rotary_emb_torch(query_rot, cos, sin,
99+
self.is_neox_style)
100+
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
101+
102+
# key may be None in some cases, e.g. cross-layer KV sharing
103+
if key is not None:
104+
key_shape = key.shape
105+
key = key.view(num_tokens, -1, self.head_size)
106+
key_rot = key[..., :self.rotary_dim]
107+
key_pass = key[..., self.rotary_dim:]
108+
key_rot = apply_rotary_emb_torch(key_rot, cos, sin,
109+
self.is_neox_style)
110+
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
111+
return query, key
112+
113+
def forward_xpu(
114+
self,
115+
positions: torch.Tensor,
116+
query: torch.Tensor,
117+
key: Optional[torch.Tensor] = None,
118+
offsets: Optional[torch.Tensor] = None,
119+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
120+
121+
# __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`)
122+
# is expensive, so avoid calling it if possible
123+
if self.cos_sin_cache.device != query.device or \
124+
self.cos_sin_cache.dtype != query.dtype:
125+
self.cos_sin_cache = self.cos_sin_cache.to(query.device,
126+
dtype=query.dtype)
127+
128+
# ops.rotary_embedding()/batched_rotary_embedding()
129+
# are in-place operations that update the query and key tensors.
130+
if offsets is not None:
131+
raise NotImplementedError(
132+
"batched_rotary_embedding is not implemented yet.")
133+
else:
134+
ops.rotary_embedding(positions, query, key, self.head_size,
135+
self.cos_sin_cache, self.is_neox_style)
136+
return query, key
137+
138+
def extra_repr(self) -> str:
139+
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
140+
s += f", max_position_embeddings={self.max_position_embeddings}"
141+
s += f", base={self.base}, is_neox_style={self.is_neox_style}"
142+
return s

0 commit comments

Comments
 (0)