Skip to content

Commit d214faf

Browse files
authored
test: Use new macros for throwing exceptions. (#9590)
Follow-up: #9588 and #9580 Target: `test` directory In summary, this PR: - Replaces all calls to `OkOrThrow()` and `GetValueOrThrow()` (that throws an exception without source location information of the *"throw-site"*) with the macros `XLA_THROW_IF_ERROR()` and `XLA_ASSIGN_OR_THROW()`. - Corresponds to the fine-grained set of PRs that came from breaking down PR #9580 - Focuses on the `test` directory, replacing every use (except for the ones in _test_status_common.h_) of those, now deprecated, functions by the newly introduced macros.
1 parent 4c586bd commit d214faf

File tree

6 files changed

+93
-82
lines changed

6 files changed

+93
-82
lines changed

test/cpp/cpp_test_util.cpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -246,17 +246,17 @@ void WithAllDevices(
246246
}
247247

248248
std::string GetTensorTextGraph(at::Tensor tensor) {
249-
XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(tensor));
249+
XLA_ASSIGN_OR_THROW(XLATensorPtr xtensor, bridge::GetXlaTensor(tensor));
250250
return DumpUtil::ToText({xtensor->GetIrValue().node.get()});
251251
}
252252

253253
std::string GetTensorDotGraph(at::Tensor tensor) {
254-
XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(tensor));
254+
XLA_ASSIGN_OR_THROW(XLATensorPtr xtensor, bridge::GetXlaTensor(tensor));
255255
return DumpUtil::ToDot({xtensor->GetIrValue().node.get()});
256256
}
257257

258258
std::string GetTensorHloGraph(at::Tensor tensor) {
259-
XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(tensor));
259+
XLA_ASSIGN_OR_THROW(XLATensorPtr xtensor, bridge::GetXlaTensor(tensor));
260260
return DumpUtil::ToHlo({xtensor->GetIrValue()}, xtensor->GetDevice());
261261
}
262262

@@ -276,9 +276,9 @@ std::vector<torch_xla::runtime::ComputationClient::DataPtr> Execute(
276276
lowering_ctx.AddResult(root);
277277
}
278278

279-
xla::XlaComputation computation = GetValueOrThrow(lowering_ctx.BuildXla());
280-
xla::ProgramShape program_shape =
281-
GetValueOrThrow(computation.GetProgramShape());
279+
XLA_ASSIGN_OR_THROW(xla::XlaComputation computation, lowering_ctx.BuildXla());
280+
XLA_ASSIGN_OR_THROW(xla::ProgramShape program_shape,
281+
computation.GetProgramShape());
282282
xla::Shape shape = MakeShapeWithDeviceLayout(
283283
program_shape.result(), static_cast<XlaDeviceType>(device.type()));
284284

@@ -295,17 +295,20 @@ std::vector<torch_xla::runtime::ComputationClient::DataPtr> Execute(
295295
std::move(instances));
296296

297297
torch_xla::runtime::ComputationClient::ExecuteComputationOptions options;
298-
return GetValueOrThrow(
298+
XLA_ASSIGN_OR_THROW(
299+
std::vector<runtime::ComputationClient::DataPtr> outputs,
299300
torch_xla::runtime::GetComputationClientOrDie()->ExecuteComputation(
300301
*computations.front(),
301302
UnwrapXlaData(lowering_ctx.GetParametersData()), device.toString(),
302303
options));
304+
return outputs;
303305
}
304306

