-
Notifications
You must be signed in to change notification settings - Fork 133
Dataset: implement global dataset_distribution option
#1676
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 3 commits
924045e
bb03813
9410000
3991edb
c0bce2b
6261afc
f6a8942
c3ad46e
c3b2c04
485b6f4
653969c
7d753f5
2607de2
5a14a5e
8d56a65
55d244d
90353a9
1011ace
3b28635
a3e6ad5
713fde9
148191d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -137,9 +137,8 @@ def __init__( | |||||||||||
| get_sub_epoch_dataset: Callable[[List[FileTree]], Dict[str, Any]], | ||||||||||||
| preload_next_n_sub_epochs: int = 1, | ||||||||||||
| buffer_size: int = 1, | ||||||||||||
| distrib_shard_files: bool = False, | ||||||||||||
| distrib_shard_files: Optional[bool] = None, | ||||||||||||
| _meta_info_cache: Optional[Dict[str, Any]] = None, | ||||||||||||
| _distrib_info: Optional[Dict[str, int]] = None, | ||||||||||||
| **kwargs, | ||||||||||||
| ): | ||||||||||||
| """ | ||||||||||||
|
|
@@ -148,10 +147,10 @@ def __init__( | |||||||||||
| :param get_sub_epoch_dataset: callable which returns a dataset dict for a given subset of files | ||||||||||||
| :param preload_next_n_sub_epochs: how many sub epoch datasets to preload | ||||||||||||
| :param buffer_size: buffer size for each worker, amount of seqs to prefetch | ||||||||||||
| :param distrib_shard_files: set to true to shard the data across worker processes in | ||||||||||||
| distributed training scenaria | ||||||||||||
| :param distrib_shard_files: deprecated. Replaced by global config option ``dataset_distribution="shard"``. | ||||||||||||
|
|
||||||||||||
| Set to true to shard the data across worker processes in distributed training scenaria. | ||||||||||||
|
Comment on lines
+155
to
+157
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this is necessary to mark it as deprecated.
Suggested change
|
||||||||||||
| :param _meta_info_cache: for internal use | ||||||||||||
| :param _distrib_info: for internal use | ||||||||||||
| """ | ||||||||||||
| super().__init__(**kwargs) | ||||||||||||
| self.files = files | ||||||||||||
|
|
@@ -166,21 +165,13 @@ def __init__( | |||||||||||
| self._workers: Dict[int, _WorkerProcParent] = {} # epoch -> worker | ||||||||||||
| self._files_order_cache: Dict[int, List[List[FileTree]]] = {} # full epoch (0-indexed) -> files order | ||||||||||||
|
|
||||||||||||
| self.distrib_shard_files = distrib_shard_files | ||||||||||||
| if distrib_shard_files: | ||||||||||||
| assert self._num_shards == 1 and self._shard_index == 0, ( # ensure defaults are set | ||||||||||||
| f"{self}: Cannot use both dataset-sharding via properties _num_shards and _shard index " | ||||||||||||
| f"and {self.__class__.__name__}'s own sharding implementation based on the trainings rank and size." | ||||||||||||
| if distrib_shard_files is not None: | ||||||||||||
| log.print_deprecation_warning( | ||||||||||||
| f"{self.__class__.__name__}' `distrib_shard_files` config option is set. " | ||||||||||||
| "Use global config option `dataset_distribution` instead " | ||||||||||||
| "for the same behavior across more types of datasets." | ||||||||||||
| ) | ||||||||||||
|
Comment on lines
+175
to
179
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure this is really necessary to mark this as deprecated.
Suggested change
|
||||||||||||
| if _distrib_info: | ||||||||||||
| # If we're in a child process `_get_rank_and_size()` no longer works, | ||||||||||||
| # so we pass the info about the shards via a pickled property. | ||||||||||||
| # See also Dataset.__reduce__. | ||||||||||||
| self._shard_index = _distrib_info["_shard_index"] | ||||||||||||
| self._num_shards = _distrib_info["_num_shards"] | ||||||||||||
| else: | ||||||||||||
| self._shard_index, self._num_shards = _get_rank_and_size() | ||||||||||||
| assert 0 <= self._shard_index < self._num_shards | ||||||||||||
| self._validate_global_shard_cfg(distrib_shard_files) | ||||||||||||
|
|
||||||||||||
| if _meta_info_cache: | ||||||||||||
| # This allows to skip the lazy init in self.initialize(). | ||||||||||||
|
|
@@ -204,10 +195,6 @@ def supports_sharding(self) -> bool: | |||||||||||
| """:return: whether the dataset supports sharding based on the seq_order""" | ||||||||||||
| return True | ||||||||||||
|
|
||||||||||||
| @property | ||||||||||||
| def _distrib_info(self): | ||||||||||||
| return {"_num_shards": self._num_shards, "_shard_index": self._shard_index} | ||||||||||||
|
|
||||||||||||
| @property | ||||||||||||
| def _meta_info_cache(self): | ||||||||||||
| if not self.num_outputs: | ||||||||||||
|
|
@@ -220,9 +207,6 @@ def _meta_info_cache(self): | |||||||||||
| "file_sizes": self._file_sizes, | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| def _uses_custom_distributed_sharding(self) -> bool: | ||||||||||||
| return self._num_shards > 1 | ||||||||||||
|
|
||||||||||||
| def _lazy_init_num_outputs(self): | ||||||||||||
| if self.num_outputs: | ||||||||||||
| return | ||||||||||||
|
|
@@ -290,11 +274,11 @@ def init_seq_order(self, epoch: Optional[int] = None, seq_list=None, seq_order=N | |||||||||||
| else: | ||||||||||||
| raise ValueError(f"{self}: seq_ordering {self.seq_ordering!r} not supported") | ||||||||||||
| file_bins = self._distribute_evenly_by_size( | ||||||||||||
| num_bins=self._num_shards * self.partition_epoch, | ||||||||||||
| num_bins=self.num_shards * self.partition_epoch, | ||||||||||||
| file_sizes=self._file_sizes, | ||||||||||||
| files_order=files_order_flat, | ||||||||||||
| ) | ||||||||||||
| self_index_base = self.partition_epoch * self._shard_index | ||||||||||||
| self_index_base = self.partition_epoch * self.shard_index | ||||||||||||
| self_index_end = self_index_base + self.partition_epoch | ||||||||||||
| self._files_order_cache[full_epoch_0idx_] = file_bins[self_index_base:self_index_end] | ||||||||||||
|
|
||||||||||||
|
|
@@ -328,6 +312,10 @@ def _get_sub_dataset_dict(self, files: List[FileTree]) -> Dict[str, Any]: | |||||||||||
|
|
||||||||||||
| dataset_dict = self.get_sub_epoch_dataset(files) | ||||||||||||
| dataset_dict = extend_dataset_dict_from_parent_dataset(dataset_dict, parent_dataset=self) | ||||||||||||
| # We shard by splitting the files list into shards, the sub datasets must not shard any further by themselves | ||||||||||||
| if self.num_shards > 1: | ||||||||||||
| dataset_dict["_num_shards"] = 1 | ||||||||||||
| dataset_dict["_shard_index"] = 0 | ||||||||||||
|
|
||||||||||||
| flat_sub_dset = tree.flatten_with_path(dataset_dict) | ||||||||||||
|
|
||||||||||||
|
|
@@ -452,6 +440,21 @@ def get_data_keys(self) -> List[str]: | |||||||||||
| self._lazy_init_num_outputs() | ||||||||||||
| return self._data_keys | ||||||||||||
|
|
||||||||||||
| @classmethod | ||||||||||||
| def _validate_global_shard_cfg(cls, distrib_shard_files: bool): | ||||||||||||
| from returnn.config import get_global_config | ||||||||||||
|
|
||||||||||||
| config = get_global_config(raise_exception=False) | ||||||||||||
| if not config: | ||||||||||||
| return | ||||||||||||
|
|
||||||||||||
| dd_cfg = config.typed_value("dataset_distribution", None) | ||||||||||||
| if dd_cfg and (distrib_shard_files and dd_cfg != "shard") or (not distrib_shard_files and dd_cfg == "shard"): | ||||||||||||
| raise ValueError( | ||||||||||||
| f"{cls.__name__}: `distrib_shard_files` config ({distrib_shard_files}) mismatch " | ||||||||||||
| f"with global config option `dataset_distribution` ({dd_cfg})." | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def _get_key_for_file_tree(t: FileTree) -> str: | ||||||||||||
| """generates a deterministic key given a file tree""" | ||||||||||||
|
|
@@ -460,32 +463,6 @@ def _get_key_for_file_tree(t: FileTree) -> str: | |||||||||||
| return ":".join(tree.flatten(t)) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def _get_rank_and_size() -> Tuple[int, int]: | ||||||||||||
| """ | ||||||||||||
| :return: tuple (rank, size): the global rank and size for distributed trainings | ||||||||||||
| """ | ||||||||||||
|
|
||||||||||||
| from returnn.config import get_global_config | ||||||||||||
|
|
||||||||||||
| config = get_global_config(raise_exception=False) | ||||||||||||
| if not config: | ||||||||||||
| return 0, 1 | ||||||||||||
| if config.typed_value("torch_distributed") is not None: | ||||||||||||
| import returnn.torch.distributed | ||||||||||||
|
|
||||||||||||
| ctx = returnn.torch.distributed.get_ctx(config=config) | ||||||||||||
| return ctx.rank(), ctx.size() | ||||||||||||
| elif config.is_true("use_horovod"): | ||||||||||||
| assert config.bool("use_tensorflow", False) or config.value("backend", "").startswith("tensorflow") | ||||||||||||
|
|
||||||||||||
| import returnn.tf.horovod | ||||||||||||
|
|
||||||||||||
| ctx = returnn.tf.horovod.get_ctx(config=config) | ||||||||||||
| return ctx.rank(), ctx.size() | ||||||||||||
| else: | ||||||||||||
| return 0, 1 | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| class _WorkerProcParent: | ||||||||||||
| def __init__( | ||||||||||||
| self, | ||||||||||||
|
|
||||||||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -561,10 +561,9 @@ def create_ogg_zip_txt_only_dataset(*, text: str = "hello world", seq_tag: str = | |||||
|
|
||||||
|
|
||||||
| @contextlib.contextmanager | ||||||
| def create_ogg_zip_txt_only_dataset_mult_seqs(*, seed: int = 1, num_seqs: int = 100, max_seq_len: int = 100): | ||||||
| def create_ogg_zip_txt_only_dataset_mult_seqs_opts(*, seed: int = 1, num_seqs: int = 100, max_seq_len: int = 100): | ||||||
| """create OggZipDataset""" | ||||||
| import zipfile | ||||||
| from returnn.datasets.audio import OggZipDataset | ||||||
|
|
||||||
| rnd = numpy.random.RandomState(seed) | ||||||
|
|
||||||
|
|
@@ -593,6 +592,15 @@ def create_ogg_zip_txt_only_dataset_mult_seqs(*, seed: int = 1, num_seqs: int = | |||||
| "audio": None, | ||||||
| "targets": {"class": "CharacterTargets", "vocab_file": tmp_vocab_file.name, "seq_postfix": []}, | ||||||
| } | ||||||
| yield opts | ||||||
|
|
||||||
|
|
||||||
| @contextlib.contextmanager | ||||||
| def create_ogg_zip_txt_only_dataset_mult_seqs(*, seed: int = 1, num_seqs: int = 100, max_seq_len: int = 100): | ||||||
| """create OggZipDataset""" | ||||||
| from returnn.datasets.audio import OggZipDataset | ||||||
|
|
||||||
| with create_ogg_zip_txt_only_dataset_mult_seqs_opts(seed=seed, num_seqs=num_seqs, max_seq_len=max_seq_len) as opts: | ||||||
| dataset = init_dataset(opts) | ||||||
| assert isinstance(dataset, OggZipDataset) | ||||||
| yield dataset | ||||||
|
|
@@ -1212,6 +1220,18 @@ def _collect_single_seq(self, seq_idx: int) -> Optional[DatasetSeq]: | |||||
| assert sub_ep == outer_epoch * multi_epoch + 1 and sub_seq_idx == 0 | ||||||
|
|
||||||
|
|
||||||
| def test_dataset_sharding(): | ||||||
| from returnn.datasets.audio import OggZipDataset | ||||||
|
|
||||||
| with create_ogg_zip_txt_only_dataset_mult_seqs_opts(num_seqs=10) as dataset_opts: | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the test would be a bit nicer if the num_seqs is uneven, not divisible by num_shards.
Suggested change
|
||||||
| datasets = [init_dataset({**dataset_opts, "_num_shards": 2, "_shard_index": i}) for i in range(2)] | ||||||
| for dataset in datasets: | ||||||
| assert isinstance(dataset, OggZipDataset) | ||||||
| dataset.init_seq_order(epoch=1) | ||||||
| assert dataset.shard_index < dataset.num_shards == 2 | ||||||
| assert dataset.num_seqs == 5 | ||||||
|
|
||||||
|
|
||||||
| if __name__ == "__main__": | ||||||
| better_exchook.install() | ||||||
| if len(sys.argv) <= 1: | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't like that we access the global config here. I know this follows similar code as
_get_default_random_seed_offsetbut I also don't like it there. Why is this needed? This should come from outside, or not? Specifically at the place where we callinit_dataset. E.g. in the__main__. There we also callDataset.kwargs_update_from_config.Also, the code is wrong. Distributed training is only one possible source which defines/influences the shard index and num shards. But there are other reasons, for example the MultiProcDataset, or PyTorch DataLoader num_workers.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is already setting the num_shards and shard index for its children. The code always was designed in such a way that it would only look at the global config in case there was no value already set. But I agree, now it's better.
I think this is actually not that trivial to implement because the torch engine already is given initialized datasets and it's difficult to change the sharding config after having initialized a dataset. So factoring in the PyTorch num_workers needs to be done during the initial dataset initialization, which mixes up pytorch code with data initialization code a bit, and I feel that is going to be a bit messy. Do you know a good way? Maybe this is fine after all.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we cannot achieve both
and
because of
However, I think it's worth it to have proper support for
torch_dataloader_opts = {"num_workers": n}withn > 1because this makes it much simpler for the end user to have multi-process data loading and this feature can replaceMultiProcDatasetfor simple use cases. So I think I need to revert back on the changes wherenum_shardsandshard_indexare set from the outside, and rather fetch them from inside the dataset, when they are needed. At that point the torch worker info is also available, which means that we can take it into account properly.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why difficult? We maybe just need a clean dataset API for that, some setter function
set_num_shards_and_shard_idxor so. And then inReturnnDatasetIterDataPipe.resetor so we just need to call that.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, I was originally not a fan of the mutability of these properties, but it seems ok now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we cannot avoid such API like
set_num_shards_and_shard_idxbecause of how the PyTorch data pipeline works.