Skip to content

Commit 7651333

Browse files
author
vm-aifluence-jro
committed
Typing isse solved
1 parent 0c07a99 commit 7651333

File tree

3 files changed

+22
-12
lines changed

3 files changed

+22
-12
lines changed

qolmat/imputations/rpca/rpca.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,21 @@ def _prepare_data(self, X: NDArray) -> NDArray:
5757
else:
5858
raise ValueError("`n_rows` should not be specified when imputing 2D data.")
5959

60-
def get_shape_original(self, X: NDArray, shape: Tuple[int]) -> NDArray:
61-
# if len(shape) == 1 or shape[0] == 1:
62-
# n_values = sum(shape)
63-
# return X.reshape(1, -1)[:, :n_values]
64-
# else:
65-
# return X
66-
X = X.flatten()[: np.prod(shape)]
67-
return X.reshape(shape)
60+
def get_shape_original(self, M: NDArray, X: NDArray) -> NDArray:
61+
"""Shapes an output matrix from the RPCA algorithm into the original shape.
62+
63+
Parameters
64+
----------
65+
M : NDArray
66+
Matrix to reshape
67+
X : NDArray
68+
Matrix of the desired shape
69+
70+
Returns
71+
-------
72+
NDArray
73+
Reshaped matrix
74+
"""
75+
size = X.size
76+
M_flat = M.flatten()[:size]
77+
return M_flat.reshape(X.shape)

qolmat/imputations/rpca/rpca_noisy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ def decompose_rpca_signal(
366366
elif self.norm == "L2":
367367
M, A, U, V, errors = self.decompose_rpca_L2(D_proj, Omega, lam, tau, rank)
368368

369-
M_final = self.get_shape_original(M, X.shape)
370-
A_final = self.get_shape_original(A, X.shape)
369+
M_final = self.get_shape_original(M, X)
370+
A_final = self.get_shape_original(A, X)
371371

372372
return M_final, A_final

qolmat/imputations/rpca/rpca_pcp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,6 @@ def decompose_rpca_signal(
100100
D = self._prepare_data(X)
101101
M, A = self.decompose_rpca(D)
102102

103-
M_final = self.get_shape_original(M, X.shape)
104-
A_final = self.get_shape_original(A, X.shape)
103+
M_final = self.get_shape_original(M, X)
104+
A_final = self.get_shape_original(A, X)
105105
return M_final, A_final

0 commit comments

Comments
 (0)