You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository was archived by the owner on Jul 10, 2025. It is now read-only.
history = model.fit(dataset_fn, epochs=..., steps_per_epoch=..., callbacks=[...])
@@ -110,7 +110,7 @@ This section discusses the changes needed to be made in `model` API and assumes
110
110
111
111
In this design, we propose `model.fit` to take a dataset function or factory, instead of a dataset instance (which is [what is currently supported](https://github.com/tensorflow/tensorflow/blob/6b9e35f1c0410607bf956c0f27e5d3e1456b4899/tensorflow/python/keras/engine/training.py#L887-L889)), for the following reasons:
112
112
113
-
* With `dataset`instance, there is complication brought by the need of replicating `dataset`s to workers.
113
+
* With `dataset`instances, there is complication brought by the need of replicating `dataset`s to workers.
114
114
115
115
* With `dataset` replication, in the past we have observed more memory consumption, less flexibility for user’s dataset transformation, and suboptimal performance.
116
116
@@ -133,7 +133,7 @@ In the simplest case, we can allow any kind of `callable` to be passed in:
133
133
134
134
135
135
```
136
-
def dataset_fn():
136
+
def dataset_fn(input_context):
137
137
return tf.data.Dataset.from_tensor_slices(...)
138
138
history = model.fit(dataset_fn, epochs=..., steps_per_epoch=..., callbacks=[...])
139
139
```
@@ -160,21 +160,19 @@ With an additionally defined class `DatasetFactory`:
160
160
161
161
162
162
```
163
-
class DatasetFactory(Factory):
163
+
class tf.keras.experimental.DatasetFactory(Factory):
164
164
165
165
def __init__(self, x):
166
166
if not callable(x):
167
167
raise TypeError('Input for `DataFactory` must be a `callable`.')
168
168
self.x = x
169
169
170
170
def __call__(self, *args, **kwargs):
171
-
# We gain the flexibility of modifying args/kwargs for future-compatibility.
172
-
# For example, if we allow different argument signature of user-provided
173
-
# `dataset_fn`, this works as an abstraction layer.
174
-
# If we now only allow zero-arg, but later extend it to one-arg (e.g. input
175
-
# context), we omit the the arg when we observe that the function doesn't
176
-
# take any arg.
177
-
return self.x(*args, **kwargs)
171
+
dataset = self.x(*args, **kwargs)
172
+
if not isinstance(dataset, Dataset):
173
+
raise TypeError('The `callable` provided to `DatasetFactory` must return '
174
+
'a `Dataset`.')
175
+
return dataset
178
176
```
179
177
180
178
Pros:
@@ -190,13 +188,14 @@ The following discussion is based on option 1, where a simple callable is taken.
190
188
191
189
##### Implication on no strategy/other strategies
192
190
193
-
If `model.fit` is allowed to take a `dataset_fn`, use cases for synchronous strategies, and no strategy, can be readily applied. That is, when a dataset is needed, the `callable` is inspected: 1) if the `callable` expects an argument (which is supposed to be the input context), we directly provide it to `distribute_datasets_from_function`, or 2) if the `callable` does not expect an argument, we wrap it in a function (whose sole argument is the discarded `input_context`), which is then provided to `distribute_datasets_from_function`. In either case, we end up obtaining a distributed `dataset`, and the remaining workflow will apply.
191
+
If `model.fit` is allowed to take a `dataset_fn`, use cases for synchronous strategies, and no strategy, can be readily applied. That is, we provide the `dataset_fn`to `distribute_datasets_from_function`, which correctly places `dataset`s on devices in synchronous training.
194
192
195
193
196
194
##### Signature of `dataset_fn`
197
195
198
-
The signature (input argument and return value) of `dataset_fn`taken by `model.fit` should basically follow the signature `ClusterCoordinator.create_per_worker_dataset` takes. There has been discussion around whether that should take an `InputContext` for effective sharding. Current decision is that it does not, since we do not expect users to shardthe dataset considering workers can get preempted.
196
+
For compatibility with other strategies, we propose that `dataset_fn` takes a single argument `input_context`, and returns a `tf.data.Dataset`. This `dataset_fn`will be used in `strategy.distribute_datasets_from_function`, wrapped by a `per_worker_dataset_fn`*, passed to `create_per_worker_dataset`. See below "DataAdapter and DataHandler changes" section for how this can be implemented in `model.fit`. Though sharding is not necessary in PS training, it is fine that users shard with `Dataset.shard` using the `input_context` (which has sensible default attributes) in `dataset_fn`, if they need to use it across multiple strategies.
199
197
198
+
*This is also in preparation for a multi-replica support in the future. See [tutorial](https://www.tensorflow.org/tutorials/distribute/parameter_server_training?hl=uk#dispatch_training_steps_to_remote_workers) for more information.
200
199
201
200
#### The setup of `ClusterCoordinator`
202
201
@@ -288,7 +287,7 @@ class Model(...):
288
287
289
288
if self._cluster_coordinator:
290
289
# Note that `train_function` has to be a `tf.function`.
"Steps per epoch must be specified with `ParameterServerStrategy`.")
@@ -488,34 +494,41 @@ class MyCheckpointCallback(tf.keras.callbacks.Callback):
488
494
489
495
#### Metrics variables
490
496
491
-
In Keras training APIs, users can specify custom metrics or strings for metrics in `model.compile`, and there is also built-in loss. The variables that are involved, are either created at `compile` time, which is under `strategy.scope`, or the first time they are being updated (at `fit` time, which is also under `strategy.scope`. Therefore the variables should be placed correctly in parameter servers.
497
+
In Keras training APIs, users can specify custom metrics or strings for metrics in `model.compile`, and there is also built-in loss. The variables that are involved, are either created at `compile` time, which is under `strategy.scope`, or the first time they are being updated (at `fit` time, which is also under `strategy.scope`. Therefore the variables will be placed correctly in parameter servers.
492
498
493
499
There is also an option to place the metrics variables on workers, and aggregating the metrics result to parameter servers periodically. In theory, this results in fewer round trips between workers and parameter servers and hence better performance, but would require an additional `ClusterCoordinator` API to have explicit placement of variables on workers.
494
500
495
501
496
502
#### Optimizer variables
497
503
498
-
Similarly, the hyper and slot variables an `optimizer` object uses, would be created at gradient application, at which point Keras `optimizer` has [entered](https://github.com/tensorflow/tensorflow/blob/4d1142b04b708372203e15abc4934f7289fd2255/tensorflow/python/keras/optimizer_v2/optimizer_v2.py#L956)`strategy.scope` for correct placement. For the variables that need to be colocated with other variables, such as slot variables, they should continue to work because Keras has made sure [`colocate_vars_with` variable creator scope is used](https://github.com/tensorflow/tensorflow/blob/4d1142b04b708372203e15abc4934f7289fd2255/tensorflow/python/keras/optimizer_v2/optimizer_v2.py#L904-L909), which gets recognized by `ParameterServerStrategy` when these variables are being created, and the variables end up getting placed accordingly.
504
+
Similarly, the hyper and slot variables an `optimizer` object uses, would be created at gradient application, at which point Keras `optimizer` has [entered](https://github.com/tensorflow/tensorflow/blob/4d1142b04b708372203e15abc4934f7289fd2255/tensorflow/python/keras/optimizer_v2/optimizer_v2.py#L956)`strategy.scope` for correct placement. For the variables that need to be colocated with other variables, such as slot variables, they should continue to work because `tf.keras.optimizers.Optimizer` has made sure [`colocate_vars_with` variable creator scope is used](https://github.com/tensorflow/tensorflow/blob/4d1142b04b708372203e15abc4934f7289fd2255/tensorflow/python/keras/optimizer_v2/optimizer_v2.py#L904-L909), which gets recognized by `ParameterServerStrategy` when these variables are being created, and the variables end up getting placed accordingly.
499
505
500
506
501
507
#### model.evaluate and model.predict
502
508
503
-
Initially, we aim to have `model.evaluate` and `model.predict` to only be carried out on the coordinator. That is, it does not involve distribution via a `ClusterCoordinator`, and thus the evaluate function is executed eagerly on the coordinator.
509
+
Initially, we aim to have `model.evaluate` and `model.predict` to only be carried out on the coordinator. That is, it does not involve distribution via a `ClusterCoordinator`, and thus the evaluate function is executed on the coordinator.
510
+
511
+
In the longer term, we seek distributed support for `model.evaluate`, where the evaluate function is scheduled onto the workers to execute. Visitation guarantee cannot be supported currently with the parameter server training API, so we can implement distributed evaluation without it, or wait until that is supported, and integrate it. Things possibly involved with distributed `model.evaluate` include:
512
+
513
+
* support for local variables
514
+
* support for local resources
515
+
* efficient skipping of dataset batches or `dataset.shard` can be tf.function'ed
504
516
505
-
In the longer term, we seek distributed support for `model.evaluate`, where the evaluate function is scheduled onto the workers to execute. Visitation guarantee cannot be supported currently with the parameter server training API, so we can implement distributed evaluation without it, or wait until that is supported, and integrate it.
517
+
With those, we do not expect an API change at `model.fit` level, but if we do encounter something that results in a change, it is reasonable to add an argument `model.fit(distribute_eval=...)`.
506
518
507
-
Also, see below “Evaluation” section for other proposed evaluation solutions accompanying `model.fit` usage.
519
+
See below “Evaluation” section for other proposed evaluation solutions accompanying `model.fit` usage.
508
520
509
521
### Changes in tf.distribute
510
522
511
-
Coordinator-based distributed training was made available with the introduction of a `ClusterCoordinator` API, where a `Strategy` should be used in conjunction with it. In contrast, classic `strategy.run`-based distributed training only requires a `Strategy` object to be used. The code written for two schemes, with custom training loops, is easily distinguishable by the presence or absence of a `ClusterCoordinator` object. However, with `model.fit`, users are not expected to create a `ClusterCoordinator` object, and thus there needs to be a way for the user to specify whether the training should be performed with a `ClusterCoordinator` object. This can possibly be done at `__init__`, so that `model.fit` knows whether or not it is intended for a coordinator-based single-client training, or a traditional multi-client training.
523
+
Coordinator-based distributed training was made available with the introduction of a `ClusterCoordinator` API, where a `Strategy` should be used in conjunction with it. In contrast, classic `strategy.run`-based distributed training only requires a `Strategy` object to be used. The code written for two schemes, with custom training loops, is easily distinguishable by the presence or absence of a `ClusterCoordinator` object. However, with `model.fit`, users are not expected to create a `ClusterCoordinator` object, and thus there needs to be a way for the user to specify whether the training should be performed with a `ClusterCoordinator` object. This can possibly be done at `Strategy.__init__`, so that `model.fit` knows whether or not it is intended for a coordinator-based single-client training, or a traditional multi-client training.
512
524
513
525
For now, it seems feasible that `ParameterServerStrategy` has a field `should_use_with_coordinator`, which is always True until usage without a `ClusterCoordinator` is supported, at which point it can be an argument of `__init__`.
514
526
515
527
516
528
```
517
529
class ParameterServerStrategy(Strategy):
518
-
self.should_use_with_coordinator = True
530
+
def __init__(self):
531
+
self.should_use_with_coordinator = True
519
532
```
520
533
521
534
@@ -578,12 +591,53 @@ SidecarEvaluator(
578
591
* also accept the checkpoint files saved by `ModelCheckpoint` callback for periodic evaluation.
579
592
* accept arbitrary callbacks to be used in its internal `model.evaluate` call
580
593
581
-
##### An evaluation thread on coordinator
594
+
##### An sidecar evaluation thread on coordinator
582
595
583
-
A potentially more seamless and encapsulated sidecar evaluation, where the user is not required to allocate an evaluator task or run separate code, can be done with an evaluation thread on the coordinator. This thread would `schedule`an evaluation function to be executed on a worker, and wait for its result. One the result is returned, it can write a summary, adjust learning rate, or signal to end the training. Then, it re-`schedule`s an evaluation function, and so on.
596
+
A potentially more seamless and encapsulated sidecar evaluation, where the user is not required to allocate an evaluator task or run separate code, can be done with an evaluation thread on the coordinator. This thread would remotely execute an evaluation function on a worker, and wait for its result synchronously. Once the result is returned, it can write a summary, adjust learning rate, or signal to end the training. Then, it re-`schedule`s an evaluation function, and so on:
584
597
585
-
In addition to more changes to `model.fit` API, this solution presents a challenge when workers can easily become unavailable, in which case a fault tolerance solution will be needed for evaluation. Moreover, evaluating on moving variables (as they are concurrently being updated by workers) can yield unreproducible evaluations, as opposed to an evaluator task case, where evaluation is always based on a checkpoint file.
# The following are mostly existing `model.evaluate` logic
609
+
data_handler = ...
610
+
self.test_function = self.make_test_function()
611
+
while self.should_eval: # This stops when `fit` ends
612
+
# Each iteration loads the latest saved by training
613
+
eval_model.load_weights(weights_path)
614
+
for _, iterator in data_handler.enumerate_epochs():
615
+
... # Callbacks, tracing, etc.
616
+
with tf.device(eval_worker):
617
+
tmp_logs = self.test_function(iterator)
618
+
... # Callbacks, etc.
619
+
620
+
def fit(self, ...):
621
+
# At some point, we start a thread for sidecar eval
622
+
t = threading.Thread(target=self._continuously_evaluate)
623
+
t.start()
624
+
...
625
+
self.should_eval = False
626
+
t.join()
627
+
```
628
+
629
+
If we compare the sidecar evaluator thread solution vs sidecar evaluator task (process):
630
+
631
+
Pros:
632
+
* This does not require a task to be set aside as evaluator
633
+
* There is easier communication between the sidecar evaluator (thread) and the coordinator main thread, which is important for many callbacks
634
+
635
+
Cons:
636
+
* This solution presents a challenge when workers can easily become unavailable, in which case it is not straightforward to immediately find another available worker to take over*
637
+
* This solution is blocked on `tf.keras.models.load_model` being available on PS (if we have another cloning solution, that works too)
638
+
* Users who can afford to allocate a high priority on an evaluator task cannot do so with workers; workers would simply have the same, usually lower, priority (and thus more frequent function-takeovers)
586
639
640
+
*Fault tolerance, the first con, may further be addressed with possibly another `ClusterCoordinator`, if it shares the threads with the other `ClusterCoordinator`, and the library allows multiple function queues to be accessed by the threads. More details may be discussed in a separate RFC.
0 commit comments