You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository was archived by the owner on Jul 10, 2025. It is now read-only.
#### Notable differences of user code between PS and other strategies
91
91
92
-
There are a couple of points worth noting in the above user code:
93
-
* The `dataset` argument of `model.fit` can no longer be a dataset instance. In fact, in the short term, it most likely will be some form of dataset factory, due to the challenges discussed below.
94
-
*`steps_per_epoch` argument will be required for PS training, at least in the short term. This is because `OutOfRangeError` is raised from `ClusterCoordinator` APIs as soon as one worker exhausts its worker dataset, at which point other workers may have datasets remaining to be processed, and this `OutOfRangeError` indicates neither every dataset is visited roughly once, nor every dataset is visited roughly number of workers times. We thus require explicit steps per epoch, and recommend users to always repeat and shuffle the input dataset.
95
-
* Batch level callback will be disabled when `ParameterServerStrategy` is used; that is, if users override `on_batch_begin` and `on_batch_end`, an error will be raised. This is necessary for reasonable performance as described below.
92
+
There are a few points worth noting in the above user code, when using PS training:
93
+
* A `dataset` callable or `dataset` factory will be added as another supported type of `x`, and is now the only type supported to be passed as `x` argument of `model.fit` (when used with PS training). This is due to the challenges discussed below.
94
+
*`steps_per_epoch` argument will be required, at least in the short term. This is because `OutOfRangeError` is raised from `ClusterCoordinator` APIs as soon as one worker exhausts its worker dataset, at which point other workers may have datasets remaining to be processed, and this `OutOfRangeError` indicates neither every dataset is visited roughly once, nor every dataset is visited roughly number of workers times. We thus require explicit steps per epoch, and recommend users to always repeat and shuffle the input dataset.
95
+
* Concept-wise, a step is one batch processed on one worker, as opposed to one batch distributed across all replicas when using some other strategies such as `MultiWorkerMirroredStrategy`.
96
+
* Batch level callback will be disabled; that is, if users override `on_batch_begin` and `on_batch_end`, an error will be raised. This is necessary for reasonable performance as described below.
96
97
* The cluster is synced at the end of every epoch. This is an implementation detail users do not necessarily need to be aware of, however is important for the correctness of epoch-level callbacks.
97
-
*`run_eagerly=True` case is not supported. This is because `ClusterCoordinator.schedule` requires a `tf.function` to be `schedule`d*, and regular python function cannot.
98
+
*model.fit(..., run_eagerly=True) case is not supported. This is because `ClusterCoordinator.schedule` requires a `tf.function` to be `schedule`d*, and regular python function cannot.
98
99
99
100
*There are a couple of reasons why we chose to only support `tf.function` to be scheduled. Primarily, we in general have better control over the behavior of `tf.functions`, including variable and resource creation. Furthermore, this forces the content of the function to be executed on remote workers, as opposed to possible execution of python code on the coordinator.
100
101
@@ -121,7 +122,7 @@ Current `model.fit` API takes a dataset from which an iterator is created, and t
121
122
122
123
*The idea behind a no-argument function is that the workers are deemed the same, and thus the datasets should be the same on every worker. At this time, we do not recommend sharding.
123
124
124
-
**The rationale behind using a `dataset_fn`as opposed to `dataset`was a historical choice as we could not get sharding to work well with fault tolerance.
125
+
**`dataset_fn` was supported in parameter server training as opposed to `dataset`instance initially as it provides simpler fault tolerance logic, and prevented us from having to deal with replicating a `dataset` instance.
125
126
126
127
In terms of how users pass a dataset factory into `model.fit`, there are a couple of options:
127
128
@@ -159,7 +160,7 @@ With an additionally defined class `DatasetFactory`:
159
160
160
161
161
162
```
162
-
class DatasetFactory(object):
163
+
class DatasetFactory(Factory):
163
164
164
165
def __init__(self, x):
165
166
if not callable(x):
@@ -177,6 +178,7 @@ class DatasetFactory(object):
177
178
```
178
179
179
180
Pros:
181
+
* If there are other types in the future to be supported in `model.fit`, we no longer need another type to be added to `x`; it will be another subclass of `Factory`.
180
182
* If `dataset` has a different interpretation, for example it takes an argument instead of none, we get an adapting layer with a `DatasetFactory`.
181
183
182
184
Cons:
@@ -200,9 +202,20 @@ The signature (input argument and return value) of `dataset_fn` taken by `model.
200
202
201
203
To take advantage of TF2 support of parameter server training, a `ClusterCoordinator` should be created for handling asynchronous function scheduling and joining. The preferred route should be that such an object is abstracted away from the user with `model.fit` training API as an implementation detail, since we do not expect users to `schedule` functions themselves, or synchronize the cluster in the basic workflow.
202
204
203
-
For power users who would like to `schedule` functions in addition to `model.fit` usage, we need to restrict them to use the `ClusterCoordinator`the library creates, because `ClusterCoordinator` does not have a graceful cleanup mechanism yet. We should error out if `ClusterCoordinator` is instantiated more than once, until we have support for that.
205
+
Since `ClusterCoordinator` instance spins off worker and failure handling threads, there should only be one `ClusterCoordinator`at any given time, and making it a singleton ensures that those threads are only created once:
204
206
205
-
In terms of who keeps track of the `ClusterCoordinator`, and when it starts allocating threads, there are a few options. Here, we assume that the distribution `Strategy` object can determine whether or not it is supposed to be used with a `ClusterCoordinator`. See below “Changes in tf.distribute” section for more information.
Being a singleton is important considering there are power users who would like to `schedule` functions themselves in addition to `model.fit` usage. That is, they can instantiate one before `model.fit` does, or use one after `model.fit` has instantiate one. In either case, they should access the same `ClusterCoordinator` instance.
217
+
218
+
In terms of who keeps track of the `ClusterCoordinator` for `model.fit`, and when it starts allocating threads, there are a few options. Here, we assume that the distribution `Strategy` object can determine whether or not it is supposed to be used with a `ClusterCoordinator`. See below “Changes in tf.distribute” section for more information.
206
219
207
220
208
221
##### Option 1: Attach the `ClusterCoordinator`’s lifecycle to `model.fit`
@@ -257,7 +270,7 @@ This option is with the assumption that there is always only one `ParameterServe
257
270
258
271
#### Keras `Model` changes
259
272
260
-
The train function in `Model.make_train_function` can be swapped with a wrapper that takes a `distribute_iterator` (when the scheduled function is executed on remote workers, the function will receive the actual worker-specific iterator inside the function being executed), and returns the resulting `RemoteValue`.
273
+
The train function in `Model.make_train_function` can be swapped with a wrapper that takes a `distributed_iterator` (when the scheduled function is executed on remote workers, the function will receive the actual worker-specific iterator inside the function being executed), and returns the resulting `RemoteValue`.
261
274
262
275
263
276
```
@@ -275,8 +288,8 @@ class Model(...):
275
288
276
289
if self._cluster_coordinator:
277
290
# Note that `train_function` has to be a `tf.function`.
0 commit comments