4
4
from .disjoint_set import ds_rank_create , ds_find , ds_union_by_rank
5
5
from .numba_kdtree import parallel_tree_query , rdist , point_to_node_lower_bound_rdist
6
6
7
- @numba .njit (locals = {"i" : numba .types .int64 })
8
- def merge_components (disjoint_set , candidate_neighbors , candidate_neighbor_distances , point_components ):
9
- component_edges = {np .int64 (0 ): (np .int64 (0 ), np .int64 (1 ), np .float32 (0.0 )) for i in range (0 )}
7
+
8
+ @numba .njit (locals = {"parent" : numba .types .int32 })
9
+ def select_components (candidate_distances , candidate_neighbors , point_components ):
10
+ component_edges = {np .int64 (0 ): (np .int32 (0 ), np .int32 (1 ), np .float32 (0.0 )) for i in range (0 )}
10
11
11
12
# Find the best edges from each component
12
- for i in range (candidate_neighbors .shape [0 ]):
13
- from_component = np .int64 (point_components [i ])
13
+ for parent , (distance , neighbor , from_component ) in enumerate (
14
+ zip (candidate_distances , candidate_neighbors , point_components )
15
+ ):
14
16
if from_component in component_edges :
15
- if candidate_neighbor_distances [ i ] < component_edges [from_component ][2 ]:
16
- component_edges [from_component ] = (np . int64 ( i ), np . int64 ( candidate_neighbors [ i ]), candidate_neighbor_distances [ i ] )
17
+ if distance < component_edges [from_component ][2 ]:
18
+ component_edges [from_component ] = (parent , neighbor , distance )
17
19
else :
18
- component_edges [from_component ] = (np .int64 (i ), np .int64 (candidate_neighbors [i ]), candidate_neighbor_distances [i ])
20
+ component_edges [from_component ] = (parent , neighbor , distance )
21
+
22
+ return component_edges
23
+
19
24
25
+ @numba .njit ()
26
+ def merge_components (disjoint_set , component_edges ):
20
27
result = np .empty ((len (component_edges ), 3 ), dtype = np .float64 )
21
28
result_idx = 0
22
29
23
30
# Add the best edges to the edge set and merge the relevant components
24
31
for edge in component_edges .values ():
25
- from_component = ds_find (disjoint_set , np . int32 ( edge [0 ]) )
26
- to_component = ds_find (disjoint_set , np . int32 ( edge [1 ]) )
32
+ from_component = ds_find (disjoint_set , edge [0 ])
33
+ to_component = ds_find (disjoint_set , edge [1 ])
27
34
if from_component != to_component :
28
35
result [result_idx ] = (np .float64 (edge [0 ]), np .float64 (edge [1 ]), np .float64 (edge [2 ]))
29
36
result_idx += 1
@@ -34,10 +41,13 @@ def merge_components(disjoint_set, candidate_neighbors, candidate_neighbor_dista
34
41
35
42
36
43
@numba .njit (parallel = True )
37
- def update_component_vectors ( tree , disjoint_set , node_components , point_components ):
44
+ def update_point_components ( disjoint_set , point_components ):
38
45
for i in numba .prange (point_components .shape [0 ]):
39
46
point_components [i ] = ds_find (disjoint_set , np .int32 (i ))
40
47
48
+
49
+ @numba .njit ()
50
+ def update_node_components (tree , node_components , point_components ):
41
51
for i in range (tree .node_data .shape [0 ] - 1 , - 1 , - 1 ):
42
52
node_info = tree .node_data [i ]
43
53
@@ -272,28 +282,28 @@ def parallel_boruvka(tree, min_samples=10, sample_weights=None):
272
282
expected_neighbors = min_samples / mean_sample_weight
273
283
distances , neighbors = parallel_tree_query (tree , tree .data , k = int (2 * expected_neighbors ))
274
284
core_distances = sample_weight_core_distance (distances , neighbors , sample_weights , min_samples )
275
- edges = initialize_boruvka_from_knn (neighbors , distances , core_distances , components_disjoint_set )
276
- update_component_vectors (tree , components_disjoint_set , node_components , point_components )
277
285
else :
278
286
if min_samples > 1 :
279
287
distances , neighbors = parallel_tree_query (tree , tree .data , k = min_samples + 1 , output_rdist = True )
280
288
core_distances = distances .T [- 1 ]
281
- edges = initialize_boruvka_from_knn (neighbors , distances , core_distances , components_disjoint_set )
282
- update_component_vectors (tree , components_disjoint_set , node_components , point_components )
283
289
else :
284
290
core_distances = np .zeros (tree .data .shape [0 ], dtype = np .float32 )
285
291
distances , neighbors = parallel_tree_query (tree , tree .data , k = 2 )
286
- edges = initialize_boruvka_from_knn (neighbors , distances , core_distances , components_disjoint_set )
287
- update_component_vectors (tree , components_disjoint_set , node_components , point_components )
288
292
289
- while n_components > 1 :
293
+ edges = [np .empty ((0 , 3 ), dtype = np .float64 ) for _ in range (0 )]
294
+ new_edges = initialize_boruvka_from_knn (neighbors , distances , core_distances , components_disjoint_set )
295
+ while True :
296
+ edges .append (new_edges )
297
+ n_components -= new_edges .shape [0 ]
298
+ if n_components == 1 :
299
+ break
300
+ update_point_components (components_disjoint_set , point_components )
301
+ update_node_components (tree , node_components , point_components )
290
302
candidate_distances , candidate_indices = boruvka_tree_query (tree , node_components , point_components ,
291
303
core_distances )
292
- new_edges = merge_components (components_disjoint_set , candidate_indices , candidate_distances , point_components )
293
- update_component_vectors (tree , components_disjoint_set , node_components , point_components )
294
-
295
- edges = np .vstack ((edges , new_edges ))
296
- n_components = np .unique (point_components ).shape [0 ]
304
+ component_edges = select_components (candidate_distances , candidate_indices , point_components )
305
+ new_edges = merge_components (components_disjoint_set , component_edges )
297
306
307
+ edges = np .vstack (edges )
298
308
edges [:, 2 ] = np .sqrt (edges .T [2 ])
299
- return edges
309
+ return edges , neighbors [:, 1 :], np . sqrt ( core_distances )
0 commit comments