@@ -1215,19 +1215,37 @@ def test_deform_conv2d_opcheck(dtype, device, requires_grad):
12151215 out_h = (height + 2 * padding [0 ] - dilation [0 ] * (kernel_size [0 ] - 1 ) - 1 ) // stride [0 ] + 1
12161216 out_w = (width + 2 * padding [1 ] - dilation [1 ] * (kernel_size [1 ] - 1 ) - 1 ) // stride [1 ] + 1
12171217 x = torch .randn (batch_size , channels_in , height , width , dtype = dtype , device = device , requires_grad = requires_grad )
1218- offset = torch .randn (batch_size , 2 * kernel_size [0 ] * kernel_size [1 ], out_h , out_w ,
1219- dtype = dtype , device = device , requires_grad = requires_grad )
1220- weight = torch .randn (out_channels , channels_in // groups , kernel_size [0 ], kernel_size [1 ],
1221- dtype = dtype , device = device , requires_grad = requires_grad )
1222- bias = torch .randn (out_channels , dtype = dtype , device = device , requires_grad = requires_grad )
1223- use_mask = True
1224- mask = torch .sigmoid (torch .randn (
1218+ offset = torch .randn (
12251219 batch_size ,
1226- kernel_size [0 ] * kernel_size [1 ],
1220+ 2 * kernel_size [0 ] * kernel_size [1 ],
12271221 out_h ,
12281222 out_w ,
1229- dtype = dtype , device = device , requires_grad = requires_grad
1230- ))
1223+ dtype = dtype ,
1224+ device = device ,
1225+ requires_grad = requires_grad ,
1226+ )
1227+ weight = torch .randn (
1228+ out_channels ,
1229+ channels_in // groups ,
1230+ kernel_size [0 ],
1231+ kernel_size [1 ],
1232+ dtype = dtype ,
1233+ device = device ,
1234+ requires_grad = requires_grad ,
1235+ )
1236+ bias = torch .randn (out_channels , dtype = dtype , device = device , requires_grad = requires_grad )
1237+ use_mask = True
1238+ mask = torch .sigmoid (
1239+ torch .randn (
1240+ batch_size ,
1241+ kernel_size [0 ] * kernel_size [1 ],
1242+ out_h ,
1243+ out_w ,
1244+ dtype = dtype ,
1245+ device = device ,
1246+ requires_grad = requires_grad ,
1247+ )
1248+ )
12311249 kwargs = {
12321250 "offset" : offset ,
12331251 "weight" : weight ,
@@ -1246,7 +1264,6 @@ def test_deform_conv2d_opcheck(dtype, device, requires_grad):
12461264 optests .opcheck (torch .ops .torchvision .deform_conv2d , args = (x ,), kwargs = kwargs )
12471265
12481266
1249-
12501267class TestFrozenBNT :
12511268 def test_frozenbatchnorm2d_repr (self ):
12521269 num_features = 32
0 commit comments