Skip to content

Commit 1eacdc0

Browse files
committed
fix incorrectly rejected ndarray args in scipy.stats.qmc
1 parent 7f5eb48 commit 1eacdc0

File tree

2 files changed

+50
-28
lines changed

2 files changed

+50
-28
lines changed

scipy-stubs/stats/_qmc.pyi

Lines changed: 37 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
import abc
22
import numbers
33
from collections.abc import Callable, Mapping, Sequence
4-
from typing import Any, ClassVar, Final, Literal, Protocol, TypeAlias, TypeVar, overload, type_check_only
5-
from typing_extensions import Self
4+
from typing import Any, ClassVar, Final, Literal, Protocol, TypeAlias, overload, type_check_only
5+
from typing_extensions import Self, TypeVar
66

77
import numpy as np
88
import numpy.typing as npt
9+
import optype as op
910
import optype.numpy as onpt
11+
import optype.typing as opt
1012
from numpy._typing import _ArrayLikeInt
11-
from optype import CanBool, CanFloat, CanIndex, CanInt, CanLen
12-
from scipy._typing import RNG, AnyBool, AnyInt, AnyReal, Seed
13+
from scipy._typing import RNG, AnyInt, AnyReal, Seed
1314
from scipy.spatial.distance import _MetricCallback, _MetricKind
1415

1516
__all__ = [
@@ -34,7 +35,9 @@ _ArrayT_f = TypeVar("_ArrayT_f", bound=npt.NDArray[np.floating[Any]])
3435
_N = TypeVar("_N", bound=int)
3536

3637
@type_check_only
37-
class _CanLenArray(CanLen, onpt.CanArray[Any, np.dtype[_SCT_co]], Protocol[_SCT_co]): ...
38+
class _CanLenArray(Protocol[_SCT_co]):
39+
def __len__(self, /) -> int: ...
40+
def __array__(self, /) -> npt.NDArray[_SCT_co]: ...
3841

3942
_Scalar_f_co: TypeAlias = np.floating[Any] | np.integer[Any] | np.bool_
4043

@@ -66,19 +69,19 @@ class QMCEngine(abc.ABC):
6669

6770
@abc.abstractmethod
6871
def __init__(self, /, d: AnyInt, *, optimization: _MethodOptimize | None = None, seed: Seed | None = None) -> None: ...
69-
def random(self, /, n: AnyInt = 1, *, workers: AnyInt = 1) -> _Array2D_f8: ...
72+
def random(self, /, n: opt.AnyInt = 1, *, workers: AnyInt = 1) -> _Array2D_f8: ...
7073
def integers(
7174
self,
7275
/,
7376
l_bounds: _ArrayLikeInt,
7477
*,
7578
u_bounds: _ArrayLikeInt | None = None,
76-
n: AnyInt = 1,
77-
endpoint: AnyBool = False,
78-
workers: AnyInt = 1,
79+
n: opt.AnyInt = 1,
80+
endpoint: op.CanBool = False,
81+
workers: opt.AnyInt = 1,
7982
) -> _Array2D[np.int64]: ...
8083
def reset(self, /) -> Self: ...
81-
def fast_forward(self, /, n: AnyInt) -> Self: ...
84+
def fast_forward(self, /, n: opt.AnyInt) -> Self: ...
8285

8386
class Halton(QMCEngine):
8487
base: list[int]
@@ -121,7 +124,7 @@ class Sobol(QMCEngine):
121124
/,
122125
d: AnyInt,
123126
*,
124-
scramble: CanBool = True,
127+
scramble: op.CanBool = True,
125128
bits: AnyInt | None = None,
126129
seed: Seed | None = None,
127130
optimization: _MethodOptimize | None = None,
@@ -172,7 +175,7 @@ class MultivariateNormalQMC:
172175
cov: _Any2D_f_co | None = None,
173176
*,
174177
cov_root: _Any2D_f_co | None = None,
175-
inv_transform: CanBool = True,
178+
inv_transform: op.CanBool = True,
176179
engine: QMCEngine | None = None,
177180
seed: Seed | None = None,
178181
) -> None: ...
@@ -204,11 +207,17 @@ def scale(
204207
l_bounds: _Any1D_f_co | AnyReal,
205208
u_bounds: _Any1D_f_co | AnyReal,
206209
*,
207-
reverse: CanBool = False,
210+
reverse: op.CanBool = False,
208211
) -> _Array2D_f8: ...
209-
def discrepancy(sample: _Any2D_f, *, iterative: CanBool = False, method: _MethodDisc = "CD", workers: CanInt = 1) -> float: ...
212+
def discrepancy(
213+
sample: _Any2D_f,
214+
*,
215+
iterative: op.CanBool = False,
216+
method: _MethodDisc = "CD",
217+
workers: op.CanInt = 1,
218+
) -> float: ...
210219
def geometric_discrepancy(sample: _Any2D_f, method: _MethodDist = "mindist", metric: _MetricDist = "euclidean") -> np.float64: ...
211-
def update_discrepancy(x_new: _Any1D_f, sample: _Any2D_f, initial_disc: CanFloat) -> float: ...
220+
def update_discrepancy(x_new: _Any1D_f, sample: _Any2D_f, initial_disc: op.CanFloat) -> float: ...
212221
def primes_from_2_to(n: AnyInt) -> _Array1D[np.int_]: ...
213222
def n_primes(n: AnyInt) -> list[int] | _Array1D[np.int_]: ...
214223

