|
16 | 16 | */ |
17 | 17 |
|
18 | 18 | #include "tensorrt_llm/common/opUtils.h" |
| 19 | +#include "tensorrt_llm/kernels/helixAllToAll.h" |
19 | 20 | #include "tensorrt_llm/runtime/torchUtils.h" |
20 | 21 | #include "tensorrt_llm/runtime/utils/mpiUtils.h" |
| 22 | +#include "tensorrt_llm/thop/thUtils.h" |
21 | 23 |
|
22 | | -#include <NvInferRuntime.h> |
23 | | -#include <c10/cuda/CUDAStream.h> |
24 | | -#include <cassert> |
25 | | -#include <set> |
26 | | -#include <string> |
27 | | -#include <torch/extension.h> |
28 | 24 | #include <vector> |
29 | | -#if ENABLE_MULTI_DEVICE |
30 | | -#include <nccl.h> |
31 | | -#endif // ENABLE_MULTI_DEVICE |
32 | 25 |
|
33 | 26 | TRTLLM_NAMESPACE_BEGIN |
34 | 27 |
|
@@ -119,16 +112,163 @@ std::vector<torch::Tensor> alltoall_helix( |
119 | 112 | #endif // ENABLE_MULTI_DEVICE |
120 | 113 | } |
121 | 114 |
|
| 115 | +/** |
| 116 | + * Helix All-to-All operation with two fields. |
| 117 | + * |
| 118 | + * Input tensors have shape [..., cp_size, kv_lora_rank] for partial_o and [..., |
| 119 | + * cp_size, 2] for softmax_stats. The operation exchanges data along the cp_size |
| 120 | + * dimension across all ranks. |
| 121 | + * |
| 122 | + * @param partial_o Field 0 tensor (half precision, shape [..., cp_size, |
| 123 | + * kv_lora_rank]) |
| 124 | + * @param softmax_stats Field 1 tensor (float32, shape [..., cp_size, 2]) |
| 125 | + * @param workspace Workspace tensor (uint64, strided across ranks) |
| 126 | + * @param cp_rank Current context parallel rank |
| 127 | + * @param cp_size Total number of context parallel ranks |
| 128 | + * @return tuple of (partial_o_out, softmax_stats_out) with same shapes as inputs |
| 129 | + */ |
| 130 | +std::tuple<torch::Tensor, torch::Tensor> alltoall_helix_native( |
| 131 | + torch::Tensor partial_o, torch::Tensor softmax_stats, torch::Tensor workspace, int64_t cp_rank, int64_t cp_size) |
| 132 | +{ |
| 133 | + |
| 134 | + // Input validation |
| 135 | + CHECK_TH_CUDA(partial_o); |
| 136 | + CHECK_TH_CUDA(softmax_stats); |
| 137 | + CHECK_TH_CUDA(workspace); |
| 138 | + CHECK_CONTIGUOUS(partial_o); |
| 139 | + CHECK_CONTIGUOUS(softmax_stats); |
| 140 | + |
| 141 | + // Type checks |
| 142 | + TORCH_CHECK(partial_o.scalar_type() == at::ScalarType::Half || partial_o.scalar_type() == at::ScalarType::BFloat16, |
| 143 | + "partial_o must be half or bfloat16"); |
| 144 | + CHECK_TYPE(softmax_stats, at::ScalarType::Float); |
| 145 | + CHECK_TYPE(workspace, at::ScalarType::UInt64); |
| 146 | + |
| 147 | + // Shape validation |
| 148 | + TORCH_CHECK(partial_o.dim() >= 2, "partial_o must have at least 2 dimensions"); |
| 149 | + TORCH_CHECK(softmax_stats.dim() >= 2, "softmax_stats must have at least 2 dimensions"); |
| 150 | + TORCH_CHECK( |
| 151 | + partial_o.dim() == softmax_stats.dim(), "partial_o and softmax_stats must have same number of dimensions"); |
| 152 | + |
| 153 | + // Get dimensions |
| 154 | + int kv_lora_rank = partial_o.size(-1); |
| 155 | + TORCH_CHECK(partial_o.size(-2) == cp_size && softmax_stats.size(-2) == cp_size, |
| 156 | + "partial_o/softmax_stats second-to-last dimension must equal cp_size"); |
| 157 | + TORCH_CHECK(softmax_stats.size(-1) % 2 == 0 && softmax_stats.size(-1) >= 2, |
| 158 | + "softmax_stats last dimension must be divisible by 2 (float2)"); |
| 159 | + bool allowVariableField1 = softmax_stats.size(-1) > 2; |
| 160 | + |
| 161 | + // Check that leading dimensions match |
| 162 | + for (int i = 0; i < partial_o.dim() - 2; i++) |
| 163 | + { |
| 164 | + TORCH_CHECK(partial_o.size(i) == softmax_stats.size(i), |
| 165 | + "partial_o and softmax_stats must have matching dimensions except last two"); |
| 166 | + } |
| 167 | + TORCH_CHECK(partial_o.size(-1) * partial_o.element_size() % 16 == 0, "partial_o must be aligned to 16 bytes"); |
| 168 | + |
| 169 | + TORCH_CHECK(workspace.dim() == 2, "workspace must be 2D (strided across ranks)"); |
| 170 | + TORCH_CHECK(workspace.size(0) == cp_size, "workspace must have cp_size rows"); |
| 171 | + |
| 172 | + // Calculate entry count (product of all dimensions before cp_size) |
| 173 | + // This is the number of entries to process per peer rank |
| 174 | + int entry_count = 1; |
| 175 | + for (int i = 0; i < partial_o.dim() - 2; i++) |
| 176 | + { |
| 177 | + entry_count *= partial_o.size(i); |
| 178 | + } |
| 179 | + |
| 180 | + // Reshape to 3D: [entry_count, cp_size, feature_dim] |
| 181 | + torch::Tensor partial_o_3d = partial_o.reshape({entry_count, cp_size, kv_lora_rank}); |
| 182 | + torch::Tensor softmax_stats_3d = softmax_stats.reshape({entry_count, cp_size, softmax_stats.size(-1)}); |
| 183 | + |
| 184 | + // Allocate output tensors (same shape as input) |
| 185 | + torch::Tensor partial_o_out = torch::empty_like(partial_o); |
| 186 | + torch::Tensor softmax_stats_out = torch::empty_like(softmax_stats); |
| 187 | + |
| 188 | + torch::Tensor partial_o_out_3d = partial_o_out.reshape({entry_count, cp_size, kv_lora_rank}); |
| 189 | + torch::Tensor softmax_stats_out_3d = softmax_stats_out.reshape({entry_count, cp_size, softmax_stats.size(-1)}); |
| 190 | + |
| 191 | + // Setup parameters |
| 192 | + tensorrt_llm::kernels::HelixAllToAllParams params; |
| 193 | + |
| 194 | + // Field 0 (variable size half) |
| 195 | + params.sendFields[0].dataPtr = reinterpret_cast<uint8_t*>(partial_o_3d.data_ptr()); |
| 196 | + params.sendFields[0].elementCount = kv_lora_rank; |
| 197 | + params.sendFields[0].elementSize = partial_o.element_size(); |
| 198 | + params.sendFields[0].stride = partial_o_3d.stride(1) * partial_o.element_size(); |
| 199 | + |
| 200 | + params.recvFields[0].dataPtr = reinterpret_cast<uint8_t*>(partial_o_out_3d.data_ptr()); |
| 201 | + params.recvFields[0].elementCount = kv_lora_rank; |
| 202 | + params.recvFields[0].elementSize = partial_o.element_size(); |
| 203 | + params.recvFields[0].stride = partial_o_out_3d.stride(1) * partial_o.element_size(); |
| 204 | + |
| 205 | + // Field 1 (single float2) |
| 206 | + params.sendFields[1].dataPtr = reinterpret_cast<uint8_t*>(softmax_stats_3d.data_ptr<float>()); |
| 207 | + params.sendFields[1].elementCount = softmax_stats.size(-1); |
| 208 | + params.sendFields[1].elementSize = softmax_stats.element_size(); |
| 209 | + params.sendFields[1].stride = softmax_stats_3d.stride(1) * softmax_stats.element_size(); |
| 210 | + |
| 211 | + params.recvFields[1].dataPtr = reinterpret_cast<uint8_t*>(softmax_stats_out_3d.data_ptr<float>()); |
| 212 | + params.recvFields[1].elementCount = softmax_stats.size(-1); |
| 213 | + params.recvFields[1].elementSize = softmax_stats.element_size(); |
| 214 | + params.recvFields[1].stride = softmax_stats_out_3d.stride(1) * softmax_stats.element_size(); |
| 215 | + |
| 216 | + // Entry count and workspace |
| 217 | + params.entryCount = entry_count; |
| 218 | + params.workspace = workspace.data_ptr<uint64_t>(); |
| 219 | + params.workspaceStrideInU64 = workspace.stride(0); |
| 220 | + |
| 221 | + // CP info |
| 222 | + params.cpRank = cp_rank; |
| 223 | + params.cpSize = cp_size; |
| 224 | + params.channelCount = 0; // auto-compute |
| 225 | + params.maxChannelCount = tensorrt_llm::kernels::computeHelixMaxChannelCount(cp_size); |
| 226 | + |
| 227 | + // Launch kernel |
| 228 | + auto stream = at::cuda::getCurrentCUDAStream(); |
| 229 | + tensorrt_llm::kernels::launchHelixAllToAll(params, allowVariableField1, stream); |
| 230 | + |
| 231 | + return std::make_tuple(partial_o_out, softmax_stats_out); |
| 232 | +} |
| 233 | + |
| 234 | +/** |
| 235 | + * Initialize workspace for helix all-to-all |
| 236 | + */ |
| 237 | +void initialize_helix_workspace(torch::Tensor workspace, int64_t cp_rank, int64_t cp_size) |
| 238 | +{ |
| 239 | + CHECK_TH_CUDA(workspace); |
| 240 | + CHECK_TYPE(workspace, at::ScalarType::UInt64); |
| 241 | + TORCH_CHECK(workspace.dim() == 2, "workspace must be 2D"); |
| 242 | + TORCH_CHECK(workspace.size(0) == cp_size, "workspace must have cp_size rows"); |
| 243 | + TORCH_CHECK(cp_rank >= 0 && cp_rank < cp_size, "cp_rank must be in [0, cp_size)"); |
| 244 | + |
| 245 | + auto stream = at::cuda::getCurrentCUDAStream(); |
| 246 | + uint64_t* global_workspace_ptr = workspace.data_ptr<uint64_t>(); |
| 247 | + uint64_t* local_workspace_ptr = workspace[cp_rank].data_ptr<uint64_t>(); |
| 248 | + TORCH_CHECK(local_workspace_ptr == global_workspace_ptr + cp_rank * workspace.stride(0), |
| 249 | + "local_workspace_ptr must be at the correct offset in the global " |
| 250 | + "workspace"); |
| 251 | + tensorrt_llm::kernels::initializeHelixWorkspace(local_workspace_ptr, cp_size, stream); |
| 252 | +} |
| 253 | + |
122 | 254 | } // namespace torch_ext |
123 | 255 |
|
124 | 256 | TRTLLM_NAMESPACE_END |
125 | 257 |
|
126 | 258 | TORCH_LIBRARY_FRAGMENT(trtllm, m) |
127 | 259 | { |
128 | 260 | m.def("alltoall_helix(Tensor[] input_list, int[] group, int? num_lists) -> Tensor[]"); |
| 261 | + m.def( |
| 262 | + "alltoall_helix_native(Tensor partial_o, Tensor softmax_stats, Tensor(a!) workspace, int " |
| 263 | + "cp_rank, int cp_size) -> (Tensor, Tensor)"); |
| 264 | + m.def( |
| 265 | + "initialize_helix_workspace(Tensor(a!) workspace, int cp_rank, int cp_size) " |
| 266 | + "-> ()"); |
129 | 267 | } |
130 | 268 |
|
131 | 269 | TORCH_LIBRARY_IMPL(trtllm, CUDA, m) |
132 | 270 | { |
133 | 271 | m.impl("alltoall_helix", &tensorrt_llm::torch_ext::alltoall_helix); |
| 272 | + m.impl("alltoall_helix_native", &tensorrt_llm::torch_ext::alltoall_helix_native); |
| 273 | + m.impl("initialize_helix_workspace", &tensorrt_llm::torch_ext::initialize_helix_workspace); |
134 | 274 | } |
0 commit comments