Skip to content

Commit ba3995c

Browse files
committed
Add additional test cases for ConvertV2ShardingToV1Test
1 parent 6f3cf07 commit ba3995c

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

test/spmd/test_spmd_debugging.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -846,6 +846,11 @@ def test_tiled_sharding(self):
846846
[str(i) for i in range(self.n_devices)]))
847847
self.run_test()
848848

849+
self.partition_spec = (1, 0)
850+
self.expected_str = '{devices=[%d,1]%s}' % (self.n_devices, ','.join(
851+
[str(i) for i in range(self.n_devices)]))
852+
self.run_test()
853+
849854
@unittest.skipIf(xr.global_runtime_device_count() < 2,
850855
f"Requires at least 2 devices.")
851856
def test_tupled_tiled_sharding(self):
@@ -873,6 +878,11 @@ def test_partial_replication_sharding(self):
873878
self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices)))
874879
self.run_test()
875880

881+
self.partition_spec = (None, 0)
882+
self.expected_str = '{devices=[1,2,%d]%s last_tile_dim_replicate}' % (
883+
self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices)))
884+
self.run_test()
885+
876886
@unittest.skipIf(xr.global_runtime_device_count() < 4,
877887
f"Requires at least 4 devices.")
878888
def test_tupled_partial_replication_sharding(self):

0 commit comments

Comments
 (0)