55import pytest
66from metric_learn ._util import make_context
77from sklearn .cluster import DBSCAN
8+ from sklearn .datasets import make_spd_matrix
9+ from sklearn .utils import check_random_state
810
11+ RNG = check_random_state (0 )
912
1013class IdentityBilinearMixin (BilinearMixin ):
1114 """A simple Identity bilinear mixin that returns an identity matrix
@@ -15,14 +18,17 @@ class IdentityBilinearMixin(BilinearMixin):
1518 def __init__ (self , preprocessor = None ):
1619 super ().__init__ (preprocessor = preprocessor )
1720
18- def fit (self , X , y ):
21+ def fit (self , X , y , random = False ):
1922 """
2023 Checks input's format. Sets M matrix to identity of shape (d,d)
2124 where d is the dimension of the input.
2225 """
2326 X , y = self ._prepare_inputs (X , y , ensure_min_samples = 2 )
2427 self .d = np .shape (X [0 ])[- 1 ]
25- self .components_ = np .identity (self .d )
28+ if random :
29+ self .components_ = np .random .rand (self .d , self .d )
30+ else :
31+ self .components_ = np .identity (self .d )
2632 return self
2733
2834 def random_M (self ):
@@ -32,29 +38,34 @@ def random_M(self):
3238 self .components_ = np .random .rand (self .d , self .d )
3339
3440
35- def identity_fit (d = 100 ):
41+ def identity_fit (d = 100 , n = 100 , n_pairs = None , random = False ):
3642 """
37- Creates two d-dimentional arrays. Fits an IdentityBilinearMixin()
38- and then returns the two arrays and the mixin. Testing purposes
43+ Creates 'n' d-dimentional arrays. Also generates 'n_pairs'
44+ sampled from the 'n' arrays. Fits an IdentityBilinearMixin()
45+ and then returns the arrays, the pairs and the mixin. Only
46+ generates the pairs if n_pairs is not None
3947 """
40- d = 100
41- u = np .random .rand (d )
42- v = np .random .rand (d )
48+ X = np .array ([np .random .rand (d ) for _ in range (n )])
4349 mixin = IdentityBilinearMixin ()
44- mixin .fit ([u , v ], [0 , 0 ])
45- return u , v , mixin
50+ mixin .fit (X , [0 for _ in range (n )], random = random )
51+ if n_pairs is not None :
52+ random_pairs = [[X [RNG .randint (0 , n )], X [RNG .randint (0 , n )]]
53+ for _ in range (n_pairs )]
54+ else :
55+ random_pairs = None
56+ return X , random_pairs , mixin
4657
4758
4859def test_same_similarity_with_two_methods ():
4960 """"
5061 Tests that score_pairs() and get_metric() give consistent results.
5162 In both cases, the results must match for the same input.
63+ Tests it for 'n_pairs' sampled from 'n' d-dimentional arrays.
5264 """
53- u , v , mixin = identity_fit ()
54- mixin .random_M () # Dummy fit
55- # The distances must match, whether calc with get_metric() or score_pairs()
56- dist1 = mixin .score_pairs ([[u , v ], [v , u ]])
57- dist2 = [mixin .get_metric ()(u , v ), mixin .get_metric ()(v , u )]
65+ d , n , n_pairs = 100 , 100 , 1000
66+ _ , random_pairs , mixin = identity_fit (d = d , n = n , n_pairs = n_pairs , random = True )
67+ dist1 = mixin .score_pairs (random_pairs )
68+ dist2 = [mixin .get_metric ()(p [0 ], p [1 ]) for p in random_pairs ]
5869
5970 assert_array_almost_equal (dist1 , dist2 )
6071
@@ -65,14 +76,12 @@ def test_check_correctness_similarity():
6576 get_metric(). Results are compared with the real bilinear similarity
6677 calculated in-place.
6778 """
68- d = 100
69- u , v , mixin = identity_fit (d )
70- dist1 = mixin .score_pairs ([[u , v ], [v , u ]])
71- dist2 = [mixin .get_metric ()(u , v ), mixin .get_metric ()(v , u )]
72-
73- u_v = np .dot (np .dot (u .T , np .identity (d )), v )
74- v_u = np .dot (np .dot (v .T , np .identity (d )), u )
75- desired = [u_v , v_u ]
79+ d , n , n_pairs = 100 , 100 , 1000
80+ _ , random_pairs , mixin = identity_fit (d = d , n = n , n_pairs = n_pairs , random = True )
81+ dist1 = mixin .score_pairs (random_pairs )
82+ dist2 = [mixin .get_metric ()(p [0 ], p [1 ]) for p in random_pairs ]
83+ desired = [np .dot (np .dot (p [0 ].T , mixin .components_ ), p [1 ]) for p in random_pairs ]
84+
7685 assert_array_almost_equal (dist1 , desired ) # score_pairs
7786 assert_array_almost_equal (dist2 , desired ) # get_metric
7887
@@ -98,27 +107,31 @@ def test_check_handmade_symmetric_example():
98107 between two arrays must be equal: S(u,v) = S(v,u). Also
99108 checks the random case: when the matrix is pd and symetric.
100109 """
101- u = np .array ([0 , 1 , 2 ])
102- v = np .array ([3 , 4 , 5 ])
103- mixin = IdentityBilinearMixin ()
104- mixin .fit ([u , v ], [0 , 0 ]) # Identity fit
105- dists = mixin .score_pairs ([[u , v ], [v , u ]])
106- assert_array_almost_equal (dists , [14 , 14 ])
110+ # Random pairs for M = Identity
111+ d , n , n_pairs = 100 , 100 , 1000
112+ _ , random_pairs , mixin = identity_fit (d = d , n = n , n_pairs = n_pairs )
113+ pairs_reverse = [[p [1 ], p [0 ]] for p in random_pairs ]
114+ dist1 = mixin .score_pairs (random_pairs )
115+ dist2 = mixin .score_pairs (pairs_reverse )
116+ assert_array_almost_equal (dist1 , dist2 )
107117
118+ # Random pairs for M = spd Matrix
119+ spd_matrix = make_spd_matrix (d , random_state = RNG )
120+ mixin .components_ = spd_matrix
121+ dist1 = mixin .score_pairs (random_pairs )
122+ dist2 = mixin .score_pairs (pairs_reverse )
123+ assert_array_almost_equal (dist1 , dist2 )
108124
109125def test_score_pairs_finite ():
110126 """
111127 Checks for 'n' score_pairs() of 'd' dimentions, that all
112128 similarities are finite numbers, not NaN, +inf or -inf.
113129 Considering a random M for bilinear similarity.
114130 """
115- d = 100
116- u , v , mixin = identity_fit (d )
117- mixin .random_M () # Dummy fit
118- n = 100
119- X = np .array ([np .random .rand (d ) for i in range (n )])
120- pairs = np .array (list (product (X , X )))
121- assert np .isfinite (mixin .score_pairs (pairs )).all ()
131+ d , n , n_pairs = 100 , 100 , 1000
132+ _ , random_pairs , mixin = identity_fit (d = d , n = n , n_pairs = n_pairs , random = True )
133+ dist1 = mixin .score_pairs (random_pairs )
134+ assert np .isfinite (dist1 ).all ()
122135
123136
124137def test_score_pairs_dim ():
@@ -127,11 +140,8 @@ def test_score_pairs_dim():
127140 and scoring of 2D arrays (one tuple) should return an error (like
128141 scikit-learn's error when scoring 1D arrays)
129142 """
130- d = 100
131- u , v , mixin = identity_fit ()
132- mixin .random_M () # Dummy fit
133- n = 100
134- X = np .array ([np .random .rand (d ) for i in range (n )])
143+ d , n , n_pairs = 100 , 100 , 1000
144+ X , _ , mixin = identity_fit (d = d , n = n , n_pairs = None , random = True )
135145 tuples = np .array (list (product (X , X )))
136146 assert mixin .score_pairs (tuples ).shape == (tuples .shape [0 ],)
137147 context = make_context (mixin )
@@ -146,11 +156,7 @@ def test_score_pairs_dim():
146156def test_check_scikitlearn_compatibility ():
147157 """Check that the similarity returned by get_metric() is compatible with
148158 scikit-learn's algorithms using a custom metric, DBSCAN for instance"""
149- d = 100
150- u , v , mixin = identity_fit (d )
151- mixin .random_M () # Dummy fit
152-
153- n = 100
154- X = np .array ([np .random .rand (d ) for i in range (n )])
159+ d , n = 100 , 100
160+ X , _ , mixin = identity_fit (d = d , n = n , n_pairs = None , random = True )
155161 clustering = DBSCAN (metric = mixin .get_metric ())
156162 clustering .fit (X )
0 commit comments