@@ -168,22 +168,16 @@ def compute_loss(params, buffers, sample, target):
168
168
# we can double check that the results using ``grad`` and ``vmap`` match the
169
169
# results of hand processing each one individually:
170
170
171
- for name , ft_per_sample_grad in ft_per_sample_grads .items ():
172
- # Find the corresponding manually computed gradient.
173
- idx = list (model .named_parameters ()).index ((name , model .get_parameter (name )))
174
- per_sample_grad = per_sample_grads [idx ]
175
-
176
- # Check if shapes match
177
- if per_sample_grad .shape != ft_per_sample_grad .shape :
178
- print (f"Shape mismatch for { name } : { per_sample_grad .shape } vs { ft_per_sample_grad .shape } " )
179
- # Reshape if needed (sometimes functional API returns different shape)
180
- if per_sample_grad .numel () == ft_per_sample_grad .numel ():
181
- ft_per_sample_grad = ft_per_sample_grad .view (per_sample_grad .shape )
182
-
183
- # Use a higher tolerance for comparison
184
- assert torch .allclose (per_sample_grad , ft_per_sample_grad , atol = 2e-2 , rtol = 2e-2 ), \
185
- f"Mismatch in { name } : max diff { (per_sample_grad - ft_per_sample_grad ).abs ().max ().item ()} "
186
-
171
+ # Get the parameter names in the same order as per_sample_grads
172
+ param_names = list (params .keys ())
173
+
174
+ # Compare gradients for each parameter
175
+ for i , name in enumerate (param_names ):
176
+ per_sample_grad = per_sample_grads [i ]
177
+ ft_per_sample_grad = ft_per_sample_grads [name ]
178
+
179
+ assert torch .allclose (per_sample_grad , ft_per_sample_grad , atol = 3e-3 , rtol = 1e-5 ), \
180
+ f"Gradients don't match for { name } : max diff = { (per_sample_grad - ft_per_sample_grad ).abs ().max ()} "
187
181
188
182
######################################################################
189
183
# A quick note: there are limitations around what types of functions can be
0 commit comments