Skip to content

Commit 38645e8

Browse files
Revert "Fix unbind_copy and add its decomposition (pytorch#134319)"
This reverts commit 8aedc64. Reverted pytorch#134319 on behalf of https://github.com/huydhn due to Sorry for reverting your PR, but this is still failing the same test on ExecuTorch ([comment](pytorch#134319 (comment)))
1 parent ea93e09 commit 38645e8

File tree

13 files changed

+28
-107
lines changed

13 files changed

+28
-107
lines changed

aten/src/ATen/native/TensorShape.cpp

Lines changed: 15 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
#include <ATen/native/cpu/SerialStackImpl.h>
2727
#include <ATen/native/cpu/StackKernel.h>
2828
#include <ATen/quantized/QTensorImpl.h>
29-
#include <c10/core/GradMode.h>
3029
#include <c10/util/Exception.h>
3130
#include <optional>
3231
#include <c10/util/SmallVector.h>
@@ -4072,41 +4071,29 @@ void split_copy_Tensor_out(const at::Tensor & self, int64_t split_size, int64_t
40724071
}
40734072
}
40744073

4075-
namespace {
4074+
void split_with_sizes_copy_out(const at::Tensor & self, at::IntArrayRef split_sizes, int64_t dim, at::TensorList out) {
4075+
auto tmp = self.split_with_sizes(split_sizes, dim);
40764076

4077-
void copy_tensor_array_to_out(const char* name, const std::vector<Tensor>& array, at::TensorList out) {
4078-
TORCH_CHECK(out.size() == array.size(), name, " expected an out= argument of size ", array.size(), ", got size ", out.size());
4077+
TORCH_CHECK(out.size() == tmp.size(), "split_with_sizes_copy_out() expected an out= argument of size ", tmp.size(), ", got size ", out.size());
40794078
for (const auto i : c10::irange(out.size())) {
4080-
if (resize_output_check(out[i], array[i].sizes())) {
4081-
out[i].resize_(array[i].sizes());
4079+
if (resize_output_check(out[i], tmp[i].sizes())) {
4080+
out[i].resize_(tmp[i].sizes());
40824081
}
4083-
TORCH_CHECK(out[i].dtype() == array[i].dtype(),
4084-
"Expected out tensor to have dtype ", array[i].dtype(), ", but got ", out[i].dtype(), " instead");
4085-
TORCH_CHECK(out[i].device() == array[i].device(),
4086-
"Expected out tensor to have device ", array[i].device(), ", but got ", out[i].device(), " instead");
4087-
out[i].copy_(array[i]);
4082+
TORCH_CHECK(out[i].dtype() == tmp[i].dtype(),
4083+
"Expected out tensor to have dtype ", tmp[i].dtype(), ", but got ", out[i].dtype(), " instead");
4084+
TORCH_CHECK(out[i].device() == tmp[i].device(),
4085+
"Expected out tensor to have device ", tmp[i].device(), ", but got ", out[i].device(), " instead");
4086+
out[i].copy_(tmp[i]);
40884087
}
40894088
}
40904089

4091-
}
4092-
4093-
void split_with_sizes_copy_out(const at::Tensor & self, at::IntArrayRef split_sizes, int64_t dim, at::TensorList out) {
4094-
auto tmp = self.split_with_sizes(split_sizes, dim);
4095-
copy_tensor_array_to_out("split_with_sizes_copy_out()", tmp, out);
4096-
}
4090+
void unbind_copy_int_out(const at::Tensor & self, int64_t dim, at::TensorList out) {
4091+
auto tmp = self.unbind(dim);
40974092

4098-
void unbind_copy_int_out(const at::Tensor & self, int64_t dim, at::TensorList out) {
4099-
if (at::GradMode::is_enabled()) {
4100-
for (const auto i : c10::irange(out.size())) {
4101-
TORCH_CHECK(!out[i].requires_grad(),
4102-
"unbind_copy(): functions with out=... arguments don't support automatic differentiation, "
4103-
"but one of the arguments requires grad."
4104-
);
4105-
}
4093+
TORCH_CHECK(out.size() == tmp.size(), "unbind_copy_int_out() expected an out= argument of size ", tmp.size(), ", got size ", out.size());
4094+
for (const auto i : c10::irange(out.size())) {
4095+
out[i].copy_(tmp[i]);
41064096
}
4107-
4108-
auto tmp = self.unbind(dim);
4109-
copy_tensor_array_to_out("unbind_copy_int_out()", tmp, out);
41104097
}
41114098

41124099
int64_t sparse_dim_default(const Tensor& self) {

test/distributed/_tensor/test_dtensor_ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,6 @@ def wrapped(fn):
446446
xfail("trapz"),
447447
xfail("triangular_solve"),
448448
xfail("unbind"),
449-
xfail("unbind_copy"),
450449
xfail("unfold"),
451450
xfail("unfold_copy"),
452451
xfail("uniform"),

test/expect/HasDecompTest.test_aten_core_operators.expect

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -506,8 +506,6 @@ aten::triu_indices.out
506506
aten::trunc
507507
aten::trunc.out
508508
aten::trunc_
509-
aten::unbind_copy.int
510-
aten::unbind_copy.int_out
511509
aten::unfold
512510
aten::uniform
513511
aten::uniform.out

test/expect/HasDecompTest.test_has_decomposition.expect

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1292,6 +1292,8 @@ aten::topk.values
12921292
aten::transpose_
12931293
aten::triangular_solve
12941294
aten::triangular_solve.X
1295+
aten::unbind_copy.int
1296+
aten::unbind_copy.int_out
12951297
aten::unique_consecutive
12961298
aten::unique_consecutive.out
12971299
aten::unique_dim

test/functorch/test_ops.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,9 +1039,6 @@ def fn(inp, *args, **kwargs):
10391039
xfail("_native_batch_norm_legit"),
10401040
# TODO: implement batching rule
10411041
xfail("_batch_norm_with_update"),
1042-
xfail(
1043-
"unbind_copy"
1044-
), # Batching rule not implemented for aten::unbind_copy.int.
10451042
}
10461043
),
10471044
)
@@ -1181,9 +1178,6 @@ def vjp_of_vjp(*args_and_cotangents):
11811178
xfail("sparse.mm", "reduce"),
11821179
xfail("as_strided_scatter", ""), # calls as_strided
11831180
xfail("index_reduce", "prod"), # .item() call
1184-
xfail(
1185-
"unbind_copy"
1186-
), # Batching rule not implemented for aten::unbind_copy.int.
11871181
# ---------------------------------------------------------------------
11881182
}
11891183
)
@@ -1322,9 +1316,6 @@ def test_vmapvjp(self, device, dtype, op):
13221316
xfail("_native_batch_norm_legit"),
13231317
# TODO: implement batching rule
13241318
xfail("_batch_norm_with_update"),
1325-
xfail(
1326-
"unbind_copy"
1327-
), # Batching rule not implemented for aten::unbind_copy.int.
13281319
# ----------------------------------------------------------------------
13291320
}
13301321

