Skip to content

Commit 5a8e8b9

Browse files
authored
spatial: generic [c]KDTree and other improvements (#413)
1 parent 99abc24 commit 5a8e8b9

File tree

2 files changed

+230
-87
lines changed

2 files changed

+230
-87
lines changed

scipy-stubs/spatial/_ckdtree.pyi

Lines changed: 194 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,76 +1,119 @@
1-
from typing import Literal as L, TypeAlias, overload, type_check_only
1+
from typing import Generic, Literal as L, Protocol, TypeAlias, overload, type_check_only
2+
from typing_extensions import TypeVar, override
23

34
import numpy as np
5+
import optype as op
46
import optype.numpy as onp
7+
from scipy._typing import Falsy, Truthy
58
from scipy.sparse import coo_matrix, dok_matrix
69

710
__all__ = ["cKDTree"]
811

912
_Weights: TypeAlias = onp.ToFloatND | tuple[onp.ToFloatND, onp.ToFloatND]
13+
_Indices: TypeAlias = onp.Array1D[np.intp]
14+
_Float1D: TypeAlias = onp.Array1D[np.float64]
15+
_Float2D: TypeAlias = onp.Array2D[np.float64]
16+
17+
_NodeT_co = TypeVar("_NodeT_co", bound=_KDTreeNode | None, default=_KDTreeNode | None, covariant=True)
18+
_BoxSizeT_co = TypeVar("_BoxSizeT_co", bound=_Float2D | None, default=_Float2D | None, covariant=True)
19+
_BoxSizeDataT_co = TypeVar("_BoxSizeDataT_co", bound=_Float1D | None, default=_Float1D | None, covariant=True)
1020

1121
@type_check_only
1222
class _CythonMixin:
1323
def __setstate_cython__(self, pyx_state: object, /) -> None: ...
1424
def __reduce_cython__(self, /) -> None: ...
1525

16-
class cKDTreeNode(_CythonMixin):
17-
@property
18-
def data_points(self, /) -> onp.ArrayND[np.float64]: ...
19-
@property
20-
def indices(self, /) -> onp.ArrayND[np.intp]: ...
21-
22-
# These are read-only attributes in cython, which behave like properties
26+
# workaround for mypy's lack of cyclical TypeVar support
27+
@type_check_only
28+
class _KDTreeNode(Protocol):
2329
@property
2430
def level(self, /) -> int: ...
2531
@property
2632
def split_dim(self, /) -> int: ...
2733
@property
34+
def split(self, /) -> float: ...
35+
@property
2836
def children(self, /) -> int: ...
2937
@property
38+
def data_points(self, /) -> _Float2D: ...
39+
@property
40+
def indices(self, /) -> _Indices: ...
41+
@property
3042
def start_idx(self, /) -> int: ...
3143
@property
3244
def end_idx(self, /) -> int: ...
3345
@property
34-
def split(self, /) -> float: ...
35-
@property
36-
def lesser(self, /) -> cKDTreeNode | None: ...
46+
def lesser(self, /) -> _KDTreeNode | None: ...
3747
@property
38-
def greater(self, /) -> cKDTreeNode | None: ...
48+
def greater(self, /) -> _KDTreeNode | None: ...
49+
50+
###
3951

40-
class cKDTree(_CythonMixin):
52+
class cKDTreeNode(_CythonMixin, _KDTreeNode, Generic[_NodeT_co]):
4153
@property
42-
def n(self, /) -> int: ...
54+
@override
55+
def lesser(self, /) -> _NodeT_co: ...
4356
@property
44-
def m(self, /) -> int: ...
57+
@override
58+
def greater(self, /) -> _NodeT_co: ...
59+
60+
class cKDTree(_CythonMixin, Generic[_BoxSizeT_co, _BoxSizeDataT_co]):
61+
@property
62+
def data(self, /) -> _Float2D: ...
4563
@property
4664
def leafsize(self, /) -> int: ...
4765
@property
48-
def size(self, /) -> int: ...
66+
def m(self, /) -> int: ...
4967
@property
50-
def tree(self, /) -> cKDTreeNode: ...
51-
52-
# These are read-only attributes in cython, which behave like properties
68+
def n(self, /) -> int: ...
69+
@property
70+
def maxes(self, /) -> _Float1D: ...
5371
@property
54-
def data(self, /) -> onp.ArrayND[np.float64]: ...
72+
def mins(self, /) -> _Float1D: ...
5573
@property
56-
def maxes(self, /) -> onp.ArrayND[np.float64]: ...
74+
def tree(self, /) -> cKDTreeNode: ...
5775
@property
58-
def mins(self, /) -> onp.ArrayND[np.float64]: ...
76+
def size(self, /) -> int: ...
5977
@property
60-
def indices(self, /) -> onp.ArrayND[np.float64]: ...
78+
def indices(self, /) -> _Indices: ...
6179
@property
62-
def boxsize(self, /) -> onp.ArrayND[np.float64] | None: ...
80+
def boxsize(self, /) -> _BoxSizeT_co: ...
81+
boxsize_data: _BoxSizeDataT_co
6382

6483
#
84+
@overload
6585
def __init__(
66-
self,
86+
self: cKDTree[None, None],
6787
/,
68-
data: onp.ToComplexND,
69-
leafsize: int = ...,
70-
compact_nodes: bool = ...,
71-
copy_data: bool = ...,
72-
balanced_tree: bool = ...,
73-
boxsize: onp.ToFloat2D | None = ...,
88+
data: onp.ToFloat2D,
89+
leafsize: int = 16,
90+
compact_nodes: bool = True,
91+
copy_data: bool = False,
92+
balanced_tree: bool = True,
93+
boxsize: None = None,
94+
) -> None: ...
95+
@overload
96+
def __init__(
97+
self: cKDTree[_Float2D, _Float1D],
98+
/,
99+
data: onp.ToFloat2D,
100+
leafsize: int,
101+
compact_nodes: bool,
102+
copy_data: bool,
103+
balanced_tree: bool,
104+
boxsize: onp.ToFloat2D,
105+
) -> None: ...
106+
@overload
107+
def __init__(
108+
self: cKDTree[_Float2D, _Float1D],
109+
/,
110+
data: onp.ToFloat2D,
111+
leafsize: int = 16,
112+
compact_nodes: bool = True,
113+
copy_data: bool = False,
114+
balanced_tree: bool = True,
115+
*,
116+
boxsize: onp.ToFloat2D,
74117
) -> None: ...
75118

76119
#
@@ -79,24 +122,124 @@ class cKDTree(_CythonMixin):
79122
/,
80123
x: onp.ToFloat1D,
81124
k: onp.ToInt | onp.ToInt1D = 1,
82-
eps: onp.ToFloat = 0.0,
83-
p: onp.ToFloat = 2.0,
84-
distance_upper_bound: float = ..., # inf
85-
workers: int | None = None,
125+
eps: onp.ToFloat = ...,
126+
p: onp.ToFloat = ...,
127+
distance_upper_bound: float = float("inf"), # noqa: PYI011
128+
workers: int | None = ...,
86129
) -> tuple[float, np.intp] | tuple[onp.ArrayND[np.float64], onp.ArrayND[np.intp]]: ...
87130

