Skip to content

Commit e27f9a0

Browse files
committed
cleaner typing in scipy.stats.qmc
1 parent 1eacdc0 commit e27f9a0

File tree

2 files changed

+83
-84
lines changed

2 files changed

+83
-84
lines changed

scipy-stubs/stats/_qmc.pyi

Lines changed: 63 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import abc
22
import numbers
33
from collections.abc import Callable, Mapping, Sequence
44
from typing import Any, ClassVar, Final, Literal, Protocol, TypeAlias, overload, type_check_only
5-
from typing_extensions import Self, TypeVar
5+
from typing_extensions import Self, TypeVar, override
66

77
import numpy as np
88
import numpy.typing as npt
@@ -28,30 +28,31 @@ __all__ = [
2828
]
2929

3030
_RNGT = TypeVar("_RNGT", bound=np.random.Generator | np.random.RandomState)
31-
_SCT = TypeVar("_SCT", bound=np.generic)
31+
_SCT0 = TypeVar("_SCT0", bound=np.generic, default=np.float64)
3232
_SCT_co = TypeVar("_SCT_co", covariant=True, bound=np.generic)
3333
_SCT_fc = TypeVar("_SCT_fc", bound=np.inexact[Any])
3434
_ArrayT_f = TypeVar("_ArrayT_f", bound=npt.NDArray[np.floating[Any]])
3535
_N = TypeVar("_N", bound=int)
3636

37+
# the `__len__` ensures that scalar types like `np.generic` are excluded
3738
@type_check_only
3839
class _CanLenArray(Protocol[_SCT_co]):
3940
def __len__(self, /) -> int: ...
4041
def __array__(self, /) -> npt.NDArray[_SCT_co]: ...
4142

4243
_Scalar_f_co: TypeAlias = np.floating[Any] | np.integer[Any] | np.bool_
44+
_ScalarLike_f: TypeAlias = float | np.floating[Any]
4345

44-
_Array1D: TypeAlias = onpt.Array[tuple[int], _SCT]
45-
_Array1D_f8: TypeAlias = _Array1D[np.float64]
46-
_Array2D: TypeAlias = onpt.Array[tuple[int, int], _SCT]
47-
_Array2D_f8: TypeAlias = _Array2D[np.float64]
46+
_Array1D: TypeAlias = onpt.Array[tuple[int], _SCT0]
47+
_Array2D: TypeAlias = onpt.Array[tuple[int, int], _SCT0]
48+
_Array1D_f_co: TypeAlias = _Array1D[_Scalar_f_co]
4849

49-
_Any1D_f: TypeAlias = _CanLenArray[np.floating[Any]] | Sequence[float | np.floating[Any]]
50+
_Any1D_f: TypeAlias = _CanLenArray[np.floating[Any]] | Sequence[_ScalarLike_f]
5051
_Any1D_f_co: TypeAlias = _CanLenArray[_Scalar_f_co] | Sequence[AnyReal]
51-
_Any2D_f: TypeAlias = _CanLenArray[np.floating[Any]] | Sequence[Sequence[float | np.floating[Any]]] | Sequence[_Any1D_f]
52+
_Any2D_f: TypeAlias = _CanLenArray[np.floating[Any]] | Sequence[Sequence[_ScalarLike_f]] | Sequence[_Any1D_f]
5253
_Any2D_f_co: TypeAlias = _CanLenArray[_Scalar_f_co] | Sequence[Sequence[AnyReal]] | Sequence[_Any1D_f_co]
5354

54-
_MethodOptimize: TypeAlias = Literal["random-cd", "lloyd"]
55+
_MethodQMC: TypeAlias = Literal["random-cd", "lloyd"]
5556
_MethodDisc: TypeAlias = Literal["CD", "WD", "MD", "L2-star"]
5657
_MethodDist: TypeAlias = Literal["mindist", "mst"]
5758
_MetricDist: TypeAlias = _MetricKind | _MetricCallback
@@ -68,8 +69,8 @@ class QMCEngine(abc.ABC):
6869
num_generated: int
6970

7071
@abc.abstractmethod
71-
def __init__(self, /, d: AnyInt, *, optimization: _MethodOptimize | None = None, seed: Seed | None = None) -> None: ...
72-
def random(self, /, n: opt.AnyInt = 1, *, workers: AnyInt = 1) -> _Array2D_f8: ...
72+
def __init__(self, /, d: AnyInt, *, optimization: _MethodQMC | None = None, seed: Seed | None = None) -> None: ...
73+
def random(self, /, n: opt.AnyInt = 1, *, workers: AnyInt = 1) -> _Array2D: ...
7374
def integers(
7475
self,
7576
/,
@@ -93,13 +94,13 @@ class Halton(QMCEngine):
9394
d: AnyInt,
9495
*,
9596
scramble: bool = True,
96-
optimization: _MethodOptimize | None = None,
97+
optimization: _MethodQMC | None = None,
9798
seed: Seed | None = None,
9899
) -> None: ...
99100

