1- __all__ = ["ReturnnForwardJob" ]
1+ """
2+ RETURNN forward jobs
3+ """
4+
5+ __all__ = ["ReturnnForwardJob" , "ReturnnForwardJobV2" ]
26
37from sisyphus import *
48
59import copy
610import glob
711import os
812import shutil
9- import subprocess as sp
13+ import subprocess
1014import tempfile
1115from typing import List , Optional , Union
1216
1317from i6_core .returnn .config import ReturnnConfig
14- from i6_core .returnn .training import Checkpoint , PtCheckpoint
18+ from i6_core .returnn .training import Checkpoint as TfCheckpoint , PtCheckpoint
1519import i6_core .util as util
1620
21+
1722Path = setup_path (__package__ )
1823
24+ Checkpoint = Union [TfCheckpoint , PtCheckpoint , tk .Path ]
25+
1926
2027class ReturnnForwardJob (Job ):
2128 """
@@ -32,7 +39,7 @@ class ReturnnForwardJob(Job):
3239
3340 def __init__ (
3441 self ,
35- model_checkpoint : Optional [Union [ Checkpoint , PtCheckpoint ] ],
42+ model_checkpoint : Optional [Checkpoint ],
3643 returnn_config : ReturnnConfig ,
3744 returnn_python_exe : tk .Path ,
3845 returnn_root : tk .Path ,
@@ -111,7 +118,9 @@ def create_files(self):
111118
112119 # check here if model actually exists
113120 if self .model_checkpoint is not None :
114- assert self .model_checkpoint .exists (), "Provided model does not exists: %s" % str (self .model_checkpoint )
121+ assert os .path .exists (
122+ _get_model_path (self .model_checkpoint ).get_path ()
123+ ), f"Provided model checkpoint does not exists: { self .model_checkpoint } "
115124
116125 def run (self ):
117126 # run everything in a TempDir as writing HDFs can cause heavy load
@@ -127,7 +136,7 @@ def run(self):
127136 env = os .environ .copy ()
128137 env ["OMP_NUM_THREADS" ] = str (self .rqmt ["cpu" ])
129138 env ["MKL_NUM_THREADS" ] = str (self .rqmt ["cpu" ])
130- sp .check_call (call , cwd = d , env = env )
139+ subprocess .check_call (call , cwd = d , env = env )
131140 except Exception as e :
132141 print ("Run crashed - copy temporary work folder as 'crash_dir'" )
133142 shutil .copytree (d , "crash_dir" )
@@ -151,17 +160,17 @@ def create_returnn_config(
151160 eval_mode : bool ,
152161 log_verbosity : int ,
153162 device : str ,
154- ** kwargs ,
155- ):
163+ ** _kwargs_unused ,
164+ ) -> ReturnnConfig :
156165 """
157166 Update the config locally to make it ready for the forward/eval task.
158167 The resulting config will be used for hashing.
159168
160169 :param model_checkpoint:
161170 :param returnn_config:
171+ :param eval_mode:
162172 :param log_verbosity:
163173 :param device:
164- :param kwargs:
165174 :return:
166175 """
167176 assert device in ["gpu" , "cpu" ]
@@ -207,3 +216,185 @@ def hash(cls, kwargs):
207216 }
208217
209218 return super ().hash (d )
219+
220+
221+ class ReturnnForwardJobV2 (Job ):
222+ """
223+ Generic forward job.
224+
225+ The user specifies the outputs in the RETURNN config
226+ via `forward_callback`.
227+ That is expected to be an instance of `returnn.forward_iface.ForwardCallbackIface`
228+ or a callable/function which returns such an instance.
229+
230+ The callback is supposed to generate the output files in the current directory.
231+ The current directory will be a local temporary directory
232+ and the files are moved to the output directory at the end.
233+
234+ Nothing is enforced here by intention, to keep it generic.
235+ The task by default is set to "forward",
236+ but other tasks of RETURNN might be used as well.
237+ """
238+
239+ def __init__ (
240+ self ,
241+ * ,
242+ model_checkpoint : Optional [Checkpoint ],
243+ returnn_config : ReturnnConfig ,
244+ returnn_python_exe : tk .Path ,
245+ returnn_root : tk .Path ,
246+ output_files : List [str ],
247+ log_verbosity : int = 5 ,
248+ device : str = "gpu" ,
249+ time_rqmt : float = 4 ,
250+ mem_rqmt : float = 4 ,
251+ cpu_rqmt : int = 2 ,
252+ ):
253+ """
254+ :param model_checkpoint: Checkpoint object pointing to a stored RETURNN Tensorflow/PyTorch model
255+ or None if network has no parameters or should be randomly initialized
256+ :param returnn_config: RETURNN config object
257+ :param returnn_python_exe: path to the RETURNN executable (python binary or launch script)
258+ :param returnn_root: path to the RETURNN src folder
259+ :param output_files: list of output file names that will be generated. These are just the basenames,
260+ and they are supposed to be created in the current directory.
261+ :param log_verbosity: RETURNN log verbosity
262+ :param device: RETURNN device, cpu or gpu
263+ :param time_rqmt: job time requirement in hours
264+ :param mem_rqmt: job memory requirement in GB
265+ :param cpu_rqmt: job cpu requirement
266+ """
267+ self .returnn_config = returnn_config
268+ self .model_checkpoint = model_checkpoint
269+ self .returnn_python_exe = returnn_python_exe
270+ self .returnn_root = returnn_root
271+ self .log_verbosity = log_verbosity
272+ self .device = device
273+
274+ self .out_returnn_config_file = self .output_path ("returnn.config" )
275+ self .out_files = {output : self .output_path (output ) for output in output_files }
276+
277+ self .rqmt = {
278+ "gpu" : 1 if device == "gpu" else 0 ,
279+ "cpu" : cpu_rqmt ,
280+ "mem" : mem_rqmt ,
281+ "time" : time_rqmt ,
282+ }
283+
284+ def tasks (self ):
285+ yield Task ("create_files" , mini_task = True )
286+ yield Task ("run" , resume = "run" , rqmt = self .rqmt )
287+
288+ def create_files (self ):
289+ """create files"""
290+ config = self .create_returnn_config (
291+ model_checkpoint = self .model_checkpoint ,
292+ returnn_config = self .returnn_config ,
293+ log_verbosity = self .log_verbosity ,
294+ device = self .device ,
295+ )
296+ config .write (self .out_returnn_config_file .get_path ())
297+
298+ cmd = [
299+ self .returnn_python_exe .get_path (),
300+ os .path .join (self .returnn_root .get_path (), "rnn.py" ),
301+ self .out_returnn_config_file .get_path (),
302+ ]
303+ util .create_executable ("rnn.sh" , cmd )
304+
305+ # check here if model actually exists
306+ if self .model_checkpoint is not None :
307+ assert os .path .exists (
308+ _get_model_path (self .model_checkpoint ).get_path ()
309+ ), f"Provided model checkpoint does not exists: { self .model_checkpoint } "
310+
311+ def run (self ):
312+ """run"""
313+ # run everything in a TempDir as writing files can cause heavy load
314+ with tempfile .TemporaryDirectory (prefix = gs .TMP_PREFIX ) as tmp_dir :
315+ print ("using temp-dir: %s" % tmp_dir )
316+ call = [
317+ self .returnn_python_exe .get_path (),
318+ os .path .join (self .returnn_root .get_path (), "rnn.py" ),
319+ self .out_returnn_config_file .get_path (),
320+ ]
321+
322+ try :
323+ env = os .environ .copy ()
324+ env ["OMP_NUM_THREADS" ] = str (self .rqmt ["cpu" ])
325+ env ["MKL_NUM_THREADS" ] = str (self .rqmt ["cpu" ])
326+ subprocess .check_call (call , cwd = tmp_dir , env = env )
327+ except Exception :
328+ print ("Run crashed - copy temporary work folder as 'crash_dir'" )
329+ if os .path .exists ("crash_dir" ):
330+ shutil .rmtree ("crash_dir" )
331+ shutil .copytree (tmp_dir , "crash_dir" , dirs_exist_ok = True )
332+ raise
333+
334+ # move outputs to output folder
335+ for k , v in self .out_files .items ():
336+ assert os .path .exists (f"{ tmp_dir } /{ k } " ), f"Output file { k } does not exist"
337+ shutil .move (f"{ tmp_dir } /{ k } " , v .get_path ())
338+
339+ # copy logs and anything else. don't make assumptions on filenames
340+ shutil .copytree (tmp_dir , "." , dirs_exist_ok = True )
341+
342+ @classmethod
343+ def create_returnn_config (
344+ cls ,
345+ * ,
346+ model_checkpoint : Optional [Checkpoint ],
347+ returnn_config : ReturnnConfig ,
348+ log_verbosity : int ,
349+ device : str ,
350+ ** _kwargs ,
351+ ):
352+ """
353+ Update the config locally to make it ready for the forward/eval task.
354+ The resulting config will be used for hashing.
355+
356+ :param model_checkpoint:
357+ :param returnn_config:
358+ :param log_verbosity:
359+ :param device:
360+ :return:
361+ """
362+ assert "load" not in returnn_config .config
363+ assert "model" not in returnn_config .config
364+
365+ res = copy .deepcopy (returnn_config )
366+
367+ res .config .setdefault ("task" , "forward" )
368+ if model_checkpoint is not None :
369+ res .config ["load" ] = model_checkpoint
370+ else :
371+ res .config .setdefault ("allow_random_model_init" , True )
372+
373+ res .post_config .setdefault ("device" , device )
374+ res .post_config .setdefault ("log" , ["./returnn.log" ])
375+ res .post_config .setdefault ("tf_log_dir" , "returnn-tf-log" )
376+ res .post_config .setdefault ("log_verbosity" , log_verbosity )
377+
378+ res .check_consistency ()
379+
380+ return res
381+
382+ @classmethod
383+ def hash (cls , kwargs ):
384+ d = {
385+ "returnn_config" : cls .create_returnn_config (** kwargs ),
386+ "returnn_python_exe" : kwargs ["returnn_python_exe" ],
387+ "returnn_root" : kwargs ["returnn_root" ],
388+ }
389+
390+ return super ().hash (d )
391+
392+
393+ def _get_model_path (model : Checkpoint ) -> tk .Path :
394+ if isinstance (model , tk .Path ):
395+ return model
396+ if isinstance (model , TfCheckpoint ):
397+ return model .index_path
398+ if isinstance (model , PtCheckpoint ):
399+ return model .path
400+ raise TypeError (f"Unknown model checkpoint type: { type (model )} " )
0 commit comments