Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
924045e
black
NeoLegends Jan 15, 2025
bb03813
Merge branch 'master' into moritz-shard-mgpu
NeoLegends Feb 4, 2025
9410000
add sharding test
NeoLegends Feb 4, 2025
3991edb
Merge branch 'master' into moritz-shard-mgpu
NeoLegends Feb 7, 2025
c0bce2b
initialize num_shards and shard_index in kwargs_update_from_config
NeoLegends Feb 10, 2025
6261afc
sharding parameters cannot be None anymore
NeoLegends Feb 10, 2025
f6a8942
we no longer need the @properties
NeoLegends Feb 10, 2025
c3ad46e
black
NeoLegends Feb 10, 2025
c3b2c04
Merge branch 'master' into moritz-shard-mgpu
NeoLegends Feb 21, 2025
485b6f4
Merge branch 'master' into moritz-shard-mgpu
NeoLegends Feb 27, 2025
653969c
Merge branch 'master' into moritz-shard-mgpu
NeoLegends Mar 5, 2025
7d753f5
take torch num_workers into account for sharding
NeoLegends Mar 5, 2025
2607de2
set sharding config in torch data pipe
NeoLegends Mar 5, 2025
5a14a5e
MultiProcDataset: support sharding on `"sharding_method": "dedicated"`
NeoLegends Mar 5, 2025
8d56a65
Fix missing assignment of distrib_shard_files
NeoLegends Mar 7, 2025
55d244d
fix sharding when distrib_shard_files is set and dataset_distribution…
NeoLegends Mar 7, 2025
90353a9
Merge branch 'master' into moritz-shard-mgpu
NeoLegends May 6, 2025
1011ace
black
NeoLegends May 6, 2025
3b28635
Merge branch 'master' into moritz-shard-mgpu
NeoLegends May 13, 2025
a3e6ad5
fix lints
NeoLegends May 20, 2025
713fde9
Merge branch 'master' into moritz-shard-mgpu
NeoLegends May 20, 2025
148191d
Merge branch 'master' into moritz-shard-mgpu
NeoLegends Jul 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 67 additions & 14 deletions returnn/datasets/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import numpy
import functools
import typing
from typing import TYPE_CHECKING, Optional, Any, Union, Type, Dict, Sequence, List, Callable
from typing import TYPE_CHECKING, Optional, Any, Tuple, Union, Type, Dict, Sequence, List, Callable

