Skip to content

Commit 99d5cba

Browse files
author
Jelmer Bot
committed
add implementation (2)
1 parent 52320f2 commit 99d5cba

File tree

1 file changed

+19
-167
lines changed

1 file changed

+19
-167
lines changed

multi_mst/k_mst/heap.py

Lines changed: 19 additions & 167 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,9 @@
11
import numba
2-
import numpy as np
3-
from collections import namedtuple
4-
5-
NumbaKDTree = namedtuple("KDTree", ["data", "idx_array", "node_data", "node_bounds"])
6-
7-
8-
def kdtree_to_numba(sklearn_kdtree):
9-
"""Convert a scikit-learn KDTree object to a NumbaKDTree object."""
10-
data, idx_array, node_data, node_bounds = sklearn_kdtree.get_arrays()
11-
return NumbaKDTree(data, idx_array, node_data, node_bounds)
12-
13-
14-
@numba.njit(
15-
[
16-
"f4(f4[::1],f4[::1])",
17-
"f8(f8[::1],f8[::1])",
18-
"f8(f4[::1],f8[::1])",
19-
],
20-
fastmath=True,
21-
locals={
22-
"dim": numba.types.intp,
23-
"i": numba.types.uint16,
24-
},
25-
)
26-
def rdist(x, y):
27-
"""Computes the squared Euclidean distance between two points."""
28-
result = 0.0
29-
dim = x.shape[0]
30-
for i in range(dim):
31-
diff = x[i] - y[i]
32-
result += diff * diff
33-
34-
return result
35-
362

373
@numba.njit(
384
[
39-
"void(f4[::1],i4[::1],f4,i4)",
40-
"void(f8[::1],i4[::1],f8,i4)",
5+
"void(f4[::1],i4[::1],i4[::1],f4,i4,i4)",
6+
"void(f8[::1],i4[::1],i4[::1],f8,i4,i4)",
417
],
428
fastmath=True,
439
locals={
@@ -48,16 +14,14 @@ def rdist(x, y):
4814
"i_swap": numba.types.uint16,
4915
},
5016
)
51-
def simple_heap_push(priorities, indices, p, n):
52-
"""Inserts value (index) in to priority heap (distance)."""
53-
# if p >= priorities[0]:
54-
# return 0
55-
17+
def simple_edge_heap_push(priorities, sources, targets, p, s, t):
18+
"""Inserts value (source, target) in to priority heap (distance)."""
5619
size = priorities.shape[0]
5720

5821
# insert val at position zero
5922
priorities[0] = p
60-
indices[0] = n
23+
sources[0] = s
24+
targets[0] = t
6125

6226
# descend the heap, swapping values until the max heap criterion is met
6327
i = 0
@@ -84,18 +48,18 @@ def simple_heap_push(priorities, indices, p, n):
8448
break
8549

8650
priorities[i] = priorities[i_swap]
87-
indices[i] = indices[i_swap]
51+
sources[i] = sources[i_swap]
52+
targets[i] = targets[i_swap]
8853

8954
i = i_swap
9055

9156
priorities[i] = p
92-
indices[i] = n
93-
94-
# return 1
57+
sources[i] = s
58+
targets[i] = t
9559

9660

9761
@numba.njit()
98-
def siftdown(heap1, heap2, elt):
62+
def siftdown(heap1, heap2, heap3, elt):
9963
"""Moves the element at index elt to its correct position in a heap."""
10064
while elt * 2 + 1 < heap1.shape[0]:
10165
left_child = elt * 2 + 1
@@ -113,133 +77,21 @@ def siftdown(heap1, heap2, elt):
11377
else:
11478
heap1[elt], heap1[swap] = heap1[swap], heap1[elt]
11579
heap2[elt], heap2[swap] = heap2[swap], heap2[elt]
80+
heap3[elt], heap3[swap] = heap3[swap], heap3[elt]
11681
elt = swap
11782

11883

11984
@numba.njit(parallel=True)
120-
def deheap_sort(distances, indices):
85+
def deheap_sort_edges(distances, sources, targets):
12186
"""Sorts the heaps and returns the sorted distances and indices."""
122-
for i in numba.prange(indices.shape[0]):
87+
for i in numba.prange(distances.shape[0]):
12388
# starting from the end of the array and moving back
124-
for j in range(indices.shape[1] - 1, 0, -1):
125-
indices[i, 0], indices[i, j] = indices[i, j], indices[i, 0]
89+
for j in range(sources.shape[1] - 1, 0, -1):
90+
sources[i, 0], sources[i, j] = sources[i, j], sources[i, 0]
91+
targets[i, 0], targets[i, j] = targets[i, j], targets[i, 0]
12692
distances[i, 0], distances[i, j] = distances[i, j], distances[i, 0]
12793

