Skip to content

Commit 9abc9aa

Browse files
garrett361pytorchmergebot
authored andcommitted
fix: use grad div factor when fsdp_degree=1 (pytorch#167178)
`fully_shard`'s `gradient_divide_factor` isn't currently respected when the sharding degree = 1. This PR ensures the division factor applies also in this case. This is a bit of an edge case, but it arises in `torchtitan`, e.g. with expert parallelism and `ep_degree=world_size` we still wrap the routed experts in `fully_shard` because: 1) It lets us take advantage of its mixed-precision mechanisms. 2) [A specific gradient_divide_factor is needed for correctness](https://github.com/pytorch/torchtitan/blob/176498cd4edd4d80e95959a618279681f8295f4c/torchtitan/models/llama4/infra/parallelize.py?plain=1#L364-L369) This PR ensures correctness in the `reduce_scatter_group.size()==1` case. Reproducer and sample failures are in the [gist here](https://gist.github.ibm.com/goon/f67e7559284cc2d322faff1ac59fe382). The net effect is that the EP grads are too-large by a factor of the world size in the case described above. I checked that the proposed fix makes these tests pass. I guess I should add a test for this, too? Pull Request resolved: pytorch#167178 Approved by: https://github.com/weifengpy
1 parent 789240b commit 9abc9aa

File tree

2 files changed

+43
-16
lines changed

2 files changed

+43
-16
lines changed

test/distributed/_composable/fsdp/test_fully_shard_comm.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -428,26 +428,46 @@ def test_manual_reshard_with_reshard_after_forward_false(self):
428428
@xfailIf(TEST_XPU) # https://github.com/intel/torch-xpu-ops/issues/1571
429429
def test_set_reduce_scatter_divide_factor(self):
430430
self.run_subtests(
431-
{"divide_factor": [self.world_size * 2, self.world_size]},
431+
{
432+
"divide_factor": [self.world_size * 2, self.world_size],
433+
"mesh_shape": [
434+
(self.world_size,),
435+
(self.world_size // 2, 2),
436+
(self.world_size, 1),
437+
],
438+
},
432439
self._test_set_reduce_scatter_divide_factor,
433440
)
434441
self.run_subtests(
435442
{"divide_factor": [self.world_size]},
436443
self._test_set_reduce_scatter_divide_factor_mixed_prevision,
437444
)
438445

439-
def _test_set_reduce_scatter_divide_factor(self, divide_factor: float):
446+
def _test_set_reduce_scatter_divide_factor(
447+
self, divide_factor: float, mesh_shape: tuple[int] | tuple[int, int]
448+
):
440449
torch.manual_seed(42)
441450
model_args = ModelArgs(dropout_p=0.0, weight_tying=False)
442451
model = Transformer(model_args)
443452
ref_model = copy.deepcopy(model).to(device_type)
444453
ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2)
454+
mesh_dim_names = ("outer",) if len(mesh_shape) == 1 else ("outer", "inner")
455+
mesh = init_device_mesh(
456+
device_type.type, mesh_shape, mesh_dim_names=mesh_dim_names
457+
)
445458
for module in model.modules():
446459
if isinstance(module, TransformerBlock):
447-
fully_shard(module, reshard_after_forward=False)
448-
model = fully_shard(model, reshard_after_forward=False)
460+
fully_shard(module, reshard_after_forward=False, mesh=mesh)
461+
model = fully_shard(model, reshard_after_forward=False, mesh=mesh)
449462
optim = torch.optim.AdamW(model.parameters(), lr=1e-2)
450-
model.set_reduce_scatter_divide_factor(divide_factor)
463+
model.set_gradient_divide_factor(divide_factor)
464+
465+
# Get ref_model params which should have the specific division factor applied
466+
block_params = set()
467+
for ref_mod in ref_model.modules():
468+
if isinstance(ref_mod, TransformerBlock):
469+
block_params.update(ref_mod.parameters())
470+
non_block_params = set(ref_model.parameters()) - block_params
451471

452472
torch.manual_seed(42 + self.rank)
453473
inp = torch.randint(0, model_args.vocab_size, (2, 16), device=device_type.type)
@@ -456,16 +476,18 @@ def _test_set_reduce_scatter_divide_factor(self, divide_factor: float):
456476
ref_loss = ref_model(inp).sum()
457477
ref_loss.backward()
458478
for param in ref_model.parameters():
459-
param.grad.mul_(1.0 / divide_factor)
479+
factor = divide_factor if param in non_block_params else self.world_size
480+
param.grad.mul_(1.0 / factor)
460481
dist.all_reduce(param.grad)
461482
loss = model(inp).sum()
462483
loss.backward()
463484
ref_optim.step()
464485
optim.step()
465-
ref_optim.zero_grad()
466-
optim.zero_grad()
467486
self.assertEqual(ref_loss, loss)
487+
# Check parity before calling zero_grad so that grads are also checked
468488
check_sharded_parity(self, ref_model, model)
489+
ref_optim.zero_grad()
490+
optim.zero_grad()
469491

470492
def _test_set_reduce_scatter_divide_factor_mixed_prevision(
471493
self, divide_factor: float
@@ -484,7 +506,7 @@ def _test_set_reduce_scatter_divide_factor_mixed_prevision(
484506
fully_shard(mlp, mp_policy=mp_policy)
485507
model = fully_shard(model, mp_policy=mp_policy)
486508
optim = torch.optim.AdamW(model.parameters(), lr=1e-2)
487-
model.set_reduce_scatter_divide_factor(divide_factor)
509+
model.set_gradient_divide_factor(divide_factor)
488510

489511
torch.manual_seed(42 + self.rank)
490512
inp = torch.randn((4, 16), device=device_type.type, dtype=param_dtype)

torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -547,8 +547,12 @@ def foreach_reduce(
547547
op=reduce_scatter_op,
548548
)
549549
else:
550-
# For single GPU, just copy the input to output (no actual reduce-scatter needed)
551-
reduce_output.copy_(reduce_scatter_input)
550+
# For single GPU, just copy the input to output (no actual reduce-scatter needed), and
551+
# account for a possible gradient_divide_factor.
552+
if gradient_divide_factor is not None:
553+
reduce_output.copy_(reduce_scatter_input / gradient_divide_factor)
554+
else:
555+
reduce_output.copy_(reduce_scatter_input)
552556
reduce_scatter_event = reduce_scatter_stream.record_event()
553557
post_reduce_stream = reduce_scatter_stream
554558
if all_reduce_group is not None: # HSDP or DDP/replicate
@@ -721,20 +725,21 @@ def _get_gradient_divide_factors(
721725
if all_reduce_group is not None:
722726
data_parallel_size *= all_reduce_group.size()
723727

724-
if factor is None:
725-
factor = float(data_parallel_size)
726-
727728
if not overflow_risk and not force_sum_reduction_for_comms:
728-
if factor == data_parallel_size:
729+
if factor is None:
729730
# Warning: NCCL ReduceOp.AVG may produce incorrect results with
730731
# world size 1.
731732
if data_parallel_size == 1:
732733
return None, None, ReduceOp.SUM, ReduceOp.SUM
733734
return None, None, ReduceOp.AVG, ReduceOp.AVG
735+
if reduce_scatter_group is not None and factor == reduce_scatter_group.size():
736+
reduce_scatter_op = ReduceOp.AVG
734737
else:
735738
reduce_scatter_op = torch.distributed._make_nccl_premul_sum(1 / factor)
736-
return None, None, reduce_scatter_op, ReduceOp.SUM
739+
return None, None, reduce_scatter_op, ReduceOp.SUM
737740

741+
if factor is None:
742+
factor = float(data_parallel_size)
738743
pre_factor: Optional[float]
739744
if overflow_risk:
740745
# Since fp16 has smaller dynamic range than fp32/bf16, we want to avoid

0 commit comments

Comments
 (0)