@@ -1222,9 +1222,10 @@ def fit(self, X: pd.DataFrame, y=None):
12221222 Returns self.
12231223 """
12241224 super ().fit (X )
1225+ df = self ._check_input (X )
12251226 hyperparameters = self .get_hyperparams ()
12261227 self .imputer_ = KNNImputer (metric = "nan_euclidean" , ** hyperparameters )
1227- self .imputer_ .fit (X )
1228+ self .imputer_ .fit (df )
12281229 return self
12291230
12301231 def _transform_element (self , df : pd .DataFrame , col : str = "__all__" ) -> pd .DataFrame :
@@ -1327,10 +1328,9 @@ def fit(self, X: pd.DataFrame, y=None):
13271328 """
13281329 hyperparams = self .get_hyperparams ()
13291330 super ().fit (X )
1330- if not isinstance (X , (pd .DataFrame )):
1331- X = pd .DataFrame (np .array (X ), columns = [i for i in range (np .array (X ).shape [1 ])])
1331+ df = self ._check_input (X )
13321332 self .imputer_ = IterativeImputer (estimator = self .estimator , ** hyperparams )
1333- self .imputer_ .fit (X )
1333+ self .imputer_ .fit (df )
13341334 self .n_iter_ = self .imputer_ .n_iter_
13351335 return self
13361336
@@ -1444,18 +1444,17 @@ def fit(self, X: pd.DataFrame, y: pd.DataFrame = None) -> _Imputer:
14441444 """
14451445
14461446 super ().fit (X )
1447- if not isinstance (X , (pd .DataFrame )):
1448- X = pd .DataFrame (np .array (X ), columns = [i for i in range (np .array (X ).shape [1 ])])
1447+ df = self ._check_input (X )
14491448
1450- cols_with_nans = X .columns [X .isna ().any ()]
1449+ cols_with_nans = df .columns [df .isna ().any ()]
14511450 self .estimators_ = {}
14521451 for col in cols_with_nans :
14531452 # Define the Train and Test set
1454- X_ = X .drop (columns = col , errors = "ignore" )
1455- y_ = X [col ]
1453+ X_ = df .drop (columns = col , errors = "ignore" )
1454+ y_ = df [col ]
14561455
14571456 # Selects only the valid values in the Train Set according to the chosen method
1458- is_valid = pd .Series (True , index = X .index )
1457+ is_valid = pd .Series (True , index = df .index )
14591458 if self .handler_nan == "fit" :
14601459 pass
14611460 elif self .handler_nan == "row" :
@@ -1787,30 +1786,31 @@ def fit(self, X: pd.DataFrame, y=None):
17871786 Returns self.
17881787 """
17891788 super ().fit (X )
1789+ df = self ._check_input (X )
17901790
1791- n_rows , n_cols = X .shape
1792- if n_rows == 1 :
1793- raise ValueError ("n_samples=1 is not allowed!" )
1791+ # n_rows, n_cols = df .shape
1792+ # if n_rows == 1:
1793+ # raise ValueError("n_samples=1 is not allowed!")
17941794
17951795 if self .model not in ["multinormal" , "VAR1" ]:
17961796 raise ValueError (
17971797 f"Model argument `{ self .model } ` is invalid!"
17981798 " Valid values are `multinormal`and `VAR`."
17991799 )
18001800
1801- cols_with_nans = X .columns [X .isna ().any ()]
1801+ cols_with_nans = df .columns [df .isna ().any ()]
18021802
18031803 self ._models = {}
18041804 if self .columnwise :
18051805 for col in cols_with_nans :
18061806 hyperparams = self .get_hyperparams (col = col )
18071807 model = self .get_model (random_state = self .rng_ , ** hyperparams )
1808- model .fit (X [col ].values )
1808+ model .fit (df [col ].values )
18091809 self ._models [col ] = model
18101810 else :
18111811 hyperparams = self .get_hyperparams ()
18121812 model = self .get_model (random_state = self .rng_ , ** hyperparams )
1813- model .fit (X .values .T )
1813+ model .fit (df .values .T )
18141814 self ._models ["__all__" ] = model
18151815 return self
18161816
0 commit comments