Skip to content

Commit 699289f

Browse files
authored
GetOutputOp: Refactor and Improve error handling. (#9721)
This PR improves error handling of `LoweringContext::GetOutputOp()`, by creating a new `LoweringContext::SafeGetOutputOp()` function that returns an status-like variable. It also refactors the `lowering_context.{h,cpp}` C++ files. **Key Changes:** - Created `SafeGetOutputOp()`, which is an implementation of `GetOutputOp()` function, but propagates the error status down the line - `GetOutputOp()` calls `SafeGetOutputOp()` - `GetOutputOp()` users are not immediately affected - Add missing includes ([IWYU ref][1]) - Remove `OutputMap<T>` declaration (using PyTorch `torch::lazy::OutputMap<T>`) [1]: https://google.github.io/styleguide/cppguide.html#Include_What_You_Use
1 parent 29c3efa commit 699289f

File tree

7 files changed

+75
-32
lines changed

7 files changed

+75
-32
lines changed

torch_xla/csrc/ir.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "torch_xla/csrc/runtime/cache.h"
1818
#include "torch_xla/csrc/runtime/debug_macros.h"
1919
#include "torch_xla/csrc/runtime/sys_util.h"
20+
#include "torch_xla/csrc/status.h"
2021

2122
namespace torch_xla {
2223
namespace {

torch_xla/csrc/ir.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,6 @@ class LoweringContext;
3333

3434
using XlaOpVector = absl::InlinedVector<xla::XlaOp, 1>;
3535

36-
template <typename T>
37-
using OutputMap =
38-
std::unordered_map<torch::lazy::Output, T, torch::lazy::Output::Hasher>;
39-
4036
void DetectDynamicShape(torch::lazy::NodePtr node);
4137

4238
template <typename T, typename... Args>

torch_xla/csrc/lowering_context.cpp

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,35 @@
11
#include "torch_xla/csrc/lowering_context.h"
22

3+
#include <cstddef>
4+
#include <cstdint>
35
#include <memory>
46
#include <optional>
5-
#include <sstream>
6-
#include <stdexcept>
7+
#include <string>
78
#include <unordered_set>
89
#include <utility>
9-
10+
#include <vector>
11+
12+
#include <c10/util/ArrayRef.h>
13+
#include <torch/csrc/lazy/backend/backend_data.h>
14+
#include <torch/csrc/lazy/backend/backend_device.h>
15+
#include <torch/csrc/lazy/backend/lowering_context.h>
16+
#include <torch/csrc/lazy/core/config.h>
17+
#include <torch/csrc/lazy/core/ir.h>
1018
#include <torch/csrc/lazy/core/ir_metadata.h>
19+
#include <torch/csrc/lazy/core/ir_util.h>
1120

12-
#include "absl/log/absl_check.h"
13-
#include "absl/log/absl_log.h"
1421
#include "absl/status/status.h"
22+
#include "absl/status/statusor.h"
1523
#include "absl/strings/str_cat.h"
1624
#include "absl/strings/str_join.h"
1725
#include "absl/strings/str_replace.h"
26+
#include "xla/hlo/builder/xla_builder.h"
27+
#include "xla/hlo/builder/xla_computation.h"
28+
#include "xla/shape.h"
29+
#include "xla/xla_data.pb.h"
1830

1931
#include "torch_xla/csrc/ir.h"
2032
#include "torch_xla/csrc/runtime/computation_client.h"
21-
#include "torch_xla/csrc/runtime/debug_macros.h"
2233
#include "torch_xla/csrc/runtime/sys_util.h"
2334
#include "torch_xla/csrc/shape_helper.h"
2435
#include "torch_xla/csrc/stack_frame_index_builder.h"
@@ -242,21 +253,23 @@ void LoweringContext::AssignOutputOp(const torch::lazy::Output& output,
242253
}
243254

244255
xla::XlaOp LoweringContext::GetOutputOp(const torch::lazy::Output& output) {
245-
auto it = emitted_outputs_.find(output);
256+
XLA_ASSIGN_OR_THROW(xla::XlaOp op, SafeGetOutputOp(output));
257+
return op;
258+
}
246259

247-
if (it == emitted_outputs_.end()) {
248-
const auto post_order =
260+
absl::StatusOr<xla::XlaOp> LoweringContext::SafeGetOutputOp(
261+
const torch::lazy::Output& output) {
262+
if (!CheckOutputIsEmitted(output).ok()) {
263+
const std::vector<const torch::lazy::Node*> post_order =
249264
torch::lazy::Util::ComputePostOrder(output.node, &emit_status_);
250-
for (const auto* const node : post_order) {
251-
XLA_THROW_IF_ERROR(LowerNode(*node));
265+
for (const torch::lazy::Node* const node : post_order) {
266+
XLA_RETURN_IF_ERROR(LowerNode(*node));
252267
}
253268
// At this point the output better be present, otherwise there is an issue
254269
// with the lowering code.
255-
it = emitted_outputs_.find(output);
256-
ABSL_CHECK(it != emitted_outputs_.end())
257-
<< "No XLA operation emitted for output: " << output;
270+
XLA_CHECK_OK(CheckOutputIsEmitted(output));
258271
}
259-
return it->second;
272+
return emitted_outputs_.at(output);
260273
}
261274

262275
absl::StatusOr<XlaOpVector> LoweringContext::LowerNode(
@@ -329,4 +342,15 @@ torch::lazy::ComputationPtr LoweringContext::Build() {
329342
builder_.name(), std::move(xla_computation), device_);
330343
}
331344

345+
absl::Status LoweringContext::CheckOutputIsEmitted(
346+
const torch::lazy::Output& output) const {
347+
torch::lazy::OutputMap<xla::XlaOp>::const_iterator it =
348+
emitted_outputs_.find(output);
349+
if (it == emitted_outputs_.end()) {
350+
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(
351+
absl::StrCat("could not find output: ", output.ToString())));
352+
}
353+
return absl::OkStatus();
354+
}
355+
332356
} // namespace torch_xla

torch_xla/csrc/lowering_context.h

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,29 @@
11
#ifndef XLA_TORCH_XLA_CSRC_LOWERING_CONTEXT_H_
22
#define XLA_TORCH_XLA_CSRC_LOWERING_CONTEXT_H_
33

4+
#include <cstddef>
5+
#include <cstdint>
46
#include <memory>
57
#include <optional>
68
#include <string>
7-
#include <string_view>
89
#include <unordered_map>
9-
#include <utility>
10+
#include <unordered_set>
1011
#include <vector>
1112

13+
#include <c10/util/ArrayRef.h>
1214
#include <torch/csrc/lazy/backend/backend_data.h>
15+
#include <torch/csrc/lazy/backend/backend_device.h>
1316
#include <torch/csrc/lazy/backend/lowering_context.h>
17+
#include <torch/csrc/lazy/core/ir.h>
18+
#include <torch/csrc/lazy/core/ir_metadata.h>
1419
#include <torch/csrc/lazy/core/ir_util.h>
1520

1621
#include "absl/status/status.h"
17-
#include "absl/types/span.h"
18-
#include "tsl/platform/macros.h"
22+
#include "absl/status/statusor.h"
1923
#include "xla/hlo/builder/xla_builder.h"
20-
#include "xla/types.h"
24+
#include "xla/hlo/builder/xla_computation.h"
2125

22-
#include "torch_xla/csrc/device.h"
2326
#include "torch_xla/csrc/ir.h"
24-
#include "torch_xla/csrc/runtime/computation_client.h"
2527

2628
namespace torch_xla {
2729

@@ -74,10 +76,23 @@ class LoweringContext : public torch::lazy::LoweringContext {
7476
// operands among the emitted outputs.
7577
void AssignOutputOp(const torch::lazy::Output& output, xla::XlaOp op);
7678

77-
// Retrieves the lowered operation for a output. If the requested output is
78-
// not available yet, the graph behind the output's XlaNode is lowered, and
79-
// the corresponding XLA operation returned.
80-
xla::XlaOp GetOutputOp(const torch::lazy::Output& output);
79+
// Retrieves the lowered operation for a output.
80+
//
81+
// If the requested output is not available yet, the graph behind the output's
82+
// XlaNode is lowered, and the corresponding XLA operation returned.
83+
[[deprecated("Use SafeGetOutputOp for better error handling.")]] xla::XlaOp
84+
GetOutputOp(const torch::lazy::Output& output);
85+
// Retrieves the lowered operation for a output.
86+
//
87+
// If the requested output is not available yet, the graph behind the output's
88+
// XlaNode is lowered, and the corresponding XLA operation returned.
89+
//
90+
// This function shall return an error status if the lowering the underlying
91+
// `output`, or any other dependent nodes, returns an error status.
92+
// Additionally, it might abort if after the lowering of `output` and its
93+
// dependent nodes, the lowered node for `output` is not available, i.e. not
94+
// in `emitted_outputs_`.
95+
absl::StatusOr<xla::XlaOp> SafeGetOutputOp(const torch::lazy::Output& output);
8196

8297
// Build the XLA computation capturing all the operations created with the
8398
// embedded XLA builder (returned by the builder() API).
@@ -110,7 +125,7 @@ class LoweringContext : public torch::lazy::LoweringContext {
110125

111126
torch::lazy::ComputationPtr Build() override;
112127

113-
const OutputMap<xla::XlaOp> GetEmittedOutputs() const {
128+
const torch::lazy::OutputMap<xla::XlaOp> GetEmittedOutputs() const {
114129
return emitted_outputs_;
115130
}
116131

@@ -124,11 +139,15 @@ class LoweringContext : public torch::lazy::LoweringContext {
124139
size_t index = 0;
125140
};
126141

142+
// Checks whether the given output is already emitted. In other words, whether
143+
// we can find it inside `emitted_outputs_`.
144+
absl::Status CheckOutputIsEmitted(const torch::lazy::Output& output) const;
145+
127146
xla::XlaBuilder builder_;
128147
std::unordered_map<torch::lazy::BackendData::Handle, Parameter>
129148
parameters_map_;
130149
std::vector<xla::XlaOp> root_tuple_;
131-
OutputMap<xla::XlaOp> emitted_outputs_;
150+
torch::lazy::OutputMap<xla::XlaOp> emitted_outputs_;
132151
std::string name_;
133152

134153
std::shared_ptr<StackFrameIndexBuilder> stack_frame_index_builder_;

torch_xla/csrc/ops/custom_call.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#include "torch_xla/csrc/lowering_context.h"
66
#include "torch_xla/csrc/ops/xla_ops.h"
7+
#include "torch_xla/csrc/runtime/debug_macros.h"
78
#include "torch_xla/csrc/shape_helper.h"
89

910
namespace torch_xla {

torch_xla/csrc/ops/dot_general.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "torch_xla/csrc/lowering_context.h"
88
#include "torch_xla/csrc/ops/infer_output_shape.h"
99
#include "torch_xla/csrc/ops/xla_ops.h"
10+
#include "torch_xla/csrc/runtime/debug_macros.h"
1011

1112
namespace torch_xla {
1213

torch_xla/csrc/ops/symeig.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "xla/hlo/builder/lib/self_adjoint_eig.h"
66

77
#include "torch_xla/csrc/lowering_context.h"
8+
#include "torch_xla/csrc/runtime/debug_macros.h"
89
#include "torch_xla/csrc/shape_helper.h"
910

1011
namespace torch_xla {

0 commit comments

Comments
 (0)