|
14 | 14 |
|
15 | 15 | import torch
|
16 | 16 | from torch.utils.data import DataLoader, Dataset
|
17 |
| -from torchtnt.framework._test_utils import generate_random_dataloader |
| 17 | +from torchtnt.framework._test_utils import ( |
| 18 | + generate_random_dataloader, |
| 19 | + generate_tensor_dataloader, |
| 20 | +) |
18 | 21 |
|
19 | 22 | from torchtnt.utils.data.iterators import (
|
20 | 23 | AllDatasetBatches,
|
21 | 24 | DataIterationStrategy,
|
22 | 25 | InOrder,
|
| 26 | + InOrderIterator, |
23 | 27 | MultiIterator,
|
24 | 28 | RandomizedBatchSampler,
|
25 | 29 | RoundRobin,
|
@@ -580,3 +584,68 @@ def _get_dataloaders_dict(
|
580 | 584 | self, first_dataloader: DataLoader, second_dataloader: DataLoader
|
581 | 585 | ) -> Dict[str, Union[DataLoader, Iterable[object]]]:
|
582 | 586 | return {"1": first_dataloader, "2": second_dataloader}
|
| 587 | + |
| 588 | + def test_state_dict_with_inorder_iterator(self) -> None: |
| 589 | + dataloader_1 = generate_tensor_dataloader(torch.tensor([1, 2]), batch_size=1) |
| 590 | + dataloader_2 = generate_tensor_dataloader(torch.tensor([3, 4]), batch_size=1) |
| 591 | + dataloader_3 = generate_tensor_dataloader(torch.tensor([5, 6]), batch_size=1) |
| 592 | + |
| 593 | + dataloaders_dict: Dict[str, Union[DataLoader, Iterable[object]]] = { |
| 594 | + "dataloader_1": dataloader_1, |
| 595 | + "dataloader_2": dataloader_2, |
| 596 | + "dataloader_3": dataloader_3, |
| 597 | + } |
| 598 | + |
| 599 | + multi_dataloader = MultiDataLoader( |
| 600 | + dataloaders_dict, |
| 601 | + InOrder(iteration_order=["dataloader_1", "dataloader_2", "dataloader_3"]), |
| 602 | + ) |
| 603 | + |
| 604 | + mdl_iter = iter(multi_dataloader) |
| 605 | + |
| 606 | + # Exhaust first dataset |
| 607 | + self.assertEqual(next(mdl_iter)["dataloader_1"], [torch.tensor([1])]) |
| 608 | + self.assertEqual(next(mdl_iter)["dataloader_1"], [torch.tensor([2])]) |
| 609 | + |
| 610 | + # We expect same iterator since we haven't raise StopIteration |
| 611 | + mdl_sd = multi_dataloader.state_dict() |
| 612 | + self.assertEqual( |
| 613 | + mdl_sd["iterator_state"], |
| 614 | + { |
| 615 | + "iterators_finished": 0, |
| 616 | + "cur_iterator": "dataloader_1", |
| 617 | + }, |
| 618 | + ) |
| 619 | + |
| 620 | + # Start second dataset |
| 621 | + self.assertEqual(next(mdl_iter)["dataloader_2"], [torch.tensor([3])]) |
| 622 | + |
| 623 | + mdl_sd = multi_dataloader.state_dict() |
| 624 | + self.assertEqual( |
| 625 | + mdl_sd["iterator_state"], |
| 626 | + { |
| 627 | + "iterators_finished": 1, |
| 628 | + "cur_iterator": "dataloader_2", |
| 629 | + }, |
| 630 | + ) |
| 631 | + |
| 632 | + # Create new dataloader and verify restore |
| 633 | + multi_dataloader_2 = MultiDataLoader( |
| 634 | + dataloaders_dict, |
| 635 | + InOrder(iteration_order=["dataloader_1", "dataloader_2", "dataloader_3"]), |
| 636 | + ) |
| 637 | + multi_dataloader_2.load_state_dict(mdl_sd) |
| 638 | + in_order_iter = cast(InOrderIterator, iter(multi_dataloader_2)) |
| 639 | + self.assertEqual(in_order_iter.cur_iterator, "dataloader_1") |
| 640 | + self.assertEqual(in_order_iter.iterators_finished, 1) |
| 641 | + |
| 642 | + # Calling next should update the currrent iterator |
| 643 | + # individual dl is not stateful |
| 644 | + self.assertEqual(next(in_order_iter)["dataloader_2"], [torch.tensor([3])]) |
| 645 | + self.assertEqual(in_order_iter.cur_iterator, "dataloader_2") |
| 646 | + |
| 647 | + # verify that after calling iter(), everything is reset |
| 648 | + in_order_iter_2 = cast(InOrderIterator, iter(multi_dataloader_2)) |
| 649 | + self.assertEqual(in_order_iter_2.cur_iterator, "dataloader_1") |
| 650 | + self.assertEqual(in_order_iter_2.iterators_finished, 0) |
| 651 | + self.assertEqual(next(in_order_iter_2)["dataloader_1"], [torch.tensor([1])]) |
0 commit comments