Skip to content

Commit 49a3b8a

Browse files
Fix fused_scaled_matmul_reduce_scatter signature for PyTorch update
Updated torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter calls to match the new PyTorch API signature. The function signature changed from PyTorch 2.7.1 to require additional positional parameters. Changes: - Added orig_scatter_dim and scatter_dim_after_maybe_reshape as positional parameters - Added output_shape calculation: [*input.shape[:-1], mat2.shape[1]] - Changed all optional parameters (bias, result_scale, out_dtype, use_fast_accum) from keyword arguments to positional arguments to match PyTorch's torch._inductor implementation References: - PyTorch function definition: torch/distributed/_symmetric_memory/__init__.py:454-461 - PyTorch test usage: test/distributed/test_symmetric_memory.py:579-590 - PyTorch inductor usage: torch/_inductor/fx_passes/micro_pipeline_tp.py:816-834 Signed-off-by: jasonlizhengjian <[email protected]>
1 parent 5db1870 commit 49a3b8a

File tree

1 file changed

+22
-6
lines changed

1 file changed

+22
-6
lines changed

vllm/compilation/collective_fusion.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -156,15 +156,23 @@ def pattern(input: torch.Tensor, mat2: torch.Tensor,
156156
def replacement(input: torch.Tensor, mat2: torch.Tensor,
157157
scale_a: torch.Tensor,
158158
scale_b: torch.Tensor) -> torch.Tensor:
159+
# Calculate output shape: input @ mat2 with scatter_dim reduced
160+
output_shape = [*input.shape[:-1], mat2.shape[1]]
161+
scatter_dim = 0
159162
gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
160163
input,
161164
mat2,
162165
scale_a,
163166
scale_b,
164167
"avg",
165-
scatter_dim=0,
166-
out_dtype=self.dtype,
167-
group_name=self.tp.device_group.group_name,
168+
scatter_dim, # orig_scatter_dim
169+
scatter_dim, # scatter_dim_after_maybe_reshape
170+
self.tp.device_group.group_name,
171+
output_shape,
172+
None, # bias
173+
None, # result_scale
174+
self.dtype, # out_dtype
175+
False, # use_fast_accum
168176
)
169177

170178
return gemm_rs
@@ -268,15 +276,23 @@ def pattern(input: torch.Tensor, weight: torch.Tensor,
268276
def replacement(input: torch.Tensor, mat2: torch.Tensor,
269277
scale_a: torch.Tensor, scale_b: torch.Tensor,
270278
cutlass_mm_output: torch.Tensor) -> torch.Tensor:
279+
# Calculate output shape: input @ mat2 with scatter_dim reduced
280+
output_shape = [*input.shape[:-1], mat2.shape[1]]
281+
scatter_dim = 0
271282
gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
272283
input,
273284
mat2,
274285
scale_a,
275286
scale_b,
276287
"avg",
277-
scatter_dim=0,
278-
out_dtype=self.dtype,
279-
group_name=self.tp.device_group.group_name,
288+
scatter_dim, # orig_scatter_dim
289+
scatter_dim, # scatter_dim_after_maybe_reshape
290+
self.tp.device_group.group_name,
291+
output_shape,
292+
None, # bias
293+
None, # result_scale
294+
self.dtype, # out_dtype
295+
False, # use_fast_accum
280296
)
281297

282298
return gemm_rs

0 commit comments

Comments
 (0)