diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index a8b00aaf0842..0e7961841bd3 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -159,6 +159,9 @@ def __call__(self, graph: fx.Graph): 6: MiB // 2, # 512KB 8: MiB // 2, # 512KB } + # opt for a more conservative default value + # when world size is not in _FI_MAX_SIZES + _DEFAULT_FI_MAX_SIZE = MiB // 2 def call_trtllm_fused_allreduce_norm( allreduce_in: torch.Tensor, @@ -173,12 +176,16 @@ def call_trtllm_fused_allreduce_norm( max_token_num: int, norm_out: Optional[torch.Tensor] = None, ) -> None: - use_flashinfer = allreduce_in.shape[0] * allreduce_in.shape[ - 1] * allreduce_in.element_size() <= min( - _FI_MAX_SIZES[world_size], - max_token_num * allreduce_in.shape[0] * - allreduce_in.element_size(), - ) + + num_tokens, hidden_size = allreduce_in.shape + element_size = allreduce_in.element_size() + current_tensor_size = num_tokens * hidden_size * element_size + max_fusion_size = max_token_num * hidden_size * element_size + use_flashinfer = current_tensor_size <= min( + _FI_MAX_SIZES.get(world_size, _DEFAULT_FI_MAX_SIZE), + max_fusion_size, + ) + if use_flashinfer: assert (_FI_WORKSPACE_TENSOR is not None ), "Flashinfer must be enabled when using flashinfer"