Skip to content

Commit c0a05b3

Browse files
committed
Initial naive boruvka algorithm. Lots of work and improvements to come.
1 parent 245fff3 commit c0a05b3

File tree

1 file changed

+267
-0
lines changed

1 file changed

+267
-0
lines changed

hdbscan/_hdbscan_boruvka.pyx

Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
#cython: boundscheck=False, nonecheck=False, profile=True
2+
# Minimum spanning tree single linkage implementation for hdbscan
3+
# Authors: Leland McInnes
4+
# License: 3-clause BSD
5+
6+
cimport cython
7+
8+
import numpy as np
9+
cimport numpy as np
10+
11+
from libc.float cimport DBL_MAX
12+
13+
from scipy.spatial.distance import cdist, pdist, squareform
14+
15+
cdef points(tree, node, data, indices=False):
16+
node_data = tree.node_data[node]
17+
if not node_data['is_leaf']:
18+
19+
if indices:
20+
return np.array([]), np.array([])
21+
else:
22+
return np.array([])
23+
24+
else:
25+
idx_start = node_data['idx_start']
26+
idx_end = node_data['idx_end']
27+
selection = tree.idx_array[idx_start:idx_end]
28+
if indices:
29+
return data[selection], selection
30+
else:
31+
return data[selection]
32+
33+
cdef descendant_points(tree, node, data):
34+
node_data = tree.node_data[node]
35+
idx_start = node_data['idx_start']
36+
idx_end = node_data['idx_end']
37+
return data[tree.idx_array[idx_start:idx_end]]
38+
39+
cdef inline list children(object tree, long long node):
40+
node_data = tree.node_data[node]
41+
if node_data['is_leaf']:
42+
return []
43+
else:
44+
return [2 * node + 1, 2 * node + 2]
45+
46+
cdef inline double min_dist_dual(object tree1,
47+
object tree2,
48+
long long node1,
49+
long long node2,
50+
np.ndarray[double, ndim=2] centroid_dist):
51+
dist_pt = centroid_dist[node1, node2]
52+
return max(0, (dist_pt - tree1.node_data[node1]['radius']
53+
- tree2.node_data[node2]['radius']))
54+
55+
cdef double max_child_distance(object tree, long long node, np.ndarray data):
56+
node_points = points(tree, node, data)
57+
if node_points.shape[0] > 0:
58+
centroid = tree.node_bounds[0, node]
59+
point_distances = cdist([centroid], node_points)[0]
60+
return np.max(point_distances)
61+
else:
62+
return 0.0
63+
64+
cdef double max_descendant_distance(object tree, long long node, np.ndarray data):
65+
node_points = descendant_points(tree, node, data)
66+
centroid = tree.node_bounds[0, node]
67+
point_distances = cdist([centroid], node_points)[0]
68+
return np.max(point_distances)
69+
70+
cdef class BoruvkaUnionFind (object):
71+
72+
cdef np.ndarray _data
73+
74+
def __init__(self, size):
75+
self._data = np.zeros((size, 2))
76+
self._data.T[0] = np.arange(size)
77+
78+
cpdef union_(self, long long x, long long y):
79+
cdef long long x_root = self.find(x)
80+
cdef long long y_root = self.find(y)
81+
82+
if self._data[x_root, 1] < self._data[y_root, 1]:
83+
self._data[x_root, 0] = y_root
84+
elif self._data[x_root, 1] > self._data[y_root, 1]:
85+
self._data[y_root, 0] = x_root
86+
else:
87+
self._data[y_root, 0] = x_root
88+
self._data[x_root, 1] += 1
89+
90+
return
91+
92+
cpdef find(self, long long x):
93+
if self._data[x, 0] != x:
94+
self._data[x, 0] = self.find(self._data[x, 0])
95+
return self._data[x, 0]
96+
97+
cpdef np.ndarray[np.int64_t, ndim=1] components(self):
98+
return self._data.T[0]
99+
100+
cdef class BoruvkaAlgorithm (object):
101+
102+
cdef object tree
103+
cdef np.ndarray _data
104+
cdef np.ndarray bounds
105+
cdef dict component_of_point
106+
cdef dict component_of_node
107+
cdef dict candidate_neighbor
108+
cdef dict candidate_point
109+
cdef dict candidate_distance
110+
cdef object component_union_find
111+
cdef set edges
112+
113+
cdef np.ndarray _centroid_distances
114+
115+
def __init__(self, tree):
116+
self.tree = tree
117+
self._data = np.array(tree.data)
118+
self.bounds = np.zeros(tree.node_bounds[0].shape[0])
119+
self.component_of_point = {}
120+
self.component_of_node = {}
121+
self.candidate_neighbor = {}
122+
self.candidate_point = {}
123+
self.candidate_distance = {}
124+
self.component_union_find = BoruvkaUnionFind(tree.data.shape[0])
125+
self.edges = set([])
126+
127+
128+
self._centroid_distances = squareform(pdist(tree.node_bounds[0]))
129+
self._compute_bounds()
130+
self._initialize_components()
131+
132+
cdef _compute_bounds(self):
133+
nn_dist = self.tree.query(self.tree.data, 2)[0][:,-1]
134+
135+
for n in range(self.tree.node_data.shape[0] - 1, -1, -1):
136+
if self.tree.node_data[n]['is_leaf']:
137+
node_points, node_point_indices = points(self.tree, n, self._data, indices=True)
138+
b1 = nn_dist[node_point_indices].max()
139+
b2 = (nn_dist[node_point_indices] + 2 * max_descendant_distance(self.tree, n, self._data)).min()
140+
self.bounds[n] = min(b1, b2)
141+
else:
142+
child_nodes = children(self.tree, n)
143+
lambda_children = np.array([max_descendant_distance(self.tree, c, self._data) for c in child_nodes])
144+
b1 = self.bounds[child_nodes].max()
145+
b2 = (self.bounds[child_nodes] + 2 * (max_descendant_distance(self.tree, n, self._data) - lambda_children)).min()
146+
if b2 > 0:
147+
self.bounds[n] = min(b1, b2)
148+
else:
149+
self.bounds[n] = b1
150+
151+
for n in range(1, self.tree.node_data.shape[0]):
152+
self.bounds[n] = min(self.bounds[n], self.bounds[(n - 1) // 2])
153+
154+
cdef _initialize_components(self):
155+
self.component_of_point = {n:n for n in range(self.tree.data.shape[0])}
156+
self.component_of_node = {n:-(n+1) for n in range(self.tree.node_data.shape[0])}
157+
self.candidate_neighbor = {n:None for n in range(self.tree.data.shape[0])}
158+
self.candidate_point = {n:None for n in range(self.tree.data.shape[0])}
159+
self.candidate_distance = {n:np.infty for n in range(self.tree.data.shape[0])}
160+
161+
cpdef score(self, node1, node2):
162+
node_dist = min_dist_dual(self.tree, self.tree, node1, node2, self._centroid_distances)
163+
if node_dist < self.bounds[node1]:
164+
if self.component_of_node[node1] == self.component_of_node[node2] and \
165+
self.component_of_node[node1] >= 0 and self.component_of_node[node2] >= 0:
166+
return np.infty
167+
else:
168+
return node_dist
169+
else:
170+
return np.infty
171+
172+
cdef base_case(self, long long p, long long q, double point_distance):
173+
174+
cdef long long component
175+
176+
component = self.component_of_point[p]
177+
if component != self.component_of_point[q] and \
178+
point_distance < self.candidate_distance[component]:
179+
self.candidate_distance[component] = point_distance
180+
self.candidate_neighbor[component] = q
181+
self.candidate_point[component] = p
182+
183+
return point_distance
184+
185+
cdef update_components(self):
186+
187+
cdef long long source
188+
cdef long long sink
189+
190+
components = np.unique(self.component_union_find.components())
191+
for component in components:
192+
source, sink = sorted([self.candidate_point[component],
193+
self.candidate_neighbor[component]])
194+
if source is None or sink is None:
195+
raise ValueError('Source or sink of edge is None!')
196+
self.edges.add((source, sink, self.candidate_distance[component]))
197+
self.component_union_find.union_(source, sink)
198+
self.candidate_distance[component] = np.infty
199+
200+
for n in range(self.tree.data.shape[0]):
201+
self.component_of_point[n] = self.component_union_find.find(n)
202+
203+
for n in range(self.tree.node_data.shape[0] - 1, -1, -1):
204+
if self.tree.node_data[n]['is_leaf']:
205+
components_of_points = np.array([self.component_of_point[p] for p in points(self.tree, n, self._data, indices=True)[1]])
206+
if np.all(components_of_points == components_of_points[0]):
207+
self.component_of_node[n] = components_of_points[0]
208+
else:
209+
child1, child2 = children(self.tree, n)
210+
if self.component_of_node[child1] == self. component_of_node[child2]:
211+
self.component_of_node[n] = self.component_of_node[child1]
212+
213+
components = np.unique(self.component_union_find.components())
214+
return components.shape[0]
215+
216+
cdef void dual_tree_traversal(self, long long node1, long long node2):
217+
218+
cdef np.ndarray[np.double_t, ndim=2] distances
219+
220+
cdef np.ndarray points1, point2
221+
# cdef np.ndarray point_indices1, point_indices2
222+
223+
cdef long long i
224+
cdef long long j
225+
226+
cdef long long p
227+
cdef long long q
228+
229+
cdef long long child1
230+
cdef long long child2
231+
232+
cdef double node_dist
233+
234+
if np.isinf(self.score(node1, node2)):
235+
return
236+
# node_dist = min_dist_dual(self.tree, self.tree, node1, node2, self._centroid_distances)
237+
# if node_dist < self.bounds[node1]:
238+
# if self.component_of_node[node1] == self.component_of_node[node2] and \
239+
# self.component_of_node[node1] >= 0 and self.component_of_node[node2] >= 0:
240+
# return
241+
# else:
242+
# return
243+
244+
245+
if self.tree.node_data[node1]['is_leaf'] and self.tree.node_data[node2]['is_leaf']:
246+
points1, point_indices1 = points(self.tree, node1, self._data, indices=True)
247+
points2, point_indices2 = points(self.tree, node2, self._data, indices=True)
248+
249+
distances = cdist(points1, points2)
250+
for i in range(point_indices1.shape[0]):
251+
for j in range(point_indices2.shape[0]):
252+
if distances[i, j] > 0:
253+
p = point_indices1[i]
254+
q = point_indices2[j]
255+
self.base_case(p, q, distances[i, j])
256+
else:
257+
for child1 in children(self.tree, node1):
258+
for child2 in children(self.tree, node2):
259+
self.dual_tree_traversal(child1, child2)
260+
261+
cpdef spanning_tree(self):
262+
num_components = self.tree.data.shape[0]
263+
while num_components > 1:
264+
self.dual_tree_traversal(0, 0)
265+
num_components = self.update_components()
266+
267+
return np.array(list(self.edges))

0 commit comments

Comments
 (0)