Skip to content

Commit 8d37046

Browse files
author
vm-aifluence-jro
committed
cross_validation fit_transform removed
1 parent 437f0f3 commit 8d37046

File tree

3 files changed

+7
-28
lines changed

3 files changed

+7
-28
lines changed

qolmat/benchmark/comparator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,10 @@ def evaluate_errors_sample(
121121
hole_generator=self.generator_holes,
122122
n_calls=self.n_calls_opt,
123123
)
124-
df_imputed = cv.fit_transform(df_corrupted)
124+
imputer.hyperparams_optim = cv.optimize_hyperparams(df_corrupted)
125125
else:
126-
df_imputed = imputer.fit_transform(df_corrupted)
126+
imputer.hyperparams_optim = {}
127+
df_imputed = imputer.fit_transform(df_corrupted)
127128
subset = self.generator_holes.subset
128129
errors = self.get_errors(df_origin[subset], df_imputed[subset], df_mask[subset])
129130
list_errors.append(errors)

qolmat/benchmark/cross_validation.py

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def obj_func(**hyperparams_flat):
207207

208208
return obj_func
209209

210-
def optimize_hyperparams(self, df: pd.DataFrame) -> Dict[str, Union[float, int, str]]:
210+
def optimize_hyperparams(self, df: pd.DataFrame) -> Dict[str, Any]:
211211
"""Optimize hyperparamaters
212212
213213
Parameters
@@ -217,7 +217,7 @@ def optimize_hyperparams(self, df: pd.DataFrame) -> Dict[str, Union[float, int,
217217
218218
Returns
219219
-------
220-
Dict[str, Union[float,int, str]]
220+
Dict[str, Any]
221221
hyperparameters optimize flat
222222
"""
223223
list_spaces = get_search_space(self.dict_config_opti_imputer)
@@ -231,25 +231,5 @@ def optimize_hyperparams(self, df: pd.DataFrame) -> Dict[str, Union[float, int,
231231
)
232232

233233
hyperparams_flat = {space.name: val for space, val in zip(list_spaces, res["x"])}
234-
return hyperparams_flat
235-
236-
def fit_transform(self, df: pd.DataFrame) -> pd.DataFrame:
237-
"""
238-
Fit and transform estimator and impute the missing values.
239-
240-
Parameters
241-
----------
242-
df : pd.DataFrame
243-
dataframe to impute
244-
245-
Returns
246-
-------
247-
pd.DataFrame
248-
imputed dataframe
249-
"""
250-
251-
hyperparams_flat = self.optimize_hyperparams(df)
252-
self.imputer.hyperparams_optim = deflat_hyperparams(hyperparams_flat)
253-
df_imputed = self.imputer.fit_transform(df)
254-
255-
return df_imputed
234+
hyperparams = deflat_hyperparams(hyperparams_flat)
235+
return hyperparams

qolmat/imputations/imputers.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,6 @@ def __init__(
4848
random_state: Union[None, int, np.random.RandomState] = None,
4949
):
5050
self.hyperparams_user = hyperparams
51-
self.hyperparams_optim: Dict = {}
52-
self.hyperparams_local: Dict = {}
5351
self.groups = groups
5452
self.columnwise = columnwise
5553
self.shrink = shrink

0 commit comments

Comments
 (0)