@@ -22,6 +22,16 @@ def create_linkage_merge_data(base_size):
22
22
return LinkageMergeData (parent , size , next_parent )
23
23
24
24
25
+ @numba .njit ()
26
+ def create_linkage_merge_data_w_sample_weights (sample_weights ):
27
+ base_size = sample_weights .shape [0 ]
28
+ parent = np .full (2 * base_size - 1 , - 1 , dtype = np .intp )
29
+ size = np .concatenate ((sample_weights , np .zeros (base_size - 1 , dtype = np .float32 )))
30
+ next_parent = np .array ([base_size ], dtype = np .intp )
31
+
32
+ return LinkageMergeData (parent , size , next_parent )
33
+
34
+
25
35
@numba .njit ()
26
36
def linkage_merge_find (linkage_merge , node ):
27
37
relabel = node
@@ -78,6 +88,36 @@ def mst_to_linkage_tree(sorted_mst):
78
88
return result
79
89
80
90
91
+ @numba .njit ()
92
+ def mst_to_linkage_tree_w_sample_weights (sorted_mst , sample_weights ):
93
+ result = np .empty ((sorted_mst .shape [0 ], sorted_mst .shape [1 ] + 1 ))
94
+
95
+ linkage_merge = create_linkage_merge_data_w_sample_weights (sample_weights )
96
+
97
+ for index in range (sorted_mst .shape [0 ]):
98
+
99
+ left = np .intp (sorted_mst [index , 0 ])
100
+ right = np .intp (sorted_mst [index , 1 ])
101
+ delta = sorted_mst [index , 2 ]
102
+
103
+ left_component = linkage_merge_find (linkage_merge , left )
104
+ right_component = linkage_merge_find (linkage_merge , right )
105
+
106
+ if left_component > right_component :
107
+ result [index ][0 ] = left_component
108
+ result [index ][1 ] = right_component
109
+ else :
110
+ result [index ][1 ] = left_component
111
+ result [index ][0 ] = right_component
112
+
113
+ result [index ][2 ] = delta
114
+ result [index ][3 ] = linkage_merge .size [left_component ] + linkage_merge .size [right_component ]
115
+
116
+ linkage_merge_join (linkage_merge , left_component , right_component )
117
+
118
+ return result
119
+
120
+
81
121
@numba .njit ()
82
122
def bfs_from_hierarchy (hierarchy , bfs_root , num_points ):
83
123
to_process = [bfs_root ]
@@ -121,7 +161,7 @@ def eliminate_branch(branch_node, parent_node, lambda_value, parents, children,
121
161
122
162
123
163
@numba .njit (fastmath = True )
124
- def condense_tree (hierarchy , min_cluster_size = 10 ):
164
+ def condense_tree (hierarchy , min_cluster_size = 10 , sample_weights = None ):
125
165
root = 2 * hierarchy .shape [0 ]
126
166
num_points = hierarchy .shape [0 ] + 1
127
167
next_label = num_points + 1
@@ -134,10 +174,13 @@ def condense_tree(hierarchy, min_cluster_size=10):
134
174
parents = np .ones (root , dtype = np .int64 )
135
175
children = np .empty (root , dtype = np .int64 )
136
176
lambdas = np .empty (root , dtype = np .float32 )
137
- sizes = np .ones (root , dtype = np .int64 )
177
+ sizes = np .ones (root , dtype = np .float32 )
138
178
139
179
ignore = np .zeros (root + 1 , dtype = np .bool_ ) # 'bool' is no longer an attribute of 'numpy'
140
180
181
+ if sample_weights is None :
182
+ sample_weights = np .ones (num_points , dtype = np .float32 )
183
+
141
184
idx = 0
142
185
143
186
for node in node_list :
@@ -153,8 +196,8 @@ def condense_tree(hierarchy, min_cluster_size=10):
153
196
else :
154
197
lambda_value = np .inf
155
198
156
- left_count = np .int64 (hierarchy [left - num_points , 3 ]) if left >= num_points else 1
157
- right_count = np .int64 (hierarchy [right - num_points , 3 ]) if right >= num_points else 1
199
+ left_count = np .float32 (hierarchy [left - num_points , 3 ]) if left >= num_points else sample_weights [ left ]
200
+ right_count = np .float32 (hierarchy [right - num_points , 3 ]) if right >= num_points else sample_weights [ right ]
158
201
159
202
# The logic here is in a strange order, but it has non-trivial performance gains ...
160
203
# The most common case by far is a singleton on the left; and cluster on the right take care of this separately
@@ -391,7 +434,7 @@ def extract_clusters_bcubed(condensed_tree, cluster_tree, label_indices, allow_v
391
434
392
435
@numba .njit ()
393
436
def score_condensed_tree_nodes (condensed_tree ):
394
- result = {0 : 0.0 for i in range (0 )}
437
+ result = {0 : np . float32 ( 0.0 ) for i in range (0 )}
395
438
396
439
for i in range (condensed_tree .parent .shape [0 ]):
397
440
parent = condensed_tree .parent [i ]
@@ -559,13 +602,16 @@ def get_cluster_labelling_at_cut(linkage_tree, cut, min_cluster_size):
559
602
def get_cluster_label_vector (
560
603
tree ,
561
604
clusters ,
562
- cluster_selection_epsilon
605
+ cluster_selection_epsilon ,
606
+ n_samples ,
563
607
):
608
+ if len (tree .parent ) == 0 :
609
+ return np .full (n_samples , - 1 , dtype = np .intp )
564
610
root_cluster = tree .parent .min ()
565
- result = np .empty ( root_cluster , dtype = np .intp )
611
+ result = np .full ( n_samples , - 1 , dtype = np .intp )
566
612
cluster_label_map = {c : n for n , c in enumerate (np .sort (clusters ))}
567
613
568
- disjoint_set = ds_rank_create (tree .parent .max () + 1 )
614
+ disjoint_set = ds_rank_create (max ( tree .parent .max () + 1 , tree . child . max () + 1 ) )
569
615
clusters = set (clusters )
570
616
571
617
for n in range (tree .parent .shape [0 ]):
0 commit comments