100101
class LatinHypercube(QMCEngine):
101102
scramble: bool
102-
lhs_method: Callable[[int | np.integer[Any]], _Array2D_f8]
103+
lhs_method: Callable[[int | np.integer[Any]], _Array2D]
103104

104105
def __init__(
105106
self,
@@ -108,7 +109,7 @@ class LatinHypercube(QMCEngine):
108109
*,
109110
scramble: bool = True,
110111
strength: int = 1,
111-
optimization: _MethodOptimize | None = None,
112+
optimization: _MethodQMC | None = None,
112113
seed: Seed | None = None,
113114
) -> None: ...
114115

@@ -126,10 +127,10 @@ class Sobol(QMCEngine):
126127
*,
127128
scramble: op.CanBool = True,
128129
bits: AnyInt | None = None,
130+
optimization: _MethodQMC | None = None,
129131
seed: Seed | None = None,
130-
optimization: _MethodOptimize | None = None,
131132
) -> None: ...
132-
def random_base2(self, /, m: AnyInt) -> _Array2D_f8: ...
133+
def random_base2(self, /, m: AnyInt) -> _Array2D: ...
133134

134135
@type_check_only
135136
class _HypersphereMethod(Protocol):
@@ -139,7 +140,7 @@ class _HypersphereMethod(Protocol):
139140
center: npt.NDArray[_Scalar_f_co],
140141
radius: AnyReal,
141142
candidates: AnyInt = 1,
142-
) -> _Array2D_f8: ...
143+
) -> _Array2D: ...
143144

144145
class PoissonDisk(QMCEngine):
145146
hypersphere_method: Final[_HypersphereMethod]
@@ -150,7 +151,7 @@ class PoissonDisk(QMCEngine):
150151
cell_size: Final[np.float64]
151152
grid_size: Final[_Array1D[np.int_]]
152153

153-
sample_pool: list[_Array1D_f8]
154+
sample_pool: list[_Array1D]
154155
sample_grid: npt.NDArray[np.float32]
155156

156157
def __init__(
@@ -161,13 +162,19 @@ class PoissonDisk(QMCEngine):
161162
radius: AnyReal = 0.05,
162163
hypersphere: Literal["volume", "surface"] = "volume",
163164
ncandidates: AnyInt = 30,
164-
optimization: _MethodOptimize | None = None,
165+
optimization: _MethodQMC | None = None,
165166
seed: Seed | None = None,
166167
) -> None: ...
167-
def fill_space(self, /) -> _Array2D_f8: ...
168+
def fill_space(self, /) -> _Array2D: ...
168169

169-
class MultivariateNormalQMC:
170-
engine: Final[QMCEngine]
170+
@type_check_only
171+
class _QMCDistribution:
172+
engine: Final[QMCEngine] # defaults to `Sobol`
173+
def __init__(self, /, *, engine: QMCEngine | None = None, seed: Seed | None = None) -> None: ...
174+
def random(self, /, n: AnyInt = 1) -> _Array2D: ...
175+
176+
class MultivariateNormalQMC(_QMCDistribution):
177+
@override
171178
def __init__(
172179
self,
173180
/,
@@ -179,13 +186,12 @@ class MultivariateNormalQMC:
179186
engine: QMCEngine | None = None,
180187
seed: Seed | None = None,
181188
) -> None: ...
182-
def random(self, /, n: AnyInt = 1) -> _Array2D_f8: ...
183189

184-
class MultinomialQMC:
190+
class MultinomialQMC(_QMCDistribution):
185191
pvals: Final[_Array1D[np.floating[Any]]]
186192
n_trials: Final[AnyInt]
187-
engine: Final[QMCEngine]
188193

194+
@override
189195
def __init__(
190196
self,
191197
/,
@@ -195,55 +201,64 @@ class MultinomialQMC:
195201
engine: QMCEngine | None = None,
196202
seed: Seed | None = None,
197203
) -> None: ...
198-
def random(self, /, n: AnyInt = 1) -> _Array2D_f8: ...
199204

