33induced by those distances within the clusters found by HDBSCAN.
44"""
55
6- import numba
76import numpy as np
87from sklearn .utils .validation import check_is_fitted
98from sklearn .base import BaseEstimator , ClusterMixin
@@ -123,6 +122,48 @@ def update_labels(
123122 return labels , probabilities , sub_labels , sub_probabilities , lens_values
124123
125124
125+ def propagate_sub_cluster_labels (labels , sub_labels , graph_list , points_list ):
126+ running_id = 0
127+ for points , core_graph in zip (points_list , graph_list ):
128+ # Skip clusters with no labelled branches
129+ unique_sub_labels = np .unique (sub_labels [points ])
130+ if unique_sub_labels [0 ] != - 1 or len (unique_sub_labels ) == 1 :
131+ continue
132+
133+ # Create undirected core graph
134+ undirected = [
135+ {np .int32 (0 ): np .float32 (0.0 ) for _ in range (0 )} for _ in range (len (points ))
136+ ]
137+ for idx , (start , end ) in enumerate (
138+ zip (core_graph .indptr , core_graph .indptr [1 :])
139+ ):
140+ for i in range (start , end ):
141+ neighbor = core_graph .indices [i ]
142+ undirected [idx ][neighbor ] = 1 / core_graph .weights [i ]
143+ undirected [neighbor ][idx ] = 1 / core_graph .weights [i ]
144+
145+ # Repeat density-weighted majority votes on noise points until all are assigned
146+ while True :
147+ noise_idx = np .nonzero (sub_labels [points ] == - 1 )[0 ]
148+ if noise_idx .shape [0 ] == 0 :
149+ break
150+ for idx in noise_idx :
151+ candidates = {np .int32 (0 ): np .float32 (0.0 ) for _ in range (0 )}
152+ for neighbor_idx , weight in undirected [idx ].items ():
153+ label = sub_labels [points [neighbor_idx ]]
154+ if label == - 1 :
155+ continue
156+ candidates [label ] = candidates .get (label , 0.0 ) + weight
157+
158+ if len (candidates ) == 0 :
159+ continue
160+ sub_labels [points [idx ]] = max (candidates .items (), key = lambda x : x [1 ])[0 ]
161+
162+ labels [points ] = sub_labels [points ] + running_id
163+ running_id += len (unique_sub_labels ) - 1
164+ return labels , sub_labels
165+
166+
126167def remap_results (
127168 labels ,
128169 probabilities ,
@@ -192,6 +233,7 @@ def find_sub_clusters(
192233 cluster_selection_method = None ,
193234 cluster_selection_epsilon = 0.0 ,
194235 cluster_selection_persistence = 0.0 ,
236+ propagate_labels = False ,
195237):
196238 check_is_fitted (
197239 clusterer ,
@@ -323,6 +365,12 @@ def find_sub_clusters(
323365 data .shape [0 ],
324366 )
325367
368+ # Propagate labels if requested
369+ if propagate_labels :
370+ labels , sub_labels = propagate_sub_cluster_labels (
371+ labels , sub_labels , core_graphs , points
372+ )
373+
326374 # Reset for infinite data points
327375 if last_outlier > 0 :
328376 (
@@ -377,6 +425,7 @@ def __init__(
377425 cluster_selection_method = "eom" ,
378426 cluster_selection_epsilon = 0.0 ,
379427 cluster_selection_persistence = 0.0 ,
428+ propagate_labels = False ,
380429 ):
381430 self .lens_values = lens_values
382431 self .min_cluster_size = min_cluster_size
@@ -385,6 +434,7 @@ def __init__(
385434 self .cluster_selection_method = cluster_selection_method
386435 self .cluster_selection_epsilon = cluster_selection_epsilon
387436 self .cluster_selection_persistence = cluster_selection_persistence
437+ self .propagate_labels = propagate_labels
388438
389439 def fit (self , clusterer , labels = None , probabilities = None , lens_callback = None ):
390440 """labels and probabilities override the clusterer's values."""
@@ -413,6 +463,7 @@ def fit(self, clusterer, labels=None, probabilities=None, lens_callback=None):
413463 cluster_selection_method = self .cluster_selection_method ,
414464 cluster_selection_epsilon = self .cluster_selection_epsilon ,
415465 cluster_selection_persistence = self .cluster_selection_persistence ,
466+ propagate_labels = self .propagate_labels ,
416467 )
417468 # also store the core distances and raw data for the member functions
418469 self ._raw_data = clusterer ._raw_data
@@ -458,7 +509,6 @@ def _make_approximation_graph(self, lens_name=None, sub_cluster_name=None):
458509 raw_data = self ._raw_data ,
459510 )
460511
461-
462512 @property
463513 def condensed_trees_ (self ):
464514 """See :class:`~hdbscan.plots.CondensedTree` for documentation."""
0 commit comments