Skip to content

Commit 24bb34c

Browse files
sshonTThshahTT
andcommitted
Add V2 sharding support and improve partition spec handling for multichip training (#2)
* Add V2 sharding support and improve partition spec handling for multi-chip training These changes are required to support multi-chip training for real models on the torch-xla side. - Added V2 OpSharding support in XlaShardingSpec, which is used internally by MpLoader for parallel input loading. The original implementation only supported V1 shardings. - Fixed environment variable parsing for CONVERT_SHLO_TO_SHARDY - previous logic treated values like "0" or "false" as truthy. - Added logic to compute dims, reshape_dims, and transpose_perm for V2 sharding based on mesh_shape and partition_spec. The new logic now correctly handles cases that were previously unsupported: case 1: mesh_shape=(2,1,1,1), partition_spec=(0,None,None,None) -> dims=[2,1,1,1], reshape_dims=[2,1,1,1], transpose_perm=[0,1,2,3] case 2: mesh_shape=(2,1,1,1), partition_spec=(0,) Ã-> dims=[2], reshape_dims=[2,1,1,1], transpose_perm=[0,1,2,3] case 3: mesh_shape=(2,4), partition_spec=(0,None) -> dims=[2,1,4], reshape_dims=[2,4], transpose_perm=[0,1] * Fix formatting according to Torch-XLA style guide --------- Co-authored-by: Het Shah <[email protected]>
1 parent 58da15c commit 24bb34c

File tree

4 files changed

+61
-21
lines changed

4 files changed

+61
-21
lines changed

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1534,25 +1534,22 @@ void InitXlaModuleBindings(py::module m) {
15341534
const py::list& replication_groups, int sharding_type,
15351535
bool minibatch) {
15361536
xla::Shape global_shape =
1537-
CreateComputationShapeFromTensor(tensor, nullptr);
1538-
if (minibatch) {
1539-
XLA_ASSIGN_OR_THROW(
1540-
runtime::ComputationClient * absl_nonnull const client,
1541-
runtime::GetComputationClient());
1542-
int num_local_devices = client->GetLocalDevices().size();
1543-
int num_global_devices = client->GetAllDevices().size();
1544-
XLA_CHECK(tile_assignment.size() == num_global_devices)
1545-
<< "Minibatch sharding only supports sharding along the batch "
1546-
"dimension";
1547-
int batch_dim_shape =
1548-
tensor.sizes()[0] * num_global_devices / num_local_devices;
1549-
global_shape.set_dimensions(0, batch_dim_shape);
1550-
}
1537+
ShardingUtil::GetAdjustedGlobalShape(tensor, minibatch);
15511538
return std::make_shared<XLATensor::ShardingSpec>(
15521539
ShardingUtil::CreateOpSharding(
15531540
tile_assignment, group_assignment, replication_groups,
15541541
ShardingUtil::ShardingType(sharding_type)),
15551542
global_shape, minibatch);
1543+
})
1544+
.def_init([](at::Tensor tensor, const py::list& dims,
1545+
const py::list& reshape_dims, const py::list& transpose_perm,
1546+
bool minibatch) {
1547+
xla::Shape global_shape =
1548+
ShardingUtil::GetAdjustedGlobalShape(tensor, minibatch);
1549+
return std::make_shared<XLATensor::ShardingSpec>(
1550+
ShardingUtil::CreateIotaOpSharding(dims, reshape_dims,
1551+
transpose_perm),
1552+
global_shape, minibatch);
15561553
});
15571554

15581555
// Define the _XLAC.IrValue class.

torch_xla/csrc/xla_sharding_util.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -889,4 +889,23 @@ bool ShardingUtil::GetAutoSharding() {
889889
}
890890
return use_auto_sharding;
891891
}
892+
893+
xla::Shape ShardingUtil::GetAdjustedGlobalShape(const at::Tensor& tensor,
894+
bool minibatch) {
895+
xla::Shape global_shape = CreateComputationShapeFromTensor(tensor, nullptr);
896+
if (minibatch) {
897+
int num_local_devices =
898+
runtime::GetComputationClientOrDie()->GetLocalDevices().size();
899+
int num_global_devices =
900+
runtime::GetComputationClientOrDie()->GetAllDevices().size();
901+
XLA_CHECK(tile_assignment.size() == num_global_devices)
902+
<< "Minibatch sharding only supports sharding along the batch "
903+
"dimension";
904+
int batch_dim_shape =
905+
tensor.sizes()[0] * num_global_devices / num_local_devices;
906+
global_shape.set_dimensions(0, batch_dim_shape);
907+
}
908+
return global_shape;
909+
}
910+
892911
} // namespace torch_xla

