-
Notifications
You must be signed in to change notification settings - Fork 62
ENH - Gets SciKeras script working #394
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 46 commits
925c960
3500592
492a1ca
4c91a07
101b90c
08f5ba7
0491a7b
f1b93fe
bbe6b34
3b274b4
e7ab34e
a1a92cc
7b0f21e
3397b17
49a16f0
d4469f9
3b55471
b51ca0e
dd8d6e1
f68eea6
f304dc7
c131ebd
846e72e
7f8593a
208839b
28e9f15
d1d260e
ac11b46
a9b9dbf
e5eb579
0eca68a
30f8993
1b1cdff
5033112
6a8e821
f119713
ab530f3
0d8efca
07e0d5a
cc08530
cc2aead
c96a50a
91983e8
2f3ae7a
3739f1f
ce11bf0
4c47aaf
12e2108
d677476
6d10f71
e307560
bb82961
76341fd
fa6b208
83891ed
e9b2dd0
2d92168
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,63 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import io | ||
| import os | ||
| import tempfile | ||
| from typing import Sequence, Type | ||
|
|
||
| import tensorflow as tf | ||
| from scikeras.wrappers import KerasClassifier, KerasRegressor | ||
|
|
||
| from ._audit import Node | ||
| from ._protocol import PROTOCOL | ||
| from ._utils import Any, LoadContext, SaveContext, get_module | ||
|
|
||
|
|
||
| def scikeras_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: | ||
| res = { | ||
| "__class__": obj.__class__.__name__, | ||
| "__module__": get_module(type(obj)), | ||
| "__loader__": "SciKerasNode", | ||
| } | ||
|
|
||
| obj_id = save_context.memoize(obj) | ||
| f_name = f"{obj_id}.keras" | ||
|
|
||
| with tempfile.TemporaryDirectory() as temp_dir: | ||
| file_name = os.path.join(temp_dir, "model.keras") | ||
| obj.model.save(file_name) | ||
|
||
| save_context.zip_file.write(file_name, f_name) | ||
|
|
||
| res.update(type="scikeras", file=f_name) | ||
| return res | ||
|
|
||
|
|
||
| class SciKerasNode(Node): | ||
| def __init__( | ||
| self, | ||
| state: dict[str, Any], | ||
| load_context: LoadContext, | ||
| trusted: bool | Sequence[str] = False, | ||
| ) -> None: | ||
| super().__init__(state, load_context, trusted) | ||
| self.trusted = self._get_trusted(trusted, default=[]) | ||
|
|
||
| self.children = {"content": io.BytesIO(load_context.src.read(state["file"]))} | ||
|
|
||
| def _construct(self): | ||
| with tempfile.TemporaryDirectory() as temp_dir: | ||
| file_path = os.path.join(temp_dir, "model.keras") | ||
| with open(file_path, "wb") as f: | ||
| f.write(self.children["content"].getbuffer()) | ||
| model = tf.keras.models.load_model(file_path, compile=False) | ||
| return model | ||
|
|
||
|
|
||
| GET_STATE_DISPATCH_FUNCTIONS = [ | ||
| (KerasClassifier, scikeras_get_state), | ||
| (KerasRegressor, scikeras_get_state), | ||
| ] | ||
|
|
||
| NODE_TYPE_MAPPING: dict[tuple[str, int], Type[Node]] = { | ||
| ("SciKerasNode", PROTOCOL): SciKerasNode | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.