Skip to content

Commit d1eff71

Browse files
author
Jelmer Bot
committed
add tests with fixes
1 parent ef782ee commit d1eff71

File tree

5 files changed

+347
-20
lines changed

5 files changed

+347
-20
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,5 @@ venv/
6565
ENV/
6666
env.bak/
6767
venv.bak/
68+
69+
*.code-workspace

multi_mst/k_mst/k_mst.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
import warnings as warn
33
from umap import UMAP
4+
from scipy.sparse import coo_matrix
45
from sklearn.neighbors import KDTree
56
from sklearn.utils import check_array
67
from sklearn.base import BaseEstimator
@@ -29,7 +30,7 @@ def validate_parameters(data, num_neighbors, min_samples, epsilon):
2930
return data, epsilon
3031

3132

32-
def kMST(data, num_neighbors=5, min_samples=1, epsilon=None, umap_kwargs=None):
33+
def kMST(data, num_neighbors=3, min_samples=1, epsilon=None, umap_kwargs=None):
3334
"""
3435
Computes a $k$-MST for the given data. Adapts the boruvka algorithm to look
3536
for $k$ candidate edges per point, of which the $k$ best per connected
@@ -43,7 +44,7 @@ def kMST(data, num_neighbors=5, min_samples=1, epsilon=None, umap_kwargs=None):
4344
data: array-like
4445
The data to construct a MST for.
4546
num_neighbors: int, optional
46-
The number of edges to connect between each fragement. Default is 5.
47+
The number of edges to connect between each fragement. Default is 3.
4748
min_samples: int, optional
4849
The number of neighbors for computing the mutual reachability distance.
4950
Value must be lower or equal to the number of neighbors. `epsilon`
@@ -98,7 +99,7 @@ class KMST(BaseEstimator):
9899
Parameters
99100
----------
100101
num_neighbors: int, optional
101-
The number of edges to connect between each fragement. Default is 5.
102+
The number of edges to connect between each fragement. Default is 35.
102103
min_samples: int, optional
103104
The number of neighbors for computing the mutual reachability distance.
104105
Value must be lower or equal to the number of neighbors. `epsilon`
@@ -130,11 +131,12 @@ class KMST(BaseEstimator):
130131
missing values. Use the graph_ and embedding_ attributes instead!
131132
"""
132133

133-
def __init__(self, *, num_neighbors=5, min_samples=1, epsilon=None, **umap_kwargs):
134+
def __init__(self, *, num_neighbors=3, min_samples=1, epsilon=None, **umap_kwargs):
134135
self.num_neighbors = num_neighbors
135136
self.min_samples = min_samples
136137
self.epsilon = epsilon
137138
self.umap_kwargs = umap_kwargs
139+
# TODO: Sklearn does not support **kwargs in __init__ for BaseEstimator...
138140

