Skip to content

Commit ede6186

Browse files
H-Huangpytorchmergebot
authored andcommitted
[PP] Allow intermediate nodes in ZB to have multiple grads (pytorch#159084)
Fixes a ZB regression (https://github.com/pytorch/torchtitan/actions/runs/16478292562/job/46585646792) Previously we only allowed an intermediate node to have 1 gradient. Recently a torchtitan ZB test started failing and I tracked to back to FusedRMSNorm grad_fn having two values `(grad, None)` (see pytorch#153666) and it started breaking our ZB tests. This PR allows `stage_backward_weight` intermediate nodes to have multiple grads (it sums them together or if the grad value is None, then ignores it). Here is an example where the backward would have two grad values (gI1, gI2): ```python class Func(torch.autograd.Function): @staticmethod def forward(ctx, x): return x, 2 @staticmethod def backward(ctx, gI1, gI2): assert gI2 is None return gI1 ``` Pull Request resolved: pytorch#159084 Approved by: https://github.com/tianyu-l
1 parent 6d071bd commit ede6186

File tree

2 files changed

+67
-23
lines changed

2 files changed

+67
-23
lines changed

test/distributed/pipelining/test_backward.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,44 @@ def test_stage_backward_weight_multiple_iters(self, device):
183183
print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}")
184184
raise
185185

186+
def test_stage_backward_weight_grad_validation(self, device):
187+
test_cases = [
188+
(
189+
"size >= 2",
190+
lambda: [
191+
(
192+
torch.randn(batch_size, d_hid, device=device),
193+
torch.randn(batch_size, d_hid, device=device),
194+
)
195+
],
196+
),
197+
("size = 1", lambda: [(torch.randn(batch_size, d_hid, device=device),)]),
198+
(
199+
"1 grad, 1 None",
200+
lambda: [(torch.randn(batch_size, d_hid, device=device), None)],
201+
),
202+
]
203+
204+
for description, mock_grads_factory in test_cases:
205+
with self.subTest(description=description):
206+
mod = MLPModule(d_hid).to(device)
207+
x = torch.randn(batch_size, d_hid, device=device)
208+
x.requires_grad_(True)
209+
out = mod(x)
210+
loss = torch.sum(out)
211+
dinputs, param_groups = stage_backward_input(
212+
stage_outputs_or_loss=[loss],
213+
output_grads=None,
214+
input_values=[x],
215+
weights=mod.parameters(),
216+
)
217+
218+
# Set up mock grads
219+
for param_group in param_groups:
220+
param_group["grads"] = mock_grads_factory()
221+
222+
stage_backward_weight(mod.parameters(), param_groups)
223+
186224

187225
devices = ["cpu", "cuda", "hpu", "xpu"]
188226
instantiate_device_type_tests(StageBackwardTests, globals(), only_for=devices)

torch/distributed/pipelining/_backward.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -235,11 +235,17 @@ def stage_backward_weight(
235235
weight_grads.append(weight.grad)
236236

237237
for param_group in param_groups:
238-
# TODO: Handle case where intermediate can have multiple outputs
239-
intermediate_edges = tuple(
240-
GradientEdge(i, 0) for i in param_group["intermediates"]
241-
)
242-
weights_edges = tuple(GradientEdge(w, 0) for w in param_group["params"])
238+
valid_edges = []
239+
valid_grad_outputs: list[torch.Tensor] = []
240+
241+
for grads_tuple, intermediate in zip(
242+
param_group["grads"], param_group["intermediates"]
243+
):
244+
non_none_grads = [g for g in grads_tuple if g is not None]
245+
if non_none_grads:
246+
summed_grad = sum(non_none_grads)
247+
valid_edges.append(GradientEdge(intermediate, 0))
248+
valid_grad_outputs.append(summed_grad)
243249

244250
# Break a reference cycle caused inside stage_backward_input->get_hook->hook
245251
# The summarized cycle is:
@@ -248,25 +254,25 @@ def stage_backward_weight(
248254
# We need to keep intermediates alive up until backward_weight, but we can free it now.
249255
del param_group["intermediates"]
250256

251-
assert all(len(g) == 1 for g in param_group["grads"])
252-
# [NEW!] Able to pass a GradientEdge to autograd.grad as output
253-
# We do not need to retain_graph because... guarantee no overlap?
254-
# print("trying to execute: ", intermediate_edges, weights_edges)
255-
dweights = torch.autograd.grad(
256-
intermediate_edges,
257-
weights_edges,
258-
grad_outputs=sum(param_group["grads"], tuple()),
259-
retain_graph=retain_graph,
260-
)
261-
# release grad memory early after use
262-
del param_group["grads"]
257+
if valid_edges: # Only call autograd.grad if we have valid gradients
258+
# [NEW!] Able to pass a GradientEdge to autograd.grad as output
259+
weights_edges = tuple(GradientEdge(w, 0) for w in param_group["params"])
260+
dweights = torch.autograd.grad(
261+
valid_edges,
262+
weights_edges,
263+
grad_outputs=valid_grad_outputs,
264+
retain_graph=retain_graph,
265+
)
263266

264-
for grad_acc, dw in zip(param_group["params"], dweights):
265-
weight, index = grad_acc_to_weight[grad_acc]
266-
if weight.grad is None:
267-
weight.grad = dw
268-
else:
269-
weight.grad += dw
267+
# release grad memory early after use
268+
del param_group["grads"]
269+
270+
for grad_acc, dw in zip(param_group["params"], dweights):
271+
weight, index = grad_acc_to_weight[grad_acc]
272+
if weight.grad is None:
273+
weight.grad = dw
274+
else:
275+
weight.grad += dw
270276
# return grads in the original order weights were provided in
271277
return tuple(weight_grads)
272278

0 commit comments

Comments
 (0)