@@ -1702,16 +1702,14 @@ at::Tensor XLANativeFunctions::empty_symint(
1702
1702
// does not actually end up doing any memory initialization, we use that and
1703
1703
// avoid going to CPU for it. A common PT pattern is indeed doing empty() plus
1704
1704
// s_copy_().
1705
- XLATensorPtr xla_tensor;
1706
- if (all_dims_static) {
1707
- xla_tensor = tensor_methods::full (XlaHelpers::I64List (int_sizes.value ()), 0 ,
1708
- GetXlaDeviceOrCurrent (device),
1709
- at::dtype_or_default (dtype));
1710
- } else {
1711
- xla_tensor =
1712
- tensor_methods::full_symint (sym_size, 0 , GetXlaDeviceOrCurrent (device),
1713
- at::dtype_or_default (dtype));
1714
- }
1705
+ XLATensorPtr xla_tensor = GetValueOrThrow (
1706
+ all_dims_static
1707
+ ? tensor_methods::full (XlaHelpers::I64List (int_sizes.value ()), 0 ,
1708
+ GetXlaDeviceOrCurrent (device),
1709
+ at::dtype_or_default (dtype))
1710
+ : tensor_methods::full_symint (sym_size, 0 ,
1711
+ GetXlaDeviceOrCurrent (device),
1712
+ at::dtype_or_default (dtype)));
1715
1713
// `tensor.to` will trigger an `empty` + `_to_copy`. In the egaer mode, the
1716
1714
// `full` will be evulated eagerly and got a replicated sharding. We should
1717
1715
// leave the sharding to be empty.
@@ -1858,9 +1856,9 @@ at::Tensor XLANativeFunctions::full(at::IntArrayRef size,
1858
1856
} else {
1859
1857
intend_dtype = fill_value.type ();
1860
1858
}
1861
- return bridge::AtenFromXlaTensor (
1859
+ return bridge::AtenFromXlaTensor (GetValueOrThrow (
1862
1860
tensor_methods::full (absl::Span<const int64_t >(size), fill_value,
1863
- GetXlaDeviceOrCurrent (device), intend_dtype));
1861
+ GetXlaDeviceOrCurrent (device), intend_dtype))) ;
1864
1862
}
1865
1863
1866
1864
at::Tensor XLANativeFunctions::gather (const at::Tensor& self, int64_t dim,
@@ -2681,8 +2679,8 @@ std::tuple<at::Tensor, at::Tensor> XLANativeFunctions::nll_loss2d_forward(
2681
2679
int64_t ignore_index) {
2682
2680
TORCH_LAZY_FN_COUNTER_TIMED_TRACING (" xla::" );
2683
2681
XLATensorPtr self_tensor = GetValueOrThrow (bridge::GetXlaTensor (self));
2684
- XLATensorPtr total_weight = tensor_methods::full (
2685
- {}, 1 , self_tensor->GetDevice (), self_tensor->dtype ());
2682
+ XLATensorPtr total_weight = GetValueOrThrow ( tensor_methods::full (
2683
+ {}, 1 , self_tensor->GetDevice (), self_tensor->dtype ())) ;
2686
2684
return std::make_tuple (
2687
2685
bridge::AtenFromXlaTensor (tensor_methods::nll_loss2d (
2688
2686
self_tensor, GetValueOrThrow (bridge::GetXlaTensor (target)),
@@ -2716,8 +2714,8 @@ std::tuple<at::Tensor, at::Tensor> XLANativeFunctions::nll_loss_forward(
2716
2714
int64_t ignore_index) {
2717
2715
TORCH_LAZY_FN_COUNTER_TIMED_TRACING (" xla::" );
2718
2716
XLATensorPtr self_tensor = GetValueOrThrow (bridge::GetXlaTensor (self));
2719
- XLATensorPtr total_weight = tensor_methods::full (
2720
- {}, 1 , self_tensor->GetDevice (), self_tensor->dtype ());
2717
+ XLATensorPtr total_weight = GetValueOrThrow ( tensor_methods::full (
2718
+ {}, 1 , self_tensor->GetDevice (), self_tensor->dtype ())) ;
2721
2719
return std::make_tuple (
2722
2720
bridge::AtenFromXlaTensor (tensor_methods::nll_loss (
2723
2721
self_tensor, GetValueOrThrow (bridge::GetXlaTensor (target)),
@@ -4038,10 +4036,10 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> XLANativeFunctions::_linalg_svd(
4038
4036
if (!compute_uv) {
4039
4037
// When compute_uv is false, torch::_linalg_svd returns an empty tensor for
4040
4038
// u and vh.
4041
- u = tensor_methods::full ({0 }, 0 , self_tensor->GetDevice (),
4042
- self_tensor->dtype ());
4043
- vh = tensor_methods::full ({0 }, 0 , self_tensor->GetDevice (),
4044
- self_tensor->dtype ());
4039
+ u = GetValueOrThrow ( tensor_methods::full ({0 }, 0 , self_tensor->GetDevice (),
4040
+ self_tensor->dtype () ));
4041
+ vh = GetValueOrThrow ( tensor_methods::full ({0 }, 0 , self_tensor->GetDevice (),
4042
+ self_tensor->dtype () ));
4045
4043
}
4046
4044
return std::make_tuple (bridge::AtenFromXlaTensor (u),
4047
4045
bridge::AtenFromXlaTensor (s),
0 commit comments