2828FAST_METRICS = KDTree .valid_metrics + BallTree .valid_metrics
2929
3030
31- def _rsl_generic (X , k = 5 , alpha = 1.4142135623730951 , metric = 'minkowski' , p = 2 ):
32- if metric == 'minkowski' :
33- distance_matrix = pairwise_distances (X , metric = metric , p = p )
34- else :
35- distance_matrix = pairwise_distances (X , metric = metric )
31+ def _rsl_generic (X , k = 5 , alpha = 1.4142135623730951 , metric = 'euclidean' , ** kwargs ):
32+ distance_matrix = pairwise_distances (X , metric = metric , ** kwargs )
3633
3734 mutual_reachability_ = mutual_reachability (distance_matrix , k )
3835
@@ -45,7 +42,7 @@ def _rsl_generic(X, k=5, alpha=1.4142135623730951, metric='minkowski', p=2):
4542 return single_linkage_tree
4643
4744
48- def _rsl_prims_kdtree (X , k = 5 , alpha = 1.4142135623730951 , metric = 'minkowski ' , ** kwargs ):
45+ def _rsl_prims_kdtree (X , k = 5 , alpha = 1.4142135623730951 , metric = 'euclidean ' , ** kwargs ):
4946
5047 # The Cython routines used require contiguous arrays
5148 if not X .flags ['C_CONTIGUOUS' ]:
@@ -67,7 +64,7 @@ def _rsl_prims_kdtree(X, k=5, alpha=1.4142135623730951, metric='minkowski', **kw
6764 return single_linkage_tree
6865
6966
70- def _rsl_prims_balltree (X , k = 5 , alpha = 1.4142135623730951 , metric = 'minkowski ' , ** kwargs ):
67+ def _rsl_prims_balltree (X , k = 5 , alpha = 1.4142135623730951 , metric = 'euclidean ' , ** kwargs ):
7168
7269 # The Cython routines used require contiguous arrays
7370 if not X .flags ['C_CONTIGUOUS' ]:
@@ -90,7 +87,7 @@ def _rsl_prims_balltree(X, k=5, alpha=1.4142135623730951, metric='minkowski', **
9087
9188
9289def _rsl_boruvka_kdtree (X , k = 5 , alpha = 1.0 ,
93- metric = 'minkowski ' , leaf_size = 40 , ** kwargs ):
90+ metric = 'euclidean ' , leaf_size = 40 , ** kwargs ):
9491
9592 dim = X .shape [0 ]
9693 min_samples = min (dim - 1 , k )
@@ -107,7 +104,7 @@ def _rsl_boruvka_kdtree(X, k=5, alpha=1.0,
107104
108105
109106def _rsl_boruvka_balltree (X , k = 5 , alpha = 1.0 ,
110- metric = 'minkowski ' , leaf_size = 40 , ** kwargs ):
107+ metric = 'euclidean ' , leaf_size = 40 , ** kwargs ):
111108
112109 dim = X .shape [0 ]
113110 min_samples = min (dim - 1 , k )
@@ -342,7 +339,7 @@ class RobustSingleLinkage(BaseEstimator, ClusterMixin):
342339 """
343340
344341 def __init__ (self , cut = 0.4 , k = 5 , alpha = 1.4142135623730951 , gamma = 5 , metric = 'euclidean' ,
345- algorithm = 'best' ):
342+ algorithm = 'best' , ** kwargs ):
346343
347344 self .cut = cut
348345 self .k = k
@@ -351,6 +348,8 @@ def __init__(self, cut=0.4, k=5, alpha=1.4142135623730951, gamma=5, metric='eucl
351348 self .metric = metric
352349 self .algorithm = algorithm
353350
351+ self ._metric_kwargs = kwargs
352+
354353 self ._cluster_hierarchy_ = None
355354
356355 def fit (self , X , y = None ):
@@ -364,7 +363,11 @@ def fit(self, X, y=None):
364363 ``metric='precomputed'``.
365364 """
366365 X = check_array (X , accept_sparse = 'csr' )
367- self .labels_ , self ._cluster_hierarchy_ = robust_single_linkage (X , ** self .get_params ())
366+
367+ kwargs = self .get_params ()
368+ kwargs .update (self ._metric_kwargs )
369+
370+ self .labels_ , self ._cluster_hierarchy_ = robust_single_linkage (X , ** kwargs )
368371 return self
369372
370373 def fit_predict (self , X , y = None ):
0 commit comments