|
| 1 | +#include <sycl/sycl.hpp> |
| 2 | +#include "utils.h" |
| 3 | +#include "dispatch_utils.h" |
| 4 | +#include <cmath> |
| 5 | +#include <c10/macros/Macros.h> |
| 6 | + |
| 7 | +namespace vllm { |
| 8 | + |
| 9 | +template <typename scalar_t, bool IS_NEOX> |
| 10 | +inline void apply_token_rotary_embedding(scalar_t* __restrict__ arr, |
| 11 | + const scalar_t* __restrict__ cos_ptr, |
| 12 | + const scalar_t* __restrict__ sin_ptr, |
| 13 | + int rot_offset, int embed_dim) { |
| 14 | + int x_index, y_index; |
| 15 | + scalar_t cos, sin; |
| 16 | + if (IS_NEOX) { |
| 17 | + // GPT-NeoX style rotary embedding. |
| 18 | + x_index = rot_offset; |
| 19 | + y_index = embed_dim + rot_offset; |
| 20 | + cos = cos_ptr[x_index]; |
| 21 | + sin = sin_ptr[x_index]; |
| 22 | + } else { |
| 23 | + // GPT-J style rotary embedding. |
| 24 | + x_index = 2 * rot_offset; |
| 25 | + y_index = 2 * rot_offset + 1; |
| 26 | + cos = cos_ptr[x_index / 2]; |
| 27 | + sin = sin_ptr[x_index / 2]; |
| 28 | + } |
| 29 | + |
| 30 | + const scalar_t x = arr[x_index]; |
| 31 | + const scalar_t y = arr[y_index]; |
| 32 | + arr[x_index] = x * cos - y * sin; |
| 33 | + arr[y_index] = y * cos + x * sin; |
| 34 | +} |
| 35 | + |
| 36 | +template <typename scalar_t, bool IS_NEOX> |
| 37 | +inline void apply_rotary_embedding( |
| 38 | + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, |
| 39 | + // head_size] or [num_tokens, num_heads, |
| 40 | + // head_size] |
| 41 | + scalar_t* __restrict__ key, // nullptr or |
| 42 | + // [batch_size, seq_len, num_kv_heads, |
| 43 | + // head_size] or [num_tokens, num_kv_heads, |
| 44 | + // head_size] |
| 45 | + const scalar_t* cache_ptr, const int head_size, const int num_heads, |
| 46 | + const int num_kv_heads, const int rot_dim, const int token_idx, |
| 47 | + const int64_t query_stride, const int64_t key_stride, |
| 48 | + const int64_t head_stride, const sycl::nd_item<3>& item_ct1) { |
| 49 | + const int embed_dim = rot_dim / 2; |
| 50 | + const scalar_t* cos_ptr = cache_ptr; |
| 51 | + const scalar_t* sin_ptr = cache_ptr + embed_dim; |
| 52 | + |
| 53 | + const int nq = num_heads * embed_dim; |
| 54 | + for (int i = item_ct1.get_local_id(2); i < nq; |
| 55 | + i += item_ct1.get_local_range(2)) { |
| 56 | + const int head_idx = i / embed_dim; |
| 57 | + const int64_t token_head = |
| 58 | + token_idx * query_stride + head_idx * head_stride; |
| 59 | + const int rot_offset = i % embed_dim; |
| 60 | + apply_token_rotary_embedding<scalar_t, IS_NEOX>( |
| 61 | + query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); |
| 62 | + } |
| 63 | + |
| 64 | + if (key != nullptr) { |
| 65 | + const int nk = num_kv_heads * embed_dim; |
| 66 | + for (int i = item_ct1.get_local_id(2); i < nk; |
| 67 | + i += item_ct1.get_local_range(2)) { |
| 68 | + const int head_idx = i / embed_dim; |
| 69 | + const int64_t token_head = |
| 70 | + token_idx * key_stride + head_idx * head_stride; |
| 71 | + const int rot_offset = i % embed_dim; |
| 72 | + apply_token_rotary_embedding<scalar_t, IS_NEOX>( |
| 73 | + key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); |
| 74 | + } |
| 75 | + } |
| 76 | +} |
| 77 | + |
| 78 | +template <typename scalar_t, bool IS_NEOX> |
| 79 | +void rotary_embedding_kernel( |
| 80 | + const int64_t* __restrict__ positions, // [batch_size, seq_len] or |
| 81 | + // [num_tokens] |
| 82 | + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, |
| 83 | + // head_size] or [num_tokens, num_heads, |
| 84 | + // head_size] |
| 85 | + scalar_t* __restrict__ key, // nullptr or |
| 86 | + // [batch_size, seq_len, num_kv_heads, |
| 87 | + // head_size] or [num_tokens, num_kv_heads, |
| 88 | + // head_size] |
| 89 | + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // |
| 90 | + // 2] |
| 91 | + const int rot_dim, const int64_t query_stride, const int64_t key_stride, |
| 92 | + const int64_t head_stride, const int num_heads, const int num_kv_heads, |
| 93 | + const int head_size, const sycl::nd_item<3>& item_ct1) { |
| 94 | + // Each thread block is responsible for one token. |
| 95 | + const int token_idx = item_ct1.get_group(2); |
| 96 | + int64_t pos = positions[token_idx]; |
| 97 | + const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; |
| 98 | + |
| 99 | + apply_rotary_embedding<scalar_t, IS_NEOX>( |
| 100 | + query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, |
| 101 | + token_idx, query_stride, key_stride, head_stride, item_ct1); |
| 102 | +} |
| 103 | + |
| 104 | +} // namespace vllm |
| 105 | + |
| 106 | +template <typename scalar_t> |
| 107 | +void call_rotary_embedding_kernel( |
| 108 | + torch::Tensor& positions, torch::Tensor& query, |
| 109 | + std::optional<torch::Tensor> key, int64_t head_size, |
| 110 | + torch::Tensor& cos_sin_cache, // [max_position, rot_dim] |
| 111 | + bool is_neox) { |
| 112 | + using sycl_t = vllm::xpu::SyclTypeTrait<scalar_t>::Type; |
| 113 | + // num_tokens = batch_size * seq_len |
| 114 | + int64_t num_tokens = positions.numel(); |
| 115 | + int positions_ndim = positions.dim(); |
| 116 | + |
| 117 | + // Make sure num_tokens dim is consistent across positions, query, and key |
| 118 | + TORCH_CHECK( |
| 119 | + positions_ndim == 1 || positions_ndim == 2, |
| 120 | + "positions must have shape [num_tokens] or [batch_size, seq_len]"); |
| 121 | + if (positions_ndim == 1) { |
| 122 | + TORCH_CHECK(query.size(0) == positions.size(0) && |
| 123 | + (!key.has_value() || key->size(0) == positions.size(0)), |
| 124 | + "query, key and positions must have the same number of tokens"); |
| 125 | + } |
| 126 | + if (positions_ndim == 2) { |
| 127 | + TORCH_CHECK( |
| 128 | + query.size(0) == positions.size(0) && |
| 129 | + (!key.has_value() || key->size(0) == positions.size(0)) && |
| 130 | + query.size(1) == positions.size(1) && |
| 131 | + (!key.has_value() || key->size(1) == positions.size(1)), |
| 132 | + "query, key and positions must have the same batch_size and seq_len"); |
| 133 | + } |
| 134 | + |
| 135 | + // Make sure head_size is valid for query and key |
| 136 | + // hidden_size = num_heads * head_size |
| 137 | + int query_hidden_size = query.numel() / num_tokens; |
| 138 | + int key_hidden_size = key.has_value() ? key->numel() / num_tokens : 0; |
| 139 | + TORCH_CHECK(query_hidden_size % head_size == 0); |
| 140 | + TORCH_CHECK(key_hidden_size % head_size == 0); |
| 141 | + |
| 142 | + // Make sure query and key have consistent number of heads |
| 143 | + int num_heads = query_hidden_size / head_size; |
| 144 | + int num_kv_heads = key.has_value() ? key_hidden_size / head_size : num_heads; |
| 145 | + TORCH_CHECK(num_heads % num_kv_heads == 0); |
| 146 | + |
| 147 | + int rot_dim = cos_sin_cache.size(1); |
| 148 | + int seq_dim_idx = positions_ndim - 1; |
| 149 | + int64_t query_stride = query.stride(seq_dim_idx); |
| 150 | + int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0; |
| 151 | + // Determine head stride: for [*, heads, head_size] use stride of last dim; |
| 152 | + // for flat [*, heads*head_size], heads blocks are contiguous of size |
| 153 | + // head_size |
| 154 | + int query_ndim = query.dim(); |
| 155 | + int64_t head_stride = |
| 156 | + (query_ndim == positions_ndim + 2) ? query.stride(-2) : head_size; |
| 157 | + |
| 158 | + auto positions_ptr = positions.data_ptr<int64_t>(); |
| 159 | + auto query_ptr = query.data_ptr<scalar_t>(); |
| 160 | + auto key_ptr = key.has_value() ? key->data_ptr<scalar_t>() : nullptr; |
| 161 | + auto cos_sin_cache_ptr = cos_sin_cache.data_ptr<scalar_t>(); |
| 162 | + |
| 163 | + sycl::range<3> grid(1, 1, num_tokens); |
| 164 | + sycl::range<3> block(1, 1, std::min<int64_t>(num_heads * rot_dim / 2, 512)); |
| 165 | + |
| 166 | + at::DeviceGuard device_guard(query.device()); |
| 167 | + auto& queue = vllm::xpu::vllmGetQueue(); |
| 168 | + if (is_neox) { |
| 169 | + queue.submit([&](sycl::handler& cgh) { |
| 170 | + cgh.parallel_for( |
| 171 | + sycl::nd_range<3>(grid * block, block), |
| 172 | + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { |
| 173 | + vllm::rotary_embedding_kernel<sycl_t, true>( |
| 174 | + positions_ptr, (sycl_t*)query_ptr, (sycl_t*)key_ptr, |
| 175 | + (sycl_t*)cos_sin_cache_ptr, rot_dim, query_stride, key_stride, |
| 176 | + head_stride, num_heads, num_kv_heads, head_size, item_ct1); |
| 177 | + }); |
| 178 | + }); |
| 179 | + } else { |
| 180 | + queue.submit([&](sycl::handler& cgh) { |
| 181 | + cgh.parallel_for( |
| 182 | + sycl::nd_range<3>(grid * block, block), |
| 183 | + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { |
| 184 | + vllm::rotary_embedding_kernel<sycl_t, false>( |
| 185 | + positions_ptr, (sycl_t*)query_ptr, (sycl_t*)key_ptr, |
| 186 | + (sycl_t*)cos_sin_cache_ptr, rot_dim, query_stride, key_stride, |
| 187 | + head_stride, num_heads, num_kv_heads, head_size, item_ct1); |
| 188 | + }); |
| 189 | + }); |
| 190 | + } |
| 191 | +} |
| 192 | + |
| 193 | +void rotary_embedding( |
| 194 | + torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] |
| 195 | + torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or |
| 196 | + // [num_tokens, num_heads * head_size] or |
| 197 | + // [batch_size, seq_len, num_heads, head_size] or |
| 198 | + // [num_tokens, num_heads, head_size] |
| 199 | + std::optional<torch::Tensor> key, |
| 200 | + // null or |
| 201 | + // [batch_size, seq_len, num_kv_heads * head_size] or |
| 202 | + // [num_tokens, num_kv_heads * head_size] or |
| 203 | + // [batch_size, seq_len, num_heads, head_size] or |
| 204 | + // [num_tokens, num_heads, head_size] |
| 205 | + int64_t head_size, |
| 206 | + torch::Tensor& cos_sin_cache, // [max_position, rot_dim] |
| 207 | + bool is_neox) { |
| 208 | + VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] { |
| 209 | + call_rotary_embedding_kernel<scalar_t>(positions, query, key, head_size, |
| 210 | + cos_sin_cache, is_neox); |
| 211 | + }); |
| 212 | +} |
0 commit comments