Skip to content

Commit e01c558

Browse files
authored
Merge pull request #172 from jorenham/fix-168
fix incorrectly rejected `ndarray` in `scipy.stats.qmc`
2 parents 7f5eb48 + e27f9a0 commit e01c558

File tree

3 files changed

+124
-103
lines changed

3 files changed

+124
-103
lines changed

scipy-stubs/stats/_qmc.pyi

Lines changed: 91 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
import abc
22
import numbers
33
from collections.abc import Callable, Mapping, Sequence
4-
from typing import Any, ClassVar, Final, Literal, Protocol, TypeAlias, TypeVar, overload, type_check_only
5-
from typing_extensions import Self
4+
from typing import Any, ClassVar, Final, Literal, Protocol, TypeAlias, overload, type_check_only
5+
from typing_extensions import Self, TypeVar, override
66

77
import numpy as np
88
import numpy.typing as npt
9+
import optype as op
910
import optype.numpy as onpt
11+
import optype.typing as opt
1012
from numpy._typing import _ArrayLikeInt
11-
from optype import CanBool, CanFloat, CanIndex, CanInt, CanLen
12-
from scipy._typing import RNG, AnyBool, AnyInt, AnyReal, Seed
13+
from scipy._typing import RNG, AnyInt, AnyReal, Seed
1314
from scipy.spatial.distance import _MetricCallback, _MetricKind
1415

1516
__all__ = [
@@ -27,28 +28,31 @@ __all__ = [
2728
]
2829

2930
_RNGT = TypeVar("_RNGT", bound=np.random.Generator | np.random.RandomState)
30-
_SCT = TypeVar("_SCT", bound=np.generic)
31+
_SCT0 = TypeVar("_SCT0", bound=np.generic, default=np.float64)
3132
_SCT_co = TypeVar("_SCT_co", covariant=True, bound=np.generic)
3233
_SCT_fc = TypeVar("_SCT_fc", bound=np.inexact[Any])
3334
_ArrayT_f = TypeVar("_ArrayT_f", bound=npt.NDArray[np.floating[Any]])
3435
_N = TypeVar("_N", bound=int)
3536

37+
# the `__len__` ensures that scalar types like `np.generic` are excluded
3638
@type_check_only
37-
class _CanLenArray(CanLen, onpt.CanArray[Any, np.dtype[_SCT_co]], Protocol[_SCT_co]): ...
39+
class _CanLenArray(Protocol[_SCT_co]):
40+
def __len__(self, /) -> int: ...
41+
def __array__(self, /) -> npt.NDArray[_SCT_co]: ...
3842

3943
_Scalar_f_co: TypeAlias = np.floating[Any] | np.integer[Any] | np.bool_
44+
_ScalarLike_f: TypeAlias = float | np.floating[Any]
4045

41-
_Array1D: TypeAlias = onpt.Array[tuple[int], _SCT]
42-
_Array1D_f8: TypeAlias = _Array1D[np.float64]
43-
_Array2D: TypeAlias = onpt.Array[tuple[int, int], _SCT]
44-
_Array2D_f8: TypeAlias = _Array2D[np.float64]
46+
_Array1D: TypeAlias = onpt.Array[tuple[int], _SCT0]
47+
_Array2D: TypeAlias = onpt.Array[tuple[int, int], _SCT0]
48+
_Array1D_f_co: TypeAlias = _Array1D[_Scalar_f_co]
4549

46-
_Any1D_f: TypeAlias = _CanLenArray[np.floating[Any]] | Sequence[float | np.floating[Any]]
50+
_Any1D_f: TypeAlias = _CanLenArray[np.floating[Any]] | Sequence[_ScalarLike_f]
4751
_Any1D_f_co: TypeAlias = _CanLenArray[_Scalar_f_co] | Sequence[AnyReal]
48-
_Any2D_f: TypeAlias = _CanLenArray[np.floating[Any]] | Sequence[Sequence[float | np.floating[Any]]] | Sequence[_Any1D_f]
52+
_Any2D_f: TypeAlias = _CanLenArray[np.floating[Any]] | Sequence[Sequence[_ScalarLike_f]] | Sequence[_Any1D_f]
4953
_Any2D_f_co: TypeAlias = _CanLenArray[_Scalar_f_co] | Sequence[Sequence[AnyReal]] | Sequence[_Any1D_f_co]
5054

51-
_MethodOptimize: TypeAlias = Literal["random-cd", "lloyd"]
55+
_MethodQMC: TypeAlias = Literal["random-cd", "lloyd"]
5256
_MethodDisc: TypeAlias = Literal["CD", "WD", "MD", "L2-star"]
5357
_MethodDist: TypeAlias = Literal["mindist", "mst"]
5458
_MetricDist: TypeAlias = _MetricKind | _MetricCallback
@@ -65,20 +69,20 @@ class QMCEngine(abc.ABC):
6569
num_generated: int
6670

6771
@abc.abstractmethod
68-
def __init__(self, /, d: AnyInt, *, optimization: _MethodOptimize | None = None, seed: Seed | None = None) -> None: ...
69-
def random(self, /, n: AnyInt = 1, *, workers: AnyInt = 1) -> _Array2D_f8: ...
72+
def __init__(self, /, d: AnyInt, *, optimization: _MethodQMC | None = None, seed: Seed | None = None) -> None: ...
73+
def random(self, /, n: opt.AnyInt = 1, *, workers: AnyInt = 1) -> _Array2D: ...
7074
def integers(
7175
self,
7276
/,
7377
l_bounds: _ArrayLikeInt,
7478
*,
7579
u_bounds: _ArrayLikeInt | None = None,
76-
n: AnyInt = 1,
77-
endpoint: AnyBool = False,
78-
workers: AnyInt = 1,
80+
n: opt.AnyInt = 1,
81+
endpoint: op.CanBool = False,
82+
workers: opt.AnyInt = 1,
7983
) -> _Array2D[np.int64]: ...
8084
def reset(self, /) -> Self: ...
81-
def fast_forward(self, /, n: AnyInt) -> Self: ...
85+
def fast_forward(self, /, n: opt.AnyInt) -> Self: ...
8286

8387
class Halton(QMCEngine):
8488
base: list[int]
@@ -90,13 +94,13 @@ class Halton(QMCEngine):
9094
d: AnyInt,
9195
*,
9296
scramble: bool = True,
93-
optimization: _MethodOptimize | None = None,
97+
optimization: _MethodQMC | None = None,
9498
seed: Seed | None = None,
9599
) -> None: ...
96100

