Skip to content

Commit 63a0af6

Browse files
jd7-trfacebook-github-bot
authored andcommitted
Move KeyedTensor's GPU test to its own test file (#2871)
Summary: Pull Request resolved: #2871 Move the `TestKeyedTensorGPU` tests to `test_keyed_tensor.py` Reviewed By: TroyGarden Differential Revision: D72404538 fbshipit-source-id: c714ea5617cf445412e037c711567ecdcbc8ea66
1 parent 9e57d85 commit 63a0af6

File tree

1 file changed

+98
-0
lines changed

1 file changed

+98
-0
lines changed

torchrec/sparse/tests/test_keyed_tensor.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
regroup_kts,
2525
)
2626
from torchrec.sparse.tests.utils import build_groups, build_kts
27+
from torchrec.test_utils import skip_if_asan_class
2728

2829
torch.fx.wrap("len")
2930

@@ -947,3 +948,100 @@ def test_keyed_tensor_regroup_backward(
947948
val_grad, ref_grad = val.grad, ref.grad
948949
assert isinstance(val_grad, torch.Tensor)
949950
self.assertTrue(torch.allclose(val_grad, ref_grad))
951+
952+
953+
@skip_if_asan_class
954+
class TestKeyedTensorGPU(unittest.TestCase):
955+
def setUp(self) -> None:
956+
super().setUp()
957+
self.device = torch.cuda.current_device()
958+
959+
# pyre-ignore
960+
@unittest.skipIf(
961+
torch.cuda.device_count() <= 0,
962+
"Not enough GPUs, this test requires at least one GPUs",
963+
)
964+
def test_regroup_backward_skips_and_duplicates(self) -> None:
965+
kts = build_kts(
966+
dense_features=20,
967+
sparse_features=20,
968+
dim_dense=64,
969+
dim_sparse=128,
970+
batch_size=128,
971+
device=self.device,
972+
run_backward=True,
973+
)
974+
groups = build_groups(kts=kts, num_groups=2, skips=True, duplicates=True)
975+
labels = torch.randint(0, 1, (128,), device=self.device).float()
976+
977+
tensor_groups = KeyedTensor.regroup(kts, groups)
978+
pred0 = tensor_groups[0].sum(dim=1).mul(tensor_groups[1].sum(dim=1))
979+
loss = torch.nn.functional.l1_loss(pred0, labels).sum()
980+
actual_kt_0_grad = torch.autograd.grad(
981+
loss, kts[0].values(), retain_graph=True
982+
)[0]
983+
actual_kt_1_grad = torch.autograd.grad(
984+
loss, kts[1].values(), retain_graph=True
985+
)[0]
986+
987+
# clear grads are return
988+
kts[0].values().grad = None
989+
kts[1].values().grad = None
990+
991+
tensor_groups = _regroup_keyed_tensors(kts, groups)
992+
pred1 = tensor_groups[0].sum(dim=1).mul(tensor_groups[1].sum(dim=1))
993+
loss = torch.nn.functional.l1_loss(pred1, labels).sum()
994+
expected_kt_0_grad = torch.autograd.grad(
995+
loss, kts[0].values(), retain_graph=True
996+
)[0]
997+
expected_kt_1_grad = torch.autograd.grad(
998+
loss, kts[1].values(), retain_graph=True
999+
)[0]
1000+
1001+
torch.allclose(actual_kt_0_grad, expected_kt_0_grad)
1002+
torch.allclose(actual_kt_1_grad, expected_kt_1_grad)
1003+
1004+
# pyre-ignore
1005+
@unittest.skipIf(
1006+
torch.cuda.device_count() <= 0,
1007+
"Not enough GPUs, this test requires at least one GPUs",
1008+
)
1009+
def test_regroup_backward(self) -> None:
1010+
kts = build_kts(
1011+
dense_features=20,
1012+
sparse_features=20,
1013+
dim_dense=64,
1014+
dim_sparse=128,
1015+
batch_size=128,
1016+
device=self.device,
1017+
run_backward=True,
1018+
)
1019+
groups = build_groups(kts=kts, num_groups=2, skips=False, duplicates=False)
1020+
labels = torch.randint(0, 1, (128,), device=self.device).float()
1021+
1022+
tensor_groups = KeyedTensor.regroup(kts, groups)
1023+
pred0 = tensor_groups[0].sum(dim=1).mul(tensor_groups[1].sum(dim=1))
1024+
loss = torch.nn.functional.l1_loss(pred0, labels).sum()
1025+
actual_kt_0_grad = torch.autograd.grad(
1026+
loss, kts[0].values(), retain_graph=True
1027+
)[0]
1028+
actual_kt_1_grad = torch.autograd.grad(
1029+
loss, kts[1].values(), retain_graph=True
1030+
)[0]
1031+
1032+
# clear grads are return
1033+
kts[0].values().grad = None
1034+
kts[1].values().grad = None
1035+
1036+
tensor_groups = _regroup_keyed_tensors(kts, groups)
1037+
pred1 = tensor_groups[0].sum(dim=1).mul(tensor_groups[1].sum(dim=1))
1038+
loss = torch.nn.functional.l1_loss(pred1, labels).sum()
1039+
expected_kt_0_grad = torch.autograd.grad(
1040+
loss, kts[0].values(), retain_graph=True
1041+
)[0]
1042+
expected_kt_1_grad = torch.autograd.grad(
1043+
loss, kts[1].values(), retain_graph=True
1044+
)[0]
1045+
1046+
torch.allclose(actual_kt_0_grad, expected_kt_0_grad)
1047+
torch.allclose(actual_kt_1_grad, expected_kt_1_grad)

0 commit comments

Comments
 (0)