Skip to content

Commit 147d2c2

Browse files
authored
full: improve error handling and error messages. (#9564)
This PR refactors the `tensor_methods::full` and `tensor_methods::full_symint` implementation by improving their error message, and returning a status type value. **Key Changes:** - Make `tensor_methods::full` and `tensor_methods::full_symint` return `StatusOr<absl_nonnull XLATensorPtr>` - Improve error message on invalid arguments
1 parent b098be8 commit 147d2c2

File tree

6 files changed

+83
-51
lines changed

6 files changed

+83
-51
lines changed

test/test_operations.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2511,6 +2511,16 @@ def test_flip_raises_error_on_duplicated_dims(self):
25112511
f"from {dims} to {dims_suggestion}.")
25122512
self.assertEqual(str(e), expected_error)
25132513

2514+
def test_full_raises_error_on_negative_size(self):
2515+
shape = [2, -2, 2]
2516+
try:
2517+
torch.full(shape, 1.5, device="xla")
2518+
except RuntimeError as e:
2519+
expected_error = (
2520+
"full(): expected concrete sizes (i.e. non-symbolic) to be "
2521+
f"positive values. However found negative ones: {shape}.")
2522+
self.assertEqual(str(e), expected_error)
2523+
25142524

25152525
class MNISTComparator(nn.Module):
25162526

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1702,16 +1702,14 @@ at::Tensor XLANativeFunctions::empty_symint(
17021702
// does not actually end up doing any memory initialization, we use that and
17031703
// avoid going to CPU for it. A common PT pattern is indeed doing empty() plus
17041704
// s_copy_().
1705-
XLATensorPtr xla_tensor;
1706-
if (all_dims_static) {
1707-
xla_tensor = tensor_methods::full(XlaHelpers::I64List(int_sizes.value()), 0,
1708-
GetXlaDeviceOrCurrent(device),
1709-
at::dtype_or_default(dtype));
1710-
} else {
1711-
xla_tensor =
1712-
tensor_methods::full_symint(sym_size, 0, GetXlaDeviceOrCurrent(device),
1713-
at::dtype_or_default(dtype));
1714-
}
1705+
XLATensorPtr xla_tensor = GetValueOrThrow(
1706+
all_dims_static
1707+
? tensor_methods::full(XlaHelpers::I64List(int_sizes.value()), 0,
1708+
GetXlaDeviceOrCurrent(device),
1709+
at::dtype_or_default(dtype))
1710+
: tensor_methods::full_symint(sym_size, 0,
1711+
GetXlaDeviceOrCurrent(device),
1712+
at::dtype_or_default(dtype)));
17151713
// `tensor.to` will trigger an `empty` + `_to_copy`. In the egaer mode, the
17161714
// `full` will be evulated eagerly and got a replicated sharding. We should
17171715
// leave the sharding to be empty.
@@ -1858,9 +1856,9 @@ at::Tensor XLANativeFunctions::full(at::IntArrayRef size,
18581856
} else {
18591857
intend_dtype = fill_value.type();
18601858
}
1861-
return bridge::AtenFromXlaTensor(
1859+
return bridge::AtenFromXlaTensor(GetValueOrThrow(
18621860
tensor_methods::full(absl::Span<const int64_t>(size), fill_value,
1863-
GetXlaDeviceOrCurrent(device), intend_dtype));
1861+
GetXlaDeviceOrCurrent(device), intend_dtype)));
18641862
}
18651863

