@@ -120,10 +120,11 @@ def compute_L1(self, proj_D, omega, lam, tau, rank) -> None:
120120
121121 if np .any (np .isnan (proj_D )):
122122 A_omega = utils .soft_thresholding (proj_D - X , lam )
123- A_omega = utils .ortho_proj (A_omega , omega , inverse = False )
123+ # A_omega = utils.ortho_proj(A_omega, omega, inverse=False)
124124 A_omega_C = proj_D - X
125- A_omega_C = utils .ortho_proj (A_omega_C , omega , inverse = True )
126- A = A_omega + A_omega_C
125+ # A_omega_C = utils.ortho_proj(A_omega_C, omega, inverse=True)
126+ # A = A_omega + A_omega_C
127+ A = np .where (omega , A_omega , A_omega_C )
127128 else :
128129 A = utils .soft_thresholding (proj_D - X , lam )
129130
@@ -169,7 +170,7 @@ def compute_L1(self, proj_D, omega, lam, tau, rank) -> None:
169170 V = Q
170171 return M , A , U , V , errors
171172
172- def compute_L2 (self , proj_D , omega , lam , tau , rank ) -> None :
173+ def compute_L2 (self , proj_D , Omega , lam , tau , rank ) -> None :
173174 """
174175 compute RPCA with possible temporal regularisations, penalised with L2 norm
175176 """
@@ -208,12 +209,13 @@ def compute_L2(self, proj_D, omega, lam, tau, rank) -> None:
208209 b = (proj_D - A + mu * L @ Q .T - Y ).T ,
209210 ).T
210211
211- if np .any (~ omega ):
212+ if np .any (~ Omega ):
212213 A_omega = utils .soft_thresholding (proj_D - X , lam )
213- A_omega = utils .ortho_proj (A_omega , omega , inverse = False )
214+ # A_omega = utils.ortho_proj(A_omega, omega, inverse=False)
214215 A_omega_C = proj_D - X
215- A_omega_C = utils .ortho_proj (A_omega_C , omega , inverse = True )
216- A = A_omega + A_omega_C
216+ # A_omega_C = utils.ortho_proj(A_omega_C, omega, inverse=True)
217+ # A = A_omega + A_omega_C
218+ A = np .where (Omega , A_omega , A_omega_C )
217219 else :
218220 A = utils .soft_thresholding (proj_D - X , lam )
219221
@@ -446,17 +448,17 @@ def get_params(self):
446448
447449 def get_params_scale_online (
448450 self ,
449- D :NDArray , Lhat : NDArray
451+ D :NDArray , M : NDArray
450452 ) -> dict [str , float ]:
451453 # D_init = self._prepare_data(signal=X)
452454 params_scale = self .get_params_scale (D )
453455 # burnin = int(D_init.shape[1] * self.burnin)
454456
455457 # super_class = TemporalRPCA(**super().get_params())
456458 # Lhat, _, _ = super_class.fit_transform(X=D_init[:, :burnin])
457- _ , sigmas_hat , _ = np .linalg .svd (Lhat )
458- online_tau = 1.0 / np .sqrt (len (Lhat )) / np .mean (sigmas_hat [: params_scale ["rank" ]])
459- online_lam = 1.0 / np .sqrt (len (Lhat ))
459+ _ , sigmas_hat , _ = np .linalg .svd (M )
460+ online_tau = 1.0 / np .sqrt (len (M )) / np .mean (sigmas_hat [: params_scale ["rank" ]])
461+ online_lam = 1.0 / np .sqrt (len (M ))
460462 params_scale ["online_tau" ] = online_tau
461463 params_scale ["online_lam" ] = online_lam
462464 return params_scale
@@ -499,34 +501,36 @@ def fit_transform(
499501 # Lhat, Shat, _, _, _ =super_class.fit_transform(X=D_init[:, :burnin])
500502
501503 proj_D = utils .impute_nans (D_init , method = "median" )
502- omega = ~ np .isnan (D_init )
504+ Omega = ~ np .isnan (D_init )
503505
504506 params_scale = self .get_params_scale (proj_D )
505507
506508 lam = params_scale ["lam" ] if self .lam is None else self .lam
507509 rank = params_scale ["rank" ] if self .rank is None else self .rank
508510 tau = params_scale ["tau" ] if self .tau is None else self .tau
509511
512+ D_burnin = proj_D [:, :burnin ]
513+ Omega_burnin = Omega [:, :burnin ]
514+
510515 if self .norm == "L1" :
511- M , A , U , V , errors = self .compute_L1 (proj_D , omega , lam , tau , rank )
516+ M , A , U , V , errors = self .compute_L1 (D_burnin , Omega_burnin , lam , tau , rank )
512517 elif self .norm == "L2" :
513- M , A , U , V , errors = self .compute_L2 (proj_D , omega , lam , tau , rank )
518+ M , A , U , V , errors = self .compute_L2 (D_burnin , Omega_burnin , lam , tau , rank )
514519
515- Lhat , Shat , _ = np .linalg .svd (M , full_matrices = False , compute_uv = True )
520+ # Lhat, Shat, _ = np.linalg.svd(M, full_matrices=False, compute_uv=True)
516521
517- params_scale = self .get_params_scale_online (proj_D , Lhat )
522+ params_scale_online = self .get_params_scale_online (proj_D , M )
518523
519- online_tau = params_scale [ "online_tau" ] if self .online_tau is None else self . online_tau
520- online_lam = params_scale ["online_lam" ] if self .online_lam is None else self .online_lam
524+ online_tau = self .online_tau or params_scale_online [ " online_tau" ]
525+ online_lam = params_scale_online ["online_lam" ] if self .online_lam is None else self .online_lam
521526
522527 if len (self .online_list_etas ) == 0 :
523528 self .online_list_etas = self .list_etas
524529
525- approx_rank = utils .approx_rank (proj_D [:, : burnin ] )
530+ approx_rank = utils .approx_rank (D_burnin )
526531
527- # TODO : is it really Lhat that should be used here?!
528532 Uhat , sigmas_hat , Vhat = randomized_svd (
529- Lhat , n_components = approx_rank , n_iter = 5 , random_state = 42
533+ M , n_components = approx_rank , n_iter = 5 , random_state = 42
530534 )
531535 U = Uhat [:, :approx_rank ]@(np .sqrt (np .diag (sigmas_hat [:approx_rank ])))
532536
0 commit comments