Skip to content

Commit 684f161

Browse files
authored
Use the Model type alias on more places (#61)
1 parent 7322421 commit 684f161

File tree

9 files changed

+66
-57
lines changed

9 files changed

+66
-57
lines changed

src/inversion_ideas/base/conditions.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from rich.panel import Panel
1313
from rich.tree import Tree
1414

15+
from ..typing import Model
16+
1517

1618
def _get_info_title(condition, model) -> str:
1719
"""
@@ -30,9 +32,9 @@ class Condition(ABC):
3032
"""
3133

3234
@abstractmethod
33-
def __call__(self, model) -> bool: ...
35+
def __call__(self, model: Model) -> bool: ...
3436

35-
def update(self, model): # noqa: B027
37+
def update(self, model: Model): # noqa: B027
3638
"""
3739
Update the condition.
3840
"""
@@ -48,7 +50,7 @@ def initialize(self): # noqa: B027
4850
# necessary. The base class implements it to provide a common interface, even
4951
# for those children that don't implement it.
5052

51-
def info(self, model) -> Tree:
53+
def info(self, model: Model) -> Tree:
5254
"""
5355
Display information about the condition for a given model.
5456
"""
@@ -86,17 +88,17 @@ def __init__(self, condition_a, condition_b):
8688
self.condition_b = condition_b
8789

8890
@abstractmethod
89-
def __call__(self, model) -> bool: ...
91+
def __call__(self, model: Model) -> bool: ...
9092

91-
def update(self, model):
93+
def update(self, model: Model):
9294
"""
9395
Update the underlying conditions.
9496
"""
9597
for condition in (self.condition_a, self.condition_b):
9698
if hasattr(condition, "update"):
9799
condition.update(model)
98100

99-
def info(self, model) -> Tree:
101+
def info(self, model: Model) -> Tree:
100102
status = self(model)
101103
checkbox = "x" if status else " "
102104
color = "green" if status else "red"
@@ -128,7 +130,7 @@ class LogicalAnd(_Mixin, Condition):
128130
Mixin condition for the AND operation between two other conditions.
129131
"""
130132

131-
def __call__(self, model) -> bool:
133+
def __call__(self, model: Model) -> bool:
132134
return self.condition_a(model) and self.condition_b(model)
133135

134136

@@ -137,7 +139,7 @@ class LogicalOr(_Mixin, Condition):
137139
Mixin condition for the OR operation between two other conditions.
138140
"""
139141

140-
def __call__(self, model) -> bool:
142+
def __call__(self, model: Model) -> bool:
141143
return self.condition_a(model) or self.condition_b(model)
142144

143145

@@ -146,5 +148,5 @@ class LogicalXor(_Mixin, Condition):
146148
Mixin condition for the XOR operation between two other conditions.
147149
"""
148150

149-
def __call__(self, model) -> bool:
151+
def __call__(self, model: Model) -> bool:
150152
return self.condition_a(model) ^ self.condition_b(model)

src/inversion_ideas/base/objective_function.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from abc import ABC, abstractmethod
66
from collections.abc import Iterable, Iterator
77
from copy import copy
8+
from numbers import Real
89
from typing import Self
910

1011
import numpy as np
@@ -56,7 +57,7 @@ def hessian(
5657
"""
5758

5859
@abstractmethod
59-
def hessian_diagonal(self, model) -> npt.NDArray[np.float64]:
60+
def hessian_diagonal(self, model: Model) -> npt.NDArray[np.float64]:
6061
"""
6162
Diagonal of the Hessian.
6263
"""
@@ -135,7 +136,7 @@ def n_params(self) -> int:
135136
"""
136137
return self.function.n_params
137138

138-
def __call__(self, model):
139+
def __call__(self, model: Model):
139140
"""
140141
Evaluate the objective function.
141142
"""
@@ -155,7 +156,7 @@ def hessian(
155156
"""
156157
return self.multiplier * self.function.hessian(model)
157158

158-
def hessian_diagonal(self, model) -> npt.NDArray[np.float64]:
159+
def hessian_diagonal(self, model: Model) -> npt.NDArray[np.float64]:
159160
"""
160161
Diagonal of the Hessian.
161162
"""
@@ -187,11 +188,11 @@ def _repr_latex_(self):
187188
phi_str = f"[ {phi_str} ]"
188189
return rf"${multiplier_str} \, {phi_str}$"
189190

190-
def __imul__(self, value) -> Self:
191+
def __imul__(self, value: Real) -> Self:
191192
self.multiplier *= value
192193
return self
193194

194-
def __itruediv__(self, value) -> Self:
195+
def __itruediv__(self, value: Real) -> Self:
195196
self.multiplier /= value
196197
return self
197198

@@ -201,7 +202,7 @@ class Combo(Objective):
201202
Sum of objective functions.
202203
"""
203204

204-
def __init__(self, functions):
205+
def __init__(self, functions: list[Objective]):
205206
_get_n_params(functions) # check if functions have the same n_params
206207
self._functions = functions
207208

@@ -228,7 +229,7 @@ def n_params(self) -> int:
228229
"""
229230
return _get_n_params(self.functions)
230231

231-
def __call__(self, model):
232+
def __call__(self, model: Model):
232233
"""
233234
Evaluate the objective function.
234235
"""
@@ -248,7 +249,7 @@ def hessian(
248249
"""
249250
return _sum(f.hessian(model) for f in self.functions)
250251

251-
def hessian_diagonal(self, model) -> npt.NDArray[np.float64]:
252+
def hessian_diagonal(self, model: Model) -> npt.NDArray[np.float64]:
252253
"""
253254
Diagonal of the Hessian.
254255
"""

src/inversion_ideas/base/simulation.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from numpy.typing import NDArray
99
from scipy.sparse.linalg import LinearOperator
1010

11+
from ..typing import Model
12+
1113

1214
class Simulation(ABC):
1315
"""
@@ -33,13 +35,13 @@ def n_data(self) -> int:
3335
"""
3436

3537
@abstractmethod
36-
def __call__(self, model) -> NDArray[np.float64]:
38+
def __call__(self, model: Model) -> NDArray[np.float64]:
3739
"""
3840
Evaluate simulation for a given model.
3941
"""
4042

4143
@abstractmethod
42-
def jacobian(self, model) -> NDArray[np.float64] | LinearOperator:
44+
def jacobian(self, model: Model) -> NDArray[np.float64] | LinearOperator:
4345
"""
4446
Jacobian matrix for a given model.
4547
"""

src/inversion_ideas/conditions.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class CustomCondition(Condition):
3030
def __init__(self, func: Callable[[Model], bool]):
3131
self.func = func
3232

33-
def __call__(self, model) -> bool:
33+
def __call__(self, model: Model) -> bool:
3434
return self.func(model)
3535

3636
@classmethod
@@ -87,14 +87,14 @@ def __init__(self, data_misfit: DataMisfit, chi_target=1.0):
8787
self.data_misfit = data_misfit
8888
self.chi_target = chi_target
8989

90-
def __call__(self, model) -> bool:
90+
def __call__(self, model: Model) -> bool:
9191
"""
9292
Check if condition has been met.
9393
"""
9494
chi = self.data_misfit.chi_factor(model)
9595
return float(chi) < self.chi_target
9696

97-
def info(self, model) -> Tree:
97+
def info(self, model: Model) -> Tree:
9898
tree = super().info(model)
9999
tree.add("Condition: chi < chi_target")
100100
tree.add(f"chi = {self.data_misfit.chi_factor(model):.2e}")
@@ -148,20 +148,20 @@ def __init__(self, rtol: float = 1e-3, atol: float = 0.0):
148148
self.rtol = rtol
149149
self.atol = atol
150150

151-
def __call__(self, model) -> bool:
151+
def __call__(self, model: Model) -> bool:
152152
if not hasattr(self, "previous"):
153153
return False
154154
diff = float(np.linalg.norm(model - self.previous))
155155
previous = float(np.linalg.norm(self.previous))
156156
return diff <= max(previous * self.rtol, self.atol)
157157

158-
def update(self, model):
158+
def update(self, model: Model):
159159
"""
160160
Cache model as the ``previous`` one.
161161
"""
162162
self.previous = model
163163

164-
def info(self, model) -> Tree:
164+
def info(self, model: Model) -> Tree:
165165
tree = super().info(model)
166166
diff = float(np.linalg.norm(model - self.previous))
167167
previous = float(np.linalg.norm(self.previous))
@@ -229,20 +229,20 @@ def __init__(self, objective_function: Objective, rtol: float = 1e-3, atol=0.0):
229229
self.rtol = rtol
230230
self.atol = atol
231231

232-
def __call__(self, model) -> bool:
232+
def __call__(self, model: Model) -> bool:
233233
if not hasattr(self, "previous"):
234234
return False
235235
diff = abs(self.objective_function(model) - self.previous)
236236
previous = abs(self.previous)
237237
return diff <= max(previous * self.rtol, self.atol)
238238

239-
def update(self, model):
239+
def update(self, model: Model):
240240
"""
241241
Cache value of objective function with model as the ``previous`` one.
242242
"""
243243
self.previous: float = float(self.objective_function(model))
244244

245-
def info(self, model) -> Tree:
245+
def info(self, model: Model) -> Tree:
246246
tree = super().info(model)
247247
diff = abs(self.objective_function(model) - self.previous)
248248
previous = abs(self.previous)
@@ -253,7 +253,7 @@ def info(self, model) -> Tree:
253253
tree.add(f"atol = {self.atol:.2e}")
254254
return tree
255255

256-
def ratio(self, model) -> float:
256+
def ratio(self, model: Model) -> float:
257257
"""
258258
Ratio ``|φ(m) - φ(m_prev)|/|φ(m_prev)|``.
259259
"""

src/inversion_ideas/data_misfit.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from scipy.sparse.linalg import LinearOperator, aslinearoperator
99

1010
from .base import Objective
11+
from .typing import Model
1112
from .utils import cache_on_model
1213

1314

@@ -92,23 +93,25 @@ def __init__(
9293
self.set_name("d")
9394

9495
@cache_on_model
95-
def __call__(self, model) -> float:
96+
def __call__(self, model: Model) -> float:
9697
# TODO:
9798
# Cache invalidation: we should clean the cache if data or uncertainties change.
9899
# Or they should be immutable.
99100
residual = self.residual(model)
100101
weights_matrix = self.weights_matrix
101102
return residual.T @ weights_matrix.T @ weights_matrix @ residual
102103

103-
def gradient(self, model) -> npt.NDArray[np.float64]:
104+
def gradient(self, model: Model) -> npt.NDArray[np.float64]:
104105
"""
105106
Gradient vector.
106107
"""
107108
jac = self.simulation.jacobian(model)
108109
weights_matrix = self.weights_matrix
109110
return -2 * jac.T @ (weights_matrix.T @ weights_matrix @ self.residual(model))
110111

111-
def hessian(self, model) -> npt.NDArray[np.float64] | sparray | LinearOperator:
112+
def hessian(
113+
self, model: Model
114+
) -> npt.NDArray[np.float64] | sparray | LinearOperator:
112115
"""
113116
Hessian matrix.
114117
"""
@@ -118,7 +121,7 @@ def hessian(self, model) -> npt.NDArray[np.float64] | sparray | LinearOperator:
118121
weights_matrix = aslinearoperator(self.weights_matrix)
119122
return 2 * jac.T @ weights_matrix.T @ weights_matrix @ jac
120123

121-
def hessian_diagonal(self, model) -> npt.NDArray[np.float64]:
124+
def hessian_diagonal(self, model: Model) -> npt.NDArray[np.float64]:
122125
"""
123126
Diagonal of the Hessian.
124127
"""
@@ -146,7 +149,7 @@ def n_data(self):
146149
"""
147150
return self.data.size
148151

149-
def residual(self, model):
152+
def residual(self, model: Model):
150153
r"""
151154
Residual vector.
152155
@@ -187,7 +190,7 @@ def weights_matrix(self) -> dia_array:
187190
"""
188191
return diags_array(1 / self.uncertainty)
189192

190-
def chi_factor(self, model):
193+
def chi_factor(self, model: Model):
191194
"""
192195
Compute chi factor.
193196

src/inversion_ideas/minimize/_minimizers.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from typing import Any
88

99
import numpy as np
10-
import numpy.typing as npt
1110
from scipy.sparse.linalg import cg
1211

1312
from ..base import Condition, Minimizer, Objective
@@ -66,7 +65,7 @@ def __call__(
6665
preconditioner: Preconditioner
6766
| Callable[[Model], Preconditioner]
6867
| None = None,
69-
) -> Generator[npt.NDArray[np.float64]]:
68+
) -> Generator[Model]:
7069
"""
7170
Create iterator over Gauss-Newton minimization.
7271
"""

src/inversion_ideas/regularization/_general.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from .._utils import prod_arrays
1010
from ..base import Objective
11+
from ..typing import Model
1112

1213

1314
class TikhonovZero(Objective):
@@ -62,7 +63,7 @@ def __init__(
6263
)
6364
self.set_name("0")
6465

65-
def __call__(self, model) -> float:
66+
def __call__(self, model: Model) -> float:
6667
"""
6768
Evaluate the regularization on a given model.
6869
@@ -75,7 +76,7 @@ def __call__(self, model) -> float:
7576
weights_matrix = self.weights_matrix
7677
return model_diff.T @ weights_matrix.T @ weights_matrix @ model_diff
7778

78-
def gradient(self, model):
79+
def gradient(self, model: Model):
7980
"""
8081
Gradient vector.
8182
@@ -88,7 +89,7 @@ def gradient(self, model):
8889
weights_matrix = self.weights_matrix
8990
return 2 * weights_matrix.T @ weights_matrix @ model_diff
9091

91-
def hessian(self, model): # noqa: ARG002
92+
def hessian(self, model: Model): # noqa: ARG002
9293
"""
9394
Hessian matrix.
9495
@@ -100,7 +101,7 @@ def hessian(self, model): # noqa: ARG002
100101
weights_matrix = self.weights_matrix
101102
return 2 * weights_matrix.T @ weights_matrix
102103

103-
def hessian_diagonal(self, model) -> npt.NDArray[np.float64]:
104+
def hessian_diagonal(self, model: Model) -> npt.NDArray[np.float64]:
104105
"""
105106
Diagonal of the Hessian.
106107

0 commit comments

Comments
 (0)