Skip to content
This repository was archived by the owner on Jul 10, 2025. It is now read-only.

Commit 9c89029

Browse files
committed
Some updates
1 parent d8d821c commit 9c89029

File tree

1 file changed

+25
-12
lines changed

1 file changed

+25
-12
lines changed

rfcs/20201121-keras-model-fit-ps.md

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,13 @@ logging.info("result: %r", history)
8989

9090
#### Notable differences of user code between PS and other strategies
9191

92-
There are a couple of points worth noting in the above user code:
93-
* The `dataset` argument of `model.fit` can no longer be a dataset instance. In fact, in the short term, it most likely will be some form of dataset factory, due to the challenges discussed below.
94-
* `steps_per_epoch` argument will be required for PS training, at least in the short term. This is because `OutOfRangeError` is raised from `ClusterCoordinator` APIs as soon as one worker exhausts its worker dataset, at which point other workers may have datasets remaining to be processed, and this `OutOfRangeError` indicates neither every dataset is visited roughly once, nor every dataset is visited roughly number of workers times. We thus require explicit steps per epoch, and recommend users to always repeat and shuffle the input dataset.
95-
* Batch level callback will be disabled when `ParameterServerStrategy` is used; that is, if users override `on_batch_begin` and `on_batch_end`, an error will be raised. This is necessary for reasonable performance as described below.
92+
There are a few points worth noting in the above user code, when using PS training:
93+
* A `dataset` callable or `dataset` factory will be added as another supported type of `x`, and is now the only type supported to be passed as `x` argument of `model.fit` (when used with PS training). This is due to the challenges discussed below.
94+
* `steps_per_epoch` argument will be required, at least in the short term. This is because `OutOfRangeError` is raised from `ClusterCoordinator` APIs as soon as one worker exhausts its worker dataset, at which point other workers may have datasets remaining to be processed, and this `OutOfRangeError` indicates neither every dataset is visited roughly once, nor every dataset is visited roughly number of workers times. We thus require explicit steps per epoch, and recommend users to always repeat and shuffle the input dataset.
95+
* Concept-wise, a step is one batch processed on one worker, as opposed to one batch distributed across all replicas when using some other strategies such as `MultiWorkerMirroredStrategy`.
96+
* Batch level callback will be disabled; that is, if users override `on_batch_begin` and `on_batch_end`, an error will be raised. This is necessary for reasonable performance as described below.
9697
* The cluster is synced at the end of every epoch. This is an implementation detail users do not necessarily need to be aware of, however is important for the correctness of epoch-level callbacks.
97-
* `run_eagerly=True` case is not supported. This is because `ClusterCoordinator.schedule` requires a `tf.function` to be `schedule`d*, and regular python function cannot.
98+
* model.fit(..., run_eagerly=True) case is not supported. This is because `ClusterCoordinator.schedule` requires a `tf.function` to be `schedule`d*, and regular python function cannot.
9899

99100
*There are a couple of reasons why we chose to only support `tf.function` to be scheduled. Primarily, we in general have better control over the behavior of `tf.functions`, including variable and resource creation. Furthermore, this forces the content of the function to be executed on remote workers, as opposed to possible execution of python code on the coordinator.
100101

@@ -121,7 +122,7 @@ Current `model.fit` API takes a dataset from which an iterator is created, and t
121122

122123
*The idea behind a no-argument function is that the workers are deemed the same, and thus the datasets should be the same on every worker. At this time, we do not recommend sharding.
123124

124-
**The rationale behind using a `dataset_fn` as opposed to `dataset` was a historical choice as we could not get sharding to work well with fault tolerance.
125+
**`dataset_fn` was supported in parameter server training as opposed to `dataset` instance initially as it provides simpler fault tolerance logic, and prevented us from having to deal with replicating a `dataset` instance.
125126

126127
In terms of how users pass a dataset factory into `model.fit`, there are a couple of options:
127128

@@ -159,7 +160,7 @@ With an additionally defined class `DatasetFactory`:
159160

160161

