Skip to content

Commit 5b71ac4

Browse files
authored
Merge pull request #51 from simpeg/accept_lin_operator
Accepts Linear Operator as a valid input for solvers
2 parents fb88e36 + 5f17cf8 commit 5b71ac4

File tree

3 files changed

+22
-3
lines changed

3 files changed

+22
-3
lines changed

pymatsolver/solvers.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,17 +77,21 @@ def __init__(
7777
if is_symmetric is None:
7878
if sp.issparse(A):
7979
is_symmetric = (A.T != A).nnz == 0
80-
else:
80+
elif isinstance(A, np.ndarray):
8181
is_symmetric = issymmetric(A)
82+
else:
83+
is_symmetric = False
8284
self.is_symmetric = is_symmetric
8385
if is_hermitian is None:
8486
if self.is_real:
8587
is_hermitian = self.is_symmetric
8688
else:
8789
if sp.issparse(A):
8890
is_hermitian = (A.T.conjugate() != A).nnz == 0
89-
else:
91+
elif isinstance(A, np.ndarray):
9092
is_hermitian = ishermitian(A)
93+
else:
94+
is_hermitian = False
9195

9296
self.is_hermitian = is_hermitian
9397

tests/test_Scipy.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from pymatsolver import Solver, Diagonal, SolverCG, SolverLU
22
import scipy.sparse as sp
3+
from scipy.sparse.linalg import aslinearoperator
34
import numpy as np
45
import numpy.testing as npt
56
import pytest
@@ -57,6 +58,17 @@ def test_solver(a_matrix, n_rhs, solver):
5758

5859
npt.assert_allclose(x, b, atol=tol)
5960

61+
@pytest.mark.parametrize('dtype', [np.float64, np.complex128])
62+
def test_iterative_solver_linear_op(dtype):
63+
n = 10
64+
A = aslinearoperator(sp.eye(n).astype(dtype))
65+
66+
Ainv = SolverCG(A)
67+
68+
rhs = np.linspace(0.9, 1.1, n)
69+
70+
npt.assert_allclose(Ainv @ rhs, rhs)
71+
6072
@pytest.mark.parametrize('n_rhs', [1, 5])
6173
def test_diag_solver(n_rhs):
6274
n = 10

tests/test_Wrappers.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@ def test_wrapper_unused_kwargs(solver_class):
1414
with pytest.warns(UnusedArgumentWarning, match="Unused keyword argument.*"):
1515
solver_class(A, not_a_keyword_arg=True)
1616

17+
1718
def test_good_arg_iterative():
1819
# Ensure this doesn't throw a warning!
1920
with warnings.catch_warnings():
2021
warnings.simplefilter("error")
2122
SolverCG(sp.eye(10), rtol=1e-4)
2223

24+
2325
def test_good_arg_direct():
2426
# Ensure this doesn't throw a warning!
2527
with warnings.catch_warnings():
@@ -40,7 +42,6 @@ def __init__(self, A):
4042
WrappedClass(sp.eye(2))
4143

4244

43-
4445
def test_direct_clean_function():
4546
def direct_func(A):
4647
class Empty():
@@ -67,6 +68,7 @@ def clean(self):
6768
Ainv.clean()
6869
assert Ainv.solver.A is None
6970

71+
7072
def test_iterative_deprecations():
7173

7274
with pytest.warns(FutureWarning, match="check_accuracy and accuracy_tol were unused.*"):
@@ -75,6 +77,7 @@ def test_iterative_deprecations():
7577
with pytest.warns(FutureWarning, match="check_accuracy and accuracy_tol were unused.*"):
7678
wrap_iterative(lambda a, x: x, accuracy_tol=1E-3)
7779

80+
7881
def test_non_scipy_iterative():
7982
def iterative_solver(A, x):
8083
return x

0 commit comments

Comments
 (0)