Skip to content

Commit eb845e2

Browse files
committed
Add skeleton of robust single linkage clustering.
1 parent 2712292 commit eb845e2

File tree

1 file changed

+174
-0
lines changed

1 file changed

+174
-0
lines changed

hdbscan/robust_single_linkage_.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Robust Single Linkage: Density based single linkage clustering.
4+
"""
5+
# Author: Leland McInnes <[email protected]>
6+
#
7+
# License: BSD 3 clause
8+
9+
import numpy as np
10+
11+
from sklearn.base import BaseEstimator, ClusterMixin
12+
from sklearn.metrics import pairwise_distances
13+
from scipy.sparse import issparse
14+
from sklearn.neighbors import KDTree
15+
16+
try:
17+
from sklearn.utils import check_array
18+
except ImportError:
19+
from sklearn.utils import check_arrays
20+
21+
check_array = check_arrays
22+
23+
from ._hdbscan_linkage import single_linkage, mst_linkage_core, mst_linkage_core_pdist, label
24+
from ._hdbscan_reachability import kdtree_pdist_mutual_reachability, kdtree_mutual_reachability, mutual_reachability
25+
from .plots import SingleLinkageTree
26+
27+
try:
28+
from fastcluster import single
29+
HAVE_FASTCLUSTER = True
30+
except ImportError:
31+
HAVE_FASTCLUSTER = False
32+
33+
def _rsl_small(X, cut, k=5, alpha=1.4142135623730951, gamma=5, metric='minkowski', p=2):
34+
35+
if metric == 'minkowski':
36+
if p is None:
37+
raise TypeError('Minkowski metric given but no p value supplied!')
38+
if p < 0:
39+
raise ValueError('Minkowski metric with negative p value is not defined!')
40+
41+
distance_matrix = pairwise_distances(X, metric=metric, p=p)
42+
else:
43+
distance_matrix = pairwise_distances(X, metric=metric)
44+
45+
mutual_reachability_ = mutual_reachability(distance_matrix, k)
46+
47+
min_spanning_tree = mst_linkage_core(mutual_reachability_)
48+
min_spanning_tree = min_spanning_tree[np.argsort(min_spanning_tree.T[2]), :]
49+
50+
single_linkage_tree = label(min_spanning_tree)
51+
single_linkage_tree = SingleLinkageTree(single_linkage_tree)
52+
53+
labels = single_linkage_tree.get_clusters(cut, gamma)
54+
55+
return labels, single_linkage_tree
56+
57+
58+
def _rsl_small_kdtree(X, cut, k=5, alpha=1.4142135623730951, gamma=5, metric='minkowski', p=2):
59+
60+
if metric == 'minkowski':
61+
if p is None:
62+
raise TypeError('Minkowski metric given but no p value supplied!')
63+
if p < 0:
64+
raise ValueError('Minkowski metric with negative p value is not defined!')
65+
66+
distance_matrix = pairwise_distances(X, metric=metric, p=p)
67+
else:
68+
distance_matrix = pairwise_distances(X, metric=metric)
69+
70+
mutual_reachability_ = kdtree_mutual_reachability(X,
71+
distance_matrix,
72+
metric,
73+
p=p,
74+
min_points=k,
75+
alpha=alpha)
76+
77+
min_spanning_tree = mst_linkage_core(mutual_reachability_)
78+
min_spanning_tree = min_spanning_tree[np.argsort(min_spanning_tree.T[2]), :]
79+
80+
single_linkage_tree = label(min_spanning_tree)
81+
single_linkage_tree = SingleLinkageTree(single_linkage_tree)
82+
83+
labels = single_linkage_tree.get_clusters(cut, gamma)
84+
85+
return labels, single_linkage_tree
86+
87+
def _rsl_large_kdtree(X, cut, k=5, alpha=1.4142135623730951, gamma=5, metric='minkowski', p=2):
88+
89+
if p is None:
90+
p = 2
91+
92+
mutual_reachability_ = kdtree_pdist_mutual_reachability(X, metric, p, k, alpha)
93+
94+
min_spanning_tree = mst_linkage_core(mutual_reachability_)
95+
min_spanning_tree = min_spanning_tree[np.argsort(min_spanning_tree.T[2]), :]
96+
97+
single_linkage_tree = label(min_spanning_tree)
98+
single_linkage_tree = SingleLinkageTree(single_linkage_tree)
99+
100+
labels = single_linkage_tree.get_clusters(cut, gamma)
101+
102+
return labels, single_linkage_tree
103+
104+
105+
def _rsl_large_kdtree_fastcluster(X, cut, k=5, alpha=1.4142135623730951, gamma=5, metric='minkowski', p=2):
106+
if p is None:
107+
p = 2
108+
109+
mutual_reachability_ = kdtree_pdist_mutual_reachability(X, metric,
110+
p, k, alpha)
111+
112+
single_linkage_tree = single(mutual_reachability_)
113+
single_linkage_tree = SingleLinkageTree(single_linkage_tree)
114+
115+
labels = single_linkage_tree.get_clusters(cut, gamma)
116+
117+
return labels, single_linkage_tree
118+
119+
120+
def robust_single_linkage(X, cut, k=5, alpha=1.4142135623730951, gamma=5, metric='minkowski', p=2, algorithm=None):
121+
122+
if type(k) is not int or k < 1:
123+
raise ValueError('k must be an integer greater than zero!')
124+
125+
if type(alpha) is not float or alpha < 1.0:
126+
raise ValueError('alpha must be a float greater than or equal to 1.0!')
127+
128+
if type(gamma) is not int or gamma < 1:
129+
raise ValueError('gamma must be an integer greater than zero!')
130+
131+
X = check_array(X, accept_sparse='csr')
132+
133+
if algorithm is not None:
134+
if algorithm == 'small':
135+
return _rsl_small(X, cut, k, alpha, gamma, metric, p)
136+
elif algorithm == 'small_kdtree':
137+
return _rsl_small_kdtree(X, cut, k, alpha, gamma, metric, p)
138+
elif algorithm == 'large_kdtree':
139+
return _rsl_large_kdtree(X, cut, k, alpha, gamma, metric, p)
140+
elif algorithm == 'large_kdtree_fastcluster':
141+
return _rsl_large_kdtree_fastcluster(X, cut, k, alpha, gamma, metric, p)
142+
else:
143+
raise TypeError('Unknown algorithm type %s specified' % algorithm)
144+
145+
if issparse(X) or metric not in KDTree.valid_metrics: # We can't do much with sparse matrices ...
146+
return _rsl_small(X, cut, k, alpha, gamma, metric, p)
147+
elif X.shape[0] < 4000:
148+
return _rsl_small_kdtree(X, cut, k, alpha, gamma, metric, p)
149+
elif HAVE_FASTCLUSTER:
150+
return _rsl_large_kdtree_fastcluster(X, cut, k, alpha, gamma, metric, p)
151+
else:
152+
return _rsl_large_kdtree(X, cut, k, alpha, gamma, metric, p)
153+
154+
155+
class RobustSingleLinkage (BaseEstimator, ClusterMixin):
156+
157+
def __init__(self, k=5, alpha=1.4142135623730951, gamma=5, metric=euclidean, p=None):
158+
159+
self.k = k
160+
self.alpha = alpha
161+
self.gamma = gamma
162+
self.metric = metric
163+
self.p = p
164+
165+
self.cluster_hierarchy_ = None
166+
167+
def fit(self, X, y=None):
168+
X = check_array(X, accept_sparse='csr')
169+
self.labels_, self.cluster_hierarchy = robust_single_linkage(X, **self.get_params())
170+
return self
171+
172+
def fit_predict(self, X, y=None):
173+
self.fit(X)
174+
return self.labels_

0 commit comments

Comments
 (0)