Skip to content

Commit 51518f9

Browse files
authored
Error Handling: replace ConsumeValue with GetValueOrThrow. (#9464)
1 parent 992e87a commit 51518f9

23 files changed

+88
-60
lines changed

test/cpp/BUILD

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ ptxla_cc_library(
4040
"//torch_xla/csrc/runtime:runtime",
4141
"//torch_xla/csrc/runtime:debug_macros",
4242
"//torch_xla/csrc/runtime:sys_util",
43+
"//torch_xla/csrc:status",
4344
"//torch_xla/csrc:tensor",
4445
"@com_google_absl//absl/types:span",
4546
"@com_google_googletest//:gtest",
@@ -78,6 +79,7 @@ ptxla_cc_test(
7879
":torch_xla_test",
7980
"//torch_xla/csrc/runtime:runtime",
8081
"//torch_xla/csrc/runtime:debug_macros",
82+
"//torch_xla/csrc:status",
8183
"//torch_xla/csrc:tensor",
8284
"//torch_xla/csrc:aten_cuda_functions",
8385
"//torch_xla/csrc:thread_pool",
@@ -109,6 +111,7 @@ ptxla_cc_test(
109111
# srcs = ["test_xla_backend_intf.cpp"],
110112
# deps = [
111113
# ":cpp_test_util",
114+
# "//torch_xla/csrc:status",
112115
# "//torch_xla/csrc:tensor",
113116
# "@com_google_googletest//:gtest_main",
114117
# ],
@@ -122,6 +125,7 @@ ptxla_cc_test(
122125
":torch_xla_test",
123126
"//torch_xla/csrc/runtime:env_vars",
124127
"//torch_xla/csrc/runtime:sys_util",
128+
"//torch_xla/csrc:status",
125129
"//torch_xla/csrc:tensor",
126130
"//torch_xla/csrc:aten_cuda_functions",
127131
"@com_google_googletest//:gtest_main",

test/cpp/cpp_test_util.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "torch_xla/csrc/runtime/debug_macros.h"
1515
#include "torch_xla/csrc/runtime/runtime.h"
1616
#include "torch_xla/csrc/runtime/sys_util.h"
17+
#include "torch_xla/csrc/status.h"
1718
#include "torch_xla/csrc/tensor_impl.h"
1819
#include "torch_xla/csrc/tensor_util.h"
1920
#include "torch_xla/csrc/torch_util.h"
@@ -275,8 +276,9 @@ std::vector<torch_xla::runtime::ComputationClient::DataPtr> Execute(
275276
lowering_ctx.AddResult(root);
276277
}
277278

278-
xla::XlaComputation computation = ConsumeValue(lowering_ctx.BuildXla());
279-
xla::ProgramShape program_shape = ConsumeValue(computation.GetProgramShape());
279+
xla::XlaComputation computation = GetValueOrThrow(lowering_ctx.BuildXla());
280+
xla::ProgramShape program_shape =
281+
GetValueOrThrow(computation.GetProgramShape());
280282
xla::Shape shape = MakeShapeWithDeviceLayout(
281283
program_shape.result(), static_cast<XlaDeviceType>(device.type()));
282284

test/cpp/test_replication.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "torch_xla/csrc/helpers.h"
1111
#include "torch_xla/csrc/runtime/debug_macros.h"
1212
#include "torch_xla/csrc/runtime/runtime.h"
13+
#include "torch_xla/csrc/status.h"
1314
#include "torch_xla/csrc/tensor_util.h"
1415
#include "torch_xla/csrc/thread_pool.h"
1516
#include "torch_xla/csrc/torch_util.h"
@@ -24,7 +25,7 @@ xla::XlaComputation CreateCrsComputation(const xla::Shape& shape) {
2425
xla::XlaBuilder builder("CrsComputation");
2526
xla::XlaOp x = xla::Parameter(&builder, 0, shape, "x");
2627
xla::CrossReplicaSum(x);
27-
return ConsumeValue(builder.Build());
28+
return GetValueOrThrow(builder.Build());
2829
}
2930

3031
void TestSingleReplication(

test/cpp/test_xla_backend_intf.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "test/cpp/cpp_test_util.h"
44
#include "torch_xla/csrc/runtime/computation_client.h"
5+
#include "torch_xla/csrc/status.h"
56
#include "torch_xla/csrc/tensor_util.h"
67
#include "torch_xla/csrc/xla_backend_impl.h"
78

@@ -52,7 +53,7 @@ xla::XlaComputation CreateAddComputation(const xla::Shape& shape) {
5253
xla::XlaOp x = xla::Parameter(&builder, 0, shape, "x");
5354
xla::XlaOp y = xla::Parameter(&builder, 1, shape, "y");
5455
xla::XlaOp sum = xla::Add(x, y);
55-
return ConsumeValue(builder.Build());
56+
return GetValueOrThrow(builder.Build());
5657
}
5758

5859
TEST(XLABackendTest, TestE2E) {

test/cpp/test_xla_sharding.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "torch_xla/csrc/runtime/env_vars.h"
1515
#include "torch_xla/csrc/runtime/runtime.h"
1616
#include "torch_xla/csrc/runtime/sys_util.h"
17+
#include "torch_xla/csrc/status.h"
1718
#include "torch_xla/csrc/tensor.h"
1819
#include "torch_xla/csrc/tensor_methods.h"
1920
#include "torch_xla/csrc/tensor_util.h"
@@ -385,7 +386,7 @@ TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) {
385386
b.ClearSharding();
386387
auto y = xla::Add(x, xla::ConstantR0<float>(&b, 3));
387388
xla::XlaComputation xla_computation =
388-
ConsumeValue(b.Build(/*remove_dynamic_dimensions=*/false));
389+
GetValueOrThrow(b.Build(/*remove_dynamic_dimensions=*/false));
389390
std::vector<torch_xla::runtime::ComputationClient::CompileInstance> instances;
390391
instances.push_back({std::move(xla_computation),
391392
bridge::GetDefaultDevice()->toString(),

torch_xla/csrc/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ ptxla_cc_library(
125125
":layout_manager",
126126
":shape_builder",
127127
":shape_helper",
128+
":status",
128129
":version",
129130
"//torch_xla/csrc:hash_util",
130131
"//torch_xla/csrc:thread_pool",
@@ -310,6 +311,7 @@ ptxla_cc_library(
310311
deps = [
311312
":device",
312313
":shape_helper",
314+
":status",
313315
":unwrap_data",
314316
"//torch_xla/csrc/runtime:cache",
315317
"//torch_xla/csrc/runtime:computation_client",
@@ -324,6 +326,7 @@ cc_library(
324326
srcs = ["shape_helper.cpp"],
325327
hdrs = ["shape_helper.h"],
326328
deps = [
329+
":status",
327330
"//torch_xla/csrc/runtime:debug_macros",
328331
"@xla//xla/hlo/builder:xla_builder",
329332
],

torch_xla/csrc/convolution.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "torch_xla/csrc/helpers.h"
44
#include "torch_xla/csrc/runtime/debug_macros.h"
55
#include "torch_xla/csrc/shape_helper.h"
6+
#include "torch_xla/csrc/status.h"
67
#include "torch_xla/csrc/xla_lower_util.h"
78
#include "xla/hlo/builder/lib/constants.h"
89

@@ -217,9 +218,9 @@ xla::XlaOp BuildConvBackwardInput(xla::XlaOp grad_output, xla::XlaOp kernel,
217218
MakeConvOpAttrs(spatial_stride, spatial_padding, spatial_dilation, false);
218219
xla::XlaOp kernel_transposed = xla::Transpose(
219220
kernel, FilterTransposePermutation(input_shape.dimensions_size()));
220-
return ConsumeValue(MakeXlaBackpropInputConvOp("conv_backward_input",
221-
input_shape, kernel_transposed,
222-
grad_output, conv_op_attrs));
221+
return GetValueOrThrow(MakeXlaBackpropInputConvOp(
222+
"conv_backward_input", input_shape, kernel_transposed, grad_output,
223+
conv_op_attrs));
223224
}
224225

225226
// Computes the kernel gradient for a convolution.
@@ -237,7 +238,7 @@ xla::XlaOp BuildConvBackwardWeight(xla::XlaOp grad_output, xla::XlaOp input,
237238
xla::InversePermutation(transpose_permutation);
238239
xla::Shape transposed_weight_shape =
239240
xla::ShapeUtil::PermuteDimensions(transpose_permutation, kernel_shape);
240-
xla::XlaOp conv = ConsumeValue(MakeXlaBackpropFilterConvOp(
241+
xla::XlaOp conv = GetValueOrThrow(MakeXlaBackpropFilterConvOp(
241242
"conv_backward_weight", input, transposed_weight_shape, grad_output,
242243
conv_op_attrs));
243244

torch_xla/csrc/debug_util.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "torch_xla/csrc/runtime/debug_macros.h"
2121
#include "torch_xla/csrc/runtime/sys_util.h"
2222
#include "torch_xla/csrc/runtime/xla_util.h"
23+
#include "torch_xla/csrc/status.h"
2324
#include "torch_xla/csrc/xla_graph_executor.h"
2425

2526
namespace torch_xla {
@@ -450,7 +451,7 @@ void DebugUtil::post_compilation_analysis(
450451
// Note that for UserComputation computations, the protobuf is factored in
451452
// the graph hash.
452453
std::string serialized_computation =
453-
ConsumeValue(runtime::util::GetDeterministicSerializedModuleProto(
454+
GetValueOrThrow(runtime::util::GetDeterministicSerializedModuleProto(
454455
computation->computation().proto()));
455456
ss << "\n"
456457
<< "Computation hash: "

torch_xla/csrc/helpers.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "torch_xla/csrc/runtime/tf_logging.h"
1414
#include "torch_xla/csrc/runtime/util.h"
1515
#include "torch_xla/csrc/shape_helper.h"
16+
#include "torch_xla/csrc/status.h"
1617
#include "torch_xla/csrc/tensor_util.h"
1718
#include "xla/hlo/builder/lib/constants.h"
1819
#include "xla/primitive_util.h"
@@ -40,7 +41,7 @@ xla::XlaComputation CreateComputation(
4041
xla::Parameter(&builder, 0, xla::ShapeUtil::MakeShape(type, {}), "x");
4142
xla::XlaOp y =
4243
xla::Parameter(&builder, 1, xla::ShapeUtil::MakeShape(type, {}), "y");
43-
return ConsumeValue(builder.Build(op(x, y)));
44+
return GetValueOrThrow(builder.Build(op(x, y)));
4445
}
4546

4647
xla::XlaComputation CreateMinMaxComputation(const std::string& name,
@@ -65,7 +66,7 @@ xla::XlaComputation CreateMinMaxComputation(const std::string& name,
6566
xla::XlaOp tie_id = xla::Min(lhs_index, rhs_index);
6667
arg_max = xla::Select(eq, tie_id, arg_max);
6768
xla::Tuple(&builder, {max, arg_max});
68-
return ConsumeValue(builder.Build());
69+
return GetValueOrThrow(builder.Build());
6970
}
7071

7172
} // namespace

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
#include "pybind11/pytypes.h"
3737
#include "pybind11/stl.h"
3838
#include "pybind11/stl_bind.h"
39-
#include "status.h"
4039
#include "torch_xla/csrc/XLANativeFunctions.h"
4140
#include "torch_xla/csrc/aten_autograd_ops.h"
4241
#include "torch_xla/csrc/aten_fallback.h"
@@ -892,7 +891,8 @@ std::vector<at::Tensor> XlaUserComputation(
892891

893892
runtime::ComputationClient::ComputationPtr CreateComputation(
894893
const std::string& name, xla::XlaOp root) {
895-
xla::XlaComputation computation = ConsumeValue(root.builder()->Build(root));
894+
xla::XlaComputation computation =
895+
GetValueOrThrow(root.builder()->Build(root));
896896
return std::make_shared<runtime::ComputationClient::Computation>(
897897
name, std::move(computation));
898898
}
@@ -1152,7 +1152,7 @@ class PyLoweringContext {
11521152

11531153
ShardingUtil::SetHloSharding(&lowering_ctx);
11541154

1155-
computation = ConsumeValue(lowering_ctx.BuildXla());
1155+
computation = GetValueOrThrow(lowering_ctx.BuildXla());
11561156
}
11571157

11581158
// Builds a HLO graph given a set of output tensors, and add unused parameters
@@ -1196,7 +1196,7 @@ class PyLoweringContext {
11961196

11971197
ShardingUtil::SetHloSharding(&lowering_ctx);
11981198

1199-
computation = ConsumeValue(lowering_ctx.BuildXla());
1199+
computation = GetValueOrThrow(lowering_ctx.BuildXla());
12001200

12011201
// wrap inputs of cond/body_computation
12021202
if ((GetNameString() == "condctx") || (GetNameString() == "bodyctx")) {
@@ -1207,12 +1207,12 @@ class PyLoweringContext {
12071207
param_shardings = XlaHelpers::ExtractInputShardings(computation);
12081208
}
12091209
xla::ProgramShape program_shape =
1210-
ConsumeValue(computation.GetProgramShape());
1210+
GetValueOrThrow(computation.GetProgramShape());
12111211
// TODO(@manfei): please confirm whether we check for more than two or use
12121212
// default value true
12131213
bool should_wrap_parameter = (program_shape.parameters_size() >= 2);
12141214
if (should_wrap_parameter) {
1215-
computation = ConsumeValue(XlaHelpers::WrapXlaComputation(
1215+
computation = GetValueOrThrow(XlaHelpers::WrapXlaComputation(
12161216
computation, program_shape.parameters(), param_shardings,
12171217
/* buffer_donor_indices */ {}));
12181218
}
@@ -1309,7 +1309,7 @@ class PyLoweringContext {
13091309
// Create a serialized HloModule protobuf from a lowered graph
13101310
py::bytes GetHlo() {
13111311
const xla::HloModuleProto& proto = computation.proto();
1312-
return ConsumeValue(
1312+
return GetValueOrThrow(
13131313
runtime::util::GetDeterministicSerializedModuleProto(proto));
13141314
}
13151315

@@ -2398,7 +2398,7 @@ void InitXlaModuleBindings(py::module m) {
23982398
.def("_xla_set_mat_mul_precision",
23992399
[](const std::string& mat_mul_precision) {
24002400
xla::PrecisionConfig::Precision precision =
2401-
ConsumeValue(xla::StringToPrecision(mat_mul_precision));
2401+
GetValueOrThrow(xla::StringToPrecision(mat_mul_precision));
24022402
XlaHelpers::set_mat_mul_precision(precision);
24032403
})
24042404
.def("_xla_get_mat_mul_precision", []() {
@@ -2447,7 +2447,7 @@ void InitXlaModuleBindings(py::module m) {
24472447
std::string hlo_text;
24482448
{
24492449
NoGilSection nogil;
2450-
hlo_text = ConsumeValue(runtime::util::GetComputationHloText(
2450+
hlo_text = GetValueOrThrow(runtime::util::GetComputationHloText(
24512451
computation->computation()));
24522452
}
24532453
return hlo_text;

0 commit comments

Comments
 (0)