Skip to content

Commit de25e89

Browse files
authored
ReturnnForwardJobV2 (#441)
1 parent 43dfdef commit de25e89

File tree

1 file changed

+200
-9
lines changed

1 file changed

+200
-9
lines changed

returnn/forward.py

Lines changed: 200 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,28 @@
1-
__all__ = ["ReturnnForwardJob"]
1+
"""
2+
RETURNN forward jobs
3+
"""
4+
5+
__all__ = ["ReturnnForwardJob", "ReturnnForwardJobV2"]
26

37
from sisyphus import *
48

59
import copy
610
import glob
711
import os
812
import shutil
9-
import subprocess as sp
13+
import subprocess
1014
import tempfile
1115
from typing import List, Optional, Union
1216

1317
from i6_core.returnn.config import ReturnnConfig
14-
from i6_core.returnn.training import Checkpoint, PtCheckpoint
18+
from i6_core.returnn.training import Checkpoint as TfCheckpoint, PtCheckpoint
1519
import i6_core.util as util
1620

21+
1722
Path = setup_path(__package__)
1823

24+
Checkpoint = Union[TfCheckpoint, PtCheckpoint, tk.Path]
25+
1926

2027
class ReturnnForwardJob(Job):
2128
"""
@@ -32,7 +39,7 @@ class ReturnnForwardJob(Job):
3239

3340
def __init__(
3441
self,
35-
model_checkpoint: Optional[Union[Checkpoint, PtCheckpoint]],
42+
model_checkpoint: Optional[Checkpoint],
3643
returnn_config: ReturnnConfig,
3744
returnn_python_exe: tk.Path,
3845
returnn_root: tk.Path,
@@ -111,7 +118,9 @@ def create_files(self):
111118

112119
# check here if model actually exists
113120
if self.model_checkpoint is not None:
114-
assert self.model_checkpoint.exists(), "Provided model does not exists: %s" % str(self.model_checkpoint)
121+
assert os.path.exists(
122+
_get_model_path(self.model_checkpoint).get_path()
123+
), f"Provided model checkpoint does not exists: {self.model_checkpoint}"
115124

116125
def run(self):
117126
# run everything in a TempDir as writing HDFs can cause heavy load
@@ -127,7 +136,7 @@ def run(self):
127136
env = os.environ.copy()
128137
env["OMP_NUM_THREADS"] = str(self.rqmt["cpu"])
129138
env["MKL_NUM_THREADS"] = str(self.rqmt["cpu"])
130-
sp.check_call(call, cwd=d, env=env)
139+
subprocess.check_call(call, cwd=d, env=env)
131140
except Exception as e:
132141
print("Run crashed - copy temporary work folder as 'crash_dir'")
133142
shutil.copytree(d, "crash_dir")
@@ -151,17 +160,17 @@ def create_returnn_config(
151160
eval_mode: bool,
152161
log_verbosity: int,
153162
device: str,
154-
**kwargs,
155-
):
163+
**_kwargs_unused,
164+
) -> ReturnnConfig:
156165
"""
157166
Update the config locally to make it ready for the forward/eval task.
158167
The resulting config will be used for hashing.
159168
160169
:param model_checkpoint:
161170
:param returnn_config:
171+
:param eval_mode:
162172
:param log_verbosity:
163173
:param device:
164-
:param kwargs:
165174
:return:
166175
"""
167176
assert device in ["gpu", "cpu"]
@@ -207,3 +216,185 @@ def hash(cls, kwargs):
207216
}
208217

209218
return super().hash(d)
219+
220+
221+
class ReturnnForwardJobV2(Job):
222+
"""
223+
Generic forward job.
224+
225+
The user specifies the outputs in the RETURNN config
226+
via `forward_callback`.
227+
That is expected to be an instance of `returnn.forward_iface.ForwardCallbackIface`
228+
or a callable/function which returns such an instance.
229+
230+
The callback is supposed to generate the output files in the current directory.
231+
The current directory will be a local temporary directory
232+
and the files are moved to the output directory at the end.
233+
234+
Nothing is enforced here by intention, to keep it generic.
235+
The task by default is set to "forward",
236+
but other tasks of RETURNN might be used as well.
237+
"""
238+
239+
def __init__(
240+
self,
241+
*,
242+
model_checkpoint: Optional[Checkpoint],
243+
returnn_config: ReturnnConfig,
244+
returnn_python_exe: tk.Path,
245+
returnn_root: tk.Path,
246+
output_files: List[str],
247+
log_verbosity: int = 5,
248+
device: str = "gpu",
249+
time_rqmt: float = 4,
250+
mem_rqmt: float = 4,
251+
cpu_rqmt: int = 2,
252+
):
253+
"""
254+
:param model_checkpoint: Checkpoint object pointing to a stored RETURNN Tensorflow/PyTorch model
255+
or None if network has no parameters or should be randomly initialized
256+
:param returnn_config: RETURNN config object
257+
:param returnn_python_exe: path to the RETURNN executable (python binary or launch script)
258+
:param returnn_root: path to the RETURNN src folder
259+
:param output_files: list of output file names that will be generated. These are just the basenames,
260+
and they are supposed to be created in the current directory.
261+
:param log_verbosity: RETURNN log verbosity
262+
:param device: RETURNN device, cpu or gpu
263+
:param time_rqmt: job time requirement in hours
264+
:param mem_rqmt: job memory requirement in GB
265+
:param cpu_rqmt: job cpu requirement
266+
"""
267+
self.returnn_config = returnn_config
268+
self.model_checkpoint = model_checkpoint
269+
self.returnn_python_exe = returnn_python_exe
270+
self.returnn_root = returnn_root
271+
self.log_verbosity = log_verbosity
272+
self.device = device
273+
274+
self.out_returnn_config_file = self.output_path("returnn.config")
275+
self.out_files = {output: self.output_path(output) for output in output_files}
276+
277+
self.rqmt = {
278+
"gpu": 1 if device == "gpu" else 0,
279+
"cpu": cpu_rqmt,
280+
"mem": mem_rqmt,
281+
"time": time_rqmt,
282+
}
283+
284+
def tasks(self):
285+
yield Task("create_files", mini_task=True)
286+
yield Task("run", resume="run", rqmt=self.rqmt)
287+
288+
def create_files(self):
289+
"""create files"""
290+
config = self.create_returnn_config(
291+
model_checkpoint=self.model_checkpoint,
292+
returnn_config=self.returnn_config,
293+
log_verbosity=self.log_verbosity,
294+
device=self.device,
295+
)
296+
config.write(self.out_returnn_config_file.get_path())
297+
298+
cmd = [
299+
self.returnn_python_exe.get_path(),
300+
os.path.join(self.returnn_root.get_path(), "rnn.py"),
301+
self.out_returnn_config_file.get_path(),
302+
]
303+
util.create_executable("rnn.sh", cmd)
304+
305+
# check here if model actually exists
306+
if self.model_checkpoint is not None:
307+
assert os.path.exists(
308+
_get_model_path(self.model_checkpoint).get_path()
309+
), f"Provided model checkpoint does not exists: {self.model_checkpoint}"
310+
311+
def run(self):
312+
"""run"""
313+
# run everything in a TempDir as writing files can cause heavy load
314+
with tempfile.TemporaryDirectory(prefix=gs.TMP_PREFIX) as tmp_dir:
315+
print("using temp-dir: %s" % tmp_dir)
316+
call = [
317+
self.returnn_python_exe.get_path(),
318+
os.path.join(self.returnn_root.get_path(), "rnn.py"),
319+
self.out_returnn_config_file.get_path(),
320+
]
321+
322+
try:
323+
env = os.environ.copy()
324+
env["OMP_NUM_THREADS"] = str(self.rqmt["cpu"])
325+
env["MKL_NUM_THREADS"] = str(self.rqmt["cpu"])
326+
subprocess.check_call(call, cwd=tmp_dir, env=env)
327+
except Exception:
328+
print("Run crashed - copy temporary work folder as 'crash_dir'")
329+
if os.path.exists("crash_dir"):
330+
shutil.rmtree("crash_dir")
331+
shutil.copytree(tmp_dir, "crash_dir", dirs_exist_ok=True)
332+
raise
333+
334+
# move outputs to output folder
335+
for k, v in self.out_files.items():
336+
assert os.path.exists(f"{tmp_dir}/{k}"), f"Output file {k} does not exist"
337+
shutil.move(f"{tmp_dir}/{k}", v.get_path())
338+
339+
# copy logs and anything else. don't make assumptions on filenames
340+
shutil.copytree(tmp_dir, ".", dirs_exist_ok=True)
341+
342+
@classmethod
343+
def create_returnn_config(
344+
cls,
345+
*,
346+
model_checkpoint: Optional[Checkpoint],
347+
returnn_config: ReturnnConfig,
348+
log_verbosity: int,
349+
device: str,
350+
**_kwargs,
351+
):
352+
"""
353+
Update the config locally to make it ready for the forward/eval task.
354+
The resulting config will be used for hashing.
355+
356+
:param model_checkpoint:
357+
:param returnn_config:
358+
:param log_verbosity:
359+
:param device:
360+
:return:
361+
"""
362+
assert "load" not in returnn_config.config
363+
assert "model" not in returnn_config.config
364+
365+
res = copy.deepcopy(returnn_config)
366+
367+
res.config.setdefault("task", "forward")
368+
if model_checkpoint is not None:
369+
res.config["load"] = model_checkpoint
370+
else:
371+
res.config.setdefault("allow_random_model_init", True)
372+
373+
res.post_config.setdefault("device", device)
374+
res.post_config.setdefault("log", ["./returnn.log"])
375+
res.post_config.setdefault("tf_log_dir", "returnn-tf-log")
376+
res.post_config.setdefault("log_verbosity", log_verbosity)
377+
378+
res.check_consistency()
379+
380+
return res
381+
382+
@classmethod
383+
def hash(cls, kwargs):
384+
d = {
385+
"returnn_config": cls.create_returnn_config(**kwargs),
386+
"returnn_python_exe": kwargs["returnn_python_exe"],
387+
"returnn_root": kwargs["returnn_root"],
388+
}
389+
390+
return super().hash(d)
391+
392+
393+
def _get_model_path(model: Checkpoint) -> tk.Path:
394+
if isinstance(model, tk.Path):
395+
return model
396+
if isinstance(model, TfCheckpoint):
397+
return model.index_path
398+
if isinstance(model, PtCheckpoint):
399+
return model.path
400+
raise TypeError(f"Unknown model checkpoint type: {type(model)}")

0 commit comments

Comments
 (0)