305307
std::vector<at::Tensor> Fetch(
306308
absl::Span<const torch_xla::runtime::ComputationClient::DataPtr>
307309
device_data) {
308-
std::vector<xla::Literal> literals = GetValueOrThrow(
310+
XLA_ASSIGN_OR_THROW(
311+
std::vector<xla::Literal> literals,
309312
runtime::GetComputationClientOrDie()->TransferFromDevice(device_data));
310313
std::vector<at::Tensor> tensors;
311314
for (auto& literal : literals) {

test/cpp/test_aten_xla_tensor_1.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ TEST_F(AtenXlaTensorTest, TestStorage) {
2727
torch::Tensor a = torch::tensor({0.0});
2828
ForEachDevice([&](const torch::Device& device) {
2929
torch::Tensor xla_a = CopyToDevice(a, device);
30-
XLATensorPtr xla_tensor_a = GetValueOrThrow(bridge::GetXlaTensor(xla_a));
30+
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_tensor_a, bridge::GetXlaTensor(xla_a));
3131
EXPECT_EQ(xla_a.device(), xla_tensor_a->Storage().device());
3232
AllClose(a, xla_a);
3333
});

test/cpp/test_replication.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ xla::XlaComputation CreateCrsComputation(const xla::Shape& shape) {
2525
xla::XlaBuilder builder("CrsComputation");
2626
xla::XlaOp x = xla::Parameter(&builder, 0, shape, "x");
2727
xla::CrossReplicaSum(x);
28-
return GetValueOrThrow(builder.Build());
28+
XLA_ASSIGN_OR_THROW(xla::XlaComputation crs_computation, builder.Build());
29+
return crs_computation;
2930
}
3031

3132
void TestSingleReplication(
@@ -65,7 +66,8 @@ void TestSingleReplication(
6566
torch_xla::runtime::ComputationClient::ExecuteComputationOptions exec_options;
6667
for (size_t i = 0; i < device_strings.size(); ++i) {
6768
auto executor = [&, i]() {
68-
results[i] = GetValueOrThrow(
69+
XLA_ASSIGN_OR_THROW(
70+
results[i],
6971
torch_xla::runtime::GetComputationClientOrDie()->ExecuteComputation(
7072
*compiled_computations[i],
7173
{std::dynamic_pointer_cast<
@@ -79,7 +81,8 @@ void TestSingleReplication(
7981
counter.Wait();
8082

8183
for (size_t i = 0; i < results.size(); ++i) {
82-
std::vector<xla::Literal> literals = GetValueOrThrow(
84+
XLA_ASSIGN_OR_THROW(
85+
std::vector<xla::Literal> literals,
8386
runtime::GetComputationClientOrDie()->TransferFromDevice(results[i]));
8487
ASSERT_EQ(literals.size(), 1);
8588

test/cpp/test_tensor.cpp

Lines changed: 69 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,8 @@ TEST_F(TensorTest, TestAdd) {
101101
at::Tensor c = a.add(b, 1.0);
102102

103103
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
104-
XLATensorPtr dev_a = GetValueOrThrow(XLATensor::Create(a, device));
105-
XLATensorPtr dev_b = GetValueOrThrow(XLATensor::Create(b, device));
104+
XLA_ASSIGN_OR_THROW(XLATensorPtr dev_a, XLATensor::Create(a, device));
105+
XLA_ASSIGN_OR_THROW(XLATensorPtr dev_b, XLATensor::Create(b, device));
106106
XLATensorPtr dev_c = tensor_methods::add(dev_a, dev_b, 1.0);
107107

108108
AllClose(c, dev_c);
@@ -121,8 +121,8 @@ TEST_F(TensorTest, TestIntegerAdd) {
121121
at::isIntegralType(type) ? at::Scalar(int64_t(1)) : at::Scalar(1.0);
122122
at::Tensor c = a.add(b, one);
123123

124-
XLATensorPtr dev_a = GetValueOrThrow(XLATensor::Create(a, device));
125-
XLATensorPtr dev_b = GetValueOrThrow(XLATensor::Create(b, device));
124+
XLA_ASSIGN_OR_THROW(XLATensorPtr dev_a, XLATensor::Create(a, device));
125+
XLA_ASSIGN_OR_THROW(XLATensorPtr dev_b, XLATensor::Create(b, device));
126126
XLATensorPtr dev_c = tensor_methods::add(dev_a, dev_b, one);
127127

128128
EXPECT_TRUE(EqualValuesNoElementTypeCheck(
@@ -135,7 +135,8 @@ TEST_F(TensorTest, TestSize) {
135135
at::Tensor input = at::rand({2, 1, 4, 6}, at::TensorOptions(at::kFloat));
136136
int rank = input.dim();
137137
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
138-
XLATensorPtr dev_input = GetValueOrThrow(XLATensor::Create(input, device));
138+
XLA_ASSIGN_OR_THROW(XLATensorPtr dev_input,
139+
XLATensor::Create(input, device));
139140
for (int dim = -rank; dim < rank; ++dim) {
140141
EXPECT_EQ(input.size(dim), dev_input->size(dim));
141142
}
@@ -151,10 +152,10 @@ TEST_F(TensorTest, TestRrelu) {
151152
at::Tensor noise = at::zeros_like(input);
152153
at::Tensor output =
153154
at::rrelu_with_noise(input, noise, lower, upper, training);
154-
XLATensorPtr dev_input =
155-
GetValueOrThrow(XLATensor::Create(input, device));
156-
XLATensorPtr dev_noise =
157-
GetValueOrThrow(XLATensor::Create(noise, device));
155+
XLA_ASSIGN_OR_THROW(XLATensorPtr dev_input,
156+
XLATensor::Create(input, device));
157+
XLA_ASSIGN_OR_THROW(XLATensorPtr dev_noise,
158+
XLATensor::Create(noise, device));
158159
XLATensorPtr dev_outputs = tensor_methods::rrelu_with_noise(
159160
dev_input, dev_noise, lower, upper, training);
160161
AllClose(output, dev_outputs);
@@ -169,7 +170,8 @@ TEST_F(TensorTest, TestThreshold) {
169170
float value = 20;
170171
at::Tensor output = at::threshold(input, threshold, value);
171172
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
172-
XLATensorPtr dev_input = GetValueOrThrow(XLATensor::Create(input, device));
173+
XLA_ASSIGN_OR_THROW(XLATensorPtr dev_input,
174+
XLATensor::Create(input, device));
173175
XLATensorPtr dev_output =
174176
tensor_methods::threshold(dev_input, threshold, value);
175177
AllClose(output, dev_output);
@@ -187,10 +189,11 @@ TEST_F(TensorTest, TestAddMatMul) {
187189
at::Tensor bias = at::rand({labels}, at::TensorOptions(at::kFloat));
188190
at::Tensor output = at::addmm(bias, input, weight);
189191
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
190-
XLATensorPtr dev_input = GetValueOrThrow(XLATensor::Create(input, device));
191-
XLATensorPtr dev_weight =
192-
GetValueOrThrow(XLATensor::Create(weight, device));
193-
XLATensorPtr dev_bias = GetValueOrThrow(XLATensor::Create(bias, device));
192+
XLA_ASSIGN_OR_THROW(XLATensorPtr dev_input,
193+
XLATensor::Create(input, device));
194+
XLA_ASSIGN_OR_THROW(XLATensorPtr dev_weight,
195+
XLATensor::Create(weight, device));
196+
XLA_ASSIGN_OR_THROW(XLATensorPtr dev_bias, XLATensor::Create(bias, device));
194197
XLATensorPtr dev_output =
195198
tensor_methods::addmm(dev_input, dev_weight, dev_bias);
196199
AllClose(output, dev_output);
@@ -201,7 +204,8 @@ TEST_F(TensorTest, TestTranspose) {
201204
at::Tensor input = at::rand({2, 3}, at::TensorOptions(at::kFloat));
202205
at::Tensor output = at::transpose(input, 0, 1);
203206
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
204-
XLATensorPtr dev_input = GetValueOrThrow(XLATensor::Create(input, device));
207+
XLA_ASSIGN_OR_THROW(XLATensorPtr dev_input,
208+
XLATensor::Create(input, device));
205209
XLATensorPtr dev_output = tensor_methods::transpose(dev_input, 0, 1);
206210
AllClose(output, dev_output);
207211
});
@@ -211,7 +215,8 @@ TEST_F(TensorTest, TestView) {
211215
at::Tensor input = at::rand({32, 20, 4, 4}, at::TensorOptions(at::kFloat));
212216
at::Tensor output = input.view({-1, 320});
213217
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
214-
XLATensorPtr dev_input = GetValueOrThrow(XLATensor::Create(input, device));
218+
XLA_ASSIGN_OR_THROW(XLATensorPtr dev_input,
219+
XLATensor::Create(input, device));
215220
XLATensorPtr dev_output = tensor_methods::view(dev_input, {-1, 320});
216221
AllClose(output, dev_output);
217222
});
@@ -292,8 +297,8 @@ TEST_F(TensorTest, TestMaxPool2D) {
292297
/*padding=*/{padding, padding}, /*dilation=*/{1, 1},
293298
/*ceil_mode=*/false);
294299
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
295-
XLATensorPtr dev_input =
296-
GetValueOrThrow(XLATensor::Create(input, device));
300+
XLA_ASSIGN_OR_THROW(XLATensorPtr dev_input,
301+
XLATensor::Create(input, device));
297302
auto dev_output = tensor_methods::max_pool_nd(
298303
dev_input,
299304
/*spatial_dim_count=*/2,
@@ -317,8 +322,8 @@ TEST_F(TensorTest, TestMaxPool2DNonSquare) {
317322
/*padding=*/{padding, padding + 1}, /*dilation=*/{1, 1},
318323
/*ceil_mode=*/false);
319324
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
320-
XLATensorPtr dev_input =
321-
GetValueOrThrow(XLATensor::Create(input, device));
325+
XLA_ASSIGN_OR_THROW(XLATensorPtr dev_input,
326+
XLATensor::Create(input, device));
322327
auto dev_output = tensor_methods::max_pool_nd(
323328
dev_input,
324329
/*spatial_dim_count=*/2,
@@ -346,8 +351,8 @@ TEST_F(TensorTest, TestAvgPool2D) {
346351
/*ceil_mode=*/false, count_include_pad,
347352
/*divisor_override=*/std::nullopt);
348353
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
349-
XLATensorPtr dev_input =
350-
GetValueOrThrow(XLATensor::Create(input, device));
354+
XLA_ASSIGN_OR_THROW(XLATensorPtr dev_input,
355+
XLATensor::Create(input, device));
351356
XLATensorPtr dev_output = tensor_methods::avg_pool_nd(
352357
dev_input,
353358
/*spatial_dim_count=*/2,
@@ -377,8 +382,8 @@ TEST_F(TensorTest, TestAvgPool2DNonSquare) {
377382
/*count_include_pad=*/count_include_pad,
378383
/*divisor_override=*/std::nullopt);
379384
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
380-
XLATensorPtr dev_input =
381-
GetValueOrThrow(XLATensor::Create(input, device));
385+
XLA_ASSIGN_OR_THROW(XLATensorPtr dev_input,
386+
XLATensor::Create(input, device));
382387
XLATensorPtr dev_output = tensor_methods::avg_pool_nd(
383388
dev_input,
384389
/*spatial_dim_count=*/2,
@@ -416,20 +421,20 @@ TEST_F(TensorTest, TestBatchNorm1D) {
416421
/*running_mean=*/running_mean, /*running_var=*/running_var,
417422
/*training=*/training, /*momentum=*/momentum, /*eps=*/eps);
418423
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
419-
XLATensorPtr xla_input =
420-
GetValueOrThrow(XLATensor::Create(input, device));
421-
XLATensorPtr xla_weight =
422-
undef_weight_bias
423-
? XLATensorPtr()
424-
: GetValueOrThrow(XLATensor::Create(weight, device));
425-
XLATensorPtr xla_bias =
426-
undef_weight_bias
427-
? XLATensorPtr()
428-
: GetValueOrThrow(XLATensor::Create(bias, device));
429-
XLATensorPtr xla_running_mean =
430-
GetValueOrThrow(XLATensor::Create(running_mean, device));
431-
XLATensorPtr xla_running_var =
432-
GetValueOrThrow(XLATensor::Create(running_var, device));
424+
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_input,
425+
XLATensor::Create(input, device));
426+
XLATensorPtr xla_weight;
427+
if (!undef_weight_bias) {
428+
XLA_ASSIGN_OR_THROW(xla_weight, XLATensor::Create(weight, device));
429+
}
430+
XLATensorPtr xla_bias;
431+
if (!undef_weight_bias) {
432+
XLA_ASSIGN_OR_THROW(xla_bias, XLATensor::Create(bias, device));
433+
}
434+
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_running_mean,
435+
XLATensor::Create(running_mean, device));
436+
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_running_var,
437+
XLATensor::Create(running_var, device));
433438
auto xla_output = tensor_methods::native_batch_norm(
434439
/*input=*/xla_input, /*weight=*/xla_weight, /*bias=*/xla_bias,
435440
/*running_mean=*/xla_running_mean, /*running_var=*/xla_running_var,
@@ -486,14 +491,14 @@ TEST_F(TensorTest, TestConv2D) {
486491
/*output_padding=*/{output_padding, output_padding},
487492
/*groups=*/groups, false, false, false);
488493
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
489-
XLATensorPtr dev_input =
490-
GetValueOrThrow(XLATensor::Create(input, device));
491-
XLATensorPtr dev_weight =
492-
GetValueOrThrow(XLATensor::Create(weight, device));
494+
XLA_ASSIGN_OR_THROW(XLATensorPtr dev_input,
495+
XLATensor::Create(input, device));
496+
XLA_ASSIGN_OR_THROW(XLATensorPtr dev_weight,
497+
XLATensor::Create(weight, device));
493498
XLATensorPtr dev_output;
494499
if (with_bias) {
495-
XLATensorPtr dev_bias =
496-
GetValueOrThrow(XLATensor::Create(bias, device));
500+
XLA_ASSIGN_OR_THROW(XLATensorPtr dev_bias,
501+
XLATensor::Create(bias, device));
497502
dev_output = tensor_methods::convolution_overrideable(
498503
dev_input, dev_weight, dev_bias,
499504
/*stride=*/{stride, stride},
@@ -558,14 +563,14 @@ TEST_F(TensorTest, TestConv2DNonSquare) {
558563
/*groups=*/groups, false, false, false);
559564

560565
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
561-
XLATensorPtr dev_input =
562-
GetValueOrThrow(XLATensor::Create(input, device));
563-
XLATensorPtr dev_weight =
564-
GetValueOrThrow(XLATensor::Create(weight, device));
566+
XLA_ASSIGN_OR_THROW(XLATensorPtr dev_input,
567+
XLATensor::Create(input, device));
568+
XLA_ASSIGN_OR_THROW(XLATensorPtr dev_weight,
569+
XLATensor::Create(weight, device));
565570
XLATensorPtr dev_output;
566571
if (with_bias) {
567-
XLATensorPtr dev_bias =
568-
GetValueOrThrow(XLATensor::Create(bias, device));
572+
XLA_ASSIGN_OR_THROW(XLATensorPtr dev_bias,
573+
XLATensor::Create(bias, device));
569574
dev_output = tensor_methods::convolution_overrideable(
570575
dev_input, dev_weight, dev_bias,
571576
/*stride=*/{stride, stride + 1},
@@ -634,14 +639,14 @@ TEST_F(TensorTest, TestConv3D) {
634639
{output_padding, output_padding, output_padding},
635640
/*groups=*/groups, false, false, false);
636641
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
637-
XLATensorPtr dev_input =
638-
GetValueOrThrow(XLATensor::Create(input, device));
639-
XLATensorPtr dev_weight =
640-
GetValueOrThrow(XLATensor::Create(weight, device));
642+
XLA_ASSIGN_OR_THROW(XLATensorPtr dev_input,
643+
XLATensor::Create(input, device));
644+
XLA_ASSIGN_OR_THROW(XLATensorPtr dev_weight,
645+
XLATensor::Create(weight, device));
641646
XLATensorPtr dev_output;
642647
if (with_bias) {
643-
XLATensorPtr dev_bias =
644-
GetValueOrThrow(XLATensor::Create(bias, device));
648+
XLA_ASSIGN_OR_THROW(XLATensorPtr dev_bias,
649+
XLATensor::Create(bias, device));
645650
dev_output = tensor_methods::convolution_overrideable(
646651
dev_input, dev_weight, dev_bias,
647652
/*stride=*/{stride, stride, stride},
@@ -709,15 +714,14 @@ TEST_F(TensorTest, TestConv3D) {
709714
// {output_padding, output_padding + 1, output_padding},
710715
// /*groups=*/groups, false, false, false);
711716
// ForEachDevice([&](const torch::lazy::BackendDevice& device) {
712-
// XLATensorPtr dev_input =
713-
// GetValueOrThrow(XLATensor::Create(input, device));
714-
// XLATensorPtr dev_weight =
715-
// GetValueOrThrow(XLATensor::Create(weight, device);
716-
// XLATensorPtr dev_output;
717-
// if (with_bias) {
718-
// XLATensorPtr dev_bias =
719-
// GetValueOrThrow(XLATensor::Create(bias, device));
720-
// dev_output = tensor_methods::convolution_overrideable(
717+
// XLA_ASSIGN_OR_THROW(XLATensorPtr dev_input,
718+
// XLATensor::Create(input, device));
719+
// XLA_ASSIGN_OR_THROW(XLATensorPtr dev_weight,
720+
// XLATensor::Create(weight, device)); XLATensorPtr
721+
// dev_output; if (with_bias) {
722+
// XLA_ASSIGN_OR_THROW(XLATensorPtr dev_bias,
723+
// XLATensor::Create(bias, device)); dev_output =
724+
// tensor_methods::convolution_overrideable(
721725
// dev_input, dev_weight, dev_bias,
722726
// /*stride=*/{stride, stride + 1, stride + 1},
723727
// /*padding=*/{padding, padding + 1, padding + 1},

test/cpp/test_xla_backend_intf.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ xla::XlaComputation CreateAddComputation(const xla::Shape& shape) {
5353
xla::XlaOp x = xla::Parameter(&builder, 0, shape, "x");
5454
xla::XlaOp y = xla::Parameter(&builder, 1, shape, "y");
5555
xla::XlaOp sum = xla::Add(x, y);
56-
return GetValueOrThrow(builder.Build());
56+
XLA_ASSIGN_OR_THROW(xla::XlaComputation add_computation, builder.Build());
57+
return add_computation;
5758
}
5859

5960
TEST(XLABackendTest, TestE2E) {

test/cpp/test_xla_sharding.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ namespace {
2828
bool XlaDataValuesEqual(torch::lazy::BackendDataPtr a,
2929
torch::lazy::BackendDataPtr b,
3030
at::ScalarType element_type) {
31-
std::vector<at::Tensor> tensors =
32-
GetValueOrThrow(XlaDataToTensors({a, b}, {element_type, element_type}));
31+
XLA_ASSIGN_OR_THROW(std::vector<at::Tensor> tensors,
32+
XlaDataToTensors({a, b}, {element_type, element_type}));
3333
return TensorCompare(tensors[0], tensors[1]);
3434
}
3535
} // namespace
@@ -385,8 +385,8 @@ TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) {
385385
auto x = xla::Parameter(&b, 0, shape, "p0");
386386
b.ClearSharding();
387387
auto y = xla::Add(x, xla::ConstantR0<float>(&b, 3));
388-
xla::XlaComputation xla_computation =
389-
GetValueOrThrow(b.Build(/*remove_dynamic_dimensions=*/false));
388+
XLA_ASSIGN_OR_THROW(xla::XlaComputation xla_computation,
389+
b.Build(/*remove_dynamic_dimensions=*/false));
390390
std::vector<torch_xla::runtime::ComputationClient::CompileInstance> instances;
391391
instances.push_back({std::move(xla_computation),
392392
bridge::GetDefaultDevice()->toString(),

0 commit comments

Comments
 (0)