Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 45 additions & 34 deletions test/cpp/test_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <limits>
#include <vector>

#include "absl/base/nullability.h"
#include "test/cpp/cpp_test_util.h"
#include "test/cpp/torch_xla_test.h"
#include "torch/csrc/autograd/variable.h"
Expand Down Expand Up @@ -297,14 +298,18 @@ TEST_F(TensorTest, TestMaxPool2D) {
/*padding=*/{padding, padding}, /*dilation=*/{1, 1},
/*ceil_mode=*/false);
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
XLA_ASSIGN_OR_THROW(XLATensorPtr dev_input,
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr dev_input,
XLATensor::Create(input, device));
auto dev_output = tensor_methods::max_pool_nd(
dev_input,
/*spatial_dim_count=*/2,
/*kernel_size=*/{kernel_size, kernel_size},
/*stride=*/{stride, stride},
/*padding=*/{padding, padding}, /*ceil_mode=*/false);
std::tuple<absl_nonnull XLATensorPtr, absl_nonnull XLATensorPtr>
dev_output;
XLA_ASSIGN_OR_THROW(
dev_output,
tensor_methods::max_pool_nd(
dev_input,
/*spatial_dim_count=*/2,
/*kernel_size=*/{kernel_size, kernel_size},
/*stride=*/{stride, stride},
/*padding=*/{padding, padding}, /*ceil_mode=*/false));
AllClose(output, std::get<0>(dev_output));
});
}
Expand All @@ -322,15 +327,18 @@ TEST_F(TensorTest, TestMaxPool2DNonSquare) {
/*padding=*/{padding, padding + 1}, /*dilation=*/{1, 1},
/*ceil_mode=*/false);
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
XLA_ASSIGN_OR_THROW(XLATensorPtr dev_input,
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr dev_input,
XLATensor::Create(input, device));
auto dev_output = tensor_methods::max_pool_nd(
dev_input,
/*spatial_dim_count=*/2,
/*kernel_size=*/{kernel_size, kernel_size + 1},
/*stride=*/{stride, stride + 1},
/*padding=*/{padding, padding + 1},
/*ceil_mode=*/false);
std::tuple<absl_nonnull XLATensorPtr, absl_nonnull XLATensorPtr>
dev_output;
XLA_ASSIGN_OR_THROW(dev_output,
tensor_methods::max_pool_nd(
dev_input,
/*spatial_dim_count=*/2,
/*kernel_size=*/{kernel_size, kernel_size + 1},
/*stride=*/{stride, stride + 1},
/*padding=*/{padding, padding + 1},
/*ceil_mode=*/false));
AllClose(output, std::get<0>(dev_output));
});
}
Expand All @@ -351,16 +359,17 @@ TEST_F(TensorTest, TestAvgPool2D) {
/*ceil_mode=*/false, count_include_pad,
/*divisor_override=*/std::nullopt);
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
XLA_ASSIGN_OR_THROW(XLATensorPtr dev_input,
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr dev_input,
XLATensor::Create(input, device));
XLATensorPtr dev_output = tensor_methods::avg_pool_nd(
dev_input,
/*spatial_dim_count=*/2,
/*kernel_size=*/{kernel_size, kernel_size},
/*stride=*/{stride, stride},
/*padding=*/{padding, padding},
/*ceil_mode=*/false, count_include_pad,
/*divisor_override=*/std::nullopt);
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr dev_output,
tensor_methods::avg_pool_nd(
dev_input,
/*spatial_dim_count=*/2,
/*kernel_size=*/{kernel_size, kernel_size},
/*stride=*/{stride, stride},
/*padding=*/{padding, padding},
/*ceil_mode=*/false, count_include_pad,
/*divisor_override=*/std::nullopt));
AllClose(output, dev_output);
});
}
Expand All @@ -382,17 +391,19 @@ TEST_F(TensorTest, TestAvgPool2DNonSquare) {
/*count_include_pad=*/count_include_pad,
/*divisor_override=*/std::nullopt);
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
XLA_ASSIGN_OR_THROW(XLATensorPtr dev_input,
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr dev_input,
XLATensor::Create(input, device));
XLATensorPtr dev_output = tensor_methods::avg_pool_nd(
dev_input,
/*spatial_dim_count=*/2,
/*kernel_size=*/{kernel_size, kernel_size + 1},
/*stride=*/{stride, stride + 1},
/*padding=*/{padding, padding + 1},
/*ceil_mode=*/false,
/*count_include_pad=*/count_include_pad,
/*divisor_override=*/std::nullopt);
XLA_ASSIGN_OR_THROW(
absl_nonnull XLATensorPtr dev_output,
tensor_methods::avg_pool_nd(
dev_input,
/*spatial_dim_count=*/2,
/*kernel_size=*/{kernel_size, kernel_size + 1},
/*stride=*/{stride, stride + 1},
/*padding=*/{padding, padding + 1},
/*ceil_mode=*/false,
/*count_include_pad=*/count_include_pad,
/*divisor_override=*/std::nullopt));
AllClose(output, dev_output);
});
}
Expand Down
25 changes: 25 additions & 0 deletions test/test_ops_error_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,31 @@ def test():
expect="""mm(): cannot matrix-multiply tensors f32[2,5] and f32[8,2]. Expected the size of dimension 1 of the first input tensor (5) to be equal the size of dimension 0 of the second input tensor (8)."""
)

