@@ -713,55 +713,49 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv2d_backward_input_cpu(
713
713
714
714
auto grad_input = at::zeros_like (input);
715
715
auto grad_offset = at::zeros_like (offset);
716
- auto columns = at::zeros (
716
+ auto columns = at::empty (
717
717
{n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w},
718
718
input.options ());
719
719
720
720
// Separate into blocks
721
- grad_input = grad_input.view (
721
+ grad_input = grad_input.reshape (
722
722
{batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w});
723
- input = input.view (
723
+ input = input.reshape (
724
724
{batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w});
725
- grad_offset = grad_offset.view ({batch_sz / n_parallel_imgs,
726
- n_parallel_imgs,
727
- n_offset_grps * 2 * weight_h * weight_w,
728
- out_h,
729
- out_w});
730
- offset = offset.view ({batch_sz / n_parallel_imgs,
731
- n_parallel_imgs,
732
- n_offset_grps * 2 * weight_h * weight_w,
733
- out_h,
734
- out_w});
735
-
736
- grad_out = grad_out.view ({batch_sz / n_parallel_imgs,
737
- n_parallel_imgs,
738
- n_out_channels,
739
- out_h,
740
- out_w});
741
- grad_out.transpose_ (1 , 2 );
742
- grad_out = grad_out.view ({grad_out.size (0 ),
743
- n_weight_grps,
744
- grad_out.size (1 ) / n_weight_grps,
745
- grad_out.size (2 ),
746
- grad_out.size (3 ),
747
- grad_out.size (4 )});
748
-
749
- weight = weight.view ({n_weight_grps,
750
- weight.size (0 ) / n_weight_grps,
751
- weight.size (1 ),
752
- weight.size (2 ),
753
- weight.size (3 )});
725
+ grad_offset = grad_offset.reshape ({batch_sz / n_parallel_imgs,
726
+ n_parallel_imgs,
727
+ n_offset_grps * 2 * weight_h * weight_w,
728
+ out_h,
729
+ out_w});
730
+ offset = offset.reshape ({batch_sz / n_parallel_imgs,
731
+ n_parallel_imgs,
732
+ n_offset_grps * 2 * weight_h * weight_w,
733
+ out_h,
734
+ out_w});
735
+
736
+ grad_out = grad_out.reshape ({batch_sz / n_parallel_imgs,
737
+ n_parallel_imgs,
738
+ n_weight_grps,
739
+ n_out_channels / n_weight_grps,
740
+ out_h,
741
+ out_w}).permute ({0 , 2 , 3 , 1 , 4 , 5 });
742
+
743
+ weight = weight.reshape ({n_weight_grps,
744
+ weight.size (0 ) / n_weight_grps,
745
+ weight.size (1 ),
746
+ weight.size (2 ),
747
+ weight.size (3 )});
748
+
749
+ columns = columns.view (
750
+ {n_weight_grps, columns.size (0 ) / n_weight_grps, columns.size (1 )});
754
751
755
752
for (int elt = 0 ; elt < batch_sz / n_parallel_imgs; elt++) {
753
+ columns.zero_ ();
756
754
// Separate into weight groups
757
- columns = columns.view (
758
- {n_weight_grps, columns.size (0 ) / n_weight_grps, columns.size (1 )});
759
755
for (int g = 0 ; g < n_weight_grps; g++) {
760
756
columns[g] = columns[g].addmm_ (
761
757
weight[g].flatten (1 ).transpose (0 , 1 ), grad_out[elt][g].flatten (1 ));
762
758
}
763
- columns =
764
- columns.view ({columns.size (0 ) * columns.size (1 ), columns.size (2 )});
765
759
766
760
compute_grad_offset (
767
761
columns,
@@ -801,20 +795,9 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv2d_backward_input_cpu(
801
795
grad_input[elt]);
802
796
}
803
797
804
- grad_out = grad_out.view ({grad_out.size (0 ),
805
- grad_out.size (1 ) * grad_out.size (2 ),
806
- grad_out.size (3 ),
807
- grad_out.size (4 ),
808
- grad_out.size (5 )});
809
- grad_out.transpose_ (1 , 2 );
810
- grad_out = grad_out.view ({batch_sz, n_out_channels, out_h, out_w});
811
-
812
798
grad_input = grad_input.view ({batch_sz, n_in_channels, in_h, in_w});
813
- input = input.view ({batch_sz, n_in_channels, in_h, in_w});
814
799
grad_offset = grad_offset.view (
815
800
{batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w});
816
- offset = offset.view (
817
- {batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w});
818
801
819
802
return std::make_tuple (grad_input, grad_offset);
820
803
}
@@ -854,46 +837,36 @@ static at::Tensor deform_conv2d_backward_parameters_cpu(
854
837
long out_w = grad_out.size (3 );
855
838
856
839
auto grad_weight = at::zeros_like (weight);
857
- ;
858
- auto columns = at::zeros (
859
- {n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w},
860
- input.options ());
861
840
862
- grad_out = grad_out.view ({batch_sz / n_parallel_imgs,
863
- n_parallel_imgs,
864
- n_out_channels,
865
- out_h,
866
- out_w});
867
- grad_out.transpose_ (1 , 2 );
868
-
869
- at::Tensor grad_out_buf = at::zeros_like (grad_out);
870
- grad_out_buf.copy_ (grad_out);
871
- grad_out_buf = grad_out_buf.view ({batch_sz / n_parallel_imgs,
872
- n_out_channels,
873
- n_parallel_imgs * out_h,
874
- out_w});
875
- grad_out_buf = grad_out_buf.view ({grad_out_buf.size (0 ),
876
- n_weight_grps,
877
- grad_out_buf.size (1 ) / n_weight_grps,
878
- grad_out_buf.size (2 ),
879
- grad_out_buf.size (3 )});
880
-
881
- grad_out.transpose_ (1 , 2 );
882
- grad_out = grad_out.view ({batch_sz, n_out_channels, out_h, out_w});
883
-
884
- input = input.view (
841
+ at::Tensor grad_out_buf = grad_out.reshape (
842
+ {batch_sz / n_parallel_imgs,
843
+ n_parallel_imgs,
844
+ n_weight_grps,
845
+ n_out_channels / n_weight_grps,
846
+ out_h,
847
+ out_w}
848
+ ).permute ({0 , 2 , 3 , 1 , 4 , 5 }).contiguous ();
849
+
850
+ input = input.reshape (
885
851
{batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w});
886
- offset = offset.view ({batch_sz / n_parallel_imgs,
887
- n_parallel_imgs,
888
- n_offset_grps * 2 * weight_h * weight_w,
889
- out_h,
890
- out_w});
852
+ offset = offset.reshape ({batch_sz / n_parallel_imgs,
853
+ n_parallel_imgs,
854
+ n_offset_grps * 2 * weight_h * weight_w,
855
+ out_h,
856
+ out_w});
891
857
892
858
grad_weight = grad_weight.view ({n_weight_grps,
893
859
grad_weight.size (0 ) / n_weight_grps,
894
860
grad_weight.size (1 ),
895
861
grad_weight.size (2 ),
896
862
grad_weight.size (3 )});
863
+
864
+ auto columns = at::empty (
865
+ {n_weight_grps,
866
+ n_in_channels * weight_w * weight_h / n_weight_grps,
867
+ n_parallel_imgs * out_h * out_w},
868
+ input.options ());
869
+
897
870
for (int elt = 0 ; elt < batch_sz / n_parallel_imgs; elt++) {
898
871
deformable_im2col (
899
872
input[elt],
@@ -915,8 +888,6 @@ static at::Tensor deform_conv2d_backward_parameters_cpu(
915
888
n_offset_grps,
916
889
columns);
917
890
918
- columns = columns.view (
919
- {n_weight_grps, columns.size (0 ) / n_weight_grps, columns.size (1 )});
920
891
for (int g = 0 ; g < n_weight_grps; g++) {
921
892
grad_weight[g] =
922
893
grad_weight[g]
@@ -925,14 +896,8 @@ static at::Tensor deform_conv2d_backward_parameters_cpu(
925
896
grad_out_buf[elt][g].flatten (1 ), columns[g].transpose (1 , 0 ))
926
897
.view_as (grad_weight[g]);
927
898
}
928
- columns =
929
- columns.view ({columns.size (0 ) * columns.size (1 ), columns.size (2 )});
930
899
}
931
900
932
- input = input.view ({batch_sz, n_in_channels, in_h, in_w});
933
- offset = offset.view (
934
- {batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w});
935
-
936
901
grad_weight = grad_weight.view ({grad_weight.size (0 ) * grad_weight.size (1 ),
937
902
grad_weight.size (2 ),
938
903
grad_weight.size (3 ),
0 commit comments