|
| 1 | +#include <ATen/ATen.h> |
| 2 | +#include <c10/xpu/XPUStream.h> |
| 3 | +#include <torch/all.h> |
| 4 | + |
| 5 | +#include <sycl/sycl.hpp> |
| 6 | + |
| 7 | +#include "SYCLHelpers.h" |
| 8 | +#include "Utils.h" |
| 9 | + |
| 10 | +namespace at::native::xpu { |
| 11 | + |
| 12 | +namespace TopKSoftmaxImpl { |
| 13 | + |
| 14 | +template <typename T> |
| 15 | +struct FusedTopkSoftmax { |
| 16 | + static constexpr int sub_group_size = 32; |
| 17 | + static constexpr int max_group_size = 1024; |
| 18 | + static constexpr int malloc_per_item = 8; |
| 19 | + static constexpr float kNegInfinity = -std::numeric_limits<float>::infinity(); |
| 20 | + |
| 21 | + FusedTopkSoftmax( |
| 22 | + float* topk_weights, |
| 23 | + int* topk_ids, |
| 24 | + const T* gating_output, |
| 25 | + const bool renormalize, |
| 26 | + const int tokens, |
| 27 | + const int experts, |
| 28 | + const int top_k) |
| 29 | + : topk_weights(topk_weights), |
| 30 | + topk_ids(topk_ids), |
| 31 | + gating_output(gating_output), |
| 32 | + renormalize(renormalize), |
| 33 | + tokens(tokens), |
| 34 | + experts(experts), |
| 35 | + top_k(top_k) {} |
| 36 | + |
| 37 | + static inline sycl::nd_range<3> get_nd_range(const int tokens, const int experts) { |
| 38 | + int calc_per_item = div_up(experts, sub_group_size); |
| 39 | + int group_size = div_up(experts, calc_per_item); |
| 40 | + group_size = group_size < sub_group_size ? sub_group_size : group_size; |
| 41 | + group_size = group_size < max_group_size ? group_size : max_group_size; |
| 42 | + int sub_groups_per_group = div_up(group_size, sub_group_size); |
| 43 | + group_size = sub_groups_per_group * sub_group_size; |
| 44 | + int global_size = div_up(tokens, sub_groups_per_group); |
| 45 | + |
| 46 | + sycl::range<3> local(1, 1, group_size); |
| 47 | + sycl::range<3> global(1, 1, global_size); |
| 48 | + return sycl::nd_range<3>(global * local, local); |
| 49 | + } |
| 50 | + |
| 51 | + static inline T Sigmoid(T x) { |
| 52 | + float sycl_x = static_cast<float>(x); |
| 53 | + float result = 1.0f / (1.0f + sycl::exp(-sycl_x)); |
| 54 | + return static_cast<T>(result); |
| 55 | + } |
| 56 | + |
| 57 | + [[sycl::reqd_sub_group_size(sub_group_size)]] void operator()(sycl::nd_item<3> item) const { |
| 58 | + int group_id = item.get_group_linear_id(); |
| 59 | + int local_range = item.get_local_range(2); |
| 60 | + int sub_groups_per_group = local_range / sub_group_size; |
| 61 | + int calc_per_item = (experts + sub_group_size - 1) / sub_group_size; |
| 62 | + |
| 63 | + sycl::sub_group sg = item.get_sub_group(); |
| 64 | + int sg_id = sg.get_group_id(); |
| 65 | + int sg_local_id = sg.get_local_id(); |
| 66 | + |
| 67 | + int tid = group_id * sub_groups_per_group + sg_id; |
| 68 | + |
| 69 | + if (tid >= tokens) { |
| 70 | + return; // Out of bounds |
| 71 | + } |
| 72 | + |
| 73 | + T local_elems[malloc_per_item]; |
| 74 | + int local_idx[malloc_per_item]; |
| 75 | + |
| 76 | + int start_offset = sg_local_id * calc_per_item; |
| 77 | + int local_num = calc_per_item; |
| 78 | + |
| 79 | + if (start_offset + local_num >= experts) { |
| 80 | + local_num = experts - start_offset; |
| 81 | + if (local_num < 0) { |
| 82 | + local_num = 0; // No elements to process |
| 83 | + } |
| 84 | + } |
| 85 | + |
| 86 | + for (int e = 0; e < calc_per_item; ++e) { |
| 87 | + local_elems[e] = kNegInfinity; |
| 88 | + local_idx[e] = -1; |
| 89 | + } |
| 90 | + |
| 91 | + for (int e = 0; e < local_num; ++e) { |
| 92 | + local_elems[e] = gating_output[tid * experts + start_offset + e]; |
| 93 | + local_idx[e] = start_offset + e; |
| 94 | + } |
| 95 | + |
| 96 | + // Perform top-k selection |
| 97 | + T topk_weights_local[malloc_per_item]; |
| 98 | + int topk_ids_local[malloc_per_item]; |
| 99 | + |
| 100 | + for (int k = 0; k < top_k; ++k) { |
| 101 | + T k_max = kNegInfinity; |
| 102 | + int k_max_idx = -1; |
| 103 | + int remove_ix = -1; |
| 104 | + for (int e = 0; e < calc_per_item; ++e) { |
| 105 | + T my_val = local_elems[e]; |
| 106 | + int my_idx = local_idx[e]; |
| 107 | + for (int offset = sub_group_size / 2; offset > 0; offset /= 2) { |
| 108 | + T other_val = sycl::permute_group_by_xor(sg, my_val, offset); |
| 109 | + int other_idx = sycl::permute_group_by_xor(sg, my_idx, offset); |
| 110 | + if (other_val > my_val || (other_val == my_val && other_idx < my_idx)) { |
| 111 | + my_val = other_val; |
| 112 | + my_idx = other_idx; |
| 113 | + } |
| 114 | + } |
| 115 | + if (my_val > k_max || (my_val == k_max && my_idx < k_max_idx)) { |
| 116 | + k_max = my_val; |
| 117 | + k_max_idx = my_idx; |
| 118 | + |
| 119 | + if (k_max_idx == local_idx[e]) { |
| 120 | + remove_ix = e; // Mark this index for removal |
| 121 | + } else |
| 122 | + remove_ix = -1; |
| 123 | + } |
| 124 | + } |
| 125 | + topk_weights_local[k] = k_max; |
| 126 | + topk_ids_local[k] = k_max_idx; |
| 127 | + if (remove_ix != -1) { |
| 128 | + // Reset the score to avoid re-selection |
| 129 | + local_elems[remove_ix] = kNegInfinity; |
| 130 | + local_idx[remove_ix] = -1; |
| 131 | + remove_ix = -1; |
| 132 | + } |
| 133 | + } |
| 134 | + |
| 135 | + float max_score = topk_weights_local[0]; |
| 136 | + float sum_exp = 0; |
| 137 | + |
| 138 | + for (int i = 0; i < top_k; ++i) { |
| 139 | + float score = topk_weights_local[i]; |
| 140 | + sum_exp += sycl::exp(score - max_score); |
| 141 | + } |
| 142 | + |
| 143 | + for (int e = 0; e < calc_per_item; ++e) { |
| 144 | + float score = local_elems[e]; |
| 145 | + float my_val = sycl::exp(score - max_score); |
| 146 | + for (int offset = sub_group_size / 2; offset > 0; offset /= 2) { |
| 147 | + float other_val = sycl::permute_group_by_xor(sg, my_val, offset); |
| 148 | + my_val += other_val; |
| 149 | + } |
| 150 | + sum_exp += my_val; |
| 151 | + } |
| 152 | + |
| 153 | + for (int i = 0; i < top_k; ++i) { |
| 154 | + float score = topk_weights_local[i]; |
| 155 | + topk_weights_local[i] = sycl::exp(score - max_score) / sum_exp; |
| 156 | + } |
| 157 | + |
| 158 | + if (renormalize) { |
| 159 | + // Renormalize the top-k weights |
| 160 | + float sum = 0; |
| 161 | + for (int i = 0; i < top_k; ++i) { |
| 162 | + sum += topk_weights_local[i]; |
| 163 | + } |
| 164 | + if (sum > 0) { |
| 165 | + for (int i = 0; i < top_k; ++i) { |
| 166 | + topk_weights_local[i] /= sum; |
| 167 | + } |
| 168 | + } |
| 169 | + } |
| 170 | + |
| 171 | + if (sg_local_id == 0) { |
| 172 | + int offset = tid * top_k; |
| 173 | + for (int i = 0; i < top_k; ++i) { |
| 174 | + topk_weights[offset + i] = topk_weights_local[i]; |
| 175 | + if (topk_ids_local[i] < 0 || topk_ids_local[i] >= experts) { |
| 176 | + // Ensure valid index |
| 177 | + topk_ids[offset + i] = 0; |
| 178 | + continue; |
| 179 | + } |
| 180 | + topk_ids[offset + i] = topk_ids_local[i]; |
| 181 | + } |
| 182 | + } |
| 183 | + } |
| 184 | + float* topk_weights; |
| 185 | + int* topk_ids; |
| 186 | + const T* gating_output; |
| 187 | + const bool renormalize; |
| 188 | + const int tokens; |
| 189 | + const int experts; |
| 190 | + const int top_k; |
| 191 | +}; |
| 192 | + |
| 193 | +template <typename T> |
| 194 | +void launch_fused_topk_softmax( |
| 195 | + sycl::queue& queue, |
| 196 | + const T* gating_output, |
| 197 | + float* topk_weights, |
| 198 | + int* topk_indices, |
| 199 | + const bool renormalize, |
| 200 | + const int top_k, |
| 201 | + const int num_tokens, |
| 202 | + const int num_experts) { |
| 203 | + using Kernel = FusedTopkSoftmax<T>; |
| 204 | + auto range = Kernel::get_nd_range(num_tokens, num_experts); |
| 205 | + |
| 206 | + auto global_range = range.get_global_range(); |
| 207 | + auto local_range = range.get_local_range(); |
| 208 | + |
| 209 | + Kernel task(topk_weights, topk_indices, gating_output, renormalize, num_tokens, num_experts, top_k); |
| 210 | + |
| 211 | + sycl_kernel_submit(global_range, local_range, queue, task); |
| 212 | + return; |
| 213 | +} |
| 214 | + |
| 215 | +template <typename T> |
| 216 | +void fused_topk_softmax( |
| 217 | + const T* gating_output, |
| 218 | + float* topk_weights, |
| 219 | + int* topk_indices, |
| 220 | + const bool renormalize, |
| 221 | + const int num_tokens, |
| 222 | + const int num_experts, |
| 223 | + const int topk) { |
| 224 | + auto stream = at::xpu::getCurrentXPUStream(); |
| 225 | + auto queue = stream.queue(); |
| 226 | + |
| 227 | + launch_fused_topk_softmax( |
| 228 | + queue, gating_output, topk_weights, topk_indices, renormalize, topk, num_tokens, num_experts); |
| 229 | +}; |
| 230 | +}; // namespace TopKSoftmaxImpl |
| 231 | + |
| 232 | +/** |
| 233 | + * @brief Perform topk after softmax on gating_output. |
| 234 | + * @param topk_weights The topk_weights tensor of shape [n_tokens, n_topk]. |
| 235 | + * @param topk_indices The topk_indices tensor of shape [n_tokens, n_topk]. |
| 236 | + * @param gating_output The gating output tensor of shape [n_tokens, n_experts]. |
| 237 | + * @param renormalize The renormalize bool whether the topk_weights needs to be renormalized. |
| 238 | + * @return void. |
| 239 | + */ |
| 240 | +void topk_softmax(at::Tensor& topk_weights, at::Tensor& topk_indices, at::Tensor& gating_output, bool renormalize) { |
| 241 | + auto shape = gating_output.sizes().vec(); |
| 242 | + TORCH_CHECK(shape.size() == 2, "gating_output must be 2D tensor, but got ", shape.size(), "D"); |
| 243 | + int64_t n_tokens = shape[0]; |
| 244 | + int64_t n_experts = shape[1]; |
| 245 | + |
| 246 | + TORCH_CHECK(n_experts <= 128, "n_experts only support up to 128, but got ", n_experts); |
| 247 | + |
| 248 | + TORCH_CHECK(topk_weights.scalar_type() == at::kFloat, "topk_weights should be Float"); |
| 249 | + TORCH_CHECK(topk_indices.scalar_type() == at::kInt, "topk_indices should be Int"); |
| 250 | + |
| 251 | + constexpr int64_t alignment = 8; |
| 252 | + int64_t n_experts_aligned = div_up(n_experts, alignment) * alignment; // align to 8 |
| 253 | + |
| 254 | + int64_t n_topk = topk_weights.size(1); |
| 255 | + |
| 256 | + AT_DISPATCH_REDUCED_FLOATING_TYPES(gating_output.scalar_type(), "fused_topk_softmax_kernel", [&]() { |
| 257 | + TopKSoftmaxImpl::fused_topk_softmax<scalar_t>( |
| 258 | + gating_output.data_ptr<scalar_t>(), |
| 259 | + topk_weights.data_ptr<float>(), |
| 260 | + topk_indices.data_ptr<int>(), |
| 261 | + renormalize, |
| 262 | + n_tokens, |
| 263 | + n_experts, |
| 264 | + n_topk); |
| 265 | + }); |
| 266 | +} |
| 267 | +} // namespace at::native::xpu |
0 commit comments