44from .disjoint_set import ds_rank_create , ds_find , ds_union_by_rank
55from .numba_kdtree import parallel_tree_query , rdist , point_to_node_lower_bound_rdist
66
7- @numba .njit (locals = {"i" : numba .types .int64 })
8- def merge_components (disjoint_set , candidate_neighbors , candidate_neighbor_distances , point_components ):
9- component_edges = {np .int64 (0 ): (np .int64 (0 ), np .int64 (1 ), np .float32 (0.0 )) for i in range (0 )}
7+
8+ @numba .njit (locals = {"parent" : numba .types .int32 })
9+ def select_components (candidate_distances , candidate_neighbors , point_components ):
10+ component_edges = {np .int64 (0 ): (np .int32 (0 ), np .int32 (1 ), np .float32 (0.0 )) for i in range (0 )}
1011
1112 # Find the best edges from each component
12- for i in range (candidate_neighbors .shape [0 ]):
13- from_component = np .int64 (point_components [i ])
13+ for parent , (distance , neighbor , from_component ) in enumerate (
14+ zip (candidate_distances , candidate_neighbors , point_components )
15+ ):
1416 if from_component in component_edges :
15- if candidate_neighbor_distances [ i ] < component_edges [from_component ][2 ]:
16- component_edges [from_component ] = (np . int64 ( i ), np . int64 ( candidate_neighbors [ i ]), candidate_neighbor_distances [ i ] )
17+ if distance < component_edges [from_component ][2 ]:
18+ component_edges [from_component ] = (parent , neighbor , distance )
1719 else :
18- component_edges [from_component ] = (np .int64 (i ), np .int64 (candidate_neighbors [i ]), candidate_neighbor_distances [i ])
20+ component_edges [from_component ] = (parent , neighbor , distance )
21+
22+ return component_edges
23+
1924
25+ @numba .njit ()
26+ def merge_components (disjoint_set , component_edges ):
2027 result = np .empty ((len (component_edges ), 3 ), dtype = np .float64 )
2128 result_idx = 0
2229
2330 # Add the best edges to the edge set and merge the relevant components
2431 for edge in component_edges .values ():
25- from_component = ds_find (disjoint_set , np . int32 ( edge [0 ]) )
26- to_component = ds_find (disjoint_set , np . int32 ( edge [1 ]) )
32+ from_component = ds_find (disjoint_set , edge [0 ])
33+ to_component = ds_find (disjoint_set , edge [1 ])
2734 if from_component != to_component :
2835 result [result_idx ] = (np .float64 (edge [0 ]), np .float64 (edge [1 ]), np .float64 (edge [2 ]))
2936 result_idx += 1
@@ -34,10 +41,13 @@ def merge_components(disjoint_set, candidate_neighbors, candidate_neighbor_dista
3441
3542
3643@numba .njit (parallel = True )
37- def update_component_vectors ( tree , disjoint_set , node_components , point_components ):
44+ def update_point_components ( disjoint_set , point_components ):
3845 for i in numba .prange (point_components .shape [0 ]):
3946 point_components [i ] = ds_find (disjoint_set , np .int32 (i ))
4047
48+
49+ @numba .njit ()
50+ def update_node_components (tree , node_components , point_components ):
4151 for i in range (tree .node_data .shape [0 ] - 1 , - 1 , - 1 ):
4252 node_info = tree .node_data [i ]
4353
@@ -272,28 +282,28 @@ def parallel_boruvka(tree, min_samples=10, sample_weights=None):
272282 expected_neighbors = min_samples / mean_sample_weight
273283 distances , neighbors = parallel_tree_query (tree , tree .data , k = int (2 * expected_neighbors ))
274284 core_distances = sample_weight_core_distance (distances , neighbors , sample_weights , min_samples )
275- edges = initialize_boruvka_from_knn (neighbors , distances , core_distances , components_disjoint_set )
276- update_component_vectors (tree , components_disjoint_set , node_components , point_components )
277285 else :
278286 if min_samples > 1 :
279287 distances , neighbors = parallel_tree_query (tree , tree .data , k = min_samples + 1 , output_rdist = True )
280288 core_distances = distances .T [- 1 ]
281- edges = initialize_boruvka_from_knn (neighbors , distances , core_distances , components_disjoint_set )
282- update_component_vectors (tree , components_disjoint_set , node_components , point_components )
283289 else :
284290 core_distances = np .zeros (tree .data .shape [0 ], dtype = np .float32 )
285291 distances , neighbors = parallel_tree_query (tree , tree .data , k = 2 )
286- edges = initialize_boruvka_from_knn (neighbors , distances , core_distances , components_disjoint_set )
287- update_component_vectors (tree , components_disjoint_set , node_components , point_components )
288292
289- while n_components > 1 :
293+ edges = [np .empty ((0 , 3 ), dtype = np .float64 ) for _ in range (0 )]
294+ new_edges = initialize_boruvka_from_knn (neighbors , distances , core_distances , components_disjoint_set )
295+ while True :
296+ edges .append (new_edges )
297+ n_components -= new_edges .shape [0 ]
298+ if n_components == 1 :
299+ break
300+ update_point_components (components_disjoint_set , point_components )
301+ update_node_components (tree , node_components , point_components )
290302 candidate_distances , candidate_indices = boruvka_tree_query (tree , node_components , point_components ,
291303 core_distances )
292- new_edges = merge_components (components_disjoint_set , candidate_indices , candidate_distances , point_components )
293- update_component_vectors (tree , components_disjoint_set , node_components , point_components )
294-
295- edges = np .vstack ((edges , new_edges ))
296- n_components = np .unique (point_components ).shape [0 ]
304+ component_edges = select_components (candidate_distances , candidate_indices , point_components )
305+ new_edges = merge_components (components_disjoint_set , component_edges )
297306
307+ edges = np .vstack (edges )
298308 edges [:, 2 ] = np .sqrt (edges .T [2 ])
299- return edges
309+ return edges , neighbors [:, 1 :], np . sqrt ( core_distances )
0 commit comments