Skip to content
This repository was archived by the owner on Jul 10, 2025. It is now read-only.

Commit bc6bce3

Browse files
author
ematejska
authored
Merge pull request #193 from frankchn/rfc-snapshot
RFC: tf.data Snapshot
2 parents 15fab62 + 41396e8 commit bc6bce3

File tree

1 file changed

+396
-0
lines changed

1 file changed

+396
-0
lines changed

rfcs/20200107-tf-data-snapshot.md

Lines changed: 396 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,396 @@
1+
# tf.data Snapshot
2+
3+
| Status | Accepted |
4+
| :------------ | :------------------------------------------------------ |
5+
| **RFC #** | [193](https://github.com/tensorflow/community/pull/193) |
6+
| **Author(s)** | Frank Chen ([email protected]), Rohan Jain |
7+
8+
| **Sponsor** | Jiri Simsa ([email protected]) |
9+
| **Updated** | 2020-02-10 |
10+
11+
## Objective
12+
13+
With ever faster accelerators available in Cloud and hyperparameter tuning
14+
consuming larger chunks of accelerator time, TensorFlow users are increasingly
15+
finding that they don’t have enough CPU resources to keep up with these
16+
accelerators, leaving valuable accelerator resources idle.
17+
18+
To alleviate this problem, we are proposing a `snapshot` API within `tf.data`,
19+
to allow users to transparently persist the output of their preprocessing
20+
pipeline to disk, and materialize the pre-processed data on a different training
21+
run.
22+
23+
This API enables repeated preprocessing steps to be consolidated, and allowing
24+
re-use of already processed data, trading off disk storage and network bandwidth
25+
for freeing up more valuable CPU resources and accelerator compute time.
26+
27+
## Motivation
28+
29+
Large TensorFlow users have indicated that they have complicated input
30+
processing pipelines which saturate their CPUs before saturating their
31+
accelerators (TPUs in particular). Since they often experiment with
32+
hyperparameter tuning or tweaks to existing model without affecting their input
33+
pipeline, they are asking for ways to avoid similar repeated preprocessing of
34+
data by either saving a dataset or caching it to disk.
35+
36+
## User Benefit
37+
38+
Users will be able to transparently persist partially or fully processed data
39+
from `tf.data` input pipelines to disk or Cloud storage systems, and materialize
40+
the pre-processed data during subsequent runs from the same pipeline. This will
41+
cut down on the input pipeline processing overheads during second and subsequent
42+
runs.
43+
44+
## Design Proposal
45+
46+
We propose that we add a new `snapshot` transformation to tf.data. To illustrate
47+
the usage of the transformation, we can start with some sample code:
48+
49+
```python
50+
dataset = Dataset.list_files("/raw/data/*").shard(num_workers, i)
51+
dataset = dataset.parallel_interleave(TFRecordDataset)
52+
dataset = dataset.map(my_preprocessing_fn)
53+
dataset = dataset.apply(tf.data.snapshot("/saved/data", options...))
54+
dataset = dataset.repeat()
55+
56+
model = ...
57+
model.fit(dataset)
58+
```
59+
60+
As we can see, the end user simply has to add this transformation in order to
61+
use this functionality. In essence, the transformation is similar to the
62+
existing `tf.data.Dataset.cache`, with the key difference is being that, unlike
63+
`cache`, `snapshot` is intended to re-used across different executions of the
64+
same input pipelines.
65+
66+
### Proposed API
67+
68+
We are proposing the following API for the snapshot transformation.
69+
70+
```python
71+
def snapshot(path,
72+
compression=None,
73+
reader_fn=None,
74+
writer_fn=None,
75+
pending_snapshot_expiry_seconds=None):
76+
pass # Implementation goes here.
77+
```
78+
79+
1. `path`: Required. A directory where we want to save our snapshots and/or
80+
read from a previously saved snapshot.
81+
82+
1. `compression`: Optional. The type of compression to apply to the snapshot
83+
written to disk. This will support `GZIP`, `SNAPPY` or None. Defaults to
84+
AUTO.
85+
86+
1. `reader_fn`: Optional. The input pipeline transformation specified by
87+
`reader_fn` is executed when the snapshot detects that there is an existing,
88+
valid snapshot available.
89+
90+
`reader_fn` is a user specified function that accepts a single argument:
91+
(1) a Dataset of Datasets, each representing a "splits" of elements of the
92+
original dataset. The cardinality of the input dataset matches the
93+
cardinality of the output of `writer_fn` (see below). The function should
94+
return a Dataset of elements of the original dataset.
95+
96+
A default `reader_fn` will look like the following:
97+
98+
```python
99+
def default_reader_fn(datasets):
100+
# shuffle the datasets splits
101+
datasets = datasets.shuffle(NUM_DATASETS)
102+
# read datasets in parallel and interleave their elements
103+
return dataset.interleave(lambda x: x, num_parallel_calls=AUTOTUNE)
104+
```
105+
106+
1. `writer_fn`: Optional. The input pipeline specified by `writer_fn` is
107+
executed when the snapshot op detects that there are no valid snapshots
108+
and no other threads are currently attempting to write a snapshot.
109+
110+
`writer_fn` is a user specified function that accepts a single argument:
111+
(1) a Dataset of elements to be written out. The function should return
112+
a Dataset of Datasets, each representing "splits" of elements of the
113+
original dataset. The tf.data snapshot implementation will then persist
114+
splits in parallel.
115+
116+
A default writer_fn will look like the following:
117+
118+
```python
119+
def default_writer_fn(dataset):
120+
# add a component with element index
121+
dataset = dataset.enumerate()
122+
# split input dataset in a round-robin fashion
123+
return dataset.split(num_splits=NUM_CORES, key_fn=lambda i, _: i % NUM_CORE
124+
```
125+
126+
1. `pending_snapshot_expiry_seconds`: Optional. How long to wait (in seconds)
127+
before the snapshot op considers a previously unfinished snapshot to be
128+
stale and starts writing a snapshot from scratch again. Defaults to 86400
129+
seconds (1 day).
130+
131+
#### Achieving Parallelism
132+
133+
`reader_fn` and `writer_fn` will default to passing the dataset through unchanged
134+
by default. In other words, the default implementation will result in
135+
single-threaded reads and writes on snapshots. Parallelism can be achieved in
136+
`writer_fn` by splitting up the dataset into multiple datasets, and using
137+
`num_parallel_calls` in the `interleave` function of the `reader_fn`.
138+
139+
#### Computing Graph Fingerprints
140+
141+
Snapshot attempts to determine whether a run of an input pipeline is the same
142+
as a previous run by computing the fingerprint of the nodes within the pipeline.
143+
144+
However, some input pipelines might vary in insignificant ways from run to run
145+
that causes the fingerprinting of them to differ. For instance, consider the
146+
following preprocessing function:
147+
148+
```python
149+
features_to_multiply = {"feature1", "feature2", "feature3", "feature4"}
150+
151+
def preprocessing_fn(value):
152+
keys_to_features = {
153+
"feature1": tf.FixedLenFeature([], tf.float32, 0.0),
154+
"feature2": tf.FixedLenFeature([], tf.float32, 0.0),
155+
"feature3": tf.FixedLenFeature([], tf.float32, 0.0),
156+
"feature4": tf.FixedLenFeature([], tf.float32, 0.0)
157+
}
158+
159+
parsed = tf.parse_single_example(value, keys_to_features)
160+
combined_feature = 1.0
161+
for item in features_to_multiply:
162+
combined_feature *= parsed[item]
163+
164+
return combined_feature
165+
166+
dataset = ...
167+
dataset = dataset.map(preprocessing_fn)
168+
```
169+
170+
In the above example, our `features_to_multiply` variable uses a `set`, which is
171+
not guaranteed to be ordered in Python. When we iterate over the set in the
172+
for loop within `preprocessing_fn`, we may get a different graph on each
173+
run (i.e. one run could have us multiplying `feature2` first, then `feature4`,
174+
etc..., while another run may have us multiplying `feature1`, then `feature3`,
175+
and so on).
176+
177+
In cases like these, we can ask fingerprinting to use a fixed value for the
178+
fingerprint of the map function with a new `with_snapshot_fingerprint`
179+
transformation, which asks the fingerprinting function to not compute the
180+
fingerprint of the previous node but to use a user-specified value instead:
181+
182+
```python
183+
dataset = ...
184+
dataset = dataset.map(preprocessing_fn)
185+
dataset = tf.data.experimental.with_snapshot_fingerprint(
186+
dataset, fingerprint="my_fixed_fp")
187+
```
188+
189+
### External API Guarantees
190+
191+
Externally, we guarantee that snapshots written by a particular version of
192+
TensorFlow will be readable by that specific version of TensorFlow.
193+
194+
We are not currently handling the case where workers do not go through the
195+
entire training set at least once.
196+
197+
### Alternatives Considered
198+
199+
An alternative proposal for an API would be `save()` and `load()`, where the
200+
saving and loading of the input pipeline would be made more explicit, avoiding
201+
some of the logic needed in determining whether to snapshot or read from a
202+
snapshot of a model.
203+
204+
The downside here would be that the user would have to split the preprocessing
205+
and training into potentially different files, and users would be forced to
206+
select whether to train or preprocess on their own, which is not good.
207+
208+
### Performance Implications
209+
210+
Benchmarks for this feature will be included as part of Dataset microbenchmarks.
211+
212+
### Dependencies
213+
214+
No new dependencies will be introduced as part of this project to TensorFlow.
215+
Dependent projects may be able to use this additional op, but there should be no
216+
significant changes otherwise.
217+
218+
### Engineering Impact
219+
220+
Binary sizes increases slightly with the inclusion of this new op, and this code
221+
will be maintained by the `tf.data` team.
222+
223+
### Platforms and Environments
224+
225+
This op will work on all TensorFlow-supported platforms. We do not anticipate
226+
this to work on embedded systems as it is not useful in resource-constrained
227+
environments.
228+
229+
### Best Practices, Tutorials and Examples
230+
231+
A user guide for snapshot will be published to guide new users in using this
232+
feature.
233+
234+
### Compatibility
235+
236+
This introduces a new op, which will impact future backwards compatibility.
237+
238+
### User Impact
239+
240+
A new python function and a new op are the only user-facing changes visible.
241+
242+
## Detailed Design
243+
244+
### Implementation Assumptions
245+
246+
The following implementation is based on the following assumptions that define
247+
the MVP this is designed for:
248+
249+
1. We assume that at least for one pipeline run, you can go through the entire
250+
training dataset and be able to store that data on disk. Otherwise, a
251+
snapshot will never get created.
252+
253+
2. In the cases where there are multiple workers and the dataset is sharded with
254+
`Dataset.shard`, we assume that the number of workers remains the same from
255+
the initial (writing) run through to the reading runs.
256+
257+
If the number of workers change, then the `num_shards` parameter to
258+
`Dataset.shard` will change, and this will result in a different graph
259+
fingerprint and another snapshot write will be triggered.
260+
261+
If all workers use the exact same input pipeline with no sharding (e.g. all
262+
workers will read from all the files), then snapshot will still be able to
263+
read from previous snapshots even if the number of workers is different.
264+
265+
3. Any `repeat`s in the dataset should be moved to after the `snapshot` op, to
266+
avoid writing large (or infinite) amounts of data during a snapshot writing
267+
run.
268+
269+
### New `SnapshotDatasetOp`
270+
271+
To implement the transformation, we are introducing a new `SnapshotDatasetOp`
272+
dataset kernel that will implement all of the functionality in TensorFlow C++.
273+
Python code is mostly glue code to pass relevant parameters into the op kernel.
274+
275+
### Internal Directory / File Structure
276+
277+
Given a user directory path (e.g. `/path/to/snapshot`), the directory will look
278+
like:
279+
280+
* /path/to/snapshot
281+
* `fingerprint`/
282+
* snapshot.metadata
283+
* `run-id`/
284+
* 0000000.snapshot
285+
* 0000001.snapshot
286+
287+
The `fingerprint` is a hash of the input processing graph. The `run-id` is
288+
unique training run ID generated.
289+
290+
### Standard Kernel Workflow
291+
292+
_Note: This is an implementation detail, and may change in the future. This
293+
should not be relied upon except as a reference to the current implementation._
294+
295+
By default, the `snapshot` operation will, upon startup, make a determination
296+
using the following algorithm as to whether the operation should be in the
297+
WRITE, PASSTHROUGH, or READ state.
298+
299+
1. We will compute a graph fingerprint containing all the information from the
300+
Dataset preprocessing graph before the `snapshot` op. We’ll use the
301+
`AsGraphDefInternal` method on DatasetBase for this.
302+
303+
1. We will attempt to enter the corresponding fingerprint directory. For
304+
instance, if the computed fingerprint is `f-abc123` and the base snapshot
305+
directory is `/saved/data`, then we will attempt to enter
306+
`/saved/data/f-abc123`.
307+
308+
1. If the snapshot directory is non-existent, empty or it doesn’t contain a
309+
`metadata` file, we will enter the **WRITE** state.
310+
311+
1. If the snapshot directory contains a `metadata.final` file, we will read
312+
the final metadata file and proceed to the **READ** state.
313+
314+
1. The file contains the following fields:
315+
1. A training run ID,
316+
1. A boolean indicating if the snapshot is complete.
317+
1. A training run start-time.
318+
319+
1. If the snapshot directory contains a `metadata` file but not a
320+
`metadata.final` file, we will read the metadata file.
321+
322+
1. If the training run start-time is more than the (configurable) training run
323+
timeout (set with the `pending_snapshot_expiry_seconds` parameter), we will
324+
enter the **WRITE** state.
325+
326+
1. If the training run start-time is less than the training run timeout, but
327+
the snapshot is not complete, then we will enter the **PASSTHROUGH** state.
328+
329+
1. If the snapshot is complete, we will enter the **READ** state.
330+
331+
#### WRITE State
332+
333+
1. We generate a random training run ID.
334+
335+
1. We write (possibly overwriting) the `snapshot.metadata` file.
336+
337+
1. We proceed to create a subdirectory containing the training run ID, and
338+
start writing data asynchronously in chunks.
339+
340+
1. At the end of the dataset (when `end_of_sequence == true`), we will check
341+
the snapshot.metadata file to determine whether it contains the same
342+
training run ID.
343+
344+
1. If it does, we write a `metadata.final` file containing the
345+
same information as the `metadata` file but with the complete
346+
bit set to true.
347+
1. If it does not, it means that someone else is concurrently writing the
348+
snapshot and we lost the race to them. We delete all data in the
349+
training run directory.
350+
351+
For the current implementation, we will store the data in chunked TFRecord
352+
files. Eventually we may move to other more higher performance data stores or
353+
support additional storage systems such as Cloud BigTable.
354+
355+
#### PASSTHROUGH State
356+
357+
1. This is a no-op, where we simply pass through the tensors to the downstream
358+
operations.
359+
360+
#### READ State
361+
362+
1. We will read from the snapshots contained within the subfolder with the
363+
correct graph fingerprint and specified training run ID.
364+
365+
1. Optionally, the user may choose to tell us to specify that the snapshots
366+
should be read back in shuffled order.
367+
368+
### Concurrency: Handling Multiple Input Workers
369+
370+
If input workers are sharded, then they will generate different graph
371+
fingerprints as their shard indexes will be different. This will result in each
372+
worker writing to a different subdirectory.
373+
374+
If input workers are not sharded, then this will result in a race and
375+
potentially multiple workers writing data (still with different training run
376+
IDs). Eventually, if each worker finishes, we will be left with one copy of the
377+
data as all the other workers will determine that they have lost the race and
378+
delete their own copy of the snapshot data.
379+
380+
## Questions and Discussion Topics
381+
382+
* Should we implement this as three ops (a control opt o determine whether a
383+
snapshot is to be read from/written to) and a write and read op to do the
384+
respective operations?
385+
* Pros include:
386+
* Modularizes the implementation into smaller chunks
387+
* Allows someone else to do the "control"
388+
* Challenges include:
389+
* Where/how the "control" runs?
390+
* How do we construct the dataset graph properly?
391+
* How should autotuning be integrated into the snapshot transformation?
392+
* Are the configuration options well named? Is it possible to consolidate some
393+
of these options?
394+
* What other compression/decompression options would you like to see
395+
supported?
396+
* Any other performance / feature tuning knobs we should make available?

0 commit comments

Comments
 (0)