Skip to content

Commit 7d989d1

Browse files
sshonTThshahTT
andauthored
Add V2 sharding support and improve partition spec handling for multichip training (#2)
* Add V2 sharding support and improve partition spec handling for multi-chip training These changes are required to support multi-chip training for real models on the torch-xla side. - Added V2 OpSharding support in XlaShardingSpec, which is used internally by MpLoader for parallel input loading. The original implementation only supported V1 shardings. - Fixed environment variable parsing for CONVERT_SHLO_TO_SHARDY - previous logic treated values like "0" or "false" as truthy. - Added logic to compute dims, reshape_dims, and transpose_perm for V2 sharding based on mesh_shape and partition_spec. The new logic now correctly handles cases that were previously unsupported: case 1: mesh_shape=(2,1,1,1), partition_spec=(0,None,None,None) -> dims=[2,1,1,1], reshape_dims=[2,1,1,1], transpose_perm=[0,1,2,3] case 2: mesh_shape=(2,1,1,1), partition_spec=(0,) Ã-> dims=[2], reshape_dims=[2,1,1,1], transpose_perm=[0,1,2,3] case 3: mesh_shape=(2,4), partition_spec=(0,None) -> dims=[2,1,4], reshape_dims=[2,4], transpose_perm=[0,1] * Fix formatting according to Torch-XLA style guide --------- Co-authored-by: Het Shah <[email protected]>
1 parent b87262d commit 7d989d1

File tree

4 files changed

+58
-20
lines changed

4 files changed

+58
-20
lines changed

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1468,24 +1468,22 @@ void InitXlaModuleBindings(py::module m) {
14681468
const py::list& replication_groups, int sharding_type,
14691469
bool minibatch) {
14701470
xla::Shape global_shape =
1471-
CreateComputationShapeFromTensor(tensor, nullptr);
1472-
if (minibatch) {
1473-
int num_local_devices =
1474-
runtime::GetComputationClientOrDie()->GetLocalDevices().size();
1475-
int num_global_devices =
1476-
runtime::GetComputationClientOrDie()->GetAllDevices().size();
1477-
XLA_CHECK(tile_assignment.size() == num_global_devices)
1478-
<< "Minibatch sharding only supports sharding along the batch "
1479-
"dimension";
1480-
int batch_dim_shape =
1481-
tensor.sizes()[0] * num_global_devices / num_local_devices;
1482-
global_shape.set_dimensions(0, batch_dim_shape);
1483-
}
1471+
ShardingUtil::GetAdjustedGlobalShape(tensor, minibatch);
14841472
return std::make_shared<XLATensor::ShardingSpec>(
14851473
ShardingUtil::CreateOpSharding(
14861474
tile_assignment, group_assignment, replication_groups,
14871475
ShardingUtil::ShardingType(sharding_type)),
14881476
global_shape, minibatch);
1477+
})
1478+
.def_init([](at::Tensor tensor, const py::list& dims,
1479+
const py::list& reshape_dims, const py::list& transpose_perm,
1480+
bool minibatch) {
1481+
xla::Shape global_shape =
1482+
ShardingUtil::GetAdjustedGlobalShape(tensor, minibatch);
1483+
return std::make_shared<XLATensor::ShardingSpec>(
1484+
ShardingUtil::CreateIotaOpSharding(dims, reshape_dims,
1485+
transpose_perm),
1486+
global_shape, minibatch);
14891487
});
14901488

14911489
// Define the _XLAC.IrValue class.

torch_xla/csrc/xla_sharding_util.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -882,4 +882,20 @@ bool ShardingUtil::GetAutoSharding() {
882882
}
883883
return use_auto_sharding;
884884
}
885+
886+
xla::Shape ShardingUtil::GetAdjustedGlobalShape(const at::Tensor& tensor,
887+
bool minibatch) {
888+
xla::Shape global_shape = CreateComputationShapeFromTensor(tensor, nullptr);
889+
if (minibatch) {
890+
int num_local_devices =
891+
runtime::GetComputationClientOrDie()->GetLocalDevices().size();
892+
int num_global_devices =
893+
runtime::GetComputationClientOrDie()->GetAllDevices().size();
894+
int batch_dim_shape =
895+
tensor.sizes()[0] * num_global_devices / num_local_devices;
896+
global_shape.set_dimensions(0, batch_dim_shape);
897+
}
898+
return global_shape;
899+
}
900+
885901
} // namespace torch_xla

