Skip to content

Commit 59bfa39

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Make InOrder iterator stateful (#993)
Summary: Pull Request resolved: #993 Reviewed By: galrotem Differential Revision: D73863744 fbshipit-source-id: 42eaec46d64945fa0a481fda98a4c15a40a8239b
1 parent a5358ce commit 59bfa39

File tree

3 files changed

+118
-2
lines changed

3 files changed

+118
-2
lines changed

tests/utils/data/test_multi_dataloader.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,16 @@
1414

1515
import torch
1616
from torch.utils.data import DataLoader, Dataset
17-
from torchtnt.framework._test_utils import generate_random_dataloader
17+
from torchtnt.framework._test_utils import (
18+
generate_random_dataloader,
19+
generate_tensor_dataloader,
20+
)
1821

1922
from torchtnt.utils.data.iterators import (
2023
AllDatasetBatches,
2124
DataIterationStrategy,
2225
InOrder,
26+
InOrderIterator,
2327
MultiIterator,
2428
RandomizedBatchSampler,
2529
RoundRobin,
@@ -580,3 +584,68 @@ def _get_dataloaders_dict(
580584
self, first_dataloader: DataLoader, second_dataloader: DataLoader
581585
) -> Dict[str, Union[DataLoader, Iterable[object]]]:
582586
return {"1": first_dataloader, "2": second_dataloader}
587+
588+
def test_state_dict_with_inorder_iterator(self) -> None:
589+
dataloader_1 = generate_tensor_dataloader(torch.tensor([1, 2]), batch_size=1)
590+
dataloader_2 = generate_tensor_dataloader(torch.tensor([3, 4]), batch_size=1)
591+
dataloader_3 = generate_tensor_dataloader(torch.tensor([5, 6]), batch_size=1)
592+
593+
dataloaders_dict: Dict[str, Union[DataLoader, Iterable[object]]] = {
594+
"dataloader_1": dataloader_1,
595+
"dataloader_2": dataloader_2,
596+
"dataloader_3": dataloader_3,
597+
}
598+
599+
multi_dataloader = MultiDataLoader(
600+
dataloaders_dict,
601+
InOrder(iteration_order=["dataloader_1", "dataloader_2", "dataloader_3"]),
602+
)
603+
604+
mdl_iter = iter(multi_dataloader)
605+
606+
# Exhaust first dataset
607+
self.assertEqual(next(mdl_iter)["dataloader_1"], [torch.tensor([1])])
608+
self.assertEqual(next(mdl_iter)["dataloader_1"], [torch.tensor([2])])
609+
610+
# We expect same iterator since we haven't raise StopIteration
611+
mdl_sd = multi_dataloader.state_dict()
612+
self.assertEqual(
613+
mdl_sd["iterator_state"],
614+
{
615+
"iterators_finished": 0,
616+
"cur_iterator": "dataloader_1",
617+
},
618+
)
619+
620+
# Start second dataset
621+
self.assertEqual(next(mdl_iter)["dataloader_2"], [torch.tensor([3])])
622+
623+
mdl_sd = multi_dataloader.state_dict()
624+
self.assertEqual(
625+
mdl_sd["iterator_state"],
626+
{
627+
"iterators_finished": 1,
628+
"cur_iterator": "dataloader_2",
629+
},
630+
)
631+
632+
# Create new dataloader and verify restore
633+
multi_dataloader_2 = MultiDataLoader(
634+
dataloaders_dict,
635+
InOrder(iteration_order=["dataloader_1", "dataloader_2", "dataloader_3"]),
636+
)
637+
multi_dataloader_2.load_state_dict(mdl_sd)
638+
in_order_iter = cast(InOrderIterator, iter(multi_dataloader_2))
639+
self.assertEqual(in_order_iter.cur_iterator, "dataloader_1")
640+
self.assertEqual(in_order_iter.iterators_finished, 1)
641+
642+
# Calling next should update the currrent iterator
643+
# individual dl is not stateful
644+
self.assertEqual(next(in_order_iter)["dataloader_2"], [torch.tensor([3])])
645+
self.assertEqual(in_order_iter.cur_iterator, "dataloader_2")
646+
647+
# verify that after calling iter(), everything is reset
648+
in_order_iter_2 = cast(InOrderIterator, iter(multi_dataloader_2))
649+
self.assertEqual(in_order_iter_2.cur_iterator, "dataloader_1")
650+
self.assertEqual(in_order_iter_2.iterators_finished, 0)
651+
self.assertEqual(next(in_order_iter_2)["dataloader_1"], [torch.tensor([1])])

torchtnt/framework/_test_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,17 @@ def generate_dummy_stateful_dataloader(
280280
)
281281

282282

283+
def generate_tensor_dataloader(
284+
samples: torch.Tensor, batch_size: int
285+
) -> DummyStatefulDataLoader:
286+
return DummyStatefulDataLoader(
287+
DataLoader(
288+
dataset=TensorDataset(samples),
289+
batch_size=batch_size,
290+
)
291+
)
292+
293+
283294
class DummyMeanMetric:
284295
def __init__(self) -> None:
285296
super().__init__()

torchtnt/utils/data/iterators.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,7 @@ def __init__(
511511
self.cur_iter: Union[Iterator[DataLoader], Iterator[object]] = iter(
512512
self.individual_dataloaders[self.iteration_order[0]]
513513
)
514+
self.cur_iterator_idx: int = 0
514515
self.cur_iterator: str = self.iteration_order[0]
515516
self.num_iterators: int = len(self.iteration_order)
516517
self.iterators_finished: int = 0
@@ -519,6 +520,17 @@ def __next__(self) -> Dict[str, Any]:
519520
if self.iterators_finished == self.num_iterators:
520521
raise StopIteration
521522

523+
# If the current iterator doesn't match the expected number of finished iterators,
524+
# it means we restored from checkpoint and we need to initialize expected iterator
525+
# This is to avoid calling iter() in the load_state_dict() function.
526+
if self.iterators_finished != self.cur_iterator_idx:
527+
logger.info(
528+
f"Initializing iterator {self.cur_iterator} after resuming from checkpoint"
529+
)
530+
self.cur_iterator_idx = self.iterators_finished
531+
self.cur_iterator = self.iteration_order[self.iterators_finished]
532+
self.cur_iter = iter(self.individual_dataloaders[self.cur_iterator])
533+
522534
try:
523535
return {self.cur_iterator: next(self.cur_iter)}
524536
except StopIteration:
@@ -528,12 +540,36 @@ def __next__(self) -> Dict[str, Any]:
528540
if self.iterators_finished == self.num_iterators:
529541
raise StopIteration
530542

543+
self.cur_iterator_idx += 1
531544
self.cur_iterator = self.iteration_order[self.iterators_finished]
532-
533545
self.cur_iter = iter(self.individual_dataloaders[self.cur_iterator])
534546

535547
return self.__next__()
536548

549+
def state_dict(self) -> Dict[str, Any]:
550+
return {
551+
"iterators_finished": self.iterators_finished,
552+
"cur_iterator": self.cur_iterator,
553+
}
554+
555+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
556+
iterators_finished: int = state_dict["iterators_finished"]
557+
cur_iterator: str = state_dict["cur_iterator"]
558+
logger.info(
559+
f"Loading InOrderIterator state. Trying to set iterators_finished to {iterators_finished} to restore {cur_iterator}"
560+
)
561+
562+
if cur_iterator not in self.iteration_order or iterators_finished > len(
563+
self.iteration_order
564+
):
565+
logger.warning(
566+
f"Will not restore InOrderIterator state, since expected dataloader was not found in available iterators: {cur_iterator}"
567+
)
568+
return
569+
570+
self.iterators_finished = iterators_finished
571+
# We do not initialize actual iterator here to avoid checkpoint restore taking longer
572+
537573

538574
class DataIterationStrategyRegistry:
539575
"""A generic iterator registry.

0 commit comments

Comments
 (0)