|
26 | 26 | #include <ATen/native/cpu/SerialStackImpl.h> |
27 | 27 | #include <ATen/native/cpu/StackKernel.h> |
28 | 28 | #include <ATen/quantized/QTensorImpl.h> |
29 | | -#include <c10/core/GradMode.h> |
30 | 29 | #include <c10/util/Exception.h> |
31 | 30 | #include <optional> |
32 | 31 | #include <c10/util/SmallVector.h> |
@@ -4072,41 +4071,29 @@ void split_copy_Tensor_out(const at::Tensor & self, int64_t split_size, int64_t |
4072 | 4071 | } |
4073 | 4072 | } |
4074 | 4073 |
|
4075 | | -namespace { |
| 4074 | +void split_with_sizes_copy_out(const at::Tensor & self, at::IntArrayRef split_sizes, int64_t dim, at::TensorList out) { |
| 4075 | + auto tmp = self.split_with_sizes(split_sizes, dim); |
4076 | 4076 |
|
4077 | | -void copy_tensor_array_to_out(const char* name, const std::vector<Tensor>& array, at::TensorList out) { |
4078 | | - TORCH_CHECK(out.size() == array.size(), name, " expected an out= argument of size ", array.size(), ", got size ", out.size()); |
| 4077 | + TORCH_CHECK(out.size() == tmp.size(), "split_with_sizes_copy_out() expected an out= argument of size ", tmp.size(), ", got size ", out.size()); |
4079 | 4078 | for (const auto i : c10::irange(out.size())) { |
4080 | | - if (resize_output_check(out[i], array[i].sizes())) { |
4081 | | - out[i].resize_(array[i].sizes()); |
| 4079 | + if (resize_output_check(out[i], tmp[i].sizes())) { |
| 4080 | + out[i].resize_(tmp[i].sizes()); |
4082 | 4081 | } |
4083 | | - TORCH_CHECK(out[i].dtype() == array[i].dtype(), |
4084 | | - "Expected out tensor to have dtype ", array[i].dtype(), ", but got ", out[i].dtype(), " instead"); |
4085 | | - TORCH_CHECK(out[i].device() == array[i].device(), |
4086 | | - "Expected out tensor to have device ", array[i].device(), ", but got ", out[i].device(), " instead"); |
4087 | | - out[i].copy_(array[i]); |
| 4082 | + TORCH_CHECK(out[i].dtype() == tmp[i].dtype(), |
| 4083 | + "Expected out tensor to have dtype ", tmp[i].dtype(), ", but got ", out[i].dtype(), " instead"); |
| 4084 | + TORCH_CHECK(out[i].device() == tmp[i].device(), |
| 4085 | + "Expected out tensor to have device ", tmp[i].device(), ", but got ", out[i].device(), " instead"); |
| 4086 | + out[i].copy_(tmp[i]); |
4088 | 4087 | } |
4089 | 4088 | } |
4090 | 4089 |
|
4091 | | -} |
4092 | | - |
4093 | | -void split_with_sizes_copy_out(const at::Tensor & self, at::IntArrayRef split_sizes, int64_t dim, at::TensorList out) { |
4094 | | - auto tmp = self.split_with_sizes(split_sizes, dim); |
4095 | | - copy_tensor_array_to_out("split_with_sizes_copy_out()", tmp, out); |
4096 | | -} |
| 4090 | +void unbind_copy_int_out(const at::Tensor & self, int64_t dim, at::TensorList out) { |
| 4091 | + auto tmp = self.unbind(dim); |
4097 | 4092 |
|
4098 | | -void unbind_copy_int_out(const at::Tensor & self, int64_t dim, at::TensorList out) { |
4099 | | - if (at::GradMode::is_enabled()) { |
4100 | | - for (const auto i : c10::irange(out.size())) { |
4101 | | - TORCH_CHECK(!out[i].requires_grad(), |
4102 | | - "unbind_copy(): functions with out=... arguments don't support automatic differentiation, " |
4103 | | - "but one of the arguments requires grad." |
4104 | | - ); |
4105 | | - } |
| 4093 | + TORCH_CHECK(out.size() == tmp.size(), "unbind_copy_int_out() expected an out= argument of size ", tmp.size(), ", got size ", out.size()); |
| 4094 | + for (const auto i : c10::irange(out.size())) { |
| 4095 | + out[i].copy_(tmp[i]); |
4106 | 4096 | } |
4107 | | - |
4108 | | - auto tmp = self.unbind(dim); |
4109 | | - copy_tensor_array_to_out("unbind_copy_int_out()", tmp, out); |
4110 | 4097 | } |
4111 | 4098 |
|
4112 | 4099 | int64_t sparse_dim_default(const Tensor& self) { |
|
0 commit comments