@@ -230,41 +239,41 @@ def _ensure_in_unit_hypercube(sample: _Any2D_f) -> _Array2D_f8: ...
230239
@overload
231240
def _perturb_discrepancy(
232241
sample: _Array2D[np.integer[Any] | np.bool_],
233-
i1: CanIndex,
234-
i2: CanIndex,
235-
k: CanIndex,
242+
i1: op.CanIndex,
243+
i2: op.CanIndex,
244+
k: op.CanIndex,
236245
disc: AnyReal,
237246
) -> np.float64: ...
238247
@overload
239248
def _perturb_discrepancy(
240249
sample: _Array2D[_SCT_fc],
241-
i1: CanIndex,
242-
i2: CanIndex,
243-
k: CanIndex,
250+
i1: op.CanIndex,
251+
i2: op.CanIndex,
252+
k: op.CanIndex,
244253
disc: AnyReal,
245254
) -> _SCT_fc: ...
246255
@overload
247-
def _van_der_corput_permutation(base: CanIndex, *, random_state: Seed | None = None) -> _Array2D[np.int_]: ...
256+
def _van_der_corput_permutation(base: op.CanIndex, *, random_state: Seed | None = None) -> _Array2D[np.int_]: ...
248257
@overload
249-
def _van_der_corput_permutation(base: CanFloat, *, random_state: Seed | None = None) -> _Array2D_f8: ...
258+
def _van_der_corput_permutation(base: op.CanFloat, *, random_state: Seed | None = None) -> _Array2D_f8: ...
250259
def van_der_corput(
251-
n: CanInt,
260+
n: op.CanInt,
252261
base: AnyInt = 2,
253262
*,
254263
start_index: AnyInt = 0,
255-
scramble: CanBool = False,
264+
scramble: op.CanBool = False,
256265
permutations: _ArrayLikeInt | None = None,
257266
seed: Seed | None = None,
258-
workers: CanInt = 1,
267+
workers: op.CanInt = 1,
259268
) -> _Array1D_f8: ...
260269

261270
#
262271
@overload
263-
def _validate_workers(workers: CanInt[Literal[1]] | CanIndex[Literal[1]] | Literal[1] = 1) -> Literal[1]: ...
272+
def _validate_workers(workers: op.CanInt[Literal[1]] | op.CanIndex[Literal[1]] | Literal[1] = 1) -> Literal[1]: ...
264273
@overload
265274
def _validate_workers(workers: _N) -> _N: ...
266275
@overload
267-
def _validate_workers(workers: CanInt[_N] | CanIndex[_N]) -> _N: ...
276+
def _validate_workers(workers: op.CanInt[_N] | op.CanIndex[_N]) -> _N: ...
268277
def _validate_bounds(
269278
l_bounds: _Any1D_f_co,
270279
u_bounds: _Any1D_f_co,

tests/stats/qmc/test_scale.pyi

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from typing import Any
2+
3+
import numpy as np
4+
from scipy.stats import qmc
5+
6+
_f8_xd: np.ndarray[Any, np.dtype[np.float64]]
7+
qmc.scale(_f8_xd, 0, 1)
8+
9+
_f8_nd: np.ndarray[tuple[int, ...], np.dtype[np.float64]]
10+
qmc.scale(_f8_nd, 0, 1)
11+
12+
_f8_2d: np.ndarray[tuple[int, int], np.dtype[np.float64]]
13+
qmc.scale(_f8_2d, 0, 1)

0 commit comments

Comments
 (0)