2626 theano = None
2727
2828
29- def build_env ():
29+ def build_env (env_update = None ):
30+ """
31+ :param dict[str,str]|None env_update:
32+ :return: env dict for Popen
33+ :rtype: dict[str,str]
34+ """
3035 theano_flags = {key : value for (key , value )
3136 in [s .split ("=" , 1 ) for s in os .environ .get ("THEANO_FLAGS" , "" ).split ("," ) if s ]}
3237 # First set some sane default for compile dir.
@@ -38,16 +43,18 @@ def build_env():
3843 theano_flags ["compiledir_format" ] += "--nosetests"
3944 # Nose-tests will set mode=FAST_COMPILE. We don't want this for our tests as it is way too slow.
4045 theano_flags ["mode" ] = "FAST_RUN"
41- env_update = os .environ .copy ()
42- env_update ["THEANO_FLAGS" ] = "," .join (["%s=%s" % (key , value ) for (key , value ) in theano_flags .items ()])
43- return env_update
46+ env_update_ = os .environ .copy ()
47+ env_update_ ["THEANO_FLAGS" ] = "," .join (["%s=%s" % (key , value ) for (key , value ) in theano_flags .items ()])
48+ if env_update :
49+ env_update_ .update (env_update )
50+ return env_update_
4451
4552
46- def run (* args ):
53+ def run (* args , env_update = None ):
4754 args = list (args )
4855 print ("run:" , args )
4956 # RETURNN by default outputs on stderr, so just merge both together
50- p = Popen (args , stdout = PIPE , stderr = STDOUT , env = build_env ())
57+ p = Popen (args , stdout = PIPE , stderr = STDOUT , env = build_env (env_update = env_update ))
5158 out , _ = p .communicate ()
5259 if p .returncode != 0 :
5360 print ("Return code is %i" % p .returncode )
@@ -56,8 +63,8 @@ def run(*args):
5663 return out .decode ("utf8" )
5764
5865
59- def run_and_parse_last_fer (* args ):
60- out = run (* args )
66+ def run_and_parse_last_fer (* args , ** kwargs ):
67+ out = run (* args , ** kwargs )
6168 parsed_fer = None
6269 for l in out .splitlines ():
6370 # example: epoch 5 score: 0.0231807245472 elapsed: 0:00:04 dev: score 0.0137521058997 error 0.00268961807423
@@ -69,9 +76,10 @@ def run_and_parse_last_fer(*args):
6976 return parsed_fer
7077
7178
72- def run_config_get_fer (config_filename ):
79+ def run_config_get_fer (config_filename , env_update = None ):
7380 cleanup_tmp_models (config_filename )
74- fer = run_and_parse_last_fer (py , "rnn.py" , config_filename , "++log_verbosity" , "5" )
81+ fer = run_and_parse_last_fer (
82+ py , "rnn.py" , config_filename , "++log_verbosity" , "5" , env_update = env_update )
7583 cleanup_tmp_models (config_filename )
7684 return fer
7785
@@ -90,11 +98,21 @@ def cleanup_tmp_models(config_filename):
9098
9199
92100@unittest .skipIf (not theano , "Theano not installed" )
93- def test_demo_task12ax ():
101+ def test_demo_theano_task12ax ():
94102 fer = run_config_get_fer ("demos/demo-theano-task12ax.config" )
95103 assert_less (fer , 0.01 )
96104
97105
106+ def test_demo_tf_task12ax ():
107+ fer = run_config_get_fer ("demos/demo-tf-native-lstm.12ax.config" )
108+ assert_less (fer , 0.01 )
109+
110+
111+ def test_demo_tf_task12ax_no_test_env ():
112+ fer = run_config_get_fer ("demos/demo-tf-native-lstm.12ax.config" , env_update = {"RETURNN_TEST" : "" })
113+ assert_less (fer , 0.01 )
114+
115+
98116def test_demo_iter_dataset_task12ax ():
99117 cleanup_tmp_models ("demos/demo-theano-task12ax.config" )
100118 out = run (py , "demos/demo-iter-dataset.py" , "demos/demo-theano-task12ax.config" )
0 commit comments