Skip to content

Commit 0bb46a4

Browse files
authored
Update per_sample_grads.py
float64 baseline comparison
1 parent 1e4f251 commit 0bb46a4

File tree

1 file changed

+56
-1
lines changed

1 file changed

+56
-1
lines changed

intermediate_source/per_sample_grads.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,36 @@ def compute_loss(params, buffers, sample, target):
168168
# we can double check that the results using ``grad`` and ``vmap`` match the
169169
# results of hand processing each one individually:
170170

171+
# Create a float64 baseline for more precise comparison
172+
def compute_grad_fp64(sample, target):
173+
# Convert to float64 for higher precision
174+
sample_fp64 = sample.to(torch.float64)
175+
target_fp64 = target
176+
177+
# Create a float64 version of the model
178+
model_fp64 = SimpleCNN().to(device=device)
179+
# Copy parameters from original model to float64 model
180+
with torch.no_grad():
181+
for param_fp32, param_fp64 in zip(model.parameters(), model_fp64.parameters()):
182+
param_fp64.copy_(param_fp32.to(torch.float64))
183+
184+
sample_fp64 = sample_fp64.unsqueeze(0) # prepend batch dimension
185+
target_fp64 = target_fp64.unsqueeze(0)
186+
187+
prediction = model_fp64(sample_fp64)
188+
loss = loss_fn(prediction, target_fp64)
189+
190+
return torch.autograd.grad(loss, list(model_fp64.parameters()))
191+
192+
193+
def compute_fp64_baseline(data, targets, indices):
194+
"""Compute float64 gradient for a specific sample"""
195+
# Only compute for the sample with the largest difference to save computation
196+
i = indices[0] # Sample index
197+
sample_grad = compute_grad_fp64(data[i], targets[i])
198+
return sample_grad
199+
200+
171201
for i, (per_sample_grad, ft_per_sample_grad) in enumerate(
172202
zip(per_sample_grads, ft_per_sample_grads.values())
173203
):
@@ -181,7 +211,8 @@ def compute_loss(params, buffers, sample, target):
181211
print(f" Max absolute difference: {max_diff}")
182212
print(f" Mean absolute difference: {mean_diff}")
183213
print(f" Shape of tensors: {per_sample_grad.shape}")
184-
# Print a sample of values from both tensors where the difference is largest
214+
215+
# Find the location of maximum difference
185216
max_idx = abs_diff.argmax().item()
186217
flat_idx = max_idx
187218
if len(abs_diff.shape) > 1:
@@ -197,6 +228,30 @@ def compute_loss(params, buffers, sample, target):
197228
f" Functional gradient value: {ft_per_sample_grad[tuple(indices)].item()}"
198229
)
199230

231+
# Compute float64 baseline for the sample with the largest difference
232+
print("\nComputing float64 baseline for comparison...")
233+
try:
234+
fp64_grads = compute_fp64_baseline(data, targets, indices)
235+
fp64_value = fp64_grads[i][
236+
tuple(indices[1:])
237+
].item() # Skip batch dimension
238+
print(f" Float64 baseline value: {fp64_value}")
239+
240+
# Compare both methods against float64 baseline
241+
manual_diff = abs(per_sample_grad[tuple(indices)].item() - fp64_value)
242+
functional_diff = abs(
243+
ft_per_sample_grad[tuple(indices)].item() - fp64_value
244+
)
245+
print(f" Manual method vs float64 difference: {manual_diff}")
246+
print(f" Functional method vs float64 difference: {functional_diff}")
247+
248+
if manual_diff < functional_diff:
249+
print(" Manual method is closer to float64 baseline")
250+
else:
251+
print(" Functional method is closer to float64 baseline")
252+
except Exception as e:
253+
print(f" Error computing float64 baseline: {e}")
254+
200255
# Keep the original assertion
201256
assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=3e-3, rtol=1e-5)
202257

0 commit comments

Comments
 (0)