Skip to content

Commit 7ea22e4

Browse files
authored
[Misc] Add override for allreduce fusion thresholds (#23639)
Signed-off-by: Julien Lin <[email protected]>
1 parent 9d4183d commit 7ea22e4

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

vllm/compilation/collective_fusion.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from torch._inductor.pattern_matcher import PatternMatcherPass
1111
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
1212

13+
import vllm.envs as envs
1314
from vllm.config import VllmConfig
1415
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
1516
from vllm.distributed.parallel_state import (
@@ -401,6 +402,18 @@ def __call__(self, graph: fx.Graph):
401402
6: MiB // 2, # 512KB
402403
8: MiB // 2, # 512KB
403404
}
405+
406+
try:
407+
_FI_MAX_SIZES.update({
408+
int(k): int(float(v) * MiB)
409+
for k, v in
410+
envs.VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB.items()
411+
})
412+
except Exception as e:
413+
raise ValueError(
414+
"Failed to parse VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB: "
415+
+ str(e)) from e
416+
404417
# opt for a more conservative default value
405418
# when world size is not in _FI_MAX_SIZES
406419
_DEFAULT_FI_MAX_SIZE = MiB // 2

vllm/envs.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import hashlib
5+
import json
56
import os
67
import sys
78
import tempfile
@@ -1046,6 +1047,16 @@ def get_vllm_port() -> Optional[int]:
10461047
"VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE":
10471048
lambda: int(os.getenv("VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE", "163840")),
10481049

1050+
# Specifies the thresholds of the communicated tensor sizes under which
1051+
# vllm should use flashinfer fused allreduce. The variable should be a
1052+
# JSON with the following format:
1053+
# { <world size>: <max size in mb> }
1054+
# Unspecified world sizes will fallback to
1055+
# { 2: 64, 4: 1, <everything else>: 0.5 }
1056+
"VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB":
1057+
lambda: json.loads(os.getenv(
1058+
"VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB", "{}")),
1059+
10491060
# MoE routing strategy selector.
10501061
# See `RoutingSimulator.get_available_strategies()` # for available
10511062
# strategies.

0 commit comments

Comments
 (0)