Skip to content
This repository was archived by the owner on Jul 10, 2025. It is now read-only.

Commit 16cb2be

Browse files
committed
Updates on model and tf.distribute parts. Reorg of sections.
1 parent 4569980 commit 16cb2be

File tree

1 file changed

+47
-40
lines changed

1 file changed

+47
-40
lines changed

rfcs/20201121-keras-model-fit-ps.md

Lines changed: 47 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -198,53 +198,30 @@ For compatibility with other strategies, we propose that `dataset_fn` takes a si
198198

199199
*This is also in preparation for a multi-replica support in the future. See [tutorial](https://www.tensorflow.org/tutorials/distribute/parameter_server_training?hl=uk#dispatch_training_steps_to_remote_workers) for more information.
200200

201-
#### The setup of `ClusterCoordinator` with `model.fit` usage
202201

203-
##### Basic use case: `ClusterCoordinator` being internal
204202

205-
To take advantage of TF2 support of parameter server training, a `ClusterCoordinator` should be created for handling asynchronous function scheduling and joining. The preferred route should be that such an object is abstracted away from the user with `model.fit` training API as an implementation detail, since we do not expect users to `schedule` functions themselves, or synchronize the cluster in the basic workflow.
206-
207-
##### Advanced use case: `ClusterCoordinator` as a singleton
208-
209-
Now, let's consider a more advanced use case where the `ClusterCoordinator` instance is needed by users. Since `ClusterCoordinator` instance spins off worker and failure handling threads, there should only be one `ClusterCoordinator` at any given time, and making it a singleton ensures that those threads are only created once:
203+
#### Keras `Model` changes
210204

211-
```
212-
class ClusterCoordinator(object):
213-
def __new__(cls, strategy):
214-
if not strategy.cluster_coordinator: # TODO: Needs a lock for thread-safety
215-
strategy.cluster_coordinator = super(ClusterCoordinator, cls).__new__(cls)
216-
return strategy.cluster_coordinator
217-
```
205+
##### `Model` abstracting the concept of `ClusterCoordinator` for `model.fit`
218206

219-
Being a singleton is important considering there are power users who would like to `schedule` functions themselves in addition to `model.fit` usage. That is, they can instantiate one before `model.fit` does, or use one after `model.fit` has instantiated one. In either case, they should access the same `ClusterCoordinator` instance, as the one `model.fit` uses.
207+
To take advantage of TF2 support of parameter server training, a `ClusterCoordinator` should be created for handling asynchronous function scheduling and joining. The preferred route should be that such an object is abstracted away from the user by `model.fit` training API as an implementation detail. For the power users who would need a `ClusterCoordinator` instance for their custom `schedule`s and `join`s, the `ClusterCoordinator` instance is available as a singleton through a constructor call. See below "`ClusterCoordinator` as a singleton" section for more information.
220208

221-
##### Have an attribute in `ParameterServerStrategy` that holds the `ClusterCoordinator`
209+
`ClusterCoordinator` instance can be created at any point prior to `Model`'s use of it, but `model.fit` seems a natural place since that indicates the user's intention for using the compile-fit API as opposed to a CTL, where we expect users to create one.
222210

223-
We propose that an attribute is added to `ParameterServerStrategy` to keep track of the `ClusterCoordinator`. When a `ClusterCoordinator` is instantiated, such attribute will be set. Here, we assume that the distribution `Strategy` object can determine whether or not it is supposed to be used with a `ClusterCoordinator`. See below “Changes in tf.distribute” section for more information.
224-
225-
```
226-
class ClusterCoordinator(...):
227-
def __init__(self, strategy):
228-
self.strategy = weakref.ref(strategy)
229-
strategy.cluster_coordinator = self
230-
```
231-
232-
And, we instantiate the `ClusterCoordinator` as soon as `model.fit` is called for the first time. Note that if users have instantiated it prior to `model.fit` calls, the same instance is returned from the `ClusterCoordinator` constructor. It will then be reused for the next `fit`, or on a different model.
211+
`model.fit` obtains such `ClusterCoordinator` instance, and links the `strategy._cluster_coordinator` connection, as soon as `model.fit` is called for the first time. Note that if users have used the `ClusterCoordinator` instance prior to `model.fit` calls, that same instance is returned from the `ClusterCoordinator` constructor. This `ClusterCoordinator` instance will then be used for later `schedule`s and `join`s, as shown in sections below.
233212

234213
```
235214
class Model(...):
236215
237216
def fit(self, ...):
238-
if (self.distribute_strategy.should_use_with_coordinator() and
239-
not self.distribute_strategy.cluster_coordinator):
217+
if (self.distribute_strategy.should_use_with_coordinator and
218+
not self.distribute_strategy._cluster_coordinator):
240219
cluster_coordinator.ClusterCoordinator(self.distribute_strategy)
241220
... # the rest of fit
242221
243222
```
244-
To avoid the leak resulting from the circular referencing between `ParameterServerStrategy` and `ClusterCoordinator`, the `coordinator`’s reference to `strategy` should be a `weakref`.
245223

246-
247-
#### Keras `Model` changes
224+
##### `make_train_function` changes
248225

249226
The train function in `Model.make_train_function` can be swapped with a wrapper that takes a `distributed_iterator` (when the scheduled function is executed on remote workers, the function will receive the actual worker-specific iterator inside the function being executed), and returns the resulting `RemoteValue`.
250227

@@ -262,16 +239,15 @@ class Model(...):
262239
263240
self.train_function = ...
264241
265-
if self.distribute_strategy.cluster_coordinator:
242+
if self.distribute_strategy._cluster_coordinator:
266243
# Note that `train_function` has to be a `tf.function`.
267-
self.train_function = lambda distributed_iterator: self.distribute_strategy.cluster_coordinator.schedule(
244+
self.train_function = lambda distributed_iterator: self.distribute_strategy._cluster_coordinator.schedule(
268245
train_function, args=(distributed_iterator,))
269246
270247
return self.train_function
271248
```
272249

273250

274-
275251
#### DataAdapter and DataHandler changes
276252

277253
Most challenges of supporting `model.fit` with `ParameterServerStrategy` are coming from the asynchronicity of dataset creation, where datasets are only created on workers when they are needed. This means the concrete dataset is not existent at the time the `DataHandler` class is instantiated, and thus some information extraction is not available, such as size of a batch, number of batches, etc.
@@ -291,7 +267,7 @@ class ClusterCoordinatorDataHandler(DataHandler):
291267
def per_worker_dataset_fn():
292268
return strategy.distribute_datasets_from_function(x)
293269
294-
coordinator = self._model.distribute_strategy.cluster_coordinator
270+
coordinator = self._model.distribute_strategy._cluster_coordinator
295271
self._dataset = coordinator.create_per_worker_dataset(per_worker_dataset_fn)
296272
297273
if steps_per_epoch is None:
@@ -300,7 +276,7 @@ class ClusterCoordinatorDataHandler(DataHandler):
300276
self._inferred_steps = steps_per_epoch
301277
302278
def sync(self):
303-
self._model.distribute_strategy.cluster_coordinator.join()
279+
self._model.distribute_strategy._cluster_coordinator.join()
304280
305281
def resolve_logs(self, logs):
306282
return logs.fetch()
@@ -337,7 +313,7 @@ The `DataHandler` `model.fit` uses depends on whether or not it is using a `Clus
337313

338314
```
339315
def get_data_handler(*args, **kwargs):
340-
if model.distribute_strategy.cluster_coordinator:
316+
if model.distribute_strategy._cluster_coordinator:
341317
return ClusterCoordinatorDataHandler(*args, **kwargs)
342318
return DataHandler(*args, **kwargs)
343319
```
@@ -378,7 +354,7 @@ With `ParameterServerStrategy`, the return value of `Model.train_function` is a
378354
```
379355
def to_numpy_or_python_type(logs):
380356
if isinstance(logs, RemoteValue):
381-
get_strategy().cluster_coordinator.join() # Sync the workers.
357+
get_strategy()._cluster_coordinator.join() # Sync the workers.
382358
return logs.fetch() # Return the NumPy results.
383359
else:
384360
... # Existing logic.
@@ -497,9 +473,13 @@ See below “Evaluation” section for other proposed evaluation solutions accom
497473

498474
### Changes in tf.distribute
499475

500-
Coordinator-based distributed training was made available with the introduction of a `ClusterCoordinator` API, where a `Strategy` should be used in conjunction with it. In contrast, classic `strategy.run`-based distributed training only requires a `Strategy` object to be used. The code written for two schemes, with custom training loops, is easily distinguishable by the presence or absence of a `ClusterCoordinator` object. However, with `model.fit`, users are not expected to create a `ClusterCoordinator` object, and thus there needs to be a way for the user to specify whether the training should be performed with a `ClusterCoordinator` object. This can possibly be done at `Strategy.__init__`, so that `model.fit` knows whether or not it is intended for a coordinator-based single-client training, or a traditional multi-client training.
476+
#### `Strategy` indicating whether they should be used with `ClusterCoordinator`
477+
478+
Coordinator-based distributed training was made available with the introduction of a `ClusterCoordinator` API, where a `Strategy` should be used in conjunction with it. In contrast, classic `strategy.run`-based distributed training only requires a `Strategy` object to be used.
501479

502-
For now, it seems feasible that `ParameterServerStrategy` has a field `should_use_with_coordinator`, which is always True until usage without a `ClusterCoordinator` is supported, at which point it can be an argument of `__init__`.
480+
The code written for two schemes, with custom training loops, is easily distinguishable by the presence or absence of a `ClusterCoordinator` object. However, with `model.fit`, users are not expected to create a `ClusterCoordinator` object, and thus there needs to be a way for the user to specify whether the training should be performed with a `ClusterCoordinator` object. This can possibly be done at `Strategy.__init__`, so that `model.fit` knows whether or not it is intended for a coordinator-based single-client training, or a traditional multi-client training.
481+
482+
We propose that `ParameterServerStrategy` has an attribute `should_use_with_coordinator`, which is always `True` until usage without a `ClusterCoordinator` is supported, at which point it can be an argument of `__init__`.
503483

504484

505485
```
@@ -508,6 +488,33 @@ For now, it seems feasible that `ParameterServerStrategy` has a field `should_us
508488
self.should_use_with_coordinator = True
509489
```
510490

491+
#### `ClusterCoordinator` as a singleton
492+
493+
Since a `ClusterCoordinator` instance spins off worker and failure handling threads, there should only be one `ClusterCoordinator` at any given time with a `strategy` instance, and making it a singleton ensures that those threads are only created once. The singleton is accessible through a constructor call:
494+
495+
```
496+
class ClusterCoordinator(object):
497+
def __new__(cls, strategy):
498+
if not strategy._cluster_coordinator: # TODO: Needs a lock for thread-safety
499+
strategy._cluster_coordinator = super(ClusterCoordinator, cls).__new__(cls)
500+
return strategy._cluster_coordinator
501+
```
502+
503+
Here, we have created this attribute referencing `cluster_coordinator` from `strategy`. This is necessary because `Model` only keeps a reference of `strategy`, and this allows `Model` to have access to this `ClusterCoordinator` instance.
504+
505+
Being a singleton is important considering there are power users who would like to `schedule` functions themselves in addition to `model.fit` usage. That is, they can instantiate one before `model.fit` does, or use one after `model.fit` has instantiated one. In either case, they should access the same `ClusterCoordinator` instance, as the one `model.fit` uses.
506+
507+
Obtaining the singleton by calling the constructor of `ClusterCoordinator`, as opposed to an instance getter, provides the future-compatibility if we allow multiple `ClusterCoordinator`s in the future.
508+
509+
#### `ClusterCoordinator`’s reference to `ParameterServerStrategy` as a `weakref`
510+
511+
Note that since currently, `ClusterCoordinator` holds a reference to `ParameterServerStrategy`, in order to avoid the leak resulting from the circular referencing between `ParameterServerStrategy` and `ClusterCoordinator`, the `coordinator`’s reference to `strategy` should be a `weakref`:
512+
513+
```
514+
class ClusterCoordinator(...):
515+
def __init__(self, strategy):
516+
self.strategy = weakref.ref(strategy)
517+
```
511518

512519

513520
### Workers and parameter servers

0 commit comments

Comments
 (0)