97101
class LatinHypercube(QMCEngine):
98102
scramble: bool
99-
lhs_method: Callable[[int | np.integer[Any]], _Array2D_f8]
103+
lhs_method: Callable[[int | np.integer[Any]], _Array2D]
100104

101105
def __init__(
102106
self,
@@ -105,7 +109,7 @@ class LatinHypercube(QMCEngine):
105109
*,
106110
scramble: bool = True,
107111
strength: int = 1,
108-
optimization: _MethodOptimize | None = None,
112+
optimization: _MethodQMC | None = None,
109113
seed: Seed | None = None,
110114
) -> None: ...
111115

@@ -121,12 +125,12 @@ class Sobol(QMCEngine):
121125
/,
122126
d: AnyInt,
123127
*,
124-
scramble: CanBool = True,
128+
scramble: op.CanBool = True,
125129
bits: AnyInt | None = None,
130+
optimization: _MethodQMC | None = None,
126131
seed: Seed | None = None,
127-
optimization: _MethodOptimize | None = None,
128132
) -> None: ...
129-
def random_base2(self, /, m: AnyInt) -> _Array2D_f8: ...
133+
def random_base2(self, /, m: AnyInt) -> _Array2D: ...
130134

131135
@type_check_only
132136
class _HypersphereMethod(Protocol):
@@ -136,7 +140,7 @@ class _HypersphereMethod(Protocol):
136140
center: npt.NDArray[_Scalar_f_co],
137141
radius: AnyReal,
138142
candidates: AnyInt = 1,
139-
) -> _Array2D_f8: ...
143+
) -> _Array2D: ...
140144

141145
class PoissonDisk(QMCEngine):
142146
hypersphere_method: Final[_HypersphereMethod]
@@ -147,7 +151,7 @@ class PoissonDisk(QMCEngine):
147151
cell_size: Final[np.float64]
148152
grid_size: Final[_Array1D[np.int_]]
149153

150-
sample_pool: list[_Array1D_f8]
154+
sample_pool: list[_Array1D]
151155
sample_grid: npt.NDArray[np.float32]
152156

