Skip to content

Commit caf94c6

Browse files
authored
Merge pull request TutteInstitute#24 from TutteInstitute/sample_weights
Add support for sample weights
2 parents 163e167 + 30e1c97 commit caf94c6

File tree

4 files changed

+182
-54
lines changed

4 files changed

+182
-54
lines changed

fast_hdbscan/boruvka.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -247,22 +247,44 @@ def initialize_boruvka_from_knn(knn_indices, knn_distances, core_distances, disj
247247
return result[:result_idx]
248248

249249

250-
def parallel_boruvka(tree, min_samples=10):
250+
@numba.njit(parallel=True)
251+
def sample_weight_core_distance(distances, neighbors, sample_weights, min_samples):
252+
core_distances = np.zeros(distances.shape[0], dtype=np.float32)
253+
for i in numba.prange(distances.shape[0]):
254+
total_weight = 0.0
255+
j = 0
256+
while total_weight < min_samples and j < neighbors.shape[1]:
257+
total_weight += sample_weights[neighbors[i, j]]
258+
j += 1
259+
260+
core_distances[i] = distances[i, j - 1]
261+
262+
return core_distances
263+
264+
def parallel_boruvka(tree, min_samples=10, sample_weights=None):
251265
components_disjoint_set = ds_rank_create(tree.data.shape[0])
252266
point_components = np.arange(tree.data.shape[0])
253267
node_components = np.full(tree.node_data.shape[0], -1)
254268
n_components = point_components.shape[0]
255269

256-
if min_samples > 1:
257-
distances, neighbors = parallel_tree_query(tree, tree.data, k=min_samples + 1, output_rdist=True)
258-
core_distances = distances.T[-1]
270+
if sample_weights is not None:
271+
mean_sample_weight = np.mean(sample_weights)
272+
expected_neighbors = min_samples / mean_sample_weight
273+
distances, neighbors = parallel_tree_query(tree, tree.data, k=int(2 * expected_neighbors))
274+
core_distances = sample_weight_core_distance(distances, neighbors, sample_weights, min_samples)
259275
edges = initialize_boruvka_from_knn(neighbors, distances, core_distances, components_disjoint_set)
260276
update_component_vectors(tree, components_disjoint_set, node_components, point_components)
261277
else:
262-
core_distances = np.zeros(tree.data.shape[0], dtype=np.float32)
263-
distances, neighbors = parallel_tree_query(tree, tree.data, k=2)
264-
edges = initialize_boruvka_from_knn(neighbors, distances, core_distances, components_disjoint_set)
265-
update_component_vectors(tree, components_disjoint_set, node_components, point_components)
278+
if min_samples > 1:
279+
distances, neighbors = parallel_tree_query(tree, tree.data, k=min_samples + 1, output_rdist=True)
280+
core_distances = distances.T[-1]
281+
edges = initialize_boruvka_from_knn(neighbors, distances, core_distances, components_disjoint_set)
282+
update_component_vectors(tree, components_disjoint_set, node_components, point_components)
283+
else:
284+
core_distances = np.zeros(tree.data.shape[0], dtype=np.float32)
285+
distances, neighbors = parallel_tree_query(tree, tree.data, k=2)
286+
edges = initialize_boruvka_from_knn(neighbors, distances, core_distances, components_disjoint_set)
287+
update_component_vectors(tree, components_disjoint_set, node_components, point_components)
266288

267289
while n_components > 1:
268290
candidate_distances, candidate_indices = boruvka_tree_query(tree, node_components, point_components,

fast_hdbscan/cluster_trees.py

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,16 @@ def create_linkage_merge_data(base_size):
2222
return LinkageMergeData(parent, size, next_parent)
2323

2424

25+
@numba.njit()
26+
def create_linkage_merge_data_w_sample_weights(sample_weights):
27+
base_size = sample_weights.shape[0]
28+
parent = np.full(2 * base_size - 1, -1, dtype=np.intp)
29+
size = np.concatenate((sample_weights, np.zeros(base_size - 1, dtype=np.float32)))
30+
next_parent = np.array([base_size], dtype=np.intp)
31+
32+
return LinkageMergeData(parent, size, next_parent)
33+
34+
2535
@numba.njit()
2636
def linkage_merge_find(linkage_merge, node):
2737
relabel = node
@@ -78,6 +88,36 @@ def mst_to_linkage_tree(sorted_mst):
7888
return result
7989

8090

91+
@numba.njit()
92+
def mst_to_linkage_tree_w_sample_weights(sorted_mst, sample_weights):
93+
result = np.empty((sorted_mst.shape[0], sorted_mst.shape[1] + 1))
94+
95+
linkage_merge = create_linkage_merge_data_w_sample_weights(sample_weights)
96+
97+
for index in range(sorted_mst.shape[0]):
98+
99+
left = np.intp(sorted_mst[index, 0])
100+
right = np.intp(sorted_mst[index, 1])
101+
delta = sorted_mst[index, 2]
102+
103+
left_component = linkage_merge_find(linkage_merge, left)
104+
right_component = linkage_merge_find(linkage_merge, right)
105+
106+
if left_component > right_component:
107+
result[index][0] = left_component
108+
result[index][1] = right_component
109+
else:
110+
result[index][1] = left_component
111+
result[index][0] = right_component
112+
113+
result[index][2] = delta
114+
result[index][3] = linkage_merge.size[left_component] + linkage_merge.size[right_component]
115+
116+
linkage_merge_join(linkage_merge, left_component, right_component)
117+
118+
return result
119+
120+
81121
@numba.njit()
82122
def bfs_from_hierarchy(hierarchy, bfs_root, num_points):
83123
to_process = [bfs_root]
@@ -121,7 +161,7 @@ def eliminate_branch(branch_node, parent_node, lambda_value, parents, children,
121161

122162

123163
@numba.njit(fastmath=True)
124-
def condense_tree(hierarchy, min_cluster_size=10):
164+
def condense_tree(hierarchy, min_cluster_size=10, sample_weights=None):
125165
root = 2 * hierarchy.shape[0]
126166
num_points = hierarchy.shape[0] + 1
127167
next_label = num_points + 1
@@ -134,10 +174,13 @@ def condense_tree(hierarchy, min_cluster_size=10):
134174
parents = np.ones(root, dtype=np.int64)
135175
children = np.empty(root, dtype=np.int64)
136176
lambdas = np.empty(root, dtype=np.float32)
137-
sizes = np.ones(root, dtype=np.int64)
177+
sizes = np.ones(root, dtype=np.float32)
138178

139179
ignore = np.zeros(root + 1, dtype=np.bool_) # 'bool' is no longer an attribute of 'numpy'
140180

181+
if sample_weights is None:
182+
sample_weights = np.ones(num_points, dtype=np.float32)
183+
141184
idx = 0
142185

143186
for node in node_list:
@@ -153,8 +196,8 @@ def condense_tree(hierarchy, min_cluster_size=10):
153196
else:
154197
lambda_value = np.inf
155198

156-
left_count = np.int64(hierarchy[left - num_points, 3]) if left >= num_points else 1
157-
right_count = np.int64(hierarchy[right - num_points, 3]) if right >= num_points else 1
199+
left_count = np.float32(hierarchy[left - num_points, 3]) if left >= num_points else sample_weights[left]
200+
right_count = np.float32(hierarchy[right - num_points, 3]) if right >= num_points else sample_weights[right]
158201

159202
# The logic here is in a strange order, but it has non-trivial performance gains ...
160203
# The most common case by far is a singleton on the left; and cluster on the right take care of this separately
@@ -391,7 +434,7 @@ def extract_clusters_bcubed(condensed_tree, cluster_tree, label_indices, allow_v
391434

392435
@numba.njit()
393436
def score_condensed_tree_nodes(condensed_tree):
394-
result = {0: 0.0 for i in range(0)}
437+
result = {0: np.float32(0.0) for i in range(0)}
395438

396439
for i in range(condensed_tree.parent.shape[0]):
397440
parent = condensed_tree.parent[i]
@@ -559,13 +602,16 @@ def get_cluster_labelling_at_cut(linkage_tree, cut, min_cluster_size):
559602
def get_cluster_label_vector(
560603
tree,
561604
clusters,
562-
cluster_selection_epsilon
605+
cluster_selection_epsilon,
606+
n_samples,
563607
):
608+
if len(tree.parent) == 0:
609+
return np.full(n_samples, -1, dtype=np.intp)
564610
root_cluster = tree.parent.min()
565-
result = np.empty(root_cluster, dtype=np.intp)
611+
result = np.full(n_samples, -1, dtype=np.intp)
566612
cluster_label_map = {c: n for n, c in enumerate(np.sort(clusters))}
567613

568-
disjoint_set = ds_rank_create(tree.parent.max() + 1)
614+
disjoint_set = ds_rank_create(max(tree.parent.max() + 1, tree.child.max() + 1))
569615
clusters = set(clusters)
570616

571617
for n in range(tree.parent.shape[0]):

0 commit comments

Comments
 (0)