Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
924045e
black
NeoLegends Jan 15, 2025
bb03813
Merge branch 'master' into moritz-shard-mgpu
NeoLegends Feb 4, 2025
9410000
add sharding test
NeoLegends Feb 4, 2025
3991edb
Merge branch 'master' into moritz-shard-mgpu
NeoLegends Feb 7, 2025
c0bce2b
initialize num_shards and shard_index in kwargs_update_from_config
NeoLegends Feb 10, 2025
6261afc
sharding parameters cannot be None anymore
NeoLegends Feb 10, 2025
f6a8942
we no longer need the @properties
NeoLegends Feb 10, 2025
c3ad46e
black
NeoLegends Feb 10, 2025
c3b2c04
Merge branch 'master' into moritz-shard-mgpu
NeoLegends Feb 21, 2025
485b6f4
Merge branch 'master' into moritz-shard-mgpu
NeoLegends Feb 27, 2025
653969c
Merge branch 'master' into moritz-shard-mgpu
NeoLegends Mar 5, 2025
7d753f5
take torch num_workers into account for sharding
NeoLegends Mar 5, 2025
2607de2
set sharding config in torch data pipe
NeoLegends Mar 5, 2025
5a14a5e
MultiProcDataset: support sharding on `"sharding_method": "dedicated"`
NeoLegends Mar 5, 2025
8d56a65
Fix missing assignment of distrib_shard_files
NeoLegends Mar 7, 2025
55d244d
fix sharding when distrib_shard_files is set and dataset_distribution…
NeoLegends Mar 7, 2025
90353a9
Merge branch 'master' into moritz-shard-mgpu
NeoLegends May 6, 2025
1011ace
black
NeoLegends May 6, 2025
3b28635
Merge branch 'master' into moritz-shard-mgpu
NeoLegends May 13, 2025
a3e6ad5
fix lints
NeoLegends May 20, 2025
713fde9
Merge branch 'master' into moritz-shard-mgpu
NeoLegends May 20, 2025
148191d
Merge branch 'master' into moritz-shard-mgpu
NeoLegends Jul 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 55 additions & 19 deletions returnn/datasets/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import math
import numpy
import functools
import typing
from typing import TYPE_CHECKING, Optional, Any, Set, Tuple, Union, Type, Dict, Sequence, List, Callable

from returnn.log import log
Expand Down Expand Up @@ -68,6 +67,12 @@ def set_or_remove(key, value):
set_or_remove("min_chunk_size", config.opt_typed_value("min_chunk_size", 0) or None)
set_or_remove("chunking_variance", config.float("chunking_variance", 0))

dd_cfg = config.typed_value("dataset_distribution", "random_seed_offset")
assert dd_cfg in ["random_seed_offset", "shard"]
shard_index, num_shards = Dataset._get_sharding_rank_and_size(config) if dd_cfg == "shard" else 0, 1
set_or_remove("num_shards", num_shards)
set_or_remove("shard_index", shard_index)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure whether it is a good idea to do this logic in kwargs_update_from_config. You don't want to apply this just for every dataset. E.g. for the dev/test/eval datasets, we would not want this logic, at least not right now.

I think this should come from the outside, not from within.


