Skip to content

Commit 863e6e4

Browse files
jiayisunxpytorchmergebot
authored andcommitted
Improve input dimensions check for reflection_pad1d, reflection_pad2d and reflection_pad3d (pytorch#141670)
Fix pytorch#141447. Pull Request resolved: pytorch#141670 Approved by: https://github.com/mingfeima, https://github.com/malfet
1 parent b588a78 commit 863e6e4

File tree

2 files changed

+31
-1
lines changed

2 files changed

+31
-1
lines changed

aten/src/ATen/native/Padding.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,10 @@ inline void check_valid_input(const Tensor& input, IntArrayRef padding) {
3535
int input_dim = input.dim();
3636

3737
bool is_batch_mode = input_dim == (dim + 2);
38+
bool is_non_batch_mode = input_dim == (dim + 1);
3839

3940
bool valid_batch_mode = is_batch_mode;
40-
bool valid_non_batch_mode = !is_batch_mode;
41+
bool valid_non_batch_mode = is_non_batch_mode;
4142

4243
if (is_batch_mode) {
4344
// allow batch size of 0-dim.

test/test_nn.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8841,6 +8841,35 @@ def test_ReflectionPad_empty(self, device, dtype):
88418841
inp = torch.randn(3, 0, 10, 10, 10, device=device, dtype=dtype)
88428842
mod(inp)
88438843

8844+
@onlyNativeDeviceTypes
8845+
def test_ReflectionPad_fails(self, device):
8846+
with self.assertRaisesRegex(RuntimeError, 'Only 2D, 3D, 4D, 5D'):
8847+
mod = torch.nn.ReflectionPad1d(2)
8848+
inp = torch.randn(3, 3, 10, 10, device=device)
8849+
mod(inp)
8850+
8851+
with self.assertRaisesRegex(RuntimeError, '2D or 3D'):
8852+
inp = torch.randn(3, 3, 10, 10, device=device)
8853+
torch.ops.aten.reflection_pad1d(inp, (2, 2))
8854+
8855+
with self.assertRaisesRegex(RuntimeError, 'Only 2D, 3D, 4D, 5D'):
8856+
mod = torch.nn.ReflectionPad2d(2)
8857+
inp = torch.randn(3, 3, 10, 10, 10, device=device)
8858+
mod(inp)
8859+
8860+
with self.assertRaisesRegex(RuntimeError, '3D or 4D'):
8861+
inp = torch.randn(3, 3, 10, 10, 10, device=device)
8862+
torch.ops.aten.reflection_pad2d(inp, (2, 2, 2, 2))
8863+
8864+
with self.assertRaisesRegex(RuntimeError, 'Only 2D, 3D, 4D, 5D'):
8865+
mod = torch.nn.ReflectionPad3d(3)
8866+
inp = torch.randn(3, 3, 10, 10, 10, 10, device=device)
8867+
mod(inp)
8868+
8869+
with self.assertRaisesRegex(RuntimeError, '4D or 5D'):
8870+
inp = torch.randn(3, 3, 10, 10, 10, 10, device=device)
8871+
torch.ops.aten.reflection_pad3d(inp, (2, 2, 2, 2, 2, 2))
8872+
88448873
@onlyCUDA # Test if CPU and GPU results match
88458874
def test_ReflectionPad2d_large(self, device):
88468875
shapes = ([2, 65736, 6, 6], [65736, 2, 6, 6])

0 commit comments

Comments
 (0)