Skip to content

Commit 30b722e

Browse files
committed
address copilot feedback
1 parent d37cbf0 commit 30b722e

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

vllm/compilation/collective_fusion.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -173,12 +173,14 @@ def call_trtllm_fused_allreduce_norm(
173173
max_token_num: int,
174174
norm_out: Optional[torch.Tensor] = None,
175175
) -> None:
176-
use_flashinfer = allreduce_in.shape[0] * allreduce_in.shape[
177-
1] * allreduce_in.element_size() <= min(
178-
_FI_MAX_SIZES[world_size],
179-
max_token_num * allreduce_in.shape[1] *
180-
allreduce_in.element_size(),
181-
)
176+
177+
num_tokens, hidden_size = allreduce_in.shape
178+
element_size = allreduce_in.element_size()
179+
current_tensor_size = num_tokens * hidden_size * element_size
180+
max_fusion_size = max_token_num * hidden_size * element_size
181+
use_flashinfer = current_tensor_size <= min(_FI_MAX_SIZES[world_size],
182+
max_fusion_size)
183+
182184
if use_flashinfer:
183185
assert (_FI_WORKSPACE_TENSOR is not None
184186
), "Flashinfer must be enabled when using flashinfer"

0 commit comments

Comments
 (0)