Skip to content

Commit bff32bd

Browse files
authored
Update per_sample_grads.py
1 parent d67bcb8 commit bff32bd

File tree

1 file changed

+2
-17
lines changed

1 file changed

+2
-17
lines changed

intermediate_source/per_sample_grads.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -168,23 +168,8 @@ def compute_loss(params, buffers, sample, target):
168168
# we can double check that the results using ``grad`` and ``vmap`` match the
169169
# results of hand processing each one individually:
170170

171-
# Get the parameter names in the same order as per_sample_grads
172-
173-
for name, ft_per_sample_grad in ft_per_sample_grads.items():
174-
# Find the corresponding manually computed gradient
175-
idx = list(model.named_parameters()).index((name, model.get_parameter(name)))
176-
per_sample_grad = per_sample_grads[idx]
177-
178-
# Check if shapes match and reshape if needed
179-
if per_sample_grad.shape != ft_per_sample_grad.shape and per_sample_grad.numel() == ft_per_sample_grad.numel():
180-
ft_per_sample_grad = ft_per_sample_grad.view(per_sample_grad.shape)
181-
182-
# Print differences instead of asserting
183-
max_diff = (per_sample_grad - ft_per_sample_grad).abs().max().item()
184-
print(f"Parameter {name}: max difference = {max_diff}")
185-
186-
# Optional: still assert for very large differences that might indicate real problems
187-
assert max_diff < 0.5, f"Extremely large difference in {name}: {max_diff}"
171+
for per_sample_grad, ft_per_sample_grad in zip(per_sample_grads, ft_per_sample_grads.values()):
172+
assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=3e-3, rtol=1e-5)
188173

189174
######################################################################
190175
# A quick note: there are limitations around what types of functions can be

0 commit comments

Comments
 (0)