153157
def __init__(
@@ -158,31 +162,36 @@ class PoissonDisk(QMCEngine):
158162
radius: AnyReal = 0.05,
159163
hypersphere: Literal["volume", "surface"] = "volume",
160164
ncandidates: AnyInt = 30,
161-
optimization: _MethodOptimize | None = None,
165+
optimization: _MethodQMC | None = None,
162166
seed: Seed | None = None,
163167
) -> None: ...
164-
def fill_space(self, /) -> _Array2D_f8: ...
168+
def fill_space(self, /) -> _Array2D: ...
165169

166-
class MultivariateNormalQMC:
167-
engine: Final[QMCEngine]
170+
@type_check_only
171+
class _QMCDistribution:
172+
engine: Final[QMCEngine] # defaults to `Sobol`
173+
def __init__(self, /, *, engine: QMCEngine | None = None, seed: Seed | None = None) -> None: ...
174+
def random(self, /, n: AnyInt = 1) -> _Array2D: ...
175+
176+
class MultivariateNormalQMC(_QMCDistribution):
177+
@override
168178
def __init__(
169179
self,
170180
/,
171181
mean: _Any1D_f_co,
172182
cov: _Any2D_f_co | None = None,
173183
*,
174184
cov_root: _Any2D_f_co | None = None,
175-
inv_transform: CanBool = True,
185+
inv_transform: op.CanBool = True,
176186
engine: QMCEngine | None = None,
177187
seed: Seed | None = None,
178188
) -> None: ...
179-
def random(self, /, n: AnyInt = 1) -> _Array2D_f8: ...
180189

181-
class MultinomialQMC:
190+
class MultinomialQMC(_QMCDistribution):
182191
pvals: Final[_Array1D[np.floating[Any]]]
183192
n_trials: Final[AnyInt]
184-
engine: Final[QMCEngine]
185193

194+
@override
186195
def __init__(
187196
self,
188197
/,
@@ -192,81 +201,98 @@ class MultinomialQMC:
192201
engine: QMCEngine | None = None,
193202
seed: Seed | None = None,
194203
) -> None: ...
195-
def random(self, /, n: AnyInt = 1) -> _Array2D_f8: ...
196204

197205
#
198206
@overload
199207
def check_random_state(seed: int | np.integer[Any] | numbers.Integral | None = None) -> np.random.Generator: ...
200208
@overload
201209
def check_random_state(seed: _RNGT) -> _RNGT: ...
210+
211+
#
202212
def scale(
203213
sample: _Any2D_f,
204214
l_bounds: _Any1D_f_co | AnyReal,
205215
u_bounds: _Any1D_f_co | AnyReal,
206216
*,
207-
reverse: CanBool = False,
208-
) -> _Array2D_f8: ...
209-
def discrepancy(sample: _Any2D_f, *, iterative: CanBool = False, method: _MethodDisc = "CD", workers: CanInt = 1) -> float: ...
210-
def geometric_discrepancy(sample: _Any2D_f, method: _MethodDist = "mindist", metric: _MetricDist = "euclidean") -> np.float64: ...
211-
def update_discrepancy(x_new: _Any1D_f, sample: _Any2D_f, initial_disc: CanFloat) -> float: ...
217+
reverse: op.CanBool = False,
218+
) -> _Array2D: ...
219+
220+
#
221+
def discrepancy(
222+
sample: _Any2D_f,
223+
*,
224+
iterative: op.CanBool = False,
225+
method: _MethodDisc = "CD",
226+
workers: opt.AnyInt = 1,
227+
) -> float | np.float64: ...
228+
229+
#
230+
def geometric_discrepancy(
231+
sample: _Any2D_f,
232+
method: _MethodDist = "mindist",
233+
metric: _MetricDist = "euclidean",
234+
) -> float | np.float64: ...
235+
def update_discrepancy(x_new: _Any1D_f, sample: _Any2D_f, initial_disc: opt.AnyFloat) -> float: ...
212236
def primes_from_2_to(n: AnyInt) -> _Array1D[np.int_]: ...
213237
def n_primes(n: AnyInt) -> list[int] | _Array1D[np.int_]: ...
214238

