11import importlib
22import os
3+ import pathlib
4+ import typing as t
35
6+ from flama import types
47from flama .injection import Component
5- from flama .models .base import Model
6- from flama .serialize import load
7- from flama .serialize .types import Framework
8+ from flama .models .base import BaseModel
9+ from flama .serialize .serializer import Serializer
810
911__all__ = ["ModelComponent" , "ModelComponentBuilder" ]
1012
@@ -13,29 +15,38 @@ class ModelComponent(Component):
1315 def __init__ (self , model ):
1416 self .model = model
1517
16- def get_model_type (self ) -> type [Model ]:
18+ def get_model_type (self ) -> type [BaseModel ]:
1719 return self .model .__class__ # type: ignore[no-any-return]
1820
1921
2022class ModelComponentBuilder :
23+ _module_name : t .Final [str ] = "flama.models.models.{}"
24+ _class_name : t .Final [str ] = "Model"
25+ _modules : t .Final [dict [types .MLLib , str ]] = {
26+ "keras" : "tensorflow" ,
27+ "sklearn" : "sklearn" ,
28+ "tensorflow" : "tensorflow" ,
29+ "torch" : "pytorch" ,
30+ }
31+
2132 @classmethod
22- def _get_model_class (cls , framework : Framework ) -> type [Model ]:
33+ def _get_model_class (cls , lib : types . MLLib ) -> type [BaseModel ]:
2334 try :
24- module , class_name = {
25- Framework .torch : ("pytorch" , "PyTorchModel" ),
26- Framework .sklearn : ("sklearn" , "SKLearnModel" ),
27- Framework .tensorflow : ("tensorflow" , "TensorFlowModel" ),
28- Framework .keras : ("tensorflow" , "TensorFlowModel" ),
29- }[framework ]
35+ return getattr (importlib .import_module (cls ._module_name .format (cls ._modules [lib ])), cls ._class_name )
3036 except KeyError : # pragma: no cover
31- raise ValueError ("Wrong framework" )
32-
33- model_class : type [Model ] = getattr (importlib .import_module (f"flama.models.models.{ module } " ), class_name )
34- return model_class
37+ raise ValueError (f"Wrong lib '{ lib } '" )
38+ except ModuleNotFoundError : # pragma: no cover
39+ raise ValueError (f"Module not found '{ cls ._module_name .format (cls ._modules [lib ])} '" )
40+ except AttributeError : # pragma: no cover
41+ raise ValueError (
42+ f"Class '{ cls ._class_name } ' not found in module '{ cls ._module_name .format (cls ._modules [lib ])} '"
43+ )
3544
3645 @classmethod
37- def load (cls , path : str | os .PathLike ) -> ModelComponent :
38- load_model = load (path )
46+ def load (cls , path : str | os .PathLike | pathlib .Path ) -> ModelComponent :
47+ with pathlib .Path (str (path )).open ("rb" ) as f :
48+ load_model = Serializer .load (f )
49+
3950 parent = cls ._get_model_class (load_model .meta .framework .lib )
4051 model_class = type (parent .__name__ , (parent ,), {})
4152 model_obj = model_class (load_model .model , load_model .meta , load_model .artifacts )
0 commit comments