Skip to content

Commit 191d99a

Browse files
authored
Make boosted_trees Garden-official (#4377)
* Make boosted_trees Garden-official * Fix nits
1 parent 1886043 commit 191d99a

File tree

4 files changed

+148
-116
lines changed

4 files changed

+148
-116
lines changed

official/boosted_trees/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ Note that the model_dir is cleaned up before every time training starts.
3939

4040
Model parameters can be adjusted by flags, like `--n_trees`, `--max_depth`, `--learning_rate` and so on. Check out the code for details.
4141

42-
The final accuacy will be around 74% and loss will be around 0.516 over the eval set, when trained with the default parameters.
42+
The final accuracy will be around 74% and loss will be around 0.516 over the eval set, when trained with the default parameters.
4343

4444
By default, the first 1 million examples among 11 millions are used for training, and the last 1 million examples are used for evaluation.
4545
The training/evaluation data can be selected as index ranges by flags `--train_start`, `--train_count`, `--eval_start`, `--eval_count`, etc.

official/boosted_trees/data_download.py

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -12,59 +12,54 @@
1212
from __future__ import division
1313
from __future__ import print_function
1414

15-
import argparse
15+
import gzip
1616
import os
17-
import sys
1817
import tempfile
1918

19+
# pylint: disable=g-bad-import-order
2020
import numpy as np
2121
import pandas as pd
2222
from six.moves import urllib
23+
from absl import app as absl_app
24+
from absl import flags
2325
import tensorflow as tf
2426

25-
URL_ROOT = 'https://archive.ics.uci.edu/ml/machine-learning-databases/00280'
26-
INPUT_FILE = 'HIGGS.csv.gz'
27-
NPZ_FILE = 'HIGGS.csv.gz.npz' # numpy compressed file to contain 'data' array.
27+
from official.utils.flags import core as flags_core
2828

29-
30-
def parse_args():
31-
"""Parses arguments and returns a tuple (known_args, unparsed_args)."""
32-
parser = argparse.ArgumentParser()
33-
parser.add_argument(
34-
'--data_dir', type=str, default='/tmp/higgs_data',
35-
help='Directory to download higgs dataset and store training/eval data.')
36-
return parser.parse_known_args()
29+
URL_ROOT = "https://archive.ics.uci.edu/ml/machine-learning-databases/00280"
30+
INPUT_FILE = "HIGGS.csv.gz"
31+
NPZ_FILE = "HIGGS.csv.gz.npz" # numpy compressed file to contain "data" array.
3732

3833

3934
def _download_higgs_data_and_save_npz(data_dir):
4035
"""Download higgs data and store as a numpy compressed file."""
4136
input_url = os.path.join(URL_ROOT, INPUT_FILE)
4237
np_filename = os.path.join(data_dir, NPZ_FILE)
4338
if tf.gfile.Exists(np_filename):
44-
raise ValueError('data_dir already has the processed data file: {}'.format(
39+
raise ValueError("data_dir already has the processed data file: {}".format(
4540
np_filename))
4641
if not tf.gfile.Exists(data_dir):
4742
tf.gfile.MkDir(data_dir)
4843
# 2.8 GB to download.
4944
try:
50-
print('Data downloading..')
45+
tf.logging.info("Data downloading...")
5146
temp_filename, _ = urllib.request.urlretrieve(input_url)
52-
5347
# Reading and parsing 11 million csv lines takes 2~3 minutes.
54-
print('Data processing.. taking multiple minutes..')
55-
data = pd.read_csv(
56-
temp_filename,
57-
dtype=np.float32,
58-
names=['c%02d' % i for i in range(29)] # label + 28 features.
59-
).as_matrix()
48+
tf.logging.info("Data processing... taking multiple minutes...")
49+
with gzip.open(temp_filename, "rb") as csv_file:
50+
data = pd.read_csv(
51+
csv_file,
52+
dtype=np.float32,
53+
names=["c%02d" % i for i in range(29)] # label + 28 features.
54+
).as_matrix()
6055
finally:
61-
os.remove(temp_filename)
56+
tf.gfile.Remove(temp_filename)
6257

6358
# Writing to temporary location then copy to the data_dir (0.8 GB).
6459
f = tempfile.NamedTemporaryFile()
6560
np.savez_compressed(f, data=data)
6661
tf.gfile.Copy(f.name, np_filename)
67-
print('Data saved to: {}'.format(np_filename))
62+
tf.logging.info("Data saved to: {}".format(np_filename))
6863

6964

7065
def main(unused_argv):
@@ -73,6 +68,16 @@ def main(unused_argv):
7368
_download_higgs_data_and_save_npz(FLAGS.data_dir)
7469

7570

76-
if __name__ == '__main__':
77-
FLAGS, unparsed = parse_args()
78-
tf.app.run(argv=[sys.argv[0]] + unparsed)
71+
def define_data_download_flags():
72+
"""Add flags specifying data download arguments."""
73+
flags.DEFINE_string(
74+
name="data_dir", default="/tmp/higgs_data",
75+
help=flags_core.help_wrap(
76+
"Directory to download higgs dataset and store training/eval data."))
77+
78+
79+
if __name__ == "__main__":
80+
tf.logging.set_verbosity(tf.logging.INFO)
81+
define_data_download_flags()
82+
FLAGS = flags.FLAGS
83+
absl_app.run(main)

official/boosted_trees/train_higgs.py

Lines changed: 83 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -29,64 +29,44 @@
2929
from __future__ import division
3030
from __future__ import print_function
3131

32-
import argparse
3332
import os
34-
import sys
3533

34+
# pylint: disable=g-bad-import-order
35+
import numpy as np
3636
from absl import app as absl_app
3737
from absl import flags
38-
import numpy as np # pylint: disable=wrong-import-order
39-
import tensorflow as tf # pylint: disable=wrong-import-order
38+
import tensorflow as tf
39+
# pylint: enable=g-bad-import-order
4040

4141
from official.utils.flags import core as flags_core
4242
from official.utils.flags._conventions import help_wrap
43+
from official.utils.logs import logger
4344

45+
NPZ_FILE = "HIGGS.csv.gz.npz" # numpy compressed file containing "data" array
4446

45-
NPZ_FILE = 'HIGGS.csv.gz.npz' # numpy compressed file containing 'data' array
46-
47-
48-
def define_train_higgs_flags():
49-
"""Add tree related flags as well as training/eval configuration."""
50-
flags_core.define_base(stop_threshold=False, batch_size=False, num_gpu=False)
51-
flags.adopt_module_key_flags(flags_core)
52-
53-
flags.DEFINE_integer(
54-
name='train_start', default=0,
55-
help=help_wrap('Start index of train examples within the data.'))
56-
flags.DEFINE_integer(
57-
name='train_count', default=1000000,
58-
help=help_wrap('Number of train examples within the data.'))
59-
flags.DEFINE_integer(
60-
name='eval_start', default=10000000,
61-
help=help_wrap('Start index of eval examples within the data.'))
62-
flags.DEFINE_integer(
63-
name='eval_count', default=1000000,
64-
help=help_wrap('Number of eval examples within the data.'))
65-
66-
flags.DEFINE_integer(
67-
'n_trees', default=100, help=help_wrap('Number of trees to build.'))
68-
flags.DEFINE_integer(
69-
'max_depth', default=6, help=help_wrap('Maximum depths of each tree.'))
70-
flags.DEFINE_float(
71-
'learning_rate', default=0.1,
72-
help=help_wrap('Maximum depths of each tree.'))
73-
74-
flags_core.set_defaults(data_dir='/tmp/higgs_data',
75-
model_dir='/tmp/higgs_model')
7647

48+
def read_higgs_data(data_dir, train_start, train_count, eval_start, eval_count):
49+
"""Reads higgs data from csv and returns train and eval data.
7750
51+
Args:
52+
data_dir: A string, the directory of higgs dataset.
53+
train_start: An integer, the start index of train examples within the data.
54+
train_count: An integer, the number of train examples within the data.
55+
eval_start: An integer, the start index of eval examples within the data.
56+
eval_count: An integer, the number of eval examples within the data.
7857
79-
def read_higgs_data(data_dir, train_start, train_count, eval_start, eval_count):
80-
"""Reads higgs data from csv and returns train and eval data."""
58+
Returns:
59+
Numpy array of train data and eval data.
60+
"""
8161
npz_filename = os.path.join(data_dir, NPZ_FILE)
8262
try:
8363
# gfile allows numpy to read data from network data sources as well.
84-
with tf.gfile.Open(npz_filename, 'rb') as npz_file:
64+
with tf.gfile.Open(npz_filename, "rb") as npz_file:
8565
with np.load(npz_file) as npz:
86-
data = npz['data']
66+
data = npz["data"]
8767
except Exception as e:
8868
raise RuntimeError(
89-
'Error loading data; use data_download.py to prepare the data:\n{}: {}'
69+
"Error loading data; use data_download.py to prepare the data:\n{}: {}"
9070
.format(type(e).__name__, e))
9171
return (data[train_start:train_start+train_count],
9272
data[eval_start:eval_start+eval_count])
@@ -105,18 +85,18 @@ def make_inputs_from_np_arrays(features_np, label_np):
10585
as a single tensor. Don't use batch.
10686
10787
Args:
108-
features_np: a numpy ndarray (shape=[batch_size, num_features]) for
88+
features_np: A numpy ndarray (shape=[batch_size, num_features]) for
10989
float32 features.
110-
label_np: a numpy ndarray (shape=[batch_size, 1]) for labels.
90+
label_np: A numpy ndarray (shape=[batch_size, 1]) for labels.
11191
11292
Returns:
113-
input_fn: a function returning a Dataset of feature dict and label.
114-
feature_column: a list of tf.feature_column.BucketizedColumn.
93+
input_fn: A function returning a Dataset of feature dict and label.
94+
feature_column: A list of tf.feature_column.BucketizedColumn.
11595
"""
11696
num_features = features_np.shape[1]
11797
features_np_list = np.split(features_np, num_features, axis=1)
11898
# 1-based feature names.
119-
feature_names = ['feature_%02d' % (i + 1) for i in range(num_features)]
99+
feature_names = ["feature_%02d" % (i + 1) for i in range(num_features)]
120100

121101
# Create source feature_columns and bucketized_columns.
122102
def get_bucket_boundaries(feature):
@@ -155,16 +135,16 @@ def make_eval_inputs_from_np_arrays(features_np, label_np):
155135
num_features = features_np.shape[1]
156136
features_np_list = np.split(features_np, num_features, axis=1)
157137
# 1-based feature names.
158-
feature_names = ['feature_%02d' % (i + 1) for i in range(num_features)]
138+
feature_names = ["feature_%02d" % (i + 1) for i in range(num_features)]
159139

160140
def input_fn():
161141
features = {
162142
feature_name: tf.constant(features_np_list[i])
163143
for i, feature_name in enumerate(feature_names)
164144
}
165-
return tf.data.Dataset.zip(
166-
(tf.data.Dataset.from_tensor_slices(features),
167-
tf.data.Dataset.from_tensor_slices(label_np),)).batch(1000)
145+
return tf.data.Dataset.zip((
146+
tf.data.Dataset.from_tensor_slices(features),
147+
tf.data.Dataset.from_tensor_slices(label_np),)).batch(1000)
168148

169149
return input_fn
170150

@@ -175,22 +155,37 @@ def train_boosted_trees(flags_obj):
175155
Args:
176156
flags_obj: An object containing parsed flag values.
177157
"""
178-
179158
# Clean up the model directory if present.
180159
if tf.gfile.Exists(flags_obj.model_dir):
181160
tf.gfile.DeleteRecursively(flags_obj.model_dir)
182-
print('## data loading..')
161+
tf.logging.info("## Data loading...")
183162
train_data, eval_data = read_higgs_data(
184163
flags_obj.data_dir, flags_obj.train_start, flags_obj.train_count,
185164
flags_obj.eval_start, flags_obj.eval_count)
186-
print('## data loaded; train: {}{}, eval: {}{}'.format(
165+
tf.logging.info("## Data loaded; train: {}{}, eval: {}{}".format(
187166
train_data.dtype, train_data.shape, eval_data.dtype, eval_data.shape))
188-
# data consists of one label column and 28 feature columns following.
167+
168+
# Data consists of one label column followed by 28 feature columns.
189169
train_input_fn, feature_columns = make_inputs_from_np_arrays(
190170
features_np=train_data[:, 1:], label_np=train_data[:, 0:1])
191171
eval_input_fn = make_eval_inputs_from_np_arrays(
192172
features_np=eval_data[:, 1:], label_np=eval_data[:, 0:1])
193-
print('## features prepared. training starts..')
173+
tf.logging.info("## Features prepared. Training starts...")
174+
175+
# Create benchmark logger to log info about the training and metric values
176+
run_params = {
177+
"train_start": flags_obj.train_start,
178+
"train_count": flags_obj.train_count,
179+
"eval_start": flags_obj.eval_start,
180+
"eval_count": flags_obj.eval_count,
181+
"n_trees": flags_obj.n_trees,
182+
"max_depth": flags_obj.max_depth,
183+
}
184+
benchmark_logger = logger.config_benchmark_logger(flags_obj)
185+
benchmark_logger.log_run_info(
186+
model_name="boosted_trees",
187+
dataset_name="higgs",
188+
run_params=run_params)
194189

195190
# Though BoostedTreesClassifier is under tf.estimator, faster in-memory
196191
# training is yet provided as a contrib library.
@@ -203,7 +198,9 @@ def train_boosted_trees(flags_obj):
203198
learning_rate=flags_obj.learning_rate)
204199

205200
# Evaluation.
206-
eval_result = classifier.evaluate(eval_input_fn)
201+
eval_results = classifier.evaluate(eval_input_fn)
202+
# Benchmark the evaluation results
203+
benchmark_logger.log_evaluation_result(eval_results)
207204

208205
# Exporting the savedmodel.
209206
if flags_obj.export_dir is not None:
@@ -216,7 +213,37 @@ def main(_):
216213
train_boosted_trees(flags.FLAGS)
217214

218215

219-
if __name__ == '__main__':
216+
def define_train_higgs_flags():
217+
"""Add tree related flags as well as training/eval configuration."""
218+
flags_core.define_base(stop_threshold=False, batch_size=False, num_gpu=False)
219+
flags.adopt_module_key_flags(flags_core)
220+
221+
flags.DEFINE_integer(
222+
name="train_start", default=0,
223+
help=help_wrap("Start index of train examples within the data."))
224+
flags.DEFINE_integer(
225+
name="train_count", default=1000000,
226+
help=help_wrap("Number of train examples within the data."))
227+
flags.DEFINE_integer(
228+
name="eval_start", default=10000000,
229+
help=help_wrap("Start index of eval examples within the data."))
230+
flags.DEFINE_integer(
231+
name="eval_count", default=1000000,
232+
help=help_wrap("Number of eval examples within the data."))
233+
234+
flags.DEFINE_integer(
235+
"n_trees", default=100, help=help_wrap("Number of trees to build."))
236+
flags.DEFINE_integer(
237+
"max_depth", default=6, help=help_wrap("Maximum depths of each tree."))
238+
flags.DEFINE_float(
239+
"learning_rate", default=0.1,
240+
help=help_wrap("The learning rate."))
241+
242+
flags_core.set_defaults(data_dir="/tmp/higgs_data",
243+
model_dir="/tmp/higgs_model")
244+
245+
246+
if __name__ == "__main__":
220247
# Training progress and eval results are shown as logging.INFO; so enables it.
221248
tf.logging.set_verbosity(tf.logging.INFO)
222249
define_train_higgs_flags()

0 commit comments

Comments
 (0)