@@ -99,10 +99,13 @@ def rnn_dbscan_inner(is_core, knns, rev_knns, labels):
99
99
100
100
101
101
class RnnDBSCAN (ClusterMixin , BaseEstimator ):
102
- def __init__ (self , n_neighbors = 5 , * , input_guarantee = "none" , n_jobs = None ):
102
+ def __init__ (
103
+ self , n_neighbors = 5 , * , input_guarantee = "none" , n_jobs = None , keep_knns = False
104
+ ):
103
105
self .n_neighbors = n_neighbors
104
106
self .input_guarantee = input_guarantee
105
107
self .n_jobs = n_jobs
108
+ self .keep_knns = keep_knns
106
109
107
110
def fit (self , X , y = None ):
108
111
X = self ._validate_data (X , accept_sparse = "csr" )
@@ -116,7 +119,11 @@ def fit(self, X, y=None):
116
119
"Expected input_guarantee to be one of 'none', 'kneighbors'"
117
120
)
118
121
import timeit
122
+
119
123
XT = X .transpose ().tocsr (copy = True )
124
+ if self .keep_knns :
125
+ self .knns_ = X
126
+ self .rev_knns = XT
120
127
121
128
# Initially, all samples are unclassified.
122
129
labels = np .full (X .shape [0 ], UNCLASSIFIED , dtype = np .int32 )
@@ -136,18 +143,16 @@ def fit_predict(self, X, y=None):
136
143
self .fit (X , y = y )
137
144
return self .labels_
138
145
146
+ def drop_knns (self ):
147
+ del self .knns_
148
+ del self .rev_knns_
149
+
139
150
140
151
def simple_rnn_dbscan_pipeline (neighbor_transformer , n_neighbors , ** kwargs ):
141
152
from sklearn .pipeline import make_pipeline
153
+
142
154
n_jobs = kwargs .get ("n_jobs" , None )
143
155
return make_pipeline (
144
- neighbor_transformer (
145
- n_neighbors = n_neighbors ,
146
- ** kwargs ,
147
- ),
148
- RnnDBSCAN (
149
- n_neighbors = n_neighbors ,
150
- input_guarantee = "kneighbors" ,
151
- n_jobs = n_jobs
152
- ),
156
+ neighbor_transformer (n_neighbors = n_neighbors , ** kwargs ,),
157
+ RnnDBSCAN (n_neighbors = n_neighbors , input_guarantee = "kneighbors" , n_jobs = n_jobs ),
153
158
)
0 commit comments