|
1 | 1 | # Owner(s): ["module: c10d"] |
2 | 2 |
|
| 3 | +import itertools |
3 | 4 | import os |
4 | 5 | from unittest import skipIf |
5 | 6 |
|
@@ -860,55 +861,69 @@ def test_multimem_one_shot_all_reduce( |
860 | 861 |
|
861 | 862 | @skipIfRocm |
862 | 863 | @skip_if_lt_x_gpu(4) |
863 | | - @parametrize("dtype", [torch.float, torch.bfloat16]) |
864 | | - @parametrize("align_bytes", [4, 8, 16]) |
865 | | - @parametrize("size_bytes", [4, 8192, 8196]) |
866 | | - def test_one_shot_all_reduce( |
867 | | - self, dtype: torch.dtype, size_bytes: int, align_bytes: int |
868 | | - ) -> None: |
| 864 | + def test_one_shot_all_reduce(self) -> None: |
869 | 865 | self._init_process() |
870 | 866 | group_name = dist.group.WORLD.group_name |
871 | 867 |
|
872 | | - inp = symm_mem.empty( |
873 | | - size_bytes // dtype.itemsize, dtype=dtype, device=self.device |
874 | | - ).normal_() |
875 | | - symm_mem.rendezvous(inp, group=group_name) |
876 | | - |
877 | | - res = torch.ops.symm_mem.one_shot_all_reduce(inp, "sum", group_name) |
878 | | - self._verify_all_reduce_result(inp, res) |
| 868 | + for dtype, size_bytes, align_bytes, copy, offset in itertools.product( |
| 869 | + [torch.float, torch.bfloat16], |
| 870 | + [4, 8192, 8196], |
| 871 | + [4, 8, 16], |
| 872 | + [True, False], |
| 873 | + [0, 16], |
| 874 | + ): |
| 875 | + inp = symm_mem.empty( |
| 876 | + size_bytes // dtype.itemsize + offset, dtype=dtype, device=self.device |
| 877 | + ) |
| 878 | + symm_mem.rendezvous(inp, group=group_name) |
| 879 | + if not copy: |
| 880 | + inp.normal_() |
| 881 | + res = torch.ops.symm_mem.one_shot_all_reduce( |
| 882 | + inp[offset:], "sum", group_name |
| 883 | + ) |
| 884 | + if copy: |
| 885 | + local_inp = torch.randn_like(inp[offset:]) |
| 886 | + res = torch.ops.symm_mem.one_shot_all_reduce_copy( |
| 887 | + inp[offset:], local_inp, "sum", group_name |
| 888 | + ) |
| 889 | + self._verify_all_reduce_result(local_inp if copy else inp[offset:], res) |
879 | 890 |
|
880 | 891 | dist.destroy_process_group() |
881 | 892 |
|
882 | 893 | @skipIfRocm |
883 | 894 | @skip_if_lt_x_gpu(4) |
884 | | - @parametrize("dtype", [torch.float, torch.bfloat16]) |
885 | | - @parametrize("align_bytes", [4, 8, 16]) |
886 | | - @parametrize("size_bytes", [4, 8192, 8196]) |
887 | | - def test_two_shot_all_reduce( |
888 | | - self, dtype: torch.dtype, size_bytes: int, align_bytes: int |
889 | | - ) -> None: |
| 895 | + def test_two_shot_all_reduce(self) -> None: |
890 | 896 | self._init_process() |
891 | 897 | group_name = dist.group.WORLD.group_name |
892 | 898 |
|
893 | | - t = symm_mem.empty(16384, dtype=dtype, device=self.device).fill_(0) |
894 | | - symm_mem.rendezvous(t, group=group_name) |
895 | | - |
896 | | - self.assertTrue(t.data_ptr() % 16 == 0) |
897 | | - self.assertTrue(align_bytes % t.element_size() == 0) |
898 | | - self.assertTrue(size_bytes % t.element_size() == 0) |
899 | | - |
900 | | - shift = align_bytes // t.element_size() |
901 | | - numel = size_bytes // t.element_size() |
902 | | - res = t[shift : shift + numel] |
903 | | - res.normal_() |
904 | | - inp = res.clone() |
905 | | - |
906 | | - torch.ops.symm_mem.two_shot_all_reduce_(res, "sum", group_name) |
| 899 | + for dtype, size_bytes, align_bytes, inplace in itertools.product( |
| 900 | + [torch.float, torch.bfloat16], |
| 901 | + [4, 8192, 8196], |
| 902 | + [4, 8, 16], |
| 903 | + [True, False], |
| 904 | + ): |
| 905 | + t = symm_mem.empty(16384, dtype=dtype, device=self.device).fill_(0) |
| 906 | + symm_mem.rendezvous(t, group=group_name) |
| 907 | + |
| 908 | + self.assertTrue(t.data_ptr() % 16 == 0) |
| 909 | + self.assertTrue(align_bytes % t.element_size() == 0) |
| 910 | + self.assertTrue(size_bytes % t.element_size() == 0) |
| 911 | + |
| 912 | + shift = align_bytes // t.element_size() |
| 913 | + numel = size_bytes // t.element_size() |
| 914 | + res = t[shift : shift + numel] |
| 915 | + res.normal_().fill_(1) |
| 916 | + inp = res.clone() |
| 917 | + if not inplace: |
| 918 | + out = torch.empty_like(inp) |
| 919 | + torch.ops.symm_mem.two_shot_all_reduce_out(res, "sum", group_name, out) |
| 920 | + else: |
| 921 | + torch.ops.symm_mem.two_shot_all_reduce_(res, "sum", group_name) |
907 | 922 |
|
908 | | - # Head and tail should not be written |
909 | | - self.assertTrue(t[:shift].eq(0).all().item()) |
910 | | - self.assertTrue(t[shift + numel :].eq(0).all().item()) |
911 | | - self._verify_all_reduce_result(inp, res) |
| 923 | + # Head and tail should not be written |
| 924 | + self.assertTrue(t[:shift].eq(0).all().item()) |
| 925 | + self.assertTrue(t[shift + numel :].eq(0).all().item()) |
| 926 | + self._verify_all_reduce_result(inp, res if inplace else out) |
912 | 927 |
|
913 | 928 | dist.destroy_process_group() |
914 | 929 |
|
|
0 commit comments