Skip to content

Commit f5a2218

Browse files
authored
torch_xla: Use new macros for throwing exceptions (part 1). (#9593)
Follow-up: #9588 and #9580 Target: `torch_xla/csrc` directory In summary, this PR: - Replaces all calls to `OkOrThrow()` and `GetValueOrThrow()` (that throws an exception without source location information of the *"throw-site"*) with the macros `XLA_THROW_IF_ERROR()` and `XLA_ASSIGN_OR_THROW()`. - Corresponds to the fine-grained set of PRs that came from breaking down PR #9580 - Focuses on the `torch_xla/csrc` directory, replacing every use of those, now deprecated, functions by the newly introduced macros. _Note: since there were lots of files in `torch_xla/csrc` that needed update, they were split in multiple parts._
1 parent d4cf42a commit f5a2218

10 files changed

+48
-33
lines changed

torch_xla/csrc/aten_xla_bridge.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,9 @@ torch_xla::XLATensorPtr GetXlaTensorOrCreateForWrappedNumber(
170170
(tensor.dim() == 0 && tensor.numel() == 1)) {
171171
return torch_xla::bridge::GetOrCreateXlaTensor(tensor, device);
172172
} else {
173-
return GetValueOrThrow(torch_xla::bridge::GetXlaTensor(tensor));
173+
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_tensor,
174+
torch_xla::bridge::GetXlaTensor(tensor));
175+
return xla_tensor;
174176
}
175177
}
176178

@@ -186,9 +188,13 @@ XLATensorPtr GetOrCreateXlaTensor(const at::Tensor& tensor,
186188
}
187189

188190
auto xtensor = GetXlaTensor(tensor);
189-
return xtensor.ok()
190-
? xtensor.value()
191-
: GetValueOrThrow(XLATensor::Create(inner_tensor, device));
191+
if (xtensor.ok()) {
192+
return xtensor.value();
193+
}
194+
195+
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_tensor,
196+
XLATensor::Create(inner_tensor, device));
197+
return xla_tensor;
192198
}
193199

