@@ -317,67 +317,73 @@ def minimise_loss(
317317 Ir = np .eye (rank )
318318 In = identity (n_rows )
319319
320- for _ in tqdm (
321- range ( max_iterations ) ,
320+ with tqdm (
321+ total = max_iterations ,
322322 desc = "Noisy RPCA loss minimization" ,
323+ unit = "iteration" ,
323324 disable = not verbose ,
324- ):
325- M_temp = M .copy ()
326- A_temp = A .copy ()
327- L_temp = L .copy ()
328- Q_temp = Q .copy ()
329- if norm == "L1" :
330- R_temp = R .copy ()
331- sums = np .zeros ((n_rows , n_cols ))
332- for i_period , _ in enumerate (list_periods ):
333- sums += mu * R [i_period ] - list_H [i_period ] @ Y
334-
335- M = spsolve (
336- (1 + mu ) * In + HtH ,
337- D - A + mu * L @ Q - Y + sums ,
338- )
339- else :
340- M = spsolve (
341- (1 + mu ) * In + 2 * HtH ,
342- D - A + mu * L @ Q - Y ,
343- )
344- M = M .reshape (D .shape )
345-
346- A_Omega = rpca_utils .soft_thresholding (D - M , lam )
347- A_Omega_C = D - M
348- A = np .where (Omega , A_Omega , A_Omega_C )
349- Q = scp .linalg .solve (
350- a = tau * Ir + mu * (L .T @ L ),
351- b = L .T @ (mu * M + Y ),
352- )
353-
354- L = scp .linalg .solve (
355- a = tau * Ir + mu * (Q @ Q .T ),
356- b = Q @ (mu * M .T + Y .T ),
357- ).T
358-
359- Y += mu * (M - L @ Q )
360- if norm == "L1" :
361- for i_period , _ in enumerate (list_periods ):
362- eta = list_etas [i_period ]
363- R [i_period ] = rpca_utils .soft_thresholding (
364- R [i_period ] / mu , eta / mu
325+ ) as pbar :
326+ for _ in range (max_iterations ):
327+ M_temp = M .copy ()
328+ A_temp = A .copy ()
329+ L_temp = L .copy ()
330+ Q_temp = Q .copy ()
331+ if norm == "L1" :
332+ R_temp = R .copy ()
333+ sums = np .zeros ((n_rows , n_cols ))
334+ for i_period , _ in enumerate (list_periods ):
335+ sums += mu * R [i_period ] - list_H [i_period ] @ Y
336+
337+ M = spsolve (
338+ (1 + mu ) * In + HtH ,
339+ D - A + mu * L @ Q - Y + sums ,
365340 )
341+ else :
342+ M = spsolve (
343+ (1 + mu ) * In + 2 * HtH ,
344+ D - A + mu * L @ Q - Y ,
345+ )
346+ M = M .reshape (D .shape )
347+
348+ A_Omega = rpca_utils .soft_thresholding (D - M , lam )
349+ A_Omega_C = D - M
350+ A = np .where (Omega , A_Omega , A_Omega_C )
351+ Q = scp .linalg .solve (
352+ a = tau * Ir + mu * (L .T @ L ),
353+ b = L .T @ (mu * M + Y ),
354+ )
366355
367- mu = min (mu * rho , mu_bar )
368-
369- Mc = np .linalg .norm (M - M_temp , np .inf )
370- Ac = np .linalg .norm (A - A_temp , np .inf )
371- Lc = np .linalg .norm (L - L_temp , np .inf )
372- Qc = np .linalg .norm (Q - Q_temp , np .inf )
373- error_max = max ([Mc , Ac , Lc , Qc ]) # type: ignore # noqa
374- if norm == "L1" :
375- for i_period , _ in enumerate (list_periods ):
376- Rc = np .linalg .norm (R [i_period ] - R_temp [i_period ], np .inf )
377- error_max = max (error_max , Rc ) # type: ignore # noqa
378-
379- if error_max < tolerance :
380- break
356+ L = scp .linalg .solve (
357+ a = tau * Ir + mu * (Q @ Q .T ),
358+ b = Q @ (mu * M .T + Y .T ),
359+ ).T
360+
361+ Y += mu * (M - L @ Q )
362+ if norm == "L1" :
363+ for i_period , _ in enumerate (list_periods ):
364+ eta = list_etas [i_period ]
365+ R [i_period ] = rpca_utils .soft_thresholding (
366+ R [i_period ] / mu , eta / mu
367+ )
368+
369+ mu = min (mu * rho , mu_bar )
370+
371+ Mc = np .linalg .norm (M - M_temp , np .inf )
372+ Ac = np .linalg .norm (A - A_temp , np .inf )
373+ Lc = np .linalg .norm (L - L_temp , np .inf )
374+ Qc = np .linalg .norm (Q - Q_temp , np .inf )
375+ error_max = max ([Mc , Ac , Lc , Qc ]) # type: ignore # noqa
376+ if norm == "L1" :
377+ for i_period , _ in enumerate (list_periods ):
378+ Rc = np .linalg .norm (
379+ R [i_period ] - R_temp [i_period ], np .inf
380+ )
381+ error_max = max (error_max , Rc ) # type: ignore # noqa
382+
383+ if error_max < tolerance :
384+ break
385+ pbar .set_postfix (error = f"{ error_max .item ():.4f} " )
386+ pbar .update (1 )
381387
382388 M = L @ Q
383389
0 commit comments