139141
def fit(self, X, y=None, **fit_params):
140142
"""
@@ -170,21 +172,29 @@ def fit(self, X, y=None, **fit_params):
170172
clean_data = X
171173

172174
kwargs = self.get_params()
173-
self.mst_indices_, self.mst_distances_, self._umap = kMST(clean_data, **kwargs)
175+
print(kwargs, self.umap_kwargs)
176+
self.mst_indices_, self.mst_distances_, self._umap = kMST(
177+
clean_data, umap_kwargs=self.umap_kwargs, **kwargs
178+
)
174179
self.graph_ = self._umap.graph_.copy()
175180
self.embedding_ = (
176-
self._umap.embedding_.copy() if self._umap.embedding_ is not None else None
181+
self._umap.embedding_.copy() if hasattr(self._umap, "embedding_") else None
177182
)
178183

179184
if not self._all_finite:
180-
self.graph_.shape = (X.shape[0], X.shape[0])
185+
self.graph_ = self.graph_.tocoo()
181186
for i in range(len(self.graph_.data)):
182187
self.graph_.row[i] = internal_to_raw[self.graph_.row[i]]
183188
self.graph_.col[i] = internal_to_raw[self.graph_.col[i]]
189+
self.graph_ = coo_matrix(
190+
(self.graph_.data, (self.graph_.row, self.graph_.col)),
191+
shape=(X.shape[0], X.shape[0]),
192+
)
184193

185-
new_embedding = np.full((X.shape[0], self.num_components), np.nan)
186-
new_embedding[finite_index] = self.embedding_
187-
self.embedding_ = new_embedding
194+
if self.embedding_ is not None:
195+
new_embedding = np.full((X.shape[0], self.embedding_.shape[1]), np.nan)
196+
new_embedding[finite_index] = self.embedding_
197+
self.embedding_ = new_embedding
188198

189199
new_indices = np.full(
190200
(X.shape[0], self.mst_indices_.shape[1]), -1, dtype=np.int32

multi_mst/noisy_mst/noisy_mst.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
import warnings as warn
33
from umap import UMAP
4+
from scipy.sparse import coo_matrix
45
from sklearn.neighbors import KDTree
56
from sklearn.utils import check_array
67
from sklearn.base import BaseEstimator
@@ -29,7 +30,7 @@ def validate_parameters(data, num_trees, noise_fraction, min_samples):
2930
return data
3031

3132

32-
def noisyMST(data, num_trees=5, noise_fraction=0.1, min_samples=1, umap_kwargs=None):
33+
def noisyMST(data, num_trees=3, noise_fraction=0.1, min_samples=1, umap_kwargs=None):
3334
"""
3435
Computes a union of $k$ noisy MSTs for the given data. Adapts the boruvka
3536
algorithm construct multiple noisy miminum spanning trees.
@@ -45,7 +46,7 @@ def noisyMST(data, num_trees=5, noise_fraction=0.1, min_samples=1, umap_kwargs=N
4546
Adds Gaussian noise with scale=noise_fraction * distance to every computed
4647
distance value.
4748
num_trees: int, optional
48-
The number of noisy MSTS to create. Default is 5.
49+
The number of noisy MSTS to create. Default is 3.
4950
min_samples: int, optional
5051
The number of neighbors for computing the mutual reachability distance.
5152
Value must be lower or equal to the number of neighbors. `epsilon`
@@ -96,7 +97,7 @@ class NoisyMST(BaseEstimator):
9697
Parameters
9798
----------
9899
num_trees: int, optional
99-
The number of minimum spanning trees created. Default is 5.
100+
The number of minimum spanning trees created. Default is 3.
100101
noise_fraction:
101102
Adds Gaussian noise with scale=noise_fraction * distance to every computed
102103
distance value.
@@ -129,12 +130,13 @@ class NoisyMST(BaseEstimator):
129130
"""
130131

131132
def __init__(
132-
self, *, num_trees=5, noise_fraction=0.1, min_samples=1, **umap_kwargs
133+
self, *, num_trees=3, noise_fraction=0.1, min_samples=1, **umap_kwargs
133134
):
134135
self.num_trees = num_trees
135136
self.noise_fraction = noise_fraction
136137
self.min_samples = min_samples
137138
self.umap_kwargs = umap_kwargs
139+
# TODO: Sklearn does not support **kwargs in __init__ for BaseEstimator...
138140

139141
def fit(self, X, y=None, **fit_params):
140142
"""
@@ -171,22 +173,27 @@ def fit(self, X, y=None, **fit_params):
171173

172174
kwargs = self.get_params()
173175
self.mst_indices_, self.mst_distances_, self._umap = noisyMST(
174-
clean_data, **kwargs
176+
clean_data, umap_kwargs=self.umap_kwargs, **kwargs
175177
)
176178
self.graph_ = self._umap.graph_.copy()
177179
self.embedding_ = (
178-
self._umap.embedding_.copy() if self._umap.embedding_ is not None else None
180+
self._umap.embedding_.copy() if hasattr(self._umap, "embedding_") else None
179181
)
180182

181183
if not self._all_finite:
182-
self.graph_.shape = (X.shape[0], X.shape[0])
184+
self.graph_ = self.graph_.tocoo()
183185
for i in range(len(self.graph_.data)):
184186
self.graph_.row[i] = internal_to_raw[self.graph_.row[i]]
185187
self.graph_.col[i] = internal_to_raw[self.graph_.col[i]]
188+
self.graph_ = coo_matrix(
189+
(self.graph_.data, (self.graph_.row, self.graph_.col)),
190+
shape=(X.shape[0], X.shape[0]),
191+
)
186192

187-
new_embedding = np.full((X.shape[0], self.num_components), np.nan)
188-
new_embedding[finite_index] = self.embedding_
189-
self.embedding_ = new_embedding
193+
if self.embedding_ is not None:
194+
new_embedding = np.full((X.shape[0], self.embedding_.shape[1]), np.nan)
195+
new_embedding[finite_index] = self.embedding_
196+
self.embedding_ = new_embedding
190197

191198
new_indices = np.full(
192199
(X.shape[0], self.mst_indices_.shape[1]), -1, dtype=np.int32

tests/test_k_mst.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
"""Tests for MultiMST"""
2+
import pytest
3+
import numpy as np
4+
from sklearn.datasets import make_blobs
5+
from sklearn.preprocessing import StandardScaler
6+
from sklearn.utils._testing import assert_raises
7+
from scipy.sparse.csgraph import connected_components
8+
9+
from umap import UMAP
10+
from multi_mst.k_mst import kMST, KMST
11+
12+
13+
def generate_noisy_data():
14+
blobs, yBlobs = make_blobs(
15+
n_samples=50,
16+
centers=[(-0.75, 2.25), (2.0, -0.5)],
17+
cluster_std=0.2,
18+
random_state=3,
19+
)
20+
np.random.seed(5)
21+
noise = np.random.uniform(-1.0, 3.0, (50, 2))
22+
yNoise = np.full(50, -1)
23+
return (
24+
np.vstack((blobs, noise)),
25+
np.concatenate((yBlobs, yNoise)),
26+
)
27+
28+
29+
X, y = generate_noisy_data()
30+
X = StandardScaler().fit_transform(X)
31+
32+
X_missing_data = X.copy()
33+
X_missing_data[0] = [np.nan, 1]
34+
X_missing_data[6] = [np.nan, np.nan]
35+
36+
37+
def test_badargs():
38+
"""Tests parameter validation."""
39+
assert_raises(ValueError, kMST, X, num_neighbors=1.0)
40+
assert_raises(ValueError, kMST, X, num_neighbors=-1)
41+
assert_raises(ValueError, kMST, X, epsilon=1)
42+
assert_raises(ValueError, kMST, X, epsilon=0.6)
43+
assert_raises(ValueError, kMST, X, epsilon=-0.4)
44+
assert_raises(ValueError, kMST, X, min_samples=1.0)
45+
assert_raises(ValueError, kMST, X, num_neighbors=3, min_samples=4)
46+
assert_raises(ValueError, kMST, X, min_samples=0)
47+
assert_raises(ValueError, kMST, X, min_samples=-1)
48+
49+
50+
def test_defaults():
51+
"""Tests with default parameters."""
52+
p = KMST()
53+
54+
embedding = p.fit_transform(X)
55+
assert embedding.shape[0] == X.shape[0]
56+
assert embedding.shape[1] == 2 # Default num_components
57+
assert np.issubdtype(embedding.dtype, np.floating)
58+
59+
assert p.mst_indices_.shape[0] == X.shape[0]
60+
assert p.mst_indices_.shape[1] >= 3 # Default num_neighbors
61+
assert np.issubdtype(p.mst_indices_.dtype, np.integer)
62+
assert p.mst_distances_.shape[0] == X.shape[0]
63+
assert p.mst_distances_.shape[1] >= 3 # Default num_neighbors
64+
assert np.issubdtype(p.mst_distances_.dtype, np.floating)
65+
assert p.graph_.shape[0] == X.shape[0]
66+
assert p.graph_.shape[1] == X.shape[0]
67+
assert p.embedding_.shape[0] == X.shape[0]
68+
assert p.embedding_.shape[1] == 2 # Default num_components
69+
assert np.issubdtype(p.embedding_.dtype, np.floating)
70+
assert isinstance(p._umap, UMAP)
71+
assert connected_components(p._umap.graph_, directed=False, return_labels=False) == 1
72+
73+
74+
def test_with_missing_data():
75+
"""Tests with nan data."""
76+
clean_indices = list(range(1, 6)) + list(range(7, X.shape[0]))
77+
model = KMST().fit(X_missing_data)
78+
clean_model = KMST().fit(X_missing_data[clean_indices])
79+
80+
assert np.all(model.mst_indices_[0, :] == -1)
81+
assert np.isinf(model.mst_distances_[0, :]).all()
82+
assert np.isnan(model.embedding_[0, :]).all()
83+
assert np.all(model.graph_.row != 0) & np.all(model.graph_.col != 0)
84+
85+
assert np.all(model.mst_indices_[6, :] == -1)
86+
assert np.isinf(model.mst_distances_[6, :]).all()
87+
assert np.isnan(model.embedding_[6, :]).all()
88+
assert np.all(model.graph_.row != 6) & np.all(model.graph_.col != 6)
89+
90+
assert np.allclose(clean_model.graph_.data, model.graph_.data)
91+
assert np.allclose(clean_model.mst_indices_, model.mst_indices_[clean_indices])
92+
assert np.allclose(clean_model.mst_distances_, model.mst_distances_[clean_indices])
93+
94+
95+
def test_with_missing_data_graph_mode():
96+
"""Tests with nan data."""
97+
clean_indices = list(range(1, 6)) + list(range(7, X.shape[0]))
98+
model = KMST(transform_mode='graph').fit(X_missing_data)
99+
clean_model = KMST(transform_mode='graph').fit(X_missing_data[clean_indices])
100+
101+
assert np.all(model.mst_indices_[0, :] == -1)
102+
assert np.isinf(model.mst_distances_[0, :]).all()
103+
assert np.all(model.graph_.row != 0) & np.all(model.graph_.col != 0)
104+
105+
assert np.all(model.mst_indices_[6, :] == -1)
106+
assert np.isinf(model.mst_distances_[6, :]).all()
107+
assert np.all(model.graph_.row != 6) & np.all(model.graph_.col != 6)
108+
109+
assert np.allclose(clean_model.graph_.data, model.graph_.data)
110+
assert np.allclose(clean_model.mst_indices_, model.mst_indices_[clean_indices])
111+
assert np.allclose(clean_model.mst_distances_, model.mst_distances_[clean_indices])
112+
113+
114+
def test_graph_mode():
115+
"""Tests with transform_mode='graph'."""
116+
p = KMST(transform_mode='graph')
117+
embedding = p.fit_transform(X)
118+
119+
assert embedding is None
120+
assert p.embedding_ is None
121+
122+
123+
def test_min_samples():
124+
"""Tests with higher min_samples."""
125+
p = KMST(min_samples=3, transform_mode='graph').fit(X)
126+
127+
assert p.mst_indices_.shape[1] >= 3 # Max(min_samples, num_neighbors)
128+
assert p.mst_distances_.shape[1] >= 3 # Max(min_samples, num_neighbors)
129+
assert connected_components(p._umap.graph_, directed=False, return_labels=False) == 1
130+
131+
132+
def test_num_neighbors():
133+
"""Tests with lower num_neighbors."""
134+
p = KMST(num_neighbors=1, transform_mode='graph').fit(X)
135+
136+
assert p.mst_indices_.shape[1] >= 1 # Max(min_samples, num_neighbors)
137+
assert p.mst_distances_.shape[1] >= 1 # Max(min_samples, num_neighbors)
138+
assert connected_components(p._umap.graph_, directed=False, return_labels=False) == 1
139+
140+
141+
def test_epsilon():
142+
"""Tests with epsilon."""
143+
base = KMST(transform_mode='graph').fit(X)
144+
p = KMST(epsilon=1.1, transform_mode='graph').fit(X)
145+
146+
assert connected_components(p._umap.graph_, directed=False, return_labels=False) == 1
147+
assert p.graph_.nnz < base.graph_.nnz
148+
149+
150+
def test_num_components():
151+
"""Tests with higher num_components."""
152+
p = KMST(n_components=3)
153+
154+
embedding = p.fit_transform(X)
155+
assert embedding.shape[0] == X.shape[0]
156+
assert embedding.shape[1] == 3
157+
assert np.issubdtype(embedding.dtype, np.floating)
158+
159+
assert p.embedding_.shape[0] == X.shape[0]
160+
assert p.embedding_.shape[1] == 3
161+
assert np.issubdtype(p.embedding_.dtype, np.floating)

0 commit comments

Comments
 (0)