File tree Expand file tree Collapse file tree 3 files changed +6
-4
lines changed Expand file tree Collapse file tree 3 files changed +6
-4
lines changed Original file line number Diff line number Diff line change @@ -2577,12 +2577,12 @@ void InitXlaModuleBindings(py::module m) {
2577
2577
return GetXLAShardingSpec (xtensor);
2578
2578
})
2579
2579
.def (" _get_xla_op_sharding" ,
2580
- [](const at::Tensor& input) -> std::optional<xla ::OpSharding> {
2580
+ [](const at::Tensor& input) -> std::optional<torch_xla ::OpSharding> {
2581
2581
XLATensorPtr xtensor = bridge::GetXlaTensor (input);
2582
2582
XLATensor::ShardingSpecPtr sharding_spec =
2583
2583
xtensor ? xtensor->sharding_spec () : nullptr ;
2584
2584
if (sharding_spec != nullptr ) {
2585
- return sharding_spec->sharding . GetXlaOpSharding () ;
2585
+ return sharding_spec->sharding ;
2586
2586
}
2587
2587
return std::nullopt ;
2588
2588
})
Original file line number Diff line number Diff line change @@ -268,7 +268,8 @@ class IfrtComputationClient : public ComputationClient {
268
268
denormalized_tile_assignment.value_or(std::vector<int64_t >{}))) {
269
269
xla_output_shardings_ = this ->executable ->GetOutputShardings ();
270
270
if (xla_output_shardings_.has_value ()) {
271
- output_shardings_->reserve (xla_output_shardings_->size ());
271
+ output_shardings_ = std::vector<torch_xla::OpSharding>{};
272
+ output_shardings_->reserve (xla_output_shardings_.value ().size ());
272
273
for (const auto & sharding : xla_output_shardings_.value ()) {
273
274
// convert each into torch_xla::OpSharding object
274
275
torch_xla::OpSharding torch_xla_op_sharding (
Original file line number Diff line number Diff line change @@ -328,7 +328,8 @@ class PjRtComputationClient : public ComputationClient {
328
328
denormalized_tile_assignment.value_or(std::vector<int64_t >{}))) {
329
329
xla_output_shardings_ = this ->executable ->GetOutputShardings ();
330
330
if (xla_output_shardings_.has_value ()) {
331
- output_shardings_->reserve (xla_output_shardings_->size ());
331
+ output_shardings_ = std::vector<torch_xla::OpSharding>{};
332
+ output_shardings_->reserve (xla_output_shardings_.value ().size ());
332
333
for (const auto & sharding : xla_output_shardings_.value ()) {
333
334
// convert each into torch_xla::OpSharding object
334
335
torch_xla::OpSharding torch_xla_op_sharding (
You can’t perform that action at this time.
0 commit comments