Skip to content

Commit 86bac8b

Browse files
committed
Fix for API match
fix for api match
1 parent 849fe9b commit 86bac8b

File tree

3 files changed

+13
-12
lines changed

3 files changed

+13
-12
lines changed

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -761,9 +761,12 @@ std::string GetTensorsHloGraph(const std::vector<at::Tensor>& tensors,
761761
}
762762

763763
std::optional<xla::OpSharding> GetXLAOpSharding(const at::Tensor& input) {
764-
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
764+
auto xtensor = bridge::GetXlaTensor(input);
765+
if (!xtensor.ok()) {
766+
return std::nullopt;
767+
}
765768
XLATensor::ShardingSpecPtr sharding_spec =
766-
xtensor ? xtensor->sharding_spec() : nullptr;
769+
xtensor.value() ? xtensor.value()->sharding_spec() : nullptr;
767770
if (sharding_spec != nullptr) {
768771
return sharding_spec->sharding;
769772
}
@@ -3350,8 +3353,10 @@ void InitXlaModuleBindings(py::module m) {
33503353
std::string key = item.first.cast<std::string>();
33513354
options[key] = py::str(item.second).cast<std::string>();
33523355
}
3353-
runtime::GetComputationClientOrDie()->SetCustomCompileOptions(
3354-
options);
3356+
XLA_ASSIGN_OR_THROW(
3357+
runtime::ComputationClient * absl_nonnull const client,
3358+
runtime::GetComputationClient());
3359+
client->SetCustomCompileOptions(options);
33553360
})
33563361
.def(
33573362
// from an XLA tensor to a PyCapsule.

torch_xla/csrc/runtime/pjrt_computation_client.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,6 @@ absl::Status PjRtComputationClient::Initialize() {
140140
auto tracked_devices = GetLocalDevices();
141141
tracked_devices.emplace_back(spmd_device_str);
142142
operation_manager_ = std::move(OperationManager(std::move(tracked_devices)));
143-
144143
return absl::OkStatus();
145144
}
146145

torch_xla/csrc/xla_sharding_util.cpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -897,13 +897,10 @@ xla::Shape ShardingUtil::GetAdjustedGlobalShape(const at::Tensor& tensor,
897897
bool minibatch) {
898898
xla::Shape global_shape = CreateComputationShapeFromTensor(tensor, nullptr);
899899
if (minibatch) {
900-
int num_local_devices =
901-
runtime::GetComputationClientOrDie()->GetLocalDevices().size();
902-
int num_global_devices =
903-
runtime::GetComputationClientOrDie()->GetAllDevices().size();
904-
XLA_CHECK(tile_assignment.size() == num_global_devices)
905-
<< "Minibatch sharding only supports sharding along the batch "
906-
"dimension";
900+
XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client,
901+
runtime::GetComputationClient());
902+
int num_local_devices = client->GetLocalDevices().size();
903+
int num_global_devices = client->GetAllDevices().size();
907904
int batch_dim_shape =
908905
tensor.sizes()[0] * num_global_devices / num_local_devices;
909906
global_shape.set_dimensions(0, batch_dim_shape);

0 commit comments

Comments
 (0)