Skip to content
This repository was archived by the owner on Jul 10, 2025. It is now read-only.

Commit b9144a3

Browse files
committed
Eval updates
1 parent 9c89029 commit b9144a3

File tree

1 file changed

+81
-27
lines changed

1 file changed

+81
-27
lines changed

rfcs/20201121-keras-model-fit-ps.md

Lines changed: 81 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,11 @@ With a dataset factory:
7575
cluster_resolver = ...
7676
strategy = tf.distribute.experimental.ParameterServerStrategy(cluster_resolver)
7777
with strategy.scope():
78-
preproc_stage = ... # Some Keras preproc layers
7978
model = ... # Building a Keras model
8079
model.compile(optimizer=..., loss=...)
81-
def dataset_fn():
82-
return tf.data.Dataset.X... # Make use of `preproc_stage` for transformation
80+
def dataset_fn(input_context):
81+
# User can shard with `input_context` for strategy-compatibility
82+
return tf.data.Dataset.from_tensors(...).repeat(...).batch(...)
8383
8484
# `ClusterCoordinator` is created at `fit`
8585
history = model.fit(dataset_fn, epochs=..., steps_per_epoch=..., callbacks=[...])
@@ -110,7 +110,7 @@ This section discusses the changes needed to be made in `model` API and assumes
110110

111111
In this design, we propose `model.fit` to take a dataset function or factory, instead of a dataset instance (which is [what is currently supported](https://github.com/tensorflow/tensorflow/blob/6b9e35f1c0410607bf956c0f27e5d3e1456b4899/tensorflow/python/keras/engine/training.py#L887-L889)), for the following reasons:
112112

113-
* With `dataset` instance, there is complication brought by the need of replicating `dataset`s to workers.
113+
* With `dataset` instances, there is complication brought by the need of replicating `dataset`s to workers.
114114

115115
* With `dataset` replication, in the past we have observed more memory consumption, less flexibility for user’s dataset transformation, and suboptimal performance.
116116

@@ -133,7 +133,7 @@ In the simplest case, we can allow any kind of `callable` to be passed in:
133133

134134

135135
```
136-
def dataset_fn():
136+
def dataset_fn(input_context):
137137
return tf.data.Dataset.from_tensor_slices(...)
138138
history = model.fit(dataset_fn, epochs=..., steps_per_epoch=..., callbacks=[...])
139139
```
@@ -160,21 +160,19 @@ With an additionally defined class `DatasetFactory`:
160160

161161

162162
```
163-
class DatasetFactory(Factory):
163+
class tf.keras.experimental.DatasetFactory(Factory):
164164
165165
def __init__(self, x):
166166
if not callable(x):
167167
raise TypeError('Input for `DataFactory` must be a `callable`.')
168168
self.x = x
169169
170170
def __call__(self, *args, **kwargs):
171-
# We gain the flexibility of modifying args/kwargs for future-compatibility.
172-
# For example, if we allow different argument signature of user-provided
173-
# `dataset_fn`, this works as an abstraction layer.
174-
# If we now only allow zero-arg, but later extend it to one-arg (e.g. input
175-
# context), we omit the the arg when we observe that the function doesn't
176-
# take any arg.
177-
return self.x(*args, **kwargs)
171+
dataset = self.x(*args, **kwargs)
172+
if not isinstance(dataset, Dataset):
173+
raise TypeError('The `callable` provided to `DatasetFactory` must return '
174+
'a `Dataset`.')
175+
return dataset
178176
```
179177

180178
Pros:
@@ -190,13 +188,14 @@ The following discussion is based on option 1, where a simple callable is taken.
190188

191189
##### Implication on no strategy/other strategies
192190

193-
If `model.fit` is allowed to take a `dataset_fn`, use cases for synchronous strategies, and no strategy, can be readily applied. That is, when a dataset is needed, the `callable` is inspected: 1) if the `callable` expects an argument (which is supposed to be the input context), we directly provide it to `distribute_datasets_from_function`, or 2) if the `callable` does not expect an argument, we wrap it in a function (whose sole argument is the discarded `input_context`), which is then provided to `distribute_datasets_from_function`. In either case, we end up obtaining a distributed `dataset`, and the remaining workflow will apply.
191+
If `model.fit` is allowed to take a `dataset_fn`, use cases for synchronous strategies, and no strategy, can be readily applied. That is, we provide the `dataset_fn` to `distribute_datasets_from_function`, which correctly places `dataset`s on devices in synchronous training.
194192

