Skip to content

Commit 5dfcdb9

Browse files
Julien RousselJulien Roussel
authored andcommitted
RPCA online not functional
1 parent c9d68de commit 5dfcdb9

File tree

3 files changed

+120
-47
lines changed

3 files changed

+120
-47
lines changed

examples/1_timeSeries.ipynb

Lines changed: 62 additions & 23 deletions
Large diffs are not rendered by default.

qolmat/imputations/rpca/pcp_rpca.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ def get_params_scale(self, D):
5959
return dict_params
6060

6161
def decompose_rpca(self, D: NDArray) -> Tuple[NDArray, NDArray]:
62-
proj_D = utils.impute_nans(D, method="median")
62+
# proj_D = utils.impute_nans(D, method="median")
63+
proj_D = np.where(np.isnan(D), -1, D)
6364

6465
params_scale = self.get_params_scale(proj_D)
6566

@@ -74,20 +75,48 @@ def decompose_rpca(self, D: NDArray) -> Tuple[NDArray, NDArray]:
7475

7576
errors = np.full((self.max_iter,), fill_value=np.nan)
7677

77-
for iteration in range(self.max_iter):
78+
print("D:")
79+
print(D[:3])
80+
81+
from matplotlib import pyplot as plt
82+
tab10 = plt.get_cmap("tab10")
83+
#plt.figure(figsize=(8, 6))
7884

85+
M = proj_D - A
86+
signal = proj_D.reshape(1, -1)[0]
87+
#plt.plot(signal, color="black")
88+
i_plot = 0
89+
for iteration in range(self.max_iter):
90+
#print("iteration=", iteration)
91+
M_old = M.copy()
7992
M = utils.svd_thresholding(proj_D - A + Y/mu, 1/mu)
93+
deltaM = M - M_old
94+
signalM = M.reshape(1, -1)[0]
95+
A_old = A.copy()
8096
A = utils.soft_thresholding(proj_D - M + Y/mu, lam/mu)
8197
A[~Omega] = (proj_D - M)[~Omega]
98+
deltaA = A - A_old
99+
signalA = A.reshape(1, -1)[0]
82100
Y += mu * (proj_D - M - A)
101+
# signalY = (proj_D - M - A).reshape(1, -1)[0]
102+
# plt.plot(6 + signalY, color=tab10(iteration), ls="-.")
83103

84104
error = np.linalg.norm(D - M - A, "fro")/D_norm
85105
errors[iteration] = error
86106

107+
# if iteration % 10 == 0:
108+
# plt.plot(signalM, color=tab10(i_plot), ls="--")
109+
# plt.plot(4 + signalA, color=tab10(i_plot))
110+
111+
# i_plot += 1
112+
113+
87114
if error < self.tol:
88115
if self.verbose:
89116
print(f"Converged in {iteration} iterations")
90117
break
118+
plt.xlim(0, 30)
119+
plt.show()
91120
return M, A
92121

93122

