Skip to content

Commit b588a78

Browse files
jiayisunxpytorchmergebot
authored andcommitted
add grad_output shape check for adaptive_max_pool2d_backward and adaptive_max_pool3d_backward (pytorch#141663)
Fix pytorch#141099, pytorch#141100. Pull Request resolved: pytorch#141663 Approved by: https://github.com/mingfeima, https://github.com/malfet
1 parent 93e8e32 commit b588a78

File tree

3 files changed

+30
-0
lines changed

3 files changed

+30
-0
lines changed

aten/src/ATen/native/AdaptiveMaxPooling2d.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,12 @@ TORCH_META_FUNC(adaptive_max_pool2d_backward)
6161

6262
at::native::adaptive_pool_empty_output_check(grad_output, "adaptive_max_pool2d_backward");
6363

64+
TORCH_CHECK(input.ndimension() == indices.ndimension(),
65+
"expected dimensions ", input.ndimension(), " for `indices` but got dimensions ", indices.ndimension());
6466
TORCH_CHECK(input.dtype() == grad_output.dtype(),
6567
"expected dtype ", input.dtype(), " for `grad_output` but got dtype ", grad_output.dtype());
68+
TORCH_CHECK(indices.sizes() == grad_output.sizes(),
69+
"expected sizes ", indices.sizes(), " for `grad_output` but got sizes ", grad_output.sizes());
6670

6771
set_output_raw_strided(0, input.sizes(), {}, input.options().memory_format(input.suggest_memory_format()));
6872
}

aten/src/ATen/native/AdaptiveMaxPooling3d.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,19 @@ TORCH_META_FUNC(adaptive_max_pool3d) (const Tensor& input, IntArrayRef output_si
6666

6767
TORCH_META_FUNC(adaptive_max_pool3d_backward)
6868
(const Tensor& gradOutput, const Tensor& input, const Tensor& indices) {
69+
int64_t ndim = gradOutput.ndimension();
70+
TORCH_CHECK(ndim == 4 || ndim == 5,
71+
"adaptive_max_pool3d_backward(): Expected 4D or 5D gradOutput, but got: ", gradOutput.sizes());
72+
6973
at::native::adaptive_pool_empty_output_check(gradOutput, "adaptive_max_pool3d_backward");
74+
75+
TORCH_CHECK(input.ndimension() == indices.ndimension(),
76+
"expected dimensions ", input.ndimension(), " for `indices` but got dimensions ", indices.ndimension());
77+
TORCH_CHECK(input.dtype() == gradOutput.dtype(),
78+
"expected dtype ", input.dtype(), " for `gradOutput` but got dtype ", gradOutput.dtype());
79+
TORCH_CHECK(indices.sizes() == gradOutput.sizes(),
80+
"expected sizes ", indices.sizes(), " for `gradOutput` but got sizes ", gradOutput.sizes());
81+
7082
set_output_raw_strided(0, input.sizes(), {}, input.options());
7183
}
7284
} // namespace meta

test/nn/test_pooling.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,20 @@ def test_adaptive_pooling_empty_output_size(self, dtype, device):
556556
with self.assertRaisesRegex(RuntimeError, error_msg):
557557
fn(input2, output_size).sum().backward()
558558

559+
@onlyNativeDeviceTypes
560+
def test_adaptive_pooling_backward_fails(self, device):
561+
grad_output = torch.randn(1, 2, 7, 7, device=device)
562+
input = torch.randn(1, 2, 7, 7, device=device)
563+
indices = torch.ones(1, 2, 3, 3, dtype=torch.long, device=device)
564+
with self.assertRaisesRegex(RuntimeError, "expected sizes"):
565+
torch.ops.aten.adaptive_max_pool2d_backward(grad_output, input, indices)
566+
567+
grad_output = torch.randn(1, 2, 7, 7, 7, device=device)
568+
input = torch.randn(1, 2, 3, 3, 3, device=device)
569+
indices = torch.ones(1, 2, 3, 3, dtype=torch.long, device=device)
570+
with self.assertRaisesRegex(RuntimeError, "expected dimensions"):
571+
torch.ops.aten.adaptive_max_pool3d_backward(grad_output, input, indices)
572+
559573
@onlyNativeDeviceTypes
560574
def test_FractionalMaxPool2d_zero_batch(self, device):
561575
mod = nn.FractionalMaxPool2d(3, output_ratio=(0.5, 0.5))

0 commit comments

Comments
 (0)