def test_avg_pool_3d_raises_error_on_bad_spec(self):
device = torch_xla.device()
a = torch.rand(1, 1, 4, 4, 4, device=device)

def gen_test_fn(kernel_size=[2, 2, 2], stride=[], padding=[0]):
return lambda: torch.nn.functional.avg_pool3d(a, kernel_size, stride, padding)

self.assertExpectedRaisesInline(
exc_type=RuntimeError,
callable=gen_test_fn(kernel_size=[2, 2]),
expect="""avg_pool3d(): expected argument kernel_size [2, 2] (size: 2) to have size of 3."""
)

self.assertExpectedRaisesInline(
exc_type=RuntimeError,
callable=gen_test_fn(stride=[1, 2]),
expect="""avg_pool3d(): expected argument stride [1, 2] (size: 2) to have size of 3."""
)

self.assertExpectedRaisesInline(
exc_type=RuntimeError,
callable=gen_test_fn(padding=[1, 2]),
expect="""avg_pool3d(): expected argument padding [1, 2] (size: 2) to have size of 3."""
)


if __name__ == "__main__":
unittest.main()
52 changes: 30 additions & 22 deletions torch_xla/csrc/aten_autograd_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,11 +192,13 @@ torch::Tensor MaxPool3dAutogradFunction::forward(
return std::get<0>(results);
}
ctx->save_for_backward({self});
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
auto outputs = tensor_methods::max_pool_nd(
xla_self, /*spatial_dim_count=*/3, XlaHelpers::I64List(kernel_size),
XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), ceil_mode);
return bridge::AtenFromXlaTensor(std::get<0>(outputs));
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self,
bridge::GetXlaTensor(self));
std::tuple<absl_nonnull XLATensorPtr, absl_nonnull XLATensorPtr> output;
XLA_ASSIGN_OR_THROW(output, tensor_methods::max_pool_nd(
xla_self, /*spatial_dim_count=*/3,
kernel_size, stride, padding, ceil_mode));
return bridge::AtenFromXlaTensor(std::get<0>(output));
}

torch::autograd::variable_list MaxPool3dAutogradFunction::backward(
Expand All @@ -220,13 +222,15 @@ torch::autograd::variable_list MaxPool3dAutogradFunction::backward(
padding, dilation,
ceil_mode, indices);
}
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output_0,
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_grad_output_0,
bridge::GetXlaTensor(grad_output[0]));
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
grad = bridge::AtenFromXlaTensor(tensor_methods::max_pool_nd_backward(
xla_grad_output_0, xla_self, /*spatial_dim_count=*/3,
XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride),
XlaHelpers::I64List(padding), ceil_mode));
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self,
bridge::GetXlaTensor(self));
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr output,
tensor_methods::max_pool_nd_backward(
xla_grad_output_0, xla_self, /*spatial_dim_count=*/3,
kernel_size, stride, padding, ceil_mode));
grad = bridge::AtenFromXlaTensor(std::move(output));

torch::Tensor undef;
torch::autograd::variable_list grad_inputs = {grad, undef, undef,
Expand All @@ -239,24 +243,28 @@ torch::Tensor max_pool2d_forward(torch::Tensor self,
torch::IntArrayRef stride,
torch::IntArrayRef padding,
torch::IntArrayRef dilation, bool ceil_mode) {
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
auto outputs = tensor_methods::max_pool_nd(
xla_self, /*spatial_dim_count=*/2, XlaHelpers::I64List(kernel_size),
XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), ceil_mode);
return bridge::AtenFromXlaTensor(std::get<0>(outputs));
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self,
bridge::GetXlaTensor(self));
std::tuple<absl_nonnull XLATensorPtr, absl_nonnull XLATensorPtr> output;
XLA_ASSIGN_OR_THROW(output, tensor_methods::max_pool_nd(
xla_self, /*spatial_dim_count=*/2,
kernel_size, stride, padding, ceil_mode));
return bridge::AtenFromXlaTensor(std::get<0>(output));
}

torch::Tensor max_pool2d_backward(torch::Tensor grad_output, torch::Tensor self,
torch::IntArrayRef kernel_size,
torch::IntArrayRef stride,
torch::IntArrayRef padding, bool ceil_mode) {
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_grad_output,
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_grad_output,
bridge::GetXlaTensor(grad_output));
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
auto grad = bridge::AtenFromXlaTensor(tensor_methods::max_pool_nd_backward(
xla_grad_output, xla_self, /*spatial_dim_count=*/2,
XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride),
XlaHelpers::I64List(padding), ceil_mode));
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self,
bridge::GetXlaTensor(self));
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr output,
tensor_methods::max_pool_nd_backward(
xla_grad_output, xla_self, /*spatial_dim_count=*/2,
kernel_size, stride, padding, ceil_mode));
auto grad = bridge::AtenFromXlaTensor(std::move(output));
return grad;
}

Expand Down
Loading