Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions vllm/compilation/collective_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@ def __call__(self, graph: fx.Graph):
6: MiB // 2, # 512KB
8: MiB // 2, # 512KB
}
# opt for a more conservative default value
# when world size is not in _FI_MAX_SIZES
_DEFAULT_FI_MAX_SIZE = MiB // 2

def call_trtllm_fused_allreduce_norm(
allreduce_in: torch.Tensor,
Expand All @@ -173,12 +176,16 @@ def call_trtllm_fused_allreduce_norm(
max_token_num: int,
norm_out: Optional[torch.Tensor] = None,
) -> None:
use_flashinfer = allreduce_in.shape[0] * allreduce_in.shape[
1] * allreduce_in.element_size() <= min(
_FI_MAX_SIZES[world_size],
max_token_num * allreduce_in.shape[0] *
allreduce_in.element_size(),
)

num_tokens, hidden_size = allreduce_in.shape
element_size = allreduce_in.element_size()
current_tensor_size = num_tokens * hidden_size * element_size
max_fusion_size = max_token_num * hidden_size * element_size
use_flashinfer = current_tensor_size <= min(
_FI_MAX_SIZES.get(world_size, _DEFAULT_FI_MAX_SIZE),
max_fusion_size,
)

if use_flashinfer:
assert (_FI_WORKSPACE_TENSOR is not None
), "Flashinfer must be enabled when using flashinfer"
Expand Down