Skip to content

Commit 55bc11d

Browse files
authored
More tutorials, bugfixing (#30)
* cleanup * bump version * cleanup * cleanup * cleanup * improvements * debug * bugfixing * debug * debug
1 parent 74ce6e7 commit 55bc11d

25 files changed

+298
-28
lines changed

src/hyperimpute/plugins/core/base_plugin.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ class Plugin(Serializable, metaclass=ABCMeta):
3838
def __init__(self) -> None:
3939
super().__init__()
4040

41+
self.drop_consts = []
42+
4143
@staticmethod
4244
@abstractmethod
4345
def hyperparameter_space(*args: Any, **kwargs: Any) -> List[Params]:
@@ -119,6 +121,13 @@ def fit_predict(self, X: pd.DataFrame, *args: Any, **kwargs: Any) -> pd.DataFram
119121

120122
def fit(self, X: pd.DataFrame, *args: Any, **kwargs: Any) -> Any:
121123
X = cast.to_dataframe(X)
124+
125+
for col in X.columns:
126+
if len(X.loc[X[col].notna(), col].unique()) <= 1:
127+
self.drop_consts.append(col)
128+
129+
X = X.drop(columns=self.drop_consts)
130+
self.columns = X.columns
122131
return self._fit(X, *args, **kwargs)
123132

124133
@abstractmethod
@@ -127,6 +136,7 @@ def _fit(self, X: pd.DataFrame, *args: Any, **kwargs: Any) -> "Plugin":
127136

128137
def transform(self, X: pd.DataFrame) -> pd.DataFrame:
129138
X = cast.to_dataframe(X)
139+
X = X.drop(columns=self.drop_consts)
130140
return pd.DataFrame(self._transform(X))
131141

132142
@abstractmethod
@@ -135,6 +145,7 @@ def _transform(self, X: pd.DataFrame) -> pd.DataFrame:
135145

136146
def predict(self, X: pd.DataFrame, *args: Any, **kwargs: Any) -> pd.DataFrame:
137147
X = cast.to_dataframe(X)
148+
X = X.drop(columns=self.drop_consts)
138149
return pd.DataFrame(self._predict(X, *args, *kwargs))
139150

140151
@abstractmethod

src/hyperimpute/plugins/imputers/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
# hyperimpute absolute
99
import hyperimpute.plugins.core.base_plugin as plugin
10+
from hyperimpute.utils.distributions import enable_reproducible_results
1011

1112

1213
class ImputerPlugin(_BaseImputer, plugin.Plugin):
@@ -24,8 +25,10 @@ class ImputerPlugin(_BaseImputer, plugin.Plugin):
2425
"""
2526

2627
def __init__(self, random_state: int = 0) -> None:
27-
super().__init__()
28+
_BaseImputer.__init__(self)
29+
plugin.Plugin.__init__(self)
2830

31+
enable_reproducible_results(random_state)
2932
self.random_state = random_state
3033

3134
@staticmethod

src/hyperimpute/plugins/prediction/classifiers/base.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import hyperimpute.plugins.core.base_plugin as plugin
1010
import hyperimpute.plugins.prediction.base as prediction_base
1111
import hyperimpute.plugins.utils.cast as cast
12+
from hyperimpute.utils.distributions import enable_reproducible_results
1213
from hyperimpute.utils.tester import Eval
1314

1415

@@ -26,17 +27,23 @@ class ClassifierPlugin(
2627
If any method implementation is missing, the class constructor will fail.
2728
"""
2829

29-
def __init__(self, **kwargs: Any) -> None:
30+
def __init__(self, random_state: int = 0, **kwargs: Any) -> None:
3031
self.args = kwargs
32+
self.random_state = random_state
3133

32-
super().__init__()
34+
enable_reproducible_results(self.random_state)
35+
36+
ClassifierMixin.__init__(self)
37+
BaseEstimator.__init__(self)
38+
prediction_base.PredictionPlugin.__init__(self)
3339

3440
@staticmethod
3541
def subtype() -> str:
3642
return "classifier"
3743

3844
def fit(self, X: pd.DataFrame, *args: Any, **kwargs: Any) -> plugin.Plugin:
3945
X = cast.to_dataframe(X)
46+
enable_reproducible_results(self.random_state)
4047

4148
if len(args) == 0:
4249
raise RuntimeError("Please provide the training labels as well")

src/hyperimpute/plugins/prediction/classifiers/plugin_catboost.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def __init__(
5454
random_strength: float = 1,
5555
**kwargs: Any
5656
) -> None:
57-
super().__init__(**kwargs)
57+
super().__init__(random_state=random_state, **kwargs)
5858
if model is not None:
5959
self.model = model
6060
return

src/hyperimpute/plugins/prediction/classifiers/plugin_gpboost.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __init__(
6565
hyperparam_search_iterations: Optional[int] = None,
6666
**kwargs: Any
6767
) -> None:
68-
super().__init__(**kwargs)
68+
super().__init__(random_state=random_state, **kwargs)
6969
if hyperparam_search_iterations:
7070
n_estimators = int(hyperparam_search_iterations)
7171

src/hyperimpute/plugins/prediction/classifiers/plugin_kneighbors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(
3636
model: Any = None,
3737
**kwargs: Any
3838
) -> None:
39-
super().__init__(**kwargs)
39+
super().__init__(random_state=random_state, **kwargs)
4040
if model is not None:
4141
self.model = model
4242
return

src/hyperimpute/plugins/prediction/classifiers/plugin_lgbm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __init__(
6565
random_state: int = 0,
6666
**kwargs: Any
6767
) -> None:
68-
super().__init__(**kwargs)
68+
super().__init__(random_state=random_state, **kwargs)
6969
if model is not None:
7070
self.model = model
7171
return

src/hyperimpute/plugins/prediction/classifiers/plugin_logistic_regression.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def __init__(
5353
hyperparam_search_iterations: Optional[int] = None,
5454
**kwargs: Any
5555
) -> None:
56-
super().__init__(**kwargs)
56+
super().__init__(random_state=random_state, **kwargs)
5757
if model is not None:
5858
self.model = model
5959
return

src/hyperimpute/plugins/prediction/classifiers/plugin_neural_nets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def __init__(
251251
hyperparam_search_iterations: Optional[int] = None,
252252
**kwargs: Any,
253253
) -> None:
254-
super().__init__(**kwargs)
254+
super().__init__(random_state=random_state, **kwargs)
255255

256256
enable_reproducible_results(random_state)
257257

src/hyperimpute/plugins/prediction/classifiers/plugin_random_forest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def __init__(
5656
hyperparam_search_iterations: Optional[int] = None,
5757
**kwargs: Any
5858
) -> None:
59-
super().__init__(**kwargs)
59+
super().__init__(random_state=random_state, **kwargs)
6060
if hyperparam_search_iterations:
6161
n_estimators = int(hyperparam_search_iterations)
6262

0 commit comments

Comments
 (0)