18661864
at::Tensor XLANativeFunctions::gather(const at::Tensor& self, int64_t dim,
@@ -2681,8 +2679,8 @@ std::tuple<at::Tensor, at::Tensor> XLANativeFunctions::nll_loss2d_forward(
26812679
int64_t ignore_index) {
26822680
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
26832681
XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self));
2684-
XLATensorPtr total_weight = tensor_methods::full(
2685-
{}, 1, self_tensor->GetDevice(), self_tensor->dtype());
2682+
XLATensorPtr total_weight = GetValueOrThrow(tensor_methods::full(
2683+
{}, 1, self_tensor->GetDevice(), self_tensor->dtype()));
26862684
return std::make_tuple(
26872685
bridge::AtenFromXlaTensor(tensor_methods::nll_loss2d(
26882686
self_tensor, GetValueOrThrow(bridge::GetXlaTensor(target)),
@@ -2716,8 +2714,8 @@ std::tuple<at::Tensor, at::Tensor> XLANativeFunctions::nll_loss_forward(
27162714
int64_t ignore_index) {
27172715
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
27182716
XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self));
2719-
XLATensorPtr total_weight = tensor_methods::full(
2720-
{}, 1, self_tensor->GetDevice(), self_tensor->dtype());
2717+
XLATensorPtr total_weight = GetValueOrThrow(tensor_methods::full(
2718+
{}, 1, self_tensor->GetDevice(), self_tensor->dtype()));
27212719
return std::make_tuple(
27222720
bridge::AtenFromXlaTensor(tensor_methods::nll_loss(
27232721
self_tensor, GetValueOrThrow(bridge::GetXlaTensor(target)),
@@ -4038,10 +4036,10 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> XLANativeFunctions::_linalg_svd(
40384036
if (!compute_uv) {
40394037
// When compute_uv is false, torch::_linalg_svd returns an empty tensor for
40404038
// u and vh.
4041-
u = tensor_methods::full({0}, 0, self_tensor->GetDevice(),
4042-
self_tensor->dtype());
4043-
vh = tensor_methods::full({0}, 0, self_tensor->GetDevice(),
4044-
self_tensor->dtype());
4039+
u = GetValueOrThrow(tensor_methods::full({0}, 0, self_tensor->GetDevice(),
4040+
self_tensor->dtype()));
4041+
vh = GetValueOrThrow(tensor_methods::full({0}, 0, self_tensor->GetDevice(),
4042+
self_tensor->dtype()));
40454043
}
40464044
return std::make_tuple(bridge::AtenFromXlaTensor(u),
40474045
bridge::AtenFromXlaTensor(s),

torch_xla/csrc/ops/index_ops.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,8 @@ XLATensorPtr GetZeroElementTensor(const XLATensorPtr& base,
315315
base_dimensions.begin() + start_dim + indices.size(),
316316
base_dimensions.end());
317317

318-
return tensor_methods::full(dimensions, 0, base->GetDevice(), base->dtype());
318+
return GetValueOrThrow(
319+
tensor_methods::full(dimensions, 0, base->GetDevice(), base->dtype()));
319320
}
320321

321322
XLATensorPtr IndexByTensors(const XLATensorPtr& base,

torch_xla/csrc/tensor_methods.cpp

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,39 @@ absl::Status CheckFlipDimensionsAreUnique(
409409
return absl::OkStatus();
410410
}
411411

412+
template <class F>
413+
absl::Status CheckFullSizesArePositiveImpl(absl::Span<const int64_t> sizes,
414+
const F& original_sizes_as_str) {
415+
const bool has_concrete_negative_size = std::any_of(
416+
sizes.begin(), sizes.end(), [](const int64_t size) { return size < 0; });
417+
if (has_concrete_negative_size) {
418+
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(
419+
absl::StrCat("full(): expected concrete sizes (i.e. non-symbolic) to "
420+
"be positive values. However found negative ones: [",
421+
original_sizes_as_str(), "].")));
422+
}
423+
return absl::OkStatus();
424+
}
425+
426+
absl::Status CheckFullSizesArePositive(absl::Span<const int64_t> sizes) {
427+
return CheckFullSizesArePositiveImpl(
428+
sizes, [&]() { return absl::StrJoin(sizes, /* sep= */ ", "); });
429+
}
430+
431+
absl::Status CheckFullConcreteSizesArePositive(at::SymIntArrayRef sym_sizes) {
432+
std::vector<int64_t> concrete_sizes_or_zero;
433+
std::transform(sym_sizes.begin(), sym_sizes.end(),
434+
std::back_inserter(concrete_sizes_or_zero),
435+
[](at::SymInt sym) { return sym.maybe_as_int().value_or(0); });
436+
return CheckFullSizesArePositiveImpl(concrete_sizes_or_zero, [&]() {
437+
return absl::StrJoin(sym_sizes.begin(), sym_sizes.end(), /* sep= */ ", ",
438+
[](std::string* out, at::SymInt sym) {
439+
absl::StrAppendFormat(out, "%s",
440+
absl::FormatStreamed(sym));
441+
});
442+
});
443+
}
444+
412445
} // namespace
413446

414447
//////////////////////////////////////////////////////////////////////////////
@@ -1767,10 +1800,10 @@ XLATensorPtr fmod(const XLATensorPtr& input, const at::Scalar& other,
17671800
logical_element_type);
17681801
}
17691802

1770-
XLATensorPtr full(absl::Span<const int64_t> size, const at::Scalar& fill_value,
1771-
const torch::lazy::BackendDevice& device,
1772-
at::ScalarType scalar_type) {
1773-
CheckShapeDimensions(size);
1803+
absl::StatusOr<absl_nonnull XLATensorPtr> full(
1804+
absl::Span<const int64_t> size, const at::Scalar& fill_value,
1805+
const torch::lazy::BackendDevice& device, at::ScalarType scalar_type) {
1806+
XLA_RETURN_IF_ERROR(CheckFullSizesArePositive(size));
17741807
xla::Shape shape =
17751808
MakeArrayShapeFromDimensions(size, /*dynamic_dimensions=*/{},
17761809
MakeXlaPrimitiveType(scalar_type, &device),
@@ -1794,19 +1827,10 @@ XLATensorPtr full_like(const XLATensorPtr& input, const at::Scalar& fill_value,
17941827
device, *scalar_type);
17951828
}
17961829

1797-
XLATensorPtr full_symint(at::SymIntArrayRef sym_size,
1798-
const at::Scalar& fill_value,
1799-
const torch::lazy::BackendDevice& device,
1800-
at::ScalarType scalar_type) {
1801-
XLA_CHECK(std::all_of(sym_size.begin(), sym_size.end(), [](at::SymInt dim) {
1802-
// TODO: It should be OK to perform this test on symbolic ints too, not
1803-
// sure why you conditionalized it.
1804-
if (auto c = dim.maybe_as_int()) {
1805-
return *c >= 0;
1806-
}
1807-
return true;
1808-
})) << "Dimensions cannot be negative numbers";
1809-
1830+
absl::StatusOr<absl_nonnull XLATensorPtr> full_symint(
1831+
at::SymIntArrayRef sym_size, const at::Scalar& fill_value,
1832+
const torch::lazy::BackendDevice& device, at::ScalarType scalar_type) {
1833+
XLA_RETURN_IF_ERROR(CheckFullConcreteSizesArePositive(sym_size));
18101834
return XLATensor::Create(
18111835
XLAGraphExecutor::Get()->GetIrValueForScalar(
18121836
fill_value, MakeXlaPrimitiveType(scalar_type, &device), sym_size,

torch_xla/csrc/tensor_methods.h

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -460,16 +460,15 @@ XLATensorPtr fmod(
460460
const XLATensorPtr& input, const at::Scalar& other,
461461
std::optional<at::ScalarType> logical_element_type = std::nullopt);
462462

463-
XLATensorPtr full(absl::Span<const int64_t> size, const at::Scalar& fill_value,
464-
const torch::lazy::BackendDevice& device,
465-
at::ScalarType scalar_type);
463+
absl::StatusOr<absl_nonnull XLATensorPtr> full(
464+
absl::Span<const int64_t> size, const at::Scalar& fill_value,
465+
const torch::lazy::BackendDevice& device, at::ScalarType scalar_type);
466466
XLATensorPtr full_like(const XLATensorPtr& input, const at::Scalar& fill_value,
467467
const torch::lazy::BackendDevice& device,
468468
std::optional<at::ScalarType> scalar_type);
469-
XLATensorPtr full_symint(at::SymIntArrayRef sym_size,
470-
const at::Scalar& fill_value,
471-
const torch::lazy::BackendDevice& device,
472-
at::ScalarType scalar_type);
469+
absl::StatusOr<absl_nonnull XLATensorPtr> full_symint(
470+
at::SymIntArrayRef sym_size, const at::Scalar& fill_value,
471+
const torch::lazy::BackendDevice& device, at::ScalarType scalar_type);
473472

474473
XLATensorPtr gather(const XLATensorPtr& input, int64_t dim,
475474
const XLATensorPtr& index);

torch_xla/csrc/tensor_ops.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -207,16 +207,16 @@ XLATensorPtr EmbeddingDenseBackward(const XLATensorPtr& grad_output,
207207
int64_t numel = xla::ShapeUtil::ElementsIn(indices_shape_ref.get());
208208
XLATensorPtr grad =
209209
tensor_methods::view(grad_output, {numel, grad_output->size(-1)});
210-
XLATensorPtr grad_weight =
210+
XLATensorPtr grad_weight = GetValueOrThrow(
211211
tensor_methods::full({num_weights, grad_output->size(-1)}, 0,
212-
grad_output->GetDevice(), grad_output->dtype());
212+
grad_output->GetDevice(), grad_output->dtype()));
213213
XLATensorPtr indices_rank1 = tensor_methods::view(indices, {numel});
214214
if (scale_grad_by_freq) {
215215
// Compute the histogram of index values.
216-
XLATensorPtr counts = tensor_methods::full(
217-
{num_weights}, 0, indices->GetDevice(), indices->dtype());
218-
XLATensorPtr ones = tensor_methods::full({numel}, 1, indices->GetDevice(),
219-
indices->dtype());
216+
XLATensorPtr counts = GetValueOrThrow(tensor_methods::full(
217+
{num_weights}, 0, indices->GetDevice(), indices->dtype()));
218+
XLATensorPtr ones = GetValueOrThrow(tensor_methods::full(
219+
{numel}, 1, indices->GetDevice(), indices->dtype()));
220220
tensor_methods::index_put_(counts, counts, {indices_rank1}, /*start_dim=*/0,
221221
/*values=*/ones,
222222
/*accumulate=*/true, /*result_permutation=*/{0});

0 commit comments

Comments
 (0)