@@ -1201,67 +1201,32 @@ def test_forward_scriptability(self):
12011201        torch .jit .script (ops .DeformConv2d (in_channels = 8 , out_channels = 8 , kernel_size = 3 ))
12021202
12031203
1204- @pytest .mark .parametrize ("dtype" , (torch .float16 , torch .float32 , torch .float64 )) 
1205- @pytest .mark .parametrize ("device" , cpu_and_cuda ()) 
1206- @pytest .mark .parametrize ("requires_grad" , (True , False )) 
1207- def  test_deform_conv2d_opcheck (dtype , device , requires_grad ):
1208-     batch_size , channels_in , height , width  =  1 , 6 , 10 , 10 
1209-     kernel_size  =  (3 , 3 )
1210-     stride  =  (1 , 1 )
1211-     padding  =  (1 , 1 )
1212-     dilation  =  (1 , 1 )
1213-     groups  =  2 
1214-     out_channels  =  4 
1215-     out_h  =  (height  +  2  *  padding [0 ] -  dilation [0 ] *  (kernel_size [0 ] -  1 ) -  1 ) //  stride [0 ] +  1 
1216-     out_w  =  (width  +  2  *  padding [1 ] -  dilation [1 ] *  (kernel_size [1 ] -  1 ) -  1 ) //  stride [1 ] +  1 
1217-     x  =  torch .randn (batch_size , channels_in , height , width , dtype = dtype , device = device , requires_grad = requires_grad )
1218-     offset  =  torch .randn (
1219-         batch_size ,
1220-         2  *  kernel_size [0 ] *  kernel_size [1 ],
1221-         out_h ,
1222-         out_w ,
1223-         dtype = dtype ,
1224-         device = device ,
1225-         requires_grad = requires_grad ,
1226-     )
1227-     weight  =  torch .randn (
1228-         out_channels ,
1229-         channels_in  //  groups ,
1230-         kernel_size [0 ],
1231-         kernel_size [1 ],
1232-         dtype = dtype ,
1233-         device = device ,
1234-         requires_grad = requires_grad ,
1235-     )
1236-     bias  =  torch .randn (out_channels , dtype = dtype , device = device , requires_grad = requires_grad )
1237-     use_mask  =  True 
1238-     mask  =  torch .sigmoid (
1239-         torch .randn (
1240-             batch_size ,
1241-             kernel_size [0 ] *  kernel_size [1 ],
1242-             out_h ,
1243-             out_w ,
1244-             dtype = dtype ,
1245-             device = device ,
1246-             requires_grad = requires_grad ,
1247-         )
1248-     )
1249-     kwargs  =  {
1250-         "offset" : offset ,
1251-         "weight" : weight ,
1252-         "bias" : bias ,
1253-         "stride_h" : stride [0 ],
1254-         "stride_w" : stride [1 ],
1255-         "pad_h" : padding [0 ],
1256-         "pad_w" : padding [1 ],
1257-         "dilation_h" : dilation [0 ],
1258-         "dilation_w" : dilation [1 ],
1259-         "groups" : groups ,
1260-         "offset_groups" : 1 ,
1261-         "use_mask" : use_mask ,
1262-         "mask" : mask ,  # no modulation in this test 
1263-     }
1264-     optests .opcheck (torch .ops .torchvision .deform_conv2d , args = (x ,), kwargs = kwargs )
1204+ # NS: Remove me once backward is implemented for MPS 
1205+ def  xfail_if_mps (x ):
1206+     mps_xfail_param  =  pytest .param ("mps" , marks = (pytest .mark .needs_mps , pytest .mark .xfail ))
1207+     new_pytestmark  =  []
1208+     for  mark  in  x .pytestmark :
1209+         if  isinstance (mark , pytest .Mark ) and  mark .name  ==  "parametrize" :
1210+             if  mark .args [0 ] ==  "device" :
1211+                 params  =  cpu_and_cuda () +  (mps_xfail_param ,)
1212+                 new_pytestmark .append (pytest .mark .parametrize ("device" , params ))
1213+                 continue 
1214+         new_pytestmark .append (mark )
1215+     x .__dict__ ["pytestmark" ] =  new_pytestmark 
1216+     return  x 
1217+ 
1218+ 
1219+ optests .generate_opcheck_tests (
1220+     testcase = TestDeformConv ,
1221+     namespaces = ["torchvision" ],
1222+     failures_dict_path = os .path .join (os .path .dirname (__file__ ), "optests_failures_dict.json" ),
1223+     # Skip tests due to unimplemented backward 
1224+     additional_decorators = {
1225+         "test_aot_dispatch_dynamic__test_forward" : [xfail_if_mps ],
1226+         "test_autograd_registration__test_forward" : [xfail_if_mps ],
1227+     },
1228+     test_utils = OPTESTS ,
1229+ )
12651230
12661231
12671232class  TestFrozenBNT :
0 commit comments