@@ -28,41 +28,34 @@ def apply_branch_threshold(
2828 labels [pts ] = running_id
2929 probabilities [pts ] = cluster_probabilities [pts ]
3030 running_id += 1
31- continue
3231 else :
33- branch_labels [pts ] = np .where (
34- branch_labels [pts ] < 0 , num_branches , branch_labels [pts ]
35- )
36- labels [pts ] = branch_labels [pts ] + running_id
32+ labels [pts ] = branch_labels [pts ] + has_noise + running_id
3733 running_id += num_branches + has_noise
3834
3935
4036def find_branch_sub_clusters (
4137 clusterer ,
4238 cluster_labels = None ,
4339 cluster_probabilities = None ,
44- * ,
45- min_branch_size = None ,
46- max_branch_size = None ,
47- allow_single_branch = None ,
48- branch_selection_method = None ,
49- branch_selection_epsilon = 0.0 ,
50- branch_selection_persistence = 0.0 ,
5140 label_sides_as_branches = False ,
52- propagate_labels = False ,
41+ min_cluster_size = None ,
42+ max_cluster_size = None ,
43+ allow_single_cluster = None ,
44+ cluster_selection_method = None ,
45+ cluster_selection_epsilon = 0.0 ,
46+ cluster_selection_persistence = 0.0 ,
5347):
5448 result = find_sub_clusters (
5549 clusterer ,
5650 cluster_labels ,
5751 cluster_probabilities ,
5852 lens_callback = compute_centrality ,
59- min_cluster_size = min_branch_size ,
60- max_cluster_size = max_branch_size ,
61- allow_single_cluster = allow_single_branch ,
62- cluster_selection_method = branch_selection_method ,
63- cluster_selection_epsilon = branch_selection_epsilon ,
64- cluster_selection_persistence = branch_selection_persistence ,
65- propagate_labels = propagate_labels ,
53+ min_cluster_size = min_cluster_size ,
54+ max_cluster_size = max_cluster_size ,
55+ allow_single_cluster = allow_single_cluster ,
56+ cluster_selection_method = cluster_selection_method ,
57+ cluster_selection_epsilon = cluster_selection_epsilon ,
58+ cluster_selection_persistence = cluster_selection_persistence ,
6659 )
6760 apply_branch_threshold (
6861 result [0 ],
@@ -95,29 +88,28 @@ class BranchDetector(SubClusterDetector):
9588
9689 def __init__ (
9790 self ,
98- * ,
99- min_branch_size = None ,
100- max_branch_size = None ,
101- allow_single_branch = None ,
102- branch_selection_method = None ,
103- branch_selection_epsilon = 0.0 ,
104- branch_selection_persistence = 0.0 ,
105- label_sides_as_branches = False ,
91+ min_cluster_size = None ,
92+ max_cluster_size = None ,
93+ allow_single_cluster = None ,
94+ cluster_selection_method = None ,
95+ cluster_selection_epsilon = 0.0 ,
96+ cluster_selection_persistence = 0.0 ,
10697 propagate_labels = False ,
98+ label_sides_as_branches = False ,
10799 ):
108100 super ().__init__ (
109- min_cluster_size = min_branch_size ,
110- max_cluster_size = max_branch_size ,
111- allow_single_cluster = allow_single_branch ,
112- cluster_selection_method = branch_selection_method ,
113- cluster_selection_epsilon = branch_selection_epsilon ,
114- cluster_selection_persistence = branch_selection_persistence ,
101+ min_cluster_size = min_cluster_size ,
102+ max_cluster_size = max_cluster_size ,
103+ allow_single_cluster = allow_single_cluster ,
104+ cluster_selection_method = cluster_selection_method ,
105+ cluster_selection_epsilon = cluster_selection_epsilon ,
106+ cluster_selection_persistence = cluster_selection_persistence ,
115107 propagate_labels = propagate_labels ,
116108 )
117109 self .label_sides_as_branches = label_sides_as_branches
118110
119- def fit (self , clusterer , labels = None , probabilities = None ):
120- super ().fit (clusterer , labels , probabilities , compute_centrality )
111+ def fit (self , clusterer , labels = None , probabilities = None , sample_weight = None ):
112+ super ().fit (clusterer , labels , probabilities , sample_weight , compute_centrality )
121113 apply_branch_threshold (
122114 self .labels_ ,
123115 self .sub_cluster_labels_ ,
@@ -132,6 +124,22 @@ def fit(self, clusterer, labels=None, probabilities=None):
132124 self .centralities_ = self .lens_values_
133125 return self
134126
127+ def propagated_labels (self , label_sides_as_branches = None ):
128+ if label_sides_as_branches is None :
129+ label_sides_as_branches = self .label_sides_as_branches
130+
131+ labels , branch_labels = super ().propagated_labels ()
132+ apply_branch_threshold (
133+ labels ,
134+ branch_labels ,
135+ np .zeros_like (self .probabilities_ ),
136+ np .zeros_like (self .probabilities_ ),
137+ self .cluster_points_ ,
138+ self .linkage_trees_ ,
139+ label_sides_as_branches = label_sides_as_branches ,
140+ )
141+ return labels , branch_labels
142+
135143 @property
136144 def approximation_graph_ (self ):
137145 """See :class:`~hdbscan.plots.ApproximationGraph` for documentation."""
0 commit comments