@@ -2,7 +2,7 @@ import abc
22import numbers
33from collections .abc import Callable , Mapping , Sequence
44from typing import Any , ClassVar , Final , Literal , Protocol , TypeAlias , overload , type_check_only
5- from typing_extensions import Self , TypeVar
5+ from typing_extensions import Self , TypeVar , override
66
77import numpy as np
88import numpy .typing as npt
@@ -28,30 +28,31 @@ __all__ = [
2828]
2929
3030_RNGT = TypeVar ("_RNGT" , bound = np .random .Generator | np .random .RandomState )
31- _SCT = TypeVar ("_SCT " , bound = np .generic )
31+ _SCT0 = TypeVar ("_SCT0 " , bound = np .generic , default = np . float64 )
3232_SCT_co = TypeVar ("_SCT_co" , covariant = True , bound = np .generic )
3333_SCT_fc = TypeVar ("_SCT_fc" , bound = np .inexact [Any ])
3434_ArrayT_f = TypeVar ("_ArrayT_f" , bound = npt .NDArray [np .floating [Any ]])
3535_N = TypeVar ("_N" , bound = int )
3636
37+ # the `__len__` ensures that scalar types like `np.generic` are excluded
3738@type_check_only
3839class _CanLenArray (Protocol [_SCT_co ]):
3940 def __len__ (self , / ) -> int : ...
4041 def __array__ (self , / ) -> npt .NDArray [_SCT_co ]: ...
4142
4243_Scalar_f_co : TypeAlias = np .floating [Any ] | np .integer [Any ] | np .bool_
44+ _ScalarLike_f : TypeAlias = float | np .floating [Any ]
4345
44- _Array1D : TypeAlias = onpt .Array [tuple [int ], _SCT ]
45- _Array1D_f8 : TypeAlias = _Array1D [np .float64 ]
46- _Array2D : TypeAlias = onpt .Array [tuple [int , int ], _SCT ]
47- _Array2D_f8 : TypeAlias = _Array2D [np .float64 ]
46+ _Array1D : TypeAlias = onpt .Array [tuple [int ], _SCT0 ]
47+ _Array2D : TypeAlias = onpt .Array [tuple [int , int ], _SCT0 ]
48+ _Array1D_f_co : TypeAlias = _Array1D [_Scalar_f_co ]
4849
49- _Any1D_f : TypeAlias = _CanLenArray [np .floating [Any ]] | Sequence [float | np . floating [ Any ] ]
50+ _Any1D_f : TypeAlias = _CanLenArray [np .floating [Any ]] | Sequence [_ScalarLike_f ]
5051_Any1D_f_co : TypeAlias = _CanLenArray [_Scalar_f_co ] | Sequence [AnyReal ]
51- _Any2D_f : TypeAlias = _CanLenArray [np .floating [Any ]] | Sequence [Sequence [float | np . floating [ Any ] ]] | Sequence [_Any1D_f ]
52+ _Any2D_f : TypeAlias = _CanLenArray [np .floating [Any ]] | Sequence [Sequence [_ScalarLike_f ]] | Sequence [_Any1D_f ]
5253_Any2D_f_co : TypeAlias = _CanLenArray [_Scalar_f_co ] | Sequence [Sequence [AnyReal ]] | Sequence [_Any1D_f_co ]
5354
54- _MethodOptimize : TypeAlias = Literal ["random-cd" , "lloyd" ]
55+ _MethodQMC : TypeAlias = Literal ["random-cd" , "lloyd" ]
5556_MethodDisc : TypeAlias = Literal ["CD" , "WD" , "MD" , "L2-star" ]
5657_MethodDist : TypeAlias = Literal ["mindist" , "mst" ]
5758_MetricDist : TypeAlias = _MetricKind | _MetricCallback
@@ -68,8 +69,8 @@ class QMCEngine(abc.ABC):
6869 num_generated : int
6970
7071 @abc .abstractmethod
71- def __init__ (self , / , d : AnyInt , * , optimization : _MethodOptimize | None = None , seed : Seed | None = None ) -> None : ...
72- def random (self , / , n : opt .AnyInt = 1 , * , workers : AnyInt = 1 ) -> _Array2D_f8 : ...
72+ def __init__ (self , / , d : AnyInt , * , optimization : _MethodQMC | None = None , seed : Seed | None = None ) -> None : ...
73+ def random (self , / , n : opt .AnyInt = 1 , * , workers : AnyInt = 1 ) -> _Array2D : ...
7374 def integers (
7475 self ,
7576 / ,
@@ -93,13 +94,13 @@ class Halton(QMCEngine):
9394 d : AnyInt ,
9495 * ,
9596 scramble : bool = True ,
96- optimization : _MethodOptimize | None = None ,
97+ optimization : _MethodQMC | None = None ,
9798 seed : Seed | None = None ,
9899 ) -> None : ...
99100
100101class LatinHypercube (QMCEngine ):
101102 scramble : bool
102- lhs_method : Callable [[int | np .integer [Any ]], _Array2D_f8 ]
103+ lhs_method : Callable [[int | np .integer [Any ]], _Array2D ]
103104
104105 def __init__ (
105106 self ,
@@ -108,7 +109,7 @@ class LatinHypercube(QMCEngine):
108109 * ,
109110 scramble : bool = True ,
110111 strength : int = 1 ,
111- optimization : _MethodOptimize | None = None ,
112+ optimization : _MethodQMC | None = None ,
112113 seed : Seed | None = None ,
113114 ) -> None : ...
114115
@@ -126,10 +127,10 @@ class Sobol(QMCEngine):
126127 * ,
127128 scramble : op .CanBool = True ,
128129 bits : AnyInt | None = None ,
130+ optimization : _MethodQMC | None = None ,
129131 seed : Seed | None = None ,
130- optimization : _MethodOptimize | None = None ,
131132 ) -> None : ...
132- def random_base2 (self , / , m : AnyInt ) -> _Array2D_f8 : ...
133+ def random_base2 (self , / , m : AnyInt ) -> _Array2D : ...
133134
134135@type_check_only
135136class _HypersphereMethod (Protocol ):
@@ -139,7 +140,7 @@ class _HypersphereMethod(Protocol):
139140 center : npt .NDArray [_Scalar_f_co ],
140141 radius : AnyReal ,
141142 candidates : AnyInt = 1 ,
142- ) -> _Array2D_f8 : ...
143+ ) -> _Array2D : ...
143144
144145class PoissonDisk (QMCEngine ):
145146 hypersphere_method : Final [_HypersphereMethod ]
@@ -150,7 +151,7 @@ class PoissonDisk(QMCEngine):
150151 cell_size : Final [np .float64 ]
151152 grid_size : Final [_Array1D [np .int_ ]]
152153
153- sample_pool : list [_Array1D_f8 ]
154+ sample_pool : list [_Array1D ]
154155 sample_grid : npt .NDArray [np .float32 ]
155156
156157 def __init__ (
@@ -161,13 +162,19 @@ class PoissonDisk(QMCEngine):
161162 radius : AnyReal = 0.05 ,
162163 hypersphere : Literal ["volume" , "surface" ] = "volume" ,
163164 ncandidates : AnyInt = 30 ,
164- optimization : _MethodOptimize | None = None ,
165+ optimization : _MethodQMC | None = None ,
165166 seed : Seed | None = None ,
166167 ) -> None : ...
167- def fill_space (self , / ) -> _Array2D_f8 : ...
168+ def fill_space (self , / ) -> _Array2D : ...
168169
169- class MultivariateNormalQMC :
170- engine : Final [QMCEngine ]
170+ @type_check_only
171+ class _QMCDistribution :
172+ engine : Final [QMCEngine ] # defaults to `Sobol`
173+ def __init__ (self , / , * , engine : QMCEngine | None = None , seed : Seed | None = None ) -> None : ...
174+ def random (self , / , n : AnyInt = 1 ) -> _Array2D : ...
175+
176+ class MultivariateNormalQMC (_QMCDistribution ):
177+ @override
171178 def __init__ (
172179 self ,
173180 / ,
@@ -179,13 +186,12 @@ class MultivariateNormalQMC:
179186 engine : QMCEngine | None = None ,
180187 seed : Seed | None = None ,
181188 ) -> None : ...
182- def random (self , / , n : AnyInt = 1 ) -> _Array2D_f8 : ...
183189
184- class MultinomialQMC :
190+ class MultinomialQMC ( _QMCDistribution ) :
185191 pvals : Final [_Array1D [np .floating [Any ]]]
186192 n_trials : Final [AnyInt ]
187- engine : Final [QMCEngine ]
188193
194+ @override
189195 def __init__ (
190196 self ,
191197 / ,
@@ -195,55 +201,64 @@ class MultinomialQMC:
195201 engine : QMCEngine | None = None ,
196202 seed : Seed | None = None ,
197203 ) -> None : ...
198- def random (self , / , n : AnyInt = 1 ) -> _Array2D_f8 : ...
199204
200205#
201206@overload
202207def check_random_state (seed : int | np .integer [Any ] | numbers .Integral | None = None ) -> np .random .Generator : ...
203208@overload
204209def check_random_state (seed : _RNGT ) -> _RNGT : ...
210+
211+ #
205212def scale (
206213 sample : _Any2D_f ,
207214 l_bounds : _Any1D_f_co | AnyReal ,
208215 u_bounds : _Any1D_f_co | AnyReal ,
209216 * ,
210217 reverse : op .CanBool = False ,
211- ) -> _Array2D_f8 : ...
218+ ) -> _Array2D : ...
219+
220+ #
212221def discrepancy (
213222 sample : _Any2D_f ,
214223 * ,
215224 iterative : op .CanBool = False ,
216225 method : _MethodDisc = "CD" ,
217- workers : op .CanInt = 1 ,
218- ) -> float : ...
219- def geometric_discrepancy (sample : _Any2D_f , method : _MethodDist = "mindist" , metric : _MetricDist = "euclidean" ) -> np .float64 : ...
220- def update_discrepancy (x_new : _Any1D_f , sample : _Any2D_f , initial_disc : op .CanFloat ) -> float : ...
226+ workers : opt .AnyInt = 1 ,
227+ ) -> float | np .float64 : ...
228+
229+ #
230+ def geometric_discrepancy (
231+ sample : _Any2D_f ,
232+ method : _MethodDist = "mindist" ,
233+ metric : _MetricDist = "euclidean" ,
234+ ) -> float | np .float64 : ...
235+ def update_discrepancy (x_new : _Any1D_f , sample : _Any2D_f , initial_disc : opt .AnyFloat ) -> float : ...
221236def primes_from_2_to (n : AnyInt ) -> _Array1D [np .int_ ]: ...
222237def n_primes (n : AnyInt ) -> list [int ] | _Array1D [np .int_ ]: ...
223238
224239#
225- def _select_optimizer (optimization : _MethodOptimize | None , config : Mapping [str , object ]) -> _FuncOptimize | None : ...
240+ def _select_optimizer (optimization : _MethodQMC | None , config : Mapping [str , object ]) -> _FuncOptimize | None : ...
226241def _random_cd (best_sample : _ArrayT_f , n_iters : AnyInt , n_nochange : AnyInt , rng : RNG ) -> _ArrayT_f : ...
227- def _l1_norm (sample : _Any2D_f ) -> np .float64 : ...
242+ def _l1_norm (sample : _Any2D_f ) -> float | np .float64 : ...
228243def _lloyd_iteration (sample : _ArrayT_f , decay : AnyReal , qhull_options : str | None ) -> _ArrayT_f : ...
229244def _lloyd_centroidal_voronoi_tessellation (
230245 sample : _Any2D_f ,
231246 * ,
232247 tol : AnyReal = 1e-5 ,
233248 maxiter : AnyInt = 10 ,
234249 qhull_options : str | None = None ,
235- ) -> _Array2D_f8 : ...
250+ ) -> _Array2D : ...
251+ def _ensure_in_unit_hypercube (sample : _Any2D_f ) -> _Array2D : ...
236252
237253#
238- def _ensure_in_unit_hypercube (sample : _Any2D_f ) -> _Array2D_f8 : ...
239254@overload
240255def _perturb_discrepancy (
241256 sample : _Array2D [np .integer [Any ] | np .bool_ ],
242257 i1 : op .CanIndex ,
243258 i2 : op .CanIndex ,
244259 k : op .CanIndex ,
245260 disc : AnyReal ,
246- ) -> np .float64 : ...
261+ ) -> float | np .float64 : ...
247262@overload
248263def _perturb_discrepancy (
249264 sample : _Array2D [_SCT_fc ],
@@ -252,10 +267,14 @@ def _perturb_discrepancy(
252267 k : op .CanIndex ,
253268 disc : AnyReal ,
254269) -> _SCT_fc : ...
270+
271+ #
255272@overload
256273def _van_der_corput_permutation (base : op .CanIndex , * , random_state : Seed | None = None ) -> _Array2D [np .int_ ]: ...
257274@overload
258- def _van_der_corput_permutation (base : op .CanFloat , * , random_state : Seed | None = None ) -> _Array2D_f8 : ...
275+ def _van_der_corput_permutation (base : op .CanFloat , * , random_state : Seed | None = None ) -> _Array2D : ...
276+
277+ #
259278def van_der_corput (
260279 n : op .CanInt ,
261280 base : AnyInt = 2 ,
@@ -264,18 +283,16 @@ def van_der_corput(
264283 scramble : op .CanBool = False ,
265284 permutations : _ArrayLikeInt | None = None ,
266285 seed : Seed | None = None ,
267- workers : op . CanInt = 1 ,
268- ) -> _Array1D_f8 : ...
286+ workers : opt . AnyInt = 1 ,
287+ ) -> _Array1D : ...
269288
270289#
271290@overload
272- def _validate_workers (workers : op . CanInt [ Literal [ 1 ]] | op . CanIndex [ Literal [ 1 ]] | Literal [1 ] = 1 ) -> Literal [1 ]: ...
291+ def _validate_workers (workers : Literal [1 ] = 1 ) -> Literal [1 ]: ...
273292@overload
274293def _validate_workers (workers : _N ) -> _N : ...
275294@overload
276- def _validate_workers (workers : op .CanInt [_N ] | op .CanIndex [_N ]) -> _N : ...
277- def _validate_bounds (
278- l_bounds : _Any1D_f_co ,
279- u_bounds : _Any1D_f_co ,
280- d : AnyInt ,
281- ) -> tuple [_Array1D [_Scalar_f_co ], _Array1D [_Scalar_f_co ]]: ...
295+ def _validate_workers (workers : opt .AnyInt [_N ]) -> _N : ...
296+
297+ #
298+ def _validate_bounds (l_bounds : _Any1D_f_co , u_bounds : _Any1D_f_co , d : AnyInt ) -> tuple [_Array1D_f_co , _Array1D_f_co ]: ...
0 commit comments