You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository was archived by the owner on Jul 10, 2025. It is now read-only.
Copy file name to clipboardExpand all lines: rfcs/20201121-keras-model-fit-ps.md
+37-33Lines changed: 37 additions & 33 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -197,10 +197,14 @@ For compatibility with other strategies, we propose that `dataset_fn` takes a si
197
197
198
198
*This is also in preparation for a multi-replica support in the future. See [tutorial](https://www.tensorflow.org/tutorials/distribute/parameter_server_training?hl=uk#dispatch_training_steps_to_remote_workers) for more information.
199
199
200
-
#### The setup of `ClusterCoordinator`
200
+
#### The setup of `ClusterCoordinator` with `model.fit` usage
201
+
202
+
##### Basic use case: `ClusterCoordinator` being internal
201
203
202
204
To take advantage of TF2 support of parameter server training, a `ClusterCoordinator` should be created for handling asynchronous function scheduling and joining. The preferred route should be that such an object is abstracted away from the user with `model.fit` training API as an implementation detail, since we do not expect users to `schedule` functions themselves, or synchronize the cluster in the basic workflow.
203
205
206
+
##### Advanced use case: `ClusterCoordinator` as a singleton
207
+
204
208
Since `ClusterCoordinator` instance spins off worker and failure handling threads, there should only be one `ClusterCoordinator` at any given time, and making it a singleton ensures that those threads are only created once:
205
209
206
210
```
@@ -212,41 +216,13 @@ class ClusterCoordinator(object):
212
216
return ClusterCoordinator.instance
213
217
```
214
218
215
-
Being a singleton is important considering there are power users who would like to `schedule` functions themselves in addition to `model.fit` usage. That is, they can instantiate one before `model.fit` does, or use one after `model.fit` has instantiate one. In either case, they should access the same `ClusterCoordinator` instance.
219
+
Being a singleton is important considering there are power users who would like to `schedule` functions themselves in addition to `model.fit` usage. That is, they can instantiate one before `model.fit` does, or use one after `model.fit` has instantiated one. In either case, they should access the same `ClusterCoordinator` instance.
216
220
217
-
In terms of who keeps track of the `ClusterCoordinator` for `model.fit`, and when it starts allocating threads, there are a few options. Here, we assume that the distribution `Strategy` object can determine whether or not it is supposed to be used with a `ClusterCoordinator`. See below “Changes in tf.distribute” section for more information.
218
-
219
-
220
-
##### Option 1: Attach the `ClusterCoordinator`’s lifecycle to `model.fit`
221
-
222
-
With this option, an attribute is added to the `Model` that keeps track of the `ClusterCoordinator`, and it is instantiated when `model.fit` is called.
221
+
##### Have an attribute in `ParameterServerStrategy` that holds the `ClusterCoordinator`
223
222
223
+
We propose that an attribute is added to the `ParameterServerStrategy` to keep track of the `ClusterCoordinator`. We instantiate `ClusterCoordinator` as soon as `ParameterServerStrategy` is instantiated:
224
224
225
225
```
226
-
class Model(...):
227
-
def __init__(self):
228
-
self._cluster_coordinator = None
229
-
...
230
-
231
-
def fit(self, ...):
232
-
if (self.distribute_strategy.should_use_with_coordinator() and
self._cluster_coordinator.shut_down() # Shut down at the end of `fit`
238
-
self._cluster_coordinator = None
239
-
240
-
class ClusterCoodinator(object):
241
-
def shut_down(self):
242
-
# Join the threads and terminate resources. We don't have this implemented yet.
243
-
```
244
-
245
-
246
-
247
-
##### Option 2: Have an attribute in `ParameterServerStrategy` that holds the `ClusterCoordinator`
248
-
249
-
With this option, an attribute is added to the `ParameterServerStrategy` to keep track of the `ClusterCoordinator`. We start the `ClusterCoordinator` as soon as the `model.fit` is called for the first time, and do not attempt to shut it down after `fit` completes. It will then be reused for the next `fit`, or on a different model.
250
226
251
227
252
228
```
@@ -591,7 +567,7 @@ SidecarEvaluator(
591
567
* also accept the checkpoint files saved by `ModelCheckpoint` callback for periodic evaluation.
592
568
* accept arbitrary callbacks to be used in its internal `model.evaluate` call
593
569
594
-
##### An sidecar evaluation thread on coordinator
570
+
##### A sidecar evaluation thread on coordinator
595
571
596
572
A potentially more seamless and encapsulated sidecar evaluation, where the user is not required to allocate an evaluator task or run separate code, can be done with an evaluation thread on the coordinator. This thread would remotely execute an evaluation function on a worker, and wait for its result synchronously. Once the result is returned, it can write a summary, adjust learning rate, or signal to end the training. Then, it re-`schedule`s an evaluation function, and so on:
597
573
@@ -785,3 +761,31 @@ dataset = tf.data.Dataset.X... # Make use of `preproc_stage` for transformation
785
761
history = model.fit(dataset, epochs=..., steps_per_epoch=..., callbacks=[...])
786
762
logging.info("result: %r", history)
787
763
```
764
+
765
+
766
+
### Attach the `ClusterCoordinator`’s lifecycle to `model.fit`
767
+
768
+
With this option, an attribute is added to the `Model` that keeps track of the `ClusterCoordinator`, and it is instantiated when `model.fit` is called.
769
+
770
+
771
+
```
772
+
class Model(...):
773
+
def __init__(self):
774
+
self._cluster_coordinator = None
775
+
...
776
+
777
+
def fit(self, ...):
778
+
if (self.distribute_strategy.should_use_with_coordinator() and
0 commit comments