Skip to content

Commit 2c12321

Browse files
authored
Update per_sample_grads.py
1 parent 0bb46a4 commit 2c12321

File tree

1 file changed

+4
-7
lines changed

1 file changed

+4
-7
lines changed

intermediate_source/per_sample_grads.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -174,12 +174,10 @@ def compute_grad_fp64(sample, target):
174174
sample_fp64 = sample.to(torch.float64)
175175
target_fp64 = target
176176

177-
# Create a float64 version of the model
178-
model_fp64 = SimpleCNN().to(device=device)
179-
# Copy parameters from original model to float64 model
180-
with torch.no_grad():
181-
for param_fp32, param_fp64 in zip(model.parameters(), model_fp64.parameters()):
182-
param_fp64.copy_(param_fp32.to(torch.float64))
177+
# Create a float64 version of the model and explicitly convert it to float64
178+
model_fp64 = SimpleCNN().to(device=device).to(torch.float64)
179+
180+
# No need to manually copy parameters as the model is already in float64
183181

184182
sample_fp64 = sample_fp64.unsqueeze(0) # prepend batch dimension
185183
target_fp64 = target_fp64.unsqueeze(0)
@@ -254,7 +252,6 @@ def compute_fp64_baseline(data, targets, indices):
254252

255253
# Keep the original assertion
256254
assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=3e-3, rtol=1e-5)
257-
258255
######################################################################
259256
# A quick note: there are limitations around what types of functions can be
260257
# transformed by ``vmap``. The best functions to transform are ones that are pure

0 commit comments

Comments
 (0)