|
| 1 | +# tf.data Snapshot |
| 2 | + |
| 3 | +| Status | Accepted | |
| 4 | +| :------------ | :------------------------------------------------------ | |
| 5 | +| **RFC #** | [193](https://github.com/tensorflow/community/pull/193) | |
| 6 | +| **Author(s) ** | Frank Chen ( [email protected]), Rohan Jain | |
| 7 | + |
| 8 | +| **Sponsor ** | Jiri Simsa ( [email protected]) | |
| 9 | +| **Updated** | 2020-02-10 | |
| 10 | + |
| 11 | +## Objective |
| 12 | + |
| 13 | +With ever faster accelerators available in Cloud and hyperparameter tuning |
| 14 | +consuming larger chunks of accelerator time, TensorFlow users are increasingly |
| 15 | +finding that they don’t have enough CPU resources to keep up with these |
| 16 | +accelerators, leaving valuable accelerator resources idle. |
| 17 | + |
| 18 | +To alleviate this problem, we are proposing a `snapshot` API within `tf.data`, |
| 19 | +to allow users to transparently persist the output of their preprocessing |
| 20 | +pipeline to disk, and materialize the pre-processed data on a different training |
| 21 | +run. |
| 22 | + |
| 23 | +This API enables repeated preprocessing steps to be consolidated, and allowing |
| 24 | +re-use of already processed data, trading off disk storage and network bandwidth |
| 25 | +for freeing up more valuable CPU resources and accelerator compute time. |
| 26 | + |
| 27 | +## Motivation |
| 28 | + |
| 29 | +Large TensorFlow users have indicated that they have complicated input |
| 30 | +processing pipelines which saturate their CPUs before saturating their |
| 31 | +accelerators (TPUs in particular). Since they often experiment with |
| 32 | +hyperparameter tuning or tweaks to existing model without affecting their input |
| 33 | +pipeline, they are asking for ways to avoid similar repeated preprocessing of |
| 34 | +data by either saving a dataset or caching it to disk. |
| 35 | + |
| 36 | +## User Benefit |
| 37 | + |
| 38 | +Users will be able to transparently persist partially or fully processed data |
| 39 | +from `tf.data` input pipelines to disk or Cloud storage systems, and materialize |
| 40 | +the pre-processed data during subsequent runs from the same pipeline. This will |
| 41 | +cut down on the input pipeline processing overheads during second and subsequent |
| 42 | +runs. |
| 43 | + |
| 44 | +## Design Proposal |
| 45 | + |
| 46 | +We propose that we add a new `snapshot` transformation to tf.data. To illustrate |
| 47 | +the usage of the transformation, we can start with some sample code: |
| 48 | + |
| 49 | +```python |
| 50 | +dataset = Dataset.list_files("/raw/data/*").shard(num_workers, i) |
| 51 | +dataset = dataset.parallel_interleave(TFRecordDataset) |
| 52 | +dataset = dataset.map(my_preprocessing_fn) |
| 53 | +dataset = dataset.apply(tf.data.snapshot("/saved/data", options...)) |
| 54 | +dataset = dataset.repeat() |
| 55 | + |
| 56 | +model = ... |
| 57 | +model.fit(dataset) |
| 58 | +``` |
| 59 | + |
| 60 | +As we can see, the end user simply has to add this transformation in order to |
| 61 | +use this functionality. In essence, the transformation is similar to the |
| 62 | +existing `tf.data.Dataset.cache`, with the key difference is being that, unlike |
| 63 | +`cache`, `snapshot` is intended to re-used across different executions of the |
| 64 | +same input pipelines. |
| 65 | + |
| 66 | +### Proposed API |
| 67 | + |
| 68 | +We are proposing the following API for the snapshot transformation. |
| 69 | + |
| 70 | +```python |
| 71 | +def snapshot(path, |
| 72 | + compression=None, |
| 73 | + reader_fn=None, |
| 74 | + writer_fn=None, |
| 75 | + pending_snapshot_expiry_seconds=None): |
| 76 | + pass # Implementation goes here. |
| 77 | +``` |
| 78 | + |
| 79 | +1. `path`: Required. A directory where we want to save our snapshots and/or |
| 80 | + read from a previously saved snapshot. |
| 81 | + |
| 82 | +1. `compression`: Optional. The type of compression to apply to the snapshot |
| 83 | + written to disk. This will support `GZIP`, `SNAPPY` or None. Defaults to |
| 84 | + AUTO. |
| 85 | + |
| 86 | +1. `reader_fn`: Optional. The input pipeline transformation specified by |
| 87 | + `reader_fn` is executed when the snapshot detects that there is an existing, |
| 88 | + valid snapshot available. |
| 89 | + |
| 90 | + `reader_fn` is a user specified function that accepts a single argument: |
| 91 | + (1) a Dataset of Datasets, each representing a "splits" of elements of the |
| 92 | + original dataset. The cardinality of the input dataset matches the |
| 93 | + cardinality of the output of `writer_fn` (see below). The function should |
| 94 | + return a Dataset of elements of the original dataset. |
| 95 | + |
| 96 | + A default `reader_fn` will look like the following: |
| 97 | + |
| 98 | + ```python |
| 99 | + def default_reader_fn(datasets): |
| 100 | + # shuffle the datasets splits |
| 101 | + datasets = datasets.shuffle(NUM_DATASETS) |
| 102 | + # read datasets in parallel and interleave their elements |
| 103 | + return dataset.interleave(lambda x: x, num_parallel_calls=AUTOTUNE) |
| 104 | + ``` |
| 105 | + |
| 106 | +1. `writer_fn`: Optional. The input pipeline specified by `writer_fn` is |
| 107 | + executed when the snapshot op detects that there are no valid snapshots |
| 108 | + and no other threads are currently attempting to write a snapshot. |
| 109 | + |
| 110 | + `writer_fn` is a user specified function that accepts a single argument: |
| 111 | + (1) a Dataset of elements to be written out. The function should return |
| 112 | + a Dataset of Datasets, each representing "splits" of elements of the |
| 113 | + original dataset. The tf.data snapshot implementation will then persist |
| 114 | + splits in parallel. |
| 115 | + |
| 116 | + A default writer_fn will look like the following: |
| 117 | + |
| 118 | + ```python |
| 119 | + def default_writer_fn(dataset): |
| 120 | + # add a component with element index |
| 121 | + dataset = dataset.enumerate() |
| 122 | + # split input dataset in a round-robin fashion |
| 123 | + return dataset.split(num_splits=NUM_CORES, key_fn=lambda i, _: i % NUM_CORE |
| 124 | + ``` |
| 125 | + |
| 126 | +1. `pending_snapshot_expiry_seconds`: Optional. How long to wait (in seconds) |
| 127 | + before the snapshot op considers a previously unfinished snapshot to be |
| 128 | + stale and starts writing a snapshot from scratch again. Defaults to 86400 |
| 129 | + seconds (1 day). |
| 130 | + |
| 131 | +#### Achieving Parallelism |
| 132 | + |
| 133 | +`reader_fn` and `writer_fn` will default to passing the dataset through unchanged |
| 134 | +by default. In other words, the default implementation will result in |
| 135 | +single-threaded reads and writes on snapshots. Parallelism can be achieved in |
| 136 | +`writer_fn` by splitting up the dataset into multiple datasets, and using |
| 137 | +`num_parallel_calls` in the `interleave` function of the `reader_fn`. |
| 138 | + |
| 139 | +#### Computing Graph Fingerprints |
| 140 | + |
| 141 | +Snapshot attempts to determine whether a run of an input pipeline is the same |
| 142 | +as a previous run by computing the fingerprint of the nodes within the pipeline. |
| 143 | + |
| 144 | +However, some input pipelines might vary in insignificant ways from run to run |
| 145 | +that causes the fingerprinting of them to differ. For instance, consider the |
| 146 | +following preprocessing function: |
| 147 | + |
| 148 | +```python |
| 149 | +features_to_multiply = {"feature1", "feature2", "feature3", "feature4"} |
| 150 | + |
| 151 | +def preprocessing_fn(value): |
| 152 | + keys_to_features = { |
| 153 | + "feature1": tf.FixedLenFeature([], tf.float32, 0.0), |
| 154 | + "feature2": tf.FixedLenFeature([], tf.float32, 0.0), |
| 155 | + "feature3": tf.FixedLenFeature([], tf.float32, 0.0), |
| 156 | + "feature4": tf.FixedLenFeature([], tf.float32, 0.0) |
| 157 | + } |
| 158 | + |
| 159 | + parsed = tf.parse_single_example(value, keys_to_features) |
| 160 | + combined_feature = 1.0 |
| 161 | + for item in features_to_multiply: |
| 162 | + combined_feature *= parsed[item] |
| 163 | + |
| 164 | + return combined_feature |
| 165 | + |
| 166 | +dataset = ... |
| 167 | +dataset = dataset.map(preprocessing_fn) |
| 168 | +``` |
| 169 | + |
| 170 | +In the above example, our `features_to_multiply` variable uses a `set`, which is |
| 171 | +not guaranteed to be ordered in Python. When we iterate over the set in the |
| 172 | +for loop within `preprocessing_fn`, we may get a different graph on each |
| 173 | +run (i.e. one run could have us multiplying `feature2` first, then `feature4`, |
| 174 | +etc..., while another run may have us multiplying `feature1`, then `feature3`, |
| 175 | +and so on). |
| 176 | + |
| 177 | +In cases like these, we can ask fingerprinting to use a fixed value for the |
| 178 | +fingerprint of the map function with a new `with_snapshot_fingerprint` |
| 179 | +transformation, which asks the fingerprinting function to not compute the |
| 180 | +fingerprint of the previous node but to use a user-specified value instead: |
| 181 | + |
| 182 | +```python |
| 183 | +dataset = ... |
| 184 | +dataset = dataset.map(preprocessing_fn) |
| 185 | +dataset = tf.data.experimental.with_snapshot_fingerprint( |
| 186 | + dataset, fingerprint="my_fixed_fp") |
| 187 | +``` |
| 188 | + |
| 189 | +### External API Guarantees |
| 190 | + |
| 191 | +Externally, we guarantee that snapshots written by a particular version of |
| 192 | +TensorFlow will be readable by that specific version of TensorFlow. |
| 193 | + |
| 194 | +We are not currently handling the case where workers do not go through the |
| 195 | +entire training set at least once. |
| 196 | + |
| 197 | +### Alternatives Considered |
| 198 | + |
| 199 | +An alternative proposal for an API would be `save()` and `load()`, where the |
| 200 | +saving and loading of the input pipeline would be made more explicit, avoiding |
| 201 | +some of the logic needed in determining whether to snapshot or read from a |
| 202 | +snapshot of a model. |
| 203 | + |
| 204 | +The downside here would be that the user would have to split the preprocessing |
| 205 | +and training into potentially different files, and users would be forced to |
| 206 | +select whether to train or preprocess on their own, which is not good. |
| 207 | + |
| 208 | +### Performance Implications |
| 209 | + |
| 210 | +Benchmarks for this feature will be included as part of Dataset microbenchmarks. |
| 211 | + |
| 212 | +### Dependencies |
| 213 | + |
| 214 | +No new dependencies will be introduced as part of this project to TensorFlow. |
| 215 | +Dependent projects may be able to use this additional op, but there should be no |
| 216 | +significant changes otherwise. |
| 217 | + |
| 218 | +### Engineering Impact |
| 219 | + |
| 220 | +Binary sizes increases slightly with the inclusion of this new op, and this code |
| 221 | +will be maintained by the `tf.data` team. |
| 222 | + |
| 223 | +### Platforms and Environments |
| 224 | + |
| 225 | +This op will work on all TensorFlow-supported platforms. We do not anticipate |
| 226 | +this to work on embedded systems as it is not useful in resource-constrained |
| 227 | +environments. |
| 228 | + |
| 229 | +### Best Practices, Tutorials and Examples |
| 230 | + |
| 231 | +A user guide for snapshot will be published to guide new users in using this |
| 232 | +feature. |
| 233 | + |
| 234 | +### Compatibility |
| 235 | + |
| 236 | +This introduces a new op, which will impact future backwards compatibility. |
| 237 | + |
| 238 | +### User Impact |
| 239 | + |
| 240 | +A new python function and a new op are the only user-facing changes visible. |
| 241 | + |
| 242 | +## Detailed Design |
| 243 | + |
| 244 | +### Implementation Assumptions |
| 245 | + |
| 246 | +The following implementation is based on the following assumptions that define |
| 247 | +the MVP this is designed for: |
| 248 | + |
| 249 | +1. We assume that at least for one pipeline run, you can go through the entire |
| 250 | + training dataset and be able to store that data on disk. Otherwise, a |
| 251 | + snapshot will never get created. |
| 252 | + |
| 253 | +2. In the cases where there are multiple workers and the dataset is sharded with |
| 254 | + `Dataset.shard`, we assume that the number of workers remains the same from |
| 255 | + the initial (writing) run through to the reading runs. |
| 256 | + |
| 257 | + If the number of workers change, then the `num_shards` parameter to |
| 258 | + `Dataset.shard` will change, and this will result in a different graph |
| 259 | + fingerprint and another snapshot write will be triggered. |
| 260 | + |
| 261 | + If all workers use the exact same input pipeline with no sharding (e.g. all |
| 262 | + workers will read from all the files), then snapshot will still be able to |
| 263 | + read from previous snapshots even if the number of workers is different. |
| 264 | + |
| 265 | +3. Any `repeat`s in the dataset should be moved to after the `snapshot` op, to |
| 266 | + avoid writing large (or infinite) amounts of data during a snapshot writing |
| 267 | + run. |
| 268 | + |
| 269 | +### New `SnapshotDatasetOp` |
| 270 | + |
| 271 | +To implement the transformation, we are introducing a new `SnapshotDatasetOp` |
| 272 | +dataset kernel that will implement all of the functionality in TensorFlow C++. |
| 273 | +Python code is mostly glue code to pass relevant parameters into the op kernel. |
| 274 | + |
| 275 | +### Internal Directory / File Structure |
| 276 | + |
| 277 | +Given a user directory path (e.g. `/path/to/snapshot`), the directory will look |
| 278 | +like: |
| 279 | + |
| 280 | +* /path/to/snapshot |
| 281 | + * `fingerprint`/ |
| 282 | + * snapshot.metadata |
| 283 | + * `run-id`/ |
| 284 | + * 0000000.snapshot |
| 285 | + * 0000001.snapshot |
| 286 | + |
| 287 | +The `fingerprint` is a hash of the input processing graph. The `run-id` is |
| 288 | +unique training run ID generated. |
| 289 | + |
| 290 | +### Standard Kernel Workflow |
| 291 | + |
| 292 | +_Note: This is an implementation detail, and may change in the future. This |
| 293 | +should not be relied upon except as a reference to the current implementation._ |
| 294 | + |
| 295 | +By default, the `snapshot` operation will, upon startup, make a determination |
| 296 | +using the following algorithm as to whether the operation should be in the |
| 297 | +WRITE, PASSTHROUGH, or READ state. |
| 298 | + |
| 299 | +1. We will compute a graph fingerprint containing all the information from the |
| 300 | + Dataset preprocessing graph before the `snapshot` op. We’ll use the |
| 301 | + `AsGraphDefInternal` method on DatasetBase for this. |
| 302 | + |
| 303 | +1. We will attempt to enter the corresponding fingerprint directory. For |
| 304 | + instance, if the computed fingerprint is `f-abc123` and the base snapshot |
| 305 | + directory is `/saved/data`, then we will attempt to enter |
| 306 | + `/saved/data/f-abc123`. |
| 307 | + |
| 308 | +1. If the snapshot directory is non-existent, empty or it doesn’t contain a |
| 309 | + `metadata` file, we will enter the **WRITE** state. |
| 310 | + |
| 311 | +1. If the snapshot directory contains a `metadata.final` file, we will read |
| 312 | + the final metadata file and proceed to the **READ** state. |
| 313 | + |
| 314 | + 1. The file contains the following fields: |
| 315 | + 1. A training run ID, |
| 316 | + 1. A boolean indicating if the snapshot is complete. |
| 317 | + 1. A training run start-time. |
| 318 | + |
| 319 | +1. If the snapshot directory contains a `metadata` file but not a |
| 320 | + `metadata.final` file, we will read the metadata file. |
| 321 | + |
| 322 | +1. If the training run start-time is more than the (configurable) training run |
| 323 | + timeout (set with the `pending_snapshot_expiry_seconds` parameter), we will |
| 324 | + enter the **WRITE** state. |
| 325 | + |
| 326 | +1. If the training run start-time is less than the training run timeout, but |
| 327 | + the snapshot is not complete, then we will enter the **PASSTHROUGH** state. |
| 328 | + |
| 329 | +1. If the snapshot is complete, we will enter the **READ** state. |
| 330 | + |
| 331 | +#### WRITE State |
| 332 | + |
| 333 | +1. We generate a random training run ID. |
| 334 | + |
| 335 | +1. We write (possibly overwriting) the `snapshot.metadata` file. |
| 336 | + |
| 337 | +1. We proceed to create a subdirectory containing the training run ID, and |
| 338 | + start writing data asynchronously in chunks. |
| 339 | + |
| 340 | +1. At the end of the dataset (when `end_of_sequence == true`), we will check |
| 341 | + the snapshot.metadata file to determine whether it contains the same |
| 342 | + training run ID. |
| 343 | + |
| 344 | + 1. If it does, we write a `metadata.final` file containing the |
| 345 | + same information as the `metadata` file but with the complete |
| 346 | + bit set to true. |
| 347 | + 1. If it does not, it means that someone else is concurrently writing the |
| 348 | + snapshot and we lost the race to them. We delete all data in the |
| 349 | + training run directory. |
| 350 | + |
| 351 | +For the current implementation, we will store the data in chunked TFRecord |
| 352 | +files. Eventually we may move to other more higher performance data stores or |
| 353 | +support additional storage systems such as Cloud BigTable. |
| 354 | + |
| 355 | +#### PASSTHROUGH State |
| 356 | + |
| 357 | +1. This is a no-op, where we simply pass through the tensors to the downstream |
| 358 | + operations. |
| 359 | + |
| 360 | +#### READ State |
| 361 | + |
| 362 | +1. We will read from the snapshots contained within the subfolder with the |
| 363 | + correct graph fingerprint and specified training run ID. |
| 364 | + |
| 365 | +1. Optionally, the user may choose to tell us to specify that the snapshots |
| 366 | + should be read back in shuffled order. |
| 367 | + |
| 368 | +### Concurrency: Handling Multiple Input Workers |
| 369 | + |
| 370 | +If input workers are sharded, then they will generate different graph |
| 371 | +fingerprints as their shard indexes will be different. This will result in each |
| 372 | +worker writing to a different subdirectory. |
| 373 | + |
| 374 | +If input workers are not sharded, then this will result in a race and |
| 375 | +potentially multiple workers writing data (still with different training run |
| 376 | +IDs). Eventually, if each worker finishes, we will be left with one copy of the |
| 377 | +data as all the other workers will determine that they have lost the race and |
| 378 | +delete their own copy of the snapshot data. |
| 379 | + |
| 380 | +## Questions and Discussion Topics |
| 381 | + |
| 382 | +* Should we implement this as three ops (a control opt o determine whether a |
| 383 | + snapshot is to be read from/written to) and a write and read op to do the |
| 384 | + respective operations? |
| 385 | + * Pros include: |
| 386 | + * Modularizes the implementation into smaller chunks |
| 387 | + * Allows someone else to do the "control" |
| 388 | + * Challenges include: |
| 389 | + * Where/how the "control" runs? |
| 390 | + * How do we construct the dataset graph properly? |
| 391 | +* How should autotuning be integrated into the snapshot transformation? |
| 392 | +* Are the configuration options well named? Is it possible to consolidate some |
| 393 | + of these options? |
| 394 | +* What other compression/decompression options would you like to see |
| 395 | + supported? |
| 396 | +* Any other performance / feature tuning knobs we should make available? |
0 commit comments