Skip to content

Commit 275f369

Browse files
sshonTThshahTT
andcommitted
Add V2 sharding support and improve partition spec handling for multichip training (#2)
* Add V2 sharding support and improve partition spec handling for multi-chip training These changes are required to support multi-chip training for real models on the torch-xla side. - Added V2 OpSharding support in XlaShardingSpec, which is used internally by MpLoader for parallel input loading. The original implementation only supported V1 shardings. - Fixed environment variable parsing for CONVERT_SHLO_TO_SHARDY - previous logic treated values like "0" or "false" as truthy. - Added logic to compute dims, reshape_dims, and transpose_perm for V2 sharding based on mesh_shape and partition_spec. The new logic now correctly handles cases that were previously unsupported: case 1: mesh_shape=(2,1,1,1), partition_spec=(0,None,None,None) -> dims=[2,1,1,1], reshape_dims=[2,1,1,1], transpose_perm=[0,1,2,3] case 2: mesh_shape=(2,1,1,1), partition_spec=(0,) Ã-> dims=[2], reshape_dims=[2,1,1,1], transpose_perm=[0,1,2,3] case 3: mesh_shape=(2,4), partition_spec=(0,None) -> dims=[2,1,4], reshape_dims=[2,4], transpose_perm=[0,1] * Fix formatting according to Torch-XLA style guide --------- Co-authored-by: Het Shah <[email protected]>
1 parent e0cedaf commit 275f369

File tree

4 files changed

+58
-20
lines changed

4 files changed

+58
-20
lines changed

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1525,24 +1525,22 @@ void InitXlaModuleBindings(py::module m) {
15251525
const py::list& replication_groups, int sharding_type,
15261526
bool minibatch) {
15271527
xla::Shape global_shape =
1528-
CreateComputationShapeFromTensor(tensor, nullptr);
1529-
if (minibatch) {
1530-
int num_local_devices =
1531-
runtime::GetComputationClientOrDie()->GetLocalDevices().size();
1532-
int num_global_devices =
1533-
runtime::GetComputationClientOrDie()->GetAllDevices().size();
1534-
XLA_CHECK(tile_assignment.size() == num_global_devices)
1535-
<< "Minibatch sharding only supports sharding along the batch "
1536-
"dimension";
1537-
int batch_dim_shape =
1538-
tensor.sizes()[0] * num_global_devices / num_local_devices;
1539-
global_shape.set_dimensions(0, batch_dim_shape);
1540-
}
1528+
ShardingUtil::GetAdjustedGlobalShape(tensor, minibatch);
15411529
return std::make_shared<XLATensor::ShardingSpec>(
15421530
ShardingUtil::CreateOpSharding(
15431531
tile_assignment, group_assignment, replication_groups,
15441532
ShardingUtil::ShardingType(sharding_type)),
15451533
global_shape, minibatch);
1534+
})
1535+
.def_init([](at::Tensor tensor, const py::list& dims,
1536+
const py::list& reshape_dims, const py::list& transpose_perm,
1537+
bool minibatch) {
1538+
xla::Shape global_shape =
1539+
ShardingUtil::GetAdjustedGlobalShape(tensor, minibatch);
1540+
return std::make_shared<XLATensor::ShardingSpec>(
1541+
ShardingUtil::CreateIotaOpSharding(dims, reshape_dims,
1542+
transpose_perm),
1543+
global_shape, minibatch);
15461544
});
15471545

15481546
// Define the _XLAC.IrValue class.

torch_xla/csrc/xla_sharding_util.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -882,4 +882,20 @@ bool ShardingUtil::GetAutoSharding() {
882882
}
883883
return use_auto_sharding;
884884
}
885+
886+
xla::Shape ShardingUtil::GetAdjustedGlobalShape(const at::Tensor& tensor,
887+
bool minibatch) {
888+
xla::Shape global_shape = CreateComputationShapeFromTensor(tensor, nullptr);
889+
if (minibatch) {
890+
int num_local_devices =
891+
runtime::GetComputationClientOrDie()->GetLocalDevices().size();
892+
int num_global_devices =
893+
runtime::GetComputationClientOrDie()->GetAllDevices().size();
894+
int batch_dim_shape =
895+
tensor.sizes()[0] * num_global_devices / num_local_devices;
896+
global_shape.set_dimensions(0, batch_dim_shape);
897+
}
898+
return global_shape;
899+
}
900+
885901
} // namespace torch_xla

