Skip to content

Commit 67ffc20

Browse files
jd7-trfacebook-github-bot
authored andcommitted
Move KJT GPU test to its own test file (#2868)
Summary: Pull Request resolved: #2868 Move KJT tests to its own test file. Reviewed By: TroyGarden Differential Revision: D72346833 fbshipit-source-id: ac01f6a57bcbfe79695865c535601030cbe0f11c
1 parent 2268800 commit 67ffc20

File tree

1 file changed

+193
-0
lines changed

1 file changed

+193
-0
lines changed

torchrec/sparse/tests/test_keyed_jagged_tensor.py

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
KeyedJaggedTensor,
2222
kjt_is_equal,
2323
)
24+
from torchrec.test_utils import skip_if_asan_class
2425

2526
torch.fx.wrap("len")
2627

@@ -1329,3 +1330,195 @@ def test_key_lookup(self) -> None:
13291330
)
13301331
self.assertTrue(torch.equal(i1._lengths, torch.tensor([0, 1, 1, 1, 0, 3])))
13311332
self.assertTrue(torch.equal(i1._offsets, torch.tensor([0, 0, 1, 2, 3, 3, 6])))
1333+
1334+
1335+
@skip_if_asan_class
1336+
class TestKeyedJaggedTensorGPU(unittest.TestCase):
1337+
def setUp(self) -> None:
1338+
super().setUp()
1339+
self.device = torch.cuda.current_device()
1340+
1341+
# pyre-ignore
1342+
@unittest.skipIf(
1343+
torch.cuda.device_count() <= 0,
1344+
"Not enough GPUs, this test requires at least one GPUs",
1345+
)
1346+
def test_permute(self) -> None:
1347+
values = torch.tensor(
1348+
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], device=self.device
1349+
)
1350+
lengths = torch.tensor([0, 2, 0, 1, 1, 1, 0, 3, 0], device=self.device)
1351+
keys = ["index_0", "index_1", "index_2"]
1352+
1353+
jag_tensor = KeyedJaggedTensor.from_lengths_sync(
1354+
values=values,
1355+
keys=keys,
1356+
lengths=lengths,
1357+
)
1358+
indices = [1, 0, 2]
1359+
permuted_jag_tensor = jag_tensor.permute(indices)
1360+
1361+
self.assertEqual(permuted_jag_tensor.keys(), ["index_1", "index_0", "index_2"])
1362+
self.assertEqual(
1363+
permuted_jag_tensor.offset_per_key(),
1364+
[0, 3, 5, 8],
1365+
)
1366+
self.assertEqual(
1367+
permuted_jag_tensor.values().tolist(),
1368+
[3.0, 4.0, 5.0, 1.0, 2.0, 6.0, 7.0, 8.0],
1369+
)
1370+
self.assertEqual(
1371+
permuted_jag_tensor.lengths().tolist(), [1, 1, 1, 0, 2, 0, 0, 3, 0]
1372+
)
1373+
self.assertEqual(permuted_jag_tensor.weights_or_none(), None)
1374+
1375+
# pyre-ignore
1376+
@unittest.skipIf(
1377+
torch.cuda.device_count() <= 0,
1378+
"Not enough GPUs, this test requires at least one GPUs",
1379+
)
1380+
def test_permute_vb(self) -> None:
1381+
values = torch.tensor(
1382+
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], device=self.device
1383+
)
1384+
lengths = torch.tensor([1, 0, 1, 3, 0, 1, 0, 2, 0], device=self.device)
1385+
keys = ["index_0", "index_1", "index_2"]
1386+
stride_per_key_per_rank = [[2], [4], [3]]
1387+
1388+
jag_tensor = KeyedJaggedTensor.from_lengths_sync(
1389+
values=values,
1390+
keys=keys,
1391+
lengths=lengths,
1392+
stride_per_key_per_rank=stride_per_key_per_rank,
1393+
)
1394+
1395+
indices = [1, 0, 2]
1396+
permuted_jag_tensor = jag_tensor.permute(indices)
1397+
1398+
self.assertEqual(permuted_jag_tensor.keys(), ["index_1", "index_0", "index_2"])
1399+
self.assertEqual(
1400+
permuted_jag_tensor.offset_per_key(),
1401+
[0, 5, 6, 8],
1402+
)
1403+
self.assertEqual(
1404+
permuted_jag_tensor.values().tolist(),
1405+
[2.0, 3.0, 4.0, 5.0, 6.0, 1.0, 7.0, 8.0],
1406+
)
1407+
self.assertEqual(
1408+
permuted_jag_tensor.lengths().tolist(), [1, 3, 0, 1, 1, 0, 0, 2, 0]
1409+
)
1410+
self.assertEqual(permuted_jag_tensor.weights_or_none(), None)
1411+
1412+
# pyre-ignore
1413+
@unittest.skipIf(
1414+
torch.cuda.device_count() <= 0,
1415+
"Not enough GPUs, this test requires at least one GPUs",
1416+
)
1417+
def test_permute_vb_duplicate(self) -> None:
1418+
values = torch.tensor(
1419+
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], device=self.device
1420+
)
1421+
lengths = torch.tensor([1, 0, 1, 3, 0, 1, 0, 2, 0], device=self.device)
1422+
keys = ["index_0", "index_1", "index_2"]
1423+
stride_per_key_per_rank = [[2], [4], [3]]
1424+
1425+
jag_tensor = KeyedJaggedTensor.from_lengths_sync(
1426+
values=values,
1427+
keys=keys,
1428+
lengths=lengths,
1429+
stride_per_key_per_rank=stride_per_key_per_rank,
1430+
)
1431+
1432+
indices = [1, 1, 0, 0, 2, 2]
1433+
permuted_jag_tensor = jag_tensor.permute(indices)
1434+
1435+
self.assertEqual(
1436+
permuted_jag_tensor.keys(),
1437+
["index_1", "index_1", "index_0", "index_0", "index_2", "index_2"],
1438+
)
1439+
self.assertTrue(
1440+
torch.equal(
1441+
permuted_jag_tensor.values().cpu(),
1442+
torch.Tensor(
1443+
[
1444+
2.0,
1445+
3.0,
1446+
4.0,
1447+
5.0,
1448+
6.0,
1449+
2.0,
1450+
3.0,
1451+
4.0,
1452+
5.0,
1453+
6.0,
1454+
1.0,
1455+
1.0,
1456+
7.0,
1457+
8.0,
1458+
7.0,
1459+
8.0,
1460+
]
1461+
),
1462+
)
1463+
)
1464+
self.assertTrue(
1465+
torch.equal(
1466+
permuted_jag_tensor.lengths().cpu(),
1467+
torch.IntTensor([1, 3, 0, 1, 1, 3, 0, 1, 1, 0, 1, 0, 0, 2, 0, 0, 2, 0]),
1468+
)
1469+
)
1470+
self.assertEqual(permuted_jag_tensor.weights_or_none(), None)
1471+
1472+
# pyre-ignore
1473+
@unittest.skipIf(
1474+
torch.cuda.device_count() <= 0,
1475+
"Not enough GPUs, this test requires at least one GPUs",
1476+
)
1477+
def test_permute_duplicates(self) -> None:
1478+
values = torch.tensor(
1479+
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], device=self.device
1480+
)
1481+
lengths = torch.tensor([0, 2, 0, 1, 1, 1, 0, 3, 0], device=self.device)
1482+
keys = ["index_0", "index_1", "index_2"]
1483+
1484+
jag_tensor = KeyedJaggedTensor.from_lengths_sync(
1485+
values=values,
1486+
keys=keys,
1487+
lengths=lengths,
1488+
)
1489+
1490+
indices = [1, 0, 2, 1, 1]
1491+
permuted_jag_tensor = jag_tensor.permute(indices)
1492+
1493+
self.assertEqual(
1494+
permuted_jag_tensor.keys(),
1495+
["index_1", "index_0", "index_2", "index_1", "index_1"],
1496+
)
1497+
self.assertEqual(
1498+
permuted_jag_tensor.offset_per_key(),
1499+
[0, 3, 5, 8, 11, 14],
1500+
)
1501+
self.assertEqual(
1502+
permuted_jag_tensor.values().tolist(),
1503+
[
1504+
3.0,
1505+
4.0,
1506+
5.0,
1507+
1.0,
1508+
2.0,
1509+
6.0,
1510+
7.0,
1511+
8.0,
1512+
3.0,
1513+
4.0,
1514+
5.0,
1515+
3.0,
1516+
4.0,
1517+
5.0,
1518+
],
1519+
)
1520+
self.assertEqual(
1521+
permuted_jag_tensor.lengths().tolist(),
1522+
[1, 1, 1, 0, 2, 0, 0, 3, 0, 1, 1, 1, 1, 1, 1],
1523+
)
1524+
self.assertEqual(permuted_jag_tensor.weights_or_none(), None)

0 commit comments

Comments
 (0)