File tree Expand file tree Collapse file tree 3 files changed +6
-4
lines changed Expand file tree Collapse file tree 3 files changed +6
-4
lines changed Original file line number Diff line number Diff line change @@ -2577,12 +2577,12 @@ void InitXlaModuleBindings(py::module m) {
2577
2577
return GetXLAShardingSpec (xtensor);
2578
2578
})
2579
2579
.def (" _get_xla_op_sharding" ,
2580
- [](const at::Tensor& input) -> std::optional<xla ::OpSharding> {
2580
+ [](const at::Tensor& input) -> std::optional<torch_xla ::OpSharding> {
2581
2581
XLATensorPtr xtensor = bridge::GetXlaTensor (input);
2582
2582
XLATensor::ShardingSpecPtr sharding_spec =
2583
2583
xtensor ? xtensor->sharding_spec () : nullptr ;
2584
2584
if (sharding_spec != nullptr ) {
2585
- return sharding_spec->sharding . GetXlaOpSharding () ;
2585
+ return sharding_spec->sharding ;
2586
2586
}
2587
2587
return std::nullopt ;
2588
2588
})
Original file line number Diff line number Diff line change @@ -285,7 +285,8 @@ class IfrtComputationClient : public ComputationClient {
285
285
denormalized_tile_assignment.value_or(std::vector<int64_t >{}))) {
286
286
xla_output_shardings_ = this ->executable ->GetOutputShardings ();
287
287
if (xla_output_shardings_.has_value ()) {
288
- output_shardings_->reserve (xla_output_shardings_->size ());
288
+ output_shardings_ = std::vector<torch_xla::OpSharding>{};
289
+ output_shardings_->reserve (xla_output_shardings_.value ().size ());
289
290
for (const auto & sharding : xla_output_shardings_.value ()) {
290
291
// convert each into torch_xla::OpSharding object
291
292
torch_xla::OpSharding torch_xla_op_sharding (
Original file line number Diff line number Diff line change @@ -345,7 +345,8 @@ class PjRtComputationClient : public ComputationClient {
345
345
denormalized_tile_assignment.value_or(std::vector<int64_t >{}))) {
346
346
xla_output_shardings_ = this ->executable ->GetOutputShardings ();
347
347
if (xla_output_shardings_.has_value ()) {
348
- output_shardings_->reserve (xla_output_shardings_->size ());
348
+ output_shardings_ = std::vector<torch_xla::OpSharding>{};
349
+ output_shardings_->reserve (xla_output_shardings_.value ().size ());
349
350
for (const auto & sharding : xla_output_shardings_.value ()) {
350
351
// convert each into torch_xla::OpSharding object
351
352
torch_xla::OpSharding torch_xla_op_sharding (
You can’t perform that action at this time.
0 commit comments