Skip to content

Commit 8a5d1fc

Browse files
author
Frankie Robertson
committed
Add keep_knns to RnnDBSCAN to keep nn graph
1 parent f6bad1a commit 8a5d1fc

File tree

1 file changed

+15
-10
lines changed

1 file changed

+15
-10
lines changed

sklearn_ann/cluster/rnn_dbscan.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,13 @@ def rnn_dbscan_inner(is_core, knns, rev_knns, labels):
9999

100100

101101
class RnnDBSCAN(ClusterMixin, BaseEstimator):
102-
def __init__(self, n_neighbors=5, *, input_guarantee="none", n_jobs=None):
102+
def __init__(
103+
self, n_neighbors=5, *, input_guarantee="none", n_jobs=None, keep_knns=False
104+
):
103105
self.n_neighbors = n_neighbors
104106
self.input_guarantee = input_guarantee
105107
self.n_jobs = n_jobs
108+
self.keep_knns = keep_knns
106109

107110
def fit(self, X, y=None):
108111
X = self._validate_data(X, accept_sparse="csr")
@@ -116,7 +119,11 @@ def fit(self, X, y=None):
116119
"Expected input_guarantee to be one of 'none', 'kneighbors'"
117120
)
118121
import timeit
122+
119123
XT = X.transpose().tocsr(copy=True)
124+
if self.keep_knns:
125+
self.knns_ = X
126+
self.rev_knns = XT
120127

121128
# Initially, all samples are unclassified.
122129
labels = np.full(X.shape[0], UNCLASSIFIED, dtype=np.int32)
@@ -136,18 +143,16 @@ def fit_predict(self, X, y=None):
136143
self.fit(X, y=y)
137144
return self.labels_
138145

146+
def drop_knns(self):
147+
del self.knns_
148+
del self.rev_knns_
149+
139150

140151
def simple_rnn_dbscan_pipeline(neighbor_transformer, n_neighbors, **kwargs):
141152
from sklearn.pipeline import make_pipeline
153+
142154
n_jobs = kwargs.get("n_jobs", None)
143155
return make_pipeline(
144-
neighbor_transformer(
145-
n_neighbors=n_neighbors,
146-
**kwargs,
147-
),
148-
RnnDBSCAN(
149-
n_neighbors=n_neighbors,
150-
input_guarantee="kneighbors",
151-
n_jobs=n_jobs
152-
),
156+
neighbor_transformer(n_neighbors=n_neighbors, **kwargs,),
157+
RnnDBSCAN(n_neighbors=n_neighbors, input_guarantee="kneighbors", n_jobs=n_jobs),
153158
)

0 commit comments

Comments
 (0)