Skip to content

Commit f411d5c

Browse files
committed
add lensed clusters from knn-mst union
1 parent 821ff98 commit f411d5c

File tree

4 files changed

+328
-23
lines changed

4 files changed

+328
-23
lines changed

fast_hdbscan/boruvka.py

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,33 @@
44
from .disjoint_set import ds_rank_create, ds_find, ds_union_by_rank
55
from .numba_kdtree import parallel_tree_query, rdist, point_to_node_lower_bound_rdist
66

7-
@numba.njit(locals={"i": numba.types.int64})
8-
def merge_components(disjoint_set, candidate_neighbors, candidate_neighbor_distances, point_components):
9-
component_edges = {np.int64(0): (np.int64(0), np.int64(1), np.float32(0.0)) for i in range(0)}
7+
8+
@numba.njit(locals={"parent": numba.types.int32})
9+
def select_components(candidate_distances, candidate_neighbors, point_components):
10+
component_edges = {np.int64(0): (np.int32(0), np.int32(1), np.float32(0.0)) for i in range(0)}
1011

1112
# Find the best edges from each component
12-
for i in range(candidate_neighbors.shape[0]):
13-
from_component = np.int64(point_components[i])
13+
for parent, (distance, neighbor, from_component) in enumerate(
14+
zip(candidate_distances, candidate_neighbors, point_components)
15+
):
1416
if from_component in component_edges:
15-
if candidate_neighbor_distances[i] < component_edges[from_component][2]:
16-
component_edges[from_component] = (np.int64(i), np.int64(candidate_neighbors[i]), candidate_neighbor_distances[i])
17+
if distance < component_edges[from_component][2]:
18+
component_edges[from_component] = (parent, neighbor, distance)
1719
else:
18-
component_edges[from_component] = (np.int64(i), np.int64(candidate_neighbors[i]), candidate_neighbor_distances[i])
20+
component_edges[from_component] = (parent, neighbor, distance)
21+
22+
return component_edges
23+
1924

25+
@numba.njit()
26+
def merge_components(disjoint_set, component_edges):
2027
result = np.empty((len(component_edges), 3), dtype=np.float64)
2128
result_idx = 0
2229

