|
14 | 14 |
|
15 | 15 | import torch
|
16 | 16 | from torch.utils.data import DataLoader, Dataset
|
| 17 | +from torchtnt.framework._test_utils import generate_random_dataloader |
17 | 18 |
|
18 | 19 | from torchtnt.utils.data.iterators import (
|
19 | 20 | AllDatasetBatches,
|
|
22 | 23 | MultiIterator,
|
23 | 24 | RandomizedBatchSampler,
|
24 | 25 | RoundRobin,
|
| 26 | + RoundRobinIterator, |
25 | 27 | StoppingMechanism,
|
26 | 28 | )
|
27 | 29 | from torchtnt.utils.data.multi_dataloader import MultiDataLoader
|
@@ -488,6 +490,92 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
488 | 490 | 0,
|
489 | 491 | )
|
490 | 492 |
|
| 493 | + def test_multi_dataloader_state_dict_with_iterator_state(self) -> None: |
| 494 | + dataloader_1 = generate_random_dataloader( |
| 495 | + num_samples=8, input_dim=1, batch_size=8 |
| 496 | + ) |
| 497 | + dataloader_2 = generate_random_dataloader( |
| 498 | + num_samples=16, input_dim=1, batch_size=8 |
| 499 | + ) |
| 500 | + multi_dataloader = MultiDataLoader( |
| 501 | + self._get_dataloaders_dict(dataloader_1, dataloader_2), |
| 502 | + RoundRobin(), |
| 503 | + ) |
| 504 | + |
| 505 | + multi_dl_state_dict = multi_dataloader.state_dict() |
| 506 | + # before creating the iterator we don't expect the iterator_state to be present in the dl state dict |
| 507 | + self.assertFalse("iterator_state" in multi_dl_state_dict) |
| 508 | + |
| 509 | + multi_dl_iter = iter(multi_dataloader) |
| 510 | + multi_dl_state_dict = multi_dataloader.state_dict() |
| 511 | + self.assertTrue("iterator_state" in multi_dl_state_dict) |
| 512 | + self.assertEqual( |
| 513 | + multi_dl_state_dict["iterator_state"], |
| 514 | + {"cur_dataloader": "1", "finished_dataloaders": []}, |
| 515 | + ) |
| 516 | + next(multi_dl_iter) # should return batch from 1 |
| 517 | + next(multi_dl_iter) # should return batch from 2 |
| 518 | + next( |
| 519 | + multi_dl_iter |
| 520 | + ) # should return batch from 2 after raising StopIteration from the first dl |
| 521 | + multi_dl_state_dict = multi_dataloader.state_dict() |
| 522 | + self.assertTrue("iterator_state" in multi_dl_state_dict) |
| 523 | + self.assertEqual( |
| 524 | + multi_dl_state_dict["iterator_state"], |
| 525 | + {"cur_dataloader": "2", "finished_dataloaders": ["1"]}, |
| 526 | + ) |
| 527 | + |
| 528 | + # create fresh dl and load state dict. assert that the initial values are updated. |
| 529 | + multi_dataloader_2 = MultiDataLoader( |
| 530 | + self._get_dataloaders_dict(dataloader_1, dataloader_2), |
| 531 | + RoundRobin(), |
| 532 | + ) |
| 533 | + multi_dataloader_2.load_state_dict(multi_dl_state_dict) |
| 534 | + round_robin_iter = cast(RoundRobinIterator, iter(multi_dataloader_2)) |
| 535 | + self.assertEqual(round_robin_iter.cur_dataloader, "2") |
| 536 | + self.assertEqual(round_robin_iter.finished_dataloaders, ["1"]) |
| 537 | + |
| 538 | + # verify that after calling iter() again, values are reset |
| 539 | + round_robin_iter = cast(RoundRobinIterator, iter(multi_dataloader_2)) |
| 540 | + self.assertEqual(round_robin_iter.cur_dataloader, "1") |
| 541 | + self.assertEqual(round_robin_iter.finished_dataloaders, []) |
| 542 | + |
| 543 | + def test_invalid_load_state_dict(self) -> None: |
| 544 | + dataloader_1 = generate_random_dataloader( |
| 545 | + num_samples=8, input_dim=1, batch_size=8 |
| 546 | + ) |
| 547 | + dataloader_2 = generate_random_dataloader( |
| 548 | + num_samples=16, input_dim=1, batch_size=8 |
| 549 | + ) |
| 550 | + multi_dataloader = MultiDataLoader( |
| 551 | + self._get_dataloaders_dict(dataloader_1, dataloader_2), |
| 552 | + RoundRobin(), |
| 553 | + ) |
| 554 | + |
| 555 | + # invalid state dict - finished dataloaders and curr dataloader do not exist |
| 556 | + multi_dataloader.load_state_dict( |
| 557 | + {"finished_dataloaders": ["3"], "cur_dataloader": "4"} |
| 558 | + ) |
| 559 | + round_robin_iter = cast(RoundRobinIterator, iter(multi_dataloader)) |
| 560 | + # ensure the iterator state is not changed |
| 561 | + self.assertEqual(round_robin_iter.cur_dataloader, "1") |
| 562 | + self.assertEqual(round_robin_iter.finished_dataloaders, []) |
| 563 | + |
| 564 | + def test_state_dict_with_non_stateful_iterator(self) -> None: |
| 565 | + dataloader_1 = generate_random_dataloader( |
| 566 | + num_samples=8, input_dim=1, batch_size=8 |
| 567 | + ) |
| 568 | + dataloader_2 = generate_random_dataloader( |
| 569 | + num_samples=16, input_dim=1, batch_size=8 |
| 570 | + ) |
| 571 | + multi_dataloader = MultiDataLoader( |
| 572 | + self._get_dataloaders_dict(dataloader_1, dataloader_2), |
| 573 | + DataIterationStrategy(), |
| 574 | + CustomRandomIterator, |
| 575 | + ) |
| 576 | + iter(multi_dataloader) |
| 577 | + self.assertFalse("iterator_state" in multi_dataloader.state_dict()) |
| 578 | + |
491 | 579 | def _get_dataloaders_dict(
|
492 | 580 | self, first_dataloader: DataLoader, second_dataloader: DataLoader
|
493 | 581 | ) -> Dict[str, Union[DataLoader, Iterable[object]]]:
|
|
0 commit comments