200205
#
201206
@overload
202207
def check_random_state(seed: int | np.integer[Any] | numbers.Integral | None = None) -> np.random.Generator: ...
203208
@overload
204209
def check_random_state(seed: _RNGT) -> _RNGT: ...
210+
211+
#
205212
def scale(
206213
sample: _Any2D_f,
207214
l_bounds: _Any1D_f_co | AnyReal,
208215
u_bounds: _Any1D_f_co | AnyReal,
209216
*,
210217
reverse: op.CanBool = False,
211-
) -> _Array2D_f8: ...
218+
) -> _Array2D: ...
219+
220+
#
212221
def discrepancy(
213222
sample: _Any2D_f,
214223
*,
215224
iterative: op.CanBool = False,
216225
method: _MethodDisc = "CD",
217-
workers: op.CanInt = 1,
218-
) -> float: ...
219-
def geometric_discrepancy(sample: _Any2D_f, method: _MethodDist = "mindist", metric: _MetricDist = "euclidean") -> np.float64: ...
220-
def update_discrepancy(x_new: _Any1D_f, sample: _Any2D_f, initial_disc: op.CanFloat) -> float: ...
226+
workers: opt.AnyInt = 1,
227+
) -> float | np.float64: ...
228+
229+
#
230+
def geometric_discrepancy(
231+
sample: _Any2D_f,
232+
method: _MethodDist = "mindist",
233+
metric: _MetricDist = "euclidean",
234+
) -> float | np.float64: ...
235+
def update_discrepancy(x_new: _Any1D_f, sample: _Any2D_f, initial_disc: opt.AnyFloat) -> float: ...
221236
def primes_from_2_to(n: AnyInt) -> _Array1D[np.int_]: ...
222237
def n_primes(n: AnyInt) -> list[int] | _Array1D[np.int_]: ...
223238

224239
#
225-
def _select_optimizer(optimization: _MethodOptimize | None, config: Mapping[str, object]) -> _FuncOptimize | None: ...
240+
def _select_optimizer(optimization: _MethodQMC | None, config: Mapping[str, object]) -> _FuncOptimize | None: ...
226241
def _random_cd(best_sample: _ArrayT_f, n_iters: AnyInt, n_nochange: AnyInt, rng: RNG) -> _ArrayT_f: ...
227-
def _l1_norm(sample: _Any2D_f) -> np.float64: ...
242+
def _l1_norm(sample: _Any2D_f) -> float | np.float64: ...
228243
def _lloyd_iteration(sample: _ArrayT_f, decay: AnyReal, qhull_options: str | None) -> _ArrayT_f: ...
229244
def _lloyd_centroidal_voronoi_tessellation(
230245
sample: _Any2D_f,
231246
*,
232247
tol: AnyReal = 1e-5,
233248
maxiter: AnyInt = 10,
234249
qhull_options: str | None = None,
235-
) -> _Array2D_f8: ...
250+
) -> _Array2D: ...
251+
def _ensure_in_unit_hypercube(sample: _Any2D_f) -> _Array2D: ...
236252