2330
# Add the best edges to the edge set and merge the relevant components
2431
for edge in component_edges.values():
25-
from_component = ds_find(disjoint_set, np.int32(edge[0]))
26-
to_component = ds_find(disjoint_set, np.int32(edge[1]))
32+
from_component = ds_find(disjoint_set, edge[0])
33+
to_component = ds_find(disjoint_set, edge[1])
2734
if from_component != to_component:
2835
result[result_idx] = (np.float64(edge[0]), np.float64(edge[1]), np.float64(edge[2]))
2936
result_idx += 1
@@ -34,10 +41,13 @@ def merge_components(disjoint_set, candidate_neighbors, candidate_neighbor_dista
3441

3542

3643
@numba.njit(parallel=True)
37-
def update_component_vectors(tree, disjoint_set, node_components, point_components):
44+
def update_point_components(disjoint_set, point_components):
3845
for i in numba.prange(point_components.shape[0]):
3946
point_components[i] = ds_find(disjoint_set, np.int32(i))
4047

48+
49+
@numba.njit()
50+
def update_node_components(tree, node_components, point_components):
4151
for i in range(tree.node_data.shape[0] - 1, -1, -1):
4252
node_info = tree.node_data[i]
4353

@@ -272,28 +282,28 @@ def parallel_boruvka(tree, min_samples=10, sample_weights=None):
272282
expected_neighbors = min_samples / mean_sample_weight
273283
distances, neighbors = parallel_tree_query(tree, tree.data, k=int(2 * expected_neighbors))
274284
core_distances = sample_weight_core_distance(distances, neighbors, sample_weights, min_samples)
275-
edges = initialize_boruvka_from_knn(neighbors, distances, core_distances, components_disjoint_set)
276-
update_component_vectors(tree, components_disjoint_set, node_components, point_components)
277285
else:
278286
if min_samples > 1:
279287
distances, neighbors = parallel_tree_query(tree, tree.data, k=min_samples + 1, output_rdist=True)
280288
core_distances = distances.T[-1]
281-
edges = initialize_boruvka_from_knn(neighbors, distances, core_distances, components_disjoint_set)
282-
update_component_vectors(tree, components_disjoint_set, node_components, point_components)
283289
else:
284290
core_distances = np.zeros(tree.data.shape[0], dtype=np.float32)
285291
distances, neighbors = parallel_tree_query(tree, tree.data, k=2)
286-
edges = initialize_boruvka_from_knn(neighbors, distances, core_distances, components_disjoint_set)
287-
update_component_vectors(tree, components_disjoint_set, node_components, point_components)
288292

289-
while n_components > 1:
293+
edges = [np.empty((0, 3), dtype=np.float64) for _ in range(0)]
294+
new_edges = initialize_boruvka_from_knn(neighbors, distances, core_distances, components_disjoint_set)
295+
while True:
296+
edges.append(new_edges)
297+
n_components -= new_edges.shape[0]
298+
if n_components == 1:
299+
break
300+
update_point_components(components_disjoint_set, point_components)
301+
update_node_components(tree, node_components, point_components)
290302
candidate_distances, candidate_indices = boruvka_tree_query(tree, node_components, point_components,
291303
core_distances)
292-
new_edges = merge_components(components_disjoint_set, candidate_indices, candidate_distances, point_components)
293-
update_component_vectors(tree, components_disjoint_set, node_components, point_components)
294-
295-
edges = np.vstack((edges, new_edges))
296-
n_components = np.unique(point_components).shape[0]
304+
component_edges = select_components(candidate_distances, candidate_indices, point_components)
305+
new_edges = merge_components(components_disjoint_set, component_edges)
297306

307+
edges = np.vstack(edges)
298308
edges[:, 2] = np.sqrt(edges.T[2])
299309
return edges, neighbors[:, 1:], np.sqrt(core_distances)

fast_hdbscan/cluster_trees.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,13 @@ def eliminate_branch(branch_node, parent_node, lambda_value, parents, children,
160160
CondensedTree = namedtuple('CondensedTree', ['parent', 'child', 'lambda_val', 'child_size'])
161161

162162

163+
@numba.njit()
164+
def empty_condensed_tree():
165+
parents = np.empty(shape=0, dtype=np.intp)
166+
others = np.empty(shape=0, dtype=np.float32)
167+
return CondensedTree(parents, parents, others, others)
168+
169+
163170
@numba.njit(fastmath=True)
164171
def condense_tree(hierarchy, min_cluster_size=10, max_cluster_size=np.inf, sample_weights=None):
165172
root = 2 * hierarchy.shape[0]

fast_hdbscan/core_graph.py

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
import numba
2+
import numpy as np
3+
from collections import namedtuple
4+
5+
from .disjoint_set import ds_rank_create
6+
from .hdbscan import clusters_from_spanning_tree
7+
from .cluster_trees import empty_condensed_tree
8+
from .boruvka import merge_components, update_point_components
9+
10+
CoreGraph = namedtuple("CoreGraph", ["weights", "distances", "indices", "indptr"])
11+
12+
13+
@numba.njit(parallel=True)
14+
def knn_mst_union(neighbors, core_distances, min_spanning_tree, lens_values):
15+
# List of dictionaries of child: (weight, distance)
16+
graph = [
17+
{np.int32(0): (np.float64(0.0), np.float64(0.0)) for _ in range(0)}
18+
for _ in range(neighbors.shape[0])
19+
]
20+
21+
# Add knn edges
22+
for point in numba.prange(len(core_distances)):
23+
children = graph[point]
24+
parent_lens = lens_values[point]
25+
parent_dist = core_distances[point]
26+
for child in neighbors[point]:
27+
if child < 0:
28+
continue
29+
children[child] = (
30+
max(parent_lens, lens_values[child]),
31+
max(parent_dist, core_distances[child]),
32+
)
33+
34+
# Add non-knn mst edges
35+
for parent, child, distance in min_spanning_tree:
36+
parent = np.int32(parent)
37+
child = np.int32(child)
38+
children = graph[parent]
39+
if child in children:
40+
continue
41+
children[child] = (max(lens_values[parent], lens_values[child]), distance)
42+
43+
return graph
44+
45+
46+
@numba.njit(parallel=True)
47+
def sort_by_lens(graph):
48+
for point in numba.prange(len(graph)):
49+
graph[point] = {
50+
k: v for k, v in sorted(graph[point].items(), key=lambda item: item[1][0])
51+
}
52+
return graph
53+
54+
55+
@numba.njit(parallel=True)
56+
def apply_lens(core_graph, lens_values):
57+
# Apply new lens to the graph
58+
for point in numba.prange(len(lens_values)):
59+
children = core_graph[point]
60+
point_lens = lens_values[point]
61+
for child, value in children.items():
62+
children[child] = (max(point_lens, lens_values[child]), value[1])
63+
return sort_by_lens(core_graph)
64+
65+
66+
@numba.njit()
67+
def flatten_to_csr(graph):
68+
# Count children to form indptr
69+
num_points = len(graph)
70+
indptr = np.empty(num_points + 1, dtype=np.int32)
71+
indptr[0] = 0
72+
for i, children in enumerate(graph):
73+
indptr[i + 1] = indptr[i] + len(children)
74+
75+
# Flatten children to form indices, weights, and distances
76+
weights = np.empty(indptr[-1], dtype=np.float32)
77+
distances = np.empty(indptr[-1], dtype=np.float32)
78+
indices = np.empty(indptr[-1], dtype=np.int32)
79+
for point in numba.prange(num_points):
80+
start = indptr[point]
81+
children = graph[point]
82+
for j, (child, (weight, distance)) in enumerate(children.items()):
83+
weights[start + j] = weight
84+
distances[start + j] = distance
85+
indices[start + j] = child
86+
87+
# Return as named csr tuple
88+
return CoreGraph(weights, distances, indices, indptr)
89+
90+
91+
@numba.njit(locals={"parent": numba.types.int32})
92+
def select_components(graph, point_components):
93+
component_edges = {
94+
np.int64(0): (np.int32(0), np.int32(1), np.float32(0.0)) for _ in range(0)
95+
}
96+
97+
# Find the best edges from each component
98+
for parent, (children, from_component) in enumerate(zip(graph, point_components)):
99+
if len(children) == 0:
100+
continue
101+
neighbor = next(iter(children.keys()))
102+
distance = np.float32(children[neighbor][0])
103+
if from_component in component_edges:
104+
if distance < component_edges[from_component][2]:
105+
component_edges[from_component] = (parent, neighbor, distance)
106+
else:
107+
component_edges[from_component] = (parent, neighbor, distance)
108+
109+
return component_edges
110+
111+
112+
@numba.njit() # enabling parallel breaks this function
113+
def update_graph_components(graph, point_components):
114+
# deleting from dictionary during iteration breaks in numba.
115+
for point in numba.prange(len(graph)):
116+
graph[point] = {
117+
child: (weight, distance)
118+
for child, (weight, distance) in graph[point].items()
119+
if point_components[child] != point_components[point]
120+
}
121+
122+
123+
@numba.njit()
124+
def minimum_spanning_tree(graph, overwrite=False):
125+
"""
126+
Implements Boruvka on lod-style graph with multiple connected components.
127+
"""
128+
if not overwrite:
129+
graph = [children for children in graph]
130+
131+
disjoint_set = ds_rank_create(len(graph))
132+
point_components = np.arange(len(graph))
133+
n_components = len(point_components)
134+
135+
edges_list = [np.empty((0, 3), dtype=np.float64) for _ in range(0)]
136+
while n_components > 1:
137+
new_edges = merge_components(
138+
disjoint_set,
139+
select_components(graph, point_components),
140+
)
141+
if new_edges.shape[0] == 0:
142+
break
143+
144+
edges_list.append(new_edges)
145+
update_point_components(disjoint_set, point_components)
146+
update_graph_components(graph, point_components)
147+
n_components -= new_edges.shape[0]
148+
149+
counter = 0
150+
num_edges = sum([edges.shape[0] for edges in edges_list])
151+
result = np.empty((num_edges, 3), dtype=np.float64)
152+
for edges in edges_list:
153+
result[counter : counter + edges.shape[0]] = edges
154+
counter += edges.shape[0]
155+
return n_components, point_components, result
156+
157+
158+
@numba.njit()
159+
def core_graph_spanning_tree(neighbors, core_distances, min_spanning_tree, lens):
160+
graph = sort_by_lens(
161+
knn_mst_union(neighbors, core_distances, min_spanning_tree, lens)
162+
)
163+
return (*minimum_spanning_tree(graph), flatten_to_csr(graph))
164+
165+
166+
def core_graph_clusters(
167+
lens,
168+
neighbors,
169+
core_distances,
170+
min_spanning_tree,
171+
**kwargs,
172+
):
173+
num_components, component_labels, lensed_mst, graph = core_graph_spanning_tree(
174+
neighbors, core_distances, min_spanning_tree, lens
175+
)
176+
if num_components > 1:
177+
for i, label in enumerate(np.unique(component_labels)):
178+
component_labels[component_labels == label] = i
179+
return (
180+
component_labels,
181+
np.ones(len(component_labels), dtype=np.float32),
182+
np.empty((0, 4)),
183+
empty_condensed_tree(),
184+
lensed_mst,
185+
graph,
186+
)
187+
188+
return (
189+
*clusters_from_spanning_tree(lensed_mst, **kwargs),
190+
graph,
191+
)
192+
193+
194+
def core_graph_to_rec_array(graph):
195+
result = np.empty(
196+
graph.indptr[-1],
197+
dtype=[
198+
("parent", np.int32),
199+
("child", np.int32),
200+
("weight", np.float32),
201+
("distance", np.float32),
202+
],
203+
)
204+
result["parent"] = np.repeat(
205+
np.arange(len(graph.indptr) - 1), np.diff(graph.indptr)
206+
)
207+
result["child"] = graph.indices
208+
result["weight"] = graph.weights
209+
result["distance"] = graph.distances
210+
return result
211+
212+
213+
def core_graph_to_edge_list(graph):
214+
result = np.empty((graph.indptr[-1], 4), dtype=np.float64)
215+
result[:, 0] = np.repeat(np.arange(len(graph.indptr) - 1), np.diff(graph.indptr))
216+
result[:, 1] = graph.indices
217+
result[:, 2] = graph.weights
218+
result[:, 3] = graph.distances
219+
return result

0 commit comments

Comments
 (0)