Skip to content

Commit bf247da

Browse files
committed
address review feedback when world size is uncommon
1 parent bf3fa63 commit bf247da

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

vllm/compilation/collective_fusion.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,9 @@ def __call__(self, graph: fx.Graph):
159159
6: MiB // 2, # 512KB
160160
8: MiB // 2, # 512KB
161161
}
162+
# opt for a more conservative default value
163+
# when world size is not in _FI_MAX_SIZES
164+
_DEFAULT_FI_MAX_SIZE = MiB // 2
162165

163166
def call_trtllm_fused_allreduce_norm(
164167
allreduce_in: torch.Tensor,
@@ -178,8 +181,10 @@ def call_trtllm_fused_allreduce_norm(
178181
element_size = allreduce_in.element_size()
179182
current_tensor_size = num_tokens * hidden_size * element_size
180183
max_fusion_size = max_token_num * hidden_size * element_size
181-
use_flashinfer = current_tensor_size <= min(_FI_MAX_SIZES[world_size],
182-
max_fusion_size)
184+
use_flashinfer = current_tensor_size <= min(
185+
_FI_MAX_SIZES.get(world_size, _DEFAULT_FI_MAX_SIZE),
186+
max_fusion_size,
187+
)
183188

184189
if use_flashinfer:
185190
assert (_FI_WORKSPACE_TENSOR is not None

0 commit comments

Comments
 (0)