@@ -28,41 +28,34 @@ def apply_branch_threshold(
28
28
labels [pts ] = running_id
29
29
probabilities [pts ] = cluster_probabilities [pts ]
30
30
running_id += 1
31
- continue
32
31
else :
33
- branch_labels [pts ] = np .where (
34
- branch_labels [pts ] < 0 , num_branches , branch_labels [pts ]
35
- )
36
- labels [pts ] = branch_labels [pts ] + running_id
32
+ labels [pts ] = branch_labels [pts ] + has_noise + running_id
37
33
running_id += num_branches + has_noise
38
34
39
35
40
36
def find_branch_sub_clusters (
41
37
clusterer ,
42
38
cluster_labels = None ,
43
39
cluster_probabilities = None ,
44
- * ,
45
- min_branch_size = None ,
46
- max_branch_size = None ,
47
- allow_single_branch = None ,
48
- branch_selection_method = None ,
49
- branch_selection_epsilon = 0.0 ,
50
- branch_selection_persistence = 0.0 ,
51
40
label_sides_as_branches = False ,
52
- propagate_labels = False ,
41
+ min_cluster_size = None ,
42
+ max_cluster_size = None ,
43
+ allow_single_cluster = None ,
44
+ cluster_selection_method = None ,
45
+ cluster_selection_epsilon = 0.0 ,
46
+ cluster_selection_persistence = 0.0 ,
53
47
):
54
48
result = find_sub_clusters (
55
49
clusterer ,
56
50
cluster_labels ,
57
51
cluster_probabilities ,
58
52
lens_callback = compute_centrality ,
59
- min_cluster_size = min_branch_size ,
60
- max_cluster_size = max_branch_size ,
61
- allow_single_cluster = allow_single_branch ,
62
- cluster_selection_method = branch_selection_method ,
63
- cluster_selection_epsilon = branch_selection_epsilon ,
64
- cluster_selection_persistence = branch_selection_persistence ,
65
- propagate_labels = propagate_labels ,
53
+ min_cluster_size = min_cluster_size ,
54
+ max_cluster_size = max_cluster_size ,
55
+ allow_single_cluster = allow_single_cluster ,
56
+ cluster_selection_method = cluster_selection_method ,
57
+ cluster_selection_epsilon = cluster_selection_epsilon ,
58
+ cluster_selection_persistence = cluster_selection_persistence ,
66
59
)
67
60
apply_branch_threshold (
68
61
result [0 ],
@@ -95,29 +88,28 @@ class BranchDetector(SubClusterDetector):
95
88
96
89
def __init__ (
97
90
self ,
98
- * ,
99
- min_branch_size = None ,
100
- max_branch_size = None ,
101
- allow_single_branch = None ,
102
- branch_selection_method = None ,
103
- branch_selection_epsilon = 0.0 ,
104
- branch_selection_persistence = 0.0 ,
105
- label_sides_as_branches = False ,
91
+ min_cluster_size = None ,
92
+ max_cluster_size = None ,
93
+ allow_single_cluster = None ,
94
+ cluster_selection_method = None ,
95
+ cluster_selection_epsilon = 0.0 ,
96
+ cluster_selection_persistence = 0.0 ,
106
97
propagate_labels = False ,
98
+ label_sides_as_branches = False ,
107
99
):
108
100
super ().__init__ (
109
- min_cluster_size = min_branch_size ,
110
- max_cluster_size = max_branch_size ,
111
- allow_single_cluster = allow_single_branch ,
112
- cluster_selection_method = branch_selection_method ,
113
- cluster_selection_epsilon = branch_selection_epsilon ,
114
- cluster_selection_persistence = branch_selection_persistence ,
101
+ min_cluster_size = min_cluster_size ,
102
+ max_cluster_size = max_cluster_size ,
103
+ allow_single_cluster = allow_single_cluster ,
104
+ cluster_selection_method = cluster_selection_method ,
105
+ cluster_selection_epsilon = cluster_selection_epsilon ,
106
+ cluster_selection_persistence = cluster_selection_persistence ,
115
107
propagate_labels = propagate_labels ,
116
108
)
117
109
self .label_sides_as_branches = label_sides_as_branches
118
110
119
- def fit (self , clusterer , labels = None , probabilities = None ):
120
- super ().fit (clusterer , labels , probabilities , compute_centrality )
111
+ def fit (self , clusterer , labels = None , probabilities = None , sample_weight = None ):
112
+ super ().fit (clusterer , labels , probabilities , sample_weight , compute_centrality )
121
113
apply_branch_threshold (
122
114
self .labels_ ,
123
115
self .sub_cluster_labels_ ,
@@ -132,6 +124,22 @@ def fit(self, clusterer, labels=None, probabilities=None):
132
124
self .centralities_ = self .lens_values_
133
125
return self
134
126
127
+ def propagated_labels (self , label_sides_as_branches = None ):
128
+ if label_sides_as_branches is None :
129
+ label_sides_as_branches = self .label_sides_as_branches
130
+
131
+ labels , branch_labels = super ().propagated_labels ()
132
+ apply_branch_threshold (
133
+ labels ,
134
+ branch_labels ,
135
+ np .zeros_like (self .probabilities_ ),
136
+ np .zeros_like (self .probabilities_ ),
137
+ self .cluster_points_ ,
138
+ self .linkage_trees_ ,
139
+ label_sides_as_branches = label_sides_as_branches ,
140
+ )
141
+ return labels , branch_labels
142
+
135
143
@property
136
144
def approximation_graph_ (self ):
137
145
"""See :class:`~hdbscan.plots.ApproximationGraph` for documentation."""
0 commit comments