diff --git a/paddle/fluid/distributed/collective/deep_ep/deep_ep.cpp b/paddle/fluid/distributed/collective/deep_ep/deep_ep.cpp index bbabc308742b0b..671961ce88a50c 100644 --- a/paddle/fluid/distributed/collective/deep_ep/deep_ep.cpp +++ b/paddle/fluid/distributed/collective/deep_ep/deep_ep.cpp @@ -49,6 +49,7 @@ void SetAllocatorStreamForGPUContext(cudaStream_t stream, Buffer::Buffer(int rank, int num_ranks, + int num_loop_stage, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode, @@ -131,16 +132,18 @@ Buffer::Buffer(int rank, CUDA_CHECK(cudaMemsetAsync(workspace, 0, NUM_WORKSPACE_BYTES, comm_stream)); // MoE counter - CUDA_CHECK( - cudaMallocHost(&moe_recv_counter, sizeof(int64_t), cudaHostAllocMapped)); + CUDA_CHECK(cudaMallocHost(&moe_recv_counter, + sizeof(int64_t) * num_loop_stage, + cudaHostAllocMapped)); CUDA_CHECK(cudaHostGetDevicePointer( &moe_recv_counter_mapped, const_cast(moe_recv_counter), 0)); *moe_recv_counter = -1; // MoE expert-level counter - CUDA_CHECK(cudaMallocHost(&moe_recv_expert_counter, - sizeof(int) * NUM_MAX_LOCAL_EXPERTS, - cudaHostAllocMapped)); + CUDA_CHECK( + cudaMallocHost(&moe_recv_expert_counter, + sizeof(int) * NUM_MAX_LOCAL_EXPERTS * num_loop_stage, + cudaHostAllocMapped)); CUDA_CHECK(cudaHostGetDevicePointer(&moe_recv_expert_counter_mapped, const_cast(moe_recv_expert_counter), 0)); @@ -149,8 +152,9 @@ Buffer::Buffer(int rank, // MoE RDMA-level counter if (num_rdma_ranks > 0) { - CUDA_CHECK(cudaMallocHost( - &moe_recv_rdma_counter, sizeof(int), cudaHostAllocMapped)); + CUDA_CHECK(cudaMallocHost(&moe_recv_rdma_counter, + sizeof(int) * num_loop_stage, + cudaHostAllocMapped)); CUDA_CHECK(cudaHostGetDevicePointer(&moe_recv_rdma_counter_mapped, const_cast(moe_recv_rdma_counter), 0)); @@ -1881,6 +1885,196 @@ Buffer::internode_notify_combine( send_nvl_head}; } +std::tuple, + std::vector, + deep_ep::detail::Tensor, + deep_ep::detail::Tensor, + deep_ep::detail::Tensor, + deep_ep::detail::Tensor, + deep_ep::detail::Tensor> +Buffer::internode_fused_notify_combine( + const deep_ep::detail::Tensor& x, + const std::optional& x_scales, + const std::optional& topk_idx, + const std::optional& num_tokens_per_rank, + const std::optional& num_tokens_per_rdma_rank, + const deep_ep::detail::Tensor& is_token_in_rank, + int num_loop_stage, + const Config& config) { + const int num_channels = config.num_sms / 2; + EP_HOST_ASSERT(config.num_sms % 2 == 0); + EP_HOST_ASSERT(0 < get_num_rdma_ranks() && + get_num_rdma_ranks() <= NUM_MAX_RDMA_PEERS); + + EP_HOST_ASSERT(num_tokens_per_rank->scalar_type() == deep_ep::detail::kInt32); + EP_HOST_ASSERT(num_tokens_per_rdma_rank->scalar_type() == + deep_ep::detail::kInt32); + + // Shape and contiguous checks + EP_HOST_ASSERT(x.dim() == 2 && x.is_contiguous()); + EP_HOST_ASSERT((x.size(1) * x.element_size()) % sizeof(int4) == 0); + EP_HOST_ASSERT(num_tokens_per_rank->dim() == 2 && + num_tokens_per_rank->is_contiguous()); + EP_HOST_ASSERT(num_tokens_per_rdma_rank->dim() == 2 && + num_tokens_per_rdma_rank->is_contiguous()); + EP_HOST_ASSERT(num_tokens_per_rank->size(0) == num_loop_stage); + EP_HOST_ASSERT(num_tokens_per_rdma_rank->size(0) == num_loop_stage); + + int num_scales = 0; + if (x_scales.has_value()) { + num_scales = x_scales->dim() == 1 ? 1 : static_cast(x_scales->size(1)); + } + + auto num_tokens = static_cast(x.size(0)), + hidden = static_cast(x.size(1)), + hidden_int4 = + static_cast(x.size(1) * x.element_size() / sizeof(int4)); + + // Top-k checks + int num_topk = 0; + if (topk_idx.has_value()) { + num_topk = static_cast(topk_idx->size(1)); + EP_HOST_ASSERT(topk_idx->dim() == 2 && topk_idx->is_contiguous()); + EP_HOST_ASSERT(num_tokens == topk_idx->size(0)); + } + + // Allocate all tensors on comm stream if set + // NOTES: do not allocate tensors upfront! + auto compute_stream = calc_ctx->stream(); + stream_wait(comm_stream, compute_stream); + + auto rdma_channel_prefix_matrix = + ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( + {num_loop_stage, num_rdma_ranks, num_channels}, + phi::DataType::INT32, + phi::GPUPlace(device_id))); + auto recv_rdma_rank_prefix_sum = ConvertPaddleTensorToDetailTensor( + paddle::experimental::empty({num_loop_stage, num_rdma_ranks}, + phi::DataType::INT32, + phi::GPUPlace(device_id))); + auto gbl_channel_prefix_matrix = ConvertPaddleTensorToDetailTensor( + paddle::experimental::empty({num_loop_stage, num_ranks, num_channels}, + phi::DataType::INT32, + phi::GPUPlace(device_id))); + auto recv_gbl_rank_prefix_sum = ConvertPaddleTensorToDetailTensor( + paddle::experimental::empty({num_loop_stage, num_ranks}, + phi::DataType::INT32, + phi::GPUPlace(device_id))); + + auto recv_rdma_channel_prefix_matrix = + ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( + {num_loop_stage, num_rdma_ranks, num_channels}, + phi::DataType::INT32, + phi::GPUPlace(device_id))); + auto recv_gbl_channel_prefix_matrix = ConvertPaddleTensorToDetailTensor( + paddle::experimental::empty({num_loop_stage, num_ranks, num_channels}, + phi::DataType::INT32, + phi::GPUPlace(device_id))); + + auto send_rdma_head = + ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( + {num_loop_stage, num_tokens, num_ranks / NUM_MAX_NVL_PEERS}, + phi::DataType::INT32, + phi::GPUPlace(device_id))); + auto send_nvl_head = + ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( + {num_loop_stage, num_tokens, num_ranks / NUM_MAX_NVL_PEERS, 8}, + phi::DataType::INT32, + phi::GPUPlace(device_id))); + + // Send sizes + for (int i = 0; i < num_loop_stage; ++i) { + moe_recv_counter[i] = -1; + moe_recv_rdma_counter[i] = -1; + } + internode::fused_notify_combine( + num_tokens_per_rank->data_ptr(), + moe_recv_counter_mapped, + num_ranks, + num_tokens_per_rdma_rank->data_ptr(), + moe_recv_rdma_counter_mapped, + is_token_in_rank.data_ptr(), + num_tokens, + num_channels, + hidden_int4, + num_scales, + num_topk, + num_loop_stage, + rdma_channel_prefix_matrix.data_ptr(), + recv_rdma_rank_prefix_sum.data_ptr(), + gbl_channel_prefix_matrix.data_ptr(), + recv_gbl_rank_prefix_sum.data_ptr(), + send_rdma_head.data_ptr(), + send_nvl_head.data_ptr(), + rdma_buffer_ptr, + config.num_max_rdma_chunked_recv_tokens, + buffer_ptrs_gpu, + config.num_max_nvl_chunked_recv_tokens, + task_fifo_ptrs_gpu, + head, + rank, + comm_stream, + config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), + num_nvl_bytes, + low_latency_mode); + + internode::fused_notify_combine_post_step( + num_ranks, + num_channels, + num_loop_stage, + recv_gbl_rank_prefix_sum.data_ptr(), + rdma_channel_prefix_matrix.data_ptr(), + gbl_channel_prefix_matrix.data_ptr(), + recv_rdma_channel_prefix_matrix.data_ptr(), + recv_gbl_channel_prefix_matrix.data_ptr(), + rdma_buffer_ptr, + buffer_ptrs_gpu, + task_fifo_ptrs_gpu, + head, + rank, + comm_stream, + low_latency_mode); + + // Synchronize total received tokens and tokens per expert + auto start_time = std::chrono::high_resolution_clock::now(); + while (true) { + // Read total count + bool ready = true; + for (int i = 0; i < num_loop_stage; ++i) { + ready = ready && (moe_recv_counter[i] >= 0) && + (moe_recv_rdma_counter[i] >= 0); + } + + if (ready) break; + + // Timeout check + if (std::chrono::duration_cast( + std::chrono::high_resolution_clock::now() - start_time) + .count() > NUM_CPU_TIMEOUT_SECS) { + LOG(INFO) << "Global rank: " << rank + << ", num_recv_tokens: " << moe_recv_counter[0] + << ", num_rdma_recv_tokens: " << moe_recv_rdma_counter[0]; + throw std::runtime_error("DeepEP error: timeout (dispatch CPU)"); + } + } + + std::vector num_recv_tokens(moe_recv_counter, + moe_recv_counter + num_loop_stage); + std::vector num_rdma_recv_tokens(moe_recv_rdma_counter, + moe_recv_rdma_counter + num_loop_stage); + + // Wait streams + stream_wait(compute_stream, comm_stream); + + return {num_recv_tokens, + num_rdma_recv_tokens, + recv_rdma_rank_prefix_sum, + recv_rdma_channel_prefix_matrix, + recv_gbl_channel_prefix_matrix, + send_rdma_head, + send_nvl_head}; +} + #endif // PADDLE_WITH_NVSHMEM void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, @@ -2690,6 +2884,201 @@ Buffer::internode_notify_dispatch( recv_gbl_rank_prefix_sum}; } +std::tuple>, + std::vector, + std::vector, + deep_ep::detail::Tensor, + deep_ep::detail::Tensor, + deep_ep::detail::Tensor, + deep_ep::detail::Tensor> +Buffer::internode_fused_notify_dispatch( + const deep_ep::detail::Tensor& x, + const std::optional& x_scales, + const std::optional& topk_idx, + const std::optional& num_tokens_per_rank, + const std::optional& num_tokens_per_rdma_rank, + const std::optional& num_tokens_per_expert, + const deep_ep::detail::Tensor& is_token_in_rank, + int expert_alignment, + int num_loop_stage, + const Config& config) { + const int num_channels = config.num_sms / 2; + EP_HOST_ASSERT(config.num_sms % 2 == 0); + EP_HOST_ASSERT(0 < get_num_rdma_ranks() && + get_num_rdma_ranks() <= NUM_MAX_RDMA_PEERS); + + EP_HOST_ASSERT(num_tokens_per_rank.has_value()); + EP_HOST_ASSERT(num_tokens_per_rdma_rank.has_value()); + EP_HOST_ASSERT(num_tokens_per_expert.has_value()); + + // Type checks + EP_HOST_ASSERT(num_tokens_per_rank->scalar_type() == deep_ep::detail::kInt32); + EP_HOST_ASSERT(num_tokens_per_rdma_rank->scalar_type() == + deep_ep::detail::kInt32); + EP_HOST_ASSERT(num_tokens_per_expert->scalar_type() == + deep_ep::detail::kInt32); + + // Shape and contiguous checks + EP_HOST_ASSERT(x.dim() == 2 && x.is_contiguous()); + EP_HOST_ASSERT((x.size(1) * x.element_size()) % sizeof(int4) == 0); + EP_HOST_ASSERT(num_tokens_per_rank->dim() == 2 && + num_tokens_per_rank->is_contiguous()); + EP_HOST_ASSERT(num_tokens_per_rdma_rank->dim() == 2 && + num_tokens_per_rdma_rank->is_contiguous()); + EP_HOST_ASSERT(num_tokens_per_expert->dim() == 2 && + num_tokens_per_expert->is_contiguous()); + EP_HOST_ASSERT(num_tokens_per_rank->size(1) == num_ranks); + EP_HOST_ASSERT(num_tokens_per_rdma_rank->size(1) == num_rdma_ranks); + EP_HOST_ASSERT(num_tokens_per_expert->size(1) % num_ranks == 0); + EP_HOST_ASSERT(num_tokens_per_expert->size(1) / num_ranks <= + NUM_MAX_LOCAL_EXPERTS); + + auto num_tokens = static_cast(x.size(0)), + hidden = static_cast(x.size(1)), + hidden_int4 = + static_cast(x.size(1) * x.element_size() / sizeof(int4)); + + // Top-k checks + int num_topk = 0; + int64_t* topk_idx_ptr = nullptr; + if (topk_idx.has_value()) { + num_topk = static_cast(topk_idx->size(1)); + EP_HOST_ASSERT(topk_idx->dim() == 2 && topk_idx->is_contiguous()); + EP_HOST_ASSERT(num_tokens == topk_idx->size(0)); + EP_HOST_ASSERT(num_topk == topk_idx->size(1)); + topk_idx_ptr = topk_idx->data_ptr(); + } + auto num_experts = static_cast(num_tokens_per_expert->size(1)); + int num_local_experts = num_experts / num_ranks; + + // FP8 scales checks + float* x_scales_ptr = nullptr; + int num_scales = 0; + if (x_scales.has_value()) { + EP_HOST_ASSERT(x.element_size() == 1); + EP_HOST_ASSERT(x_scales->scalar_type() == deep_ep::detail::kFloat32); + EP_HOST_ASSERT(x_scales->dim() > 0 && x_scales->dim() < 3 && + x_scales->is_contiguous()); + EP_HOST_ASSERT(x_scales->size(0) == num_tokens); + num_scales = x_scales->dim() == 1 ? 1 : static_cast(x_scales->size(1)); + } + + auto rdma_channel_prefix_matrix = + ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( + {num_loop_stage, num_rdma_ranks, num_channels}, + phi::DataType::INT32, + phi::GPUPlace(device_id))); + auto recv_rdma_rank_prefix_sum = ConvertPaddleTensorToDetailTensor( + paddle::experimental::empty({num_loop_stage, num_rdma_ranks}, + phi::DataType::INT32, + phi::GPUPlace(device_id))); + auto gbl_channel_prefix_matrix = ConvertPaddleTensorToDetailTensor( + paddle::experimental::empty({num_loop_stage, num_ranks, num_channels}, + phi::DataType::INT32, + phi::GPUPlace(device_id))); + auto recv_gbl_rank_prefix_sum = ConvertPaddleTensorToDetailTensor( + paddle::experimental::empty({num_loop_stage, num_ranks}, + phi::DataType::INT32, + phi::GPUPlace(device_id))); + + auto compute_stream = calc_ctx->stream(); + stream_wait(comm_stream, compute_stream); + + // Send sizes + for (int s = 0; s < num_loop_stage; ++s) { + moe_recv_counter[s] = -1; + moe_recv_rdma_counter[s] = -1; + for (int i = 0; i < num_local_experts; ++i) + moe_recv_expert_counter[s * num_local_experts + i] = -1; + } + + internode::fused_notify_dispatch( + num_tokens_per_rank->data_ptr(), + moe_recv_counter_mapped, + num_ranks, + num_tokens_per_rdma_rank->data_ptr(), + moe_recv_rdma_counter_mapped, + num_tokens_per_expert->data_ptr(), + moe_recv_expert_counter_mapped, + num_experts, + is_token_in_rank.data_ptr(), + num_tokens, + num_channels, + hidden_int4, + num_scales, + num_topk, + expert_alignment, + num_loop_stage, + rdma_channel_prefix_matrix.data_ptr(), + recv_rdma_rank_prefix_sum.data_ptr(), + gbl_channel_prefix_matrix.data_ptr(), + recv_gbl_rank_prefix_sum.data_ptr(), + rdma_buffer_ptr, + config.num_max_rdma_chunked_recv_tokens, + buffer_ptrs_gpu, + config.num_max_nvl_chunked_recv_tokens, + task_fifo_ptrs_gpu, + head, + rank, + comm_stream, + config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), + num_nvl_bytes, + low_latency_mode); + move_fifo_slots(3); + + // Synchronize total received tokens and tokens per expert + auto start_time = std::chrono::high_resolution_clock::now(); + while (true) { + bool ready = true; + for (int s = 0; s < num_loop_stage && ready; ++s) { + // Read total count + ready &= (moe_recv_counter[s] >= 0) && (moe_recv_rdma_counter[s] >= 0); + // Read per-expert count + for (int i = 0; i < num_local_experts && ready; ++i) + ready &= moe_recv_expert_counter[s * num_local_experts + i] >= 0; + } + + if (ready) break; + + // Timeout check + if (std::chrono::duration_cast( + std::chrono::high_resolution_clock::now() - start_time) + .count() > NUM_CPU_TIMEOUT_SECS) { + for (int s = 0; s < num_loop_stage; ++s) { + LOG(INFO) << "Global rank: " << rank << ", stage: " << s + << ", num_recv_tokens: " << moe_recv_counter[s] + << ", num_rdma_recv_tokens: " << moe_recv_rdma_counter[s]; + for (int i = 0; i < num_local_experts; ++i) + LOG(INFO) << " moe_recv_expert_counter[" << i << "]: " + << moe_recv_expert_counter[s * num_local_experts + i]; + throw std::runtime_error("DeepEP error: timeout (dispatch CPU)"); + } + } + } + std::vector num_recv_tokens(moe_recv_counter, + moe_recv_counter + num_loop_stage); + std::vector num_rdma_recv_tokens(moe_recv_rdma_counter, + moe_recv_rdma_counter + num_loop_stage); + + std::vector> num_recv_tokens_per_expert_list; + num_recv_tokens_per_expert_list.reserve(num_loop_stage); + for (int s = 0; s < num_loop_stage; ++s) { + num_recv_tokens_per_expert_list.emplace_back( + moe_recv_expert_counter + s * num_local_experts, + moe_recv_expert_counter + (s + 1) * num_local_experts); + } + + stream_wait(compute_stream, comm_stream); + + return {num_recv_tokens_per_expert_list, + num_recv_tokens, + num_rdma_recv_tokens, + rdma_channel_prefix_matrix, + gbl_channel_prefix_matrix, + recv_rdma_rank_prefix_sum, + recv_gbl_rank_prefix_sum}; +} + #endif // PADDLE_WITH_NVSHMEM std::tuple, + std::vector, + paddle::Tensor, + paddle::Tensor, + paddle::Tensor, + paddle::Tensor, + paddle::Tensor> +Buffer::internode_fused_notify_combine_api( + const paddle::Tensor& x, + const std::optional& x_scales, + const std::optional& topk_idx, + const std::optional& num_tokens_per_rank, + const std::optional& num_tokens_per_rdma_rank, + const paddle::Tensor& is_token_in_rank, + int num_loop_stages, + const Config& config) { +#ifdef PADDLE_WITH_NVSHMEM + const auto& x_ = ConvertPaddleTensorToDetailTensor(x); + std::optional x_scales_ = + ConvertOptionalPaddleTensorToDetailTensor(x_scales); + + std::optional topk_idx_ = + ConvertOptionalPaddleTensorToDetailTensor(topk_idx); + std::optional num_tokens_per_rank_ = + ConvertOptionalPaddleTensorToDetailTensor(num_tokens_per_rank); + std::optional num_tokens_per_rdma_rank_ = + ConvertOptionalPaddleTensorToDetailTensor(num_tokens_per_rdma_rank); + const auto& is_token_in_rank_ = + ConvertPaddleTensorToDetailTensor(is_token_in_rank); + + auto res = internode_fused_notify_combine(x_, + x_scales_, + topk_idx_, + num_tokens_per_rank_, + num_tokens_per_rdma_rank_, + is_token_in_rank_, + num_loop_stages, + config); + + auto num_recv_tokens_ = std::get<0>(res); + auto num_rdma_recv_tokens_ = std::get<1>(res); + auto recv_rdma_rank_prefix_sum_ = + ConvertDetailTensorToPaddleTensor(std::get<2>(res)); + + auto recv_rdma_channel_prefix_matrix_ = + ConvertDetailTensorToPaddleTensor(std::get<3>(res)); + + auto recv_gbl_channel_prefix_matrix_ = + ConvertDetailTensorToPaddleTensor(std::get<4>(res)); + + auto send_rdma_head_ = ConvertDetailTensorToPaddleTensor(std::get<5>(res)); + auto send_nvl_head_ = ConvertDetailTensorToPaddleTensor(std::get<6>(res)); + + return {num_recv_tokens_, + num_rdma_recv_tokens_, + recv_rdma_rank_prefix_sum_, + recv_rdma_channel_prefix_matrix_, + recv_gbl_channel_prefix_matrix_, + send_rdma_head_, + send_nvl_head_}; +#else + LOG(ERROR) << "NVSHMEM is not enabled. You can enable it by setting cmake " + "option WITH_NVSHMEM=ON."; + return {}; +#endif +} + std::tuple, paddle::Tensor, @@ -3562,6 +4018,81 @@ Buffer::internode_notify_dispatch_api( #endif } +std::tuple>, + std::vector, + std::vector, + paddle::Tensor, + paddle::Tensor, + paddle::Tensor, + paddle::Tensor> +Buffer::internode_fused_notify_dispatch_api( + const paddle::Tensor& x, + const std::optional& x_scales, + const std::optional& topk_idx, + const std::optional& num_tokens_per_rank, + const std::optional& num_tokens_per_rdma_rank, + const std::optional& num_tokens_per_expert, + const paddle::Tensor& is_token_in_rank, + int expert_alignment, + int num_loop_stage, + const Config& config) { +#ifdef PADDLE_WITH_NVSHMEM + const auto& x_ = ConvertPaddleTensorToDetailTensor(x); + std::optional x_scales_ = + ConvertOptionalPaddleTensorToDetailTensor(x_scales); + + std::optional topk_idx_ = + ConvertOptionalPaddleTensorToDetailTensor(topk_idx); + std::optional num_tokens_per_rank_ = + ConvertOptionalPaddleTensorToDetailTensor(num_tokens_per_rank); + std::optional num_tokens_per_rdma_rank_ = + ConvertOptionalPaddleTensorToDetailTensor(num_tokens_per_rdma_rank); + std::optional num_tokens_per_expert_ = + ConvertOptionalPaddleTensorToDetailTensor(num_tokens_per_expert); + const auto& is_token_in_rank_ = + ConvertPaddleTensorToDetailTensor(is_token_in_rank); + + auto res = internode_fused_notify_dispatch(x_, + x_scales_, + topk_idx_, + num_tokens_per_rank_, + num_tokens_per_rdma_rank_, + num_tokens_per_expert_, + is_token_in_rank_, + expert_alignment, + num_loop_stage, + config); + + auto num_recv_tokens_per_expert_list_ = std::get<0>(res); + auto num_recv_tokens_ = std::get<1>(res); + auto num_rdma_recv_tokens_ = std::get<2>(res); + + auto rdma_channel_prefix_matrix_ = + ConvertDetailTensorToPaddleTensor(std::get<3>(res)); + + auto gbl_channel_prefix_matrix_ = + ConvertDetailTensorToPaddleTensor(std::get<4>(res)); + + auto recv_rdma_rank_prefix_sum_ = + ConvertDetailTensorToPaddleTensor(std::get<5>(res)); + + auto recv_gbl_rank_prefix_sum_ = + ConvertDetailTensorToPaddleTensor(std::get<6>(res)); + + return {num_recv_tokens_per_expert_list_, + num_recv_tokens_, + num_rdma_recv_tokens_, + rdma_channel_prefix_matrix_, + gbl_channel_prefix_matrix_, + recv_rdma_rank_prefix_sum_, + recv_gbl_rank_prefix_sum_}; +#else + LOG(ERROR) << "NVSHMEM is not enabled. You can enable it by setting cmake " + "option WITH_NVSHMEM=ON."; + return {}; +#endif +} + deep_ep::detail::Tensor ConvertPaddleTensorToDetailTensor( const paddle::Tensor& tensor) { deep_ep::detail::Tensor res(tensor); diff --git a/paddle/fluid/distributed/collective/deep_ep/deep_ep.hpp b/paddle/fluid/distributed/collective/deep_ep/deep_ep.hpp index 8185ae7e7a51ec..5503ffd715f811 100644 --- a/paddle/fluid/distributed/collective/deep_ep/deep_ep.hpp +++ b/paddle/fluid/distributed/collective/deep_ep/deep_ep.hpp @@ -103,6 +103,7 @@ struct Buffer { public: Buffer(int rank, int num_ranks, + int num_loop_stage, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode, @@ -285,6 +286,25 @@ struct Buffer { int expert_alignment, const Config& config); + std::tuple>, + std::vector, + std::vector, + deep_ep::detail::Tensor, + deep_ep::detail::Tensor, + deep_ep::detail::Tensor, + deep_ep::detail::Tensor> + internode_fused_notify_dispatch( + const deep_ep::detail::Tensor& x, + const std::optional& x_scales, + const std::optional& topk_idx, + const std::optional& num_tokens_per_rank, + const std::optional& num_tokens_per_rdma_rank, + const std::optional& num_tokens_per_expert, + const deep_ep::detail::Tensor& is_token_in_rank, + int expert_alignment, + int num_loop_stage, + const Config& config); + std::tuple, + std::vector, + deep_ep::detail::Tensor, + deep_ep::detail::Tensor, + deep_ep::detail::Tensor, + deep_ep::detail::Tensor, + deep_ep::detail::Tensor> + internode_fused_notify_combine( + const deep_ep::detail::Tensor& x, + const std::optional& x_scales, + const std::optional& topk_idx, + const std::optional& num_tokens_per_rank, + const std::optional& num_tokens_per_rdma_rank, + const deep_ep::detail::Tensor& is_token_in_rank, + int num_loop_stage, + const Config& config); + #endif // PADDLE_WITH_NVSHMEM void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, @@ -464,6 +501,23 @@ struct Buffer { int expert_alignment, const Config& config); + std::tuple, + std::vector, + paddle::Tensor, + paddle::Tensor, + paddle::Tensor, + paddle::Tensor, + paddle::Tensor> + internode_fused_notify_combine_api( + const paddle::Tensor& x, + const std::optional& x_scales, + const std::optional& topk_idx, + const std::optional& num_tokens_per_rank, + const std::optional& num_tokens_per_rdma_rank, + const paddle::Tensor& is_token_in_rank, + int num_loop_stages, + const Config& config); + std::tuple, paddle::Tensor, @@ -605,6 +659,25 @@ struct Buffer { int expert_alignment, const Config& config); + std::tuple>, + std::vector, + std::vector, + paddle::Tensor, + paddle::Tensor, + paddle::Tensor, + paddle::Tensor> + internode_fused_notify_dispatch_api( + const paddle::Tensor& x, + const std::optional& x_scales, + const std::optional& topk_idx, + const std::optional& num_tokens_per_rank, + const std::optional& num_tokens_per_rdma_rank, + const std::optional& num_tokens_per_expert, + const paddle::Tensor& is_token_in_rank, + int expert_alignment, + int num_loop_stage, + const Config& config); + void clear_buffer_api(const paddle::Tensor& x, const std::optional& x_scales, const std::optional& topk_idx, diff --git a/paddle/fluid/distributed/collective/deep_ep/kernels/api.cuh b/paddle/fluid/distributed/collective/deep_ep/kernels/api.cuh index b474af49c4b1b5..9797b9e502ebfb 100644 --- a/paddle/fluid/distributed/collective/deep_ep/kernels/api.cuh +++ b/paddle/fluid/distributed/collective/deep_ep/kernels/api.cuh @@ -243,6 +243,84 @@ void notify_combine_post_step(int num_ranks, cudaStream_t stream, bool low_latency_mode); +void fused_notify_dispatch(const int* num_tokens_per_rank, + int* moe_recv_counter_mapped, + int num_ranks, + const int* num_tokens_per_rdma_rank, + int* moe_recv_rdma_counter_mapped, + const int* num_tokens_per_expert, + int* moe_recv_expert_counter_mapped, + int num_experts, + const bool* is_token_in_rank, + int num_tokens, + int num_channels, + int hidden_int4, + int num_scales, + int num_topk, + int expert_alignment, + int num_loop_stage, + int* rdma_channel_prefix_matrix, + int* recv_rdma_rank_prefix_sum, + int* gbl_channel_prefix_matrix, + int* recv_gbl_rank_prefix_sum, + void* rdma_buffer_ptr, + int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, + int num_max_nvl_chunked_recv_tokens, + int** task_fifo_ptrs, + int head, + int rank, + cudaStream_t stream, + int64_t num_rdma_bytes, + int64_t num_nvl_bytes, + bool low_latency_mode); + +void fused_notify_combine(const int* num_tokens_per_rank, + int* moe_recv_counter_mapped, + int num_ranks, + const int* num_tokens_per_rdma_rank, + int* moe_recv_rdma_counter_mapped, + const bool* is_token_in_rank, + int num_tokens, + int num_channels, + int hidden_int4, + int num_scales, + int num_topk, + int num_loop_stage, + int* rdma_channel_prefix_matrix, + int* recv_rdma_rank_prefix_sum, + int* gbl_channel_prefix_matrix, + int* recv_gbl_rank_prefix_sum, + int* send_rdma_head, + int* send_nvl_head, + void* rdma_buffer_ptr, + int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, + int num_max_nvl_chunked_recv_tokens, + int** task_fifo_ptrs, + int head, + int rank, + cudaStream_t stream, + int64_t num_rdma_bytes, + int64_t num_nvl_bytes, + bool low_latency_mode); + +void fused_notify_combine_post_step(int num_ranks, + int num_channels, + int num_loop_stage, + const int* recv_gbl_rank_prefix_sum, + const int* rdma_channel_prefix_matrix, + const int* gbl_channel_prefix_matrix, + int* recv_rdma_channel_prefix_matrix, + int* recv_gbl_channel_prefix_matrix, + void* rdma_buffer_ptr, + void** buffer_ptrs, + int** task_fifo_ptrs, + int head, + int rank, + cudaStream_t stream, + bool low_latency_mode); + void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, diff --git a/paddle/fluid/distributed/collective/deep_ep/kernels/internode.cu b/paddle/fluid/distributed/collective/deep_ep/kernels/internode.cu index e38239c4bdeadc..85ce27165483cb 100644 --- a/paddle/fluid/distributed/collective/deep_ep/kernels/internode.cu +++ b/paddle/fluid/distributed/collective/deep_ep/kernels/internode.cu @@ -1077,6 +1077,995 @@ void notify_combine_post_step(int num_ranks, #undef NOTIFY_DISPATCH_LAUNCH_CASE } +template +__global__ void fused_notify_dispatch(const int* num_tokens_per_rank, + int* moe_recv_counter_mapped, + int num_ranks, + const int* num_tokens_per_rdma_rank, + int* moe_recv_rdma_counter_mapped, + const int* num_tokens_per_expert, + int* moe_recv_expert_counter_mapped, + int num_experts, + const bool* is_token_in_rank, + int num_tokens, + int num_channels, + int expert_alignment, + int num_loop_stage, + const int rdma_clean_offset, + const int rdma_num_int_clean, + const int nvl_clean_offset, + const int nvl_num_int_clean, + int* rdma_channel_prefix_matrix, + int* recv_rdma_rank_prefix_sum, + int* gbl_channel_prefix_matrix, + int* recv_gbl_rank_prefix_sum, + void* rdma_buffer_ptr, + void** buffer_ptrs, + int** task_fifo_ptrs, + int head, + int rank, + const nvshmem_team_t rdma_team) { + auto sm_id = static_cast(blockIdx.x); + auto thread_id = static_cast(threadIdx.x), warp_id = thread_id / 32, + lane_id = get_lane_id(); + auto num_threads = static_cast(blockDim.x), num_warps = num_threads / 32; + + auto rdma_rank = rank / NUM_MAX_NVL_PEERS, + nvl_rank = rank % NUM_MAX_NVL_PEERS; + auto num_rdma_experts = num_experts / kNumRDMARanks, + num_nvl_experts = num_rdma_experts / NUM_MAX_NVL_PEERS; + + if (sm_id == 0) { + // Communication with others + // Global barrier: the first warp do intra-node sync, the second warp do + // internode sync + EP_DEVICE_ASSERT(num_warps > 1); + EP_DEVICE_ASSERT(kNumRDMARanks <= num_threads); + if (thread_id == 32) + nvshmem_barrier_with_same_gpu_idx(rdma_team); + barrier_device(task_fifo_ptrs, head, nvl_rank); + move_fifo_slots(head); + __syncthreads(); + + // Send numbers of tokens per rank/expert to RDMA ranks + auto rdma_buffer_ptr_int = reinterpret_cast(rdma_buffer_ptr); + auto rdma_recv_num_tokens_mixed = SymBuffer( + rdma_buffer_ptr, + (NUM_MAX_NVL_PEERS + num_rdma_experts + 1) * num_loop_stage, + kNumRDMARanks); + + // Clean up for later data dispatch + EP_DEVICE_ASSERT(rdma_recv_num_tokens_mixed.total_bytes <= + rdma_clean_offset * sizeof(int)); +#pragma unroll + for (int i = thread_id; i < rdma_num_int_clean; i += num_threads) + rdma_buffer_ptr_int[rdma_clean_offset + i] = 0; + +// Copy to send buffer +#pragma unroll + for (int i = thread_id; i < num_ranks; i += num_threads) { + for (int j = 0; j < num_loop_stage; ++j) { + rdma_recv_num_tokens_mixed.send_buffer( + i / + NUM_MAX_NVL_PEERS)[j * (NUM_MAX_NVL_PEERS + num_rdma_experts + 1) + + (i % NUM_MAX_NVL_PEERS)] = + num_tokens_per_rank[j * num_ranks + i]; + } + } +#pragma unroll + for (int i = thread_id; i < num_experts; i += num_threads) { + for (int j = 0; j < num_loop_stage; ++j) { + rdma_recv_num_tokens_mixed.send_buffer( + i / + num_rdma_experts)[j * (NUM_MAX_NVL_PEERS + num_rdma_experts + 1) + + NUM_MAX_NVL_PEERS + i % num_rdma_experts] = + num_tokens_per_expert[j * num_experts + i]; + } + } + + if (thread_id < kNumRDMARanks) { +#pragma unroll + for (int j = 0; j < num_loop_stage; ++j) { + rdma_recv_num_tokens_mixed.send_buffer( + thread_id)[j * (NUM_MAX_NVL_PEERS + num_rdma_experts + 1) + + NUM_MAX_NVL_PEERS + num_rdma_experts] = + num_tokens_per_rdma_rank[j * kNumRDMARanks + thread_id]; + } + } + + __syncthreads(); + + if (thread_id < kNumRDMARanks) { + nvshmem_int_put_nbi( + rdma_recv_num_tokens_mixed.recv_buffer(rdma_rank), + rdma_recv_num_tokens_mixed.send_buffer(thread_id), + (NUM_MAX_NVL_PEERS + num_rdma_experts + 1) * num_loop_stage, + translate_dst_rdma_rank(thread_id, nvl_rank)); + } + __syncthreads(); + if (thread_id == 0) + nvshmem_barrier_with_same_gpu_idx(rdma_team); + __syncthreads(); + + // NVL buffers + auto nvl_send_buffer = + thread_id < NUM_MAX_NVL_PEERS ? buffer_ptrs[thread_id] : nullptr; + auto nvl_recv_buffer = buffer_ptrs[nvl_rank]; + auto nvl_reduced_num_tokens_per_expert = + Buffer(nvl_recv_buffer, num_rdma_experts * num_loop_stage) + .advance_also(nvl_send_buffer); + auto nvl_send_num_tokens_per_rank = AsymBuffer( + nvl_send_buffer, kNumRDMARanks * num_loop_stage, NUM_MAX_NVL_PEERS); + auto nvl_send_num_tokens_per_expert = AsymBuffer( + nvl_send_buffer, num_nvl_experts * num_loop_stage, NUM_MAX_NVL_PEERS); + auto nvl_recv_num_tokens_per_rank = AsymBuffer( + nvl_recv_buffer, kNumRDMARanks * num_loop_stage, NUM_MAX_NVL_PEERS); + auto nvl_recv_num_tokens_per_expert = AsymBuffer( + nvl_recv_buffer, num_nvl_experts * num_loop_stage, NUM_MAX_NVL_PEERS); + + // Clean up for later data dispatch + auto nvl_buffer_ptr_int = reinterpret_cast(buffer_ptrs[nvl_rank]); + EP_DEVICE_ASSERT(nvl_reduced_num_tokens_per_expert.total_bytes + + nvl_send_num_tokens_per_rank.total_bytes + + nvl_send_num_tokens_per_expert.total_bytes <= + nvl_clean_offset * sizeof(int)); +#pragma unroll + for (int i = thread_id; i < nvl_num_int_clean; i += num_threads) + nvl_buffer_ptr_int[nvl_clean_offset + i] = 0; + + // Reduce number of tokens per expert into the NVL send buffer + // TODO(Xreki): may use NVSHMEM reduction + EP_DEVICE_ASSERT(num_rdma_experts <= num_threads); + if (thread_id < num_rdma_experts) { + for (int j = 0; j < num_loop_stage; ++j) { + int sum = 0; +#pragma unroll + for (int i = 0; i < kNumRDMARanks; ++i) + sum += rdma_recv_num_tokens_mixed.recv_buffer( + i)[j * (NUM_MAX_NVL_PEERS + num_rdma_experts + 1) + + NUM_MAX_NVL_PEERS + thread_id]; + nvl_reduced_num_tokens_per_expert[j * num_rdma_experts + thread_id] = + sum; + } + } + __syncthreads(); + + // Reduce RDMA received tokens + if (thread_id < num_loop_stage) { + int sum = 0; +#pragma unroll + for (int i = 0; i < kNumRDMARanks; ++i) { + sum += rdma_recv_num_tokens_mixed.recv_buffer( + i)[thread_id * (NUM_MAX_NVL_PEERS + num_rdma_experts + 1) + + NUM_MAX_NVL_PEERS + num_rdma_experts]; + recv_rdma_rank_prefix_sum[thread_id * kNumRDMARanks + i] = sum; + } + while (ld_volatile_global(moe_recv_rdma_counter_mapped + thread_id) != + -1) { + } + moe_recv_rdma_counter_mapped[thread_id] = sum; + } + + // Send numbers of tokens per rank/expert to NVL ranks + EP_DEVICE_ASSERT(NUM_MAX_NVL_PEERS <= num_threads); + if (thread_id < NUM_MAX_NVL_PEERS) { + for (int j = 0; j < num_loop_stage; ++j) { +#pragma unroll + for (int i = 0; i < kNumRDMARanks; ++i) + nvl_send_num_tokens_per_rank.buffer(nvl_rank)[j * kNumRDMARanks + i] = + rdma_recv_num_tokens_mixed.recv_buffer( + i)[j * (NUM_MAX_NVL_PEERS + num_rdma_experts + 1) + + thread_id]; +#pragma unroll + for (int i = 0; i < num_nvl_experts; ++i) + nvl_send_num_tokens_per_expert.buffer( + nvl_rank)[j * num_nvl_experts + i] = + nvl_reduced_num_tokens_per_expert[j * num_rdma_experts + + thread_id * num_nvl_experts + + i]; + } + } + memory_fence(); + __syncthreads(); + barrier_device(task_fifo_ptrs, head, nvl_rank); + move_fifo_slots(head); + __syncthreads(); + + // Reduce number of tokens per rank/expert + EP_DEVICE_ASSERT(num_nvl_experts <= num_threads); + if (thread_id < num_loop_stage) { + int sum = 0; +#pragma unroll + for (int i = 0; i < num_ranks; ++i) { + int src_rdma_rank = i / NUM_MAX_NVL_PEERS, + src_nvl_rank = i % NUM_MAX_NVL_PEERS; + sum += nvl_recv_num_tokens_per_rank.buffer( + src_nvl_rank)[thread_id * kNumRDMARanks + src_rdma_rank]; + recv_gbl_rank_prefix_sum[thread_id * num_ranks + i] = sum; + } + while (ld_volatile_global(moe_recv_counter_mapped + thread_id) != -1) { + } + moe_recv_counter_mapped[thread_id] = sum; + } + + EP_DEVICE_ASSERT(num_nvl_experts * num_loop_stage <= num_threads); + if (thread_id < num_nvl_experts * num_loop_stage) { + int sum = 0; +#pragma unroll + for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i) + sum += nvl_recv_num_tokens_per_expert.buffer(i)[thread_id]; + sum = (sum + expert_alignment - 1) / expert_alignment * expert_alignment; + while (ld_volatile_global(moe_recv_expert_counter_mapped + thread_id) != + -1) { + } + moe_recv_expert_counter_mapped[thread_id] = sum; + } + + // Finally barrier + __syncthreads(); + if (thread_id == 32) + nvshmem_barrier_with_same_gpu_idx(rdma_team); + barrier_device(task_fifo_ptrs, head, nvl_rank); + move_fifo_slots(head); + } else { + // Calculate meta data + int stage_id = (sm_id - 1) / kNumRDMARanks; + int dst_rdma_rank = (sm_id - 1) % kNumRDMARanks; + for (int channel_id = warp_id; channel_id < num_channels; + channel_id += num_warps) { + int token_start_idx, token_end_idx; + get_channel_task_range( + num_tokens, num_channels, channel_id, token_start_idx, token_end_idx); + + // Iterate over tokens + int total_count = 0, per_nvl_rank_count[NUM_MAX_NVL_PEERS] = {0}; + for (int64_t i = token_start_idx + lane_id; i < token_end_idx; i += 32) { + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * sizeof(bool) == sizeof(uint64_t), + "Invalid number of NVL peers"); + auto is_token_in_rank_uint64 = *reinterpret_cast( + is_token_in_rank + (stage_id * num_tokens + i) * num_ranks + + dst_rdma_rank * NUM_MAX_NVL_PEERS); + auto is_token_in_rank_values = + reinterpret_cast(&is_token_in_rank_uint64); +#pragma unroll + for (int j = 0; j < NUM_MAX_NVL_PEERS; ++j) + per_nvl_rank_count[j] += is_token_in_rank_values[j]; + total_count += (is_token_in_rank_uint64 != 0); + } + + // Warp reduce + total_count = warp_reduce_sum(total_count); +#pragma unroll + for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i) + per_nvl_rank_count[i] = warp_reduce_sum(per_nvl_rank_count[i]); + + // Write into channel matrix + if (lane_id == 0) { +#pragma unroll + for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i) + gbl_channel_prefix_matrix[(stage_id * num_ranks + + dst_rdma_rank * NUM_MAX_NVL_PEERS + i) * + num_channels + + channel_id] = per_nvl_rank_count[i]; + rdma_channel_prefix_matrix[(stage_id * kNumRDMARanks + dst_rdma_rank) * + num_channels + + channel_id] = total_count; + } + } + + // Calculate prefix sum + __syncthreads(); + if (thread_id == 0) { + auto prefix_row = + rdma_channel_prefix_matrix + + (stage_id * kNumRDMARanks + dst_rdma_rank) * num_channels; +#pragma unroll + for (int i = 1; i < num_channels; ++i) prefix_row[i] += prefix_row[i - 1]; + } + + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= 32, "Invalid number of NVL peers"); + if (thread_id < NUM_MAX_NVL_PEERS) { + auto prefix_row = gbl_channel_prefix_matrix + + (stage_id * num_ranks + + dst_rdma_rank * NUM_MAX_NVL_PEERS + thread_id) * + num_channels; +#pragma unroll + for (int i = 1; i < num_channels; ++i) prefix_row[i] += prefix_row[i - 1]; + } + } +} + +void fused_notify_dispatch(const int* num_tokens_per_rank, + int* moe_recv_counter_mapped, + int num_ranks, + const int* num_tokens_per_rdma_rank, + int* moe_recv_rdma_counter_mapped, + const int* num_tokens_per_expert, + int* moe_recv_expert_counter_mapped, + int num_experts, + const bool* is_token_in_rank, + int num_tokens, + int num_channels, + int hidden_int4, + int num_scales, + int num_topk, + int expert_alignment, + int num_loop_stage, + int* rdma_channel_prefix_matrix, + int* recv_rdma_rank_prefix_sum, + int* gbl_channel_prefix_matrix, + int* recv_gbl_rank_prefix_sum, + void* rdma_buffer_ptr, + int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, + int num_max_nvl_chunked_recv_tokens, + int** task_fifo_ptrs, + int head, + int rank, + cudaStream_t stream, + int64_t num_rdma_bytes, + int64_t num_nvl_bytes, + bool low_latency_mode) { +#define NOTIFY_DISPATCH_LAUNCH_CASE(num_rdma_ranks) \ + { \ + auto fused_notify_dispatch_func = \ + low_latency_mode ? fused_notify_dispatch \ + : fused_notify_dispatch; \ + LAUNCH_KERNEL(&cfg, \ + fused_notify_dispatch_func, \ + num_tokens_per_rank, \ + moe_recv_counter_mapped, \ + num_ranks, \ + num_tokens_per_rdma_rank, \ + moe_recv_rdma_counter_mapped, \ + num_tokens_per_expert, \ + moe_recv_expert_counter_mapped, \ + num_experts, \ + is_token_in_rank, \ + num_tokens, \ + num_channels, \ + expert_alignment, \ + num_loop_stage, \ + rdma_clean_meta.first, \ + rdma_clean_meta.second, \ + nvl_clean_meta.first, \ + nvl_clean_meta.second, \ + rdma_channel_prefix_matrix, \ + recv_rdma_rank_prefix_sum, \ + gbl_channel_prefix_matrix, \ + recv_gbl_rank_prefix_sum, \ + rdma_buffer_ptr, \ + buffer_ptrs, \ + task_fifo_ptrs, \ + head, \ + rank, \ + cpu_rdma_team); \ + } \ + break + + constexpr int kNumThreads = 512; + const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; + + // Get clean meta + auto rdma_clean_meta = get_rdma_clean_meta(hidden_int4, + num_scales, + num_topk, + num_topk, + num_rdma_ranks, + num_max_rdma_chunked_recv_tokens, + num_channels); + auto nvl_clean_meta = get_nvl_clean_meta(hidden_int4, + num_scales, + num_topk, + num_topk, + num_rdma_ranks, + NUM_MAX_NVL_PEERS, + num_max_nvl_chunked_recv_tokens, + num_channels); + EP_HOST_ASSERT((rdma_clean_meta.first + rdma_clean_meta.second) * + sizeof(int) <= + num_rdma_bytes); + EP_HOST_ASSERT((nvl_clean_meta.first + nvl_clean_meta.second) * sizeof(int) <= + num_nvl_bytes); + EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits::max()); + EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits::max()); + + // Launch kernel + SETUP_LAUNCH_CONFIG(1 + num_rdma_ranks * num_loop_stage, kNumThreads, stream); + SWITCH_RDMA_RANKS(NOTIFY_DISPATCH_LAUNCH_CASE); +#undef NOTIFY_DISPATCH_LAUNCH_CASE +} + +template +__global__ void fused_notify_combine( + const int* num_tokens_per_rank, // [num_loop_stage, 2, num_ranks] + int* moe_recv_counter_mapped, + int num_ranks, + const int* num_tokens_per_rdma_rank, + int* moe_recv_rdma_counter_mapped, + const bool* is_token_in_rank, + int num_tokens, + int num_channels, + int num_loop_stage, + const int rdma_clean_offset, + const int rdma_num_int_clean, + const int nvl_clean_offset, + const int nvl_num_int_clean, + int* rdma_channel_prefix_matrix, + int* gbl_channel_prefix_matrix, + int* recv_rdma_rank_prefix_sum, + int* recv_gbl_rank_prefix_sum, + int* send_rdma_head, + int* send_nvl_head, + void* rdma_buffer_ptr, + void** buffer_ptrs, + int** task_fifo_ptrs, + int head, + int rank, + const nvshmem_team_t rdma_team) { + auto sm_id = static_cast(blockIdx.x); + auto thread_id = static_cast(threadIdx.x), warp_id = thread_id / 32, + lane_id = get_lane_id(); + auto num_threads = static_cast(blockDim.x), num_warps = num_threads / 32; + + auto rdma_rank = rank / NUM_MAX_NVL_PEERS, + nvl_rank = rank % NUM_MAX_NVL_PEERS; + + if (sm_id == 0) { + // Communication with others + // Global barrier: the first warp does intra-node sync, the second warp does + // internode sync + EP_DEVICE_ASSERT(num_warps > 1); + EP_DEVICE_ASSERT(kNumRDMARanks <= num_threads); + if (thread_id == 32) + nvshmem_barrier_with_same_gpu_idx(rdma_team); + barrier_device(task_fifo_ptrs, head, nvl_rank); + move_fifo_slots(head); + __syncthreads(); + + // Send numbers of tokens per rank/expert to RDMA ranks + auto rdma_buffer_ptr_int = static_cast(rdma_buffer_ptr); + auto rdma_recv_num_tokens_mixed = + SymBuffer(rdma_buffer_ptr, + (NUM_MAX_NVL_PEERS + 1) * num_loop_stage, + kNumRDMARanks); + + // Clean up for later data dispatch + EP_DEVICE_ASSERT(rdma_recv_num_tokens_mixed.total_bytes <= + rdma_clean_offset * sizeof(int)); +#pragma unroll + for (int i = thread_id; i < rdma_num_int_clean; i += num_threads) + rdma_buffer_ptr_int[rdma_clean_offset + i] = 0; + +// Copy to send buffer +#pragma unroll + for (int i = thread_id; i < num_ranks; i += num_threads) { + for (int j = 0; j < num_loop_stage; ++j) { + rdma_recv_num_tokens_mixed.send_buffer( + i / NUM_MAX_NVL_PEERS)[j * (NUM_MAX_NVL_PEERS + 1) + + (i % NUM_MAX_NVL_PEERS)] = + num_tokens_per_rank[j * num_ranks + i]; + } + } + + if (thread_id < kNumRDMARanks) { +#pragma unroll + for (int j = 0; j < num_loop_stage; ++j) { + rdma_recv_num_tokens_mixed.send_buffer( + thread_id)[j * (NUM_MAX_NVL_PEERS + 1) + NUM_MAX_NVL_PEERS] = + num_tokens_per_rdma_rank[j * kNumRDMARanks + thread_id]; + } + } + + __syncthreads(); + + // Issue send + if (thread_id < kNumRDMARanks) { + nvshmem_int_put_nbi( + rdma_recv_num_tokens_mixed.recv_buffer(rdma_rank), + rdma_recv_num_tokens_mixed.send_buffer(thread_id), + (NUM_MAX_NVL_PEERS + 1) * num_loop_stage, + translate_dst_rdma_rank(thread_id, nvl_rank)); + } + __syncthreads(); + + // Barrier + if (thread_id == 0) { + nvshmem_barrier_with_same_gpu_idx(rdma_team); + } + __syncthreads(); + + // NVL buffers + auto nvl_send_buffer = + thread_id < NUM_MAX_NVL_PEERS ? buffer_ptrs[thread_id] : nullptr; + auto nvl_recv_buffer = buffer_ptrs[nvl_rank]; + + auto nvl_send_num_tokens_per_rank = AsymBuffer( + nvl_send_buffer, kNumRDMARanks * num_loop_stage, NUM_MAX_NVL_PEERS); + auto nvl_recv_num_tokens_per_rank = AsymBuffer( + nvl_recv_buffer, kNumRDMARanks * num_loop_stage, NUM_MAX_NVL_PEERS); + + // Clean up for later data dispatch + auto nvl_buffer_ptr_int = static_cast(buffer_ptrs[nvl_rank]); + EP_DEVICE_ASSERT(nvl_send_num_tokens_per_rank.total_bytes <= + nvl_clean_offset * sizeof(int)); +#pragma unroll + for (int i = thread_id; i < nvl_num_int_clean; i += num_threads) + nvl_buffer_ptr_int[nvl_clean_offset + i] = 0; + __syncthreads(); + + // Reduce RDMA received tokens + if (thread_id < num_loop_stage) { + int sum = 0; +#pragma unroll + for (int i = 0; i < kNumRDMARanks; ++i) { + sum += rdma_recv_num_tokens_mixed.recv_buffer( + i)[thread_id * (NUM_MAX_NVL_PEERS + 1) + NUM_MAX_NVL_PEERS]; + recv_rdma_rank_prefix_sum[thread_id * kNumRDMARanks + i] = sum; + } + while (ld_volatile_global(moe_recv_rdma_counter_mapped + thread_id) != + -1) { + } + moe_recv_rdma_counter_mapped[thread_id] = sum; + } + + // Send numbers of tokens per rank/expert to NVL ranks + EP_DEVICE_ASSERT(NUM_MAX_NVL_PEERS <= num_threads); + if (thread_id < NUM_MAX_NVL_PEERS) { + for (int j = 0; j < num_loop_stage; ++j) { +#pragma unroll + for (int i = 0; i < kNumRDMARanks; ++i) { + nvl_send_num_tokens_per_rank.buffer(nvl_rank)[j * kNumRDMARanks + i] = + rdma_recv_num_tokens_mixed.recv_buffer( + i)[j * (NUM_MAX_NVL_PEERS + 1) + thread_id]; + } + } + } + + memory_fence(); + __syncthreads(); + barrier_device(task_fifo_ptrs, head, nvl_rank); + move_fifo_slots(head); + __syncthreads(); + + if (thread_id < num_loop_stage) { + int sum = 0; +#pragma unroll + for (int i = 0; i < num_ranks; ++i) { + int src_rdma_rank = i / NUM_MAX_NVL_PEERS, + src_nvl_rank = i % NUM_MAX_NVL_PEERS; + sum += nvl_recv_num_tokens_per_rank.buffer( + src_nvl_rank)[thread_id * kNumRDMARanks + src_rdma_rank]; + recv_gbl_rank_prefix_sum[thread_id * num_ranks + i] = sum; + } + while (ld_volatile_global(moe_recv_counter_mapped + thread_id) != -1) { + } + moe_recv_counter_mapped[thread_id] = sum; + } + + // Finally barrier + if (thread_id == 32) + nvshmem_barrier_with_same_gpu_idx(rdma_team); + barrier_device(task_fifo_ptrs, head, nvl_rank); + move_fifo_slots(head); + } else { + // Calculate meta data + int stage_id = (sm_id - 1) / kNumRDMARanks; + int dst_rdma_rank = (sm_id - 1) % kNumRDMARanks; + for (int channel_id = warp_id; channel_id < num_channels; + channel_id += num_warps) { + int token_start_idx, token_end_idx; + get_channel_task_range( + num_tokens, num_channels, channel_id, token_start_idx, token_end_idx); + + // Iterate over tokens + int total_count = 0, per_nvl_rank_count[NUM_MAX_NVL_PEERS] = {0}; + int global_rdma_tail_idx = 0, + global_nvl_tail_idx[NUM_MAX_NVL_PEERS] = {0}; + for (int i = token_start_idx + lane_id; i < token_end_idx; i += 32) { + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * sizeof(bool) == sizeof(uint64_t), + "Invalid number of NVL peers"); + auto is_token_in_rank_uint64 = *reinterpret_cast( + is_token_in_rank + (stage_id * num_tokens + i) * num_ranks + + dst_rdma_rank * NUM_MAX_NVL_PEERS); + auto is_token_in_rank_values = + reinterpret_cast(&is_token_in_rank_uint64); + + total_count += (is_token_in_rank_uint64 != 0); + + // Calculate RDMA tail index for combine + auto warp_valid_tokens = std::min(token_end_idx - (i - lane_id), 32); + unsigned int mask = 0xffffffff >> (32 - warp_valid_tokens); + int warp_rdma_tail_idx = (is_token_in_rank_uint64 != 0); + global_rdma_tail_idx += warp_scan(warp_rdma_tail_idx, mask); + auto rdma_tail_idx = + is_token_in_rank_uint64 == 0 ? -1 : global_rdma_tail_idx - 1; + send_rdma_head[(stage_id * num_tokens + i) * kNumRDMARanks + + dst_rdma_rank] = rdma_tail_idx; + global_rdma_tail_idx = + __shfl_sync(mask, global_rdma_tail_idx, warp_valid_tokens - 1); + +#pragma unroll + for (int j = 0; j < NUM_MAX_NVL_PEERS; ++j) { + per_nvl_rank_count[j] += is_token_in_rank_values[j]; + int warp_nvl_tail_idx = (is_token_in_rank_values[j]); + global_nvl_tail_idx[j] += warp_scan(warp_nvl_tail_idx, mask); + auto nvl_tail_idx = + is_token_in_rank_values[j] == 0 ? -1 : global_nvl_tail_idx[j] - 1; + send_nvl_head[(stage_id * num_tokens + i) * kNumRDMARanks * + NUM_MAX_NVL_PEERS + + dst_rdma_rank * NUM_MAX_NVL_PEERS + j] = nvl_tail_idx; + global_nvl_tail_idx[j] = + __shfl_sync(mask, global_nvl_tail_idx[j], warp_valid_tokens - 1); + } + } + + // Warp reduce + total_count = warp_reduce_sum(total_count); +#pragma unroll + for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i) + per_nvl_rank_count[i] = warp_reduce_sum(per_nvl_rank_count[i]); + + // Write into channel matrix + if (lane_id == 0) { +#pragma unroll + for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i) + gbl_channel_prefix_matrix[(stage_id * num_ranks + + dst_rdma_rank * NUM_MAX_NVL_PEERS + i) * + num_channels + + channel_id] = per_nvl_rank_count[i]; + rdma_channel_prefix_matrix[(stage_id * kNumRDMARanks + dst_rdma_rank) * + num_channels + + channel_id] = total_count; + } + } + + // Calculate prefix sum + __syncthreads(); + if (thread_id == 0) { + auto prefix_row = + rdma_channel_prefix_matrix + + (stage_id * kNumRDMARanks + dst_rdma_rank) * num_channels; +#pragma unroll + for (int i = 1; i < num_channels; ++i) { + prefix_row[i] += prefix_row[i - 1]; + } + } + + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= 32, "Invalid number of NVL peers"); + if (thread_id < NUM_MAX_NVL_PEERS) { + auto prefix_row = gbl_channel_prefix_matrix + + (stage_id * num_ranks + + dst_rdma_rank * NUM_MAX_NVL_PEERS + thread_id) * + num_channels; +#pragma unroll + for (int i = 1; i < num_channels; ++i) { + prefix_row[i] += prefix_row[i - 1]; + } + } + } +} + +void fused_notify_combine(const int* num_tokens_per_rank, + int* moe_recv_counter_mapped, + int num_ranks, + const int* num_tokens_per_rdma_rank, + int* moe_recv_rdma_counter_mapped, + const bool* is_token_in_rank, + int num_tokens, + int num_channels, + int hidden_int4, + int num_scales, + int num_topk, + int num_loop_stage, + int* rdma_channel_prefix_matrix, + int* recv_rdma_rank_prefix_sum, + int* gbl_channel_prefix_matrix, + int* recv_gbl_rank_prefix_sum, + int* send_rdma_head, + int* send_nvl_head, + void* rdma_buffer_ptr, + int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, + int num_max_nvl_chunked_recv_tokens, + int** task_fifo_ptrs, + int head, + int rank, + cudaStream_t stream, + int64_t num_rdma_bytes, + int64_t num_nvl_bytes, + bool low_latency_mode) { +#define NOTIFY_COMBINE_LAUNCH_CASE(num_rdma_ranks) \ + { \ + auto notify_combine_func = \ + low_latency_mode ? fused_notify_combine \ + : fused_notify_combine; \ + LAUNCH_KERNEL(&cfg, \ + notify_combine_func, \ + num_tokens_per_rank, \ + moe_recv_counter_mapped, \ + num_ranks, \ + num_tokens_per_rdma_rank, \ + moe_recv_rdma_counter_mapped, \ + is_token_in_rank, \ + num_tokens, \ + num_channels, \ + num_loop_stage, \ + rdma_clean_meta.first, \ + rdma_clean_meta.second, \ + nvl_clean_meta.first, \ + nvl_clean_meta.second, \ + rdma_channel_prefix_matrix, \ + gbl_channel_prefix_matrix, \ + recv_rdma_rank_prefix_sum, \ + recv_gbl_rank_prefix_sum, \ + send_rdma_head, \ + send_nvl_head, \ + rdma_buffer_ptr, \ + buffer_ptrs, \ + task_fifo_ptrs, \ + head, \ + rank, \ + cpu_rdma_team); \ + } \ + break + + constexpr int kNumThreads = 512; + const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; + + // Get clean meta + auto rdma_clean_meta = get_rdma_clean_meta(hidden_int4, + num_scales, + num_topk, + num_topk, + num_rdma_ranks, + num_max_rdma_chunked_recv_tokens, + num_channels); + auto nvl_clean_meta = get_nvl_clean_meta(hidden_int4, + num_scales, + num_topk, + num_topk, + num_rdma_ranks, + NUM_MAX_NVL_PEERS, + num_max_nvl_chunked_recv_tokens, + num_channels); + EP_HOST_ASSERT((rdma_clean_meta.first + rdma_clean_meta.second) * + sizeof(int) <= + num_rdma_bytes); + EP_HOST_ASSERT((nvl_clean_meta.first + nvl_clean_meta.second) * sizeof(int) <= + num_nvl_bytes); + EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits::max()); + EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits::max()); + + // Launch kernel + SETUP_LAUNCH_CONFIG(1 + num_rdma_ranks * num_loop_stage, kNumThreads, stream); + SWITCH_RDMA_RANKS(NOTIFY_COMBINE_LAUNCH_CASE); +#undef NOTIFY_DISPATCH_LAUNCH_CASE +} + +template +__global__ void fused_notify_combine_post_step( + int num_channels, + int num_loop_stage, + const int* recv_gbl_rank_prefix_sum, + const int* rdma_channel_prefix_matrix, + const int* gbl_channel_prefix_matrix, + int* recv_rdma_channel_prefix_matrix, + int* recv_gbl_channel_prefix_matrix, + void* rdma_buffer_ptr, + void** buffer_ptrs, + int** task_fifo_ptrs, + int head, + int rank, + const nvshmem_team_t rdma_team) { + auto sm_id = static_cast(blockIdx.x); + EP_DEVICE_ASSERT(sm_id == 0); + auto thread_id = static_cast(threadIdx.x), warp_id = thread_id / 32, + lane_id = get_lane_id(); + auto num_threads = static_cast(blockDim.x), num_warps = num_threads / 32; + + auto rdma_rank = rank / NUM_MAX_NVL_PEERS, + nvl_rank = rank % NUM_MAX_NVL_PEERS; + + auto rdma_channel_meta = + SymBuffer(rdma_buffer_ptr, + (num_channels + num_channels * NUM_MAX_NVL_PEERS) * + num_loop_stage, // (rdma_channel_meta + + // nvl_channel_meta) * num_loop_stage + kNumRDMARanks); + + // NVL buffers + auto nvl_send_buffer = + warp_id - 1 < NUM_MAX_NVL_PEERS ? buffer_ptrs[warp_id - 1] : nullptr; + auto nvl_recv_buffer = buffer_ptrs[nvl_rank]; + + auto nvl_send_channel_meta = + AsymBuffer(nvl_send_buffer, + num_loop_stage * kNumRDMARanks * num_channels, + NUM_MAX_NVL_PEERS); + auto nvl_recv_channel_meta = + AsymBuffer(nvl_recv_buffer, + num_loop_stage * kNumRDMARanks * num_channels, + NUM_MAX_NVL_PEERS); + + if (thread_id == 32) + nvshmem_barrier_with_same_gpu_idx(rdma_team); + barrier_device(task_fifo_ptrs, head, nvl_rank); + move_fifo_slots(head); + __syncthreads(); + + if (warp_id == 0) { // rdma_channel_prefix_matrix data + // land_id -> dst_rdma_rank + if (lane_id < kNumRDMARanks) { + for (int j = 0; j < num_loop_stage; ++j) { + for (int i = 0; i < num_channels; ++i) { + rdma_channel_meta.send_buffer( + lane_id)[j * (num_channels + num_channels * NUM_MAX_NVL_PEERS) + + i] = + -rdma_channel_prefix_matrix[(j * kNumRDMARanks + lane_id) * + num_channels + + i] - + 1; + } + } + } + } else if (warp_id < + NUM_MAX_NVL_PEERS + 1) { // gbl_channel_prefix_matrix data + // land_id -> dst_rdma_rank + // warp_id -1 -> dst_nvl_rank + if (lane_id < kNumRDMARanks) { + for (int j = 0; j < num_loop_stage; ++j) { + auto dst_ptr = rdma_channel_meta.send_buffer(lane_id) + + j * (num_channels + num_channels * NUM_MAX_NVL_PEERS) + + num_channels + (warp_id - 1) * num_channels; + dst_ptr[0] = -0 - 1; + for (int i = 1; i < num_channels; ++i) { + dst_ptr[i] = -gbl_channel_prefix_matrix[(j * kNumRDMARanks * + NUM_MAX_NVL_PEERS + + lane_id * NUM_MAX_NVL_PEERS + + warp_id - 1) * + num_channels + + i - 1] - + 1; + } + } + } + } + __syncthreads(); + + // Issue send + if (thread_id < kNumRDMARanks) { + nvshmem_int_put_nbi( + rdma_channel_meta.recv_buffer(rdma_rank), + rdma_channel_meta.send_buffer(thread_id), + (num_channels + num_channels * NUM_MAX_NVL_PEERS) * num_loop_stage, + translate_dst_rdma_rank(thread_id, nvl_rank)); + } + __syncthreads(); + + // Barrier + if (thread_id == 0) { + nvshmem_barrier_with_same_gpu_idx(rdma_team); + } + __syncthreads(); + + // Receive RDMA + if (warp_id == 0) { + // lane_id -> src_rdma_rank + if (lane_id < kNumRDMARanks) { + for (int j = 0; j < num_loop_stage; ++j) { + for (int i = 0; i < num_channels; ++i) { + recv_rdma_channel_prefix_matrix[(j * kNumRDMARanks + lane_id) * + num_channels + + i] = + -rdma_channel_meta.recv_buffer(lane_id) + [j * (num_channels + num_channels * NUM_MAX_NVL_PEERS) + i] - + 1; + } + } + } + + } else if (warp_id < NUM_MAX_NVL_PEERS + 1) { + // lane_id -> src_rdma_rank + // warp_id - 1 -> dst_nvl_rank + if (lane_id < kNumRDMARanks) { + for (int j = 0; j < num_loop_stage; ++j) { + auto recv_ptr = rdma_channel_meta.recv_buffer(lane_id) + + j * (num_channels + num_channels * NUM_MAX_NVL_PEERS) + + num_channels + (warp_id - 1) * num_channels; + for (int i = 0; i < num_channels; ++i) { + st_relaxed_sys_global( + nvl_send_channel_meta.buffer(nvl_rank) + + (j * kNumRDMARanks + lane_id) * num_channels + i, + recv_ptr[i]); + } + } + } + } + memory_fence(); + __syncthreads(); + barrier_device(task_fifo_ptrs, head, nvl_rank); + move_fifo_slots(head); + __syncthreads(); + + EP_DEVICE_ASSERT(kNumRDMARanks * NUM_MAX_NVL_PEERS <= num_threads); + if (thread_id < kNumRDMARanks * NUM_MAX_NVL_PEERS) { + const auto src_rdma_rank = thread_id / NUM_MAX_NVL_PEERS; + const auto src_nvl_rank = thread_id % NUM_MAX_NVL_PEERS; + for (int j = 0; j < num_loop_stage; ++j) { + int rank_offset = + thread_id > 0 + ? recv_gbl_rank_prefix_sum[j * kNumRDMARanks * NUM_MAX_NVL_PEERS + + thread_id - 1] + : 0; +#pragma unroll + for (int i = 0; i < num_channels; ++i) { + recv_gbl_channel_prefix_matrix[(j * kNumRDMARanks * NUM_MAX_NVL_PEERS + + thread_id) * + num_channels + + i] = + rank_offset - + nvl_recv_channel_meta.buffer(src_nvl_rank) + [(j * kNumRDMARanks + src_rdma_rank) * num_channels + i] - + 1; + } + } + } + + // Finally barrier + if (thread_id == 32) + nvshmem_barrier_with_same_gpu_idx(rdma_team); + barrier_device(task_fifo_ptrs, head, nvl_rank); + move_fifo_slots(head); +} + +void fused_notify_combine_post_step(int num_ranks, + int num_channels, + int num_loop_stage, + const int* recv_gbl_rank_prefix_sum, + const int* rdma_channel_prefix_matrix, + const int* gbl_channel_prefix_matrix, + int* recv_rdma_channel_prefix_matrix, + int* recv_gbl_channel_prefix_matrix, + void* rdma_buffer_ptr, + void** buffer_ptrs, + int** task_fifo_ptrs, + int head, + int rank, + cudaStream_t stream, + bool low_latency_mode) { +#define NOTIFY_COMBINE_S1_LAUNCH_CASE(num_rdma_ranks) \ + { \ + auto notify_combine_post_step_func = \ + low_latency_mode \ + ? fused_notify_combine_post_step \ + : fused_notify_combine_post_step; \ + LAUNCH_KERNEL(&cfg, \ + notify_combine_post_step_func, \ + num_channels, \ + num_loop_stage, \ + recv_gbl_rank_prefix_sum, \ + rdma_channel_prefix_matrix, \ + gbl_channel_prefix_matrix, \ + recv_rdma_channel_prefix_matrix, \ + recv_gbl_channel_prefix_matrix, \ + rdma_buffer_ptr, \ + buffer_ptrs, \ + task_fifo_ptrs, \ + head, \ + rank, \ + cpu_rdma_team); \ + } \ + break + + constexpr int kNumThreads = 512; + + // Launch kernel + SETUP_LAUNCH_CONFIG(1, kNumThreads, stream); + SWITCH_RDMA_RANKS(NOTIFY_COMBINE_S1_LAUNCH_CASE); +#undef NOTIFY_DISPATCH_LAUNCH_CASE +} + // At most 8 RDMA ranks to be sent constexpr int get_num_topk_rdma_ranks(int num_rdma_ranks) { return num_rdma_ranks < 8 ? num_rdma_ranks : 8; @@ -2660,7 +3649,7 @@ template < int kNumWarpsPerForwarder = (kNumCombineForwarderWarps / kNumRDMARanks > 0) ? kNumCombineForwarderWarps / kNumRDMARanks : 1, - int kNumForwarders = kNumRDMARanks* kNumWarpsPerForwarder, + int kNumForwarders = kNumRDMARanks * kNumWarpsPerForwarder, int kNumRDMAReceivers = kNumForwarders + NUM_MAX_NVL_PEERS> __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32, 1) diff --git a/paddle/fluid/pybind/deep_ep_api.cc b/paddle/fluid/pybind/deep_ep_api.cc index 001d134ff245d1..7228f773d87a42 100644 --- a/paddle/fluid/pybind/deep_ep_api.cc +++ b/paddle/fluid/pybind/deep_ep_api.cc @@ -61,7 +61,7 @@ void BindDeepEPApi(pybind11::module *m) { &deep_ep::GetEventHandleFromCustomStream); pybind11::class_(*m, "Buffer") - .def(pybind11::init()) + .def(pybind11::init()) .def("is_available", &deep_ep::Buffer::is_available) .def("get_num_rdma_ranks", &deep_ep::Buffer::get_num_rdma_ranks) .def("get_rdma_rank", &deep_ep::Buffer::get_rdma_rank) @@ -97,9 +97,13 @@ void BindDeepEPApi(pybind11::module *m) { .def("internode_dispatch", &deep_ep::Buffer::internode_dispatch_api) .def("internode_notify_dispatch", &deep_ep::Buffer::internode_notify_dispatch_api) + .def("internode_fused_notify_dispatch", + &deep_ep::Buffer::internode_fused_notify_dispatch_api) .def("clear_buffer", &deep_ep::Buffer::clear_buffer_api) .def("internode_notify_combine", &deep_ep::Buffer::internode_notify_combine_api) + .def("internode_fused_notify_combine", + &deep_ep::Buffer::internode_fused_notify_combine_api) .def("internode_combine", &deep_ep::Buffer::internode_combine_api) .def("barrier_all", &deep_ep::Buffer::barrier_all) .def("clean_low_latency_buffer", diff --git a/python/paddle/distributed/communication/deep_ep/buffer.py b/python/paddle/distributed/communication/deep_ep/buffer.py index c04eb9c1afaebd..e68f5b517e2bd2 100644 --- a/python/paddle/distributed/communication/deep_ep/buffer.py +++ b/python/paddle/distributed/communication/deep_ep/buffer.py @@ -65,6 +65,7 @@ def __init__( num_rdma_bytes: int = 0, low_latency_mode: bool = False, num_qps_per_rank: int = 12, + num_loop_stages: int = 3, ) -> None: """ Initialize the communication buffer. @@ -88,6 +89,7 @@ def __init__( self.runtime = CppBuffer( self.rank, self.group_size, + num_loop_stages, num_nvl_bytes, num_rdma_bytes, low_latency_mode, @@ -900,6 +902,80 @@ def internode_notify_dispatch( handle, ) + def internode_fused_notify_dispatch( + self, + x: paddle.Tensor | tuple[paddle.Tensor, paddle.Tensor], + topk_idx: paddle.Tensor | None = None, + num_tokens_per_rank: paddle.Tensor | None = None, + num_tokens_per_rdma_rank: paddle.Tensor | None = None, + num_tokens_per_expert: paddle.Tensor | None = None, + is_token_in_rank: paddle.Tensor | None = None, + num_loop_stage: int = 1, + expert_alignment: int = 1, + config: Config | None = None, + ) -> tuple[ + list[list[int]], + list[int], + list[int], + list[tuple], + ]: + # Default config + config = ( + self.get_dispatch_config(self.group_size) + if config is None + else config + ) + # Launch the kernel with cached or non-cached mode + x, x_scales = x if isinstance(x, tuple) else (x, None) + assert ( + num_tokens_per_rank is not None + and is_token_in_rank is not None + and num_tokens_per_expert is not None + ) + + ( + num_recv_tokens_per_expert_list, + num_recv_tokens, + num_rdma_recv_tokens, + rdma_channel_prefix_matrix, + gbl_channel_prefix_matrix, + recv_rdma_rank_prefix_sum, + recv_gbl_rank_prefix_sum, + ) = self.runtime.internode_fused_notify_dispatch( + x, + x_scales, + topk_idx, + num_tokens_per_rank, + num_tokens_per_rdma_rank, + num_tokens_per_expert, + is_token_in_rank, + expert_alignment, + num_loop_stage, + config, + ) + handles = [] + for loop_idx in range(num_loop_stage): + handle = ( + is_token_in_rank[loop_idx], + rdma_channel_prefix_matrix[loop_idx], + gbl_channel_prefix_matrix[loop_idx], + None, + recv_rdma_rank_prefix_sum[loop_idx], + None, + recv_gbl_rank_prefix_sum[loop_idx], + paddle.empty([num_recv_tokens[loop_idx], 0]), + None, + paddle.empty([num_rdma_recv_tokens[loop_idx], 0]), + ) + handles.append(handle) + + return ( + num_recv_tokens_per_expert_list, + num_recv_tokens, + num_rdma_recv_tokens, + handles, + ) + def internode_notify_combine( self, x: paddle.Tensor | tuple[paddle.Tensor, paddle.Tensor], @@ -957,6 +1033,63 @@ def internode_notify_combine( send_nvl_head, ) + def internode_fused_notify_combine( + self, + x: paddle.Tensor | tuple[paddle.Tensor, paddle.Tensor], + topk_idx: paddle.Tensor | None = None, + num_tokens_per_rank: paddle.Tensor | None = None, + num_tokens_per_rdma_rank: paddle.Tensor | None = None, + is_token_in_rank: paddle.Tensor | None = None, + num_loop_stages: int = 1, + config: Config | None = None, + ) -> tuple[ + list[int], + list[int], + paddle.Tensor, + paddle.Tensor, + paddle.Tensor, + paddle.Tensor, + paddle.Tensor, + ]: + # Default config + config = ( + self.get_dispatch_config(self.group_size) + if config is None + else config + ) + # Launch the kernel with cached or non-cached mode + x, x_scales = x if isinstance(x, tuple) else (x, None) + assert num_tokens_per_rank is not None and is_token_in_rank is not None + + ( + num_recv_tokens, + num_rdma_recv_tokens, + recv_rdma_rank_prefix_sum, + recv_rdma_channel_prefix_matrix, + recv_gbl_channel_prefix_matrix, + send_rdma_head, + send_nvl_head, + ) = self.runtime.internode_fused_notify_combine( + x, + x_scales, + topk_idx, + num_tokens_per_rank, + num_tokens_per_rdma_rank, + is_token_in_rank, + num_loop_stages, + config, + ) + + return ( + num_recv_tokens, + num_rdma_recv_tokens, + recv_rdma_rank_prefix_sum, + recv_rdma_channel_prefix_matrix, + recv_gbl_channel_prefix_matrix, + send_rdma_head, + send_nvl_head, + ) + # noinspection PyTypeChecker def internode_combine( self,