Skip to content

Commit 9da1a9a

Browse files
Samantha Andowzou3519
authored andcommitted
[functorch] make jvpvjp actually run all the tests for decompositions (pytorch/functorch#792)
1 parent ff2d64a commit 9da1a9a

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

functorch/functorch/_src/eager_transforms.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1302,6 +1302,7 @@ def get_function_def(sig):
13021302

13031303
_register_jit_decomposition(torch.ops.aten.trace.default)
13041304
_register_jit_decomposition(torch.ops.aten.nll_loss_backward.default)
1305+
_register_jit_decomposition(torch.ops.aten.nll_loss2d_backward.default)
13051306
_register_jit_decomposition(torch.ops.aten.mse_loss_backward.default)
13061307
_register_jit_decomposition(torch.ops.aten.l1_loss_backward.default)
13071308
_register_jit_decomposition(torch.ops.aten._log_softmax_backward_data.default)

functorch/test/test_ops.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1359,10 +1359,9 @@ def get_vjp(primals, cotangents):
13591359
if op.name == 'nn.functional.binary_cross_entropy': # reverse second derivative wrt target not defined
13601360
in_dims = 1
13611361
compare_jacobians(primals, cotangents, in_dims)
1362-
return
1363-
1364-
expected = reference(primals, cotangents, primals_tangents, cotangents_tangents)
1365-
self.assertEqual(result, expected)
1362+
else:
1363+
expected = reference(primals, cotangents, primals_tangents, cotangents_tangents)
1364+
self.assertEqual(result, expected)
13661365

13671366
@ops(filter(lambda op: op.name == "nn.functional.group_norm", functorch_lagging_op_db + additional_op_db),
13681367
allowed_dtypes=(torch.float32, torch.double)) # TODO: generalize

0 commit comments

Comments
 (0)