Skip to content

Commit 2cfa1d3

Browse files
committed
Fix bug in core_distances; now works correctly for clustering the sample data in the repository!
1 parent 10c3427 commit 2cfa1d3

File tree

1 file changed

+17
-7
lines changed

1 file changed

+17
-7
lines changed

hdbscan/_hdbscan_boruvka.pyx

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,17 +64,19 @@ cdef class BoruvkaAlgorithm (object):
6464
cdef object tree
6565
cdef object dist
6666
cdef np.ndarray _data
67+
cdef np.int64_t min_samples
68+
cdef public np.double_t[::1] core_distance
6769
cdef public np.double_t[::1] bounds
6870
cdef public np.int64_t[::1] component_of_point
6971
cdef public np.int64_t[::1] component_of_node
7072
cdef public np.int64_t[::1] candidate_neighbor
7173
cdef public np.int64_t[::1] candidate_point
7274
cdef public np.double_t[::1] candidate_distance
73-
# cdef public np.double_t[:,::1] _centroid_distances
7475
cdef object component_union_find
7576
cdef set edges
7677

7778
cdef np.ndarray components
79+
cdef np.ndarray core_distance_arr
7880
cdef np.ndarray bounds_arr
7981
cdef np.ndarray _centroid_distances
8082
cdef np.ndarray component_of_point_arr
@@ -83,13 +85,14 @@ cdef class BoruvkaAlgorithm (object):
8385
cdef np.ndarray candidate_neighbor_arr
8486
cdef np.ndarray candidate_distance_arr
8587

86-
def __init__(self, tree, metric='euclidean', **kwargs):
88+
def __init__(self, tree, min_samples=5, metric='euclidean', **kwargs):
8789

8890
cdef np.int64_t num_points = tree.data.shape[0]
8991
cdef np.int64_t num_nodes = tree.node_data.shape[0]
9092

9193
self.tree = tree
9294
self._data = np.array(tree.data)
95+
self.min_samples = min_samples
9396

9497
self.dist = dist_metrics.DistanceMetric.get_metric(metric, **kwargs)
9598

@@ -112,8 +115,6 @@ cdef class BoruvkaAlgorithm (object):
112115

113116
self._centroid_distances = self.dist.pairwise(tree.node_bounds[0])
114117

115-
# self._centroid_distances = (<np.double_t[:num_nodes, :num_nodes:1]> (<np.double_t *> self._centroid_distances.data))
116-
117118
self._compute_bounds()
118119
self._initialize_components()
119120

@@ -123,6 +124,7 @@ cdef class BoruvkaAlgorithm (object):
123124
cdef np.int64_t num_points = self.tree.data.shape[0]
124125
cdef np.int64_t num_nodes = self.tree.node_data.shape[0]
125126

127+
cdef np.ndarray[np.double_t, ndim=2] knn_dist
126128
cdef np.ndarray[np.double_t, ndim=1] nn_dist
127129
cdef np.int64_t[::1] point_indices
128130
cdef np.int64_t[::1] idx_array = self.tree.idx_array
@@ -137,14 +139,19 @@ cdef class BoruvkaAlgorithm (object):
137139
cdef NodeData_t child1_info
138140
cdef NodeData_t child2_info
139141

140-
nn_dist = self.tree.query(self.tree.data, 2)[0][:,1]
142+
knn_dist = self.tree.query(self.tree.data, max(2, self.min_samples))[0]
143+
nn_dist = knn_dist[:, 1]
144+
self.core_distance_arr = knn_dist[:, self.min_samples - 1].copy()
145+
self.core_distance = (<np.double_t [:num_points:1]> (<np.double_t *> self.core_distance_arr.data))
141146

142147
for n in range(num_nodes - 1, -1, -1):
143148
node_info = self.tree.node_data[n]
144149
if node_info.is_leaf:
145150
point_indices = idx_array[node_info.idx_start:node_info.idx_end]
146151
b1 = nn_dist[point_indices].max()
152+
# b1 = self.core_distance_arr[point_indices].max()
147153
b2 = (nn_dist[point_indices] + 2 * node_info.radius).min()
154+
# b2 = (self.core_distance_arr[point_indices] + 2 * node_info.radius).min()
148155
self.bounds[n] = min(b1, b2)
149156
else:
150157
child1 = 2 * n + 1
@@ -258,6 +265,8 @@ cdef class BoruvkaAlgorithm (object):
258265
cdef np.int64_t component1
259266
cdef np.int64_t component2
260267

268+
cdef np.double_t mr_dist
269+
261270
node_dist = min_dist_dual(node1_info.radius, node2_info.radius,
262271
node1, node2, (<np.double_t [:num_nodes, :num_nodes:1]>
263272
(<np.double_t *> self._centroid_distances.data)))
@@ -289,8 +298,9 @@ cdef class BoruvkaAlgorithm (object):
289298
component1 = component_of_point_ptr[p]
290299
component2 = component_of_point_ptr[q]
291300
if component1 != component2:
292-
if distances[i, j] < candidate_distance_ptr[component1]:
293-
candidate_distance_ptr[component1] = distances[i, j]
301+
mr_dist = max(distances[i, j], self.core_distance[p], self.core_distance[q])
302+
if mr_dist < candidate_distance_ptr[component1]:
303+
candidate_distance_ptr[component1] = mr_dist
294304
self.candidate_neighbor[component1] = q
295305
self.candidate_point[component1] = p
296306

0 commit comments

Comments
 (0)