-
Notifications
You must be signed in to change notification settings - Fork 133
Dataset: implement global dataset_distribution
option
#1676
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 14 commits
924045e
bb03813
9410000
3991edb
c0bce2b
6261afc
f6a8942
c3ad46e
c3b2c04
485b6f4
653969c
7d753f5
2607de2
5a14a5e
8d56a65
55d244d
90353a9
1011ace
3b28635
a3e6ad5
713fde9
148191d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -20,7 +20,7 @@ | |||||||||
import numpy | ||||||||||
import functools | ||||||||||
import typing | ||||||||||
from typing import TYPE_CHECKING, Optional, Any, Union, Type, Dict, Sequence, List, Callable | ||||||||||
from typing import TYPE_CHECKING, Optional, Any, Tuple, Union, Type, Dict, Sequence, List, Callable | ||||||||||
|
||||||||||
from returnn.log import log | ||||||||||
from returnn.engine.batch import Batch, BatchSetGenerator | ||||||||||
|
@@ -68,6 +68,12 @@ def set_or_remove(key, value): | |||||||||
set_or_remove("min_chunk_size", config.opt_typed_value("min_chunk_size", 0) or None) | ||||||||||
set_or_remove("chunking_variance", config.float("chunking_variance", 0)) | ||||||||||
|
||||||||||
dd_cfg = config.typed_value("dataset_distribution", "random_seed_offset") | ||||||||||
assert dd_cfg in ["random_seed_offset", "shard"] | ||||||||||
shard_index, num_shards = Dataset._get_sharding_rank_and_size(config) if dd_cfg == "shard" else 0, 1 | ||||||||||
set_or_remove("num_shards", num_shards) | ||||||||||
set_or_remove("shard_index", shard_index) | ||||||||||
|
||||||||||
@staticmethod | ||||||||||
def get_default_kwargs_eval(config: Config) -> Dict[str, Any]: | ||||||||||
""" | ||||||||||
|
@@ -112,8 +118,8 @@ def __init__( | |||||||||
min_chunk_size=0, | ||||||||||
chunking_variance=0, | ||||||||||
estimated_num_seqs=None, | ||||||||||
_num_shards=1, | ||||||||||
_shard_index=0, | ||||||||||
num_shards: int = 1, | ||||||||||
shard_index: int = 0, | ||||||||||
Comment on lines
+120
to
+121
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand: Why are those public (not prefixed with |
||||||||||
): | ||||||||||
""" | ||||||||||
:param str name: e.g. "train" or "eval" | ||||||||||
|
@@ -137,8 +143,8 @@ def __init__( | |||||||||
:param str|None seq_order_seq_lens_file: for seq order, use the seq length given by this file | ||||||||||
:param int shuffle_frames_of_nseqs: shuffles the frames. not always supported | ||||||||||
:param None|int estimated_num_seqs: for progress reporting in case the real num_seqs is unknown | ||||||||||
:param int _num_shards: number of shards the data is split into | ||||||||||
:param int _shard_index: local shard index, when sharding is enabled | ||||||||||
:param num_shards: number of shards the data is split into | ||||||||||
:param shard_index: local shard index, when sharding is enabled | ||||||||||
""" | ||||||||||
self.name = name or ("dataset_id%s" % id(self)) | ||||||||||
self.lock = None # type: Optional[RLock] # Used when manipulating our data potentially from multiple threads. | ||||||||||
|
@@ -172,9 +178,9 @@ def __init__( | |||||||||
self._chunking = chunking | ||||||||||
self.chunk_size, self.chunk_step, self.custom_chunking_func = self._parse_chunking(chunking) | ||||||||||
self._context_window = context_window | ||||||||||
assert 0 <= _shard_index < _num_shards | ||||||||||
self._num_shards = _num_shards | ||||||||||
self._shard_index = _shard_index | ||||||||||
assert 0 <= shard_index < num_shards | ||||||||||
self.num_shards = num_shards | ||||||||||
self.shard_index = shard_index | ||||||||||
Comment on lines
+178
to
+180
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Slightly cleaner and reuses code, but it's also fine as is. Your choice. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That function has additional asserts that rely on correct initialization. Not for now. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I don't understand the comment. What correct initialization? What do you mean by "not for now"? |
||||||||||
if isinstance(context_window, (tuple, list)): | ||||||||||
assert len(context_window) == 2 | ||||||||||
for elem in context_window: | ||||||||||
|
@@ -250,6 +256,37 @@ def __reduce__(self): | |||||||||
state = {attr: getattr(self, attr) for attr in ["epoch", "zpad"]} | ||||||||||
return Dataset._create_from_reduce, (self.__class__, kwargs, state) | ||||||||||
|
||||||||||
def set_shard_idx_and_num_shards(self, shard_index: int, num_shards: int): | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Small inconsistency: idx vs index. |
||||||||||
"""set sharding config for the dataset""" | ||||||||||
assert 0 <= shard_index < num_shards | ||||||||||
assert num_shards == 1 or self.supports_sharding() | ||||||||||
self.num_shards = num_shards | ||||||||||
self.shard_index = shard_index | ||||||||||
|
||||||||||
@staticmethod | ||||||||||
def _get_sharding_rank_and_size(config: Optional[Config] = None) -> Tuple[int, int]: | ||||||||||
""" | ||||||||||
:return: tuple (rank, size): the global rank and size for distributed trainings | ||||||||||
""" | ||||||||||
if config is None: | ||||||||||
from returnn.config import get_global_config | ||||||||||
|
||||||||||
config = get_global_config(return_empty_if_none=True) | ||||||||||
if config.typed_value("torch_distributed") is not None: | ||||||||||
import returnn.torch.distributed | ||||||||||
|
||||||||||
ctx = returnn.torch.distributed.get_ctx(config=config) | ||||||||||
return ctx.rank(), ctx.size() | ||||||||||
elif config.is_true("use_horovod"): | ||||||||||
assert config.bool("use_tensorflow", False) or config.value("backend", "").startswith("tensorflow") | ||||||||||
|
||||||||||
import returnn.tf.horovod | ||||||||||
|
||||||||||
ctx = returnn.tf.horovod.get_ctx(config=config) | ||||||||||
return ctx.rank(), ctx.size() | ||||||||||
else: | ||||||||||
return 0, 1 | ||||||||||
|
||||||||||
@property | ||||||||||
def random_seed_offset(self) -> int: | ||||||||||
""":return: random seed offset for shuffling""" | ||||||||||
|
@@ -259,10 +296,10 @@ def random_seed_offset(self) -> int: | |||||||||
|
||||||||||
def _uses_custom_distributed_sharding(self) -> bool: | ||||||||||
""" | ||||||||||
:return: if dataset has its own data sharding logic independent of TF/PT. | ||||||||||
:return: if the dataset has its own data sharding logic independent of TF/PT. | ||||||||||
Leads to a fixed random_seed_offset independent of the workers local rank. | ||||||||||
""" | ||||||||||
return False | ||||||||||
return self.num_shards > 1 | ||||||||||
|
||||||||||
def _get_default_random_seed_offset(self): | ||||||||||
""" | ||||||||||
|
@@ -643,9 +680,9 @@ def get_seq_order_for_epoch( | |||||||||
seq_index = [ | ||||||||||
i for i in seq_index if (all_seq_tags[i] not in used_seq_tags, used_seq_tags.add(all_seq_tags[i]))[0] | ||||||||||
] | ||||||||||
if partition_epoch > 1 or self._num_shards > 1: | ||||||||||
if partition_epoch > 1 or self.num_shards > 1: | ||||||||||
seq_index = self._apply_partition_epoch_and_sharding( | ||||||||||
seq_index, partition_epoch, epoch, self._num_shards, self._shard_index | ||||||||||
seq_index, partition_epoch, epoch, self.num_shards, self.shard_index | ||||||||||
) | ||||||||||
if repeat_epoch > 1: | ||||||||||
seq_index = list(seq_index) * repeat_epoch | ||||||||||
|
@@ -737,8 +774,8 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None): | |||||||||
self.epoch = epoch | ||||||||||
self.rnd_seq_drop = Random(self._get_random_seed_for_epoch(epoch=epoch)) | ||||||||||
assert ( | ||||||||||
self._num_shards == 1 or self.supports_sharding() | ||||||||||
), f"{self}: does not support sharding, but got num_shards == {self._num_shards}" | ||||||||||
self.num_shards == 1 or self.supports_sharding() | ||||||||||
), f"{self}: does not support sharding, but got num_shards == {self.num_shards}" | ||||||||||
return False | ||||||||||
|
||||||||||
def finish_epoch(self, *, free_resources: bool = False): | ||||||||||
|
@@ -1588,9 +1625,9 @@ def _dataset_extend_default_kwargs_from_parent_dataset( | |||||||||
default_kwargs = default_kwargs.copy() if default_kwargs else {} | ||||||||||
default_kwargs.setdefault("random_seed_offset", parent_dataset.random_seed_offset) | ||||||||||
# noinspection PyProtectedMember | ||||||||||
default_kwargs.setdefault("_num_shards", parent_dataset._num_shards) | ||||||||||
default_kwargs.setdefault("num_shards", parent_dataset.num_shards) | ||||||||||
# noinspection PyProtectedMember | ||||||||||
default_kwargs.setdefault("_shard_index", parent_dataset._shard_index) | ||||||||||
default_kwargs.setdefault("shard_index", parent_dataset.shard_index) | ||||||||||
return default_kwargs | ||||||||||
|
||||||||||
|
||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -137,9 +137,8 @@ def __init__( | |||||||||||
get_sub_epoch_dataset: Callable[[List[FileTree]], Dict[str, Any]], | ||||||||||||
preload_next_n_sub_epochs: int = 1, | ||||||||||||
buffer_size: int = 1, | ||||||||||||
distrib_shard_files: bool = False, | ||||||||||||
distrib_shard_files: Optional[bool] = None, | ||||||||||||
_meta_info_cache: Optional[Dict[str, Any]] = None, | ||||||||||||
_distrib_info: Optional[Dict[str, int]] = None, | ||||||||||||
**kwargs, | ||||||||||||
): | ||||||||||||
""" | ||||||||||||
|
@@ -148,10 +147,10 @@ def __init__( | |||||||||||
:param get_sub_epoch_dataset: callable which returns a dataset dict for a given subset of files | ||||||||||||
:param preload_next_n_sub_epochs: how many sub epoch datasets to preload | ||||||||||||
:param buffer_size: buffer size for each worker, amount of seqs to prefetch | ||||||||||||
:param distrib_shard_files: set to true to shard the data across worker processes in | ||||||||||||
distributed training scenaria | ||||||||||||
:param distrib_shard_files: deprecated. Replaced by global config option ``dataset_distribution="shard"``. | ||||||||||||
|
||||||||||||
Set to true to shard the data across worker processes in distributed training scenaria. | ||||||||||||
Comment on lines
+155
to
+157
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this is necessary to mark it as deprecated.
Suggested change
|
||||||||||||
:param _meta_info_cache: for internal use | ||||||||||||
:param _distrib_info: for internal use | ||||||||||||
""" | ||||||||||||
super().__init__(**kwargs) | ||||||||||||
self.files = files | ||||||||||||
|
@@ -166,21 +165,13 @@ def __init__( | |||||||||||
self._workers: Dict[int, _WorkerProcParent] = {} # epoch -> worker | ||||||||||||
self._files_order_cache: Dict[int, List[List[FileTree]]] = {} # full epoch (0-indexed) -> files order | ||||||||||||
|
||||||||||||
self.distrib_shard_files = distrib_shard_files | ||||||||||||
if distrib_shard_files: | ||||||||||||
assert self._num_shards == 1 and self._shard_index == 0, ( # ensure defaults are set | ||||||||||||
f"{self}: Cannot use both dataset-sharding via properties _num_shards and _shard index " | ||||||||||||
f"and {self.__class__.__name__}'s own sharding implementation based on the trainings rank and size." | ||||||||||||
if distrib_shard_files is not None: | ||||||||||||
log.print_deprecation_warning( | ||||||||||||
f"{self.__class__.__name__}' `distrib_shard_files` config option is set. " | ||||||||||||
"Use global config option `dataset_distribution` instead " | ||||||||||||
"for the same behavior across more types of datasets." | ||||||||||||
) | ||||||||||||
Comment on lines
+175
to
179
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure this is really necessary to mark this as deprecated.
Suggested change
|
||||||||||||
if _distrib_info: | ||||||||||||
# If we're in a child process `_get_rank_and_size()` no longer works, | ||||||||||||
# so we pass the info about the shards via a pickled property. | ||||||||||||
# See also Dataset.__reduce__. | ||||||||||||
self._shard_index = _distrib_info["_shard_index"] | ||||||||||||
self._num_shards = _distrib_info["_num_shards"] | ||||||||||||
else: | ||||||||||||
self._shard_index, self._num_shards = _get_rank_and_size() | ||||||||||||
assert 0 <= self._shard_index < self._num_shards | ||||||||||||
self._validate_global_shard_cfg(distrib_shard_files) | ||||||||||||
|
||||||||||||
if _meta_info_cache: | ||||||||||||
# This allows to skip the lazy init in self.initialize(). | ||||||||||||
|
@@ -204,10 +195,6 @@ def supports_sharding(self) -> bool: | |||||||||||
""":return: whether the dataset supports sharding based on the seq_order""" | ||||||||||||
return True | ||||||||||||
|
||||||||||||
@property | ||||||||||||
def _distrib_info(self): | ||||||||||||
return {"_num_shards": self._num_shards, "_shard_index": self._shard_index} | ||||||||||||
|
||||||||||||
@property | ||||||||||||
def _meta_info_cache(self): | ||||||||||||
if not self.num_outputs: | ||||||||||||
|
@@ -220,9 +207,6 @@ def _meta_info_cache(self): | |||||||||||
"file_sizes": self._file_sizes, | ||||||||||||
} | ||||||||||||
|
||||||||||||
def _uses_custom_distributed_sharding(self) -> bool: | ||||||||||||
return self._num_shards > 1 | ||||||||||||
|
||||||||||||
def _lazy_init_num_outputs(self): | ||||||||||||
if self.num_outputs: | ||||||||||||
return | ||||||||||||
|
@@ -290,11 +274,11 @@ def init_seq_order(self, epoch: Optional[int] = None, seq_list=None, seq_order=N | |||||||||||
else: | ||||||||||||
raise ValueError(f"{self}: seq_ordering {self.seq_ordering!r} not supported") | ||||||||||||
file_bins = self._distribute_evenly_by_size( | ||||||||||||
num_bins=self._num_shards * self.partition_epoch, | ||||||||||||
num_bins=self.num_shards * self.partition_epoch, | ||||||||||||
file_sizes=self._file_sizes, | ||||||||||||
files_order=files_order_flat, | ||||||||||||
) | ||||||||||||
self_index_base = self.partition_epoch * self._shard_index | ||||||||||||
self_index_base = self.partition_epoch * self.shard_index | ||||||||||||
self_index_end = self_index_base + self.partition_epoch | ||||||||||||
self._files_order_cache[full_epoch_0idx_] = file_bins[self_index_base:self_index_end] | ||||||||||||
|
||||||||||||
|
@@ -328,6 +312,10 @@ def _get_sub_dataset_dict(self, files: List[FileTree]) -> Dict[str, Any]: | |||||||||||
|
||||||||||||
dataset_dict = self.get_sub_epoch_dataset(files) | ||||||||||||
dataset_dict = extend_dataset_dict_from_parent_dataset(dataset_dict, parent_dataset=self) | ||||||||||||
# We shard by splitting the files list into shards, the sub datasets must not shard any further by themselves | ||||||||||||
if self.num_shards > 1: | ||||||||||||
dataset_dict["num_shards"] = 1 | ||||||||||||
dataset_dict["shard_index"] = 0 | ||||||||||||
|
||||||||||||
flat_sub_dset = tree.flatten_with_path(dataset_dict) | ||||||||||||
|
||||||||||||
|
@@ -452,6 +440,21 @@ def get_data_keys(self) -> List[str]: | |||||||||||
self._lazy_init_num_outputs() | ||||||||||||
return self._data_keys | ||||||||||||
|
||||||||||||
@classmethod | ||||||||||||
def _validate_global_shard_cfg(cls, distrib_shard_files: bool): | ||||||||||||
from returnn.config import get_global_config | ||||||||||||
|
||||||||||||
config = get_global_config(raise_exception=False) | ||||||||||||
if not config: | ||||||||||||
return | ||||||||||||
|
||||||||||||
dd_cfg = config.typed_value("dataset_distribution", None) | ||||||||||||
if dd_cfg and (distrib_shard_files and dd_cfg != "shard") or (not distrib_shard_files and dd_cfg == "shard"): | ||||||||||||
raise ValueError( | ||||||||||||
f"{cls.__name__}: `distrib_shard_files` config ({distrib_shard_files}) mismatch " | ||||||||||||
f"with global config option `dataset_distribution` ({dd_cfg})." | ||||||||||||
) | ||||||||||||
|
||||||||||||
|
||||||||||||
def _get_key_for_file_tree(t: FileTree) -> str: | ||||||||||||
"""generates a deterministic key given a file tree""" | ||||||||||||
|
@@ -460,32 +463,6 @@ def _get_key_for_file_tree(t: FileTree) -> str: | |||||||||||
return ":".join(tree.flatten(t)) | ||||||||||||
|
||||||||||||
|
||||||||||||
def _get_rank_and_size() -> Tuple[int, int]: | ||||||||||||
""" | ||||||||||||
:return: tuple (rank, size): the global rank and size for distributed trainings | ||||||||||||
""" | ||||||||||||
|
||||||||||||
from returnn.config import get_global_config | ||||||||||||
|
||||||||||||
config = get_global_config(raise_exception=False) | ||||||||||||
if not config: | ||||||||||||
return 0, 1 | ||||||||||||
if config.typed_value("torch_distributed") is not None: | ||||||||||||
import returnn.torch.distributed | ||||||||||||
|
||||||||||||
ctx = returnn.torch.distributed.get_ctx(config=config) | ||||||||||||
return ctx.rank(), ctx.size() | ||||||||||||
elif config.is_true("use_horovod"): | ||||||||||||
assert config.bool("use_tensorflow", False) or config.value("backend", "").startswith("tensorflow") | ||||||||||||
|
||||||||||||
import returnn.tf.horovod | ||||||||||||
|
||||||||||||
ctx = returnn.tf.horovod.get_ctx(config=config) | ||||||||||||
return ctx.rank(), ctx.size() | ||||||||||||
else: | ||||||||||||
return 0, 1 | ||||||||||||
|
||||||||||||
|
||||||||||||
class _WorkerProcParent: | ||||||||||||
def __init__( | ||||||||||||
self, | ||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -94,6 +94,10 @@ def initialize(self): | |
self._lazy_init() | ||
super().initialize() | ||
|
||
def supports_sharding(self): | ||
"""this dataset supports sharding on the ``dedicated`` sharding method""" | ||
return self._sharding_method == "dedicated" | ||
|
||
@property | ||
def _meta_info_cache(self): | ||
if not self.num_outputs: | ||
|
@@ -144,7 +148,11 @@ def _lazy_init(self): | |
self._sharding_method, | ||
) | ||
elif self._sharding_method == "dedicated": | ||
sub_dataset = {**self.dataset, "_num_shards": self.num_workers, "_shard_index": i} | ||
sub_dataset = { | ||
**self.dataset, | ||
"num_shards": self.num_workers * self.num_shards, | ||
"shard_index": (self.shard_index * self.num_workers) + i, | ||
} | ||
Comment on lines
-147
to
+155
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is something that I overlooked before. We can trivially allow sharding for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I might not know the context of this comment. Isn't sharding already allowed as per your change here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is when sharding was set on the MultiProcDataset itself, and then MultiProcDataset also uses sharding on the subdatasets. |
||
args = ( | ||
i, | ||
sub_dataset, | ||
|
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -41,8 +41,22 @@ class ReturnnDatasetResetMpSharedEpochCallback: | |||||||||
def __init__(self, dataset: ReturnnDataset, epoch_mp_shared: torch.multiprocessing.Value): | ||||||||||
self.dataset = dataset | ||||||||||
self.epoch_mp_shared = epoch_mp_shared | ||||||||||
self._sharding_config_set = False | ||||||||||
|
||||||||||
def __call__(self): | ||||||||||
# Include local worker rank in the dataset sharding config | ||||||||||
if not self._sharding_config_set: | ||||||||||
worker_info = torch.utils.data.get_worker_info() | ||||||||||
if worker_info is not None: | ||||||||||
assert self.dataset.supports_sharding() or worker_info.num_workers == 1, ( | ||||||||||
f"Using {worker_info.num_workers} torch data loading workers " | ||||||||||
"but dataset does not support sharding. This will result in repeated training data." | ||||||||||
NeoLegends marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
) | ||||||||||
self.dataset.set_shard_idx_and_num_shards( | ||||||||||
self.dataset.shard_index + worker_info.id, self.dataset.num_shards * worker_info.num_workers | ||||||||||
) | ||||||||||
Comment on lines
+55
to
+57
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand. Why does this consider the existing shard_index/num_shards? I would expect that you overwrite those here.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Or do you expect that the existing shard_index/num_shards where set with the distributed rank/size, and this here additionally adds further sharding for workerid/num_workers? But then this code written in this way is very confusing... I think this should be done differently somehow. Not sure how... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could also do the distributed logic directly here. We pass on the rank/size info anyway to the subproc children via the env var The question is, how to handle There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of having all this logic here, we maybe should better move it to |
||||||||||
self._sharding_config_set = True | ||||||||||
|
||||||||||
# dataset is likely a copy of the original dataset, either in the main process or in a worker process | ||||||||||
# Use epoch_mp_shared to get the current epoch correctly in worked processes | ||||||||||
epoch = self.epoch_mp_shared.value or None | ||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure whether it is a good idea to do this logic in
kwargs_update_from_config
. You don't want to apply this just for every dataset. E.g. for the dev/test/eval datasets, we would not want this logic, at least not right now.I think this should come from the outside, not from within.