Skip to content
This repository was archived by the owner on Jul 10, 2025. It is now read-only.

Commit 42714cb

Browse files
author
Frank Chen
committed
Add tf.data Snapshot Public RFC
1 parent ca8039f commit 42714cb

File tree

1 file changed

+368
-0
lines changed

1 file changed

+368
-0
lines changed

rfcs/20200107-tf-data-snapshot.md

Lines changed: 368 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,368 @@
1+
# tf.data Snapshot
2+
3+
| Status | Proposed |
4+
| :------------ | :------------------------------------------------------ |
5+
| **RFC #** | [193](https://github.com/tensorflow/community/pull/193) |
6+
| **Author(s)** | Frank Chen ([email protected]), Rohan Jain |
7+
8+
| **Sponsor** | Jiri Simsa ([email protected]) |
9+
| **Updated** | 2020-01-07 |
10+
11+
## Objective
12+
13+
With ever faster accelerators available in Cloud and hyperparameter tuning
14+
consuming larger chunks of accelerator time, TensorFlow users are increasingly
15+
finding that they don’t have enough CPU resources to keep up with these
16+
accelerators, leaving valuable accelerator resources idle.
17+
18+
To alleviate this problem, we are proposing a `snapshot` API within `tf.data`,
19+
to allow users to transparently persist the output of their preprocessing
20+
pipeline to disk, and materialize the pre-processed data on a different training
21+
run.
22+
23+
This API enables repeated preprocessing steps to be consolidated, and allowing
24+
re-use of already processed data, trading off disk storage and network bandwidth
25+
for freeing up more valuable CPU resources and accelerator compute time.
26+
27+
## Motivation
28+
29+
Large TensorFlow users have indicated that they have complicated input
30+
processing pipelines which saturate their CPUs before saturating their
31+
accelerators (TPUs in particular). Since they often experiment with
32+
hyperparameter tuning or tweaks to existing model without affecting their input
33+
pipeline, they are asking for ways to avoid similar repeated preprocessing of
34+
data by either saving a dataset or caching it to disk.
35+
36+
## User Benefit
37+
38+
Users will be able to transparently persist partially or fully processed data
39+
from `tf.data` input pipelines to disk or Cloud storage systems, and materialize
40+
the pre-processed data during subsequent runs from the same pipeline. This will
41+
cut down on the input pipeline processing overheads during second and subsequent
42+
runs.
43+
44+
## Design Proposal
45+
46+
We propose that we add a new `snapshot` transformation to tf.data. To illustrate
47+
the usage of the transformation, we can start with some sample code:
48+
49+
```python
50+
dataset = Dataset.list_files("/raw/data/*").shard(num_workers, i)
51+
dataset = dataset.parallel_interleave(TFRecordDataset)
52+
dataset = dataset.map(my_preprocessing_fn)
53+
dataset = dataset.apply(tf.data.snapshot("/saved/data", options...))
54+
dataset = dataset.repeat()
55+
56+
model = ...
57+
model.fit(dataset)
58+
```
59+
60+
As we can see, the end user simply has to add this transformation in order to
61+
use this functionality. In essence, the transformation is similar to the
62+
existing `tf.data.Dataset.cache`, with the key difference is being that, unlike
63+
`cache`, `snapshot` is intended to re-used across different executions of the
64+
same input pipelines.
65+
66+
### Proposed API
67+
68+
We are proposing the following API for the snapshot transformation.
69+
70+
```python
71+
def snapshot(path,
72+
compression=None,
73+
reader_path_prefix=None,
74+
writer_path_prefix=None,
75+
shard_size_bytes=None,
76+
pending_snapshot_expiry_seconds=None,
77+
num_reader_threads=None,
78+
reader_buffer_size=None,
79+
num_writer_threads=None,
80+
writer_buffer_size=None,
81+
shuffle_on_read=None,
82+
shuffle_seed=None,
83+
mode=None,
84+
snapshot_name=None):
85+
pass # Implementation goes here.
86+
```
87+
88+
1. `path`: Required. A directory where we want to save our snapshots and/or
89+
read from a previously saved snapshot.
90+
91+
2. `compression`: Optional. The type of compression to apply to the snapshot
92+
written to disk. This will support `GZIP`, `SNAPPY` or None. Defaults to
93+
None.
94+
95+
3. `reader_path_prefix`: Optional. A prefix to add to the path when reading
96+
from snapshots. This is useful for filesystems where configuration is passed
97+
in through the path. Defaults to None.
98+
99+
4. `writer_path_prefix`: Optional. A prefix to add to the path when writing to
100+
snapshots. This is useful for filesystems where configuration is passed in
101+
through the path. Defaults to None.
102+
103+
5. `shard_size_bytes`: Optional. The maximum size of each data file to be
104+
written by the snapshot dataset op. Defaults to 10 GiB.
105+
106+
6. `pending_snapshot_expiry_seconds`: Optional. How long to wait (in seconds)
107+
before the snapshot op considers a previously unfinished snapshot to be
108+
stale and starts writing a snapshot from scratch again. Defaults to 86400
109+
seconds (1 day).
110+
111+
7. `num_reader_threads`: Optional. Number of threads to parallelize reading
112+
from snapshot. Especially useful if compression is turned on since the
113+
decompression operation tends to be intensive. Defaults to 1. If > 1, then
114+
this might introduce non-determinism i.e. the order in which the elements
115+
are read from the snapshot are different from the order they're written.
116+
117+
8. `reader_buffer_size`: Optional. Maximum number of elements we can prefetch
118+
reading from the snapshot. Defaults to 1. Increasing this might improve
119+
performance but will increase memory consumption.
120+
121+
9. `num_writer_threads`: Optional. Number of threads to parallelize writing
122+
from snapshot. We'll open up `num_writer_threads` files and write to them in
123+
parallel. Especially useful if compression is turned on since the
124+
compression operation tends to be intensive. Defaults to 1. If > 1, then
125+
this might introduce non-determinism i.e. the order in which the elements
126+
are read from the upstream iterator are different from the order they're
127+
written.
128+
129+
10. `writer_buffer_size`: Optional. Maximum number of pipeline elements to fill
130+
up the buffer before writing them out using `num_writer_threads`.
131+
132+
11. `shuffle_on_read`: Optional. If this is True, then the order in which
133+
examples are produced when reading from a snapshot will be random. Defaults
134+
to False.
135+
136+
12. `shuffle_seed`: Optional. If shuffle_seed is set, the random number
137+
generator used for shuffling (when `shuffle_on_read` is turned on) is seeded
138+
by the given seed. Otherwise, it is seeded by a random seed that differs for
139+
every run.
140+
141+
13. `mode`: Optional. The mode at which snapshot should operate. Valid options
142+
are `auto`, `read`, `write`, and `passthrough`. The default mode is `auto`,
143+
where the snapshot op will automatically determine what mode to operate in.
144+
145+
1. `write` mode forces the snapshot transformation to write a new
146+
materialization to disk, regardless of whether a complete and valid
147+
materialization currently exists. In other words, we enter the **WRITE**
148+
state immediately.
149+
150+
2. `read` mode forces the snapshot transformation to read from the latest
151+
version of the materialization on disk, regardless of whether the data
152+
stored on disk is complete and valid. In other words, we enter the
153+
**READ** state immediately.
154+
155+
3. `passthrough` mode turns the snapshot transformation into a no-op. In
156+
other words, we enter the **PASSTHROUGH** state immediately.
157+
158+
4. `auto` retains the default behavior of snapshot. See the "Standard
159+
Kernel Workflow" section for the default behavior.
160+
161+
14. `snapshot_name`: Optional. If set, use the supplied string as a named
162+
snapshot name instead of introspecting the data pipeline and automatically
163+
generating a unique identifier for the specific data pipeline.
164+
165+
1. Instead of generating a new fingerprint of the input processing graph or
166+
and `run_id` (see the _Detailed Design_ section for details), we will
167+
use the `snapshot_name` to uniquely identify the snapshot.
168+
169+
### External API Guarantees
170+
171+
Externally, we guarantee that snapshots written by a particular version of
172+
TensorFlow will be readable by that specific version of TensorFlow. Eventually,
173+
we can also guarantee that snapshots written will be readable by all future
174+
versions of TensorFlow.
175+
176+
We are not currently handling the case where workers do not go through the
177+
entire training set at least once.
178+
179+
### Alternatives Considered
180+
181+
An alternative proposal for an API would be `save()` and `load()`, where the
182+
saving and loading of the input pipeline would be made more explicit, avoiding
183+
some of the logic needed in determining whether to snapshot or read from a
184+
snapshot of a model.
185+
186+
The downside here would be that the user would have to split the preprocessing
187+
and training into potentially different files, and users would be forced to
188+
select whether to train or preprocess on their own, which is not good.
189+
190+
### Performance Implications
191+
192+
* Do you expect any (speed / memory)? How will you confirm?
193+
* There should be microbenchmarks. Are there?
194+
* There should be end-to-end tests and benchmarks. If there are not (since
195+
this is still a design), how will you track that these will be created?
196+
197+
### Dependencies
198+
199+
No new dependencies will be introduced as part of this project to TensorFlow.
200+
Dependent projects may be able to use this additional op, but there should be no
201+
significant changes otherwise.
202+
203+
### Engineering Impact
204+
205+
Binary sizes increases slightly with the inclusion of this new op, and this code
206+
will be maintained by the `tf.data` team.
207+
208+
### Platforms and Environments
209+
210+
This op will work on all TensorFlow-supported platforms. We do not anticipate
211+
this to work on embedded systems as it is not useful in resource-constrained
212+
environments.
213+
214+
### Best Practices, Tutorials and Examples
215+
216+
A user guide for snapshot will be published to guide new users in using this
217+
feature.
218+
219+
### Compatibility
220+
221+
This introduces a new op, which will impact future backwards compatibility.
222+
223+
### User Impact
224+
225+
A new python function and a new op are the only user-facing changes visible.
226+
227+
## Detailed Design
228+
229+
### Implementation Assumptions
230+
231+
The following implementation is based on the following assumptions that define
232+
the MVP this is designed for:
233+
234+
1. We assume that at least for one pipeline run, you can go through the entire
235+
training dataset and be able to store that data on disk. Otherwise, a
236+
snapshot will never get created.
237+
238+
2. In case there are multiple workers and the dataset is sharded across
239+
workers, we assume that the number of workers remains the same from one run
240+
to another. If the number changes, we’ll trigger another snapshot.
241+
242+
3. Any `repeat`s in the dataset should be moved to after the `snapshot` op, to
243+
avoid writing large (or infinite) amounts of data during a snapshot writing
244+
run.
245+
246+
### New `SnapshotDatasetOp`
247+
248+
To implement the transformation, we are introducing a new `SnapshotDatasetOp`
249+
dataset kernel that will implement all of the functionality in TensorFlow C++.
250+
Python code is mostly glue code to pass relevant parameters into the op kernel.
251+
252+
### Internal Directory / File Structure
253+
254+
Given a user directory path (e.g. `/path/to/snapshot`), the directory will look
255+
like:
256+
257+
* /path/to/snapshot
258+
* `fingerprint`/
259+
* snapshot.metadata
260+
* `run-id`/
261+
* 0000000.snapshot
262+
* 0000001.snapshot
263+
264+
The `fingerprint` is a hash of the input processing graph. The `run-id` is
265+
unique training run ID generated.
266+
267+
### Standard Kernel Workflow
268+
269+
_Note: This is an implementation detail, and may change in the future. This
270+
should not be relied upon except as a reference to the current implementation._
271+
272+
By default, the `snapshot` operation will, upon startup, make a determination
273+
using the following algorithm as to whether the operation should be in the
274+
WRITE, PASSTHROUGH, or READ state.
275+
276+
1. We will compute a graph fingerprint containing all the information from the
277+
Dataset preprocessing graph before the `snapshot` op. We’ll use the
278+
`AsGraphDefInternal` method on DatasetBase for this.
279+
280+
1. We will attempt to enter the corresponding fingerprint directory. For
281+
instance, if the computed fingerprint is `f-abc123` and the base snapshot
282+
directory is `/saved/data`, then we will attempt to enter
283+
`/saved/data/f-abc123`.
284+
285+
1. If the snapshot directory is non-existent, empty or it doesn’t contain a
286+
`metadata` file, we will enter the **WRITE** state.
287+
288+
1. If the snapshot directory contains a `metadata` file, we will read the
289+
metadata file.
290+
291+
1. The metadata file contains the following fields:
292+
1. A training run ID
293+
1. A boolean indicating if the snapshot is complete
294+
1. A training run start-time.
295+
296+
1. If the training run start-time is more than the (configurable) training run
297+
timeout (set with the `pending_snapshot_expiry_seconds` parameter), we will
298+
enter the **WRITE** state.
299+
300+
1. If the training run start-time is less than the training run timeout, but
301+
the snapshot is not complete, then we will enter the **PASSTHROUGH** state.
302+
303+
1. If the snapshot is complete, we will enter the **READ** state.
304+
305+
#### WRITE State
306+
307+
1. We generate a random training run ID.
308+
309+
1. We write (possibly overwriting) the `snapshot.metadata` file.
310+
311+
1. We proceed to create a subdirectory containing the training run ID, and
312+
start writing data asynchronously in chunks.
313+
314+
1. At the end of the dataset (when `end_of_sequence == true`), we will check
315+
the snapshot.metadata file to determine whether it contains the same
316+
training run ID.
317+
318+
1. If it does, we set the complete bit to true to finalize the directory.
319+
1. If it does not, it means that someone else is concurrently writing the
320+
snapshot and we lost the race to them. We delete all data in the
321+
training run directory.
322+
323+
For the current implementation, we will store the data in chunked TFRecord
324+
files. Eventually we may move to other more higher performance data stores or
325+
support additional storage systems such as Cloud BigTable.
326+
327+
#### PASSTHROUGH State
328+
329+
1. This is a no-op, where we simply pass through the tensors to the downstream
330+
operations.
331+
332+
#### READ State
333+
334+
1. We will read from the snapshots contained within the subfolder with the
335+
correct graph fingerprint and specified training run ID.
336+
337+
1. Optionally, the user may choose to tell us to specify that the snapshots
338+
should be read back in shuffled order.
339+
340+
### Concurrency: Handling Multiple Input Workers
341+
342+
If input workers are sharded, then they will generate different graph
343+
fingerprints as their shard indexes will be different. This will result in each
344+
worker writing to a different subdirectory.
345+
346+
If input workers are not sharded, then this will result in a race and
347+
potentially multiple workers writing data (still with different training run
348+
IDs). Eventually, if each worker finishes, we will be left with one copy of the
349+
data as all the other workers will determine that they have lost the race and
350+
delete their own copy of the snapshot data.
351+
352+
## Questions and Discussion Topics
353+
354+
* Should we implement this as three ops (a control opt o determine whether a
355+
snapshot is to be read from/written to) and a write and read op to do the
356+
respective operations?
357+
* Pros include:
358+
* Modularizes the implementation into smaller chunks
359+
* Allows someone else to do the "control"
360+
* Challenges include:
361+
* Where/how the "control" runs?
362+
* How do we construct the dataset graph properly?
363+
* How should autotuning be integrated into the snapshot transformation?
364+
* Are the configuration options well named? Is it possible to consolidate some
365+
of these options?
366+
* What other compression/decompression options would you like to see
367+
supported?
368+
* Any other performance / feature tuning knobs we should make available?

0 commit comments

Comments
 (0)