from returnn.log import log
from returnn.engine.batch import Batch, BatchSetGenerator
Expand Down Expand Up @@ -111,8 +111,8 @@ def __init__(
min_chunk_size=0,
chunking_variance=0,
estimated_num_seqs=None,
_num_shards=1,
_shard_index=0,
_num_shards=None,
_shard_index=None,
):
"""
:param str name: e.g. "train" or "eval"
Expand All @@ -136,8 +136,8 @@ def __init__(
:param str|None seq_order_seq_lens_file: for seq order, use the seq length given by this file
:param int shuffle_frames_of_nseqs: shuffles the frames. not always supported
:param None|int estimated_num_seqs: for progress reporting in case the real num_seqs is unknown
:param int _num_shards: number of shards the data is split into
:param int _shard_index: local shard index, when sharding is enabled
:param int|None _num_shards: number of shards the data is split into
:param int|None _shard_index: local shard index, when sharding is enabled
"""
self.name = name or ("dataset_id%s" % id(self))
self.lock = None # type: Optional[RLock] # Used when manipulating our data potentially from multiple threads.
Expand Down Expand Up @@ -171,7 +171,7 @@ def __init__(
self._chunking = chunking
self.chunk_size, self.chunk_step, self.custom_chunking_func = self._parse_chunking(chunking)
self._context_window = context_window
assert 0 <= _shard_index < _num_shards
assert (_shard_index is None and _num_shards is None) or 0 <= _shard_index < _num_shards
self._num_shards = _num_shards
self._shard_index = _shard_index
if isinstance(context_window, (tuple, list)):
Expand Down Expand Up @@ -249,6 +249,59 @@ def __reduce__(self):
state = {attr: getattr(self, attr) for attr in ["epoch", "zpad"]}
return Dataset._create_from_reduce, (self.__class__, kwargs, state)

@staticmethod
def _get_rank_and_size() -> Tuple[int, int]:
"""
:return: tuple (rank, size): the global rank and size for distributed trainings
"""
from returnn.config import get_global_config

config = get_global_config(raise_exception=False)
if not config:
return 0, 1
if config.typed_value("torch_distributed") is not None:
import returnn.torch.distributed

ctx = returnn.torch.distributed.get_ctx(config=config)
return ctx.rank(), ctx.size()
elif config.is_true("use_horovod"):
assert config.bool("use_tensorflow", False) or config.value("backend", "").startswith("tensorflow")

import returnn.tf.horovod

ctx = returnn.tf.horovod.get_ctx(config=config)
return ctx.rank(), ctx.size()
else:
return 0, 1

@staticmethod
def _get_default_shard_config():
"""
:return: default shard index and number of shards based on the global config
"""
from returnn.config import get_global_config

config = get_global_config(raise_exception=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like that we access the global config here. I know this follows similar code as _get_default_random_seed_offset but I also don't like it there. Why is this needed? This should come from outside, or not? Specifically at the place where we call init_dataset. E.g. in the __main__. There we also call Dataset.kwargs_update_from_config.

Also, the code is wrong. Distributed training is only one possible source which defines/influences the shard index and num shards. But there are other reasons, for example the MultiProcDataset, or PyTorch DataLoader num_workers.

Copy link
Member Author

@NeoLegends NeoLegends Feb 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MultiProcDataset

This is already setting the num_shards and shard index for its children. The code always was designed in such a way that it would only look at the global config in case there was no value already set. But I agree, now it's better.

PyTorch DataLoader num_workers

I think this is actually not that trivial to implement because the torch engine already is given initialized datasets and it's difficult to change the sharding config after having initialized a dataset. So factoring in the PyTorch num_workers needs to be done during the initial dataset initialization, which mixes up pytorch code with data initialization code a bit, and I feel that is going to be a bit messy. Do you know a good way? Maybe this is fine after all.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we cannot achieve both

But e.g. there could be a test with PyTorch DataLoader num_workers=2 which checks that all data from the dataset was properly covered.

and

I know this follows similar code as _get_default_random_seed_offset but I also don't like it there. Why is this needed? This should come from outside, or not?

because of

I think this is actually not that trivial to implement because the torch engine already is given initialized datasets and it's difficult to change the sharding config after having initialized a dataset

However, I think it's worth it to have proper support for torch_dataloader_opts = {"num_workers": n} with n > 1 because this makes it much simpler for the end user to have multi-process data loading and this feature can replace MultiProcDataset for simple use cases. So I think I need to revert back on the changes where num_shards and shard_index are set from the outside, and rather fetch them from inside the dataset, when they are needed. At that point the torch worker info is also available, which means that we can take it into account properly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is actually not that trivial to implement because the torch engine already is given initialized datasets and it's difficult to change the sharding config after having initialized a dataset.

Why difficult? We maybe just need a clean dataset API for that, some setter function set_num_shards_and_shard_idx or so. And then in ReturnnDatasetIterDataPipe.reset or so we just need to call that.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, I was originally not a fan of the mutability of these properties, but it seems ok now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we cannot avoid such API like set_num_shards_and_shard_idx because of how the PyTorch data pipeline works.

if not config:
return 0, 1
dd_cfg = config.typed_value("dataset_distribution", "random_seed_offset")
assert dd_cfg in ["random_seed_offset", "shard"]
return Dataset._get_rank_and_size() if dd_cfg == "shard" else 0, 1

@property
def num_shards(self) -> int:
""":return: number of shards the data is split into"""
if self._num_shards is None:
self._shard_index, self._num_shards = self._get_default_shard_config()
return self._num_shards

@property
def shard_index(self) -> int:
""":return: local shard index, when sharding is enabled"""
if self._shard_index is None:
self._shard_index, self._num_shards = self._get_default_shard_config()
return self._shard_index

@property
def random_seed_offset(self) -> int:
""":return: random seed offset for shuffling"""
Expand All @@ -258,10 +311,10 @@ def random_seed_offset(self) -> int:

def _uses_custom_distributed_sharding(self) -> bool:
"""
:return: if dataset has its own data sharding logic independent of TF/PT.
:return: if the dataset has its own data sharding logic independent of TF/PT.
Leads to a fixed random_seed_offset independent of the workers local rank.
"""
return False
return self.num_shards > 1

def _get_default_random_seed_offset(self):
"""
Expand Down Expand Up @@ -642,9 +695,9 @@ def get_seq_order_for_epoch(
seq_index = [
i for i in seq_index if (all_seq_tags[i] not in used_seq_tags, used_seq_tags.add(all_seq_tags[i]))[0]
]
if partition_epoch > 1 or self._num_shards > 1:
if partition_epoch > 1 or self.num_shards > 1:
seq_index = self._apply_partition_epoch_and_sharding(
seq_index, partition_epoch, epoch, self._num_shards, self._shard_index
seq_index, partition_epoch, epoch, self.num_shards, self.shard_index
)
if repeat_epoch > 1:
seq_index = list(seq_index) * repeat_epoch
Expand Down Expand Up @@ -736,8 +789,8 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None):
self.epoch = epoch
self.rnd_seq_drop = Random(self._get_random_seed_for_epoch(epoch=epoch))
assert (
self._num_shards == 1 or self.supports_sharding()
), f"{self}: does not support sharding, but got num_shards == {self._num_shards}"
self.num_shards == 1 or self.supports_sharding()
), f"{self}: does not support sharding, but got num_shards == {self.num_shards}"
return False

def finish_epoch(self, *, free_resources: bool = False):
Expand Down Expand Up @@ -1553,9 +1606,9 @@ def _dataset_extend_default_kwargs_from_parent_dataset(
default_kwargs = default_kwargs.copy() if default_kwargs else {}
default_kwargs.setdefault("random_seed_offset", parent_dataset.random_seed_offset)
# noinspection PyProtectedMember
default_kwargs.setdefault("_num_shards", parent_dataset._num_shards)
default_kwargs.setdefault("_num_shards", parent_dataset.num_shards)
# noinspection PyProtectedMember
default_kwargs.setdefault("_shard_index", parent_dataset._shard_index)
default_kwargs.setdefault("_shard_index", parent_dataset.shard_index)
return default_kwargs


Expand Down
85 changes: 31 additions & 54 deletions returnn/datasets/distrib_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,8 @@ def __init__(
get_sub_epoch_dataset: Callable[[List[FileTree]], Dict[str, Any]],
preload_next_n_sub_epochs: int = 1,
buffer_size: int = 1,
distrib_shard_files: bool = False,
distrib_shard_files: Optional[bool] = None,
_meta_info_cache: Optional[Dict[str, Any]] = None,
_distrib_info: Optional[Dict[str, int]] = None,
**kwargs,
):
"""
Expand All @@ -148,10 +147,10 @@ def __init__(
:param get_sub_epoch_dataset: callable which returns a dataset dict for a given subset of files
:param preload_next_n_sub_epochs: how many sub epoch datasets to preload
:param buffer_size: buffer size for each worker, amount of seqs to prefetch
:param distrib_shard_files: set to true to shard the data across worker processes in
distributed training scenaria
:param distrib_shard_files: deprecated. Replaced by global config option ``dataset_distribution="shard"``.

Set to true to shard the data across worker processes in distributed training scenaria.
Comment on lines +155 to +157
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is necessary to mark it as deprecated.

Suggested change
:param distrib_shard_files: deprecated. Replaced by global config option ``dataset_distribution="shard"``.
Set to true to shard the data across worker processes in distributed training scenaria.
:param distrib_shard_files: deprecated. set to true to shard the data across worker processes in
distributed training scenaria

:param _meta_info_cache: for internal use
:param _distrib_info: for internal use
"""
super().__init__(**kwargs)
self.files = files
Expand All @@ -166,21 +165,13 @@ def __init__(
self._workers: Dict[int, _WorkerProcParent] = {} # epoch -> worker
self._files_order_cache: Dict[int, List[List[FileTree]]] = {} # full epoch (0-indexed) -> files order

self.distrib_shard_files = distrib_shard_files
if distrib_shard_files:
assert self._num_shards == 1 and self._shard_index == 0, ( # ensure defaults are set
f"{self}: Cannot use both dataset-sharding via properties _num_shards and _shard index "
f"and {self.__class__.__name__}'s own sharding implementation based on the trainings rank and size."
if distrib_shard_files is not None:
log.print_deprecation_warning(
f"{self.__class__.__name__}' `distrib_shard_files` config option is set. "
"Use global config option `dataset_distribution` instead "
"for the same behavior across more types of datasets."
)
Comment on lines +175 to 179
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this is really necessary to mark this as deprecated.

Suggested change
log.print_deprecation_warning(
f"{self.__class__.__name__}' `distrib_shard_files` config option is set. "
"Use global config option `dataset_distribution` instead "
"for the same behavior across more types of datasets."
)

if _distrib_info:
# If we're in a child process `_get_rank_and_size()` no longer works,
# so we pass the info about the shards via a pickled property.
# See also Dataset.__reduce__.
self._shard_index = _distrib_info["_shard_index"]
self._num_shards = _distrib_info["_num_shards"]
else:
self._shard_index, self._num_shards = _get_rank_and_size()
assert 0 <= self._shard_index < self._num_shards
self._validate_global_shard_cfg(distrib_shard_files)

if _meta_info_cache:
# This allows to skip the lazy init in self.initialize().
Expand All @@ -204,10 +195,6 @@ def supports_sharding(self) -> bool:
""":return: whether the dataset supports sharding based on the seq_order"""
return True

@property
def _distrib_info(self):
return {"_num_shards": self._num_shards, "_shard_index": self._shard_index}

@property
def _meta_info_cache(self):
if not self.num_outputs:
Expand All @@ -220,9 +207,6 @@ def _meta_info_cache(self):
"file_sizes": self._file_sizes,
}

def _uses_custom_distributed_sharding(self) -> bool:
return self._num_shards > 1

def _lazy_init_num_outputs(self):
if self.num_outputs:
return
Expand Down Expand Up @@ -290,11 +274,11 @@ def init_seq_order(self, epoch: Optional[int] = None, seq_list=None, seq_order=N
else:
raise ValueError(f"{self}: seq_ordering {self.seq_ordering!r} not supported")
file_bins = self._distribute_evenly_by_size(
num_bins=self._num_shards * self.partition_epoch,
num_bins=self.num_shards * self.partition_epoch,
file_sizes=self._file_sizes,
files_order=files_order_flat,
)
self_index_base = self.partition_epoch * self._shard_index
self_index_base = self.partition_epoch * self.shard_index
self_index_end = self_index_base + self.partition_epoch
self._files_order_cache[full_epoch_0idx_] = file_bins[self_index_base:self_index_end]

Expand Down Expand Up @@ -328,6 +312,10 @@ def _get_sub_dataset_dict(self, files: List[FileTree]) -> Dict[str, Any]:

dataset_dict = self.get_sub_epoch_dataset(files)
dataset_dict = extend_dataset_dict_from_parent_dataset(dataset_dict, parent_dataset=self)
# We shard by splitting the files list into shards, the sub datasets must not shard any further by themselves
if self.num_shards > 1:
dataset_dict["_num_shards"] = 1
dataset_dict["_shard_index"] = 0

flat_sub_dset = tree.flatten_with_path(dataset_dict)

Expand Down Expand Up @@ -452,6 +440,21 @@ def get_data_keys(self) -> List[str]:
self._lazy_init_num_outputs()
return self._data_keys

@classmethod
def _validate_global_shard_cfg(cls, distrib_shard_files: bool):
from returnn.config import get_global_config

config = get_global_config(raise_exception=False)
if not config:
return

dd_cfg = config.typed_value("dataset_distribution", None)
if dd_cfg and (distrib_shard_files and dd_cfg != "shard") or (not distrib_shard_files and dd_cfg == "shard"):
raise ValueError(
f"{cls.__name__}: `distrib_shard_files` config ({distrib_shard_files}) mismatch "
f"with global config option `dataset_distribution` ({dd_cfg})."
)


def _get_key_for_file_tree(t: FileTree) -> str:
"""generates a deterministic key given a file tree"""
Expand All @@ -460,32 +463,6 @@ def _get_key_for_file_tree(t: FileTree) -> str:
return ":".join(tree.flatten(t))


def _get_rank_and_size() -> Tuple[int, int]:
"""
:return: tuple (rank, size): the global rank and size for distributed trainings
"""

from returnn.config import get_global_config

config = get_global_config(raise_exception=False)
if not config:
return 0, 1
if config.typed_value("torch_distributed") is not None:
import returnn.torch.distributed

ctx = returnn.torch.distributed.get_ctx(config=config)
return ctx.rank(), ctx.size()
elif config.is_true("use_horovod"):
assert config.bool("use_tensorflow", False) or config.value("backend", "").startswith("tensorflow")

import returnn.tf.horovod

ctx = returnn.tf.horovod.get_ctx(config=config)
return ctx.rank(), ctx.size()
else:
return 0, 1


class _WorkerProcParent:
def __init__(
self,
Expand Down
24 changes: 22 additions & 2 deletions tests/test_Dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,10 +561,9 @@ def create_ogg_zip_txt_only_dataset(*, text: str = "hello world", seq_tag: str =


@contextlib.contextmanager
def create_ogg_zip_txt_only_dataset_mult_seqs(*, seed: int = 1, num_seqs: int = 100, max_seq_len: int = 100):
def create_ogg_zip_txt_only_dataset_mult_seqs_opts(*, seed: int = 1, num_seqs: int = 100, max_seq_len: int = 100):
"""create OggZipDataset"""
import zipfile
from returnn.datasets.audio import OggZipDataset

rnd = numpy.random.RandomState(seed)

Expand Down Expand Up @@ -593,6 +592,15 @@ def create_ogg_zip_txt_only_dataset_mult_seqs(*, seed: int = 1, num_seqs: int =
"audio": None,
"targets": {"class": "CharacterTargets", "vocab_file": tmp_vocab_file.name, "seq_postfix": []},
}
yield opts


@contextlib.contextmanager
def create_ogg_zip_txt_only_dataset_mult_seqs(*, seed: int = 1, num_seqs: int = 100, max_seq_len: int = 100):
"""create OggZipDataset"""
from returnn.datasets.audio import OggZipDataset

with create_ogg_zip_txt_only_dataset_mult_seqs_opts(seed=seed, num_seqs=num_seqs, max_seq_len=max_seq_len) as opts:
dataset = init_dataset(opts)
assert isinstance(dataset, OggZipDataset)
yield dataset
Expand Down Expand Up @@ -1212,6 +1220,18 @@ def _collect_single_seq(self, seq_idx: int) -> Optional[DatasetSeq]:
assert sub_ep == outer_epoch * multi_epoch + 1 and sub_seq_idx == 0


def test_dataset_sharding():
from returnn.datasets.audio import OggZipDataset

with create_ogg_zip_txt_only_dataset_mult_seqs_opts(num_seqs=10) as dataset_opts:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the test would be a bit nicer if the num_seqs is uneven, not divisible by num_shards.

Suggested change
with create_ogg_zip_txt_only_dataset_mult_seqs_opts(num_seqs=10) as dataset_opts:
with create_ogg_zip_txt_only_dataset_mult_seqs_opts(num_seqs=11) as dataset_opts:

datasets = [init_dataset({**dataset_opts, "_num_shards": 2, "_shard_index": i}) for i in range(2)]
for dataset in datasets:
assert isinstance(dataset, OggZipDataset)
dataset.init_seq_order(epoch=1)
assert dataset.shard_index < dataset.num_shards == 2
assert dataset.num_seqs == 5


if __name__ == "__main__":
better_exchook.install()
if len(sys.argv) <= 1:
Expand Down