|
10 | 10 |
|
11 | 11 | import numpy as np |
12 | 12 | import numpy.typing as npt |
13 | | -from scipy.sparse import sparray |
14 | 13 | from scipy.sparse.linalg import LinearOperator, aslinearoperator |
15 | 14 |
|
16 | | -from ..typing import Model |
| 15 | +from ..typing import Model, SparseArray |
17 | 16 |
|
18 | 17 |
|
19 | 18 | class Objective(ABC): |
@@ -51,7 +50,7 @@ def gradient(self, model: Model) -> npt.NDArray[np.float64]: |
51 | 50 | @abstractmethod |
52 | 51 | def hessian( |
53 | 52 | self, model: Model |
54 | | - ) -> npt.NDArray[np.float64] | sparray | LinearOperator: |
| 53 | + ) -> npt.NDArray[np.float64] | SparseArray | LinearOperator: |
55 | 54 | """ |
56 | 55 | Evaluate the hessian of the objective function for a given model. |
57 | 56 | """ |
@@ -150,7 +149,7 @@ def gradient(self, model: Model) -> npt.NDArray[np.float64]: |
150 | 149 |
|
151 | 150 | def hessian( |
152 | 151 | self, model: Model |
153 | | - ) -> npt.NDArray[np.float64] | sparray | LinearOperator: |
| 152 | + ) -> npt.NDArray[np.float64] | SparseArray | LinearOperator: |
154 | 153 | """ |
155 | 154 | Evaluate the hessian of the objective function for a given model. |
156 | 155 | """ |
@@ -243,7 +242,7 @@ def gradient(self, model: Model) -> npt.NDArray[np.float64]: |
243 | 242 |
|
244 | 243 | def hessian( |
245 | 244 | self, model: Model |
246 | | - ) -> npt.NDArray[np.float64] | sparray | LinearOperator: |
| 245 | + ) -> npt.NDArray[np.float64] | SparseArray | LinearOperator: |
247 | 246 | """ |
248 | 247 | Evaluate the hessian of the objective function for a given model. |
249 | 248 | """ |
@@ -362,8 +361,8 @@ def _get_n_params(functions: list) -> int: |
362 | 361 |
|
363 | 362 |
|
364 | 363 | def _sum( |
365 | | - operators: Iterator[npt.NDArray | sparray | LinearOperator], |
366 | | -) -> npt.NDArray | sparray | LinearOperator: |
| 364 | + operators: Iterator[npt.NDArray | SparseArray | LinearOperator], |
| 365 | +) -> npt.NDArray | SparseArray | LinearOperator: |
367 | 366 | """ |
368 | 367 | Sum objects within an iterator. |
369 | 368 |
|
|
0 commit comments