22
33#include < torch/csrc/lazy/core/ir_metadata.h>
44
5- #include < iostream>
65#include < optional>
76#include < sstream>
87#include < stdexcept>
9- #include < string_view >
8+ #include < utility >
109
10+ #include " absl/log/absl_check.h"
11+ #include " absl/log/absl_log.h"
1112#include " absl/status/status.h"
1213#include " absl/strings/str_cat.h"
13- #include " absl/strings/str_join.h"
1414#include " absl/strings/str_replace.h"
15- #include " absl/strings/string_view.h"
1615#include " torch_xla/csrc/ir.h"
1716#include " torch_xla/csrc/runtime/computation_client.h"
1817#include " torch_xla/csrc/runtime/debug_macros.h"
1918#include " torch_xla/csrc/runtime/sys_util.h"
2019#include " torch_xla/csrc/shape_helper.h"
2120#include " torch_xla/csrc/stack_frame_index_builder.h"
22- #include " torch_xla/csrc/unwrap_data.h"
2321
2422namespace torch_xla {
2523
2624namespace {
2725
2826class HloMetadataSetter {
2927 public:
30- HloMetadataSetter (LoweringContext* loctx, const torch::lazy::Node* node) {
28+ HloMetadataSetter (LoweringContext& lowering_context,
29+ const torch::lazy::Node& node)
30+ : lowering_context_(lowering_context) {
3131 if (ShouldPopulateXlaOpMetadata ()) {
32- PopulateXlaOpMetadata (loctx, node);
33- loctx_ = loctx;
32+ PopulateXlaOpMetadata (lowering_context, node);
3433 }
3534 }
3635
36+ // This class is neither copyable nor movable.
37+ HloMetadataSetter (const HloMetadataSetter&) = delete ;
38+ HloMetadataSetter& operator =(const HloMetadataSetter&) = delete ;
39+ HloMetadataSetter (HloMetadataSetter&&) = delete ;
40+ HloMetadataSetter& operator =(HloMetadataSetter&&) = delete ;
41+
3742 ~HloMetadataSetter () {
38- if (loctx_ != nullptr ) {
39- loctx_-> builder ()->ClearOpMetadata ();
43+ if (ShouldPopulateXlaOpMetadata () ) {
44+ lowering_context_. builder ()->ClearOpMetadata ();
4045 }
4146 }
4247
4348 private:
49+ // Returns true iff this class should populate XLA op metadata in its
50+ // constructor.
4451 static bool ShouldPopulateXlaOpMetadata () {
45- static bool op_metadata =
52+ static const bool op_metadata =
4653 runtime::sys_util::GetEnvBool (" XLA_HLO_DEBUG" , false );
4754 return FLAGS_torch_lazy_ir_debug || op_metadata;
4855 }
4956
50- static void PopulateXlaOpMetadata (LoweringContext* loctx ,
51- const torch::lazy::Node* node) {
57+ static void PopulateXlaOpMetadata (LoweringContext& lowering_context ,
58+ const torch::lazy::Node& node) {
5259 xla::OpMetadata metadata;
5360 // NOTE: we apply some string manipulation as xprof backend utility
5461 // for nesting/grouping traces depends on certain op name/type
5562 // patterns for classification.
5663 // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/profiler/utils/tf_op_utils.cc#L55
57- std::string op_type =
58- absl::StrReplaceAll (node-> op ().ToString (), {{" :" , " _" }});
64+ const std::string op_type =
65+ absl::StrReplaceAll (node. op ().ToString (), {{" :" , " _" }});
5966 metadata.set_op_type (op_type);
6067
61- const torch::lazy::MetaData& nmeta = node->metadata ();
62-
63- const CustomOpNameMetaData* custom_opname_meta =
64- dynamic_cast <const CustomOpNameMetaData*>(node->user_metadata ());
68+ const torch::lazy::MetaData& nmeta = node.metadata ();
69+ auto * const custom_opname_meta =
70+ dynamic_cast <const CustomOpNameMetaData*>(node.user_metadata ());
6571
6672 std::string op_name_prefix;
6773 size_t max_stack_depth = nmeta.frame_info .size ();
@@ -78,66 +84,65 @@ class HloMetadataSetter {
7884 metadata.set_op_name (absl::StrCat (op_name_prefix, op_type));
7985
8086 // Sets file, line and stack_frame_id in metadata
81- loctx-> stack_frame_index_builder ()->AddStackFrameLocations (
82- nmeta.frame_info , max_stack_depth, metadata);
87+ lowering_context. stack_frame_index_builder ()->AddStackFrameLocations (
88+ nmeta.frame_info , static_cast < int >( max_stack_depth) , metadata);
8389
84- loctx-> builder ()->SetOpMetadata (std::move (metadata));
90+ lowering_context. builder ()->SetOpMetadata (std::move (metadata));
8591 }
8692
87- LoweringContext* loctx_ = nullptr ;
93+ LoweringContext& lowering_context_ ;
8894};
8995
9096} // namespace
9197
9298LoweringContext::LoweringContext (const std::string& name,
9399 torch::lazy::BackendDevice device)
94- : torch::lazy::LoweringContext(name, device),
100+ : torch::lazy::LoweringContext(name, std::move( device) ),
95101 builder_ (name),
96102 stack_frame_index_builder_(std::make_shared<StackFrameIndexBuilder>()) {}
97103
98104LoweringContext::LoweringContext (
99105 const std::string& name, torch::lazy::BackendDevice device,
100- c10::ArrayRef<const torch::lazy::Node*> post_order,
106+ const c10::ArrayRef<const torch::lazy::Node*> post_order,
101107 torch::lazy::Util::EmissionMap emit_status)
102- : torch::lazy::LoweringContext(name, device, {}, emit_status),
108+ : torch::lazy::LoweringContext(name, std::move(device), {},
109+ std::move (emit_status)),
103110 builder_(name),
104111 stack_frame_index_builder_(std::make_shared<StackFrameIndexBuilder>()) {
105- for (auto node : post_order) {
106- LowerNode (node);
112+ for (const auto * node : post_order) {
113+ LowerNode (* node);
107114 }
108115}
109116
110- // TODO(lsy323): Get reserved number for unbounded dim after it's added in XLA.
111- static constexpr int64_t kUnboundedSize = std::numeric_limits<int64_t >::min();
112-
113117xla::XlaOp LoweringContext::GetParameter (
114118 const std::shared_ptr<torch::lazy::BackendData>& backend_data,
115119 const std::unordered_set<uint32_t >& unbounded_dynamic_dims) {
116- torch::lazy::BackendData::Handle handle = backend_data->GetHandle ();
120+ const torch::lazy::BackendData::Handle handle = backend_data->GetHandle ();
117121 auto it = parameters_map_.find (handle);
118122 if (it == parameters_map_.end ()) {
119- auto data = std::dynamic_pointer_cast<runtime::ComputationClient::Data>(
120- backend_data);
121- XLA_CHECK (data != nullptr );
123+ auto * const data =
124+ dynamic_cast <runtime::ComputationClient::Data*>( backend_data. get () );
125+ ABSL_CHECK (data != nullptr );
122126 xla::Shape shape = data->shape ();
123127 for (const int dim : unbounded_dynamic_dims) {
124128 shape.set_dynamic_dimension (dim, true );
125- shape.set_dimensions (dim, kUnboundedSize );
129+ shape.set_dimensions (dim, xla::Shape:: kUnboundedSize );
126130 }
127- size_t param_index = parameters_.size ();
128- std::string param_name = absl::StrCat (" p" , param_index);
131+ const size_t param_index = parameters_.size ();
132+ const std::string param_name = absl::StrCat (" p" , param_index);
129133 xla::XlaOp param;
130134 if (data->HasSharding ()) {
131- xla::OpSharding sharding = data->GetSharding ();
132- xla::XlaScopedShardingAssignment scoped_sharding (builder (), sharding);
135+ const xla::OpSharding sharding = data->GetSharding ();
136+ const xla::XlaScopedShardingAssignment scoped_sharding (builder (),
137+ sharding);
133138 param = xla::Parameter (builder (), param_index, shape, param_name);
134139 } else {
135140 param = xla::Parameter (builder (), param_index, shape, param_name);
136141 }
137142 it = parameters_map_.emplace (handle, Parameter{param, param_index}).first ;
138143 parameters_.push_back (backend_data);
139144 } else {
140- XLA_CHECK (unbounded_dynamic_dims.empty ())
145+ ABSL_CHECK (unbounded_dynamic_dims.empty ())
141146 << " The unbounded dynamic dims can only be set when Parameter is "
142147 " created." ;
143148 }
@@ -147,8 +152,8 @@ xla::XlaOp LoweringContext::GetParameter(
147152
148153std::optional<size_t > LoweringContext::GetParameterId (
149154 const std::shared_ptr<torch::lazy::BackendData>& backend_data) const {
150- torch::lazy::BackendData::Handle handle = backend_data->GetHandle ();
151- auto it = parameters_map_.find (handle);
155+ const torch::lazy::BackendData::Handle handle = backend_data->GetHandle ();
156+ const auto it = parameters_map_.find (handle);
152157 if (it == parameters_map_.end ()) {
153158 return std::nullopt ;
154159 }
@@ -164,12 +169,12 @@ const std::vector<size_t>& LoweringContext::GetParameterSequence() const {
164169 return parameter_sequence_;
165170}
166171
167- xla::XlaOp LoweringContext::GetResult (size_t index) const {
172+ xla::XlaOp LoweringContext::GetResult (const size_t index) const {
168173 return root_tuple_.at (index);
169174}
170175
171- void LoweringContext::SetResult (size_t index, xla::XlaOp op) {
172- root_tuple_.at (index) = std::move (op) ;
176+ void LoweringContext::SetResult (const size_t index, const xla::XlaOp op) {
177+ root_tuple_.at (index) = op ;
173178}
174179
175180absl::StatusOr<xla::XlaComputation> LoweringContext::BuildXla () {
@@ -181,7 +186,7 @@ absl::StatusOr<xla::XlaComputation> LoweringContext::BuildXla() {
181186 ((get_name_string () == " condctx" ) or (get_name_string () == " bodyctx" ))) {
182187 xla = builder ()->Build (root_tuple_.at (0 ));
183188 } else if (!root_tuple_.empty ()) {
184- xla::XlaOp root = xla::Tuple (builder (), root_tuple_);
189+ const xla::XlaOp root = xla::Tuple (builder (), root_tuple_);
185190 xla = builder ()->Build (root);
186191 } else {
187192 xla = builder ()->Build ();
@@ -195,8 +200,9 @@ absl::StatusOr<xla::XlaComputation> LoweringContext::BuildXla() {
195200 return xla;
196201}
197202
198- absl::StatusOr<xla::XlaComputation> LoweringContext::BuildXla (xla::XlaOp root) {
199- XLA_CHECK (root_tuple_.empty ());
203+ absl::StatusOr<xla::XlaComputation> LoweringContext::BuildXla (
204+ const xla::XlaOp root) {
205+ ABSL_CHECK (root_tuple_.empty ());
200206 auto xla = builder ()->Build (root);
201207
202208 if (xla.ok ()) {
@@ -208,72 +214,71 @@ absl::StatusOr<xla::XlaComputation> LoweringContext::BuildXla(xla::XlaOp root) {
208214}
209215
210216void LoweringContext::AssignOutputOp (const torch::lazy::Output& output,
211- xla::XlaOp op) {
212- emitted_outputs_[output] = std::move (op) ;
217+ const xla::XlaOp op) {
218+ emitted_outputs_[output] = op ;
213219}
214220
215221xla::XlaOp LoweringContext::GetOutputOp (const torch::lazy::Output& output) {
216222 auto it = emitted_outputs_.find (output);
217223
218224 if (it == emitted_outputs_.end ()) {
219- auto post_order =
225+ const auto post_order =
220226 torch::lazy::Util::ComputePostOrder (output.node , &emit_status_);
221- for (auto node : post_order) {
222- LowerNode (node);
227+ for (const auto * const node : post_order) {
228+ LowerNode (* node);
223229 }
224230 // At this point the output better be present, otherwise there is an issue
225231 // with the lowering code.
226232 it = emitted_outputs_.find (output);
227- XLA_CHECK (it != emitted_outputs_.end ())
233+ ABSL_CHECK (it != emitted_outputs_.end ())
228234 << " No XLA operation emitted for output: " << output;
229235 }
230236 return it->second ;
231237}
232238
233- XlaOpVector LoweringContext::LowerNode (const torch::lazy::Node* node) {
239+ XlaOpVector LoweringContext::LowerNode (const torch::lazy::Node& node) {
234240 XlaOpVector result_ops;
235241 try {
236- HloMetadataSetter meta_setter (this , node);
237-
238- const XlaNode* casted = dynamic_cast <const XlaNode*>(node);
242+ const HloMetadataSetter meta_setter (*this , node);
243+ const XlaNode* const casted = dynamic_cast <const XlaNode*>(&node);
239244
240245 result_ops = casted->Lower (this );
241246 if (!casted->dynamic_dims ().empty ()) {
242- xla::internal::XlaBuilderFriend builder_friend;
243- auto * inst = builder_friend.GetInstruction (result_ops[0 ]);
244- auto * mutable_dynamic =
247+ const xla::internal::XlaBuilderFriend builder_friend;
248+ auto * const inst = builder_friend.GetInstruction (result_ops[0 ]);
249+ auto * const mutable_dynamic =
245250 inst->mutable_shape ()->mutable_is_dynamic_dimension ();
246251 if (mutable_dynamic->empty ()) {
247252 for (int i = 0 ; i < inst->dimensions_size (); i++) {
248253 mutable_dynamic->Add (false );
249254 }
250255 }
251- auto * mutable_dims = inst->mutable_shape ()->mutable_dimensions ();
256+ auto * const mutable_dims = inst->mutable_shape ()->mutable_dimensions ();
252257 for (const auto dim : casted->dynamic_dims ()) {
253258 mutable_dynamic->Set (dim, true );
254- mutable_dims->Set (dim, kUnboundedSize );
259+ mutable_dims->Set (dim, xla::Shape:: kUnboundedSize );
255260 }
256261 }
257262 } catch (const std::exception& ex) {
258263 ReportBuilderError (node, ex.what ());
259264 }
260265 if (!builder ()->first_error ().ok ()) {
261- ReportBuilderError (node, /* error_msg=*/ nullptr );
266+ ReportBuilderError (node, /* error_msg=*/ " " );
262267 }
263268 return result_ops;
264269}
265270
266- void LoweringContext::ReportBuilderError (const torch::lazy::Node* node,
267- const char * error_msg) {
271+ void LoweringContext::ReportBuilderError (const torch::lazy::Node& node,
272+ const absl::string_view error_msg) {
268273 std::stringstream ss;
269- ss << " Error while lowering: " << node-> ToString () << " \n " ;
274+ ss << " Error while lowering: " << node. ToString () << " \n " ;
270275 if (!builder ()->first_error ().ok ()) {
271276 ss << " XLA builder error: " << builder ()->GetCurrentStatus () << " \n " ;
272277 }
273- if (error_msg != nullptr ) {
278+ if (!error_msg. empty () ) {
274279 ss << " Error: " << error_msg << " \n " ;
275280 }
276- const torch::lazy::MetaData& nmeta = node-> metadata ();
281+ const torch::lazy::MetaData& nmeta = node. metadata ();
277282 if (!nmeta.scope .empty ()) {
278283 ss << " Scope: " << nmeta.scope << " \n " ;
279284 }
@@ -282,17 +287,18 @@ void LoweringContext::ReportBuilderError(const torch::lazy::Node* node,
282287}
283288
284289void LoweringContext::SetUpAlias (const std::vector<int64_t >& output_index,
285- int64_t param_number,
290+ const int64_t param_number,
286291 const std::vector<int64_t >& param_index,
287- bool must_alias) {
288- XLA_CHECK_EQ (output_index.size (), 1 );
289- XLA_CHECK_EQ (param_index.size (), 1 );
292+ const bool must_alias) {
293+ ABSL_CHECK_EQ (output_index.size (), 1 );
294+ ABSL_CHECK_EQ (param_index.size (), 1 );
290295 builder_.SetUpAlias ({output_index[0 ]}, param_number, {param_index[0 ]});
291296}
292297
293298bool LoweringContext::CheckResultShape (
294- const torch::lazy::BackendDataPtr& parameter_data, size_t result_idx) {
295- xla::XlaOp root = GetResult (result_idx);
299+ const torch::lazy::BackendDataPtr& parameter_data,
300+ const size_t result_idx) {
301+ const xla::XlaOp root = GetResult (result_idx);
296302 const xla::Shape& root_shape = ShapeHelper::ShapeOfXlaOp (root);
297303 return std::dynamic_pointer_cast<runtime::ComputationClient::Data>(
298304 parameter_data)
@@ -304,22 +310,21 @@ size_t LoweringContext::AddResult(const torch::lazy::Output& output) {
304310 return root_tuple_.size () - 1 ;
305311}
306312
307- size_t LoweringContext::AddResult (xla::XlaOp op) {
313+ size_t LoweringContext::AddResult (const xla::XlaOp op) {
308314 root_tuple_.push_back (op);
309315 return root_tuple_.size () - 1 ;
310316}
311317
312318void LoweringContext::AddParameter (const torch::lazy::Output& output,
313- size_t index,
319+ const size_t index,
314320 const torch::lazy::Shape& shape,
315321 const std::string& name) {
316- XLA_ERROR ( ) << " not implemented" ;
322+ ABSL_LOG (FATAL ) << " not implemented" ;
317323 return ;
318324}
319325
320326torch::lazy::ComputationPtr LoweringContext::Build () {
321327 xla::XlaComputation xla_computation = ConsumeValue (BuildXla ());
322-
323328 return std::make_shared<runtime::ComputationClient::Computation>(
324329 builder_.name (), std::move (xla_computation), device_);
325330}
0 commit comments