Skip to content
This repository was archived by the owner on Jul 10, 2025. It is now read-only.

Commit 9c90c5e

Browse files
committed
DatasetFactory changes
1 parent 01f1814 commit 9c90c5e

File tree

1 file changed

+55
-44
lines changed

1 file changed

+55
-44
lines changed

rfcs/20201121-keras-model-fit-ps.md

Lines changed: 55 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,9 @@ In this design, we will discuss the changes in `model.fit` API, and its integrat
6363

6464
## Proposed options and solutions
6565

66-
Let’s first take a look at the proposed user flow (on the coordinator). It is expected to be largely the same with other strategies, but notable differences are highlighted in the "Notable differences" section below. Unless mentioned otherwise, the discussion here applies to the python program intended to be run on the coordinator.
67-
68-
6966
### User Journey
7067

71-
With a dataset factory:
68+
Let’s first take a look at the proposed user flow (on the coordinator). It is expected to be largely the same with other strategies, but notable differences are highlighted in the "Notable differences" section below. Unless mentioned otherwise, the discussion here applies to the python program intended to be run on the coordinator.
7269

7370

7471
```
@@ -82,16 +79,17 @@ def dataset_fn(input_context):
8279
return tf.data.Dataset.from_tensors(...).repeat(...).batch(...)
8380
8481
# `ClusterCoordinator` is created at `fit`
85-
history = model.fit(dataset_fn, epochs=..., steps_per_epoch=..., callbacks=[...])
82+
dataset_factory = tf.data.experimental.DatasetFactory(dataset_fn)
83+
history = model.fit(dataset_factory, epochs=..., steps_per_epoch=..., callbacks=[...])
8684
logging.info("result: %r", history)
8785
```
8886

8987

9088
#### Notable differences of user code between PS and other strategies
9189

9290
There are a few points worth noting in the above user code, when using PS training:
93-
* A `dataset` callable or `dataset` factory will be added as another supported type of `x`, and is now the only type supported to be passed as `x` argument of `model.fit` (when used with PS training). This is due to the challenges discussed below.
94-
* `steps_per_epoch` argument will be required, at least in the short term. This is because `OutOfRangeError` is raised from `ClusterCoordinator` APIs as soon as one worker exhausts its worker dataset, at which point other workers may have datasets remaining to be processed, and this `OutOfRangeError` indicates neither every dataset is visited roughly once, nor every dataset is visited roughly number of workers times. We thus require explicit steps per epoch, and recommend users to always repeat and shuffle the input dataset.
91+
* A `tf.data.experimental.DatasetFactory` will be added as another supported type of `x`, and is now the only type supported to be passed as `x` argument of `model.fit` (when used with PS training). This is due to the challenges discussed below.
92+
* `steps_per_epoch` argument will be required, at least in the short term. This is because `OutOfRangeError` is raised from `ClusterCoordinator` APIs as soon as one worker exhausts its worker `dataset`, at which point other workers may have datasets remaining to be processed, and this `OutOfRangeError` indicates neither every dataset is visited roughly once, nor every dataset is visited roughly number of workers times. We thus require explicit steps per epoch, and recommend users to always repeat and shuffle the input dataset.
9593
* Concept-wise, a step is one batch processed on one worker, as opposed to one batch distributed across all replicas when using some other strategies such as `MultiWorkerMirroredStrategy`.
9694
* Batch level callback will be disabled; that is, if users override `on_batch_begin` and `on_batch_end`, an error will be raised. This is necessary for reasonable performance as described below.
9795
* The cluster is synced at the end of every epoch. This is an implementation detail users do not necessarily need to be aware of, however is important for the correctness of epoch-level callbacks.
@@ -106,19 +104,21 @@ There are a few points worth noting in the above user code, when using PS traini
106104
This section discusses the changes needed to be made in `model` API and assumes the reader has basic familiarity with Keras training APIs.
107105

108106

109-
#### Dataset function or factory in `model.fit`
107+
#### Acceptance of `DatasetFactory` in `model.fit`
110108

