diff --git a/gdist.pyx b/gdist.pyx index 704b758..a6f7774 100644 --- a/gdist.pyx +++ b/gdist.pyx @@ -58,6 +58,9 @@ import scipy.sparse from libcpp cimport bool from libcpp.vector cimport vector +import tqdm +from cpython.exc cimport PyErr_CheckSignals + ################################################################################ ############ Defininitions to access the C++ geodesic distance library ######### ################################################################################ @@ -85,8 +88,8 @@ cdef extern from "geodesic_utils.h": cdef extern from "geodesic_algorithm_exact.h" namespace "geodesic": cdef cppclass GeodesicAlgorithmExact: GeodesicAlgorithmExact(Mesh*) - void propagate(vector[SurfacePoint]&, double, vector[SurfacePoint]*) - unsigned best_source(SurfacePoint&, double&) + void propagate(vector[SurfacePoint]&, double, vector[SurfacePoint]*) nogil + unsigned best_source(SurfacePoint&, double&) nogil cdef extern from "geodesic_constants_and_simple_functions.h" namespace "geodesic": double GEODESIC_INF @@ -197,6 +200,7 @@ def compute_gdist(numpy.ndarray[numpy.float64_t, ndim=2] vertices, def local_gdist_matrix(numpy.ndarray[numpy.float64_t, ndim=2] vertices, numpy.ndarray[numpy.int32_t, ndim=2] triangles, + numpy.ndarray[numpy.int32_t, ndim=1] subset, double max_distance=GEODESIC_INF, bool is_one_indexed=False): """This is the wrapper function for computing geodesic distance from every @@ -253,8 +257,10 @@ def local_gdist_matrix(numpy.ndarray[numpy.float64_t, ndim=2] vertices, cdef vector[SurfacePoint] source, targets cdef Py_ssize_t N = vertices.shape[0] - cdef Py_ssize_t k + cdef Py_ssize_t k, i, ssmin cdef Py_ssize_t kk + cdef vector[Py_ssize_t] rows_v, columns_v + cdef vector[numpy.float64_t] data_v cdef numpy.float64_t distance = 0 # Add all vertices as targets @@ -264,12 +270,15 @@ def local_gdist_matrix(numpy.ndarray[numpy.float64_t, ndim=2] vertices, rows = [] columns = [] data = [] - for k in range(N): + ssmin = subset.min() + for i in range(subset.size): + k = subset[i] source.push_back(SurfacePoint(&amesh.vertices()[k])) algorithm.propagate(source, max_distance, NULL) source.pop_back() for kk in range(N): # TODO: Reduce to vertices reached during propagate. + algorithm.best_source(targets[kk], distance) if ( @@ -277,11 +286,13 @@ def local_gdist_matrix(numpy.ndarray[numpy.float64_t, ndim=2] vertices, and distance is not 0 and distance <= max_distance ): - rows.append(k) + rows.append(k - ssmin) columns.append(kk) data.append(distance) - return scipy.sparse.csc_matrix((data, (rows, columns)), shape=(N, N)) + PyErr_CheckSignals() + + return scipy.sparse.csc_matrix((data, (rows, columns)), shape=(subset.size, N)) def distance_matrix_of_selected_points( @@ -357,3 +368,25 @@ def distance_matrix_of_selected_points( (distance_matrix, (rows, columns)), shape=(no_of_vertices, no_of_vertices) ) + +def _local_gdist_matrix_parallel_helper(args): + vtx, tri, max_distance, is_one_indexed, ss = args + return local_gdist_matrix(vtx, tri, ss, max_distance, is_one_indexed) + + +def local_gdist_matrix_parallel(vtx, tri, max_distance, is_one_indexed, n_jobs=None, chunksize=256): + from multiprocessing import Pool, cpu_count + chunks = vtx.size // chunksize + ss = numpy.array_split(numpy.r_[:len(vtx)].astype(numpy.int32), chunks) + args = [(vtx, tri, max_distance, is_one_indexed, _) for _ in ss] + if n_jobs == 0: + parts = [_local_gdist_matrix_parallel_helper(_) for _ in args] + else: + with Pool(n_jobs) as p: + parts = [] + with tqdm.tqdm(total=chunks) as pbar: + imap = p.imap_unordered(_local_gdist_matrix_parallel_helper, args) + for i, part in enumerate(imap): + pbar.update() + parts.append(part) + return scipy.sparse.hstack(parts)