Skip to content

Commit 87934a7

Browse files
authored
Add an example of tf.estimator (#2525)
* Add an example of tf.estimator * Import hooks from elasticai_api * Fix by comments
1 parent ddfa3b7 commit 87934a7

File tree

2 files changed

+170
-0
lines changed

2 files changed

+170
-0
lines changed

model_zoo/iris/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Copyright 2021 The ElasticDL Authors. All rights reserved.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.

model_zoo/iris/dnn_estimator.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# Copyright 2021 The ElasticDL Authors. All rights reserved.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
import csv
15+
import os
16+
17+
import tensorflow as tf
18+
19+
from elasticai_api.common.data_shard_service import build_data_shard_service
20+
from elasticai_api.tensorflow.hooks import ElasticDataShardReportHook
21+
22+
tf.logging.set_verbosity(tf.logging.INFO)
23+
24+
CATEGORY_CODE = {"Iris-setosa": 0, "Iris-versicolor": 1, "Iris-virginica": 2}
25+
DATASET_DIR = "/data/iris.data"
26+
27+
28+
def read_csv(file_path):
29+
rows = []
30+
with open(file_path) as csvfile:
31+
csv_reader = csv.reader(csvfile)
32+
for row in csv_reader:
33+
rows.append(row)
34+
return rows
35+
36+
37+
def model_fn(features, labels, mode, params):
38+
net = tf.feature_column.input_layer(features, params["feature_columns"])
39+
40+
for units in params["hidden_units"]:
41+
net = tf.layers.dense(net, units=units, activation=tf.nn.relu)
42+
logits = tf.layers.dense(net, params["n_classes"], activation=None)
43+
44+
predicted_classes = tf.argmax(logits, 1)
45+
if mode == tf.estimator.ModeKeys.PREDICT:
46+
predictions = {
47+
"classes": predicted_classes[:, tf.newaxis],
48+
"probs": tf.nn.softmax(logits),
49+
"logits": logits,
50+
}
51+
export_outputs = {
52+
"prediction": tf.estimator.export.PredictOutput(predictions)
53+
}
54+
return tf.estimator.EstimatorSpec(
55+
mode, predictions=predictions, export_outputs=export_outputs
56+
)
57+
58+
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
59+
if mode == tf.estimator.ModeKeys.TRAIN:
60+
optimizer = tf.train.AdagradOptimizer(learning_rate=0.1)
61+
train_op = optimizer.minimize(
62+
loss, global_step=tf.train.get_global_step()
63+
)
64+
logging_hook = tf.train.LoggingTensorHook(
65+
{"loss": loss}, every_n_iter=10
66+
)
67+
return tf.estimator.EstimatorSpec(
68+
mode, loss=loss, train_op=train_op, training_hooks=[logging_hook]
69+
)
70+
71+
accuracy = tf.metrics.accuracy(
72+
labels=labels, predictions=predicted_classes, name="acc"
73+
)
74+
metrics = {"accuracy": accuracy}
75+
if mode == tf.estimator.ModeKeys.EVAL:
76+
return tf.estimator.EstimatorSpec(
77+
mode, loss=loss, eval_metric_ops=metrics
78+
)
79+
80+
81+
def train_generator(shard_service):
82+
rows = read_csv(DATASET_DIR)
83+
while True:
84+
# Read samples by the range of the shard from
85+
# the data shard serice.
86+
shard = shard_service.fetch_shard()
87+
if not shard:
88+
break
89+
for i in range(shard.start, shard.end):
90+
label = CATEGORY_CODE[rows[i][-1]]
91+
yield rows[i][0:-1], [label]
92+
93+
94+
def eval_generator():
95+
rows = read_csv(DATASET_DIR)
96+
for row in rows:
97+
label = CATEGORY_CODE[row[-1]]
98+
yield row[0:-1], [label]
99+
100+
101+
def input_fn(sample_generator, batch_size):
102+
dataset = tf.data.Dataset.from_generator(
103+
sample_generator,
104+
output_types=(tf.float32, tf.int32),
105+
output_shapes=(4, 1),
106+
)
107+
dataset = dataset.shuffle(100).batch(batch_size)
108+
feature_values, label_values = dataset.make_one_shot_iterator().get_next()
109+
features = {"x": feature_values}
110+
return features, label_values
111+
112+
113+
if __name__ == "__main__":
114+
model_dir = "/data/ckpts/"
115+
batch_size = 64
116+
feature_columns = [
117+
tf.feature_column.numeric_column(key="x", shape=(4,), dtype=tf.float32)
118+
]
119+
os.makedirs(model_dir, exist_ok=True)
120+
121+
config = tf.estimator.RunConfig(
122+
model_dir=model_dir, save_checkpoints_steps=300, keep_checkpoint_max=5
123+
)
124+
classifier = tf.estimator.Estimator(
125+
model_fn=model_fn,
126+
config=config,
127+
params={
128+
"hidden_units": [8, 4],
129+
"n_classes": 3,
130+
"feature_columns": feature_columns,
131+
},
132+
)
133+
134+
# Create a data shard service which can split the dataset
135+
# into shards.
136+
rows = read_csv(DATASET_DIR)
137+
training_data_shard_svc = build_data_shard_service(
138+
batch_size=batch_size,
139+
num_epochs=100,
140+
dataset_size=len(rows),
141+
num_minibatches_per_shard=1,
142+
dataset_name="iris_training_data",
143+
)
144+
145+
# Add a hook to report the shard done so that the data
146+
# shard service will not reassign the shard to other workers.
147+
hooks = [ElasticDataShardReportHook(training_data_shard_svc)]
148+
149+
def train_input_fn():
150+
return input_fn(
151+
lambda: train_generator(training_data_shard_svc), batch_size
152+
)
153+
154+
train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, hooks=hooks,)
155+
eval_spec = tf.estimator.EvalSpec(
156+
input_fn=lambda: input_fn(eval_generator, batch_size)
157+
)
158+
tf.estimator.train_and_evaluate(classifier, train_spec, eval_spec)

0 commit comments

Comments
 (0)