You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Add V2 sharding support and improve partition spec handling for multichip training (#2)
* Add V2 sharding support and improve partition spec handling for multi-chip training
These changes are required to support multi-chip training for real models on the torch-xla side.
- Added V2 OpSharding support in XlaShardingSpec, which is used internally by MpLoader for parallel input loading. The original implementation only supported V1 shardings.
- Fixed environment variable parsing for CONVERT_SHLO_TO_SHARDY - previous logic treated values like "0" or "false" as truthy.
- Added logic to compute dims, reshape_dims, and transpose_perm for V2 sharding based on mesh_shape and partition_spec.
The new logic now correctly handles cases that were previously unsupported:
case 1: mesh_shape=(2,1,1,1), partition_spec=(0,None,None,None)
-> dims=[2,1,1,1], reshape_dims=[2,1,1,1], transpose_perm=[0,1,2,3]
case 2: mesh_shape=(2,1,1,1), partition_spec=(0,)
Ã-> dims=[2], reshape_dims=[2,1,1,1], transpose_perm=[0,1,2,3]
case 3: mesh_shape=(2,4), partition_spec=(0,None)
-> dims=[2,1,4], reshape_dims=[2,4], transpose_perm=[0,1]
* Fix formatting according to Torch-XLA style guide
---------
Co-authored-by: Het Shah <[email protected]>
0 commit comments