|
21 | 21 | KeyedJaggedTensor,
|
22 | 22 | kjt_is_equal,
|
23 | 23 | )
|
| 24 | +from torchrec.test_utils import skip_if_asan_class |
24 | 25 |
|
25 | 26 | torch.fx.wrap("len")
|
26 | 27 |
|
@@ -1329,3 +1330,195 @@ def test_key_lookup(self) -> None:
|
1329 | 1330 | )
|
1330 | 1331 | self.assertTrue(torch.equal(i1._lengths, torch.tensor([0, 1, 1, 1, 0, 3])))
|
1331 | 1332 | self.assertTrue(torch.equal(i1._offsets, torch.tensor([0, 0, 1, 2, 3, 3, 6])))
|
| 1333 | + |
| 1334 | + |
| 1335 | +@skip_if_asan_class |
| 1336 | +class TestKeyedJaggedTensorGPU(unittest.TestCase): |
| 1337 | + def setUp(self) -> None: |
| 1338 | + super().setUp() |
| 1339 | + self.device = torch.cuda.current_device() |
| 1340 | + |
| 1341 | + # pyre-ignore |
| 1342 | + @unittest.skipIf( |
| 1343 | + torch.cuda.device_count() <= 0, |
| 1344 | + "Not enough GPUs, this test requires at least one GPUs", |
| 1345 | + ) |
| 1346 | + def test_permute(self) -> None: |
| 1347 | + values = torch.tensor( |
| 1348 | + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], device=self.device |
| 1349 | + ) |
| 1350 | + lengths = torch.tensor([0, 2, 0, 1, 1, 1, 0, 3, 0], device=self.device) |
| 1351 | + keys = ["index_0", "index_1", "index_2"] |
| 1352 | + |
| 1353 | + jag_tensor = KeyedJaggedTensor.from_lengths_sync( |
| 1354 | + values=values, |
| 1355 | + keys=keys, |
| 1356 | + lengths=lengths, |
| 1357 | + ) |
| 1358 | + indices = [1, 0, 2] |
| 1359 | + permuted_jag_tensor = jag_tensor.permute(indices) |
| 1360 | + |
| 1361 | + self.assertEqual(permuted_jag_tensor.keys(), ["index_1", "index_0", "index_2"]) |
| 1362 | + self.assertEqual( |
| 1363 | + permuted_jag_tensor.offset_per_key(), |
| 1364 | + [0, 3, 5, 8], |
| 1365 | + ) |
| 1366 | + self.assertEqual( |
| 1367 | + permuted_jag_tensor.values().tolist(), |
| 1368 | + [3.0, 4.0, 5.0, 1.0, 2.0, 6.0, 7.0, 8.0], |
| 1369 | + ) |
| 1370 | + self.assertEqual( |
| 1371 | + permuted_jag_tensor.lengths().tolist(), [1, 1, 1, 0, 2, 0, 0, 3, 0] |
| 1372 | + ) |
| 1373 | + self.assertEqual(permuted_jag_tensor.weights_or_none(), None) |
| 1374 | + |
| 1375 | + # pyre-ignore |
| 1376 | + @unittest.skipIf( |
| 1377 | + torch.cuda.device_count() <= 0, |
| 1378 | + "Not enough GPUs, this test requires at least one GPUs", |
| 1379 | + ) |
| 1380 | + def test_permute_vb(self) -> None: |
| 1381 | + values = torch.tensor( |
| 1382 | + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], device=self.device |
| 1383 | + ) |
| 1384 | + lengths = torch.tensor([1, 0, 1, 3, 0, 1, 0, 2, 0], device=self.device) |
| 1385 | + keys = ["index_0", "index_1", "index_2"] |
| 1386 | + stride_per_key_per_rank = [[2], [4], [3]] |
| 1387 | + |
| 1388 | + jag_tensor = KeyedJaggedTensor.from_lengths_sync( |
| 1389 | + values=values, |
| 1390 | + keys=keys, |
| 1391 | + lengths=lengths, |
| 1392 | + stride_per_key_per_rank=stride_per_key_per_rank, |
| 1393 | + ) |
| 1394 | + |
| 1395 | + indices = [1, 0, 2] |
| 1396 | + permuted_jag_tensor = jag_tensor.permute(indices) |
| 1397 | + |
| 1398 | + self.assertEqual(permuted_jag_tensor.keys(), ["index_1", "index_0", "index_2"]) |
| 1399 | + self.assertEqual( |
| 1400 | + permuted_jag_tensor.offset_per_key(), |
| 1401 | + [0, 5, 6, 8], |
| 1402 | + ) |
| 1403 | + self.assertEqual( |
| 1404 | + permuted_jag_tensor.values().tolist(), |
| 1405 | + [2.0, 3.0, 4.0, 5.0, 6.0, 1.0, 7.0, 8.0], |
| 1406 | + ) |
| 1407 | + self.assertEqual( |
| 1408 | + permuted_jag_tensor.lengths().tolist(), [1, 3, 0, 1, 1, 0, 0, 2, 0] |
| 1409 | + ) |
| 1410 | + self.assertEqual(permuted_jag_tensor.weights_or_none(), None) |
| 1411 | + |
| 1412 | + # pyre-ignore |
| 1413 | + @unittest.skipIf( |
| 1414 | + torch.cuda.device_count() <= 0, |
| 1415 | + "Not enough GPUs, this test requires at least one GPUs", |
| 1416 | + ) |
| 1417 | + def test_permute_vb_duplicate(self) -> None: |
| 1418 | + values = torch.tensor( |
| 1419 | + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], device=self.device |
| 1420 | + ) |
| 1421 | + lengths = torch.tensor([1, 0, 1, 3, 0, 1, 0, 2, 0], device=self.device) |
| 1422 | + keys = ["index_0", "index_1", "index_2"] |
| 1423 | + stride_per_key_per_rank = [[2], [4], [3]] |
| 1424 | + |
| 1425 | + jag_tensor = KeyedJaggedTensor.from_lengths_sync( |
| 1426 | + values=values, |
| 1427 | + keys=keys, |
| 1428 | + lengths=lengths, |
| 1429 | + stride_per_key_per_rank=stride_per_key_per_rank, |
| 1430 | + ) |
| 1431 | + |
| 1432 | + indices = [1, 1, 0, 0, 2, 2] |
| 1433 | + permuted_jag_tensor = jag_tensor.permute(indices) |
| 1434 | + |
| 1435 | + self.assertEqual( |
| 1436 | + permuted_jag_tensor.keys(), |
| 1437 | + ["index_1", "index_1", "index_0", "index_0", "index_2", "index_2"], |
| 1438 | + ) |
| 1439 | + self.assertTrue( |
| 1440 | + torch.equal( |
| 1441 | + permuted_jag_tensor.values().cpu(), |
| 1442 | + torch.Tensor( |
| 1443 | + [ |
| 1444 | + 2.0, |
| 1445 | + 3.0, |
| 1446 | + 4.0, |
| 1447 | + 5.0, |
| 1448 | + 6.0, |
| 1449 | + 2.0, |
| 1450 | + 3.0, |
| 1451 | + 4.0, |
| 1452 | + 5.0, |
| 1453 | + 6.0, |
| 1454 | + 1.0, |
| 1455 | + 1.0, |
| 1456 | + 7.0, |
| 1457 | + 8.0, |
| 1458 | + 7.0, |
| 1459 | + 8.0, |
| 1460 | + ] |
| 1461 | + ), |
| 1462 | + ) |
| 1463 | + ) |
| 1464 | + self.assertTrue( |
| 1465 | + torch.equal( |
| 1466 | + permuted_jag_tensor.lengths().cpu(), |
| 1467 | + torch.IntTensor([1, 3, 0, 1, 1, 3, 0, 1, 1, 0, 1, 0, 0, 2, 0, 0, 2, 0]), |
| 1468 | + ) |
| 1469 | + ) |
| 1470 | + self.assertEqual(permuted_jag_tensor.weights_or_none(), None) |
| 1471 | + |
| 1472 | + # pyre-ignore |
| 1473 | + @unittest.skipIf( |
| 1474 | + torch.cuda.device_count() <= 0, |
| 1475 | + "Not enough GPUs, this test requires at least one GPUs", |
| 1476 | + ) |
| 1477 | + def test_permute_duplicates(self) -> None: |
| 1478 | + values = torch.tensor( |
| 1479 | + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], device=self.device |
| 1480 | + ) |
| 1481 | + lengths = torch.tensor([0, 2, 0, 1, 1, 1, 0, 3, 0], device=self.device) |
| 1482 | + keys = ["index_0", "index_1", "index_2"] |
| 1483 | + |
| 1484 | + jag_tensor = KeyedJaggedTensor.from_lengths_sync( |
| 1485 | + values=values, |
| 1486 | + keys=keys, |
| 1487 | + lengths=lengths, |
| 1488 | + ) |
| 1489 | + |
| 1490 | + indices = [1, 0, 2, 1, 1] |
| 1491 | + permuted_jag_tensor = jag_tensor.permute(indices) |
| 1492 | + |
| 1493 | + self.assertEqual( |
| 1494 | + permuted_jag_tensor.keys(), |
| 1495 | + ["index_1", "index_0", "index_2", "index_1", "index_1"], |
| 1496 | + ) |
| 1497 | + self.assertEqual( |
| 1498 | + permuted_jag_tensor.offset_per_key(), |
| 1499 | + [0, 3, 5, 8, 11, 14], |
| 1500 | + ) |
| 1501 | + self.assertEqual( |
| 1502 | + permuted_jag_tensor.values().tolist(), |
| 1503 | + [ |
| 1504 | + 3.0, |
| 1505 | + 4.0, |
| 1506 | + 5.0, |
| 1507 | + 1.0, |
| 1508 | + 2.0, |
| 1509 | + 6.0, |
| 1510 | + 7.0, |
| 1511 | + 8.0, |
| 1512 | + 3.0, |
| 1513 | + 4.0, |
| 1514 | + 5.0, |
| 1515 | + 3.0, |
| 1516 | + 4.0, |
| 1517 | + 5.0, |
| 1518 | + ], |
| 1519 | + ) |
| 1520 | + self.assertEqual( |
| 1521 | + permuted_jag_tensor.lengths().tolist(), |
| 1522 | + [1, 1, 1, 0, 2, 0, 0, 3, 0, 1, 1, 1, 1, 1, 1], |
| 1523 | + ) |
| 1524 | + self.assertEqual(permuted_jag_tensor.weights_or_none(), None) |
0 commit comments