Skip to content

Commit b19c028

Browse files
author
kvshbg-aws
committed
fix for failing ci/cd tests
1 parent 938250c commit b19c028

File tree

3 files changed

+6
-4
lines changed

3 files changed

+6
-4
lines changed

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2577,12 +2577,12 @@ void InitXlaModuleBindings(py::module m) {
25772577
return GetXLAShardingSpec(xtensor);
25782578
})
25792579
.def("_get_xla_op_sharding",
2580-
[](const at::Tensor& input) -> std::optional<xla::OpSharding> {
2580+
[](const at::Tensor& input) -> std::optional<torch_xla::OpSharding> {
25812581
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
25822582
XLATensor::ShardingSpecPtr sharding_spec =
25832583
xtensor ? xtensor->sharding_spec() : nullptr;
25842584
if (sharding_spec != nullptr) {
2585-
return sharding_spec->sharding.GetXlaOpSharding();
2585+
return sharding_spec->sharding;
25862586
}
25872587
return std::nullopt;
25882588
})

torch_xla/csrc/runtime/ifrt_computation_client.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,8 @@ class IfrtComputationClient : public ComputationClient {
268268
denormalized_tile_assignment.value_or(std::vector<int64_t>{}))) {
269269
xla_output_shardings_ = this->executable->GetOutputShardings();
270270
if (xla_output_shardings_.has_value()) {
271-
output_shardings_->reserve(xla_output_shardings_->size());
271+
output_shardings_ = std::vector<torch_xla::OpSharding>{};
272+
output_shardings_->reserve(xla_output_shardings_.value().size());
272273
for (const auto& sharding : xla_output_shardings_.value()) {
273274
// convert each into torch_xla::OpSharding object
274275
torch_xla::OpSharding torch_xla_op_sharding(

torch_xla/csrc/runtime/pjrt_computation_client.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,8 @@ class PjRtComputationClient : public ComputationClient {
328328
denormalized_tile_assignment.value_or(std::vector<int64_t>{}))) {
329329
xla_output_shardings_ = this->executable->GetOutputShardings();
330330
if (xla_output_shardings_.has_value()) {
331-
output_shardings_->reserve(xla_output_shardings_->size());
331+
output_shardings_ = std::vector<torch_xla::OpSharding>{};
332+
output_shardings_->reserve(xla_output_shardings_.value().size());
332333
for (const auto& sharding : xla_output_shardings_.value()) {
333334
// convert each into torch_xla::OpSharding object
334335
torch_xla::OpSharding torch_xla_op_sharding(

0 commit comments

Comments
 (0)