@@ -938,11 +938,12 @@ Buffer::low_latency_dispatch(const at::Tensor &x, const at::Tensor &topk_idx,
938938 this ->new_topk_idx = torch::cat (topk_blocks, 0 );
939939 }
940940
941+ EP_HOST_ASSERT (num_max_dispatch_tokens_per_rank >= new_x.size (0 ));
942+
941943 auto num_tokens = static_cast <int >(new_x.size (0 )), hidden = static_cast <int >(new_x.size (1 ));
942944 auto num_scales = hidden / 128 , num_topk = static_cast <int >(new_topk_idx.size (1 ));
943945 auto num_local_experts = num_experts / (num_ranks - shared_expert_rank_num);
944-
945- int64_t global_bs = std::max (new_topk_idx.size (0 ), num_max_dispatch_tokens_per_rank) * num_ranks;
946+ int64_t global_bs = num_max_dispatch_tokens_per_rank * num_ranks;
946947 auto num_max_tokens = 0 ;
947948 if (rank < shared_expert_rank_num) {
948949 num_max_tokens = global_bs / shared_expert_rank_num;
@@ -1059,6 +1060,7 @@ std::tuple<at::Tensor, std::optional<EventHandle>, std::optional<std::function<v
10591060 }
10601061 // Tensor checks
10611062 EP_HOST_ASSERT (x.dim () == 2 and x.is_contiguous () and x.scalar_type () == at::kBFloat16 );
1063+ EP_HOST_ASSERT (num_max_dispatch_tokens_per_rank >= new_idx.size (0 ));
10621064 // EP_HOST_ASSERT(x.size(0) == num_experts / num_ranks);
10631065
10641066 // get ep & tp name
@@ -1082,7 +1084,7 @@ std::tuple<at::Tensor, std::optional<EventHandle>, std::optional<std::function<v
10821084 int64_t tp_world_size = 1 ;
10831085 int64_t tp_rankId = 0 ;
10841086 int64_t expert_shared_type = 0 ;
1085- int64_t global_bs = std::max (new_idx. size ( 0 ), num_max_dispatch_tokens_per_rank) * num_ranks;
1087+ int64_t global_bs = num_max_dispatch_tokens_per_rank * num_ranks;
10861088 int64_t out_dtype = 0 ;
10871089 int64_t comm_quant_mode = 0 ;
10881090 int64_t group_list_type = 0 ;
0 commit comments