diff --git a/returnn/perplexity.py b/returnn/perplexity.py new file mode 100644 index 00000000..6ec07ec2 --- /dev/null +++ b/returnn/perplexity.py @@ -0,0 +1,45 @@ +__all__ = ["ExtractPerplexityFromLearningRatesFileJob"] + +import ast +from typing import List + +from sisyphus import Job, Task, tk + + + +class ExtractPerplexityFromLearningRatesFileJob(Job): + """ + Extracts the perplexity from the RETURNN learning rates files. + """ + + def __init__( + self, + returnn_learning_rates: tk.Path, + eval_datasets: List[str], + ): + self.returnn_learning_rates = returnn_learning_rates + self.eval_datasets = sorted(eval_datasets) + + self.out_ppl_file = self.output_path("ppl.txt") + + self.out_perplexities = {f"ppl_{d}": self.output_var(f"ppl_{d}") for d in eval_datasets} + + self.rqmt = {"gpu": 0, "cpu": 1, "mem": 1, "time": 1} + + def tasks(self): + yield Task("run", resume="run", mini_task=True) + + def run(self): + with open(self.returnn_learning_rates.get_path(), "rt", encoding="utf-8") as f_in: + data = f_in.read() + lr_dict = ast.literal_eval(data) + lr_dict = sorted(lr_dict.items(), reverse=True) + last_entry = lr_dict[0] + + res = [] + for data_set in self.eval_datasets: + full_name = f"{data_set}_loss_ppl" # TODO actually check which name fits + res.append(f"{data_set} - ppl: {last_entry[full_name]} \n") + + with open(self.out_ppl_file.get_path(), "wt", encoding="utf-8") as f_out: + f_out.writelines(res) diff --git a/returnn/training.py b/returnn/training.py index 03e17127..749b25f5 100644 --- a/returnn/training.py +++ b/returnn/training.py @@ -35,12 +35,12 @@ class ReturnnModel: This is deprecated, use :class:`Checkpoint` instead. """ - def __init__(self, returnn_config_file, model, epoch): + def __init__(self, returnn_config_file: Path, model: Path, epoch: int): """ - :param Path returnn_config_file: Path to a returnn config file - :param Path model: Path to a RETURNN checkpoint (only the .meta for Tensorflow) - :param int epoch: + :param returnn_config_file: Path to a returnn config file + :param model: Path to a RETURNN checkpoint (only the .meta for Tensorflow) + :param epoch: """ self.returnn_config_file = returnn_config_file self.model = model @@ -52,7 +52,7 @@ class Checkpoint: Checkpoint object which holds the (Tensorflow) index file path as tk.Path, and will return the checkpoint path as common prefix of the .index/.meta/.data[...] - A checkpoint object should directly assigned to a RasrConfig entry (do not call `.ckpt_path`) + A checkpoint object should be directly assigned to a RasrConfig entry (do not call `.ckpt_path`) so that the hash will resolve correctly """