diff --git a/returnn/datasets/basic.py b/returnn/datasets/basic.py index 15d05b631..8dfb4c46f 100644 --- a/returnn/datasets/basic.py +++ b/returnn/datasets/basic.py @@ -19,7 +19,6 @@ import math import numpy import functools -import typing from typing import TYPE_CHECKING, Optional, Any, Set, Tuple, Union, Type, Dict, Sequence, List, Callable from returnn.log import log @@ -68,6 +67,12 @@ def set_or_remove(key, value): set_or_remove("min_chunk_size", config.opt_typed_value("min_chunk_size", 0) or None) set_or_remove("chunking_variance", config.float("chunking_variance", 0)) + dd_cfg = config.typed_value("dataset_distribution", "random_seed_offset") + assert dd_cfg in ["random_seed_offset", "shard"] + shard_index, num_shards = Dataset._get_sharding_rank_and_size(config) if dd_cfg == "shard" else 0, 1 + set_or_remove("num_shards", num_shards) + set_or_remove("shard_index", shard_index) + @staticmethod def get_default_kwargs_eval(config: Config) -> Dict[str, Any]: """ @@ -112,8 +117,8 @@ def __init__( min_chunk_size=0, chunking_variance=0, estimated_num_seqs=None, - _num_shards=1, - _shard_index=0, + num_shards: int = 1, + shard_index: int = 0, ): """ :param str name: e.g. "train" or "eval" @@ -137,8 +142,8 @@ def __init__( :param str|None seq_order_seq_lens_file: for seq order, use the seq length given by this file :param int shuffle_frames_of_nseqs: shuffles the frames. not always supported :param None|int estimated_num_seqs: for progress reporting in case the real num_seqs is unknown - :param int _num_shards: number of shards the data is split into - :param int _shard_index: local shard index, when sharding is enabled + :param num_shards: number of shards the data is split into + :param shard_index: local shard index, when sharding is enabled """ self.name = name or ("dataset_id%s" % id(self)) self.lock: Optional[RLock] = None # Used when manipulating our data potentially from multiple threads. @@ -170,9 +175,9 @@ def __init__( self._chunking = chunking self.chunk_size, self.chunk_step, self.custom_chunking_func = self._parse_chunking(chunking) self._context_window = context_window - assert 0 <= _shard_index < _num_shards - self._num_shards = _num_shards - self._shard_index = _shard_index + assert 0 <= shard_index < num_shards + self.num_shards = num_shards + self.shard_index = shard_index if isinstance(context_window, (tuple, list)): assert len(context_window) == 2 for elem in context_window: @@ -248,6 +253,37 @@ def __reduce__(self): state = {attr: getattr(self, attr) for attr in ["epoch", "zpad"]} return Dataset._create_from_reduce, (self.__class__, kwargs, state) + def set_shard_idx_and_num_shards(self, shard_index: int, num_shards: int): + """set sharding config for the dataset""" + assert 0 <= shard_index < num_shards + assert num_shards == 1 or self.supports_sharding() + self.num_shards = num_shards + self.shard_index = shard_index + + @staticmethod + def _get_sharding_rank_and_size(config: Optional[Config] = None) -> Tuple[int, int]: + """ + :return: tuple (rank, size): the global rank and size for distributed trainings + """ + if config is None: + from returnn.config import get_global_config + + config = get_global_config(return_empty_if_none=True) + if config.typed_value("torch_distributed") is not None: + import returnn.torch.distributed + + ctx = returnn.torch.distributed.get_ctx(config=config) + return ctx.rank(), ctx.size() + elif config.is_true("use_horovod"): + assert config.bool("use_tensorflow", False) or config.value("backend", "").startswith("tensorflow") + + import returnn.tf.horovod + + ctx = returnn.tf.horovod.get_ctx(config=config) + return ctx.rank(), ctx.size() + else: + return 0, 1 + @property def random_seed_offset(self) -> int: """:return: random seed offset for shuffling""" @@ -257,10 +293,10 @@ def random_seed_offset(self) -> int: def _uses_custom_distributed_sharding(self) -> bool: """ - :return: if dataset has its own data sharding logic independent of TF/PT. + :return: if the dataset has its own data sharding logic independent of TF/PT. Leads to a fixed random_seed_offset independent of the workers local rank. """ - return False + return self.num_shards > 1 def _get_default_random_seed_offset(self): """ @@ -559,7 +595,7 @@ def get_seq_order_for_epoch( for i in range(1, num): seq_index[i::num] += i * (num_seqs // num) elif seq_ordering_method == "reverse": - seq_index = range(num_seqs - 1, -1, -1) # type: Union[range, typing.Sequence[int]] + seq_index: Union[range, Sequence[int]] = range(num_seqs - 1, -1, -1) elif seq_ordering_method in ["sorted", "sorted_reverse"]: assert get_seq_len reverse = -1 if seq_ordering_method == "sorted_reverse" else 1 @@ -641,9 +677,9 @@ def get_seq_order_for_epoch( seq_index = [ i for i in seq_index if (all_seq_tags[i] not in used_seq_tags, used_seq_tags.add(all_seq_tags[i]))[0] ] - if partition_epoch > 1 or self._num_shards > 1: + if partition_epoch > 1 or self.num_shards > 1: seq_index = self._apply_partition_epoch_and_sharding( - seq_index, partition_epoch, epoch, self._num_shards, self._shard_index + seq_index, partition_epoch, epoch, self.num_shards, self.shard_index ) if repeat_epoch > 1: seq_index = list(seq_index) * repeat_epoch @@ -735,8 +771,8 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None): """ self.epoch = epoch self.rnd_seq_drop = Random(self._get_random_seed_for_epoch(epoch=epoch)) - assert self._num_shards == 1 or self.supports_sharding(), ( - f"{self}: does not support sharding, but got num_shards == {self._num_shards}" + assert self.num_shards == 1 or self.supports_sharding(), ( + f"{self}: does not support sharding, but got num_shards == {self.num_shards}" ) return False @@ -902,7 +938,7 @@ def get_corpus_seq_idx(self, seq_idx): if self.seq_ordering == "default" and self.partition_epoch == 1: return seq_idx assert self.have_corpus_seq_idx() - raise NotImplemented + raise NotImplementedError def have_get_corpus_seq(self) -> bool: """ @@ -1061,7 +1097,7 @@ def get_data_shape(self, key: str) -> List[int]: if key in self.num_outputs: if self.num_outputs[key][1] <= 1: return [] - res_shape = [None] * (self.num_outputs[key][1] - 1) # type: typing.List[typing.Union[None,int]] + res_shape: List[Union[None, int]] = [None] * (self.num_outputs[key][1] - 1) if not self.is_data_sparse(key): res_shape[-1] = self.get_data_dim(key) return res_shape @@ -1587,9 +1623,9 @@ def _dataset_extend_default_kwargs_from_parent_dataset( default_kwargs = default_kwargs.copy() if default_kwargs else {} default_kwargs.setdefault("random_seed_offset", parent_dataset.random_seed_offset) # noinspection PyProtectedMember - default_kwargs.setdefault("_num_shards", parent_dataset._num_shards) + default_kwargs.setdefault("num_shards", parent_dataset.num_shards) # noinspection PyProtectedMember - default_kwargs.setdefault("_shard_index", parent_dataset._shard_index) + default_kwargs.setdefault("shard_index", parent_dataset.shard_index) return default_kwargs diff --git a/returnn/datasets/distrib_files.py b/returnn/datasets/distrib_files.py index 62cf5f1d9..fc1eb6a0a 100644 --- a/returnn/datasets/distrib_files.py +++ b/returnn/datasets/distrib_files.py @@ -138,9 +138,8 @@ def __init__( get_sub_epoch_dataset: Callable[[List[FileTree]], Dict[str, Any]], preload_next_n_sub_epochs: int = 1, buffer_size: int = 1, - distrib_shard_files: bool = False, + distrib_shard_files: Optional[bool] = None, _meta_info_cache: Optional[Dict[str, Any]] = None, - _distrib_info: Optional[Dict[str, int]] = None, **kwargs, ): """ @@ -153,10 +152,10 @@ def __init__( :param get_sub_epoch_dataset: callable which returns a dataset dict for a given subset of files :param preload_next_n_sub_epochs: how many sub epoch datasets to preload :param buffer_size: buffer size for each worker, amount of seqs to prefetch - :param distrib_shard_files: set to true to shard the data across worker processes in - distributed training scenaria + :param distrib_shard_files: deprecated. Replaced by global config option ``dataset_distribution="shard"``. + + Set to true to shard the data across worker processes in distributed training scenaria. :param _meta_info_cache: for internal use - :param _distrib_info: for internal use """ super().__init__(**kwargs) self.files = files @@ -172,21 +171,16 @@ def __init__( self._workers: Dict[int, _WorkerProcParent] = {} # epoch -> worker self._files_order_cache: Dict[int, List[List[FileTree]]] = {} # full epoch (0-indexed) -> files order - self.distrib_shard_files = distrib_shard_files - if distrib_shard_files: - assert self._num_shards == 1 and self._shard_index == 0, ( # ensure defaults are set - f"{self}: Cannot use both dataset-sharding via properties _num_shards and _shard index " - f"and {self.__class__.__name__}'s own sharding implementation based on the trainings rank and size." + if distrib_shard_files is not None: + log.print_deprecation_warning( + f"{self.__class__.__name__}' `distrib_shard_files` config option is set. " + "Use global config option `dataset_distribution` instead " + "for the same behavior across more types of datasets." ) - if _distrib_info: - # If we're in a child process `_get_rank_and_size()` no longer works, - # so we pass the info about the shards via a pickled property. - # See also Dataset.__reduce__. - self._shard_index = _distrib_info["_shard_index"] - self._num_shards = _distrib_info["_num_shards"] - else: - self._shard_index, self._num_shards = _get_rank_and_size() - assert 0 <= self._shard_index < self._num_shards + sharding_info = self._validate_global_config_and_get_sharding_info(distrib_shard_files) + if sharding_info is not None: + self.shard_index, self.num_shards = sharding_info + self.distrib_shard_files = distrib_shard_files if _meta_info_cache: # This allows to skip the lazy init in self.initialize(). @@ -208,10 +202,6 @@ def supports_sharding(self) -> bool: """:return: whether the dataset supports sharding based on the seq_order""" return True - @property - def _distrib_info(self): - return {"_num_shards": self._num_shards, "_shard_index": self._shard_index} - @property def _meta_info_cache(self): if not self.num_outputs: @@ -225,9 +215,6 @@ def _meta_info_cache(self): "files": self._files, } - def _uses_custom_distributed_sharding(self) -> bool: - return self._num_shards > 1 - def _lazy_init_file_list(self): """ The list of data files can either be provided as python list, or, if that grows @@ -332,11 +319,11 @@ def init_seq_order(self, epoch: Optional[int] = None, seq_list=None, seq_order=N else: raise ValueError(f"{self}: seq_ordering {self.seq_ordering!r} not supported") file_bins = self._distribute_evenly_by_size( - num_bins=self._num_shards * self.partition_epoch, + num_bins=self.num_shards * self.partition_epoch, file_sizes=self._file_sizes, files_order=files_order_flat, ) - self_index_base = self.partition_epoch * self._shard_index + self_index_base = self.partition_epoch * self.shard_index self_index_end = self_index_base + self.partition_epoch self._files_order_cache[full_epoch_0idx_] = file_bins[self_index_base:self_index_end] @@ -370,6 +357,10 @@ def _get_sub_dataset_dict(self, files: List[FileTree]) -> Dict[str, Any]: dataset_dict = self.get_sub_epoch_dataset(files) dataset_dict = extend_dataset_dict_from_parent_dataset(dataset_dict, parent_dataset=self) + # We shard by splitting the files list into shards, the sub datasets must not shard any further by themselves + if self.num_shards > 1: + dataset_dict["num_shards"] = 1 + dataset_dict["shard_index"] = 0 flat_sub_dset = tree.flatten_with_path(dataset_dict) @@ -494,38 +485,33 @@ def get_data_keys(self) -> List[str]: self._lazy_init_num_outputs() return self._data_keys + @classmethod + def _validate_global_config_and_get_sharding_info(cls, distrib_shard_files: bool) -> Optional[Tuple[int, int]]: + from returnn.config import get_global_config -def _get_key_for_file_tree(t: FileTree) -> str: - """generates a deterministic key given a file tree""" - import tree - - return ":".join(tree.flatten(t)) - - -def _get_rank_and_size() -> Tuple[int, int]: - """ - :return: tuple (rank, size): the global rank and size for distributed trainings - """ + config = get_global_config(raise_exception=False) + if not config: + return - from returnn.config import get_global_config + dd_cfg = config.typed_value("dataset_distribution", None) + if dd_cfg and (distrib_shard_files and dd_cfg != "shard") or (not distrib_shard_files and dd_cfg == "shard"): + raise ValueError( + f"{cls.__name__}: `distrib_shard_files` config ({distrib_shard_files}) mismatch " + f"with global config option `dataset_distribution` ({dd_cfg})." + ) - config = get_global_config(raise_exception=False) - if not config: - return 0, 1 - if config.typed_value("torch_distributed") is not None: - import returnn.torch.distributed + if distrib_shard_files and dd_cfg is None: + # RETURNN will set sharding info on the dataset if the global config is set. + # If it's not set, however, we need to respect the existing `distrib_shard_files` property + # for backwards compatibility and load the sharding info ourselves. + return CachedDataset2._get_sharding_rank_and_size(config) - ctx = returnn.torch.distributed.get_ctx(config=config) - return ctx.rank(), ctx.size() - elif config.is_true("use_horovod"): - assert config.bool("use_tensorflow", False) or config.value("backend", "").startswith("tensorflow") - import returnn.tf.horovod +def _get_key_for_file_tree(t: FileTree) -> str: + """generates a deterministic key given a file tree""" + import tree - ctx = returnn.tf.horovod.get_ctx(config=config) - return ctx.rank(), ctx.size() - else: - return 0, 1 + return ":".join(tree.flatten(t)) class _WorkerProcParent: diff --git a/returnn/datasets/meta.py b/returnn/datasets/meta.py index 901fbec8d..d55deeb09 100644 --- a/returnn/datasets/meta.py +++ b/returnn/datasets/meta.py @@ -1162,9 +1162,9 @@ def _get_random_dataset_seq_order(self, total_num_seqs): assert sum(counters) == total_num_seqs - if self.partition_epoch or self._num_shards > 1: + if self.partition_epoch or self.num_shards > 1: seq_order = self._apply_partition_epoch_and_sharding( - seq_order, self.partition_epoch, self.epoch, self._num_shards, self._shard_index + seq_order, self.partition_epoch, self.epoch, self.num_shards, self.shard_index ) if self.repeat_epoch: seq_order = seq_order * self.repeat_epoch diff --git a/returnn/datasets/multi_proc.py b/returnn/datasets/multi_proc.py index 9aeddf85e..11cf8b22e 100644 --- a/returnn/datasets/multi_proc.py +++ b/returnn/datasets/multi_proc.py @@ -94,6 +94,10 @@ def initialize(self): self._lazy_init() super().initialize() + def supports_sharding(self): + """this dataset supports sharding on the ``dedicated`` sharding method""" + return self._sharding_method == "dedicated" + @property def _meta_info_cache(self): if not self.num_outputs: @@ -144,7 +148,11 @@ def _lazy_init(self): self._sharding_method, ) elif self._sharding_method == "dedicated": - sub_dataset = {**self.dataset, "_num_shards": self.num_workers, "_shard_index": i} + sub_dataset = { + **self.dataset, + "num_shards": self.num_workers * self.num_shards, + "shard_index": (self.shard_index * self.num_workers) + i, + } args = ( i, sub_dataset, diff --git a/returnn/torch/data/returnn_dataset_wrapper.py b/returnn/torch/data/returnn_dataset_wrapper.py index ae70ec73c..8842d3fc4 100644 --- a/returnn/torch/data/returnn_dataset_wrapper.py +++ b/returnn/torch/data/returnn_dataset_wrapper.py @@ -41,8 +41,22 @@ class ReturnnDatasetResetMpSharedEpochCallback: def __init__(self, dataset: ReturnnDataset, epoch_mp_shared: torch.multiprocessing.Value): self.dataset = dataset self.epoch_mp_shared = epoch_mp_shared + self._sharding_config_set = False def __call__(self): + # Include local worker rank in the dataset sharding config + if not self._sharding_config_set: + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None: + assert self.dataset.supports_sharding() or worker_info.num_workers == 1, ( + f"Using {worker_info.num_workers} torch data loading workers " + "but dataset does not support sharding. This will result in repeated training data." + ) + self.dataset.set_shard_idx_and_num_shards( + self.dataset.shard_index + worker_info.id, self.dataset.num_shards * worker_info.num_workers + ) + self._sharding_config_set = True + # dataset is likely a copy of the original dataset, either in the main process or in a worker process # Use epoch_mp_shared to get the current epoch correctly in worked processes epoch = self.epoch_mp_shared.value or None diff --git a/tests/test_Dataset.py b/tests/test_Dataset.py index 2b6469b88..eee67b7fc 100644 --- a/tests/test_Dataset.py +++ b/tests/test_Dataset.py @@ -564,10 +564,9 @@ def create_ogg_zip_txt_only_dataset(*, text: str = "hello world", seq_tag: str = @contextlib.contextmanager -def create_ogg_zip_txt_only_dataset_mult_seqs(*, seed: int = 1, num_seqs: int = 100, max_seq_len: int = 100): +def create_ogg_zip_txt_only_dataset_mult_seqs_opts(*, seed: int = 1, num_seqs: int = 100, max_seq_len: int = 100): """create OggZipDataset""" import zipfile - from returnn.datasets.audio import OggZipDataset rnd = numpy.random.RandomState(seed) @@ -596,6 +595,15 @@ def create_ogg_zip_txt_only_dataset_mult_seqs(*, seed: int = 1, num_seqs: int = "audio": None, "targets": {"class": "CharacterTargets", "vocab_file": tmp_vocab_file.name, "seq_postfix": []}, } + yield opts + + +@contextlib.contextmanager +def create_ogg_zip_txt_only_dataset_mult_seqs(*, seed: int = 1, num_seqs: int = 100, max_seq_len: int = 100): + """create OggZipDataset""" + from returnn.datasets.audio import OggZipDataset + + with create_ogg_zip_txt_only_dataset_mult_seqs_opts(seed=seed, num_seqs=num_seqs, max_seq_len=max_seq_len) as opts: dataset = init_dataset(opts) assert isinstance(dataset, OggZipDataset) yield dataset @@ -1304,6 +1312,18 @@ def _collect_single_seq(self, seq_idx: int) -> Optional[DatasetSeq]: assert sub_ep == outer_epoch * multi_epoch + 1 and sub_seq_idx == 0 +def test_dataset_sharding(): + from returnn.datasets.audio import OggZipDataset + + with create_ogg_zip_txt_only_dataset_mult_seqs_opts(num_seqs=10) as dataset_opts: + datasets = [init_dataset({**dataset_opts, "num_shards": 2, "shard_index": i}) for i in range(2)] + for dataset in datasets: + assert isinstance(dataset, OggZipDataset) + dataset.init_seq_order(epoch=1) + assert dataset.shard_index < dataset.num_shards == 2 + assert dataset.num_seqs == 5 + + if __name__ == "__main__": better_exchook.install() if len(sys.argv) <= 1: diff --git a/tests/test_torch_dataset.py b/tests/test_torch_dataset.py index bc2cfa42c..0b13e3789 100644 --- a/tests/test_torch_dataset.py +++ b/tests/test_torch_dataset.py @@ -18,7 +18,12 @@ def get_loader_from_returnn_dataset( - dataset: Dataset, mp_manager: torch.multiprocessing.Manager, *, batch_size: int = 5, max_seqs: int = 2 + dataset: Dataset, + mp_manager: torch.multiprocessing.Manager, + *, + batch_size: int = 5, + max_seqs: int = 2, + num_workers: int = 1, ) -> DataLoader: # Follow mostly similar logic as in the PT engine. @@ -46,7 +51,7 @@ def get_loader_from_returnn_dataset( pickle.loads(pickle.dumps(batches_dataset)) - return data_pipeline.create_data_loader_from_batches(batches_dataset, {"num_workers": 1}) + return data_pipeline.create_data_loader_from_batches(batches_dataset, {"num_workers": num_workers}) def test_pipeline_serialization(): @@ -202,6 +207,13 @@ def test_MultiProcDataset_HDFDataset(): assert c == n +def test_dataset_num_workers_sharding(): + dataset = init_dataset({"class": "Task12AXDataset", "num_seqs": 10}) + mp_manager = torch.multiprocessing.Manager() + loader = get_loader_from_returnn_dataset(dataset, mp_manager, max_seqs=1, num_workers=2) + assert len(list(iter(loader))) == 10 + + if __name__ == "__main__": better_exchook.install() if len(sys.argv) <= 1: