1
1
import collections
2
2
from collections .abc import Generator , MutableMapping
3
3
import math
4
+ import os
4
5
from collections import OrderedDict , defaultdict
5
6
from dataclasses import dataclass , field
6
7
import torch
@@ -118,9 +119,18 @@ def get_axis_name_idx(self, name: str) -> int:
118
119
return None
119
120
return self .axis_names .index (name )
120
121
122
+ def _validate_translated_partition_spec (self , partition_spec : tuple ):
123
+ flat_specs = np .hstack ([d for d in partition_spec ])
124
+ specs = [d for d in flat_specs if d is not None ]
125
+ assert all (d >= 0 and d < len (self .mesh_shape ) for d in specs ), \
126
+ f"partition_spec ({ partition_spec } ) contains out of bound index into mesh_shape."
127
+ assert len (specs ) == len (np .unique (specs )), \
128
+ f"Each device mesh dimension should appear at most once in partition_spec { partition_spec } ."
129
+
121
130
@functools .lru_cache (maxsize = None )
122
131
def _get_op_sharding_args (self , partition_spec : PartitionSpec ):
123
132
partition_spec = _translate_named_partition_spec (self , partition_spec )
133
+ self ._validate_translated_partition_spec (partition_spec )
124
134
flat_specs = np .hstack ([d for d in partition_spec ])
125
135
specs = [d for d in flat_specs if d is not None ]
126
136
assert all (d >= 0 and d < len (self .mesh_shape ) for d in specs ), \
@@ -142,6 +152,57 @@ def _get_op_sharding_args(self, partition_spec: PartitionSpec):
142
152
sharding_type = int (sharding_type )
143
153
return tile_assignment , group_assignment , replication_groups , sharding_type
144
154
155
+ @functools .lru_cache (maxsize = None )
156
+ def _get_op_sharding_args_v2 (self , partition_spec : PartitionSpec ):
157
+ """
158
+ Returns the appropriate dims, reshape_dims, and transpose_perm for the given partition spec.
159
+ """
160
+ partition_spec = _translate_named_partition_spec (self , partition_spec )
161
+ self ._validate_translated_partition_spec (partition_spec )
162
+
163
+ dims = []
164
+ used_axes = OrderedDict ()
165
+ for axis in partition_spec :
166
+ if isinstance (axis , tuple ):
167
+ dim_size = 1
168
+ for i in axis :
169
+ assert i is not None , "None not allowed within tuple"
170
+ dim_size *= self .mesh_shape [i ]
171
+ used_axes [i ] = True
172
+ dims .append (dim_size )
173
+ elif axis is not None :
174
+ assert isinstance (axis , int ), "Axis must be an int or a tuple of ints"
175
+ dims .append (self .mesh_shape [axis ])
176
+ used_axes [axis ] = True
177
+ else :
178
+ # Replicated mesh axis
179
+ dims .append (1 )
180
+
181
+ transpose_perm = [k for k in used_axes .keys ()]
182
+ for i in range (len (self .mesh_shape )):
183
+ if i not in used_axes :
184
+ dims .append (self .mesh_shape [i ])
185
+ transpose_perm .append (i )
186
+ reshape_dims = list (self .mesh_shape )
187
+
188
+ return dims , reshape_dims , transpose_perm
189
+
190
+ @functools .lru_cache (maxsize = None )
191
+ def get_op_sharding_v2 (
192
+ self , partition_spec : PartitionSpec ) -> torch_xla ._XLAC .OpSharding :
193
+ """
194
+ Return the OpSharding for the given partition spec using V2 annotations.
195
+ """
196
+ if len (partition_spec ) == 0 :
197
+ return torch_xla ._XLAC .OpSharding ([], [], [], ShardingType .REPLICATED )
198
+ sharding_type = _get_sharding_type (partition_spec , self .size ())
199
+ if sharding_type not in (ShardingType .TILED , ShardingType .PARTIAL ):
200
+ return torch_xla ._XLAC .OpSharding ([], [], [0 ], sharding_type )
201
+
202
+ dims , reshape_dims , transpose_perm = self ._get_op_sharding_args_v2 (
203
+ partition_spec )
204
+ return torch_xla ._XLAC .OpSharding (dims , reshape_dims , transpose_perm )
205
+
145
206
@functools .lru_cache (maxsize = None )
146
207
def get_op_sharding (
147
208
self , partition_spec : PartitionSpec ) -> torch_xla ._XLAC .OpSharding :
@@ -157,6 +218,7 @@ def get_op_sharding(
157
218
158
219
tile_assignment , group_assignment , replication_groups , sharding_type = self ._get_op_sharding_args (
159
220
partition_spec )
221
+
160
222
return torch_xla ._XLAC .OpSharding (tile_assignment , group_assignment ,
161
223
replication_groups , sharding_type )
162
224
@@ -653,7 +715,10 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
653
715
t .shard_ (NamedSharding (jmesh , P (* partition_spec )))
654
716
return t
655
717
656
- op_sharding = mesh .get_op_sharding (partition_spec )
718
+ if os .environ .get ('CONVERT_SHLO_TO_SHARDY' , False ):
719
+ op_sharding = mesh .get_op_sharding_v2 (partition_spec )
720
+ else :
721
+ op_sharding = mesh .get_op_sharding (partition_spec )
657
722
annotate_func = torch_xla ._XLAC ._xla_mark_sharding
658
723
annotate_func (unwrap_sharded_tensor (t ), op_sharding )
659
724
# Pass mesh and partition spec information for DTensor compatibility
0 commit comments