@@ -116,6 +145,7 @@ def fit_transform(
116145
errors: NDArray
117146
Array of iterative errors
118147
"""
148+
print("coucou")
119149
X = X.copy().T
120150
D = self._prepare_data(X)
121151
M, A = self.decompose_rpca(D)

qolmat/imputations/rpca/temporal_rpca.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,11 @@ def compute_L1(self, proj_D, omega, lam, tau, rank) -> None:
120120

121121
if np.any(np.isnan(proj_D)):
122122
A_omega = utils.soft_thresholding(proj_D - X, lam)
123-
A_omega = utils.ortho_proj(A_omega, omega, inverse=False)
123+
# A_omega = utils.ortho_proj(A_omega, omega, inverse=False)
124124
A_omega_C = proj_D - X
125-
A_omega_C = utils.ortho_proj(A_omega_C, omega, inverse=True)
126-
A = A_omega + A_omega_C
125+
# A_omega_C = utils.ortho_proj(A_omega_C, omega, inverse=True)
126+
# A = A_omega + A_omega_C
127+
A = np.where(omega, A_omega, A_omega_C)
127128
else:
128129
A = utils.soft_thresholding(proj_D - X, lam)
129130

@@ -169,7 +170,7 @@ def compute_L1(self, proj_D, omega, lam, tau, rank) -> None:
169170
V = Q
170171
return M, A, U, V, errors
171172

172-
def compute_L2(self, proj_D, omega, lam, tau, rank) -> None:
173+
def compute_L2(self, proj_D, Omega, lam, tau, rank) -> None:
173174
"""
174175
compute RPCA with possible temporal regularisations, penalised with L2 norm
175176
"""
@@ -208,12 +209,13 @@ def compute_L2(self, proj_D, omega, lam, tau, rank) -> None:
208209
b=(proj_D - A + mu * L @ Q.T - Y).T,
209210
).T
210211

211-
if np.any(~omega):
212+
if np.any(~Omega):
212213
A_omega = utils.soft_thresholding(proj_D - X, lam)
213-
A_omega = utils.ortho_proj(A_omega, omega, inverse=False)
214+
# A_omega = utils.ortho_proj(A_omega, omega, inverse=False)
214215
A_omega_C = proj_D - X
215-
A_omega_C = utils.ortho_proj(A_omega_C, omega, inverse=True)
216-
A = A_omega + A_omega_C
216+
# A_omega_C = utils.ortho_proj(A_omega_C, omega, inverse=True)
217+
# A = A_omega + A_omega_C
218+
A = np.where(Omega, A_omega, A_omega_C)
217219
else:
218220
A = utils.soft_thresholding(proj_D - X, lam)
219221

@@ -446,17 +448,17 @@ def get_params(self):
446448

447449
def get_params_scale_online(
448450
self,
449-
D:NDArray, Lhat: NDArray
451+
D:NDArray, M: NDArray
450452
) -> dict[str, float]:
451453
# D_init = self._prepare_data(signal=X)
452454
params_scale = self.get_params_scale(D)
453455
# burnin = int(D_init.shape[1] * self.burnin)
454456

455457
# super_class = TemporalRPCA(**super().get_params())
456458
# Lhat, _, _ = super_class.fit_transform(X=D_init[:, :burnin])
457-
_, sigmas_hat, _ = np.linalg.svd(Lhat)
458-
online_tau = 1.0 / np.sqrt(len(Lhat)) / np.mean(sigmas_hat[: params_scale["rank"]])
459-
online_lam = 1.0 / np.sqrt(len(Lhat))
459+
_, sigmas_hat, _ = np.linalg.svd(M)
460+
online_tau = 1.0 / np.sqrt(len(M)) / np.mean(sigmas_hat[: params_scale["rank"]])
461+
online_lam = 1.0 / np.sqrt(len(M))
460462
params_scale["online_tau"] = online_tau
461463
params_scale["online_lam"] = online_lam
462464
return params_scale
@@ -499,34 +501,36 @@ def fit_transform(
499501
# Lhat, Shat, _, _, _ =super_class.fit_transform(X=D_init[:, :burnin])
500502

501503
proj_D = utils.impute_nans(D_init, method="median")
502-
omega = ~np.isnan(D_init)
504+
Omega = ~np.isnan(D_init)
503505

504506
params_scale = self.get_params_scale(proj_D)
505507

506508
lam = params_scale["lam"] if self.lam is None else self.lam
507509
rank = params_scale["rank"] if self.rank is None else self.rank
508510
tau = params_scale["tau"] if self.tau is None else self.tau
509511

512+
D_burnin = proj_D[:, :burnin]
513+
Omega_burnin = Omega[:, :burnin]
514+
510515
if self.norm == "L1":
511-
M, A, U, V, errors = self.compute_L1(proj_D, omega, lam, tau, rank)
516+
M, A, U, V, errors = self.compute_L1(D_burnin, Omega_burnin, lam, tau, rank)
512517
elif self.norm == "L2":
513-
M, A, U, V, errors = self.compute_L2(proj_D, omega, lam, tau, rank)
518+
M, A, U, V, errors = self.compute_L2(D_burnin, Omega_burnin, lam, tau, rank)
514519

515-
Lhat, Shat, _ = np.linalg.svd(M, full_matrices=False, compute_uv=True)
520+
# Lhat, Shat, _ = np.linalg.svd(M, full_matrices=False, compute_uv=True)
516521

517-
params_scale = self.get_params_scale_online(proj_D, Lhat)
522+
params_scale_online = self.get_params_scale_online(proj_D, M)
518523

519-
online_tau = params_scale["online_tau"] if self.online_tau is None else self.online_tau
520-
online_lam = params_scale["online_lam"] if self.online_lam is None else self.online_lam
524+
online_tau = self.online_tau or params_scale_online["online_tau"]
525+
online_lam = params_scale_online["online_lam"] if self.online_lam is None else self.online_lam
521526

522527
if len(self.online_list_etas) == 0:
523528
self.online_list_etas = self.list_etas
524529

525-
approx_rank = utils.approx_rank(proj_D[:, :burnin])
530+
approx_rank = utils.approx_rank(D_burnin)
526531

527-
# TODO : is it really Lhat that should be used here?!
528532
Uhat, sigmas_hat, Vhat = randomized_svd(
529-
Lhat, n_components=approx_rank, n_iter=5, random_state=42
533+
M, n_components=approx_rank, n_iter=5, random_state=42
530534
)
531535
U = Uhat[:, :approx_rank]@(np.sqrt(np.diag(sigmas_hat[:approx_rank])))
532536

0 commit comments

Comments
 (0)