Skip to content

Commit 088f210

Browse files
committed
Fixes for RSL, plus more tests
1 parent b4c6815 commit 088f210

File tree

2 files changed

+43
-11
lines changed

2 files changed

+43
-11
lines changed

hdbscan/robust_single_linkage_.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,8 @@
2828
FAST_METRICS = KDTree.valid_metrics + BallTree.valid_metrics
2929

3030

31-
def _rsl_generic(X, k=5, alpha=1.4142135623730951, metric='minkowski', p=2):
32-
if metric == 'minkowski':
33-
distance_matrix = pairwise_distances(X, metric=metric, p=p)
34-
else:
35-
distance_matrix = pairwise_distances(X, metric=metric)
31+
def _rsl_generic(X, k=5, alpha=1.4142135623730951, metric='euclidean', **kwargs):
32+
distance_matrix = pairwise_distances(X, metric=metric, **kwargs)
3633

3734
mutual_reachability_ = mutual_reachability(distance_matrix, k)
3835

@@ -45,7 +42,7 @@ def _rsl_generic(X, k=5, alpha=1.4142135623730951, metric='minkowski', p=2):
4542
return single_linkage_tree
4643

4744

48-
def _rsl_prims_kdtree(X, k=5, alpha=1.4142135623730951, metric='minkowski', **kwargs):
45+
def _rsl_prims_kdtree(X, k=5, alpha=1.4142135623730951, metric='euclidean', **kwargs):
4946

5047
# The Cython routines used require contiguous arrays
5148
if not X.flags['C_CONTIGUOUS']:
@@ -67,7 +64,7 @@ def _rsl_prims_kdtree(X, k=5, alpha=1.4142135623730951, metric='minkowski', **kw
6764
return single_linkage_tree
6865

6966

70-
def _rsl_prims_balltree(X, k=5, alpha=1.4142135623730951, metric='minkowski', **kwargs):
67+
def _rsl_prims_balltree(X, k=5, alpha=1.4142135623730951, metric='euclidean', **kwargs):
7168

7269
# The Cython routines used require contiguous arrays
7370
if not X.flags['C_CONTIGUOUS']:
@@ -90,7 +87,7 @@ def _rsl_prims_balltree(X, k=5, alpha=1.4142135623730951, metric='minkowski', **
9087

9188

9289
def _rsl_boruvka_kdtree(X, k=5, alpha=1.0,
93-
metric='minkowski', leaf_size=40, **kwargs):
90+
metric='euclidean', leaf_size=40, **kwargs):
9491

9592
dim = X.shape[0]
9693
min_samples = min(dim - 1, k)
@@ -107,7 +104,7 @@ def _rsl_boruvka_kdtree(X, k=5, alpha=1.0,
107104

108105

109106
def _rsl_boruvka_balltree(X, k=5, alpha=1.0,
110-
metric='minkowski', leaf_size=40, **kwargs):
107+
metric='euclidean', leaf_size=40, **kwargs):
111108

112109
dim = X.shape[0]
113110
min_samples = min(dim - 1, k)
@@ -342,7 +339,7 @@ class RobustSingleLinkage(BaseEstimator, ClusterMixin):
342339
"""
343340

344341
def __init__(self, cut=0.4, k=5, alpha=1.4142135623730951, gamma=5, metric='euclidean',
345-
algorithm='best'):
342+
algorithm='best', **kwargs):
346343

347344
self.cut = cut
348345
self.k = k
@@ -351,6 +348,8 @@ def __init__(self, cut=0.4, k=5, alpha=1.4142135623730951, gamma=5, metric='eucl
351348
self.metric = metric
352349
self.algorithm = algorithm
353350

351+
self._metric_kwargs = kwargs
352+
354353
self._cluster_hierarchy_ = None
355354

356355
def fit(self, X, y=None):
@@ -364,7 +363,11 @@ def fit(self, X, y=None):
364363
``metric='precomputed'``.
365364
"""
366365
X = check_array(X, accept_sparse='csr')
367-
self.labels_, self._cluster_hierarchy_ = robust_single_linkage(X, **self.get_params())
366+
367+
kwargs = self.get_params()
368+
kwargs.update(self._metric_kwargs)
369+
370+
self.labels_, self._cluster_hierarchy_ = robust_single_linkage(X, **kwargs)
368371
return self
369372

370373
def fit_predict(self, X, y=None):

hdbscan/tests/test_rsl.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
# from sklearn.cluster.tests.common import generate_clustered_data
2020

2121
from sklearn import datasets
22+
import warnings
2223

2324
from sklearn.datasets import make_blobs
2425
from sklearn.utils import shuffle
@@ -96,6 +97,31 @@ def test_rsl_prims_kdtree():
9697
n_clusters_2 = len(set(labels)) - int(-1 in labels)
9798
assert_equal(n_clusters_2, n_clusters)
9899

100+
def test_rsl_unavailable_hierarchy():
101+
clusterer = RobustSingleLinkage()
102+
with warnings.catch_warnings(record=True) as w:
103+
tree = clusterer.cluster_hierarchy_
104+
assert(len(w) > 0)
105+
assert(tree is None)
106+
107+
def test_rsl_hierarchy():
108+
clusterer = RobustSingleLinkage().fit(X)
109+
assert(clusterer.cluster_hierarchy_ is not None)
110+
111+
def test_rsl_high_dimensional():
112+
H, y = make_blobs(n_samples=50, random_state=0, n_features=64)
113+
# H, y = shuffle(X, y, random_state=7)
114+
H = StandardScaler().fit_transform(H)
115+
labels, tree = robust_single_linkage(H, 5.5)
116+
n_clusters_1 = len(set(labels)) - int(-1 in labels)
117+
print(n_clusters_1)
118+
assert_equal(n_clusters_1, n_clusters)
119+
120+
labels = RobustSingleLinkage(cut=5.5, algorithm='best', metric='seuclidean', V=np.ones(H.shape[1])).fit(H).labels_
121+
n_clusters_2 = len(set(labels)) - int(-1 in labels)
122+
print n_clusters_2
123+
assert_equal(n_clusters_2, n_clusters)
124+
99125
def test_rsl_badargs():
100126
assert_raises(ValueError,
101127
robust_single_linkage,
@@ -154,6 +180,9 @@ def test_rsl_badargs():
154180
assert_raises(ValueError,
155181
robust_single_linkage,
156182
X, 0.4, leaf_size=0)
183+
assert_raises(ValueError,
184+
robust_single_linkage,
185+
X, 0.4, gamma=0)
157186

158187
def test_rsl_is_sklearn_estimator():
159188

0 commit comments

Comments
 (0)