1+ import abc
2+ import copy
13import sys
24from typing import Any , Dict , List , Optional , Union
35
46import sklearn .neighbors ._base
7+ from sklearn .base import BaseEstimator
58
69sys .modules ["sklearn.neighbors.base" ] = sklearn .neighbors ._base
710
2023
2124
2225class Imputer (_BaseImputer ):
23- def __init__ (self , groups : List [str ] = [], columnwise : bool = False , hyperparams : Dict = {}):
26+ def __init__ (
27+ self ,
28+ groups : List [str ] = [],
29+ columnwise : bool = False ,
30+ shrink : bool = False ,
31+ hyperparams : Dict = {},
32+ ):
2433 self .hyperparams_user = hyperparams
2534 self .hyperparams_optim = {}
2635 self .hyperparams_local = {}
2736 self .groups = groups
2837 self .columnwise = columnwise
38+ self .shrink = shrink
2939
3040 def fit_transform (self , df : pd .DataFrame ) -> pd .DataFrame :
3141 """
@@ -47,6 +57,12 @@ def fit_transform(self, df: pd.DataFrame) -> pd.DataFrame:
4757 hyperparams = self .hyperparams_user .copy ()
4858 hyperparams .update (self .hyperparams_optim )
4959 cols_with_nans = df .columns [df .isna ().any ()]
60+
61+ if self .groups == []:
62+ self .ngroups = pd .Series (0 , index = df .index ).rename ("_ngroup" )
63+ else :
64+ self .ngroups = df .groupby (self .groups ).ngroup ().rename ("_ngroup" )
65+
5066 if self .columnwise :
5167
5268 # imputed = pd.DataFrame(index=df.index, columns=df.columns)
@@ -79,16 +95,20 @@ def impute_element(self, df: pd.DataFrame) -> pd.DataFrame:
7995 df = df .copy ()
8096 if self .groups :
8197
82- groupby = utils .custom_groupby (df , self .groups )
83- imputation_values = groupby .apply (self .fit_transform_element )
98+ # groupby = utils.custom_groupby(df, self.groups)
99+ groupby = df .groupby (self .ngroups , group_keys = False )
100+ if self .shrink :
101+ imputation_values = groupby .transform (self .fit_transform_element )
102+ else :
103+ imputation_values = groupby .apply (self .fit_transform_element )
84104 else :
85105 imputation_values = self .fit_transform_element (df )
86106
87107 df = df .fillna (imputation_values )
88- # # fill na by applying imputation method without groups
89- # if df.isna().any().any():
90- # imputation_values = self.fit_transform_fallback(df)
91- # df = df.fillna(imputation_values)
108+ # fill na by applying imputation method without groups
109+ if df .isna ().any ().any ():
110+ imputation_values = self .fit_transform_fallback (df )
111+ df = df .fillna (imputation_values )
92112
93113 return df
94114
@@ -114,7 +134,7 @@ def __init__(
114134 self ,
115135 groups : List [str ] = [],
116136 ) -> None :
117- super ().__init__ (groups = groups , columnwise = True )
137+ super ().__init__ (groups = groups , columnwise = True , shrink = True )
118138 self .fit_transform_element = pd .DataFrame .mean
119139
120140
@@ -139,7 +159,7 @@ def __init__(
139159 self ,
140160 groups : List [str ] = [],
141161 ) -> None :
142- super ().__init__ (groups = groups , columnwise = True )
162+ super ().__init__ (groups = groups , columnwise = True , shrink = True )
143163 self .fit_transform_element = pd .DataFrame .median
144164
145165
@@ -164,7 +184,7 @@ def __init__(
164184 self ,
165185 groups : List [str ] = [],
166186 ) -> None :
167- super ().__init__ (groups = groups , columnwise = True )
187+ super ().__init__ (groups = groups , columnwise = True , shrink = True )
168188 self .fit_transform_element = lambda df : df .mode ().iloc [0 ]
169189
170190
@@ -509,9 +529,11 @@ class ImputerMICE(Imputer):
509529 def __init__ (
510530 self ,
511531 groups : List [str ] = [],
532+ estimator : Optional [BaseEstimator ] = None ,
512533 ** hyperparams ,
513534 ) -> None :
514535 super ().__init__ (groups = groups , columnwise = False , hyperparams = hyperparams )
536+ self .estimator = estimator
515537
516538 def fit_transform_element (self , df : pd .DataFrame ) -> pd .DataFrame :
517539 """
@@ -530,7 +552,7 @@ def fit_transform_element(self, df: pd.DataFrame) -> pd.DataFrame:
530552 if not isinstance (df , pd .DataFrame ):
531553 raise ValueError ("Input has to be a pandas.DataFrame." )
532554
533- iterative_imputer = IterativeImputer (** self .hyperparams_element )
555+ iterative_imputer = IterativeImputer (estimator = self . estimator , ** self .hyperparams_element )
534556 res = iterative_imputer .fit_transform (df .values )
535557 imputed = pd .DataFrame (columns = df .columns )
536558 for ind , col in enumerate (imputed .columns ):
@@ -564,11 +586,15 @@ class ImputerRegressor(Imputer):
564586 """
565587
566588 def __init__ (
567- self , type_model : Any , groups : List [str ] = [], fit_on_nan : bool = False , ** hyperparams
589+ self ,
590+ groups : List [str ] = [],
591+ estimator : Optional [BaseEstimator ] = None ,
592+ fit_on_nan : bool = False ,
593+ ** hyperparams ,
568594 ):
569595 super ().__init__ (groups = groups , hyperparams = hyperparams )
570596 self .columnwise = False
571- self .type_model = type_model
597+ self .estimator = estimator
572598 self .fit_on_nan = fit_on_nan
573599
574600 def fit_transform_element (self , df : pd .DataFrame ) -> pd .DataFrame :
@@ -598,7 +624,9 @@ def fit_transform_element(self, df: pd.DataFrame) -> pd.DataFrame:
598624 value = value [col ]
599625 hyperparams [hyperparam ] = value
600626
601- model = self .type_model (** hyperparams )
627+ # model = copy.deepcopy(self.estimator)
628+ # for hyperparam, value in hyperparams.items():
629+ # setattr(model, hyperparam, value)
602630
603631 if self .fit_on_nan :
604632 X = df .drop (columns = col )
@@ -609,8 +637,8 @@ def fit_transform_element(self, df: pd.DataFrame) -> pd.DataFrame:
609637 if X .empty :
610638 y_imputed = pd .Series (y .mean (), index = y .index )
611639 else :
612- model .fit (X [~ is_na ], y [~ is_na ])
613- y_imputed = model .predict (X [is_na ])
640+ self . estimator .fit (X [~ is_na ], y [~ is_na ])
641+ y_imputed = self . estimator .predict (X [is_na ])
614642 df_imputed .loc [is_na , col ] = y_imputed
615643
616644 return df_imputed
@@ -632,17 +660,19 @@ class ImputerStochasticRegressor(Imputer):
632660 >>> import pandas as pd
633661 >>> from qolmat.imputations.models import ImputeStochasticRegressor
634662 >>> from sklearn.ensemble import ExtraTreesRegressor
635- >>> imputor = ImputeStochasticRegressor(model =ExtraTreesRegressor() )
663+ >>> imputer = ImputeStochasticRegressor(estimator =ExtraTreesRegressor)
636664 >>> df = pd.DataFrame(data=[[1, 1, 1, 1],
637665 >>> [np.nan, np.nan, 2, 3],
638666 >>> [1, 2, 2, 5], [2, 2, 2, 2]],
639667 >>> columns=["var1", "var2", "var3", "var4"])
640- >>> imputor .fit_transform(df)
668+ >>> imputer .fit_transform(df)
641669 """
642670
643- def __init__ (self , type_model : str , groups : List [str ] = [], ** hyperparams ) -> None :
671+ def __init__ (
672+ self , groups : List [str ] = [], estimator : Optional [BaseEstimator ] = None , ** hyperparams
673+ ) -> None :
644674 super ().__init__ (groups = groups , hyperparams = hyperparams )
645- self .type_model = type_model
675+ self .estimator = estimator
646676
647677 def fit_transform_element (self , df : pd .DataFrame ) -> pd .Series :
648678 """
@@ -659,7 +689,6 @@ def fit_transform_element(self, df: pd.DataFrame) -> pd.Series:
659689 imputed dataframe
660690 """
661691 df_imp = df .copy ()
662- model = self .type_model (** self .hyperparams )
663692 cols_with_nans = df .columns [df .isna ().any ()]
664693 cols_without_nans = df .columns [df .notna ().all ()]
665694
@@ -670,8 +699,8 @@ def fit_transform_element(self, df: pd.DataFrame) -> pd.Series:
670699 X = df [cols_without_nans ]
671700 y = df [col ]
672701 is_na = y .isna ()
673- model .fit (X [~ is_na ], y [~ is_na ])
674- y_pred = model .predict (X )
702+ self . estimator .fit (X [~ is_na ], y [~ is_na ])
703+ y_pred = self . estimator .predict (X )
675704 std_error = (y_pred [~ is_na ] - y [~ is_na ]).std ()
676705 random_pred = np .random .normal (size = len (y ), loc = y_pred , scale = std_error )
677706 df_imp .loc [is_na , col ] = random_pred [is_na ]
@@ -696,8 +725,8 @@ class ImputerRPCA(Imputer):
696725
697726 def __init__ (
698727 self ,
699- method : str = "noisy" ,
700728 groups : List [str ] = [],
729+ method : str = "noisy" ,
701730 columnwise : bool = False ,
702731 ** hyperparams ,
703732 ) -> None :
0 commit comments