Skip to content

Commit 1e4f251

Browse files
authored
Update per_sample_grads.py
Printing differences on assertion fail
1 parent 785e38c commit 1e4f251

File tree

1 file changed

+30
-1
lines changed

1 file changed

+30
-1
lines changed

intermediate_source/per_sample_grads.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,36 @@ def compute_loss(params, buffers, sample, target):
168168
# we can double check that the results using ``grad`` and ``vmap`` match the
169169
# results of hand processing each one individually:
170170

171-
for per_sample_grad, ft_per_sample_grad in zip(per_sample_grads, ft_per_sample_grads.values()):
171+
for i, (per_sample_grad, ft_per_sample_grad) in enumerate(
172+
zip(per_sample_grads, ft_per_sample_grads.values())
173+
):
174+
is_close = torch.allclose(per_sample_grad, ft_per_sample_grad, atol=3e-3, rtol=1e-5)
175+
if not is_close:
176+
# Calculate and print the maximum absolute difference
177+
abs_diff = (per_sample_grad - ft_per_sample_grad).abs()
178+
max_diff = abs_diff.max().item()
179+
mean_diff = abs_diff.mean().item()
180+
print(f"Gradient {i} mismatch:")
181+
print(f" Max absolute difference: {max_diff}")
182+
print(f" Mean absolute difference: {mean_diff}")
183+
print(f" Shape of tensors: {per_sample_grad.shape}")
184+
# Print a sample of values from both tensors where the difference is largest
185+
max_idx = abs_diff.argmax().item()
186+
flat_idx = max_idx
187+
if len(abs_diff.shape) > 1:
188+
# Convert flat index to multi-dimensional index
189+
indices = []
190+
temp_shape = abs_diff.shape
191+
for dim in reversed(temp_shape):
192+
indices.insert(0, flat_idx % dim)
193+
flat_idx //= dim
194+
print(f" Max difference at index: {indices}")
195+
print(f" Manual gradient value: {per_sample_grad[tuple(indices)].item()}")
196+
print(
197+
f" Functional gradient value: {ft_per_sample_grad[tuple(indices)].item()}"
198+
)
199+
200+
# Keep the original assertion
172201
assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=3e-3, rtol=1e-5)
173202

174203
######################################################################

0 commit comments

Comments
 (0)