Skip to content

Commit 095faec

Browse files
authored
Error Handling: make XLATensor::Create() return status type. (#9544)
1 parent 8c1449f commit 095faec

File tree

5 files changed

+85
-50
lines changed

5 files changed

+85
-50
lines changed

test/cpp/test_tensor.cpp

Lines changed: 64 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,8 @@ TEST_F(TensorTest, TestAdd) {
101101
at::Tensor c = a.add(b, 1.0);
102102

103103
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
104-
XLATensorPtr dev_a = XLATensor::Create(a, device);
105-
XLATensorPtr dev_b = XLATensor::Create(b, device);
104+
XLATensorPtr dev_a = GetValueOrThrow(XLATensor::Create(a, device));
105+
XLATensorPtr dev_b = GetValueOrThrow(XLATensor::Create(b, device));
106106
XLATensorPtr dev_c = tensor_methods::add(dev_a, dev_b, 1.0);
107107

108108
AllClose(c, dev_c);
@@ -121,8 +121,8 @@ TEST_F(TensorTest, TestIntegerAdd) {
121121
at::isIntegralType(type) ? at::Scalar(int64_t(1)) : at::Scalar(1.0);
122122
at::Tensor c = a.add(b, one);
123123

124-
XLATensorPtr dev_a = XLATensor::Create(a, device);
125-
XLATensorPtr dev_b = XLATensor::Create(b, device);
124+
XLATensorPtr dev_a = GetValueOrThrow(XLATensor::Create(a, device));
125+
XLATensorPtr dev_b = GetValueOrThrow(XLATensor::Create(b, device));
126126
XLATensorPtr dev_c = tensor_methods::add(dev_a, dev_b, one);
127127

128128
EXPECT_TRUE(EqualValuesNoElementTypeCheck(
@@ -135,7 +135,7 @@ TEST_F(TensorTest, TestSize) {
135135
at::Tensor input = at::rand({2, 1, 4, 6}, at::TensorOptions(at::kFloat));
136136
int rank = input.dim();
137137
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
138-
XLATensorPtr dev_input = XLATensor::Create(input, device);
138+
XLATensorPtr dev_input = GetValueOrThrow(XLATensor::Create(input, device));
139139
for (int dim = -rank; dim < rank; ++dim) {
140140
EXPECT_EQ(input.size(dim), dev_input->size(dim));
141141
}
@@ -151,8 +151,10 @@ TEST_F(TensorTest, TestRrelu) {
151151
at::Tensor noise = at::zeros_like(input);
152152
at::Tensor output =
153153
at::rrelu_with_noise(input, noise, lower, upper, training);
154-
XLATensorPtr dev_input = XLATensor::Create(input, device);
155-
XLATensorPtr dev_noise = XLATensor::Create(noise, device);
154+
XLATensorPtr dev_input =
155+
GetValueOrThrow(XLATensor::Create(input, device));
156+
XLATensorPtr dev_noise =
157+
GetValueOrThrow(XLATensor::Create(noise, device));
156158
XLATensorPtr dev_outputs = tensor_methods::rrelu_with_noise(
157159
dev_input, dev_noise, lower, upper, training);
158160
AllClose(output, dev_outputs);
@@ -167,7 +169,7 @@ TEST_F(TensorTest, TestThreshold) {
167169
float value = 20;
168170
at::Tensor output = at::threshold(input, threshold, value);
169171
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
170-
XLATensorPtr dev_input = XLATensor::Create(input, device);
172+
XLATensorPtr dev_input = GetValueOrThrow(XLATensor::Create(input, device));
171173
XLATensorPtr dev_output =
172174
tensor_methods::threshold(dev_input, threshold, value);
173175
AllClose(output, dev_output);
@@ -185,9 +187,10 @@ TEST_F(TensorTest, TestAddMatMul) {
185187
at::Tensor bias = at::rand({labels}, at::TensorOptions(at::kFloat));
186188
at::Tensor output = at::addmm(bias, input, weight);
187189
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
188-
XLATensorPtr dev_input = XLATensor::Create(input, device);
189-
XLATensorPtr dev_weight = XLATensor::Create(weight, device);
190-
XLATensorPtr dev_bias = XLATensor::Create(bias, device);
190+
XLATensorPtr dev_input = GetValueOrThrow(XLATensor::Create(input, device));
191+
XLATensorPtr dev_weight =
192+
GetValueOrThrow(XLATensor::Create(weight, device));
193+
XLATensorPtr dev_bias = GetValueOrThrow(XLATensor::Create(bias, device));
191194
XLATensorPtr dev_output =
192195
tensor_methods::addmm(dev_input, dev_weight, dev_bias);
193196
AllClose(output, dev_output);
@@ -198,7 +201,7 @@ TEST_F(TensorTest, TestTranspose) {
198201
at::Tensor input = at::rand({2, 3}, at::TensorOptions(at::kFloat));
199202
at::Tensor output = at::transpose(input, 0, 1);
200203
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
201-
XLATensorPtr dev_input = XLATensor::Create(input, device);
204+
XLATensorPtr dev_input = GetValueOrThrow(XLATensor::Create(input, device));
202205
XLATensorPtr dev_output = tensor_methods::transpose(dev_input, 0, 1);
203206
AllClose(output, dev_output);
204207
});
@@ -208,7 +211,7 @@ TEST_F(TensorTest, TestView) {
208211
at::Tensor input = at::rand({32, 20, 4, 4}, at::TensorOptions(at::kFloat));
209212
at::Tensor output = input.view({-1, 320});
210213
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
211-
XLATensorPtr dev_input = XLATensor::Create(input, device);
214+
XLATensorPtr dev_input = GetValueOrThrow(XLATensor::Create(input, device));
212215
XLATensorPtr dev_output = tensor_methods::view(dev_input, {-1, 320});
213216
AllClose(output, dev_output);
214217
});
@@ -289,7 +292,8 @@ TEST_F(TensorTest, TestMaxPool2D) {
289292
/*padding=*/{padding, padding}, /*dilation=*/{1, 1},
290293
/*ceil_mode=*/false);
291294
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
292-
XLATensorPtr dev_input = XLATensor::Create(input, device);
295+
XLATensorPtr dev_input =
296+
GetValueOrThrow(XLATensor::Create(input, device));
293297
auto dev_output = tensor_methods::max_pool_nd(
294298
dev_input,
295299
/*spatial_dim_count=*/2,
@@ -313,7 +317,8 @@ TEST_F(TensorTest, TestMaxPool2DNonSquare) {
313317
/*padding=*/{padding, padding + 1}, /*dilation=*/{1, 1},
314318
/*ceil_mode=*/false);
315319
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
316-
XLATensorPtr dev_input = XLATensor::Create(input, device);
320+
XLATensorPtr dev_input =
321+
GetValueOrThrow(XLATensor::Create(input, device));
317322
auto dev_output = tensor_methods::max_pool_nd(
318323
dev_input,
319324
/*spatial_dim_count=*/2,
@@ -341,7 +346,8 @@ TEST_F(TensorTest, TestAvgPool2D) {
341346
/*ceil_mode=*/false, count_include_pad,
342347
/*divisor_override=*/std::nullopt);
343348
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
344-
XLATensorPtr dev_input = XLATensor::Create(input, device);
349+
XLATensorPtr dev_input =
350+
GetValueOrThrow(XLATensor::Create(input, device));
345351
XLATensorPtr dev_output = tensor_methods::avg_pool_nd(
346352
dev_input,
347353
/*spatial_dim_count=*/2,
@@ -371,7 +377,8 @@ TEST_F(TensorTest, TestAvgPool2DNonSquare) {
371377
/*count_include_pad=*/count_include_pad,
372378
/*divisor_override=*/std::nullopt);
373379
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
374-
XLATensorPtr dev_input = XLATensor::Create(input, device);
380+
XLATensorPtr dev_input =
381+
GetValueOrThrow(XLATensor::Create(input, device));
375382
XLATensorPtr dev_output = tensor_methods::avg_pool_nd(
376383
dev_input,
377384
/*spatial_dim_count=*/2,
@@ -409,15 +416,20 @@ TEST_F(TensorTest, TestBatchNorm1D) {
409416
/*running_mean=*/running_mean, /*running_var=*/running_var,
410417
/*training=*/training, /*momentum=*/momentum, /*eps=*/eps);
411418
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
412-
XLATensorPtr xla_input = XLATensor::Create(input, device);
413-
XLATensorPtr xla_weight = undef_weight_bias
414-
? XLATensorPtr()
415-
: XLATensor::Create(weight, device);
416-
XLATensorPtr xla_bias = undef_weight_bias
417-
? XLATensorPtr()
418-
: XLATensor::Create(bias, device);
419-
XLATensorPtr xla_running_mean = XLATensor::Create(running_mean, device);
420-
XLATensorPtr xla_running_var = XLATensor::Create(running_var, device);
419+
XLATensorPtr xla_input =
420+
GetValueOrThrow(XLATensor::Create(input, device));
421+
XLATensorPtr xla_weight =
422+
undef_weight_bias
423+
? XLATensorPtr()
424+
: GetValueOrThrow(XLATensor::Create(weight, device));
425+
XLATensorPtr xla_bias =
426+
undef_weight_bias
427+
? XLATensorPtr()
428+
: GetValueOrThrow(XLATensor::Create(bias, device));
429+
XLATensorPtr xla_running_mean =
430+
GetValueOrThrow(XLATensor::Create(running_mean, device));
431+
XLATensorPtr xla_running_var =
432+
GetValueOrThrow(XLATensor::Create(running_var, device));
421433
auto xla_output = tensor_methods::native_batch_norm(
422434
/*input=*/xla_input, /*weight=*/xla_weight, /*bias=*/xla_bias,
423435
/*running_mean=*/xla_running_mean, /*running_var=*/xla_running_var,
@@ -474,11 +486,14 @@ TEST_F(TensorTest, TestConv2D) {
474486
/*output_padding=*/{output_padding, output_padding},
475487
/*groups=*/groups, false, false, false);
476488
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
477-
XLATensorPtr dev_input = XLATensor::Create(input, device);
478-
XLATensorPtr dev_weight = XLATensor::Create(weight, device);
489+
XLATensorPtr dev_input =
490+
GetValueOrThrow(XLATensor::Create(input, device));
491+
XLATensorPtr dev_weight =
492+
GetValueOrThrow(XLATensor::Create(weight, device));
479493
XLATensorPtr dev_output;
480494
if (with_bias) {
481-
XLATensorPtr dev_bias = XLATensor::Create(bias, device);
495+
XLATensorPtr dev_bias =
496+
GetValueOrThrow(XLATensor::Create(bias, device));
482497
dev_output = tensor_methods::convolution_overrideable(
483498
dev_input, dev_weight, dev_bias,
484499
/*stride=*/{stride, stride},
@@ -543,11 +558,14 @@ TEST_F(TensorTest, TestConv2DNonSquare) {
543558
/*groups=*/groups, false, false, false);
544559

545560
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
546-
XLATensorPtr dev_input = XLATensor::Create(input, device);
547-
XLATensorPtr dev_weight = XLATensor::Create(weight, device);
561+
XLATensorPtr dev_input =
562+
GetValueOrThrow(XLATensor::Create(input, device));
563+
XLATensorPtr dev_weight =
564+
GetValueOrThrow(XLATensor::Create(weight, device));
548565
XLATensorPtr dev_output;
549566
if (with_bias) {
550-
XLATensorPtr dev_bias = XLATensor::Create(bias, device);
567+
XLATensorPtr dev_bias =
568+
GetValueOrThrow(XLATensor::Create(bias, device));
551569
dev_output = tensor_methods::convolution_overrideable(
552570
dev_input, dev_weight, dev_bias,
553571
/*stride=*/{stride, stride + 1},
@@ -616,11 +634,14 @@ TEST_F(TensorTest, TestConv3D) {
616634
{output_padding, output_padding, output_padding},
617635
/*groups=*/groups, false, false, false);
618636
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
619-
XLATensorPtr dev_input = XLATensor::Create(input, device);
620-
XLATensorPtr dev_weight = XLATensor::Create(weight, device);
637+
XLATensorPtr dev_input =
638+
GetValueOrThrow(XLATensor::Create(input, device));
639+
XLATensorPtr dev_weight =
640+
GetValueOrThrow(XLATensor::Create(weight, device));
621641
XLATensorPtr dev_output;
622642
if (with_bias) {
623-
XLATensorPtr dev_bias = XLATensor::Create(bias, device);
643+
XLATensorPtr dev_bias =
644+
GetValueOrThrow(XLATensor::Create(bias, device));
624645
dev_output = tensor_methods::convolution_overrideable(
625646
dev_input, dev_weight, dev_bias,
626647
/*stride=*/{stride, stride, stride},
@@ -688,10 +709,14 @@ TEST_F(TensorTest, TestConv3D) {
688709
// {output_padding, output_padding + 1, output_padding},
689710
// /*groups=*/groups, false, false, false);
690711
// ForEachDevice([&](const torch::lazy::BackendDevice& device) {
691-
// XLATensorPtr dev_input = XLATensor::Create(input, device);
692-
// XLATensorPtr dev_weight = XLATensor::Create(weight,
693-
// device); XLATensorPtr dev_output; if (with_bias) {
694-
// XLATensorPtr dev_bias = XLATensor::Create(bias, device);
712+
// XLATensorPtr dev_input =
713+
// GetValueOrThrow(XLATensor::Create(input, device));
714+
// XLATensorPtr dev_weight =
715+
// GetValueOrThrow(XLATensor::Create(weight, device);
716+
// XLATensorPtr dev_output;
717+
// if (with_bias) {
718+
// XLATensorPtr dev_bias =
719+
// GetValueOrThrow(XLATensor::Create(bias, device));
695720
// dev_output = tensor_methods::convolution_overrideable(
696721
// dev_input, dev_weight, dev_bias,
697722
// /*stride=*/{stride, stride + 1, stride + 1},

torch_xla/csrc/aten_xla_bridge.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,8 +186,9 @@ XLATensorPtr GetOrCreateXlaTensor(const at::Tensor& tensor,
186186
}
187187

188188
auto xtensor = GetXlaTensor(tensor);
189-
return xtensor.ok() ? xtensor.value()
190-
: XLATensor::Create(inner_tensor, device);
189+
return xtensor.ok()
190+
? xtensor.value()
191+
: GetValueOrThrow(XLATensor::Create(inner_tensor, device));
191192
}
192193

193194
XLATensorPtr GetOrCreateXlaTensor(const std::optional<at::Tensor>& tensor,
@@ -478,7 +479,8 @@ at::Tensor CreateXlaTensor(
478479
at::Tensor tensor,
479480
const std::optional<torch::lazy::BackendDevice>& device) {
480481
if (tensor.defined() && device) {
481-
XLATensorPtr xla_tensor = XLATensor::Create(std::move(tensor), *device);
482+
XLATensorPtr xla_tensor =
483+
GetValueOrThrow(XLATensor::Create(std::move(tensor), *device));
482484
tensor = AtenFromXlaTensor(xla_tensor);
483485
}
484486
return tensor;

torch_xla/csrc/tensor.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,15 @@ bool CanApplySharding(const XLATensor::ShardingSpecPtr sharding) {
6161

6262
XLATensor::Data::~Data() { XLAGraphExecutor::Get()->UnregisterTensor(this); }
6363

64-
XLATensorPtr XLATensor::Create(const at::Tensor& tensor,
65-
const torch::lazy::BackendDevice& device) {
66-
XLA_CHECK_EQ(tensor.device().type(), at::kCPU);
64+
absl::StatusOr<absl_nonnull XLATensorPtr> XLATensor::Create(
65+
const at::Tensor& tensor, const torch::lazy::BackendDevice& device) {
66+
if (!tensor.is_cpu()) {
67+
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat(
68+
"Could not create an XLATensor out of the provided tensor. Expected "
69+
"tensor data to be on CPU. Got: ",
70+
c10::DeviceTypeName(tensor.device().type()),
71+
". Consider moving the tensor to CPU.")));
72+
}
6773
XLATensorPtr xtensor =
6874
c10::make_intrusive<XLATensor>(XLATensor(tensor, device));
6975
XLAGraphExecutor::Get()->RegisterTensor(xtensor->data());
@@ -621,7 +627,7 @@ std::vector<XLATensorPtr> XLATensor::MakeOutputTensors(
621627
XLATensorPtr XLATensor::CopyTensorToDevice(
622628
const torch::lazy::BackendDevice& device) {
623629
// TODO: This can be optimized via proper XRT/XLA computation.
624-
return Create(ToTensor(/*detached=*/true), device);
630+
return GetValueOrThrow(Create(ToTensor(/*detached=*/true), device));
625631
}
626632

627633
torch::lazy::Value XLATensor::MaybeCastIrValue(

torch_xla/csrc/tensor.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <memory>
1010
#include <string>
1111

12+
#include "absl/base/nullability.h"
1213
#include "torch_xla/csrc/runtime/util.h"
1314
#include "torch_xla/csrc/view.h"
1415

@@ -149,8 +150,8 @@ class XLATensor : public torch::lazy::LazyTensor {
149150
bool is_cloned = false;
150151
};
151152

152-
static XLATensorPtr Create(const at::Tensor& tensor,
153-
const torch::lazy::BackendDevice& device);
153+
static absl::StatusOr<absl_nonnull XLATensorPtr> Create(
154+
const at::Tensor& tensor, const torch::lazy::BackendDevice& device);
154155
static XLATensorPtr Create(
155156
torch::lazy::BackendDataPtr handle,
156157
std::optional<at::ScalarType> logical_element_type = std::nullopt);

torch_xla/csrc/tensor_methods.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1336,8 +1336,9 @@ std::tuple<XLATensorPtr, XLATensorPtr> cummax(const XLATensorPtr& input,
13361336
at::Tensor val =
13371337
at::empty(shape_, at::TensorOptions().dtype(input->dtype()));
13381338
at::Tensor idx = at::empty(shape_, at::TensorOptions().dtype(at::kLong));
1339-
return std::make_tuple(input->Create(val, input->GetDevice()),
1340-
input->Create(idx, input->GetDevice()));
1339+
return std::make_tuple(
1340+
GetValueOrThrow(XLATensor::Create(val, input->GetDevice())),
1341+
GetValueOrThrow(XLATensor::Create(idx, input->GetDevice())));
13411342
}
13421343
torch::lazy::NodePtr node =
13431344
torch_xla::MakeNode<CumMax>(input->GetIrValue(), canonical_dim);

0 commit comments

Comments
 (0)