Skip to content

Commit ff0b3a5

Browse files
committed
fix parallel dictionary crazyness
1 parent 70dff00 commit ff0b3a5

File tree

1 file changed

+62
-42
lines changed

1 file changed

+62
-42
lines changed

fast_hdbscan/core_graph.py

Lines changed: 62 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -43,26 +43,6 @@ def knn_mst_union(neighbors, core_distances, min_spanning_tree, lens_values):
4343
return graph
4444

4545

46-
@numba.njit(parallel=True)
47-
def sort_by_lens(graph):
48-
for point in numba.prange(len(graph)):
49-
graph[point] = {
50-
k: v for k, v in sorted(graph[point].items(), key=lambda item: item[1][0])
51-
}
52-
return graph
53-
54-
55-
@numba.njit(parallel=True)
56-
def apply_lens(core_graph, lens_values):
57-
# Apply new lens to the graph
58-
for point in numba.prange(len(lens_values)):
59-
children = core_graph[point]
60-
point_lens = lens_values[point]
61-
for child, value in children.items():
62-
children[child] = (max(point_lens, lens_values[child]), value[1])
63-
return sort_by_lens(core_graph)
64-
65-
6646
@numba.njit()
6747
def flatten_to_csr(graph):
6848
# Count children to form indptr
@@ -88,18 +68,45 @@ def flatten_to_csr(graph):
8868
return CoreGraph(weights, distances, indices, indptr)
8969

9070

71+
@numba.njit(parallel=True)
72+
def sort_by_lens(graph):
73+
for point in numba.prange(len(graph)):
74+
start = graph.indptr[point]
75+
end = graph.indptr[point + 1]
76+
weights = graph.weights[start:end]
77+
order = np.argsort(weights)
78+
graph.weights[start:end] = weights[order]
79+
graph.distances[start:end] = graph.distances[start:end][order]
80+
graph.indices[start:end] = graph.indices[start:end][order]
81+
return graph
82+
83+
84+
@numba.njit(parallel=True)
85+
def apply_lens(core_graph, lens_values):
86+
# Apply new lens to the graph
87+
for point in numba.prange(len(lens_values)):
88+
point_lens = lens_values[point]
89+
start = core_graph.indptr[point]
90+
end = core_graph.indptr[point + 1]
91+
for idx, child in enumerate(core_graph.indices[start:end]):
92+
core_graph.weights[start + idx] = max(point_lens, lens_values[child])
93+
return sort_by_lens(core_graph)
94+
95+
9196
@numba.njit(locals={"parent": numba.types.int32})
92-
def select_components(graph, point_components):
97+
def select_components(distances, indices, indptr, point_components):
9398
component_edges = {
9499
np.int64(0): (np.int32(0), np.int32(1), np.float32(0.0)) for _ in range(0)
95100
}
96101

97102
# Find the best edges from each component
98-
for parent, (children, from_component) in enumerate(zip(graph, point_components)):
99-
if len(children) == 0:
103+
for parent, from_component in enumerate(point_components):
104+
start = indptr[parent]
105+
if indices[start] == -1:
100106
continue
101-
neighbor = next(iter(children.keys()))
102-
distance = np.float32(children[neighbor][0])
107+
108+
neighbor = indices[start]
109+
distance = distances[start]
103110
if from_component in component_edges:
104111
if distance < component_edges[from_component][2]:
105112
component_edges[from_component] = (parent, neighbor, distance)
@@ -109,41 +116,52 @@ def select_components(graph, point_components):
109116
return component_edges
110117

111118

112-
@numba.njit() # enabling parallel breaks this function
113-
def update_graph_components(graph, point_components):
114-
# deleting from dictionary during iteration breaks in numba.
115-
for point in numba.prange(len(graph)):
116-
graph[point] = {
117-
child: (weight, distance)
118-
for child, (weight, distance) in graph[point].items()
119-
if point_components[child] != point_components[point]
120-
}
119+
@numba.njit(parallel=True)
120+
def update_graph_components(distances, indices, indptr, point_components):
121+
for point in numba.prange(len(point_components)):
122+
counter = 0
123+
start = indptr[point]
124+
end = indptr[point + 1]
125+
for idx in range(start, end):
126+
child = indices[idx]
127+
if child == -1:
128+
break
129+
if point_components[child] != point_components[point]:
130+
indices[start + counter] = indices[idx]
131+
distances[start + counter] = distances[idx]
132+
counter += 1
133+
indices[start + counter : end] = -1
134+
distances[start + counter : end] = np.inf
121135

122136

123137
@numba.njit()
124138
def minimum_spanning_tree(graph, overwrite=False):
125139
"""
126140
Implements Boruvka on lod-style graph with multiple connected components.
127141
"""
142+
distances = graph.weights
143+
indices = graph.indices
144+
indptr = graph.indptr
128145
if not overwrite:
129-
graph = [children for children in graph]
146+
indices = indices.copy()
147+
distances = distances.copy()
130148

131-
disjoint_set = ds_rank_create(len(graph))
132-
point_components = np.arange(len(graph))
149+
disjoint_set = ds_rank_create(len(indptr) - 1)
150+
point_components = np.arange(len(indptr) - 1)
133151
n_components = len(point_components)
134152

135153
edges_list = [np.empty((0, 3), dtype=np.float64) for _ in range(0)]
136154
while n_components > 1:
137155
new_edges = merge_components(
138156
disjoint_set,
139-
select_components(graph, point_components),
157+
select_components(distances, indices, indptr, point_components),
140158
)
141159
if new_edges.shape[0] == 0:
142160
break
143161

144162
edges_list.append(new_edges)
145163
update_point_components(disjoint_set, point_components)
146-
update_graph_components(graph, point_components)
164+
update_graph_components(distances, indices, indptr, point_components)
147165
n_components -= new_edges.shape[0]
148166

149167
counter = 0
@@ -155,12 +173,14 @@ def minimum_spanning_tree(graph, overwrite=False):
155173
return n_components, point_components, result
156174

157175

158-
@numba.njit()
176+
# @numba.njit()
159177
def core_graph_spanning_tree(neighbors, core_distances, min_spanning_tree, lens):
160178
graph = sort_by_lens(
161-
knn_mst_union(neighbors, core_distances, min_spanning_tree, lens)
179+
flatten_to_csr(
180+
knn_mst_union(neighbors, core_distances, min_spanning_tree, lens)
181+
)
162182
)
163-
return (*minimum_spanning_tree(graph), flatten_to_csr(graph))
183+
return (*minimum_spanning_tree(graph), graph)
164184

165185

166186
def core_graph_clusters(

0 commit comments

Comments
 (0)