File tree Expand file tree Collapse file tree 1 file changed +7
-2
lines changed Expand file tree Collapse file tree 1 file changed +7
-2
lines changed Original file line number Diff line number Diff line change @@ -159,6 +159,9 @@ def __call__(self, graph: fx.Graph):
159
159
6 : MiB // 2 , # 512KB
160
160
8 : MiB // 2 , # 512KB
161
161
}
162
+ # opt for a more conservative default value
163
+ # when world size is not in _FI_MAX_SIZES
164
+ _DEFAULT_FI_MAX_SIZE = MiB // 2
162
165
163
166
def call_trtllm_fused_allreduce_norm (
164
167
allreduce_in : torch .Tensor ,
@@ -178,8 +181,10 @@ def call_trtllm_fused_allreduce_norm(
178
181
element_size = allreduce_in .element_size ()
179
182
current_tensor_size = num_tokens * hidden_size * element_size
180
183
max_fusion_size = max_token_num * hidden_size * element_size
181
- use_flashinfer = current_tensor_size <= min (_FI_MAX_SIZES [world_size ],
182
- max_fusion_size )
184
+ use_flashinfer = current_tensor_size <= min (
185
+ _FI_MAX_SIZES .get (world_size , _DEFAULT_FI_MAX_SIZE ),
186
+ max_fusion_size ,
187
+ )
183
188
184
189
if use_flashinfer :
185
190
assert (_FI_WORKSPACE_TENSOR is not None
You can’t perform that action at this time.
0 commit comments