Skip to content

Commit 8d20a86

Browse files
authored
ops: Use new macros for throwing exceptions. (#9592)
Follow-up: #9588 and #9580 Target: - `torch_xla/csrc/ops` directory - Files related to the tracing of tensor operations In summary, this PR: - Replaces all calls to `OkOrThrow()` and `GetValueOrThrow()` (that throws an exception without source location information of the *"throw-site"*) with the macros `XLA_THROW_IF_ERROR()` and `XLA_ASSIGN_OR_THROW()`. - Corresponds to the fine-grained set of PRs that came from breaking down PR #9580 - Focuses on the `torch_xla/csrc/ops` directory and other files related to the tracing of tensor operations, replacing every use of those, now deprecated, functions by the newly introduced macros.
1 parent d9a9e44 commit 8d20a86

11 files changed

+61
-38
lines changed

torch_xla/csrc/convolution.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -218,9 +218,11 @@ xla::XlaOp BuildConvBackwardInput(xla::XlaOp grad_output, xla::XlaOp kernel,
218218
MakeConvOpAttrs(spatial_stride, spatial_padding, spatial_dilation, false);
219219
xla::XlaOp kernel_transposed = xla::Transpose(
220220
kernel, FilterTransposePermutation(input_shape.dimensions_size()));
221-
return GetValueOrThrow(MakeXlaBackpropInputConvOp(
222-
"conv_backward_input", input_shape, kernel_transposed, grad_output,
223-
conv_op_attrs));
221+
XLA_ASSIGN_OR_THROW(xla::XlaOp conv_backward_input,
222+
MakeXlaBackpropInputConvOp("conv_backward_input",
223+
input_shape, kernel_transposed,
224+
grad_output, conv_op_attrs));
225+
return conv_backward_input;
224226
}
225227

226228
// Computes the kernel gradient for a convolution.
@@ -238,14 +240,15 @@ xla::XlaOp BuildConvBackwardWeight(xla::XlaOp grad_output, xla::XlaOp input,
238240
xla::InversePermutation(transpose_permutation);
239241
xla::Shape transposed_weight_shape =
240242
xla::ShapeUtil::PermuteDimensions(transpose_permutation, kernel_shape);
241-
xla::XlaOp conv = GetValueOrThrow(MakeXlaBackpropFilterConvOp(
242-
"conv_backward_weight", input, transposed_weight_shape, grad_output,
243-
conv_op_attrs));
243+
XLA_ASSIGN_OR_THROW(xla::XlaOp conv_backward_weight,
244+
MakeXlaBackpropFilterConvOp("conv_backward_weight", input,
245+
transposed_weight_shape,
246+
grad_output, conv_op_attrs));
244247

245248
// Reorder the dimensions of the filter gradient to match the NCHW convention
246249
// of PyTorch. The original result of the convolution has the spatial and
247250
// feature dimensions swapped and the spatial dimensions reversed.
248-
return xla::Transpose(conv, inv_transpose_permutation);
251+
return xla::Transpose(conv_backward_weight, inv_transpose_permutation);
249252
}
250253

251254
xla::XlaOp BuildGradBias(xla::XlaOp grad_output) {

torch_xla/csrc/cross_replica_reduces.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ std::shared_ptr<torch::lazy::Value> CreateToken(
116116
at::Tensor all_reduce(const at::Tensor& self, std::string reduceOp,
117117
std::string /*group_name*/) {
118118
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
119-
auto self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self));
119+
XLA_ASSIGN_OR_THROW(XLATensorPtr self_tensor, bridge::GetXlaTensor(self));
120120
// TODO(alanwaketan): Use group_name to generate groups. Currently we just
121121
// use {} as a workaround. Scale is always 1.0 here, and we always pin
122122
// layout.
@@ -270,7 +270,7 @@ AllGatherResult BuildAllGather(xla::XlaOp input, xla::XlaOp token, int64_t dim,
270270
at::Tensor all_gather_into_tensor(const at::Tensor& self, int64_t group_size,
271271
std::string group_name) {
272272
TORCH_LAZY_FN_COUNTER("xla::");
273-
auto self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self));
273+
XLA_ASSIGN_OR_THROW(XLATensorPtr self_tensor, bridge::GetXlaTensor(self));
274274
std::vector<int64_t> all_groups(group_size);
275275
std::iota(all_groups.begin(), all_groups.end(), 0);
276276
auto result = tensor_methods::all_gather(self_tensor, 0, group_size,
@@ -349,9 +349,9 @@ at::Tensor all_to_all_single(const at::Tensor& input,
349349
}
350350
XLATensorPtr result_ptr;
351351
torch::lazy::Value new_token;
352+
XLA_ASSIGN_OR_THROW(XLATensorPtr input_tensor, bridge::GetXlaTensor(input));
352353
std::tie(result_ptr, new_token) = tensor_methods::all_to_all(
353-
GetValueOrThrow(bridge::GetXlaTensor(input)), token, 0, 0, split_count,
354-
{all_groups}, pin_layout);
354+
input_tensor, token, 0, 0, split_count, {all_groups}, pin_layout);
355355
at::Tensor result = bridge::AtenFromXlaTensor(std::move(result_ptr));
356356

357357
at::Tensor result_with_grad = torch::autograd::make_variable(
@@ -481,7 +481,7 @@ xla::XlaOp BuildReduceScatter(AllReduceType reduce_type, xla::XlaOp input,
481481
at::Tensor reduce_scatter_tensor(const at::Tensor& input, std::string reduce_op,
482482
int64_t group_size, std::string group_name) {
483483
TORCH_LAZY_FN_COUNTER("xla::");
484-
auto self = GetValueOrThrow(bridge::GetXlaTensor(input));
484+
XLA_ASSIGN_OR_THROW(XLATensorPtr self, bridge::GetXlaTensor(input));
485485
std::vector<int64_t> all_groups(group_size);
486486
std::iota(all_groups.begin(), all_groups.end(), 0);
487487
int64_t shard_count = group_size;

torch_xla/csrc/data_ops.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,8 +197,8 @@ xla::XlaOp BuildMaskedFillScalar(xla::XlaOp input, xla::XlaOp mask,
197197
const xla::Shape& mask_shape = ShapeHelper::ShapeOfXlaOp(mask);
198198

199199
if (!xla::ShapeUtil::Compatible(input_shape, mask_shape)) {
200-
xla::Shape shape =
201-
GetValueOrThrow(XlaHelpers::GetPromotedShape(input_shape, mask_shape));
200+
XLA_ASSIGN_OR_THROW(xla::Shape shape,
201+
XlaHelpers::GetPromotedShape(input_shape, mask_shape));
202202
input = BuildExpand(input, shape.dimensions());
203203
mask = BuildExpand(mask, shape.dimensions());
204204
}

torch_xla/csrc/ops/index_ops.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "torch_xla/csrc/ops/scalar.h"
1919
#include "torch_xla/csrc/runtime/debug_macros.h"
2020
#include "torch_xla/csrc/runtime/util.h"
21+
#include "torch_xla/csrc/status.h"
2122
#include "torch_xla/csrc/tensor_methods.h"
2223
#include "torch_xla/csrc/tensor_util.h"
2324
#include "torch_xla/csrc/xla_graph_executor.h"
@@ -315,8 +316,10 @@ XLATensorPtr GetZeroElementTensor(const XLATensorPtr& base,
315316
base_dimensions.begin() + start_dim + indices.size(),
316317
base_dimensions.end());
317318

318-
return GetValueOrThrow(
319+
XLA_ASSIGN_OR_THROW(
320+
XLATensorPtr output,
319321
tensor_methods::full(dimensions, 0, base->GetDevice(), base->dtype()));
322+
return output;
320323
}
321324

322325
XLATensorPtr IndexByTensors(const XLATensorPtr& base,

torch_xla/csrc/ops/triangular_solve.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ std::pair<xla::Shape, xla::Shape> InferTriangularSolveShape(
3333
return std::pair<xla::Shape, xla::Shape>(rhs_batch_shape, lhs_batch_shape);
3434
}
3535
// Obtain the promoted shapes and add back the trailing dimension.
36-
xla::Shape rhs_batch_promoted_shape = GetValueOrThrow(
36+
XLA_ASSIGN_OR_THROW(
37+
xla::Shape rhs_batch_promoted_shape,
3738
XlaHelpers::GetPromotedShape(rhs_batch_shape, lhs_batch_shape));
3839
xla::Shape lhs_batch_promoted_shape(rhs_batch_promoted_shape);
3940
rhs_batch_promoted_shape.add_dimensions(nrhs);

torch_xla/csrc/pooling.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ xla::XlaComputation CreateGeComputation(xla::PrimitiveType type) {
4949
xla::XlaOp y = xla::Parameter(&reduction_builder, 1,
5050
xla::ShapeUtil::MakeShape(type, {}), "y");
5151
xla::Ge(x, y);
52-
return GetValueOrThrow(reduction_builder.Build());
52+
XLA_ASSIGN_OR_THROW(xla::XlaComputation ge_computation,
53+
reduction_builder.Build());
54+
return ge_computation;
5355
}
5456

5557
xla::TensorFormat MakeNCHWFormat(int64_t spatial_dim_count) {
@@ -367,7 +369,8 @@ xla::XlaOp ComputeMaxPoolIndices(
367369
return results;
368370
};
369371

370-
std::vector<xla::XlaOp> results = GetValueOrThrow(
372+
XLA_ASSIGN_OR_THROW(
373+
std::vector<xla::XlaOp> results,
371374
xla::WhileLoopHelper(cond_fn, body_fn, initial_values.values,
372375
"ComputeMaxPoolIndices", padded_input.builder()));
373376

torch_xla/csrc/reduction.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ xla::XlaComputation CreateAllComputation(xla::PrimitiveType type) {
6060
xla::XlaOp zero = xla::Zero(&builder, type);
6161
xla::XlaOp one = xla::One(&builder, type);
6262
xla::Select(xla::And(xla::Ne(x, zero), xla::Ne(y, zero)), one, zero);
63-
return GetValueOrThrow(builder.Build());
63+
XLA_ASSIGN_OR_THROW(xla::XlaComputation all_computation, builder.Build());
64+
return all_computation;
6465
}
6566

6667
xla::XlaComputation CreateAnyComputation(xla::PrimitiveType type) {
@@ -72,7 +73,8 @@ xla::XlaComputation CreateAnyComputation(xla::PrimitiveType type) {
7273
xla::XlaOp zero = xla::Zero(&builder, type);
7374
xla::XlaOp one = xla::One(&builder, type);
7475
xla::Select(xla::Or(xla::Ne(x, zero), xla::Ne(y, zero)), one, zero);
75-
return GetValueOrThrow(builder.Build());
76+
XLA_ASSIGN_OR_THROW(xla::XlaComputation any_computation, builder.Build());
77+
return any_computation;
7678
}
7779

7880
xla::XlaOp GetScaleValue(xla::XlaOp input, xla::XlaOp count,

torch_xla/csrc/shape_helper.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
namespace torch_xla {
77

88
const xla::Shape& ShapeHelper::ShapeOfXlaOp(xla::XlaOp op) {
9-
return *GetValueOrThrow(GetShape(op));
9+
XLA_ASSIGN_OR_THROW(const xla::Shape* shape, GetShape(op));
10+
return *shape;
1011
}
1112

1213
absl::StatusOr<const xla::Shape * absl_nonnull> GetShape(xla::XlaOp op) {

torch_xla/csrc/tensor_methods.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1479,9 +1479,11 @@ std::tuple<XLATensorPtr, XLATensorPtr> cummax(const XLATensorPtr& input,
14791479
at::Tensor val =
14801480
at::empty(shape_, at::TensorOptions().dtype(input->dtype()));
14811481
at::Tensor idx = at::empty(shape_, at::TensorOptions().dtype(at::kLong));
1482-
return std::make_tuple(
1483-
GetValueOrThrow(XLATensor::Create(val, input->GetDevice())),
1484-
GetValueOrThrow(XLATensor::Create(idx, input->GetDevice())));
1482+
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_val,
1483+
XLATensor::Create(val, input->GetDevice()));
1484+
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_idx,
1485+
XLATensor::Create(idx, input->GetDevice()));
1486+
return std::make_tuple(xla_val, xla_idx);
14851487
}
14861488
torch::lazy::NodePtr node =
14871489
torch_xla::MakeNode<CumMax>(input->GetIrValue(), canonical_dim);
@@ -2533,10 +2535,10 @@ std::tuple<XLATensorPtr, XLATensorPtr, XLATensorPtr> native_batch_norm(
25332535
}
25342536
} else {
25352537
at::Tensor at_input = bridge::AtenFromXlaTensor(input);
2536-
mean = GetValueOrThrow(
2537-
bridge::GetXlaTensor(at::empty({0}, at_input.options())));
2538-
variance_inverse = GetValueOrThrow(
2539-
bridge::GetXlaTensor(at::empty({0}, at_input.options())));
2538+
XLA_ASSIGN_OR_THROW(
2539+
mean, bridge::GetXlaTensor(at::empty({0}, at_input.options())));
2540+
XLA_ASSIGN_OR_THROW(variance_inverse, bridge::GetXlaTensor(at::empty(
2541+
{0}, at_input.options())));
25402542
}
25412543

25422544
XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get();

torch_xla/csrc/tensor_ops.cpp

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "torch_xla/csrc/runtime/computation_client.h"
99
#include "torch_xla/csrc/runtime/debug_macros.h"
1010
#include "torch_xla/csrc/runtime/util.h"
11+
#include "torch_xla/csrc/status.h"
1112
#include "torch_xla/csrc/tensor_methods.h"
1213

1314
namespace torch_xla {
@@ -148,7 +149,8 @@ XLATensorPtr SmoothL1LossBackward(const XLATensorPtr& grad_output,
148149
XLATensorPtr grad_scale = tensor_methods::get_dimensions_size(
149150
broadcasted_input,
150151
XlaHelpers::GetAllDimensions(broadcasted_input->shape()));
151-
XLATensorPtr div_result = GetValueOrThrow(
152+
XLA_ASSIGN_OR_THROW(
153+
XLATensorPtr div_result,
152154
tensor_methods::div(elementwise_loss_backward, grad_scale));
153155
return tensor_methods::mul(div_result, grad_output);
154156
}
@@ -174,7 +176,8 @@ XLATensorPtr SoftplusBackward(const XLATensorPtr& grad_output,
174176
XLATensorPtr z = tensor_methods::exp(scaled_input);
175177
XLATensorPtr one_vec =
176178
tensor_methods::full_like(z, 1, z->GetDevice(), z->dtype());
177-
XLATensorPtr div = GetValueOrThrow(
179+
XLA_ASSIGN_OR_THROW(
180+
XLATensorPtr div,
178181
tensor_methods::div(z, tensor_methods::add(z, one_vec, 1)));
179182

180183
return tensor_methods::where(tensor_methods::gt(scaled_input, threshold),
@@ -207,24 +210,29 @@ XLATensorPtr EmbeddingDenseBackward(const XLATensorPtr& grad_output,
207210
int64_t numel = xla::ShapeUtil::ElementsIn(indices_shape_ref.get());
208211
XLATensorPtr grad =
209212
tensor_methods::view(grad_output, {numel, grad_output->size(-1)});
210-
XLATensorPtr grad_weight = GetValueOrThrow(
213+
XLA_ASSIGN_OR_THROW(
214+
XLATensorPtr grad_weight,
211215
tensor_methods::full({num_weights, grad_output->size(-1)}, 0,
212216
grad_output->GetDevice(), grad_output->dtype()));
213217
XLATensorPtr indices_rank1 = tensor_methods::view(indices, {numel});
214218
if (scale_grad_by_freq) {
215219
// Compute the histogram of index values.
216-
XLATensorPtr counts = GetValueOrThrow(tensor_methods::full(
217-
{num_weights}, 0, indices->GetDevice(), indices->dtype()));
218-
XLATensorPtr ones = GetValueOrThrow(tensor_methods::full(
219-
{numel}, 1, indices->GetDevice(), indices->dtype()));
220+
XLA_ASSIGN_OR_THROW(
221+
XLATensorPtr counts,
222+
tensor_methods::full({num_weights}, 0, indices->GetDevice(),
223+
indices->dtype()));
224+
XLA_ASSIGN_OR_THROW(XLATensorPtr ones,
225+
tensor_methods::full({numel}, 1, indices->GetDevice(),
226+
indices->dtype()));
220227
tensor_methods::index_put_(counts, counts, {indices_rank1}, /*start_dim=*/0,
221228
/*values=*/ones,
222229
/*accumulate=*/true, /*result_permutation=*/{0});
223230
XLATensorPtr grad_weights_scale =
224231
tensor_methods::index(counts, {indices_rank1}, 0);
225232
// Scale the value of the gradient by the histogram.
226-
grad = GetValueOrThrow(tensor_methods::div(
227-
grad, tensor_methods::unsqueeze(grad_weights_scale, 1)));
233+
XLA_ASSIGN_OR_THROW(
234+
grad, tensor_methods::div(
235+
grad, tensor_methods::unsqueeze(grad_weights_scale, 1)));
228236
}
229237
// Don't accumulate gradients for indices which are equal with the given
230238
// padding_idx.

0 commit comments

Comments
 (0)