Skip to content

Commit c66f5c1

Browse files
committed
fixes; consistent naming; docs;
1 parent a3a5756 commit c66f5c1

File tree

10 files changed

+821
-133
lines changed

10 files changed

+821
-133
lines changed

doc/detecting_branches.ipynb

Lines changed: 581 additions & 0 deletions
Large diffs are not rendered by default.

doc/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ User Guide
7676
basic_usage
7777
benchmarks
7878
comparable_clusterings
79+
detecting_branches
7980

8081

8182
----------

fast_hdbscan/branches.py

Lines changed: 44 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -28,41 +28,34 @@ def apply_branch_threshold(
2828
labels[pts] = running_id
2929
probabilities[pts] = cluster_probabilities[pts]
3030
running_id += 1
31-
continue
3231
else:
33-
branch_labels[pts] = np.where(
34-
branch_labels[pts] < 0, num_branches, branch_labels[pts]
35-
)
36-
labels[pts] = branch_labels[pts] + running_id
32+
labels[pts] = branch_labels[pts] + has_noise + running_id
3733
running_id += num_branches + has_noise
3834

3935

4036
def find_branch_sub_clusters(
4137
clusterer,
4238
cluster_labels=None,
4339
cluster_probabilities=None,
44-
*,
45-
min_branch_size=None,
46-
max_branch_size=None,
47-
allow_single_branch=None,
48-
branch_selection_method=None,
49-
branch_selection_epsilon=0.0,
50-
branch_selection_persistence=0.0,
5140
label_sides_as_branches=False,
52-
propagate_labels=False,
41+
min_cluster_size=None,
42+
max_cluster_size=None,
43+
allow_single_cluster=None,
44+
cluster_selection_method=None,
45+
cluster_selection_epsilon=0.0,
46+
cluster_selection_persistence=0.0,
5347
):
5448
result = find_sub_clusters(
5549
clusterer,
5650
cluster_labels,
5751
cluster_probabilities,
5852
lens_callback=compute_centrality,
59-
min_cluster_size=min_branch_size,
60-
max_cluster_size=max_branch_size,
61-
allow_single_cluster=allow_single_branch,
62-
cluster_selection_method=branch_selection_method,
63-
cluster_selection_epsilon=branch_selection_epsilon,
64-
cluster_selection_persistence=branch_selection_persistence,
65-
propagate_labels=propagate_labels,
53+
min_cluster_size=min_cluster_size,
54+
max_cluster_size=max_cluster_size,
55+
allow_single_cluster=allow_single_cluster,
56+
cluster_selection_method=cluster_selection_method,
57+
cluster_selection_epsilon=cluster_selection_epsilon,
58+
cluster_selection_persistence=cluster_selection_persistence,
6659
)
6760
apply_branch_threshold(
6861
result[0],
@@ -95,29 +88,28 @@ class BranchDetector(SubClusterDetector):
9588

9689
def __init__(
9790
self,
98-
*,
99-
min_branch_size=None,
100-
max_branch_size=None,
101-
allow_single_branch=None,
102-
branch_selection_method=None,
103-
branch_selection_epsilon=0.0,
104-
branch_selection_persistence=0.0,
105-
label_sides_as_branches=False,
91+
min_cluster_size=None,
92+
max_cluster_size=None,
93+
allow_single_cluster=None,
94+
cluster_selection_method=None,
95+
cluster_selection_epsilon=0.0,
96+
cluster_selection_persistence=0.0,
10697
propagate_labels=False,
98+
label_sides_as_branches=False,
10799
):
108100
super().__init__(
109-
min_cluster_size=min_branch_size,
110-
max_cluster_size=max_branch_size,
111-
allow_single_cluster=allow_single_branch,
112-
cluster_selection_method=branch_selection_method,
113-
cluster_selection_epsilon=branch_selection_epsilon,
114-
cluster_selection_persistence=branch_selection_persistence,
101+
min_cluster_size=min_cluster_size,
102+
max_cluster_size=max_cluster_size,
103+
allow_single_cluster=allow_single_cluster,
104+
cluster_selection_method=cluster_selection_method,
105+
cluster_selection_epsilon=cluster_selection_epsilon,
106+
cluster_selection_persistence=cluster_selection_persistence,
115107
propagate_labels=propagate_labels,
116108
)
117109
self.label_sides_as_branches = label_sides_as_branches
118110

119-
def fit(self, clusterer, labels=None, probabilities=None):
120-
super().fit(clusterer, labels, probabilities, compute_centrality)
111+
def fit(self, clusterer, labels=None, probabilities=None, sample_weight=None):
112+
super().fit(clusterer, labels, probabilities, sample_weight, compute_centrality)
121113
apply_branch_threshold(
122114
self.labels_,
123115
self.sub_cluster_labels_,
@@ -132,6 +124,22 @@ def fit(self, clusterer, labels=None, probabilities=None):
132124
self.centralities_ = self.lens_values_
133125
return self
134126

127+
def propagated_labels(self, label_sides_as_branches=None):
128+
if label_sides_as_branches is None:
129+
label_sides_as_branches = self.label_sides_as_branches
130+
131+
labels, branch_labels = super().propagated_labels()
132+
apply_branch_threshold(
133+
labels,
134+
branch_labels,
135+
np.zeros_like(self.probabilities_),
136+
np.zeros_like(self.probabilities_),
137+
self.cluster_points_,
138+
self.linkage_trees_,
139+
label_sides_as_branches=label_sides_as_branches,
140+
)
141+
return labels, branch_labels
142+
135143
@property
136144
def approximation_graph_(self):
137145
"""See :class:`~hdbscan.plots.ApproximationGraph` for documentation."""

fast_hdbscan/cluster_trees.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,21 @@ def condense_tree(hierarchy, min_cluster_size=10, max_cluster_size=np.inf, sampl
257257

258258

259259
@numba.njit()
260-
def extract_leaves(cluster_tree, n_points):
260+
def extract_leaves(condensed_tree, allow_single_cluster=True):
261+
n_nodes = condensed_tree.parent.max() + 1
262+
n_points = condensed_tree.parent.min()
263+
leaf_indicator = np.ones(n_nodes, dtype=np.bool_)
264+
leaf_indicator[:n_points] = False
265+
266+
for parent, child_size in zip(condensed_tree.parent, condensed_tree.child_size):
267+
if child_size > 1:
268+
leaf_indicator[parent] = False
269+
270+
return np.nonzero(leaf_indicator)[0]
271+
272+
273+
@numba.njit()
274+
def cluster_tree_leaves(cluster_tree, n_points):
261275
n_nodes = cluster_tree.child.max() + 1
262276
leaf_indicator = np.ones(n_nodes - n_points, dtype=np.bool_)
263277
leaf_indicator[cluster_tree.parent - n_points] = False
@@ -538,7 +552,7 @@ def simplify_hierarchy(condensed_tree, n_points, persistence_threshold):
538552
processed = {np.int64(0)}
539553
processed.clear()
540554
while cluster_tree.parent.shape[0] > 0:
541-
leaves = set(extract_leaves(cluster_tree, n_points))
555+
leaves = set(cluster_tree_leaves(cluster_tree, n_points))
542556
births = max_lambdas(condensed_tree, leaves)
543557
deaths = min_lambdas(cluster_tree, leaves)
544558

fast_hdbscan/core_graph.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -70,15 +70,22 @@ def flatten_to_csr(graph):
7070

7171
@numba.njit(parallel=True)
7272
def sort_by_lens(graph):
73-
for point in numba.prange(len(graph)):
73+
new_weights = np.empty_like(graph.weights)
74+
new_distances = np.empty_like(graph.distances)
75+
new_indices = np.empty_like(graph.indices)
76+
for point in numba.prange(len(graph.indptr) - 1):
7477
start = graph.indptr[point]
7578
end = graph.indptr[point + 1]
76-
weights = graph.weights[start:end]
77-
order = np.argsort(weights)
78-
graph.weights[start:end] = weights[order]
79-
graph.distances[start:end] = graph.distances[start:end][order]
80-
graph.indices[start:end] = graph.indices[start:end][order]
81-
return graph
79+
80+
row_weights = graph.weights[start:end]
81+
row_distances = graph.distances[start:end]
82+
row_indices = graph.indices[start:end]
83+
84+
order = np.argsort(row_weights)
85+
new_weights[start:end] = row_weights[order]
86+
new_distances[start:end] = row_distances[order]
87+
new_indices[start:end] = row_indices[order]
88+
return CoreGraph(new_weights, new_distances, new_indices, graph.indptr)
8289

8390

8491
@numba.njit(parallel=True)
@@ -173,7 +180,7 @@ def minimum_spanning_tree(graph, overwrite=False):
173180
return n_components, point_components, result
174181

175182

176-
# @numba.njit()
183+
@numba.njit()
177184
def core_graph_spanning_tree(neighbors, core_distances, min_spanning_tree, lens):
178185
graph = sort_by_lens(
179186
flatten_to_csr(

fast_hdbscan/hdbscan.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
condense_tree,
1616
simplify_hierarchy,
1717
extract_eom_clusters,
18-
extract_leaves,
18+
cluster_tree_leaves,
1919
cluster_epsilon_search,
2020
get_cluster_labelling_at_cut,
2121
get_cluster_label_vector,
@@ -266,7 +266,7 @@ def clusters_from_spanning_tree(
266266
if cluster_tree.parent.shape[0] == 0:
267267
selected_clusters = np.empty(0, dtype=np.int64)
268268
else:
269-
selected_clusters = extract_leaves(cluster_tree, n_points)
269+
selected_clusters = cluster_tree_leaves(cluster_tree, n_points)
270270
else:
271271
raise ValueError(f"Invalid cluster_selection_method {cluster_selection_method}")
272272

@@ -319,8 +319,6 @@ def fit(self, X, y=None, sample_weight=None, **fit_params):
319319

320320
if self.semi_supervised:
321321
X, y = check_X_y(X, y, accept_sparse="csr", force_all_finite=False)
322-
if sample_weight is not None:
323-
sample_weight = _check_sample_weight(sample_weight, X, dtype=np.float32)
324322
self._raw_labels = y
325323
# Replace non-finite labels with -1 labels
326324
y[~np.isfinite(y)] = -1
@@ -331,20 +329,18 @@ def fit(self, X, y=None, sample_weight=None, **fit_params):
331329
)
332330
else:
333331
X = check_array(X, accept_sparse="csr", force_all_finite=False)
334-
if sample_weight is not None:
335-
sample_weight = _check_sample_weight(sample_weight, X, dtype=np.float32)
336332
self._raw_data = X
333+
if sample_weight is not None:
334+
sample_weight = _check_sample_weight(sample_weight, X, dtype=np.float32)
337335

338336
self._all_finite = np.all(np.isfinite(X))
339337
if ~self._all_finite:
340338
# Pass only the purely finite indices into hdbscan
341339
# We will later assign all non-finite points to the background -1 cluster
342340
finite_index = np.where(np.isfinite(X).sum(axis=1) == X.shape[1])[0]
343341
clean_data = X[finite_index]
344-
clean_data_labels = y
345-
346-
if self.semi_supervised:
347-
clean_data_labels = y[finite_index]
342+
clean_data_labels = y[finite_index] if self.semi_supervised else None
343+
sample_weight = sample_weight[finite_index] if sample_weight is not None else None
348344

349345
internal_to_raw = {
350346
x: y for x, y in zip(range(len(finite_index)), finite_index)
@@ -392,10 +388,6 @@ def fit(self, X, y=None, sample_weight=None, **fit_params):
392388

393389
return self
394390

395-
def fit_predict(self, X, y=None, sample_weight=None, **fit_params):
396-
self.fit(X, y, sample_weight, **fit_params)
397-
return self.labels_
398-
399391
def dbscan_clustering(self, epsilon):
400392
check_is_fitted(
401393
self,

0 commit comments

Comments
 (0)