torch_xla/csrc/xla_sharding_util.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,9 @@ class ShardingUtil {
155155

156156
static void SetAutoSharding();
157157
static bool GetAutoSharding();
158+
159+
static xla::Shape GetAdjustedGlobalShape(const at::Tensor& tensor,
160+
bool minibatch);
158161
};
159162

160163
} // namespace torch_xla

torch_xla/distributed/spmd/xla_sharding.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -154,12 +154,10 @@ def _get_op_sharding_args(self, partition_spec: PartitionSpec):
154154

155155
@functools.lru_cache(maxsize=None)
156156
def _get_op_sharding_args_v2(self, partition_spec: PartitionSpec):
157-
"""
158-
Returns the appropriate dims, reshape_dims, and transpose_perm for the given partition spec.
159-
"""
160157
partition_spec = _translate_named_partition_spec(self, partition_spec)
161158
self._validate_translated_partition_spec(partition_spec)
162159

160+
# 1. Calculate the initial part of dims based on the partition_spec.
163161
dims = []
164162
used_axes = OrderedDict()
165163
for axis in partition_spec:
@@ -175,14 +173,22 @@ def _get_op_sharding_args_v2(self, partition_spec: PartitionSpec):
175173
dims.append(self.mesh_shape[axis])
176174
used_axes[axis] = True
177175
else:
178-
# Replicated mesh axis
179176
dims.append(1)
180177

181-
transpose_perm = [k for k in used_axes.keys()]
178+
# 2. If the product of dims is less than the total number of devices,
179+
# append the sizes of the unused mesh axes.
180+
if math.prod(dims) < math.prod(self.mesh_shape):
181+
for i in range(len(self.mesh_shape)):
182+
if i not in used_axes:
183+
dims.append(self.mesh_shape[i])
184+
185+
# 3. Calculate transpose_perm (sharded axes first, then unused axes).
186+
transpose_perm = list(used_axes.keys())
182187
for i in range(len(self.mesh_shape)):
183188
if i not in used_axes:
184-
dims.append(self.mesh_shape[i])
185189
transpose_perm.append(i)
190+
191+
# 4. reshape_dims is always the physical mesh shape.
186192
reshape_dims = list(self.mesh_shape)
187193

188194
return dims, reshape_dims, transpose_perm
@@ -591,6 +597,11 @@ def _mark_manual_sharding(
591597
return wrap_as_sharded_tensor(t)
592598

593599

600+
def _use_shlo_to_shardy() -> bool:
601+
return os.environ.get("CONVERT_SHLO_TO_SHARDY",
602+
"").lower() in ("1", "true", "yes")
603+
604+
594605
def enable_manual_sharding(t: Union[torch.Tensor, XLAShardedTensor],
595606
partition_spec: PartitionSpec,
596607
*,
@@ -710,7 +721,7 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
710721
t.shard_(NamedSharding(jmesh, P(*partition_spec)))
711722
return t
712723

713-
if os.environ.get('CONVERT_SHLO_TO_SHARDY', False):
724+
if _use_shlo_to_shardy():
714725
op_sharding = mesh.get_op_sharding_v2(partition_spec)
715726
else:
716727
op_sharding = mesh.get_op_sharding(partition_spec)
@@ -869,6 +880,9 @@ def __post_init__(self):
869880
self._group_assignment, self._replication_groups = _get_group_assignment(
870881
self._sharding_type, tile_assignment, len(partition_spec),
871882
replicate_dims)
883+
if _use_shlo_to_shardy():
884+
self.dims, self.reshape_dims, self.transpose_dims = mesh._get_op_sharding_args_v2(
885+
partition_spec)
872886

873887
def xla_spec(self, t: torch.Tensor) -> Union['XlaShardingSpec', None]:
874888
"""
@@ -877,6 +891,13 @@ def xla_spec(self, t: torch.Tensor) -> Union['XlaShardingSpec', None]:
877891
"""
878892
if not self.can_apply(t):
879893
return None
894+
895+
if _use_shlo_to_shardy():
896+
# Convert to Shardy spec if the environment variable is set.
897+
return torch_xla._XLAC.XlaShardingSpec(t, self.dims, self.reshape_dims,
898+
self.transpose_dims,
899+
self.minibatch)
900+
880901
return torch_xla._XLAC.XlaShardingSpec(t, self._tile_assignment,
881902
self._group_assignment,
882903
self._replication_groups,

0 commit comments

Comments
 (0)