@@ -43,26 +43,6 @@ def knn_mst_union(neighbors, core_distances, min_spanning_tree, lens_values):
4343 return graph
4444
4545
46- @numba .njit (parallel = True )
47- def sort_by_lens (graph ):
48- for point in numba .prange (len (graph )):
49- graph [point ] = {
50- k : v for k , v in sorted (graph [point ].items (), key = lambda item : item [1 ][0 ])
51- }
52- return graph
53-
54-
55- @numba .njit (parallel = True )
56- def apply_lens (core_graph , lens_values ):
57- # Apply new lens to the graph
58- for point in numba .prange (len (lens_values )):
59- children = core_graph [point ]
60- point_lens = lens_values [point ]
61- for child , value in children .items ():
62- children [child ] = (max (point_lens , lens_values [child ]), value [1 ])
63- return sort_by_lens (core_graph )
64-
65-
6646@numba .njit ()
6747def flatten_to_csr (graph ):
6848 # Count children to form indptr
@@ -88,18 +68,45 @@ def flatten_to_csr(graph):
8868 return CoreGraph (weights , distances , indices , indptr )
8969
9070
71+ @numba .njit (parallel = True )
72+ def sort_by_lens (graph ):
73+ for point in numba .prange (len (graph )):
74+ start = graph .indptr [point ]
75+ end = graph .indptr [point + 1 ]
76+ weights = graph .weights [start :end ]
77+ order = np .argsort (weights )
78+ graph .weights [start :end ] = weights [order ]
79+ graph .distances [start :end ] = graph .distances [start :end ][order ]
80+ graph .indices [start :end ] = graph .indices [start :end ][order ]
81+ return graph
82+
83+
84+ @numba .njit (parallel = True )
85+ def apply_lens (core_graph , lens_values ):
86+ # Apply new lens to the graph
87+ for point in numba .prange (len (lens_values )):
88+ point_lens = lens_values [point ]
89+ start = core_graph .indptr [point ]
90+ end = core_graph .indptr [point + 1 ]
91+ for idx , child in enumerate (core_graph .indices [start :end ]):
92+ core_graph .weights [start + idx ] = max (point_lens , lens_values [child ])
93+ return sort_by_lens (core_graph )
94+
95+
9196@numba .njit (locals = {"parent" : numba .types .int32 })
92- def select_components (graph , point_components ):
97+ def select_components (distances , indices , indptr , point_components ):
9398 component_edges = {
9499 np .int64 (0 ): (np .int32 (0 ), np .int32 (1 ), np .float32 (0.0 )) for _ in range (0 )
95100 }
96101
97102 # Find the best edges from each component
98- for parent , (children , from_component ) in enumerate (zip (graph , point_components )):
99- if len (children ) == 0 :
103+ for parent , from_component in enumerate (point_components ):
104+ start = indptr [parent ]
105+ if indices [start ] == - 1 :
100106 continue
101- neighbor = next (iter (children .keys ()))
102- distance = np .float32 (children [neighbor ][0 ])
107+
108+ neighbor = indices [start ]
109+ distance = distances [start ]
103110 if from_component in component_edges :
104111 if distance < component_edges [from_component ][2 ]:
105112 component_edges [from_component ] = (parent , neighbor , distance )
@@ -109,41 +116,52 @@ def select_components(graph, point_components):
109116 return component_edges
110117
111118
112- @numba .njit () # enabling parallel breaks this function
113- def update_graph_components (graph , point_components ):
114- # deleting from dictionary during iteration breaks in numba.
115- for point in numba .prange (len (graph )):
116- graph [point ] = {
117- child : (weight , distance )
118- for child , (weight , distance ) in graph [point ].items ()
119- if point_components [child ] != point_components [point ]
120- }
119+ @numba .njit (parallel = True )
120+ def update_graph_components (distances , indices , indptr , point_components ):
121+ for point in numba .prange (len (point_components )):
122+ counter = 0
123+ start = indptr [point ]
124+ end = indptr [point + 1 ]
125+ for idx in range (start , end ):
126+ child = indices [idx ]
127+ if child == - 1 :
128+ break
129+ if point_components [child ] != point_components [point ]:
130+ indices [start + counter ] = indices [idx ]
131+ distances [start + counter ] = distances [idx ]
132+ counter += 1
133+ indices [start + counter : end ] = - 1
134+ distances [start + counter : end ] = np .inf
121135
122136
123137@numba .njit ()
124138def minimum_spanning_tree (graph , overwrite = False ):
125139 """
126140 Implements Boruvka on lod-style graph with multiple connected components.
127141 """
142+ distances = graph .weights
143+ indices = graph .indices
144+ indptr = graph .indptr
128145 if not overwrite :
129- graph = [children for children in graph ]
146+ indices = indices .copy ()
147+ distances = distances .copy ()
130148
131- disjoint_set = ds_rank_create (len (graph ) )
132- point_components = np .arange (len (graph ) )
149+ disjoint_set = ds_rank_create (len (indptr ) - 1 )
150+ point_components = np .arange (len (indptr ) - 1 )
133151 n_components = len (point_components )
134152
135153 edges_list = [np .empty ((0 , 3 ), dtype = np .float64 ) for _ in range (0 )]
136154 while n_components > 1 :
137155 new_edges = merge_components (
138156 disjoint_set ,
139- select_components (graph , point_components ),
157+ select_components (distances , indices , indptr , point_components ),
140158 )
141159 if new_edges .shape [0 ] == 0 :
142160 break
143161
144162 edges_list .append (new_edges )
145163 update_point_components (disjoint_set , point_components )
146- update_graph_components (graph , point_components )
164+ update_graph_components (distances , indices , indptr , point_components )
147165 n_components -= new_edges .shape [0 ]
148166
149167 counter = 0
@@ -155,12 +173,14 @@ def minimum_spanning_tree(graph, overwrite=False):
155173 return n_components , point_components , result
156174
157175
158- @numba .njit ()
176+ # @numba.njit()
159177def core_graph_spanning_tree (neighbors , core_distances , min_spanning_tree , lens ):
160178 graph = sort_by_lens (
161- knn_mst_union (neighbors , core_distances , min_spanning_tree , lens )
179+ flatten_to_csr (
180+ knn_mst_union (neighbors , core_distances , min_spanning_tree , lens )
181+ )
162182 )
163- return (* minimum_spanning_tree (graph ), flatten_to_csr ( graph ) )
183+ return (* minimum_spanning_tree (graph ), graph )
164184
165185
166186def core_graph_clusters (
0 commit comments