You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository was archived by the owner on Jul 10, 2025. It is now read-only.
Copy file name to clipboardExpand all lines: rfcs/20201121-keras-model-fit-ps.md
+55-44Lines changed: 55 additions & 44 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -63,12 +63,9 @@ In this design, we will discuss the changes in `model.fit` API, and its integrat
63
63
64
64
## Proposed options and solutions
65
65
66
-
Let’s first take a look at the proposed user flow (on the coordinator). It is expected to be largely the same with other strategies, but notable differences are highlighted in the "Notable differences" section below. Unless mentioned otherwise, the discussion here applies to the python program intended to be run on the coordinator.
67
-
68
-
69
66
### User Journey
70
67
71
-
With a dataset factory:
68
+
Let’s first take a look at the proposed user flow (on the coordinator). It is expected to be largely the same with other strategies, but notable differences are highlighted in the "Notable differences" section below. Unless mentioned otherwise, the discussion here applies to the python program intended to be run on the coordinator.
history = model.fit(dataset_factory, epochs=..., steps_per_epoch=..., callbacks=[...])
86
84
logging.info("result: %r", history)
87
85
```
88
86
89
87
90
88
#### Notable differences of user code between PS and other strategies
91
89
92
90
There are a few points worth noting in the above user code, when using PS training:
93
-
* A `dataset` callable or `dataset` factory will be added as another supported type of `x`, and is now the only type supported to be passed as `x` argument of `model.fit` (when used with PS training). This is due to the challenges discussed below.
94
-
*`steps_per_epoch` argument will be required, at least in the short term. This is because `OutOfRangeError` is raised from `ClusterCoordinator` APIs as soon as one worker exhausts its worker dataset, at which point other workers may have datasets remaining to be processed, and this `OutOfRangeError` indicates neither every dataset is visited roughly once, nor every dataset is visited roughly number of workers times. We thus require explicit steps per epoch, and recommend users to always repeat and shuffle the input dataset.
91
+
* A `tf.data.experimental.DatasetFactory` will be added as another supported type of `x`, and is now the only type supported to be passed as `x` argument of `model.fit` (when used with PS training). This is due to the challenges discussed below.
92
+
*`steps_per_epoch` argument will be required, at least in the short term. This is because `OutOfRangeError` is raised from `ClusterCoordinator` APIs as soon as one worker exhausts its worker `dataset`, at which point other workers may have datasets remaining to be processed, and this `OutOfRangeError` indicates neither every dataset is visited roughly once, nor every dataset is visited roughly number of workers times. We thus require explicit steps per epoch, and recommend users to always repeat and shuffle the input dataset.
95
93
* Concept-wise, a step is one batch processed on one worker, as opposed to one batch distributed across all replicas when using some other strategies such as `MultiWorkerMirroredStrategy`.
96
94
* Batch level callback will be disabled; that is, if users override `on_batch_begin` and `on_batch_end`, an error will be raised. This is necessary for reasonable performance as described below.
97
95
* The cluster is synced at the end of every epoch. This is an implementation detail users do not necessarily need to be aware of, however is important for the correctness of epoch-level callbacks.
@@ -106,19 +104,21 @@ There are a few points worth noting in the above user code, when using PS traini
106
104
This section discusses the changes needed to be made in `model` API and assumes the reader has basic familiarity with Keras training APIs.
107
105
108
106
109
-
#### Dataset function or factory in `model.fit`
107
+
#### Acceptance of `DatasetFactory` in `model.fit`
110
108
111
-
In this design, we propose `model.fit` to take a dataset function or factory, instead of a dataset instance (which is [what is currently supported](https://github.com/tensorflow/tensorflow/blob/6b9e35f1c0410607bf956c0f27e5d3e1456b4899/tensorflow/python/keras/engine/training.py#L887-L889)), for the following reasons:
109
+
In this design, we propose `model.fit` to take a new type, `tf.data.experimental.DatasetFactory`, instead of a dataset instance (which is [what is currently supported](https://github.com/tensorflow/tensorflow/blob/6b9e35f1c0410607bf956c0f27e5d3e1456b4899/tensorflow/python/keras/engine/training.py#L887-L889)), for the following reasons:
112
110
113
111
* With `dataset` instances, there is complication brought by the need of replicating `dataset`s to workers.
114
112
115
113
* With `dataset` replication, in the past we have observed more memory consumption, less flexibility for user’s dataset transformation, and suboptimal performance.
116
114
117
115
* When using Keras preprocessing layers (KPL), read-only resources are created at layer creation, which ends up being placed at the coordinator. However, `tf.data replicate` API does not support the resources referenced in the dataset graph to be accessed once serialized and deserialized, in the remotely worker. This prevents the `dataset` instance path from supporting resources, and thus KPLs.
118
116
117
+
Please see below for the rationale of using a `DatasetFactory` type instead of a simple `callable`.
118
+
119
119
##### Implementation
120
120
121
-
Current `model.fit` API takes a dataset from which an iterator is created, and the train function is built with this iterator. However, `ClusterCoordinator`only supports taking a no-argument* function** that returns a `Dataset`. This is done by the `create_per_worker_dataset` API, which creates datasets on remote workers. By leveraging such data factory support, `model.fit` with `dataset_fn` can be implemented by subclassing the existing Keras `DataHandler` (a Keras internal private API) to provide a worker-distributed dataset for Keras to use (i.e. call `iter` on). Please see the `DataHandler` section below for proposed changes.
121
+
Currently, `ClusterCoordinator` supports taking a no-argument* function** that returns a `Dataset`. This is done by the `create_per_worker_dataset` API, which creates datasets on remote workers. By leveraging such `Dataset` function support, `model.fit` with a `DatasetFactory` can be implemented by subclassing the existing Keras `DataHandler` (a Keras internal private API) to provide a worker-distributed dataset for Keras to use (i.e. call `iter` on). Please see the `DataHandler` section below for proposed changes.
122
122
123
123
*The idea behind a no-argument function is that the workers are deemed the same, and thus the datasets should be the same on every worker. At this time, we do not recommend sharding.
124
124
@@ -127,27 +127,15 @@ Current `model.fit` API takes a dataset from which an iterator is created, and t
127
127
In terms of how users pass a dataset factory into `model.fit`, there are a couple of options:
128
128
129
129
130
-
###### Option 1: any `callable`
130
+
###### `DatasetFactory` class
131
131
132
-
In the simplest case, we can allow any kind of `callable` to be passed in:
132
+
We propose to define a new class `DatasetFactory` that holds a reference to the `dataset_fn`, for the following reasons:
133
133
134
+
* The input argument, `x`, of `model.fit`, is already heavily overloaded with different types. With `DatasetFactory`, we can potentially have a `DataFactory` superclass in the future, for other types of callable, e.g., a callable that returns a numpy array, and `DataFactory` will cover different callable types.
134
135
135
-
```
136
-
def dataset_fn(input_context):
137
-
return tf.data.Dataset.from_tensor_slices(...)
138
-
history = model.fit(dataset_fn, epochs=..., steps_per_epoch=..., callbacks=[...])
139
-
```
136
+
* With `DatasetFactory`, we learn user's intention to provide a function that returns a `Dataset`. If needed, this allows us to perform logic that is only applicable to `Dataset` as the input, prior to invoking the `dataset_fn`.
140
137
141
-
Pros:
142
-
*`callable` does not require users to use additional APIs and may be less overhead.
143
-
144
-
Cons:
145
-
* Less future proof as there could be different interpretation of callable passed as `dataset` to `model.fit` in the future.
146
-
147
-
148
-
###### Option 2: dataset factory
149
-
150
-
For future-compatibility of `model.fit` API where a `dataset_fn` may have a signature change, a `DatasetFactory` can come handy which determines how the function is supposed to be used.
138
+
* The library gets to verify the type of the return value, before it is used.
history = model.fit(DatasetFactory(dataset_fn), epochs=..., steps_per_epoch=..., callbacks=[...])
157
145
```
158
146
159
-
160
-
With an additionally defined class `DatasetFactory`:
161
-
147
+
where
162
148
163
149
```
164
-
class tf.keras.experimental.DatasetFactory(Factory):
150
+
class tf.data.experimental.DatasetFactory(Factory):
165
151
166
152
def __init__(self, x):
167
153
if not callable(x):
@@ -176,25 +162,17 @@ class tf.keras.experimental.DatasetFactory(Factory):
176
162
return dataset
177
163
```
178
164
179
-
Pros:
180
-
* If there are other types in the future to be supported in `model.fit`, we no longer need another type to be added to `x`; it will be another subclass of `Factory`.
181
-
* If `dataset` has a different interpretation, for example it takes an argument instead of none, we get an adapting layer with a `DatasetFactory`.
182
-
183
-
Cons:
184
-
* This requires users to use an additional symbol.
185
-
186
-
187
-
The following discussion is based on option 1, where a simple callable is taken.
165
+
We believe the effort users will spend learning and using this API is marginal, and the benefit we gain from such class is worthwhile.
188
166
189
167
190
168
##### Implication on no strategy/other strategies
191
169
192
-
If `model.fit` is allowed to take a `dataset_fn`, use cases for synchronous strategies, and no strategy, can be readily applied. That is, we provide the `dataset_fn` to `distribute_datasets_from_function`, which correctly places `dataset`s on devices in synchronous training.
170
+
If `model.fit` is allowed to take a `DatasetFactory`, use cases for synchronous strategies, and no strategy, can be readily applied. That is, we provide the `dataset_fn` that is obtained by invoking `DatasetFactory`, to `distribute_datasets_from_function`, which correctly places `dataset`s on devices in synchronous training.
193
171
194
172
195
173
##### Signature of `dataset_fn`
196
174
197
-
For compatibility with other strategies, we propose that `dataset_fn` takes a single argument `input_context`, and returns a `tf.data.Dataset`. This `dataset_fn` will be used in `strategy.distribute_datasets_from_function`, wrapped by a `per_worker_dataset_fn`*, passed to `create_per_worker_dataset`. See below "DataAdapter and DataHandler changes" section for how this can be implemented in `model.fit`. Though sharding is not necessary in PS training, it is fine that users shard with `Dataset.shard` using the `input_context` (which has sensible default attributes) in `dataset_fn`, if they need to use it across multiple strategies.
175
+
For compatibility with other strategies, we propose that `dataset_fn`(which the `DatasetFactory` wraps) takes a single argument `input_context`, and returns a `tf.data.Dataset`. This `dataset_fn` will be used in `strategy.distribute_datasets_from_function`, wrapped by a `per_worker_dataset_fn`*, passed to `create_per_worker_dataset`. See below "DataAdapter and DataHandler changes" section for how this can be implemented in `model.fit`. Though sharding is not necessary in PS training, it is fine that users shard with `Dataset.shard` using the `input_context` (which has sensible default attributes) in `dataset_fn`, if they need to use it across multiple strategies.
198
176
199
177
*This is also in preparation for a multi-replica support in the future. See [tutorial](https://www.tensorflow.org/tutorials/distribute/parameter_server_training?hl=uk#dispatch_training_steps_to_remote_workers) for more information.
200
178
@@ -371,7 +349,7 @@ Since the workers will sync every epoch anyway, fetching the remote values incur
371
349
372
350
##### Batch-level callbacks
373
351
374
-
##### What constitutes a `step` in `Model.fit` with `ParameterServerStrategy`?
352
+
######What constitutes a `step` in `Model.fit` with `ParameterServerStrategy`?
375
353
376
354
There are two mental models users might have of what constitutes a single `step` when running `Model.fit` with `ParameterServerStrategy`. The mental models are clarified below, as each has implications for how users specify `steps_per_epoch` and how we handle batch-level `Callback`s.
377
355
@@ -394,6 +372,20 @@ With Mental Model 1, we cannot sync every batch. If we did so, only one worker w
394
372
395
373
For now, we will throw an error if a user provides a `Callback` that overrides `Callback.on_train_batch_begin` or `Callback.on_train_batch_end`, warning that batch-level Callbacks are not supported at this time. However, this design does not preclude supporting batch-level Callbacks in the future, as long as we give the user control of when to perform a sync. See the section Future Work on Batch-Level Callbacks below for a detailed discussion of this possibility.
396
374
375
+
###### Built-in callbacks that have batch-level calls
376
+
377
+
What about the existing callbacks Keras provide that have batch-level calls? There are 1 built-in, and by default added `callback` where batch-level calls are involved:
378
+
379
+
*`ProgbarLogger`: We'll make the default logging every epoch, and not batch. If user sets the verbose such that it'd log every batch, an error is raised.
380
+
381
+
In addition, there are 3 built-in, but by default not added callbacks, which have batch-level calls:
382
+
383
+
1.`ModelCheckpoint`: Default use case (checkpoint every epoch) is good. For users who do checkpointing every N examples (and thus batch level calls are involved), we will make it remote-aware, i.e., `ModelCheckpoint` knows that what it receives is `RemoteValue` and it needs to sync. With this, it's fine that it gets called at every batch, and only sync at N examples.
384
+
385
+
2.`TensorBoard`: Batch-level calls do not need output from `train_function`, so can be called anyway (by making it remote-aware as well).
386
+
387
+
3.`TerminateOnNan`: We should disable this in PS training.
388
+
397
389
##### Timing-Based Callbacks
398
390
399
391
Users who wish to create `Callbacks` that execute on a timing interval rather than a step interval can do so via launching a thread in `Callback.on_train_begin`. An example is shown below:
@@ -421,6 +413,8 @@ class MyTimingCallback(tf.keras.callbacks.Callback):
421
413
self.model.save(self.save_dir)
422
414
```
423
415
416
+
We plan to provide built-in timing-based callbacks, for common functionalities such as model checkpointing. The asynchronous nature of calls at intervals limits those usages to PS training only, for now. Detailed design of built-in timing-based callbacks will be separately discussed and not covered in this proposal.
417
+
424
418
##### Future Work on Batch-level Callbacks
425
419
426
420
Although we will not support batch-level Callbacks with the current proposal, it is worth noting that this design does not preclude us from supporting some form of batch-level Callbacks in the future.
@@ -467,7 +461,7 @@ In the longer term, we seek distributed support for `model.evaluate`, where the
467
461
1. Implement distributed `model.evaluate` without visitation guarantee, but require user's opt-in because of the behavior change (by `model.evaluate(..., distributed_eval=True)`)
468
462
2. Support distributed `model.evaluate` only after `ClusterCoordinator` provides visitation guarantee mechanism
469
463
470
-
Note that similar to the dataset factory change for `model.fit`, validation dataset will also need to be a function. That is, `model.fit` will take a `validation_data_fn` instead of a `validation_data`, and `model.evaluate` will take a `dataset_fn` as opposed to a `dataset` instance.
464
+
Note that similar to the dataset factory change for `model.fit`, validation dataset will also need to be a dataset factory. That is, `model.fit` will take a `DatasetFactory` for `validation_data` argument, and `model.evaluate` will take a `DatasetFactory` for `x` as opposed to a `dataset` instance.
471
465
472
466
See below “Evaluation” section for other proposed evaluation solutions accompanying `model.fit` usage.
473
467
@@ -778,6 +772,23 @@ history = model.fit(dataset, epochs=..., steps_per_epoch=..., callbacks=[...])
778
772
logging.info("result: %r", history)
779
773
```
780
774
775
+
### Using a simple `callable` rather than `DatasetFactory`
776
+
777
+
In the simplest case, we can allow any kind of `callable` to be passed in:
778
+
779
+
780
+
```
781
+
def dataset_fn(input_context):
782
+
return tf.data.Dataset.from_tensor_slices(...)
783
+
history = model.fit(dataset_fn, epochs=..., steps_per_epoch=..., callbacks=[...])
784
+
```
785
+
786
+
Pros:
787
+
*`callable` does not require users to use additional APIs and may be less overhead.
788
+
789
+
Cons:
790
+
* Less future proof as there could be different interpretation of callable passed as `dataset` to `model.fit` in the future.
791
+
781
792
782
793
### Attach the `ClusterCoordinator`’s lifecycle to `model.fit`
0 commit comments