Skip to content

Commit c62c630

Browse files
galrotemfacebook-github-bot
authored andcommitted
multidataloader - ensure state is restored for round robin (#823)
Summary: Pull Request resolved: #823 Add stateful API to `MultiIterator` and implement for `RoundRobinIterator` Reviewed By: diego-urgell, JKSenthil Differential Revision: D57012694 fbshipit-source-id: bc0ef26ede428bcc20009757c240c4968e89f029
1 parent ad59364 commit c62c630

File tree

3 files changed

+142
-2
lines changed

3 files changed

+142
-2
lines changed

tests/utils/data/test_multi_dataloader.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import torch
1616
from torch.utils.data import DataLoader, Dataset
17+
from torchtnt.framework._test_utils import generate_random_dataloader
1718

1819
from torchtnt.utils.data.iterators import (
1920
AllDatasetBatches,
@@ -22,6 +23,7 @@
2223
MultiIterator,
2324
RandomizedBatchSampler,
2425
RoundRobin,
26+
RoundRobinIterator,
2527
StoppingMechanism,
2628
)
2729
from torchtnt.utils.data.multi_dataloader import MultiDataLoader
@@ -488,6 +490,92 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
488490
0,
489491
)
490492

493+
def test_multi_dataloader_state_dict_with_iterator_state(self) -> None:
494+
dataloader_1 = generate_random_dataloader(
495+
num_samples=8, input_dim=1, batch_size=8
496+
)
497+
dataloader_2 = generate_random_dataloader(
498+
num_samples=16, input_dim=1, batch_size=8
499+
)
500+
multi_dataloader = MultiDataLoader(
501+
self._get_dataloaders_dict(dataloader_1, dataloader_2),
502+
RoundRobin(),
503+
)
504+
505+
multi_dl_state_dict = multi_dataloader.state_dict()
506+
# before creating the iterator we don't expect the iterator_state to be present in the dl state dict
507+
self.assertFalse("iterator_state" in multi_dl_state_dict)
508+
509+
multi_dl_iter = iter(multi_dataloader)
510+
multi_dl_state_dict = multi_dataloader.state_dict()
511+
self.assertTrue("iterator_state" in multi_dl_state_dict)
512+
self.assertEqual(
513+
multi_dl_state_dict["iterator_state"],
514+
{"cur_dataloader": "1", "finished_dataloaders": []},
515+
)
516+
next(multi_dl_iter) # should return batch from 1
517+
next(multi_dl_iter) # should return batch from 2
518+
next(
519+
multi_dl_iter
520+
) # should return batch from 2 after raising StopIteration from the first dl
521+
multi_dl_state_dict = multi_dataloader.state_dict()
522+
self.assertTrue("iterator_state" in multi_dl_state_dict)
523+
self.assertEqual(
524+
multi_dl_state_dict["iterator_state"],
525+
{"cur_dataloader": "2", "finished_dataloaders": ["1"]},
526+
)
527+
528+
# create fresh dl and load state dict. assert that the initial values are updated.
529+
multi_dataloader_2 = MultiDataLoader(
530+
self._get_dataloaders_dict(dataloader_1, dataloader_2),
531+
RoundRobin(),
532+
)
533+
multi_dataloader_2.load_state_dict(multi_dl_state_dict)
534+
round_robin_iter = cast(RoundRobinIterator, iter(multi_dataloader_2))
535+
self.assertEqual(round_robin_iter.cur_dataloader, "2")
536+
self.assertEqual(round_robin_iter.finished_dataloaders, ["1"])
537+
538+
# verify that after calling iter() again, values are reset
539+
round_robin_iter = cast(RoundRobinIterator, iter(multi_dataloader_2))
540+
self.assertEqual(round_robin_iter.cur_dataloader, "1")
541+
self.assertEqual(round_robin_iter.finished_dataloaders, [])
542+
543+
def test_invalid_load_state_dict(self) -> None:
544+
dataloader_1 = generate_random_dataloader(
545+
num_samples=8, input_dim=1, batch_size=8
546+
)
547+
dataloader_2 = generate_random_dataloader(
548+
num_samples=16, input_dim=1, batch_size=8
549+
)
550+
multi_dataloader = MultiDataLoader(
551+
self._get_dataloaders_dict(dataloader_1, dataloader_2),
552+
RoundRobin(),
553+
)
554+
555+
# invalid state dict - finished dataloaders and curr dataloader do not exist
556+
multi_dataloader.load_state_dict(
557+
{"finished_dataloaders": ["3"], "cur_dataloader": "4"}
558+
)
559+
round_robin_iter = cast(RoundRobinIterator, iter(multi_dataloader))
560+
# ensure the iterator state is not changed
561+
self.assertEqual(round_robin_iter.cur_dataloader, "1")
562+
self.assertEqual(round_robin_iter.finished_dataloaders, [])
563+
564+
def test_state_dict_with_non_stateful_iterator(self) -> None:
565+
dataloader_1 = generate_random_dataloader(
566+
num_samples=8, input_dim=1, batch_size=8
567+
)
568+
dataloader_2 = generate_random_dataloader(
569+
num_samples=16, input_dim=1, batch_size=8
570+
)
571+
multi_dataloader = MultiDataLoader(
572+
self._get_dataloaders_dict(dataloader_1, dataloader_2),
573+
DataIterationStrategy(),
574+
CustomRandomIterator,
575+
)
576+
iter(multi_dataloader)
577+
self.assertFalse("iterator_state" in multi_dataloader.state_dict())
578+
491579
def _get_dataloaders_dict(
492580
self, first_dataloader: DataLoader, second_dataloader: DataLoader
493581
) -> Dict[str, Union[DataLoader, Iterable[object]]]:

