@@ -168,6 +168,36 @@ def compute_loss(params, buffers, sample, target):
168
168
# we can double check that the results using ``grad`` and ``vmap`` match the
169
169
# results of hand processing each one individually:
170
170
171
+ # Create a float64 baseline for more precise comparison
172
+ def compute_grad_fp64 (sample , target ):
173
+ # Convert to float64 for higher precision
174
+ sample_fp64 = sample .to (torch .float64 )
175
+ target_fp64 = target
176
+
177
+ # Create a float64 version of the model
178
+ model_fp64 = SimpleCNN ().to (device = device )
179
+ # Copy parameters from original model to float64 model
180
+ with torch .no_grad ():
181
+ for param_fp32 , param_fp64 in zip (model .parameters (), model_fp64 .parameters ()):
182
+ param_fp64 .copy_ (param_fp32 .to (torch .float64 ))
183
+
184
+ sample_fp64 = sample_fp64 .unsqueeze (0 ) # prepend batch dimension
185
+ target_fp64 = target_fp64 .unsqueeze (0 )
186
+
187
+ prediction = model_fp64 (sample_fp64 )
188
+ loss = loss_fn (prediction , target_fp64 )
189
+
190
+ return torch .autograd .grad (loss , list (model_fp64 .parameters ()))
191
+
192
+
193
+ def compute_fp64_baseline (data , targets , indices ):
194
+ """Compute float64 gradient for a specific sample"""
195
+ # Only compute for the sample with the largest difference to save computation
196
+ i = indices [0 ] # Sample index
197
+ sample_grad = compute_grad_fp64 (data [i ], targets [i ])
198
+ return sample_grad
199
+
200
+
171
201
for i , (per_sample_grad , ft_per_sample_grad ) in enumerate (
172
202
zip (per_sample_grads , ft_per_sample_grads .values ())
173
203
):
@@ -181,7 +211,8 @@ def compute_loss(params, buffers, sample, target):
181
211
print (f" Max absolute difference: { max_diff } " )
182
212
print (f" Mean absolute difference: { mean_diff } " )
183
213
print (f" Shape of tensors: { per_sample_grad .shape } " )
184
- # Print a sample of values from both tensors where the difference is largest
214
+
215
+ # Find the location of maximum difference
185
216
max_idx = abs_diff .argmax ().item ()
186
217
flat_idx = max_idx
187
218
if len (abs_diff .shape ) > 1 :
@@ -197,6 +228,30 @@ def compute_loss(params, buffers, sample, target):
197
228
f" Functional gradient value: { ft_per_sample_grad [tuple (indices )].item ()} "
198
229
)
199
230
231
+ # Compute float64 baseline for the sample with the largest difference
232
+ print ("\n Computing float64 baseline for comparison..." )
233
+ try :
234
+ fp64_grads = compute_fp64_baseline (data , targets , indices )
235
+ fp64_value = fp64_grads [i ][
236
+ tuple (indices [1 :])
237
+ ].item () # Skip batch dimension
238
+ print (f" Float64 baseline value: { fp64_value } " )
239
+
240
+ # Compare both methods against float64 baseline
241
+ manual_diff = abs (per_sample_grad [tuple (indices )].item () - fp64_value )
242
+ functional_diff = abs (
243
+ ft_per_sample_grad [tuple (indices )].item () - fp64_value
244
+ )
245
+ print (f" Manual method vs float64 difference: { manual_diff } " )
246
+ print (f" Functional method vs float64 difference: { functional_diff } " )
247
+
248
+ if manual_diff < functional_diff :
249
+ print (" Manual method is closer to float64 baseline" )
250
+ else :
251
+ print (" Functional method is closer to float64 baseline" )
252
+ except Exception as e :
253
+ print (f" Error computing float64 baseline: { e } " )
254
+
200
255
# Keep the original assertion
201
256
assert torch .allclose (per_sample_grad , ft_per_sample_grad , atol = 3e-3 , rtol = 1e-5 )
202
257
0 commit comments