@@ -8841,6 +8841,35 @@ def test_ReflectionPad_empty(self, device, dtype):
88418841 inp = torch.randn(3, 0, 10, 10, 10, device=device, dtype=dtype)
88428842 mod(inp)
88438843
8844+ @onlyNativeDeviceTypes
8845+ def test_ReflectionPad_fails(self, device):
8846+ with self.assertRaisesRegex(RuntimeError, 'Only 2D, 3D, 4D, 5D'):
8847+ mod = torch.nn.ReflectionPad1d(2)
8848+ inp = torch.randn(3, 3, 10, 10, device=device)
8849+ mod(inp)
8850+
8851+ with self.assertRaisesRegex(RuntimeError, '2D or 3D'):
8852+ inp = torch.randn(3, 3, 10, 10, device=device)
8853+ torch.ops.aten.reflection_pad1d(inp, (2, 2))
8854+
8855+ with self.assertRaisesRegex(RuntimeError, 'Only 2D, 3D, 4D, 5D'):
8856+ mod = torch.nn.ReflectionPad2d(2)
8857+ inp = torch.randn(3, 3, 10, 10, 10, device=device)
8858+ mod(inp)
8859+
8860+ with self.assertRaisesRegex(RuntimeError, '3D or 4D'):
8861+ inp = torch.randn(3, 3, 10, 10, 10, device=device)
8862+ torch.ops.aten.reflection_pad2d(inp, (2, 2, 2, 2))
8863+
8864+ with self.assertRaisesRegex(RuntimeError, 'Only 2D, 3D, 4D, 5D'):
8865+ mod = torch.nn.ReflectionPad3d(3)
8866+ inp = torch.randn(3, 3, 10, 10, 10, 10, device=device)
8867+ mod(inp)
8868+
8869+ with self.assertRaisesRegex(RuntimeError, '4D or 5D'):
8870+ inp = torch.randn(3, 3, 10, 10, 10, 10, device=device)
8871+ torch.ops.aten.reflection_pad3d(inp, (2, 2, 2, 2, 2, 2))
8872+
88448873 @onlyCUDA # Test if CPU and GPU results match
88458874 def test_ReflectionPad2d_large(self, device):
88468875 shapes = ([2, 65736, 6, 6], [65736, 2, 6, 6])
0 commit comments