@@ -168,23 +168,8 @@ def compute_loss(params, buffers, sample, target):
168
168
# we can double check that the results using ``grad`` and ``vmap`` match the
169
169
# results of hand processing each one individually:
170
170
171
- # Get the parameter names in the same order as per_sample_grads
172
-
173
- for name , ft_per_sample_grad in ft_per_sample_grads .items ():
174
- # Find the corresponding manually computed gradient
175
- idx = list (model .named_parameters ()).index ((name , model .get_parameter (name )))
176
- per_sample_grad = per_sample_grads [idx ]
177
-
178
- # Check if shapes match and reshape if needed
179
- if per_sample_grad .shape != ft_per_sample_grad .shape and per_sample_grad .numel () == ft_per_sample_grad .numel ():
180
- ft_per_sample_grad = ft_per_sample_grad .view (per_sample_grad .shape )
181
-
182
- # Print differences instead of asserting
183
- max_diff = (per_sample_grad - ft_per_sample_grad ).abs ().max ().item ()
184
- print (f"Parameter { name } : max difference = { max_diff } " )
185
-
186
- # Optional: still assert for very large differences that might indicate real problems
187
- assert max_diff < 0.5 , f"Extremely large difference in { name } : { max_diff } "
171
+ for per_sample_grad , ft_per_sample_grad in zip (per_sample_grads , ft_per_sample_grads .values ()):
172
+ assert torch .allclose (per_sample_grad , ft_per_sample_grad , atol = 3e-3 , rtol = 1e-5 )
188
173
189
174
######################################################################
190
175
# A quick note: there are limitations around what types of functions can be
0 commit comments