From 87864e2c7a9e0ab672507ff01532bcc19d1249d3 Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Fri, 19 Sep 2025 12:02:59 +0200 Subject: [PATCH 01/58] Pull apart map_seq application and dataset iteration --- returnn/datasets/postprocessing.py | 71 ++++++++++++++++++------------ 1 file changed, 43 insertions(+), 28 deletions(-) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index d6ccaed08..676690ba6 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -166,7 +166,11 @@ def __init__( self.labels = self._dataset.labels.copy() # update only after _out_tensor_dict_template has been created from _in_tensor_dict_template self._in_tensor_dict_template.update( - {"complete_frac": {"dims": (), "dtype": "float32"}, "seq_tag": {"dims": (), "dtype": "string"}}, + { + "complete_frac": {"dims": (), "dtype": "float32"}, + "seq_idx": {"dims": (), "dtype": "int32"}, + "seq_tag": {"dims": (), "dtype": "string"}, + }, auto_convert=True, ) self.num_outputs = { @@ -322,7 +326,43 @@ def _validate_tensor_dict_iter(inner: Iterator[TensorDict]) -> Iterator[TensorDi ) yield t_dict + def _apply_map_seq(inner: Iterator[TensorDict]) -> Iterator[TensorDict]: + for tensor_dict in inner: + seq_index = int(tensor_dict.data["seq_idx"].raw_tensor.item()) + comp_frac_raw_tensor = ( + tensor_dict.data["complete_frac"].raw_tensor if "complete_frac" in tensor_dict.data else None + ) + seq_tag_raw_tensor = tensor_dict.data["seq_tag"].raw_tensor + + tensor_dict = self._map_seq( + tensor_dict, epoch=self.epoch, seq_idx=seq_index, rng=self._rng, **util.get_fwd_compat_kwargs() + ) + assert isinstance(tensor_dict, TensorDict), ( + f"map_seq must produce a {TensorDict.__name__}, but produced {type(tensor_dict).__name__}" + ) + + # Re-adding the seq_tag/complete_frac here causes no harm in case they are dropped + # since we don't add/drop any segments w/ the non-iterator postprocessing function. + if "complete_frac" not in tensor_dict.data and comp_frac_raw_tensor is not None: + tensor_dict.data["complete_frac"] = Tensor( + "complete_frac", dims=(), dtype="float32", raw_tensor=comp_frac_raw_tensor + ) + if "seq_tag" not in tensor_dict.data: + tensor_dict.data["seq_tag"] = Tensor( + "seq_tag", dims=(), dtype="string", raw_tensor=seq_tag_raw_tensor + ) + + if self._seq_list_for_validation is not None: + seq_tag = self._seq_list_for_validation[seq_index] + tag_of_seq = tensor_dict.data["seq_tag"].raw_tensor.item() + assert tag_of_seq == seq_tag, ( + f"seq tag mismath: {tag_of_seq} != {seq_tag} for seq index {seq_index} when seq list is given" + ) + yield tensor_dict + data_iter = self._iterate_dataset() + if self._map_seq is not None: + data_iter = _apply_map_seq(data_iter) if self._map_seq_stream is not None: data_iter = self._map_seq_stream(data_iter, epoch=self.epoch, rng=self._rng, **util.get_fwd_compat_kwargs()) assert isinstance(data_iter, Iterator), ( @@ -345,39 +385,14 @@ def _iterate_dataset(self) -> Iterator[TensorDict]: tensor_dict.data[data_key].raw_tensor = self._dataset.get_data(seq_index, data_key) complete_frac = self._dataset.get_complete_frac(seq_index, allow_only_lr_suitable=True) - comp_frac_raw_tensor = None if complete_frac is not None: comp_frac_raw_tensor = numpy.array(complete_frac, dtype=numpy.float32) tensor_dict.data["complete_frac"].raw_tensor = comp_frac_raw_tensor + seq_idx_raw_tensor = numpy.array(seq_index, dtype=numpy.int32) + tensor_dict.data["seq_idx"].raw_tensor = seq_idx_raw_tensor seq_tag_raw_tensor = str_to_numpy_array(self._dataset.get_tag(seq_index)) tensor_dict.data["seq_tag"].raw_tensor = seq_tag_raw_tensor - if self._map_seq is not None: - tensor_dict = self._map_seq( - tensor_dict, epoch=self.epoch, seq_idx=seq_index, rng=self._rng, **util.get_fwd_compat_kwargs() - ) - assert isinstance(tensor_dict, TensorDict), ( - f"map_seq must produce a {TensorDict.__name__}, but produced {type(tensor_dict).__name__}" - ) - - # Re-adding the seq_tag/complete_frac here causes no harm in case they are dropped - # since we don't add/drop any segments w/ the non-iterator postprocessing function. - if "complete_frac" not in tensor_dict.data and comp_frac_raw_tensor is not None: - tensor_dict.data["complete_frac"] = Tensor( - "complete_frac", dims=(), dtype="float32", raw_tensor=comp_frac_raw_tensor - ) - if "seq_tag" not in tensor_dict.data: - tensor_dict.data["seq_tag"] = Tensor( - "seq_tag", dims=(), dtype="string", raw_tensor=seq_tag_raw_tensor - ) - - if self._seq_list_for_validation is not None: - seq_tag = self._seq_list_for_validation[seq_index] - tag_of_seq = tensor_dict.data["seq_tag"].raw_tensor.item() - assert tag_of_seq == seq_tag, ( - f"seq tag mismath: {tag_of_seq} != {seq_tag} for seq index {seq_index} when seq list is given" - ) - yield tensor_dict seq_index += 1 From 7e6f78ffb3777c0b9545d73bd150786555bcae5b Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Fri, 19 Sep 2025 12:10:35 +0200 Subject: [PATCH 02/58] make mapping iterator builder static --- returnn/datasets/postprocessing.py | 154 +++++++++++++++++------------ 1 file changed, 91 insertions(+), 63 deletions(-) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index 676690ba6..0700ba857 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -212,7 +212,17 @@ def init_seq_order( self._rng = RandomState(self._get_random_seed_for_epoch(epoch=epoch)) assert self._dataset is not None self._dataset.init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order) - self._data_iter = enumerate(self._build_mapping_iter()) + self._data_iter = enumerate( + self._build_mapping_iter( + self._iterate_dataset(), + map_seq=self._map_seq, + map_seq_stream=self._map_seq_stream, + epoch=epoch, + out_tensor_dict_template=self._out_tensor_dict_template, + rng=self._rng, + seq_list_for_validation=seq_list, + ) + ) self._data_iter_produced_num_seqs = 0 self._seq_list_for_validation = seq_list if self._map_seq_stream is None or self._map_seq_stream_preserves_num_seqs is True: @@ -290,8 +300,53 @@ def _collect_single_seq(self, seq_idx: int) -> Optional[DatasetSeq]: seq = DatasetSeq(complete_frac=complete_frac, features=features, seq_idx=seq_idx, seq_tag=seq_tag) return seq - def _build_mapping_iter(self) -> Iterator[TensorDict]: + def _iterate_dataset(self) -> Iterator[TensorDict]: + """ + :return: generator providing data samples in the form of a TensorDict + """ + data_keys = self._dataset.get_data_keys() + + seq_index = 0 + while self._dataset.is_less_than_num_seqs(seq_index): + self._dataset.load_seqs(seq_index, seq_index + 1) + + tensor_dict = self._in_tensor_dict_template.copy_template() + for data_key in data_keys: + tensor_dict.data[data_key].raw_tensor = self._dataset.get_data(seq_index, data_key) + + complete_frac = self._dataset.get_complete_frac(seq_index, allow_only_lr_suitable=True) + if complete_frac is not None: + comp_frac_raw_tensor = numpy.array(complete_frac, dtype=numpy.float32) + tensor_dict.data["complete_frac"].raw_tensor = comp_frac_raw_tensor + seq_idx_raw_tensor = numpy.array(seq_index, dtype=numpy.int32) + tensor_dict.data["seq_idx"].raw_tensor = seq_idx_raw_tensor + seq_tag_raw_tensor = str_to_numpy_array(self._dataset.get_tag(seq_index)) + tensor_dict.data["seq_tag"].raw_tensor = seq_tag_raw_tensor + + yield tensor_dict + seq_index += 1 + + @staticmethod + def _build_mapping_iter( + data_iter: Iterator[TensorDict], + *, + map_seq: Optional[Callable[[TensorDict, int, int, RandomState], TensorDict]] = None, + map_seq_stream: Optional[Callable[[Iterator[TensorDict], int, RandomState], Iterator[TensorDict]]] = None, + epoch: int, + out_tensor_dict_template: TensorDict, + rng: RandomState, + seq_list_for_validation: Optional[List[str]] = None, + ) -> Iterator[TensorDict]: """ + Build an iterator applying the mapping functions on the given dataset iterator. + + :param data_iter: iterator providing data samples in the form of a TensorDict + :param map_seq: see :class:`PostprocessingDataset` + :param map_seq_stream: see :class:`PostprocessingDataset` + :param epoch: current epoch number + :param out_tensor_dict_template: template for the output TensorDicts, used for validation + :param rng: random number generator to use + :param seq_list_for_validation: optional list of seq tags to validate against when processing the data :return: an iterator applying both the segment level and across-segment transformations on the given dataset """ @@ -310,7 +365,7 @@ def _validate_tensor_dict_iter(inner: Iterator[TensorDict]) -> Iterator[TensorDi f"but got {complete_frac} after {last_complete_frac}" ) last_complete_frac = complete_frac - for data_key, out_t in self._out_tensor_dict_template.data.items(): + for data_key, out_t in out_tensor_dict_template.data.items(): in_t = t_dict.data[data_key] assert in_t.ndim == out_t.batch_ndim, ( f"Dim number mismatch for {data_key}: {in_t.ndim} != {out_t.batch_ndim}. " @@ -326,76 +381,49 @@ def _validate_tensor_dict_iter(inner: Iterator[TensorDict]) -> Iterator[TensorDi ) yield t_dict - def _apply_map_seq(inner: Iterator[TensorDict]) -> Iterator[TensorDict]: - for tensor_dict in inner: - seq_index = int(tensor_dict.data["seq_idx"].raw_tensor.item()) - comp_frac_raw_tensor = ( - tensor_dict.data["complete_frac"].raw_tensor if "complete_frac" in tensor_dict.data else None - ) - seq_tag_raw_tensor = tensor_dict.data["seq_tag"].raw_tensor + def _apply_map_seq(tensor_dict: TensorDict) -> TensorDict: + comp_frac_raw_tensor = ( + tensor_dict.data["complete_frac"].raw_tensor if "complete_frac" in tensor_dict.data else None + ) + seq_index_raw = tensor_dict.data["seq_idx"].raw_tensor + seq_index = int(seq_index_raw.item()) + seq_tag_raw_tensor = tensor_dict.data["seq_tag"].raw_tensor + + tensor_dict = map_seq(tensor_dict, epoch=epoch, seq_idx=seq_index, rng=rng, **util.get_fwd_compat_kwargs()) + assert isinstance(tensor_dict, TensorDict), ( + f"map_seq must produce a {TensorDict.__name__}, but produced {type(tensor_dict).__name__}" + ) - tensor_dict = self._map_seq( - tensor_dict, epoch=self.epoch, seq_idx=seq_index, rng=self._rng, **util.get_fwd_compat_kwargs() + # Re-adding the seq_tag/complete_frac here causes no harm in case they are dropped + # since we don't add/drop any segments w/ the non-iterator postprocessing function. + if "complete_frac" not in tensor_dict.data and comp_frac_raw_tensor is not None: + tensor_dict.data["complete_frac"] = Tensor( + "complete_frac", dims=(), dtype="float32", raw_tensor=comp_frac_raw_tensor ) - assert isinstance(tensor_dict, TensorDict), ( - f"map_seq must produce a {TensorDict.__name__}, but produced {type(tensor_dict).__name__}" + if "seq_idx" not in tensor_dict.data: + tensor_dict.data["seq_idx"] = Tensor("seq_idx", dims=(), dtype="int32", raw_tensor=seq_index_raw) + if "seq_tag" not in tensor_dict.data: + tensor_dict.data["seq_tag"] = Tensor("seq_tag", dims=(), dtype="string", raw_tensor=seq_tag_raw_tensor) + + if seq_list_for_validation is not None: + seq_tag = seq_list_for_validation[seq_index] + tag_of_seq = tensor_dict.data["seq_tag"].raw_tensor.item() + assert tag_of_seq == seq_tag, ( + f"seq tag mismath: {tag_of_seq} != {seq_tag} for seq index {seq_index} when seq list is given" ) - # Re-adding the seq_tag/complete_frac here causes no harm in case they are dropped - # since we don't add/drop any segments w/ the non-iterator postprocessing function. - if "complete_frac" not in tensor_dict.data and comp_frac_raw_tensor is not None: - tensor_dict.data["complete_frac"] = Tensor( - "complete_frac", dims=(), dtype="float32", raw_tensor=comp_frac_raw_tensor - ) - if "seq_tag" not in tensor_dict.data: - tensor_dict.data["seq_tag"] = Tensor( - "seq_tag", dims=(), dtype="string", raw_tensor=seq_tag_raw_tensor - ) - - if self._seq_list_for_validation is not None: - seq_tag = self._seq_list_for_validation[seq_index] - tag_of_seq = tensor_dict.data["seq_tag"].raw_tensor.item() - assert tag_of_seq == seq_tag, ( - f"seq tag mismath: {tag_of_seq} != {seq_tag} for seq index {seq_index} when seq list is given" - ) - yield tensor_dict + return tensor_dict - data_iter = self._iterate_dataset() - if self._map_seq is not None: - data_iter = _apply_map_seq(data_iter) - if self._map_seq_stream is not None: - data_iter = self._map_seq_stream(data_iter, epoch=self.epoch, rng=self._rng, **util.get_fwd_compat_kwargs()) + assert not (map_seq and map_seq_stream), "cannot set both map_seq and map_seq_stream" + if map_seq is not None: + data_iter = (_apply_map_seq(t_dict) for t_dict in data_iter) + if map_seq_stream is not None: + data_iter = map_seq_stream(data_iter, epoch=epoch, rng=rng, **util.get_fwd_compat_kwargs()) assert isinstance(data_iter, Iterator), ( f"map_seq_stream must produce an {Iterator.__name__}, but produced {type(data_iter).__name__}" ) return _validate_tensor_dict_iter(data_iter) - def _iterate_dataset(self) -> Iterator[TensorDict]: - """ - :return: generator providing data samples in the form of a TensorDict - """ - data_keys = self._dataset.get_data_keys() - - seq_index = 0 - while self._dataset.is_less_than_num_seqs(seq_index): - self._dataset.load_seqs(seq_index, seq_index + 1) - - tensor_dict = self._in_tensor_dict_template.copy_template() - for data_key in data_keys: - tensor_dict.data[data_key].raw_tensor = self._dataset.get_data(seq_index, data_key) - - complete_frac = self._dataset.get_complete_frac(seq_index, allow_only_lr_suitable=True) - if complete_frac is not None: - comp_frac_raw_tensor = numpy.array(complete_frac, dtype=numpy.float32) - tensor_dict.data["complete_frac"].raw_tensor = comp_frac_raw_tensor - seq_idx_raw_tensor = numpy.array(seq_index, dtype=numpy.int32) - tensor_dict.data["seq_idx"].raw_tensor = seq_idx_raw_tensor - seq_tag_raw_tensor = str_to_numpy_array(self._dataset.get_tag(seq_index)) - tensor_dict.data["seq_tag"].raw_tensor = seq_tag_raw_tensor - - yield tensor_dict - seq_index += 1 - def _make_tensor_template_from_input(self, data_key: str) -> Tensor: dtype = self._dataset.get_data_dtype(data_key) if dtype == "string": From 823434e1a720e4ac34e820f08b70e1e069ea14e8 Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Fri, 19 Sep 2025 14:22:55 +0200 Subject: [PATCH 03/58] Implement multi-process postprocessing --- returnn/datasets/postprocessing.py | 297 +++++++++++++++++++++++++++-- 1 file changed, 286 insertions(+), 11 deletions(-) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index 0700ba857..b0d7ad18f 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -4,20 +4,34 @@ from __future__ import annotations +from collections import deque from itertools import islice import numpy from numpy.random import RandomState -from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, TypeVar +import sys +import threading +from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, TypeVar from returnn.datasets.basic import DatasetSeq from returnn.datasets.util.strings import str_to_numpy_array from returnn.datasets.util.vocabulary import Vocabulary from returnn.tensor import Tensor, TensorDict from returnn.tensor.dim import Dim -from returnn.util import basic as util +from returnn.util import basic as util, better_exchook +from returnn.util.multi_proc_non_daemonic_spawn import NonDaemonicSpawnContext +from returnn.config import SubProcCopyGlobalConfigPreInitFunc from .basic import init_dataset from .cached2 import CachedDataset2 +# noinspection PyProtectedMember +from multiprocessing.connection import Connection as mpConnection + +# noinspection PyProtectedMember +from multiprocessing.queues import Queue as mpQueue + +_mp = NonDaemonicSpawnContext(process_pre_init_func=SubProcCopyGlobalConfigPreInitFunc()) + + __all__ = ["PostprocessingDataset", "LaplaceOrdering", "Sequential"] @@ -93,11 +107,14 @@ def my_map_seq_stream(iterator): def __init__( self, + *, dataset: Dict[str, Any], map_seq: Optional[Callable] = None, map_seq_stream: Optional[Callable] = None, map_outputs: Optional[Dict[str, Any]] = None, map_seq_stream_preserves_num_seqs: Optional[bool] = None, + buf_size: int = 1, + num_workers: int = 0, **kwargs, ): """ @@ -123,6 +140,10 @@ def __init__( Example: `map_outputs={"data": {"dim": 42}}` :param map_seq_stream_preserves_num_seqs: whether the function in map_seq_stream preserves the number of sequences, i.e. for every input sequence there is exactly one output sequence. + :param buf_size: buffer size for each worker, number of seqs to prefetch. + :param num_workers: number of worker processes to use for data postprocessing. + This does not apply parallelism to the wrapped dataset, but only to the postprocessing step. + When set to 0, postprocessing happens inline. :param kwargs: see :class:`CachedDataset2`, :class:`Dataset` """ super().__init__(**kwargs) @@ -135,6 +156,10 @@ def __init__( raise ValueError(f"{self}: cannot set both map_seq and map_seq_stream") if map_seq and map_seq_stream_preserves_num_seqs is not None: raise ValueError(f"{self}: map_seq_stream_preserves_num_seqs is only allowed with map_seq_stream") + if buf_size < 1: + raise ValueError(f"{self}: buf_size must be >= 1, but got {buf_size}") + if num_workers < 0: + raise ValueError(f"{self}: num_workers must be >= 0, but got {num_workers}") self._dataset_def = dataset self._map_seq = map_seq @@ -144,7 +169,8 @@ def __init__( assert map_seq_stream_preserves_num_seqs is None or isinstance(map_seq_stream_preserves_num_seqs, bool) self._map_seq_stream_preserves_num_seqs = map_seq_stream_preserves_num_seqs self._map_outputs = map_outputs - self._rng = RandomState(self._get_random_seed_for_epoch(0)) + self._buf_size = buf_size + self._num_workers = num_workers self._seq_list_for_validation: Optional[List[str]] = None self._dataset = init_dataset(self._dataset_def, parent_dataset=self) @@ -209,20 +235,46 @@ def init_seq_order( self._num_seqs = 0 return True - self._rng = RandomState(self._get_random_seed_for_epoch(epoch=epoch)) assert self._dataset is not None self._dataset.init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order) - self._data_iter = enumerate( - self._build_mapping_iter( + if self._num_workers > 0: + assert self._buf_size > 0 + seq_queues = [_mp.Queue(maxsize=self._buf_size) for _ in range(self._num_workers)] + worker_procs = [ + _WorkerProcParent( + name=f"{self.__class__.__name__} {self.name} ep {epoch}", + epoch=epoch, + buffer_size=self._buf_size, + map_seq=self._map_seq, + map_seq_stream=self._map_seq_stream, + out_tensor_dict_template=self._out_tensor_dict_template, + rng_seed=self._get_random_seed_for_epoch(epoch=epoch * self._num_workers + i), + seq_list=seq_list, + seq_queue=seq_queue, + ) + for i, seq_queue in enumerate(seq_queues) + ] + quit_event = threading.Event() + dataset_thread = threading.Thread( + target=self._distribute_seqs_to_children, + kwargs={"child_queues": seq_queues, "quit_event": quit_event}, + name=f"{self.__class__.__name__} {self.name} dataset ep {epoch}", + ) + dataset_thread.start() + data_iter = _MultiProcDataIter( + dataset_thread=dataset_thread, quit_event=quit_event, worker_procs=worker_procs + ) + else: + data_iter = self._build_mapping_iter( self._iterate_dataset(), map_seq=self._map_seq, map_seq_stream=self._map_seq_stream, epoch=epoch, out_tensor_dict_template=self._out_tensor_dict_template, - rng=self._rng, + rng=RandomState(self._get_random_seed_for_epoch(epoch=epoch)), seq_list_for_validation=seq_list, ) - ) + self._data_iter = enumerate(data_iter) self._data_iter_produced_num_seqs = 0 self._seq_list_for_validation = seq_list if self._map_seq_stream is None or self._map_seq_stream_preserves_num_seqs is True: @@ -330,8 +382,8 @@ def _iterate_dataset(self) -> Iterator[TensorDict]: def _build_mapping_iter( data_iter: Iterator[TensorDict], *, - map_seq: Optional[Callable[[TensorDict, int, int, RandomState], TensorDict]] = None, - map_seq_stream: Optional[Callable[[Iterator[TensorDict], int, RandomState], Iterator[TensorDict]]] = None, + map_seq: Optional[Callable] = None, + map_seq_stream: Optional[Callable] = None, epoch: int, out_tensor_dict_template: TensorDict, rng: RandomState, @@ -394,7 +446,7 @@ def _apply_map_seq(tensor_dict: TensorDict) -> TensorDict: f"map_seq must produce a {TensorDict.__name__}, but produced {type(tensor_dict).__name__}" ) - # Re-adding the seq_tag/complete_frac here causes no harm in case they are dropped + # Re-adding the complete_frac/seq_idx/seq_tag here causes no harm in case they are dropped # since we don't add/drop any segments w/ the non-iterator postprocessing function. if "complete_frac" not in tensor_dict.data and comp_frac_raw_tensor is not None: tensor_dict.data["complete_frac"] = Tensor( @@ -414,6 +466,7 @@ def _apply_map_seq(tensor_dict: TensorDict) -> TensorDict: return tensor_dict + assert map_seq or map_seq_stream, "need to specify either map_seq or map_seq_stream" assert not (map_seq and map_seq_stream), "cannot set both map_seq and map_seq_stream" if map_seq is not None: data_iter = (_apply_map_seq(t_dict) for t_dict in data_iter) @@ -442,6 +495,228 @@ def _make_tensor_template_from_input(self, data_key: str) -> Tensor: sparse_dim.vocab = Vocabulary.create_vocab_from_labels(self._dataset.labels[data_key]) return Tensor(data_key, dims=dims, dtype=dtype, sparse_dim=sparse_dim) + def _distribute_seqs_to_children(self, *, child_queues: Sequence[mpQueue], quit_event: threading.Event): + num_workers = len(child_queues) + assert num_workers > 0 + for seq_idx, tensor_dict in enumerate(self._iterate_dataset()): + if quit_event.is_set(): + break + try: + child_queues[seq_idx % num_workers].put(tensor_dict, block=True) # block for backpressure + except ValueError: + # queue is closed, i.e. the worker process died + break + for q in child_queues: + try: + q.put(None, block=True) # signal end of data + except ValueError: + # queue is already closed, i.e. the worker process died + pass + + +class _MultiProcDataIter: + """ + Data iter that pulls from the worker processes and manages their lifetime. + """ + + def __init__( + self, + *, + dataset_thread: threading.Thread, + quit_event: threading.Event, + worker_procs: List[_WorkerProcParent], + ): + self.dataset_thread = dataset_thread + self.quit_event = quit_event + self.worker_procs = worker_procs + + self._max_seq_idx: Optional[int] = None + self._seq_idx = 0 + + def __iter__(self): + return self + + def __next__(self) -> Optional[TensorDict]: + if self.quit_event.is_set(): + raise StopIteration + + for _ in range(len(self.worker_procs)): + if self._max_seq_idx is not None and self._seq_idx >= self._max_seq_idx: + break + num_workers = len(self.worker_procs) + seq = self.worker_procs[self._seq_idx % num_workers].get_seq() + if seq is None and self._max_seq_idx is None: + self._max_seq_idx = self._seq_idx + num_workers # drain the other workers + self._seq_idx += 1 + if seq is not None: + return seq + + # when we reach this point, all workers are exhausted and we stop + self.stop() + raise StopIteration + + def stop(self): + if self.quit_event.is_set(): + return + self.quit_event.set() + for wp in self.worker_procs: + wp.exit(join=True) + util.try_run(self.dataset_thread.join) + + def __del__(self): + try: + self.stop() + except Exception: + pass + + +class _WorkerProcParent: + def __init__( + self, + *, + name: str, + epoch: int, + buffer_size: int, + map_seq: Optional[Callable], + map_seq_stream: Optional[Callable], + out_tensor_dict_template: TensorDict, + rng_seed: int, + seq_list: Optional[List[str]], + seq_queue: mpQueue, + ): + parent_conn, child_conn = _mp.Pipe() + self.parent_conn = parent_conn + + self.worker_proc = _mp.Process( + name=f"{name} worker ep {epoch}", + target=_worker_proc_loop, + args=( + epoch, + buffer_size, + map_seq, + map_seq_stream, + out_tensor_dict_template, + rng_seed, + seq_list, + child_conn, + seq_queue, + ), + daemon=True, + ) + self.worker_proc.start() + + # Make sure the child connection is closed here. + # It stays open in the child, until the child dies. + # When that happens, now any consecutive read on the pipe + # should yield an exception -- which is what we want, + # otherwise it would just hang. + child_conn.close() + + def get_seq(self) -> Optional[DatasetSeq]: + """get_seq""" + self.parent_conn.send(("get_seq", {})) + msg, seq = self.parent_conn.recv() + assert msg == "seq" + return seq + + def exit(self, *, join: bool = True): + """exit""" + self.parent_conn.send(("exit", {})) + if join: + self.worker_proc.join() + + def __del__(self): + # noinspection PyBroadException + try: + self.exit(join=False) + except Exception: + pass + else: + util.try_run(self.worker_proc.join) + + +def _worker_proc_loop( + epoch: int, + buffer_size: int, + map_seq: Optional[Callable], + map_seq_stream: Optional[Callable], + out_tensor_dict_template: TensorDict, + rng_seed: int, + seq_list: Optional[List[str]], + parent_conn: mpConnection, + seq_queue: mpQueue, +): + if sys.platform == "linux": + with open("/proc/self/comm", "w") as f: + f.write(f"PP worker {epoch}") + better_exchook.setup_all() + + assert isinstance(epoch, int) + assert isinstance(buffer_size, int) + assert buffer_size > 0 + assert map_seq or map_seq_stream, "need to specify either map_seq or map_seq_stream" + assert not (map_seq and map_seq_stream), "cannot set both map_seq and map_seq_stream" + assert map_seq is None or isinstance(map_seq, Callable) + assert map_seq_stream is None or isinstance(map_seq_stream, Callable) + assert isinstance(out_tensor_dict_template, TensorDict) + assert isinstance(rng_seed, int) + assert isinstance(parent_conn, mpConnection) + assert isinstance(seq_queue, mpQueue) + + cache: deque[TensorDict] = deque() + + def _iter_queue(q: mpQueue) -> Iterator[TensorDict]: + while True: + try: + item = q.get(block=True) + except ValueError: + # queue is closed + break + if item is None: + break + yield item + + data_iter = PostprocessingDataset._build_mapping_iter( + _iter_queue(seq_queue), + map_seq=map_seq, + map_seq_stream=map_seq_stream, + epoch=epoch, + out_tensor_dict_template=out_tensor_dict_template, + rng=RandomState(rng_seed), + seq_list_for_validation=seq_list, + ) + assert isinstance(data_iter, Iterator) + + def _add_to_cache(): + nonlocal data_iter + try: + seq = next(data_iter) + except StopIteration: + data_iter = None + return False + cache.append(seq) + return True + + try: + while True: + while not parent_conn.poll(): + while data_iter is not None and len(cache) < buffer_size: + if not _add_to_cache(): + break + msg, kwargs = parent_conn.recv() + if msg == "exit": + break + elif msg == "get_seq": + if not cache and data_iter is not None: + _add_to_cache() + parent_conn.send(("seq", cache.popleft() if cache else None)) + else: + raise Exception(f"unknown msg {msg!r}") + except KeyboardInterrupt: # when parent dies + pass + except EOFError: # when parent dies + pass + class LaplaceOrdering(Callable[[Iterator[TensorDict]], Iterator[TensorDict]]): """ From 7d3573149b52782bdbe15c0f1cb48cd223e529f0 Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Fri, 19 Sep 2025 14:30:02 +0200 Subject: [PATCH 04/58] add test (currently fails due to pickle issues) --- tests/test_Dataset.py | 30 ++++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/tests/test_Dataset.py b/tests/test_Dataset.py index 4da968429..25d68ea6a 100644 --- a/tests/test_Dataset.py +++ b/tests/test_Dataset.py @@ -1203,7 +1203,7 @@ def _add_1337_to_classes(tdict: TensorDict, **kwargs) -> TensorDict: count = 0 - def _repeat2(input_iter: Iterator[TensorDict], **kwargs) -> Iterator[TensorDict]: + def _repeat2_and_count(input_iter: Iterator[TensorDict], **kwargs) -> Iterator[TensorDict]: nonlocal count for tdict in input_iter: @@ -1216,7 +1216,7 @@ def _repeat2(input_iter: Iterator[TensorDict], **kwargs) -> Iterator[TensorDict] ds_opts = { "class": "PostprocessingDataset", "dataset": sub_ds_opts, - "map_seq_stream": _repeat2, + "map_seq_stream": _repeat2_and_count, } dataset = init_dataset(ds_opts) dataset.init_seq_order(epoch=1) @@ -1267,6 +1267,32 @@ def _repeat2(input_iter: Iterator[TensorDict], **kwargs) -> Iterator[TensorDict] assert func(2) == 21 +def _repeat2(input_iter: Iterator[TensorDict], **kwargs) -> Iterator[TensorDict]: + for tdict in input_iter: + yield tdict + yield tdict + + +def test_PostprocessingDataset_multi_proc(): + _demo_txt = "some utterance text that has a few words" + with create_ogg_zip_txt_only_dataset_opts(text=_demo_txt) as sub_ds_opts: + ds_opts = { + "class": "PostprocessingDataset", + "dataset": sub_ds_opts, + "map_seq_stream": _repeat2, + "buf_size": 1, + "num_workers": 2, + } + dataset = init_dataset(ds_opts) + dataset.init_seq_order(epoch=1) + assert dataset.have_seqs() + + dataset.load_seqs(0, 2) + for i in range(2): + classes = dataset.get_data(i, "classes") + assert len(classes) > 0 + + def _post_process_map_seq_no_op(tdict: TensorDict, **_other) -> TensorDict: return tdict From e70135b7d607a47af52396e91818f6e8d1ebdfdb Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Fri, 19 Sep 2025 14:34:46 +0200 Subject: [PATCH 05/58] add TODO on dataset lock --- returnn/datasets/postprocessing.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index b0d7ad18f..4bc8ce0e5 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -498,6 +498,12 @@ def _make_tensor_template_from_input(self, data_key: str) -> Tensor: def _distribute_seqs_to_children(self, *, child_queues: Sequence[mpQueue], quit_event: threading.Event): num_workers = len(child_queues) assert num_workers > 0 + + # TODO: should we hold a lock around the dataset while this thread is alive? + # + # This would help prevent issues when switching from one epoch to the next + # (where a new thread will be started). + for seq_idx, tensor_dict in enumerate(self._iterate_dataset()): if quit_event.is_set(): break From bf878f29ab11697a5ba56eb55749043987a03bea Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Fri, 19 Sep 2025 14:34:53 +0200 Subject: [PATCH 06/58] ignore broad exception lint --- returnn/datasets/postprocessing.py | 1 + 1 file changed, 1 insertion(+) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index 4bc8ce0e5..16cb2aaf3 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -570,6 +570,7 @@ def stop(self): util.try_run(self.dataset_thread.join) def __del__(self): + # noinspection PyBroadException try: self.stop() except Exception: From bc1a7c5b002a4d3d064f301807ae67b6aff169be Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Fri, 19 Sep 2025 15:05:16 +0200 Subject: [PATCH 07/58] extend docs --- returnn/datasets/postprocessing.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index 16cb2aaf3..c21f635b2 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -142,7 +142,13 @@ def __init__( sequences, i.e. for every input sequence there is exactly one output sequence. :param buf_size: buffer size for each worker, number of seqs to prefetch. :param num_workers: number of worker processes to use for data postprocessing. + This does not apply parallelism to the wrapped dataset, but only to the postprocessing step. + + Conceptually, this achieves similar results to using MultiProcDataset, but with potentially lower + memory consumption, since only the postprocessing step is parallelized and not the wrapped/underlying + source dataset that is postprocessed. + When set to 0, postprocessing happens inline. :param kwargs: see :class:`CachedDataset2`, :class:`Dataset` """ From 6c463e9fc625419110fb356b0a4d3a132f3cb813 Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Fri, 19 Sep 2025 15:44:19 +0200 Subject: [PATCH 08/58] test if tensor pickling makes it work --- returnn/tensor/_tensor_extra.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/returnn/tensor/_tensor_extra.py b/returnn/tensor/_tensor_extra.py index 5151af3c3..dacf95208 100644 --- a/returnn/tensor/_tensor_extra.py +++ b/returnn/tensor/_tensor_extra.py @@ -588,7 +588,7 @@ def _sis_hash(self): def __getstate__(self): d = {k: getattr(self, k) for k in self.__slots__} - d["_raw_tensor"] = None # do not store the TF tensors + # d["_raw_tensor"] = None # do not store the TF tensors return d def __setstate__(self, state): From 4ecfdfdd71d8a20a606823627c22edf67e57fe84 Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Fri, 19 Sep 2025 23:46:05 +0200 Subject: [PATCH 09/58] late night todo to self: fix draining workers --- returnn/datasets/postprocessing.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index c21f635b2..11b0332ae 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -558,6 +558,16 @@ def __next__(self) -> Optional[TensorDict]: num_workers = len(self.worker_procs) seq = self.worker_procs[self._seq_idx % num_workers].get_seq() if seq is None and self._max_seq_idx is None: + # TODO: this is wrong, we need to drain all workers until each of them is exhausted + # this is not neccessarily the case after just one more iteration through all workers. + # + # Consider the case where the underlying dataset has just 1 seq, but the workers repeat + # this seq 5 times. Then the first worker has 5 seqs, the others have 0. + # Right now we would stop after noticing the first worker that is exhausted, + # but the first one still has 4 more seqs to provide. + # + # just remove them from the list of worker procs when they are exhausted and only stop + # when the list is empty self._max_seq_idx = self._seq_idx + num_workers # drain the other workers self._seq_idx += 1 if seq is not None: From 55abd831711267aceead942bce1b380bc92b68bd Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Fri, 19 Sep 2025 23:55:09 +0200 Subject: [PATCH 10/58] add docs --- returnn/datasets/postprocessing.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index 11b0332ae..3b891c4a5 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -45,8 +45,14 @@ class PostprocessingDataset(CachedDataset2): SpecAugment or speed perturbation into the data loading pipeline. The integration into the data loading pipeline makes it easy to distribute the - data processing work across multiple CPU cores using `MultiProcDataset` and in - turn frees the GPU from data preprocessing tasks. + data processing work across multiple CPU cores and in turn frees the GPU from + data preprocessing tasks. + + Multiprocessing can either be done using ``MultiProcDataset`` or by specifying + the ``num_workers`` parameter of this class. + The latter only applies parallelism to the post-processing functions themselves, + and does not duplicate the underlying dataset once per worker. + This is often fast enough and has the advantage of lower memory consumption. Example usage:: @@ -75,8 +81,8 @@ class PostprocessingDataset(CachedDataset2): The postprocessor functions operate on ``TensorDict``s, which have entries for all data keys in the underlying dataset. - There may also be additional "meta" entries in the tensor dicts, like ``complete_frac`` - and ``seq_tag``. + There may also be additional "meta" entries in the tensor dicts, like ``complete_frac``, + ``seq_idx`` and ``seq_tag``. These should be copied over in a manner that is reasonable for the use case at hand and ensures forwards compatibility as well as reasonably possible. @@ -141,6 +147,7 @@ def __init__( :param map_seq_stream_preserves_num_seqs: whether the function in map_seq_stream preserves the number of sequences, i.e. for every input sequence there is exactly one output sequence. :param buf_size: buffer size for each worker, number of seqs to prefetch. + Must be > 0. Only relevant when num_workers > 0. :param num_workers: number of worker processes to use for data postprocessing. This does not apply parallelism to the wrapped dataset, but only to the postprocessing step. From b700514a427504a6e443945de1ff4da018c6e233 Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Sat, 20 Sep 2025 00:06:20 +0200 Subject: [PATCH 11/58] CI: ensure numpy remains at v1 even in espnet installation --- .github/workflows/main.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 15ebe72fc..73944233f 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -62,7 +62,7 @@ jobs: # Used in some tests. pip install --user --progress-bar=off transformers - pip install --user --progress-bar=off espnet + pip install --user --progress-bar=off espnet "numpy<2" # ensure numpy remains <2 - name: Test Python/Numpy/TF versions. run: | @@ -410,7 +410,7 @@ jobs: # https://github.com/rwth-i6/returnn/issues/1729 pip install --user --progress-bar=off ctc-segmentation==1.6.6 pyworld==0.3.4 fi - pip install --user --progress-bar=off espnet + pip install --user --progress-bar=off espnet "numpy<2" # ensure numpy remains <2 # TorchAudio needed by ESPnet. # https://pytorch.org/audio/stable/installation.html#compatibility-matrix if [[ "${{matrix.torch-version}}" == 2.0.0 ]]; then From 31a541392b887f4db4f994a69de19600bbf28f08 Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Mon, 22 Sep 2025 10:23:08 +0200 Subject: [PATCH 12/58] fix lints --- returnn/datasets/postprocessing.py | 200 ++++++++++++++--------------- 1 file changed, 99 insertions(+), 101 deletions(-) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index 3b891c4a5..0ed5ea101 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -278,7 +278,7 @@ def init_seq_order( dataset_thread=dataset_thread, quit_event=quit_event, worker_procs=worker_procs ) else: - data_iter = self._build_mapping_iter( + data_iter = _build_mapping_iter( self._iterate_dataset(), map_seq=self._map_seq, map_seq_stream=self._map_seq_stream, @@ -391,105 +391,6 @@ def _iterate_dataset(self) -> Iterator[TensorDict]: yield tensor_dict seq_index += 1 - @staticmethod - def _build_mapping_iter( - data_iter: Iterator[TensorDict], - *, - map_seq: Optional[Callable] = None, - map_seq_stream: Optional[Callable] = None, - epoch: int, - out_tensor_dict_template: TensorDict, - rng: RandomState, - seq_list_for_validation: Optional[List[str]] = None, - ) -> Iterator[TensorDict]: - """ - Build an iterator applying the mapping functions on the given dataset iterator. - - :param data_iter: iterator providing data samples in the form of a TensorDict - :param map_seq: see :class:`PostprocessingDataset` - :param map_seq_stream: see :class:`PostprocessingDataset` - :param epoch: current epoch number - :param out_tensor_dict_template: template for the output TensorDicts, used for validation - :param rng: random number generator to use - :param seq_list_for_validation: optional list of seq tags to validate against when processing the data - :return: an iterator applying both the segment level and across-segment transformations on the given dataset - """ - - def _validate_tensor_dict_iter(inner: Iterator[TensorDict]) -> Iterator[TensorDict]: - last_complete_frac = 0.0 - for t_dict in inner: - assert isinstance(t_dict, TensorDict), ( - f"postprocessing mapper function must produce a {TensorDict.__name__}, " - f"but got a {type(t_dict).__name__}" - ) - if "complete_frac" in t_dict.data: # sanity check complete_frac - complete_frac = float(t_dict.data["complete_frac"].raw_tensor) - assert 0.0 <= complete_frac <= 1.0, f"complete_frac must be in [0, 1], but got {complete_frac}" - assert complete_frac >= last_complete_frac, ( - "complete_frac must be monotonically increasing, " - f"but got {complete_frac} after {last_complete_frac}" - ) - last_complete_frac = complete_frac - for data_key, out_t in out_tensor_dict_template.data.items(): - in_t = t_dict.data[data_key] - assert in_t.ndim == out_t.batch_ndim, ( - f"Dim number mismatch for {data_key}: {in_t.ndim} != {out_t.batch_ndim}. " - "Postprocessing data tensors must not have a batch dimension." - ) - assert in_t.dtype == out_t.dtype, ( - f"dtype mismatch for {data_key}: '{in_t.dtype}' != '{out_t.dtype}'" - ) - for i, (in_dim, out_shape) in enumerate(zip(in_t.dims, out_t.shape)): - assert in_dim.dimension is None or in_dim.dimension == out_shape, ( - f"Dim {i} mismatch on {data_key}: " - f"{in_dim.dimension} must either be `None` or equal {out_shape}" - ) - yield t_dict - - def _apply_map_seq(tensor_dict: TensorDict) -> TensorDict: - comp_frac_raw_tensor = ( - tensor_dict.data["complete_frac"].raw_tensor if "complete_frac" in tensor_dict.data else None - ) - seq_index_raw = tensor_dict.data["seq_idx"].raw_tensor - seq_index = int(seq_index_raw.item()) - seq_tag_raw_tensor = tensor_dict.data["seq_tag"].raw_tensor - - tensor_dict = map_seq(tensor_dict, epoch=epoch, seq_idx=seq_index, rng=rng, **util.get_fwd_compat_kwargs()) - assert isinstance(tensor_dict, TensorDict), ( - f"map_seq must produce a {TensorDict.__name__}, but produced {type(tensor_dict).__name__}" - ) - - # Re-adding the complete_frac/seq_idx/seq_tag here causes no harm in case they are dropped - # since we don't add/drop any segments w/ the non-iterator postprocessing function. - if "complete_frac" not in tensor_dict.data and comp_frac_raw_tensor is not None: - tensor_dict.data["complete_frac"] = Tensor( - "complete_frac", dims=(), dtype="float32", raw_tensor=comp_frac_raw_tensor - ) - if "seq_idx" not in tensor_dict.data: - tensor_dict.data["seq_idx"] = Tensor("seq_idx", dims=(), dtype="int32", raw_tensor=seq_index_raw) - if "seq_tag" not in tensor_dict.data: - tensor_dict.data["seq_tag"] = Tensor("seq_tag", dims=(), dtype="string", raw_tensor=seq_tag_raw_tensor) - - if seq_list_for_validation is not None: - seq_tag = seq_list_for_validation[seq_index] - tag_of_seq = tensor_dict.data["seq_tag"].raw_tensor.item() - assert tag_of_seq == seq_tag, ( - f"seq tag mismath: {tag_of_seq} != {seq_tag} for seq index {seq_index} when seq list is given" - ) - - return tensor_dict - - assert map_seq or map_seq_stream, "need to specify either map_seq or map_seq_stream" - assert not (map_seq and map_seq_stream), "cannot set both map_seq and map_seq_stream" - if map_seq is not None: - data_iter = (_apply_map_seq(t_dict) for t_dict in data_iter) - if map_seq_stream is not None: - data_iter = map_seq_stream(data_iter, epoch=epoch, rng=rng, **util.get_fwd_compat_kwargs()) - assert isinstance(data_iter, Iterator), ( - f"map_seq_stream must produce an {Iterator.__name__}, but produced {type(data_iter).__name__}" - ) - return _validate_tensor_dict_iter(data_iter) - def _make_tensor_template_from_input(self, data_key: str) -> Tensor: dtype = self._dataset.get_data_dtype(data_key) if dtype == "string": @@ -533,6 +434,102 @@ def _distribute_seqs_to_children(self, *, child_queues: Sequence[mpQueue], quit_ pass +def _build_mapping_iter( + data_iter: Iterator[TensorDict], + *, + map_seq: Optional[Callable] = None, + map_seq_stream: Optional[Callable] = None, + epoch: int, + out_tensor_dict_template: TensorDict, + rng: RandomState, + seq_list_for_validation: Optional[List[str]] = None, +) -> Iterator[TensorDict]: + """ + Build an iterator applying the mapping functions on the given dataset iterator. + + :param data_iter: iterator providing data samples in the form of a TensorDict + :param map_seq: see :class:`PostprocessingDataset` + :param map_seq_stream: see :class:`PostprocessingDataset` + :param epoch: current epoch number + :param out_tensor_dict_template: template for the output TensorDicts, used for validation + :param rng: random number generator to use + :param seq_list_for_validation: optional list of seq tags to validate against when processing the data + :return: an iterator applying both the segment level and across-segment transformations on the given dataset + """ + + def _validate_tensor_dict_iter(inner: Iterator[TensorDict]) -> Iterator[TensorDict]: + last_complete_frac = 0.0 + for t_dict in inner: + assert isinstance(t_dict, TensorDict), ( + f"postprocessing mapper function must produce a {TensorDict.__name__}, " + f"but got a {type(t_dict).__name__}" + ) + if "complete_frac" in t_dict.data: # sanity check complete_frac + complete_frac = float(t_dict.data["complete_frac"].raw_tensor) + assert 0.0 <= complete_frac <= 1.0, f"complete_frac must be in [0, 1], but got {complete_frac}" + assert complete_frac >= last_complete_frac, ( + "complete_frac must be monotonically increasing, " + f"but got {complete_frac} after {last_complete_frac}" + ) + last_complete_frac = complete_frac + for data_key, out_t in out_tensor_dict_template.data.items(): + in_t = t_dict.data[data_key] + assert in_t.ndim == out_t.batch_ndim, ( + f"Dim number mismatch for {data_key}: {in_t.ndim} != {out_t.batch_ndim}. " + "Postprocessing data tensors must not have a batch dimension." + ) + assert in_t.dtype == out_t.dtype, f"dtype mismatch for {data_key}: '{in_t.dtype}' != '{out_t.dtype}'" + for i, (in_dim, out_shape) in enumerate(zip(in_t.dims, out_t.shape)): + assert in_dim.dimension is None or in_dim.dimension == out_shape, ( + f"Dim {i} mismatch on {data_key}: {in_dim.dimension} must either be `None` or equal {out_shape}" + ) + yield t_dict + + def _apply_map_seq(tensor_dict: TensorDict) -> TensorDict: + comp_frac_raw_tensor = ( + tensor_dict.data["complete_frac"].raw_tensor if "complete_frac" in tensor_dict.data else None + ) + seq_index_raw = tensor_dict.data["seq_idx"].raw_tensor + seq_index = int(seq_index_raw.item()) + seq_tag_raw_tensor = tensor_dict.data["seq_tag"].raw_tensor + + tensor_dict = map_seq(tensor_dict, epoch=epoch, seq_idx=seq_index, rng=rng, **util.get_fwd_compat_kwargs()) + assert isinstance(tensor_dict, TensorDict), ( + f"map_seq must produce a {TensorDict.__name__}, but produced {type(tensor_dict).__name__}" + ) + + # Re-adding the complete_frac/seq_idx/seq_tag here causes no harm in case they are dropped + # since we don't add/drop any segments w/ the non-iterator postprocessing function. + if "complete_frac" not in tensor_dict.data and comp_frac_raw_tensor is not None: + tensor_dict.data["complete_frac"] = Tensor( + "complete_frac", dims=(), dtype="float32", raw_tensor=comp_frac_raw_tensor + ) + if "seq_idx" not in tensor_dict.data: + tensor_dict.data["seq_idx"] = Tensor("seq_idx", dims=(), dtype="int32", raw_tensor=seq_index_raw) + if "seq_tag" not in tensor_dict.data: + tensor_dict.data["seq_tag"] = Tensor("seq_tag", dims=(), dtype="string", raw_tensor=seq_tag_raw_tensor) + + if seq_list_for_validation is not None: + seq_tag = seq_list_for_validation[seq_index] + tag_of_seq = tensor_dict.data["seq_tag"].raw_tensor.item() + assert tag_of_seq == seq_tag, ( + f"seq tag mismath: {tag_of_seq} != {seq_tag} for seq index {seq_index} when seq list is given" + ) + + return tensor_dict + + assert map_seq or map_seq_stream, "need to specify either map_seq or map_seq_stream" + assert not (map_seq and map_seq_stream), "cannot set both map_seq and map_seq_stream" + if map_seq is not None: + data_iter = (_apply_map_seq(t_dict) for t_dict in data_iter) + if map_seq_stream is not None: + data_iter = map_seq_stream(data_iter, epoch=epoch, rng=rng, **util.get_fwd_compat_kwargs()) + assert isinstance(data_iter, Iterator), ( + f"map_seq_stream must produce an {Iterator.__name__}, but produced {type(data_iter).__name__}" + ) + return _validate_tensor_dict_iter(data_iter) + + class _MultiProcDataIter: """ Data iter that pulls from the worker processes and manages their lifetime. @@ -585,6 +582,7 @@ def __next__(self) -> Optional[TensorDict]: raise StopIteration def stop(self): + """Stop the worker processes and the dataset thread.""" if self.quit_event.is_set(): return self.quit_event.set() @@ -706,7 +704,7 @@ def _iter_queue(q: mpQueue) -> Iterator[TensorDict]: break yield item - data_iter = PostprocessingDataset._build_mapping_iter( + data_iter = _build_mapping_iter( _iter_queue(seq_queue), map_seq=map_seq, map_seq_stream=map_seq_stream, From 4aa13495f54857f8ef2d9a55c236331872cea5d4 Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Mon, 22 Sep 2025 10:23:28 +0200 Subject: [PATCH 13/58] optimize --- returnn/datasets/postprocessing.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index 0ed5ea101..5138dff22 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -717,6 +717,8 @@ def _iter_queue(q: mpQueue) -> Iterator[TensorDict]: def _add_to_cache(): nonlocal data_iter + if data_iter is None: + return False try: seq = next(data_iter) except StopIteration: @@ -728,14 +730,14 @@ def _add_to_cache(): try: while True: while not parent_conn.poll(): - while data_iter is not None and len(cache) < buffer_size: + while len(cache) < buffer_size: if not _add_to_cache(): break msg, kwargs = parent_conn.recv() if msg == "exit": break elif msg == "get_seq": - if not cache and data_iter is not None: + if not cache: _add_to_cache() parent_conn.send(("seq", cache.popleft() if cache else None)) else: From 69a88f382b612f14c50d9ae14a28f90ec7eb920e Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Mon, 22 Sep 2025 10:51:48 +0200 Subject: [PATCH 14/58] clean implementation --- returnn/datasets/postprocessing.py | 102 ++++++++++++++--------------- 1 file changed, 49 insertions(+), 53 deletions(-) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index 5138dff22..c14199755 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -182,8 +182,6 @@ def __init__( assert map_seq_stream_preserves_num_seqs is None or isinstance(map_seq_stream_preserves_num_seqs, bool) self._map_seq_stream_preserves_num_seqs = map_seq_stream_preserves_num_seqs self._map_outputs = map_outputs - self._buf_size = buf_size - self._num_workers = num_workers self._seq_list_for_validation: Optional[List[str]] = None self._dataset = init_dataset(self._dataset_def, parent_dataset=self) @@ -193,6 +191,12 @@ def __init__( self._data_iter: Optional[Iterator[Tuple[int, TensorDict]]] = None self._data_iter_produced_num_seqs = 0 + self._buf_size = buf_size + # Ensures only one feeder thread at a time accesses the wrapped dataset. Only used if num_workers > 0. + self._dataset_lock: Optional[threading.Lock] = None + self._multi_proc_data_iter: Optional[_MultiProcDataIter] = None + self._num_workers = num_workers + self._in_tensor_dict_template = TensorDict( {name: self._make_tensor_template_from_input(name) for name in self._dataset.get_data_keys()} ) @@ -252,6 +256,8 @@ def init_seq_order( self._dataset.init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order) if self._num_workers > 0: assert self._buf_size > 0 + if self._dataset_lock is None: + self._dataset_lock = threading.Lock() seq_queues = [_mp.Queue(maxsize=self._buf_size) for _ in range(self._num_workers)] worker_procs = [ _WorkerProcParent( @@ -270,11 +276,13 @@ def init_seq_order( quit_event = threading.Event() dataset_thread = threading.Thread( target=self._distribute_seqs_to_children, - kwargs={"child_queues": seq_queues, "quit_event": quit_event}, + kwargs={"child_queues": seq_queues, "dataset_lock": self._dataset_lock, "quit_event": quit_event}, name=f"{self.__class__.__name__} {self.name} dataset ep {epoch}", ) dataset_thread.start() - data_iter = _MultiProcDataIter( + if self._multi_proc_data_iter is not None: + self._multi_proc_data_iter.stop() + data_iter = self._multi_proc_data_iter = _MultiProcDataIter( dataset_thread=dataset_thread, quit_event=quit_event, worker_procs=worker_procs ) else: @@ -409,29 +417,25 @@ def _make_tensor_template_from_input(self, data_key: str) -> Tensor: sparse_dim.vocab = Vocabulary.create_vocab_from_labels(self._dataset.labels[data_key]) return Tensor(data_key, dims=dims, dtype=dtype, sparse_dim=sparse_dim) - def _distribute_seqs_to_children(self, *, child_queues: Sequence[mpQueue], quit_event: threading.Event): - num_workers = len(child_queues) - assert num_workers > 0 - - # TODO: should we hold a lock around the dataset while this thread is alive? - # - # This would help prevent issues when switching from one epoch to the next - # (where a new thread will be started). - - for seq_idx, tensor_dict in enumerate(self._iterate_dataset()): - if quit_event.is_set(): - break - try: - child_queues[seq_idx % num_workers].put(tensor_dict, block=True) # block for backpressure - except ValueError: - # queue is closed, i.e. the worker process died - break - for q in child_queues: - try: - q.put(None, block=True) # signal end of data - except ValueError: - # queue is already closed, i.e. the worker process died - pass + def _distribute_seqs_to_children( + self, *, child_queues: Sequence[mpQueue], dataset_lock: threading.Lock, quit_event: threading.Event + ): + assert len(child_queues) > 0 + with dataset_lock: + for seq_idx, tensor_dict in enumerate(self._iterate_dataset()): + if quit_event.is_set(): + break + try: + child_queues[seq_idx % len(child_queues)].put(tensor_dict, block=True) # block for backpressure + except ValueError: + # queue is closed, i.e. the worker process crashed for some reason -> stop + break + for q in child_queues: + try: + q.put(None, block=True) # signal end of data + except ValueError: + # queue is already closed, i.e. the worker process died + pass def _build_mapping_iter( @@ -544,10 +548,11 @@ def __init__( ): self.dataset_thread = dataset_thread self.quit_event = quit_event + assert len(worker_procs) > 0 self.worker_procs = worker_procs - self._max_seq_idx: Optional[int] = None - self._seq_idx = 0 + self._workers_exhausted = [False for _ in range(len(worker_procs))] + self._worker_idx = 0 def __iter__(self): return self @@ -556,44 +561,35 @@ def __next__(self) -> Optional[TensorDict]: if self.quit_event.is_set(): raise StopIteration - for _ in range(len(self.worker_procs)): - if self._max_seq_idx is not None and self._seq_idx >= self._max_seq_idx: - break - num_workers = len(self.worker_procs) - seq = self.worker_procs[self._seq_idx % num_workers].get_seq() - if seq is None and self._max_seq_idx is None: - # TODO: this is wrong, we need to drain all workers until each of them is exhausted - # this is not neccessarily the case after just one more iteration through all workers. - # - # Consider the case where the underlying dataset has just 1 seq, but the workers repeat - # this seq 5 times. Then the first worker has 5 seqs, the others have 0. - # Right now we would stop after noticing the first worker that is exhausted, - # but the first one still has 4 more seqs to provide. - # - # just remove them from the list of worker procs when they are exhausted and only stop - # when the list is empty - self._max_seq_idx = self._seq_idx + num_workers # drain the other workers - self._seq_idx += 1 - if seq is not None: - return seq + while True: + seq = self.worker_procs[self._worker_idx].get_seq() + self._worker_idx = (self._worker_idx + 1) % len(self.worker_procs) + if seq is None: + self._workers_exhausted[self._worker_idx] = True + if all(self._workers_exhausted): + break + else: + continue + return seq # when we reach this point, all workers are exhausted and we stop self.stop() raise StopIteration - def stop(self): + def stop(self, *, join=True): """Stop the worker processes and the dataset thread.""" if self.quit_event.is_set(): return self.quit_event.set() for wp in self.worker_procs: - wp.exit(join=True) - util.try_run(self.dataset_thread.join) + wp.exit(join=join) + if join: + util.try_run(self.dataset_thread.join) def __del__(self): # noinspection PyBroadException try: - self.stop() + self.stop(join=False) except Exception: pass From 72152bf342e3b0a269d93a12df5270f4a5417067 Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Mon, 22 Sep 2025 11:07:21 +0200 Subject: [PATCH 15/58] init wrapped dataset in feeder thread --- returnn/datasets/postprocessing.py | 34 ++++++++++++++++++++++++------ 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index c14199755..3d8dba389 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -252,8 +252,6 @@ def init_seq_order( self._num_seqs = 0 return True - assert self._dataset is not None - self._dataset.init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order) if self._num_workers > 0: assert self._buf_size > 0 if self._dataset_lock is None: @@ -274,9 +272,17 @@ def init_seq_order( for i, seq_queue in enumerate(seq_queues) ] quit_event = threading.Event() + # dataset thread takes care of init_seq_order of the wrapped dataset dataset_thread = threading.Thread( - target=self._distribute_seqs_to_children, - kwargs={"child_queues": seq_queues, "dataset_lock": self._dataset_lock, "quit_event": quit_event}, + target=self._init_and_distribute_seqs_to_children, + kwargs={ + "child_queues": seq_queues, + "dataset_lock": self._dataset_lock, + "epoch": epoch, + "quit_event": quit_event, + "seq_list": seq_list, + "seq_order": seq_order, + }, name=f"{self.__class__.__name__} {self.name} dataset ep {epoch}", ) dataset_thread.start() @@ -286,6 +292,8 @@ def init_seq_order( dataset_thread=dataset_thread, quit_event=quit_event, worker_procs=worker_procs ) else: + assert self._dataset is not None + self._dataset.init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order) data_iter = _build_mapping_iter( self._iterate_dataset(), map_seq=self._map_seq, @@ -417,11 +425,25 @@ def _make_tensor_template_from_input(self, data_key: str) -> Tensor: sparse_dim.vocab = Vocabulary.create_vocab_from_labels(self._dataset.labels[data_key]) return Tensor(data_key, dims=dims, dtype=dtype, sparse_dim=sparse_dim) - def _distribute_seqs_to_children( - self, *, child_queues: Sequence[mpQueue], dataset_lock: threading.Lock, quit_event: threading.Event + def _init_and_distribute_seqs_to_children( + self, + *, + child_queues: Sequence[mpQueue], + dataset_lock: threading.Lock, + epoch: int, + quit_event: threading.Event, + seq_list: Optional[List[str]] = None, + seq_order: Optional[List[int]] = None, ): assert len(child_queues) > 0 + + # Lock ensures that only one thread at a time accesses the wrapped dataset. + # + # This protects against issues while moving from one epoch to the next. with dataset_lock: + assert self._dataset is not None + self._dataset.init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order) + for seq_idx, tensor_dict in enumerate(self._iterate_dataset()): if quit_event.is_set(): break From 898760828db710187421bb6f9942702fde538f69 Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Mon, 22 Sep 2025 11:34:34 +0200 Subject: [PATCH 16/58] cleaner code --- returnn/datasets/postprocessing.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index 3d8dba389..7885d4e92 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -584,15 +584,16 @@ def __next__(self) -> Optional[TensorDict]: raise StopIteration while True: - seq = self.worker_procs[self._worker_idx].get_seq() + if all(self._workers_exhausted): + break + worker_idx = self._worker_idx self._worker_idx = (self._worker_idx + 1) % len(self.worker_procs) - if seq is None: - self._workers_exhausted[self._worker_idx] = True - if all(self._workers_exhausted): - break - else: - continue - return seq + if self._workers_exhausted[worker_idx]: + continue + seq = self.worker_procs[worker_idx].get_seq() + if seq is not None: + return seq + self._workers_exhausted[worker_idx] = True # when we reach this point, all workers are exhausted and we stop self.stop() From 22455dc1f3c67d6941a1cbe3d6a5d87ad53acf5a Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Mon, 22 Sep 2025 11:55:37 +0200 Subject: [PATCH 17/58] pass correct epoch, but multiply worker count and worker index in via constants --- returnn/datasets/postprocessing.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index 7885d4e92..9b384d5ca 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -265,7 +265,8 @@ def init_seq_order( map_seq=self._map_seq, map_seq_stream=self._map_seq_stream, out_tensor_dict_template=self._out_tensor_dict_template, - rng_seed=self._get_random_seed_for_epoch(epoch=epoch * self._num_workers + i), + rng_seed=self._get_random_seed_for_epoch(epoch=epoch) * self._num_workers * 6838594027 + + i * 30411167, seq_list=seq_list, seq_queue=seq_queue, ) From fb8f31770bff2f93cfc241ae019f7ca76d655d38 Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Mon, 22 Sep 2025 13:25:07 +0200 Subject: [PATCH 18/58] limit RNG seed range --- returnn/datasets/postprocessing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index 9b384d5ca..52257a234 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -257,6 +257,7 @@ def init_seq_order( if self._dataset_lock is None: self._dataset_lock = threading.Lock() seq_queues = [_mp.Queue(maxsize=self._buf_size) for _ in range(self._num_workers)] + base_rng_seed = self._get_random_seed_for_epoch(epoch=epoch) * 6838594027 * self._num_workers worker_procs = [ _WorkerProcParent( name=f"{self.__class__.__name__} {self.name} ep {epoch}", @@ -265,8 +266,7 @@ def init_seq_order( map_seq=self._map_seq, map_seq_stream=self._map_seq_stream, out_tensor_dict_template=self._out_tensor_dict_template, - rng_seed=self._get_random_seed_for_epoch(epoch=epoch) * self._num_workers * 6838594027 - + i * 30411167, + rng_seed=(base_rng_seed + 30411167 * i) % (2**32 - 1), seq_list=seq_list, seq_queue=seq_queue, ) From 27d76ddcec96c874c0df75487bc08cc1a4845dce Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Mon, 22 Sep 2025 16:30:37 +0200 Subject: [PATCH 19/58] move multi proc implementation into subclass --- returnn/datasets/postprocessing.py | 305 ++++++++++++++++------------- tests/test_Dataset.py | 4 +- 2 files changed, 171 insertions(+), 138 deletions(-) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index 52257a234..736b077a6 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -20,7 +20,7 @@ from returnn.util import basic as util, better_exchook from returnn.util.multi_proc_non_daemonic_spawn import NonDaemonicSpawnContext from returnn.config import SubProcCopyGlobalConfigPreInitFunc -from .basic import init_dataset +from .basic import Dataset, init_dataset from .cached2 import CachedDataset2 # noinspection PyProtectedMember @@ -32,7 +32,7 @@ _mp = NonDaemonicSpawnContext(process_pre_init_func=SubProcCopyGlobalConfigPreInitFunc()) -__all__ = ["PostprocessingDataset", "LaplaceOrdering", "Sequential"] +__all__ = ["PostprocessingDataset", "MultiProcPostprocessingDataset", "LaplaceOrdering", "Sequential"] class PostprocessingDataset(CachedDataset2): @@ -48,8 +48,8 @@ class PostprocessingDataset(CachedDataset2): data processing work across multiple CPU cores and in turn frees the GPU from data preprocessing tasks. - Multiprocessing can either be done using ``MultiProcDataset`` or by specifying - the ``num_workers`` parameter of this class. + Multiprocessing can either be done using :class:``MultiProcDataset`` or via the subclass + :class:``MultiProcPostprocessingDataset``. The latter only applies parallelism to the post-processing functions themselves, and does not duplicate the underlying dataset once per worker. This is often fast enough and has the advantage of lower memory consumption. @@ -119,8 +119,6 @@ def __init__( map_seq_stream: Optional[Callable] = None, map_outputs: Optional[Dict[str, Any]] = None, map_seq_stream_preserves_num_seqs: Optional[bool] = None, - buf_size: int = 1, - num_workers: int = 0, **kwargs, ): """ @@ -146,17 +144,6 @@ def __init__( Example: `map_outputs={"data": {"dim": 42}}` :param map_seq_stream_preserves_num_seqs: whether the function in map_seq_stream preserves the number of sequences, i.e. for every input sequence there is exactly one output sequence. - :param buf_size: buffer size for each worker, number of seqs to prefetch. - Must be > 0. Only relevant when num_workers > 0. - :param num_workers: number of worker processes to use for data postprocessing. - - This does not apply parallelism to the wrapped dataset, but only to the postprocessing step. - - Conceptually, this achieves similar results to using MultiProcDataset, but with potentially lower - memory consumption, since only the postprocessing step is parallelized and not the wrapped/underlying - source dataset that is postprocessed. - - When set to 0, postprocessing happens inline. :param kwargs: see :class:`CachedDataset2`, :class:`Dataset` """ super().__init__(**kwargs) @@ -169,10 +156,6 @@ def __init__( raise ValueError(f"{self}: cannot set both map_seq and map_seq_stream") if map_seq and map_seq_stream_preserves_num_seqs is not None: raise ValueError(f"{self}: map_seq_stream_preserves_num_seqs is only allowed with map_seq_stream") - if buf_size < 1: - raise ValueError(f"{self}: buf_size must be >= 1, but got {buf_size}") - if num_workers < 0: - raise ValueError(f"{self}: num_workers must be >= 0, but got {num_workers}") self._dataset_def = dataset self._map_seq = map_seq @@ -191,12 +174,6 @@ def __init__( self._data_iter: Optional[Iterator[Tuple[int, TensorDict]]] = None self._data_iter_produced_num_seqs = 0 - self._buf_size = buf_size - # Ensures only one feeder thread at a time accesses the wrapped dataset. Only used if num_workers > 0. - self._dataset_lock: Optional[threading.Lock] = None - self._multi_proc_data_iter: Optional[_MultiProcDataIter] = None - self._num_workers = num_workers - self._in_tensor_dict_template = TensorDict( {name: self._make_tensor_template_from_input(name) for name in self._dataset.get_data_keys()} ) @@ -252,58 +229,17 @@ def init_seq_order( self._num_seqs = 0 return True - if self._num_workers > 0: - assert self._buf_size > 0 - if self._dataset_lock is None: - self._dataset_lock = threading.Lock() - seq_queues = [_mp.Queue(maxsize=self._buf_size) for _ in range(self._num_workers)] - base_rng_seed = self._get_random_seed_for_epoch(epoch=epoch) * 6838594027 * self._num_workers - worker_procs = [ - _WorkerProcParent( - name=f"{self.__class__.__name__} {self.name} ep {epoch}", - epoch=epoch, - buffer_size=self._buf_size, - map_seq=self._map_seq, - map_seq_stream=self._map_seq_stream, - out_tensor_dict_template=self._out_tensor_dict_template, - rng_seed=(base_rng_seed + 30411167 * i) % (2**32 - 1), - seq_list=seq_list, - seq_queue=seq_queue, - ) - for i, seq_queue in enumerate(seq_queues) - ] - quit_event = threading.Event() - # dataset thread takes care of init_seq_order of the wrapped dataset - dataset_thread = threading.Thread( - target=self._init_and_distribute_seqs_to_children, - kwargs={ - "child_queues": seq_queues, - "dataset_lock": self._dataset_lock, - "epoch": epoch, - "quit_event": quit_event, - "seq_list": seq_list, - "seq_order": seq_order, - }, - name=f"{self.__class__.__name__} {self.name} dataset ep {epoch}", - ) - dataset_thread.start() - if self._multi_proc_data_iter is not None: - self._multi_proc_data_iter.stop() - data_iter = self._multi_proc_data_iter = _MultiProcDataIter( - dataset_thread=dataset_thread, quit_event=quit_event, worker_procs=worker_procs - ) - else: - assert self._dataset is not None - self._dataset.init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order) - data_iter = _build_mapping_iter( - self._iterate_dataset(), - map_seq=self._map_seq, - map_seq_stream=self._map_seq_stream, - epoch=epoch, - out_tensor_dict_template=self._out_tensor_dict_template, - rng=RandomState(self._get_random_seed_for_epoch(epoch=epoch)), - seq_list_for_validation=seq_list, - ) + assert self._dataset is not None + self._dataset.init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order) + data_iter = _build_mapping_iter( + _iterate_dataset(self._dataset, in_tensor_dict_template=self._in_tensor_dict_template), + map_seq=self._map_seq, + map_seq_stream=self._map_seq_stream, + epoch=epoch, + out_tensor_dict_template=self._out_tensor_dict_template, + rng=RandomState(self._get_random_seed_for_epoch(epoch=epoch)), + seq_list_for_validation=seq_list, + ) self._data_iter = enumerate(data_iter) self._data_iter_produced_num_seqs = 0 self._seq_list_for_validation = seq_list @@ -382,32 +318,6 @@ def _collect_single_seq(self, seq_idx: int) -> Optional[DatasetSeq]: seq = DatasetSeq(complete_frac=complete_frac, features=features, seq_idx=seq_idx, seq_tag=seq_tag) return seq - def _iterate_dataset(self) -> Iterator[TensorDict]: - """ - :return: generator providing data samples in the form of a TensorDict - """ - data_keys = self._dataset.get_data_keys() - - seq_index = 0 - while self._dataset.is_less_than_num_seqs(seq_index): - self._dataset.load_seqs(seq_index, seq_index + 1) - - tensor_dict = self._in_tensor_dict_template.copy_template() - for data_key in data_keys: - tensor_dict.data[data_key].raw_tensor = self._dataset.get_data(seq_index, data_key) - - complete_frac = self._dataset.get_complete_frac(seq_index, allow_only_lr_suitable=True) - if complete_frac is not None: - comp_frac_raw_tensor = numpy.array(complete_frac, dtype=numpy.float32) - tensor_dict.data["complete_frac"].raw_tensor = comp_frac_raw_tensor - seq_idx_raw_tensor = numpy.array(seq_index, dtype=numpy.int32) - tensor_dict.data["seq_idx"].raw_tensor = seq_idx_raw_tensor - seq_tag_raw_tensor = str_to_numpy_array(self._dataset.get_tag(seq_index)) - tensor_dict.data["seq_tag"].raw_tensor = seq_tag_raw_tensor - - yield tensor_dict - seq_index += 1 - def _make_tensor_template_from_input(self, data_key: str) -> Tensor: dtype = self._dataset.get_data_dtype(data_key) if dtype == "string": @@ -426,39 +336,32 @@ def _make_tensor_template_from_input(self, data_key: str) -> Tensor: sparse_dim.vocab = Vocabulary.create_vocab_from_labels(self._dataset.labels[data_key]) return Tensor(data_key, dims=dims, dtype=dtype, sparse_dim=sparse_dim) - def _init_and_distribute_seqs_to_children( - self, - *, - child_queues: Sequence[mpQueue], - dataset_lock: threading.Lock, - epoch: int, - quit_event: threading.Event, - seq_list: Optional[List[str]] = None, - seq_order: Optional[List[int]] = None, - ): - assert len(child_queues) > 0 - # Lock ensures that only one thread at a time accesses the wrapped dataset. - # - # This protects against issues while moving from one epoch to the next. - with dataset_lock: - assert self._dataset is not None - self._dataset.init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order) +def _iterate_dataset(dataset: Dataset, *, in_tensor_dict_template: TensorDict) -> Iterator[TensorDict]: + """ + :return: generator providing data samples in the form of a TensorDict + """ + data_keys = dataset.get_data_keys() - for seq_idx, tensor_dict in enumerate(self._iterate_dataset()): - if quit_event.is_set(): - break - try: - child_queues[seq_idx % len(child_queues)].put(tensor_dict, block=True) # block for backpressure - except ValueError: - # queue is closed, i.e. the worker process crashed for some reason -> stop - break - for q in child_queues: - try: - q.put(None, block=True) # signal end of data - except ValueError: - # queue is already closed, i.e. the worker process died - pass + seq_index = 0 + while dataset.is_less_than_num_seqs(seq_index): + dataset.load_seqs(seq_index, seq_index + 1) + + tensor_dict = in_tensor_dict_template.copy_template() + for data_key in data_keys: + tensor_dict.data[data_key].raw_tensor = dataset.get_data(seq_index, data_key) + + complete_frac = dataset.get_complete_frac(seq_index, allow_only_lr_suitable=True) + if complete_frac is not None: + comp_frac_raw_tensor = numpy.array(complete_frac, dtype=numpy.float32) + tensor_dict.data["complete_frac"].raw_tensor = comp_frac_raw_tensor + seq_idx_raw_tensor = numpy.array(seq_index, dtype=numpy.int32) + tensor_dict.data["seq_idx"].raw_tensor = seq_idx_raw_tensor + seq_tag_raw_tensor = str_to_numpy_array(dataset.get_tag(seq_index)) + tensor_dict.data["seq_tag"].raw_tensor = seq_tag_raw_tensor + + yield tensor_dict + seq_index += 1 def _build_mapping_iter( @@ -557,6 +460,136 @@ def _apply_map_seq(tensor_dict: TensorDict) -> TensorDict: return _validate_tensor_dict_iter(data_iter) +class MultiProcPostprocessingDataset(PostprocessingDataset): + """ + Subclass of :class:`PostprocessingDataset` that parallelizes the post-processing using multiple processes. + + The underlying dataset is only instantiated once, only the post-processing functions are parallelized. + + Since it is usually the postprocessing itself and not the data loading from the underlying dataset + that is the bottleneck, it is often sufficient to only parallelize the postprocessing step. + The advantage is that this usually has lower memory consumption than using :class:``MultiProcDataset``. + + The dataset interface is the same as for :class:`PostprocessingDataset`, with two additional parameters + to configure the multi-processing behavior. + """ + + def __init__(self, *args, buf_size: int = 1, num_workers: int = 1, **kwargs): + """ + :param args: Same args as :class:``PostprocessingDataset``. + :param buf_size: Buffer size for each worker, number of seqs to prefetch. Must be > 0. + :param num_workers: Number of worker processes to use for data postprocessing. Must be > 0. + :param kwargs: Same args as :class:``PostprocessingDataset``. + """ + + super().__init__(*args, **kwargs) + + if buf_size < 1: + raise ValueError(f"{self}: buf_size must be > 0, but got {buf_size}") + if num_workers < 1: + raise ValueError(f"{self}: num_workers must be > 0, but got {num_workers}") + + self._buf_size = buf_size + # Ensure only one feeder thread at a time accesses the wrapped dataset to + # prevent race conditions while moving from one epoch to the next. + self._dataset_lock = threading.Lock() + self._multi_proc_data_iter: Optional[_MultiProcDataIter] = None # store for cleanup + self._num_workers = num_workers + + def init_seq_order( + self, epoch: Optional[int] = None, seq_list: Optional[List[str]] = None, seq_order: Optional[List[int]] = None + ): + """ + :param epoch: + :param seq_list: + :param seq_order: + :return: whether the order changed (True is always safe to return) + """ + super().init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order) + + if self._multi_proc_data_iter is not None: + self._multi_proc_data_iter.stop() + self._multi_proc_data_iter = None + + if self._num_seqs == 0: + return True + + assert self._buf_size > 0 + assert self._num_workers > 0 + seq_queues = [_mp.Queue(maxsize=self._buf_size) for _ in range(self._num_workers)] + base_rng_seed = self._get_random_seed_for_epoch(epoch=epoch) * 683859 * self._num_workers + worker_procs = [ + _WorkerProcParent( + name=f"{self.__class__.__name__} {self.name} ep {epoch}", + epoch=epoch, + buffer_size=self._buf_size, + map_seq=self._map_seq, + map_seq_stream=self._map_seq_stream, + out_tensor_dict_template=self._out_tensor_dict_template, + rng_seed=(base_rng_seed + 30411 * i) % (2**32 - 1), + seq_list=seq_list, + seq_queue=seq_queue, + ) + for i, seq_queue in enumerate(seq_queues) + ] + quit_event = threading.Event() + dataset_thread = threading.Thread( + target=self._init_seq_order_and_distribute_seqs_to_children, + kwargs={ + "child_queues": seq_queues, + "dataset_lock": self._dataset_lock, + "epoch": epoch, + "quit_event": quit_event, + "seq_list": seq_list, + "seq_order": seq_order, + }, + name=f"{self.__class__.__name__} {self.name} dataset ep {epoch}", + ) + dataset_thread.start() + data_iter = self._multi_proc_data_iter = _MultiProcDataIter( + dataset_thread=dataset_thread, quit_event=quit_event, worker_procs=worker_procs + ) + self._data_iter = enumerate(data_iter) + + def _init_seq_order_and_distribute_seqs_to_children( + self, + *, + child_queues: Sequence[mpQueue], + dataset_lock: threading.Lock, + epoch: int, + quit_event: threading.Event, + seq_list: Optional[List[str]] = None, + seq_order: Optional[List[int]] = None, + ): + """ + Initialize the wrapped dataset and distribute the contained sequences to the child worker processes. + """ + + assert len(child_queues) > 0 + + # Lock ensures that only one thread at a time accesses the wrapped dataset. + # + # This protects against issues while moving from one epoch to the next. + with dataset_lock: + self._dataset.init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order) + data_iter = _iterate_dataset(self._dataset, in_tensor_dict_template=self._in_tensor_dict_template) + + for seq_idx, tensor_dict in enumerate(data_iter): + if quit_event.is_set(): + break + try: + child_queues[seq_idx % len(child_queues)].put(tensor_dict, block=True) # block for backpressure + except ValueError: + # queue is closed, i.e. the worker process crashed for some reason -> stop + break + for q in child_queues: + try: + q.put(None, block=True) # signal end of data + except ValueError: + # queue is already closed, i.e. the worker process died + pass + + class _MultiProcDataIter: """ Data iter that pulls from the worker processes and manages their lifetime. diff --git a/tests/test_Dataset.py b/tests/test_Dataset.py index 25d68ea6a..be5d6a41f 100644 --- a/tests/test_Dataset.py +++ b/tests/test_Dataset.py @@ -1273,11 +1273,11 @@ def _repeat2(input_iter: Iterator[TensorDict], **kwargs) -> Iterator[TensorDict] yield tdict -def test_PostprocessingDataset_multi_proc(): +def test_MultiProcPostprocessingDataset(): _demo_txt = "some utterance text that has a few words" with create_ogg_zip_txt_only_dataset_opts(text=_demo_txt) as sub_ds_opts: ds_opts = { - "class": "PostprocessingDataset", + "class": "MultiProcPostprocessingDataset", "dataset": sub_ds_opts, "map_seq_stream": _repeat2, "buf_size": 1, From 55bf70105a5dcf823fb3f633e4e9b9ae1af4b15e Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Tue, 23 Sep 2025 14:53:40 +0200 Subject: [PATCH 20/58] test dataset across multiple epochs --- tests/test_Dataset.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/test_Dataset.py b/tests/test_Dataset.py index be5d6a41f..7b5ed34e3 100644 --- a/tests/test_Dataset.py +++ b/tests/test_Dataset.py @@ -1284,13 +1284,15 @@ def test_MultiProcPostprocessingDataset(): "num_workers": 2, } dataset = init_dataset(ds_opts) - dataset.init_seq_order(epoch=1) - assert dataset.have_seqs() - dataset.load_seqs(0, 2) - for i in range(2): - classes = dataset.get_data(i, "classes") - assert len(classes) > 0 + for ep in range(1, 20 + 1): + dataset.init_seq_order(epoch=ep) + assert dataset.have_seqs() + dataset.load_seqs(0, 3) + for i in range(2): + classes = dataset.get_data(i, "classes") + assert len(classes) > 0 + assert not dataset.is_less_than_num_seqs(2) def _post_process_map_seq_no_op(tdict: TensorDict, **_other) -> TensorDict: From 0d828297677e6f5659e0342aa165bfe65fce47d2 Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Wed, 24 Sep 2025 09:42:03 -0400 Subject: [PATCH 21/58] switch to pipe --- returnn/datasets/postprocessing.py | 73 ++++++++++++++++++++---------- tests/test_Dataset.py | 2 +- 2 files changed, 49 insertions(+), 26 deletions(-) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index 736b077a6..c03debb1f 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -12,6 +12,7 @@ import threading from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, TypeVar +from returnn.config import SubProcCopyGlobalConfigPreInitFunc from returnn.datasets.basic import DatasetSeq from returnn.datasets.util.strings import str_to_numpy_array from returnn.datasets.util.vocabulary import Vocabulary @@ -19,16 +20,12 @@ from returnn.tensor.dim import Dim from returnn.util import basic as util, better_exchook from returnn.util.multi_proc_non_daemonic_spawn import NonDaemonicSpawnContext -from returnn.config import SubProcCopyGlobalConfigPreInitFunc from .basic import Dataset, init_dataset from .cached2 import CachedDataset2 # noinspection PyProtectedMember from multiprocessing.connection import Connection as mpConnection -# noinspection PyProtectedMember -from multiprocessing.queues import Queue as mpQueue - _mp = NonDaemonicSpawnContext(process_pre_init_func=SubProcCopyGlobalConfigPreInitFunc()) @@ -516,27 +513,29 @@ def init_seq_order( assert self._buf_size > 0 assert self._num_workers > 0 - seq_queues = [_mp.Queue(maxsize=self._buf_size) for _ in range(self._num_workers)] + parent_conns, child_conns = zip(*[_mp.Pipe() for _ in range(self._num_workers)]) base_rng_seed = self._get_random_seed_for_epoch(epoch=epoch) * 683859 * self._num_workers worker_procs = [ _WorkerProcParent( name=f"{self.__class__.__name__} {self.name} ep {epoch}", epoch=epoch, buffer_size=self._buf_size, + index=i, map_seq=self._map_seq, map_seq_stream=self._map_seq_stream, out_tensor_dict_template=self._out_tensor_dict_template, rng_seed=(base_rng_seed + 30411 * i) % (2**32 - 1), seq_list=seq_list, - seq_queue=seq_queue, + seq_pipe=child_conn, ) - for i, seq_queue in enumerate(seq_queues) + for i, child_conn in enumerate(child_conns) ] quit_event = threading.Event() dataset_thread = threading.Thread( target=self._init_seq_order_and_distribute_seqs_to_children, kwargs={ - "child_queues": seq_queues, + "buf_size": self._buf_size, + "child_queues": parent_conns, "dataset_lock": self._dataset_lock, "epoch": epoch, "quit_event": quit_event, @@ -554,7 +553,8 @@ def init_seq_order( def _init_seq_order_and_distribute_seqs_to_children( self, *, - child_queues: Sequence[mpQueue], + buf_size: int, + child_queues: Sequence[mpConnection], dataset_lock: threading.Lock, epoch: int, quit_event: threading.Event, @@ -565,6 +565,7 @@ def _init_seq_order_and_distribute_seqs_to_children( Initialize the wrapped dataset and distribute the contained sequences to the child worker processes. """ + assert buf_size > 0 assert len(child_queues) > 0 # Lock ensures that only one thread at a time accesses the wrapped dataset. @@ -573,21 +574,27 @@ def _init_seq_order_and_distribute_seqs_to_children( with dataset_lock: self._dataset.init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order) data_iter = _iterate_dataset(self._dataset, in_tensor_dict_template=self._in_tensor_dict_template) + data_iter = enumerate(data_iter) - for seq_idx, tensor_dict in enumerate(data_iter): + for seq_idx, tensor_dict in data_iter: if quit_event.is_set(): break + worker_idx = seq_idx % len(child_queues) try: - child_queues[seq_idx % len(child_queues)].put(tensor_dict, block=True) # block for backpressure - except ValueError: + msg, _ = child_queues[worker_idx].recv() + assert msg == "get_seq" + child_queues[worker_idx].send(("seq", tensor_dict)) + except EOFError: # queue is closed, i.e. the worker process crashed for some reason -> stop break for q in child_queues: try: - q.put(None, block=True) # signal end of data - except ValueError: + q.send(("exit", None)) # signal end of data + except (BrokenPipeError, EOFError): # queue is already closed, i.e. the worker process died pass + finally: + q.close() class _MultiProcDataIter: @@ -657,30 +664,32 @@ def __init__( *, name: str, epoch: int, + index: int, buffer_size: int, map_seq: Optional[Callable], map_seq_stream: Optional[Callable], out_tensor_dict_template: TensorDict, rng_seed: int, seq_list: Optional[List[str]], - seq_queue: mpQueue, + seq_pipe: mpConnection, ): parent_conn, child_conn = _mp.Pipe() self.parent_conn = parent_conn self.worker_proc = _mp.Process( - name=f"{name} worker ep {epoch}", + name=f"{name} worker {index} ep {epoch}", target=_worker_proc_loop, args=( epoch, buffer_size, + index, map_seq, map_seq_stream, out_tensor_dict_template, rng_seed, seq_list, child_conn, - seq_queue, + seq_pipe, ), daemon=True, ) @@ -693,6 +702,10 @@ def __init__( # otherwise it would just hang. child_conn.close() + # seq_pipe is owned by the child process, + # and so must be closed in the parent to avoid hangs + seq_pipe.close() + def get_seq(self) -> Optional[DatasetSeq]: """get_seq""" self.parent_conn.send(("get_seq", {})) @@ -719,21 +732,23 @@ def __del__(self): def _worker_proc_loop( epoch: int, buffer_size: int, + index: int, map_seq: Optional[Callable], map_seq_stream: Optional[Callable], out_tensor_dict_template: TensorDict, rng_seed: int, seq_list: Optional[List[str]], parent_conn: mpConnection, - seq_queue: mpQueue, + feeder_conn: mpConnection, ): if sys.platform == "linux": with open("/proc/self/comm", "w") as f: - f.write(f"PP worker {epoch}") + f.write(f"PP worker {index} ep {epoch}") better_exchook.setup_all() assert isinstance(epoch, int) assert isinstance(buffer_size, int) + assert isinstance(index, int) assert buffer_size > 0 assert map_seq or map_seq_stream, "need to specify either map_seq or map_seq_stream" assert not (map_seq and map_seq_stream), "cannot set both map_seq and map_seq_stream" @@ -742,23 +757,28 @@ def _worker_proc_loop( assert isinstance(out_tensor_dict_template, TensorDict) assert isinstance(rng_seed, int) assert isinstance(parent_conn, mpConnection) - assert isinstance(seq_queue, mpQueue) + assert isinstance(feeder_conn, mpConnection) cache: deque[TensorDict] = deque() - def _iter_queue(q: mpQueue) -> Iterator[TensorDict]: + def _iter_pipe(q: mpConnection) -> Iterator[TensorDict]: + assert isinstance(q, mpConnection) + while True: try: - item = q.get(block=True) - except ValueError: + q.send(("get_seq", None)) + msg, item = q.recv() + except (BrokenPipeError, EOFError): # queue is closed break - if item is None: + assert msg in ("seq", "exit") + if msg == "exit" or item is None: break + assert isinstance(item, TensorDict) yield item data_iter = _build_mapping_iter( - _iter_queue(seq_queue), + _iter_pipe(feeder_conn), map_seq=map_seq, map_seq_stream=map_seq_stream, epoch=epoch, @@ -799,6 +819,9 @@ def _add_to_cache(): pass except EOFError: # when parent dies pass + finally: + feeder_conn.close() + parent_conn.close() class LaplaceOrdering(Callable[[Iterator[TensorDict]], Iterator[TensorDict]]): diff --git a/tests/test_Dataset.py b/tests/test_Dataset.py index 7b5ed34e3..80d20e00a 100644 --- a/tests/test_Dataset.py +++ b/tests/test_Dataset.py @@ -1285,7 +1285,7 @@ def test_MultiProcPostprocessingDataset(): } dataset = init_dataset(ds_opts) - for ep in range(1, 20 + 1): + for ep in range(1, 1 + 1): dataset.init_seq_order(epoch=ep) assert dataset.have_seqs() dataset.load_seqs(0, 3) From 3908b7b6de435cd002d39e2c1f63f4340178fb32 Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Thu, 25 Sep 2025 14:24:29 +0200 Subject: [PATCH 22/58] test across multiple epochs --- tests/test_Dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_Dataset.py b/tests/test_Dataset.py index 80d20e00a..cf39ca188 100644 --- a/tests/test_Dataset.py +++ b/tests/test_Dataset.py @@ -1285,7 +1285,7 @@ def test_MultiProcPostprocessingDataset(): } dataset = init_dataset(ds_opts) - for ep in range(1, 1 + 1): + for ep in range(1, 5 + 1): dataset.init_seq_order(epoch=ep) assert dataset.have_seqs() dataset.load_seqs(0, 3) From 39047f17d64867e994569d4e4eabe4a1612f7c76 Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Thu, 25 Sep 2025 14:24:42 +0200 Subject: [PATCH 23/58] re-add cache to feeder thread --- returnn/datasets/postprocessing.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index c03debb1f..57a94d096 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -8,6 +8,7 @@ from itertools import islice import numpy from numpy.random import RandomState +import select import sys import threading from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, TypeVar @@ -568,17 +569,34 @@ def _init_seq_order_and_distribute_seqs_to_children( assert buf_size > 0 assert len(child_queues) > 0 + def _any_q_ready() -> bool: + ready, _, _ = select.select(child_queues, [], []) + return len(ready) > 0 + + cache: deque[Tuple[int, TensorDict]] = deque() + # Lock ensures that only one thread at a time accesses the wrapped dataset. - # # This protects against issues while moving from one epoch to the next. with dataset_lock: self._dataset.init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order) data_iter = _iterate_dataset(self._dataset, in_tensor_dict_template=self._in_tensor_dict_template) data_iter = enumerate(data_iter) - for seq_idx, tensor_dict in data_iter: - if quit_event.is_set(): + def _add_to_cache() -> bool: + try: + cache.append(next(data_iter)) + return True + except StopIteration: + return False + + while not quit_event.is_set(): + while len(cache) < buf_size - 1 and not _any_q_ready(): + if not _add_to_cache(): + break + _add_to_cache() + if not cache: break + seq_idx, tensor_dict = cache.popleft() worker_idx = seq_idx % len(child_queues) try: msg, _ = child_queues[worker_idx].recv() @@ -587,6 +605,7 @@ def _init_seq_order_and_distribute_seqs_to_children( except EOFError: # queue is closed, i.e. the worker process crashed for some reason -> stop break + for q in child_queues: try: q.send(("exit", None)) # signal end of data From 7584a0811bf7e119f5c7f72178508f6affe62891 Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Thu, 25 Sep 2025 14:25:16 +0200 Subject: [PATCH 24/58] also catch broken pipe error --- returnn/datasets/postprocessing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index 57a94d096..dd5ae7c2b 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -602,7 +602,7 @@ def _add_to_cache() -> bool: msg, _ = child_queues[worker_idx].recv() assert msg == "get_seq" child_queues[worker_idx].send(("seq", tensor_dict)) - except EOFError: + except (BrokenPipeError, EOFError): # queue is closed, i.e. the worker process crashed for some reason -> stop break From e0e17c9a10d1107528bb4660ddf701670db4d9e6 Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Thu, 25 Sep 2025 14:25:49 +0200 Subject: [PATCH 25/58] we can actually signal exit outside of the dataset lock --- returnn/datasets/postprocessing.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index dd5ae7c2b..9fa98a1f7 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -606,14 +606,14 @@ def _add_to_cache() -> bool: # queue is closed, i.e. the worker process crashed for some reason -> stop break - for q in child_queues: - try: - q.send(("exit", None)) # signal end of data - except (BrokenPipeError, EOFError): - # queue is already closed, i.e. the worker process died - pass - finally: - q.close() + for q in child_queues: + try: + q.send(("exit", None)) # signal end of data + except (BrokenPipeError, EOFError): + # queue is already closed, i.e. the worker process died + pass + finally: + q.close() class _MultiProcDataIter: From bc6b0256bd74b63ecb74b180395945cc9956ff9d Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Thu, 25 Sep 2025 14:26:55 +0200 Subject: [PATCH 26/58] fix while loop nesting --- returnn/datasets/postprocessing.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index 9fa98a1f7..b247ad051 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -821,10 +821,9 @@ def _add_to_cache(): try: while True: - while not parent_conn.poll(): - while len(cache) < buffer_size: - if not _add_to_cache(): - break + while len(cache) < buffer_size and not parent_conn.poll(): + if not _add_to_cache(): + break msg, kwargs = parent_conn.recv() if msg == "exit": break From 576c116aa5893f3da005edeb28c0ac77108c845e Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Thu, 25 Sep 2025 14:37:49 +0200 Subject: [PATCH 27/58] re-merge into single class --- returnn/datasets/postprocessing.py | 310 +++++++++++++---------------- tests/test_Dataset.py | 2 +- 2 files changed, 139 insertions(+), 173 deletions(-) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index b247ad051..6847a66e6 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -30,7 +30,7 @@ _mp = NonDaemonicSpawnContext(process_pre_init_func=SubProcCopyGlobalConfigPreInitFunc()) -__all__ = ["PostprocessingDataset", "MultiProcPostprocessingDataset", "LaplaceOrdering", "Sequential"] +__all__ = ["PostprocessingDataset", "LaplaceOrdering", "Sequential"] class PostprocessingDataset(CachedDataset2): @@ -46,8 +46,9 @@ class PostprocessingDataset(CachedDataset2): data processing work across multiple CPU cores and in turn frees the GPU from data preprocessing tasks. - Multiprocessing can either be done using :class:``MultiProcDataset`` or via the subclass - :class:``MultiProcPostprocessingDataset``. + Multiprocessing can either be done using :class:``MultiProcDataset`` or by setting + `num_workers > 0` on this class. + The latter only applies parallelism to the post-processing functions themselves, and does not duplicate the underlying dataset once per worker. This is often fast enough and has the advantage of lower memory consumption. @@ -117,6 +118,8 @@ def __init__( map_seq_stream: Optional[Callable] = None, map_outputs: Optional[Dict[str, Any]] = None, map_seq_stream_preserves_num_seqs: Optional[bool] = None, + buf_size: int = 1, + num_workers: int = 1, **kwargs, ): """ @@ -142,6 +145,11 @@ def __init__( Example: `map_outputs={"data": {"dim": 42}}` :param map_seq_stream_preserves_num_seqs: whether the function in map_seq_stream preserves the number of sequences, i.e. for every input sequence there is exactly one output sequence. + :param buf_size: Buffer size for each worker, number of seqs to prefetch. Must be > 0. + :param num_workers: If > 0, configures the number of worker processes to use for data postprocessing. + Only the postprocessing is distributed across subprocesses, + the underlying dataset is only instantiated once. + This usually has lower memory consumption than using :class:``MultiProcDataset``. :param kwargs: see :class:`CachedDataset2`, :class:`Dataset` """ super().__init__(**kwargs) @@ -155,6 +163,11 @@ def __init__( if map_seq and map_seq_stream_preserves_num_seqs is not None: raise ValueError(f"{self}: map_seq_stream_preserves_num_seqs is only allowed with map_seq_stream") + if buf_size < 1: + raise ValueError(f"{self}: buf_size must be > 0, but got {buf_size}") + if num_workers < 0: + raise ValueError(f"{self}: num_workers must be >= 0, but got {num_workers}") + self._dataset_def = dataset self._map_seq = map_seq self._map_seq_stream = map_seq_stream @@ -165,6 +178,13 @@ def __init__( self._map_outputs = map_outputs self._seq_list_for_validation: Optional[List[str]] = None + self._buf_size = buf_size + # Ensure only one feeder thread at a time accesses the wrapped dataset to + # prevent race conditions while moving from one epoch to the next. + self._dataset_lock = threading.Lock() + self._multi_proc_data_iter: Optional[_MultiProcDataIter] = None # store for cleanup + self._num_workers = num_workers + self._dataset = init_dataset(self._dataset_def, parent_dataset=self) if self._map_seq_stream is None or self._map_seq_stream_preserves_num_seqs is True: # if the stream mapper is set, the num_seqs may change and the estimation is less accurate @@ -223,21 +243,61 @@ def init_seq_order( if seq_order is not None: raise ValueError("map_seq_stream is set, cannot specify custom seq_order") + if self._multi_proc_data_iter is not None: + self._multi_proc_data_iter.stop() + self._multi_proc_data_iter = None + if epoch is None and seq_list is None and seq_order is None: self._num_seqs = 0 return True - assert self._dataset is not None - self._dataset.init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order) - data_iter = _build_mapping_iter( - _iterate_dataset(self._dataset, in_tensor_dict_template=self._in_tensor_dict_template), - map_seq=self._map_seq, - map_seq_stream=self._map_seq_stream, - epoch=epoch, - out_tensor_dict_template=self._out_tensor_dict_template, - rng=RandomState(self._get_random_seed_for_epoch(epoch=epoch)), - seq_list_for_validation=seq_list, - ) + if self._num_workers > 0: + parent_conns, child_conns = zip(*[_mp.Pipe() for _ in range(self._num_workers)]) + base_rng_seed = self._get_random_seed_for_epoch(epoch=epoch) * 683859 * self._num_workers + worker_procs = [ + _WorkerProcParent( + name=f"{self.__class__.__name__} {self.name} ep {epoch}", + epoch=epoch, + buffer_size=self._buf_size, + index=i, + map_seq=self._map_seq, + map_seq_stream=self._map_seq_stream, + out_tensor_dict_template=self._out_tensor_dict_template, + rng_seed=(base_rng_seed + 30411 * i) % (2**32 - 1), + seq_list=seq_list, + seq_pipe=child_conn, + ) + for i, child_conn in enumerate(child_conns) + ] + quit_event = threading.Event() + dataset_thread = threading.Thread( + target=self._init_seq_order_and_distribute_seqs_to_children, + kwargs={ + "buf_size": self._buf_size, + "child_queues": parent_conns, + "dataset_lock": self._dataset_lock, + "epoch": epoch, + "quit_event": quit_event, + "seq_list": seq_list, + "seq_order": seq_order, + }, + name=f"{self.__class__.__name__} {self.name} dataset ep {epoch}", + ) + dataset_thread.start() + data_iter = self._multi_proc_data_iter = _MultiProcDataIter( + dataset_thread=dataset_thread, quit_event=quit_event, worker_procs=worker_procs + ) + else: + self._dataset.init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order) + data_iter = _build_mapping_iter( + _iterate_dataset(self._dataset, in_tensor_dict_template=self._in_tensor_dict_template), + map_seq=self._map_seq, + map_seq_stream=self._map_seq_stream, + epoch=epoch, + out_tensor_dict_template=self._out_tensor_dict_template, + rng=RandomState(self._get_random_seed_for_epoch(epoch=epoch)), + seq_list_for_validation=seq_list, + ) self._data_iter = enumerate(data_iter) self._data_iter_produced_num_seqs = 0 self._seq_list_for_validation = seq_list @@ -334,6 +394,70 @@ def _make_tensor_template_from_input(self, data_key: str) -> Tensor: sparse_dim.vocab = Vocabulary.create_vocab_from_labels(self._dataset.labels[data_key]) return Tensor(data_key, dims=dims, dtype=dtype, sparse_dim=sparse_dim) + def _init_seq_order_and_distribute_seqs_to_children( + self, + *, + buf_size: int, + child_queues: Sequence[mpConnection], + dataset_lock: threading.Lock, + epoch: int, + quit_event: threading.Event, + seq_list: Optional[List[str]] = None, + seq_order: Optional[List[int]] = None, + ): + """ + Initialize the wrapped dataset and distribute the contained sequences to the child worker processes. + """ + + assert buf_size > 0 + assert len(child_queues) > 0 + + def _any_q_ready() -> bool: + ready, _, _ = select.select(child_queues, [], []) + return len(ready) > 0 + + cache: deque[Tuple[int, TensorDict]] = deque() + + # Lock ensures that only one thread at a time accesses the wrapped dataset. + # This protects against issues while moving from one epoch to the next. + with dataset_lock: + self._dataset.init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order) + data_iter = _iterate_dataset(self._dataset, in_tensor_dict_template=self._in_tensor_dict_template) + data_iter = enumerate(data_iter) + + def _add_to_cache() -> bool: + try: + cache.append(next(data_iter)) + return True + except StopIteration: + return False + + while not quit_event.is_set(): + while len(cache) < buf_size - 1 and not _any_q_ready(): + if not _add_to_cache(): + break + _add_to_cache() + if not cache: + break + seq_idx, tensor_dict = cache.popleft() + worker_idx = seq_idx % len(child_queues) + try: + msg, _ = child_queues[worker_idx].recv() + assert msg == "get_seq" + child_queues[worker_idx].send(("seq", tensor_dict)) + except (BrokenPipeError, EOFError): + # queue is closed, i.e. the worker process crashed for some reason -> stop + break + + for q in child_queues: + try: + q.send(("exit", None)) # signal end of data + except (BrokenPipeError, EOFError): + # queue is already closed, i.e. the worker process died + pass + finally: + q.close() + def _iterate_dataset(dataset: Dataset, *, in_tensor_dict_template: TensorDict) -> Iterator[TensorDict]: """ @@ -458,164 +582,6 @@ def _apply_map_seq(tensor_dict: TensorDict) -> TensorDict: return _validate_tensor_dict_iter(data_iter) -class MultiProcPostprocessingDataset(PostprocessingDataset): - """ - Subclass of :class:`PostprocessingDataset` that parallelizes the post-processing using multiple processes. - - The underlying dataset is only instantiated once, only the post-processing functions are parallelized. - - Since it is usually the postprocessing itself and not the data loading from the underlying dataset - that is the bottleneck, it is often sufficient to only parallelize the postprocessing step. - The advantage is that this usually has lower memory consumption than using :class:``MultiProcDataset``. - - The dataset interface is the same as for :class:`PostprocessingDataset`, with two additional parameters - to configure the multi-processing behavior. - """ - - def __init__(self, *args, buf_size: int = 1, num_workers: int = 1, **kwargs): - """ - :param args: Same args as :class:``PostprocessingDataset``. - :param buf_size: Buffer size for each worker, number of seqs to prefetch. Must be > 0. - :param num_workers: Number of worker processes to use for data postprocessing. Must be > 0. - :param kwargs: Same args as :class:``PostprocessingDataset``. - """ - - super().__init__(*args, **kwargs) - - if buf_size < 1: - raise ValueError(f"{self}: buf_size must be > 0, but got {buf_size}") - if num_workers < 1: - raise ValueError(f"{self}: num_workers must be > 0, but got {num_workers}") - - self._buf_size = buf_size - # Ensure only one feeder thread at a time accesses the wrapped dataset to - # prevent race conditions while moving from one epoch to the next. - self._dataset_lock = threading.Lock() - self._multi_proc_data_iter: Optional[_MultiProcDataIter] = None # store for cleanup - self._num_workers = num_workers - - def init_seq_order( - self, epoch: Optional[int] = None, seq_list: Optional[List[str]] = None, seq_order: Optional[List[int]] = None - ): - """ - :param epoch: - :param seq_list: - :param seq_order: - :return: whether the order changed (True is always safe to return) - """ - super().init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order) - - if self._multi_proc_data_iter is not None: - self._multi_proc_data_iter.stop() - self._multi_proc_data_iter = None - - if self._num_seqs == 0: - return True - - assert self._buf_size > 0 - assert self._num_workers > 0 - parent_conns, child_conns = zip(*[_mp.Pipe() for _ in range(self._num_workers)]) - base_rng_seed = self._get_random_seed_for_epoch(epoch=epoch) * 683859 * self._num_workers - worker_procs = [ - _WorkerProcParent( - name=f"{self.__class__.__name__} {self.name} ep {epoch}", - epoch=epoch, - buffer_size=self._buf_size, - index=i, - map_seq=self._map_seq, - map_seq_stream=self._map_seq_stream, - out_tensor_dict_template=self._out_tensor_dict_template, - rng_seed=(base_rng_seed + 30411 * i) % (2**32 - 1), - seq_list=seq_list, - seq_pipe=child_conn, - ) - for i, child_conn in enumerate(child_conns) - ] - quit_event = threading.Event() - dataset_thread = threading.Thread( - target=self._init_seq_order_and_distribute_seqs_to_children, - kwargs={ - "buf_size": self._buf_size, - "child_queues": parent_conns, - "dataset_lock": self._dataset_lock, - "epoch": epoch, - "quit_event": quit_event, - "seq_list": seq_list, - "seq_order": seq_order, - }, - name=f"{self.__class__.__name__} {self.name} dataset ep {epoch}", - ) - dataset_thread.start() - data_iter = self._multi_proc_data_iter = _MultiProcDataIter( - dataset_thread=dataset_thread, quit_event=quit_event, worker_procs=worker_procs - ) - self._data_iter = enumerate(data_iter) - - def _init_seq_order_and_distribute_seqs_to_children( - self, - *, - buf_size: int, - child_queues: Sequence[mpConnection], - dataset_lock: threading.Lock, - epoch: int, - quit_event: threading.Event, - seq_list: Optional[List[str]] = None, - seq_order: Optional[List[int]] = None, - ): - """ - Initialize the wrapped dataset and distribute the contained sequences to the child worker processes. - """ - - assert buf_size > 0 - assert len(child_queues) > 0 - - def _any_q_ready() -> bool: - ready, _, _ = select.select(child_queues, [], []) - return len(ready) > 0 - - cache: deque[Tuple[int, TensorDict]] = deque() - - # Lock ensures that only one thread at a time accesses the wrapped dataset. - # This protects against issues while moving from one epoch to the next. - with dataset_lock: - self._dataset.init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order) - data_iter = _iterate_dataset(self._dataset, in_tensor_dict_template=self._in_tensor_dict_template) - data_iter = enumerate(data_iter) - - def _add_to_cache() -> bool: - try: - cache.append(next(data_iter)) - return True - except StopIteration: - return False - - while not quit_event.is_set(): - while len(cache) < buf_size - 1 and not _any_q_ready(): - if not _add_to_cache(): - break - _add_to_cache() - if not cache: - break - seq_idx, tensor_dict = cache.popleft() - worker_idx = seq_idx % len(child_queues) - try: - msg, _ = child_queues[worker_idx].recv() - assert msg == "get_seq" - child_queues[worker_idx].send(("seq", tensor_dict)) - except (BrokenPipeError, EOFError): - # queue is closed, i.e. the worker process crashed for some reason -> stop - break - - for q in child_queues: - try: - q.send(("exit", None)) # signal end of data - except (BrokenPipeError, EOFError): - # queue is already closed, i.e. the worker process died - pass - finally: - q.close() - - class _MultiProcDataIter: """ Data iter that pulls from the worker processes and manages their lifetime. diff --git a/tests/test_Dataset.py b/tests/test_Dataset.py index cf39ca188..85dfaa00e 100644 --- a/tests/test_Dataset.py +++ b/tests/test_Dataset.py @@ -1277,7 +1277,7 @@ def test_MultiProcPostprocessingDataset(): _demo_txt = "some utterance text that has a few words" with create_ogg_zip_txt_only_dataset_opts(text=_demo_txt) as sub_ds_opts: ds_opts = { - "class": "MultiProcPostprocessingDataset", + "class": "PostprocessingDataset", "dataset": sub_ds_opts, "map_seq_stream": _repeat2, "buf_size": 1, From 5c0e14a6425b2a18d611db14bed8bab96369fd87 Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Thu, 25 Sep 2025 14:42:07 +0200 Subject: [PATCH 28/58] default num_workers=0 --- returnn/datasets/postprocessing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index 6847a66e6..1e2eb143e 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -119,7 +119,7 @@ def __init__( map_outputs: Optional[Dict[str, Any]] = None, map_seq_stream_preserves_num_seqs: Optional[bool] = None, buf_size: int = 1, - num_workers: int = 1, + num_workers: int = 0, **kwargs, ): """ From a0bc8a1f076f62129ab807772663d2e741b84f0b Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Thu, 25 Sep 2025 14:51:14 +0200 Subject: [PATCH 29/58] ensure seq order and distributor function is only called in multi-worker scenario --- returnn/datasets/postprocessing.py | 1 + 1 file changed, 1 insertion(+) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index 1e2eb143e..b7ffb8911 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -411,6 +411,7 @@ def _init_seq_order_and_distribute_seqs_to_children( assert buf_size > 0 assert len(child_queues) > 0 + assert self._num_workers > 0 def _any_q_ready() -> bool: ready, _, _ = select.select(child_queues, [], []) From 5309127decc33c79fe6d7aadde632c504dc27e7a Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Thu, 25 Sep 2025 15:05:51 +0200 Subject: [PATCH 30/58] fix blocking bug --- returnn/datasets/postprocessing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index b7ffb8911..458a276e0 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -414,7 +414,7 @@ def _init_seq_order_and_distribute_seqs_to_children( assert self._num_workers > 0 def _any_q_ready() -> bool: - ready, _, _ = select.select(child_queues, [], []) + ready, _, _ = select.select(child_queues, [], [], 0) return len(ready) > 0 cache: deque[Tuple[int, TensorDict]] = deque() From 9420fbf6c6f1ebb7bf8a11a5c6aba9c78969de33 Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Thu, 25 Sep 2025 15:27:36 +0200 Subject: [PATCH 31/58] avoid variable shadowing --- returnn/datasets/postprocessing.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index 458a276e0..2f73d84f9 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -753,12 +753,12 @@ def _iter_pipe(q: mpConnection) -> Iterator[TensorDict]: while True: try: q.send(("get_seq", None)) - msg, item = q.recv() + seq_msg, item = q.recv() except (BrokenPipeError, EOFError): # queue is closed break - assert msg in ("seq", "exit") - if msg == "exit" or item is None: + assert seq_msg in ("seq", "exit") + if seq_msg == "exit" or item is None: break assert isinstance(item, TensorDict) yield item From e4568bd52c0e3384a841254e52f0892a5cd1ad53 Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Fri, 26 Sep 2025 10:33:20 +0200 Subject: [PATCH 32/58] keep worker procs alive across subepochs --- returnn/datasets/postprocessing.py | 264 ++++++++++++++++------------- 1 file changed, 147 insertions(+), 117 deletions(-) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index 2f73d84f9..202d9e379 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -178,13 +178,6 @@ def __init__( self._map_outputs = map_outputs self._seq_list_for_validation: Optional[List[str]] = None - self._buf_size = buf_size - # Ensure only one feeder thread at a time accesses the wrapped dataset to - # prevent race conditions while moving from one epoch to the next. - self._dataset_lock = threading.Lock() - self._multi_proc_data_iter: Optional[_MultiProcDataIter] = None # store for cleanup - self._num_workers = num_workers - self._dataset = init_dataset(self._dataset_def, parent_dataset=self) if self._map_seq_stream is None or self._map_seq_stream_preserves_num_seqs is True: # if the stream mapper is set, the num_seqs may change and the estimation is less accurate @@ -192,6 +185,14 @@ def __init__( self._data_iter: Optional[Iterator[Tuple[int, TensorDict]]] = None self._data_iter_produced_num_seqs = 0 + self._buf_size = buf_size + # Ensure only one feeder thread at a time accesses the wrapped dataset to + # prevent race conditions while moving from one epoch to the next. + self._dataset_lock = threading.Lock() + self._multi_proc_data_iter: Optional[_MultiProcDataIter] = None # store for cleanup + self._num_workers = num_workers + self._worker_procs: Optional[List[_WorkerProcParent]] = None + self._in_tensor_dict_template = TensorDict( {name: self._make_tensor_template_from_input(name) for name in self._dataset.get_data_keys()} ) @@ -252,14 +253,13 @@ def init_seq_order( return True if self._num_workers > 0: + self._lazy_init_worker_procs() + assert self._worker_procs is not None and len(self._worker_procs) == self._num_workers parent_conns, child_conns = zip(*[_mp.Pipe() for _ in range(self._num_workers)]) base_rng_seed = self._get_random_seed_for_epoch(epoch=epoch) * 683859 * self._num_workers - worker_procs = [ - _WorkerProcParent( - name=f"{self.__class__.__name__} {self.name} ep {epoch}", + for i, (worker, child_conn) in enumerate(zip(self._worker_procs, child_conns)): + worker.init_seq_order( epoch=epoch, - buffer_size=self._buf_size, - index=i, map_seq=self._map_seq, map_seq_stream=self._map_seq_stream, out_tensor_dict_template=self._out_tensor_dict_template, @@ -267,25 +267,8 @@ def init_seq_order( seq_list=seq_list, seq_pipe=child_conn, ) - for i, child_conn in enumerate(child_conns) - ] - quit_event = threading.Event() - dataset_thread = threading.Thread( - target=self._init_seq_order_and_distribute_seqs_to_children, - kwargs={ - "buf_size": self._buf_size, - "child_queues": parent_conns, - "dataset_lock": self._dataset_lock, - "epoch": epoch, - "quit_event": quit_event, - "seq_list": seq_list, - "seq_order": seq_order, - }, - name=f"{self.__class__.__name__} {self.name} dataset ep {epoch}", - ) - dataset_thread.start() - data_iter = self._multi_proc_data_iter = _MultiProcDataIter( - dataset_thread=dataset_thread, quit_event=quit_event, worker_procs=worker_procs + data_iter = self._multi_proc_data_iter = self._init_multi_proc_data_iter( + epoch=epoch, parent_conns=parent_conns, seq_list=seq_list, seq_order=seq_order ) else: self._dataset.init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order) @@ -310,6 +293,20 @@ def init_seq_order( pass # some datasets don't know their num_seqs return True + def __del__(self): + if not self._worker_procs: + return + got_exception = False + for parent in self._worker_procs: + try: + parent.exit(join=False) + except Exception: + got_exception = True + if got_exception: + return + for parent in self._worker_procs: + util.try_run(parent.worker_proc.join) + def get_current_seq_order(self): """:return: current seq order of wrapped dataset, if map_seq_stream is not used""" if self._map_seq_stream is not None: @@ -346,6 +343,18 @@ def supports_sharding(self) -> bool: assert self._dataset is not None return self._dataset.supports_sharding() + def finish_epoch(self, *, free_resources=False): + super().finish_epoch(free_resources=free_resources) + if not free_resources: + return + if self._multi_proc_data_iter is not None: + self._multi_proc_data_iter.stop(join=True) + self._multi_proc_data_iter = None + if self._worker_procs is not None: + for wp in self._worker_procs: + wp.exit(join=True) + self._worker_procs = None + def _collect_single_seq(self, seq_idx: int) -> Optional[DatasetSeq]: while True: try: @@ -394,12 +403,48 @@ def _make_tensor_template_from_input(self, data_key: str) -> Tensor: sparse_dim.vocab = Vocabulary.create_vocab_from_labels(self._dataset.labels[data_key]) return Tensor(data_key, dims=dims, dtype=dtype, sparse_dim=sparse_dim) + def _lazy_init_worker_procs(self): + if self._worker_procs is not None: + return + self._worker_procs = [ + _WorkerProcParent(name=f"{self.__class__.__name__} {self.name} worker", buffer_size=self._buf_size, index=i) + for i in range(self._num_workers) + ] + + def _init_multi_proc_data_iter( + self, + *, + epoch: int, + parent_conns: Sequence[mpConnection], + seq_list: Optional[List[str]] = None, + seq_order: Optional[List[int]] = None, + ) -> _MultiProcDataIter: + assert len(parent_conns) == self._num_workers + + quit_event = threading.Event() + dataset_thread = threading.Thread( + target=self._init_seq_order_and_distribute_seqs_to_children, + kwargs={ + "child_queues": parent_conns, + "epoch": epoch, + "quit_event": quit_event, + "seq_list": seq_list, + "seq_order": seq_order, + }, + name=f"{self.__class__.__name__} {self.name} dataset ep {epoch}", + ) + # parent_conns are not closed here, because they move to a different thread, not process, + # and so they must remain open. + dataset_thread.start() + data_iter = _MultiProcDataIter( + dataset_thread=dataset_thread, quit_event=quit_event, worker_procs=self._worker_procs + ) + return data_iter + def _init_seq_order_and_distribute_seqs_to_children( self, *, - buf_size: int, child_queues: Sequence[mpConnection], - dataset_lock: threading.Lock, epoch: int, quit_event: threading.Event, seq_list: Optional[List[str]] = None, @@ -409,7 +454,7 @@ def _init_seq_order_and_distribute_seqs_to_children( Initialize the wrapped dataset and distribute the contained sequences to the child worker processes. """ - assert buf_size > 0 + assert self._buf_size > 0 assert len(child_queues) > 0 assert self._num_workers > 0 @@ -421,7 +466,7 @@ def _any_q_ready() -> bool: # Lock ensures that only one thread at a time accesses the wrapped dataset. # This protects against issues while moving from one epoch to the next. - with dataset_lock: + with self._dataset_lock: self._dataset.init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order) data_iter = _iterate_dataset(self._dataset, in_tensor_dict_template=self._in_tensor_dict_template) data_iter = enumerate(data_iter) @@ -434,7 +479,7 @@ def _add_to_cache() -> bool: return False while not quit_event.is_set(): - while len(cache) < buf_size - 1 and not _any_q_ready(): + while len(cache) < self._buf_size - 1 and not _any_q_ready(): if not _add_to_cache(): break _add_to_cache() @@ -589,11 +634,7 @@ class _MultiProcDataIter: """ def __init__( - self, - *, - dataset_thread: threading.Thread, - quit_event: threading.Event, - worker_procs: List[_WorkerProcParent], + self, *, dataset_thread: threading.Thread, quit_event: threading.Event, worker_procs: List[_WorkerProcParent] ): self.dataset_thread = dataset_thread self.quit_event = quit_event @@ -631,8 +672,6 @@ def stop(self, *, join=True): if self.quit_event.is_set(): return self.quit_event.set() - for wp in self.worker_procs: - wp.exit(join=join) if join: util.try_run(self.dataset_thread.join) @@ -645,38 +684,14 @@ def __del__(self): class _WorkerProcParent: - def __init__( - self, - *, - name: str, - epoch: int, - index: int, - buffer_size: int, - map_seq: Optional[Callable], - map_seq_stream: Optional[Callable], - out_tensor_dict_template: TensorDict, - rng_seed: int, - seq_list: Optional[List[str]], - seq_pipe: mpConnection, - ): + def __init__(self, *, name: str, index: int, buffer_size: int): parent_conn, child_conn = _mp.Pipe() self.parent_conn = parent_conn self.worker_proc = _mp.Process( - name=f"{name} worker {index} ep {epoch}", + name=f"{name} worker {index}", target=_worker_proc_loop, - args=( - epoch, - buffer_size, - index, - map_seq, - map_seq_stream, - out_tensor_dict_template, - rng_seed, - seq_list, - child_conn, - seq_pipe, - ), + args=(buffer_size, index, child_conn), daemon=True, ) self.worker_proc.start() @@ -688,6 +703,30 @@ def __init__( # otherwise it would just hang. child_conn.close() + def init_seq_order( + self, + *, + epoch: int, + map_seq: Optional[Callable], + map_seq_stream: Optional[Callable], + out_tensor_dict_template: TensorDict, + rng_seed: int, + seq_list: Optional[List[str]], + seq_pipe: mpConnection, + ): + """init_seq_order""" + args = { + "epoch": epoch, + "map_seq": map_seq, + "map_seq_stream": map_seq_stream, + "out_tensor_dict_template": out_tensor_dict_template, + "rng_seed": rng_seed, + "seq_list": seq_list, + "seq_pipe": seq_pipe, + } + self.parent_conn.send(("init_seq_order", args)) + msg, _ = self.parent_conn.recv() + assert msg == "init_seq_order" # seq_pipe is owned by the child process, # and so must be closed in the parent to avoid hangs seq_pipe.close() @@ -715,38 +754,33 @@ def __del__(self): util.try_run(self.worker_proc.join) -def _worker_proc_loop( - epoch: int, - buffer_size: int, - index: int, - map_seq: Optional[Callable], - map_seq_stream: Optional[Callable], - out_tensor_dict_template: TensorDict, - rng_seed: int, - seq_list: Optional[List[str]], - parent_conn: mpConnection, - feeder_conn: mpConnection, -): +def _worker_proc_loop(buffer_size: int, index: int, parent_conn: mpConnection): if sys.platform == "linux": with open("/proc/self/comm", "w") as f: - f.write(f"PP worker {index} ep {epoch}") + f.write(f"PP worker {index}") better_exchook.setup_all() - assert isinstance(epoch, int) - assert isinstance(buffer_size, int) + assert isinstance(buffer_size, int) and buffer_size > 0 assert isinstance(index, int) - assert buffer_size > 0 - assert map_seq or map_seq_stream, "need to specify either map_seq or map_seq_stream" - assert not (map_seq and map_seq_stream), "cannot set both map_seq and map_seq_stream" - assert map_seq is None or isinstance(map_seq, Callable) - assert map_seq_stream is None or isinstance(map_seq_stream, Callable) - assert isinstance(out_tensor_dict_template, TensorDict) - assert isinstance(rng_seed, int) assert isinstance(parent_conn, mpConnection) - assert isinstance(feeder_conn, mpConnection) cache: deque[TensorDict] = deque() + data_iter: Optional[Iterator[TensorDict]] = None + feeder_conn: Optional[mpConnection] = None + + def _add_to_cache(): + nonlocal data_iter + if data_iter is None: + return False + try: + seq = next(data_iter) + except StopIteration: + data_iter = None + return False + cache.append(seq) + return True + def _iter_pipe(q: mpConnection) -> Iterator[TensorDict]: assert isinstance(q, mpConnection) @@ -763,29 +797,6 @@ def _iter_pipe(q: mpConnection) -> Iterator[TensorDict]: assert isinstance(item, TensorDict) yield item - data_iter = _build_mapping_iter( - _iter_pipe(feeder_conn), - map_seq=map_seq, - map_seq_stream=map_seq_stream, - epoch=epoch, - out_tensor_dict_template=out_tensor_dict_template, - rng=RandomState(rng_seed), - seq_list_for_validation=seq_list, - ) - assert isinstance(data_iter, Iterator) - - def _add_to_cache(): - nonlocal data_iter - if data_iter is None: - return False - try: - seq = next(data_iter) - except StopIteration: - data_iter = None - return False - cache.append(seq) - return True - try: while True: while len(cache) < buffer_size and not parent_conn.poll(): @@ -798,6 +809,24 @@ def _add_to_cache(): if not cache: _add_to_cache() parent_conn.send(("seq", cache.popleft() if cache else None)) + elif msg == "init_seq_order": + epoch = kwargs["epoch"] + if sys.platform == "linux": + with open("/proc/self/comm", "w") as f: + f.write(f"PP worker {index} ep {epoch}") + feeder_conn = kwargs["seq_pipe"] + data_iter = _build_mapping_iter( + _iter_pipe(feeder_conn), + map_seq=kwargs["map_seq"], + map_seq_stream=kwargs["map_seq_stream"], + epoch=epoch, + out_tensor_dict_template=kwargs["out_tensor_dict_template"], + rng=RandomState(kwargs["rng_seed"]), + seq_list_for_validation=kwargs["seq_list"], + ) + assert isinstance(data_iter, Iterator) + cache.clear() + parent_conn.send(("init_seq_order", None)) else: raise Exception(f"unknown msg {msg!r}") except KeyboardInterrupt: # when parent dies @@ -805,7 +834,8 @@ def _add_to_cache(): except EOFError: # when parent dies pass finally: - feeder_conn.close() + if feeder_conn is not None: + feeder_conn.close() parent_conn.close() From a81940ee9890fdc83a345b0cb855672bbf8e42b7 Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Fri, 26 Sep 2025 10:42:15 +0200 Subject: [PATCH 33/58] extend tests --- tests/test_Dataset.py | 47 ++++++++++++++++++++++++++----------------- 1 file changed, 29 insertions(+), 18 deletions(-) diff --git a/tests/test_Dataset.py b/tests/test_Dataset.py index 85dfaa00e..b4c7980a0 100644 --- a/tests/test_Dataset.py +++ b/tests/test_Dataset.py @@ -1274,25 +1274,36 @@ def _repeat2(input_iter: Iterator[TensorDict], **kwargs) -> Iterator[TensorDict] def test_MultiProcPostprocessingDataset(): - _demo_txt = "some utterance text that has a few words" - with create_ogg_zip_txt_only_dataset_opts(text=_demo_txt) as sub_ds_opts: - ds_opts = { - "class": "PostprocessingDataset", - "dataset": sub_ds_opts, - "map_seq_stream": _repeat2, - "buf_size": 1, - "num_workers": 2, - } - dataset = init_dataset(ds_opts) + from test_HDFDataset import generate_hdf_from_dummy + + num_hdfs = 20 + num_seqs = 23 + total_num_seqs = num_hdfs * num_seqs + total_num_seqs_pp = 2 * total_num_seqs + # Create a few HDF files such that we can easily verify the data later. + hdf_files = [generate_hdf_from_dummy() for _ in range(num_hdfs)] - for ep in range(1, 5 + 1): - dataset.init_seq_order(epoch=ep) - assert dataset.have_seqs() - dataset.load_seqs(0, 3) - for i in range(2): - classes = dataset.get_data(i, "classes") - assert len(classes) > 0 - assert not dataset.is_less_than_num_seqs(2) + ds_opts = { + "class": "PostprocessingDataset", + "dataset": { + "class": "HDFDataset", + "files": hdf_files, + "seq_ordering": "default", + }, + "map_seq_stream": _repeat2, + "buf_size": 1, + "num_workers": 2, + } + dataset = init_dataset(ds_opts) + + for ep in range(1, 5 + 1): + dataset.init_seq_order(epoch=ep) + assert dataset.have_seqs() + dataset.load_seqs(0, total_num_seqs_pp + 1) + for i in range(total_num_seqs_pp): + classes = dataset.get_data(i, "classes") + assert len(classes) > 0 + assert not dataset.is_less_than_num_seqs(total_num_seqs_pp + 1) def _post_process_map_seq_no_op(tdict: TensorDict, **_other) -> TensorDict: From 0dcc6eefeccc9fea35e1c4da98d8b96e9447d03a Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Fri, 26 Sep 2025 11:29:35 +0200 Subject: [PATCH 34/58] also clean up background thread in __del__ --- returnn/datasets/postprocessing.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index 202d9e379..8fe5e0a8e 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -294,6 +294,9 @@ def init_seq_order( return True def __del__(self): + if self._multi_proc_data_iter is not None: + self._multi_proc_data_iter.stop(join=True) + self._multi_proc_data_iter = None if not self._worker_procs: return got_exception = False From 19689ab9074cf992065297fc183a45c1f637c8d8 Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Fri, 26 Sep 2025 11:29:40 +0200 Subject: [PATCH 35/58] fix lint errors --- returnn/datasets/postprocessing.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index 8fe5e0a8e..d90fedcb6 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -301,6 +301,7 @@ def __del__(self): return got_exception = False for parent in self._worker_procs: + # noinspection PyBroadException try: parent.exit(join=False) except Exception: @@ -347,6 +348,7 @@ def supports_sharding(self) -> bool: return self._dataset.supports_sharding() def finish_epoch(self, *, free_resources=False): + """finish_epoch""" super().finish_epoch(free_resources=free_resources) if not free_resources: return From f335fb8b3eab7c25a5a569ad2020d231521f8f3b Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Fri, 26 Sep 2025 13:05:53 +0200 Subject: [PATCH 36/58] fix deadlock --- returnn/datasets/postprocessing.py | 44 ++++++++++++++++++++---------- 1 file changed, 29 insertions(+), 15 deletions(-) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index d90fedcb6..7952a4894 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -463,11 +463,22 @@ def _init_seq_order_and_distribute_seqs_to_children( assert len(child_queues) > 0 assert self._num_workers > 0 + caches: List[deque[TensorDict]] = [deque() for _ in range(len(child_queues))] + def _any_q_ready() -> bool: ready, _, _ = select.select(child_queues, [], [], 0) return len(ready) > 0 - cache: deque[Tuple[int, TensorDict]] = deque() + def _distrib_seq(): + ready_conns, _, _ = select.select(child_queues, [], []) + assert len(child_queues) == len(caches) + for queue, cache in zip(child_queues, caches): + if queue not in ready_conns: + continue + msg, _ = queue.recv() + assert msg == "get_seq" + tensor_dict = cache.popleft() if len(cache) > 0 else None + queue.send(("seq", tensor_dict)) # Lock ensures that only one thread at a time accesses the wrapped dataset. # This protects against issues while moving from one epoch to the next. @@ -478,36 +489,39 @@ def _any_q_ready() -> bool: def _add_to_cache() -> bool: try: - cache.append(next(data_iter)) + idx, tensor_dict = next(data_iter) + caches[idx % len(caches)].append(tensor_dict) return True except StopIteration: return False + def _fill_cache_to_min() -> bool: + while any(len(c) == 0 for c in caches): + if not _add_to_cache(): + return False + return True + while not quit_event.is_set(): - while len(cache) < self._buf_size - 1 and not _any_q_ready(): + _fill_cache_to_min() + while sum(len(cache) for cache in caches) < self._buf_size and not _any_q_ready(): if not _add_to_cache(): break - _add_to_cache() - if not cache: + if all(len(c) == 0 for c in caches): break - seq_idx, tensor_dict = cache.popleft() - worker_idx = seq_idx % len(child_queues) try: - msg, _ = child_queues[worker_idx].recv() - assert msg == "get_seq" - child_queues[worker_idx].send(("seq", tensor_dict)) + _distrib_seq() except (BrokenPipeError, EOFError): # queue is closed, i.e. the worker process crashed for some reason -> stop break - for q in child_queues: + for queue in child_queues: try: - q.send(("exit", None)) # signal end of data + queue.send(("seq", None)) except (BrokenPipeError, EOFError): # queue is already closed, i.e. the worker process died pass finally: - q.close() + queue.close() def _iterate_dataset(dataset: Dataset, *, in_tensor_dict_template: TensorDict) -> Iterator[TensorDict]: @@ -796,8 +810,8 @@ def _iter_pipe(q: mpConnection) -> Iterator[TensorDict]: except (BrokenPipeError, EOFError): # queue is closed break - assert seq_msg in ("seq", "exit") - if seq_msg == "exit" or item is None: + assert seq_msg == "seq" + if item is None: break assert isinstance(item, TensorDict) yield item From 3e5cff1dbebd2a7e06486c4951531065187ada2f Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Fri, 26 Sep 2025 13:19:55 +0200 Subject: [PATCH 37/58] enforce that complete_frac is monotonic --- returnn/datasets/postprocessing.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index 7952a4894..1d5a1c2dc 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -660,6 +660,7 @@ def __init__( assert len(worker_procs) > 0 self.worker_procs = worker_procs + self._complete_frac = 0.0 # need to force monotonicity self._workers_exhausted = [False for _ in range(len(worker_procs))] self._worker_idx = 0 @@ -679,7 +680,7 @@ def __next__(self) -> Optional[TensorDict]: continue seq = self.worker_procs[worker_idx].get_seq() if seq is not None: - return seq + return self._ensure_complete_frac_monotonic(seq) self._workers_exhausted[worker_idx] = True # when we reach this point, all workers are exhausted and we stop @@ -694,6 +695,16 @@ def stop(self, *, join=True): if join: util.try_run(self.dataset_thread.join) + def _ensure_complete_frac_monotonic(self, seq: TensorDict) -> TensorDict: + """Enforces monotonicity in complete_frac across all workers.""" + if "complete_frac" not in seq.data: + return seq + complete_frac = float(seq.data["complete_frac"].raw_tensor) + assert 0.0 <= complete_frac <= 1.0, f"complete_frac must be in [0, 1], but got {complete_frac}" + self._complete_frac = max(complete_frac, self._complete_frac) + seq.data["complete_frac"].raw_tensor = numpy.array(self._complete_frac, dtype=numpy.float32) + return seq + def __del__(self): # noinspection PyBroadException try: From 34f829a3b54522242b61b0a67bd0a02eb6ef7643 Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Fri, 26 Sep 2025 13:20:01 +0200 Subject: [PATCH 38/58] fix wrong typing --- returnn/datasets/postprocessing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index 1d5a1c2dc..dfa829b7a 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -761,7 +761,7 @@ def init_seq_order( # and so must be closed in the parent to avoid hangs seq_pipe.close() - def get_seq(self) -> Optional[DatasetSeq]: + def get_seq(self) -> Optional[TensorDict]: """get_seq""" self.parent_conn.send(("get_seq", {})) msg, seq = self.parent_conn.recv() From 0f8bbbf1594ea56cd5f8ee5c12e3ce1735edb4ee Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Fri, 26 Sep 2025 13:51:17 +0200 Subject: [PATCH 39/58] fix variable shadowing lint --- returnn/datasets/postprocessing.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index dfa829b7a..0ce2dde19 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -472,13 +472,13 @@ def _any_q_ready() -> bool: def _distrib_seq(): ready_conns, _, _ = select.select(child_queues, [], []) assert len(child_queues) == len(caches) - for queue, cache in zip(child_queues, caches): - if queue not in ready_conns: + for child_queue, cache in zip(child_queues, caches): + if child_queue not in ready_conns: continue - msg, _ = queue.recv() + msg, _ = child_queue.recv() assert msg == "get_seq" tensor_dict = cache.popleft() if len(cache) > 0 else None - queue.send(("seq", tensor_dict)) + child_queue.send(("seq", tensor_dict)) # Lock ensures that only one thread at a time accesses the wrapped dataset. # This protects against issues while moving from one epoch to the next. From f6a44ac665846db1b5393ae5923868212b3748a4 Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Fri, 26 Sep 2025 13:53:05 +0200 Subject: [PATCH 40/58] periodically check quit_event --- returnn/datasets/postprocessing.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index 0ce2dde19..b179ef78c 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -469,8 +469,9 @@ def _any_q_ready() -> bool: ready, _, _ = select.select(child_queues, [], [], 0) return len(ready) > 0 - def _distrib_seq(): - ready_conns, _, _ = select.select(child_queues, [], []) + def _distrib_seq(*, timeout=0.1): + assert timeout > 0.0, "must not block indefinetely to check quit_event periodically" + ready_conns, _, _ = select.select(child_queues, [], [], timeout) assert len(child_queues) == len(caches) for child_queue, cache in zip(child_queues, caches): if child_queue not in ready_conns: From 05929ef2a6fc446b8c0b2fac78d0b98db3cb36d3 Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Fri, 26 Sep 2025 13:54:10 +0200 Subject: [PATCH 41/58] assertion msg does not have anything to do with what it does --- returnn/datasets/postprocessing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index b179ef78c..a34f76598 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -470,7 +470,7 @@ def _any_q_ready() -> bool: return len(ready) > 0 def _distrib_seq(*, timeout=0.1): - assert timeout > 0.0, "must not block indefinetely to check quit_event periodically" + assert timeout >= 0.0 ready_conns, _, _ = select.select(child_queues, [], [], timeout) assert len(child_queues) == len(caches) for child_queue, cache in zip(child_queues, caches): From f2c82823aeaca99c61e975786430a28fdff6a762 Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Fri, 26 Sep 2025 13:54:38 +0200 Subject: [PATCH 42/58] naming --- returnn/datasets/postprocessing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index a34f76598..9c60fe633 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -469,7 +469,7 @@ def _any_q_ready() -> bool: ready, _, _ = select.select(child_queues, [], [], 0) return len(ready) > 0 - def _distrib_seq(*, timeout=0.1): + def _maybe_distrib_seq(*, timeout=0.1): assert timeout >= 0.0 ready_conns, _, _ = select.select(child_queues, [], [], timeout) assert len(child_queues) == len(caches) @@ -510,7 +510,7 @@ def _fill_cache_to_min() -> bool: if all(len(c) == 0 for c in caches): break try: - _distrib_seq() + _maybe_distrib_seq() except (BrokenPipeError, EOFError): # queue is closed, i.e. the worker process crashed for some reason -> stop break From 8e6ccdc33c07969817b60b62de90f388fbd78ac9 Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Fri, 26 Sep 2025 13:56:01 +0200 Subject: [PATCH 43/58] simplify --- returnn/datasets/postprocessing.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index 9c60fe633..a28df22a8 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -496,15 +496,10 @@ def _add_to_cache() -> bool: except StopIteration: return False - def _fill_cache_to_min() -> bool: - while any(len(c) == 0 for c in caches): - if not _add_to_cache(): - return False - return True - while not quit_event.is_set(): - _fill_cache_to_min() - while sum(len(cache) for cache in caches) < self._buf_size and not _any_q_ready(): + while any(len(cache) == 0 for cache in caches) or ( + sum(len(cache) for cache in caches) < self._buf_size and not _any_q_ready() + ): if not _add_to_cache(): break if all(len(c) == 0 for c in caches): From 093464a21e3faa32140a4e761ee54a6e713ca8d8 Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Fri, 26 Sep 2025 13:57:01 +0200 Subject: [PATCH 44/58] comment --- returnn/datasets/postprocessing.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index a28df22a8..4f2706fe8 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -471,6 +471,7 @@ def _any_q_ready() -> bool: def _maybe_distrib_seq(*, timeout=0.1): assert timeout >= 0.0 + # do not block indefinetely to periodically check the quit_event ready_conns, _, _ = select.select(child_queues, [], [], timeout) assert len(child_queues) == len(caches) for child_queue, cache in zip(child_queues, caches): @@ -497,6 +498,8 @@ def _add_to_cache() -> bool: return False while not quit_event.is_set(): + # fetch seqs until all caches have at least one seq, + # if no child is waiting for seqs also fill until buf_size while any(len(cache) == 0 for cache in caches) or ( sum(len(cache) for cache in caches) < self._buf_size and not _any_q_ready() ): From a26d7e82c7c421964d1a25d2ea72f539bf53a988 Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Fri, 26 Sep 2025 14:45:12 +0200 Subject: [PATCH 45/58] naming --- returnn/datasets/postprocessing.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index 4f2706fe8..bb6ed8b9d 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -268,7 +268,7 @@ def init_seq_order( seq_pipe=child_conn, ) data_iter = self._multi_proc_data_iter = self._init_multi_proc_data_iter( - epoch=epoch, parent_conns=parent_conns, seq_list=seq_list, seq_order=seq_order + epoch=epoch, feeder_to_worker_conns=parent_conns, seq_list=seq_list, seq_order=seq_order ) else: self._dataset.init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order) @@ -420,21 +420,21 @@ def _init_multi_proc_data_iter( self, *, epoch: int, - parent_conns: Sequence[mpConnection], + feeder_to_worker_conns: Sequence[mpConnection], seq_list: Optional[List[str]] = None, seq_order: Optional[List[int]] = None, ) -> _MultiProcDataIter: - assert len(parent_conns) == self._num_workers + assert len(feeder_to_worker_conns) == self._num_workers quit_event = threading.Event() dataset_thread = threading.Thread( target=self._init_seq_order_and_distribute_seqs_to_children, kwargs={ - "child_queues": parent_conns, "epoch": epoch, "quit_event": quit_event, "seq_list": seq_list, "seq_order": seq_order, + "worker_conns": feeder_to_worker_conns, }, name=f"{self.__class__.__name__} {self.name} dataset ep {epoch}", ) @@ -449,32 +449,32 @@ def _init_multi_proc_data_iter( def _init_seq_order_and_distribute_seqs_to_children( self, *, - child_queues: Sequence[mpConnection], epoch: int, quit_event: threading.Event, seq_list: Optional[List[str]] = None, seq_order: Optional[List[int]] = None, + worker_conns: Sequence[mpConnection], ): """ Initialize the wrapped dataset and distribute the contained sequences to the child worker processes. """ assert self._buf_size > 0 - assert len(child_queues) > 0 + assert len(worker_conns) > 0 assert self._num_workers > 0 - caches: List[deque[TensorDict]] = [deque() for _ in range(len(child_queues))] + caches: List[deque[TensorDict]] = [deque() for _ in range(len(worker_conns))] - def _any_q_ready() -> bool: - ready, _, _ = select.select(child_queues, [], [], 0) + def _any_conn_ready() -> bool: + ready, _, _ = select.select(worker_conns, [], [], 0) return len(ready) > 0 def _maybe_distrib_seq(*, timeout=0.1): assert timeout >= 0.0 # do not block indefinetely to periodically check the quit_event - ready_conns, _, _ = select.select(child_queues, [], [], timeout) - assert len(child_queues) == len(caches) - for child_queue, cache in zip(child_queues, caches): + ready_conns, _, _ = select.select(worker_conns, [], [], timeout) + assert len(worker_conns) == len(caches) + for child_queue, cache in zip(worker_conns, caches): if child_queue not in ready_conns: continue msg, _ = child_queue.recv() @@ -501,7 +501,7 @@ def _add_to_cache() -> bool: # fetch seqs until all caches have at least one seq, # if no child is waiting for seqs also fill until buf_size while any(len(cache) == 0 for cache in caches) or ( - sum(len(cache) for cache in caches) < self._buf_size and not _any_q_ready() + sum(len(cache) for cache in caches) < self._buf_size and not _any_conn_ready() ): if not _add_to_cache(): break @@ -513,7 +513,7 @@ def _add_to_cache() -> bool: # queue is closed, i.e. the worker process crashed for some reason -> stop break - for queue in child_queues: + for queue in worker_conns: try: queue.send(("seq", None)) except (BrokenPipeError, EOFError): From 6cd0e65998c71c1f40496db1deb0eb1bf81bf720 Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Fri, 26 Sep 2025 14:48:16 +0200 Subject: [PATCH 46/58] docs --- returnn/datasets/postprocessing.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index bb6ed8b9d..93ab7e194 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -648,7 +648,11 @@ def _apply_map_seq(tensor_dict: TensorDict) -> TensorDict: class _MultiProcDataIter: """ - Data iter that pulls from the worker processes and manages their lifetime. + Data iter that pulls from the worker processes in a well-defined order and + manages the lifetime of the feeder thread. + + Also ensures monotonicity of complete_frac, which would otherwise be no longer + guaranteed if there is more than one worker. """ def __init__( From 64a41a547700fbc6ea45ab06096eb52bb73e4be9 Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Fri, 26 Sep 2025 14:48:48 +0200 Subject: [PATCH 47/58] more docs --- returnn/datasets/postprocessing.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index 93ab7e194..191981029 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -691,7 +691,11 @@ def __next__(self) -> Optional[TensorDict]: raise StopIteration def stop(self, *, join=True): - """Stop the worker processes and the dataset thread.""" + """ + Stop the iterator and the dataset thread. + + Once this is called, the iterator cannot be used anymore. + """ if self.quit_event.is_set(): return self.quit_event.set() @@ -699,7 +703,9 @@ def stop(self, *, join=True): util.try_run(self.dataset_thread.join) def _ensure_complete_frac_monotonic(self, seq: TensorDict) -> TensorDict: - """Enforces monotonicity in complete_frac across all workers.""" + """ + Enforce monotonicity of `complete_frac` in the given `TensorDict`. + """ if "complete_frac" not in seq.data: return seq complete_frac = float(seq.data["complete_frac"].raw_tensor) From a302c377f1dc780916cdc35860f90cfd9244011e Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Fri, 26 Sep 2025 14:51:18 +0200 Subject: [PATCH 48/58] close feeder conn after switching epoch --- returnn/datasets/postprocessing.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index 191981029..93cbb22c5 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -853,6 +853,8 @@ def _iter_pipe(q: mpConnection) -> Iterator[TensorDict]: if sys.platform == "linux": with open("/proc/self/comm", "w") as f: f.write(f"PP worker {index} ep {epoch}") + if feeder_conn is not None: + feeder_conn.close() feeder_conn = kwargs["seq_pipe"] data_iter = _build_mapping_iter( _iter_pipe(feeder_conn), From 20f4496e03dde7f7a93f013d52190b6fd87e3c31 Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Fri, 26 Sep 2025 14:51:57 +0200 Subject: [PATCH 49/58] better name for feeder thread --- returnn/datasets/postprocessing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index 93cbb22c5..7444e63e2 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -436,7 +436,7 @@ def _init_multi_proc_data_iter( "seq_order": seq_order, "worker_conns": feeder_to_worker_conns, }, - name=f"{self.__class__.__name__} {self.name} dataset ep {epoch}", + name=f"{self.__class__.__name__} feeder ep {epoch}", ) # parent_conns are not closed here, because they move to a different thread, not process, # and so they must remain open. From 2a40c0c49e64a5877e2f6505464e8f930803b5f8 Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Fri, 26 Sep 2025 15:34:14 +0200 Subject: [PATCH 50/58] do not pass postprocessors on every epoch --- returnn/datasets/postprocessing.py | 53 +++++++++++++++++------------- 1 file changed, 31 insertions(+), 22 deletions(-) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index 7444e63e2..c0573b19a 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -260,9 +260,6 @@ def init_seq_order( for i, (worker, child_conn) in enumerate(zip(self._worker_procs, child_conns)): worker.init_seq_order( epoch=epoch, - map_seq=self._map_seq, - map_seq_stream=self._map_seq_stream, - out_tensor_dict_template=self._out_tensor_dict_template, rng_seed=(base_rng_seed + 30411 * i) % (2**32 - 1), seq_list=seq_list, seq_pipe=child_conn, @@ -412,7 +409,14 @@ def _lazy_init_worker_procs(self): if self._worker_procs is not None: return self._worker_procs = [ - _WorkerProcParent(name=f"{self.__class__.__name__} {self.name} worker", buffer_size=self._buf_size, index=i) + _WorkerProcParent( + name=f"{self.__class__.__name__} {self.name} worker", + buffer_size=self._buf_size, + index=i, + map_seq=self._map_seq, + map_seq_stream=self._map_seq_stream, + out_tensor_dict_template=self._out_tensor_dict_template, + ) for i in range(self._num_workers) ] @@ -723,14 +727,23 @@ def __del__(self): class _WorkerProcParent: - def __init__(self, *, name: str, index: int, buffer_size: int): + def __init__( + self, + *, + buffer_size: int, + index: int, + name: str, + map_seq: Optional[Callable], + map_seq_stream: Optional[Callable], + out_tensor_dict_template: TensorDict, + ): parent_conn, child_conn = _mp.Pipe() self.parent_conn = parent_conn self.worker_proc = _mp.Process( name=f"{name} worker {index}", target=_worker_proc_loop, - args=(buffer_size, index, child_conn), + args=(index, child_conn, buffer_size, map_seq, map_seq_stream, out_tensor_dict_template), daemon=True, ) self.worker_proc.start() @@ -746,23 +759,12 @@ def init_seq_order( self, *, epoch: int, - map_seq: Optional[Callable], - map_seq_stream: Optional[Callable], - out_tensor_dict_template: TensorDict, rng_seed: int, seq_list: Optional[List[str]], seq_pipe: mpConnection, ): """init_seq_order""" - args = { - "epoch": epoch, - "map_seq": map_seq, - "map_seq_stream": map_seq_stream, - "out_tensor_dict_template": out_tensor_dict_template, - "rng_seed": rng_seed, - "seq_list": seq_list, - "seq_pipe": seq_pipe, - } + args = {"epoch": epoch, "rng_seed": rng_seed, "seq_list": seq_list, "seq_pipe": seq_pipe} self.parent_conn.send(("init_seq_order", args)) msg, _ = self.parent_conn.recv() assert msg == "init_seq_order" @@ -793,7 +795,14 @@ def __del__(self): util.try_run(self.worker_proc.join) -def _worker_proc_loop(buffer_size: int, index: int, parent_conn: mpConnection): +def _worker_proc_loop( + index: int, + parent_conn: mpConnection, + buffer_size: int, + map_seq: Optional[Callable], + map_seq_stream: Optional[Callable], + out_tensor_dict_template: TensorDict, +): if sys.platform == "linux": with open("/proc/self/comm", "w") as f: f.write(f"PP worker {index}") @@ -858,10 +867,10 @@ def _iter_pipe(q: mpConnection) -> Iterator[TensorDict]: feeder_conn = kwargs["seq_pipe"] data_iter = _build_mapping_iter( _iter_pipe(feeder_conn), - map_seq=kwargs["map_seq"], - map_seq_stream=kwargs["map_seq_stream"], epoch=epoch, - out_tensor_dict_template=kwargs["out_tensor_dict_template"], + map_seq=map_seq, + map_seq_stream=map_seq_stream, + out_tensor_dict_template=out_tensor_dict_template, rng=RandomState(kwargs["rng_seed"]), seq_list_for_validation=kwargs["seq_list"], ) From d2a910133776278c5d7e3a0ce4fe2d68b8b5053f Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Mon, 29 Sep 2025 15:31:02 +0200 Subject: [PATCH 51/58] add frequent GC collection --- returnn/datasets/postprocessing.py | 20 ++++++++++++++++++-- returnn/util/basic.py | 23 ++++++++++++++++++++++- 2 files changed, 40 insertions(+), 3 deletions(-) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index c0573b19a..d491dd57f 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -5,6 +5,7 @@ from __future__ import annotations from collections import deque +import gc from itertools import islice import numpy from numpy.random import RandomState @@ -119,6 +120,7 @@ def __init__( map_outputs: Optional[Dict[str, Any]] = None, map_seq_stream_preserves_num_seqs: Optional[bool] = None, buf_size: int = 1, + gc_interval: Optional[int] = None, num_workers: int = 0, **kwargs, ): @@ -146,6 +148,11 @@ def __init__( :param map_seq_stream_preserves_num_seqs: whether the function in map_seq_stream preserves the number of sequences, i.e. for every input sequence there is exactly one output sequence. :param buf_size: Buffer size for each worker, number of seqs to prefetch. Must be > 0. + :param gc_interval: Specifies after how many seqs garbage collection should be called + in the worker processes. + If > 0, must be >= `buf_size`. + If <= 0, no explicit garbage collection is done. This can lead to suboptimal memory consumption in the workers. + If None (default), uses a reasonable default (`buf_size` seqs). :param num_workers: If > 0, configures the number of worker processes to use for data postprocessing. Only the postprocessing is distributed across subprocesses, the underlying dataset is only instantiated once. @@ -165,6 +172,10 @@ def __init__( if buf_size < 1: raise ValueError(f"{self}: buf_size must be > 0, but got {buf_size}") + if gc_interval is not None and 0 < gc_interval < buf_size: + raise ValueError( + f"{self}: if gc_interval > 0, it must be >= buf_size, but got 0 < {gc_interval} < buf_size ({buf_size})" + ) if num_workers < 0: raise ValueError(f"{self}: num_workers must be >= 0, but got {num_workers}") @@ -189,6 +200,7 @@ def __init__( # Ensure only one feeder thread at a time accesses the wrapped dataset to # prevent race conditions while moving from one epoch to the next. self._dataset_lock = threading.Lock() + self._gc_interval = gc_interval self._multi_proc_data_iter: Optional[_MultiProcDataIter] = None # store for cleanup self._num_workers = num_workers self._worker_procs: Optional[List[_WorkerProcParent]] = None @@ -412,6 +424,7 @@ def _lazy_init_worker_procs(self): _WorkerProcParent( name=f"{self.__class__.__name__} {self.name} worker", buffer_size=self._buf_size, + gc_interval=self._gc_interval or self._buf_size, index=i, map_seq=self._map_seq, map_seq_stream=self._map_seq_stream, @@ -731,6 +744,7 @@ def __init__( self, *, buffer_size: int, + gc_interval: int, index: int, name: str, map_seq: Optional[Callable], @@ -743,7 +757,7 @@ def __init__( self.worker_proc = _mp.Process( name=f"{name} worker {index}", target=_worker_proc_loop, - args=(index, child_conn, buffer_size, map_seq, map_seq_stream, out_tensor_dict_template), + args=(index, child_conn, buffer_size, gc_interval, map_seq, map_seq_stream, out_tensor_dict_template), daemon=True, ) self.worker_proc.start() @@ -799,6 +813,7 @@ def _worker_proc_loop( index: int, parent_conn: mpConnection, buffer_size: int, + gc_interval: int, map_seq: Optional[Callable], map_seq_stream: Optional[Callable], out_tensor_dict_template: TensorDict, @@ -809,6 +824,7 @@ def _worker_proc_loop( better_exchook.setup_all() assert isinstance(buffer_size, int) and buffer_size > 0 + assert isinstance(gc_interval, int) and gc_interval >= buffer_size assert isinstance(index, int) assert isinstance(parent_conn, mpConnection) @@ -866,7 +882,7 @@ def _iter_pipe(q: mpConnection) -> Iterator[TensorDict]: feeder_conn.close() feeder_conn = kwargs["seq_pipe"] data_iter = _build_mapping_iter( - _iter_pipe(feeder_conn), + util.iter_with_gc(_iter_pipe(feeder_conn), gc_interval=gc_interval), epoch=epoch, map_seq=map_seq, map_seq_stream=map_seq_stream, diff --git a/returnn/util/basic.py b/returnn/util/basic.py index 0df740229..e4695ec59 100644 --- a/returnn/util/basic.py +++ b/returnn/util/basic.py @@ -5,12 +5,13 @@ """ from __future__ import annotations -from typing import Optional, Union, Any, Generic, TypeVar, Iterable, Tuple, Dict, List, Set, Callable +from typing import Optional, Union, Any, Generic, TypeVar, Iterator, Iterable, Tuple, Dict, List, Set, Callable import subprocess from subprocess import CalledProcessError from collections import deque +import gc import inspect import os import sys @@ -4552,3 +4553,23 @@ def get_fwd_compat_kwargs() -> Dict[str, Any]: """ i = fwd_compatibility_rng.integers(0, 100) return {f"__fwd_compat_random_arg_{i:03}": None} + + +def iter_with_gc(iter: Iterable[T], *, gc_interval: int) -> Iterator[T]: + """ + Iterate and call gc.collect() every `gc_interval` steps. + + See https://github.com/rwth-i6/returnn/pull/1765#issuecomment-3346576978 for discussion. + + :param iter: iterable + :param int gc_interval: call gc.collect() every `gc_interval` steps. + If <= 0, will not call gc.collect(). + :return: iterator + """ + if gc_interval <= 0: + yield from iter + return + for i, item in enumerate(iter): + if i > 0 and i % gc_interval == 0: + gc.collect() + yield item From c2bc4fdbf8ea523547cbb1d02a1a5e4378f80f68 Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Mon, 29 Sep 2025 15:55:14 +0200 Subject: [PATCH 52/58] improve test --- tests/test_Dataset.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/tests/test_Dataset.py b/tests/test_Dataset.py index b4c7980a0..7745f781b 100644 --- a/tests/test_Dataset.py +++ b/tests/test_Dataset.py @@ -1274,21 +1274,17 @@ def _repeat2(input_iter: Iterator[TensorDict], **kwargs) -> Iterator[TensorDict] def test_MultiProcPostprocessingDataset(): - from test_HDFDataset import generate_hdf_from_dummy - - num_hdfs = 20 - num_seqs = 23 - total_num_seqs = num_hdfs * num_seqs + total_num_seqs = 500 total_num_seqs_pp = 2 * total_num_seqs - # Create a few HDF files such that we can easily verify the data later. - hdf_files = [generate_hdf_from_dummy() for _ in range(num_hdfs)] ds_opts = { "class": "PostprocessingDataset", "dataset": { - "class": "HDFDataset", - "files": hdf_files, - "seq_ordering": "default", + "class": "DummyDataset", + "input_dim": 13, + "output_dim": 7, + "num_seqs": total_num_seqs, + "seq_len": 17, }, "map_seq_stream": _repeat2, "buf_size": 1, @@ -1299,8 +1295,9 @@ def test_MultiProcPostprocessingDataset(): for ep in range(1, 5 + 1): dataset.init_seq_order(epoch=ep) assert dataset.have_seqs() - dataset.load_seqs(0, total_num_seqs_pp + 1) for i in range(total_num_seqs_pp): + assert dataset.is_less_than_num_seqs(i) + dataset.load_seqs(i, i + 1) classes = dataset.get_data(i, "classes") assert len(classes) > 0 assert not dataset.is_less_than_num_seqs(total_num_seqs_pp + 1) From 1fa9f50a4a6743db7f46c0504a13afdf3ba584ca Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Mon, 29 Sep 2025 15:58:09 +0200 Subject: [PATCH 53/58] add 100 seqs lower bound if gc_interval is not configured explicitly --- returnn/datasets/postprocessing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index d491dd57f..5fea27c37 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -152,7 +152,7 @@ def __init__( in the worker processes. If > 0, must be >= `buf_size`. If <= 0, no explicit garbage collection is done. This can lead to suboptimal memory consumption in the workers. - If None (default), uses a reasonable default (`buf_size` seqs). + If None (default), uses a reasonable default (`buf_size`, but min 100 seqs). :param num_workers: If > 0, configures the number of worker processes to use for data postprocessing. Only the postprocessing is distributed across subprocesses, the underlying dataset is only instantiated once. @@ -424,7 +424,7 @@ def _lazy_init_worker_procs(self): _WorkerProcParent( name=f"{self.__class__.__name__} {self.name} worker", buffer_size=self._buf_size, - gc_interval=self._gc_interval or self._buf_size, + gc_interval=self._gc_interval or max(self._buf_size, 100), index=i, map_seq=self._map_seq, map_seq_stream=self._map_seq_stream, From a4c509cac8952440c9ef8ff94729e15d1750feaa Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Mon, 29 Sep 2025 15:58:39 +0200 Subject: [PATCH 54/58] fix naming issue --- returnn/util/basic.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/returnn/util/basic.py b/returnn/util/basic.py index e4695ec59..3ddc69ff4 100644 --- a/returnn/util/basic.py +++ b/returnn/util/basic.py @@ -4555,21 +4555,21 @@ def get_fwd_compat_kwargs() -> Dict[str, Any]: return {f"__fwd_compat_random_arg_{i:03}": None} -def iter_with_gc(iter: Iterable[T], *, gc_interval: int) -> Iterator[T]: +def iter_with_gc(iterator: Iterable[T], *, gc_interval: int) -> Iterator[T]: """ Iterate and call gc.collect() every `gc_interval` steps. See https://github.com/rwth-i6/returnn/pull/1765#issuecomment-3346576978 for discussion. - :param iter: iterable + :param iterator: iterable :param int gc_interval: call gc.collect() every `gc_interval` steps. If <= 0, will not call gc.collect(). :return: iterator """ if gc_interval <= 0: - yield from iter + yield from iterator return - for i, item in enumerate(iter): + for i, item in enumerate(iterator): if i > 0 and i % gc_interval == 0: gc.collect() yield item From ad1e87869abe121a88d9a3b1c3c8518fcf99ba95 Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Mon, 29 Sep 2025 16:00:23 +0200 Subject: [PATCH 55/58] pass 0 interval correctly --- returnn/datasets/postprocessing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index 5fea27c37..f4005b5f5 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -424,7 +424,7 @@ def _lazy_init_worker_procs(self): _WorkerProcParent( name=f"{self.__class__.__name__} {self.name} worker", buffer_size=self._buf_size, - gc_interval=self._gc_interval or max(self._buf_size, 100), + gc_interval=self._gc_interval if self._gc_interval is not None else max(self._buf_size, 100), index=i, map_seq=self._map_seq, map_seq_stream=self._map_seq_stream, From c08733f4c995bb703d9f54950357a4833471fab5 Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Mon, 29 Sep 2025 16:00:29 +0200 Subject: [PATCH 56/58] docs --- returnn/datasets/postprocessing.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index f4005b5f5..ac00c08b3 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -424,6 +424,8 @@ def _lazy_init_worker_procs(self): _WorkerProcParent( name=f"{self.__class__.__name__} {self.name} worker", buffer_size=self._buf_size, + # use reasonable default if unset, + # see https://github.com/rwth-i6/returnn/pull/1765 for discussion gc_interval=self._gc_interval if self._gc_interval is not None else max(self._buf_size, 100), index=i, map_seq=self._map_seq, From fb2d6e10737ac2d033f2fc0801934bec229c4697 Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Mon, 29 Sep 2025 16:28:17 +0200 Subject: [PATCH 57/58] fix long line --- returnn/datasets/postprocessing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index ac00c08b3..916cfb89d 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -151,7 +151,7 @@ def __init__( :param gc_interval: Specifies after how many seqs garbage collection should be called in the worker processes. If > 0, must be >= `buf_size`. - If <= 0, no explicit garbage collection is done. This can lead to suboptimal memory consumption in the workers. + If <= 0, no explicit garbage collection is done. This can lead to higher memory consumption in the workers. If None (default), uses a reasonable default (`buf_size`, but min 100 seqs). :param num_workers: If > 0, configures the number of worker processes to use for data postprocessing. Only the postprocessing is distributed across subprocesses, From ba991806c15d612d51d5d1a5bb8acaf12c0217c3 Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Tue, 30 Sep 2025 16:06:42 +0200 Subject: [PATCH 58/58] remove unused import --- returnn/datasets/postprocessing.py | 1 - 1 file changed, 1 deletion(-) diff --git a/returnn/datasets/postprocessing.py b/returnn/datasets/postprocessing.py index 916cfb89d..a831abfe8 100644 --- a/returnn/datasets/postprocessing.py +++ b/returnn/datasets/postprocessing.py @@ -5,7 +5,6 @@ from __future__ import annotations from collections import deque -import gc from itertools import islice import numpy from numpy.random import RandomState