1818
1919from qolmat .imputations import em_sampler
2020from qolmat .imputations .rpca import rpca , rpca_noisy , rpca_pcp
21+ from qolmat .imputations import softimpute
2122from qolmat .utils .exceptions import NotDataFrame
2223from qolmat .utils .utils import HyperValue
2324
@@ -1772,6 +1773,120 @@ def _transform_element(
17721773 return df_imputed
17731774
17741775
1776+ class ImputerSoftImpute (_Imputer ):
1777+ """_summary_
1778+
1779+ Parameters
1780+ ----------
1781+ """
1782+
1783+ def __init__ (
1784+ self ,
1785+ groups : Tuple [str , ...] = (),
1786+ columnwise : bool = False ,
1787+ random_state : Union [None , int , np .random .RandomState ] = None ,
1788+ period : int = 1 ,
1789+ rank : int = 2 ,
1790+ tolerance : float = 1e-05 ,
1791+ tau : float = 0 ,
1792+ max_iterations : int = 100 ,
1793+ verbose : bool = False ,
1794+ projected : bool = True ,
1795+ ):
1796+ super ().__init__ (
1797+ imputer_params = (
1798+ "period" ,
1799+ "rank" ,
1800+ "tolerance" ,
1801+ "tau" ,
1802+ "max_iterations" ,
1803+ "verbose" ,
1804+ "projected" ,
1805+ ),
1806+ groups = groups ,
1807+ columnwise = columnwise ,
1808+ random_state = random_state ,
1809+ )
1810+ self .period = period
1811+ self .rank = rank
1812+ self .tolerance = tolerance
1813+ self .tau = tau
1814+ self .max_iterations = max_iterations
1815+ self .verbose = verbose
1816+ self .projected = projected
1817+
1818+ def _fit_element (
1819+ self , df : pd .DataFrame , col : str = "__all__" , ngroup : int = 0
1820+ ) -> softimpute .SoftImpute :
1821+ """
1822+ Fits the imputer on `df`, at the group and/or column level depending on
1823+ self.groups and self.columnwise.
1824+
1825+ Parameters
1826+ ----------
1827+ df : pd.DataFrame
1828+ Dataframe on which the imputer is fitted
1829+ col : str, optional
1830+ Column on which the imputer is fitted, by default "__all__"
1831+ ngroup : int, optional
1832+ Id of the group on which the method is applied
1833+
1834+ Returns
1835+ -------
1836+ Any
1837+ Return fitted SoftImpute model
1838+
1839+ Raises
1840+ ------
1841+ NotDataFrame
1842+ Input has to be a pandas.DataFrame.
1843+ """
1844+ self ._check_dataframe (df )
1845+ assert col == "__all__"
1846+ hyperparams = self .get_hyperparams ()
1847+ model = softimpute .SoftImpute (random_state = self ._rng , ** hyperparams )
1848+ model = model .fit (df .values )
1849+ return model
1850+
1851+ def _transform_element (
1852+ self , df : pd .DataFrame , col : str = "__all__" , ngroup : int = 0
1853+ ) -> pd .DataFrame :
1854+ """
1855+ Transforms the fataframe `df`, at the group level depending on
1856+ self.groups
1857+
1858+ Parameters
1859+ ----------
1860+ df : pd.DataFrame
1861+ Dataframe or column to impute
1862+ col : str, optional
1863+ Column transformed by the imputer, by default "__all__"
1864+
1865+ Returns
1866+ -------
1867+ pd.DataFrame
1868+ Imputed dataframe
1869+
1870+ Raises
1871+ ------
1872+ NotDataFrame
1873+ Input has to be a pandas.DataFrame.
1874+ """
1875+ self ._check_dataframe (df )
1876+ assert col == "__all__"
1877+ model = self ._dict_fitting ["__all__" ][ngroup ]
1878+ X_imputed = model .transform (df .values )
1879+ return pd .DataFrame (X_imputed , index = df .index , columns = df .columns )
1880+
1881+ def _more_tags (self ):
1882+ return {
1883+ "_xfail_checks" : {
1884+ "check_fit2d_1sample" : "This test shouldn't be running at all!" ,
1885+ "check_fit2d_1feature" : "This test shouldn't be running at all!" ,
1886+ },
1887+ }
1888+
1889+
17751890class ImputerEM (_Imputer ):
17761891 """
17771892 This class implements an imputation method based on joint modelling and an inference using a
@@ -1874,7 +1989,7 @@ def get_model(self, **hyperparams) -> em_sampler.EM:
18741989
18751990 def _fit_element (
18761991 self , df : pd .DataFrame , col : str = "__all__" , ngroup : int = 0
1877- ) -> IterativeImputer :
1992+ ) -> em_sampler . EM :
18781993 """
18791994 Fits the imputer on `df`, at the group and/or column level depending onself.groups and
18801995 self.columnwise.
@@ -1891,7 +2006,7 @@ def _fit_element(
18912006 Returns
18922007 -------
18932008 Any
1894- Return fitted KNN model
2009+ Return fitted EM model
18952010
18962011 Raises
18972012 ------
0 commit comments