@@ -398,15 +398,13 @@ cdef class KDTreeBoruvkaAlgorithm (object):
398
398
# into four piles and query them in parallel. On multicore systems
399
399
# (most systems) this amounts to a 2x-3x wall clock improvement.
400
400
if self .tree.data.shape[0 ] > 16384 and self .n_jobs > 1 :
401
- datasets = [
402
- np.asarray(self .tree.data[0 :self .num_points// 4 ]),
403
- np.asarray(self .tree.data[self .num_points// 4 :
404
- self .num_points// 2 ]),
405
- np.asarray(self .tree.data[self .num_points// 2 :
406
- 3 * (self .num_points// 4 )]),
407
- np.asarray(self .tree.data[3 * (self .num_points// 4 ):
408
- self .num_points])
409
- ]
401
+ split_cnt = self .num_points // self .n_jobs
402
+ datasets = []
403
+ for i in range (self .n_jobs):
404
+ if i == self .n_jobs - 1 :
405
+ datasets.append(np.asarray(self .tree.data[i* split_cnt:]))
406
+ else :
407
+ datasets.append(np.asarray(self .tree.data[i* split_cnt:(i+ 1 )* split_cnt]))
410
408
411
409
knn_data = Parallel(n_jobs = self .n_jobs)(
412
410
delayed(_core_dist_query,
@@ -1003,14 +1001,13 @@ cdef class BallTreeBoruvkaAlgorithm (object):
1003
1001
cdef np.ndarray[np.intp_t, ndim= 2 ] knn_indices
1004
1002
1005
1003
if self .tree.data.shape[0 ] > 16384 and self .n_jobs > 1 :
1006
- datasets = [np.asarray(self .tree.data[0 :self .num_points// 4 ]),
1007
- np.asarray(self .tree.data[self .num_points// 4 :
1008
- self .num_points// 2 ]),
1009
- np.asarray(self .tree.data[self .num_points// 2 :
1010
- 3 * (self .num_points// 4 )]),
1011
- np.asarray(self .tree.data[3 * (self .num_points// 4 ):
1012
- self .num_points])
1013
- ]
1004
+ split_cnt = self .num_points // self .n_jobs
1005
+ datasets = []
1006
+ for i in range (self .n_jobs):
1007
+ if i == self .n_jobs - 1 :
1008
+ datasets.append(np.asarray(self .tree.data[i* split_cnt:]))
1009
+ else :
1010
+ datasets.append(np.asarray(self .tree.data[i* split_cnt:(i+ 1 )* split_cnt]))
1014
1011
1015
1012
knn_data = Parallel(n_jobs = self .n_jobs)(
1016
1013
delayed(_core_dist_query,
0 commit comments