You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository was archived by the owner on Jul 10, 2025. It is now read-only.
Copy file name to clipboardExpand all lines: rfcs/20201121-keras-model-fit-ps.md
+47-40Lines changed: 47 additions & 40 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -198,53 +198,30 @@ For compatibility with other strategies, we propose that `dataset_fn` takes a si
198
198
199
199
*This is also in preparation for a multi-replica support in the future. See [tutorial](https://www.tensorflow.org/tutorials/distribute/parameter_server_training?hl=uk#dispatch_training_steps_to_remote_workers) for more information.
200
200
201
-
#### The setup of `ClusterCoordinator` with `model.fit` usage
202
201
203
-
##### Basic use case: `ClusterCoordinator` being internal
204
202
205
-
To take advantage of TF2 support of parameter server training, a `ClusterCoordinator` should be created for handling asynchronous function scheduling and joining. The preferred route should be that such an object is abstracted away from the user with `model.fit` training API as an implementation detail, since we do not expect users to `schedule` functions themselves, or synchronize the cluster in the basic workflow.
206
-
207
-
##### Advanced use case: `ClusterCoordinator` as a singleton
208
-
209
-
Now, let's consider a more advanced use case where the `ClusterCoordinator` instance is needed by users. Since `ClusterCoordinator` instance spins off worker and failure handling threads, there should only be one `ClusterCoordinator` at any given time, and making it a singleton ensures that those threads are only created once:
203
+
#### Keras `Model` changes
210
204
211
-
```
212
-
class ClusterCoordinator(object):
213
-
def __new__(cls, strategy):
214
-
if not strategy.cluster_coordinator: # TODO: Needs a lock for thread-safety
##### `Model` abstracting the concept of `ClusterCoordinator` for `model.fit`
218
206
219
-
Being a singleton is important considering there are power users who would like to `schedule` functions themselves in addition to `model.fit` usage. That is, they can instantiate one before `model.fit` does, or use one after `model.fit`has instantiated one. In either case, they should access the same`ClusterCoordinator` instance, as the one `model.fit` uses.
207
+
To take advantage of TF2 support of parameter server training, a `ClusterCoordinator` should be created for handling asynchronous function scheduling and joining. The preferred route should be that such an object is abstracted away from the user by `model.fit`training API as an implementation detail. For the power users who would need a`ClusterCoordinator` instance for their custom `schedule`s and `join`s, the `ClusterCoordinator` instance is available as a singleton through a constructor call. See below "`ClusterCoordinator` as a singleton" section for more information.
220
208
221
-
##### Have an attribute in `ParameterServerStrategy` that holds the `ClusterCoordinator`
209
+
`ClusterCoordinator` instance can be created at any point prior to `Model`'s use of it, but `model.fit` seems a natural place since that indicates the user's intention for using the compile-fit API as opposed to a CTL, where we expect users to create one.
222
210
223
-
We propose that an attribute is added to `ParameterServerStrategy` to keep track of the `ClusterCoordinator`. When a `ClusterCoordinator` is instantiated, such attribute will be set. Here, we assume that the distribution `Strategy` object can determine whether or not it is supposed to be used with a `ClusterCoordinator`. See below “Changes in tf.distribute” section for more information.
224
-
225
-
```
226
-
class ClusterCoordinator(...):
227
-
def __init__(self, strategy):
228
-
self.strategy = weakref.ref(strategy)
229
-
strategy.cluster_coordinator = self
230
-
```
231
-
232
-
And, we instantiate the `ClusterCoordinator` as soon as `model.fit` is called for the first time. Note that if users have instantiated it prior to `model.fit` calls, the same instance is returned from the `ClusterCoordinator` constructor. It will then be reused for the next `fit`, or on a different model.
211
+
`model.fit` obtains such `ClusterCoordinator` instance, and links the `strategy._cluster_coordinator` connection, as soon as `model.fit` is called for the first time. Note that if users have used the `ClusterCoordinator` instance prior to `model.fit` calls, that same instance is returned from the `ClusterCoordinator` constructor. This `ClusterCoordinator` instance will then be used for later `schedule`s and `join`s, as shown in sections below.
233
212
234
213
```
235
214
class Model(...):
236
215
237
216
def fit(self, ...):
238
-
if (self.distribute_strategy.should_use_with_coordinator() and
239
-
not self.distribute_strategy.cluster_coordinator):
217
+
if (self.distribute_strategy.should_use_with_coordinator and
218
+
not self.distribute_strategy._cluster_coordinator):
To avoid the leak resulting from the circular referencing between `ParameterServerStrategy` and `ClusterCoordinator`, the `coordinator`’s reference to `strategy` should be a `weakref`.
245
223
246
-
247
-
#### Keras `Model` changes
224
+
##### `make_train_function` changes
248
225
249
226
The train function in `Model.make_train_function` can be swapped with a wrapper that takes a `distributed_iterator` (when the scheduled function is executed on remote workers, the function will receive the actual worker-specific iterator inside the function being executed), and returns the resulting `RemoteValue`.
250
227
@@ -262,16 +239,15 @@ class Model(...):
262
239
263
240
self.train_function = ...
264
241
265
-
if self.distribute_strategy.cluster_coordinator:
242
+
if self.distribute_strategy._cluster_coordinator:
266
243
# Note that `train_function` has to be a `tf.function`.
Most challenges of supporting `model.fit` with `ParameterServerStrategy` are coming from the asynchronicity of dataset creation, where datasets are only created on workers when they are needed. This means the concrete dataset is not existent at the time the `DataHandler` class is instantiated, and thus some information extraction is not available, such as size of a batch, number of batches, etc.
@@ -291,7 +267,7 @@ class ClusterCoordinatorDataHandler(DataHandler):
@@ -378,7 +354,7 @@ With `ParameterServerStrategy`, the return value of `Model.train_function` is a
378
354
```
379
355
def to_numpy_or_python_type(logs):
380
356
if isinstance(logs, RemoteValue):
381
-
get_strategy().cluster_coordinator.join() # Sync the workers.
357
+
get_strategy()._cluster_coordinator.join() # Sync the workers.
382
358
return logs.fetch() # Return the NumPy results.
383
359
else:
384
360
... # Existing logic.
@@ -497,9 +473,13 @@ See below “Evaluation” section for other proposed evaluation solutions accom
497
473
498
474
### Changes in tf.distribute
499
475
500
-
Coordinator-based distributed training was made available with the introduction of a `ClusterCoordinator` API, where a `Strategy` should be used in conjunction with it. In contrast, classic `strategy.run`-based distributed training only requires a `Strategy` object to be used. The code written for two schemes, with custom training loops, is easily distinguishable by the presence or absence of a `ClusterCoordinator` object. However, with `model.fit`, users are not expected to create a `ClusterCoordinator` object, and thus there needs to be a way for the user to specify whether the training should be performed with a `ClusterCoordinator` object. This can possibly be done at `Strategy.__init__`, so that `model.fit` knows whether or not it is intended for a coordinator-based single-client training, or a traditional multi-client training.
476
+
#### `Strategy` indicating whether they should be used with `ClusterCoordinator`
477
+
478
+
Coordinator-based distributed training was made available with the introduction of a `ClusterCoordinator` API, where a `Strategy` should be used in conjunction with it. In contrast, classic `strategy.run`-based distributed training only requires a `Strategy` object to be used.
501
479
502
-
For now, it seems feasible that `ParameterServerStrategy` has a field `should_use_with_coordinator`, which is always True until usage without a `ClusterCoordinator` is supported, at which point it can be an argument of `__init__`.
480
+
The code written for two schemes, with custom training loops, is easily distinguishable by the presence or absence of a `ClusterCoordinator` object. However, with `model.fit`, users are not expected to create a `ClusterCoordinator` object, and thus there needs to be a way for the user to specify whether the training should be performed with a `ClusterCoordinator` object. This can possibly be done at `Strategy.__init__`, so that `model.fit` knows whether or not it is intended for a coordinator-based single-client training, or a traditional multi-client training.
481
+
482
+
We propose that `ParameterServerStrategy` has an attribute `should_use_with_coordinator`, which is always `True` until usage without a `ClusterCoordinator` is supported, at which point it can be an argument of `__init__`.
503
483
504
484
505
485
```
@@ -508,6 +488,33 @@ For now, it seems feasible that `ParameterServerStrategy` has a field `should_us
508
488
self.should_use_with_coordinator = True
509
489
```
510
490
491
+
#### `ClusterCoordinator` as a singleton
492
+
493
+
Since a `ClusterCoordinator` instance spins off worker and failure handling threads, there should only be one `ClusterCoordinator` at any given time with a `strategy` instance, and making it a singleton ensures that those threads are only created once. The singleton is accessible through a constructor call:
494
+
495
+
```
496
+
class ClusterCoordinator(object):
497
+
def __new__(cls, strategy):
498
+
if not strategy._cluster_coordinator: # TODO: Needs a lock for thread-safety
Here, we have created this attribute referencing `cluster_coordinator` from `strategy`. This is necessary because `Model` only keeps a reference of `strategy`, and this allows `Model` to have access to this `ClusterCoordinator` instance.
504
+
505
+
Being a singleton is important considering there are power users who would like to `schedule` functions themselves in addition to `model.fit` usage. That is, they can instantiate one before `model.fit` does, or use one after `model.fit` has instantiated one. In either case, they should access the same `ClusterCoordinator` instance, as the one `model.fit` uses.
506
+
507
+
Obtaining the singleton by calling the constructor of `ClusterCoordinator`, as opposed to an instance getter, provides the future-compatibility if we allow multiple `ClusterCoordinator`s in the future.
508
+
509
+
#### `ClusterCoordinator`’s reference to `ParameterServerStrategy` as a `weakref`
510
+
511
+
Note that since currently, `ClusterCoordinator` holds a reference to `ParameterServerStrategy`, in order to avoid the leak resulting from the circular referencing between `ParameterServerStrategy` and `ClusterCoordinator`, the `coordinator`’s reference to `strategy` should be a `weakref`:
0 commit comments