Skip to content

Commit bd2f2d8

Browse files
committed
test demos, include test without test env
1 parent 21d6b64 commit bd2f2d8

File tree

2 files changed

+32
-14
lines changed

2 files changed

+32
-14
lines changed

demos/demo-tf-native-lstm.12ax.config

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,15 @@ max_seqs = 10
1919
chunking = "200:200"
2020

2121
network = {
22-
"fw0": {"class": "rec", "unit": "nativelstm", "dropout": 0.1, "n_out": 10},
23-
"output": {"class": "softmax", "loss": "ce", "from": ["fw0"]}
22+
"fw0": {"class": "rec", "unit": "nativelstm", "dropout": 0.1, "n_out": 10, "from": "data"},
23+
"output": {"class": "softmax", "loss": "ce", "from": "fw0"}
2424
}
2525

2626
# training
2727
adam = True
2828
learning_rate = 0.01
2929
model = "/tmp/%s/returnn/%s/model" % (os.getlogin(), demo_name) # https://github.com/tensorflow/tensorflow/issues/6537
30-
num_epochs = 100
30+
num_epochs = 5
3131
save_interval = 1
3232
gradient_clip = 0
3333

tests/test_demos.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,12 @@
2626
theano = None
2727

2828

29-
def build_env():
29+
def build_env(env_update=None):
30+
"""
31+
:param dict[str,str]|None env_update:
32+
:return: env dict for Popen
33+
:rtype: dict[str,str]
34+
"""
3035
theano_flags = {key: value for (key, value)
3136
in [s.split("=", 1) for s in os.environ.get("THEANO_FLAGS", "").split(",") if s]}
3237
# First set some sane default for compile dir.
@@ -38,16 +43,18 @@ def build_env():
3843
theano_flags["compiledir_format"] += "--nosetests"
3944
# Nose-tests will set mode=FAST_COMPILE. We don't want this for our tests as it is way too slow.
4045
theano_flags["mode"] = "FAST_RUN"
41-
env_update = os.environ.copy()
42-
env_update["THEANO_FLAGS"] = ",".join(["%s=%s" % (key, value) for (key, value) in theano_flags.items()])
43-
return env_update
46+
env_update_ = os.environ.copy()
47+
env_update_["THEANO_FLAGS"] = ",".join(["%s=%s" % (key, value) for (key, value) in theano_flags.items()])
48+
if env_update:
49+
env_update_.update(env_update)
50+
return env_update_
4451

4552

46-
def run(*args):
53+
def run(*args, env_update=None):
4754
args = list(args)
4855
print("run:", args)
4956
# RETURNN by default outputs on stderr, so just merge both together
50-
p = Popen(args, stdout=PIPE, stderr=STDOUT, env=build_env())
57+
p = Popen(args, stdout=PIPE, stderr=STDOUT, env=build_env(env_update=env_update))
5158
out, _ = p.communicate()
5259
if p.returncode != 0:
5360
print("Return code is %i" % p.returncode)
@@ -56,8 +63,8 @@ def run(*args):
5663
return out.decode("utf8")
5764

5865

59-
def run_and_parse_last_fer(*args):
60-
out = run(*args)
66+
def run_and_parse_last_fer(*args, **kwargs):
67+
out = run(*args, **kwargs)
6168
parsed_fer = None
6269
for l in out.splitlines():
6370
# example: epoch 5 score: 0.0231807245472 elapsed: 0:00:04 dev: score 0.0137521058997 error 0.00268961807423
@@ -69,9 +76,10 @@ def run_and_parse_last_fer(*args):
6976
return parsed_fer
7077

7178

72-
def run_config_get_fer(config_filename):
79+
def run_config_get_fer(config_filename, env_update=None):
7380
cleanup_tmp_models(config_filename)
74-
fer = run_and_parse_last_fer(py, "rnn.py", config_filename, "++log_verbosity", "5")
81+
fer = run_and_parse_last_fer(
82+
py, "rnn.py", config_filename, "++log_verbosity", "5", env_update=env_update)
7583
cleanup_tmp_models(config_filename)
7684
return fer
7785

@@ -90,11 +98,21 @@ def cleanup_tmp_models(config_filename):
9098

9199

92100
@unittest.skipIf(not theano, "Theano not installed")
93-
def test_demo_task12ax():
101+
def test_demo_theano_task12ax():
94102
fer = run_config_get_fer("demos/demo-theano-task12ax.config")
95103
assert_less(fer, 0.01)
96104

97105

106+
def test_demo_tf_task12ax():
107+
fer = run_config_get_fer("demos/demo-tf-native-lstm.12ax.config")
108+
assert_less(fer, 0.01)
109+
110+
111+
def test_demo_tf_task12ax_no_test_env():
112+
fer = run_config_get_fer("demos/demo-tf-native-lstm.12ax.config", env_update={"RETURNN_TEST": ""})
113+
assert_less(fer, 0.01)
114+
115+
98116
def test_demo_iter_dataset_task12ax():
99117
cleanup_tmp_models("demos/demo-theano-task12ax.config")
100118
out = run(py, "demos/demo-iter-dataset.py", "demos/demo-theano-task12ax.config")

0 commit comments

Comments
 (0)