From 8c5f46505f0dab2727be5efd70f5d3a7f55535c1 Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Tue, 10 Sep 2024 15:53:13 +0200 Subject: [PATCH 1/3] `OggZipDataset`: normalize // to / when reading files from archive The dataset implementation joins some paths wrongly at times, leading to double slashes breaking lookup in the zip file, even though there is a file at that path. By normalizing // to / the file can be read properly. --- returnn/datasets/audio.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/returnn/datasets/audio.py b/returnn/datasets/audio.py index 2803f2599..2add4a106 100644 --- a/returnn/datasets/audio.py +++ b/returnn/datasets/audio.py @@ -179,7 +179,7 @@ def _read(self, filename, zip_index): return gzip.open(self._separate_txt_files[name], "rb").read() if self._zip_files is not None: - return self._zip_files[zip_index].read(filename) + return self._zip_files[zip_index].read(filename.replace("//", "/")) return open("%s/%s" % (self.paths[0], filename), "rb").read() def _collect_data_part(self, zip_index) -> Tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]: From f7947a42597b6a0742d65e7ccf1abbe7a4b5716f Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Fri, 23 May 2025 15:27:54 +0200 Subject: [PATCH 2/3] PT: in param avg, only sync data every sync step --- returnn/torch/distributed.py | 14 +++++++++++++- returnn/torch/engine.py | 7 ++++++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/returnn/torch/distributed.py b/returnn/torch/distributed.py index b82702128..a3bc87c88 100644 --- a/returnn/torch/distributed.py +++ b/returnn/torch/distributed.py @@ -126,9 +126,21 @@ def maybe_make_distributed_module(self, module: torch.nn.Module) -> Optional[Dis **kwargs, ) + def should_sync_now(self, *, epoch_step_idx: int) -> bool: + """ + :param epoch_step_idx: current step index + :return: whether to sync the training processes in this step + """ + if self._reduce_type == "grad": + return True + elif self._reduce_type == "param": + return (epoch_step_idx % self._param_sync_step) == (self._param_sync_step - 1) + else: + raise ValueError(f"invalid reduce_type {self._reduce_type}") + def step_after_param_update(self, *, module: torch.nn.Module, epoch_step_idx: int): """one train step""" - if self._reduce_type == "param" and ((epoch_step_idx % self._param_sync_step) == (self._param_sync_step - 1)): + if self._reduce_type == "param" and self.should_sync_now(epoch_step_idx=epoch_step_idx): _sync_params_avg(module=module, sync_on_cpu=self._opts.get("sync_on_cpu", False)) diff --git a/returnn/torch/engine.py b/returnn/torch/engine.py index b07be4052..757a35d72 100644 --- a/returnn/torch/engine.py +++ b/returnn/torch/engine.py @@ -405,7 +405,12 @@ def train_epoch(self): print("Time to get first batch data:", hms(step_begin_time - epoch_start_time), file=log.v5) _has_data = torch.tensor([extern_data_raw is not None], dtype=torch.int8) - if self._torch_distributed_ctx: + # Sync only on first train step, when we have run out of data and every time we synchronize + # the model between workers. + # This allows the different workers to progress independently between synchronizations. + if self._torch_distributed_ctx and ( + self._torch_distributed_ctx.should_sync_now() or step_idx == 0 or extern_data_raw is None + ): # use all reduce to check if all workers have data, if at least one worker does not have data, # all workers finish this epoch torch.distributed.all_reduce(_has_data, op=torch.distributed.ReduceOp.MIN) From bc10da0ae3c4debebaf0c2f76c63afb4525fc31c Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Fri, 23 May 2025 15:38:01 +0200 Subject: [PATCH 3/3] fix --- returnn/torch/engine.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/returnn/torch/engine.py b/returnn/torch/engine.py index 757a35d72..28fd1ce68 100644 --- a/returnn/torch/engine.py +++ b/returnn/torch/engine.py @@ -409,7 +409,9 @@ def train_epoch(self): # the model between workers. # This allows the different workers to progress independently between synchronizations. if self._torch_distributed_ctx and ( - self._torch_distributed_ctx.should_sync_now() or step_idx == 0 or extern_data_raw is None + self._torch_distributed_ctx.should_sync_now(epoch_step_idx=step_idx) + or step_idx == 0 + or extern_data_raw is None ): # use all reduce to check if all workers have data, if at least one worker does not have data, # all workers finish this epoch