diff --git a/CMakeLists.txt b/CMakeLists.txt index e3c09e3..4577237 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -148,6 +148,7 @@ if(VLLM_GPU_LANG STREQUAL "SYCL") set(VLLM_EXT_SRC "csrc/xpu/cache.cpp" "csrc/xpu/layernorm.cpp" + "csrc/xpu/grouped_topk.cpp" "csrc/xpu/activation.cpp" "csrc/xpu/pos_encoding_kernels.cpp" "csrc/xpu/torch_bindings.cpp" diff --git a/benchmark/benchmark_grouped_topk.py b/benchmark/benchmark_grouped_topk.py new file mode 100644 index 0000000..a68a65b --- /dev/null +++ b/benchmark/benchmark_grouped_topk.py @@ -0,0 +1,193 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import time +from argparse import ArgumentParser +from typing import Optional + +import torch + +from tests.ops.grouped_topk import grouped_topk, grouped_topk_native + +dpcpp_device = torch.device("xpu") + + +@torch.compile +def grouped_topk_compile( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int, + topk_group: int, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + assert hidden_states.size(0) == gating_output.size(0), ( + "Number of tokens mismatch") + if scoring_func == "softmax": + gating_output = gating_output.to(torch.float32) + scores = torch.softmax(gating_output, dim=-1) + elif scoring_func == "sigmoid": + scores = gating_output.sigmoid() + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + + num_token = scores.shape[0] + if e_score_correction_bias is not None: + # Store original scores before applying correction bias. We use biased + # scores for expert selection but original scores for routing weights + e_score_correction_bias = e_score_correction_bias.to(torch.float32) + original_scores = scores + scores = scores + e_score_correction_bias.unsqueeze(0) + group_scores = (scores.view(num_token, num_expert_group, + -1).topk(2, dim=-1)[0].sum(dim=-1)) + else: + group_scores = (scores.view(num_token, num_expert_group, + -1).max(dim=-1).values) # [n, n_group] + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, + sorted=True)[1] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = (group_mask.unsqueeze(-1).expand( + num_token, num_expert_group, + scores.shape[-1] // num_expert_group).reshape(num_token, -1)) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), + float("-inf")) # [n, e] + + if e_score_correction_bias is not None: + topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=True)[1] + # Use original unbiased scores for the routing weights + topk_weights = original_scores.gather(1, topk_ids) + else: + topk_weights, topk_ids = torch.topk(tmp_scores, + k=topk, + dim=-1, + sorted=True) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + return topk_weights.to(torch.float32), topk_ids.to(torch.int32) + + +@torch.inference_mode() +def main( + dtype: torch.dtype, + num_tokens: int, + num_experts: int, + topk: int, + renormalize: bool, + num_expert_group: int, + topk_group: int, + scoring_func: str = "softmax", + has_bias: bool = False, + seed: int = 0, + num_warmup_iters: int = 5, + num_iters: int = 100, + provider: str = "vllm", +) -> None: + torch.manual_seed(seed) + torch.set_default_device("xpu") + + gating_output = torch.randn(num_tokens, num_experts, + device=dpcpp_device).to(dtype) + hidden_states = torch.zeros(num_tokens, num_experts, + device=dpcpp_device).to(dtype) + bias = None + if has_bias: + if has_bias and scoring_func == "sigmoid" \ + and dtype is not torch.float32: + # using a bias of bigger number to avoid Low-precision + bias = torch.arange(1, num_experts + 1).to(dpcpp_device).to(dtype) + else: + bias = torch.randn(num_experts, device=dpcpp_device).to(dtype) + + def run_xpu_benchmark(num_iters: int) -> float: + torch.xpu.synchronize() + + start_time = time.perf_counter() + + if provider == "vllm": + run_op = grouped_topk + elif provider == "native": + run_op = grouped_topk_native + elif provider == "compile": + run_op = grouped_topk_compile + else: + raise ValueError(f"Unsupported provider: {provider}") + for _ in range(num_iters): + topk_weights, topk_indices = run_op( + hidden_states, + gating_output, + topk, + renormalize, + num_expert_group, + topk_group, + scoring_func=scoring_func, + e_score_correction_bias=bias, + ) + torch.xpu.synchronize() + + end_time = time.perf_counter() + + return (end_time - start_time) / num_iters + + # Warmup. + print("Warming up...") + run_benchmark = run_xpu_benchmark + run_benchmark(num_iters=num_warmup_iters) + + # Benchmark. + latency = run_benchmark(num_iters=num_iters) + print(f"Kernel running time: {latency * 1000000:.3f} us") + + +if __name__ == "__main__": + parser = ArgumentParser(description="Benchmark the grouped topk kernel.") + parser.add_argument("--num-tokens", type=int, default=64) + parser.add_argument("--num-experts", type=int, default=128) + parser.add_argument("--topk", type=int, default=6) + parser.add_argument("--renormalize", action="store_true") + parser.add_argument("--num-expert-group", type=int, default=8) + parser.add_argument("--topk-group", type=int, default=8) + parser.add_argument("--scoring-func", + type=str, + choices=["sigmoid", "softmax"], + default="softmax") + parser.add_argument("--has-bias", action="store_true") + parser.add_argument("--dtype", + type=str, + choices=["half", "bfloat16", "float"], + default="half") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--num-warmup-iters", type=int, default=5) + parser.add_argument("--num-iters", + type=int, + default=100, + help="Number of benchmark iterations. ") + parser.add_argument("--provider", + type=str, + choices=["vllm", "native", "compile"], + default="vllm") + + args = parser.parse_args() + print(args) + + # Convert dtype string to torch.dtype + dtype = getattr(torch, args.dtype) + main( + dtype=dtype, + num_tokens=args.num_tokens, + num_experts=args.num_experts, + topk=args.topk, + renormalize=args.renormalize, + num_expert_group=args.num_expert_group, + topk_group=args.topk_group, + scoring_func=args.scoring_func, + has_bias=args.has_bias, + seed=args.seed, + num_warmup_iters=args.num_warmup_iters, + num_iters=args.num_iters, + provider=args.provider, + ) diff --git a/csrc/xpu/grouped_topk.cpp b/csrc/xpu/grouped_topk.cpp new file mode 100644 index 0000000..cf1c648 --- /dev/null +++ b/csrc/xpu/grouped_topk.cpp @@ -0,0 +1,434 @@ +#include + +#include "utils.h" +#include "dispatch_utils.h" + +namespace vllm { +namespace GroupedTopKImpl { + +enum class ScoringFunc { + DEFAULT = 0, + SOFTMAX = 1, + SIGMOID = 2, +}; + +template +struct Fused_Grouped_Topk { + static constexpr int sub_group_size = 32; + static constexpr int max_group_size = 1024; + static constexpr int malloc_per_item = MAX_EXPERT_GROUPS; + static constexpr float kNegInfinity = INFINITY * -1; + + Fused_Grouped_Topk(float* topk_weights, int* topk_ids, const T* gating_output, + const T* e_score_correction_bias, + const ScoringFunc scoring_mode, const bool renormalize, + const int tokens, const int experts, const int top_k, + const int num_expert_group, const int topk_group) + : topk_weights(topk_weights), + topk_ids(topk_ids), + gating_output(gating_output), + e_score_correction_bias(e_score_correction_bias), + scoring_mode(scoring_mode), + renormalize(renormalize), + tokens(tokens), + experts(experts), + top_k(top_k), + num_expert_group(num_expert_group), + topk_group(topk_group) {} + + static inline sycl::nd_range<3> get_nd_range(const int tokens, + const int experts) { + int calc_per_item = (experts + sub_group_size - 1) / sub_group_size; + int group_size = (experts + calc_per_item - 1) / calc_per_item; + group_size = group_size < sub_group_size ? sub_group_size : group_size; + group_size = group_size < max_group_size ? group_size : max_group_size; + int sub_groups_per_group = + (group_size + sub_group_size - 1) / sub_group_size; + group_size = sub_groups_per_group * sub_group_size; + int global_size = + (tokens + sub_groups_per_group - 1) / sub_groups_per_group; + + sycl::range<3> local(1, 1, group_size); + sycl::range<3> global(1, 1, global_size); + return sycl::nd_range<3>(global * local, local); + } + + static inline float Sigmoid(float x) { + return 1.0f / (1.0f + sycl::native::exp(-x)); + } + + [[sycl::reqd_sub_group_size(sub_group_size)]] void operator()( + sycl::nd_item<3> item) const { + int group_id = item.get_group_linear_id(); + int local_range = item.get_local_range(2); + int sub_groups_per_group = local_range / sub_group_size; + int calc_per_item = (experts + sub_group_size - 1) / sub_group_size; + + int experts_per_group = experts / num_expert_group; + + sycl::sub_group sg = item.get_sub_group(); + int sg_id = sg.get_group_id(); + int sg_local_id = sg.get_local_id(); + + int tid = group_id * sub_groups_per_group + sg_id; + + if (tid >= tokens) { + return; // Out of bounds + } + + T load_elems[malloc_per_item]; + int local_idx[malloc_per_item]; + T bias[malloc_per_item]; + + int start_offset = sg_local_id * calc_per_item; + int local_num = calc_per_item; + + if (start_offset + local_num >= experts) { + local_num = experts - start_offset; + if (local_num < 0) { + local_num = 0; // No elements to process + } + } + + for (int e = 0; e < calc_per_item; ++e) { + load_elems[e] = kNegInfinity; + local_idx[e] = -1; + bias[e] = 0.0f; // Initialize bias to zero + } + + for (int e = 0; e < local_num; ++e) { + load_elems[e] = gating_output[tid * experts + start_offset + e]; + } + + float local_elems[malloc_per_item]; + + for (int e = 0; e < local_num; ++e) { + local_elems[e] = load_elems[e]; + local_idx[e] = start_offset + e; + } + + if (scoring_mode == ScoringFunc::SOFTMAX) { + float softmax_max = kNegInfinity; + for (int e = 0; e < local_num; ++e) { + softmax_max = + (softmax_max > local_elems[e]) ? softmax_max : local_elems[e]; + } + for (int offset = sub_group_size / 2; offset > 0; offset /= 2) { + float other_val = sycl::permute_group_by_xor(sg, softmax_max, offset); + softmax_max = (softmax_max > other_val) ? softmax_max : other_val; + } + float softmax_sum = 0.0f; + for (int e = 0; e < local_num; ++e) { + float s = local_elems[e]; + softmax_sum += sycl::native::exp(s - softmax_max); + } + for (int offset = sub_group_size / 2; offset > 0; offset /= 2) { + float other_val = sycl::permute_group_by_xor(sg, softmax_sum, offset); + softmax_sum += other_val; + } + for (int e = 0; e < local_num; ++e) { + float s = local_elems[e]; + local_elems[e] = sycl::native::exp(s - softmax_max) / softmax_sum; + } + } else if (scoring_mode == ScoringFunc::SIGMOID) { + for (int e = 0; e < local_num; ++e) { + float s = load_elems[e]; + load_elems[e] = Sigmoid(s); + } + for (int e = 0; e < local_num; ++e) { + local_elems[e] = load_elems[e]; + } + } + + bool has_bias = e_score_correction_bias != nullptr; + if (has_bias) { + for (int e = 0; e < local_num; ++e) { + bias[e] = e_score_correction_bias[start_offset + e]; + } + } + + // perform topk_group groups + // 1 calculate each group scores + float group_scores[malloc_per_item * 2]; + for (int i = 0; i < num_expert_group * 2; ++i) { + group_scores[i] = kNegInfinity; + } + for (int i = 0; i < local_num; ++i) { + float b = bias[i]; + float score = local_elems[i] + b; + int i_group = (calc_per_item * sg_local_id + i) / experts_per_group; + float group_max = group_scores[i_group]; + float group_next_max = group_scores[num_expert_group + i_group]; + if (score > group_max) { + group_next_max = group_max; + group_max = score; + } else if (score > group_next_max) { + group_next_max = score; + } + group_scores[i_group] = group_max; + group_scores[num_expert_group + i_group] = group_next_max; + } + for (int i = 0; i < num_expert_group; ++i) { + float group_max = group_scores[i]; + float group_next_max = group_scores[num_expert_group + i]; + + float max1 = sycl::reduce_over_group( + sg, sycl::max(group_max, group_next_max), sycl::maximum<>()); + float local_second = + (group_max < max1 && group_max > -INFINITY) ? group_max : -INFINITY; + local_second = (group_next_max < max1 && group_next_max > local_second) + ? group_next_max + : local_second; + float max2 = sycl::reduce_over_group(sg, local_second, sycl::maximum<>()); + group_scores[i] = max1 + (has_bias ? max2 : 0.0f); + } + + // 2 find topk_group groups as kNegInfinity + int group_topk_idx[malloc_per_item]; + for (int k = 0; k < topk_group; ++k) { + float k_max = group_scores[0]; + int k_max_idx = 0; + for (int e = 1; e < num_expert_group; ++e) { + float score = group_scores[e]; + + if (score > k_max) { + k_max = score; + k_max_idx = e; + } + } + group_scores[k_max_idx] = kNegInfinity; + group_topk_idx[k] = k_max_idx; + } + + // 3 mask no-topk_group groups + for (int i = 0; i < calc_per_item; ++i) { + bool is_masked = true; + for (int k = 0; k < topk_group; ++k) { + if ((local_idx[i] / experts_per_group) == group_topk_idx[k]) { + is_masked = false; + break; + } + } + if (is_masked) { + local_elems[i] = kNegInfinity; + } + } + + // Perform top-k selection + float topk_weights_local[malloc_per_item]; + int topk_ids_local[malloc_per_item]; + + for (int k = 0; k < top_k; ++k) { + float k_max = kNegInfinity; + int k_max_idx = -1; + int remove_ix = -1; + for (int e = 0; e < calc_per_item; ++e) { + float le = local_elems[e]; + float b = bias[e]; + float my_val = le + b; + int my_idx = local_idx[e]; + for (int offset = sub_group_size / 2; offset > 0; offset /= 2) { + float other_val = sycl::permute_group_by_xor(sg, my_val, offset); + int other_idx = sycl::permute_group_by_xor(sg, my_idx, offset); + if (other_val > my_val || + (other_val == my_val && other_idx < my_idx)) { + my_val = other_val; + my_idx = other_idx; + } + } + if (my_val > k_max || (my_val == k_max && my_idx < k_max_idx)) { + k_max = my_val; + k_max_idx = my_idx; + + if (k_max_idx == local_idx[e]) { + remove_ix = e; // Mark this index for removal + } else + remove_ix = -1; + } + } + + int select_item = k_max_idx / calc_per_item; + int select_elem = k_max_idx % calc_per_item; + k_max = local_elems[select_elem]; + k_max = sycl::group_broadcast(sg, k_max, select_item); + if (remove_ix != -1) { + local_elems[remove_ix] = + kNegInfinity; // Reset the score to avoid re-selection + local_idx[remove_ix] = -1; + remove_ix = -1; + } + + topk_weights_local[k] = k_max; + topk_ids_local[k] = k_max_idx < 0 ? k : k_max_idx; + } + + if (renormalize) { + // Renormalize the top-k weights + float sum = 0; + for (int i = 0; i < top_k; ++i) { + sum += topk_weights_local[i]; + } + if (sum > 0) { + for (int i = 0; i < top_k; ++i) { + topk_weights_local[i] /= sum; + } + } + } + + if (sg_local_id == 0) { + int offset = tid * top_k; + for (int i = 0; i < top_k; ++i) { + topk_weights[offset + i] = topk_weights_local[i]; + if (!(topk_ids_local[i] >= 0 && topk_ids_local[i] < experts)) { + // Ensure valid index + topk_ids[offset + i] = 0; + continue; + } + topk_ids[offset + i] = topk_ids_local[i]; + } + } + } + float* topk_weights; + int* topk_ids; + const T* gating_output; + const T* e_score_correction_bias; + const ScoringFunc scoring_mode; + const bool renormalize; + const int tokens; + const int experts; + const int top_k; + const int num_expert_group; + const int topk_group; +}; + +template +void launch_fused_grouped_topk(sycl::queue& queue, float* topk_weights, + int* topk_ids, const T* gating_output, + const T* e_score_correction_bias, + const ScoringFunc scoring_mode, + const bool renormalize, const int tokens, + const int experts, const int top_k, + const int num_expert_group, + const int topk_group) { + using Kernel = Fused_Grouped_Topk; + auto range = Kernel::get_nd_range(tokens, experts); + + queue.submit([&](sycl::handler& cgh) { + Kernel task(topk_weights, topk_ids, gating_output, e_score_correction_bias, + scoring_mode, renormalize, tokens, experts, top_k, + num_expert_group, topk_group); + cgh.parallel_for(range, task); + }); +} + +template +void fused_grouped_topk(float* topk_weights, int* topk_ids, + const T* gating_output, + const T* e_score_correction_bias, + const ScoringFunc scoring_mode, const bool renormalize, + const int tokens, const int experts, const int top_k, + const int num_expert_group, const int topk_group) { + auto& queue = vllm::xpu::vllmGetQueue(); + + TORCH_CHECK(topk_group <= num_expert_group, + "topk_group must be less than or equal to num_expert_group"); + TORCH_CHECK(experts % num_expert_group == 0, + "The number of experts (experts=", experts, + ") must be divisible by num_expert_group (", num_expert_group, + ")."); + + int max_expert_group = ((num_expert_group + 7) / 8) * 8; +#define CASE_TOPK(K) \ + case K: \ + launch_fused_grouped_topk( \ + queue, topk_weights, topk_ids, gating_output, e_score_correction_bias, \ + scoring_mode, renormalize, tokens, experts, top_k, num_expert_group, \ + topk_group); \ + break; + switch (max_expert_group) { + CASE_TOPK(8) + CASE_TOPK(16) + default: + TORCH_CHECK(false, "error: not support num_expert_group=%d,\n", + num_expert_group); + } +#undef CASE_TOPK +} + +}; // namespace GroupedTopKImpl +} // namespace vllm + +/** + * @brief Perform grouped topk after sigmoid/addbias on gating_output. + * @param gating_output The gating output tensor of shape [n_tokens, n_experts]. + * @param n_topk The number of top experts to select. + * @param n_topk_group The number of top experts to select in the group. + * @return A tuple of tensors (topk_weights, topk_indices). + */ +std::tuple grouped_topk( + const torch::Tensor& hidden_states, const torch::Tensor& gating_output, + const int64_t n_topk, const bool renormalize, const int64_t n_expert_group, + const int64_t n_topk_group, const c10::string_view scoring_func, + const c10::optional& bias) { + auto shape = gating_output.sizes().vec(); + TORCH_CHECK(hidden_states.sizes()[0] == gating_output.sizes()[0], + "Number of tokens mismatch") + TORCH_CHECK(shape.size() == 2, "gating_output must be 2D tensor, but got ", + shape.size(), "D"); + if (bias.has_value()) { + auto shape_bias = bias->sizes().vec(); + TORCH_CHECK( + shape_bias[0] == shape[1], + "gating_output and bias must has same innermost dimension, but got ", + shape, " and ", shape_bias); + } + int n_tokens = shape[0]; + int n_experts = shape[1]; + + vllm::GroupedTopKImpl::ScoringFunc scoring_mode; + if (scoring_func == "sigmoid") { + scoring_mode = vllm::GroupedTopKImpl::ScoringFunc::SIGMOID; + } else if (scoring_func == "softmax") { + scoring_mode = vllm::GroupedTopKImpl::ScoringFunc::SOFTMAX; + } else { + scoring_mode = vllm::GroupedTopKImpl::ScoringFunc::DEFAULT; + } + + auto topk_weights = + torch::empty({n_tokens, n_topk}, at::dtype(at::kFloat).device(at::kXPU)); + auto topk_indices = + torch::empty({n_tokens, n_topk}, at::dtype(at::kInt).device(at::kXPU)); + + if (gating_output.scalar_type() == at::kBFloat16) { + using scalar_t = sycl::ext::oneapi::bfloat16; + vllm::GroupedTopKImpl::fused_grouped_topk( + reinterpret_cast(topk_weights.data_ptr()), + reinterpret_cast(topk_indices.data_ptr()), + reinterpret_cast(gating_output.data_ptr()), + bias.has_value() ? reinterpret_cast(bias->data_ptr()) + : nullptr, + scoring_mode, renormalize, n_tokens, n_experts, n_topk, n_expert_group, + n_topk_group); + } else if (gating_output.scalar_type() == at::kHalf) { + using scalar_t = sycl::half; + vllm::GroupedTopKImpl::fused_grouped_topk( + reinterpret_cast(topk_weights.data_ptr()), + reinterpret_cast(topk_indices.data_ptr()), + reinterpret_cast(gating_output.data_ptr()), + bias.has_value() ? reinterpret_cast(bias->data_ptr()) + : nullptr, + scoring_mode, renormalize, n_tokens, n_experts, n_topk, n_expert_group, + n_topk_group); + } else { + using scalar_t = float; + vllm::GroupedTopKImpl::fused_grouped_topk( + reinterpret_cast(topk_weights.data_ptr()), + reinterpret_cast(topk_indices.data_ptr()), + reinterpret_cast(gating_output.data_ptr()), + bias.has_value() ? reinterpret_cast(bias->data_ptr()) + : nullptr, + scoring_mode, renormalize, n_tokens, n_experts, n_topk, n_expert_group, + n_topk_group); + } + return std::make_tuple(topk_weights, topk_indices); +} \ No newline at end of file diff --git a/csrc/xpu/ops.h b/csrc/xpu/ops.h index 8dd9b71..588cd5a 100644 --- a/csrc/xpu/ops.h +++ b/csrc/xpu/ops.h @@ -8,6 +8,12 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, double epsilon); +std::tuple grouped_topk( + const torch::Tensor& hidden_states, const torch::Tensor& gating_output, + const int64_t n_topk, const bool renormalize, const int64_t n_expert_group, + const int64_t n_topk_group, const c10::string_view scoring_func, + const c10::optional& bias); + void silu_and_mul(torch::Tensor& out, torch::Tensor& input); void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, diff --git a/csrc/xpu/torch_bindings.cpp b/csrc/xpu/torch_bindings.cpp index 1c8fe57..2c1a485 100644 --- a/csrc/xpu/torch_bindings.cpp +++ b/csrc/xpu/torch_bindings.cpp @@ -32,6 +32,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "float epsilon) -> ()"); ops.impl("fused_add_rms_norm", torch::kXPU, &fused_add_rms_norm); + // Grouped TopK + ops.def( + "grouped_topk(Tensor hidden_states, Tensor gating_output, int n_topk, " + "bool renormalize, int n_expert_group, int n_topk_group, str " + "scoring_func, Tensor? bias=None) -> (Tensor, Tensor)"); + ops.impl("grouped_topk", torch::kXPU, &grouped_topk); + // activation ops ops.def("silu_and_mul(Tensor! out, Tensor! input) -> ()"); ops.impl("silu_and_mul", torch::kXPU, &silu_and_mul); diff --git a/tests/ops/grouped_topk.py b/tests/ops/grouped_topk.py new file mode 100644 index 0000000..fadb250 --- /dev/null +++ b/tests/ops/grouped_topk.py @@ -0,0 +1,80 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + + +def grouped_topk_native( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int, + topk_group: int, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + assert hidden_states.size(0) == gating_output.size(0), ( + "Number of tokens mismatch") + if scoring_func == "softmax": + gating_output = gating_output.to(torch.float32) + scores = torch.softmax(gating_output, dim=-1) + elif scoring_func == "sigmoid": + scores = gating_output.sigmoid() + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + + num_token = scores.shape[0] + if e_score_correction_bias is not None: + # Store original scores before applying correction bias. We use biased + # scores for expert selection but original scores for routing weights + e_score_correction_bias = e_score_correction_bias.to(torch.float32) + original_scores = scores + scores = scores + e_score_correction_bias.unsqueeze(0) + group_scores = (scores.view(num_token, num_expert_group, + -1).topk(2, dim=-1)[0].sum(dim=-1)) + else: + group_scores = (scores.view(num_token, num_expert_group, + -1).max(dim=-1).values) # [n, n_group] + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, + sorted=True)[1] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = (group_mask.unsqueeze(-1).expand( + num_token, num_expert_group, + scores.shape[-1] // num_expert_group).reshape(num_token, -1)) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), + float("-inf")) # [n, e] + + if e_score_correction_bias is not None: + topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=True)[1] + # Use original unbiased scores for the routing weights + topk_weights = original_scores.gather(1, topk_ids) + else: + topk_weights, topk_ids = torch.topk(tmp_scores, + k=topk, + dim=-1, + sorted=True) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + return topk_weights.to(torch.float32), topk_ids.to(torch.int32) + + +def grouped_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int, + topk_group: int, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + import tests.register_ops as ops + return ops.grouped_topk(hidden_states, gating_output, topk, renormalize, + num_expert_group, topk_group, scoring_func, + e_score_correction_bias) diff --git a/tests/register_ops.py b/tests/register_ops.py index 47df662..b24a75d 100644 --- a/tests/register_ops.py +++ b/tests/register_ops.py @@ -1,10 +1,21 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import TYPE_CHECKING, Optional + import torch -from typing import Optional import vllm_xpu_kernels._C # noqa: F401 +if TYPE_CHECKING: + + def register_fake(fn): + return lambda name: fn +else: + try: + from torch.library import register_fake + except ImportError: + from torch.library import impl_abstract as register_fake + # layer norm ops def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, @@ -20,6 +31,43 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon) +def grouped_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int, + topk_group: int, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + return torch.ops._C.grouped_topk(hidden_states, gating_output, topk, + renormalize, num_expert_group, topk_group, + scoring_func, e_score_correction_bias) + + +if hasattr(torch.ops._C, "grouped_topk"): + + @register_fake("_C::grouped_topk") + def _grouped_topk_fake( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int, + topk_group: int, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + topk_weights = torch.empty((gating_output.size(0), topk), + dtype=torch.float32, + device=hidden_states.device) + topk_indices = torch.empty((gating_output.size(0), topk), + dtype=torch.int32, + device=hidden_states.device) + return topk_weights, topk_indices + + def silu_and_mul(out: torch.Tensor, input: torch.Tensor) -> None: torch.ops._C.silu_and_mul(out, input) diff --git a/tests/test_grouped_topk.py b/tests/test_grouped_topk.py new file mode 100644 index 0000000..8b50730 --- /dev/null +++ b/tests/test_grouped_topk.py @@ -0,0 +1,127 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest +import torch + +from tests.ops.grouped_topk import grouped_topk, grouped_topk_native +from tests.utils import opcheck + +dpcpp_device = torch.device("xpu") + + +class TestTorchMethod: + + @pytest.mark.parametrize("seed", [123, 356, 478]) + @pytest.mark.parametrize("dtype", + [torch.float16, torch.bfloat16, torch.float32]) + @pytest.mark.parametrize("n_token", [1, 2, 64, 256]) + @pytest.mark.parametrize("n_expert", [64, 128, 256]) + @pytest.mark.parametrize("n_topk", [4, 6, 8]) + @pytest.mark.parametrize("n_topk_group", [4, 6, 8]) + @pytest.mark.parametrize("n_expert_group", [8]) + @pytest.mark.parametrize("renormalize", [True, False]) + @pytest.mark.parametrize("scoring_func", ["sigmoid", "softmax"]) + @pytest.mark.parametrize("has_bias", [True, False]) + def test_grouped_topk( + self, + seed, + dtype, + n_token, + n_expert, + n_topk, + n_expert_group, + n_topk_group, + renormalize, + scoring_func, + has_bias, + ): + + torch.manual_seed(seed) + torch.set_default_device("xpu") + gating_output = torch.randn(n_token, n_expert, + device=dpcpp_device).to(dtype) + hidden_states = torch.zeros(n_token, n_expert, + device=dpcpp_device).to(dtype) + bias = None + if has_bias: + if has_bias and scoring_func == "sigmoid" \ + and dtype is not torch.float32: + # using a bias of bigger number to avoid Low-precision + bias = torch.arange(1, n_expert + 1).to(dpcpp_device).to(dtype) + else: + bias = torch.randn(n_expert, device=dpcpp_device).to(dtype) + + ref_topk_weights, ref_topk_indices = grouped_topk_native( + hidden_states, + gating_output, + n_topk, + renormalize, + n_expert_group, + n_topk_group, + scoring_func=scoring_func, + e_score_correction_bias=bias, + ) + + topk_weights, topk_indices = grouped_topk( + hidden_states, + gating_output, + n_topk, + renormalize, + n_expert_group, + n_topk_group, + scoring_func=scoring_func, + e_score_correction_bias=bias, + ) + + # Compare the results + torch.testing.assert_close(ref_topk_weights, + topk_weights, + atol=2e-2, + rtol=1e-2) + assert torch.equal(ref_topk_indices, topk_indices) + + opcheck( + torch.ops._C.grouped_topk, + (hidden_states, gating_output, n_topk, renormalize, n_expert_group, + n_topk_group, scoring_func, bias), + ) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("renormalize", [True]) + @pytest.mark.parametrize("full_nan", [True, False]) + def test_grouped_topk_sigmoid_nan( + self, + dtype, + renormalize, + full_nan, + ): + n_token = 512 + n_expert = 256 + n_topk = 8 + n_expert_group = 8 + n_topk_group = 4 + + gating_output = torch.randn(n_token, n_expert, + device=dpcpp_device).to(dtype) + hidden_states = torch.zeros(n_token, n_expert, + device=dpcpp_device).to(dtype) + bias = torch.randn(n_expert, device=dpcpp_device).to(dtype) + + if full_nan: + gating_output = torch.full(gating_output.size(), + float("nan"), + device=dpcpp_device, + dtype=dtype).contiguous() + else: + gating_output[0][0] = float("nan") + + topk_weights, topk_indices = grouped_topk( + hidden_states, + gating_output, + n_topk, + renormalize, + n_expert_group, + n_topk_group, + e_score_correction_bias=bias, + ) + + assert torch.all(topk_indices < n_expert)