@@ -168,7 +168,36 @@ def compute_loss(params, buffers, sample, target):
168
168
# we can double check that the results using ``grad`` and ``vmap`` match the
169
169
# results of hand processing each one individually:
170
170
171
- for per_sample_grad , ft_per_sample_grad in zip (per_sample_grads , ft_per_sample_grads .values ()):
171
+ for i , (per_sample_grad , ft_per_sample_grad ) in enumerate (
172
+ zip (per_sample_grads , ft_per_sample_grads .values ())
173
+ ):
174
+ is_close = torch .allclose (per_sample_grad , ft_per_sample_grad , atol = 3e-3 , rtol = 1e-5 )
175
+ if not is_close :
176
+ # Calculate and print the maximum absolute difference
177
+ abs_diff = (per_sample_grad - ft_per_sample_grad ).abs ()
178
+ max_diff = abs_diff .max ().item ()
179
+ mean_diff = abs_diff .mean ().item ()
180
+ print (f"Gradient { i } mismatch:" )
181
+ print (f" Max absolute difference: { max_diff } " )
182
+ print (f" Mean absolute difference: { mean_diff } " )
183
+ print (f" Shape of tensors: { per_sample_grad .shape } " )
184
+ # Print a sample of values from both tensors where the difference is largest
185
+ max_idx = abs_diff .argmax ().item ()
186
+ flat_idx = max_idx
187
+ if len (abs_diff .shape ) > 1 :
188
+ # Convert flat index to multi-dimensional index
189
+ indices = []
190
+ temp_shape = abs_diff .shape
191
+ for dim in reversed (temp_shape ):
192
+ indices .insert (0 , flat_idx % dim )
193
+ flat_idx //= dim
194
+ print (f" Max difference at index: { indices } " )
195
+ print (f" Manual gradient value: { per_sample_grad [tuple (indices )].item ()} " )
196
+ print (
197
+ f" Functional gradient value: { ft_per_sample_grad [tuple (indices )].item ()} "
198
+ )
199
+
200
+ # Keep the original assertion
172
201
assert torch .allclose (per_sample_grad , ft_per_sample_grad , atol = 3e-3 , rtol = 1e-5 )
173
202
174
203
######################################################################
0 commit comments