128-
siftdown(distances[i, :j], indices[i, :j], 0)
129-
130-
return distances, indices
94+
siftdown(distances[i, :j], sources[i, :j], targets[i, :j], 0)
13195

96+
return distances, sources, targets
13297

133-
@numba.njit(
134-
[
135-
"f4(f4[::1],f4[::1],f4[::1])",
136-
"f4(f8[::1],f8[::1],f4[::1])",
137-
"f4(f8[::1],f8[::1],f8[::1])",
138-
],
139-
fastmath=True,
140-
locals={
141-
"dim": numba.types.intp,
142-
"i": numba.types.uint16,
143-
},
144-
)
145-
def point_to_node_lower_bound_rdist(upper, lower, pt):
146-
"""
147-
Calculate the lower bound of the squared Euclidean distance between a point
148-
and a node in a KD-tree.
149-
"""
150-
result = 0.0
151-
dim = pt.shape[0]
152-
for i in range(dim):
153-
d_lo = upper[i] - pt[i] if upper[i] > pt[i] else 0.0
154-
d_hi = pt[i] - lower[i] if pt[i] > lower[i] else 0.0
155-
d = d_lo + d_hi
156-
result += d * d
157-
158-
return result
159-
160-
161-
@numba.njit(
162-
locals={
163-
"node": numba.types.intp,
164-
"left": numba.types.intp,
165-
"right": numba.types.intp,
166-
"d": numba.types.float32,
167-
"idx": numba.types.uint32,
168-
}
169-
)
170-
def tree_query_recursion(tree, node, point, heap_p, heap_i, dist_lower_bound):
171-
"""
172-
Traverses a KD-tree recursively to find $k$ nearest points. Updates heap
173-
with neighbors inplace.
174-
"""
175-
node_info = tree.node_data[node]
176-
177-
# ------------------------------------------------------------
178-
# Case 1: query point is outside node radius: trim node from the query
179-
if dist_lower_bound > heap_p[0]:
180-
return
181-
182-
# ------------------------------------------------------------
183-
# Case 2: this is a leaf node. Update set of nearby points
184-
elif node_info.is_leaf:
185-
for i in range(node_info.idx_start, node_info.idx_end):
186-
idx = tree.idx_array[i]
187-
d = rdist(point, tree.data[idx])
188-
if d < heap_p[0]:
189-
simple_heap_push(heap_p, heap_i, d, idx)
190-
191-
# ------------------------------------------------------------
192-
# Case 3: Node is not a leaf. Recursively query subnodes starting with the
193-
# closest
194-
else:
195-
left = 2 * node + 1
196-
right = left + 1
197-
dist_lower_bound_left = point_to_node_lower_bound_rdist(
198-
tree.node_bounds[0, left], tree.node_bounds[1, left], point
199-
)
200-
dist_lower_bound_right = point_to_node_lower_bound_rdist(
201-
tree.node_bounds[0, right], tree.node_bounds[1, right], point
202-
)
203-
204-
# recursively query subnodes
205-
if dist_lower_bound_left <= dist_lower_bound_right:
206-
tree_query_recursion(
207-
tree, left, point, heap_p, heap_i, dist_lower_bound_left
208-
)
209-
tree_query_recursion(
210-
tree, right, point, heap_p, heap_i, dist_lower_bound_right
211-
)
212-
else:
213-
tree_query_recursion(
214-
tree, right, point, heap_p, heap_i, dist_lower_bound_right
215-
)
216-
tree_query_recursion(
217-
tree, left, point, heap_p, heap_i, dist_lower_bound_left
218-
)
219-
return
220-
221-
222-
@numba.njit(parallel=True)
223-
def parallel_tree_query(tree, data, k=10, output_rdist=False):
224-
"""
225-
Queries the KDTree for the k nearest neighbors of the given data points in
226-
parallel.
227-
"""
228-
result = (
229-
np.full((data.shape[0], k), np.inf, dtype=np.float32),
230-
np.full((data.shape[0], k), -1, dtype=np.int32),
231-
)
232-
233-
for i in numba.prange(data.shape[0]):
234-
distance_lower_bound = point_to_node_lower_bound_rdist(
235-
tree.node_bounds[0, 0], tree.node_bounds[1, 0], data[i]
236-
)
237-
heap_priorities, heap_indices = result[0][i], result[1][i]
238-
tree_query_recursion(
239-
tree, 0, data[i], heap_priorities, heap_indices, distance_lower_bound
240-
)
241-
242-
if output_rdist:
243-
return deheap_sort(result[0], result[1])
244-
else:
245-
return deheap_sort(np.sqrt(result[0]), result[1])

0 commit comments

Comments
 (0)