File tree Expand file tree Collapse file tree 1 file changed +8
-6
lines changed Expand file tree Collapse file tree 1 file changed +8
-6
lines changed Original file line number Diff line number Diff line change @@ -173,12 +173,14 @@ def call_trtllm_fused_allreduce_norm(
173
173
max_token_num : int ,
174
174
norm_out : Optional [torch .Tensor ] = None ,
175
175
) -> None :
176
- use_flashinfer = allreduce_in .shape [0 ] * allreduce_in .shape [
177
- 1 ] * allreduce_in .element_size () <= min (
178
- _FI_MAX_SIZES [world_size ],
179
- max_token_num * allreduce_in .shape [1 ] *
180
- allreduce_in .element_size (),
181
- )
176
+
177
+ num_tokens , hidden_size = allreduce_in .shape
178
+ element_size = allreduce_in .element_size ()
179
+ current_tensor_size = num_tokens * hidden_size * element_size
180
+ max_fusion_size = max_token_num * hidden_size * element_size
181
+ use_flashinfer = current_tensor_size <= min (
182
+ _FI_MAX_SIZES [world_size ], max_fusion_size )
183
+
182
184
if use_flashinfer :
183
185
assert (_FI_WORKSPACE_TENSOR is not None
184
186
), "Flashinfer must be enabled when using flashinfer"
You can’t perform that action at this time.
0 commit comments