You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Fix fused_scaled_matmul_reduce_scatter signature for PyTorch update
Updated torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter calls to match
the new PyTorch API signature. The function signature changed from PyTorch 2.7.1
to require additional positional parameters.
Changes:
- Added orig_scatter_dim and scatter_dim_after_maybe_reshape as positional parameters
- Added output_shape calculation: [*input.shape[:-1], mat2.shape[1]]
- Changed all optional parameters (bias, result_scale, out_dtype, use_fast_accum)
from keyword arguments to positional arguments to match PyTorch's torch._inductor
implementation
References:
- PyTorch function definition: torch/distributed/_symmetric_memory/__init__.py:454-461
- PyTorch test usage: test/distributed/test_symmetric_memory.py:579-590
- PyTorch inductor usage: torch/_inductor/fx_passes/micro_pipeline_tp.py:816-834
Signed-off-by: jasonlizhengjian <[email protected]>
0 commit comments