Skip to content

Commit ccd797d

Browse files
authored
Add test for large batches in DeformConv2d (#2040)
* Add test for large batches in DeformConv2d * Clean-up and (try) fix DeformConv2d * Simplifications and bugfixes * Try fix CUDA now
1 parent 979bb72 commit ccd797d

File tree

3 files changed

+109
-179
lines changed

3 files changed

+109
-179
lines changed

test/test_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,7 @@ def expected_fn(self, x, weight, offset, bias, stride=1, padding=0, dilation=1):
454454
return out
455455

456456
def get_fn_args(self, device, contiguous):
457-
batch_sz = 1
457+
batch_sz = 33
458458
n_in_channels = 6
459459
n_out_channels = 2
460460
n_weight_grps = 2

torchvision/csrc/cpu/DeformConv_cpu.cpp

Lines changed: 52 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -713,55 +713,49 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv2d_backward_input_cpu(
713713

714714
auto grad_input = at::zeros_like(input);
715715
auto grad_offset = at::zeros_like(offset);
716-
auto columns = at::zeros(
716+
auto columns = at::empty(
717717
{n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w},
718718
input.options());
719719

720720
// Separate into blocks
721-
grad_input = grad_input.view(
721+
grad_input = grad_input.reshape(
722722
{batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w});
723-
input = input.view(
723+
input = input.reshape(
724724
{batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w});
725-
grad_offset = grad_offset.view({batch_sz / n_parallel_imgs,
726-
n_parallel_imgs,
727-
n_offset_grps * 2 * weight_h * weight_w,
728-
out_h,
729-
out_w});
730-
offset = offset.view({batch_sz / n_parallel_imgs,
731-
n_parallel_imgs,
732-
n_offset_grps * 2 * weight_h * weight_w,
733-
out_h,
734-
out_w});
735-
736-
grad_out = grad_out.view({batch_sz / n_parallel_imgs,
737-
n_parallel_imgs,
738-
n_out_channels,
739-
out_h,
740-
out_w});
741-
grad_out.transpose_(1, 2);
742-
grad_out = grad_out.view({grad_out.size(0),
743-
n_weight_grps,
744-
grad_out.size(1) / n_weight_grps,
745-
grad_out.size(2),
746-
grad_out.size(3),
747-
grad_out.size(4)});
748-
749-
weight = weight.view({n_weight_grps,
750-
weight.size(0) / n_weight_grps,
751-
weight.size(1),
752-
weight.size(2),
753-
weight.size(3)});
725+
grad_offset = grad_offset.reshape({batch_sz / n_parallel_imgs,
726+
n_parallel_imgs,
727+
n_offset_grps * 2 * weight_h * weight_w,
728+
out_h,
729+
out_w});
730+
offset = offset.reshape({batch_sz / n_parallel_imgs,
731+
n_parallel_imgs,
732+
n_offset_grps * 2 * weight_h * weight_w,
733+
out_h,
734+
out_w});
735+
736+
grad_out = grad_out.reshape({batch_sz / n_parallel_imgs,
737+
n_parallel_imgs,
738+
n_weight_grps,
739+
n_out_channels / n_weight_grps,
740+
out_h,
741+
out_w}).permute({0, 2, 3, 1, 4, 5});
742+
743+
weight = weight.reshape({n_weight_grps,
744+
weight.size(0) / n_weight_grps,
745+
weight.size(1),
746+
weight.size(2),
747+
weight.size(3)});
748+
749+
columns = columns.view(
750+
{n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)});
754751

755752
for (int elt = 0; elt < batch_sz / n_parallel_imgs; elt++) {
753+
columns.zero_();
756754
// Separate into weight groups
757-
columns = columns.view(
758-
{n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)});
759755
for (int g = 0; g < n_weight_grps; g++) {
760756
columns[g] = columns[g].addmm_(
761757
weight[g].flatten(1).transpose(0, 1), grad_out[elt][g].flatten(1));
762758
}
763-
columns =
764-
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
765759

766760
compute_grad_offset(
767761
columns,
@@ -801,20 +795,9 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv2d_backward_input_cpu(
801795
grad_input[elt]);
802796
}
803797

804-
grad_out = grad_out.view({grad_out.size(0),
805-
grad_out.size(1) * grad_out.size(2),
806-
grad_out.size(3),
807-
grad_out.size(4),
808-
grad_out.size(5)});
809-
grad_out.transpose_(1, 2);
810-
grad_out = grad_out.view({batch_sz, n_out_channels, out_h, out_w});
811-
812798
grad_input = grad_input.view({batch_sz, n_in_channels, in_h, in_w});
813-
input = input.view({batch_sz, n_in_channels, in_h, in_w});
814799
grad_offset = grad_offset.view(
815800
{batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w});
816-
offset = offset.view(
817-
{batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w});
818801

819802
return std::make_tuple(grad_input, grad_offset);
820803
}
@@ -854,46 +837,36 @@ static at::Tensor deform_conv2d_backward_parameters_cpu(
854837
long out_w = grad_out.size(3);
855838

856839
auto grad_weight = at::zeros_like(weight);
857-
;
858-
auto columns = at::zeros(
859-
{n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w},
860-
input.options());
861840

862-
grad_out = grad_out.view({batch_sz / n_parallel_imgs,
863-
n_parallel_imgs,
864-
n_out_channels,
865-
out_h,
866-
out_w});
867-
grad_out.transpose_(1, 2);
868-
869-
at::Tensor grad_out_buf = at::zeros_like(grad_out);
870-
grad_out_buf.copy_(grad_out);
871-
grad_out_buf = grad_out_buf.view({batch_sz / n_parallel_imgs,
872-
n_out_channels,
873-
n_parallel_imgs * out_h,
874-
out_w});
875-
grad_out_buf = grad_out_buf.view({grad_out_buf.size(0),
876-
n_weight_grps,
877-
grad_out_buf.size(1) / n_weight_grps,
878-
grad_out_buf.size(2),
879-
grad_out_buf.size(3)});
880-
881-
grad_out.transpose_(1, 2);
882-
grad_out = grad_out.view({batch_sz, n_out_channels, out_h, out_w});
883-
884-
input = input.view(
841+
at::Tensor grad_out_buf = grad_out.reshape(
842+
{batch_sz / n_parallel_imgs,
843+
n_parallel_imgs,
844+
n_weight_grps,
845+
n_out_channels / n_weight_grps,
846+
out_h,
847+
out_w}
848+
).permute({0, 2, 3, 1, 4, 5}).contiguous();
849+
850+
input = input.reshape(
885851
{batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w});
886-
offset = offset.view({batch_sz / n_parallel_imgs,
887-
n_parallel_imgs,
888-
n_offset_grps * 2 * weight_h * weight_w,
889-
out_h,
890-
out_w});
852+
offset = offset.reshape({batch_sz / n_parallel_imgs,
853+
n_parallel_imgs,
854+
n_offset_grps * 2 * weight_h * weight_w,
855+
out_h,
856+
out_w});
891857

892858
grad_weight = grad_weight.view({n_weight_grps,
893859
grad_weight.size(0) / n_weight_grps,
894860
grad_weight.size(1),
895861
grad_weight.size(2),
896862
grad_weight.size(3)});
863+
864+
auto columns = at::empty(
865+
{n_weight_grps,
866+
n_in_channels * weight_w * weight_h / n_weight_grps,
867+
n_parallel_imgs * out_h * out_w},
868+
input.options());
869+
897870
for (int elt = 0; elt < batch_sz / n_parallel_imgs; elt++) {
898871
deformable_im2col(
899872
input[elt],
@@ -915,8 +888,6 @@ static at::Tensor deform_conv2d_backward_parameters_cpu(
915888
n_offset_grps,
916889
columns);
917890

918-
columns = columns.view(
919-
{n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)});
920891
for (int g = 0; g < n_weight_grps; g++) {
921892
grad_weight[g] =
922893
grad_weight[g]
@@ -925,14 +896,8 @@ static at::Tensor deform_conv2d_backward_parameters_cpu(
925896
grad_out_buf[elt][g].flatten(1), columns[g].transpose(1, 0))
926897
.view_as(grad_weight[g]);
927898
}
928-
columns =
929-
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
930899
}
931900

932-
input = input.view({batch_sz, n_in_channels, in_h, in_w});
933-
offset = offset.view(
934-
{batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w});
935-
936901
grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),
937902
grad_weight.size(2),
938903
grad_weight.size(3),

0 commit comments

Comments
 (0)