Skip to content

Commit 3bfd391

Browse files
authored
Merge pull request #22 from John-194/leak_fixes
Fix memory leaks and performance improvements
2 parents 42e1e71 + 2f9b33d commit 3bfd391

File tree

4 files changed

+80
-21
lines changed

4 files changed

+80
-21
lines changed

include/dbscan/algo.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,14 @@ int DBSCAN(intT n, floatT* PF, double epsilon, intT minPts, bool* coreFlagOut, i
8585

8686
typedef kdTree<dim, pointT> treeT;
8787
auto trees = newA(treeT*, G->numCell());
88-
parallel_for(0, G->numCell(), [&](intT i) {trees[i] = NULL;});
88+
89+
parallel_for(0, G->numCell(), [&](intT i) {
90+
if (ccFlag[i]) {
91+
trees[i] = new treeT(G->getCell(i)->getItem(), G->getCell(i)->size(), false);
92+
} else {
93+
trees[i] = NULL;
94+
}
95+
});
8996

9097
// auto degCmp = [&](intT i, intT j) {
9198
// return G->getCell(i)->size() < G->getCell(j)->size();

include/dbscan/grid.h

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
#pragma once
2525

26+
#include <mutex>
2627
#include "cell.h"
2728
#include "point.h"
2829
#include "shared.h"
@@ -77,6 +78,7 @@ struct grid {
7778
treeT* tree=NULL;
7879
intT totalPoints;
7980
cellBuf **nbrCache;
81+
std::mutex* cacheLocks;
8082

8183
/**
8284
* Grid constructor.
@@ -89,10 +91,12 @@ struct grid {
8991

9092
cells = newA(cellT, cellCapacity);
9193
nbrCache = newA(cellBuf*, cellCapacity);
94+
cacheLocks = (std::mutex*) malloc(cellCapacity * sizeof(std::mutex));
9295
parallel_for(0, cellCapacity, [&](intT i) {
93-
nbrCache[i] = NULL;
94-
cells[i].init();
95-
});
96+
new (&cacheLocks[i]) std::mutex();
97+
nbrCache[i] = NULL;
98+
cells[i].init();
99+
});
96100
numCells = 0;
97101

98102
myHash = new cellHashT(pMinn, r);
@@ -101,9 +105,10 @@ struct grid {
101105

102106
~grid() {
103107
free(cells);
104-
parallel_for(0, numCells, [&](intT i) {
105-
if(nbrCache[i]) delete nbrCache[i];
106-
});
108+
free(cacheLocks);
109+
parallel_for(0, cellCapacity, [&](intT i) {
110+
if(nbrCache[i]) delete nbrCache[i];
111+
});
107112
free(nbrCache);
108113
if(myHash) delete myHash;
109114
if(table) {
@@ -141,14 +146,24 @@ struct grid {
141146
}
142147
}
143148
return false;};//todo, optimize
144-
if (nbrCache[bait-cells]) {
145-
auto accum = nbrCache[bait-cells];
149+
int idx = bait - cells;
150+
if (nbrCache[idx]) {
151+
auto accum = nbrCache[idx];
146152
for (auto accum_i : *accum) {
147153
if(fWrap(accum_i)) break;
148154
}
149155
} else {
150-
floatT hop = sqrt(dim + 3) * 1.0000001;
151-
nbrCache[bait-cells] = tree->rangeNeighbor(bait, r * hop, fStop, fWrap, true, nbrCache[bait-cells]);
156+
// wait for other threads to do their thing then try again
157+
std::lock_guard<std::mutex> lock(cacheLocks[idx]);
158+
if (nbrCache[idx]) {
159+
auto accum = nbrCache[idx];
160+
for (auto accum_i : *accum) {
161+
if (fWrap(accum_i)) break;
162+
}
163+
} else {
164+
floatT hop = sqrt(dim + 3) * 1.0000001;
165+
nbrCache[idx] = tree->rangeNeighbor(bait, r * hop, fStop, fWrap, true, nbrCache[idx]);
166+
}
152167
}
153168
}
154169

@@ -160,14 +175,24 @@ struct grid {
160175
return f(cell);
161176
return false;
162177
};
163-
if (nbrCache[bait-cells]) {
164-
auto accum = nbrCache[bait-cells];
178+
int idx = bait - cells;
179+
if (nbrCache[idx]) {
180+
auto accum = nbrCache[idx];
165181
for (auto accum_i : *accum) {
166-
if(fWrap(accum_i)) break;
182+
if (fWrap(accum_i)) break;
167183
}
168184
} else {
169-
floatT hop = sqrt(dim + 3) * 1.0000001;
170-
nbrCache[bait-cells] = tree->rangeNeighbor(bait, r * hop, fStop, fWrap, true, nbrCache[bait-cells]);
185+
// wait for other threads to do their thing then try again
186+
std::lock_guard<std::mutex> lock(cacheLocks[idx]);
187+
if (nbrCache[idx]) {
188+
auto accum = nbrCache[idx];
189+
for (auto accum_i : *accum) {
190+
if (fWrap(accum_i)) break;
191+
}
192+
} else {
193+
floatT hop = sqrt(dim + 3) * 1.0000001;
194+
nbrCache[bait-cells] = tree->rangeNeighbor(bait, r * hop, fStop, fWrap, true, nbrCache[idx]);
195+
}
171196
}
172197
}
173198

setup.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ def initialize_options(self):
5151
depends=depends,
5252
py_limited_api=True,
5353
define_macros=[
54-
('Py_LIMITED_API', '0x03020000'),
5554
('NPY_NO_DEPRECATED_API', 'NPY_1_7_API_VERSION'),
5655
# ('DBSCAN_VERSION', json.dumps(version)),
5756
]

src/dbscanmodule.cpp

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,27 @@
44
#include "dbscan/pbbs/parallel.h"
55

66

7+
static bool scheduler_initialized = false;
8+
static PyObject* scheduler_cleanup_weakref = nullptr;
9+
10+
static void cleanup_scheduler(PyObject *capsule)
11+
{
12+
if (scheduler_initialized)
13+
{
14+
parlay::internal::stop_scheduler();
15+
scheduler_initialized = false;
16+
}
17+
}
18+
19+
static void ensure_scheduler_initialized()
20+
{
21+
if (!scheduler_initialized)
22+
{
23+
parlay::internal::start_scheduler();
24+
scheduler_initialized = true;
25+
}
26+
}
27+
728
static PyObject* DBSCAN_py(PyObject* self, PyObject* args, PyObject *kwargs)
829
{
930
PyObject *Xobj;
@@ -58,7 +79,7 @@ static PyObject* DBSCAN_py(PyObject* self, PyObject* args, PyObject *kwargs)
5879
PyArrayObject* core_samples = (PyArrayObject*)PyArray_SimpleNew(1, &n, NPY_BOOL);
5980
PyArrayObject* labels = (PyArrayObject*)PyArray_SimpleNew(1, &n, NPY_INT);
6081

61-
parlay::internal::start_scheduler();
82+
ensure_scheduler_initialized();
6283

6384
DBSCAN(
6485
dim,
@@ -70,9 +91,11 @@ static PyObject* DBSCAN_py(PyObject* self, PyObject* args, PyObject *kwargs)
7091
(int*)PyArray_DATA(labels)
7192
);
7293

73-
parlay::internal::stop_scheduler();
74-
75-
return PyTuple_Pack(2, labels, core_samples);
94+
PyObject* result_tuple = PyTuple_Pack(2, labels, core_samples);
95+
Py_DECREF(X);
96+
Py_DECREF(core_samples);
97+
Py_DECREF(labels);
98+
return result_tuple;
7699
}
77100

78101
PyDoc_STRVAR(doc_DBSCAN,
@@ -126,6 +149,11 @@ PyInit__dbscan(void)
126149
#endif
127150
PyModule_AddIntMacro(module, DBSCAN_MIN_DIMS);
128151
PyModule_AddIntMacro(module, DBSCAN_MAX_DIMS);
152+
PyObject *capsule = PyCapsule_New((void *)module, "dbscan.scheduler", cleanup_scheduler);
153+
if (capsule != NULL)
154+
{
155+
PyModule_AddObject(module, "_scheduler_capsule", capsule);
156+
}
129157

130158
return module;
131159
}

0 commit comments

Comments
 (0)