Skip to content

Commit 46b73de

Browse files
authored
Added the verification of num_max_dispatch_tokens_per_rank to the decode operator adaptation layer. (#330)
1 parent ef80a67 commit 46b73de

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

csrc/deepep/deep_ep.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -938,11 +938,12 @@ Buffer::low_latency_dispatch(const at::Tensor &x, const at::Tensor &topk_idx,
938938
this->new_topk_idx = torch::cat(topk_blocks, 0);
939939
}
940940

941+
EP_HOST_ASSERT(num_max_dispatch_tokens_per_rank >= new_x.size(0));
942+
941943
auto num_tokens = static_cast<int>(new_x.size(0)), hidden = static_cast<int>(new_x.size(1));
942944
auto num_scales = hidden / 128, num_topk = static_cast<int>(new_topk_idx.size(1));
943945
auto num_local_experts = num_experts / (num_ranks - shared_expert_rank_num);
944-
945-
int64_t global_bs = std::max(new_topk_idx.size(0), num_max_dispatch_tokens_per_rank) * num_ranks;
946+
int64_t global_bs = num_max_dispatch_tokens_per_rank * num_ranks;
946947
auto num_max_tokens = 0;
947948
if (rank < shared_expert_rank_num) {
948949
num_max_tokens = global_bs / shared_expert_rank_num;
@@ -1059,6 +1060,7 @@ std::tuple<at::Tensor, std::optional<EventHandle>, std::optional<std::function<v
10591060
}
10601061
// Tensor checks
10611062
EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous() and x.scalar_type() == at::kBFloat16);
1063+
EP_HOST_ASSERT(num_max_dispatch_tokens_per_rank >= new_idx.size(0));
10621064
// EP_HOST_ASSERT(x.size(0) == num_experts / num_ranks);
10631065

10641066
// get ep & tp name
@@ -1082,7 +1084,7 @@ std::tuple<at::Tensor, std::optional<EventHandle>, std::optional<std::function<v
10821084
int64_t tp_world_size = 1;
10831085
int64_t tp_rankId = 0;
10841086
int64_t expert_shared_type = 0;
1085-
int64_t global_bs = std::max(new_idx.size(0), num_max_dispatch_tokens_per_rank) * num_ranks;
1087+
int64_t global_bs = num_max_dispatch_tokens_per_rank * num_ranks;
10861088
int64_t out_dtype = 0;
10871089
int64_t comm_quant_mode = 0;
10881090
int64_t group_list_type = 0;

0 commit comments

Comments
 (0)