Skip to content

Commit 8dca42e

Browse files
Julien RousselJulien Roussel
authored andcommitted
mypy passing
1 parent e831b6a commit 8dca42e

File tree

2 files changed

+4
-5
lines changed

2 files changed

+4
-5
lines changed

qolmat/imputations/rpca/rpca_noisy.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -389,8 +389,7 @@ def cost_function(
389389
tau: float,
390390
lam: float,
391391
):
392-
393-
temporal_norm = 0
392+
temporal_norm: float = 0
394393
if len(self.list_etas) > 0:
395394
# matrices for temporal correlation
396395
H = [
@@ -402,7 +401,7 @@ def cost_function(
402401
temporal_norm += eta * np.sum(np.abs(H_matrix @ low_rank))
403402
elif self.norm == "L2":
404403
for eta, H_matrix in zip(self.list_etas, H):
405-
temporal_norm += eta * np.linalg.norm(low_rank @ H_matrix, "fro")
404+
temporal_norm += eta * float(np.linalg.norm(low_rank @ H_matrix, "fro"))
406405
anomalies_norm = np.sum(np.abs(anomalies * Omega))
407406
cost = (
408407
1 / 2 * ((Omega * (observations - low_rank - anomalies)) ** 2).sum()

tests/utils/test_plot.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Tuple
1+
from typing import Any, List, Tuple
22
import matplotlib as mpl
33
import matplotlib.pyplot as plt
44
import numpy as np
@@ -45,7 +45,7 @@ def test_utils_plot_plot_matrices(list_matrices: List[np.ndarray], mocker: Mocke
4545

4646

4747
@pytest.mark.parametrize("list_signals", [list_signals])
48-
def test_utils_plot_plot_signal(list_signals: List[np.ndarray], mocker: MockerFixture) -> None:
48+
def test_utils_plot_plot_signal(list_signals: List[List[Any]], mocker: MockerFixture) -> None:
4949
mocker.patch("matplotlib.pyplot.savefig")
5050
mocker.patch("matplotlib.pyplot.show")
5151
plot.plot_signal(list_signals=list_signals, ylabel="ylabel", title="title")

0 commit comments

Comments
 (0)