Skip to content

Commit 5fc349e

Browse files
authored
Update per_sample_grads.py
1 parent 2d8bda9 commit 5fc349e

File tree

1 file changed

+10
-16
lines changed

1 file changed

+10
-16
lines changed

intermediate_source/per_sample_grads.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -168,22 +168,16 @@ def compute_loss(params, buffers, sample, target):
168168
# we can double check that the results using ``grad`` and ``vmap`` match the
169169
# results of hand processing each one individually:
170170

171-
for name, ft_per_sample_grad in ft_per_sample_grads.items():
172-
# Find the corresponding manually computed gradient.
173-
idx = list(model.named_parameters()).index((name, model.get_parameter(name)))
174-
per_sample_grad = per_sample_grads[idx]
175-
176-
# Check if shapes match
177-
if per_sample_grad.shape != ft_per_sample_grad.shape:
178-
print(f"Shape mismatch for {name}: {per_sample_grad.shape} vs {ft_per_sample_grad.shape}")
179-
# Reshape if needed (sometimes functional API returns different shape)
180-
if per_sample_grad.numel() == ft_per_sample_grad.numel():
181-
ft_per_sample_grad = ft_per_sample_grad.view(per_sample_grad.shape)
182-
183-
# Use a higher tolerance for comparison
184-
assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=2e-2, rtol=2e-2), \
185-
f"Mismatch in {name}: max diff {(per_sample_grad - ft_per_sample_grad).abs().max().item()}"
186-
171+
# Get the parameter names in the same order as per_sample_grads
172+
param_names = list(params.keys())
173+
174+
# Compare gradients for each parameter
175+
for i, name in enumerate(param_names):
176+
per_sample_grad = per_sample_grads[i]
177+
ft_per_sample_grad = ft_per_sample_grads[name]
178+
179+
assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=3e-3, rtol=1e-5), \
180+
f"Gradients don't match for {name}: max diff = {(per_sample_grad - ft_per_sample_grad).abs().max()}"
187181

188182
######################################################################
189183
# A quick note: there are limitations around what types of functions can be

0 commit comments

Comments
 (0)