Skip to content

Commit b4fd72c

Browse files
committed
✨ Rebuild serialisation protocol to allow versions (#192)
1 parent 12b9e79 commit b4fd72c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+1157
-613
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ repos:
1313
- id: name-tests-test
1414
args:
1515
- --django
16-
exclude: "asserts.py|utils.py"
16+
exclude: "_utils/.*"
1717
- id: pretty-format-json
1818
args:
1919
- --autofix

flama/compat.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import sys
22

3-
__all__ = ["Self", "NotRequired", "StrEnum", "tomllib", "get_annotations"]
3+
__all__ = ["Self", "NotRequired", "StrEnum", "tomllib", "get_annotations", "bz2", "lzma", "zlib", "zstd"]
44

55
# PORT: Remove when stop supporting 3.10
66
# Self was added in Python 3.11
@@ -54,3 +54,15 @@ def _generate_next_value_(name, start, count, last_values):
5454
from annotationlib import get_annotations
5555
else:
5656
from typing_extensions import get_annotations
57+
58+
# PORT: Remove when stop supporting 3.13
59+
# compression was added in Python 3.14
60+
# https://docs.python.org/3/library/compression.html
61+
if sys.version_info >= (3, 14):
62+
from compression import bz2, lzma, zlib, zstd
63+
else:
64+
import bz2
65+
import lzma
66+
import zlib
67+
68+
import zstd

flama/models/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
if t.TYPE_CHECKING:
55
from flama.serialize.data_structures import Artifacts, Metadata
66

7-
__all__ = ["Model"]
7+
__all__ = ["BaseModel"]
88

99

10-
class Model:
10+
class BaseModel:
1111
def __init__(self, model: t.Any, meta: "Metadata", artifacts: "Artifacts | None"):
1212
self.model = model
1313
self.meta = meta

flama/models/components.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import importlib
22
import os
3+
import pathlib
4+
import typing as t
35

6+
from flama import types
47
from flama.injection import Component
5-
from flama.models.base import Model
6-
from flama.serialize import load
7-
from flama.serialize.types import Framework
8+
from flama.models.base import BaseModel
9+
from flama.serialize.serializer import Serializer
810

911
__all__ = ["ModelComponent", "ModelComponentBuilder"]
1012

@@ -13,29 +15,38 @@ class ModelComponent(Component):
1315
def __init__(self, model):
1416
self.model = model
1517

16-
def get_model_type(self) -> type[Model]:
18+
def get_model_type(self) -> type[BaseModel]:
1719
return self.model.__class__ # type: ignore[no-any-return]
1820

1921

2022
class ModelComponentBuilder:
23+
_module_name: t.Final[str] = "flama.models.models.{}"
24+
_class_name: t.Final[str] = "Model"
25+
_modules: t.Final[dict[types.MLLib, str]] = {
26+
"keras": "tensorflow",
27+
"sklearn": "sklearn",
28+
"tensorflow": "tensorflow",
29+
"torch": "pytorch",
30+
}
31+
2132
@classmethod
22-
def _get_model_class(cls, framework: Framework) -> type[Model]:
33+
def _get_model_class(cls, lib: types.MLLib) -> type[BaseModel]:
2334
try:
24-
module, class_name = {
25-
Framework.torch: ("pytorch", "PyTorchModel"),
26-
Framework.sklearn: ("sklearn", "SKLearnModel"),
27-
Framework.tensorflow: ("tensorflow", "TensorFlowModel"),
28-
Framework.keras: ("tensorflow", "TensorFlowModel"),
29-
}[framework]
35+
return getattr(importlib.import_module(cls._module_name.format(cls._modules[lib])), cls._class_name)
3036
except KeyError: # pragma: no cover
31-
raise ValueError("Wrong framework")
32-
33-
model_class: type[Model] = getattr(importlib.import_module(f"flama.models.models.{module}"), class_name)
34-
return model_class
37+
raise ValueError(f"Wrong lib '{lib}'")
38+
except ModuleNotFoundError: # pragma: no cover
39+
raise ValueError(f"Module not found '{cls._module_name.format(cls._modules[lib])}'")
40+
except AttributeError: # pragma: no cover
41+
raise ValueError(
42+
f"Class '{cls._class_name}' not found in module '{cls._module_name.format(cls._modules[lib])}'"
43+
)
3544

3645
@classmethod
37-
def load(cls, path: str | os.PathLike) -> ModelComponent:
38-
load_model = load(path)
46+
def load(cls, path: str | os.PathLike | pathlib.Path) -> ModelComponent:
47+
with pathlib.Path(str(path)).open("rb") as f:
48+
load_model = Serializer.load(f)
49+
3950
parent = cls._get_model_class(load_model.meta.framework.lib)
4051
model_class = type(parent.__name__, (parent,), {})
4152
model_obj = model_class(load_model.model, load_model.meta, load_model.artifacts)

flama/models/models/pytorch.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
import typing as t
22

33
from flama import exceptions
4-
from flama.models.base import Model
4+
from flama.models.base import BaseModel
55

66
try:
77
import torch # type: ignore
88
except Exception: # pragma: no cover
99
torch = None # type: ignore
1010

11+
__all__ = ["Model"]
1112

12-
class PyTorchModel(Model):
13+
14+
class Model(BaseModel):
1315
def predict(self, x: list[list[t.Any]]) -> t.Any:
1416
if torch is None: # noqa
1517
raise exceptions.FrameworkNotInstalled("pytorch")

flama/models/models/sklearn.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
import typing as t
22

33
from flama import exceptions
4-
from flama.models.base import Model
4+
from flama.models.base import BaseModel
55

66
try:
77
import sklearn # type: ignore
88
except Exception: # pragma: no cover
99
sklearn = None
1010

1111

12-
class SKLearnModel(Model):
12+
__all__ = ["Model"]
13+
14+
15+
class Model(BaseModel):
1316
def predict(self, x: list[list[t.Any]]) -> t.Any:
1417
if sklearn is None: # noqa
1518
raise exceptions.FrameworkNotInstalled("scikit-learn")

flama/models/models/tensorflow.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import typing as t
22

33
from flama import exceptions
4-
from flama.models.base import Model
4+
from flama.models.base import BaseModel
55

66
try:
77
import numpy as np # type: ignore
@@ -14,7 +14,10 @@
1414
tf = None
1515

1616

17-
class TensorFlowModel(Model):
17+
__all__ = ["Model"]
18+
19+
20+
class Model(BaseModel):
1821
def predict(self, x: list[list[t.Any]]) -> t.Any:
1922
if np is None: # noqa
2023
raise exceptions.FrameworkNotInstalled("numpy")

flama/models/resource.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from flama.resources.routing import ResourceRoute
1111

1212
if t.TYPE_CHECKING:
13-
from flama.models.base import Model
13+
from flama.models.base import BaseModel
1414
from flama.models.components import ModelComponent
1515

1616
__all__ = ["BaseModelResource", "ModelResource", "InspectMixin", "PredictMixin", "ModelResourceType"]
@@ -21,7 +21,9 @@
2121

2222
class InspectMixin:
2323
@classmethod
24-
def _add_inspect(cls, name: str, verbose_name: str, model_model_type: type["Model"], **kwargs) -> dict[str, t.Any]:
24+
def _add_inspect(
25+
cls, name: str, verbose_name: str, model_model_type: type["BaseModel"], **kwargs
26+
) -> dict[str, t.Any]:
2527
@ResourceRoute.method("/", methods=["GET"], name="inspect")
2628
async def inspect(self, model: model_model_type): # type: ignore[valid-type]
2729
return model.inspect() # type: ignore[attr-defined]
@@ -44,7 +46,9 @@ async def inspect(self, model: model_model_type): # type: ignore[valid-type]
4446

4547
class PredictMixin:
4648
@classmethod
47-
def _add_predict(cls, name: str, verbose_name: str, model_model_type: type["Model"], **kwargs) -> dict[str, t.Any]:
49+
def _add_predict(
50+
cls, name: str, verbose_name: str, model_model_type: type["BaseModel"], **kwargs
51+
) -> dict[str, t.Any]:
4852
@ResourceRoute.method("/predict/", methods=["POST"], name="predict")
4953
async def predict(
5054
self,

flama/serialize/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from flama.serialize.data_structures import * # noqa
2-
from flama.serialize.dump import * # noqa
3-
from flama.serialize.load import * # noqa
1+
from flama.serialize.serializer import dump, load
42

5-
__all__ = ["dump", "load"] # noqa
3+
__all__ = ["dump", "load"]

flama/serialize/base.py

Lines changed: 0 additions & 25 deletions
This file was deleted.

0 commit comments

Comments
 (0)