Skip to content

Commit 825635d

Browse files
authored
Serialization bugfixing (#35)
* serde error * more tests * improvements * drop lgbm * cleanup
1 parent ead5c4b commit 825635d

File tree

12 files changed

+75
-466
lines changed

12 files changed

+75
-466
lines changed

src/hyperimpute/plugins/core/base_plugin.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,6 @@ class Plugin(Serializable, metaclass=ABCMeta):
3838
def __init__(self) -> None:
3939
super().__init__()
4040

41-
self.drop_consts = []
42-
4341
@staticmethod
4442
@abstractmethod
4543
def hyperparameter_space(*args: Any, **kwargs: Any) -> List[Params]:
@@ -122,11 +120,6 @@ def fit_predict(self, X: pd.DataFrame, *args: Any, **kwargs: Any) -> pd.DataFram
122120
def fit(self, X: pd.DataFrame, *args: Any, **kwargs: Any) -> Any:
123121
X = cast.to_dataframe(X)
124122

125-
for col in X.columns:
126-
if len(X.loc[X[col].notna(), col].unique()) <= 1:
127-
self.drop_consts.append(col)
128-
129-
X = X.drop(columns=self.drop_consts)
130123
self.columns = X.columns
131124
return self._fit(X, *args, **kwargs)
132125

@@ -136,7 +129,6 @@ def _fit(self, X: pd.DataFrame, *args: Any, **kwargs: Any) -> "Plugin":
136129

137130
def transform(self, X: pd.DataFrame) -> pd.DataFrame:
138131
X = cast.to_dataframe(X)
139-
X = X.drop(columns=self.drop_consts)
140132
return pd.DataFrame(self._transform(X))
141133

142134
@abstractmethod
@@ -145,7 +137,6 @@ def _transform(self, X: pd.DataFrame) -> pd.DataFrame:
145137

146138
def predict(self, X: pd.DataFrame, *args: Any, **kwargs: Any) -> pd.DataFrame:
147139
X = cast.to_dataframe(X)
148-
X = X.drop(columns=self.drop_consts)
149140
return pd.DataFrame(self._predict(X, *args, *kwargs))
150141

151142
@abstractmethod

src/hyperimpute/plugins/prediction/classifiers/plugin_lgbm.py

Lines changed: 0 additions & 127 deletions
This file was deleted.

src/hyperimpute/plugins/prediction/classifiers/plugin_random_forest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class RandomForestPlugin(base.ClassifierPlugin):
4141
"""
4242

4343
criterions = ["gini", "entropy"]
44-
features = ["auto", "sqrt", "log2"]
44+
features = ["sqrt", "log2", None]
4545

4646
def __init__(
4747
self,
@@ -97,11 +97,13 @@ def _fit(self, X: pd.DataFrame, *args: Any, **kwargs: Any) -> "RandomForestPlugi
9797
return self
9898

9999
def _predict(self, X: pd.DataFrame, *args: Any, **kwargs: Any) -> pd.DataFrame:
100+
X = np.asarray(X)
100101
return self.model.predict(X, *args, **kwargs)
101102

102103
def _predict_proba(
103104
self, X: pd.DataFrame, *args: Any, **kwargs: Any
104105
) -> pd.DataFrame:
106+
X = np.asarray(X)
105107
return self.model.predict_proba(X, *args, **kwargs)
106108

107109

src/hyperimpute/plugins/prediction/regression/plugin_lgbm_regressor.py

Lines changed: 0 additions & 112 deletions
This file was deleted.

src/hyperimpute/plugins/prediction/regression/plugin_random_forest_regressor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class RandomForestRegressionPlugin(base.RegressionPlugin):
4141
"""
4242

4343
criterions = ["squared_error", "absolute_error", "poisson"]
44-
features = ["auto", "sqrt", "log2"]
44+
features = ["sqrt", "log2", None]
4545

4646
def __init__(
4747
self,
@@ -103,6 +103,8 @@ def _fit(
103103
return self
104104

105105
def _predict(self, X: pd.DataFrame, *args: Any, **kwargs: Any) -> pd.DataFrame:
106+
X = np.asarray(X)
107+
106108
return self.model.predict(X, *args, **kwargs)
107109

108110

src/hyperimpute/utils/serialization.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,19 +138,42 @@ def version() -> str:
138138
return MAJOR_VERSION
139139

140140

141+
def _add_version(obj: Any) -> Any:
142+
obj._serde_version = MAJOR_VERSION
143+
return obj
144+
145+
146+
def _check_version(obj: Any) -> Any:
147+
local_version = obj._serde_version
148+
149+
if not hasattr(obj, "_serde_version"):
150+
raise RuntimeError("Missing serialization version")
151+
152+
if local_version != MAJOR_VERSION:
153+
raise ValueError(
154+
f"Serialized object mismatch. Current major version is {MAJOR_VERSION}, but the serialized object has version {local_version}."
155+
)
156+
157+
141158
def save(model: Any) -> bytes:
159+
_add_version(model)
142160
return cloudpickle.dumps(model)
143161

144162

145163
def load(buff: bytes) -> Any:
146-
return cloudpickle.loads(buff)
164+
obj = cloudpickle.loads(buff)
165+
_check_version(obj)
166+
return obj
147167

148168

149169
def save_to_file(path: Union[str, Path], model: Any) -> Any:
170+
_add_version(model)
150171
with open(path, "wb") as f:
151172
return cloudpickle.dump(model, f)
152173

153174

154175
def load_from_file(path: Union[str, Path]) -> Any:
155176
with open(path, "rb") as f:
156-
return cloudpickle.load(f)
177+
obj = cloudpickle.load(f)
178+
_check_version(obj)
179+
return obj

src/hyperimpute/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.1.14"
1+
__version__ = "0.1.15"
22

33
MAJOR_VERSION = ".".join(__version__.split(".")[:-1])
44
MINOR_VERSION = __version__.split(".")[-1]

0 commit comments

Comments
 (0)