Skip to content

Commit 87d08b9

Browse files
author
Frankie Robertson
committed
Fix up keep_knns in RnnDBSCAN
1 parent 8a5d1fc commit 87d08b9

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

sklearn_ann/cluster/rnn_dbscan.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def fit(self, X, y=None):
123123
XT = X.transpose().tocsr(copy=True)
124124
if self.keep_knns:
125125
self.knns_ = X
126-
self.rev_knns = XT
126+
self.rev_knns_ = XT
127127

128128
# Initially, all samples are unclassified.
129129
labels = np.full(X.shape[0], UNCLASSIFIED, dtype=np.int32)
@@ -152,7 +152,8 @@ def simple_rnn_dbscan_pipeline(neighbor_transformer, n_neighbors, **kwargs):
152152
from sklearn.pipeline import make_pipeline
153153

154154
n_jobs = kwargs.get("n_jobs", None)
155+
keep_knns = kwargs.pop("keep_knns", None)
155156
return make_pipeline(
156-
neighbor_transformer(n_neighbors=n_neighbors, **kwargs,),
157-
RnnDBSCAN(n_neighbors=n_neighbors, input_guarantee="kneighbors", n_jobs=n_jobs),
157+
neighbor_transformer(n_neighbors=n_neighbors, **kwargs),
158+
RnnDBSCAN(n_neighbors=n_neighbors, input_guarantee="kneighbors", n_jobs=n_jobs, keep_knns=keep_knns),
158159
)

0 commit comments

Comments
 (0)