215239
#
216-
def _select_optimizer(optimization: _MethodOptimize | None, config: Mapping[str, object]) -> _FuncOptimize | None: ...
240+
def _select_optimizer(optimization: _MethodQMC | None, config: Mapping[str, object]) -> _FuncOptimize | None: ...
217241
def _random_cd(best_sample: _ArrayT_f, n_iters: AnyInt, n_nochange: AnyInt, rng: RNG) -> _ArrayT_f: ...
218-
def _l1_norm(sample: _Any2D_f) -> np.float64: ...
242+
def _l1_norm(sample: _Any2D_f) -> float | np.float64: ...
219243
def _lloyd_iteration(sample: _ArrayT_f, decay: AnyReal, qhull_options: str | None) -> _ArrayT_f: ...
220244
def _lloyd_centroidal_voronoi_tessellation(
221245
sample: _Any2D_f,
222246
*,
223247
tol: AnyReal = 1e-5,
224248
maxiter: AnyInt = 10,
225249
qhull_options: str | None = None,
226-
) -> _Array2D_f8: ...
250+
) -> _Array2D: ...
251+
def _ensure_in_unit_hypercube(sample: _Any2D_f) -> _Array2D: ...
227252

228253
#
229-
def _ensure_in_unit_hypercube(sample: _Any2D_f) -> _Array2D_f8: ...
230254
@overload
231255
def _perturb_discrepancy(
232256
sample: _Array2D[np.integer[Any] | np.bool_],
233-
i1: CanIndex,
234-
i2: CanIndex,
235-
k: CanIndex,
257+
i1: op.CanIndex,
258+
i2: op.CanIndex,
259+
k: op.CanIndex,
236260
disc: AnyReal,
237-
) -> np.float64: ...
261+
) -> float | np.float64: ...
238262
@overload
239263
def _perturb_discrepancy(
240264
sample: _Array2D[_SCT_fc],
241-
i1: CanIndex,
242-
i2: CanIndex,
243-
k: CanIndex,
265+
i1: op.CanIndex,
266+
i2: op.CanIndex,
267+
k: op.CanIndex,
244268
disc: AnyReal,
245269
) -> _SCT_fc: ...
270+
271+
#
246272
@overload
247-
def _van_der_corput_permutation(base: CanIndex, *, random_state: Seed | None = None) -> _Array2D[np.int_]: ...
273+
def _van_der_corput_permutation(base: op.CanIndex, *, random_state: Seed | None = None) -> _Array2D[np.int_]: ...
248274
@overload
249-
def _van_der_corput_permutation(base: CanFloat, *, random_state: Seed | None = None) -> _Array2D_f8: ...
275+
def _van_der_corput_permutation(base: op.CanFloat, *, random_state: Seed | None = None) -> _Array2D: ...
276+
277+
#
250278
def van_der_corput(
251-
n: CanInt,
279+
n: op.CanInt,
252280
base: AnyInt = 2,
253281
*,
254282
start_index: AnyInt = 0,
255-
scramble: CanBool = False,
283+
scramble: op.CanBool = False,
256284
permutations: _ArrayLikeInt | None = None,
257285
seed: Seed | None = None,
258-
workers: CanInt = 1,
259-
) -> _Array1D_f8: ...
286+
workers: opt.AnyInt = 1,
287+
) -> _Array1D: ...
260288

261289
#
262290
@overload
263-
def _validate_workers(workers: CanInt[Literal[1]] | CanIndex[Literal[1]] | Literal[1] = 1) -> Literal[1]: ...
291+
def _validate_workers(workers: Literal[1] = 1) -> Literal[1]: ...
264292
@overload
265293
def _validate_workers(workers: _N) -> _N: ...
266294
@overload
267-
def _validate_workers(workers: CanInt[_N] | CanIndex[_N]) -> _N: ...
268-
def _validate_bounds(
269-
l_bounds: _Any1D_f_co,
270-
u_bounds: _Any1D_f_co,
271-
d: AnyInt,
272-
) -> tuple[_Array1D[_Scalar_f_co], _Array1D[_Scalar_f_co]]: ...
295+
def _validate_workers(workers: opt.AnyInt[_N]) -> _N: ...
296+
297+
#
298+
def _validate_bounds(l_bounds: _Any1D_f_co, u_bounds: _Any1D_f_co, d: AnyInt) -> tuple[_Array1D_f_co, _Array1D_f_co]: ...

