|
| 1 | +# Support for Pickle, Python's serialization protocol |
| 2 | + |
| 3 | +| Status | Accepted | |
| 4 | +:-------------- |:---------------------------------------------------- | |
| 5 | +| **RFC #** | [286](https://github.com/tensorflow/community/pull/286) | |
| 6 | +| **Author(s) ** | Adrian Garcia Badaracco ({firstname}at{firstname}gb.com), Scott Sievert ( [email protected]) | |
| 7 | +| **Sponsor ** | Mihai Maruseac ( [email protected]) | |
| 8 | +| **Updated** | 2020-09-21 | |
| 9 | + |
| 10 | +## Objective |
| 11 | + |
| 12 | +Implement support for Pickle, Python's serialization protocol within Keras. |
| 13 | + |
| 14 | +## Motivation |
| 15 | + |
| 16 | +> *Why this is a valuable problem to solve? What background information is |
| 17 | +> needed to show how this design addresses the problem?* |
| 18 | +
|
| 19 | +The specific motivation for this RFC: we want to use Keras models in Dask-ML's |
| 20 | +and Ray's hyperparameter optimization. More generally, support for serialization |
| 21 | +with the Pickle protocol will enable: |
| 22 | + |
| 23 | +* Using Keras with other parallelization libraries like Python's |
| 24 | + `multiprocessing`, Dask, Ray or IPython parallel. |
| 25 | +* Saving Keras models to disk with custom serialization libraries like Joblib |
| 26 | + or Dill. This is common when using a Keras model as part of a Scikit-Learn |
| 27 | + pipeline or with their hyperparameter searches. |
| 28 | +* Copying Keras models with Python's built-in `copy.deepcopy`. |
| 29 | + |
| 30 | +Supporting Pickle will enable wider usage in the Python ecosystem because |
| 31 | +Python's ecosystems of libraries depend strongly on the presence of protocols. |
| 32 | +Without these protocols, it's necessary for each library to implement a custom |
| 33 | +serialization method for every other library. For example, Dask Distributed has |
| 34 | +a custom serialization method for Keras at [distributed/protocol/keras.py]. |
| 35 | +See "[Pickle isn't slow, it's a protocol]" for more detail (notably, this post |
| 36 | +focuses on having an efficient Pickle implementation for PyTorch). |
| 37 | + |
| 38 | +[distributed/protocol/keras.py]:https://github.com/dask/distributed/blob/73fa9bd1bd7dcb4ceed72cdbdc6dd4b92f887521/distributed/protocol/keras.py |
| 39 | + |
| 40 | +This request is *not* advocating for use of Pickle while saving or sharing |
| 41 | +Keras models. We believe the efficient, secure and stable methods in TF should |
| 42 | +be used for that. Instead, we are proposing to add a Pickle implementation to |
| 43 | +support wider usage in the Python ecosystem. |
| 44 | + |
| 45 | +[Pickle isn't slow, it's a protocol]:https://blog.dask.org/2018/07/23/protocols-pickle |
| 46 | + |
| 47 | +> *Which users are affected by the problem? Why is it a problem? What data |
| 48 | +> supports this? What related work exists?* |
| 49 | +
|
| 50 | +Users trying to use distributed systems (e.g, Ray or Dask) with Keras are |
| 51 | +affected. In our experience, this is common in hyperparameter optimization. In |
| 52 | +general, having Pickle support means a better experience, especially when using |
| 53 | +Keras with other libraries. Briefly, implementation of this RFC will make the |
| 54 | +following possible: |
| 55 | + |
| 56 | +* Saving a Scikit-Learn pipeline to disk if it includes a Keras model |
| 57 | +* Using custom parallelization like Joblib or Dask. |
| 58 | + |
| 59 | +More use cases and examples are give in "User Benefit." |
| 60 | + |
| 61 | +Related work is in [SciKeras], which brings a Scikit-Learn API |
| 62 | +to Keras. Pickle is relevant because Scikit-Learn requires that estimators must be able to be pickled ([source][skp]). |
| 63 | +As such, SciKeras has an implementation of `__reduce__`, which is also in |
| 64 | +[tensorflow#39609]. |
| 65 | + |
| 66 | +[dask-ml#534]:https://github.com/dask/dask-ml/issues/534 |
| 67 | +[SO#51110834]:https://stackoverflow.com/questions/51110834/cannot-pickle-dill-a-keras-object |
| 68 | +[SO#54070845]:https://stackoverflow.com/questions/54070845/how-to-pickle-keras-custom-layer |
| 69 | +[SO#59872509]:https://stackoverflow.com/questions/59872509/how-to-export-a-model-created-from-kerasclassifier-and-gridsearchcv-using-joblib |
| 70 | +[SO#37984304]:https://stackoverflow.com/questions/37984304/how-to-save-a-scikit-learn-pipline-with-keras-regressor-inside-to-disk |
| 71 | +[SO#48295661]:https://stackoverflow.com/questions/48295661/how-to-pickle-keras-model |
| 72 | +[skper]:https://scikit-learn.org/stable/modules/model_persistence.html#persistence-example |
| 73 | +[TF#33204]:https://github.com/tensorflow/tensorflow/issues/33204 |
| 74 | +[TF#34697]:https://github.com/tensorflow/tensorflow/issues/34697 |
| 75 | + |
| 76 | +[tensorflow#39609]:https://github.com/tensorflow/tensorflow/pull/39609 |
| 77 | +[SciKeras]:https://github.com/adriangb/scikeras |
| 78 | +[skp]:https://github.com/scikit-learn/scikit-learn/blob/0fb307bf39bbdacd6ed713c00724f8f871d60370/sklearn/utils/estimator_checks.py#L1523-L1524 |
| 79 | + |
| 80 | +<!-- |
| 81 | +StackOverflow questions where `Model.save` would not work: |
| 82 | +
|
| 83 | +* [SO#40396042](https://stackoverflow.com/questions/40396042/how-to-save-scikit-learn-keras-model-into-a-persistence-file-pickle-hd5-json-ya) |
| 84 | + |
| 85 | +Examples that could be resolved using `Model.save` (but the user tried pickle first): |
| 86 | +
|
| 87 | +* [SO #51878627](https://stackoverflow.com/questions/51878627/pickle-keras-ann) |
| 88 | +--> |
| 89 | + |
| 90 | +## User Benefit |
| 91 | + |
| 92 | +> How will users (or other contributors) benefit from this work? What would be the headline in the release notes or blog post? |
| 93 | +
|
| 94 | +One blog post headline: "Keras models can be used with the advanced |
| 95 | +hyperparameter optimization techniques found in Dask-ML and Ray Tune." This has |
| 96 | +already been mentioned in "Framework support" of [a Dask blog post][dbp] |
| 97 | +comparing Dask-ML's hyperparameter optimization with Ray's tune-sklearn. |
| 98 | + |
| 99 | +[dbp]:https://blog.dask.org/2020/08/06/ray-tune#framework-support |
| 100 | + |
| 101 | +Users will also benefit with easier usage; they won't run into any of these |
| 102 | +errors: |
| 103 | + |
| 104 | +* People try to save Scikit-Learn meta-estimators with Keras components using |
| 105 | + the serialization libraries Joblib or Dill. |
| 106 | + This fails because Keras models can not be serialized without a custom |
| 107 | + method. Examples include [SO#59872509], [SO#37984304] and |
| 108 | + [SO#48295661], and [SO#51110834]. |
| 109 | +* Using custom parallelization strategies requires serialization support through |
| 110 | + Pickle; however, many parallelization libraries don't |
| 111 | + special case Keras models (e.g, Joblib). Relevant errors are most common in hyperparameter |
| 112 | + optimization with Scikit-Learn's parallelization through Joblib |
| 113 | + ([TF#33204] and [TF#34697]) or parallelization through Dask ([dask-ml#534]). |
| 114 | +* Lack of Pickle support can complicate saving training history like in |
| 115 | + (the poorly asked) [SO#54070845]. |
| 116 | + |
| 117 | +This RFC would resolve these issues. |
| 118 | + |
| 119 | +## Design Proposal |
| 120 | + |
| 121 | +We propose implementing the Pickle protocol using the existing Keras |
| 122 | +saving functionality as a backend. For example, adding pickle support to TF Metrics |
| 123 | +is as simple as the following: |
| 124 | + |
| 125 | +``` python |
| 126 | +# tensorflow/python/keras/metrics.py |
| 127 | + |
| 128 | +@keras_export('keras.metrics.Metric') # line 80 |
| 129 | +@six.add_metaclass(abc.ABCMeta) |
| 130 | +class Metric(base_layer.Layer): |
| 131 | + ... |
| 132 | + |
| 133 | + def __reduce__(self, protocol): |
| 134 | + # the deserialized/serialize functions are defined in this file |
| 135 | + return deserialize, (serialize(self),) |
| 136 | +``` |
| 137 | + |
| 138 | +This implementation adds support for the Pickle protocol, which supports serialization |
| 139 | +to arbitrary IO, either memory or disk. The `__reduce__` special method can return |
| 140 | +the string that would have been written to disk and the function to load that string into memory ([docs][reduce_docs]). |
| 141 | + |
| 142 | +[reduce_docs]:https://docs.python.org/3/library/pickle.html#object.__reduce__ |
| 143 | + |
| 144 | +For `tf.keras.Model`, we can use `SaveModel` as the backend for `__reduce__`: |
| 145 | + |
| 146 | +``` python |
| 147 | +# tensorflow/python/keras/engine/training.py |
| 148 | +... |
| 149 | +from tesorflow.python.keras.models import load_model |
| 150 | + |
| 151 | +class Model(base_layer.Layer, version_utils.ModelVersionSelector): # line 131 |
| 152 | + ... |
| 153 | + |
| 154 | + def __reduce__(self, protocol): |
| 155 | + temp_ram_location = f"ram://tmp/saving/{id(self)}" |
| 156 | + self.save(temp_ram_location) |
| 157 | + b = tf.io.gfile.read_folder(temp_ram_location) |
| 158 | + return self._reconstruct_pickle, (np.asarray(memoryview(b)), ) |
| 159 | + |
| 160 | + @classmethod |
| 161 | + def _reconstruct_pickle(cls, obj): |
| 162 | + temp_ram_location = f"ram://tmp/saving/{id(obj)}" |
| 163 | + tf.io.gfile.write_folder(temp_ram_location, b) |
| 164 | + return load_model(temp_ram_location) |
| 165 | +``` |
| 166 | + |
| 167 | +This almost exactly mirrors the PyTorch |
| 168 | +implementation of Pickle support in [pytorch#9184] |
| 169 | +as mentioned in "[Pickle isn't slow, it's a protocol]." |
| 170 | +In addition, small augmentations to TensorFlow's IO module will be required (as discussed in [tensorflow#39609]). |
| 171 | + |
| 172 | +By wrapping the pickled object within a Numpy array, pickling will support |
| 173 | +pickle protocol 5 for zero-copy pickling. This provides an immediate |
| 174 | +performance improvement for many use cases. |
| 175 | + |
| 176 | +[pytorch#9184]:https://github.com/pytorch/pytorch/pull/9184 |
| 177 | + |
| 178 | +### Alternatives Considered |
| 179 | + |
| 180 | +Of course, one method is to ask users to monkey-patch Keras models themselves. |
| 181 | +This would hold for libraries too. Clearly, this is unreasonable. Regardless, |
| 182 | +some libraries like Dask Distributed have already implemented custom serialization |
| 183 | +protocols ([distributed/protocol/keras.py]). |
| 184 | + |
| 185 | +#### Other pickle implementations |
| 186 | + |
| 187 | +The Pickle protocol supports two features: |
| 188 | + |
| 189 | +1. In-memory copying of live objects: via Python's `copy` module. This falls back to (2) below. |
| 190 | +2. Serialization to arbitrary IO (memory or disk): via Python's `pickle` module. |
| 191 | + |
| 192 | +This proposal seeks to take the conservative approach at least initially and |
| 193 | +only implement (2) above since (1) can always fall back to (2) and using only |
| 194 | +(2) alleviates any concerns around references to freed memory in the C++ |
| 195 | +portions of TF and other such bugs. |
| 196 | + |
| 197 | +This said, for situations where the user is making an in-memory copy of an object and it might |
| 198 | +even be okay to keep around references to non-Python objects, a separate approach that optimizes |
| 199 | +(1) would be warranted. This RFC does not seek to address this problem. Hence |
| 200 | +this RFC is generally not concerned with: |
| 201 | + |
| 202 | +* Issues arising from C++ references. These cannot be kept around when |
| 203 | + serializing to a binary file stream. |
| 204 | +* Performance of the serialization/deserialization. |
| 205 | + |
| 206 | +### Performance Implications |
| 207 | + |
| 208 | +* The performance should be the same as the underlying backend that is already |
| 209 | + implemented in TF. |
| 210 | +* For cases where the user was going to pickle anyway, this will be faster |
| 211 | + because it uses TF's methods instead of letting Python deal with it naively. |
| 212 | +* Tests will consist of running `new_model = pickle.loads(pickle.dumps(model))` |
| 213 | + and then doing checks on `new_model`. |
| 214 | + |
| 215 | +### Dependencies |
| 216 | + |
| 217 | +> Dependencies: does this proposal add any new dependencies to TensorFlow? |
| 218 | +
|
| 219 | +No |
| 220 | + |
| 221 | +> Dependent projects: are there other areas of TensorFlow or things that use |
| 222 | + TensorFlow (TFX/pipelines, TensorBoard, etc.) that this affects? |
| 223 | + |
| 224 | +This should not affect those libraries. It will affect libraries |
| 225 | +further downstream like Dask-ML and Ray Tune. |
| 226 | + |
| 227 | +### Engineering Impact |
| 228 | + |
| 229 | +> Do you expect changes to binary size / startup time / build time / test |
| 230 | + times? |
| 231 | + |
| 232 | +No |
| 233 | + |
| 234 | +> Who will maintain this code? Is this code in its own buildable unit? Can this |
| 235 | + code be tested in its own? Is visibility suitably restricted to only a small |
| 236 | + API surface for others to use? |
| 237 | + |
| 238 | +This code depends on existing Keras/TF methods. This code will not break |
| 239 | +presuming they are maintained (the new API surface area is very small). |
| 240 | + |
| 241 | +### Platforms and Environments |
| 242 | + |
| 243 | +> Platforms: does this work on all platforms supported by TensorFlow? If not, |
| 244 | + why is that ok? Will it work on embedded/mobile? Does it impact automatic |
| 245 | + code generation or mobile stripping tooling? Will it work with transformation |
| 246 | + tools? |
| 247 | + |
| 248 | +Yes, as long as a Python backend is available. |
| 249 | + |
| 250 | +> Execution environments (Cloud services, accelerator hardware): what impact do |
| 251 | + you expect and how will you confirm? |
| 252 | + |
| 253 | +We don't see any impact. |
| 254 | + |
| 255 | +### Best Practices |
| 256 | + |
| 257 | +> Does this proposal change best practices for some aspect of using/developing |
| 258 | + TensorFlow? How will these changes be communicated/enforced? |
| 259 | + |
| 260 | +No |
| 261 | + |
| 262 | +### Tutorials and Examples |
| 263 | + |
| 264 | +There are plenty of examples of how this can and would be used within all of the issues above, in addition to the linked notebook |
| 265 | +([link again](https://colab.research.google.com/drive/14ECRN8ZQDa1McKri2dctlV_CaPkE574I?authuser=1#scrollTo=qlXDfJObNXVf)) which has |
| 266 | +end to end implementations and tests for all of this. |
| 267 | + |
| 268 | +### Compatibility |
| 269 | + |
| 270 | +> Does the design conform to the backwards & forwards compatibility |
| 271 | + [requirements](https://www.tensorflow.org/programmers_guide/version_compat)? |
| 272 | + |
| 273 | +Yes |
| 274 | + |
| 275 | +> How will this proposal interact with other parts of the TensorFlow Ecosystem?* |
| 276 | +
|
| 277 | +It should have no immediate impact on other parts of the TF ecosystem. |
| 278 | + |
| 279 | +> How will it work with TFLite? |
| 280 | +
|
| 281 | +N/A |
| 282 | + |
| 283 | +> How will it work with distribution strategies? |
| 284 | +
|
| 285 | +This enables use of other serialization libraries, which might enable support for other distribution strategies. |
| 286 | + |
| 287 | +> How will it interact with tf.function? |
| 288 | +
|
| 289 | +N/A |
| 290 | + |
| 291 | +> Will this work on GPU/TPU? |
| 292 | +
|
| 293 | +N/A |
| 294 | + |
| 295 | +> How will it serialize to a SavedModel? |
| 296 | +
|
| 297 | +Not applicable, and almost a circular question. |
| 298 | + |
| 299 | +### User Impact |
| 300 | + |
| 301 | +> What are the user-facing changes? How will this feature be rolled out? |
| 302 | +
|
| 303 | +There are no user-facing changes: this is a backend change to private methods. |
| 304 | + |
| 305 | +Rolling out only involves testing. It will not require any documentation |
| 306 | +changes to advertise this features: `Model.save` should still be used for users |
| 307 | +simply trying to save their model to disk. |
| 308 | + |
| 309 | +## Questions and Discussion Topics |
| 310 | + |
| 311 | +> Seed this with open questions you require feedback on from the RFC process. |
0 commit comments