1- from typing import Literal as L , TypeAlias , overload , type_check_only
1+ from typing import Generic , Literal as L , Protocol , TypeAlias , overload , type_check_only
2+ from typing_extensions import TypeVar , override
23
34import numpy as np
5+ import optype as op
46import optype .numpy as onp
7+ from scipy ._typing import Falsy , Truthy
58from scipy .sparse import coo_matrix , dok_matrix
69
710__all__ = ["cKDTree" ]
811
912_Weights : TypeAlias = onp .ToFloatND | tuple [onp .ToFloatND , onp .ToFloatND ]
13+ _Indices : TypeAlias = onp .Array1D [np .intp ]
14+ _Float1D : TypeAlias = onp .Array1D [np .float64 ]
15+ _Float2D : TypeAlias = onp .Array2D [np .float64 ]
16+
17+ _NodeT_co = TypeVar ("_NodeT_co" , bound = _KDTreeNode | None , default = _KDTreeNode | None , covariant = True )
18+ _BoxSizeT_co = TypeVar ("_BoxSizeT_co" , bound = _Float2D | None , default = _Float2D | None , covariant = True )
19+ _BoxSizeDataT_co = TypeVar ("_BoxSizeDataT_co" , bound = _Float1D | None , default = _Float1D | None , covariant = True )
1020
1121@type_check_only
1222class _CythonMixin :
1323 def __setstate_cython__ (self , pyx_state : object , / ) -> None : ...
1424 def __reduce_cython__ (self , / ) -> None : ...
1525
16- class cKDTreeNode (_CythonMixin ):
17- @property
18- def data_points (self , / ) -> onp .ArrayND [np .float64 ]: ...
19- @property
20- def indices (self , / ) -> onp .ArrayND [np .intp ]: ...
21-
22- # These are read-only attributes in cython, which behave like properties
26+ # workaround for mypy's lack of cyclical TypeVar support
27+ @type_check_only
28+ class _KDTreeNode (Protocol ):
2329 @property
2430 def level (self , / ) -> int : ...
2531 @property
2632 def split_dim (self , / ) -> int : ...
2733 @property
34+ def split (self , / ) -> float : ...
35+ @property
2836 def children (self , / ) -> int : ...
2937 @property
38+ def data_points (self , / ) -> _Float2D : ...
39+ @property
40+ def indices (self , / ) -> _Indices : ...
41+ @property
3042 def start_idx (self , / ) -> int : ...
3143 @property
3244 def end_idx (self , / ) -> int : ...
3345 @property
34- def split (self , / ) -> float : ...
35- @property
36- def lesser (self , / ) -> cKDTreeNode | None : ...
46+ def lesser (self , / ) -> _KDTreeNode | None : ...
3747 @property
38- def greater (self , / ) -> cKDTreeNode | None : ...
48+ def greater (self , / ) -> _KDTreeNode | None : ...
49+
50+ ###
3951
40- class cKDTree (_CythonMixin ):
52+ class cKDTreeNode (_CythonMixin , _KDTreeNode , Generic [ _NodeT_co ] ):
4153 @property
42- def n (self , / ) -> int : ...
54+ @override
55+ def lesser (self , / ) -> _NodeT_co : ...
4356 @property
44- def m (self , / ) -> int : ...
57+ @override
58+ def greater (self , / ) -> _NodeT_co : ...
59+
60+ class cKDTree (_CythonMixin , Generic [_BoxSizeT_co , _BoxSizeDataT_co ]):
61+ @property
62+ def data (self , / ) -> _Float2D : ...
4563 @property
4664 def leafsize (self , / ) -> int : ...
4765 @property
48- def size (self , / ) -> int : ...
66+ def m (self , / ) -> int : ...
4967 @property
50- def tree (self , / ) -> cKDTreeNode : ...
51-
52- # These are read-only attributes in cython, which behave like properties
68+ def n (self , / ) -> int : ...
69+ @ property
70+ def maxes ( self , / ) -> _Float1D : ...
5371 @property
54- def data (self , / ) -> onp . ArrayND [ np . float64 ] : ...
72+ def mins (self , / ) -> _Float1D : ...
5573 @property
56- def maxes (self , / ) -> onp . ArrayND [ np . float64 ] : ...
74+ def tree (self , / ) -> cKDTreeNode : ...
5775 @property
58- def mins (self , / ) -> onp . ArrayND [ np . float64 ] : ...
76+ def size (self , / ) -> int : ...
5977 @property
60- def indices (self , / ) -> onp . ArrayND [ np . float64 ] : ...
78+ def indices (self , / ) -> _Indices : ...
6179 @property
62- def boxsize (self , / ) -> onp .ArrayND [np .float64 ] | None : ...
80+ def boxsize (self , / ) -> _BoxSizeT_co : ...
81+ boxsize_data : _BoxSizeDataT_co
6382
6483 #
84+ @overload
6585 def __init__ (
66- self ,
86+ self : cKDTree [ None , None ] ,
6787 / ,
68- data : onp .ToComplexND ,
69- leafsize : int = ...,
70- compact_nodes : bool = ...,
71- copy_data : bool = ...,
72- balanced_tree : bool = ...,
73- boxsize : onp .ToFloat2D | None = ...,
88+ data : onp .ToFloat2D ,
89+ leafsize : int = 16 ,
90+ compact_nodes : bool = True ,
91+ copy_data : bool = False ,
92+ balanced_tree : bool = True ,
93+ boxsize : None = None ,
94+ ) -> None : ...
95+ @overload
96+ def __init__ (
97+ self : cKDTree [_Float2D , _Float1D ],
98+ / ,
99+ data : onp .ToFloat2D ,
100+ leafsize : int ,
101+ compact_nodes : bool ,
102+ copy_data : bool ,
103+ balanced_tree : bool ,
104+ boxsize : onp .ToFloat2D ,
105+ ) -> None : ...
106+ @overload
107+ def __init__ (
108+ self : cKDTree [_Float2D , _Float1D ],
109+ / ,
110+ data : onp .ToFloat2D ,
111+ leafsize : int = 16 ,
112+ compact_nodes : bool = True ,
113+ copy_data : bool = False ,
114+ balanced_tree : bool = True ,
115+ * ,
116+ boxsize : onp .ToFloat2D ,
74117 ) -> None : ...
75118
76119 #
@@ -79,24 +122,124 @@ class cKDTree(_CythonMixin):
79122 / ,
80123 x : onp .ToFloat1D ,
81124 k : onp .ToInt | onp .ToInt1D = 1 ,
82- eps : onp .ToFloat = 0.0 ,
83- p : onp .ToFloat = 2.0 ,
84- distance_upper_bound : float = ... , # inf
85- workers : int | None = None ,
125+ eps : onp .ToFloat = ... ,
126+ p : onp .ToFloat = ... ,
127+ distance_upper_bound : float = float ( "inf" ) , # noqa: PYI011
128+ workers : int | None = ... ,
86129 ) -> tuple [float , np .intp ] | tuple [onp .ArrayND [np .float64 ], onp .ArrayND [np .intp ]]: ...
87130
88131 #
132+ @overload
133+ def query_ball_point (
134+ self ,
135+ / ,
136+ x : onp .ToFloatStrict1D ,
137+ r : onp .ToFloat ,
138+ p : onp .ToFloat = 2.0 ,
139+ eps : onp .ToFloat = ...,
140+ workers : op .CanIndex | None = None ,
141+ return_sorted : onp .ToBool | None = None ,
142+ return_length : Falsy = False ,
143+ ) -> list [int ]: ...
144+ @overload
145+ def query_ball_point (
146+ self ,
147+ / ,
148+ x : onp .ToFloatStrict1D ,
149+ r : onp .ToFloat ,
150+ p : onp .ToFloat ,
151+ eps : onp .ToFloat ,
152+ workers : op .CanIndex | None ,
153+ return_sorted : onp .ToBool | None ,
154+ return_length : Truthy ,
155+ ) -> np .intp : ...
156+ @overload
157+ def query_ball_point (
158+ self ,
159+ / ,
160+ x : onp .ToFloatStrict1D ,
161+ r : onp .ToFloat ,
162+ p : onp .ToFloat = 2.0 ,
163+ eps : onp .ToFloat = ...,
164+ workers : op .CanIndex | None = None ,
165+ return_sorted : onp .ToBool | None = None ,
166+ * ,
167+ return_length : Truthy ,
168+ ) -> np .intp : ...
169+ @overload
170+ def query_ball_point (
171+ self ,
172+ / ,
173+ x : onp .ToFloatND ,
174+ r : onp .ToFloatND ,
175+ p : onp .ToFloat = 2.0 ,
176+ eps : onp .ToFloat = ...,
177+ workers : op .CanIndex | None = None ,
178+ return_sorted : onp .ToBool | None = None ,
179+ return_length : Falsy = False ,
180+ ) -> onp .ArrayND [np .object_ ]: ...
181+ @overload
182+ def query_ball_point (
183+ self ,
184+ / ,
185+ x : onp .ToFloatND ,
186+ r : onp .ToFloatND ,
187+ p : onp .ToFloat ,
188+ eps : onp .ToFloat ,
189+ workers : op .CanIndex | None ,
190+ return_sorted : onp .ToBool | None ,
191+ return_length : Truthy ,
192+ ) -> onp .ArrayND [np .intp ]: ...
193+ @overload
194+ def query_ball_point (
195+ self ,
196+ / ,
197+ x : onp .ToFloatND ,
198+ r : onp .ToFloatND ,
199+ p : onp .ToFloat = 2.0 ,
200+ eps : onp .ToFloat = ...,
201+ workers : op .CanIndex | None = None ,
202+ return_sorted : onp .ToBool | None = None ,
203+ * ,
204+ return_length : Truthy ,
205+ ) -> onp .ArrayND [np .intp ]: ...
206+ @overload
89207 def query_ball_point (
90208 self ,
91209 / ,
92210 x : onp .ToFloatND ,
93211 r : onp .ToFloat | onp .ToFloatND ,
94212 p : onp .ToFloat = 2.0 ,
95- eps : onp .ToFloat = 0.0 ,
96- workers : int | None = None ,
97- return_sorted : bool | None = None ,
98- return_length : bool = False ,
213+ eps : onp .ToFloat = ... ,
214+ workers : op . CanIndex | None = None ,
215+ return_sorted : onp . ToBool | None = None ,
216+ return_length : Falsy = False ,
99217 ) -> list [int ] | onp .ArrayND [np .object_ ]: ...
218+ @overload
219+ def query_ball_point (
220+ self ,
221+ / ,
222+ x : onp .ToFloatND ,
223+ r : onp .ToFloat | onp .ToFloatND ,
224+ p : onp .ToFloat ,
225+ eps : onp .ToFloat ,
226+ workers : op .CanIndex | None ,
227+ return_sorted : onp .ToBool | None ,
228+ return_length : Truthy ,
229+ ) -> np .intp | onp .ArrayND [np .intp ]: ...
230+ @overload
231+ def query_ball_point (
232+ self ,
233+ / ,
234+ x : onp .ToFloatND ,
235+ r : onp .ToFloat | onp .ToFloatND ,
236+ p : onp .ToFloat = 2.0 ,
237+ eps : onp .ToFloat = ...,
238+ workers : op .CanIndex | None = None ,
239+ return_sorted : onp .ToBool | None = None ,
240+ * ,
241+ return_length : Truthy ,
242+ ) -> np .intp | onp .ArrayND [np .intp ]: ...
100243
101244 #
102245 def query_ball_tree (
@@ -105,7 +248,7 @@ class cKDTree(_CythonMixin):
105248 other : cKDTree ,
106249 r : onp .ToFloat ,
107250 p : onp .ToFloat = 2.0 ,
108- eps : onp .ToFloat = 0.0 ,
251+ eps : onp .ToFloat = ..., # defaults to ` 0.0`, but is overridden in `KDTree` with `0` as default
109252 ) -> list [list [int ]]: ...
110253
111254 #
@@ -144,7 +287,7 @@ class cKDTree(_CythonMixin):
144287 self ,
145288 / ,
146289 other : cKDTree ,
147- r : onp .ToScalar ,
290+ r : onp .ToFloat ,
148291 p : onp .ToFloat = 2.0 ,
149292 weights : tuple [None , None ] | None = None ,
150293 cumulative : bool = True ,
@@ -154,7 +297,7 @@ class cKDTree(_CythonMixin):
154297 self ,
155298 / ,
156299 other : cKDTree ,
157- r : onp .ToScalar ,
300+ r : onp .ToFloat ,
158301 p : onp .ToFloat ,
159302 weights : _Weights ,
160303 cumulative : bool = True ,
@@ -164,7 +307,7 @@ class cKDTree(_CythonMixin):
164307 self ,
165308 / ,
166309 other : cKDTree ,
167- r : onp .ToScalar ,
310+ r : onp .ToFloat ,
168311 p : onp .ToFloat = 2.0 ,
169312 * ,
170313 weights : _Weights ,
@@ -175,32 +318,32 @@ class cKDTree(_CythonMixin):
175318 self ,
176319 / ,
177320 other : cKDTree ,
178- r : onp .ToFloat | onp .ToFloatND ,
321+ r : onp .ToFloat | onp .ToFloat1D ,
179322 p : onp .ToFloat = 2.0 ,
180323 weights : tuple [None , None ] | None = ...,
181324 cumulative : bool = True ,
182- ) -> np .float64 | np . intp | onp .ArrayND [np .intp ]: ...
325+ ) -> np .intp | onp .Array1D [np .intp ]: ...
183326 @overload
184327 def count_neighbors (
185328 self ,
186329 / ,
187330 other : cKDTree ,
188- r : onp .ToFloat | onp .ToFloatND ,
331+ r : onp .ToFloat | onp .ToFloat1D ,
189332 p : onp .ToFloat ,
190333 weights : _Weights ,
191334 cumulative : bool = True ,
192- ) -> np .float64 | np . intp | onp .ArrayND [np .float64 ]: ...
335+ ) -> np .float64 | onp .Array1D [np .float64 ]: ...
193336 @overload
194337 def count_neighbors (
195338 self ,
196339 / ,
197340 other : cKDTree ,
198- r : onp .ToFloat | onp .ToFloatND ,
341+ r : onp .ToFloat | onp .ToFloat1D ,
199342 p : onp .ToFloat = 2.0 ,
200343 * ,
201344 weights : _Weights ,
202345 cumulative : bool = True ,
203- ) -> np .float64 | np . intp | onp .ArrayND [np .float64 ]: ...
346+ ) -> np .float64 | onp .Array1D [np .float64 ]: ...
204347
205348 #
206349 @overload
@@ -211,7 +354,7 @@ class cKDTree(_CythonMixin):
211354 max_distance : onp .ToFloat ,
212355 p : onp .ToFloat = 2.0 ,
213356 output_type : L ["dok_matrix" ] = ...,
214- ) -> dok_matrix : ...
357+ ) -> dok_matrix [ np . float64 ] : ...
215358 @overload
216359 def sparse_distance_matrix (
217360 self ,
@@ -221,7 +364,7 @@ class cKDTree(_CythonMixin):
221364 p : onp .ToFloat = 2.0 ,
222365 * ,
223366 output_type : L ["coo_matrix" ],
224- ) -> coo_matrix : ...
367+ ) -> coo_matrix [ np . float64 ] : ...
225368 @overload
226369 def sparse_distance_matrix (
227370 self ,
0 commit comments