torch_xla/csrc/xla_sharding_util.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,9 @@ class ShardingUtil {
155155

156156
static void SetAutoSharding();
157157
static bool GetAutoSharding();
158+
159+
static xla::Shape GetAdjustedGlobalShape(const at::Tensor& tensor,
160+
bool minibatch);
158161
};
159162

160163
} // namespace torch_xla

torch_xla/distributed/spmd/xla_sharding.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -154,12 +154,10 @@ def _get_op_sharding_args(self, partition_spec: PartitionSpec):
154154

155155
@functools.lru_cache(maxsize=None)
156156
def _get_op_sharding_args_v2(self, partition_spec: PartitionSpec):
157-
"""
158-
Returns the appropriate dims, reshape_dims, and transpose_perm for the given partition spec.
159-
"""
160157
partition_spec = _translate_named_partition_spec(self, partition_spec)
161158
self._validate_translated_partition_spec(partition_spec)
162159

160+
# 1. Calculate the initial part of dims based on the partition_spec.
163161
dims = []
164162
used_axes = OrderedDict()
165163
for axis in partition_spec:
@@ -175,14 +173,22 @@ def _get_op_sharding_args_v2(self, partition_spec: PartitionSpec):
175173
dims.append(self.mesh_shape[axis])
176174
used_axes[axis] = True
177175
else:
178-
# Replicated mesh axis
179176
dims.append(1)
180177

181-
transpose_perm = [k for k in used_axes.keys()]
178+
# 2. If the product of dims is less than the total number of devices,
179+
# append the sizes of the unused mesh axes.
180+
if math.prod(dims) < math.prod(self.mesh_shape):
181+
for i in range(len(self.mesh_shape)):
182+
if i not in used_axes:
183+
dims.append(self.mesh_shape[i])
184+
185+
# 3. Calculate transpose_perm (sharded axes first, then unused axes).
186+
transpose_perm = list(used_axes.keys())
182187
for i in range(len(self.mesh_shape)):
183188
if i not in used_axes:
184-
dims.append(self.mesh_shape[i])
185189
transpose_perm.append(i)
190+
191+
# 4. reshape_dims is always the physical mesh shape.
186192
reshape_dims = list(self.mesh_shape)
187193

188194
return dims, reshape_dims, transpose_perm
@@ -592,6 +598,11 @@ def _mark_manual_sharding(
592598
return wrap_as_sharded_tensor(t)
593599

594600

601+
def _use_shlo_to_shardy() -> bool:
602+
return os.environ.get("CONVERT_SHLO_TO_SHARDY",
603+
"").lower() in ("1", "true", "yes")
604+
605+
595606
def enable_manual_sharding(t: Union[torch.Tensor, XLAShardedTensor],
596607
partition_spec: PartitionSpec,
597608
*,
@@ -716,7 +727,7 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
716727
t.shard_(NamedSharding(jmesh, P(*partition_spec)))
717728
return t
718729

719-
if os.environ.get('CONVERT_SHLO_TO_SHARDY', False):
730+
if _use_shlo_to_shardy():
720731
op_sharding = mesh.get_op_sharding_v2(partition_spec)
721732
else:
722733
op_sharding = mesh.get_op_sharding(partition_spec)
@@ -898,6 +909,9 @@ def __post_init__(self):
898909
self._group_assignment, self._replication_groups = _get_group_assignment(
899910
self._sharding_type, tile_assignment, len(partition_spec),
900911
replicate_dims)
912+
if _use_shlo_to_shardy():
913+
self.dims, self.reshape_dims, self.transpose_dims = mesh._get_op_sharding_args_v2(
914+
partition_spec)
901915

902916
def xla_spec(self, t: torch.Tensor) -> Union['XlaShardingSpec', None]:
903917
"""
@@ -906,6 +920,13 @@ def xla_spec(self, t: torch.Tensor) -> Union['XlaShardingSpec', None]:
906920
"""
907921
if not self.can_apply(t):
908922
return None
923+
924+
if _use_shlo_to_shardy():
925+
# Convert to Shardy spec if the environment variable is set.
926+
return torch_xla._XLAC.XlaShardingSpec(t, self.dims, self.reshape_dims,
927+
self.transpose_dims,
928+
self.minibatch)
929+
909930
return torch_xla._XLAC.XlaShardingSpec(t, self._tile_assignment,
910931
self._group_assignment,
911932
self._replication_groups,

0 commit comments

Comments
 (0)