torch_xla/csrc/xla_sharding_util.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,9 @@ class ShardingUtil {
155155

156156
static void SetAutoSharding();
157157
static bool GetAutoSharding();
158+
159+
static xla::Shape GetAdjustedGlobalShape(const at::Tensor& tensor,
160+
bool minibatch);
158161
};
159162

160163
} // namespace torch_xla

torch_xla/distributed/spmd/xla_sharding.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -154,12 +154,10 @@ def _get_op_sharding_args(self, partition_spec: PartitionSpec):
154154

155155
@functools.lru_cache(maxsize=None)
156156
def _get_op_sharding_args_v2(self, partition_spec: PartitionSpec):
157-
"""
158-
Returns the appropriate dims, reshape_dims, and transpose_perm for the given partition spec.
159-
"""
160157
partition_spec = _translate_named_partition_spec(self, partition_spec)
161158
self._validate_translated_partition_spec(partition_spec)
162159

160+
# 1. Calculate the initial part of dims based on the partition_spec.
163161
dims = []
164162
used_axes = OrderedDict()
165163
for axis in partition_spec:
@@ -175,14 +173,22 @@ def _get_op_sharding_args_v2(self, partition_spec: PartitionSpec):
175173
dims.append(self.mesh_shape[axis])
176174
used_axes[axis] = True
177175
else:
178-
# Replicated mesh axis
179176
dims.append(1)
180177

181-
transpose_perm = [k for k in used_axes.keys()]
178+
# 2. If the product of dims is less than the total number of devices,
179+
# append the sizes of the unused mesh axes.
180+
if math.prod(dims) < math.prod(self.mesh_shape):
181+
for i in range(len(self.mesh_shape)):
182+
if i not in used_axes:
183+
dims.append(self.mesh_shape[i])
184+
185+
# 3. Calculate transpose_perm (sharded axes first, then unused axes).
186+
transpose_perm = list(used_axes.keys())
182187
for i in range(len(self.mesh_shape)):
183188
if i not in used_axes:
184-
dims.append(self.mesh_shape[i])
185189
transpose_perm.append(i)
190+
191+
# 4. reshape_dims is always the physical mesh shape.
186192
reshape_dims = list(self.mesh_shape)
187193

188194
return dims, reshape_dims, transpose_perm
@@ -592,6 +598,11 @@ def _mark_manual_sharding(
592598
return wrap_as_sharded_tensor(t)
593599

594600

601+
def _use_shlo_to_shardy() -> bool:
602+
return os.environ.get("CONVERT_SHLO_TO_SHARDY",
603+
"").lower() in ("1", "true", "yes")
604+
605+
595606
def enable_manual_sharding(t: Union[torch.Tensor, XLAShardedTensor],
596607
partition_spec: PartitionSpec,
597608
*,
@@ -715,7 +726,7 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
715726
t.shard_(NamedSharding(jmesh, P(*partition_spec)))
716727
return t
717728

718-
if os.environ.get('CONVERT_SHLO_TO_SHARDY', False):
729+
if _use_shlo_to_shardy():
719730
op_sharding = mesh.get_op_sharding_v2(partition_spec)
720731
else:
721732
op_sharding = mesh.get_op_sharding(partition_spec)
@@ -897,6 +908,9 @@ def __post_init__(self):
897908
self._group_assignment, self._replication_groups = _get_group_assignment(
898909
self._sharding_type, tile_assignment, len(partition_spec),
899910
replicate_dims)
911+
if _use_shlo_to_shardy():
912+
self.dims, self.reshape_dims, self.transpose_dims = mesh._get_op_sharding_args_v2(
913+
partition_spec)
900914

901915
def xla_spec(self, t: torch.Tensor) -> Union['XlaShardingSpec', None]:
902916
"""
@@ -905,6 +919,13 @@ def xla_spec(self, t: torch.Tensor) -> Union['XlaShardingSpec', None]:
905919
"""
906920
if not self.can_apply(t):
907921
return None
922+
923+
if _use_shlo_to_shardy():
924+
# Convert to Shardy spec if the environment variable is set.
925+
return torch_xla._XLAC.XlaShardingSpec(t, self.dims, self.reshape_dims,
926+
self.transpose_dims,
927+
self.minibatch)
928+
908929
return torch_xla._XLAC.XlaShardingSpec(t, self._tile_assignment,
909930
self._group_assignment,
910931
self._replication_groups,

0 commit comments

Comments
 (0)