File tree Expand file tree Collapse file tree 3 files changed +6
-4
lines changed Expand file tree Collapse file tree 3 files changed +6
-4
lines changed Original file line number Diff line number Diff line change @@ -2632,12 +2632,12 @@ void InitXlaModuleBindings(py::module m) {
2632
2632
return GetXLAShardingSpec (xtensor);
2633
2633
})
2634
2634
.def (" _get_xla_op_sharding" ,
2635
- [](const at::Tensor& input) -> std::optional<xla ::OpSharding> {
2635
+ [](const at::Tensor& input) -> std::optional<torch_xla ::OpSharding> {
2636
2636
XLATensorPtr xtensor = GetValueOrThrow (bridge::GetXlaTensor (input));
2637
2637
XLATensor::ShardingSpecPtr sharding_spec =
2638
2638
xtensor ? xtensor->sharding_spec () : nullptr ;
2639
2639
if (sharding_spec != nullptr ) {
2640
- return sharding_spec->sharding . GetXlaOpSharding () ;
2640
+ return sharding_spec->sharding ;
2641
2641
}
2642
2642
return std::nullopt ;
2643
2643
})
Original file line number Diff line number Diff line change @@ -287,7 +287,8 @@ class IfrtComputationClient : public ComputationClient {
287
287
denormalized_tile_assignment.value_or(std::vector<int64_t >{}))) {
288
288
xla_output_shardings_ = this ->executable ->GetOutputShardings ();
289
289
if (xla_output_shardings_.has_value ()) {
290
- output_shardings_->reserve (xla_output_shardings_->size ());
290
+ output_shardings_ = std::vector<torch_xla::OpSharding>{};
291
+ output_shardings_->reserve (xla_output_shardings_.value ().size ());
291
292
for (const auto & sharding : xla_output_shardings_.value ()) {
292
293
// convert each into torch_xla::OpSharding object
293
294
torch_xla::OpSharding torch_xla_op_sharding (
Original file line number Diff line number Diff line change @@ -347,7 +347,8 @@ class PjRtComputationClient : public ComputationClient {
347
347
denormalized_tile_assignment.value_or(std::vector<int64_t >{}))) {
348
348
xla_output_shardings_ = this ->executable ->GetOutputShardings ();
349
349
if (xla_output_shardings_.has_value ()) {
350
- output_shardings_->reserve (xla_output_shardings_->size ());
350
+ output_shardings_ = std::vector<torch_xla::OpSharding>{};
351
+ output_shardings_->reserve (xla_output_shardings_.value ().size ());
351
352
for (const auto & sharding : xla_output_shardings_.value ()) {
352
353
// convert each into torch_xla::OpSharding object
353
354
torch_xla::OpSharding torch_xla_op_sharding (
You can’t perform that action at this time.
0 commit comments