@@ -1200,12 +1200,28 @@ def test_forward_scriptability(self):
12001200 # Non-regression test for https://github.com/pytorch/vision/issues/4078
12011201 torch .jit .script (ops .DeformConv2d (in_channels = 8 , out_channels = 8 , kernel_size = 3 ))
12021202
1203+ # NS: Removeme once bacward is implemented
1204+ def xfail_if_mps (x ):
1205+ mps_xfail_param = pytest .param ("mps" , marks = (pytest .mark .needs_mps , pytest .mark .xfail ))
1206+ new_pytestmark = []
1207+ for mark in x .pytestmark :
1208+ if isinstance (mark , pytest .Mark ) and mark .name == "parametrize" :
1209+ if mark .args [0 ] == 'device' :
1210+ params = cpu_and_cuda () + (mps_xfail_param ,)
1211+ new_pytestmark .append (pytest .mark .parametrize ('device' , params ))
1212+ continue
1213+ new_pytestmark .append (mark )
1214+ x .__dict__ ["pytestmark" ] = new_pytestmark
1215+ return x
1216+
12031217
12041218optests .generate_opcheck_tests (
12051219 testcase = TestDeformConv ,
12061220 namespaces = ["torchvision" ],
12071221 failures_dict_path = os .path .join (os .path .dirname (__file__ ), "optests_failures_dict.json" ),
1208- additional_decorators = [],
1222+ # Skip tests due to unimplemented backward
1223+ additional_decorators = {"test_aot_dispatch_dynamic__test_forward" : [xfail_if_mps ],
1224+ "test_autograd_registration__test_forward" : [xfail_if_mps ]},
12091225 test_utils = OPTESTS ,
12101226)
12111227
0 commit comments