88131
#
132+
@overload
133+
def query_ball_point(
134+
self,
135+
/,
136+
x: onp.ToFloatStrict1D,
137+
r: onp.ToFloat,
138+
p: onp.ToFloat = 2.0,
139+
eps: onp.ToFloat = ...,
140+
workers: op.CanIndex | None = None,
141+
return_sorted: onp.ToBool | None = None,
142+
return_length: Falsy = False,
143+
) -> list[int]: ...
144+
@overload
145+
def query_ball_point(
146+
self,
147+
/,
148+
x: onp.ToFloatStrict1D,
149+
r: onp.ToFloat,
150+
p: onp.ToFloat,
151+
eps: onp.ToFloat,
152+
workers: op.CanIndex | None,
153+
return_sorted: onp.ToBool | None,
154+
return_length: Truthy,
155+
) -> np.intp: ...
156+
@overload
157+
def query_ball_point(
158+
self,
159+
/,
160+
x: onp.ToFloatStrict1D,
161+
r: onp.ToFloat,
162+
p: onp.ToFloat = 2.0,
163+
eps: onp.ToFloat = ...,
164+
workers: op.CanIndex | None = None,
165+
return_sorted: onp.ToBool | None = None,
166+
*,
167+
return_length: Truthy,
168+
) -> np.intp: ...
169+
@overload
170+
def query_ball_point(
171+
self,
172+
/,
173+
x: onp.ToFloatND,
174+
r: onp.ToFloatND,
175+
p: onp.ToFloat = 2.0,
176+
eps: onp.ToFloat = ...,
177+
workers: op.CanIndex | None = None,
178+
return_sorted: onp.ToBool | None = None,
179+
return_length: Falsy = False,
180+
) -> onp.ArrayND[np.object_]: ...
181+
@overload
182+
def query_ball_point(
183+
self,
184+
/,
185+
x: onp.ToFloatND,
186+
r: onp.ToFloatND,
187+
p: onp.ToFloat,
188+
eps: onp.ToFloat,
189+
workers: op.CanIndex | None,
190+
return_sorted: onp.ToBool | None,
191+
return_length: Truthy,
192+
) -> onp.ArrayND[np.intp]: ...
193+
@overload
194+
def query_ball_point(
195+
self,
196+
/,
197+
x: onp.ToFloatND,
198+
r: onp.ToFloatND,
199+
p: onp.ToFloat = 2.0,
200+
eps: onp.ToFloat = ...,
201+
workers: op.CanIndex | None = None,
202+
return_sorted: onp.ToBool | None = None,
203+
*,
204+
return_length: Truthy,
205+
) -> onp.ArrayND[np.intp]: ...
206+
@overload
89207
def query_ball_point(
90208
self,
91209
/,
92210
x: onp.ToFloatND,
93211
r: onp.ToFloat | onp.ToFloatND,
94212
p: onp.ToFloat = 2.0,
95-
eps: onp.ToFloat = 0.0,
96-
workers: int | None = None,
97-
return_sorted: bool | None = None,
98-
return_length: bool = False,
213+
eps: onp.ToFloat = ...,
214+
workers: op.CanIndex | None = None,
215+
return_sorted: onp.ToBool | None = None,
216+
return_length: Falsy = False,
99217
) -> list[int] | onp.ArrayND[np.object_]: ...
218+
@overload
219+
def query_ball_point(
220+
self,
221+
/,
222+
x: onp.ToFloatND,
223+
r: onp.ToFloat | onp.ToFloatND,
224+
p: onp.ToFloat,
225+
eps: onp.ToFloat,
226+
workers: op.CanIndex | None,
227+
return_sorted: onp.ToBool | None,
228+
return_length: Truthy,
229+
) -> np.intp | onp.ArrayND[np.intp]: ...
230+
@overload
231+
def query_ball_point(
232+
self,
233+
/,
234+
x: onp.ToFloatND,
235+
r: onp.ToFloat | onp.ToFloatND,
236+
p: onp.ToFloat = 2.0,
237+
eps: onp.ToFloat = ...,
238+
workers: op.CanIndex | None = None,
239+
return_sorted: onp.ToBool | None = None,
240+
*,
241+
return_length: Truthy,
242+
) -> np.intp | onp.ArrayND[np.intp]: ...
100243

