diff --git a/test/cpp/test_tensor.cpp b/test/cpp/test_tensor.cpp index 61627f242a8..a75580d22f4 100644 --- a/test/cpp/test_tensor.cpp +++ b/test/cpp/test_tensor.cpp @@ -4,6 +4,7 @@ #include #include +#include "absl/base/nullability.h" #include "test/cpp/cpp_test_util.h" #include "test/cpp/torch_xla_test.h" #include "torch/csrc/autograd/variable.h" @@ -297,14 +298,18 @@ TEST_F(TensorTest, TestMaxPool2D) { /*padding=*/{padding, padding}, /*dilation=*/{1, 1}, /*ceil_mode=*/false); ForEachDevice([&](const torch::lazy::BackendDevice& device) { - XLA_ASSIGN_OR_THROW(XLATensorPtr dev_input, + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr dev_input, XLATensor::Create(input, device)); - auto dev_output = tensor_methods::max_pool_nd( - dev_input, - /*spatial_dim_count=*/2, - /*kernel_size=*/{kernel_size, kernel_size}, - /*stride=*/{stride, stride}, - /*padding=*/{padding, padding}, /*ceil_mode=*/false); + std::tuple + dev_output; + XLA_ASSIGN_OR_THROW( + dev_output, + tensor_methods::max_pool_nd( + dev_input, + /*spatial_dim_count=*/2, + /*kernel_size=*/{kernel_size, kernel_size}, + /*stride=*/{stride, stride}, + /*padding=*/{padding, padding}, /*ceil_mode=*/false)); AllClose(output, std::get<0>(dev_output)); }); } @@ -322,15 +327,18 @@ TEST_F(TensorTest, TestMaxPool2DNonSquare) { /*padding=*/{padding, padding + 1}, /*dilation=*/{1, 1}, /*ceil_mode=*/false); ForEachDevice([&](const torch::lazy::BackendDevice& device) { - XLA_ASSIGN_OR_THROW(XLATensorPtr dev_input, + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr dev_input, XLATensor::Create(input, device)); - auto dev_output = tensor_methods::max_pool_nd( - dev_input, - /*spatial_dim_count=*/2, - /*kernel_size=*/{kernel_size, kernel_size + 1}, - /*stride=*/{stride, stride + 1}, - /*padding=*/{padding, padding + 1}, - /*ceil_mode=*/false); + std::tuple + dev_output; + XLA_ASSIGN_OR_THROW(dev_output, + tensor_methods::max_pool_nd( + dev_input, + /*spatial_dim_count=*/2, + /*kernel_size=*/{kernel_size, kernel_size + 1}, + /*stride=*/{stride, stride + 1}, + /*padding=*/{padding, padding + 1}, + /*ceil_mode=*/false)); AllClose(output, std::get<0>(dev_output)); }); } @@ -351,16 +359,17 @@ TEST_F(TensorTest, TestAvgPool2D) { /*ceil_mode=*/false, count_include_pad, /*divisor_override=*/std::nullopt); ForEachDevice([&](const torch::lazy::BackendDevice& device) { - XLA_ASSIGN_OR_THROW(XLATensorPtr dev_input, + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr dev_input, XLATensor::Create(input, device)); - XLATensorPtr dev_output = tensor_methods::avg_pool_nd( - dev_input, - /*spatial_dim_count=*/2, - /*kernel_size=*/{kernel_size, kernel_size}, - /*stride=*/{stride, stride}, - /*padding=*/{padding, padding}, - /*ceil_mode=*/false, count_include_pad, - /*divisor_override=*/std::nullopt); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr dev_output, + tensor_methods::avg_pool_nd( + dev_input, + /*spatial_dim_count=*/2, + /*kernel_size=*/{kernel_size, kernel_size}, + /*stride=*/{stride, stride}, + /*padding=*/{padding, padding}, + /*ceil_mode=*/false, count_include_pad, + /*divisor_override=*/std::nullopt)); AllClose(output, dev_output); }); } @@ -382,17 +391,19 @@ TEST_F(TensorTest, TestAvgPool2DNonSquare) { /*count_include_pad=*/count_include_pad, /*divisor_override=*/std::nullopt); ForEachDevice([&](const torch::lazy::BackendDevice& device) { - XLA_ASSIGN_OR_THROW(XLATensorPtr dev_input, + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr dev_input, XLATensor::Create(input, device)); - XLATensorPtr dev_output = tensor_methods::avg_pool_nd( - dev_input, - /*spatial_dim_count=*/2, - /*kernel_size=*/{kernel_size, kernel_size + 1}, - /*stride=*/{stride, stride + 1}, - /*padding=*/{padding, padding + 1}, - /*ceil_mode=*/false, - /*count_include_pad=*/count_include_pad, - /*divisor_override=*/std::nullopt); + XLA_ASSIGN_OR_THROW( + absl_nonnull XLATensorPtr dev_output, + tensor_methods::avg_pool_nd( + dev_input, + /*spatial_dim_count=*/2, + /*kernel_size=*/{kernel_size, kernel_size + 1}, + /*stride=*/{stride, stride + 1}, + /*padding=*/{padding, padding + 1}, + /*ceil_mode=*/false, + /*count_include_pad=*/count_include_pad, + /*divisor_override=*/std::nullopt)); AllClose(output, dev_output); }); } diff --git a/test/test_ops_error_message.py b/test/test_ops_error_message.py index 8fbbf66d69e..796f9aa7333 100644 --- a/test/test_ops_error_message.py +++ b/test/test_ops_error_message.py @@ -180,6 +180,31 @@ def test(): expect="""mm(): cannot matrix-multiply tensors f32[2,5] and f32[8,2]. Expected the size of dimension 1 of the first input tensor (5) to be equal the size of dimension 0 of the second input tensor (8).""" ) + def test_avg_pool_3d_raises_error_on_bad_spec(self): + device = torch_xla.device() + a = torch.rand(1, 1, 4, 4, 4, device=device) + + def gen_test_fn(kernel_size=[2, 2, 2], stride=[], padding=[0]): + return lambda: torch.nn.functional.avg_pool3d(a, kernel_size, stride, padding) + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=gen_test_fn(kernel_size=[2, 2]), + expect="""avg_pool3d(): expected argument kernel_size [2, 2] (size: 2) to have size of 3.""" + ) + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=gen_test_fn(stride=[1, 2]), + expect="""avg_pool3d(): expected argument stride [1, 2] (size: 2) to have size of 3.""" + ) + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=gen_test_fn(padding=[1, 2]), + expect="""avg_pool3d(): expected argument padding [1, 2] (size: 2) to have size of 3.""" + ) + if __name__ == "__main__": unittest.main() diff --git a/torch_xla/csrc/aten_autograd_ops.cpp b/torch_xla/csrc/aten_autograd_ops.cpp index b86a4055270..ad5d25deba7 100644 --- a/torch_xla/csrc/aten_autograd_ops.cpp +++ b/torch_xla/csrc/aten_autograd_ops.cpp @@ -192,11 +192,13 @@ torch::Tensor MaxPool3dAutogradFunction::forward( return std::get<0>(results); } ctx->save_for_backward({self}); - XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); - auto outputs = tensor_methods::max_pool_nd( - xla_self, /*spatial_dim_count=*/3, XlaHelpers::I64List(kernel_size), - XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), ceil_mode); - return bridge::AtenFromXlaTensor(std::get<0>(outputs)); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self, + bridge::GetXlaTensor(self)); + std::tuple output; + XLA_ASSIGN_OR_THROW(output, tensor_methods::max_pool_nd( + xla_self, /*spatial_dim_count=*/3, + kernel_size, stride, padding, ceil_mode)); + return bridge::AtenFromXlaTensor(std::get<0>(output)); } torch::autograd::variable_list MaxPool3dAutogradFunction::backward( @@ -220,13 +222,15 @@ torch::autograd::variable_list MaxPool3dAutogradFunction::backward( padding, dilation, ceil_mode, indices); } - XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output_0, + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_grad_output_0, bridge::GetXlaTensor(grad_output[0])); - XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); - grad = bridge::AtenFromXlaTensor(tensor_methods::max_pool_nd_backward( - xla_grad_output_0, xla_self, /*spatial_dim_count=*/3, - XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride), - XlaHelpers::I64List(padding), ceil_mode)); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self, + bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr output, + tensor_methods::max_pool_nd_backward( + xla_grad_output_0, xla_self, /*spatial_dim_count=*/3, + kernel_size, stride, padding, ceil_mode)); + grad = bridge::AtenFromXlaTensor(std::move(output)); torch::Tensor undef; torch::autograd::variable_list grad_inputs = {grad, undef, undef, @@ -239,24 +243,28 @@ torch::Tensor max_pool2d_forward(torch::Tensor self, torch::IntArrayRef stride, torch::IntArrayRef padding, torch::IntArrayRef dilation, bool ceil_mode) { - XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); - auto outputs = tensor_methods::max_pool_nd( - xla_self, /*spatial_dim_count=*/2, XlaHelpers::I64List(kernel_size), - XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), ceil_mode); - return bridge::AtenFromXlaTensor(std::get<0>(outputs)); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self, + bridge::GetXlaTensor(self)); + std::tuple output; + XLA_ASSIGN_OR_THROW(output, tensor_methods::max_pool_nd( + xla_self, /*spatial_dim_count=*/2, + kernel_size, stride, padding, ceil_mode)); + return bridge::AtenFromXlaTensor(std::get<0>(output)); } torch::Tensor max_pool2d_backward(torch::Tensor grad_output, torch::Tensor self, torch::IntArrayRef kernel_size, torch::IntArrayRef stride, torch::IntArrayRef padding, bool ceil_mode) { - XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output, + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_grad_output, bridge::GetXlaTensor(grad_output)); - XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); - auto grad = bridge::AtenFromXlaTensor(tensor_methods::max_pool_nd_backward( - xla_grad_output, xla_self, /*spatial_dim_count=*/2, - XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride), - XlaHelpers::I64List(padding), ceil_mode)); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self, + bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr output, + tensor_methods::max_pool_nd_backward( + xla_grad_output, xla_self, /*spatial_dim_count=*/2, + kernel_size, stride, padding, ceil_mode)); + auto grad = bridge::AtenFromXlaTensor(std::move(output)); return grad; } diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index c042d703aa3..e22fe0d4783 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1182,11 +1182,14 @@ at::Tensor XLANativeFunctions::avg_pool2d( at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, std::optional divisor_override) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); - return bridge::AtenFromXlaTensor(tensor_methods::avg_pool_nd( - xla_self, /*spatial_dim_count=*/2, XlaHelpers::I64List(kernel_size), - XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), ceil_mode, - count_include_pad, divisor_override)); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self, + bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW( + absl_nonnull XLATensorPtr output, + tensor_methods::avg_pool_nd(xla_self, /*spatial_dim_count=*/2, + kernel_size, stride, padding, ceil_mode, + count_include_pad, divisor_override)); + return bridge::AtenFromXlaTensor(std::move(output)); } at::Tensor XLANativeFunctions::avg_pool2d_backward( @@ -1203,13 +1206,16 @@ at::Tensor XLANativeFunctions::avg_pool2d_backward( count_include_pad, divisor_override); } - XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output, + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_grad_output, bridge::GetXlaTensor(grad_output)); - XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); - return bridge::AtenFromXlaTensor(tensor_methods::avg_pool_nd_backward( - xla_grad_output, xla_self, /*spatial_dim_count=*/2, - XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride), - XlaHelpers::I64List(padding), ceil_mode, count_include_pad)); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self, + bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW( + absl_nonnull XLATensorPtr output, + tensor_methods::avg_pool_nd_backward( + xla_grad_output, xla_self, /*spatial_dim_count=*/2, kernel_size, + stride, padding, ceil_mode, count_include_pad)); + return bridge::AtenFromXlaTensor(std::move(output)); } at::Tensor XLANativeFunctions::avg_pool3d( @@ -1217,11 +1223,14 @@ at::Tensor XLANativeFunctions::avg_pool3d( at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, std::optional divisor_override) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); - return bridge::AtenFromXlaTensor(tensor_methods::avg_pool_nd( - xla_self, /*spatial_dim_count=*/3, XlaHelpers::I64List(kernel_size), - XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), ceil_mode, - count_include_pad, divisor_override)); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self, + bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW( + absl_nonnull XLATensorPtr output, + tensor_methods::avg_pool_nd(xla_self, /*spatial_dim_count=*/3, + kernel_size, stride, padding, ceil_mode, + count_include_pad, divisor_override)); + return bridge::AtenFromXlaTensor(std::move(output)); } at::Tensor XLANativeFunctions::avg_pool3d_backward( @@ -1238,13 +1247,16 @@ at::Tensor XLANativeFunctions::avg_pool3d_backward( count_include_pad, divisor_override); } - XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output, + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_grad_output, bridge::GetXlaTensor(grad_output)); - XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); - return bridge::AtenFromXlaTensor(tensor_methods::avg_pool_nd_backward( - xla_grad_output, xla_self, /*spatial_dim_count=*/3, - XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride), - XlaHelpers::I64List(padding), ceil_mode, count_include_pad)); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self, + bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW( + absl_nonnull XLATensorPtr output, + tensor_methods::avg_pool_nd_backward( + xla_grad_output, xla_self, /*spatial_dim_count=*/3, kernel_size, + stride, padding, ceil_mode, count_include_pad)); + return bridge::AtenFromXlaTensor(std::move(output)); } at::Tensor XLANativeFunctions::baddbmm(const at::Tensor& self, @@ -2327,12 +2339,14 @@ std::tuple XLANativeFunctions::max_pool2d_with_indices( dilation, ceil_mode); } - XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); - auto outputs = tensor_methods::max_pool_nd( - xla_self, /*spatial_dim_count=*/2, XlaHelpers::I64List(kernel_size), - XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), ceil_mode); - return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(outputs)), - bridge::AtenFromXlaTensor(std::get<1>(outputs))); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self, + bridge::GetXlaTensor(self)); + std::tuple output; + XLA_ASSIGN_OR_THROW(output, tensor_methods::max_pool_nd( + xla_self, /*spatial_dim_count=*/2, + kernel_size, stride, padding, ceil_mode)); + return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(output)), + bridge::AtenFromXlaTensor(std::get<1>(output))); } at::Tensor XLANativeFunctions::max_pool2d_with_indices_backward( @@ -2350,13 +2364,15 @@ at::Tensor XLANativeFunctions::max_pool2d_with_indices_backward( padding, dilation, ceil_mode, indices); } - XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output, + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_grad_output, bridge::GetXlaTensor(grad_output)); - XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); - return bridge::AtenFromXlaTensor(tensor_methods::max_pool_nd_backward( - xla_grad_output, xla_self, /*spatial_dim_count=*/2, - XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride), - XlaHelpers::I64List(padding), ceil_mode)); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self, + bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr output, + tensor_methods::max_pool_nd_backward( + xla_grad_output, xla_self, /*spatial_dim_count=*/2, + kernel_size, stride, padding, ceil_mode)); + return bridge::AtenFromXlaTensor(std::move(output)); } at::Tensor XLANativeFunctions::max_pool3d( @@ -2382,13 +2398,15 @@ at::Tensor XLANativeFunctions::max_pool3d_with_indices_backward( padding, dilation, ceil_mode, indices); } - XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output, + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_grad_output, bridge::GetXlaTensor(grad_output)); - XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); - return bridge::AtenFromXlaTensor(tensor_methods::max_pool_nd_backward( - xla_grad_output, xla_self, /*spatial_dim_count=*/3, - XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride), - XlaHelpers::I64List(padding), ceil_mode)); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self, + bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr output, + tensor_methods::max_pool_nd_backward( + xla_grad_output, xla_self, /*spatial_dim_count=*/3, + kernel_size, stride, padding, ceil_mode)); + return bridge::AtenFromXlaTensor(std::move(output)); } std::tuple XLANativeFunctions::max_pool3d_with_indices( @@ -2404,12 +2422,14 @@ std::tuple XLANativeFunctions::max_pool3d_with_indices( dilation, ceil_mode); } - XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); - auto outputs = tensor_methods::max_pool_nd( - xla_self, /*spatial_dim_count=*/3, XlaHelpers::I64List(kernel_size), - XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), ceil_mode); - return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(outputs)), - bridge::AtenFromXlaTensor(std::get<1>(outputs))); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self, + bridge::GetXlaTensor(self)); + std::tuple output; + XLA_ASSIGN_OR_THROW(output, tensor_methods::max_pool_nd( + xla_self, /*spatial_dim_count=*/3, + kernel_size, stride, padding, ceil_mode)); + return std::make_tuple(bridge::AtenFromXlaTensor(std::get<0>(output)), + bridge::AtenFromXlaTensor(std::get<1>(output))); } at::Tensor XLANativeFunctions::max_unpool2d(const at::Tensor& self, diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 21f4db59713..c05bbe0d591 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -13,6 +13,7 @@ #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "absl/strings/str_split.h" #include "torch_xla/csrc/LazyIr.h" #include "torch_xla/csrc/aten_xla_bridge.h" @@ -168,6 +169,20 @@ struct MinMaxValues { torch::lazy::Value max; }; +// Gathers the common inputs among `*_pool_nd` operations. +// This is specifically used for input checking purposes. +template +struct _PoolNdInputs { + T kernel_size; + T stride; + T padding; +}; + +// Convenience aliases for representing `_PoolNdInputs` that either own or +// reference the storage. +using PoolNdInputsOwner = _PoolNdInputs>; +using PoolNdInputsRef = _PoolNdInputs>; + torch::lazy::Value MaybeExpand(const torch::lazy::Value& input, const xla::Shape& target_shape) { if (GetXlaShape(input).dimensions() == target_shape.dimensions()) { @@ -245,6 +260,56 @@ std::vector GetExpandDimensions(const xla::Shape& shape, return dimensions; } +std::vector RepeatIfSingleElement(const absl::Span span, + int64_t n) { + return (span.size() == 1 && n > 1) + ? std::vector(n, span[0]) + : std::vector(span.begin(), span.end()); +} + +absl::Status CheckPoolNdInputHasSize(const std::string_view op, + const std::string_view arg, + const absl::Span input, + int64_t size) { + if (input.size() != size) { + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat( + op, "(): expected argument ", arg, " [", + absl::StrJoin(input, /* sep= */ ", "), "] (size: ", input.size(), + ") to have size of ", size, "."))); + } + return absl::OkStatus(); +} + +// Fills each field of `inputs` (only those that have only 1 element) so that +// they have exactly `spatial_dim_count` elements, and check that they actually +// have that. +absl::StatusOr FillAndCheckPoolNdInputs( + const std::string_view op, int64_t spatial_dim_count, + const PoolNdInputsRef& inputs) { + // Fill and check `inputs.kernel_size`. + std::vector kernel_size = + RepeatIfSingleElement(inputs.kernel_size, spatial_dim_count); + XLA_RETURN_IF_ERROR(CheckPoolNdInputHasSize(op, "kernel_size", kernel_size, + spatial_dim_count)); + + // Fill and check `inputs.stride`. + // Only for this field, if it is empty, copy from `kernel_size`. + std::vector stride = + inputs.stride.empty() + ? kernel_size + : RepeatIfSingleElement(inputs.stride, spatial_dim_count); + XLA_RETURN_IF_ERROR( + CheckPoolNdInputHasSize(op, "stride", stride, spatial_dim_count)); + + // Fill and check `inputs.padding`. + std::vector padding = + RepeatIfSingleElement(inputs.padding, spatial_dim_count); + XLA_RETURN_IF_ERROR( + CheckPoolNdInputHasSize(op, "padding", padding, spatial_dim_count)); + + return PoolNdInputsOwner{kernel_size, stride, padding}; +} + // Resizes and / or checks whether a list is of the given size. The list is only // resized if its size is 1. If it's empty, it's replaced with the provided // default first. @@ -1183,35 +1248,37 @@ void as_strided_(XLATensorPtr& input, std::vector size, } } -XLATensorPtr avg_pool_nd(const XLATensorPtr& input, int64_t spatial_dim_count, - std::vector kernel_size, - std::vector stride, - std::vector padding, bool ceil_mode, - bool count_include_pad, - std::optional divisor_override) { - kernel_size = CheckIntList(kernel_size, spatial_dim_count, "kernel_size"); - stride = CheckIntList(stride, spatial_dim_count, "stride", kernel_size); - padding = CheckIntList(padding, spatial_dim_count, "padding"); +absl::StatusOr avg_pool_nd( + const XLATensorPtr& input, int64_t spatial_dim_count, + const absl::Span kernel_size, + const absl::Span stride, + const absl::Span padding, bool ceil_mode, + bool count_include_pad, std::optional divisor_override) { + XLA_ASSIGN_OR_RETURN(PoolNdInputsOwner inputs, + FillAndCheckPoolNdInputs( + absl::StrCat("avg_pool", spatial_dim_count, "d"), + spatial_dim_count, {kernel_size, stride, padding})); return input->CreateFrom(torch_xla::MakeNode( - input->GetIrValue(), spatial_dim_count, std::move(kernel_size), - std::move(stride), std::move(padding), ceil_mode, count_include_pad, - divisor_override)); + input->GetIrValue(), spatial_dim_count, std::move(inputs.kernel_size), + std::move(inputs.stride), std::move(inputs.padding), ceil_mode, + count_include_pad, divisor_override)); } -XLATensorPtr avg_pool_nd_backward(const XLATensorPtr& out_backprop, - const XLATensorPtr& input, - int64_t spatial_dim_count, - std::vector kernel_size, - std::vector stride, - std::vector padding, bool ceil_mode, - bool count_include_pad) { - kernel_size = CheckIntList(kernel_size, spatial_dim_count, "kernel_size"); - stride = CheckIntList(stride, spatial_dim_count, "stride", kernel_size); - padding = CheckIntList(padding, spatial_dim_count, "padding"); +absl::StatusOr avg_pool_nd_backward( + const XLATensorPtr& out_backprop, const XLATensorPtr& input, + int64_t spatial_dim_count, const absl::Span kernel_size, + const absl::Span stride, + const absl::Span padding, bool ceil_mode, + bool count_include_pad) { + XLA_ASSIGN_OR_RETURN( + PoolNdInputsOwner inputs, + FillAndCheckPoolNdInputs( + absl::StrCat("avg_pool", spatial_dim_count, "d_backward"), + spatial_dim_count, {kernel_size, stride, padding})); return out_backprop->CreateFrom(torch_xla::MakeNode( out_backprop->GetIrValue(), input->GetIrValue(), spatial_dim_count, - std::move(kernel_size), std::move(stride), std::move(padding), ceil_mode, - count_include_pad)); + std::move(inputs.kernel_size), std::move(inputs.stride), + std::move(inputs.padding), ceil_mode, count_include_pad)); } XLATensorPtr baddbmm(const XLATensorPtr& input, const XLATensorPtr& batch1, @@ -2262,16 +2329,18 @@ void max_out(XLATensorPtr& max, XLATensorPtr& max_values, } } -std::tuple max_pool_nd( - const XLATensorPtr& input, int64_t spatial_dim_count, - std::vector kernel_size, std::vector stride, - std::vector padding, bool ceil_mode) { - kernel_size = CheckIntList(kernel_size, spatial_dim_count, "kernel_size"); - stride = CheckIntList(stride, spatial_dim_count, "stride", kernel_size); - padding = CheckIntList(padding, spatial_dim_count, "padding"); +absl::StatusOr> +max_pool_nd(const XLATensorPtr& input, int64_t spatial_dim_count, + const absl::Span kernel_size, + const absl::Span stride, + const absl::Span padding, bool ceil_mode) { + XLA_ASSIGN_OR_RETURN(PoolNdInputsOwner inputs, + FillAndCheckPoolNdInputs( + absl::StrCat("max_pool", spatial_dim_count, "d"), + spatial_dim_count, {kernel_size, stride, padding})); torch::lazy::NodePtr node = torch_xla::MakeNode( - input->GetIrValue(), spatial_dim_count, std::move(kernel_size), - std::move(stride), std::move(padding), ceil_mode); + input->GetIrValue(), spatial_dim_count, std::move(inputs.kernel_size), + std::move(inputs.stride), std::move(inputs.padding), ceil_mode); XLATensorPtr t1 = input->CreateFrom(torch::lazy::Value(node, 0), /*delay_eager_execution=*/true); @@ -2287,17 +2356,20 @@ std::tuple max_pool_nd( return std::make_tuple(t1, t2); } -XLATensorPtr max_pool_nd_backward( +absl::StatusOr max_pool_nd_backward( const XLATensorPtr& out_backprop, const XLATensorPtr& input, - int64_t spatial_dim_count, std::vector kernel_size, - std::vector stride, std::vector padding, bool ceil_mode) { - kernel_size = CheckIntList(kernel_size, spatial_dim_count, "kernel_size"); - stride = CheckIntList(stride, spatial_dim_count, "stride", kernel_size); - padding = CheckIntList(padding, spatial_dim_count, "padding"); + int64_t spatial_dim_count, const absl::Span kernel_size, + const absl::Span stride, + const absl::Span padding, bool ceil_mode) { + XLA_ASSIGN_OR_RETURN( + PoolNdInputsOwner inputs, + FillAndCheckPoolNdInputs( + absl::StrCat("max_pool", spatial_dim_count, "d_backward"), + spatial_dim_count, {kernel_size, stride, padding})); return out_backprop->CreateFrom(torch_xla::MakeNode( out_backprop->GetIrValue(), input->GetIrValue(), spatial_dim_count, - std::move(kernel_size), std::move(stride), std::move(padding), - ceil_mode)); + std::move(inputs.kernel_size), std::move(inputs.stride), + std::move(inputs.padding), ceil_mode)); } XLATensorPtr max_unpool(const XLATensorPtr& input, const XLATensorPtr& indices, diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index b25b423d49c..a96d752e5f4 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -257,20 +257,19 @@ void as_strided_(XLATensorPtr& input, std::vector size, std::vector stride, std::optional storage_offset); -XLATensorPtr avg_pool_nd(const XLATensorPtr& input, int64_t spatial_dim_count, - std::vector kernel_size, - std::vector stride, - std::vector padding, bool ceil_mode, - bool count_include_pad, - std::optional divisor_override); - -XLATensorPtr avg_pool_nd_backward(const XLATensorPtr& out_backprop, - const XLATensorPtr& input, - int64_t spatial_dim_count, - std::vector kernel_size, - std::vector stride, - std::vector padding, bool ceil_mode, - bool count_include_pad); +absl::StatusOr avg_pool_nd( + const XLATensorPtr& input, int64_t spatial_dim_count, + const absl::Span kernel_size, + const absl::Span stride, + const absl::Span padding, bool ceil_mode, + bool count_include_pad, std::optional divisor_override); + +absl::StatusOr avg_pool_nd_backward( + const XLATensorPtr& out_backprop, const XLATensorPtr& input, + int64_t spatial_dim_count, const absl::Span kernel_size, + const absl::Span stride, + const absl::Span padding, bool ceil_mode, + bool count_include_pad); XLATensorPtr baddbmm(const XLATensorPtr& input, const XLATensorPtr& batch1, const XLATensorPtr& batch2, const at::Scalar& beta, @@ -608,17 +607,17 @@ std::tuple max(const XLATensorPtr& input, void max_out(XLATensorPtr& max, XLATensorPtr& max_values, const XLATensorPtr& input, int64_t dim, bool keepdim); -std::tuple max_pool_nd( - const XLATensorPtr& input, int64_t spatial_dim_count, - std::vector kernel_size, std::vector stride, - std::vector padding, bool ceil_mode); +absl::StatusOr> +max_pool_nd(const XLATensorPtr& input, int64_t spatial_dim_count, + const absl::Span kernel_size, + const absl::Span stride, + const absl::Span padding, bool ceil_mode); -XLATensorPtr max_pool_nd_backward(const XLATensorPtr& out_backprop, - const XLATensorPtr& input, - int64_t spatial_dim_count, - std::vector kernel_size, - std::vector stride, - std::vector padding, bool ceil_mode); +absl::StatusOr max_pool_nd_backward( + const XLATensorPtr& out_backprop, const XLATensorPtr& input, + int64_t spatial_dim_count, const absl::Span kernel_size, + const absl::Span stride, + const absl::Span padding, bool ceil_mode); XLATensorPtr max_unpool(const XLATensorPtr& input, const XLATensorPtr& indices, std::vector output_size);