Skip to content

Commit e773b9b

Browse files
Open source the second half of multi-task library
PiperOrigin-RevId: 365085378
1 parent 983837f commit e773b9b

File tree

10 files changed

+1037
-0
lines changed

10 files changed

+1037
-0
lines changed
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Lint as: python3
16+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
17+
#
18+
# Licensed under the Apache License, Version 2.0 (the "License");
19+
# you may not use this file except in compliance with the License.
20+
# You may obtain a copy of the License at
21+
#
22+
# http://www.apache.org/licenses/LICENSE-2.0
23+
#
24+
# Unless required by applicable law or agreed to in writing, software
25+
# distributed under the License is distributed on an "AS IS" BASIS,
26+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27+
# See the License for the specific language governing permissions and
28+
# limitations under the License.
29+
# ==============================================================================
30+
"""Multitask base trainer implementation.
31+
32+
The trainer derives from the Orbit `StandardTrainer` class.
33+
"""
34+
from typing import Union
35+
import gin
36+
import orbit
37+
import tensorflow as tf
38+
39+
from official.modeling.multitask import base_model
40+
from official.modeling.multitask import multitask
41+
42+
43+
@gin.configurable
44+
class MultiTaskBaseTrainer(orbit.StandardTrainer):
45+
"""Multitask base trainer."""
46+
47+
def __init__(self,
48+
multi_task: multitask.MultiTask,
49+
multi_task_model: Union[tf.keras.Model,
50+
base_model.MultiTaskBaseModel],
51+
optimizer: tf.optimizers.Optimizer,
52+
trainer_options=None):
53+
self._strategy = tf.distribute.get_strategy()
54+
self._multi_task = multi_task
55+
self._multi_task_model = multi_task_model
56+
self._optimizer = optimizer
57+
58+
self._training_losses = None
59+
self._training_metrics = None
60+
self._global_step = orbit.utils.create_global_step()
61+
62+
if hasattr(self.multi_task_model, "checkpoint_items"):
63+
checkpoint_items = self.multi_task_model.checkpoint_items
64+
else:
65+
checkpoint_items = {}
66+
67+
self._checkpoint = tf.train.Checkpoint(
68+
model=self.multi_task_model,
69+
optimizer=self.optimizer,
70+
global_step=self.global_step,
71+
**checkpoint_items)
72+
73+
train_datasets = {}
74+
for name, task in self.multi_task.tasks.items():
75+
train_datasets[name] = orbit.utils.make_distributed_dataset(
76+
self.strategy, task.build_inputs, task.task_config.train_data)
77+
78+
super().__init__(
79+
train_dataset=train_datasets,
80+
options=trainer_options or orbit.StandardTrainerOptions())
81+
82+
def train_loop_begin(self):
83+
"""Clean up states that hold losses and metrics."""
84+
for _, train_loss_metric in self.training_losses.items():
85+
train_loss_metric.reset_states()
86+
87+
for _, metrics in self.training_metrics.items():
88+
for metric in metrics:
89+
metric.reset_states()
90+
91+
def train_loop_end(self):
92+
"""Record loss and metric values per task."""
93+
result = {}
94+
for task_name, loss in self.training_losses.items():
95+
result[task_name] = {loss.name: loss.result()}
96+
for task_name, task_metrics in self.training_metrics.items():
97+
result[task_name].update(
98+
{metric.name: metric.result() for metric in task_metrics})
99+
# Note that, the learning rate schedule is managed by the keras optimizer
100+
# internally, which respects the number of backward pass as `iterations`.
101+
# The learning rate schedule does not follow the trainer logical global
102+
# step of multiple tasks.
103+
if callable(self.optimizer.learning_rate):
104+
result["learning_rate"] = self.optimizer.learning_rate(
105+
self.optimizer.iterations)
106+
else:
107+
result["learning_rate"] = self.optimizer.learning_rate
108+
return result
109+
110+
@property
111+
def checkpoint(self):
112+
"""Accesses the training checkpoint."""
113+
return self._checkpoint
114+
115+
@property
116+
def training_losses(self):
117+
"""Access training loss metric objects for all tasks."""
118+
if self._training_losses is None:
119+
# Builds the per-task metrics and losses.
120+
# This the total summed training loss of tasks in the joint training.
121+
self._training_losses = dict(
122+
total_loss=tf.keras.metrics.Mean("training_loss", dtype=tf.float32))
123+
for name in self.multi_task.tasks:
124+
self._training_losses[name] = tf.keras.metrics.Mean(
125+
"training_loss", dtype=tf.float32)
126+
return self._training_losses
127+
128+
@property
129+
def training_metrics(self):
130+
"""Access training metric metric objects for all tasks."""
131+
if self._training_metrics is None:
132+
# Builds the per-task metrics and losses.
133+
self._training_metrics = {}
134+
for name, task in self.multi_task.tasks.items():
135+
self._training_metrics[name] = task.build_metrics(training=True)
136+
return self._training_metrics
137+
138+
@property
139+
def strategy(self):
140+
return self._strategy
141+
142+
@property
143+
def multi_task(self):
144+
return self._multi_task
145+
146+
@property
147+
def multi_task_model(self):
148+
return self._multi_task_model
149+
150+
@property
151+
def optimizer(self):
152+
return self._optimizer
153+
154+
@property
155+
def global_step(self):
156+
return self._global_step
157+
158+
def train_step(self, iterator_map):
159+
"""The default train step calling the multi-task train step.
160+
161+
Args:
162+
iterator_map: a dictionary of task names and per-task dataset iterators.
163+
"""
164+
165+
def step_fn(inputs):
166+
losses = self.multi_task.joint_train_step(
167+
inputs,
168+
multi_task_model=self.multi_task_model,
169+
optimizer=self.optimizer,
170+
task_metrics=self.training_metrics)
171+
for key, loss in losses.items():
172+
self.training_losses[key].update_state(loss)
173+
174+
self.strategy.run(
175+
step_fn, args=(tf.nest.map_structure(next, iterator_map),))
176+
self.global_step.assign_add(1)
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for multitask.base_trainer."""
16+
from absl.testing import parameterized
17+
import tensorflow as tf
18+
19+
from tensorflow.python.distribute import combinations
20+
from tensorflow.python.distribute import strategy_combinations
21+
from official.modeling.multitask import base_trainer
22+
from official.modeling.multitask import configs
23+
from official.modeling.multitask import multitask
24+
from official.modeling.multitask import test_utils
25+
26+
27+
def all_strategy_combinations():
28+
return combinations.combine(
29+
distribution=[
30+
strategy_combinations.default_strategy,
31+
strategy_combinations.cloud_tpu_strategy,
32+
strategy_combinations.one_device_strategy_gpu,
33+
],
34+
mode="eager",
35+
)
36+
37+
38+
class BaseTrainerTest(tf.test.TestCase, parameterized.TestCase):
39+
40+
@combinations.generate(all_strategy_combinations())
41+
def test_multitask_joint_trainer(self, distribution):
42+
with distribution.scope():
43+
tasks = [
44+
test_utils.MockFooTask(params=test_utils.FooConfig(), name="foo"),
45+
test_utils.MockBarTask(params=test_utils.BarConfig(), name="bar")
46+
]
47+
task_weights = {"foo": 1.0, "bar": 1.0}
48+
test_multitask = multitask.MultiTask(
49+
tasks=tasks, task_weights=task_weights)
50+
test_optimizer = tf.keras.optimizers.SGD(0.1)
51+
model = test_utils.MockMultiTaskModel()
52+
test_trainer = base_trainer.MultiTaskBaseTrainer(
53+
multi_task=test_multitask,
54+
multi_task_model=model,
55+
optimizer=test_optimizer)
56+
results = test_trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
57+
self.assertContainsSubset(["training_loss", "bar_acc"],
58+
results["bar"].keys())
59+
self.assertContainsSubset(["training_loss", "foo_acc"],
60+
results["foo"].keys())
61+
62+
def test_trainer_with_configs(self):
63+
config = configs.MultiTaskConfig(
64+
task_routines=(configs.TaskRoutine(
65+
task_name="foo",
66+
task_config=test_utils.FooConfig(),
67+
task_weight=0.5),
68+
configs.TaskRoutine(
69+
task_name="bar",
70+
task_config=test_utils.BarConfig(),
71+
task_weight=0.5)))
72+
test_multitask = multitask.MultiTask.from_config(config)
73+
test_optimizer = tf.keras.optimizers.SGD(0.1)
74+
model = test_utils.MockMultiTaskModel()
75+
test_trainer = base_trainer.MultiTaskBaseTrainer(
76+
multi_task=test_multitask,
77+
multi_task_model=model,
78+
optimizer=test_optimizer)
79+
results = test_trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
80+
self.assertContainsSubset(["training_loss", "bar_acc"],
81+
results["bar"].keys())
82+
self.assertContainsSubset(["training_loss", "foo_acc"],
83+
results["foo"].keys())
84+
self.assertEqual(test_multitask.task_weight("foo"), 0.5)
85+
self.assertEqual(test_trainer.global_step.numpy(), 5)
86+
self.assertIn("learning_rate", results)
87+
88+
89+
if __name__ == "__main__":
90+
tf.test.main()

official/modeling/multitask/configs.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,39 @@ class MultiTaskConfig(hyperparams.Config):
3636
task_routines: Tuple[TaskRoutine, ...] = ()
3737

3838

39+
@dataclasses.dataclass
40+
class ProportionalSampleConfig(hyperparams.Config):
41+
alpha: float = 1.0
42+
43+
44+
@dataclasses.dataclass
45+
class AnnealingSampleConfig(hyperparams.Config):
46+
steps_per_epoch: int = 5
47+
total_steps: int = 20
48+
49+
50+
@dataclasses.dataclass
51+
class TaskSamplingConfig(hyperparams.OneOfConfig):
52+
type: str = ""
53+
uniform: hyperparams.Config = hyperparams.Config()
54+
proportional: ProportionalSampleConfig = ProportionalSampleConfig()
55+
annealing: AnnealingSampleConfig = AnnealingSampleConfig()
56+
57+
58+
@dataclasses.dataclass
59+
class MultiTaskTrainerConfig(cfg.TrainerConfig):
60+
trainer_type: str = "interleaving"
61+
task_sampler: TaskSamplingConfig = TaskSamplingConfig(type="proportional")
62+
63+
64+
@dataclasses.dataclass
65+
class MultiTaskExperimentConfig(hyperparams.Config):
66+
"""An experiment config for multi-task training and multi-task evaluation."""
67+
task: MultiTaskConfig = MultiTaskConfig()
68+
trainer: MultiTaskTrainerConfig = MultiTaskTrainerConfig()
69+
runtime: cfg.RuntimeConfig = cfg.RuntimeConfig()
70+
71+
3972
@dataclasses.dataclass
4073
class MultiEvalExperimentConfig(cfg.ExperimentConfig):
4174
"""An experiment config for single-task training and multi-task evaluation.

0 commit comments

Comments
 (0)