1+ from itertools import product
12from metric_learn .base_metric import BilinearMixin
23import numpy as np
34from numpy .testing import assert_array_almost_equal
4-
5+ import pytest
6+ from metric_learn ._util import make_context
7+ from sklearn import clone
8+ from sklearn .cluster import DBSCAN
59
610class IdentityBilinearMixin (BilinearMixin ):
7- """A simple Identity bilinear mixin that returns an identity matrix
8- M as learned. Can change M for a random matrix calling random_M.
9- Class for testing purposes.
10- """
11- def __init__ (self , preprocessor = None ):
12- super ().__init__ (preprocessor = preprocessor )
11+ """A simple Identity bilinear mixin that returns an identity matrix
12+ M as learned. Can change M for a random matrix calling random_M.
13+ Class for testing purposes.
14+ """
15+ def __init__ (self , preprocessor = None ):
16+ super ().__init__ (preprocessor = preprocessor )
1317
14- def fit (self , X , y ):
15- X , y = self ._prepare_inputs (X , y , ensure_min_samples = 2 )
16- self .d = np .shape (X [0 ])[- 1 ]
17- self .components_ = np .identity (self .d )
18- return self
18+ def fit (self , X , y ):
19+ X , y = self ._prepare_inputs (X , y , ensure_min_samples = 2 )
20+ self .d = np .shape (X [0 ])[- 1 ]
21+ self .components_ = np .identity (self .d )
22+ return self
1923
20- def random_M (self ):
21- self .components_ = np .random .rand (self .d , self .d )
24+ def random_M (self ):
25+ self .components_ = np .random .rand (self .d , self .d )
2226
2327
2428def test_same_similarity_with_two_methods ():
25- d = 100
26- u = np .random .rand (d )
27- v = np .random .rand (d )
28- mixin = IdentityBilinearMixin ()
29- mixin .fit ([u , v ], [0 , 0 ])
30- mixin .random_M () # Dummy fit
29+ d = 100
30+ u = np .random .rand (d )
31+ v = np .random .rand (d )
32+ mixin = IdentityBilinearMixin ()
33+ mixin .fit ([u , v ], [0 , 0 ])
34+ mixin .random_M () # Dummy fit
3135
32- # The distances must match, whether calc with get_metric() or score_pairs()
33- dist1 = mixin .score_pairs ([[u , v ], [v , u ]])
34- dist2 = [mixin .get_metric ()(u , v ), mixin .get_metric ()(v , u )]
36+ # The distances must match, whether calc with get_metric() or score_pairs()
37+ dist1 = mixin .score_pairs ([[u , v ], [v , u ]])
38+ dist2 = [mixin .get_metric ()(u , v ), mixin .get_metric ()(v , u )]
3539
36- assert_array_almost_equal (dist1 , dist2 )
40+ assert_array_almost_equal (dist1 , dist2 )
3741
3842
3943def test_check_correctness_similarity ():
40- d = 100
41- u = np .random .rand (d )
42- v = np .random .rand (d )
43- mixin = IdentityBilinearMixin ()
44- mixin .fit ([u , v ], [0 , 0 ]) # Identity fit
45- dist1 = mixin .score_pairs ([[u , v ], [v , u ]])
46- u_v = np .dot (np .dot (u .T , np .identity (d )), v )
47- v_u = np .dot (np .dot (v .T , np .identity (d )), u )
48- desired = [u_v , v_u ]
49- assert_array_almost_equal (dist1 , desired )
44+ d = 100
45+ u = np .random .rand (d )
46+ v = np .random .rand (d )
47+ mixin = IdentityBilinearMixin ()
48+ mixin .fit ([u , v ], [0 , 0 ]) # Identity fit
49+ dist1 = mixin .score_pairs ([[u , v ], [v , u ]])
50+ dist2 = [mixin .get_metric ()(u , v ), mixin .get_metric ()(v , u )]
5051
52+ u_v = np .dot (np .dot (u .T , np .identity (d )), v )
53+ v_u = np .dot (np .dot (v .T , np .identity (d )), u )
54+ desired = [u_v , v_u ]
55+ assert_array_almost_equal (dist1 , desired ) # score_pairs
56+ assert_array_almost_equal (dist2 , desired ) # get_metric
5157
5258def test_check_handmade_example ():
53- u = np .array ([0 , 1 , 2 ])
54- v = np .array ([3 , 4 , 5 ])
55- mixin = IdentityBilinearMixin ()
56- mixin .fit ([u , v ], [0 , 0 ]) # Identity fit
57- c = np .array ([[2 , 4 , 6 ], [6 , 4 , 2 ], [1 , 2 , 3 ]])
58- mixin .components_ = c # Force components_
59- dists = mixin .score_pairs ([[u , v ], [v , u ]])
60- assert_array_almost_equal (dists , [96 , 120 ])
59+ u = np .array ([0 , 1 , 2 ])
60+ v = np .array ([3 , 4 , 5 ])
61+ mixin = IdentityBilinearMixin ()
62+ mixin .fit ([u , v ], [0 , 0 ]) # Identity fit
63+ c = np .array ([[2 , 4 , 6 ], [6 , 4 , 2 ], [1 , 2 , 3 ]])
64+ mixin .components_ = c # Force components_
65+ dists = mixin .score_pairs ([[u , v ], [v , u ]])
66+ assert_array_almost_equal (dists , [96 , 120 ])
6167
6268
6369def test_check_handmade_symmetric_example ():
64- u = np .array ([0 , 1 , 2 ])
65- v = np .array ([3 , 4 , 5 ])
66- mixin = IdentityBilinearMixin ()
67- mixin .fit ([u , v ], [0 , 0 ]) # Identity fit
68- dists = mixin .score_pairs ([[u , v ], [v , u ]])
69- assert_array_almost_equal (dists , [14 , 14 ])
70+ u = np .array ([0 , 1 , 2 ])
71+ v = np .array ([3 , 4 , 5 ])
72+ mixin = IdentityBilinearMixin ()
73+ mixin .fit ([u , v ], [0 , 0 ]) # Identity fit
74+ dists = mixin .score_pairs ([[u , v ], [v , u ]])
75+ assert_array_almost_equal (dists , [14 , 14 ])
76+
77+
78+ def test_score_pairs_finite ():
79+ d = 100
80+ u = np .random .rand (d )
81+ v = np .random .rand (d )
82+ mixin = IdentityBilinearMixin ()
83+ mixin .fit ([u , v ], [0 , 0 ])
84+ mixin .random_M () # Dummy fit
85+ n = 100
86+ X = np .array ([np .random .rand (d ) for i in range (n )])
87+ pairs = np .array (list (product (X , X )))
88+ assert np .isfinite (mixin .score_pairs (pairs )).all ()
89+
90+
91+ def test_score_pairs_dim ():
92+ # scoring of 3D arrays should return 1D array (several tuples),
93+ # and scoring of 2D arrays (one tuple) should return an error (like
94+ # scikit-learn's error when scoring 1D arrays)
95+ d = 100
96+ u = np .random .rand (d )
97+ v = np .random .rand (d )
98+ mixin = IdentityBilinearMixin ()
99+ mixin .fit ([u , v ], [0 , 0 ])
100+ mixin .random_M () # Dummy fit
101+ n = 100
102+ X = np .array ([np .random .rand (d ) for i in range (n )])
103+ tuples = np .array (list (product (X , X )))
104+ assert mixin .score_pairs (tuples ).shape == (tuples .shape [0 ],)
105+ context = make_context (mixin )
106+ msg = ("3D array of formed tuples expected{}. Found 2D array "
107+ "instead:\n input={}. Reshape your data and/or use a preprocessor.\n "
108+ .format (context , tuples [1 ]))
109+ with pytest .raises (ValueError ) as raised_error :
110+ mixin .score_pairs (tuples [1 ])
111+ assert str (raised_error .value ) == msg
112+
113+
114+ def test_check_scikitlearn_compatibility ():
115+ d = 100
116+ u = np .random .rand (d )
117+ v = np .random .rand (d )
118+ mixin = IdentityBilinearMixin ()
119+ mixin .fit ([u , v ], [0 , 0 ])
120+ mixin .random_M () # Dummy fit
121+
122+ n = 100
123+ X = np .array ([np .random .rand (d ) for i in range (n )])
124+ clustering = DBSCAN (metric = mixin .get_metric ())
125+ clustering .fit (X )
0 commit comments