Skip to content

Commit 6a4fb4b

Browse files
Revert "Align CPU behavior with CUDA for ConvTranspose when out_channels=0 (pytorch#142859)"
This reverts commit cb814c0. Reverted pytorch#142859 on behalf of https://github.com/malfet due to It broke ROCM tests again, see https://hud.pytorch.org/hud/pytorch/pytorch/5cd2b34e821c293909fe3f6a958767b9c535c094/1?per_page=50&name_filter=rocm&mergeLF=true ([comment](pytorch#142859 (comment)))
1 parent 5cd2b34 commit 6a4fb4b

File tree

3 files changed

+4
-28
lines changed

3 files changed

+4
-28
lines changed

aten/src/ATen/native/NaiveConvolutionTranspose2d.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@ static inline void slow_conv_transpose2d_shape_check(
7878

7979
if (weight.defined()) {
8080
TORCH_CHECK(
81-
(weight.dim() == 2 || weight.dim() == 4),
82-
"2D or 4D weight tensor expected, but got: ",
81+
weight.numel() != 0 && (weight.dim() == 2 || weight.dim() == 4),
82+
"non-empty 2D or 4D weight tensor expected, but got: ",
8383
weight.sizes());
8484
if (bias.defined()) {
8585
check_dim_size(bias, 1, 0, weight.size(1));

aten/src/ATen/native/NaiveConvolutionTranspose3d.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,8 @@ static inline void slow_conv_transpose3d_shape_check(
9898
if (weight.defined()) {
9999
/* TODO: TORCH_CHECK just have 2 args: condition and message */
100100
TORCH_CHECK(
101-
weight.dim() == 5,
102-
"5D (n_output_plane x n_input_plane x kernel_depth",
101+
weight.numel() != 0 && weight.dim() == 5,
102+
"non-empty 5D (n_output_plane x n_input_plane x kernel_depth",
103103
" x kernel_height x kernel_width) tensor ",
104104
"expected for weight, but got: ",
105105
weight.sizes());

test/nn/test_convolution.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
skipCUDAIfRocmVersionLessThan,
4141
skipMeta,
4242
skipMPS,
43-
skipXLA,
4443
)
4544
from torch.testing._internal.common_dtype import (
4645
floating_and_complex_types_and,
@@ -1750,29 +1749,6 @@ def test_conv2d_same_padding(self, device, dtype):
17501749
actual = F.conv2d(x, y, padding="same", dilation=3)
17511750
self.assertEqual(expect, actual, rtol=rtol, atol=atol)
17521751

1753-
@dtypes(torch.float)
1754-
# aten/src/ATen/native/mps/OperationUtils.mm: TORCH_INTERNAL_ASSERT([srcBuf length] > 0, "Placeholder tensor is empty!"); on MPS
1755-
@expectedFailureMPS
1756-
@skipXLA
1757-
def test_ConvTranspose_output_channels_0(self, device, dtype):
1758-
class Model(nn.Module):
1759-
def __init__(self, operator, dim):
1760-
super().__init__()
1761-
self.op = eval(
1762-
f"torch.nn.{operator}{dim}d(in_channels=1, out_channels=0, kernel_size={tuple([1] * dim)})"
1763-
)
1764-
1765-
def forward(self, x):
1766-
x = self.op(x)
1767-
return x
1768-
1769-
for dim in [1, 2, 3]:
1770-
x = torch.randn([1] * (dim + 1), device=device, dtype=dtype)
1771-
model = Model("ConvTranspose", dim).to(device).to(dtype=dtype)
1772-
y = model(x)
1773-
self.assertEqual(y.numel(), 0)
1774-
self.assertEqual(x.shape[1:], y.shape[1:])
1775-
17761752
@dtypes(torch.float, torch.cfloat)
17771753
def test_conv3d_same_padding(self, device, dtype):
17781754
if dtype is torch.cfloat:

0 commit comments

Comments
 (0)