Skip to content
This repository was archived by the owner on Jul 10, 2025. It is now read-only.

Commit 77a4425

Browse files
committed
User flow with a dataset instance moved to alternative considered
1 parent 30b69f1 commit 77a4425

File tree

1 file changed

+17
-17
lines changed

1 file changed

+17
-17
lines changed

rfcs/20201121-keras-model-fit-ps.md

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -87,23 +87,6 @@ logging.info("result: %r", history)
8787
```
8888

8989

90-
with a dataset instance:
91-
92-
93-
```
94-
cluster_resolver = ...
95-
strategy = tf.distribute.experimental.ParameterServerStrategy(cluster_resolver)
96-
with strategy.scope():
97-
preproc_stage = ... # Some Keras preproc layers
98-
model = ... # Building a Keras model
99-
model.compile(optimizer=..., loss=...) # `ClusterCoordinator` is created
100-
dataset = tf.data.Dataset.X... # Make use of `preproc_stage` for transformation
101-
102-
# `model.fit` serializes and deserializes dataset onto workers
103-
history = model.fit(dataset, epochs=..., steps_per_epoch=..., callbacks=[...])
104-
logging.info("result: %r", history)
105-
```
106-
10790
#### Notable differences of user code between PS and other strategies
10891

10992
There are a couple of points worth noting in the above user code:
@@ -715,3 +698,20 @@ Asynchronous `Callback`s might be worth exploring in a future extension to the f
715698
### Support of `dataset` in `ClusterCoordinator`
716699

717700
Previously, we have considered the possibility to support `dataset` instance in `model.fit` to keep the existing API contract. In this case, it should be preferred that `ClusterCoordinator` provides native `dataset` support, which `model.fit` can readily use, rather than `model.fit` implementing replication logic to accommodate that. Similar to `experimental_distribute_dataset` API, `ClusterCoordinator` can use `tf.data`’s `replicate` API to serialize the dataset graph, and unserialize onto workers.
701+
702+
User flow with a dataset instance:
703+
704+
705+
```
706+
cluster_resolver = ...
707+
strategy = tf.distribute.experimental.ParameterServerStrategy(cluster_resolver)
708+
with strategy.scope():
709+
preproc_stage = ... # Some Keras preproc layers
710+
model = ... # Building a Keras model
711+
model.compile(optimizer=..., loss=...) # `ClusterCoordinator` is created
712+
dataset = tf.data.Dataset.X... # Make use of `preproc_stage` for transformation
713+
714+
# `model.fit` serializes and deserializes dataset onto workers
715+
history = model.fit(dataset, epochs=..., steps_per_epoch=..., callbacks=[...])
716+
logging.info("result: %r", history)
717+
```

0 commit comments

Comments
 (0)