diff --git a/test/test_ops.py b/test/test_ops.py index 26d13bbe208..9cb0cddedf7 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1201,11 +1201,30 @@ def test_forward_scriptability(self): torch.jit.script(ops.DeformConv2d(in_channels=8, out_channels=8, kernel_size=3)) +# NS: Remove me once backward is implemented for MPS +def xfail_if_mps(x): + mps_xfail_param = pytest.param("mps", marks=(pytest.mark.needs_mps, pytest.mark.xfail)) + new_pytestmark = [] + for mark in x.pytestmark: + if isinstance(mark, pytest.Mark) and mark.name == "parametrize": + if mark.args[0] == "device": + params = cpu_and_cuda() + (mps_xfail_param,) + new_pytestmark.append(pytest.mark.parametrize("device", params)) + continue + new_pytestmark.append(mark) + x.__dict__["pytestmark"] = new_pytestmark + return x + + optests.generate_opcheck_tests( testcase=TestDeformConv, namespaces=["torchvision"], failures_dict_path=os.path.join(os.path.dirname(__file__), "optests_failures_dict.json"), - additional_decorators=[], + # Skip tests due to unimplemented backward + additional_decorators={ + "test_aot_dispatch_dynamic__test_forward": [xfail_if_mps], + "test_autograd_registration__test_forward": [xfail_if_mps], + }, test_utils=OPTESTS, )