|
24 | 24 | regroup_kts,
|
25 | 25 | )
|
26 | 26 | from torchrec.sparse.tests.utils import build_groups, build_kts
|
| 27 | +from torchrec.test_utils import skip_if_asan_class |
27 | 28 |
|
28 | 29 | torch.fx.wrap("len")
|
29 | 30 |
|
@@ -947,3 +948,100 @@ def test_keyed_tensor_regroup_backward(
|
947 | 948 | val_grad, ref_grad = val.grad, ref.grad
|
948 | 949 | assert isinstance(val_grad, torch.Tensor)
|
949 | 950 | self.assertTrue(torch.allclose(val_grad, ref_grad))
|
| 951 | + |
| 952 | + |
| 953 | +@skip_if_asan_class |
| 954 | +class TestKeyedTensorGPU(unittest.TestCase): |
| 955 | + def setUp(self) -> None: |
| 956 | + super().setUp() |
| 957 | + self.device = torch.cuda.current_device() |
| 958 | + |
| 959 | + # pyre-ignore |
| 960 | + @unittest.skipIf( |
| 961 | + torch.cuda.device_count() <= 0, |
| 962 | + "Not enough GPUs, this test requires at least one GPUs", |
| 963 | + ) |
| 964 | + def test_regroup_backward_skips_and_duplicates(self) -> None: |
| 965 | + kts = build_kts( |
| 966 | + dense_features=20, |
| 967 | + sparse_features=20, |
| 968 | + dim_dense=64, |
| 969 | + dim_sparse=128, |
| 970 | + batch_size=128, |
| 971 | + device=self.device, |
| 972 | + run_backward=True, |
| 973 | + ) |
| 974 | + groups = build_groups(kts=kts, num_groups=2, skips=True, duplicates=True) |
| 975 | + labels = torch.randint(0, 1, (128,), device=self.device).float() |
| 976 | + |
| 977 | + tensor_groups = KeyedTensor.regroup(kts, groups) |
| 978 | + pred0 = tensor_groups[0].sum(dim=1).mul(tensor_groups[1].sum(dim=1)) |
| 979 | + loss = torch.nn.functional.l1_loss(pred0, labels).sum() |
| 980 | + actual_kt_0_grad = torch.autograd.grad( |
| 981 | + loss, kts[0].values(), retain_graph=True |
| 982 | + )[0] |
| 983 | + actual_kt_1_grad = torch.autograd.grad( |
| 984 | + loss, kts[1].values(), retain_graph=True |
| 985 | + )[0] |
| 986 | + |
| 987 | + # clear grads are return |
| 988 | + kts[0].values().grad = None |
| 989 | + kts[1].values().grad = None |
| 990 | + |
| 991 | + tensor_groups = _regroup_keyed_tensors(kts, groups) |
| 992 | + pred1 = tensor_groups[0].sum(dim=1).mul(tensor_groups[1].sum(dim=1)) |
| 993 | + loss = torch.nn.functional.l1_loss(pred1, labels).sum() |
| 994 | + expected_kt_0_grad = torch.autograd.grad( |
| 995 | + loss, kts[0].values(), retain_graph=True |
| 996 | + )[0] |
| 997 | + expected_kt_1_grad = torch.autograd.grad( |
| 998 | + loss, kts[1].values(), retain_graph=True |
| 999 | + )[0] |
| 1000 | + |
| 1001 | + torch.allclose(actual_kt_0_grad, expected_kt_0_grad) |
| 1002 | + torch.allclose(actual_kt_1_grad, expected_kt_1_grad) |
| 1003 | + |
| 1004 | + # pyre-ignore |
| 1005 | + @unittest.skipIf( |
| 1006 | + torch.cuda.device_count() <= 0, |
| 1007 | + "Not enough GPUs, this test requires at least one GPUs", |
| 1008 | + ) |
| 1009 | + def test_regroup_backward(self) -> None: |
| 1010 | + kts = build_kts( |
| 1011 | + dense_features=20, |
| 1012 | + sparse_features=20, |
| 1013 | + dim_dense=64, |
| 1014 | + dim_sparse=128, |
| 1015 | + batch_size=128, |
| 1016 | + device=self.device, |
| 1017 | + run_backward=True, |
| 1018 | + ) |
| 1019 | + groups = build_groups(kts=kts, num_groups=2, skips=False, duplicates=False) |
| 1020 | + labels = torch.randint(0, 1, (128,), device=self.device).float() |
| 1021 | + |
| 1022 | + tensor_groups = KeyedTensor.regroup(kts, groups) |
| 1023 | + pred0 = tensor_groups[0].sum(dim=1).mul(tensor_groups[1].sum(dim=1)) |
| 1024 | + loss = torch.nn.functional.l1_loss(pred0, labels).sum() |
| 1025 | + actual_kt_0_grad = torch.autograd.grad( |
| 1026 | + loss, kts[0].values(), retain_graph=True |
| 1027 | + )[0] |
| 1028 | + actual_kt_1_grad = torch.autograd.grad( |
| 1029 | + loss, kts[1].values(), retain_graph=True |
| 1030 | + )[0] |
| 1031 | + |
| 1032 | + # clear grads are return |
| 1033 | + kts[0].values().grad = None |
| 1034 | + kts[1].values().grad = None |
| 1035 | + |
| 1036 | + tensor_groups = _regroup_keyed_tensors(kts, groups) |
| 1037 | + pred1 = tensor_groups[0].sum(dim=1).mul(tensor_groups[1].sum(dim=1)) |
| 1038 | + loss = torch.nn.functional.l1_loss(pred1, labels).sum() |
| 1039 | + expected_kt_0_grad = torch.autograd.grad( |
| 1040 | + loss, kts[0].values(), retain_graph=True |
| 1041 | + )[0] |
| 1042 | + expected_kt_1_grad = torch.autograd.grad( |
| 1043 | + loss, kts[1].values(), retain_graph=True |
| 1044 | + )[0] |
| 1045 | + |
| 1046 | + torch.allclose(actual_kt_0_grad, expected_kt_0_grad) |
| 1047 | + torch.allclose(actual_kt_1_grad, expected_kt_1_grad) |
0 commit comments