11import abc
22import numbers
33from collections .abc import Callable , Mapping , Sequence
4- from typing import Any , ClassVar , Final , Literal , Protocol , TypeAlias , TypeVar , overload , type_check_only
5- from typing_extensions import Self
4+ from typing import Any , ClassVar , Final , Literal , Protocol , TypeAlias , overload , type_check_only
5+ from typing_extensions import Self , TypeVar , override
66
77import numpy as np
88import numpy .typing as npt
9+ import optype as op
910import optype .numpy as onpt
11+ import optype .typing as opt
1012from numpy ._typing import _ArrayLikeInt
11- from optype import CanBool , CanFloat , CanIndex , CanInt , CanLen
12- from scipy ._typing import RNG , AnyBool , AnyInt , AnyReal , Seed
13+ from scipy ._typing import RNG , AnyInt , AnyReal , Seed
1314from scipy .spatial .distance import _MetricCallback , _MetricKind
1415
1516__all__ = [
@@ -27,28 +28,31 @@ __all__ = [
2728]
2829
2930_RNGT = TypeVar ("_RNGT" , bound = np .random .Generator | np .random .RandomState )
30- _SCT = TypeVar ("_SCT " , bound = np .generic )
31+ _SCT0 = TypeVar ("_SCT0 " , bound = np .generic , default = np . float64 )
3132_SCT_co = TypeVar ("_SCT_co" , covariant = True , bound = np .generic )
3233_SCT_fc = TypeVar ("_SCT_fc" , bound = np .inexact [Any ])
3334_ArrayT_f = TypeVar ("_ArrayT_f" , bound = npt .NDArray [np .floating [Any ]])
3435_N = TypeVar ("_N" , bound = int )
3536
37+ # the `__len__` ensures that scalar types like `np.generic` are excluded
3638@type_check_only
37- class _CanLenArray (CanLen , onpt .CanArray [Any , np .dtype [_SCT_co ]], Protocol [_SCT_co ]): ...
39+ class _CanLenArray (Protocol [_SCT_co ]):
40+ def __len__ (self , / ) -> int : ...
41+ def __array__ (self , / ) -> npt .NDArray [_SCT_co ]: ...
3842
3943_Scalar_f_co : TypeAlias = np .floating [Any ] | np .integer [Any ] | np .bool_
44+ _ScalarLike_f : TypeAlias = float | np .floating [Any ]
4045
41- _Array1D : TypeAlias = onpt .Array [tuple [int ], _SCT ]
42- _Array1D_f8 : TypeAlias = _Array1D [np .float64 ]
43- _Array2D : TypeAlias = onpt .Array [tuple [int , int ], _SCT ]
44- _Array2D_f8 : TypeAlias = _Array2D [np .float64 ]
46+ _Array1D : TypeAlias = onpt .Array [tuple [int ], _SCT0 ]
47+ _Array2D : TypeAlias = onpt .Array [tuple [int , int ], _SCT0 ]
48+ _Array1D_f_co : TypeAlias = _Array1D [_Scalar_f_co ]
4549
46- _Any1D_f : TypeAlias = _CanLenArray [np .floating [Any ]] | Sequence [float | np . floating [ Any ] ]
50+ _Any1D_f : TypeAlias = _CanLenArray [np .floating [Any ]] | Sequence [_ScalarLike_f ]
4751_Any1D_f_co : TypeAlias = _CanLenArray [_Scalar_f_co ] | Sequence [AnyReal ]
48- _Any2D_f : TypeAlias = _CanLenArray [np .floating [Any ]] | Sequence [Sequence [float | np . floating [ Any ] ]] | Sequence [_Any1D_f ]
52+ _Any2D_f : TypeAlias = _CanLenArray [np .floating [Any ]] | Sequence [Sequence [_ScalarLike_f ]] | Sequence [_Any1D_f ]
4953_Any2D_f_co : TypeAlias = _CanLenArray [_Scalar_f_co ] | Sequence [Sequence [AnyReal ]] | Sequence [_Any1D_f_co ]
5054
51- _MethodOptimize : TypeAlias = Literal ["random-cd" , "lloyd" ]
55+ _MethodQMC : TypeAlias = Literal ["random-cd" , "lloyd" ]
5256_MethodDisc : TypeAlias = Literal ["CD" , "WD" , "MD" , "L2-star" ]
5357_MethodDist : TypeAlias = Literal ["mindist" , "mst" ]
5458_MetricDist : TypeAlias = _MetricKind | _MetricCallback
@@ -65,20 +69,20 @@ class QMCEngine(abc.ABC):
6569 num_generated : int
6670
6771 @abc .abstractmethod
68- def __init__ (self , / , d : AnyInt , * , optimization : _MethodOptimize | None = None , seed : Seed | None = None ) -> None : ...
69- def random (self , / , n : AnyInt = 1 , * , workers : AnyInt = 1 ) -> _Array2D_f8 : ...
72+ def __init__ (self , / , d : AnyInt , * , optimization : _MethodQMC | None = None , seed : Seed | None = None ) -> None : ...
73+ def random (self , / , n : opt . AnyInt = 1 , * , workers : AnyInt = 1 ) -> _Array2D : ...
7074 def integers (
7175 self ,
7276 / ,
7377 l_bounds : _ArrayLikeInt ,
7478 * ,
7579 u_bounds : _ArrayLikeInt | None = None ,
76- n : AnyInt = 1 ,
77- endpoint : AnyBool = False ,
78- workers : AnyInt = 1 ,
80+ n : opt . AnyInt = 1 ,
81+ endpoint : op . CanBool = False ,
82+ workers : opt . AnyInt = 1 ,
7983 ) -> _Array2D [np .int64 ]: ...
8084 def reset (self , / ) -> Self : ...
81- def fast_forward (self , / , n : AnyInt ) -> Self : ...
85+ def fast_forward (self , / , n : opt . AnyInt ) -> Self : ...
8286
8387class Halton (QMCEngine ):
8488 base : list [int ]
@@ -90,13 +94,13 @@ class Halton(QMCEngine):
9094 d : AnyInt ,
9195 * ,
9296 scramble : bool = True ,
93- optimization : _MethodOptimize | None = None ,
97+ optimization : _MethodQMC | None = None ,
9498 seed : Seed | None = None ,
9599 ) -> None : ...
96100
97101class LatinHypercube (QMCEngine ):
98102 scramble : bool
99- lhs_method : Callable [[int | np .integer [Any ]], _Array2D_f8 ]
103+ lhs_method : Callable [[int | np .integer [Any ]], _Array2D ]
100104
101105 def __init__ (
102106 self ,
@@ -105,7 +109,7 @@ class LatinHypercube(QMCEngine):
105109 * ,
106110 scramble : bool = True ,
107111 strength : int = 1 ,
108- optimization : _MethodOptimize | None = None ,
112+ optimization : _MethodQMC | None = None ,
109113 seed : Seed | None = None ,
110114 ) -> None : ...
111115
@@ -121,12 +125,12 @@ class Sobol(QMCEngine):
121125 / ,
122126 d : AnyInt ,
123127 * ,
124- scramble : CanBool = True ,
128+ scramble : op . CanBool = True ,
125129 bits : AnyInt | None = None ,
130+ optimization : _MethodQMC | None = None ,
126131 seed : Seed | None = None ,
127- optimization : _MethodOptimize | None = None ,
128132 ) -> None : ...
129- def random_base2 (self , / , m : AnyInt ) -> _Array2D_f8 : ...
133+ def random_base2 (self , / , m : AnyInt ) -> _Array2D : ...
130134
131135@type_check_only
132136class _HypersphereMethod (Protocol ):
@@ -136,7 +140,7 @@ class _HypersphereMethod(Protocol):
136140 center : npt .NDArray [_Scalar_f_co ],
137141 radius : AnyReal ,
138142 candidates : AnyInt = 1 ,
139- ) -> _Array2D_f8 : ...
143+ ) -> _Array2D : ...
140144
141145class PoissonDisk (QMCEngine ):
142146 hypersphere_method : Final [_HypersphereMethod ]
@@ -147,7 +151,7 @@ class PoissonDisk(QMCEngine):
147151 cell_size : Final [np .float64 ]
148152 grid_size : Final [_Array1D [np .int_ ]]
149153
150- sample_pool : list [_Array1D_f8 ]
154+ sample_pool : list [_Array1D ]
151155 sample_grid : npt .NDArray [np .float32 ]
152156
153157 def __init__ (
@@ -158,31 +162,36 @@ class PoissonDisk(QMCEngine):
158162 radius : AnyReal = 0.05 ,
159163 hypersphere : Literal ["volume" , "surface" ] = "volume" ,
160164 ncandidates : AnyInt = 30 ,
161- optimization : _MethodOptimize | None = None ,
165+ optimization : _MethodQMC | None = None ,
162166 seed : Seed | None = None ,
163167 ) -> None : ...
164- def fill_space (self , / ) -> _Array2D_f8 : ...
168+ def fill_space (self , / ) -> _Array2D : ...
165169
166- class MultivariateNormalQMC :
167- engine : Final [QMCEngine ]
170+ @type_check_only
171+ class _QMCDistribution :
172+ engine : Final [QMCEngine ] # defaults to `Sobol`
173+ def __init__ (self , / , * , engine : QMCEngine | None = None , seed : Seed | None = None ) -> None : ...
174+ def random (self , / , n : AnyInt = 1 ) -> _Array2D : ...
175+
176+ class MultivariateNormalQMC (_QMCDistribution ):
177+ @override
168178 def __init__ (
169179 self ,
170180 / ,
171181 mean : _Any1D_f_co ,
172182 cov : _Any2D_f_co | None = None ,
173183 * ,
174184 cov_root : _Any2D_f_co | None = None ,
175- inv_transform : CanBool = True ,
185+ inv_transform : op . CanBool = True ,
176186 engine : QMCEngine | None = None ,
177187 seed : Seed | None = None ,
178188 ) -> None : ...
179- def random (self , / , n : AnyInt = 1 ) -> _Array2D_f8 : ...
180189
181- class MultinomialQMC :
190+ class MultinomialQMC ( _QMCDistribution ) :
182191 pvals : Final [_Array1D [np .floating [Any ]]]
183192 n_trials : Final [AnyInt ]
184- engine : Final [QMCEngine ]
185193
194+ @override
186195 def __init__ (
187196 self ,
188197 / ,
@@ -192,81 +201,98 @@ class MultinomialQMC:
192201 engine : QMCEngine | None = None ,
193202 seed : Seed | None = None ,
194203 ) -> None : ...
195- def random (self , / , n : AnyInt = 1 ) -> _Array2D_f8 : ...
196204
197205#
198206@overload
199207def check_random_state (seed : int | np .integer [Any ] | numbers .Integral | None = None ) -> np .random .Generator : ...
200208@overload
201209def check_random_state (seed : _RNGT ) -> _RNGT : ...
210+
211+ #
202212def scale (
203213 sample : _Any2D_f ,
204214 l_bounds : _Any1D_f_co | AnyReal ,
205215 u_bounds : _Any1D_f_co | AnyReal ,
206216 * ,
207- reverse : CanBool = False ,
208- ) -> _Array2D_f8 : ...
209- def discrepancy (sample : _Any2D_f , * , iterative : CanBool = False , method : _MethodDisc = "CD" , workers : CanInt = 1 ) -> float : ...
210- def geometric_discrepancy (sample : _Any2D_f , method : _MethodDist = "mindist" , metric : _MetricDist = "euclidean" ) -> np .float64 : ...
211- def update_discrepancy (x_new : _Any1D_f , sample : _Any2D_f , initial_disc : CanFloat ) -> float : ...
217+ reverse : op .CanBool = False ,
218+ ) -> _Array2D : ...
219+
220+ #
221+ def discrepancy (
222+ sample : _Any2D_f ,
223+ * ,
224+ iterative : op .CanBool = False ,
225+ method : _MethodDisc = "CD" ,
226+ workers : opt .AnyInt = 1 ,
227+ ) -> float | np .float64 : ...
228+
229+ #
230+ def geometric_discrepancy (
231+ sample : _Any2D_f ,
232+ method : _MethodDist = "mindist" ,
233+ metric : _MetricDist = "euclidean" ,
234+ ) -> float | np .float64 : ...
235+ def update_discrepancy (x_new : _Any1D_f , sample : _Any2D_f , initial_disc : opt .AnyFloat ) -> float : ...
212236def primes_from_2_to (n : AnyInt ) -> _Array1D [np .int_ ]: ...
213237def n_primes (n : AnyInt ) -> list [int ] | _Array1D [np .int_ ]: ...
214238
215239#
216- def _select_optimizer (optimization : _MethodOptimize | None , config : Mapping [str , object ]) -> _FuncOptimize | None : ...
240+ def _select_optimizer (optimization : _MethodQMC | None , config : Mapping [str , object ]) -> _FuncOptimize | None : ...
217241def _random_cd (best_sample : _ArrayT_f , n_iters : AnyInt , n_nochange : AnyInt , rng : RNG ) -> _ArrayT_f : ...
218- def _l1_norm (sample : _Any2D_f ) -> np .float64 : ...
242+ def _l1_norm (sample : _Any2D_f ) -> float | np .float64 : ...
219243def _lloyd_iteration (sample : _ArrayT_f , decay : AnyReal , qhull_options : str | None ) -> _ArrayT_f : ...
220244def _lloyd_centroidal_voronoi_tessellation (
221245 sample : _Any2D_f ,
222246 * ,
223247 tol : AnyReal = 1e-5 ,
224248 maxiter : AnyInt = 10 ,
225249 qhull_options : str | None = None ,
226- ) -> _Array2D_f8 : ...
250+ ) -> _Array2D : ...
251+ def _ensure_in_unit_hypercube (sample : _Any2D_f ) -> _Array2D : ...
227252
228253#
229- def _ensure_in_unit_hypercube (sample : _Any2D_f ) -> _Array2D_f8 : ...
230254@overload
231255def _perturb_discrepancy (
232256 sample : _Array2D [np .integer [Any ] | np .bool_ ],
233- i1 : CanIndex ,
234- i2 : CanIndex ,
235- k : CanIndex ,
257+ i1 : op . CanIndex ,
258+ i2 : op . CanIndex ,
259+ k : op . CanIndex ,
236260 disc : AnyReal ,
237- ) -> np .float64 : ...
261+ ) -> float | np .float64 : ...
238262@overload
239263def _perturb_discrepancy (
240264 sample : _Array2D [_SCT_fc ],
241- i1 : CanIndex ,
242- i2 : CanIndex ,
243- k : CanIndex ,
265+ i1 : op . CanIndex ,
266+ i2 : op . CanIndex ,
267+ k : op . CanIndex ,
244268 disc : AnyReal ,
245269) -> _SCT_fc : ...
270+
271+ #
246272@overload
247- def _van_der_corput_permutation (base : CanIndex , * , random_state : Seed | None = None ) -> _Array2D [np .int_ ]: ...
273+ def _van_der_corput_permutation (base : op . CanIndex , * , random_state : Seed | None = None ) -> _Array2D [np .int_ ]: ...
248274@overload
249- def _van_der_corput_permutation (base : CanFloat , * , random_state : Seed | None = None ) -> _Array2D_f8 : ...
275+ def _van_der_corput_permutation (base : op .CanFloat , * , random_state : Seed | None = None ) -> _Array2D : ...
276+
277+ #
250278def van_der_corput (
251- n : CanInt ,
279+ n : op . CanInt ,
252280 base : AnyInt = 2 ,
253281 * ,
254282 start_index : AnyInt = 0 ,
255- scramble : CanBool = False ,
283+ scramble : op . CanBool = False ,
256284 permutations : _ArrayLikeInt | None = None ,
257285 seed : Seed | None = None ,
258- workers : CanInt = 1 ,
259- ) -> _Array1D_f8 : ...
286+ workers : opt . AnyInt = 1 ,
287+ ) -> _Array1D : ...
260288
261289#
262290@overload
263- def _validate_workers (workers : CanInt [ Literal [ 1 ]] | CanIndex [ Literal [ 1 ]] | Literal [1 ] = 1 ) -> Literal [1 ]: ...
291+ def _validate_workers (workers : Literal [1 ] = 1 ) -> Literal [1 ]: ...
264292@overload
265293def _validate_workers (workers : _N ) -> _N : ...
266294@overload
267- def _validate_workers (workers : CanInt [_N ] | CanIndex [_N ]) -> _N : ...
268- def _validate_bounds (
269- l_bounds : _Any1D_f_co ,
270- u_bounds : _Any1D_f_co ,
271- d : AnyInt ,
272- ) -> tuple [_Array1D [_Scalar_f_co ], _Array1D [_Scalar_f_co ]]: ...
295+ def _validate_workers (workers : opt .AnyInt [_N ]) -> _N : ...
296+
297+ #
298+ def _validate_bounds (l_bounds : _Any1D_f_co , u_bounds : _Any1D_f_co , d : AnyInt ) -> tuple [_Array1D_f_co , _Array1D_f_co ]: ...
0 commit comments