Skip to content

Commit 43e8102

Browse files
Julien RousselJulien Roussel
authored andcommitted
EM with single sample managed
1 parent f43cbb2 commit 43e8102

File tree

3 files changed

+39
-17
lines changed

3 files changed

+39
-17
lines changed

examples/benchmark.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,18 @@ In Qolmat, a few data imputation methods are implemented as well as a way to eva
1919

2020
First, import some useful librairies
2121

22+
```python
23+
X= np.array([[0], [1], [2]])
24+
```
25+
26+
```python
27+
np.cov(X)
28+
```
29+
30+
```python
31+
32+
```
33+
2234
```python
2335
import warnings
2436
# warnings.filterwarnings('error')

qolmat/imputations/em_sampler.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,10 @@ def fit(self, X: NDArray):
179179

180180
# first imputation
181181
X_sample_last = utils.linear_interpolation(X)
182+
print("X_sample_last")
183+
print(X_sample_last)
184+
print("x")
185+
print(X)
182186
self.fit_distribution(X_sample_last)
183187

184188
for iter_em in range(self.max_iter_em):
@@ -300,7 +304,13 @@ def __init__(
300304

301305
def fit_distribution(self, X):
302306
self.means = np.mean(X, axis=1)
303-
self.cov = np.cov(X).reshape(len(X), -1)
307+
n_rows, n_cols = X.shape
308+
if n_cols == 1:
309+
self.cov = np.eye(n_rows)
310+
else:
311+
self.cov = np.cov(X).reshape(n_rows, -1)
312+
print("cov")
313+
print(self.cov)
304314
self.cov_inv = np.linalg.pinv(self.cov, rcond=1e-2)
305315

306316
def get_loglikelihood(self, X: NDArray) -> float:

qolmat/imputations/imputers.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1222,9 +1222,10 @@ def fit(self, X: pd.DataFrame, y=None):
12221222
Returns self.
12231223
"""
12241224
super().fit(X)
1225+
df = self._check_input(X)
12251226
hyperparameters = self.get_hyperparams()
12261227
self.imputer_ = KNNImputer(metric="nan_euclidean", **hyperparameters)
1227-
self.imputer_.fit(X)
1228+
self.imputer_.fit(df)
12281229
return self
12291230

12301231
def _transform_element(self, df: pd.DataFrame, col: str = "__all__") -> pd.DataFrame:
@@ -1327,10 +1328,9 @@ def fit(self, X: pd.DataFrame, y=None):
13271328
"""
13281329
hyperparams = self.get_hyperparams()
13291330
super().fit(X)
1330-
if not isinstance(X, (pd.DataFrame)):
1331-
X = pd.DataFrame(np.array(X), columns=[i for i in range(np.array(X).shape[1])])
1331+
df = self._check_input(X)
13321332
self.imputer_ = IterativeImputer(estimator=self.estimator, **hyperparams)
1333-
self.imputer_.fit(X)
1333+
self.imputer_.fit(df)
13341334
self.n_iter_ = self.imputer_.n_iter_
13351335
return self
13361336

@@ -1444,18 +1444,17 @@ def fit(self, X: pd.DataFrame, y: pd.DataFrame = None) -> _Imputer:
14441444
"""
14451445

14461446
super().fit(X)
1447-
if not isinstance(X, (pd.DataFrame)):
1448-
X = pd.DataFrame(np.array(X), columns=[i for i in range(np.array(X).shape[1])])
1447+
df = self._check_input(X)
14491448

1450-
cols_with_nans = X.columns[X.isna().any()]
1449+
cols_with_nans = df.columns[df.isna().any()]
14511450
self.estimators_ = {}
14521451
for col in cols_with_nans:
14531452
# Define the Train and Test set
1454-
X_ = X.drop(columns=col, errors="ignore")
1455-
y_ = X[col]
1453+
X_ = df.drop(columns=col, errors="ignore")
1454+
y_ = df[col]
14561455

14571456
# Selects only the valid values in the Train Set according to the chosen method
1458-
is_valid = pd.Series(True, index=X.index)
1457+
is_valid = pd.Series(True, index=df.index)
14591458
if self.handler_nan == "fit":
14601459
pass
14611460
elif self.handler_nan == "row":
@@ -1787,30 +1786,31 @@ def fit(self, X: pd.DataFrame, y=None):
17871786
Returns self.
17881787
"""
17891788
super().fit(X)
1789+
df = self._check_input(X)
17901790

1791-
n_rows, n_cols = X.shape
1792-
if n_rows == 1:
1793-
raise ValueError("n_samples=1 is not allowed!")
1791+
# n_rows, n_cols = df.shape
1792+
# if n_rows == 1:
1793+
# raise ValueError("n_samples=1 is not allowed!")
17941794

17951795
if self.model not in ["multinormal", "VAR1"]:
17961796
raise ValueError(
17971797
f"Model argument `{self.model}` is invalid!"
17981798
" Valid values are `multinormal`and `VAR`."
17991799
)
18001800

1801-
cols_with_nans = X.columns[X.isna().any()]
1801+
cols_with_nans = df.columns[df.isna().any()]
18021802

18031803
self._models = {}
18041804
if self.columnwise:
18051805
for col in cols_with_nans:
18061806
hyperparams = self.get_hyperparams(col=col)
18071807
model = self.get_model(random_state=self.rng_, **hyperparams)
1808-
model.fit(X[col].values)
1808+
model.fit(df[col].values)
18091809
self._models[col] = model
18101810
else:
18111811
hyperparams = self.get_hyperparams()
18121812
model = self.get_model(random_state=self.rng_, **hyperparams)
1813-
model.fit(X.values.T)
1813+
model.fit(df.values.T)
18141814
self._models["__all__"] = model
18151815
return self
18161816

0 commit comments

Comments
 (0)