Skip to content

Commit 2b61477

Browse files
committed
add tsne layout
1 parent a288717 commit 2b61477

File tree

6 files changed

+304
-17
lines changed

6 files changed

+304
-17
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,5 @@ ENV/
4545
env.bak/
4646
venv.bak/
4747

48-
*.code-workspace
48+
*.code-workspace
49+
.vscode/

src/multi_mst/base.py

Lines changed: 266 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,20 @@
1+
import math
2+
import numba
13
import numpy as np
24
import warnings as warn
35
from typing import Literal, Any, Callable
46

5-
from scipy.sparse import csr_array
7+
from time import time
8+
from numpy.random import RandomState
9+
from scipy.sparse import csr_array, spmatrix
610
from sklearn.utils import check_array
7-
from sklearn.utils.validation import check_is_fitted, _check_sample_weight
11+
from sklearn.decomposition import PCA
12+
from sklearn.manifold._t_sne import TSNE, _joint_probabilities_nn
13+
from sklearn.utils.validation import (
14+
check_is_fitted,
15+
_check_sample_weight,
16+
check_random_state,
17+
)
818

919
from umap import UMAP
1020
from fast_hbcc.sub_clusters import BoundaryClusterDetector
@@ -229,18 +239,6 @@ def umap(
229239
provide easy visualization, but can reasonably be set to any integer
230240
value in the range 2 to 100.
231241
232-
metric: string or function (optional, default 'euclidean')
233-
The metric to use to compute distances in output dimensional space.
234-
If a string is passed it must match a valid predefined metric, see
235-
UMAP's documentation for available options. If a general metric is
236-
required a function that takes two 1d arrays and returns a float can
237-
be provided. For performance purposes it is required that this be a
238-
numba jit'd function.
239-
240-
metric_kwds: dict (optional, default None)
241-
Keyword arguments to pass on to the metric, such as the ``p`` value
242-
of Minkowski distance. If None then no arguments are passed on.
243-
244242
n_epochs: int (optional, default None)
245243
The number of training epochs to be used in optimizing the low
246244
dimensional embedding. Larger values result in more accurate
@@ -434,11 +432,13 @@ def umap(
434432
message=".*is not an NNDescent object.*",
435433
)
436434
umap = UMAP(
435+
n_components=n_components,
437436
n_neighbors=self.graph_neighbors_.shape[1],
437+
metric=self.metric,
438+
metric_kwds=self.metric_kwds,
438439
precomputed_knn=csr_to_neighbor_list(
439440
self._graph.data, self._graph.indices, self._graph.indptr
440441
),
441-
n_components=n_components,
442442
output_metric=output_metric,
443443
output_metric_kwds=output_metric_kwds,
444444
n_epochs=n_epochs,
@@ -471,6 +471,173 @@ def umap(
471471

472472
return umap
473473

474+
def tsne(
475+
self,
476+
*,
477+
n_components: int = 2,
478+
init: np.ndarray | spmatrix | Literal["random", "pca"] = "pca",
479+
learning_rate: float | Literal["auto"] = "auto",
480+
early_exaggeration: float = 12.0,
481+
min_grad_norm: float = 1e-7,
482+
max_iter: int = 1000,
483+
n_iter_without_progress: int = 300,
484+
method: Literal["barnes_hut", "exact"] = "barnes_hut",
485+
angle: float = 0.5,
486+
random_state: RandomState | int | None = None,
487+
verbose: int = 0,
488+
):
489+
"""Constructs and fits a TSNE model [1]_ to the kMST graph.
490+
491+
Unlike HDBSCAN and HBCC, TSNE does not support infinite data. To ensure
492+
all TSNE's member functions work as expected, the TSNE model is NOT
493+
remapped to the infinite data after fitting. As a result, combining TSNE
494+
and HDBSCAN results need to consider the finite index: ```
495+
plt.scatter(*tsne.embedding_.T,
496+
c=hdbscan.labels_[multi_mst.finite_index])
497+
```
498+
499+
Parameters
500+
----------
501+
n_components : int, default=2
502+
Dimension of the embedded space.
503+
504+
init : {"random", "pca"} or array of shape (n_samples, n_components), default="pca"
505+
Initialization of embedding.
506+
507+
learning_rate : float or "auto", default="auto"
508+
The learning rate for t-SNE is usually in the range [10.0, 1000.0].
509+
If the learning rate is too high, the data may look like a 'ball'
510+
with any point approximately equidistant from its nearest neighbors.
511+
If the learning rate is too low, most points may look compressed in
512+
a dense cloud with few outliers. If the cost function gets stuck in
513+
a bad local minimum increasing the learning rate may help.
514+
515+
early_exaggeration : float, default=12.0
516+
Controls how tight natural clusters in the original space are in the
517+
embedded space and how much space will be between them. For larger
518+
values, the space between natural clusters will be larger in the
519+
embedded space. The choice of this parameter is not very critical.
520+
If the cost function increases during initial optimization, the
521+
early exaggeration factor or the learning rate might be too high.
522+
523+
min_grad_norm : float, default=1e-7
524+
If the gradient norm is below this threshold, the optimization will
525+
be stopped.
526+
527+
max_iter : int, default=1000
528+
Maximum number of iterations for the optimization. Should be at
529+
least 250.
530+
531+
n_iter_without_progress : int, default=300
532+
Maximum number of iterations without progress before we abort the
533+
optimization, used after 250 initial iterations with early
534+
exaggeration. Note that progress is only checked every 50 iterations
535+
so this value is rounded to the next multiple of 50.
536+
537+
method : {'barnes_hut', 'exact'}, default='barnes_hut'
538+
By default the gradient calculation algorithm uses Barnes-Hut
539+
approximation running in O(NlogN) time. method='exact'
540+
will run on the slower, but exact, algorithm in O(N^2) time. The
541+
exact algorithm should be used when nearest-neighbor errors need
542+
to be better than 3%. However, the exact method cannot scale to
543+
millions of examples.
544+
545+
angle : float, default=0.5
546+
This is the trade-off between speed and accuracy for Barnes-Hut
547+
T-SNE. 'angle' is the angular size of a distant node as measured
548+
from a point. If this size is below 'angle' then it is used as a
549+
summary node of all points contained within it. This method is not
550+
very sensitive to changes in this parameter in the range of 0.2 -
551+
0.8. Angle less than 0.2 has quickly increasing computation time and
552+
angle greater 0.8 has quickly increasing error.
553+
554+
random_state : int, RandomState instance or None, default=None
555+
Determines the random number generator. Pass an int for reproducible
556+
results across multiple function calls. Note that different
557+
initializations might result in different local minima of the cost
558+
function.
559+
560+
verbose : int, default=0
561+
Verbosity level.
562+
563+
Returns
564+
-------
565+
tsne : TSNE
566+
The fitted TSNE model.
567+
568+
References
569+
----------
570+
.. [1] van der Maaten, L., & Hinton, G. (2008). Visualizing data using
571+
t-SNE. Journal of Machine Learning Research.
572+
"""
573+
check_is_fitted(
574+
self,
575+
["_graph", "_raw_data"],
576+
msg="You first need to fit the estimator before accessing member functions.",
577+
)
578+
if method == "barnes_hut" and n_components > 3:
579+
raise ValueError(
580+
"'n_components' should be inferior to 4 for the "
581+
"barnes_hut algorithm as it relies on "
582+
"quad-tree or oct-tree."
583+
)
584+
585+
# Extract raw data
586+
X = self._raw_data
587+
if not self._all_finite:
588+
X = self._raw_data[self.finite_index]
589+
X = X.astype(np.float32, copy=False)
590+
591+
# Build the t-SNE model
592+
n_samples = self._graph.shape[0]
593+
random_state = check_random_state(random_state)
594+
tsne = TSNE(
595+
method=method,
596+
metric=self.metric,
597+
metric_params=self.metric_kwds,
598+
n_components=n_components,
599+
init=init,
600+
learning_rate=learning_rate,
601+
early_exaggeration=early_exaggeration,
602+
min_grad_norm=min_grad_norm,
603+
max_iter=max_iter,
604+
n_iter_without_progress=n_iter_without_progress,
605+
angle=angle,
606+
random_state=random_state,
607+
verbose=verbose,
608+
)
609+
610+
# Configure parameters set in fit
611+
if learning_rate == "auto":
612+
tsne.learning_rate_ = X.shape[0] / tsne.early_exaggeration / 4
613+
tsne.learning_rate_ = np.maximum(tsne.learning_rate_, 50)
614+
else:
615+
tsne.learning_rate_ = tsne.learning_rate
616+
617+
# Do the initialization
618+
if isinstance(init, np.ndarray):
619+
X_embedded = init
620+
elif init == "pca":
621+
pca = PCA(
622+
n_components=n_components,
623+
svd_solver="randomized",
624+
random_state=random_state,
625+
)
626+
pca.set_output(transform="default")
627+
X_embedded = pca.fit_transform(X).astype(np.float32, copy=False)
628+
X_embedded = X_embedded / np.std(X_embedded[:, 0]) * 1e-4
629+
elif init == "random":
630+
X_embedded = 1e-4 * random_state.standard_normal(
631+
size=(n_samples, n_components)
632+
).astype(np.float32)
633+
634+
# Fit tSNE optimizer (run on csr matrix)
635+
tsne.graph_, tsne.perplexity = _joint_probabilities_csr(self._graph, verbose)
636+
tsne.embedding_ = tsne._tsne(
637+
tsne.graph_, max(n_components - 1, 1), n_samples, X_embedded=X_embedded
638+
)
639+
return tsne
640+
474641
def hdbscan(
475642
self,
476643
data_labels=None,
@@ -820,7 +987,7 @@ def branch_detector(
820987
cluster_selection_persistence: float = 0.0,
821988
propagate_labels: bool = False,
822989
):
823-
"""Constructs and fits a metric-aware BranchDetector [1]_, ensuring
990+
"""Constructs and fits a metric-aware BranchDetector [1]_, ensuring
824991
valid parameter--metric combinations.
825992
826993
Parameters
@@ -1115,3 +1282,86 @@ def remap_hbcc(clusterer, finite_index, internal_to_raw, num_points):
11151282
new_bc = np.zeros(num_points)
11161283
new_bc[finite_index] = clusterer.boundary_coefficient_
11171284
clusterer.boundary_coefficient_ = new_bc
1285+
1286+
1287+
@numba.njit()
1288+
def _binary_search_perplexity_csr(dists, indptr):
1289+
INFINITY = np.inf
1290+
EPSILON_DBL = 1e-8
1291+
PERPLEXITY_TOLERANCE = 1e-5
1292+
n_steps = 100
1293+
desired_perplexity = (np.diff(indptr).max() - 1) / 3
1294+
desired_entropy = math.log(desired_perplexity)
1295+
1296+
probs = np.empty_like(dists)
1297+
for start, end in zip(indptr[:-1], indptr[1:]):
1298+
beta_min = -INFINITY
1299+
beta_max = INFINITY
1300+
beta = 1.0
1301+
1302+
# Binary search for beta to achieve desired perplexity
1303+
for _ in range(n_steps):
1304+
# Convert distances to similarities
1305+
sum_Pi = 0.0
1306+
for idx in range(start, end):
1307+
probs[idx] = math.exp(-dists[idx] * beta)
1308+
sum_Pi += probs[idx]
1309+
1310+
if sum_Pi == 0.0:
1311+
sum_Pi = EPSILON_DBL
1312+
1313+
# Normalize the probabilities
1314+
sum_disti_Pi = 0.0
1315+
for idx in range(start, end):
1316+
probs[idx] /= sum_Pi
1317+
sum_disti_Pi += dists[idx] * probs[idx]
1318+
1319+
# Compute the resulting entropy
1320+
entropy = math.log(sum_Pi) + beta * sum_disti_Pi
1321+
entropy_diff = entropy - desired_entropy
1322+
if math.fabs(entropy_diff) <= PERPLEXITY_TOLERANCE:
1323+
break
1324+
1325+
# Update beta values
1326+
if entropy_diff > 0.0:
1327+
beta_min = beta
1328+
if beta_max == INFINITY:
1329+
beta *= 2.0
1330+
else:
1331+
beta = (beta + beta_max) / 2.0
1332+
else:
1333+
beta_max = beta
1334+
if beta_min == -INFINITY:
1335+
beta /= 2.0
1336+
else:
1337+
beta = (beta + beta_min) / 2.0
1338+
1339+
return probs, desired_perplexity
1340+
1341+
1342+
def _joint_probabilities_csr(graph, verbose):
1343+
"""Compute joint probabilities p_ij from sparse distances."""
1344+
t0 = time()
1345+
# Compute conditional probabilities such that they approximately match
1346+
# the desired perplexity
1347+
graph.sort_indices()
1348+
n_samples = graph.shape[0]
1349+
conditional_P, perplexity = _binary_search_perplexity_csr(graph.data, graph.indptr)
1350+
assert np.all(np.isfinite(conditional_P)), "All probabilities should be finite"
1351+
1352+
# Symmetrize the joint probability distribution using sparse operations
1353+
P = csr_array(
1354+
(conditional_P, graph.indices, graph.indptr),
1355+
shape=(n_samples, n_samples),
1356+
)
1357+
P = P + P.T
1358+
1359+
# Normalize the joint probability distribution
1360+
sum_P = np.maximum(P.sum(), np.finfo(np.double).eps)
1361+
P /= sum_P
1362+
1363+
assert np.all(np.abs(P.data) <= 1.0)
1364+
if verbose >= 2:
1365+
duration = time() - t0
1366+
print("[t-SNE] Computed conditional probabilities in {:.3f}s".format(duration))
1367+
return P, perplexity

src/multi_mst/tests/test_descent.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from sklearn.preprocessing import StandardScaler
77

88
from umap import UMAP
9+
from sklearn.manifold import TSNE
910
from fast_hbcc import HBCC, BoundaryClusterDetector
1011
from fast_hdbscan import HDBSCAN
1112
from multi_mst import kMSTDescent, KMSTDescent
@@ -138,6 +139,14 @@ def test_umap():
138139
assert umap.embedding_.shape == (X_missing_data.shape[0] - 2, 2)
139140

140141

142+
def test_tsne():
143+
model = KMSTDescent().fit(X_missing_data)
144+
tsne = model.tsne()
145+
146+
assert isinstance(tsne, TSNE)
147+
assert tsne.embedding_.shape == (X_missing_data.shape[0] - 2, 2)
148+
149+
141150
def test_hdbscan():
142151
model = KMSTDescent().fit(X_missing_data)
143152
hdbscan = model.hdbscan(min_cluster_size=5)

src/multi_mst/tests/test_kmst.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from scipy.sparse import coo_array
99

1010
from umap import UMAP
11+
from sklearn.manifold import TSNE
1112
from fast_hbcc import HBCC, BoundaryClusterDetector
1213
from fast_hdbscan import HDBSCAN
1314
from multi_mst import kMST, KMST
@@ -173,6 +174,14 @@ def test_umap():
173174
assert umap.embedding_.shape == (X_missing_data.shape[0] - 2, 2)
174175

175176

177+
def test_tsne():
178+
model = KMST().fit(X_missing_data)
179+
tsne = model.tsne()
180+
181+
assert isinstance(tsne, TSNE)
182+
assert tsne.embedding_.shape == (X_missing_data.shape[0] - 2, 2)
183+
184+
176185
def test_hdbscan():
177186
model = KMST().fit(X_missing_data)
178187
hdbscan = model.hdbscan(min_cluster_size=5)

0 commit comments

Comments
 (0)