237253
#
238-
def _ensure_in_unit_hypercube(sample: _Any2D_f) -> _Array2D_f8: ...
239254
@overload
240255
def _perturb_discrepancy(
241256
sample: _Array2D[np.integer[Any] | np.bool_],
242257
i1: op.CanIndex,
243258
i2: op.CanIndex,
244259
k: op.CanIndex,
245260
disc: AnyReal,
246-
) -> np.float64: ...
261+
) -> float | np.float64: ...
247262
@overload
248263
def _perturb_discrepancy(
249264
sample: _Array2D[_SCT_fc],
@@ -252,10 +267,14 @@ def _perturb_discrepancy(
252267
k: op.CanIndex,
253268
disc: AnyReal,
254269
) -> _SCT_fc: ...
270+
271+
#
255272
@overload
256273
def _van_der_corput_permutation(base: op.CanIndex, *, random_state: Seed | None = None) -> _Array2D[np.int_]: ...
257274
@overload
258-
def _van_der_corput_permutation(base: op.CanFloat, *, random_state: Seed | None = None) -> _Array2D_f8: ...
275+
def _van_der_corput_permutation(base: op.CanFloat, *, random_state: Seed | None = None) -> _Array2D: ...
276+
277+
#
259278
def van_der_corput(
260279
n: op.CanInt,
261280
base: AnyInt = 2,
@@ -264,18 +283,16 @@ def van_der_corput(
264283
scramble: op.CanBool = False,
265284
permutations: _ArrayLikeInt | None = None,
266285
seed: Seed | None = None,
267-
workers: op.CanInt = 1,
268-
) -> _Array1D_f8: ...
286+
workers: opt.AnyInt = 1,
287+
) -> _Array1D: ...
269288

270289
#
271290
@overload
272-
def _validate_workers(workers: op.CanInt[Literal[1]] | op.CanIndex[Literal[1]] | Literal[1] = 1) -> Literal[1]: ...
291+
def _validate_workers(workers: Literal[1] = 1) -> Literal[1]: ...
273292
@overload
274293
def _validate_workers(workers: _N) -> _N: ...
275294
@overload
276-
def _validate_workers(workers: op.CanInt[_N] | op.CanIndex[_N]) -> _N: ...
277-
def _validate_bounds(
278-
l_bounds: _Any1D_f_co,
279-
u_bounds: _Any1D_f_co,
280-
d: AnyInt,
281-
) -> tuple[_Array1D[_Scalar_f_co], _Array1D[_Scalar_f_co]]: ...
295+
def _validate_workers(workers: opt.AnyInt[_N]) -> _N: ...
296+
297+
#
298+
def _validate_bounds(l_bounds: _Any1D_f_co, u_bounds: _Any1D_f_co, d: AnyInt) -> tuple[_Array1D_f_co, _Array1D_f_co]: ...

scipy-stubs/stats/_qmc_cy.pyi

Lines changed: 20 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,24 @@
1+
from typing import TypeAlias
2+
13
import numpy as np
4+
import optype as op
25
import optype.numpy as onpt
3-
from optype import CanBool, CanFloat, CanInt
6+
import optype.typing as opt
7+
8+
_Vector_i8: TypeAlias = onpt.Array[tuple[int, int], np.int64]
9+
_Vector_f8: TypeAlias = onpt.Array[tuple[int], np.float64]
10+
_Matrix_f8: TypeAlias = onpt.Array[tuple[int, int], np.float64]
411

5-
def _cy_wrapper_centered_discrepancy(
6-
sample: onpt.Array[tuple[int, int], np.float64],
7-
iterative: CanBool,
8-
workers: CanInt,
9-
) -> float: ...
10-
def _cy_wrapper_wrap_around_discrepancy(
11-
sample: onpt.Array[tuple[int, int], np.float64],
12-
iterative: CanBool,
13-
workers: CanInt,
14-
) -> float: ...
15-
def _cy_wrapper_mixture_discrepancy(
16-
sample: onpt.Array[tuple[int, int], np.float64],
17-
iterative: CanBool,
18-
workers: CanInt,
19-
) -> float: ...
20-
def _cy_wrapper_l2_star_discrepancy(
21-
sample: onpt.Array[tuple[int, int], np.float64],
22-
iterative: CanBool,
23-
workers: CanInt,
24-
) -> float: ...
25-
def _cy_wrapper_update_discrepancy(
26-
x_new_view: onpt.Array[tuple[int], np.float64],
27-
sample_view: onpt.Array[tuple[int, int], np.float64],
28-
initial_disc: CanFloat,
29-
) -> float: ...
30-
def _cy_van_der_corput(
31-
n: CanInt,
32-
base: CanInt,
33-
start_index: CanInt,
34-
workers: CanInt,
35-
) -> onpt.Array[tuple[int], np.float64]: ...
12+
def _cy_wrapper_centered_discrepancy(sample: _Matrix_f8, iterative: op.CanBool, workers: opt.AnyInt) -> float: ...
13+
def _cy_wrapper_wrap_around_discrepancy(sample: _Matrix_f8, iterative: op.CanBool, workers: opt.AnyInt) -> float: ...
14+
def _cy_wrapper_mixture_discrepancy(sample: _Matrix_f8, iterative: op.CanBool, workers: opt.AnyInt) -> float: ...
15+
def _cy_wrapper_l2_star_discrepancy(sample: _Matrix_f8, iterative: op.CanBool, workers: opt.AnyInt) -> float: ...
16+
def _cy_wrapper_update_discrepancy(x_new_view: _Vector_f8, sample_view: _Matrix_f8, initial_disc: opt.AnyFloat) -> float: ...
17+
def _cy_van_der_corput(n: opt.AnyInt, base: opt.AnyInt, start_index: opt.AnyInt, workers: opt.AnyInt) -> _Vector_f8: ...
3618
def _cy_van_der_corput_scrambled(
37-
n: CanInt,
38-
base: CanInt,
39-
start_index: CanInt,
40-
permutations: onpt.Array[tuple[int, int], np.int64],
41-
workers: CanInt,
42-
) -> onpt.Array[tuple[int], np.float64]: ...
19+
n: opt.AnyInt,
20+
base: opt.AnyInt,
21+
start_index: opt.AnyInt,
22+
permutations: _Vector_i8,
23+
workers: opt.AnyInt,
24+
) -> _Vector_f8: ...

0 commit comments

Comments
 (0)