Skip to content

Commit 23682f9

Browse files
kvshbg-awskvshbg-aws
authored andcommitted
fix for failing ci/cd tests
1 parent 2f04535 commit 23682f9

File tree

3 files changed

+6
-4
lines changed

3 files changed

+6
-4
lines changed

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2632,12 +2632,12 @@ void InitXlaModuleBindings(py::module m) {
26322632
return GetXLAShardingSpec(xtensor);
26332633
})
26342634
.def("_get_xla_op_sharding",
2635-
[](const at::Tensor& input) -> std::optional<xla::OpSharding> {
2635+
[](const at::Tensor& input) -> std::optional<torch_xla::OpSharding> {
26362636
XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(input));
26372637
XLATensor::ShardingSpecPtr sharding_spec =
26382638
xtensor ? xtensor->sharding_spec() : nullptr;
26392639
if (sharding_spec != nullptr) {
2640-
return sharding_spec->sharding.GetXlaOpSharding();
2640+
return sharding_spec->sharding;
26412641
}
26422642
return std::nullopt;
26432643
})

torch_xla/csrc/runtime/ifrt_computation_client.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,8 @@ class IfrtComputationClient : public ComputationClient {
287287
denormalized_tile_assignment.value_or(std::vector<int64_t>{}))) {
288288
xla_output_shardings_ = this->executable->GetOutputShardings();
289289
if (xla_output_shardings_.has_value()) {
290-
output_shardings_->reserve(xla_output_shardings_->size());
290+
output_shardings_ = std::vector<torch_xla::OpSharding>{};
291+
output_shardings_->reserve(xla_output_shardings_.value().size());
291292
for (const auto& sharding : xla_output_shardings_.value()) {
292293
// convert each into torch_xla::OpSharding object
293294
torch_xla::OpSharding torch_xla_op_sharding(

torch_xla/csrc/runtime/pjrt_computation_client.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,8 @@ class PjRtComputationClient : public ComputationClient {
347347
denormalized_tile_assignment.value_or(std::vector<int64_t>{}))) {
348348
xla_output_shardings_ = this->executable->GetOutputShardings();
349349
if (xla_output_shardings_.has_value()) {
350-
output_shardings_->reserve(xla_output_shardings_->size());
350+
output_shardings_ = std::vector<torch_xla::OpSharding>{};
351+
output_shardings_->reserve(xla_output_shardings_.value().size());
351352
for (const auto& sharding : xla_output_shardings_.value()) {
352353
// convert each into torch_xla::OpSharding object
353354
torch_xla::OpSharding torch_xla_op_sharding(

0 commit comments

Comments
 (0)