@@ -613,7 +613,7 @@ def test_inplace_add_with_sharding(self):
613
613
self .assertEqual (sharding_spec , torch_xla ._XLAC ._get_xla_sharding_spec (xt ))
614
614
hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([xt ])
615
615
self .assertIn (
616
- '%custom-call.7 = f32[2,2]{1,0} custom-call(f32[2,2]{1,0} %add.6 ), custom_call_target="Sharding", sharding=' ,
616
+ '%custom-call.1 = f32[2,2]{1,0} custom-call(f32[2,2]{1,0} %add.1 ), custom_call_target="Sharding", sharding=' ,
617
617
hlo )
618
618
619
619
# avoid calling xr.addressable_device_count here otherwise it will init the test
@@ -713,7 +713,8 @@ def test_xla_sharded_hlo_dump(self):
713
713
partition_spec )
714
714
xst2 = xst1 + 5
715
715
hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([xst2 .global_tensor ])
716
- self .assertIn ('%p1.3 = f32[1,8]{1,0} parameter(1), sharding' , hlo )
716
+ print (hlo )
717
+ self .assertIn ('%p1.1 = f32[1,8]{1,0} parameter(1), sharding' , hlo )
717
718
if torch_xla ._XLAC ._xla_get_auto_sharding ():
718
719
# scalar 5 should be implicitly replicated, so the pre-optimization HLO
719
720
# shouldn't mark it with sharding.
@@ -828,13 +829,13 @@ def test_mark_sharding_ir(self):
828
829
(0 , 1 ))
829
830
hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([actual .global_tensor ])
830
831
self .assertIn (
831
- '%custom-call.7 = f32[1,128]{1,0} custom-call(f32[1,128]{1,0} %add.6 ), custom_call_target="Sharding", sharding=' ,
832
+ '%custom-call.1 = f32[1,128]{1,0} custom-call(f32[1,128]{1,0} %add.1 ), custom_call_target="Sharding", sharding=' ,
832
833
hlo )
833
834
834
835
actual += 0
835
836
hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([actual .global_tensor ])
836
837
self .assertIn (
837
- '%add.12 = f32[1,128]{1,0} add(f32[1,128]{1,0} %custom-call.9 , f32[1,128]{1,0} %broadcast.11 )' ,
838
+ '%add.3 = f32[1,128]{1,0} add(f32[1,128]{1,0} %custom-call.1 , f32[1,128]{1,0} %broadcast.3 )' ,
838
839
hlo )
839
840
840
841
self .assertTrue (torch .allclose (expected , actual .cpu ()))
@@ -1141,7 +1142,7 @@ def test_backward_optimization_barrier(self):
1141
1142
1142
1143
hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([model .fc2 .weight .grad ])
1143
1144
self .assertIn (
1144
- '%opt-barrier.37 = (f32[1,64]{0,1}, f32[1]{0}, f32[2,64]{1,0}) opt-barrier((f32[1,64]{0,1}, f32[1]{0}, f32[2,64]{1,0}) %tuple.36 )' ,
1145
+ '%opt-barrier.1 = (f32[1,64]{0,1}, f32[1]{0}, f32[2,64]{1,0}) opt-barrier((f32[1,64]{0,1}, f32[1]{0}, f32[2,64]{1,0}) %tuple.2 )' ,
1145
1146
hlo )
1146
1147
1147
1148
def test_mark_shard_scalar (self ):
@@ -1198,7 +1199,7 @@ def test_spmd_full_to_shard_shape(self):
1198
1199
1199
1200
hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([xx ])
1200
1201
self .assertEqual (xx .shape , (8 , 8 // self .n_devices ))
1201
- self .assertIn (f'%custom-call.2 = f32[8,{ 8 // self .n_devices } ]{{1,0}}' , hlo )
1202
+ self .assertIn (f'%custom-call.1 = f32[8,{ 8 // self .n_devices } ]{{1,0}}' , hlo )
1202
1203
self .assertIn (
1203
1204
f'custom_call_target="SPMDFullToShardShape", sharding={{manual}}' , hlo )
1204
1205
self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (xx ), "{manual}" )
@@ -1215,7 +1216,7 @@ def test_spmd_full_to_shard_shape(self):
1215
1216
1216
1217
hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([xx ])
1217
1218
self .assertEqual (xx .shape , (8 , 4 ))
1218
- self .assertIn (f'%custom-call.2 = f32[8,4]{{1,0}}' , hlo )
1219
+ self .assertIn (f'%custom-call.1 = f32[8,4]{{1,0}}' , hlo )
1219
1220
self .assertIn (
1220
1221
f'custom_call_target="SPMDFullToShardShape", sharding={{manual}}' , hlo )
1221
1222
self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (xx ), "{manual}" )
@@ -1246,7 +1247,7 @@ def test_spmd_shard_to_full_shape(self):
1246
1247
1247
1248
hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([xx ])
1248
1249
self .assertEqual (xx .shape , x .shape )
1249
- self .assertIn ('%custom-call.9 = f32[8,8]{1,0}' , hlo )
1250
+ self .assertIn ('%custom-call.5 = f32[8,8]{1,0}' , hlo )
1250
1251
self .assertIn (
1251
1252
'custom_call_target="SPMDShardToFullShape", sharding={replicated}' , hlo )
1252
1253
self .assertEqual (torch_xla ._XLAC ._get_xla_sharding_spec (xx ), "{replicated}" )
@@ -1297,7 +1298,7 @@ def test_spmd_reduce_scatter(self):
1297
1298
1298
1299
hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([x ])
1299
1300
self .assertIn (
1300
- f"reduce-scatter(f32[8,8]{{1,0}} %custom-call.2 ), channel_id=1, replica_groups={{{{{ ',' .join ([str (x ) for x in self .device_ids ])} }}}}, use_global_device_ids=true, dimensions={{0}}, to_apply=%AddComputation.3 " ,
1301
+ f"reduce-scatter(f32[8,8]{{1,0}} %custom-call.3 ), channel_id=1, replica_groups={{{{{ ',' .join ([str (x ) for x in self .device_ids ])} }}}}, use_global_device_ids=true, dimensions={{0}}, to_apply=%AddComputation.1 " ,
1301
1302
hlo )
1302
1303
1303
1304
expected_x = torch .ones (8 // self .n_devices , 8 ) * self .n_devices
@@ -1318,7 +1319,7 @@ def test_spmd_reduce_scatter_canonical_index(self):
1318
1319
1319
1320
hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([x ])
1320
1321
self .assertIn (
1321
- f"reduce-scatter(f32[8,8]{{1,0}} %custom-call.2 ), channel_id=1, replica_groups={{{{{ ',' .join ([str (x ) for x in self .device_ids ])} }}}}, use_global_device_ids=true, dimensions={{1}}, to_apply=%AddComputation.3 " ,
1322
+ f"reduce-scatter(f32[8,8]{{1,0}} %custom-call.3 ), channel_id=1, replica_groups={{{{{ ',' .join ([str (x ) for x in self .device_ids ])} }}}}, use_global_device_ids=true, dimensions={{1}}, to_apply=%AddComputation.1 " ,
1322
1323
hlo )
1323
1324
1324
1325
expected_x = torch .ones (8 , 8 // self .n_devices ) * self .n_devices
@@ -1338,7 +1339,7 @@ def test_spmd_all_reduce(self):
1338
1339
1339
1340
hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([x ])
1340
1341
self .assertIn (
1341
- f"all-reduce(f32[8,8]{{1,0}} %custom-call.2 ), channel_id=1, replica_groups={{{{{ ',' .join ([str (x ) for x in self .device_ids ])} }}}}, use_global_device_ids=true, to_apply=%AddComputation.3 " ,
1342
+ f"all-reduce(f32[8,8]{{1,0}} %custom-call.3 ), channel_id=1, replica_groups={{{{{ ',' .join ([str (x ) for x in self .device_ids ])} }}}}, use_global_device_ids=true, to_apply=%AddComputation.1 " ,
1342
1343
hlo )
1343
1344
1344
1345
expected_x = torch .ones (8 , 8 ) * self .n_devices
@@ -1359,7 +1360,7 @@ def test_spmd_all_reduce_scale(self):
1359
1360
1360
1361
hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([x ])
1361
1362
self .assertIn (
1362
- f"all-reduce(f32[8,8]{{1,0}} %custom-call.2 ), channel_id=1, replica_groups={{{{{ ',' .join ([str (x ) for x in self .device_ids ])} }}}}, use_global_device_ids=true, to_apply=%AddComputation.3 " ,
1363
+ f"all-reduce(f32[8,8]{{1,0}} %custom-call.3 ), channel_id=1, replica_groups={{{{{ ',' .join ([str (x ) for x in self .device_ids ])} }}}}, use_global_device_ids=true, to_apply=%AddComputation.1 " ,
1363
1364
hlo )
1364
1365
1365
1366
expected_x = torch .ones (8 , 8 ) * int (self .n_devices * scale )
@@ -1713,7 +1714,7 @@ def test_annotate_custom_sharding(self):
1713
1714
f'%p0.1 = f32[2,4,64,64]{{3,2,1,0}} parameter(0), sharding={ original_sharding_spec } ' ,
1714
1715
hlo )
1715
1716
self .assertIn (
1716
- f'%custom-call.2 = f32[2,4,64,64]{{3,2,1,0}} custom-call(f32[2,4,64,64]{{3,2,1,0}} %p0.1), custom_call_target="Sharding", sharding={ custom_sharding_spec } ' ,
1717
+ f'%custom-call.1 = f32[2,4,64,64]{{3,2,1,0}} custom-call(f32[2,4,64,64]{{3,2,1,0}} %p0.1), custom_call_target="Sharding", sharding={ custom_sharding_spec } ' ,
1717
1718
hlo )
1718
1719
xm .mark_step ()
1719
1720
# Ensure that the resulting sharding spec is preserved
0 commit comments