Skip to content

Commit c2aa21e

Browse files
Switch XLA_CHECKs to ABSL_CHECKs in lowering_context. (#9338)
Co-authored-by: Zhanyong Wan <[email protected]>
1 parent 1d40939 commit c2aa21e

File tree

4 files changed

+91
-84
lines changed

4 files changed

+91
-84
lines changed

torch_xla/csrc/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,8 @@ ptxla_cc_library(
308308
":unwrap_data",
309309
"//torch_xla/csrc/runtime:cache",
310310
"//torch_xla/csrc/runtime:computation_client",
311+
"@com_google_absl//absl/log:absl_check",
312+
"@com_google_absl//absl/log:absl_log",
311313
"@com_google_absl//absl/types:span",
312314
],
313315
)

torch_xla/csrc/ir.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ class XlaNode : public torch::lazy::Node {
100100
XlaNode(torch::lazy::OpKind op, xla::Shape xla_shape, size_t num_outputs,
101101
torch::lazy::hash_t hash_seed);
102102

103-
virtual ~XlaNode();
103+
~XlaNode() override;
104104

105105
// Retrieves the full shape of the IR XlaNode. Note that if this is a
106106
// multi-output node, the returned shape will be a tuple.

torch_xla/csrc/lowering_context.cpp

Lines changed: 85 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -2,66 +2,72 @@
22

33
#include <torch/csrc/lazy/core/ir_metadata.h>
44

5-
#include <iostream>
65
#include <optional>
76
#include <sstream>
87
#include <stdexcept>
9-
#include <string_view>
8+
#include <utility>
109

10+
#include "absl/log/absl_check.h"
11+
#include "absl/log/absl_log.h"
1112
#include "absl/status/status.h"
1213
#include "absl/strings/str_cat.h"
13-
#include "absl/strings/str_join.h"
1414
#include "absl/strings/str_replace.h"
15-
#include "absl/strings/string_view.h"
1615
#include "torch_xla/csrc/ir.h"
1716
#include "torch_xla/csrc/runtime/computation_client.h"
1817
#include "torch_xla/csrc/runtime/debug_macros.h"
1918
#include "torch_xla/csrc/runtime/sys_util.h"
2019
#include "torch_xla/csrc/shape_helper.h"
2120
#include "torch_xla/csrc/stack_frame_index_builder.h"
22-
#include "torch_xla/csrc/unwrap_data.h"
2321

2422
namespace torch_xla {
2523

2624
namespace {
2725

2826
class HloMetadataSetter {
2927
public:
30-
HloMetadataSetter(LoweringContext* loctx, const torch::lazy::Node* node) {
28+
HloMetadataSetter(LoweringContext& lowering_context,
29+
const torch::lazy::Node& node)
30+
: lowering_context_(lowering_context) {
3131
if (ShouldPopulateXlaOpMetadata()) {
32-
PopulateXlaOpMetadata(loctx, node);
33-
loctx_ = loctx;
32+
PopulateXlaOpMetadata(lowering_context, node);
3433
}
3534
}
3635

36+
// This class is neither copyable nor movable.
37+
HloMetadataSetter(const HloMetadataSetter&) = delete;
38+
HloMetadataSetter& operator=(const HloMetadataSetter&) = delete;
39+
HloMetadataSetter(HloMetadataSetter&&) = delete;
40+
HloMetadataSetter& operator=(HloMetadataSetter&&) = delete;
41+
3742
~HloMetadataSetter() {
38-
if (loctx_ != nullptr) {
39-
loctx_->builder()->ClearOpMetadata();
43+
if (ShouldPopulateXlaOpMetadata()) {
44+
lowering_context_.builder()->ClearOpMetadata();
4045
}
4146
}
4247

4348
private:
49+
// Returns true iff this class should populate XLA op metadata in its
50+
// constructor.
4451
static bool ShouldPopulateXlaOpMetadata() {
45-
static bool op_metadata =
52+
static const bool op_metadata =
4653
runtime::sys_util::GetEnvBool("XLA_HLO_DEBUG", false);
4754
return FLAGS_torch_lazy_ir_debug || op_metadata;
4855
}
4956

50-
static void PopulateXlaOpMetadata(LoweringContext* loctx,
51-
const torch::lazy::Node* node) {
57+
static void PopulateXlaOpMetadata(LoweringContext& lowering_context,
58+
const torch::lazy::Node& node) {
5259
xla::OpMetadata metadata;
5360
// NOTE: we apply some string manipulation as xprof backend utility
5461
// for nesting/grouping traces depends on certain op name/type
5562
// patterns for classification.
5663
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/profiler/utils/tf_op_utils.cc#L55
57-
std::string op_type =
58-
absl::StrReplaceAll(node->op().ToString(), {{":", "_"}});
64+
const std::string op_type =
65+
absl::StrReplaceAll(node.op().ToString(), {{":", "_"}});
5966
metadata.set_op_type(op_type);
6067

61-
const torch::lazy::MetaData& nmeta = node->metadata();
62-
63-
const CustomOpNameMetaData* custom_opname_meta =
64-
dynamic_cast<const CustomOpNameMetaData*>(node->user_metadata());
68+
const torch::lazy::MetaData& nmeta = node.metadata();
69+
auto* const custom_opname_meta =
70+
dynamic_cast<const CustomOpNameMetaData*>(node.user_metadata());
6571

6672
std::string op_name_prefix;
6773
size_t max_stack_depth = nmeta.frame_info.size();
@@ -78,66 +84,65 @@ class HloMetadataSetter {
7884
metadata.set_op_name(absl::StrCat(op_name_prefix, op_type));
7985

8086
// Sets file, line and stack_frame_id in metadata
81-
loctx->stack_frame_index_builder()->AddStackFrameLocations(
82-
nmeta.frame_info, max_stack_depth, metadata);
87+
lowering_context.stack_frame_index_builder()->AddStackFrameLocations(
88+
nmeta.frame_info, static_cast<int>(max_stack_depth), metadata);
8389

84-
loctx->builder()->SetOpMetadata(std::move(metadata));
90+
lowering_context.builder()->SetOpMetadata(std::move(metadata));
8591
}
8692

87-
LoweringContext* loctx_ = nullptr;
93+
LoweringContext& lowering_context_;
8894
};
8995

9096
} // namespace
9197

9298
LoweringContext::LoweringContext(const std::string& name,
9399
torch::lazy::BackendDevice device)
94-
: torch::lazy::LoweringContext(name, device),
100+
: torch::lazy::LoweringContext(name, std::move(device)),
95101
builder_(name),
96102
stack_frame_index_builder_(std::make_shared<StackFrameIndexBuilder>()) {}
97103

98104
LoweringContext::LoweringContext(
99105
const std::string& name, torch::lazy::BackendDevice device,
100-
c10::ArrayRef<const torch::lazy::Node*> post_order,
106+
const c10::ArrayRef<const torch::lazy::Node*> post_order,
101107
torch::lazy::Util::EmissionMap emit_status)
102-
: torch::lazy::LoweringContext(name, device, {}, emit_status),
108+
: torch::lazy::LoweringContext(name, std::move(device), {},
109+
std::move(emit_status)),
103110
builder_(name),
104111
stack_frame_index_builder_(std::make_shared<StackFrameIndexBuilder>()) {
105-
for (auto node : post_order) {
106-
LowerNode(node);
112+
for (const auto* node : post_order) {
113+
LowerNode(*node);
107114
}
108115
}
109116

110-
// TODO(lsy323): Get reserved number for unbounded dim after it's added in XLA.
111-
static constexpr int64_t kUnboundedSize = std::numeric_limits<int64_t>::min();
112-
113117
xla::XlaOp LoweringContext::GetParameter(
114118
const std::shared_ptr<torch::lazy::BackendData>& backend_data,
115119
const std::unordered_set<uint32_t>& unbounded_dynamic_dims) {
116-
torch::lazy::BackendData::Handle handle = backend_data->GetHandle();
120+
const torch::lazy::BackendData::Handle handle = backend_data->GetHandle();
117121
auto it = parameters_map_.find(handle);
118122
if (it == parameters_map_.end()) {
119-
auto data = std::dynamic_pointer_cast<runtime::ComputationClient::Data>(
120-
backend_data);
121-
XLA_CHECK(data != nullptr);
123+
auto* const data =
124+
dynamic_cast<runtime::ComputationClient::Data*>(backend_data.get());
125+
ABSL_CHECK(data != nullptr);
122126
xla::Shape shape = data->shape();
123127
for (const int dim : unbounded_dynamic_dims) {
124128
shape.set_dynamic_dimension(dim, true);
125-
shape.set_dimensions(dim, kUnboundedSize);
129+
shape.set_dimensions(dim, xla::Shape::kUnboundedSize);
126130
}
127-
size_t param_index = parameters_.size();
128-
std::string param_name = absl::StrCat("p", param_index);
131+
const size_t param_index = parameters_.size();
132+
const std::string param_name = absl::StrCat("p", param_index);
129133
xla::XlaOp param;
130134
if (data->HasSharding()) {
131-
xla::OpSharding sharding = data->GetSharding();
132-
xla::XlaScopedShardingAssignment scoped_sharding(builder(), sharding);
135+
const xla::OpSharding sharding = data->GetSharding();
136+
const xla::XlaScopedShardingAssignment scoped_sharding(builder(),
137+
sharding);
133138
param = xla::Parameter(builder(), param_index, shape, param_name);
134139
} else {
135140
param = xla::Parameter(builder(), param_index, shape, param_name);
136141
}
137142
it = parameters_map_.emplace(handle, Parameter{param, param_index}).first;
138143
parameters_.push_back(backend_data);
139144
} else {
140-
XLA_CHECK(unbounded_dynamic_dims.empty())
145+
ABSL_CHECK(unbounded_dynamic_dims.empty())
141146
<< "The unbounded dynamic dims can only be set when Parameter is "
142147
"created.";
143148
}
@@ -147,8 +152,8 @@ xla::XlaOp LoweringContext::GetParameter(
147152

148153
std::optional<size_t> LoweringContext::GetParameterId(
149154
const std::shared_ptr<torch::lazy::BackendData>& backend_data) const {
150-
torch::lazy::BackendData::Handle handle = backend_data->GetHandle();
151-
auto it = parameters_map_.find(handle);
155+
const torch::lazy::BackendData::Handle handle = backend_data->GetHandle();
156+
const auto it = parameters_map_.find(handle);
152157
if (it == parameters_map_.end()) {
153158
return std::nullopt;
154159
}
@@ -164,12 +169,12 @@ const std::vector<size_t>& LoweringContext::GetParameterSequence() const {
164169
return parameter_sequence_;
165170
}
166171

167-
xla::XlaOp LoweringContext::GetResult(size_t index) const {
172+
xla::XlaOp LoweringContext::GetResult(const size_t index) const {
168173
return root_tuple_.at(index);
169174
}
170175

171-
void LoweringContext::SetResult(size_t index, xla::XlaOp op) {
172-
root_tuple_.at(index) = std::move(op);
176+
void LoweringContext::SetResult(const size_t index, const xla::XlaOp op) {
177+
root_tuple_.at(index) = op;
173178
}
174179

175180
absl::StatusOr<xla::XlaComputation> LoweringContext::BuildXla() {
@@ -181,7 +186,7 @@ absl::StatusOr<xla::XlaComputation> LoweringContext::BuildXla() {
181186
((get_name_string() == "condctx") or (get_name_string() == "bodyctx"))) {
182187
xla = builder()->Build(root_tuple_.at(0));
183188
} else if (!root_tuple_.empty()) {
184-
xla::XlaOp root = xla::Tuple(builder(), root_tuple_);
189+
const xla::XlaOp root = xla::Tuple(builder(), root_tuple_);
185190
xla = builder()->Build(root);
186191
} else {
187192
xla = builder()->Build();
@@ -195,8 +200,9 @@ absl::StatusOr<xla::XlaComputation> LoweringContext::BuildXla() {
195200
return xla;
196201
}
197202

198-
absl::StatusOr<xla::XlaComputation> LoweringContext::BuildXla(xla::XlaOp root) {
199-
XLA_CHECK(root_tuple_.empty());
203+
absl::StatusOr<xla::XlaComputation> LoweringContext::BuildXla(
204+
const xla::XlaOp root) {
205+
ABSL_CHECK(root_tuple_.empty());
200206
auto xla = builder()->Build(root);
201207

202208
if (xla.ok()) {
@@ -208,72 +214,71 @@ absl::StatusOr<xla::XlaComputation> LoweringContext::BuildXla(xla::XlaOp root) {
208214
}
209215

210216
void LoweringContext::AssignOutputOp(const torch::lazy::Output& output,
211-
xla::XlaOp op) {
212-
emitted_outputs_[output] = std::move(op);
217+
const xla::XlaOp op) {
218+
emitted_outputs_[output] = op;
213219
}
214220

215221
xla::XlaOp LoweringContext::GetOutputOp(const torch::lazy::Output& output) {
216222
auto it = emitted_outputs_.find(output);
217223

218224
if (it == emitted_outputs_.end()) {
219-
auto post_order =
225+
const auto post_order =
220226
torch::lazy::Util::ComputePostOrder(output.node, &emit_status_);
221-
for (auto node : post_order) {
222-
LowerNode(node);
227+
for (const auto* const node : post_order) {
228+
LowerNode(*node);
223229
}
224230
// At this point the output better be present, otherwise there is an issue
225231
// with the lowering code.
226232
it = emitted_outputs_.find(output);
227-
XLA_CHECK(it != emitted_outputs_.end())
233+
ABSL_CHECK(it != emitted_outputs_.end())
228234
<< "No XLA operation emitted for output: " << output;
229235
}
230236
return it->second;
231237
}
232238

233-
XlaOpVector LoweringContext::LowerNode(const torch::lazy::Node* node) {
239+
XlaOpVector LoweringContext::LowerNode(const torch::lazy::Node& node) {
234240
XlaOpVector result_ops;
235241
try {
236-
HloMetadataSetter meta_setter(this, node);
237-
238-
const XlaNode* casted = dynamic_cast<const XlaNode*>(node);
242+
const HloMetadataSetter meta_setter(*this, node);
243+
const XlaNode* const casted = dynamic_cast<const XlaNode*>(&node);
239244

240245
result_ops = casted->Lower(this);
241246
if (!casted->dynamic_dims().empty()) {
242-
xla::internal::XlaBuilderFriend builder_friend;
243-
auto* inst = builder_friend.GetInstruction(result_ops[0]);
244-
auto* mutable_dynamic =
247+
const xla::internal::XlaBuilderFriend builder_friend;
248+
auto* const inst = builder_friend.GetInstruction(result_ops[0]);
249+
auto* const mutable_dynamic =
245250
inst->mutable_shape()->mutable_is_dynamic_dimension();
246251
if (mutable_dynamic->empty()) {
247252
for (int i = 0; i < inst->dimensions_size(); i++) {
248253
mutable_dynamic->Add(false);
249254
}
250255
}
251-
auto* mutable_dims = inst->mutable_shape()->mutable_dimensions();
256+
auto* const mutable_dims = inst->mutable_shape()->mutable_dimensions();
252257
for (const auto dim : casted->dynamic_dims()) {
253258
mutable_dynamic->Set(dim, true);
254-
mutable_dims->Set(dim, kUnboundedSize);
259+
mutable_dims->Set(dim, xla::Shape::kUnboundedSize);
255260
}
256261
}
257262
} catch (const std::exception& ex) {
258263
ReportBuilderError(node, ex.what());
259264
}
260265
if (!builder()->first_error().ok()) {
261-
ReportBuilderError(node, /*error_msg=*/nullptr);
266+
ReportBuilderError(node, /*error_msg=*/"");
262267
}
263268
return result_ops;
264269
}
265270

266-
void LoweringContext::ReportBuilderError(const torch::lazy::Node* node,
267-
const char* error_msg) {
271+
void LoweringContext::ReportBuilderError(const torch::lazy::Node& node,
272+
const absl::string_view error_msg) {
268273
std::stringstream ss;
269-
ss << "Error while lowering: " << node->ToString() << "\n";
274+
ss << "Error while lowering: " << node.ToString() << "\n";
270275
if (!builder()->first_error().ok()) {
271276
ss << "XLA builder error: " << builder()->GetCurrentStatus() << "\n";
272277
}
273-
if (error_msg != nullptr) {
278+
if (!error_msg.empty()) {
274279
ss << "Error: " << error_msg << "\n";
275280
}
276-
const torch::lazy::MetaData& nmeta = node->metadata();
281+
const torch::lazy::MetaData& nmeta = node.metadata();
277282
if (!nmeta.scope.empty()) {
278283
ss << "Scope: " << nmeta.scope << "\n";
279284
}
@@ -282,17 +287,18 @@ void LoweringContext::ReportBuilderError(const torch::lazy::Node* node,
282287
}
283288

284289
void LoweringContext::SetUpAlias(const std::vector<int64_t>& output_index,
285-
int64_t param_number,
290+
const int64_t param_number,
286291
const std::vector<int64_t>& param_index,
287-
bool must_alias) {
288-
XLA_CHECK_EQ(output_index.size(), 1);
289-
XLA_CHECK_EQ(param_index.size(), 1);
292+
const bool must_alias) {
293+
ABSL_CHECK_EQ(output_index.size(), 1);
294+
ABSL_CHECK_EQ(param_index.size(), 1);
290295
builder_.SetUpAlias({output_index[0]}, param_number, {param_index[0]});
291296
}
292297

293298
bool LoweringContext::CheckResultShape(
294-
const torch::lazy::BackendDataPtr& parameter_data, size_t result_idx) {
295-
xla::XlaOp root = GetResult(result_idx);
299+
const torch::lazy::BackendDataPtr& parameter_data,
300+
const size_t result_idx) {
301+
const xla::XlaOp root = GetResult(result_idx);
296302
const xla::Shape& root_shape = ShapeHelper::ShapeOfXlaOp(root);
297303
return std::dynamic_pointer_cast<runtime::ComputationClient::Data>(
298304
parameter_data)
@@ -304,22 +310,21 @@ size_t LoweringContext::AddResult(const torch::lazy::Output& output) {
304310
return root_tuple_.size() - 1;
305311
}
306312

307-
size_t LoweringContext::AddResult(xla::XlaOp op) {
313+
size_t LoweringContext::AddResult(const xla::XlaOp op) {
308314
root_tuple_.push_back(op);
309315
return root_tuple_.size() - 1;
310316
}
311317

312318
void LoweringContext::AddParameter(const torch::lazy::Output& output,
313-
size_t index,
319+
const size_t index,
314320
const torch::lazy::Shape& shape,
315321
const std::string& name) {
316-
XLA_ERROR() << "not implemented";
322+
ABSL_LOG(FATAL) << "not implemented";
317323
return;
318324
}
319325

320326
torch::lazy::ComputationPtr LoweringContext::Build() {
321327
xla::XlaComputation xla_computation = ConsumeValue(BuildXla());
322-
323328
return std::make_shared<runtime::ComputationClient::Computation>(
324329
builder_.name(), std::move(xla_computation), device_);
325330
}

0 commit comments

Comments
 (0)