194200
XLATensorPtr GetOrCreateXlaTensor(const std::optional<at::Tensor>& tensor,
@@ -479,8 +485,8 @@ at::Tensor CreateXlaTensor(
479485
at::Tensor tensor,
480486
const std::optional<torch::lazy::BackendDevice>& device) {
481487
if (tensor.defined() && device) {
482-
XLATensorPtr xla_tensor =
483-
GetValueOrThrow(XLATensor::Create(std::move(tensor), *device));
488+
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_tensor,
489+
XLATensor::Create(std::move(tensor), *device));
484490
tensor = AtenFromXlaTensor(xla_tensor);
485491
}
486492
return tensor;

torch_xla/csrc/debug_util.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -450,9 +450,9 @@ void DebugUtil::post_compilation_analysis(
450450
// This can be used to verify the hash of the underlying computation proto.
451451
// Note that for UserComputation computations, the protobuf is factored in
452452
// the graph hash.
453-
std::string serialized_computation =
454-
GetValueOrThrow(runtime::util::GetDeterministicSerializedModuleProto(
455-
computation->computation().proto()));
453+
XLA_ASSIGN_OR_THROW(std::string serialized_computation,
454+
runtime::util::GetDeterministicSerializedModuleProto(
455+
computation->computation().proto()));
456456
ss << "\n"
457457
<< "Computation hash: "
458458
<< torch::lazy::HashToString(torch::lazy::Hash(serialized_computation))

torch_xla/csrc/dl_convertor.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,10 @@ DLManagedTensor* toDLPack(const at::Tensor& input) {
141141
DLTensor& dt = pack->tensor.dl_tensor;
142142
{
143143
// AcquireExternalReference may block
144-
pack->external_reference =
145-
GetValueOrThrow(pjrt_buffer->AcquireExternalReference());
144+
XLA_ASSIGN_OR_THROW(pack->external_reference,
145+
pjrt_buffer->AcquireExternalReference());
146146
xla::PjRtFuture<> future = pjrt_buffer->GetReadyFuture();
147-
OkOrThrow(future.Await());
147+
XLA_THROW_IF_ERROR(future.Await());
148148
}
149149
pack->buffer_reference = pjrt_buffer;
150150

@@ -329,8 +329,9 @@ at::Tensor fromDLPack(DLManagedTensor* dlmt) {
329329
if (dlmt->deleter) {
330330
on_delete_callback = [dlmt]() { dlmt->deleter(dlmt); };
331331
}
332-
std::unique_ptr<xla::PjRtBuffer> pjrt_buffer =
333-
GetValueOrThrow(device->client()->CreateViewOfDeviceBuffer(
332+
XLA_ASSIGN_OR_THROW(
333+
std::unique_ptr<xla::PjRtBuffer> pjrt_buffer,
334+
device->client()->CreateViewOfDeviceBuffer(
334335
static_cast<char*>(dlmt->dl_tensor.data) +
335336
dlmt->dl_tensor.byte_offset,
336337
shape, *device->default_memory_space(), on_delete_callback));

torch_xla/csrc/helpers.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ xla::XlaComputation CreateComputation(
4141
xla::Parameter(&builder, 0, xla::ShapeUtil::MakeShape(type, {}), "x");
4242
xla::XlaOp y =
4343
xla::Parameter(&builder, 1, xla::ShapeUtil::MakeShape(type, {}), "y");
44-
return GetValueOrThrow(builder.Build(op(x, y)));
44+
XLA_ASSIGN_OR_THROW(xla::XlaComputation computation, builder.Build(op(x, y)));
45+
return computation;
4546
}
4647

4748
xla::XlaComputation CreateMinMaxComputation(const std::string& name,
@@ -66,7 +67,8 @@ xla::XlaComputation CreateMinMaxComputation(const std::string& name,
6667
xla::XlaOp tie_id = xla::Min(lhs_index, rhs_index);
6768
arg_max = xla::Select(eq, tie_id, arg_max);
6869
xla::Tuple(&builder, {max, arg_max});
69-
return GetValueOrThrow(builder.Build());
70+
XLA_ASSIGN_OR_THROW(xla::XlaComputation min_max_computation, builder.Build());
71+
return min_max_computation;
7072
}
7173

7274
} // namespace
@@ -697,7 +699,8 @@ std::vector<int64_t> XlaHelpers::getBroadcastDimensions(xla::XlaOp op1,
697699
xla::Shape XlaHelpers::GetPromotedBinaryOpShape(const xla::Shape& shape1,
698700
const xla::Shape& shape2) {
699701
if (!shape1.is_dynamic() && !shape2.is_dynamic()) {
700-
auto promoted_shape = GetValueOrThrow(GetPromotedShape(shape1, shape2));
702+
XLA_ASSIGN_OR_THROW(xla::Shape promoted_shape,
703+
GetPromotedShape(shape1, shape2));
701704
return xla::ShapeUtil::MakeShape(
702705
PromoteType(shape1.element_type(), shape2.element_type()),
703706
promoted_shape.dimensions());
@@ -776,7 +779,7 @@ std::pair<xla::XlaOp, xla::XlaOp> XlaHelpers::PromoteShapes(xla::XlaOp op1,
776779
const xla::Shape& shape1 = ShapeHelper::ShapeOfXlaOp(op1);
777780
const xla::Shape& shape2 = ShapeHelper::ShapeOfXlaOp(op2);
778781

779-
xla::Shape shape = GetValueOrThrow(GetPromotedShape(shape1, shape2));
782+
XLA_ASSIGN_OR_THROW(xla::Shape shape, GetPromotedShape(shape1, shape2));
780783
if (shape1.is_unbounded_dynamic() || shape2.is_unbounded_dynamic()) {
781784
return ImplicitBroadcastWithUnboundedDynamicShapes(op1, op2, shape);
782785
}

torch_xla/csrc/lowering_context.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ void LoweringContext::AddParameter(const torch::lazy::Output& output,
325325
}
326326

327327
torch::lazy::ComputationPtr LoweringContext::Build() {
328-
xla::XlaComputation xla_computation = GetValueOrThrow(BuildXla());
328+
XLA_ASSIGN_OR_THROW(xla::XlaComputation xla_computation, BuildXla());
329329
return std::make_shared<runtime::ComputationClient::Computation>(
330330
builder_.name(), std::move(xla_computation), device_);
331331
}

torch_xla/csrc/tensor.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -518,8 +518,8 @@ at::Tensor XLATensor::ToTensor(bool detached) {
518518
XLAGraphExecutor::Get()->DeviceBarrier(GetDevice());
519519
// The GetXlaData() call will trigger an ApplyPendingGraph() if an IR
520520
// XlaNode is available on the tensor.
521-
std::vector<at::Tensor> tensors =
522-
GetValueOrThrow(XlaDataToTensors({GetXlaData()}, {dtype()}));
521+
XLA_ASSIGN_OR_THROW(std::vector<at::Tensor> tensors,
522+
XlaDataToTensors({GetXlaData()}, {dtype()}));
523523
tensor = std::move(tensors.front());
524524
if (!detached) {
525525
SetTensorData(tensor);
@@ -627,7 +627,9 @@ std::vector<XLATensorPtr> XLATensor::MakeOutputTensors(
627627
XLATensorPtr XLATensor::CopyTensorToDevice(
628628
const torch::lazy::BackendDevice& device) {
629629
// TODO: This can be optimized via proper XRT/XLA computation.
630-
return GetValueOrThrow(Create(ToTensor(/*detached=*/true), device));
630+
XLA_ASSIGN_OR_THROW(XLATensorPtr result,
631+
Create(ToTensor(/*detached=*/true), device));
632+
return result;
631633
}
632634

633635
torch::lazy::Value XLATensor::MaybeCastIrValue(

torch_xla/csrc/xla_backend_impl.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,9 @@ class XlaBackendImpl : public torch::lazy::BackendImplInterface {
9494
const torch::lazy::BackendDataPtr data,
9595
std::optional<at::ScalarType> logical_scalar_type) const override {
9696
// TODO(JackCaoG): handle the logical_scalar_type == nullptr case
97-
return GetValueOrThrow(XlaDataToTensors({data}, {*logical_scalar_type}))[0];
97+
XLA_ASSIGN_OR_THROW(std::vector<at::Tensor> tensors,
98+
XlaDataToTensors({data}, {*logical_scalar_type}));
99+
return tensors[0];
98100
}
99101

100102
std::unique_ptr<torch::lazy::LoweringContext> CreateLoweringContext(
@@ -163,7 +165,8 @@ class XlaBackendImpl : public torch::lazy::BackendImplInterface {
163165
torch::lazy::ComputationPtr computation,
164166
c10::ArrayRef<torch::lazy::BackendDataPtr> arguments,
165167
const torch::lazy::BackendDevice& device) const override {
166-
std::vector<runtime::ComputationClient::DataPtr> results = GetValueOrThrow(
168+
XLA_ASSIGN_OR_THROW(
169+
std::vector<runtime::ComputationClient::DataPtr> results,
167170
runtime::GetComputationClientOrDie()->ExecuteComputation(
168171
*std::dynamic_pointer_cast<runtime::ComputationClient::Computation>(
169172
computation),

torch_xla/csrc/xla_manual_registration.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ at::Tensor nms_kernel(const at::Tensor& boxes, const at::Tensor& scores,
3838
XLA_CHECK_EQ(boxes.size(0), scores.size(0))
3939
<< "nms(): boxes and scores should have the same size for dimension 0.";
4040

41-
XLATensorPtr xla_boxes = GetValueOrThrow(bridge::GetXlaTensor(boxes));
42-
XLATensorPtr xla_scores = GetValueOrThrow(bridge::GetXlaTensor(scores));
41+
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_boxes, bridge::GetXlaTensor(boxes));
42+
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_scores, bridge::GetXlaTensor(scores));
4343
return bridge::AtenFromXlaTensor(
4444
tensor_methods::nms(xla_boxes, xla_scores, iou_threshold),
4545
/*skip_functionalization=*/true);

torch_xla/csrc/xla_op_builder.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -690,8 +690,8 @@ xla::XlaOp ConcatInDim(const BuilderPtr& builder,
690690
xla::XlaOp Convert(const BuilderPtr& builder,
691691
const std::vector<OpPtr>& operands, py::dict args) {
692692
std::string type = args["to_type"].cast<std::string>();
693-
xla::PrimitiveType xla_type =
694-
GetValueOrThrow(xla::primitive_util::StringToPrimitiveType(type));
693+
XLA_ASSIGN_OR_THROW(xla::PrimitiveType xla_type,
694+
xla::primitive_util::StringToPrimitiveType(type));
695695
return MaybeConvertTo(operands.at(0)->op, xla_type);
696696
}
697697

@@ -717,8 +717,8 @@ xla::XlaOp SetDimensionSize(const BuilderPtr& builder,
717717
xla::XlaOp BitcastConvert(const BuilderPtr& builder,
718718
const std::vector<OpPtr>& operands, py::dict args) {
719719
std::string type = args["to_type"].cast<std::string>();
720-
xla::PrimitiveType xla_type =
721-
GetValueOrThrow(xla::primitive_util::StringToPrimitiveType(type));
720+
XLA_ASSIGN_OR_THROW(xla::PrimitiveType xla_type,
721+
xla::primitive_util::StringToPrimitiveType(type));
722722
return xla::BitcastConvertType(operands.at(0)->op, xla_type);
723723
}
724724

@@ -873,8 +873,8 @@ xla::Shape PyShapeToShape(py::object shape) {
873873
std::string type = py_shape["type"].cast<std::string>();
874874
std::vector<int64_t> dimensions =
875875
GetTupleVector<int64_t>(py_shape["sizes"].cast<py::tuple>());
876-
xla::PrimitiveType xla_type =
877-
GetValueOrThrow(xla::primitive_util::StringToPrimitiveType(type));
876+
XLA_ASSIGN_OR_THROW(xla::PrimitiveType xla_type,
877+
xla::primitive_util::StringToPrimitiveType(type));
878878
if (py_shape.contains("dynamic_dimensions")) {
879879
std::vector<bool> dynamic_dimensions =
880880
GetTupleVector<bool>(py_shape["dynamic_dimensions"]);

torch_xla/csrc/xla_sharding_util.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -767,7 +767,7 @@ void ShardingUtil::XlaMarkSharding(const at::Tensor& input,
767767
<< "Please enable SPMD via `torch_xla.runtime.use_spmd()`";
768768
XLA_CHECK(sharding.type() != xla::OpSharding::UNKNOWN)
769769
<< "Can't explicilty annotate with UNKNOWN sharding type.";
770-
XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(input));
770+
XLA_ASSIGN_OR_THROW(XLATensorPtr xtensor, bridge::GetXlaTensor(input));
771771

772772
// For Non DeviceData IR values, we directly attach the sharding spec to the
773773
// xtensor.

0 commit comments

Comments
 (0)