111-
In this design, we propose `model.fit` to take a dataset function or factory, instead of a dataset instance (which is [what is currently supported](https://github.com/tensorflow/tensorflow/blob/6b9e35f1c0410607bf956c0f27e5d3e1456b4899/tensorflow/python/keras/engine/training.py#L887-L889)), for the following reasons:
109+
In this design, we propose `model.fit` to take a new type, `tf.data.experimental.DatasetFactory`, instead of a dataset instance (which is [what is currently supported](https://github.com/tensorflow/tensorflow/blob/6b9e35f1c0410607bf956c0f27e5d3e1456b4899/tensorflow/python/keras/engine/training.py#L887-L889)), for the following reasons:
112110

113111
* With `dataset` instances, there is complication brought by the need of replicating `dataset`s to workers.
114112

115113
* With `dataset` replication, in the past we have observed more memory consumption, less flexibility for user’s dataset transformation, and suboptimal performance.
116114

117115
* When using Keras preprocessing layers (KPL), read-only resources are created at layer creation, which ends up being placed at the coordinator. However, `tf.data replicate` API does not support the resources referenced in the dataset graph to be accessed once serialized and deserialized, in the remotely worker. This prevents the `dataset` instance path from supporting resources, and thus KPLs.
118116

117+
Please see below for the rationale of using a `DatasetFactory` type instead of a simple `callable`.
118+
119119
##### Implementation
120120

121-
Current `model.fit` API takes a dataset from which an iterator is created, and the train function is built with this iterator. However, `ClusterCoordinator` only supports taking a no-argument* function** that returns a `Dataset`. This is done by the `create_per_worker_dataset` API, which creates datasets on remote workers. By leveraging such data factory support, `model.fit` with `dataset_fn` can be implemented by subclassing the existing Keras `DataHandler` (a Keras internal private API) to provide a worker-distributed dataset for Keras to use (i.e. call `iter` on). Please see the `DataHandler` section below for proposed changes.
121+
Currently, `ClusterCoordinator` supports taking a no-argument* function** that returns a `Dataset`. This is done by the `create_per_worker_dataset` API, which creates datasets on remote workers. By leveraging such `Dataset` function support, `model.fit` with a `DatasetFactory` can be implemented by subclassing the existing Keras `DataHandler` (a Keras internal private API) to provide a worker-distributed dataset for Keras to use (i.e. call `iter` on). Please see the `DataHandler` section below for proposed changes.
122122

123123
*The idea behind a no-argument function is that the workers are deemed the same, and thus the datasets should be the same on every worker. At this time, we do not recommend sharding.
124124

@@ -127,27 +127,15 @@ Current `model.fit` API takes a dataset from which an iterator is created, and t
127127
In terms of how users pass a dataset factory into `model.fit`, there are a couple of options:
128128

129129

130-
###### Option 1: any `callable`
130+
###### `DatasetFactory` class
131131

132-
In the simplest case, we can allow any kind of `callable` to be passed in:
132+
We propose to define a new class `DatasetFactory` that holds a reference to the `dataset_fn`, for the following reasons:
133133

134+
* The input argument, `x`, of `model.fit`, is already heavily overloaded with different types. With `DatasetFactory`, we can potentially have a `DataFactory` superclass in the future, for other types of callable, e.g., a callable that returns a numpy array, and `DataFactory` will cover different callable types.
134135

135-
```
136-
def dataset_fn(input_context):
137-
return tf.data.Dataset.from_tensor_slices(...)
138-
history = model.fit(dataset_fn, epochs=..., steps_per_epoch=..., callbacks=[...])
139-
```
136+
* With `DatasetFactory`, we learn user's intention to provide a function that returns a `Dataset`. If needed, this allows us to perform logic that is only applicable to `Dataset` as the input, prior to invoking the `dataset_fn`.
140137

141-
Pros:
142-
* `callable` does not require users to use additional APIs and may be less overhead.
143-
144-
Cons:
145-
* Less future proof as there could be different interpretation of callable passed as `dataset` to `model.fit` in the future.
146-
147-
148-
###### Option 2: dataset factory
149-
150-
For future-compatibility of `model.fit` API where a `dataset_fn` may have a signature change, a `DatasetFactory` can come handy which determines how the function is supposed to be used.
138+
* The library gets to verify the type of the return value, before it is used.
151139

152140

153141
```
@@ -156,12 +144,10 @@ def dataset_fn(input_context):
156144
history = model.fit(DatasetFactory(dataset_fn), epochs=..., steps_per_epoch=..., callbacks=[...])
157145
```
158146

159-
160-
With an additionally defined class `DatasetFactory`:
161-
147+
where
162148

163149
```
164-
class tf.keras.experimental.DatasetFactory(Factory):
150+
class tf.data.experimental.DatasetFactory(Factory):
165151
166152
def __init__(self, x):
167153
if not callable(x):
@@ -176,25 +162,17 @@ class tf.keras.experimental.DatasetFactory(Factory):
176162
return dataset
177163
```
178164

179-
Pros:
180-
* If there are other types in the future to be supported in `model.fit`, we no longer need another type to be added to `x`; it will be another subclass of `Factory`.
181-
* If `dataset` has a different interpretation, for example it takes an argument instead of none, we get an adapting layer with a `DatasetFactory`.
182-
183-
Cons:
184-
* This requires users to use an additional symbol.
185-
186-
187-
The following discussion is based on option 1, where a simple callable is taken.
165+
We believe the effort users will spend learning and using this API is marginal, and the benefit we gain from such class is worthwhile.
188166

189167

190168
##### Implication on no strategy/other strategies
191169

192-
If `model.fit` is allowed to take a `dataset_fn`, use cases for synchronous strategies, and no strategy, can be readily applied. That is, we provide the `dataset_fn` to `distribute_datasets_from_function`, which correctly places `dataset`s on devices in synchronous training.
170+
If `model.fit` is allowed to take a `DatasetFactory`, use cases for synchronous strategies, and no strategy, can be readily applied. That is, we provide the `dataset_fn` that is obtained by invoking `DatasetFactory`, to `distribute_datasets_from_function`, which correctly places `dataset`s on devices in synchronous training.
193171

194172

195173
##### Signature of `dataset_fn`
196174

197-
For compatibility with other strategies, we propose that `dataset_fn` takes a single argument `input_context`, and returns a `tf.data.Dataset`. This `dataset_fn` will be used in `strategy.distribute_datasets_from_function`, wrapped by a `per_worker_dataset_fn`*, passed to `create_per_worker_dataset`. See below "DataAdapter and DataHandler changes" section for how this can be implemented in `model.fit`. Though sharding is not necessary in PS training, it is fine that users shard with `Dataset.shard` using the `input_context` (which has sensible default attributes) in `dataset_fn`, if they need to use it across multiple strategies.
175+
For compatibility with other strategies, we propose that `dataset_fn` (which the `DatasetFactory` wraps) takes a single argument `input_context`, and returns a `tf.data.Dataset`. This `dataset_fn` will be used in `strategy.distribute_datasets_from_function`, wrapped by a `per_worker_dataset_fn`*, passed to `create_per_worker_dataset`. See below "DataAdapter and DataHandler changes" section for how this can be implemented in `model.fit`. Though sharding is not necessary in PS training, it is fine that users shard with `Dataset.shard` using the `input_context` (which has sensible default attributes) in `dataset_fn`, if they need to use it across multiple strategies.
198176

199177
*This is also in preparation for a multi-replica support in the future. See [tutorial](https://www.tensorflow.org/tutorials/distribute/parameter_server_training?hl=uk#dispatch_training_steps_to_remote_workers) for more information.
200178

@@ -371,7 +349,7 @@ Since the workers will sync every epoch anyway, fetching the remote values incur
371349

372350
##### Batch-level callbacks
373351

374-
##### What constitutes a `step` in `Model.fit` with `ParameterServerStrategy`?
352+
###### What constitutes a `step` in `Model.fit` with `ParameterServerStrategy`?
375353

376354
There are two mental models users might have of what constitutes a single `step` when running `Model.fit` with `ParameterServerStrategy`. The mental models are clarified below, as each has implications for how users specify `steps_per_epoch` and how we handle batch-level `Callback`s.
377355

@@ -394,6 +372,20 @@ With Mental Model 1, we cannot sync every batch. If we did so, only one worker w
394372

395373
For now, we will throw an error if a user provides a `Callback` that overrides `Callback.on_train_batch_begin` or `Callback.on_train_batch_end`, warning that batch-level Callbacks are not supported at this time. However, this design does not preclude supporting batch-level Callbacks in the future, as long as we give the user control of when to perform a sync. See the section Future Work on Batch-Level Callbacks below for a detailed discussion of this possibility.
396374

375+
###### Built-in callbacks that have batch-level calls
376+
377+
What about the existing callbacks Keras provide that have batch-level calls? There are 1 built-in, and by default added `callback` where batch-level calls are involved:
378+
379+
* `ProgbarLogger`: We'll make the default logging every epoch, and not batch. If user sets the verbose such that it'd log every batch, an error is raised.
380+
381+
In addition, there are 3 built-in, but by default not added callbacks, which have batch-level calls:
382+
383+
1. `ModelCheckpoint`: Default use case (checkpoint every epoch) is good. For users who do checkpointing every N examples (and thus batch level calls are involved), we will make it remote-aware, i.e., `ModelCheckpoint` knows that what it receives is `RemoteValue` and it needs to sync. With this, it's fine that it gets called at every batch, and only sync at N examples.
384+
385+
2. `TensorBoard`: Batch-level calls do not need output from `train_function`, so can be called anyway (by making it remote-aware as well).
386+
387+
3. `TerminateOnNan`: We should disable this in PS training.
388+
397389
##### Timing-Based Callbacks
398390

399391
Users who wish to create `Callbacks` that execute on a timing interval rather than a step interval can do so via launching a thread in `Callback.on_train_begin`. An example is shown below:
@@ -421,6 +413,8 @@ class MyTimingCallback(tf.keras.callbacks.Callback):
421413
self.model.save(self.save_dir)
422414
```
423415

416+
We plan to provide built-in timing-based callbacks, for common functionalities such as model checkpointing. The asynchronous nature of calls at intervals limits those usages to PS training only, for now. Detailed design of built-in timing-based callbacks will be separately discussed and not covered in this proposal.
417+
424418
##### Future Work on Batch-level Callbacks
425419

426420
Although we will not support batch-level Callbacks with the current proposal, it is worth noting that this design does not preclude us from supporting some form of batch-level Callbacks in the future.
@@ -467,7 +461,7 @@ In the longer term, we seek distributed support for `model.evaluate`, where the
467461
1. Implement distributed `model.evaluate` without visitation guarantee, but require user's opt-in because of the behavior change (by `model.evaluate(..., distributed_eval=True)`)
468462
2. Support distributed `model.evaluate` only after `ClusterCoordinator` provides visitation guarantee mechanism
469463

470-
Note that similar to the dataset factory change for `model.fit`, validation dataset will also need to be a function. That is, `model.fit` will take a `validation_data_fn` instead of a `validation_data`, and `model.evaluate` will take a `dataset_fn` as opposed to a `dataset` instance.
464+
Note that similar to the dataset factory change for `model.fit`, validation dataset will also need to be a dataset factory. That is, `model.fit` will take a `DatasetFactory` for `validation_data` argument, and `model.evaluate` will take a `DatasetFactory` for `x` as opposed to a `dataset` instance.
471465

472466
See below “Evaluation” section for other proposed evaluation solutions accompanying `model.fit` usage.
473467

@@ -778,6 +772,23 @@ history = model.fit(dataset, epochs=..., steps_per_epoch=..., callbacks=[...])
778772
logging.info("result: %r", history)
779773
```
780774

775+
### Using a simple `callable` rather than `DatasetFactory`
776+
777+
In the simplest case, we can allow any kind of `callable` to be passed in:
778+
779+
780+
```
781+
def dataset_fn(input_context):
782+
return tf.data.Dataset.from_tensor_slices(...)
783+
history = model.fit(dataset_fn, epochs=..., steps_per_epoch=..., callbacks=[...])
784+
```
785+
786+
Pros:
787+
* `callable` does not require users to use additional APIs and may be less overhead.
788+
789+
Cons:
790+
* Less future proof as there could be different interpretation of callable passed as `dataset` to `model.fit` in the future.
791+
781792

782793
### Attach the `ClusterCoordinator`’s lifecycle to `model.fit`
783794

0 commit comments

Comments
 (0)