From d5e68e92c6f68ff372ea6322054ad27175aff4b9 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Wed, 19 Nov 2025 17:30:01 +0800 Subject: [PATCH 1/4] add fused notify_combine --- .../collective/deep_ep/deep_ep.cpp | 259 +++++++- .../collective/deep_ep/deep_ep.hpp | 34 + .../collective/deep_ep/kernels/api.cuh | 46 ++ .../collective/deep_ep/kernels/internode.cu | 593 +++++++++++++++++- paddle/fluid/pybind/deep_ep_api.cc | 2 + .../communication/deep_ep/buffer.py | 58 ++ 6 files changed, 988 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/distributed/collective/deep_ep/deep_ep.cpp b/paddle/fluid/distributed/collective/deep_ep/deep_ep.cpp index bbabc308742b0b..5aadc01ec8df8b 100644 --- a/paddle/fluid/distributed/collective/deep_ep/deep_ep.cpp +++ b/paddle/fluid/distributed/collective/deep_ep/deep_ep.cpp @@ -131,8 +131,8 @@ Buffer::Buffer(int rank, CUDA_CHECK(cudaMemsetAsync(workspace, 0, NUM_WORKSPACE_BYTES, comm_stream)); // MoE counter - CUDA_CHECK( - cudaMallocHost(&moe_recv_counter, sizeof(int64_t), cudaHostAllocMapped)); + CUDA_CHECK(cudaMallocHost( + &moe_recv_counter, sizeof(int64_t) * 3, cudaHostAllocMapped)); CUDA_CHECK(cudaHostGetDevicePointer( &moe_recv_counter_mapped, const_cast(moe_recv_counter), 0)); *moe_recv_counter = -1; @@ -150,7 +150,7 @@ Buffer::Buffer(int rank, // MoE RDMA-level counter if (num_rdma_ranks > 0) { CUDA_CHECK(cudaMallocHost( - &moe_recv_rdma_counter, sizeof(int), cudaHostAllocMapped)); + &moe_recv_rdma_counter, sizeof(int) * 3, cudaHostAllocMapped)); CUDA_CHECK(cudaHostGetDevicePointer(&moe_recv_rdma_counter_mapped, const_cast(moe_recv_rdma_counter), 0)); @@ -1881,6 +1881,192 @@ Buffer::internode_notify_combine( send_nvl_head}; } +std::tuple, + std::vector, + deep_ep::detail::Tensor, + deep_ep::detail::Tensor, + deep_ep::detail::Tensor, + deep_ep::detail::Tensor, + deep_ep::detail::Tensor> +Buffer::internode_fused_notify_combine( + const deep_ep::detail::Tensor& x, + const std::optional& x_scales, + const std::optional& topk_idx, + const std::optional& num_tokens_per_rank, + const std::optional& num_tokens_per_rdma_rank, + const deep_ep::detail::Tensor& is_token_in_rank, + int num_loop_stage, + const Config& config) { + const int num_channels = config.num_sms / 2; + EP_HOST_ASSERT(config.num_sms % 2 == 0); + EP_HOST_ASSERT(0 < get_num_rdma_ranks() && + get_num_rdma_ranks() <= NUM_MAX_RDMA_PEERS); + + EP_HOST_ASSERT(num_tokens_per_rank->scalar_type() == deep_ep::detail::kInt32); + EP_HOST_ASSERT(num_tokens_per_rdma_rank->scalar_type() == + deep_ep::detail::kInt32); + + // Shape and contiguous checks + EP_HOST_ASSERT(x.dim() == 2 && x.is_contiguous()); + EP_HOST_ASSERT((x.size(1) * x.element_size()) % sizeof(int4) == 0); + EP_HOST_ASSERT(num_tokens_per_rank->dim() == 1 && + num_tokens_per_rank->is_contiguous()); + EP_HOST_ASSERT(num_tokens_per_rdma_rank->dim() == 1 && + num_tokens_per_rdma_rank->is_contiguous()); + EP_HOST_ASSERT(num_tokens_per_rank->size(0) == num_ranks); + EP_HOST_ASSERT(num_tokens_per_rdma_rank->size(0) == num_rdma_ranks); + + int num_scales = 0; + if (x_scales.has_value()) { + num_scales = x_scales->dim() == 1 ? 1 : static_cast(x_scales->size(1)); + } + + auto num_tokens = static_cast(x.size(0)), + hidden = static_cast(x.size(1)), + hidden_int4 = + static_cast(x.size(1) * x.element_size() / sizeof(int4)); + + // Top-k checks + int num_topk = 0; + if (topk_idx.has_value()) { + num_topk = static_cast(topk_idx->size(1)); + EP_HOST_ASSERT(topk_idx->dim() == 2 && topk_idx->is_contiguous()); + EP_HOST_ASSERT(num_tokens == topk_idx->size(0)); + } + + // Allocate all tensors on comm stream if set + // NOTES: do not allocate tensors upfront! + auto compute_stream = calc_ctx->stream(); + stream_wait(comm_stream, compute_stream); + + auto rdma_channel_prefix_matrix = ConvertPaddleTensorToDetailTensor( + paddle::experimental::empty({num_rdma_ranks, num_channels}, + phi::DataType::INT32, + phi::GPUPlace(device_id))); + auto recv_rdma_rank_prefix_sum = + ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( + {num_rdma_ranks}, phi::DataType::INT32, phi::GPUPlace(device_id))); + auto gbl_channel_prefix_matrix = ConvertPaddleTensorToDetailTensor( + paddle::experimental::empty({num_ranks, num_channels}, + phi::DataType::INT32, + phi::GPUPlace(device_id))); + auto recv_gbl_rank_prefix_sum = + ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( + {num_ranks}, phi::DataType::INT32, phi::GPUPlace(device_id))); + + auto recv_rdma_channel_prefix_matrix = ConvertPaddleTensorToDetailTensor( + paddle::experimental::empty({num_rdma_ranks, num_channels}, + phi::DataType::INT32, + phi::GPUPlace(device_id))); + auto recv_gbl_channel_prefix_matrix = ConvertPaddleTensorToDetailTensor( + paddle::experimental::empty({num_ranks, num_channels}, + phi::DataType::INT32, + phi::GPUPlace(device_id))); + + auto send_rdma_head = ConvertPaddleTensorToDetailTensor( + paddle::experimental::empty({num_tokens, num_ranks / NUM_MAX_NVL_PEERS}, + phi::DataType::INT32, + phi::GPUPlace(device_id))); + auto send_nvl_head = + ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( + {num_tokens, num_ranks / NUM_MAX_NVL_PEERS, 8}, + phi::DataType::INT32, + phi::GPUPlace(device_id))); + + // Send sizes + for (int i = 0; i < num_loop_stage; ++i) { + moe_recv_counter[i] = -1; + moe_recv_rdma_counter[i] = -1; + } + internode::fused_notify_combine( + num_tokens_per_rank->data_ptr(), + moe_recv_counter_mapped, + num_ranks, + num_tokens_per_rdma_rank->data_ptr(), + moe_recv_rdma_counter_mapped, + is_token_in_rank.data_ptr(), + num_tokens, + num_channels, + hidden_int4, + num_scales, + num_topk, + num_loop_stage, + rdma_channel_prefix_matrix.data_ptr(), + recv_rdma_rank_prefix_sum.data_ptr(), + gbl_channel_prefix_matrix.data_ptr(), + recv_gbl_rank_prefix_sum.data_ptr(), + send_rdma_head.data_ptr(), + send_nvl_head.data_ptr(), + rdma_buffer_ptr, + config.num_max_rdma_chunked_recv_tokens, + buffer_ptrs_gpu, + config.num_max_nvl_chunked_recv_tokens, + task_fifo_ptrs_gpu, + head, + rank, + comm_stream, + config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), + num_nvl_bytes, + low_latency_mode); + + internode::fused_notify_combine_post_step( + num_ranks, + num_channels, + num_loop_stage, + recv_gbl_rank_prefix_sum.data_ptr(), + rdma_channel_prefix_matrix.data_ptr(), + gbl_channel_prefix_matrix.data_ptr(), + recv_rdma_channel_prefix_matrix.data_ptr(), + recv_gbl_channel_prefix_matrix.data_ptr(), + rdma_buffer_ptr, + buffer_ptrs_gpu, + task_fifo_ptrs_gpu, + head, + rank, + comm_stream, + low_latency_mode); + + // Synchronize total received tokens and tokens per expert + auto start_time = std::chrono::high_resolution_clock::now(); + while (true) { + // Read total count + bool ready = true; + for (int i = 0; i < num_loop_stage; ++i) { + ready = ready && (moe_recv_counter[i] >= 0) && + (moe_recv_rdma_counter[i] >= 0); + } + + if (ready) break; + + // Timeout check + if (std::chrono::duration_cast( + std::chrono::high_resolution_clock::now() - start_time) + .count() > NUM_CPU_TIMEOUT_SECS) { + LOG(INFO) << "Global rank: " << rank + << ", num_recv_tokens: " << moe_recv_counter[0] + << ", num_rdma_recv_tokens: " << moe_recv_rdma_counter[0]; + throw std::runtime_error("DeepEP error: timeout (dispatch CPU)"); + } + } + + std::vector num_recv_tokens(moe_recv_counter_mapped, + moe_recv_counter_mapped + num_loop_stage); + std::vector num_rdma_recv_tokens( + moe_recv_rdma_counter_mapped, + moe_recv_rdma_counter_mapped + num_loop_stage); + + // Wait streams + stream_wait(compute_stream, comm_stream); + + return {num_recv_tokens, + num_rdma_recv_tokens, + recv_rdma_rank_prefix_sum, + recv_rdma_channel_prefix_matrix, + recv_gbl_channel_prefix_matrix, + send_rdma_head, + send_nvl_head}; +} + #endif // PADDLE_WITH_NVSHMEM void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, @@ -3011,6 +3197,73 @@ Buffer::internode_notify_combine_api( #endif } +std::tuple, + std::vector, + paddle::Tensor, + paddle::Tensor, + paddle::Tensor, + paddle::Tensor, + paddle::Tensor> +Buffer::internode_fused_notify_combine_api( + const paddle::Tensor& x, + const std::optional& x_scales, + const std::optional& topk_idx, + const std::optional& num_tokens_per_rank, + const std::optional& num_tokens_per_rdma_rank, + const paddle::Tensor& is_token_in_rank, + int num_loop_stages, + const Config& config) { +#ifdef PADDLE_WITH_NVSHMEM + const auto& x_ = ConvertPaddleTensorToDetailTensor(x); + std::optional x_scales_ = + ConvertOptionalPaddleTensorToDetailTensor(x_scales); + + std::optional topk_idx_ = + ConvertOptionalPaddleTensorToDetailTensor(topk_idx); + std::optional num_tokens_per_rank_ = + ConvertOptionalPaddleTensorToDetailTensor(num_tokens_per_rank); + std::optional num_tokens_per_rdma_rank_ = + ConvertOptionalPaddleTensorToDetailTensor(num_tokens_per_rdma_rank); + const auto& is_token_in_rank_ = + ConvertPaddleTensorToDetailTensor(is_token_in_rank); + + auto res = internode_fused_notify_combine(x_, + x_scales_, + topk_idx_, + num_tokens_per_rank_, + num_tokens_per_rdma_rank_, + is_token_in_rank_, + num_loop_stages, + config); + + auto num_recv_tokens_ = std::get<0>(res); + auto num_rdma_recv_tokens_ = std::get<1>(res); + auto recv_rdma_rank_prefix_sum_ = + ConvertDetailTensorToPaddleTensor(std::get<2>(res)); + + auto recv_rdma_channel_prefix_matrix_ = + ConvertDetailTensorToPaddleTensor(std::get<3>(res)); + + auto recv_gbl_channel_prefix_matrix_ = + ConvertDetailTensorToPaddleTensor(std::get<4>(res)); + + auto send_rdma_head_ = ConvertDetailTensorToPaddleTensor(std::get<5>(res)); + auto send_nvl_head_ = ConvertDetailTensorToPaddleTensor(std::get<6>(res)); + + return {num_recv_tokens_, + num_rdma_recv_tokens_, + recv_rdma_rank_prefix_sum_, + recv_rdma_channel_prefix_matrix_, + recv_gbl_channel_prefix_matrix_, + send_rdma_head_, + send_nvl_head_}; +#else + LOG(ERROR) << "NVSHMEM is not enabled. You can enable it by setting cmake " + "option WITH_NVSHMEM=ON."; + return {}; +#endif +} + std::tuple, paddle::Tensor, diff --git a/paddle/fluid/distributed/collective/deep_ep/deep_ep.hpp b/paddle/fluid/distributed/collective/deep_ep/deep_ep.hpp index 8185ae7e7a51ec..b3a0c32a0e2cee 100644 --- a/paddle/fluid/distributed/collective/deep_ep/deep_ep.hpp +++ b/paddle/fluid/distributed/collective/deep_ep/deep_ep.hpp @@ -302,6 +302,23 @@ struct Buffer { int expert_alignment, const Config& config); + std::tuple, + std::vector, + deep_ep::detail::Tensor, + deep_ep::detail::Tensor, + deep_ep::detail::Tensor, + deep_ep::detail::Tensor, + deep_ep::detail::Tensor> + internode_fused_notify_combine( + const deep_ep::detail::Tensor& x, + const std::optional& x_scales, + const std::optional& topk_idx, + const std::optional& num_tokens_per_rank, + const std::optional& num_tokens_per_rdma_rank, + const deep_ep::detail::Tensor& is_token_in_rank, + int num_loop_stage, + const Config& config); + #endif // PADDLE_WITH_NVSHMEM void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, @@ -464,6 +481,23 @@ struct Buffer { int expert_alignment, const Config& config); + std::tuple, + std::vector, + paddle::Tensor, + paddle::Tensor, + paddle::Tensor, + paddle::Tensor, + paddle::Tensor> + internode_fused_notify_combine_api( + const paddle::Tensor& x, + const std::optional& x_scales, + const std::optional& topk_idx, + const std::optional& num_tokens_per_rank, + const std::optional& num_tokens_per_rdma_rank, + const paddle::Tensor& is_token_in_rank, + int num_loop_stages, + const Config& config); + std::tuple, paddle::Tensor, diff --git a/paddle/fluid/distributed/collective/deep_ep/kernels/api.cuh b/paddle/fluid/distributed/collective/deep_ep/kernels/api.cuh index b474af49c4b1b5..c8a28a46208ffd 100644 --- a/paddle/fluid/distributed/collective/deep_ep/kernels/api.cuh +++ b/paddle/fluid/distributed/collective/deep_ep/kernels/api.cuh @@ -243,6 +243,52 @@ void notify_combine_post_step(int num_ranks, cudaStream_t stream, bool low_latency_mode); +void fused_notify_combine(const int* num_tokens_per_rank, + int* moe_recv_counter_mapped, + int num_ranks, + const int* num_tokens_per_rdma_rank, + int* moe_recv_rdma_counter_mapped, + const bool* is_token_in_rank, + int num_tokens, + int num_channels, + int hidden_int4, + int num_scales, + int num_topk, + int num_loop_stage, + int* rdma_channel_prefix_matrix, + int* recv_rdma_rank_prefix_sum, + int* gbl_channel_prefix_matrix, + int* recv_gbl_rank_prefix_sum, + int* send_rdma_head, + int* send_nvl_head, + void* rdma_buffer_ptr, + int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, + int num_max_nvl_chunked_recv_tokens, + int** task_fifo_ptrs, + int head, + int rank, + cudaStream_t stream, + int64_t num_rdma_bytes, + int64_t num_nvl_bytes, + bool low_latency_mode); + +void fused_notify_combine_post_step(int num_ranks, + int num_channels, + int num_loop_stage, + const int* recv_gbl_rank_prefix_sum, + const int* rdma_channel_prefix_matrix, + const int* gbl_channel_prefix_matrix, + int* recv_rdma_channel_prefix_matrix, + int* recv_gbl_channel_prefix_matrix, + void* rdma_buffer_ptr, + void** buffer_ptrs, + int** task_fifo_ptrs, + int head, + int rank, + cudaStream_t stream, + bool low_latency_mode); + void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, diff --git a/paddle/fluid/distributed/collective/deep_ep/kernels/internode.cu b/paddle/fluid/distributed/collective/deep_ep/kernels/internode.cu index e38239c4bdeadc..50dc911bb10297 100644 --- a/paddle/fluid/distributed/collective/deep_ep/kernels/internode.cu +++ b/paddle/fluid/distributed/collective/deep_ep/kernels/internode.cu @@ -1077,6 +1077,597 @@ void notify_combine_post_step(int num_ranks, #undef NOTIFY_DISPATCH_LAUNCH_CASE } +template +__global__ void fused_notify_combine( + const int* num_tokens_per_rank, // [num_loop_stage, 2, num_ranks] + int* moe_recv_counter_mapped, + int num_ranks, + const int* num_tokens_per_rdma_rank, + int* moe_recv_rdma_counter_mapped, + const bool* is_token_in_rank, + int num_tokens, + int num_channels, + int num_loop_stage, + const int rdma_clean_offset, + const int rdma_num_int_clean, + const int nvl_clean_offset, + const int nvl_num_int_clean, + int* rdma_channel_prefix_matrix, + int* gbl_channel_prefix_matrix, + int* recv_rdma_rank_prefix_sum, + int* recv_gbl_rank_prefix_sum, + int* send_rdma_head, + int* send_nvl_head, + void* rdma_buffer_ptr, + void** buffer_ptrs, + int** task_fifo_ptrs, + int head, + int rank, + const nvshmem_team_t rdma_team) { + auto sm_id = static_cast(blockIdx.x); + auto thread_id = static_cast(threadIdx.x), warp_id = thread_id / 32, + lane_id = get_lane_id(); + auto num_threads = static_cast(blockDim.x), num_warps = num_threads / 32; + + auto rdma_rank = rank / NUM_MAX_NVL_PEERS, + nvl_rank = rank % NUM_MAX_NVL_PEERS; + + if (sm_id == 0) { + // Communication with others + // Global barrier: the first warp does intra-node sync, the second warp does + // internode sync + EP_DEVICE_ASSERT(num_warps > 1); + EP_DEVICE_ASSERT(kNumRDMARanks <= num_threads); + if (thread_id == 32) + nvshmem_barrier_with_same_gpu_idx(rdma_team); + barrier_device(task_fifo_ptrs, head, nvl_rank); + move_fifo_slots(head); + __syncthreads(); + + // Send numbers of tokens per rank/expert to RDMA ranks + auto rdma_buffer_ptr_int = static_cast(rdma_buffer_ptr); + auto rdma_recv_num_tokens_mixed = + SymBuffer(rdma_buffer_ptr, + (NUM_MAX_NVL_PEERS + 1) * num_loop_stage, + kNumRDMARanks); + + // Clean up for later data dispatch + EP_DEVICE_ASSERT(rdma_recv_num_tokens_mixed.total_bytes <= + rdma_clean_offset * sizeof(int)); + +#pragma unroll + for (int i = thread_id; i < rdma_num_int_clean; i += num_threads) + rdma_buffer_ptr_int[rdma_clean_offset + i] = 0; + +// Copy to send buffer +#pragma unroll + for (int i = thread_id; i < num_ranks; i += num_threads) { + for (int j = 0; j < num_loop_stage; ++j) { + rdma_recv_num_tokens_mixed.send_buffer( + i / NUM_MAX_NVL_PEERS)[j * (NUM_MAX_NVL_PEERS + 1) + + (i % NUM_MAX_NVL_PEERS)] = + num_tokens_per_rank[(j * 2 + 1) * num_ranks + i]; + } + } + +#pragma unroll + if (thread_id < kNumRDMARanks) { + for (int j = 0; j < num_loop_stage; ++j) { + rdma_recv_num_tokens_mixed.send_buffer( + thread_id)[j * (NUM_MAX_NVL_PEERS + 1) + NUM_MAX_NVL_PEERS] = + num_tokens_per_rdma_rank[j * num_ranks + thread_id]; + } + } + + __syncthreads(); + + // Issue send + if (thread_id < kNumRDMARanks) { + nvshmem_int_put_nbi( + rdma_recv_num_tokens_mixed.recv_buffer(rdma_rank), + rdma_recv_num_tokens_mixed.send_buffer(thread_id), + (NUM_MAX_NVL_PEERS + 1) * num_loop_stage, + translate_dst_rdma_rank(thread_id, nvl_rank)); + } + __syncthreads(); + + // Barrier + if (thread_id == 0) { + nvshmem_barrier_with_same_gpu_idx(rdma_team); + } + __syncthreads(); + + // NVL buffers + auto nvl_send_buffer = + thread_id < NUM_MAX_NVL_PEERS ? buffer_ptrs[thread_id] : nullptr; + auto nvl_recv_buffer = buffer_ptrs[nvl_rank]; + + auto nvl_send_num_tokens_per_rank = AsymBuffer( + nvl_send_buffer, kNumRDMARanks * num_loop_stage, NUM_MAX_NVL_PEERS); + auto nvl_recv_num_tokens_per_rank = AsymBuffer( + nvl_recv_buffer, kNumRDMARanks * num_loop_stage, NUM_MAX_NVL_PEERS); + + // Clean up for later data dispatch + auto nvl_buffer_ptr_int = static_cast(buffer_ptrs[nvl_rank]); + EP_DEVICE_ASSERT(nvl_send_num_tokens_per_rank.total_bytes <= + nvl_clean_offset * sizeof(int)); +#pragma unroll + for (int i = thread_id; i < nvl_num_int_clean; i += num_threads) + nvl_buffer_ptr_int[nvl_clean_offset + i] = 0; + __syncthreads(); + + // Reduce RDMA received tokens + if (thread_id < num_loop_stage) { + int sum = 0; +#pragma unroll + for (int i = 0; i < kNumRDMARanks; ++i) { + sum += rdma_recv_num_tokens_mixed.recv_buffer( + i)[thread_id * (NUM_MAX_NVL_PEERS + 1) + NUM_MAX_NVL_PEERS]; + recv_rdma_rank_prefix_sum[thread_id * kNumRDMARanks + i] = sum; + } + while (ld_volatile_global(moe_recv_rdma_counter_mapped + thread_id) != + -1) { + } + moe_recv_rdma_counter_mapped[thread_id] = sum; + } + + // Send numbers of tokens per rank/expert to NVL ranks + EP_DEVICE_ASSERT(NUM_MAX_NVL_PEERS <= num_threads); + if (thread_id < NUM_MAX_NVL_PEERS) { +#pragma unroll + for (int i = 0; i < kNumRDMARanks; ++i) { + for (int j = 0; j < num_loop_stage; ++j) { + nvl_send_num_tokens_per_rank.buffer(nvl_rank)[j * kNumRDMARanks + i] = + rdma_recv_num_tokens_mixed.recv_buffer( + i)[j * (NUM_MAX_NVL_PEERS + 1) + thread_id]; + } + } + } + + memory_fence(); + __syncthreads(); + barrier_device(task_fifo_ptrs, head, nvl_rank); + move_fifo_slots(head); + __syncthreads(); + + if (thread_id < num_loop_stage) { + int sum = 0; +#pragma unroll + for (int i = 0; i < num_ranks; ++i) { + int src_rdma_rank = i / NUM_MAX_NVL_PEERS, + src_nvl_rank = i % NUM_MAX_NVL_PEERS; + sum += nvl_recv_num_tokens_per_rank.buffer( + src_nvl_rank)[thread_id * kNumRDMARanks + src_rdma_rank]; + recv_gbl_rank_prefix_sum[thread_id * num_ranks + i] = sum; + } + while (ld_volatile_global(moe_recv_counter_mapped + thread_id) != -1) { + } + moe_recv_counter_mapped[thread_id] = sum; + } + + // Finally barrier + if (thread_id == 32) + nvshmem_barrier_with_same_gpu_idx(rdma_team); + barrier_device(task_fifo_ptrs, head, nvl_rank); + move_fifo_slots(head); + } else { + // Calculate meta data + int stage_id = (sm_id - 1) / num_loop_stage; + int dst_rdma_rank = (sm_id - 1) % num_loop_stage; + for (int channel_id = warp_id; channel_id < num_channels; + channel_id += num_warps) { + int token_start_idx, token_end_idx; + get_channel_task_range( + num_tokens, num_channels, channel_id, token_start_idx, token_end_idx); + + // Iterate over tokens + int total_count = 0, per_nvl_rank_count[NUM_MAX_NVL_PEERS] = {0}; + int global_rdma_tail_idx = 0, + global_nvl_tail_idx[NUM_MAX_NVL_PEERS] = {0}; + for (int i = token_start_idx + lane_id; i < token_end_idx; i += 32) { + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * sizeof(bool) == sizeof(uint64_t), + "Invalid number of NVL peers"); + auto is_token_in_rank_uint64 = *reinterpret_cast( + is_token_in_rank + (stage_id * num_tokens + i) * num_ranks + + dst_rdma_rank * NUM_MAX_NVL_PEERS); + auto is_token_in_rank_values = + reinterpret_cast(&is_token_in_rank_uint64); + + total_count += (is_token_in_rank_uint64 != 0); + + // Calculate RDMA tail index for combine + auto warp_valid_tokens = std::min(token_end_idx - (i - lane_id), 32); + unsigned int mask = 0xffffffff >> (32 - warp_valid_tokens); + int warp_rdma_tail_idx = (is_token_in_rank_uint64 != 0); + global_rdma_tail_idx += warp_scan(warp_rdma_tail_idx, mask); + auto rdma_tail_idx = + is_token_in_rank_uint64 == 0 ? -1 : global_rdma_tail_idx - 1; + send_rdma_head[(stage_id * num_tokens + i) * kNumRDMARanks + + dst_rdma_rank] = rdma_tail_idx; + global_rdma_tail_idx = + __shfl_sync(mask, global_rdma_tail_idx, warp_valid_tokens - 1); + +#pragma unroll + for (int j = 0; j < NUM_MAX_NVL_PEERS; ++j) { + per_nvl_rank_count[j] += is_token_in_rank_values[j]; + int warp_nvl_tail_idx = (is_token_in_rank_values[j]); + global_nvl_tail_idx[j] += warp_scan(warp_nvl_tail_idx, mask); + auto nvl_tail_idx = + is_token_in_rank_values[j] == 0 ? -1 : global_nvl_tail_idx[j] - 1; + send_nvl_head[(stage_id * num_tokens + i) * kNumRDMARanks * + NUM_MAX_NVL_PEERS + + dst_rdma_rank * NUM_MAX_NVL_PEERS + j] = nvl_tail_idx; + global_nvl_tail_idx[j] = + __shfl_sync(mask, global_nvl_tail_idx[j], warp_valid_tokens - 1); + } + } + + // Warp reduce + total_count = warp_reduce_sum(total_count); +#pragma unroll + for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i) + per_nvl_rank_count[i] = warp_reduce_sum(per_nvl_rank_count[i]); + + // Write into channel matrix + if (lane_id == 0) { +#pragma unroll + for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i) + gbl_channel_prefix_matrix[(stage_id * num_ranks + + dst_rdma_rank * NUM_MAX_NVL_PEERS + i) * + num_channels + + channel_id] = per_nvl_rank_count[i]; + rdma_channel_prefix_matrix[(stage_id * kNumRDMARanks + dst_rdma_rank) * + num_channels + + channel_id] = total_count; + } + } + + // Calculate prefix sum + __syncthreads(); + if (thread_id == 0) { + auto prefix_row = + rdma_channel_prefix_matrix + + (stage_id * kNumRDMARanks + dst_rdma_rank) * num_channels; +#pragma unroll + for (int i = 1; i < num_channels; ++i) { + prefix_row[i] += prefix_row[i - 1]; + } + } + + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= 32, "Invalid number of NVL peers"); + if (thread_id < NUM_MAX_NVL_PEERS) { + auto prefix_row = gbl_channel_prefix_matrix + + (stage_id * num_ranks + + dst_rdma_rank * NUM_MAX_NVL_PEERS + thread_id) * + num_channels; +#pragma unroll + for (int i = 1; i < num_channels; ++i) { + prefix_row[i] += prefix_row[i - 1]; + } + } + } +} + +void fused_notify_combine(const int* num_tokens_per_rank, + int* moe_recv_counter_mapped, + int num_ranks, + const int* num_tokens_per_rdma_rank, + int* moe_recv_rdma_counter_mapped, + const bool* is_token_in_rank, + int num_tokens, + int num_channels, + int hidden_int4, + int num_scales, + int num_topk, + int num_loop_stage, + int* rdma_channel_prefix_matrix, + int* recv_rdma_rank_prefix_sum, + int* gbl_channel_prefix_matrix, + int* recv_gbl_rank_prefix_sum, + int* send_rdma_head, + int* send_nvl_head, + void* rdma_buffer_ptr, + int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, + int num_max_nvl_chunked_recv_tokens, + int** task_fifo_ptrs, + int head, + int rank, + cudaStream_t stream, + int64_t num_rdma_bytes, + int64_t num_nvl_bytes, + bool low_latency_mode) { +#define NOTIFY_COMBINE_LAUNCH_CASE(num_rdma_ranks) \ + { \ + auto notify_combine_func = \ + low_latency_mode ? fused_notify_combine \ + : fused_notify_combine; \ + LAUNCH_KERNEL(&cfg, \ + notify_combine_func, \ + num_tokens_per_rank, \ + moe_recv_counter_mapped, \ + num_ranks, \ + num_tokens_per_rdma_rank, \ + moe_recv_rdma_counter_mapped, \ + is_token_in_rank, \ + num_tokens, \ + num_channels, \ + num_loop_stage, \ + rdma_clean_meta.first, \ + rdma_clean_meta.second, \ + nvl_clean_meta.first, \ + nvl_clean_meta.second, \ + rdma_channel_prefix_matrix, \ + gbl_channel_prefix_matrix, \ + recv_rdma_rank_prefix_sum, \ + recv_gbl_rank_prefix_sum, \ + send_rdma_head, \ + send_nvl_head, \ + rdma_buffer_ptr, \ + buffer_ptrs, \ + task_fifo_ptrs, \ + head, \ + rank, \ + cpu_rdma_team); \ + } \ + break + + constexpr int kNumThreads = 512; + const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; + + // Get clean meta + auto rdma_clean_meta = get_rdma_clean_meta(hidden_int4, + num_scales, + num_topk, + num_topk, + num_rdma_ranks, + num_max_rdma_chunked_recv_tokens, + num_channels); + auto nvl_clean_meta = get_nvl_clean_meta(hidden_int4, + num_scales, + num_topk, + num_topk, + num_rdma_ranks, + NUM_MAX_NVL_PEERS, + num_max_nvl_chunked_recv_tokens, + num_channels); + EP_HOST_ASSERT((rdma_clean_meta.first + rdma_clean_meta.second) * + sizeof(int) <= + num_rdma_bytes); + EP_HOST_ASSERT((nvl_clean_meta.first + nvl_clean_meta.second) * sizeof(int) <= + num_nvl_bytes); + EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits::max()); + EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits::max()); + + // Launch kernel + SETUP_LAUNCH_CONFIG(1 + num_rdma_ranks * num_loop_stage, kNumThreads, stream); + SWITCH_RDMA_RANKS(NOTIFY_COMBINE_LAUNCH_CASE); +#undef NOTIFY_DISPATCH_LAUNCH_CASE +} + +template +__global__ void fused_notify_combine_post_step( + int num_channels, + int num_loop_stage, + const int* recv_gbl_rank_prefix_sum, + const int* rdma_channel_prefix_matrix, + const int* gbl_channel_prefix_matrix, + int* recv_rdma_channel_prefix_matrix, + int* recv_gbl_channel_prefix_matrix, + void* rdma_buffer_ptr, + void** buffer_ptrs, + int** task_fifo_ptrs, + int head, + int rank, + const nvshmem_team_t rdma_team) { + auto sm_id = static_cast(blockIdx.x); + EP_DEVICE_ASSERT(sm_id == 0); + auto thread_id = static_cast(threadIdx.x), warp_id = thread_id / 32, + lane_id = get_lane_id(); + auto num_threads = static_cast(blockDim.x), num_warps = num_threads / 32; + + auto rdma_rank = rank / NUM_MAX_NVL_PEERS, + nvl_rank = rank % NUM_MAX_NVL_PEERS; + + auto rdma_channel_meta = + SymBuffer(rdma_buffer_ptr, + (num_channels + num_channels * NUM_MAX_NVL_PEERS) * + num_loop_stage, // (rdma_channel_meta + + // nvl_channel_meta) * num_loop_stage + kNumRDMARanks); + + // NVL buffers + auto nvl_send_buffer = + warp_id - 1 < NUM_MAX_NVL_PEERS ? buffer_ptrs[warp_id - 1] : nullptr; + auto nvl_recv_buffer = buffer_ptrs[nvl_rank]; + + auto nvl_send_channel_meta = + AsymBuffer(nvl_send_buffer, + num_loop_stage * kNumRDMARanks * num_channels, + NUM_MAX_NVL_PEERS); + auto nvl_recv_channel_meta = + AsymBuffer(nvl_recv_buffer, + num_loop_stage * kNumRDMARanks * num_channels, + NUM_MAX_NVL_PEERS); + + if (thread_id == 32) + nvshmem_barrier_with_same_gpu_idx(rdma_team); + barrier_device(task_fifo_ptrs, head, nvl_rank); + move_fifo_slots(head); + __syncthreads(); + + if (warp_id == 0) { // rdma_channel_prefix_matrix data + // land_id -> dst_rdma_rank + if (lane_id < kNumRDMARanks) { + for (int j = 0; j < num_loop_stage; ++j) { + for (int i = 0; i < num_channels; ++i) { + rdma_channel_meta.send_buffer( + lane_id)[j * (num_channels + num_channels * NUM_MAX_NVL_PEERS) + + i] = + -rdma_channel_prefix_matrix[(j * kNumRDMARanks + lane_id) * + num_channels + + i] - + 1; + } + } + } + } else if (warp_id < + NUM_MAX_NVL_PEERS + 1) { // gbl_channel_prefix_matrix data + // land_id -> dst_rdma_rank + // warp_id -1 -> dst_nvl_rank + if (lane_id < kNumRDMARanks) { + for (int j = 0; j < num_loop_stage; ++j) { + auto dst_ptr = rdma_channel_meta.send_buffer(lane_id) + + j * (num_channels + num_channels * NUM_MAX_NVL_PEERS) + + num_channels + (warp_id - 1) * num_channels; + dst_ptr[0] = -0 - 1; + for (int i = 1; i < num_channels; ++i) { + dst_ptr[i] = -gbl_channel_prefix_matrix[(j * kNumRDMARanks * + NUM_MAX_NVL_PEERS + + lane_id * NUM_MAX_NVL_PEERS + + warp_id - 1) * + num_channels + + i - 1] - + 1; + } + } + } + } + __syncthreads(); + + // Issue send + if (thread_id < kNumRDMARanks) { + nvshmem_int_put_nbi( + rdma_channel_meta.recv_buffer(rdma_rank), + rdma_channel_meta.send_buffer(thread_id), + (num_channels + num_channels * NUM_MAX_NVL_PEERS) * num_loop_stage, + translate_dst_rdma_rank(thread_id, nvl_rank)); + } + __syncthreads(); + + // Barrier + if (thread_id == 0) { + nvshmem_barrier_with_same_gpu_idx(rdma_team); + } + __syncthreads(); + + // Receive RDMA + if (warp_id == 0) { + // lane_id -> src_rdma_rank + if (lane_id < kNumRDMARanks) { + for (int j = 0; j < num_loop_stage; ++j) { + for (int i = 0; i < num_channels; ++i) { + recv_rdma_channel_prefix_matrix[(j * kNumRDMARanks + lane_id) * + num_channels + + i] = + -rdma_channel_meta.recv_buffer(lane_id) + [j * (num_channels + num_channels * NUM_MAX_NVL_PEERS) + i] - + 1; + } + } + } + + } else if (warp_id < NUM_MAX_NVL_PEERS + 1) { + // lane_id -> src_rdma_rank + // warp_id - 1 -> dst_nvl_rank + if (lane_id < kNumRDMARanks) { + for (int j = 0; j < num_loop_stage; ++j) { + auto recv_ptr = rdma_channel_meta.recv_buffer(lane_id) + + j * (num_channels + num_channels * NUM_MAX_NVL_PEERS) + + num_channels + (warp_id - 1) * num_channels; + for (int i = 0; i < num_channels; ++i) { + st_relaxed_sys_global( + nvl_send_channel_meta.buffer(nvl_rank) + + (j * kNumRDMARanks + lane_id) * num_channels + i, + recv_ptr[i]); + } + } + } + } + memory_fence(); + __syncthreads(); + barrier_device(task_fifo_ptrs, head, nvl_rank); + move_fifo_slots(head); + __syncthreads(); + + EP_DEVICE_ASSERT(kNumRDMARanks * NUM_MAX_NVL_PEERS <= num_threads); + if (thread_id < kNumRDMARanks * NUM_MAX_NVL_PEERS) { + const auto src_rdma_rank = thread_id / NUM_MAX_NVL_PEERS; + const auto src_nvl_rank = thread_id % NUM_MAX_NVL_PEERS; + for (int j = 0; j < num_loop_stage; ++j) { + int rank_offset = + thread_id > 0 + ? recv_gbl_rank_prefix_sum[j * kNumRDMARanks * NUM_MAX_NVL_PEERS + + thread_id - 1] + : 0; +#pragma unroll + for (int i = 0; i < num_channels; ++i) { + recv_gbl_channel_prefix_matrix[(j * kNumRDMARanks * NUM_MAX_NVL_PEERS + + thread_id) * + num_channels + + i] = + rank_offset - + nvl_recv_channel_meta.buffer(src_nvl_rank) + [(j * kNumRDMARanks + src_rdma_rank) * num_channels + i] - + 1; + } + } + } + + // Finally barrier + if (thread_id == 32) + nvshmem_barrier_with_same_gpu_idx(rdma_team); + barrier_device(task_fifo_ptrs, head, nvl_rank); + move_fifo_slots(head); +} + +void fused_notify_combine_post_step(int num_ranks, + int num_channels, + int num_loop_stage, + const int* recv_gbl_rank_prefix_sum, + const int* rdma_channel_prefix_matrix, + const int* gbl_channel_prefix_matrix, + int* recv_rdma_channel_prefix_matrix, + int* recv_gbl_channel_prefix_matrix, + void* rdma_buffer_ptr, + void** buffer_ptrs, + int** task_fifo_ptrs, + int head, + int rank, + cudaStream_t stream, + bool low_latency_mode) { +#define NOTIFY_COMBINE_S1_LAUNCH_CASE(num_rdma_ranks) \ + { \ + auto notify_combine_post_step_func = \ + low_latency_mode \ + ? fused_notify_combine_post_step \ + : fused_notify_combine_post_step; \ + LAUNCH_KERNEL(&cfg, \ + notify_combine_post_step_func, \ + num_channels, \ + num_loop_stage, \ + recv_gbl_rank_prefix_sum, \ + rdma_channel_prefix_matrix, \ + gbl_channel_prefix_matrix, \ + recv_rdma_channel_prefix_matrix, \ + recv_gbl_channel_prefix_matrix, \ + rdma_buffer_ptr, \ + buffer_ptrs, \ + task_fifo_ptrs, \ + head, \ + rank, \ + cpu_rdma_team); \ + } \ + break + + constexpr int kNumThreads = 512; + + // Launch kernel + SETUP_LAUNCH_CONFIG(1, kNumThreads, stream); + SWITCH_RDMA_RANKS(NOTIFY_COMBINE_S1_LAUNCH_CASE); +#undef NOTIFY_DISPATCH_LAUNCH_CASE +} + // At most 8 RDMA ranks to be sent constexpr int get_num_topk_rdma_ranks(int num_rdma_ranks) { return num_rdma_ranks < 8 ? num_rdma_ranks : 8; @@ -2660,7 +3251,7 @@ template < int kNumWarpsPerForwarder = (kNumCombineForwarderWarps / kNumRDMARanks > 0) ? kNumCombineForwarderWarps / kNumRDMARanks : 1, - int kNumForwarders = kNumRDMARanks* kNumWarpsPerForwarder, + int kNumForwarders = kNumRDMARanks * kNumWarpsPerForwarder, int kNumRDMAReceivers = kNumForwarders + NUM_MAX_NVL_PEERS> __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32, 1) diff --git a/paddle/fluid/pybind/deep_ep_api.cc b/paddle/fluid/pybind/deep_ep_api.cc index 001d134ff245d1..ae79ff62a3878f 100644 --- a/paddle/fluid/pybind/deep_ep_api.cc +++ b/paddle/fluid/pybind/deep_ep_api.cc @@ -100,6 +100,8 @@ void BindDeepEPApi(pybind11::module *m) { .def("clear_buffer", &deep_ep::Buffer::clear_buffer_api) .def("internode_notify_combine", &deep_ep::Buffer::internode_notify_combine_api) + .def("internode_fused_notify_combine_api", + &deep_ep::Buffer::internode_fused_notify_combine_api) .def("internode_combine", &deep_ep::Buffer::internode_combine_api) .def("barrier_all", &deep_ep::Buffer::barrier_all) .def("clean_low_latency_buffer", diff --git a/python/paddle/distributed/communication/deep_ep/buffer.py b/python/paddle/distributed/communication/deep_ep/buffer.py index c04eb9c1afaebd..5ae7f358e529b9 100644 --- a/python/paddle/distributed/communication/deep_ep/buffer.py +++ b/python/paddle/distributed/communication/deep_ep/buffer.py @@ -65,6 +65,7 @@ def __init__( num_rdma_bytes: int = 0, low_latency_mode: bool = False, num_qps_per_rank: int = 12, + # num_loop_stages: int = 1, ) -> None: """ Initialize the communication buffer. @@ -957,6 +958,63 @@ def internode_notify_combine( send_nvl_head, ) + def internode_fused_notify_combine( + self, + x: paddle.Tensor | tuple[paddle.Tensor, paddle.Tensor], + topk_idx: paddle.Tensor | None = None, + num_tokens_per_rank: paddle.Tensor | None = None, + num_tokens_per_rdma_rank: paddle.Tensor | None = None, + is_token_in_rank: paddle.Tensor | None = None, + num_loop_stages: int = 1, + config: Config | None = None, + ) -> tuple[ + list[int], + list[int], + paddle.Tensor, + paddle.Tensor, + paddle.Tensor, + paddle.Tensor, + paddle.Tensor, + ]: + # Default config + config = ( + self.get_dispatch_config(self.group_size) + if config is None + else config + ) + # Launch the kernel with cached or non-cached mode + x, x_scales = x if isinstance(x, tuple) else (x, None) + assert num_tokens_per_rank is not None and is_token_in_rank is not None + + ( + num_recv_tokens, + num_rdma_recv_tokens, + recv_rdma_rank_prefix_sum, + recv_rdma_channel_prefix_matrix, + recv_gbl_channel_prefix_matrix, + send_rdma_head, + send_nvl_head, + ) = self.runtime.internode_fused_notify_combine( + x, + x_scales, + topk_idx, + num_tokens_per_rank, + num_tokens_per_rdma_rank, + is_token_in_rank, + num_loop_stages, + config, + ) + + return ( + num_recv_tokens, + num_rdma_recv_tokens, + recv_rdma_rank_prefix_sum, + recv_rdma_channel_prefix_matrix, + recv_gbl_channel_prefix_matrix, + send_rdma_head, + send_nvl_head, + ) + # noinspection PyTypeChecker def internode_combine( self, From 3b4bbdb993e580f79a7de014cb815c2a55c351d6 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Mon, 24 Nov 2025 11:01:39 +0800 Subject: [PATCH 2/4] add fused_nodtify_combine --- .../collective/deep_ep/deep_ep.cpp | 60 ++++---- .../collective/deep_ep/deep_ep.hpp | 1 + .../collective/deep_ep/kernels/internode.cu | 144 ++++++++++++++++-- paddle/fluid/pybind/deep_ep_api.cc | 4 +- .../communication/deep_ep/buffer.py | 3 +- 5 files changed, 171 insertions(+), 41 deletions(-) diff --git a/paddle/fluid/distributed/collective/deep_ep/deep_ep.cpp b/paddle/fluid/distributed/collective/deep_ep/deep_ep.cpp index 5aadc01ec8df8b..f18db8468fbfaa 100644 --- a/paddle/fluid/distributed/collective/deep_ep/deep_ep.cpp +++ b/paddle/fluid/distributed/collective/deep_ep/deep_ep.cpp @@ -49,6 +49,7 @@ void SetAllocatorStreamForGPUContext(cudaStream_t stream, Buffer::Buffer(int rank, int num_ranks, + int num_loop_stage, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode, @@ -131,8 +132,9 @@ Buffer::Buffer(int rank, CUDA_CHECK(cudaMemsetAsync(workspace, 0, NUM_WORKSPACE_BYTES, comm_stream)); // MoE counter - CUDA_CHECK(cudaMallocHost( - &moe_recv_counter, sizeof(int64_t) * 3, cudaHostAllocMapped)); + CUDA_CHECK(cudaMallocHost(&moe_recv_counter, + sizeof(int64_t) * num_loop_stage, + cudaHostAllocMapped)); CUDA_CHECK(cudaHostGetDevicePointer( &moe_recv_counter_mapped, const_cast(moe_recv_counter), 0)); *moe_recv_counter = -1; @@ -149,8 +151,9 @@ Buffer::Buffer(int rank, // MoE RDMA-level counter if (num_rdma_ranks > 0) { - CUDA_CHECK(cudaMallocHost( - &moe_recv_rdma_counter, sizeof(int) * 3, cudaHostAllocMapped)); + CUDA_CHECK(cudaMallocHost(&moe_recv_rdma_counter, + sizeof(int) * num_loop_stage, + cudaHostAllocMapped)); CUDA_CHECK(cudaHostGetDevicePointer(&moe_recv_rdma_counter_mapped, const_cast(moe_recv_rdma_counter), 0)); @@ -1909,12 +1912,12 @@ Buffer::internode_fused_notify_combine( // Shape and contiguous checks EP_HOST_ASSERT(x.dim() == 2 && x.is_contiguous()); EP_HOST_ASSERT((x.size(1) * x.element_size()) % sizeof(int4) == 0); - EP_HOST_ASSERT(num_tokens_per_rank->dim() == 1 && + EP_HOST_ASSERT(num_tokens_per_rank->dim() == 2 && num_tokens_per_rank->is_contiguous()); - EP_HOST_ASSERT(num_tokens_per_rdma_rank->dim() == 1 && + EP_HOST_ASSERT(num_tokens_per_rdma_rank->dim() == 2 && num_tokens_per_rdma_rank->is_contiguous()); - EP_HOST_ASSERT(num_tokens_per_rank->size(0) == num_ranks); - EP_HOST_ASSERT(num_tokens_per_rdma_rank->size(0) == num_rdma_ranks); + EP_HOST_ASSERT(num_tokens_per_rank->size(0) == num_loop_stage); + EP_HOST_ASSERT(num_tokens_per_rdma_rank->size(0) == num_loop_stage); int num_scales = 0; if (x_scales.has_value()) { @@ -1939,37 +1942,42 @@ Buffer::internode_fused_notify_combine( auto compute_stream = calc_ctx->stream(); stream_wait(comm_stream, compute_stream); - auto rdma_channel_prefix_matrix = ConvertPaddleTensorToDetailTensor( - paddle::experimental::empty({num_rdma_ranks, num_channels}, + auto rdma_channel_prefix_matrix = + ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( + {num_loop_stage, num_rdma_ranks, num_channels}, + phi::DataType::INT32, + phi::GPUPlace(device_id))); + auto recv_rdma_rank_prefix_sum = ConvertPaddleTensorToDetailTensor( + paddle::experimental::empty({num_loop_stage, num_rdma_ranks}, phi::DataType::INT32, phi::GPUPlace(device_id))); - auto recv_rdma_rank_prefix_sum = - ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( - {num_rdma_ranks}, phi::DataType::INT32, phi::GPUPlace(device_id))); auto gbl_channel_prefix_matrix = ConvertPaddleTensorToDetailTensor( - paddle::experimental::empty({num_ranks, num_channels}, + paddle::experimental::empty({num_loop_stage, num_ranks, num_channels}, phi::DataType::INT32, phi::GPUPlace(device_id))); - auto recv_gbl_rank_prefix_sum = - ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( - {num_ranks}, phi::DataType::INT32, phi::GPUPlace(device_id))); - - auto recv_rdma_channel_prefix_matrix = ConvertPaddleTensorToDetailTensor( - paddle::experimental::empty({num_rdma_ranks, num_channels}, + auto recv_gbl_rank_prefix_sum = ConvertPaddleTensorToDetailTensor( + paddle::experimental::empty({num_loop_stage, num_ranks}, phi::DataType::INT32, phi::GPUPlace(device_id))); + + auto recv_rdma_channel_prefix_matrix = + ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( + {num_loop_stage, num_rdma_ranks, num_channels}, + phi::DataType::INT32, + phi::GPUPlace(device_id))); auto recv_gbl_channel_prefix_matrix = ConvertPaddleTensorToDetailTensor( - paddle::experimental::empty({num_ranks, num_channels}, + paddle::experimental::empty({num_loop_stage, num_ranks, num_channels}, phi::DataType::INT32, phi::GPUPlace(device_id))); - auto send_rdma_head = ConvertPaddleTensorToDetailTensor( - paddle::experimental::empty({num_tokens, num_ranks / NUM_MAX_NVL_PEERS}, - phi::DataType::INT32, - phi::GPUPlace(device_id))); + auto send_rdma_head = + ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( + {num_loop_stage, num_tokens, num_ranks / NUM_MAX_NVL_PEERS}, + phi::DataType::INT32, + phi::GPUPlace(device_id))); auto send_nvl_head = ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( - {num_tokens, num_ranks / NUM_MAX_NVL_PEERS, 8}, + {num_loop_stage, num_tokens, num_ranks / NUM_MAX_NVL_PEERS, 8}, phi::DataType::INT32, phi::GPUPlace(device_id))); diff --git a/paddle/fluid/distributed/collective/deep_ep/deep_ep.hpp b/paddle/fluid/distributed/collective/deep_ep/deep_ep.hpp index b3a0c32a0e2cee..3b84c4b8bc6726 100644 --- a/paddle/fluid/distributed/collective/deep_ep/deep_ep.hpp +++ b/paddle/fluid/distributed/collective/deep_ep/deep_ep.hpp @@ -103,6 +103,7 @@ struct Buffer { public: Buffer(int rank, int num_ranks, + int num_loop_stage, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode, diff --git a/paddle/fluid/distributed/collective/deep_ep/kernels/internode.cu b/paddle/fluid/distributed/collective/deep_ep/kernels/internode.cu index 50dc911bb10297..7f089ff9c61743 100644 --- a/paddle/fluid/distributed/collective/deep_ep/kernels/internode.cu +++ b/paddle/fluid/distributed/collective/deep_ep/kernels/internode.cu @@ -1134,7 +1134,6 @@ __global__ void fused_notify_combine( // Clean up for later data dispatch EP_DEVICE_ASSERT(rdma_recv_num_tokens_mixed.total_bytes <= rdma_clean_offset * sizeof(int)); - #pragma unroll for (int i = thread_id; i < rdma_num_int_clean; i += num_threads) rdma_buffer_ptr_int[rdma_clean_offset + i] = 0; @@ -1146,16 +1145,32 @@ __global__ void fused_notify_combine( rdma_recv_num_tokens_mixed.send_buffer( i / NUM_MAX_NVL_PEERS)[j * (NUM_MAX_NVL_PEERS + 1) + (i % NUM_MAX_NVL_PEERS)] = - num_tokens_per_rank[(j * 2 + 1) * num_ranks + i]; + num_tokens_per_rank[j * num_ranks + i]; + printf( + "# SM%d Send Rdma rank: %d -> %d, stage:%d, num_tokens_per_rank: " + "%d\n", + sm_id, + rank, + thread_id, + j, + num_tokens_per_rank[j * num_ranks + i]); } } -#pragma unroll if (thread_id < kNumRDMARanks) { +#pragma unroll for (int j = 0; j < num_loop_stage; ++j) { rdma_recv_num_tokens_mixed.send_buffer( thread_id)[j * (NUM_MAX_NVL_PEERS + 1) + NUM_MAX_NVL_PEERS] = - num_tokens_per_rdma_rank[j * num_ranks + thread_id]; + num_tokens_per_rdma_rank[j * kNumRDMARanks + thread_id]; + printf( + "# SM%d Send Rdma rank: %d -> %d, stage:%d, " + "num_tokens_per_rdma_rank: %d\n", + sm_id, + rank, + thread_id, + j, + num_tokens_per_rdma_rank[j * kNumRDMARanks + thread_id]); } } @@ -1163,6 +1178,13 @@ __global__ void fused_notify_combine( // Issue send if (thread_id < kNumRDMARanks) { + printf( + "# SM%d rank: %d, thread_id: %d, kNumRDMARanks: %d " + "nvshmem_int_put_nbi\n", + sm_id, + rank, + thread_id, + kNumRDMARanks); nvshmem_int_put_nbi( rdma_recv_num_tokens_mixed.recv_buffer(rdma_rank), rdma_recv_num_tokens_mixed.send_buffer(thread_id), @@ -1173,7 +1195,19 @@ __global__ void fused_notify_combine( // Barrier if (thread_id == 0) { + printf( + "#SM%d rank: %d, thread_id: %d, Before " + "nvshmem_barrier_with_same_gpu_idx\n", + sm_id, + rank, + thread_id); nvshmem_barrier_with_same_gpu_idx(rdma_team); + printf( + "#SM%d rank: %d, thread_id: %d, After " + "nvshmem_barrier_with_same_gpu_idx\n", + sm_id, + rank, + thread_id); } __syncthreads(); @@ -1204,19 +1238,31 @@ __global__ void fused_notify_combine( sum += rdma_recv_num_tokens_mixed.recv_buffer( i)[thread_id * (NUM_MAX_NVL_PEERS + 1) + NUM_MAX_NVL_PEERS]; recv_rdma_rank_prefix_sum[thread_id * kNumRDMARanks + i] = sum; + printf( + "####### SM0 Rdma Recv rank: %d <- %d, stage: %d " + "moe_recv_rdma_counter_mapped %d\n", + rank, + i, + thread_id, + sum); } while (ld_volatile_global(moe_recv_rdma_counter_mapped + thread_id) != -1) { } moe_recv_rdma_counter_mapped[thread_id] = sum; + printf( + "####### rank: %d, thread_id: %d moe_recv_rdma_counter_mapped %d\n", + rank, + thread_id, + sum); } // Send numbers of tokens per rank/expert to NVL ranks EP_DEVICE_ASSERT(NUM_MAX_NVL_PEERS <= num_threads); if (thread_id < NUM_MAX_NVL_PEERS) { + for (int j = 0; j < num_loop_stage; ++j) { #pragma unroll - for (int i = 0; i < kNumRDMARanks; ++i) { - for (int j = 0; j < num_loop_stage; ++j) { + for (int i = 0; i < kNumRDMARanks; ++i) { nvl_send_num_tokens_per_rank.buffer(nvl_rank)[j * kNumRDMARanks + i] = rdma_recv_num_tokens_mixed.recv_buffer( i)[j * (NUM_MAX_NVL_PEERS + 1) + thread_id]; @@ -1231,6 +1277,9 @@ __global__ void fused_notify_combine( __syncthreads(); if (thread_id < num_loop_stage) { + printf("####### rank: %d, thread_id: %d recv_gbl_rank_prefix_sum\n", + rank, + thread_id); int sum = 0; #pragma unroll for (int i = 0; i < num_ranks; ++i) { @@ -1243,6 +1292,10 @@ __global__ void fused_notify_combine( while (ld_volatile_global(moe_recv_counter_mapped + thread_id) != -1) { } moe_recv_counter_mapped[thread_id] = sum; + printf("####### rank: %d, thread_id: %d moe_recv_counter_mapped %d\n", + rank, + thread_id, + sum); } // Finally barrier @@ -1250,15 +1303,25 @@ __global__ void fused_notify_combine( nvshmem_barrier_with_same_gpu_idx(rdma_team); barrier_device(task_fifo_ptrs, head, nvl_rank); move_fifo_slots(head); + printf( + "####### SM: %d, thread_id: %d, Send Recv Finish\n", sm_id, thread_id); } else { // Calculate meta data - int stage_id = (sm_id - 1) / num_loop_stage; - int dst_rdma_rank = (sm_id - 1) % num_loop_stage; + int stage_id = (sm_id - 1) / kNumRDMARanks; + int dst_rdma_rank = (sm_id - 1) % kNumRDMARanks; for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) { int token_start_idx, token_end_idx; get_channel_task_range( num_tokens, num_channels, channel_id, token_start_idx, token_end_idx); + if (lane_id == 0) + printf( + "####### SM: %d, thread_id: %d, channel_id: %d, token[%d-%d], \n", + sm_id, + thread_id, + channel_id, + token_start_idx, + token_end_idx); // Iterate over tokens int total_count = 0, per_nvl_rank_count[NUM_MAX_NVL_PEERS] = {0}; @@ -1277,6 +1340,11 @@ __global__ void fused_notify_combine( // Calculate RDMA tail index for combine auto warp_valid_tokens = std::min(token_end_idx - (i - lane_id), 32); + // if (warp_valid_tokens < 32) { + // printf("####### SM: %d, channel: %d, lane_id: %d, token(%d-%d)[%d], + // warp_valid_tokens: %d \n", sm_id, channel_id, lane_id, + // token_start_idx, token_end_idx, i, warp_valid_tokens); + // } unsigned int mask = 0xffffffff >> (32 - warp_valid_tokens); int warp_rdma_tail_idx = (is_token_in_rank_uint64 != 0); global_rdma_tail_idx += warp_scan(warp_rdma_tail_idx, mask); @@ -1286,6 +1354,10 @@ __global__ void fused_notify_combine( dst_rdma_rank] = rdma_tail_idx; global_rdma_tail_idx = __shfl_sync(mask, global_rdma_tail_idx, warp_valid_tokens - 1); + // if (lane_id == 0) + // printf("####### SM: %d, channel: %d, lane_id: %d, token(%d-%d)[%d], + // global_rdma_tail_idx: %d, \n", sm_id, channel_id, lane_id, + // token_start_idx, token_end_idx, i, global_rdma_tail_idx); #pragma unroll for (int j = 0; j < NUM_MAX_NVL_PEERS; ++j) { @@ -1300,22 +1372,62 @@ __global__ void fused_notify_combine( global_nvl_tail_idx[j] = __shfl_sync(mask, global_nvl_tail_idx[j], warp_valid_tokens - 1); } + if (lane_id == 0) + printf( + "####### SM: %d, channel: %d, lane_id: %d, token(%d-%d)[%d], " + "global_nvl_tail_idx: %d\n", + sm_id, + channel_id, + lane_id, + token_start_idx, + token_end_idx, + i, + global_nvl_tail_idx[0]); } // Warp reduce total_count = warp_reduce_sum(total_count); + if (lane_id == 0) + printf("####### SM: %d, stage: %d, [%d -> %d] channel[%d]=%d\n", + sm_id, + stage_id, + rank, + dst_rdma_rank, + channel_id, + total_count); #pragma unroll for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i) per_nvl_rank_count[i] = warp_reduce_sum(per_nvl_rank_count[i]); // Write into channel matrix if (lane_id == 0) { + // printf("####### sm: %d warp_id: %d, compute + // gbl_channel_prefix_matrix\n", sm_id, warp_id); #pragma unroll for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i) - gbl_channel_prefix_matrix[(stage_id * num_ranks + - dst_rdma_rank * NUM_MAX_NVL_PEERS + i) * - num_channels + - channel_id] = per_nvl_rank_count[i]; + printf( + "####### sm: %d warp_id: %d, compute " + "gbl_channel_prefix_matrix[%d][%d] per_nvl_rank_count[%d]=%d\n", + sm_id, + warp_id, + (stage_id * num_ranks + dst_rdma_rank * NUM_MAX_NVL_PEERS + i), + channel_id, + i, + per_nvl_rank_count[i]); + // gbl_channel_prefix_matrix[(stage_id * num_ranks + + // dst_rdma_rank * NUM_MAX_NVL_PEERS + i) * + // num_channels + + // channel_id] = per_nvl_rank_count[i]; + printf( + "####### sm: %d warp_id: %d, dst_rdma_rank: %d, stage_id: %d, " + "compute rdma_channel_prefix_matrix[%d][%d]=%d\n", + sm_id, + warp_id, + dst_rdma_rank, + stage_id, + (stage_id * kNumRDMARanks + dst_rdma_rank), + channel_id, + total_count); rdma_channel_prefix_matrix[(stage_id * kNumRDMARanks + dst_rdma_rank) * num_channels + channel_id] = total_count; @@ -1325,6 +1437,13 @@ __global__ void fused_notify_combine( // Calculate prefix sum __syncthreads(); if (thread_id == 0) { + printf( + "####### rank: %d, sm: %d warp_id: %d, thread_id: %d, " + "rdma_channel_prefix_matrix\n", + rank, + sm_id, + warp_id, + thread_id); auto prefix_row = rdma_channel_prefix_matrix + (stage_id * kNumRDMARanks + dst_rdma_rank) * num_channels; @@ -1345,6 +1464,7 @@ __global__ void fused_notify_combine( prefix_row[i] += prefix_row[i - 1]; } } + printf("####### SM: %d, thread_id: %d, Compute Finish\n", sm_id, thread_id); } } diff --git a/paddle/fluid/pybind/deep_ep_api.cc b/paddle/fluid/pybind/deep_ep_api.cc index ae79ff62a3878f..3dddc79940947c 100644 --- a/paddle/fluid/pybind/deep_ep_api.cc +++ b/paddle/fluid/pybind/deep_ep_api.cc @@ -61,7 +61,7 @@ void BindDeepEPApi(pybind11::module *m) { &deep_ep::GetEventHandleFromCustomStream); pybind11::class_(*m, "Buffer") - .def(pybind11::init()) + .def(pybind11::init()) .def("is_available", &deep_ep::Buffer::is_available) .def("get_num_rdma_ranks", &deep_ep::Buffer::get_num_rdma_ranks) .def("get_rdma_rank", &deep_ep::Buffer::get_rdma_rank) @@ -100,7 +100,7 @@ void BindDeepEPApi(pybind11::module *m) { .def("clear_buffer", &deep_ep::Buffer::clear_buffer_api) .def("internode_notify_combine", &deep_ep::Buffer::internode_notify_combine_api) - .def("internode_fused_notify_combine_api", + .def("internode_fused_notify_combine", &deep_ep::Buffer::internode_fused_notify_combine_api) .def("internode_combine", &deep_ep::Buffer::internode_combine_api) .def("barrier_all", &deep_ep::Buffer::barrier_all) diff --git a/python/paddle/distributed/communication/deep_ep/buffer.py b/python/paddle/distributed/communication/deep_ep/buffer.py index 5ae7f358e529b9..dbd77b35c89339 100644 --- a/python/paddle/distributed/communication/deep_ep/buffer.py +++ b/python/paddle/distributed/communication/deep_ep/buffer.py @@ -65,7 +65,7 @@ def __init__( num_rdma_bytes: int = 0, low_latency_mode: bool = False, num_qps_per_rank: int = 12, - # num_loop_stages: int = 1, + num_loop_stages: int = 3, ) -> None: """ Initialize the communication buffer. @@ -89,6 +89,7 @@ def __init__( self.runtime = CppBuffer( self.rank, self.group_size, + num_loop_stages, num_nvl_bytes, num_rdma_bytes, low_latency_mode, From 6ba074a5299f45ba2bfd729f416240c2d22564d5 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Mon, 24 Nov 2025 12:16:09 +0800 Subject: [PATCH 3/4] polish code --- .../collective/deep_ep/kernels/internode.cu | 129 +----------------- 1 file changed, 4 insertions(+), 125 deletions(-) diff --git a/paddle/fluid/distributed/collective/deep_ep/kernels/internode.cu b/paddle/fluid/distributed/collective/deep_ep/kernels/internode.cu index 7f089ff9c61743..e867794129f05a 100644 --- a/paddle/fluid/distributed/collective/deep_ep/kernels/internode.cu +++ b/paddle/fluid/distributed/collective/deep_ep/kernels/internode.cu @@ -1146,14 +1146,6 @@ __global__ void fused_notify_combine( i / NUM_MAX_NVL_PEERS)[j * (NUM_MAX_NVL_PEERS + 1) + (i % NUM_MAX_NVL_PEERS)] = num_tokens_per_rank[j * num_ranks + i]; - printf( - "# SM%d Send Rdma rank: %d -> %d, stage:%d, num_tokens_per_rank: " - "%d\n", - sm_id, - rank, - thread_id, - j, - num_tokens_per_rank[j * num_ranks + i]); } } @@ -1163,14 +1155,6 @@ __global__ void fused_notify_combine( rdma_recv_num_tokens_mixed.send_buffer( thread_id)[j * (NUM_MAX_NVL_PEERS + 1) + NUM_MAX_NVL_PEERS] = num_tokens_per_rdma_rank[j * kNumRDMARanks + thread_id]; - printf( - "# SM%d Send Rdma rank: %d -> %d, stage:%d, " - "num_tokens_per_rdma_rank: %d\n", - sm_id, - rank, - thread_id, - j, - num_tokens_per_rdma_rank[j * kNumRDMARanks + thread_id]); } } @@ -1178,13 +1162,6 @@ __global__ void fused_notify_combine( // Issue send if (thread_id < kNumRDMARanks) { - printf( - "# SM%d rank: %d, thread_id: %d, kNumRDMARanks: %d " - "nvshmem_int_put_nbi\n", - sm_id, - rank, - thread_id, - kNumRDMARanks); nvshmem_int_put_nbi( rdma_recv_num_tokens_mixed.recv_buffer(rdma_rank), rdma_recv_num_tokens_mixed.send_buffer(thread_id), @@ -1195,19 +1172,7 @@ __global__ void fused_notify_combine( // Barrier if (thread_id == 0) { - printf( - "#SM%d rank: %d, thread_id: %d, Before " - "nvshmem_barrier_with_same_gpu_idx\n", - sm_id, - rank, - thread_id); nvshmem_barrier_with_same_gpu_idx(rdma_team); - printf( - "#SM%d rank: %d, thread_id: %d, After " - "nvshmem_barrier_with_same_gpu_idx\n", - sm_id, - rank, - thread_id); } __syncthreads(); @@ -1238,23 +1203,11 @@ __global__ void fused_notify_combine( sum += rdma_recv_num_tokens_mixed.recv_buffer( i)[thread_id * (NUM_MAX_NVL_PEERS + 1) + NUM_MAX_NVL_PEERS]; recv_rdma_rank_prefix_sum[thread_id * kNumRDMARanks + i] = sum; - printf( - "####### SM0 Rdma Recv rank: %d <- %d, stage: %d " - "moe_recv_rdma_counter_mapped %d\n", - rank, - i, - thread_id, - sum); } while (ld_volatile_global(moe_recv_rdma_counter_mapped + thread_id) != -1) { } moe_recv_rdma_counter_mapped[thread_id] = sum; - printf( - "####### rank: %d, thread_id: %d moe_recv_rdma_counter_mapped %d\n", - rank, - thread_id, - sum); } // Send numbers of tokens per rank/expert to NVL ranks @@ -1277,9 +1230,6 @@ __global__ void fused_notify_combine( __syncthreads(); if (thread_id < num_loop_stage) { - printf("####### rank: %d, thread_id: %d recv_gbl_rank_prefix_sum\n", - rank, - thread_id); int sum = 0; #pragma unroll for (int i = 0; i < num_ranks; ++i) { @@ -1292,10 +1242,6 @@ __global__ void fused_notify_combine( while (ld_volatile_global(moe_recv_counter_mapped + thread_id) != -1) { } moe_recv_counter_mapped[thread_id] = sum; - printf("####### rank: %d, thread_id: %d moe_recv_counter_mapped %d\n", - rank, - thread_id, - sum); } // Finally barrier @@ -1303,8 +1249,6 @@ __global__ void fused_notify_combine( nvshmem_barrier_with_same_gpu_idx(rdma_team); barrier_device(task_fifo_ptrs, head, nvl_rank); move_fifo_slots(head); - printf( - "####### SM: %d, thread_id: %d, Send Recv Finish\n", sm_id, thread_id); } else { // Calculate meta data int stage_id = (sm_id - 1) / kNumRDMARanks; @@ -1314,14 +1258,6 @@ __global__ void fused_notify_combine( int token_start_idx, token_end_idx; get_channel_task_range( num_tokens, num_channels, channel_id, token_start_idx, token_end_idx); - if (lane_id == 0) - printf( - "####### SM: %d, thread_id: %d, channel_id: %d, token[%d-%d], \n", - sm_id, - thread_id, - channel_id, - token_start_idx, - token_end_idx); // Iterate over tokens int total_count = 0, per_nvl_rank_count[NUM_MAX_NVL_PEERS] = {0}; @@ -1340,11 +1276,6 @@ __global__ void fused_notify_combine( // Calculate RDMA tail index for combine auto warp_valid_tokens = std::min(token_end_idx - (i - lane_id), 32); - // if (warp_valid_tokens < 32) { - // printf("####### SM: %d, channel: %d, lane_id: %d, token(%d-%d)[%d], - // warp_valid_tokens: %d \n", sm_id, channel_id, lane_id, - // token_start_idx, token_end_idx, i, warp_valid_tokens); - // } unsigned int mask = 0xffffffff >> (32 - warp_valid_tokens); int warp_rdma_tail_idx = (is_token_in_rank_uint64 != 0); global_rdma_tail_idx += warp_scan(warp_rdma_tail_idx, mask); @@ -1354,10 +1285,6 @@ __global__ void fused_notify_combine( dst_rdma_rank] = rdma_tail_idx; global_rdma_tail_idx = __shfl_sync(mask, global_rdma_tail_idx, warp_valid_tokens - 1); - // if (lane_id == 0) - // printf("####### SM: %d, channel: %d, lane_id: %d, token(%d-%d)[%d], - // global_rdma_tail_idx: %d, \n", sm_id, channel_id, lane_id, - // token_start_idx, token_end_idx, i, global_rdma_tail_idx); #pragma unroll for (int j = 0; j < NUM_MAX_NVL_PEERS; ++j) { @@ -1372,62 +1299,22 @@ __global__ void fused_notify_combine( global_nvl_tail_idx[j] = __shfl_sync(mask, global_nvl_tail_idx[j], warp_valid_tokens - 1); } - if (lane_id == 0) - printf( - "####### SM: %d, channel: %d, lane_id: %d, token(%d-%d)[%d], " - "global_nvl_tail_idx: %d\n", - sm_id, - channel_id, - lane_id, - token_start_idx, - token_end_idx, - i, - global_nvl_tail_idx[0]); } // Warp reduce total_count = warp_reduce_sum(total_count); - if (lane_id == 0) - printf("####### SM: %d, stage: %d, [%d -> %d] channel[%d]=%d\n", - sm_id, - stage_id, - rank, - dst_rdma_rank, - channel_id, - total_count); #pragma unroll for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i) per_nvl_rank_count[i] = warp_reduce_sum(per_nvl_rank_count[i]); // Write into channel matrix if (lane_id == 0) { - // printf("####### sm: %d warp_id: %d, compute - // gbl_channel_prefix_matrix\n", sm_id, warp_id); #pragma unroll for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i) - printf( - "####### sm: %d warp_id: %d, compute " - "gbl_channel_prefix_matrix[%d][%d] per_nvl_rank_count[%d]=%d\n", - sm_id, - warp_id, - (stage_id * num_ranks + dst_rdma_rank * NUM_MAX_NVL_PEERS + i), - channel_id, - i, - per_nvl_rank_count[i]); - // gbl_channel_prefix_matrix[(stage_id * num_ranks + - // dst_rdma_rank * NUM_MAX_NVL_PEERS + i) * - // num_channels + - // channel_id] = per_nvl_rank_count[i]; - printf( - "####### sm: %d warp_id: %d, dst_rdma_rank: %d, stage_id: %d, " - "compute rdma_channel_prefix_matrix[%d][%d]=%d\n", - sm_id, - warp_id, - dst_rdma_rank, - stage_id, - (stage_id * kNumRDMARanks + dst_rdma_rank), - channel_id, - total_count); + gbl_channel_prefix_matrix[(stage_id * num_ranks + + dst_rdma_rank * NUM_MAX_NVL_PEERS + i) * + num_channels + + channel_id] = per_nvl_rank_count[i]; rdma_channel_prefix_matrix[(stage_id * kNumRDMARanks + dst_rdma_rank) * num_channels + channel_id] = total_count; @@ -1437,13 +1324,6 @@ __global__ void fused_notify_combine( // Calculate prefix sum __syncthreads(); if (thread_id == 0) { - printf( - "####### rank: %d, sm: %d warp_id: %d, thread_id: %d, " - "rdma_channel_prefix_matrix\n", - rank, - sm_id, - warp_id, - thread_id); auto prefix_row = rdma_channel_prefix_matrix + (stage_id * kNumRDMARanks + dst_rdma_rank) * num_channels; @@ -1464,7 +1344,6 @@ __global__ void fused_notify_combine( prefix_row[i] += prefix_row[i - 1]; } } - printf("####### SM: %d, thread_id: %d, Compute Finish\n", sm_id, thread_id); } } From 5f9a5377bcbd8df990ffa52e21b9180ce1f3bbd0 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Mon, 24 Nov 2025 19:07:06 +0800 Subject: [PATCH 4/4] add fused_notify_dispatch --- .../collective/deep_ep/deep_ep.cpp | 286 ++++++++++++- .../collective/deep_ep/deep_ep.hpp | 38 ++ .../collective/deep_ep/kernels/api.cuh | 32 ++ .../collective/deep_ep/kernels/internode.cu | 399 ++++++++++++++++++ paddle/fluid/pybind/deep_ep_api.cc | 2 + .../communication/deep_ep/buffer.py | 74 ++++ 6 files changed, 823 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/distributed/collective/deep_ep/deep_ep.cpp b/paddle/fluid/distributed/collective/deep_ep/deep_ep.cpp index f18db8468fbfaa..671961ce88a50c 100644 --- a/paddle/fluid/distributed/collective/deep_ep/deep_ep.cpp +++ b/paddle/fluid/distributed/collective/deep_ep/deep_ep.cpp @@ -140,9 +140,10 @@ Buffer::Buffer(int rank, *moe_recv_counter = -1; // MoE expert-level counter - CUDA_CHECK(cudaMallocHost(&moe_recv_expert_counter, - sizeof(int) * NUM_MAX_LOCAL_EXPERTS, - cudaHostAllocMapped)); + CUDA_CHECK( + cudaMallocHost(&moe_recv_expert_counter, + sizeof(int) * NUM_MAX_LOCAL_EXPERTS * num_loop_stage, + cudaHostAllocMapped)); CUDA_CHECK(cudaHostGetDevicePointer(&moe_recv_expert_counter_mapped, const_cast(moe_recv_expert_counter), 0)); @@ -2057,11 +2058,10 @@ Buffer::internode_fused_notify_combine( } } - std::vector num_recv_tokens(moe_recv_counter_mapped, - moe_recv_counter_mapped + num_loop_stage); - std::vector num_rdma_recv_tokens( - moe_recv_rdma_counter_mapped, - moe_recv_rdma_counter_mapped + num_loop_stage); + std::vector num_recv_tokens(moe_recv_counter, + moe_recv_counter + num_loop_stage); + std::vector num_rdma_recv_tokens(moe_recv_rdma_counter, + moe_recv_rdma_counter + num_loop_stage); // Wait streams stream_wait(compute_stream, comm_stream); @@ -2884,6 +2884,201 @@ Buffer::internode_notify_dispatch( recv_gbl_rank_prefix_sum}; } +std::tuple>, + std::vector, + std::vector, + deep_ep::detail::Tensor, + deep_ep::detail::Tensor, + deep_ep::detail::Tensor, + deep_ep::detail::Tensor> +Buffer::internode_fused_notify_dispatch( + const deep_ep::detail::Tensor& x, + const std::optional& x_scales, + const std::optional& topk_idx, + const std::optional& num_tokens_per_rank, + const std::optional& num_tokens_per_rdma_rank, + const std::optional& num_tokens_per_expert, + const deep_ep::detail::Tensor& is_token_in_rank, + int expert_alignment, + int num_loop_stage, + const Config& config) { + const int num_channels = config.num_sms / 2; + EP_HOST_ASSERT(config.num_sms % 2 == 0); + EP_HOST_ASSERT(0 < get_num_rdma_ranks() && + get_num_rdma_ranks() <= NUM_MAX_RDMA_PEERS); + + EP_HOST_ASSERT(num_tokens_per_rank.has_value()); + EP_HOST_ASSERT(num_tokens_per_rdma_rank.has_value()); + EP_HOST_ASSERT(num_tokens_per_expert.has_value()); + + // Type checks + EP_HOST_ASSERT(num_tokens_per_rank->scalar_type() == deep_ep::detail::kInt32); + EP_HOST_ASSERT(num_tokens_per_rdma_rank->scalar_type() == + deep_ep::detail::kInt32); + EP_HOST_ASSERT(num_tokens_per_expert->scalar_type() == + deep_ep::detail::kInt32); + + // Shape and contiguous checks + EP_HOST_ASSERT(x.dim() == 2 && x.is_contiguous()); + EP_HOST_ASSERT((x.size(1) * x.element_size()) % sizeof(int4) == 0); + EP_HOST_ASSERT(num_tokens_per_rank->dim() == 2 && + num_tokens_per_rank->is_contiguous()); + EP_HOST_ASSERT(num_tokens_per_rdma_rank->dim() == 2 && + num_tokens_per_rdma_rank->is_contiguous()); + EP_HOST_ASSERT(num_tokens_per_expert->dim() == 2 && + num_tokens_per_expert->is_contiguous()); + EP_HOST_ASSERT(num_tokens_per_rank->size(1) == num_ranks); + EP_HOST_ASSERT(num_tokens_per_rdma_rank->size(1) == num_rdma_ranks); + EP_HOST_ASSERT(num_tokens_per_expert->size(1) % num_ranks == 0); + EP_HOST_ASSERT(num_tokens_per_expert->size(1) / num_ranks <= + NUM_MAX_LOCAL_EXPERTS); + + auto num_tokens = static_cast(x.size(0)), + hidden = static_cast(x.size(1)), + hidden_int4 = + static_cast(x.size(1) * x.element_size() / sizeof(int4)); + + // Top-k checks + int num_topk = 0; + int64_t* topk_idx_ptr = nullptr; + if (topk_idx.has_value()) { + num_topk = static_cast(topk_idx->size(1)); + EP_HOST_ASSERT(topk_idx->dim() == 2 && topk_idx->is_contiguous()); + EP_HOST_ASSERT(num_tokens == topk_idx->size(0)); + EP_HOST_ASSERT(num_topk == topk_idx->size(1)); + topk_idx_ptr = topk_idx->data_ptr(); + } + auto num_experts = static_cast(num_tokens_per_expert->size(1)); + int num_local_experts = num_experts / num_ranks; + + // FP8 scales checks + float* x_scales_ptr = nullptr; + int num_scales = 0; + if (x_scales.has_value()) { + EP_HOST_ASSERT(x.element_size() == 1); + EP_HOST_ASSERT(x_scales->scalar_type() == deep_ep::detail::kFloat32); + EP_HOST_ASSERT(x_scales->dim() > 0 && x_scales->dim() < 3 && + x_scales->is_contiguous()); + EP_HOST_ASSERT(x_scales->size(0) == num_tokens); + num_scales = x_scales->dim() == 1 ? 1 : static_cast(x_scales->size(1)); + } + + auto rdma_channel_prefix_matrix = + ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( + {num_loop_stage, num_rdma_ranks, num_channels}, + phi::DataType::INT32, + phi::GPUPlace(device_id))); + auto recv_rdma_rank_prefix_sum = ConvertPaddleTensorToDetailTensor( + paddle::experimental::empty({num_loop_stage, num_rdma_ranks}, + phi::DataType::INT32, + phi::GPUPlace(device_id))); + auto gbl_channel_prefix_matrix = ConvertPaddleTensorToDetailTensor( + paddle::experimental::empty({num_loop_stage, num_ranks, num_channels}, + phi::DataType::INT32, + phi::GPUPlace(device_id))); + auto recv_gbl_rank_prefix_sum = ConvertPaddleTensorToDetailTensor( + paddle::experimental::empty({num_loop_stage, num_ranks}, + phi::DataType::INT32, + phi::GPUPlace(device_id))); + + auto compute_stream = calc_ctx->stream(); + stream_wait(comm_stream, compute_stream); + + // Send sizes + for (int s = 0; s < num_loop_stage; ++s) { + moe_recv_counter[s] = -1; + moe_recv_rdma_counter[s] = -1; + for (int i = 0; i < num_local_experts; ++i) + moe_recv_expert_counter[s * num_local_experts + i] = -1; + } + + internode::fused_notify_dispatch( + num_tokens_per_rank->data_ptr(), + moe_recv_counter_mapped, + num_ranks, + num_tokens_per_rdma_rank->data_ptr(), + moe_recv_rdma_counter_mapped, + num_tokens_per_expert->data_ptr(), + moe_recv_expert_counter_mapped, + num_experts, + is_token_in_rank.data_ptr(), + num_tokens, + num_channels, + hidden_int4, + num_scales, + num_topk, + expert_alignment, + num_loop_stage, + rdma_channel_prefix_matrix.data_ptr(), + recv_rdma_rank_prefix_sum.data_ptr(), + gbl_channel_prefix_matrix.data_ptr(), + recv_gbl_rank_prefix_sum.data_ptr(), + rdma_buffer_ptr, + config.num_max_rdma_chunked_recv_tokens, + buffer_ptrs_gpu, + config.num_max_nvl_chunked_recv_tokens, + task_fifo_ptrs_gpu, + head, + rank, + comm_stream, + config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), + num_nvl_bytes, + low_latency_mode); + move_fifo_slots(3); + + // Synchronize total received tokens and tokens per expert + auto start_time = std::chrono::high_resolution_clock::now(); + while (true) { + bool ready = true; + for (int s = 0; s < num_loop_stage && ready; ++s) { + // Read total count + ready &= (moe_recv_counter[s] >= 0) && (moe_recv_rdma_counter[s] >= 0); + // Read per-expert count + for (int i = 0; i < num_local_experts && ready; ++i) + ready &= moe_recv_expert_counter[s * num_local_experts + i] >= 0; + } + + if (ready) break; + + // Timeout check + if (std::chrono::duration_cast( + std::chrono::high_resolution_clock::now() - start_time) + .count() > NUM_CPU_TIMEOUT_SECS) { + for (int s = 0; s < num_loop_stage; ++s) { + LOG(INFO) << "Global rank: " << rank << ", stage: " << s + << ", num_recv_tokens: " << moe_recv_counter[s] + << ", num_rdma_recv_tokens: " << moe_recv_rdma_counter[s]; + for (int i = 0; i < num_local_experts; ++i) + LOG(INFO) << " moe_recv_expert_counter[" << i << "]: " + << moe_recv_expert_counter[s * num_local_experts + i]; + throw std::runtime_error("DeepEP error: timeout (dispatch CPU)"); + } + } + } + std::vector num_recv_tokens(moe_recv_counter, + moe_recv_counter + num_loop_stage); + std::vector num_rdma_recv_tokens(moe_recv_rdma_counter, + moe_recv_rdma_counter + num_loop_stage); + + std::vector> num_recv_tokens_per_expert_list; + num_recv_tokens_per_expert_list.reserve(num_loop_stage); + for (int s = 0; s < num_loop_stage; ++s) { + num_recv_tokens_per_expert_list.emplace_back( + moe_recv_expert_counter + s * num_local_experts, + moe_recv_expert_counter + (s + 1) * num_local_experts); + } + + stream_wait(compute_stream, comm_stream); + + return {num_recv_tokens_per_expert_list, + num_recv_tokens, + num_rdma_recv_tokens, + rdma_channel_prefix_matrix, + gbl_channel_prefix_matrix, + recv_rdma_rank_prefix_sum, + recv_gbl_rank_prefix_sum}; +} + #endif // PADDLE_WITH_NVSHMEM std::tuple>, + std::vector, + std::vector, + paddle::Tensor, + paddle::Tensor, + paddle::Tensor, + paddle::Tensor> +Buffer::internode_fused_notify_dispatch_api( + const paddle::Tensor& x, + const std::optional& x_scales, + const std::optional& topk_idx, + const std::optional& num_tokens_per_rank, + const std::optional& num_tokens_per_rdma_rank, + const std::optional& num_tokens_per_expert, + const paddle::Tensor& is_token_in_rank, + int expert_alignment, + int num_loop_stage, + const Config& config) { +#ifdef PADDLE_WITH_NVSHMEM + const auto& x_ = ConvertPaddleTensorToDetailTensor(x); + std::optional x_scales_ = + ConvertOptionalPaddleTensorToDetailTensor(x_scales); + + std::optional topk_idx_ = + ConvertOptionalPaddleTensorToDetailTensor(topk_idx); + std::optional num_tokens_per_rank_ = + ConvertOptionalPaddleTensorToDetailTensor(num_tokens_per_rank); + std::optional num_tokens_per_rdma_rank_ = + ConvertOptionalPaddleTensorToDetailTensor(num_tokens_per_rdma_rank); + std::optional num_tokens_per_expert_ = + ConvertOptionalPaddleTensorToDetailTensor(num_tokens_per_expert); + const auto& is_token_in_rank_ = + ConvertPaddleTensorToDetailTensor(is_token_in_rank); + + auto res = internode_fused_notify_dispatch(x_, + x_scales_, + topk_idx_, + num_tokens_per_rank_, + num_tokens_per_rdma_rank_, + num_tokens_per_expert_, + is_token_in_rank_, + expert_alignment, + num_loop_stage, + config); + + auto num_recv_tokens_per_expert_list_ = std::get<0>(res); + auto num_recv_tokens_ = std::get<1>(res); + auto num_rdma_recv_tokens_ = std::get<2>(res); + + auto rdma_channel_prefix_matrix_ = + ConvertDetailTensorToPaddleTensor(std::get<3>(res)); + + auto gbl_channel_prefix_matrix_ = + ConvertDetailTensorToPaddleTensor(std::get<4>(res)); + + auto recv_rdma_rank_prefix_sum_ = + ConvertDetailTensorToPaddleTensor(std::get<5>(res)); + + auto recv_gbl_rank_prefix_sum_ = + ConvertDetailTensorToPaddleTensor(std::get<6>(res)); + + return {num_recv_tokens_per_expert_list_, + num_recv_tokens_, + num_rdma_recv_tokens_, + rdma_channel_prefix_matrix_, + gbl_channel_prefix_matrix_, + recv_rdma_rank_prefix_sum_, + recv_gbl_rank_prefix_sum_}; +#else + LOG(ERROR) << "NVSHMEM is not enabled. You can enable it by setting cmake " + "option WITH_NVSHMEM=ON."; + return {}; +#endif +} + deep_ep::detail::Tensor ConvertPaddleTensorToDetailTensor( const paddle::Tensor& tensor) { deep_ep::detail::Tensor res(tensor); diff --git a/paddle/fluid/distributed/collective/deep_ep/deep_ep.hpp b/paddle/fluid/distributed/collective/deep_ep/deep_ep.hpp index 3b84c4b8bc6726..5503ffd715f811 100644 --- a/paddle/fluid/distributed/collective/deep_ep/deep_ep.hpp +++ b/paddle/fluid/distributed/collective/deep_ep/deep_ep.hpp @@ -286,6 +286,25 @@ struct Buffer { int expert_alignment, const Config& config); + std::tuple>, + std::vector, + std::vector, + deep_ep::detail::Tensor, + deep_ep::detail::Tensor, + deep_ep::detail::Tensor, + deep_ep::detail::Tensor> + internode_fused_notify_dispatch( + const deep_ep::detail::Tensor& x, + const std::optional& x_scales, + const std::optional& topk_idx, + const std::optional& num_tokens_per_rank, + const std::optional& num_tokens_per_rdma_rank, + const std::optional& num_tokens_per_expert, + const deep_ep::detail::Tensor& is_token_in_rank, + int expert_alignment, + int num_loop_stage, + const Config& config); + std::tuple>, + std::vector, + std::vector, + paddle::Tensor, + paddle::Tensor, + paddle::Tensor, + paddle::Tensor> + internode_fused_notify_dispatch_api( + const paddle::Tensor& x, + const std::optional& x_scales, + const std::optional& topk_idx, + const std::optional& num_tokens_per_rank, + const std::optional& num_tokens_per_rdma_rank, + const std::optional& num_tokens_per_expert, + const paddle::Tensor& is_token_in_rank, + int expert_alignment, + int num_loop_stage, + const Config& config); + void clear_buffer_api(const paddle::Tensor& x, const std::optional& x_scales, const std::optional& topk_idx, diff --git a/paddle/fluid/distributed/collective/deep_ep/kernels/api.cuh b/paddle/fluid/distributed/collective/deep_ep/kernels/api.cuh index c8a28a46208ffd..9797b9e502ebfb 100644 --- a/paddle/fluid/distributed/collective/deep_ep/kernels/api.cuh +++ b/paddle/fluid/distributed/collective/deep_ep/kernels/api.cuh @@ -243,6 +243,38 @@ void notify_combine_post_step(int num_ranks, cudaStream_t stream, bool low_latency_mode); +void fused_notify_dispatch(const int* num_tokens_per_rank, + int* moe_recv_counter_mapped, + int num_ranks, + const int* num_tokens_per_rdma_rank, + int* moe_recv_rdma_counter_mapped, + const int* num_tokens_per_expert, + int* moe_recv_expert_counter_mapped, + int num_experts, + const bool* is_token_in_rank, + int num_tokens, + int num_channels, + int hidden_int4, + int num_scales, + int num_topk, + int expert_alignment, + int num_loop_stage, + int* rdma_channel_prefix_matrix, + int* recv_rdma_rank_prefix_sum, + int* gbl_channel_prefix_matrix, + int* recv_gbl_rank_prefix_sum, + void* rdma_buffer_ptr, + int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, + int num_max_nvl_chunked_recv_tokens, + int** task_fifo_ptrs, + int head, + int rank, + cudaStream_t stream, + int64_t num_rdma_bytes, + int64_t num_nvl_bytes, + bool low_latency_mode); + void fused_notify_combine(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks, diff --git a/paddle/fluid/distributed/collective/deep_ep/kernels/internode.cu b/paddle/fluid/distributed/collective/deep_ep/kernels/internode.cu index e867794129f05a..85ce27165483cb 100644 --- a/paddle/fluid/distributed/collective/deep_ep/kernels/internode.cu +++ b/paddle/fluid/distributed/collective/deep_ep/kernels/internode.cu @@ -1077,6 +1077,405 @@ void notify_combine_post_step(int num_ranks, #undef NOTIFY_DISPATCH_LAUNCH_CASE } +template +__global__ void fused_notify_dispatch(const int* num_tokens_per_rank, + int* moe_recv_counter_mapped, + int num_ranks, + const int* num_tokens_per_rdma_rank, + int* moe_recv_rdma_counter_mapped, + const int* num_tokens_per_expert, + int* moe_recv_expert_counter_mapped, + int num_experts, + const bool* is_token_in_rank, + int num_tokens, + int num_channels, + int expert_alignment, + int num_loop_stage, + const int rdma_clean_offset, + const int rdma_num_int_clean, + const int nvl_clean_offset, + const int nvl_num_int_clean, + int* rdma_channel_prefix_matrix, + int* recv_rdma_rank_prefix_sum, + int* gbl_channel_prefix_matrix, + int* recv_gbl_rank_prefix_sum, + void* rdma_buffer_ptr, + void** buffer_ptrs, + int** task_fifo_ptrs, + int head, + int rank, + const nvshmem_team_t rdma_team) { + auto sm_id = static_cast(blockIdx.x); + auto thread_id = static_cast(threadIdx.x), warp_id = thread_id / 32, + lane_id = get_lane_id(); + auto num_threads = static_cast(blockDim.x), num_warps = num_threads / 32; + + auto rdma_rank = rank / NUM_MAX_NVL_PEERS, + nvl_rank = rank % NUM_MAX_NVL_PEERS; + auto num_rdma_experts = num_experts / kNumRDMARanks, + num_nvl_experts = num_rdma_experts / NUM_MAX_NVL_PEERS; + + if (sm_id == 0) { + // Communication with others + // Global barrier: the first warp do intra-node sync, the second warp do + // internode sync + EP_DEVICE_ASSERT(num_warps > 1); + EP_DEVICE_ASSERT(kNumRDMARanks <= num_threads); + if (thread_id == 32) + nvshmem_barrier_with_same_gpu_idx(rdma_team); + barrier_device(task_fifo_ptrs, head, nvl_rank); + move_fifo_slots(head); + __syncthreads(); + + // Send numbers of tokens per rank/expert to RDMA ranks + auto rdma_buffer_ptr_int = reinterpret_cast(rdma_buffer_ptr); + auto rdma_recv_num_tokens_mixed = SymBuffer( + rdma_buffer_ptr, + (NUM_MAX_NVL_PEERS + num_rdma_experts + 1) * num_loop_stage, + kNumRDMARanks); + + // Clean up for later data dispatch + EP_DEVICE_ASSERT(rdma_recv_num_tokens_mixed.total_bytes <= + rdma_clean_offset * sizeof(int)); +#pragma unroll + for (int i = thread_id; i < rdma_num_int_clean; i += num_threads) + rdma_buffer_ptr_int[rdma_clean_offset + i] = 0; + +// Copy to send buffer +#pragma unroll + for (int i = thread_id; i < num_ranks; i += num_threads) { + for (int j = 0; j < num_loop_stage; ++j) { + rdma_recv_num_tokens_mixed.send_buffer( + i / + NUM_MAX_NVL_PEERS)[j * (NUM_MAX_NVL_PEERS + num_rdma_experts + 1) + + (i % NUM_MAX_NVL_PEERS)] = + num_tokens_per_rank[j * num_ranks + i]; + } + } +#pragma unroll + for (int i = thread_id; i < num_experts; i += num_threads) { + for (int j = 0; j < num_loop_stage; ++j) { + rdma_recv_num_tokens_mixed.send_buffer( + i / + num_rdma_experts)[j * (NUM_MAX_NVL_PEERS + num_rdma_experts + 1) + + NUM_MAX_NVL_PEERS + i % num_rdma_experts] = + num_tokens_per_expert[j * num_experts + i]; + } + } + + if (thread_id < kNumRDMARanks) { +#pragma unroll + for (int j = 0; j < num_loop_stage; ++j) { + rdma_recv_num_tokens_mixed.send_buffer( + thread_id)[j * (NUM_MAX_NVL_PEERS + num_rdma_experts + 1) + + NUM_MAX_NVL_PEERS + num_rdma_experts] = + num_tokens_per_rdma_rank[j * kNumRDMARanks + thread_id]; + } + } + + __syncthreads(); + + if (thread_id < kNumRDMARanks) { + nvshmem_int_put_nbi( + rdma_recv_num_tokens_mixed.recv_buffer(rdma_rank), + rdma_recv_num_tokens_mixed.send_buffer(thread_id), + (NUM_MAX_NVL_PEERS + num_rdma_experts + 1) * num_loop_stage, + translate_dst_rdma_rank(thread_id, nvl_rank)); + } + __syncthreads(); + if (thread_id == 0) + nvshmem_barrier_with_same_gpu_idx(rdma_team); + __syncthreads(); + + // NVL buffers + auto nvl_send_buffer = + thread_id < NUM_MAX_NVL_PEERS ? buffer_ptrs[thread_id] : nullptr; + auto nvl_recv_buffer = buffer_ptrs[nvl_rank]; + auto nvl_reduced_num_tokens_per_expert = + Buffer(nvl_recv_buffer, num_rdma_experts * num_loop_stage) + .advance_also(nvl_send_buffer); + auto nvl_send_num_tokens_per_rank = AsymBuffer( + nvl_send_buffer, kNumRDMARanks * num_loop_stage, NUM_MAX_NVL_PEERS); + auto nvl_send_num_tokens_per_expert = AsymBuffer( + nvl_send_buffer, num_nvl_experts * num_loop_stage, NUM_MAX_NVL_PEERS); + auto nvl_recv_num_tokens_per_rank = AsymBuffer( + nvl_recv_buffer, kNumRDMARanks * num_loop_stage, NUM_MAX_NVL_PEERS); + auto nvl_recv_num_tokens_per_expert = AsymBuffer( + nvl_recv_buffer, num_nvl_experts * num_loop_stage, NUM_MAX_NVL_PEERS); + + // Clean up for later data dispatch + auto nvl_buffer_ptr_int = reinterpret_cast(buffer_ptrs[nvl_rank]); + EP_DEVICE_ASSERT(nvl_reduced_num_tokens_per_expert.total_bytes + + nvl_send_num_tokens_per_rank.total_bytes + + nvl_send_num_tokens_per_expert.total_bytes <= + nvl_clean_offset * sizeof(int)); +#pragma unroll + for (int i = thread_id; i < nvl_num_int_clean; i += num_threads) + nvl_buffer_ptr_int[nvl_clean_offset + i] = 0; + + // Reduce number of tokens per expert into the NVL send buffer + // TODO(Xreki): may use NVSHMEM reduction + EP_DEVICE_ASSERT(num_rdma_experts <= num_threads); + if (thread_id < num_rdma_experts) { + for (int j = 0; j < num_loop_stage; ++j) { + int sum = 0; +#pragma unroll + for (int i = 0; i < kNumRDMARanks; ++i) + sum += rdma_recv_num_tokens_mixed.recv_buffer( + i)[j * (NUM_MAX_NVL_PEERS + num_rdma_experts + 1) + + NUM_MAX_NVL_PEERS + thread_id]; + nvl_reduced_num_tokens_per_expert[j * num_rdma_experts + thread_id] = + sum; + } + } + __syncthreads(); + + // Reduce RDMA received tokens + if (thread_id < num_loop_stage) { + int sum = 0; +#pragma unroll + for (int i = 0; i < kNumRDMARanks; ++i) { + sum += rdma_recv_num_tokens_mixed.recv_buffer( + i)[thread_id * (NUM_MAX_NVL_PEERS + num_rdma_experts + 1) + + NUM_MAX_NVL_PEERS + num_rdma_experts]; + recv_rdma_rank_prefix_sum[thread_id * kNumRDMARanks + i] = sum; + } + while (ld_volatile_global(moe_recv_rdma_counter_mapped + thread_id) != + -1) { + } + moe_recv_rdma_counter_mapped[thread_id] = sum; + } + + // Send numbers of tokens per rank/expert to NVL ranks + EP_DEVICE_ASSERT(NUM_MAX_NVL_PEERS <= num_threads); + if (thread_id < NUM_MAX_NVL_PEERS) { + for (int j = 0; j < num_loop_stage; ++j) { +#pragma unroll + for (int i = 0; i < kNumRDMARanks; ++i) + nvl_send_num_tokens_per_rank.buffer(nvl_rank)[j * kNumRDMARanks + i] = + rdma_recv_num_tokens_mixed.recv_buffer( + i)[j * (NUM_MAX_NVL_PEERS + num_rdma_experts + 1) + + thread_id]; +#pragma unroll + for (int i = 0; i < num_nvl_experts; ++i) + nvl_send_num_tokens_per_expert.buffer( + nvl_rank)[j * num_nvl_experts + i] = + nvl_reduced_num_tokens_per_expert[j * num_rdma_experts + + thread_id * num_nvl_experts + + i]; + } + } + memory_fence(); + __syncthreads(); + barrier_device(task_fifo_ptrs, head, nvl_rank); + move_fifo_slots(head); + __syncthreads(); + + // Reduce number of tokens per rank/expert + EP_DEVICE_ASSERT(num_nvl_experts <= num_threads); + if (thread_id < num_loop_stage) { + int sum = 0; +#pragma unroll + for (int i = 0; i < num_ranks; ++i) { + int src_rdma_rank = i / NUM_MAX_NVL_PEERS, + src_nvl_rank = i % NUM_MAX_NVL_PEERS; + sum += nvl_recv_num_tokens_per_rank.buffer( + src_nvl_rank)[thread_id * kNumRDMARanks + src_rdma_rank]; + recv_gbl_rank_prefix_sum[thread_id * num_ranks + i] = sum; + } + while (ld_volatile_global(moe_recv_counter_mapped + thread_id) != -1) { + } + moe_recv_counter_mapped[thread_id] = sum; + } + + EP_DEVICE_ASSERT(num_nvl_experts * num_loop_stage <= num_threads); + if (thread_id < num_nvl_experts * num_loop_stage) { + int sum = 0; +#pragma unroll + for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i) + sum += nvl_recv_num_tokens_per_expert.buffer(i)[thread_id]; + sum = (sum + expert_alignment - 1) / expert_alignment * expert_alignment; + while (ld_volatile_global(moe_recv_expert_counter_mapped + thread_id) != + -1) { + } + moe_recv_expert_counter_mapped[thread_id] = sum; + } + + // Finally barrier + __syncthreads(); + if (thread_id == 32) + nvshmem_barrier_with_same_gpu_idx(rdma_team); + barrier_device(task_fifo_ptrs, head, nvl_rank); + move_fifo_slots(head); + } else { + // Calculate meta data + int stage_id = (sm_id - 1) / kNumRDMARanks; + int dst_rdma_rank = (sm_id - 1) % kNumRDMARanks; + for (int channel_id = warp_id; channel_id < num_channels; + channel_id += num_warps) { + int token_start_idx, token_end_idx; + get_channel_task_range( + num_tokens, num_channels, channel_id, token_start_idx, token_end_idx); + + // Iterate over tokens + int total_count = 0, per_nvl_rank_count[NUM_MAX_NVL_PEERS] = {0}; + for (int64_t i = token_start_idx + lane_id; i < token_end_idx; i += 32) { + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * sizeof(bool) == sizeof(uint64_t), + "Invalid number of NVL peers"); + auto is_token_in_rank_uint64 = *reinterpret_cast( + is_token_in_rank + (stage_id * num_tokens + i) * num_ranks + + dst_rdma_rank * NUM_MAX_NVL_PEERS); + auto is_token_in_rank_values = + reinterpret_cast(&is_token_in_rank_uint64); +#pragma unroll + for (int j = 0; j < NUM_MAX_NVL_PEERS; ++j) + per_nvl_rank_count[j] += is_token_in_rank_values[j]; + total_count += (is_token_in_rank_uint64 != 0); + } + + // Warp reduce + total_count = warp_reduce_sum(total_count); +#pragma unroll + for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i) + per_nvl_rank_count[i] = warp_reduce_sum(per_nvl_rank_count[i]); + + // Write into channel matrix + if (lane_id == 0) { +#pragma unroll + for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i) + gbl_channel_prefix_matrix[(stage_id * num_ranks + + dst_rdma_rank * NUM_MAX_NVL_PEERS + i) * + num_channels + + channel_id] = per_nvl_rank_count[i]; + rdma_channel_prefix_matrix[(stage_id * kNumRDMARanks + dst_rdma_rank) * + num_channels + + channel_id] = total_count; + } + } + + // Calculate prefix sum + __syncthreads(); + if (thread_id == 0) { + auto prefix_row = + rdma_channel_prefix_matrix + + (stage_id * kNumRDMARanks + dst_rdma_rank) * num_channels; +#pragma unroll + for (int i = 1; i < num_channels; ++i) prefix_row[i] += prefix_row[i - 1]; + } + + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= 32, "Invalid number of NVL peers"); + if (thread_id < NUM_MAX_NVL_PEERS) { + auto prefix_row = gbl_channel_prefix_matrix + + (stage_id * num_ranks + + dst_rdma_rank * NUM_MAX_NVL_PEERS + thread_id) * + num_channels; +#pragma unroll + for (int i = 1; i < num_channels; ++i) prefix_row[i] += prefix_row[i - 1]; + } + } +} + +void fused_notify_dispatch(const int* num_tokens_per_rank, + int* moe_recv_counter_mapped, + int num_ranks, + const int* num_tokens_per_rdma_rank, + int* moe_recv_rdma_counter_mapped, + const int* num_tokens_per_expert, + int* moe_recv_expert_counter_mapped, + int num_experts, + const bool* is_token_in_rank, + int num_tokens, + int num_channels, + int hidden_int4, + int num_scales, + int num_topk, + int expert_alignment, + int num_loop_stage, + int* rdma_channel_prefix_matrix, + int* recv_rdma_rank_prefix_sum, + int* gbl_channel_prefix_matrix, + int* recv_gbl_rank_prefix_sum, + void* rdma_buffer_ptr, + int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, + int num_max_nvl_chunked_recv_tokens, + int** task_fifo_ptrs, + int head, + int rank, + cudaStream_t stream, + int64_t num_rdma_bytes, + int64_t num_nvl_bytes, + bool low_latency_mode) { +#define NOTIFY_DISPATCH_LAUNCH_CASE(num_rdma_ranks) \ + { \ + auto fused_notify_dispatch_func = \ + low_latency_mode ? fused_notify_dispatch \ + : fused_notify_dispatch; \ + LAUNCH_KERNEL(&cfg, \ + fused_notify_dispatch_func, \ + num_tokens_per_rank, \ + moe_recv_counter_mapped, \ + num_ranks, \ + num_tokens_per_rdma_rank, \ + moe_recv_rdma_counter_mapped, \ + num_tokens_per_expert, \ + moe_recv_expert_counter_mapped, \ + num_experts, \ + is_token_in_rank, \ + num_tokens, \ + num_channels, \ + expert_alignment, \ + num_loop_stage, \ + rdma_clean_meta.first, \ + rdma_clean_meta.second, \ + nvl_clean_meta.first, \ + nvl_clean_meta.second, \ + rdma_channel_prefix_matrix, \ + recv_rdma_rank_prefix_sum, \ + gbl_channel_prefix_matrix, \ + recv_gbl_rank_prefix_sum, \ + rdma_buffer_ptr, \ + buffer_ptrs, \ + task_fifo_ptrs, \ + head, \ + rank, \ + cpu_rdma_team); \ + } \ + break + + constexpr int kNumThreads = 512; + const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; + + // Get clean meta + auto rdma_clean_meta = get_rdma_clean_meta(hidden_int4, + num_scales, + num_topk, + num_topk, + num_rdma_ranks, + num_max_rdma_chunked_recv_tokens, + num_channels); + auto nvl_clean_meta = get_nvl_clean_meta(hidden_int4, + num_scales, + num_topk, + num_topk, + num_rdma_ranks, + NUM_MAX_NVL_PEERS, + num_max_nvl_chunked_recv_tokens, + num_channels); + EP_HOST_ASSERT((rdma_clean_meta.first + rdma_clean_meta.second) * + sizeof(int) <= + num_rdma_bytes); + EP_HOST_ASSERT((nvl_clean_meta.first + nvl_clean_meta.second) * sizeof(int) <= + num_nvl_bytes); + EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits::max()); + EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits::max()); + + // Launch kernel + SETUP_LAUNCH_CONFIG(1 + num_rdma_ranks * num_loop_stage, kNumThreads, stream); + SWITCH_RDMA_RANKS(NOTIFY_DISPATCH_LAUNCH_CASE); +#undef NOTIFY_DISPATCH_LAUNCH_CASE +} + template __global__ void fused_notify_combine( const int* num_tokens_per_rank, // [num_loop_stage, 2, num_ranks] diff --git a/paddle/fluid/pybind/deep_ep_api.cc b/paddle/fluid/pybind/deep_ep_api.cc index 3dddc79940947c..7228f773d87a42 100644 --- a/paddle/fluid/pybind/deep_ep_api.cc +++ b/paddle/fluid/pybind/deep_ep_api.cc @@ -97,6 +97,8 @@ void BindDeepEPApi(pybind11::module *m) { .def("internode_dispatch", &deep_ep::Buffer::internode_dispatch_api) .def("internode_notify_dispatch", &deep_ep::Buffer::internode_notify_dispatch_api) + .def("internode_fused_notify_dispatch", + &deep_ep::Buffer::internode_fused_notify_dispatch_api) .def("clear_buffer", &deep_ep::Buffer::clear_buffer_api) .def("internode_notify_combine", &deep_ep::Buffer::internode_notify_combine_api) diff --git a/python/paddle/distributed/communication/deep_ep/buffer.py b/python/paddle/distributed/communication/deep_ep/buffer.py index dbd77b35c89339..e68f5b517e2bd2 100644 --- a/python/paddle/distributed/communication/deep_ep/buffer.py +++ b/python/paddle/distributed/communication/deep_ep/buffer.py @@ -902,6 +902,80 @@ def internode_notify_dispatch( handle, ) + def internode_fused_notify_dispatch( + self, + x: paddle.Tensor | tuple[paddle.Tensor, paddle.Tensor], + topk_idx: paddle.Tensor | None = None, + num_tokens_per_rank: paddle.Tensor | None = None, + num_tokens_per_rdma_rank: paddle.Tensor | None = None, + num_tokens_per_expert: paddle.Tensor | None = None, + is_token_in_rank: paddle.Tensor | None = None, + num_loop_stage: int = 1, + expert_alignment: int = 1, + config: Config | None = None, + ) -> tuple[ + list[list[int]], + list[int], + list[int], + list[tuple], + ]: + # Default config + config = ( + self.get_dispatch_config(self.group_size) + if config is None + else config + ) + # Launch the kernel with cached or non-cached mode + x, x_scales = x if isinstance(x, tuple) else (x, None) + assert ( + num_tokens_per_rank is not None + and is_token_in_rank is not None + and num_tokens_per_expert is not None + ) + + ( + num_recv_tokens_per_expert_list, + num_recv_tokens, + num_rdma_recv_tokens, + rdma_channel_prefix_matrix, + gbl_channel_prefix_matrix, + recv_rdma_rank_prefix_sum, + recv_gbl_rank_prefix_sum, + ) = self.runtime.internode_fused_notify_dispatch( + x, + x_scales, + topk_idx, + num_tokens_per_rank, + num_tokens_per_rdma_rank, + num_tokens_per_expert, + is_token_in_rank, + expert_alignment, + num_loop_stage, + config, + ) + handles = [] + for loop_idx in range(num_loop_stage): + handle = ( + is_token_in_rank[loop_idx], + rdma_channel_prefix_matrix[loop_idx], + gbl_channel_prefix_matrix[loop_idx], + None, + recv_rdma_rank_prefix_sum[loop_idx], + None, + recv_gbl_rank_prefix_sum[loop_idx], + paddle.empty([num_recv_tokens[loop_idx], 0]), + None, + paddle.empty([num_rdma_recv_tokens[loop_idx], 0]), + ) + handles.append(handle) + + return ( + num_recv_tokens_per_expert_list, + num_recv_tokens, + num_rdma_recv_tokens, + handles, + ) + def internode_notify_combine( self, x: paddle.Tensor | tuple[paddle.Tensor, paddle.Tensor],