@staticmethod
def get_default_kwargs_eval(config: Config) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -112,8 +117,8 @@ def __init__(
min_chunk_size=0,
chunking_variance=0,
estimated_num_seqs=None,
_num_shards=1,
_shard_index=0,
num_shards: int = 1,
shard_index: int = 0,
Comment on lines +120 to +121
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand: Why are those public (not prefixed with _)? These are never supposed to be set by the user. Those are either internally set via DataLoader num_workers, or via parent dataset logic like MultiProcDataset, or via distributed training logic somehow, or so. But never directly by the user.

):
"""
:param str name: e.g. "train" or "eval"
Expand All @@ -137,8 +142,8 @@ def __init__(
:param str|None seq_order_seq_lens_file: for seq order, use the seq length given by this file
:param int shuffle_frames_of_nseqs: shuffles the frames. not always supported
:param None|int estimated_num_seqs: for progress reporting in case the real num_seqs is unknown
:param int _num_shards: number of shards the data is split into
:param int _shard_index: local shard index, when sharding is enabled
:param num_shards: number of shards the data is split into
:param shard_index: local shard index, when sharding is enabled
"""
self.name = name or ("dataset_id%s" % id(self))
self.lock: Optional[RLock] = None # Used when manipulating our data potentially from multiple threads.
Expand Down Expand Up @@ -170,9 +175,9 @@ def __init__(
self._chunking = chunking
self.chunk_size, self.chunk_step, self.custom_chunking_func = self._parse_chunking(chunking)
self._context_window = context_window
assert 0 <= _shard_index < _num_shards
self._num_shards = _num_shards
self._shard_index = _shard_index
assert 0 <= shard_index < num_shards
self.num_shards = num_shards
self.shard_index = shard_index
Comment on lines +178 to +180
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert 0 <= shard_index < num_shards
self.num_shards = num_shards
self.shard_index = shard_index
self.set_shard_idx_and_num_shards(shard_index, num_shards)

Slightly cleaner and reuses code, but it's also fine as is. Your choice.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That function has additional asserts that rely on correct initialization. Not for now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That function has additional asserts that rely on correct initialization. Not for now.

I don't understand the comment. What correct initialization? What do you mean by "not for now"?

if isinstance(context_window, (tuple, list)):
assert len(context_window) == 2
for elem in context_window:
Expand Down Expand Up @@ -248,6 +253,37 @@ def __reduce__(self):
state = {attr: getattr(self, attr) for attr in ["epoch", "zpad"]}
return Dataset._create_from_reduce, (self.__class__, kwargs, state)

def set_shard_idx_and_num_shards(self, shard_index: int, num_shards: int):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small inconsistency: idx vs index.

"""set sharding config for the dataset"""
assert 0 <= shard_index < num_shards
assert num_shards == 1 or self.supports_sharding()
self.num_shards = num_shards
self.shard_index = shard_index

@staticmethod
def _get_sharding_rank_and_size(config: Optional[Config] = None) -> Tuple[int, int]:
"""
:return: tuple (rank, size): the global rank and size for distributed trainings
"""
if config is None:
from returnn.config import get_global_config

config = get_global_config(return_empty_if_none=True)
if config.typed_value("torch_distributed") is not None:
import returnn.torch.distributed

ctx = returnn.torch.distributed.get_ctx(config=config)
return ctx.rank(), ctx.size()
elif config.is_true("use_horovod"):
assert config.bool("use_tensorflow", False) or config.value("backend", "").startswith("tensorflow")

import returnn.tf.horovod

ctx = returnn.tf.horovod.get_ctx(config=config)
return ctx.rank(), ctx.size()
else:
return 0, 1

@property
def random_seed_offset(self) -> int:
""":return: random seed offset for shuffling"""
Expand All @@ -257,10 +293,10 @@ def random_seed_offset(self) -> int:

def _uses_custom_distributed_sharding(self) -> bool:
"""
:return: if dataset has its own data sharding logic independent of TF/PT.
:return: if the dataset has its own data sharding logic independent of TF/PT.
Leads to a fixed random_seed_offset independent of the workers local rank.
"""
return False
return self.num_shards > 1

def _get_default_random_seed_offset(self):
"""
Expand Down Expand Up @@ -559,7 +595,7 @@ def get_seq_order_for_epoch(
for i in range(1, num):
seq_index[i::num] += i * (num_seqs // num)
elif seq_ordering_method == "reverse":
seq_index = range(num_seqs - 1, -1, -1) # type: Union[range, typing.Sequence[int]]
seq_index: Union[range, Sequence[int]] = range(num_seqs - 1, -1, -1)
elif seq_ordering_method in ["sorted", "sorted_reverse"]:
assert get_seq_len
reverse = -1 if seq_ordering_method == "sorted_reverse" else 1
Expand Down Expand Up @@ -641,9 +677,9 @@ def get_seq_order_for_epoch(
seq_index = [
i for i in seq_index if (all_seq_tags[i] not in used_seq_tags, used_seq_tags.add(all_seq_tags[i]))[0]
]
if partition_epoch > 1 or self._num_shards > 1:
if partition_epoch > 1 or self.num_shards > 1:
seq_index = self._apply_partition_epoch_and_sharding(
seq_index, partition_epoch, epoch, self._num_shards, self._shard_index
seq_index, partition_epoch, epoch, self.num_shards, self.shard_index
)
if repeat_epoch > 1:
seq_index = list(seq_index) * repeat_epoch
Expand Down Expand Up @@ -735,8 +771,8 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None):
"""
self.epoch = epoch
self.rnd_seq_drop = Random(self._get_random_seed_for_epoch(epoch=epoch))
assert self._num_shards == 1 or self.supports_sharding(), (
f"{self}: does not support sharding, but got num_shards == {self._num_shards}"
assert self.num_shards == 1 or self.supports_sharding(), (
f"{self}: does not support sharding, but got num_shards == {self.num_shards}"
)
return False

Expand Down Expand Up @@ -902,7 +938,7 @@ def get_corpus_seq_idx(self, seq_idx):
if self.seq_ordering == "default" and self.partition_epoch == 1:
return seq_idx
assert self.have_corpus_seq_idx()
raise NotImplemented
raise NotImplementedError

def have_get_corpus_seq(self) -> bool:
"""
Expand Down Expand Up @@ -1061,7 +1097,7 @@ def get_data_shape(self, key: str) -> List[int]:
if key in self.num_outputs:
if self.num_outputs[key][1] <= 1:
return []
res_shape = [None] * (self.num_outputs[key][1] - 1) # type: typing.List[typing.Union[None,int]]
res_shape: List[Union[None, int]] = [None] * (self.num_outputs[key][1] - 1)
if not self.is_data_sparse(key):
res_shape[-1] = self.get_data_dim(key)
return res_shape
Expand Down Expand Up @@ -1587,9 +1623,9 @@ def _dataset_extend_default_kwargs_from_parent_dataset(
default_kwargs = default_kwargs.copy() if default_kwargs else {}
default_kwargs.setdefault("random_seed_offset", parent_dataset.random_seed_offset)
# noinspection PyProtectedMember
default_kwargs.setdefault("_num_shards", parent_dataset._num_shards)
default_kwargs.setdefault("num_shards", parent_dataset.num_shards)
# noinspection PyProtectedMember
default_kwargs.setdefault("_shard_index", parent_dataset._shard_index)
default_kwargs.setdefault("shard_index", parent_dataset.shard_index)
return default_kwargs


Expand Down
94 changes: 40 additions & 54 deletions returnn/datasets/distrib_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,8 @@ def __init__(
get_sub_epoch_dataset: Callable[[List[FileTree]], Dict[str, Any]],
preload_next_n_sub_epochs: int = 1,
buffer_size: int = 1,
distrib_shard_files: bool = False,
distrib_shard_files: Optional[bool] = None,
_meta_info_cache: Optional[Dict[str, Any]] = None,
_distrib_info: Optional[Dict[str, int]] = None,
**kwargs,
):
"""
Expand All @@ -153,10 +152,10 @@ def __init__(
:param get_sub_epoch_dataset: callable which returns a dataset dict for a given subset of files
:param preload_next_n_sub_epochs: how many sub epoch datasets to preload
:param buffer_size: buffer size for each worker, amount of seqs to prefetch
:param distrib_shard_files: set to true to shard the data across worker processes in
distributed training scenaria
:param distrib_shard_files: deprecated. Replaced by global config option ``dataset_distribution="shard"``.

Set to true to shard the data across worker processes in distributed training scenaria.
Comment on lines +155 to +157
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is necessary to mark it as deprecated.

Suggested change
:param distrib_shard_files: deprecated. Replaced by global config option ``dataset_distribution="shard"``.
Set to true to shard the data across worker processes in distributed training scenaria.
:param distrib_shard_files: deprecated. set to true to shard the data across worker processes in
distributed training scenaria

:param _meta_info_cache: for internal use
:param _distrib_info: for internal use
"""
super().__init__(**kwargs)
self.files = files
Expand All @@ -172,21 +171,16 @@ def __init__(
self._workers: Dict[int, _WorkerProcParent] = {} # epoch -> worker
self._files_order_cache: Dict[int, List[List[FileTree]]] = {} # full epoch (0-indexed) -> files order

self.distrib_shard_files = distrib_shard_files
if distrib_shard_files:
assert self._num_shards == 1 and self._shard_index == 0, ( # ensure defaults are set
f"{self}: Cannot use both dataset-sharding via properties _num_shards and _shard index "
f"and {self.__class__.__name__}'s own sharding implementation based on the trainings rank and size."
if distrib_shard_files is not None:
log.print_deprecation_warning(
f"{self.__class__.__name__}' `distrib_shard_files` config option is set. "
"Use global config option `dataset_distribution` instead "
"for the same behavior across more types of datasets."
)
Comment on lines +175 to 179
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this is really necessary to mark this as deprecated.

Suggested change
log.print_deprecation_warning(
f"{self.__class__.__name__}' `distrib_shard_files` config option is set. "
"Use global config option `dataset_distribution` instead "
"for the same behavior across more types of datasets."
)

if _distrib_info:
# If we're in a child process `_get_rank_and_size()` no longer works,
# so we pass the info about the shards via a pickled property.
# See also Dataset.__reduce__.
self._shard_index = _distrib_info["_shard_index"]
self._num_shards = _distrib_info["_num_shards"]
else:
self._shard_index, self._num_shards = _get_rank_and_size()
assert 0 <= self._shard_index < self._num_shards
sharding_info = self._validate_global_config_and_get_sharding_info(distrib_shard_files)
if sharding_info is not None:
self.shard_index, self.num_shards = sharding_info
self.distrib_shard_files = distrib_shard_files

if _meta_info_cache:
# This allows to skip the lazy init in self.initialize().
Expand All @@ -208,10 +202,6 @@ def supports_sharding(self) -> bool:
""":return: whether the dataset supports sharding based on the seq_order"""
return True

@property
def _distrib_info(self):
return {"_num_shards": self._num_shards, "_shard_index": self._shard_index}

@property
def _meta_info_cache(self):
if not self.num_outputs:
Expand All @@ -225,9 +215,6 @@ def _meta_info_cache(self):
"files": self._files,
}

def _uses_custom_distributed_sharding(self) -> bool:
return self._num_shards > 1

def _lazy_init_file_list(self):
"""
The list of data files can either be provided as python list, or, if that grows
Expand Down Expand Up @@ -332,11 +319,11 @@ def init_seq_order(self, epoch: Optional[int] = None, seq_list=None, seq_order=N
else:
raise ValueError(f"{self}: seq_ordering {self.seq_ordering!r} not supported")
file_bins = self._distribute_evenly_by_size(
num_bins=self._num_shards * self.partition_epoch,
num_bins=self.num_shards * self.partition_epoch,
file_sizes=self._file_sizes,
files_order=files_order_flat,
)
self_index_base = self.partition_epoch * self._shard_index
self_index_base = self.partition_epoch * self.shard_index
self_index_end = self_index_base + self.partition_epoch
self._files_order_cache[full_epoch_0idx_] = file_bins[self_index_base:self_index_end]

Expand Down Expand Up @@ -370,6 +357,10 @@ def _get_sub_dataset_dict(self, files: List[FileTree]) -> Dict[str, Any]:

dataset_dict = self.get_sub_epoch_dataset(files)
dataset_dict = extend_dataset_dict_from_parent_dataset(dataset_dict, parent_dataset=self)
# We shard by splitting the files list into shards, the sub datasets must not shard any further by themselves
if self.num_shards > 1:
dataset_dict["num_shards"] = 1
dataset_dict["shard_index"] = 0

flat_sub_dset = tree.flatten_with_path(dataset_dict)

Expand Down Expand Up @@ -494,38 +485,33 @@ def get_data_keys(self) -> List[str]:
self._lazy_init_num_outputs()
return self._data_keys

@classmethod
def _validate_global_config_and_get_sharding_info(cls, distrib_shard_files: bool) -> Optional[Tuple[int, int]]:
from returnn.config import get_global_config

def _get_key_for_file_tree(t: FileTree) -> str:
"""generates a deterministic key given a file tree"""
import tree

return ":".join(tree.flatten(t))


def _get_rank_and_size() -> Tuple[int, int]:
"""
:return: tuple (rank, size): the global rank and size for distributed trainings
"""
config = get_global_config(raise_exception=False)
if not config:
return

from returnn.config import get_global_config
dd_cfg = config.typed_value("dataset_distribution", None)
if dd_cfg and (distrib_shard_files and dd_cfg != "shard") or (not distrib_shard_files and dd_cfg == "shard"):
raise ValueError(
f"{cls.__name__}: `distrib_shard_files` config ({distrib_shard_files}) mismatch "
f"with global config option `dataset_distribution` ({dd_cfg})."
)

config = get_global_config(raise_exception=False)
if not config:
return 0, 1
if config.typed_value("torch_distributed") is not None:
import returnn.torch.distributed
if distrib_shard_files and dd_cfg is None:
# RETURNN will set sharding info on the dataset if the global config is set.
# If it's not set, however, we need to respect the existing `distrib_shard_files` property
# for backwards compatibility and load the sharding info ourselves.
return CachedDataset2._get_sharding_rank_and_size(config)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return CachedDataset2._get_sharding_rank_and_size(config)
return cls._get_sharding_rank_and_size(config)

?
Or:

Suggested change
return CachedDataset2._get_sharding_rank_and_size(config)
return Dataset._get_sharding_rank_and_size(config)

?
But CachedDataset2 doesn't really make sense?


ctx = returnn.torch.distributed.get_ctx(config=config)
return ctx.rank(), ctx.size()
elif config.is_true("use_horovod"):
assert config.bool("use_tensorflow", False) or config.value("backend", "").startswith("tensorflow")

import returnn.tf.horovod
def _get_key_for_file_tree(t: FileTree) -> str:
"""generates a deterministic key given a file tree"""
import tree

ctx = returnn.tf.horovod.get_ctx(config=config)
return ctx.rank(), ctx.size()
else:
return 0, 1
return ":".join(tree.flatten(t))


class _WorkerProcParent:
Expand Down
4 changes: 2 additions & 2 deletions returnn/datasets/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,9 +1162,9 @@ def _get_random_dataset_seq_order(self, total_num_seqs):

assert sum(counters) == total_num_seqs

if self.partition_epoch or self._num_shards > 1:
if self.partition_epoch or self.num_shards > 1:
seq_order = self._apply_partition_epoch_and_sharding(
seq_order, self.partition_epoch, self.epoch, self._num_shards, self._shard_index
seq_order, self.partition_epoch, self.epoch, self.num_shards, self.shard_index
)
if self.repeat_epoch:
seq_order = seq_order * self.repeat_epoch
Expand Down
10 changes: 9 additions & 1 deletion returnn/datasets/multi_proc.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ def initialize(self):
self._lazy_init()
super().initialize()

def supports_sharding(self):
"""this dataset supports sharding on the ``dedicated`` sharding method"""
return self._sharding_method == "dedicated"

@property
def _meta_info_cache(self):
if not self.num_outputs:
Expand Down Expand Up @@ -144,7 +148,11 @@ def _lazy_init(self):
self._sharding_method,
)
elif self._sharding_method == "dedicated":
sub_dataset = {**self.dataset, "_num_shards": self.num_workers, "_shard_index": i}
sub_dataset = {
**self.dataset,
"num_shards": self.num_workers * self.num_shards,
"shard_index": (self.shard_index * self.num_workers) + i,
}
Comment on lines -147 to +155
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is something that I overlooked before. We can trivially allow sharding for MultiProcDataset in this case.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might not know the context of this comment. Isn't sharding already allowed as per your change here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is when sharding was set on the MultiProcDataset itself, and then MultiProcDataset also uses sharding on the subdatasets.

args = (
i,
sub_dataset,
Expand Down
14 changes: 14 additions & 0 deletions returnn/torch/data/returnn_dataset_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,22 @@ class ReturnnDatasetResetMpSharedEpochCallback:
def __init__(self, dataset: ReturnnDataset, epoch_mp_shared: torch.multiprocessing.Value):
self.dataset = dataset
self.epoch_mp_shared = epoch_mp_shared
self._sharding_config_set = False

def __call__(self):
# Include local worker rank in the dataset sharding config
if not self._sharding_config_set:
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
assert self.dataset.supports_sharding() or worker_info.num_workers == 1, (
f"Using {worker_info.num_workers} torch data loading workers "
"but dataset does not support sharding. This will result in repeated training data."
)
self.dataset.set_shard_idx_and_num_shards(
self.dataset.shard_index + worker_info.id, self.dataset.num_shards * worker_info.num_workers
)
Comment on lines +55 to +57
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand. Why does this consider the existing shard_index/num_shards? I would expect that you overwrite those here.

Suggested change
self.dataset.set_shard_idx_and_num_shards(
self.dataset.shard_index + worker_info.id, self.dataset.num_shards * worker_info.num_workers
)
self.dataset.set_shard_idx_and_num_shards(worker_info.id, worker_info.num_workers)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or do you expect that the existing shard_index/num_shards where set with the distributed rank/size, and this here additionally adds further sharding for workerid/num_workers?

But then this code written in this way is very confusing... I think this should be done differently somehow. Not sure how...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also do the distributed logic directly here. We pass on the rank/size info anyway to the subproc children via the env var _RETURNN_TORCH_DISTRIBUTED_INIT_INFO.

The question is, how to handle dataset_distribution. We could also just check the global config here at this point (even though I don't like accessing the global config too much... maybe I get a better idea). Then all the logic about what sharding options to set (or whether to set it at all) is all here in this place, together with the dataloader num workers.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of having all this logic here, we maybe should better move it to ReturnnDatasetIterDataPipe.reset though?

self._sharding_config_set = True

# dataset is likely a copy of the original dataset, either in the main process or in a worker process
# Use epoch_mp_shared to get the current epoch correctly in worked processes
epoch = self.epoch_mp_shared.value or None
Expand Down
Loading