Skip to content

Commit 9cc7eac

Browse files
author
Taylor Robie
authored
Add End-to-end tests for wide deep, and fix "wide" and "deep" configurations. (#3798)
* add end-to-end tests for wide_deep delint * address PR comments
1 parent eb73a85 commit 9cc7eac

File tree

3 files changed

+49
-5
lines changed

3 files changed

+49
-5
lines changed

official/utils/testing/integration.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import tempfile
2727

2828

29-
def run_synthetic(main, tmp_root, extra_flags=None):
29+
def run_synthetic(main, tmp_root, extra_flags=None, synth=True, max_train=1):
3030
"""Performs a minimal run of a model.
3131
3232
This function is intended to test for syntax errors throughout a model. A
@@ -37,15 +37,22 @@ def run_synthetic(main, tmp_root, extra_flags=None):
3737
function is "<MODULE>.main(argv)".
3838
tmp_root: Root path for the temp directory created by the test class.
3939
extra_flags: Additional flags passed by the caller of this function.
40+
synth: Use synthetic data.
41+
max_train: Maximum number of allowed training steps.
4042
"""
4143

4244
extra_flags = [] if extra_flags is None else extra_flags
4345

4446
model_dir = tempfile.mkdtemp(dir=tmp_root)
4547

4648
args = [sys.argv[0], "--model_dir", model_dir, "--train_epochs", "1",
47-
"--epochs_between_evals", "1", "--use_synthetic_data",
48-
"--max_train_steps", "1"] + extra_flags
49+
"--epochs_between_evals", "1"] + extra_flags
50+
51+
if synth:
52+
args.append("--use_synthetic_data")
53+
54+
if max_train is not None:
55+
args.extend(["--max_train_steps", str(max_train)])
4956

5057
try:
5158
main(args)

official/wide_deep/wide_deep.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@
4343
}
4444

4545

46+
LOSS_PREFIX = {'wide': 'linear/', 'deep': 'dnn/'}
47+
48+
4649
def build_model_columns():
4750
"""Builds a set of wide and deep feature columns."""
4851
# Continuous columns
@@ -190,10 +193,11 @@ def train_input_fn():
190193
def eval_input_fn():
191194
return input_fn(test_file, 1, False, flags.batch_size)
192195

196+
loss_prefix = LOSS_PREFIX.get(flags.model_type, '')
193197
train_hooks = hooks_helper.get_train_hooks(
194198
flags.hooks, batch_size=flags.batch_size,
195-
tensors_to_log={'average_loss': 'head/truediv',
196-
'loss': 'head/weighted_loss/Sum'})
199+
tensors_to_log={'average_loss': loss_prefix + 'head/truediv',
200+
'loss': loss_prefix + 'head/weighted_loss/Sum'})
197201

198202
# Train and evaluate the model every `flags.epochs_between_evals` epochs.
199203
for n in range(flags.train_epochs // flags.epochs_between_evals):

official/wide_deep/wide_deep_test.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import tensorflow as tf # pylint: disable=g-bad-import-order
2323

24+
from official.utils.testing import integration
2425
from official.wide_deep import wide_deep
2526

2627
tf.logging.set_verbosity(tf.logging.ERROR)
@@ -54,6 +55,14 @@ def setUp(self):
5455
with tf.gfile.Open(self.input_csv, 'w') as temp_csv:
5556
temp_csv.write(TEST_INPUT)
5657

58+
with tf.gfile.Open(TEST_CSV, "r") as temp_csv:
59+
test_csv_contents = temp_csv.read()
60+
61+
# Used for end-to-end tests.
62+
for fname in ['adult.data', 'adult.test']:
63+
with tf.gfile.Open(os.path.join(self.temp_dir, fname), 'w') as test_csv:
64+
test_csv.write(test_csv_contents)
65+
5766
def test_input_fn(self):
5867
dataset = wide_deep.input_fn(self.input_csv, 1, False, 1)
5968
features, labels = dataset.make_one_shot_iterator().get_next()
@@ -107,6 +116,30 @@ def input_fn():
107116
def test_wide_deep_estimator_training(self):
108117
self.build_and_test_estimator('wide_deep')
109118

119+
def test_end_to_end_wide(self):
120+
integration.run_synthetic(
121+
main=wide_deep.main, tmp_root=self.get_temp_dir(), extra_flags=[
122+
'--data_dir', self.get_temp_dir(),
123+
'--model_type', 'wide',
124+
],
125+
synth=False, max_train=None)
126+
127+
def test_end_to_end_deep(self):
128+
integration.run_synthetic(
129+
main=wide_deep.main, tmp_root=self.get_temp_dir(), extra_flags=[
130+
'--data_dir', self.get_temp_dir(),
131+
'--model_type', 'deep',
132+
],
133+
synth=False, max_train=None)
134+
135+
def test_end_to_end_wide_deep(self):
136+
integration.run_synthetic(
137+
main=wide_deep.main, tmp_root=self.get_temp_dir(), extra_flags=[
138+
'--data_dir', self.get_temp_dir(),
139+
'--model_type', 'wide_deep',
140+
],
141+
synth=False, max_train=None)
142+
110143

111144
if __name__ == '__main__':
112145
tf.test.main()

0 commit comments

Comments
 (0)