@@ -64,17 +64,19 @@ cdef class BoruvkaAlgorithm (object):
6464 cdef object tree
6565 cdef object dist
6666 cdef np.ndarray _data
67+ cdef np.int64_t min_samples
68+ cdef public np.double_t[::1 ] core_distance
6769 cdef public np.double_t[::1 ] bounds
6870 cdef public np.int64_t[::1 ] component_of_point
6971 cdef public np.int64_t[::1 ] component_of_node
7072 cdef public np.int64_t[::1 ] candidate_neighbor
7173 cdef public np.int64_t[::1 ] candidate_point
7274 cdef public np.double_t[::1 ] candidate_distance
73- # cdef public np.double_t[:,::1] _centroid_distances
7475 cdef object component_union_find
7576 cdef set edges
7677
7778 cdef np.ndarray components
79+ cdef np.ndarray core_distance_arr
7880 cdef np.ndarray bounds_arr
7981 cdef np.ndarray _centroid_distances
8082 cdef np.ndarray component_of_point_arr
@@ -83,13 +85,14 @@ cdef class BoruvkaAlgorithm (object):
8385 cdef np.ndarray candidate_neighbor_arr
8486 cdef np.ndarray candidate_distance_arr
8587
86- def __init__ (self , tree , metric = ' euclidean' , **kwargs ):
88+ def __init__ (self , tree , min_samples = 5 , metric = ' euclidean' , **kwargs ):
8789
8890 cdef np.int64_t num_points = tree.data.shape[0 ]
8991 cdef np.int64_t num_nodes = tree.node_data.shape[0 ]
9092
9193 self .tree = tree
9294 self ._data = np.array(tree.data)
95+ self .min_samples = min_samples
9396
9497 self .dist = dist_metrics.DistanceMetric.get_metric(metric, ** kwargs)
9598
@@ -112,8 +115,6 @@ cdef class BoruvkaAlgorithm (object):
112115
113116 self ._centroid_distances = self .dist.pairwise(tree.node_bounds[0 ])
114117
115- # self._centroid_distances = (<np.double_t[:num_nodes, :num_nodes:1]> (<np.double_t *> self._centroid_distances.data))
116-
117118 self ._compute_bounds()
118119 self ._initialize_components()
119120
@@ -123,6 +124,7 @@ cdef class BoruvkaAlgorithm (object):
123124 cdef np.int64_t num_points = self .tree.data.shape[0 ]
124125 cdef np.int64_t num_nodes = self .tree.node_data.shape[0 ]
125126
127+ cdef np.ndarray[np.double_t, ndim= 2 ] knn_dist
126128 cdef np.ndarray[np.double_t, ndim= 1 ] nn_dist
127129 cdef np.int64_t[::1 ] point_indices
128130 cdef np.int64_t[::1 ] idx_array = self .tree.idx_array
@@ -137,14 +139,19 @@ cdef class BoruvkaAlgorithm (object):
137139 cdef NodeData_t child1_info
138140 cdef NodeData_t child2_info
139141
140- nn_dist = self .tree.query(self .tree.data, 2 )[0 ][:,1 ]
142+ knn_dist = self .tree.query(self .tree.data, max (2 , self .min_samples))[0 ]
143+ nn_dist = knn_dist[:, 1 ]
144+ self .core_distance_arr = knn_dist[:, self .min_samples - 1 ].copy()
145+ self .core_distance = (< np.double_t [:num_points:1 ]> (< np.double_t * > self .core_distance_arr.data))
141146
142147 for n in range (num_nodes - 1 , - 1 , - 1 ):
143148 node_info = self .tree.node_data[n]
144149 if node_info.is_leaf:
145150 point_indices = idx_array[node_info.idx_start:node_info.idx_end]
146151 b1 = nn_dist[point_indices].max()
152+ # b1 = self.core_distance_arr[point_indices].max()
147153 b2 = (nn_dist[point_indices] + 2 * node_info.radius).min()
154+ # b2 = (self.core_distance_arr[point_indices] + 2 * node_info.radius).min()
148155 self .bounds[n] = min (b1, b2)
149156 else :
150157 child1 = 2 * n + 1
@@ -258,6 +265,8 @@ cdef class BoruvkaAlgorithm (object):
258265 cdef np.int64_t component1
259266 cdef np.int64_t component2
260267
268+ cdef np.double_t mr_dist
269+
261270 node_dist = min_dist_dual(node1_info.radius, node2_info.radius,
262271 node1, node2, (< np.double_t [:num_nodes, :num_nodes:1 ]>
263272 (< np.double_t * > self ._centroid_distances.data)))
@@ -289,8 +298,9 @@ cdef class BoruvkaAlgorithm (object):
289298 component1 = component_of_point_ptr[p]
290299 component2 = component_of_point_ptr[q]
291300 if component1 != component2:
292- if distances[i, j] < candidate_distance_ptr[component1]:
293- candidate_distance_ptr[component1] = distances[i, j]
301+ mr_dist = max (distances[i, j], self .core_distance[p], self .core_distance[q])
302+ if mr_dist < candidate_distance_ptr[component1]:
303+ candidate_distance_ptr[component1] = mr_dist
294304 self .candidate_neighbor[component1] = q
295305 self .candidate_point[component1] = p
296306
0 commit comments