1+ """Tests for MultiMST"""
2+ import pytest
3+ import numpy as np
4+ from sklearn .datasets import make_blobs
5+ from sklearn .preprocessing import StandardScaler
6+ from sklearn .utils ._testing import assert_raises
7+ from scipy .sparse .csgraph import connected_components
8+
9+ from umap import UMAP
10+ from multi_mst .k_mst import kMST , KMST
11+
12+
13+ def generate_noisy_data ():
14+ blobs , yBlobs = make_blobs (
15+ n_samples = 50 ,
16+ centers = [(- 0.75 , 2.25 ), (2.0 , - 0.5 )],
17+ cluster_std = 0.2 ,
18+ random_state = 3 ,
19+ )
20+ np .random .seed (5 )
21+ noise = np .random .uniform (- 1.0 , 3.0 , (50 , 2 ))
22+ yNoise = np .full (50 , - 1 )
23+ return (
24+ np .vstack ((blobs , noise )),
25+ np .concatenate ((yBlobs , yNoise )),
26+ )
27+
28+
29+ X , y = generate_noisy_data ()
30+ X = StandardScaler ().fit_transform (X )
31+
32+ X_missing_data = X .copy ()
33+ X_missing_data [0 ] = [np .nan , 1 ]
34+ X_missing_data [6 ] = [np .nan , np .nan ]
35+
36+
37+ def test_badargs ():
38+ """Tests parameter validation."""
39+ assert_raises (ValueError , kMST , X , num_neighbors = 1.0 )
40+ assert_raises (ValueError , kMST , X , num_neighbors = - 1 )
41+ assert_raises (ValueError , kMST , X , epsilon = 1 )
42+ assert_raises (ValueError , kMST , X , epsilon = 0.6 )
43+ assert_raises (ValueError , kMST , X , epsilon = - 0.4 )
44+ assert_raises (ValueError , kMST , X , min_samples = 1.0 )
45+ assert_raises (ValueError , kMST , X , num_neighbors = 3 , min_samples = 4 )
46+ assert_raises (ValueError , kMST , X , min_samples = 0 )
47+ assert_raises (ValueError , kMST , X , min_samples = - 1 )
48+
49+
50+ def test_defaults ():
51+ """Tests with default parameters."""
52+ p = KMST ()
53+
54+ embedding = p .fit_transform (X )
55+ assert embedding .shape [0 ] == X .shape [0 ]
56+ assert embedding .shape [1 ] == 2 # Default num_components
57+ assert np .issubdtype (embedding .dtype , np .floating )
58+
59+ assert p .mst_indices_ .shape [0 ] == X .shape [0 ]
60+ assert p .mst_indices_ .shape [1 ] >= 3 # Default num_neighbors
61+ assert np .issubdtype (p .mst_indices_ .dtype , np .integer )
62+ assert p .mst_distances_ .shape [0 ] == X .shape [0 ]
63+ assert p .mst_distances_ .shape [1 ] >= 3 # Default num_neighbors
64+ assert np .issubdtype (p .mst_distances_ .dtype , np .floating )
65+ assert p .graph_ .shape [0 ] == X .shape [0 ]
66+ assert p .graph_ .shape [1 ] == X .shape [0 ]
67+ assert p .embedding_ .shape [0 ] == X .shape [0 ]
68+ assert p .embedding_ .shape [1 ] == 2 # Default num_components
69+ assert np .issubdtype (p .embedding_ .dtype , np .floating )
70+ assert isinstance (p ._umap , UMAP )
71+ assert connected_components (p ._umap .graph_ , directed = False , return_labels = False ) == 1
72+
73+
74+ def test_with_missing_data ():
75+ """Tests with nan data."""
76+ clean_indices = list (range (1 , 6 )) + list (range (7 , X .shape [0 ]))
77+ model = KMST ().fit (X_missing_data )
78+ clean_model = KMST ().fit (X_missing_data [clean_indices ])
79+
80+ assert np .all (model .mst_indices_ [0 , :] == - 1 )
81+ assert np .isinf (model .mst_distances_ [0 , :]).all ()
82+ assert np .isnan (model .embedding_ [0 , :]).all ()
83+ assert np .all (model .graph_ .row != 0 ) & np .all (model .graph_ .col != 0 )
84+
85+ assert np .all (model .mst_indices_ [6 , :] == - 1 )
86+ assert np .isinf (model .mst_distances_ [6 , :]).all ()
87+ assert np .isnan (model .embedding_ [6 , :]).all ()
88+ assert np .all (model .graph_ .row != 6 ) & np .all (model .graph_ .col != 6 )
89+
90+ assert np .allclose (clean_model .graph_ .data , model .graph_ .data )
91+ assert np .allclose (clean_model .mst_indices_ , model .mst_indices_ [clean_indices ])
92+ assert np .allclose (clean_model .mst_distances_ , model .mst_distances_ [clean_indices ])
93+
94+
95+ def test_with_missing_data_graph_mode ():
96+ """Tests with nan data."""
97+ clean_indices = list (range (1 , 6 )) + list (range (7 , X .shape [0 ]))
98+ model = KMST (transform_mode = 'graph' ).fit (X_missing_data )
99+ clean_model = KMST (transform_mode = 'graph' ).fit (X_missing_data [clean_indices ])
100+
101+ assert np .all (model .mst_indices_ [0 , :] == - 1 )
102+ assert np .isinf (model .mst_distances_ [0 , :]).all ()
103+ assert np .all (model .graph_ .row != 0 ) & np .all (model .graph_ .col != 0 )
104+
105+ assert np .all (model .mst_indices_ [6 , :] == - 1 )
106+ assert np .isinf (model .mst_distances_ [6 , :]).all ()
107+ assert np .all (model .graph_ .row != 6 ) & np .all (model .graph_ .col != 6 )
108+
109+ assert np .allclose (clean_model .graph_ .data , model .graph_ .data )
110+ assert np .allclose (clean_model .mst_indices_ , model .mst_indices_ [clean_indices ])
111+ assert np .allclose (clean_model .mst_distances_ , model .mst_distances_ [clean_indices ])
112+
113+
114+ def test_graph_mode ():
115+ """Tests with transform_mode='graph'."""
116+ p = KMST (transform_mode = 'graph' )
117+ embedding = p .fit_transform (X )
118+
119+ assert embedding is None
120+ assert p .embedding_ is None
121+
122+
123+ def test_min_samples ():
124+ """Tests with higher min_samples."""
125+ p = KMST (min_samples = 3 , transform_mode = 'graph' ).fit (X )
126+
127+ assert p .mst_indices_ .shape [1 ] >= 3 # Max(min_samples, num_neighbors)
128+ assert p .mst_distances_ .shape [1 ] >= 3 # Max(min_samples, num_neighbors)
129+ assert connected_components (p ._umap .graph_ , directed = False , return_labels = False ) == 1
130+
131+
132+ def test_num_neighbors ():
133+ """Tests with lower num_neighbors."""
134+ p = KMST (num_neighbors = 1 , transform_mode = 'graph' ).fit (X )
135+
136+ assert p .mst_indices_ .shape [1 ] >= 1 # Max(min_samples, num_neighbors)
137+ assert p .mst_distances_ .shape [1 ] >= 1 # Max(min_samples, num_neighbors)
138+ assert connected_components (p ._umap .graph_ , directed = False , return_labels = False ) == 1
139+
140+
141+ def test_epsilon ():
142+ """Tests with epsilon."""
143+ base = KMST (transform_mode = 'graph' ).fit (X )
144+ p = KMST (epsilon = 1.1 , transform_mode = 'graph' ).fit (X )
145+
146+ assert connected_components (p ._umap .graph_ , directed = False , return_labels = False ) == 1
147+ assert p .graph_ .nnz < base .graph_ .nnz
148+
149+
150+ def test_num_components ():
151+ """Tests with higher num_components."""
152+ p = KMST (n_components = 3 )
153+
154+ embedding = p .fit_transform (X )
155+ assert embedding .shape [0 ] == X .shape [0 ]
156+ assert embedding .shape [1 ] == 3
157+ assert np .issubdtype (embedding .dtype , np .floating )
158+
159+ assert p .embedding_ .shape [0 ] == X .shape [0 ]
160+ assert p .embedding_ .shape [1 ] == 3
161+ assert np .issubdtype (p .embedding_ .dtype , np .floating )
0 commit comments