@@ -1201,67 +1201,13 @@ def test_forward_scriptability(self):
12011201 torch .jit .script (ops .DeformConv2d (in_channels = 8 , out_channels = 8 , kernel_size = 3 ))
12021202
12031203
1204- @pytest .mark .parametrize ("dtype" , (torch .float16 , torch .float32 , torch .float64 ))
1205- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
1206- @pytest .mark .parametrize ("requires_grad" , (True , False ))
1207- def test_deform_conv2d_opcheck (dtype , device , requires_grad ):
1208- batch_size , channels_in , height , width = 1 , 6 , 10 , 10
1209- kernel_size = (3 , 3 )
1210- stride = (1 , 1 )
1211- padding = (1 , 1 )
1212- dilation = (1 , 1 )
1213- groups = 2
1214- out_channels = 4
1215- out_h = (height + 2 * padding [0 ] - dilation [0 ] * (kernel_size [0 ] - 1 ) - 1 ) // stride [0 ] + 1
1216- out_w = (width + 2 * padding [1 ] - dilation [1 ] * (kernel_size [1 ] - 1 ) - 1 ) // stride [1 ] + 1
1217- x = torch .randn (batch_size , channels_in , height , width , dtype = dtype , device = device , requires_grad = requires_grad )
1218- offset = torch .randn (
1219- batch_size ,
1220- 2 * kernel_size [0 ] * kernel_size [1 ],
1221- out_h ,
1222- out_w ,
1223- dtype = dtype ,
1224- device = device ,
1225- requires_grad = requires_grad ,
1226- )
1227- weight = torch .randn (
1228- out_channels ,
1229- channels_in // groups ,
1230- kernel_size [0 ],
1231- kernel_size [1 ],
1232- dtype = dtype ,
1233- device = device ,
1234- requires_grad = requires_grad ,
1235- )
1236- bias = torch .randn (out_channels , dtype = dtype , device = device , requires_grad = requires_grad )
1237- use_mask = True
1238- mask = torch .sigmoid (
1239- torch .randn (
1240- batch_size ,
1241- kernel_size [0 ] * kernel_size [1 ],
1242- out_h ,
1243- out_w ,
1244- dtype = dtype ,
1245- device = device ,
1246- requires_grad = requires_grad ,
1247- )
1248- )
1249- kwargs = {
1250- "offset" : offset ,
1251- "weight" : weight ,
1252- "bias" : bias ,
1253- "stride_h" : stride [0 ],
1254- "stride_w" : stride [1 ],
1255- "pad_h" : padding [0 ],
1256- "pad_w" : padding [1 ],
1257- "dilation_h" : dilation [0 ],
1258- "dilation_w" : dilation [1 ],
1259- "groups" : groups ,
1260- "offset_groups" : 1 ,
1261- "use_mask" : use_mask ,
1262- "mask" : mask , # no modulation in this test
1263- }
1264- optests .opcheck (torch .ops .torchvision .deform_conv2d , args = (x ,), kwargs = kwargs )
1204+ optests .generate_opcheck_tests (
1205+ testcase = TestDeformConv ,
1206+ namespaces = ["torchvision" ],
1207+ failures_dict_path = os .path .join (os .path .dirname (__file__ ), "optests_failures_dict.json" ),
1208+ additional_decorators = [],
1209+ test_utils = OPTESTS ,
1210+ )
12651211
12661212
12671213class TestFrozenBNT :
0 commit comments