36
36
#include " pybind11/pytypes.h"
37
37
#include " pybind11/stl.h"
38
38
#include " pybind11/stl_bind.h"
39
- #include " status.h"
40
39
#include " torch_xla/csrc/XLANativeFunctions.h"
41
40
#include " torch_xla/csrc/aten_autograd_ops.h"
42
41
#include " torch_xla/csrc/aten_fallback.h"
@@ -892,7 +891,8 @@ std::vector<at::Tensor> XlaUserComputation(
892
891
893
892
runtime::ComputationClient::ComputationPtr CreateComputation (
894
893
const std::string& name, xla::XlaOp root) {
895
- xla::XlaComputation computation = ConsumeValue (root.builder ()->Build (root));
894
+ xla::XlaComputation computation =
895
+ GetValueOrThrow (root.builder ()->Build (root));
896
896
return std::make_shared<runtime::ComputationClient::Computation>(
897
897
name, std::move (computation));
898
898
}
@@ -1152,7 +1152,7 @@ class PyLoweringContext {
1152
1152
1153
1153
ShardingUtil::SetHloSharding (&lowering_ctx);
1154
1154
1155
- computation = ConsumeValue (lowering_ctx.BuildXla ());
1155
+ computation = GetValueOrThrow (lowering_ctx.BuildXla ());
1156
1156
}
1157
1157
1158
1158
// Builds a HLO graph given a set of output tensors, and add unused parameters
@@ -1196,7 +1196,7 @@ class PyLoweringContext {
1196
1196
1197
1197
ShardingUtil::SetHloSharding (&lowering_ctx);
1198
1198
1199
- computation = ConsumeValue (lowering_ctx.BuildXla ());
1199
+ computation = GetValueOrThrow (lowering_ctx.BuildXla ());
1200
1200
1201
1201
// wrap inputs of cond/body_computation
1202
1202
if ((GetNameString () == " condctx" ) || (GetNameString () == " bodyctx" )) {
@@ -1207,12 +1207,12 @@ class PyLoweringContext {
1207
1207
param_shardings = XlaHelpers::ExtractInputShardings (computation);
1208
1208
}
1209
1209
xla::ProgramShape program_shape =
1210
- ConsumeValue (computation.GetProgramShape ());
1210
+ GetValueOrThrow (computation.GetProgramShape ());
1211
1211
// TODO(@manfei): please confirm whether we check for more than two or use
1212
1212
// default value true
1213
1213
bool should_wrap_parameter = (program_shape.parameters_size () >= 2 );
1214
1214
if (should_wrap_parameter) {
1215
- computation = ConsumeValue (XlaHelpers::WrapXlaComputation (
1215
+ computation = GetValueOrThrow (XlaHelpers::WrapXlaComputation (
1216
1216
computation, program_shape.parameters (), param_shardings,
1217
1217
/* buffer_donor_indices */ {}));
1218
1218
}
@@ -1309,7 +1309,7 @@ class PyLoweringContext {
1309
1309
// Create a serialized HloModule protobuf from a lowered graph
1310
1310
py::bytes GetHlo () {
1311
1311
const xla::HloModuleProto& proto = computation.proto ();
1312
- return ConsumeValue (
1312
+ return GetValueOrThrow (
1313
1313
runtime::util::GetDeterministicSerializedModuleProto (proto));
1314
1314
}
1315
1315
@@ -2398,7 +2398,7 @@ void InitXlaModuleBindings(py::module m) {
2398
2398
.def (" _xla_set_mat_mul_precision" ,
2399
2399
[](const std::string& mat_mul_precision) {
2400
2400
xla::PrecisionConfig::Precision precision =
2401
- ConsumeValue (xla::StringToPrecision (mat_mul_precision));
2401
+ GetValueOrThrow (xla::StringToPrecision (mat_mul_precision));
2402
2402
XlaHelpers::set_mat_mul_precision (precision);
2403
2403
})
2404
2404
.def (" _xla_get_mat_mul_precision" , []() {
@@ -2447,7 +2447,7 @@ void InitXlaModuleBindings(py::module m) {
2447
2447
std::string hlo_text;
2448
2448
{
2449
2449
NoGilSection nogil;
2450
- hlo_text = ConsumeValue (runtime::util::GetComputationHloText (
2450
+ hlo_text = GetValueOrThrow (runtime::util::GetComputationHloText (
2451
2451
computation->computation ()));
2452
2452
}
2453
2453
return hlo_text;
0 commit comments