@@ -31,6 +31,7 @@ class BasicXlaShardingTest(test_xla_sharding_base.XlaShardingTest):
31
31
@classmethod
32
32
def setUpClass (cls ):
33
33
super ().setUpClass ()
34
+ cls .convert_to_shardy = xu .check_env_flag ("CONVERT_SHLO_TO_SHARDY" )
34
35
35
36
def test_xla_sharded_tensor (self ):
36
37
partition_spec = (0 , 1 )
@@ -238,6 +239,8 @@ def test_custom_tile_assignment(self):
238
239
if self .n_devices > 1 :
239
240
annotation = '{devices=[1,%d]%s}' % (self .n_devices , ',' .join (
240
241
[str (i ) for i in reversed (range (self .n_devices ))]))
242
+ if self .convert_to_shardy :
243
+ annotation = '{devices=[1,%d]<=[%d]}' % (self .n_devices , self .n_devices )
241
244
self .assertEqual (annotation , torch_xla ._XLAC ._get_xla_sharding_spec (xt ))
242
245
243
246
def test_mark_sharding_2d (self ):
@@ -252,6 +255,8 @@ def test_mark_sharding_2d(self):
252
255
if self .n_devices > 1 :
253
256
annotation = '{devices=[1,%d]%s}' % (self .n_devices , ',' .join (
254
257
[str (i ) for i in range (self .n_devices )]))
258
+ if self .convert_to_shardy :
259
+ annotation = '{devices=[1,%d]<=[%d]}' % (self .n_devices , self .n_devices )
255
260
self .assertEqual (annotation , torch_xla ._XLAC ._get_xla_sharding_spec (xt1 ))
256
261
257
262
actual = (xt1 + xt2 ).cpu ()
@@ -271,6 +276,9 @@ def test_mark_sharding_4d(self):
271
276
annotation = '{devices=[1,1,%d,%d]%s}' % (
272
277
z_dim , self .n_devices // z_dim , ',' .join (
273
278
[str (i ) for i in range (self .n_devices )]))
279
+ if self .convert_to_shardy :
280
+ annotation = '{devices=[1,1,%d,%d]<=[%d]}' % (z_dim , self .n_devices //
281
+ z_dim , self .n_devices )
274
282
self .assertEqual (annotation , torch_xla ._XLAC ._get_xla_sharding_spec (xt ))
275
283
276
284
actual = (xt + xt ).cpu ()
@@ -403,9 +411,11 @@ def test_tupled_partition_spec(self):
403
411
mesh = self ._get_mesh ((2 , self .n_devices // 2 ))
404
412
t = torch .randn (16 ).to ('xla' )
405
413
xs .mark_sharding (t , mesh , ((0 , 1 ),))
406
- self .assertEqual (
407
- torch_xla ._XLAC ._get_xla_sharding_spec (t ), "{devices=[%d]%s}" %
408
- (self .n_devices , ',' .join (str (x ) for x in range (self .n_devices ))))
414
+ annotation = "{devices=[%d]%s}" % (self .n_devices , ',' .join (
415
+ str (x ) for x in range (self .n_devices )))
416
+ if self .convert_to_shardy :
417
+ annotation = "{devices=[%d]<=[%d]}" % (self .n_devices , self .n_devices )
418
+ self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (t ), annotation )
409
419
410
420
@unittest .skipUnless (xr .global_runtime_device_count () >= 4 ,
411
421
"Multiple devices required for tupled partition spec" )
@@ -415,34 +425,43 @@ def test_named_partial_tupled_partition_spec(self):
415
425
# Shard the first dimension on `r` and `b`, replicate the second dimension
416
426
t = torch .randn (16 , 16 ).to ('xla' )
417
427
xs .mark_sharding (t , mesh , (('r' , 'b' ), None ))
418
- self .assertEqual (
419
- torch_xla ._XLAC ._get_xla_sharding_spec (t ),
420
- "{devices=[2,1,%d]%s last_tile_dim_replicate}" %
421
- (self .n_devices // 2 , ',' .join (str (x ) for x in range (self .n_devices ))))
428
+ annotation = "{devices=[2,1,%d]%s last_tile_dim_replicate}" % (
429
+ self .n_devices // 2 , ',' .join (str (x ) for x in range (self .n_devices )))
430
+ if self .convert_to_shardy :
431
+ annotation = "{devices=[2,1,%d]<=[%d] last_tile_dim_replicate}" % (
432
+ self .n_devices // 2 , self .n_devices )
433
+ self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (t ), annotation )
422
434
423
435
# Replicate the first dimension, shard the second on `b` and `m`
424
436
u = torch .randn (16 , 16 ).to ('xla' )
425
437
xs .mark_sharding (u , mesh , (None , ('b' , 'm' )))
426
- self .assertEqual (
427
- torch_xla ._XLAC ._get_xla_sharding_spec (u ), "{devices=[1,%d]%s}" %
428
- (self .n_devices , ',' .join (str (x ) for x in range (self .n_devices ))))
438
+ annotation = "{devices=[1,%d]%s}" % (self .n_devices , ',' .join (
439
+ str (x ) for x in range (self .n_devices )))
440
+ if self .convert_to_shardy :
441
+ annotation = "{devices=[1,%d]<=[%d]}" % (self .n_devices , self .n_devices )
442
+ self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (u ), annotation )
429
443
430
444
# Replicate the first dimension, shard the second on `r` and `m`
431
445
v = torch .randn (16 , 16 ).to ('xla' )
432
446
xs .mark_sharding (v , mesh , (None , ('r' , 'm' )))
433
447
device_order = mesh .get_logical_mesh ().transpose ((0 , 2 , 1 )).flatten ()
434
- self .assertEqual (
435
- torch_xla ._XLAC ._get_xla_sharding_spec (v ),
436
- "{devices=[1,%d,2]%s last_tile_dim_replicate}" %
437
- (self .n_devices // 2 , ',' .join (str (x ) for x in device_order )))
448
+ annotation = "{devices=[1,%d,2]%s last_tile_dim_replicate}" % (
449
+ self .n_devices // 2 , ',' .join (str (x ) for x in device_order ))
450
+ if self .convert_to_shardy :
451
+ annotation = "{devices=[1,%d,2]<=[2,%d]T(1,0) last_tile_dim_replicate}" % (
452
+ self .n_devices // 2 , self .n_devices // 2 )
453
+ self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (v ), annotation )
438
454
439
455
# Replicate the first dimension, shard the second on `m` and `b`
440
456
v = torch .randn (16 , 16 ).to ('xla' )
441
457
xs .mark_sharding (v , mesh , (None , ('m' , 'b' )))
442
458
device_order = mesh .get_logical_mesh ().transpose ((2 , 1 , 0 )).flatten ()
443
- self .assertEqual (
444
- torch_xla ._XLAC ._get_xla_sharding_spec (v ), "{devices=[1,%d]%s}" %
445
- (self .n_devices , ',' .join (str (x ) for x in device_order )))
459
+ annotation = "{devices=[1,%d]%s}" % (self .n_devices , ',' .join (
460
+ str (x ) for x in device_order ))
461
+ if self .convert_to_shardy :
462
+ annotation = "{devices=[1,%d]<=[2,%d]T(1,0)}" % (self .n_devices ,
463
+ self .n_devices // 2 )
464
+ self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (v ), annotation )
446
465
447
466
@unittest .skipUnless (xr .global_runtime_device_count () > 1 ,
448
467
'Multiple devices required for tupled partition spec' )
@@ -452,19 +471,25 @@ def test_multiple_tuples_in_spec(self):
452
471
('a' , 'b' , 'c' , 'd' ))
453
472
t = torch .randn (2 , 2 ).to ('xla' )
454
473
xs .mark_sharding (t , mesh , (('a' , 'b' ), ('c' , 'd' )))
455
- self .assertEqual (
456
- torch_xla ._XLAC ._get_xla_sharding_spec (t ), "{devices=[2,%d]%s}" %
457
- (self .n_devices // 2 , ',' .join (str (x ) for x in range (self .n_devices ))))
474
+ annotation = "{devices=[2,%d]%s}" % (self .n_devices // 2 , ',' .join (
475
+ str (x ) for x in range (self .n_devices )))
476
+ if self .convert_to_shardy :
477
+ annotation = "{devices=[2,%d]<=[%d]}" % (self .n_devices // 2 ,
478
+ self .n_devices )
479
+ self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (t ), annotation )
458
480
459
481
@unittest .skipUnless (xr .global_runtime_device_count () > 1 ,
460
482
'At least 2 devices needed for 2D mesh' )
461
483
def test_3d_tensor_2d_mesh (self ):
462
484
mesh = self ._get_mesh ((2 , self .n_devices // 2 ))
463
485
t = torch .randn (16 , 16 , 16 ).to ('xla' )
464
486
xs .mark_sharding (t , mesh , (None , 0 , 1 ))
465
- self .assertEqual (
466
- torch_xla ._XLAC ._get_xla_sharding_spec (t ), '{devices=[1,2,%d]%s}' %
467
- (self .n_devices // 2 , ',' .join (str (x ) for x in range (self .n_devices ))))
487
+ annotation = '{devices=[1,2,%d]%s}' % (self .n_devices // 2 , ',' .join (
488
+ str (x ) for x in range (self .n_devices )))
489
+ if self .convert_to_shardy :
490
+ annotation = '{devices=[1,2,%d]<=[%d]}' % (self .n_devices // 2 ,
491
+ self .n_devices )
492
+ self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (t ), annotation )
468
493
469
494
def test_partial_replication_addmm (self ):
470
495
device = torch_xla .device ()
@@ -983,18 +1008,20 @@ def test_op_sharding_cache(self):
983
1008
984
1009
t = torch .randn (1 , self .n_devices ).to ('xla' )
985
1010
xs .mark_sharding (t , mesh , (0 , 1 ))
986
- self .assertIn ("CreateOpSharding" , met .counter_names ())
987
- self .assertEqual (met .counter_value ("CreateOpSharding" ), 1 )
1011
+ counter_name = "CreateIotaOpSharding" if self .convert_to_shardy else "CreateOpSharding"
1012
+ self .assertIn (counter_name , met .counter_names ())
1013
+ self .assertEqual (met .counter_value (counter_name ), 1 )
988
1014
989
1015
# Sharding with the same partition spec should not result in another call
990
1016
u = torch .randn (1 , self .n_devices ).to ('xla' )
991
1017
xs .mark_sharding (u , mesh , (0 , 1 ))
992
- self .assertEqual (met .counter_value ("CreateOpSharding" ), 1 )
1018
+ self .assertEqual (met .counter_value (counter_name ), 1 )
993
1019
994
- # Changing the partition spec will result in another CreateOpSharding
1020
+ # Changing the partition spec will result in another
1021
+ # CreateOpSharding or CreatingIotaOpSharding call
995
1022
v = torch .randn (1 , self .n_devices ).to ('xla' )
996
1023
xs .mark_sharding (v , mesh , (0 , None ))
997
- self .assertEqual (met .counter_value ("CreateOpSharding" ), 2 )
1024
+ self .assertEqual (met .counter_value (counter_name ), 2 )
998
1025
999
1026
def test_from_cpu_shards_replicated (self ):
1000
1027
from_cpu_shards = torch_xla ._XLAC ._global_tensor_from_cpu_shards
@@ -1397,10 +1424,10 @@ def test_data_loader_with_sharding(self):
1397
1424
input_sharding = xs .ShardingSpec (mesh , ('data' , None , None , None )))
1398
1425
data , _ = iter (train_device_loader ).__next__ ()
1399
1426
self .assertEqual (data .size (), torch .Size ([8 , 3 , 64 , 64 ]))
1400
- self . assertEqual (
1401
- torch_xla . _XLAC . _get_xla_sharding_spec ( data ),
1402
- f"{{devices=[{ mesh .size ()} ,1,1,1]{ ',' . join ([ str ( i ) for i in range ( mesh .size ())]) } }}"
1403
- )
1427
+ annotation = f"{{devices=[ { mesh . size () } ,1,1,1] { ',' . join ([ str ( i ) for i in range ( mesh . size ())]) } }}"
1428
+ if self . convert_to_shardy :
1429
+ annotation = f"{{devices=[{ mesh .size ()} ,1,1,1]<=[ { mesh .size ()} ] }}"
1430
+ self . assertEqual ( torch_xla . _XLAC . _get_xla_sharding_spec ( data ), annotation )
1404
1431
1405
1432
@unittest .skipUnless (
1406
1433
xr .global_runtime_device_count () > 1 ,
@@ -1420,10 +1447,10 @@ def test_data_loader_with_non_batch_size(self):
1420
1447
input_sharding = xs .ShardingSpec (mesh , ('data' , None , None , None )))
1421
1448
data , _ = iter (train_device_loader ).__next__ ()
1422
1449
self .assertEqual (data .size (), torch .Size ([mesh .size () - 1 , 3 , 64 , 64 ]))
1423
- self . assertEqual (
1424
- torch_xla . _XLAC . _get_xla_sharding_spec ( data ),
1425
- f"{{devices=[{ mesh .size ()} ,1,1,1]{ ',' . join ([ str ( i ) for i in range ( mesh .size ())]) } }}"
1426
- )
1450
+ annotation = f"{{devices=[ { mesh . size () } ,1,1,1] { ',' . join ([ str ( i ) for i in range ( mesh . size ())]) } }}"
1451
+ if self . convert_to_shardy :
1452
+ annotation = f"{{devices=[{ mesh .size ()} ,1,1,1]<=[ { mesh .size ()} ] }}"
1453
+ self . assertEqual ( torch_xla . _XLAC . _get_xla_sharding_spec ( data ), annotation )
1427
1454
1428
1455
@unittest .skipUnless (
1429
1456
xr .global_runtime_device_count () > 1 ,
0 commit comments