Skip to content

Commit 1d8de22

Browse files
Julien RousselJulien Roussel
authored andcommitted
Merge branch 'dev' into angoho_CSDI_T
2 parents 61a6023 + 141e9dc commit 1d8de22

File tree

12 files changed

+463
-10
lines changed

12 files changed

+463
-10
lines changed

.github/workflows/test.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,7 @@ jobs:
4141
run: |
4242
mypy qolmat
4343
echo you should uncomment mypy qolmat and delete this line
44+
- name: Upload coverage reports to Codecov
45+
uses: codecov/codecov-action@v3
46+
env:
47+
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}

MANIFEST.in

Lines changed: 0 additions & 1 deletion
This file was deleted.

README.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
.. -*- mode: rst -*-
22
3-
|GitHubActions|_ |ReadTheDocs|_ |License|_ |PythonVersion|_ |PyPi|_ |Release|_ |Commits|_
3+
|GitHubActions|_ |ReadTheDocs|_ |License|_ |PythonVersion|_ |PyPi|_ |Release|_ |Commits|_ |Codecov|_
44

55
.. |GitHubActions| image:: https://github.com/Quantmetry/qolmat/actions/workflows/test.yml/badge.svg
66
.. _GitHubActions: https://github.com/Quantmetry/qolmat/actions
@@ -23,6 +23,9 @@
2323
.. |Commits| image:: https://img.shields.io/github/commits-since/Quantmetry/qolmat/latest/main
2424
.. _Commits: https://github.com/Quantmetry/qolmat/commits/main
2525

26+
.. |Codecov| image:: https://codecov.io/gh/quantmetry/qolmat/branch/master/graph/badge.svg
27+
.. _Codecov: https://codecov.io/gh/quantmetry/qolmat
28+
2629
.. image:: https://raw.githubusercontent.com/Quantmetry/qolmat/main/docs/images/logo.png
2730
:align: center
2831

environment.doc.yml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
name: env_qolmat_doc
22
channels:
3-
- defaults
43
- conda-forge
4+
- defaults
55
dependencies:
66
- numpydoc=1.1.0
77
- python=3.8
88
- sphinx=4.3.2
99
- sphinx-gallery=0.10.1
1010
- sphinx_rtd_theme=1.0.0
11-
- sphinx_markdown_tables==0.0.17
1211
- typing_extensions=4.0.1
12+
- pip
13+
- pip:
14+
- sphinx-markdown-tables==0.0.17
File renamed without changes.
File renamed without changes.

qolmat/imputations/imputers.py

Lines changed: 117 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from qolmat.imputations import em_sampler
2020
from qolmat.imputations.rpca import rpca, rpca_noisy, rpca_pcp
21+
from qolmat.imputations import softimpute
2122
from qolmat.utils.exceptions import NotDataFrame
2223
from qolmat.utils.utils import HyperValue
2324

@@ -1772,6 +1773,120 @@ def _transform_element(
17721773
return df_imputed
17731774

17741775

1776+
class ImputerSoftImpute(_Imputer):
1777+
"""_summary_
1778+
1779+
Parameters
1780+
----------
1781+
"""
1782+
1783+
def __init__(
1784+
self,
1785+
groups: Tuple[str, ...] = (),
1786+
columnwise: bool = False,
1787+
random_state: Union[None, int, np.random.RandomState] = None,
1788+
period: int = 1,
1789+
rank: int = 2,
1790+
tolerance: float = 1e-05,
1791+
tau: float = 0,
1792+
max_iterations: int = 100,
1793+
verbose: bool = False,
1794+
projected: bool = True,
1795+
):
1796+
super().__init__(
1797+
imputer_params=(
1798+
"period",
1799+
"rank",
1800+
"tolerance",
1801+
"tau",
1802+
"max_iterations",
1803+
"verbose",
1804+
"projected",
1805+
),
1806+
groups=groups,
1807+
columnwise=columnwise,
1808+
random_state=random_state,
1809+
)
1810+
self.period = period
1811+
self.rank = rank
1812+
self.tolerance = tolerance
1813+
self.tau = tau
1814+
self.max_iterations = max_iterations
1815+
self.verbose = verbose
1816+
self.projected = projected
1817+
1818+
def _fit_element(
1819+
self, df: pd.DataFrame, col: str = "__all__", ngroup: int = 0
1820+
) -> softimpute.SoftImpute:
1821+
"""
1822+
Fits the imputer on `df`, at the group and/or column level depending on
1823+
self.groups and self.columnwise.
1824+
1825+
Parameters
1826+
----------
1827+
df : pd.DataFrame
1828+
Dataframe on which the imputer is fitted
1829+
col : str, optional
1830+
Column on which the imputer is fitted, by default "__all__"
1831+
ngroup : int, optional
1832+
Id of the group on which the method is applied
1833+
1834+
Returns
1835+
-------
1836+
Any
1837+
Return fitted SoftImpute model
1838+
1839+
Raises
1840+
------
1841+
NotDataFrame
1842+
Input has to be a pandas.DataFrame.
1843+
"""
1844+
self._check_dataframe(df)
1845+
assert col == "__all__"
1846+
hyperparams = self.get_hyperparams()
1847+
model = softimpute.SoftImpute(random_state=self._rng, **hyperparams)
1848+
model = model.fit(df.values)
1849+
return model
1850+
1851+
def _transform_element(
1852+
self, df: pd.DataFrame, col: str = "__all__", ngroup: int = 0
1853+
) -> pd.DataFrame:
1854+
"""
1855+
Transforms the fataframe `df`, at the group level depending on
1856+
self.groups
1857+
1858+
Parameters
1859+
----------
1860+
df : pd.DataFrame
1861+
Dataframe or column to impute
1862+
col : str, optional
1863+
Column transformed by the imputer, by default "__all__"
1864+
1865+
Returns
1866+
-------
1867+
pd.DataFrame
1868+
Imputed dataframe
1869+
1870+
Raises
1871+
------
1872+
NotDataFrame
1873+
Input has to be a pandas.DataFrame.
1874+
"""
1875+
self._check_dataframe(df)
1876+
assert col == "__all__"
1877+
model = self._dict_fitting["__all__"][ngroup]
1878+
X_imputed = model.transform(df.values)
1879+
return pd.DataFrame(X_imputed, index=df.index, columns=df.columns)
1880+
1881+
def _more_tags(self):
1882+
return {
1883+
"_xfail_checks": {
1884+
"check_fit2d_1sample": "This test shouldn't be running at all!",
1885+
"check_fit2d_1feature": "This test shouldn't be running at all!",
1886+
},
1887+
}
1888+
1889+
17751890
class ImputerEM(_Imputer):
17761891
"""
17771892
This class implements an imputation method based on joint modelling and an inference using a
@@ -1874,7 +1989,7 @@ def get_model(self, **hyperparams) -> em_sampler.EM:
18741989

18751990
def _fit_element(
18761991
self, df: pd.DataFrame, col: str = "__all__", ngroup: int = 0
1877-
) -> IterativeImputer:
1992+
) -> em_sampler.EM:
18781993
"""
18791994
Fits the imputer on `df`, at the group and/or column level depending onself.groups and
18801995
self.columnwise.
@@ -1891,7 +2006,7 @@ def _fit_element(
18912006
Returns
18922007
-------
18932008
Any
1894-
Return fitted KNN model
2009+
Return fitted EM model
18952010
18962011
Raises
18972012
------

0 commit comments

Comments
 (0)