Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions returnn/perplexity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
__all__ = ["ExtractPerplexityFromLearningRatesFileJob"]

import ast
from typing import List

from sisyphus import Job, Task, tk



class ExtractPerplexityFromLearningRatesFileJob(Job):
"""
Extracts the perplexity from the RETURNN learning rates files.
"""

def __init__(
self,
returnn_learning_rates: tk.Path,
eval_datasets: List[str],
):
self.returnn_learning_rates = returnn_learning_rates
self.eval_datasets = sorted(eval_datasets)

self.out_ppl_file = self.output_path("ppl.txt")

self.out_perplexities = {f"ppl_{d}": self.output_var(f"ppl_{d}") for d in eval_datasets}

self.rqmt = {"gpu": 0, "cpu": 1, "mem": 1, "time": 1}

def tasks(self):
yield Task("run", resume="run", mini_task=True)

def run(self):
with open(self.returnn_learning_rates.get_path(), "rt", encoding="utf-8") as f_in:
data = f_in.read()
lr_dict = ast.literal_eval(data)
lr_dict = sorted(lr_dict.items(), reverse=True)
last_entry = lr_dict[0]

res = []
for data_set in self.eval_datasets:
full_name = f"{data_set}_loss_ppl" # TODO actually check which name fits
res.append(f"{data_set} - ppl: {last_entry[full_name]} \n")

with open(self.out_ppl_file.get_path(), "wt", encoding="utf-8") as f_out:
f_out.writelines(res)
10 changes: 5 additions & 5 deletions returnn/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ class ReturnnModel:
This is deprecated, use :class:`Checkpoint` instead.
"""

def __init__(self, returnn_config_file, model, epoch):
def __init__(self, returnn_config_file: Path, model: Path, epoch: int):
"""

:param Path returnn_config_file: Path to a returnn config file
:param Path model: Path to a RETURNN checkpoint (only the .meta for Tensorflow)
:param int epoch:
:param returnn_config_file: Path to a returnn config file
:param model: Path to a RETURNN checkpoint (only the .meta for Tensorflow)
:param epoch:
"""
self.returnn_config_file = returnn_config_file
self.model = model
Expand All @@ -52,7 +52,7 @@ class Checkpoint:
Checkpoint object which holds the (Tensorflow) index file path as tk.Path,
and will return the checkpoint path as common prefix of the .index/.meta/.data[...]

A checkpoint object should directly assigned to a RasrConfig entry (do not call `.ckpt_path`)
A checkpoint object should be directly assigned to a RasrConfig entry (do not call `.ckpt_path`)
so that the hash will resolve correctly
"""

Expand Down
Loading