11import numba
2- import numpy as np
3- from collections import namedtuple
4-
5- NumbaKDTree = namedtuple ("KDTree" , ["data" , "idx_array" , "node_data" , "node_bounds" ])
6-
7-
8- def kdtree_to_numba (sklearn_kdtree ):
9- """Convert a scikit-learn KDTree object to a NumbaKDTree object."""
10- data , idx_array , node_data , node_bounds = sklearn_kdtree .get_arrays ()
11- return NumbaKDTree (data , idx_array , node_data , node_bounds )
12-
13-
14- @numba .njit (
15- [
16- "f4(f4[::1],f4[::1])" ,
17- "f8(f8[::1],f8[::1])" ,
18- "f8(f4[::1],f8[::1])" ,
19- ],
20- fastmath = True ,
21- locals = {
22- "dim" : numba .types .intp ,
23- "i" : numba .types .uint16 ,
24- },
25- )
26- def rdist (x , y ):
27- """Computes the squared Euclidean distance between two points."""
28- result = 0.0
29- dim = x .shape [0 ]
30- for i in range (dim ):
31- diff = x [i ] - y [i ]
32- result += diff * diff
33-
34- return result
35-
362
373@numba .njit (
384 [
39- "void(f4[::1],i4[::1],f4 ,i4)" ,
40- "void(f8[::1],i4[::1],f8 ,i4)" ,
5+ "void(f4[::1],i4[::1],i4[::1],f4,i4 ,i4)" ,
6+ "void(f8[::1],i4[::1],i4[::1],f8,i4 ,i4)" ,
417 ],
428 fastmath = True ,
439 locals = {
@@ -48,16 +14,14 @@ def rdist(x, y):
4814 "i_swap" : numba .types .uint16 ,
4915 },
5016)
51- def simple_heap_push (priorities , indices , p , n ):
52- """Inserts value (index) in to priority heap (distance)."""
53- # if p >= priorities[0]:
54- # return 0
55-
17+ def simple_edge_heap_push (priorities , sources , targets , p , s , t ):
18+ """Inserts value (source, target) in to priority heap (distance)."""
5619 size = priorities .shape [0 ]
5720
5821 # insert val at position zero
5922 priorities [0 ] = p
60- indices [0 ] = n
23+ sources [0 ] = s
24+ targets [0 ] = t
6125
6226 # descend the heap, swapping values until the max heap criterion is met
6327 i = 0
@@ -84,18 +48,18 @@ def simple_heap_push(priorities, indices, p, n):
8448 break
8549
8650 priorities [i ] = priorities [i_swap ]
87- indices [i ] = indices [i_swap ]
51+ sources [i ] = sources [i_swap ]
52+ targets [i ] = targets [i_swap ]
8853
8954 i = i_swap
9055
9156 priorities [i ] = p
92- indices [i ] = n
93-
94- # return 1
57+ sources [i ] = s
58+ targets [i ] = t
9559
9660
9761@numba .njit ()
98- def siftdown (heap1 , heap2 , elt ):
62+ def siftdown (heap1 , heap2 , heap3 , elt ):
9963 """Moves the element at index elt to its correct position in a heap."""
10064 while elt * 2 + 1 < heap1 .shape [0 ]:
10165 left_child = elt * 2 + 1
@@ -113,133 +77,21 @@ def siftdown(heap1, heap2, elt):
11377 else :
11478 heap1 [elt ], heap1 [swap ] = heap1 [swap ], heap1 [elt ]
11579 heap2 [elt ], heap2 [swap ] = heap2 [swap ], heap2 [elt ]
80+ heap3 [elt ], heap3 [swap ] = heap3 [swap ], heap3 [elt ]
11681 elt = swap
11782
11883
11984@numba .njit (parallel = True )
120- def deheap_sort (distances , indices ):
85+ def deheap_sort_edges (distances , sources , targets ):
12186 """Sorts the heaps and returns the sorted distances and indices."""
122- for i in numba .prange (indices .shape [0 ]):
87+ for i in numba .prange (distances .shape [0 ]):
12388 # starting from the end of the array and moving back
124- for j in range (indices .shape [1 ] - 1 , 0 , - 1 ):
125- indices [i , 0 ], indices [i , j ] = indices [i , j ], indices [i , 0 ]
89+ for j in range (sources .shape [1 ] - 1 , 0 , - 1 ):
90+ sources [i , 0 ], sources [i , j ] = sources [i , j ], sources [i , 0 ]
91+ targets [i , 0 ], targets [i , j ] = targets [i , j ], targets [i , 0 ]
12692 distances [i , 0 ], distances [i , j ] = distances [i , j ], distances [i , 0 ]
12793
128- siftdown (distances [i , :j ], indices [i , :j ], 0 )
129-
130- return distances , indices
94+ siftdown (distances [i , :j ], sources [i , :j ], targets [i , :j ], 0 )
13195
96+ return distances , sources , targets
13297
133- @numba .njit (
134- [
135- "f4(f4[::1],f4[::1],f4[::1])" ,
136- "f4(f8[::1],f8[::1],f4[::1])" ,
137- "f4(f8[::1],f8[::1],f8[::1])" ,
138- ],
139- fastmath = True ,
140- locals = {
141- "dim" : numba .types .intp ,
142- "i" : numba .types .uint16 ,
143- },
144- )
145- def point_to_node_lower_bound_rdist (upper , lower , pt ):
146- """
147- Calculate the lower bound of the squared Euclidean distance between a point
148- and a node in a KD-tree.
149- """
150- result = 0.0
151- dim = pt .shape [0 ]
152- for i in range (dim ):
153- d_lo = upper [i ] - pt [i ] if upper [i ] > pt [i ] else 0.0
154- d_hi = pt [i ] - lower [i ] if pt [i ] > lower [i ] else 0.0
155- d = d_lo + d_hi
156- result += d * d
157-
158- return result
159-
160-
161- @numba .njit (
162- locals = {
163- "node" : numba .types .intp ,
164- "left" : numba .types .intp ,
165- "right" : numba .types .intp ,
166- "d" : numba .types .float32 ,
167- "idx" : numba .types .uint32 ,
168- }
169- )
170- def tree_query_recursion (tree , node , point , heap_p , heap_i , dist_lower_bound ):
171- """
172- Traverses a KD-tree recursively to find $k$ nearest points. Updates heap
173- with neighbors inplace.
174- """
175- node_info = tree .node_data [node ]
176-
177- # ------------------------------------------------------------
178- # Case 1: query point is outside node radius: trim node from the query
179- if dist_lower_bound > heap_p [0 ]:
180- return
181-
182- # ------------------------------------------------------------
183- # Case 2: this is a leaf node. Update set of nearby points
184- elif node_info .is_leaf :
185- for i in range (node_info .idx_start , node_info .idx_end ):
186- idx = tree .idx_array [i ]
187- d = rdist (point , tree .data [idx ])
188- if d < heap_p [0 ]:
189- simple_heap_push (heap_p , heap_i , d , idx )
190-
191- # ------------------------------------------------------------
192- # Case 3: Node is not a leaf. Recursively query subnodes starting with the
193- # closest
194- else :
195- left = 2 * node + 1
196- right = left + 1
197- dist_lower_bound_left = point_to_node_lower_bound_rdist (
198- tree .node_bounds [0 , left ], tree .node_bounds [1 , left ], point
199- )
200- dist_lower_bound_right = point_to_node_lower_bound_rdist (
201- tree .node_bounds [0 , right ], tree .node_bounds [1 , right ], point
202- )
203-
204- # recursively query subnodes
205- if dist_lower_bound_left <= dist_lower_bound_right :
206- tree_query_recursion (
207- tree , left , point , heap_p , heap_i , dist_lower_bound_left
208- )
209- tree_query_recursion (
210- tree , right , point , heap_p , heap_i , dist_lower_bound_right
211- )
212- else :
213- tree_query_recursion (
214- tree , right , point , heap_p , heap_i , dist_lower_bound_right
215- )
216- tree_query_recursion (
217- tree , left , point , heap_p , heap_i , dist_lower_bound_left
218- )
219- return
220-
221-
222- @numba .njit (parallel = True )
223- def parallel_tree_query (tree , data , k = 10 , output_rdist = False ):
224- """
225- Queries the KDTree for the k nearest neighbors of the given data points in
226- parallel.
227- """
228- result = (
229- np .full ((data .shape [0 ], k ), np .inf , dtype = np .float32 ),
230- np .full ((data .shape [0 ], k ), - 1 , dtype = np .int32 ),
231- )
232-
233- for i in numba .prange (data .shape [0 ]):
234- distance_lower_bound = point_to_node_lower_bound_rdist (
235- tree .node_bounds [0 , 0 ], tree .node_bounds [1 , 0 ], data [i ]
236- )
237- heap_priorities , heap_indices = result [0 ][i ], result [1 ][i ]
238- tree_query_recursion (
239- tree , 0 , data [i ], heap_priorities , heap_indices , distance_lower_bound
240- )
241-
242- if output_rdist :
243- return deheap_sort (result [0 ], result [1 ])
244- else :
245- return deheap_sort (np .sqrt (result [0 ]), result [1 ])
0 commit comments