Skip to content

Commit 147b489

Browse files
committed
Allow sample weights in HDBSCAN
1 parent c9d5a76 commit 147b489

File tree

3 files changed

+63
-17
lines changed

3 files changed

+63
-17
lines changed

fast_hdbscan/boruvka.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -247,22 +247,44 @@ def initialize_boruvka_from_knn(knn_indices, knn_distances, core_distances, disj
247247
return result[:result_idx]
248248

249249

250-
def parallel_boruvka(tree, min_samples=10):
250+
@numba.njit(parallel=True)
251+
def sample_weight_core_distance(distances, neighbors, sample_weights, min_samples):
252+
core_distances = np.zeros(distances.shape[0], dtype=np.float32)
253+
for i in numba.prange(distances.shape[0]):
254+
total_weight = 0.0
255+
j = 0
256+
while total_weight < min_samples and j < neighbors.shape[1]:
257+
total_weight += sample_weights[neighbors[i, j]]
258+
j += 1
259+
260+
core_distances[i] = distances[i, j - 1]
261+
262+
return core_distances
263+
264+
def parallel_boruvka(tree, min_samples=10, sample_weights=None):
251265
components_disjoint_set = ds_rank_create(tree.data.shape[0])
252266
point_components = np.arange(tree.data.shape[0])
253267
node_components = np.full(tree.node_data.shape[0], -1)
254268
n_components = point_components.shape[0]
255269

256-
if min_samples > 1:
257-
distances, neighbors = parallel_tree_query(tree, tree.data, k=min_samples + 1, output_rdist=True)
258-
core_distances = distances.T[-1]
270+
if sample_weights is not None:
271+
mean_sample_weight = np.mean(sample_weights)
272+
expected_neighbors = min_samples / mean_sample_weight
273+
distances, neighbors = parallel_tree_query(tree, tree.data, k=int(2 * expected_neighbors))
274+
core_distances = sample_weight_core_distance(distances, neighbors, sample_weights, min_samples)
259275
edges = initialize_boruvka_from_knn(neighbors, distances, core_distances, components_disjoint_set)
260276
update_component_vectors(tree, components_disjoint_set, node_components, point_components)
261277
else:
262-
core_distances = np.zeros(tree.data.shape[0], dtype=np.float32)
263-
distances, neighbors = parallel_tree_query(tree, tree.data, k=2)
264-
edges = initialize_boruvka_from_knn(neighbors, distances, core_distances, components_disjoint_set)
265-
update_component_vectors(tree, components_disjoint_set, node_components, point_components)
278+
if min_samples > 1:
279+
distances, neighbors = parallel_tree_query(tree, tree.data, k=min_samples + 1, output_rdist=True)
280+
core_distances = distances.T[-1]
281+
edges = initialize_boruvka_from_knn(neighbors, distances, core_distances, components_disjoint_set)
282+
update_component_vectors(tree, components_disjoint_set, node_components, point_components)
283+
else:
284+
core_distances = np.zeros(tree.data.shape[0], dtype=np.float32)
285+
distances, neighbors = parallel_tree_query(tree, tree.data, k=2)
286+
edges = initialize_boruvka_from_knn(neighbors, distances, core_distances, components_disjoint_set)
287+
update_component_vectors(tree, components_disjoint_set, node_components, point_components)
266288

267289
while n_components > 1:
268290
candidate_distances, candidate_indices = boruvka_tree_query(tree, node_components, point_components,

fast_hdbscan/cluster_trees.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,16 @@ def create_linkage_merge_data(base_size):
1717
return LinkageMergeData(parent, size, next_parent)
1818

1919

20+
@numba.njit()
21+
def create_linkage_merge_data_w_sample_weights(sample_weights):
22+
base_size = sample_weights.shape[0]
23+
parent = np.full(2 * base_size - 1, -1, dtype=np.intp)
24+
size = np.concatenate((sample_weights, np.zeros(base_size - 1, dtype=np.float32)))
25+
next_parent = np.array([base_size], dtype=np.intp)
26+
27+
return LinkageMergeData(parent, size, next_parent)
28+
29+
2030
@numba.njit()
2131
def linkage_merge_find(linkage_merge, node):
2232
relabel = node
@@ -43,11 +53,14 @@ def linkage_merge_join(linkage_merge, left, right):
4353

4454

4555
@numba.njit()
46-
def mst_to_linkage_tree(sorted_mst):
56+
def mst_to_linkage_tree(sorted_mst, sample_weights=None):
4757
result = np.empty((sorted_mst.shape[0], sorted_mst.shape[1] + 1))
4858

4959
n_samples = sorted_mst.shape[0] + 1
50-
linkage_merge = create_linkage_merge_data(n_samples)
60+
if sample_weights is None:
61+
linkage_merge = create_linkage_merge_data(n_samples)
62+
else:
63+
linkage_merge = create_linkage_merge_data_w_sample_weights(sample_weights)
5164

5265
for index in range(sorted_mst.shape[0]):
5366

@@ -116,7 +129,7 @@ def eliminate_branch(branch_node, parent_node, lambda_value, parents, children,
116129

117130

118131
@numba.njit(fastmath=True)
119-
def condense_tree(hierarchy, min_cluster_size=10):
132+
def condense_tree(hierarchy, min_cluster_size=10, sample_weights=None):
120133
root = 2 * hierarchy.shape[0]
121134
num_points = hierarchy.shape[0] + 1
122135
next_label = num_points + 1
@@ -133,6 +146,9 @@ def condense_tree(hierarchy, min_cluster_size=10):
133146

134147
ignore = np.zeros(root + 1, dtype=np.bool8)
135148

149+
if sample_weights is None:
150+
sample_weights = np.ones(num_points, dtype=np.float32)
151+
136152
idx = 0
137153

138154
for node in node_list:
@@ -148,8 +164,8 @@ def condense_tree(hierarchy, min_cluster_size=10):
148164
else:
149165
lambda_value = np.inf
150166

151-
left_count = np.int64(hierarchy[left - num_points, 3]) if left >= num_points else 1
152-
right_count = np.int64(hierarchy[right - num_points, 3]) if right >= num_points else 1
167+
left_count = np.int64(hierarchy[left - num_points, 3]) if left >= num_points else sample_weights[left]
168+
right_count = np.int64(hierarchy[right - num_points, 3]) if right >= num_points else sample_weights[left]
153169

154170
# The logic here is in a strange order, but it has non-trivial performance gains ...
155171
# The most common case by far is a singleton on the left; and cluster on the right take care of this separately

fast_hdbscan/hdbscan.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from sklearn.base import BaseEstimator, ClusterMixin
44
from sklearn.utils import check_array
5+
from sklearn.utils.validation import check_is_fitted, _check_sample_weight
56
from sklearn.neighbors import KDTree
67

78
from warnings import warn
@@ -135,6 +136,7 @@ def fast_hdbscan(
135136
cluster_selection_method="eom",
136137
allow_single_cluster=False,
137138
cluster_selection_epsilon=0.0,
139+
sample_weights=None,
138140
return_trees=False,
139141
):
140142
data = check_array(data)
@@ -156,10 +158,10 @@ def fast_hdbscan(
156158
sklearn_tree = KDTree(data)
157159
numba_tree = kdtree_to_numba(sklearn_tree)
158160
edges = parallel_boruvka(
159-
numba_tree, min_samples=min_cluster_size if min_samples is None else min_samples
161+
numba_tree, min_samples=min_cluster_size if min_samples is None else min_samples, sample_weights=sample_weights
160162
)
161163
sorted_mst = edges[np.argsort(edges.T[2])]
162-
linkage_tree = mst_to_linkage_tree(sorted_mst)
164+
linkage_tree = mst_to_linkage_tree(sorted_mst, sample_weights=sample_weights)
163165
condensed_tree = condense_tree(linkage_tree, min_cluster_size=min_cluster_size)
164166
if cluster_selection_epsilon > 0.0 or cluster_selection_method == "eom":
165167
cluster_tree = cluster_tree_from_condensed_tree(condensed_tree)
@@ -208,8 +210,10 @@ def __init__(
208210
self.allow_single_cluster = allow_single_cluster
209211
self.cluster_selection_epsilon = cluster_selection_epsilon
210212

211-
def fit(self, X, y=None, **fit_params):
213+
def fit(self, X, y=None, sample_weight=None, **fit_params):
212214
X = check_array(X, accept_sparse="csr", force_all_finite=False)
215+
if sample_weight is not None:
216+
sample_weight = _check_sample_weight(sample_weight, X, dtype=np.float32)
213217
self._raw_data = X
214218

215219
self._all_finite = np.all(np.isfinite(X))
@@ -233,7 +237,7 @@ def fit(self, X, y=None, **fit_params):
233237
self._single_linkage_tree,
234238
self._condensed_tree,
235239
self._min_spanning_tree,
236-
) = fast_hdbscan(clean_data, return_trees=True, **kwargs)
240+
) = fast_hdbscan(clean_data, return_trees=True, sample_weights=sample_weight, **kwargs)
237241

238242
self._condensed_tree = to_numpy_rec_array(self._condensed_tree)
239243

@@ -256,6 +260,7 @@ def fit(self, X, y=None, **fit_params):
256260
return self
257261

258262
def dbscan_clustering(self, epsilon):
263+
check_is_fitted(self, "_single_linkage_tree", msg="You first need to fit the HDBSCAN model before picking a DBSCAN clustering")
259264
return get_cluster_labelling_at_cut(
260265
self._single_linkage_tree,
261266
epsilon,
@@ -264,6 +269,7 @@ def dbscan_clustering(self, epsilon):
264269

265270
@property
266271
def condensed_tree_(self):
272+
check_is_fitted(self, "_condensed_tree", msg="You first need to fit the HDBSCAN model before accessing the condensed tree")
267273
if self._condensed_tree is not None:
268274
return CondensedTree(
269275
self._condensed_tree,
@@ -277,6 +283,7 @@ def condensed_tree_(self):
277283

278284
@property
279285
def single_linkage_tree_(self):
286+
check_is_fitted(self, "_single_linkage_tree", msg="You first need to fit the HDBSCAN model before accessing the single linkage tree")
280287
if self._single_linkage_tree is not None:
281288
return SingleLinkageTree(self._single_linkage_tree)
282289
else:
@@ -286,6 +293,7 @@ def single_linkage_tree_(self):
286293

287294
@property
288295
def minimum_spanning_tree_(self):
296+
check_is_fitted(self, "_min_spanning_tree", msg="You first need to fit the HDBSCAN model before accessing the minimum spanning tree")
289297
if self._min_spanning_tree is not None:
290298
if self._raw_data is not None:
291299
return MinimumSpanningTree(self._min_spanning_tree, self._raw_data)

0 commit comments

Comments
 (0)