Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 28 additions & 7 deletions include/dbscan/pbbs/parallel.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,26 +139,38 @@ namespace parlay {
get_default_scheduler().stop();
}
}
inline bool sequential = false;

inline size_t num_workers() {
return internal::get_default_scheduler().num_workers();
return sequential ? 1u : internal::get_default_scheduler().num_workers();
}

inline size_t worker_id() {
return internal::get_default_scheduler().worker_id();
return sequential ? 0u : internal::get_default_scheduler().worker_id();
}

template <class F>
inline void parallel_for(size_t start, size_t end, F f,
size_t granularity=0,
bool conservative=false) {
if (end > start)
internal::get_default_scheduler().parfor(start, end, f, granularity, conservative);
if (end > start){
if (sequential){
for(size_t i=start; i<end; ++i) f(i);
}
else{
internal::get_default_scheduler().parfor(start, end, f, granularity, conservative);
}
}
}

template <typename Lf, typename Rf>
inline void par_do(Lf left, Rf right, bool conservative=false) {
return internal::get_default_scheduler().pardo(left, right, conservative);
if (sequential) {
left(); right();
}
else {
internal::get_default_scheduler().pardo(left, right, conservative);
}
}
}

Expand All @@ -171,15 +183,24 @@ using namespace parlay;
#define par_for_1 for
#define par_for_256 for

static int getWorkers() {return (int)num_workers();}
static int getWorkerId() {return (int)worker_id();}
static int getWorkers() {return sequential ? 1 : (int)num_workers();}
static int getWorkerId() {return sequential ? 0 : (int)worker_id();}
static void setWorkers(int n) { }
static void printScheduler() {
cout << "scheduler = Parlay-HomeGrown" << endl;
cout << "num-threads = " << getWorkers() << endl;}

#else

// Fix errors:
#include <atomic>
namespace parlay {
namespace internal {
extern inline void start_scheduler() {}
extern inline void stop_scheduler() {}
}
}

#define cilk_spawn
#define cilk_sync
#define parallel_main main
Expand Down
64 changes: 54 additions & 10 deletions src/dbscanmodule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
static bool scheduler_initialized = false;
static PyObject* scheduler_cleanup_weakref = nullptr;

static void cleanup_scheduler(PyObject *capsule)
static void cleanup_scheduler(PyObject *capsule=nullptr)
{
if (scheduler_initialized)
{
parlay::internal::stop_scheduler();
scheduler_initialized = false;
}
if (scheduler_initialized)
{
parlay::internal::stop_scheduler();
scheduler_initialized = false;
}
}

static void ensure_scheduler_initialized()
Expand All @@ -40,7 +40,7 @@ static PyObject* DBSCAN_py(PyObject* self, PyObject* args, PyObject *kwargs)
return NULL;
}

// Check the number of dimensions and that we actually recieved an np.ndarray
// Check the number of dimensions and that we actually received an np.ndarray
X = (PyArrayObject*)PyArray_FROMANY(
Xobj,
NPY_DOUBLE,
Expand Down Expand Up @@ -79,7 +79,10 @@ static PyObject* DBSCAN_py(PyObject* self, PyObject* args, PyObject *kwargs)
PyArrayObject* core_samples = (PyArrayObject*)PyArray_SimpleNew(1, &n, NPY_BOOL);
PyArrayObject* labels = (PyArrayObject*)PyArray_SimpleNew(1, &n, NPY_INT);

ensure_scheduler_initialized();
if (!parlay::sequential)
{
ensure_scheduler_initialized();
}

DBSCAN(
dim,
Expand All @@ -98,9 +101,27 @@ static PyObject* DBSCAN_py(PyObject* self, PyObject* args, PyObject *kwargs)
return result_tuple;
}

static PyObject* set_sequential_py(PyObject* self, PyObject* args)
{
int state = 1;
if (!PyArg_ParseTuple(args, "|p", &state)) {
return nullptr;
}
parlay::sequential = state == 1;
if (parlay::sequential) {
cleanup_scheduler();
}
Py_RETURN_NONE;
}

static PyObject* get_sequential_py(PyObject* self, PyObject* args)
{
return PyBool_FromLong(parlay::sequential ? 1 : 0);
}

PyDoc_STRVAR(doc_DBSCAN,
"DBSCAN(X, eps=0.5, min_samples=5)\n--\n\n\
Run DBSCAN on a set of n samples of dimension dim with a minimum seperation\n\
Run DBSCAN on a set of n samples of dimension dim with a minimum separation\n\
between the clusters (which must include at least min_samples) of eps. Points\n\
that do not fit in any cluster are labeled as noise (-1).\n\
\n\
Expand All @@ -113,7 +134,7 @@ Parameters\n\
X : np.ndarray[tuple[n, dim], np.float64]\n\
2-D array representing the samples.\n\
eps : float\n\
minimum seperation between the clusters.\n\
minimum separation between the clusters.\n\
min_samples : int\n\
minimum number of samples in the clusters.\n\
\n\
Expand All @@ -125,8 +146,31 @@ core_samples : np.ndarray[tuple[n], np.bool_]\n\
is each sample the core sample of its cluster\n\
\n");

PyDoc_STRVAR(doc_set_sequential,
"set_sequential(state=True)\n--\n\n\
Set whether DBSCAN runs in sequential mode (single-threaded).\n\
This mode is potentially more efficient.\n\
\n\
Parameters\n\
----------\n\
state : bool, default True\n\
If True, run sequentially. If False, allow parallel execution.\n\
");

PyDoc_STRVAR(doc_get_sequential,
"get_sequential()\n--\n\n\
Return the current state of the sequential setting.\n\
\n\
Returns\n\
-------\n\
state : bool\n\
True if running sequentially, False if in parallel mode.\n\
");

static struct PyMethodDef methods[] = {
{"DBSCAN", (PyCFunction)(void*)(PyCFunctionWithKeywords) DBSCAN_py, METH_VARARGS | METH_KEYWORDS, doc_DBSCAN},
{"set_sequential", (PyCFunction)set_sequential_py, METH_VARARGS, doc_set_sequential},
{"get_sequential", (PyCFunction)get_sequential_py, METH_NOARGS, doc_get_sequential},
{NULL, NULL, 0, NULL}
};

Expand Down