Skip to content

Commit 831c0df

Browse files
authored
[DO NOT MERGE/n00b] Add empty batch support for DeformConv2d (#2782)
* Adding checks on forward and backward passes. * Adding unit-tests.
1 parent d537965 commit 831c0df

File tree

3 files changed

+29
-4
lines changed

3 files changed

+29
-4
lines changed

test/test_ops.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -478,8 +478,7 @@ def expected_fn(self, x, weight, offset, bias, stride=1, padding=0, dilation=1):
478478
out += bias.view(1, n_out_channels, 1, 1)
479479
return out
480480

481-
def get_fn_args(self, device, contiguous):
482-
batch_sz = 33
481+
def get_fn_args(self, device, contiguous, batch_sz):
483482
n_in_channels = 6
484483
n_out_channels = 2
485484
n_weight_grps = 2
@@ -516,7 +515,11 @@ def get_fn_args(self, device, contiguous):
516515
return x, weight, offset, bias, stride, pad, dilation
517516

518517
def _test_forward(self, device, contiguous):
519-
x, _, offset, _, stride, padding, dilation = self.get_fn_args(device, contiguous)
518+
for batch_sz in [0, 33]:
519+
self._test_forward_with_batchsize(device, contiguous, batch_sz)
520+
521+
def _test_forward_with_batchsize(self, device, contiguous, batch_sz):
522+
x, _, offset, _, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz)
520523
in_channels = 6
521524
out_channels = 2
522525
kernel_size = (3, 2)
@@ -538,7 +541,11 @@ def _test_forward(self, device, contiguous):
538541
res = layer(x, wrong_offset)
539542

540543
def _test_backward(self, device, contiguous):
541-
x, weight, offset, bias, stride, padding, dilation = self.get_fn_args(device, contiguous)
544+
for batch_sz in [0, 33]:
545+
self._test_backward_with_batchsize(device, contiguous, batch_sz)
546+
547+
def _test_backward_with_batchsize(self, device, contiguous, batch_sz):
548+
x, weight, offset, bias, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz)
542549

543550
def func(x_, offset_, weight_, bias_):
544551
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride, padding=padding, dilation=dilation)

torchvision/csrc/cpu/DeformConv_cpu.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,9 @@ at::Tensor DeformConv2d_forward_cpu(
326326
out_w);
327327

328328
auto out = at::zeros({batch_sz, out_channels, out_h, out_w}, input.options());
329+
if (batch_sz == 0) {
330+
return out;
331+
}
329332

330333
// Separate batches into blocks
331334
out = out.view({batch_sz / n_parallel_imgs,
@@ -713,6 +716,9 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv2d_backward_input_cpu(
713716

714717
auto grad_input = at::zeros_like(input);
715718
auto grad_offset = at::zeros_like(offset);
719+
if (batch_sz == 0) {
720+
return std::make_tuple(grad_input, grad_offset);
721+
}
716722
auto columns = at::empty(
717723
{n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w},
718724
input.options());
@@ -839,6 +845,9 @@ static at::Tensor deform_conv2d_backward_parameters_cpu(
839845
long out_w = grad_out.size(3);
840846

841847
auto grad_weight = at::zeros_like(weight);
848+
if (batch_sz == 0) {
849+
return grad_weight;
850+
}
842851

843852
at::Tensor grad_out_buf = grad_out
844853
.reshape({batch_sz / n_parallel_imgs,

torchvision/csrc/cuda/DeformConv_cuda.cu

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,9 @@ at::Tensor DeformConv2d_forward_cuda(
343343
out_w);
344344
345345
auto out = at::zeros({batch_sz, out_channels, out_h, out_w}, input.options());
346+
if (batch_sz == 0) {
347+
return out;
348+
}
346349
347350
// Separate batches into blocks
348351
out = out.view({batch_sz / n_parallel_imgs,
@@ -743,6 +746,9 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv_backward_input_cuda(
743746
744747
auto grad_input = at::zeros_like(input);
745748
auto grad_offset = at::zeros_like(offset);
749+
if (batch_sz == 0) {
750+
return std::make_tuple(grad_input, grad_offset);
751+
}
746752
auto columns = at::empty(
747753
{n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w},
748754
input.options());
@@ -869,6 +875,9 @@ static at::Tensor deform_conv_backward_parameters_cuda(
869875
long out_w = grad_out.size(3);
870876
871877
auto grad_weight = at::zeros_like(weight);
878+
if (batch_sz == 0) {
879+
return grad_weight;
880+
}
872881
873882
at::Tensor grad_out_buf = grad_out.reshape(
874883
{batch_sz / n_parallel_imgs,

0 commit comments

Comments
 (0)