2
2
3
3
from sklearn .base import BaseEstimator , ClusterMixin
4
4
from sklearn .utils import check_array
5
+ from sklearn .utils .validation import check_is_fitted , _check_sample_weight
5
6
from sklearn .neighbors import KDTree
6
7
7
8
from warnings import warn
@@ -135,6 +136,7 @@ def fast_hdbscan(
135
136
cluster_selection_method = "eom" ,
136
137
allow_single_cluster = False ,
137
138
cluster_selection_epsilon = 0.0 ,
139
+ sample_weights = None ,
138
140
return_trees = False ,
139
141
):
140
142
data = check_array (data )
@@ -156,10 +158,10 @@ def fast_hdbscan(
156
158
sklearn_tree = KDTree (data )
157
159
numba_tree = kdtree_to_numba (sklearn_tree )
158
160
edges = parallel_boruvka (
159
- numba_tree , min_samples = min_cluster_size if min_samples is None else min_samples
161
+ numba_tree , min_samples = min_cluster_size if min_samples is None else min_samples , sample_weights = sample_weights
160
162
)
161
163
sorted_mst = edges [np .argsort (edges .T [2 ])]
162
- linkage_tree = mst_to_linkage_tree (sorted_mst )
164
+ linkage_tree = mst_to_linkage_tree (sorted_mst , sample_weights = sample_weights )
163
165
condensed_tree = condense_tree (linkage_tree , min_cluster_size = min_cluster_size )
164
166
if cluster_selection_epsilon > 0.0 or cluster_selection_method == "eom" :
165
167
cluster_tree = cluster_tree_from_condensed_tree (condensed_tree )
@@ -208,8 +210,10 @@ def __init__(
208
210
self .allow_single_cluster = allow_single_cluster
209
211
self .cluster_selection_epsilon = cluster_selection_epsilon
210
212
211
- def fit (self , X , y = None , ** fit_params ):
213
+ def fit (self , X , y = None , sample_weight = None , ** fit_params ):
212
214
X = check_array (X , accept_sparse = "csr" , force_all_finite = False )
215
+ if sample_weight is not None :
216
+ sample_weight = _check_sample_weight (sample_weight , X , dtype = np .float32 )
213
217
self ._raw_data = X
214
218
215
219
self ._all_finite = np .all (np .isfinite (X ))
@@ -233,7 +237,7 @@ def fit(self, X, y=None, **fit_params):
233
237
self ._single_linkage_tree ,
234
238
self ._condensed_tree ,
235
239
self ._min_spanning_tree ,
236
- ) = fast_hdbscan (clean_data , return_trees = True , ** kwargs )
240
+ ) = fast_hdbscan (clean_data , return_trees = True , sample_weights = sample_weight , ** kwargs )
237
241
238
242
self ._condensed_tree = to_numpy_rec_array (self ._condensed_tree )
239
243
@@ -256,6 +260,7 @@ def fit(self, X, y=None, **fit_params):
256
260
return self
257
261
258
262
def dbscan_clustering (self , epsilon ):
263
+ check_is_fitted (self , "_single_linkage_tree" , msg = "You first need to fit the HDBSCAN model before picking a DBSCAN clustering" )
259
264
return get_cluster_labelling_at_cut (
260
265
self ._single_linkage_tree ,
261
266
epsilon ,
@@ -264,6 +269,7 @@ def dbscan_clustering(self, epsilon):
264
269
265
270
@property
266
271
def condensed_tree_ (self ):
272
+ check_is_fitted (self , "_condensed_tree" , msg = "You first need to fit the HDBSCAN model before accessing the condensed tree" )
267
273
if self ._condensed_tree is not None :
268
274
return CondensedTree (
269
275
self ._condensed_tree ,
@@ -277,6 +283,7 @@ def condensed_tree_(self):
277
283
278
284
@property
279
285
def single_linkage_tree_ (self ):
286
+ check_is_fitted (self , "_single_linkage_tree" , msg = "You first need to fit the HDBSCAN model before accessing the single linkage tree" )
280
287
if self ._single_linkage_tree is not None :
281
288
return SingleLinkageTree (self ._single_linkage_tree )
282
289
else :
@@ -286,6 +293,7 @@ def single_linkage_tree_(self):
286
293
287
294
@property
288
295
def minimum_spanning_tree_ (self ):
296
+ check_is_fitted (self , "_min_spanning_tree" , msg = "You first need to fit the HDBSCAN model before accessing the minimum spanning tree" )
289
297
if self ._min_spanning_tree is not None :
290
298
if self ._raw_data is not None :
291
299
return MinimumSpanningTree (self ._min_spanning_tree , self ._raw_data )
0 commit comments