44#include " dbscan/pbbs/parallel.h"
55
66
7+ static bool scheduler_initialized = false ;
8+ static PyObject* scheduler_cleanup_weakref = nullptr ;
9+
10+ static void cleanup_scheduler (PyObject *capsule)
11+ {
12+ if (scheduler_initialized)
13+ {
14+ parlay::internal::stop_scheduler ();
15+ scheduler_initialized = false ;
16+ }
17+ }
18+
19+ static void ensure_scheduler_initialized ()
20+ {
21+ if (!scheduler_initialized)
22+ {
23+ parlay::internal::start_scheduler ();
24+ scheduler_initialized = true ;
25+ }
26+ }
27+
728static PyObject* DBSCAN_py (PyObject* self, PyObject* args, PyObject *kwargs)
829{
930 PyObject *Xobj;
@@ -58,7 +79,7 @@ static PyObject* DBSCAN_py(PyObject* self, PyObject* args, PyObject *kwargs)
5879 PyArrayObject* core_samples = (PyArrayObject*)PyArray_SimpleNew (1 , &n, NPY_BOOL);
5980 PyArrayObject* labels = (PyArrayObject*)PyArray_SimpleNew (1 , &n, NPY_INT);
6081
61- parlay::internal::start_scheduler ();
82+ ensure_scheduler_initialized ();
6283
6384 DBSCAN (
6485 dim,
@@ -70,8 +91,6 @@ static PyObject* DBSCAN_py(PyObject* self, PyObject* args, PyObject *kwargs)
7091 (int *)PyArray_DATA (labels)
7192 );
7293
73- parlay::internal::stop_scheduler ();
74-
7594 PyObject* result_tuple = PyTuple_Pack (2 , labels, core_samples);
7695 Py_DECREF (X);
7796 Py_DECREF (core_samples);
@@ -130,6 +149,11 @@ PyInit__dbscan(void)
130149#endif
131150 PyModule_AddIntMacro (module , DBSCAN_MIN_DIMS);
132151 PyModule_AddIntMacro (module , DBSCAN_MAX_DIMS);
152+ PyObject *capsule = PyCapsule_New ((void *)module , " dbscan.scheduler" , cleanup_scheduler);
153+ if (capsule != NULL )
154+ {
155+ PyModule_AddObject (module , " _scheduler_capsule" , capsule);
156+ }
133157
134158 return module ;
135159}
0 commit comments