@@ -1638,9 +1629,6 @@ def test():
16381629
xfail("__getitem__", ""),
16391630
xfail("index_put", ""),
16401631
xfail("view_as_complex"),
1641-
xfail(
1642-
"unbind_copy"
1643-
), # Batching rule not implemented for aten::unbind_copy.int.
16441632
xfail("nn.functional.gaussian_nll_loss"),
16451633
xfail("masked_select"),
16461634
xfail(
@@ -1935,9 +1923,6 @@ def reference(primals, cotangents, primals_tangents, cotangents_tangents):
19351923
xfail(
19361924
"as_strided_scatter"
19371925
), # AssertionError: Tensor-likes are not close!
1938-
xfail(
1939-
"unbind_copy"
1940-
), # Batching rule not implemented for aten::unbind_copy.int.
19411926
xfail("bernoulli"), # calls random op
19421927
xfail("bfloat16"), # required rank 4 tensor to use channels_last format
19431928
xfail("cdist"), # Forward AD not implemented and no decomposition

test/functorch/test_vmap.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4375,9 +4375,6 @@ def sample_vmap_out_dim_numpy_split_copy_with_int(x, splits, dim):
43754375
xfail("torch.ops.aten._efficient_attention_forward"), # outputs ints
43764376
# TypeError: expected Tensor as element 0 in argument 0, but got float
43774377
xfail("item"),
4378-
xfail(
4379-
"unbind_copy"
4380-
), # Batching rule not implemented for aten::unbind_copy.int.
43814378
}
43824379
),
43834380
)
@@ -4453,9 +4450,6 @@ def test_vmap_exhaustive(self, device, dtype, op):
44534450
xfail("item"),
44544451
xfail("tril"), # Exception not raised on error input
44554452
xfail("triu"), # Exception not raised on error input
4456-
xfail(
4457-
"unbind_copy"
4458-
), # Batching rule not implemented for aten::unbind_copy.int.
44594453
xfail("__getitem__", ""),
44604454
xfail("count_nonzero"),
44614455
xfail(

test/test_mps.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,6 @@ def mps_ops_modifier(ops):
350350
'transpose_copy',
351351
'T',
352352
'unbind',
353-
'unbind_copy',
354353
'unflatten',
355354
'unfold',
356355
'unfold_copy',

tools/autograd/gen_variable_type.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,6 @@
241241
"slice",
242242
"constant_pad_nd",
243243
"unbind",
244-
"unbind_copy",
245244
"split",
246245
"split_with_sizes",
247246
"unsafe_split",

torch/_inductor/decomposition.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@
8383
aten._to_copy,
8484
aten.tril_indices,
8585
aten.triu_indices,
86-
aten.unbind_copy.int,
8786
aten.upsample_bilinear2d.vec,
8887
quantized.linear_dynamic_fp16_unpacked_weight,
8988
_quantized.wrapped_quantized_linear,

torch/_prims/context.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,6 @@ def __torch_function__(
129129
func = torch._decomp.decomposition_table.get(orig_func, None)
130130
elif func is None and isinstance(orig_func, torch._ops.OpOverloadPacket):
131131
default = getattr(orig_func, "default", None)
132-
if default is None and orig_func._dir:
133-
default = getattr(orig_func, orig_func._dir[0], None)
134132
if default is not None:
135133
func = torch._decomp.decomposition_table.get(default, None)
136134

0 commit comments

Comments
 (0)