Skip to content
This repository was archived by the owner on Jul 10, 2025. It is now read-only.

Commit 0b28765

Browse files
author
Frank Chen
committed
Revision to the API after TensorFlow Design Review
1 parent 149a6f9 commit 0b28765

File tree

1 file changed

+84
-85
lines changed

1 file changed

+84
-85
lines changed

rfcs/20200107-tf-data-snapshot.md

Lines changed: 84 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,9 @@ We are proposing the following API for the snapshot transformation.
7070
```python
7171
def snapshot(path,
7272
compression=None,
73-
shard_size_bytes=None,
74-
pending_snapshot_expiry_seconds=None,
75-
num_writer_threads=None,
7673
reader_fn=None,
77-
mode=None,
78-
snapshot_name=None):
74+
writer_fn=None,
75+
pending_snapshot_expiry_seconds=None):
7976
pass # Implementation goes here.
8077
```
8178

@@ -86,105 +83,107 @@ def snapshot(path,
8683
written to disk. This will support `GZIP`, `SNAPPY` or None. Defaults to
8784
AUTO.
8885

89-
1. `shard_size_bytes`: Optional. The maximum size of each data file to be
90-
written by the snapshot dataset op. Defaults to AUTO.
86+
1. `reader_fn`: Optional. The input pipeline transformation specified by
87+
`reader_fn` is executed when the snapshot detects that there is an existing,
88+
valid snapshot available.
9189

92-
1. `pending_snapshot_expiry_seconds`: Optional. How long to wait (in seconds)
93-
before the snapshot op considers a previously unfinished snapshot to be
94-
stale and starts writing a snapshot from scratch again. Defaults to 86400
95-
seconds (1 day).
96-
97-
1. `num_writer_threads`: Optional. Number of threads to parallelize writing
98-
from snapshot. We'll open up `num_writer_threads` files and write to them in
99-
parallel. Especially useful if compression is turned on since the
100-
compression operation tends to be intensive. If > 1, then
101-
this might introduce non-determinism i.e. the order in which the elements
102-
are read from the upstream iterator are different from the order they're
103-
written. Defaults to AUTO.
104-
105-
1. `reader_fn`: Optional. A user provided reader function to use when reading
106-
the snapshot back. This allows the user to specify the concurrency and
107-
randomization required when reading from the snapshot.
90+
`reader_fn` is a user specified function that accepts a single argument:
91+
(1) a Dataset of Datasets, each representing a "splits" of elements of the
92+
original dataset. The cardinality of the input dataset matches the
93+
cardinality of the output of `writer_fn` (see below). The function should
94+
return a Dataset of elements of the original dataset.
10895

109-
`reader_fn` should be a function that accepts two arguments: (1) a list of
110-
snapshot file paths, and (2) a reference to a `SnapshotDataset` class.
111-
The function should return a `Dataset` class.
96+
A default `reader_fn` will look like the following:
11297

113-
The `SnapshotReaderDataset` class is a `Dataset` (similar to other source datasets
114-
like `TFRecordDataset` or `TextLineDataset`) with the following constructor:
11598
```python
116-
class SnapshotDataset(dataset_ops.DatasetSource):
117-
def __init__(filenames):
118-
"""Creates a `SnapshotDataset`.
119-
120-
Args:
121-
filenames: A `tf.string` tensor or a `tf.data.Dataset` containing one or
122-
more filenames.
123-
"""
124-
pass
99+
def default_reader_fn(datasets):
100+
# shuffle the datasets splits
101+
datasets = datasets.shuffle(NUM_DATASETS)
102+
# read datasets in parallel and interleave their elements
103+
return dataset.interleave(lambda x: x, num_parallel_calls=AUTOTUNE)
125104
```
126105

127-
If the `reader_fn` is not specified, a default equivalent to the following
128-
will be used:
129-
```python
130-
def reader_fn(filenames, SnapshotDataset):
131-
return SnapshotDataset(filenames)
132-
```
106+
1. `writer_fn`: Optional. The input pipeline specified by `writer_fn` is
107+
executed when the snapshot op detects that there are no valid snapshots
108+
and no other threads are currently attempting to write a snapshot.
109+
110+
`writer_fn` is a user specified function that accepts a single argument:
111+
(1) a Dataset of elements to be written out. The function should return
112+
a Dataset of Datasets, each representing "splits" of elements of the
113+
original dataset. The tf.data snapshot implementation will then persist
114+
splits in parallel.
115+
116+
A default writer_fn will look like the following:
133117

134-
Users can optionally add snapshot file shuffling and parallelism by passing
135-
a `reader_fn` similar to the one here:
136118
```python
137-
def reader_fn(filenames, SnapshotDataset):
138-
file_ds = Dataset.from_tensor_slices(filenames)
139-
file_ds = file_ds.shuffle(1000)
140-
reader_ds = dataset.interleave(
141-
lambda x: SnapshotDataset(x),
142-
cycle_length=32,
143-
num_parallel_calls=32)
144-
return reader_ds
119+
def default_writer_fn(dataset):
120+
# add a component with element index
121+
dataset = dataset.enumerate()
122+
# split input dataset in a round-robin fashion
123+
return dataset.split(num_splits=NUM_CORES, key_fn=lambda i, _: i % NUM_CORE
145124
```
146125

147-
1. `mode`: Optional. The mode at which snapshot should operate. Valid options
148-
are `auto`, `read`, `write`, and `passthrough`. The default mode is `auto`,
149-
where the snapshot op will automatically determine what mode to operate in.
126+
1. `pending_snapshot_expiry_seconds`: Optional. How long to wait (in seconds)
127+
before the snapshot op considers a previously unfinished snapshot to be
128+
stale and starts writing a snapshot from scratch again. Defaults to 86400
129+
seconds (1 day).
130+
131+
#### Achieving Parallelism
150132

151-
1. `write` mode forces the snapshot transformation to write a new
152-
materialization to disk, regardless of whether a complete and valid
153-
materialization currently exists. In other words, we enter the **WRITE**
154-
state immediately.
133+
`reader_fn` and `writer_fn` will default to passing the dataset through unchanged
134+
by default. In other words, the default implementation will result in
135+
single-threaded reads and writes on snapshots. Parallelism can be achieved in
136+
`writer_fn` by splitting up the dataset into multiple datasets, and using
137+
`num_parallel_calls` in the `interleave` function of the `reader_fn`.
155138

156-
1. `read` mode forces the snapshot transformation to read from the latest
157-
version of the materialization on disk, regardless of whether the data
158-
stored on disk is complete and valid. In other words, we enter the
159-
**READ** state immediately.
139+
#### Computing Graph Fingerprints
160140

161-
1. `passthrough` mode turns the snapshot transformation into a no-op. In
162-
other words, we enter the **PASSTHROUGH** state immediately.
141+
Snapshot attempts to determine whether a run of an input pipeline is the same
142+
as a previous run by computing the fingerprint of the nodes within the pipeline.
163143

164-
1. `auto` retains the default behavior of snapshot. See the "Standard
165-
Kernel Workflow" section for the default behavior.
144+
However, some input pipelines might vary in insignificant ways from run to run
145+
that causes the fingerprinting of them to differ. For instance, consider the
146+
following preprocessing function:
166147

167-
1. `snapshot_name`: Optional. If set, use the supplied string as a named
168-
snapshot name instead of introspecting the data pipeline and automatically
169-
generating a unique identifier for the specific data pipeline.
148+
```python
149+
features_to_multiply = {"feature1", "feature2", "feature3", "feature4"}
170150

171-
1. Instead of generating a new fingerprint of the input processing graph or
172-
and `run_id` (see the _Detailed Design_ section for details), we will
173-
use the `snapshot_name` to uniquely identify the snapshot.
151+
def preprocessing_fn(value):
152+
keys_to_features = {
153+
"feature1": tf.FixedLenFeature([], tf.float32, 0.0),
154+
"feature2": tf.FixedLenFeature([], tf.float32, 0.0),
155+
"feature3": tf.FixedLenFeature([], tf.float32, 0.0),
156+
"feature4": tf.FixedLenFeature([], tf.float32, 0.0)
157+
}
174158

175-
1. Multiple concurrent training jobs with the same "snapshot_name" may
176-
result in concurrent write collisions and a potentially invalid snapshot
177-
if the jobs tries to read from and then write to the metadata file at
178-
exactly the same time.
159+
parsed = tf.parse_single_example(value, keys_to_features)
160+
combined_feature = 1.0
161+
for item in features_to_multiply:
162+
combined_feature *= parsed[item]
179163

180-
The user is expected to handle these cases and explicitly specify `mode`s
181-
to ensure that only one run is set to `write` mode at any point if
182-
collisions are a possibility.
164+
return combined_feature
183165

184-
Note: `AUTO` options above indicates that snapshot will attempt to pick a
185-
reasonable default that is suitable for most use cases. We will eventually add
186-
tf.data autotuning to pick the right parameters for the best performance for
187-
individual workloads.
166+
dataset = ...
167+
dataset = dataset.map(preprocessing_fn)
168+
```
169+
170+
In the above example, our `features_to_multiply` variable uses a `set`, which is
171+
not guaranteed to be ordered in Python 2. When we iterate over the set in the
172+
for loop within `preprocessing_fn`, we may get a different graph on each
173+
run (i.e. one run could have us multiplying `feature2` first, then `feature4`,
174+
etc..., while another run may have us multiplying `feature1`, then `feature3`,
175+
and so on).
176+
177+
In cases like these, we can ask fingerprinting to use a fixed value for the
178+
fingerprint of the map function with a new `set_fingerprint`
179+
transformation, which asks the fingerprinting function to not compute the
180+
fingerprint of the previous node but to use a user-specified value instead:
181+
182+
```python
183+
dataset = ...
184+
dataset = dataset.map(preprocessing_fn)
185+
dataset = tf.data.set_fingerprint(dataset, fingerprint="my_fixed_fp")
186+
```
188187

189188
### External API Guarantees
190189

0 commit comments

Comments
 (0)