6
6
import unittest
7
7
from unittest .mock import patch
8
8
import sys
9
+ import os
9
10
10
11
import torch
11
12
from torch import nn
26
27
from torch_xla ._internal import tpu
27
28
28
29
30
+ def should_convert_to_shardy ():
31
+ return os .environ .get ("CONVERT_SHLO_TO_SHARDY" ,
32
+ "" ).lower () in ("1" , "true" , "yes" )
33
+
34
+
29
35
class BasicXlaShardingTest (test_xla_sharding_base .XlaShardingTest ):
30
36
31
37
@classmethod
@@ -238,6 +244,8 @@ def test_custom_tile_assignment(self):
238
244
if self .n_devices > 1 :
239
245
annotation = '{devices=[1,%d]%s}' % (self .n_devices , ',' .join (
240
246
[str (i ) for i in reversed (range (self .n_devices ))]))
247
+ if should_convert_to_shardy ():
248
+ annotation = '{devices=[1,%d]<=[%d]}' % (self .n_devices , self .n_devices )
241
249
self .assertEqual (annotation , torch_xla ._XLAC ._get_xla_sharding_spec (xt ))
242
250
243
251
def test_mark_sharding_2d (self ):
@@ -252,6 +260,8 @@ def test_mark_sharding_2d(self):
252
260
if self .n_devices > 1 :
253
261
annotation = '{devices=[1,%d]%s}' % (self .n_devices , ',' .join (
254
262
[str (i ) for i in range (self .n_devices )]))
263
+ if should_convert_to_shardy ():
264
+ annotation = '{devices=[1,%d]<=[%d]}' % (self .n_devices , self .n_devices )
255
265
self .assertEqual (annotation , torch_xla ._XLAC ._get_xla_sharding_spec (xt1 ))
256
266
257
267
actual = (xt1 + xt2 ).cpu ()
@@ -271,6 +281,9 @@ def test_mark_sharding_4d(self):
271
281
annotation = '{devices=[1,1,%d,%d]%s}' % (
272
282
z_dim , self .n_devices // z_dim , ',' .join (
273
283
[str (i ) for i in range (self .n_devices )]))
284
+ if should_convert_to_shardy ():
285
+ annotation = '{devices=[1,1,%d,%d]<=[%d]}' % (z_dim , self .n_devices //
286
+ z_dim , self .n_devices )
274
287
self .assertEqual (annotation , torch_xla ._XLAC ._get_xla_sharding_spec (xt ))
275
288
276
289
actual = (xt + xt ).cpu ()
@@ -403,9 +416,11 @@ def test_tupled_partition_spec(self):
403
416
mesh = self ._get_mesh ((2 , self .n_devices // 2 ))
404
417
t = torch .randn (16 ).to ('xla' )
405
418
xs .mark_sharding (t , mesh , ((0 , 1 ),))
406
- self .assertEqual (
407
- torch_xla ._XLAC ._get_xla_sharding_spec (t ), "{devices=[%d]%s}" %
408
- (self .n_devices , ',' .join (str (x ) for x in range (self .n_devices ))))
419
+ annotation = "{devices=[%d]%s}" % (self .n_devices , ',' .join (
420
+ str (x ) for x in range (self .n_devices )))
421
+ if should_convert_to_shardy ():
422
+ annotation = "{devices=[%d]<=[%d]}" % (self .n_devices , self .n_devices )
423
+ self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (t ), annotation )
409
424
410
425
@unittest .skipUnless (xr .global_runtime_device_count () >= 4 ,
411
426
"Multiple devices required for tupled partition spec" )
@@ -452,19 +467,25 @@ def test_multiple_tuples_in_spec(self):
452
467
('a' , 'b' , 'c' , 'd' ))
453
468
t = torch .randn (2 , 2 ).to ('xla' )
454
469
xs .mark_sharding (t , mesh , (('a' , 'b' ), ('c' , 'd' )))
455
- self .assertEqual (
456
- torch_xla ._XLAC ._get_xla_sharding_spec (t ), "{devices=[2,%d]%s}" %
457
- (self .n_devices // 2 , ',' .join (str (x ) for x in range (self .n_devices ))))
470
+ annotation = "{devices=[2,%d]%s}" % (self .n_devices // 2 , ',' .join (
471
+ str (x ) for x in range (self .n_devices )))
472
+ if should_convert_to_shardy ():
473
+ annotation = "{devices=[2,%d]<=[%d]}" % (self .n_devices // 2 ,
474
+ self .n_devices )
475
+ self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (t ), annotation )
458
476
459
477
@unittest .skipUnless (xr .global_runtime_device_count () > 1 ,
460
478
'At least 2 devices needed for 2D mesh' )
461
479
def test_3d_tensor_2d_mesh (self ):
462
480
mesh = self ._get_mesh ((2 , self .n_devices // 2 ))
463
481
t = torch .randn (16 , 16 , 16 ).to ('xla' )
464
482
xs .mark_sharding (t , mesh , (None , 0 , 1 ))
465
- self .assertEqual (
466
- torch_xla ._XLAC ._get_xla_sharding_spec (t ), '{devices=[1,2,%d]%s}' %
467
- (self .n_devices // 2 , ',' .join (str (x ) for x in range (self .n_devices ))))
483
+ expected = '{devices=[1,2,%d]%s}' % (self .n_devices // 2 , ',' .join (
484
+ str (x ) for x in range (self .n_devices )))
485
+ if should_convert_to_shardy ():
486
+ expected = '{devices=[1,2,%d]<=[%d] last_tile_dim_replicate}' % (
487
+ self .n_devices // 2 , self .n_devices )
488
+ self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (t ), expected )
468
489
469
490
def test_partial_replication_addmm (self ):
470
491
device = torch_xla .device ()
@@ -983,18 +1004,21 @@ def test_op_sharding_cache(self):
983
1004
984
1005
t = torch .randn (1 , self .n_devices ).to ('xla' )
985
1006
xs .mark_sharding (t , mesh , (0 , 1 ))
986
- self .assertIn ("CreateOpSharding" , met .counter_names ())
987
- self .assertEqual (met .counter_value ("CreateOpSharding" ), 1 )
1007
+ counter_name = "CreateIotaOpSharding" if should_convert_to_shardy (
1008
+ ) else "CreateOpSharding"
1009
+ self .assertIn (counter_name , met .counter_names ())
1010
+ self .assertEqual (met .counter_value (counter_name ), 1 )
988
1011
989
1012
# Sharding with the same partition spec should not result in another call
990
1013
u = torch .randn (1 , self .n_devices ).to ('xla' )
991
1014
xs .mark_sharding (u , mesh , (0 , 1 ))
992
- self .assertEqual (met .counter_value ("CreateOpSharding" ), 1 )
1015
+ self .assertEqual (met .counter_value (counter_name ), 1 )
993
1016
994
- # Changing the partition spec will result in another CreateOpSharding
1017
+ # Changing the partition spec will result in another
1018
+ # CreateOpSharding or CreatingIotaOpSharding call
995
1019
v = torch .randn (1 , self .n_devices ).to ('xla' )
996
1020
xs .mark_sharding (v , mesh , (0 , None ))
997
- self .assertEqual (met .counter_value ("CreateOpSharding" ), 2 )
1021
+ self .assertEqual (met .counter_value (counter_name ), 2 )
998
1022
999
1023
def test_from_cpu_shards_replicated (self ):
1000
1024
from_cpu_shards = torch_xla ._XLAC ._global_tensor_from_cpu_shards
0 commit comments