Skip to content
This repository was archived by the owner on Jul 10, 2025. It is now read-only.

Commit 960af33

Browse files
committed
Propose support for Python's picikle protocol within Keras
1 parent cf6faa2 commit 960af33

File tree

1 file changed

+311
-0
lines changed

1 file changed

+311
-0
lines changed

rfcs/20200921-pickle-for-keras.md

Lines changed: 311 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,311 @@
1+
# Support for Pickle, Python's serialization protocol
2+
3+
| Status | Proposed |
4+
:-------------- |:---------------------------------------------------- |
5+
| **RFC #** | [286](https://github.com/tensorflow/community/pull/286) |
6+
| **Author(s)** | Adrian Garcia Badaracco ({firstname}at{firstname}gb.com), Scott Sievert ([email protected]) |
7+
| **Sponsor** | Mihai Maruseac ([email protected]) |
8+
| **Updated** | 2020-09-21 |
9+
10+
## Objective
11+
12+
Implement support for Pickle, Python's serialization protocol within Keras.
13+
14+
## Motivation
15+
16+
> *Why this is a valuable problem to solve? What background information is
17+
> needed to show how this design addresses the problem?*
18+
19+
The specific motivation for this RFC: we want to use Keras models in Dask-ML's
20+
and Ray's hyperparameter optimization. More generally, support for serialization
21+
with the Pickle protocol will enable:
22+
23+
* Using Keras with other parallelization libraries like Python's
24+
`multiprocessing`, Dask, Ray or IPython parallel.
25+
* Saving Keras models to disk with custom serialization libraries like Joblib
26+
or Dill. This is common when using a Keras model as part of a Scikit-Learn
27+
pipeline or with their hyperparameter searches.
28+
* Copying Keras models with Python's built-in `copy.deepcopy`.
29+
30+
Supporting Pickle will enable wider usage in the Python ecosystem because
31+
Python's ecosystems of libraries depend strongly on the presence of protocols.
32+
Without these protocols, it's necessary for each library to implement a custom
33+
serialization method for every other library. For example, Dask Distributed has
34+
a custom serialization method for Keras at [distributed/protocol/keras.py].
35+
See "[Pickle isn't slow, it's a protocol]" for more detail (notably, this post
36+
focuses on having an efficient Pickle implementation for PyTorch).
37+
38+
[distributed/protocol/keras.py]:https://github.com/dask/distributed/blob/73fa9bd1bd7dcb4ceed72cdbdc6dd4b92f887521/distributed/protocol/keras.py
39+
40+
This request is *not* advocating for use of Pickle while saving or sharing
41+
Keras models. We believe the efficient, secure and stable methods in TF should
42+
be used for that. Instead, we are proposing to add a Pickle implementation to
43+
support wider usage in the Python ecosystem.
44+
45+
[Pickle isn't slow, it's a protocol]:https://blog.dask.org/2018/07/23/protocols-pickle
46+
47+
> *Which users are affected by the problem? Why is it a problem? What data
48+
> supports this? What related work exists?*
49+
50+
Users trying to use distributed systems (e.g, Ray or Dask) with Keras are
51+
affected. In our experience, this is common in hyperparameter optimization. In
52+
general, having Pickle support means a better experience, especially when using
53+
Keras with other libraries. Briefly, implementation of this RFC will make the
54+
following possible:
55+
56+
* Saving a Scikit-Learn pipeline to disk if it includes a Keras model
57+
* Using custom parallelization like Joblib or Dask.
58+
59+
More use cases and examples are give in "User Benefit."
60+
61+
Related work is in [SciKeras], which brings a Scikit-Learn API
62+
to Keras. Pickle is relevant because Scikit-Learn requires that estimators must be able to be pickled ([source][skp]).
63+
As such, SciKeras has an implementation of `__reduce__`, which is also in
64+
[tensorflow#39609].
65+
66+
[dask-ml#534]:https://github.com/dask/dask-ml/issues/534
67+
[SO#51110834]:https://stackoverflow.com/questions/51110834/cannot-pickle-dill-a-keras-object
68+
[SO#54070845]:https://stackoverflow.com/questions/54070845/how-to-pickle-keras-custom-layer
69+
[SO#59872509]:https://stackoverflow.com/questions/59872509/how-to-export-a-model-created-from-kerasclassifier-and-gridsearchcv-using-joblib
70+
[SO#37984304]:https://stackoverflow.com/questions/37984304/how-to-save-a-scikit-learn-pipline-with-keras-regressor-inside-to-disk
71+
[SO#48295661]:https://stackoverflow.com/questions/48295661/how-to-pickle-keras-model
72+
[skper]:https://scikit-learn.org/stable/modules/model_persistence.html#persistence-example
73+
[TF#33204]:https://github.com/tensorflow/tensorflow/issues/33204
74+
[TF#34697]:https://github.com/tensorflow/tensorflow/issues/34697
75+
76+
[tensorflow#39609]:https://github.com/tensorflow/tensorflow/pull/39609
77+
[SciKeras]:https://github.com/adriangb/scikeras
78+
[skp]:https://github.com/scikit-learn/scikit-learn/blob/0fb307bf39bbdacd6ed713c00724f8f871d60370/sklearn/utils/estimator_checks.py#L1523-L1524
79+
80+
<!--
81+
StackOverflow questions where `Model.save` would not work:
82+
83+
* [SO#40396042](https://stackoverflow.com/questions/40396042/how-to-save-scikit-learn-keras-model-into-a-persistence-file-pickle-hd5-json-ya)
84+
85+
Examples that could be resolved using `Model.save` (but the user tried pickle first):
86+
87+
* [SO #51878627](https://stackoverflow.com/questions/51878627/pickle-keras-ann)
88+
-->
89+
90+
## User Benefit
91+
92+
> How will users (or other contributors) benefit from this work? What would be the headline in the release notes or blog post?
93+
94+
One blog post headline: "Keras models can be used with the advanced
95+
hyperparameter optimization techniques found in Dask-ML and Ray Tune." This has
96+
already been mentioned in "Framework support" of [a Dask blog post][dbp]
97+
comparing Dask-ML's hyperparameter optimization with Ray's tune-sklearn.
98+
99+
[dbp]:https://blog.dask.org/2020/08/06/ray-tune#framework-support
100+
101+
Users will also benefit with easier usage; they won't run into any of these
102+
errors:
103+
104+
* People try to save Scikit-Learn meta-estimators with Keras components using
105+
the serialization libraries Joblib or Dill.
106+
This fails because Keras models can not be serialized without a custom
107+
method. Examples include [SO#59872509], [SO#37984304] and
108+
[SO#48295661], and [SO#51110834].
109+
* Using custom parallelization strategies requires serialization support through
110+
Pickle; however, many parallelization libraries don't
111+
special case Keras models (e.g, Joblib). Relevant errors are most common in hyperparameter
112+
optimization with Scikit-Learn's parallelization through Joblib
113+
([TF#33204] and [TF#34697]) or parallelization through Dask ([dask-ml#534]).
114+
* Lack of Pickle support can complicate saving training history like in
115+
(the poorly asked) [SO#54070845].
116+
117+
This RFC would resolve these issues.
118+
119+
## Design Proposal
120+
121+
We propose implementing the Pickle protocol using the existing Keras
122+
saving functionality as a backend. For example, adding pickle support to TF Metrics
123+
is as simple as the following:
124+
125+
``` python
126+
# tensorflow/python/keras/metrics.py
127+
128+
@keras_export('keras.metrics.Metric') # line 80
129+
@six.add_metaclass(abc.ABCMeta)
130+
class Metric(base_layer.Layer):
131+
...
132+
133+
def __reduce__(self, protocol):
134+
# the deserialized/serialize functions are defined in this file
135+
return deserialize, (serialize(self),)
136+
```
137+
138+
This implementation adds support for the Pickle protocol, which supports serialization
139+
to arbitrary IO, either memory or disk. The `__reduce__` special method can return
140+
the string that would have been written to disk and the function to load that string into memory ([docs][reduce_docs]).
141+
142+
[reduce_docs]:https://docs.python.org/3/library/pickle.html#object.__reduce__
143+
144+
For `tf.keras.Model`, we can use `SaveModel` as the backend for `__reduce__`:
145+
146+
``` python
147+
# tensorflow/python/keras/engine/training.py
148+
...
149+
from tesorflow.python.keras.models import load_model
150+
151+
class Model(base_layer.Layer, version_utils.ModelVersionSelector): # line 131
152+
...
153+
154+
def __reduce__(self, protocol):
155+
temp_ram_location = f"ram://tmp/saving/{id(self)}"
156+
self.save(temp_ram_location)
157+
b = tf.io.gfile.read_folder(temp_ram_location)
158+
return self._reconstruct_pickle, (np.asarray(memoryview(b)), )
159+
160+
@classmethod
161+
def _reconstruct_pickle(cls, obj):
162+
temp_ram_location = f"ram://tmp/saving/{id(obj)}"
163+
tf.io.gfile.write_folder(temp_ram_location, b)
164+
return load_model(temp_ram_location)
165+
```
166+
167+
This almost exactly mirrors the PyTorch
168+
implementation of Pickle support in [pytorch#9184]
169+
as mentioned in "[Pickle isn't slow, it's a protocol]."
170+
In addition, small augmentations to TensorFlow's IO module will be required (as discussed in [tensorflow#39609]).
171+
172+
By wrapping the pickled object within a Numpy array, pickling will support
173+
pickle protocol 5 for zero-copy pickling. This provides an immediate
174+
performance improvement for many use cases.
175+
176+
[pytorch#9184]:https://github.com/pytorch/pytorch/pull/9184
177+
178+
### Alternatives Considered
179+
180+
Of course, one method is to ask users to monkey-patch Keras models themselves.
181+
This would hold for libraries too. Clearly, this is unreasonable. Regardless,
182+
some libraries like Dask Distributed have already implemented custom serialization
183+
protocols ([distributed/protocol/keras.py]).
184+
185+
#### Other pickle implementations
186+
187+
The Pickle protocol supports two features:
188+
189+
1. In-memory copying of live objects: via Python's `copy` module. This falls back to (2) below.
190+
2. Serialization to arbitrary IO (memory or disk): via Python's `pickle` module.
191+
192+
This proposal seeks to take the conservative approach at least initially and
193+
only implement (2) above since (1) can always fall back to (2) and using only
194+
(2) alleviates any concerns around references to freed memory in the C++
195+
portions of TF and other such bugs.
196+
197+
This said, for situations where the user is making an in-memory copy of an object and it might
198+
even be okay to keep around references to non-Python objects, a separate approach that optimizes
199+
(1) would be warranted. This RFC does not seek to address this problem. Hence
200+
this RFC is generally not concerned with:
201+
202+
* Issues arising from C++ references. These cannot be kept around when
203+
serializing to a binary file stream.
204+
* Performance of the serialization/deserialization.
205+
206+
### Performance Implications
207+
208+
* The performance should be the same as the underlying backend that is already
209+
implemented in TF.
210+
* For cases where the user was going to pickle anyway, this will be faster
211+
because it uses TF's methods instead of letting Python deal with it naively.
212+
* Tests will consist of running `new_model = pickle.loads(pickle.dumps(model))`
213+
and then doing checks on `new_model`.
214+
215+
### Dependencies
216+
217+
> Dependencies: does this proposal add any new dependencies to TensorFlow?
218+
219+
No
220+
221+
> Dependent projects: are there other areas of TensorFlow or things that use
222+
TensorFlow (TFX/pipelines, TensorBoard, etc.) that this affects?
223+
224+
This should not affect those libraries. It will affect libraries
225+
further downstream like Dask-ML and Ray Tune.
226+
227+
### Engineering Impact
228+
229+
> Do you expect changes to binary size / startup time / build time / test
230+
times?
231+
232+
No
233+
234+
> Who will maintain this code? Is this code in its own buildable unit? Can this
235+
code be tested in its own? Is visibility suitably restricted to only a small
236+
API surface for others to use?
237+
238+
This code depends on existing Keras/TF methods. This code will not break
239+
presuming they are maintained (the new API surface area is very small).
240+
241+
### Platforms and Environments
242+
243+
> Platforms: does this work on all platforms supported by TensorFlow? If not,
244+
why is that ok? Will it work on embedded/mobile? Does it impact automatic
245+
code generation or mobile stripping tooling? Will it work with transformation
246+
tools?
247+
248+
Yes, as long as a Python backend is available.
249+
250+
> Execution environments (Cloud services, accelerator hardware): what impact do
251+
you expect and how will you confirm?
252+
253+
We don't see any impact.
254+
255+
### Best Practices
256+
257+
> Does this proposal change best practices for some aspect of using/developing
258+
TensorFlow? How will these changes be communicated/enforced?
259+
260+
No
261+
262+
### Tutorials and Examples
263+
264+
There are plenty of examples of how this can and would be used within all of the issues above, in addition to the linked notebook
265+
([link again](https://colab.research.google.com/drive/14ECRN8ZQDa1McKri2dctlV_CaPkE574I?authuser=1#scrollTo=qlXDfJObNXVf)) which has
266+
end to end implementations and tests for all of this.
267+
268+
### Compatibility
269+
270+
> Does the design conform to the backwards & forwards compatibility
271+
[requirements](https://www.tensorflow.org/programmers_guide/version_compat)?
272+
273+
Yes
274+
275+
> How will this proposal interact with other parts of the TensorFlow Ecosystem?*
276+
277+
It should have no immediate impact on other parts of the TF ecosystem.
278+
279+
> How will it work with TFLite?
280+
281+
N/A
282+
283+
> How will it work with distribution strategies?
284+
285+
This enables use of other serialization libraries, which might enable support for other distribution strategies.
286+
287+
> How will it interact with tf.function?
288+
289+
N/A
290+
291+
> Will this work on GPU/TPU?
292+
293+
N/A
294+
295+
> How will it serialize to a SavedModel?
296+
297+
Not applicable, and almost a circular question.
298+
299+
### User Impact
300+
301+
> What are the user-facing changes? How will this feature be rolled out?
302+
303+
There are no user-facing changes: this is a backend change to private methods.
304+
305+
Rolling out only involves testing. It will not require any documentation
306+
changes to advertise this features: `Model.save` should still be used for users
307+
simply trying to save their model to disk.
308+
309+
## Questions and Discussion Topics
310+
311+
> Seed this with open questions you require feedback on from the RFC process.

0 commit comments

Comments
 (0)