101244
#
102245
def query_ball_tree(
@@ -105,7 +248,7 @@ class cKDTree(_CythonMixin):
105248
other: cKDTree,
106249
r: onp.ToFloat,
107250
p: onp.ToFloat = 2.0,
108-
eps: onp.ToFloat = 0.0,
251+
eps: onp.ToFloat = ..., # defaults to `0.0`, but is overridden in `KDTree` with `0` as default
109252
) -> list[list[int]]: ...
110253

111254
#
@@ -144,7 +287,7 @@ class cKDTree(_CythonMixin):
144287
self,
145288
/,
146289
other: cKDTree,
147-
r: onp.ToScalar,
290+
r: onp.ToFloat,
148291
p: onp.ToFloat = 2.0,
149292
weights: tuple[None, None] | None = None,
150293
cumulative: bool = True,
@@ -154,7 +297,7 @@ class cKDTree(_CythonMixin):
154297
self,
155298
/,
156299
other: cKDTree,
157-
r: onp.ToScalar,
300+
r: onp.ToFloat,
158301
p: onp.ToFloat,
159302
weights: _Weights,
160303
cumulative: bool = True,
@@ -164,7 +307,7 @@ class cKDTree(_CythonMixin):
164307
self,
165308
/,
166309
other: cKDTree,
167-
r: onp.ToScalar,
310+
r: onp.ToFloat,
168311
p: onp.ToFloat = 2.0,
169312
*,
170313
weights: _Weights,
@@ -175,32 +318,32 @@ class cKDTree(_CythonMixin):
175318
self,
176319
/,
177320
other: cKDTree,
178-
r: onp.ToFloat | onp.ToFloatND,
321+
r: onp.ToFloat | onp.ToFloat1D,
179322
p: onp.ToFloat = 2.0,
180323
weights: tuple[None, None] | None = ...,
181324
cumulative: bool = True,
182-
) -> np.float64 | np.intp | onp.ArrayND[np.intp]: ...
325+
) -> np.intp | onp.Array1D[np.intp]: ...
183326
@overload
184327
def count_neighbors(
185328
self,
186329
/,
187330
other: cKDTree,
188-
r: onp.ToFloat | onp.ToFloatND,
331+
r: onp.ToFloat | onp.ToFloat1D,
189332
p: onp.ToFloat,
190333
weights: _Weights,
191334
cumulative: bool = True,
192-
) -> np.float64 | np.intp | onp.ArrayND[np.float64]: ...
335+
) -> np.float64 | onp.Array1D[np.float64]: ...
193336
@overload
194337
def count_neighbors(
195338
self,
196339
/,
197340
other: cKDTree,
198-
r: onp.ToFloat | onp.ToFloatND,
341+
r: onp.ToFloat | onp.ToFloat1D,
199342
p: onp.ToFloat = 2.0,
200343
*,
201344
weights: _Weights,
202345
cumulative: bool = True,
203-
) -> np.float64 | np.intp | onp.ArrayND[np.float64]: ...
346+
) -> np.float64 | onp.Array1D[np.float64]: ...
204347

205348
#
206349
@overload
@@ -211,7 +354,7 @@ class cKDTree(_CythonMixin):
211354
max_distance: onp.ToFloat,
212355
p: onp.ToFloat = 2.0,
213356
output_type: L["dok_matrix"] = ...,
214-
) -> dok_matrix: ...
357+
) -> dok_matrix[np.float64]: ...
215358
@overload
216359
def sparse_distance_matrix(
217360
self,
@@ -221,7 +364,7 @@ class cKDTree(_CythonMixin):
221364
p: onp.ToFloat = 2.0,
222365
*,
223366
output_type: L["coo_matrix"],
224-
) -> coo_matrix: ...
367+
) -> coo_matrix[np.float64]: ...
225368
@overload
226369
def sparse_distance_matrix(
227370
self,

0 commit comments

Comments
 (0)