Skip to content

Commit 89a10ce

Browse files
Fix fused_scaled_matmul_reduce_scatter signature for PyTorch update
Updated torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter calls to match the new PyTorch API signature. The function signature changed from PyTorch 2.7.1 to require additional positional parameters. Changes: - Added orig_scatter_dim and scatter_dim_after_maybe_reshape as positional parameters - Added output_shape calculation: [*input.shape[:-1], mat2.shape[1]] - Changed all optional parameters (bias, result_scale, out_dtype, use_fast_accum) from keyword arguments to positional arguments to match PyTorch's torch._inductor implementation References: - PyTorch function definition: torch/distributed/_symmetric_memory/__init__.py:454-461 - PyTorch test usage: test/distributed/test_symmetric_memory.py:579-590 - PyTorch inductor usage: torch/_inductor/fx_passes/micro_pipeline_tp.py:816-834 Signed-off-by: jasonlizhengjian <[email protected]>
1 parent 1c0c682 commit 89a10ce

File tree

1 file changed

+22
-6
lines changed

1 file changed

+22
-6
lines changed

vllm/compilation/collective_fusion.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -169,15 +169,23 @@ def replacement(
169169
scale_a: torch.Tensor,
170170
scale_b: torch.Tensor,
171171
) -> torch.Tensor:
172+
# Calculate output shape: input @ mat2 with scatter_dim reduced
173+
output_shape = [*input.shape[:-1], mat2.shape[1]]
174+
scatter_dim = 0
172175
gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
173176
input,
174177
mat2,
175178
scale_a,
176179
scale_b,
177180
"avg",
178-
scatter_dim=0,
179-
out_dtype=self.dtype,
180-
group_name=self.tp.device_group.group_name,
181+
scatter_dim, # orig_scatter_dim
182+
scatter_dim, # scatter_dim_after_maybe_reshape
183+
self.tp.device_group.group_name,
184+
output_shape,
185+
None, # bias
186+
None, # result_scale
187+
self.dtype, # out_dtype
188+
False, # use_fast_accum
181189
)
182190

183191
return gemm_rs
@@ -296,15 +304,23 @@ def replacement(
296304
scale_b: torch.Tensor,
297305
cutlass_mm_output: torch.Tensor,
298306
) -> torch.Tensor:
307+
# Calculate output shape: input @ mat2 with scatter_dim reduced
308+
output_shape = [*input.shape[:-1], mat2.shape[1]]
309+
scatter_dim = 0
299310
gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
300311
input,
301312
mat2,
302313
scale_a,
303314
scale_b,
304315
"avg",
305-
scatter_dim=0,
306-
out_dtype=self.dtype,
307-
group_name=self.tp.device_group.group_name,
316+
scatter_dim, # orig_scatter_dim
317+
scatter_dim, # scatter_dim_after_maybe_reshape
318+
self.tp.device_group.group_name,
319+
output_shape,
320+
None, # bias
321+
None, # result_scale
322+
self.dtype, # out_dtype
323+
False, # use_fast_accum
308324
)
309325

310326
return gemm_rs

0 commit comments

Comments
 (0)