Skip to content

Commit acea843

Browse files
author
Rima Hajou
committed
fix cols_to_impute handling
1 parent f6a6646 commit acea843

File tree

1 file changed

+24
-15
lines changed

1 file changed

+24
-15
lines changed

qolmat/imputations/models.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,11 @@ def fit_transform(self, df: pd.DataFrame) -> pd.DataFrame:
6161
for col in cols_with_nans:
6262
if self.groups:
6363
groupby = utils.custom_groupby(df, self.groups)
64-
imputation_values = groupby[col].transform(self.apply_imputation)
64+
self.imputation_values = groupby[col].transform(self.apply_imputation)
6565
else:
66-
imputation_values = self.apply_imputation(df[col])
66+
self.imputation_values = self.apply_imputation(df[col])
6767

68-
df_imputed[col] = df_imputed[col].fillna(imputation_values)
68+
df_imputed[col] = df_imputed[col].fillna(self.imputation_values)
6969

7070
# fill na by applying imputation method without groups
7171
if df_imputed[col].isna().any():
@@ -650,8 +650,10 @@ class ImputeRegressor(_BaseImputer):
650650
>>> imputor.fit_transform(X)
651651
"""
652652

653-
def __init__(self, model, **kwargs) -> None:
653+
def __init__(self, model, cols_to_impute=None, fit_on_nan = True, **kwargs) -> None:
654654
self.model = model
655+
self.cols_to_impute = cols_to_impute
656+
self.fit_on_nan = fit_on_nan
655657

656658
for name, value in kwargs.items():
657659
setattr(self, name, value)
@@ -673,21 +675,28 @@ def fit_transform(self, df: pd.DataFrame) -> pd.Series:
673675
if not isinstance(df, pd.DataFrame):
674676
raise ValueError("Input has to be a pandas.DataFrame.")
675677

676-
df_imputed = df.copy()
677-
678678
cols_with_nans = df.columns[df.isna().any()]
679-
cols_without_nans = df.columns[df.notna().all()]
680-
681-
if len(cols_without_nans) == 0:
682-
raise Exception("There must be at least one column without missing values.")
679+
self.cols_without_nans = df.columns[df.notna().all()]
683680

684-
for col in cols_with_nans:
685-
X = df[cols_without_nans]
681+
if self.cols_to_impute is None:
682+
self.cols_to_impute = cols_with_nans
683+
elif not set(self.cols_to_impute).issubset(set(df.columns) ):
684+
raise ValueError("Input has to have at least one column of cols_to_impute")
685+
else:
686+
self.cols_to_impute =list(set(self.cols_to_impute) & set(cols_with_nans))
687+
688+
df_imputed = df.copy()
689+
690+
self.models = {col: self.model() for col in self.cols_to_impute}
691+
for col in self.cols_to_impute:
692+
if self.fit_on_nan:
693+
X = df.drop(columns=col)
694+
else:
695+
X = df[self.cols_without_nans].drop(columns=col)
686696
y = df[col]
687697
is_na = y.isna()
688-
self.model.fit(X[~is_na], y[~is_na])
689-
df_imputed.loc[is_na, col] = self.model.predict(X[is_na])
690-
698+
self.models[col].fit(X[~is_na], y[~is_na])
699+
df_imputed.loc[is_na, col] = self.models[col].predict(X[is_na])
691700
return df_imputed
692701

693702
def get_hyperparams(self):

0 commit comments

Comments
 (0)