scipy-stubs/stats/_qmc_cy.pyi

Lines changed: 20 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,24 @@
1+
from typing import TypeAlias
2+
13
import numpy as np
4+
import optype as op
25
import optype.numpy as onpt
3-
from optype import CanBool, CanFloat, CanInt
6+
import optype.typing as opt
7+
8+
_Vector_i8: TypeAlias = onpt.Array[tuple[int, int], np.int64]
9+
_Vector_f8: TypeAlias = onpt.Array[tuple[int], np.float64]
10+
_Matrix_f8: TypeAlias = onpt.Array[tuple[int, int], np.float64]
411

5-
def _cy_wrapper_centered_discrepancy(
6-
sample: onpt.Array[tuple[int, int], np.float64],
7-
iterative: CanBool,
8-
workers: CanInt,
9-
) -> float: ...
10-
def _cy_wrapper_wrap_around_discrepancy(
11-
sample: onpt.Array[tuple[int, int], np.float64],
12-
iterative: CanBool,
13-
workers: CanInt,
14-
) -> float: ...
15-
def _cy_wrapper_mixture_discrepancy(
16-
sample: onpt.Array[tuple[int, int], np.float64],
17-
iterative: CanBool,
18-
workers: CanInt,
19-
) -> float: ...
20-
def _cy_wrapper_l2_star_discrepancy(
21-
sample: onpt.Array[tuple[int, int], np.float64],
22-
iterative: CanBool,
23-
workers: CanInt,
24-
) -> float: ...
25-
def _cy_wrapper_update_discrepancy(
26-
x_new_view: onpt.Array[tuple[int], np.float64],
27-
sample_view: onpt.Array[tuple[int, int], np.float64],
28-
initial_disc: CanFloat,
29-
) -> float: ...
30-
def _cy_van_der_corput(
31-
n: CanInt,
32-
base: CanInt,
33-
start_index: CanInt,
34-
workers: CanInt,
35-
) -> onpt.Array[tuple[int], np.float64]: ...
12+
def _cy_wrapper_centered_discrepancy(sample: _Matrix_f8, iterative: op.CanBool, workers: opt.AnyInt) -> float: ...
13+
def _cy_wrapper_wrap_around_discrepancy(sample: _Matrix_f8, iterative: op.CanBool, workers: opt.AnyInt) -> float: ...
14+
def _cy_wrapper_mixture_discrepancy(sample: _Matrix_f8, iterative: op.CanBool, workers: opt.AnyInt) -> float: ...
15+
def _cy_wrapper_l2_star_discrepancy(sample: _Matrix_f8, iterative: op.CanBool, workers: opt.AnyInt) -> float: ...
16+
def _cy_wrapper_update_discrepancy(x_new_view: _Vector_f8, sample_view: _Matrix_f8, initial_disc: opt.AnyFloat) -> float: ...
17+
def _cy_van_der_corput(n: opt.AnyInt, base: opt.AnyInt, start_index: opt.AnyInt, workers: opt.AnyInt) -> _Vector_f8: ...
3618
def _cy_van_der_corput_scrambled(
37-
n: CanInt,
38-
base: CanInt,
39-
start_index: CanInt,
40-
permutations: onpt.Array[tuple[int, int], np.int64],
41-
workers: CanInt,
42-
) -> onpt.Array[tuple[int], np.float64]: ...
19+
n: opt.AnyInt,
20+
base: opt.AnyInt,
21+
start_index: opt.AnyInt,
22+
permutations: _Vector_i8,
23+
workers: opt.AnyInt,
24+
) -> _Vector_f8: ...

tests/stats/qmc/test_scale.pyi

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from typing import Any
2+
3+
import numpy as np
4+
from scipy.stats import qmc
5+
6+
_f8_xd: np.ndarray[Any, np.dtype[np.float64]]
7+
qmc.scale(_f8_xd, 0, 1)
8+
9+
_f8_nd: np.ndarray[tuple[int, ...], np.dtype[np.float64]]
10+
qmc.scale(_f8_nd, 0, 1)
11+
12+
_f8_2d: np.ndarray[tuple[int, int], np.dtype[np.float64]]
13+
qmc.scale(_f8_2d, 0, 1)

0 commit comments

Comments
 (0)