195193

196194
##### Signature of `dataset_fn`
197195

198-
The signature (input argument and return value) of `dataset_fn` taken by `model.fit` should basically follow the signature `ClusterCoordinator.create_per_worker_dataset` takes. There has been discussion around whether that should take an `InputContext` for effective sharding. Current decision is that it does not, since we do not expect users to shard the dataset considering workers can get preempted.
196+
For compatibility with other strategies, we propose that `dataset_fn` takes a single argument `input_context`, and returns a `tf.data.Dataset`. This `dataset_fn` will be used in `strategy.distribute_datasets_from_function`, wrapped by a `per_worker_dataset_fn`*, passed to `create_per_worker_dataset`. See below "DataAdapter and DataHandler changes" section for how this can be implemented in `model.fit`. Though sharding is not necessary in PS training, it is fine that users shard with `Dataset.shard` using the `input_context` (which has sensible default attributes) in `dataset_fn`, if they need to use it across multiple strategies.
199197

198+
*This is also in preparation for a multi-replica support in the future. See [tutorial](https://www.tensorflow.org/tutorials/distribute/parameter_server_training?hl=uk#dispatch_training_steps_to_remote_workers) for more information.
200199

201200
#### The setup of `ClusterCoordinator`
202201

@@ -288,7 +287,7 @@ class Model(...):
288287
289288
if self._cluster_coordinator:
290289
# Note that `train_function` has to be a `tf.function`.
291-
self.train_function = lambda distributed_iterator: self._cluster_coordinator.schedule( # pylint: disable=g-long-lambda
290+
self.train_function = lambda distributed_iterator: self._cluster_coordinator.schedule(
292291
train_function, args=(distributed_iterator,))
293292
294293
return self.train_function
@@ -309,7 +308,14 @@ class ClusterCoordinatorDataHandler(DataHandler):
309308
310309
def _configure_dataset_and_inferred_steps(self, strategy, x, steps_per_epoch,
311310
class_weight):
312-
self._dataset = self._model._cluster_coordinator.create_per_worker_dataset(x)
311+
if not callable(x):
312+
raise TypeError("When using `ClusterCoordinator`, `x` must be a "
313+
"`callable`")
314+
def per_worker_dataset_fn():
315+
return strategy.distribute_datasets_from_function(x)
316+
317+
self._dataset = self._model._cluster_coordinator.create_per_worker_dataset(
318+
per_worker_dataset_fn)
313319
if steps_per_epoch is None:
314320
raise RuntimeError(
315321
"Steps per epoch must be specified with `ParameterServerStrategy`.")
@@ -488,34 +494,41 @@ class MyCheckpointCallback(tf.keras.callbacks.Callback):
488494

489495
#### Metrics variables
490496

491-
In Keras training APIs, users can specify custom metrics or strings for metrics in `model.compile`, and there is also built-in loss. The variables that are involved, are either created at `compile` time, which is under `strategy.scope`, or the first time they are being updated (at `fit` time, which is also under `strategy.scope`. Therefore the variables should be placed correctly in parameter servers.
497+
In Keras training APIs, users can specify custom metrics or strings for metrics in `model.compile`, and there is also built-in loss. The variables that are involved, are either created at `compile` time, which is under `strategy.scope`, or the first time they are being updated (at `fit` time, which is also under `strategy.scope`. Therefore the variables will be placed correctly in parameter servers.
492498

493499
There is also an option to place the metrics variables on workers, and aggregating the metrics result to parameter servers periodically. In theory, this results in fewer round trips between workers and parameter servers and hence better performance, but would require an additional `ClusterCoordinator` API to have explicit placement of variables on workers.
494500

495501

496502
#### Optimizer variables
497503

498-
Similarly, the hyper and slot variables an `optimizer` object uses, would be created at gradient application, at which point Keras `optimizer` has [entered](https://github.com/tensorflow/tensorflow/blob/4d1142b04b708372203e15abc4934f7289fd2255/tensorflow/python/keras/optimizer_v2/optimizer_v2.py#L956) `strategy.scope` for correct placement. For the variables that need to be colocated with other variables, such as slot variables, they should continue to work because Keras has made sure [`colocate_vars_with` variable creator scope is used](https://github.com/tensorflow/tensorflow/blob/4d1142b04b708372203e15abc4934f7289fd2255/tensorflow/python/keras/optimizer_v2/optimizer_v2.py#L904-L909), which gets recognized by `ParameterServerStrategy` when these variables are being created, and the variables end up getting placed accordingly.
504+
Similarly, the hyper and slot variables an `optimizer` object uses, would be created at gradient application, at which point Keras `optimizer` has [entered](https://github.com/tensorflow/tensorflow/blob/4d1142b04b708372203e15abc4934f7289fd2255/tensorflow/python/keras/optimizer_v2/optimizer_v2.py#L956) `strategy.scope` for correct placement. For the variables that need to be colocated with other variables, such as slot variables, they should continue to work because `tf.keras.optimizers.Optimizer` has made sure [`colocate_vars_with` variable creator scope is used](https://github.com/tensorflow/tensorflow/blob/4d1142b04b708372203e15abc4934f7289fd2255/tensorflow/python/keras/optimizer_v2/optimizer_v2.py#L904-L909), which gets recognized by `ParameterServerStrategy` when these variables are being created, and the variables end up getting placed accordingly.
499505

500506

501507
#### model.evaluate and model.predict
502508

503-
Initially, we aim to have `model.evaluate` and `model.predict` to only be carried out on the coordinator. That is, it does not involve distribution via a `ClusterCoordinator`, and thus the evaluate function is executed eagerly on the coordinator.
509+
Initially, we aim to have `model.evaluate` and `model.predict` to only be carried out on the coordinator. That is, it does not involve distribution via a `ClusterCoordinator`, and thus the evaluate function is executed on the coordinator.
510+
511+
In the longer term, we seek distributed support for `model.evaluate`, where the evaluate function is scheduled onto the workers to execute. Visitation guarantee cannot be supported currently with the parameter server training API, so we can implement distributed evaluation without it, or wait until that is supported, and integrate it. Things possibly involved with distributed `model.evaluate` include:
512+
513+
* support for local variables
514+
* support for local resources
515+
* efficient skipping of dataset batches or `dataset.shard` can be tf.function'ed
504516

505-
In the longer term, we seek distributed support for `model.evaluate`, where the evaluate function is scheduled onto the workers to execute. Visitation guarantee cannot be supported currently with the parameter server training API, so we can implement distributed evaluation without it, or wait until that is supported, and integrate it.
517+
With those, we do not expect an API change at `model.fit` level, but if we do encounter something that results in a change, it is reasonable to add an argument `model.fit(distribute_eval=...)`.
506518

507-
Also, see below “Evaluation” section for other proposed evaluation solutions accompanying `model.fit` usage.
519+
See below “Evaluation” section for other proposed evaluation solutions accompanying `model.fit` usage.
508520

509521
### Changes in tf.distribute
510522

511-
Coordinator-based distributed training was made available with the introduction of a `ClusterCoordinator` API, where a `Strategy` should be used in conjunction with it. In contrast, classic `strategy.run`-based distributed training only requires a `Strategy` object to be used. The code written for two schemes, with custom training loops, is easily distinguishable by the presence or absence of a `ClusterCoordinator` object. However, with `model.fit`, users are not expected to create a `ClusterCoordinator` object, and thus there needs to be a way for the user to specify whether the training should be performed with a `ClusterCoordinator` object. This can possibly be done at `__init__`, so that `model.fit` knows whether or not it is intended for a coordinator-based single-client training, or a traditional multi-client training.
523+
Coordinator-based distributed training was made available with the introduction of a `ClusterCoordinator` API, where a `Strategy` should be used in conjunction with it. In contrast, classic `strategy.run`-based distributed training only requires a `Strategy` object to be used. The code written for two schemes, with custom training loops, is easily distinguishable by the presence or absence of a `ClusterCoordinator` object. However, with `model.fit`, users are not expected to create a `ClusterCoordinator` object, and thus there needs to be a way for the user to specify whether the training should be performed with a `ClusterCoordinator` object. This can possibly be done at `Strategy.__init__`, so that `model.fit` knows whether or not it is intended for a coordinator-based single-client training, or a traditional multi-client training.
512524

513525
For now, it seems feasible that `ParameterServerStrategy` has a field `should_use_with_coordinator`, which is always True until usage without a `ClusterCoordinator` is supported, at which point it can be an argument of `__init__`.
514526

515527

516528
```
517529
class ParameterServerStrategy(Strategy):
518-
self.should_use_with_coordinator = True
530+
def __init__(self):
531+
self.should_use_with_coordinator = True
519532
```
520533

521534

@@ -578,12 +591,53 @@ SidecarEvaluator(
578591
* also accept the checkpoint files saved by `ModelCheckpoint` callback for periodic evaluation.
579592
* accept arbitrary callbacks to be used in its internal `model.evaluate` call
580593

581-
##### An evaluation thread on coordinator
594+
##### An sidecar evaluation thread on coordinator
582595

583-
A potentially more seamless and encapsulated sidecar evaluation, where the user is not required to allocate an evaluator task or run separate code, can be done with an evaluation thread on the coordinator. This thread would `schedule` an evaluation function to be executed on a worker, and wait for its result. One the result is returned, it can write a summary, adjust learning rate, or signal to end the training. Then, it re-`schedule`s an evaluation function, and so on.
596+
A potentially more seamless and encapsulated sidecar evaluation, where the user is not required to allocate an evaluator task or run separate code, can be done with an evaluation thread on the coordinator. This thread would remotely execute an evaluation function on a worker, and wait for its result synchronously. Once the result is returned, it can write a summary, adjust learning rate, or signal to end the training. Then, it re-`schedule`s an evaluation function, and so on:
584597

585-
In addition to more changes to `model.fit` API, this solution presents a challenge when workers can easily become unavailable, in which case a fault tolerance solution will be needed for evaluation. Moreover, evaluating on moving variables (as they are concurrently being updated by workers) can yield unreproducible evaluations, as opposed to an evaluator task case, where evaluation is always based on a checkpoint file.
598+
```
599+
class Model(...):
600+
def _continuously_evaluate(
601+
self, strategy, train_model, eval_dataset, eval_worker):
602+
603+
# The following attempts to clone the model
604+
train_model.save(model_path)
605+
with strategy.scope():
606+
eval_model = tf.keras.models.load_model(model_path)
607+
608+
# The following are mostly existing `model.evaluate` logic
609+
data_handler = ...
610+
self.test_function = self.make_test_function()
611+
while self.should_eval: # This stops when `fit` ends
612+
# Each iteration loads the latest saved by training
613+
eval_model.load_weights(weights_path)
614+
for _, iterator in data_handler.enumerate_epochs():
615+
... # Callbacks, tracing, etc.
616+
with tf.device(eval_worker):
617+
tmp_logs = self.test_function(iterator)
618+
... # Callbacks, etc.
619+
620+
def fit(self, ...):
621+
# At some point, we start a thread for sidecar eval
622+
t = threading.Thread(target=self._continuously_evaluate)
623+
t.start()
624+
...
625+
self.should_eval = False
626+
t.join()
627+
```
628+
629+
If we compare the sidecar evaluator thread solution vs sidecar evaluator task (process):
630+
631+
Pros:
632+
* This does not require a task to be set aside as evaluator
633+
* There is easier communication between the sidecar evaluator (thread) and the coordinator main thread, which is important for many callbacks
634+
635+
Cons:
636+
* This solution presents a challenge when workers can easily become unavailable, in which case it is not straightforward to immediately find another available worker to take over*
637+
* This solution is blocked on `tf.keras.models.load_model` being available on PS (if we have another cloning solution, that works too)
638+
* Users who can afford to allocate a high priority on an evaluator task cannot do so with workers; workers would simply have the same, usually lower, priority (and thus more frequent function-takeovers)
586639

640+
*Fault tolerance, the first con, may further be addressed with possibly another `ClusterCoordinator`, if it shares the threads with the other `ClusterCoordinator`, and the library allows multiple function queues to be accessed by the threads. More details may be discussed in a separate RFC.
587641

588642
### Fault tolerance
589643

0 commit comments

Comments
 (0)