@@ -61,11 +61,11 @@ def fit_transform(self, df: pd.DataFrame) -> pd.DataFrame:
6161 for col in cols_with_nans :
6262 if self .groups :
6363 groupby = utils .custom_groupby (df , self .groups )
64- imputation_values = groupby [col ].transform (self .apply_imputation )
64+ self . imputation_values = groupby [col ].transform (self .apply_imputation )
6565 else :
66- imputation_values = self .apply_imputation (df [col ])
66+ self . imputation_values = self .apply_imputation (df [col ])
6767
68- df_imputed [col ] = df_imputed [col ].fillna (imputation_values )
68+ df_imputed [col ] = df_imputed [col ].fillna (self . imputation_values )
6969
7070 # fill na by applying imputation method without groups
7171 if df_imputed [col ].isna ().any ():
@@ -650,8 +650,10 @@ class ImputeRegressor(_BaseImputer):
650650 >>> imputor.fit_transform(X)
651651 """
652652
653- def __init__ (self , model , ** kwargs ) -> None :
653+ def __init__ (self , model , cols_to_impute = None , fit_on_nan = True , ** kwargs ) -> None :
654654 self .model = model
655+ self .cols_to_impute = cols_to_impute
656+ self .fit_on_nan = fit_on_nan
655657
656658 for name , value in kwargs .items ():
657659 setattr (self , name , value )
@@ -673,21 +675,28 @@ def fit_transform(self, df: pd.DataFrame) -> pd.Series:
673675 if not isinstance (df , pd .DataFrame ):
674676 raise ValueError ("Input has to be a pandas.DataFrame." )
675677
676- df_imputed = df .copy ()
677-
678678 cols_with_nans = df .columns [df .isna ().any ()]
679- cols_without_nans = df .columns [df .notna ().all ()]
680-
681- if len (cols_without_nans ) == 0 :
682- raise Exception ("There must be at least one column without missing values." )
679+ self .cols_without_nans = df .columns [df .notna ().all ()]
683680
684- for col in cols_with_nans :
685- X = df [cols_without_nans ]
681+ if self .cols_to_impute is None :
682+ self .cols_to_impute = cols_with_nans
683+ elif not set (self .cols_to_impute ).issubset (set (df .columns ) ):
684+ raise ValueError ("Input has to have at least one column of cols_to_impute" )
685+ else :
686+ self .cols_to_impute = list (set (self .cols_to_impute ) & set (cols_with_nans ))
687+
688+ df_imputed = df .copy ()
689+
690+ self .models = {col : self .model () for col in self .cols_to_impute }
691+ for col in self .cols_to_impute :
692+ if self .fit_on_nan :
693+ X = df .drop (columns = col )
694+ else :
695+ X = df [self .cols_without_nans ].drop (columns = col )
686696 y = df [col ]
687697 is_na = y .isna ()
688- self .model .fit (X [~ is_na ], y [~ is_na ])
689- df_imputed .loc [is_na , col ] = self .model .predict (X [is_na ])
690-
698+ self .models [col ].fit (X [~ is_na ], y [~ is_na ])
699+ df_imputed .loc [is_na , col ] = self .models [col ].predict (X [is_na ])
691700 return df_imputed
692701
693702 def get_hyperparams (self ):
0 commit comments