torchtnt/utils/data/iterators.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
from __future__ import annotations
1111

12+
import logging
13+
1214
import random
1315
from abc import abstractmethod
1416
from dataclasses import dataclass
@@ -32,11 +34,13 @@
3234
import torch
3335
import torch.distributed as dist
3436

35-
3637
if TYPE_CHECKING:
3738
from torch.utils.data import DataLoader
3839

3940

41+
logger: logging.Logger = logging.getLogger(__name__)
42+
43+
4044
@dataclass
4145
class DataIterationStrategy:
4246
pass
@@ -75,6 +79,12 @@ def __str__(self) -> str:
7579
def __next__(self) -> Dict[str, Any]:
7680
pass
7781

82+
def state_dict(self) -> Dict[str, Any]:
83+
return {}
84+
85+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
86+
pass
87+
7888

7989
class StoppingMechanism(Enum):
8090
ALL_DATASETS_EXHAUSTED = "ALL_DATASETS_EXHAUSTED"
@@ -176,6 +186,26 @@ def __next__(self) -> Dict[str, Any]:
176186

177187
return self.__next__()
178188

189+
def state_dict(self) -> Dict[str, Any]:
190+
return {
191+
"finished_dataloaders": self.finished_dataloaders,
192+
"cur_dataloader": self.cur_dataloader,
193+
}
194+
195+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
196+
logger.info(
197+
f"Loading RoundRobinIterator state. Finished dataloaders: {state_dict['finished_dataloaders']} and trying to set cur_dataloader to {self.cur_dataloader}"
198+
)
199+
self.finished_dataloaders = state_dict["finished_dataloaders"]
200+
cur_dataloader = state_dict["cur_dataloader"]
201+
if cur_dataloader not in self.dataloader_cycle:
202+
logger.warning(
203+
f"Did not find {cur_dataloader} in {list(self.dataloader_cycle)}. Skipping setting cur_dataloader"
204+
)
205+
return
206+
while self.cur_dataloader != cur_dataloader:
207+
self.cur_dataloader = next(self.dataloader_cycle)
208+
179209

180210
@dataclass
181211
class AllDatasetBatches(DataIterationStrategy):

torchtnt/utils/data/multi_dataloader.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,16 @@
1212
import logging
1313
from typing import Any, Dict, Iterable, Iterator, Optional, Type, TYPE_CHECKING, Union
1414

15+
from pyre_extensions import none_throws
16+
1517
from torchtnt.utils.data.iterators import (
1618
DataIterationStrategy,
1719
DataIterationStrategyRegistry,
1820
MultiIterator,
1921
)
2022
from torchtnt.utils.stateful import Stateful
2123

24+
2225
if TYPE_CHECKING:
2326
from torch.utils.data import DataLoader
2427

@@ -53,6 +56,7 @@ def __init__(
5356
self.individual_dataloaders = individual_dataloaders
5457
self.iteration_strategy = iteration_strategy
5558
self.iterator_cls = iterator_cls
59+
self.current_iterator: Optional[MultiIterator] = None
5660
for name in list(individual_dataloaders.keys()):
5761
try:
5862
next(iter(self.individual_dataloaders[name]))
@@ -64,6 +68,7 @@ def __init__(
6468
f"Dataloader '{name}' which contains no data. "
6569
"You might have empty dataloaders in the input dict."
6670
)
71+
self.iterator_state: Optional[Dict[str, Any]] = None
6772

6873
def __iter__(self) -> Iterator[Dict[str, Any]]:
6974
"""Iterator functions for the collection of dataloaders.
@@ -77,10 +82,15 @@ def __iter__(self) -> Iterator[Dict[str, Any]]:
7782
iterator_cls = DataIterationStrategyRegistry.get(self.iteration_strategy)
7883
# in practice, DataIterationStrategyRegistry.get() returns just concrete classes
7984
# pyre-ignore[45]: Cannot instantiate abstract class `MultiIterator`.
80-
return iterator_cls(
85+
self.current_iterator = iterator_cls(
8186
individual_dataloaders=self.individual_dataloaders,
8287
iteration_strategy=self.iteration_strategy,
8388
)
89+
if self.iterator_state is not None:
90+
self.current_iterator.load_state_dict(self.iterator_state)
91+
92+
self.iterator_state = None
93+
return none_throws(self.current_iterator)
8494

8595
def state_dict(self) -> Dict[str, Any]:
8696
"""Return an aggregated state dict based on individual dataloaders.
@@ -95,6 +105,14 @@ def state_dict(self) -> Dict[str, Any]:
95105
if isinstance(dl, Stateful):
96106
state_dict[name] = dl.state_dict()
97107

108+
if (current_iterator := self.current_iterator) is not None:
109+
iterator_state = current_iterator.state_dict()
110+
if iterator_state:
111+
logger.info("Storing iterator state in MultiDataLoader state_dict")
112+
# we make an implicit assumption here that none of the dataloaders have the "iterator_state" key in order to be backwards compatible
113+
# with already saved checkpoints (we don't want to modify the dataloaders stateful names)
114+
state_dict["iterator_state"] = iterator_state
115+
98116
return state_dict
99117

100118
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
@@ -114,3 +132,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
114132
)
115133
continue
116134
dl.load_state_dict(contents)
135+
136+
if "iterator_state" in state_dict:
137+
# this will be used during the next __iter__ call
138+
self.iterator_state = state_dict["iterator_state"]

0 commit comments

Comments
 (0)