Skip to content

Commit 52320f2

Browse files
author
Jelmer Bot
committed
Merge to duplicate files for k mst
2 parents 35e8da2 + e46f19c commit 52320f2

File tree

1 file changed

+245
-0
lines changed

1 file changed

+245
-0
lines changed

multi_mst/k_mst/heap.py

Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
import numba
2+
import numpy as np
3+
from collections import namedtuple
4+
5+
NumbaKDTree = namedtuple("KDTree", ["data", "idx_array", "node_data", "node_bounds"])
6+
7+
8+
def kdtree_to_numba(sklearn_kdtree):
9+
"""Convert a scikit-learn KDTree object to a NumbaKDTree object."""
10+
data, idx_array, node_data, node_bounds = sklearn_kdtree.get_arrays()
11+
return NumbaKDTree(data, idx_array, node_data, node_bounds)
12+
13+
14+
@numba.njit(
15+
[
16+
"f4(f4[::1],f4[::1])",
17+
"f8(f8[::1],f8[::1])",
18+
"f8(f4[::1],f8[::1])",
19+
],
20+
fastmath=True,
21+
locals={
22+
"dim": numba.types.intp,
23+
"i": numba.types.uint16,
24+
},
25+
)
26+
def rdist(x, y):
27+
"""Computes the squared Euclidean distance between two points."""
28+
result = 0.0
29+
dim = x.shape[0]
30+
for i in range(dim):
31+
diff = x[i] - y[i]
32+
result += diff * diff
33+
34+
return result
35+
36+
37+
@numba.njit(
38+
[
39+
"void(f4[::1],i4[::1],f4,i4)",
40+
"void(f8[::1],i4[::1],f8,i4)",
41+
],
42+
fastmath=True,
43+
locals={
44+
"size": numba.types.intp,
45+
"i": numba.types.uint16,
46+
"ic1": numba.types.uint16,
47+
"ic2": numba.types.uint16,
48+
"i_swap": numba.types.uint16,
49+
},
50+
)
51+
def simple_heap_push(priorities, indices, p, n):
52+
"""Inserts value (index) in to priority heap (distance)."""
53+
# if p >= priorities[0]:
54+
# return 0
55+
56+
size = priorities.shape[0]
57+
58+
# insert val at position zero
59+
priorities[0] = p
60+
indices[0] = n
61+
62+
# descend the heap, swapping values until the max heap criterion is met
63+
i = 0
64+
while True:
65+
ic1 = 2 * i + 1
66+
ic2 = ic1 + 1
67+
68+
if ic1 >= size:
69+
break
70+
elif ic2 >= size:
71+
if priorities[ic1] > p:
72+
i_swap = ic1
73+
else:
74+
break
75+
elif priorities[ic1] >= priorities[ic2]:
76+
if p < priorities[ic1]:
77+
i_swap = ic1
78+
else:
79+
break
80+
else:
81+
if p < priorities[ic2]:
82+
i_swap = ic2
83+
else:
84+
break
85+
86+
priorities[i] = priorities[i_swap]
87+
indices[i] = indices[i_swap]
88+
89+
i = i_swap
90+
91+
priorities[i] = p
92+
indices[i] = n
93+
94+
# return 1
95+
96+
97+
@numba.njit()
98+
def siftdown(heap1, heap2, elt):
99+
"""Moves the element at index elt to its correct position in a heap."""
100+
while elt * 2 + 1 < heap1.shape[0]:
101+
left_child = elt * 2 + 1
102+
right_child = left_child + 1
103+
swap = elt
104+
105+
if heap1[swap] < heap1[left_child]:
106+
swap = left_child
107+
108+
if right_child < heap1.shape[0] and heap1[swap] < heap1[right_child]:
109+
swap = right_child
110+
111+
if swap == elt:
112+
break
113+
else:
114+
heap1[elt], heap1[swap] = heap1[swap], heap1[elt]
115+
heap2[elt], heap2[swap] = heap2[swap], heap2[elt]
116+
elt = swap
117+
118+
119+
@numba.njit(parallel=True)
120+
def deheap_sort(distances, indices):
121+
"""Sorts the heaps and returns the sorted distances and indices."""
122+
for i in numba.prange(indices.shape[0]):
123+
# starting from the end of the array and moving back
124+
for j in range(indices.shape[1] - 1, 0, -1):
125+
indices[i, 0], indices[i, j] = indices[i, j], indices[i, 0]
126+
distances[i, 0], distances[i, j] = distances[i, j], distances[i, 0]
127+
128+
siftdown(distances[i, :j], indices[i, :j], 0)
129+
130+
return distances, indices
131+
132+
133+
@numba.njit(
134+
[
135+
"f4(f4[::1],f4[::1],f4[::1])",
136+
"f4(f8[::1],f8[::1],f4[::1])",
137+
"f4(f8[::1],f8[::1],f8[::1])",
138+
],
139+
fastmath=True,
140+
locals={
141+
"dim": numba.types.intp,
142+
"i": numba.types.uint16,
143+
},
144+
)
145+
def point_to_node_lower_bound_rdist(upper, lower, pt):
146+
"""
147+
Calculate the lower bound of the squared Euclidean distance between a point
148+
and a node in a KD-tree.
149+
"""
150+
result = 0.0
151+
dim = pt.shape[0]
152+
for i in range(dim):
153+
d_lo = upper[i] - pt[i] if upper[i] > pt[i] else 0.0
154+
d_hi = pt[i] - lower[i] if pt[i] > lower[i] else 0.0
155+
d = d_lo + d_hi
156+
result += d * d
157+
158+
return result
159+
160+
161+
@numba.njit(
162+
locals={
163+
"node": numba.types.intp,
164+
"left": numba.types.intp,
165+
"right": numba.types.intp,
166+
"d": numba.types.float32,
167+
"idx": numba.types.uint32,
168+
}
169+
)
170+
def tree_query_recursion(tree, node, point, heap_p, heap_i, dist_lower_bound):
171+
"""
172+
Traverses a KD-tree recursively to find $k$ nearest points. Updates heap
173+
with neighbors inplace.
174+
"""
175+
node_info = tree.node_data[node]
176+
177+
# ------------------------------------------------------------
178+
# Case 1: query point is outside node radius: trim node from the query
179+
if dist_lower_bound > heap_p[0]:
180+
return
181+
182+
# ------------------------------------------------------------
183+
# Case 2: this is a leaf node. Update set of nearby points
184+
elif node_info.is_leaf:
185+
for i in range(node_info.idx_start, node_info.idx_end):
186+
idx = tree.idx_array[i]
187+
d = rdist(point, tree.data[idx])
188+
if d < heap_p[0]:
189+
simple_heap_push(heap_p, heap_i, d, idx)
190+
191+
# ------------------------------------------------------------
192+
# Case 3: Node is not a leaf. Recursively query subnodes starting with the
193+
# closest
194+
else:
195+
left = 2 * node + 1
196+
right = left + 1
197+
dist_lower_bound_left = point_to_node_lower_bound_rdist(
198+
tree.node_bounds[0, left], tree.node_bounds[1, left], point
199+
)
200+
dist_lower_bound_right = point_to_node_lower_bound_rdist(
201+
tree.node_bounds[0, right], tree.node_bounds[1, right], point
202+
)
203+
204+
# recursively query subnodes
205+
if dist_lower_bound_left <= dist_lower_bound_right:
206+
tree_query_recursion(
207+
tree, left, point, heap_p, heap_i, dist_lower_bound_left
208+
)
209+
tree_query_recursion(
210+
tree, right, point, heap_p, heap_i, dist_lower_bound_right
211+
)
212+
else:
213+
tree_query_recursion(
214+
tree, right, point, heap_p, heap_i, dist_lower_bound_right
215+
)
216+
tree_query_recursion(
217+
tree, left, point, heap_p, heap_i, dist_lower_bound_left
218+
)
219+
return
220+
221+
222+
@numba.njit(parallel=True)
223+
def parallel_tree_query(tree, data, k=10, output_rdist=False):
224+
"""
225+
Queries the KDTree for the k nearest neighbors of the given data points in
226+
parallel.
227+
"""
228+
result = (
229+
np.full((data.shape[0], k), np.inf, dtype=np.float32),
230+
np.full((data.shape[0], k), -1, dtype=np.int32),
231+
)
232+
233+
for i in numba.prange(data.shape[0]):
234+
distance_lower_bound = point_to_node_lower_bound_rdist(
235+
tree.node_bounds[0, 0], tree.node_bounds[1, 0], data[i]
236+
)
237+
heap_priorities, heap_indices = result[0][i], result[1][i]
238+
tree_query_recursion(
239+
tree, 0, data[i], heap_priorities, heap_indices, distance_lower_bound
240+
)
241+
242+
if output_rdist:
243+
return deheap_sort(result[0], result[1])
244+
else:
245+
return deheap_sort(np.sqrt(result[0]), result[1])

0 commit comments

Comments
 (0)