161162
```
162-
class DatasetFactory(object):
163+
class DatasetFactory(Factory):
163164
164165
def __init__(self, x):
165166
if not callable(x):
@@ -177,6 +178,7 @@ class DatasetFactory(object):
177178
```
178179

179180
Pros:
181+
* If there are other types in the future to be supported in `model.fit`, we no longer need another type to be added to `x`; it will be another subclass of `Factory`.
180182
* If `dataset` has a different interpretation, for example it takes an argument instead of none, we get an adapting layer with a `DatasetFactory`.
181183

182184
Cons:
@@ -200,9 +202,20 @@ The signature (input argument and return value) of `dataset_fn` taken by `model.
200202

201203
To take advantage of TF2 support of parameter server training, a `ClusterCoordinator` should be created for handling asynchronous function scheduling and joining. The preferred route should be that such an object is abstracted away from the user with `model.fit` training API as an implementation detail, since we do not expect users to `schedule` functions themselves, or synchronize the cluster in the basic workflow.
202204

203-
For power users who would like to `schedule` functions in addition to `model.fit` usage, we need to restrict them to use the `ClusterCoordinator` the library creates, because `ClusterCoordinator` does not have a graceful cleanup mechanism yet. We should error out if `ClusterCoordinator` is instantiated more than once, until we have support for that.
205+
Since `ClusterCoordinator` instance spins off worker and failure handling threads, there should only be one `ClusterCoordinator` at any given time, and making it a singleton ensures that those threads are only created once:
204206

205-
In terms of who keeps track of the `ClusterCoordinator`, and when it starts allocating threads, there are a few options. Here, we assume that the distribution `Strategy` object can determine whether or not it is supposed to be used with a `ClusterCoordinator`. See below “Changes in tf.distribute” section for more information.
207+
```
208+
class ClusterCoordinator(object):
209+
instance = None
210+
def __new__(cls):
211+
if not ClusterCoordinator.instance:
212+
ClusterCoordinator.instance = super(ClusterCoordinator, cls).__new__(cls)
213+
return ClusterCoordinator.instance
214+
```
215+
216+
Being a singleton is important considering there are power users who would like to `schedule` functions themselves in addition to `model.fit` usage. That is, they can instantiate one before `model.fit` does, or use one after `model.fit` has instantiate one. In either case, they should access the same `ClusterCoordinator` instance.
217+
218+
In terms of who keeps track of the `ClusterCoordinator` for `model.fit`, and when it starts allocating threads, there are a few options. Here, we assume that the distribution `Strategy` object can determine whether or not it is supposed to be used with a `ClusterCoordinator`. See below “Changes in tf.distribute” section for more information.
206219

207220

208221
##### Option 1: Attach the `ClusterCoordinator`’s lifecycle to `model.fit`
@@ -257,7 +270,7 @@ This option is with the assumption that there is always only one `ParameterServe
257270

258271
#### Keras `Model` changes
259272

260-
The train function in `Model.make_train_function` can be swapped with a wrapper that takes a `distribute_iterator` (when the scheduled function is executed on remote workers, the function will receive the actual worker-specific iterator inside the function being executed), and returns the resulting `RemoteValue`.
273+
The train function in `Model.make_train_function` can be swapped with a wrapper that takes a `distributed_iterator` (when the scheduled function is executed on remote workers, the function will receive the actual worker-specific iterator inside the function being executed), and returns the resulting `RemoteValue`.
261274

262275

263276
```
@@ -275,8 +288,8 @@ class Model(...):
275288
276289
if self._cluster_coordinator:
277290
# Note that `train_function` has to be a `tf.function`.
278-
self.train_function = lambda distribute_iterator: self._cluster_coordinator.schedule( # pylint: disable=g-long-lambda
279-
train_function, args=(distribute_iterator,))
291+
self.train_function = lambda distributed_iterator: self._cluster_coordinator.schedule( # pylint: disable=g-long-lambda
292+
train_function, args=(distributed_iterator,))
280293
281294
return self.train_function
282295
```

0 commit comments

Comments
 (0)