Skip to content

Commit 8d51b58

Browse files
committed
Internal change
PiperOrigin-RevId: 367564187
1 parent 04b94a5 commit 8d51b58

File tree

6 files changed

+374
-0
lines changed

6 files changed

+374
-0
lines changed

orbit/examples/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2021 The Orbit Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2021 The Orbit Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Copyright 2021 The Orbit Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""An evaluator object that can evaluate models with a single output."""
16+
import orbit
17+
import tensorflow as tf
18+
19+
20+
class SingleTaskEvaluator(orbit.StandardEvaluator):
21+
"""Evaluates a single-output model on a given dataset.
22+
23+
This evaluator will handle running a model with one output on a single
24+
dataset, and will apply the output of that model to one or more
25+
`tf.keras.metrics.Metric` objects.
26+
"""
27+
28+
def __init__(self,
29+
eval_dataset,
30+
label_key,
31+
model,
32+
metrics,
33+
evaluator_options=None):
34+
"""Initializes a `SingleTaskEvaluator` instance.
35+
36+
If the `SingleTaskEvaluator` should run its model under a distribution
37+
strategy, it should be created within that strategy's scope.
38+
39+
Arguments:
40+
eval_dataset: A `tf.data.Dataset` or `DistributedDataset` that contains a
41+
string-keyed dict of `Tensor`s.
42+
label_key: The key corresponding to the label value in feature
43+
dictionaries dequeued from `eval_dataset`. This key will be removed from
44+
the dictionary before it is passed to the model.
45+
model: A `tf.Module` or Keras `Model` object to evaluate.
46+
metrics: A single `tf.keras.metrics.Metric` object, or a list of
47+
`tf.keras.metrics.Metric` objects.
48+
evaluator_options: An optional `orbit.StandardEvaluatorOptions` object.
49+
"""
50+
51+
self.label_key = label_key
52+
self.model = model
53+
self.metrics = metrics if isinstance(metrics, list) else [metrics]
54+
55+
# Capture the strategy from the containing scope.
56+
self.strategy = tf.distribute.get_strategy()
57+
58+
super(SingleTaskEvaluator, self).__init__(
59+
eval_dataset=eval_dataset, options=evaluator_options)
60+
61+
def eval_begin(self):
62+
"""Actions to take once before every eval loop."""
63+
for metric in self.metrics:
64+
metric.reset_states()
65+
66+
def eval_step(self, iterator):
67+
"""One eval step. Called multiple times per eval loop by the superclass."""
68+
69+
def step_fn(inputs):
70+
# Extract the target value and delete it from the input dict, so that
71+
# the model never sees it.
72+
target = inputs.pop(self.label_key)
73+
output = self.model(inputs)
74+
for metric in self.metrics:
75+
metric.update_state(target, output)
76+
77+
# This is needed to handle distributed computation.
78+
self.strategy.run(step_fn, args=(next(iterator),))
79+
80+
def eval_end(self):
81+
"""Actions to take once after an eval loop."""
82+
with self.strategy.scope():
83+
# Export the metrics.
84+
metrics = {metric.name: metric.result() for metric in self.metrics}
85+
86+
return metrics
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright 2021 The Orbit Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for the single_task_evaluator."""
16+
import orbit
17+
from orbit.examples.single_task import single_task_evaluator
18+
from orbit.examples.single_task import single_task_trainer
19+
20+
import tensorflow as tf
21+
import tensorflow_datasets as tfds
22+
23+
24+
class SingleTaskEvaluatorTest(tf.test.TestCase):
25+
26+
def test_single_task_evaluation(self):
27+
28+
iris = tfds.load('iris')
29+
train_ds = iris['train'].batch(32)
30+
31+
model = tf.keras.Sequential([
32+
tf.keras.Input(shape=(4,), name='features'),
33+
tf.keras.layers.Dense(10, activation=tf.nn.relu),
34+
tf.keras.layers.Dense(10, activation=tf.nn.relu),
35+
tf.keras.layers.Dense(3)
36+
])
37+
38+
trainer = single_task_trainer.SingleTaskTrainer(
39+
train_ds,
40+
label_key='label',
41+
model=model,
42+
loss_fn=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
43+
optimizer=tf.keras.optimizers.SGD(learning_rate=0.01))
44+
45+
evaluator = single_task_evaluator.SingleTaskEvaluator(
46+
train_ds,
47+
label_key='label',
48+
model=model,
49+
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
50+
51+
controller = orbit.Controller(
52+
trainer=trainer,
53+
evaluator=evaluator,
54+
steps_per_loop=100,
55+
global_step=trainer.optimizer.iterations)
56+
57+
controller.train(train_ds.cardinality().numpy())
58+
controller.evaluate()
59+
accuracy = evaluator.metrics[0].result().numpy()
60+
61+
self.assertGreater(0.925, accuracy)
62+
63+
64+
if __name__ == '__main__':
65+
tf.test.main()
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# Copyright 2021 The Orbit Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""A trainer object that can train models with a single output."""
16+
17+
import orbit
18+
import tensorflow as tf
19+
20+
21+
class SingleTaskTrainer(orbit.StandardTrainer):
22+
"""Trains a single-output model on a given dataset.
23+
24+
This trainer will handle running a model with one output on a single
25+
dataset. It will apply the provided loss function to the model's output
26+
to calculate gradients and will apply them via the provided optimizer. It will
27+
also supply the output of that model to one or more `tf.keras.metrics.Metric`
28+
objects.
29+
"""
30+
31+
def __init__(self,
32+
train_dataset,
33+
label_key,
34+
model,
35+
loss_fn,
36+
optimizer,
37+
metrics=None,
38+
trainer_options=None):
39+
"""Initializes a `SingleTaskTrainer` instance.
40+
41+
If the `SingleTaskTrainer` should run its model under a distribution
42+
strategy, it should be created within that strategy's scope.
43+
44+
This trainer will also calculate metrics during training. The loss metric
45+
is calculated by default, but other metrics can be passed to the `metrics`
46+
arg.
47+
48+
Arguments:
49+
train_dataset: A `tf.data.Dataset` or `DistributedDataset` that contains a
50+
string-keyed dict of `Tensor`s.
51+
label_key: The key corresponding to the label value in feature
52+
dictionaries dequeued from `train_dataset`. This key will be removed
53+
from the dictionary before it is passed to the model.
54+
model: A `tf.Module` or Keras `Model` object to evaluate. It must accept a
55+
`training` kwarg.
56+
loss_fn: A per-element loss function of the form (target, output). The
57+
output of this loss function will be reduced via `tf.reduce_mean` to
58+
create the final loss. We recommend using the functions in the
59+
`tf.keras.losses` package or `tf.keras.losses.Loss` objects with
60+
`reduction=tf.keras.losses.reduction.NONE`.
61+
optimizer: A `tf.keras.optimizers.Optimizer` instance.
62+
metrics: A single `tf.keras.metrics.Metric` object, or a list of
63+
`tf.keras.metrics.Metric` objects.
64+
trainer_options: An optional `orbit.utils.StandardTrainerOptions` object.
65+
"""
66+
self.label_key = label_key
67+
self.model = model
68+
self.loss_fn = loss_fn
69+
self.optimizer = optimizer
70+
71+
# Capture the strategy from the containing scope.
72+
self.strategy = tf.distribute.get_strategy()
73+
74+
# We always want to report training loss.
75+
self.train_loss = tf.keras.metrics.Mean('training_loss', dtype=tf.float32)
76+
77+
# We need self.metrics to be an iterable later, so we handle that here.
78+
if metrics is None:
79+
self.metrics = []
80+
elif isinstance(metrics, list):
81+
self.metrics = metrics
82+
else:
83+
self.metrics = [metrics]
84+
85+
super(SingleTaskTrainer, self).__init__(
86+
train_dataset=train_dataset, options=trainer_options)
87+
88+
def train_loop_begin(self):
89+
"""Actions to take once, at the beginning of each train loop."""
90+
self.train_loss.reset_states()
91+
for metric in self.metrics:
92+
metric.reset_states()
93+
94+
def train_step(self, iterator):
95+
"""A train step. Called multiple times per train loop by the superclass."""
96+
97+
def train_fn(inputs):
98+
with tf.GradientTape() as tape:
99+
# Extract the target value and delete it from the input dict, so that
100+
# the model never sees it.
101+
target = inputs.pop(self.label_key)
102+
103+
# Get the outputs of the model.
104+
output = self.model(inputs, training=True)
105+
106+
# Get the average per-batch loss and scale it down by the number of
107+
# replicas. This ensures that we don't end up multiplying our loss by
108+
# the number of workers - gradients are summed, not averaged, across
109+
# replicas during the apply_gradients call.
110+
loss = tf.reduce_mean(self.loss_fn(target, output))
111+
scaled_loss = loss / self.strategy.num_replicas_in_sync
112+
113+
# Get the gradients by applying the loss to the model's trainable
114+
# variables.
115+
gradients = tape.gradient(scaled_loss, self.model.trainable_variables)
116+
117+
# Apply the gradients via the optimizer.
118+
self.optimizer.apply_gradients(
119+
list(zip(gradients, self.model.trainable_variables)))
120+
121+
# Update metrics.
122+
self.train_loss.update_state(loss)
123+
for metric in self.metrics:
124+
metric.update_state(target, output)
125+
126+
# This is needed to handle distributed computation.
127+
self.strategy.run(train_fn, args=(next(iterator),))
128+
129+
def train_loop_end(self):
130+
"""Actions to take once after a training loop."""
131+
with self.strategy.scope():
132+
# Export the metrics.
133+
metrics = {metric.name: metric.result() for metric in self.metrics}
134+
metrics[self.train_loss.name] = self.train_loss.result()
135+
136+
return metrics
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright 2021 The Orbit Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for the single_task_trainer."""
16+
import orbit
17+
from orbit.examples.single_task import single_task_trainer
18+
19+
import tensorflow as tf
20+
import tensorflow_datasets as tfds
21+
22+
23+
class SingleTaskTrainerTest(tf.test.TestCase):
24+
25+
def test_single_task_training(self):
26+
iris = tfds.load('iris')
27+
train_ds = iris['train'].batch(32).repeat()
28+
29+
model = tf.keras.Sequential([
30+
tf.keras.Input(shape=(4,), name='features'),
31+
tf.keras.layers.Dense(10, activation=tf.nn.relu),
32+
tf.keras.layers.Dense(10, activation=tf.nn.relu),
33+
tf.keras.layers.Dense(3)
34+
])
35+
36+
trainer = single_task_trainer.SingleTaskTrainer(
37+
train_ds,
38+
label_key='label',
39+
model=model,
40+
loss_fn=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
41+
optimizer=tf.keras.optimizers.SGD(learning_rate=0.01))
42+
43+
controller = orbit.Controller(
44+
trainer=trainer,
45+
steps_per_loop=100,
46+
global_step=trainer.optimizer.iterations)
47+
48+
controller.train(1)
49+
start_loss = trainer.train_loss.result().numpy()
50+
controller.train(500)
51+
end_loss = trainer.train_loss.result().numpy()
52+
53+
# Assert that the model has trained 'significantly' - that the loss
54+
# has dropped by over 50%.
55+
self.assertLess(end_loss, start_loss / 2)
56+
57+
58+
if __name__ == '__main__':
59+
tf.test.main()

0 commit comments

Comments
 (0)