1+ from metric_learn .base_metric import BilinearMixin
2+ import numpy as np
3+ from numpy .testing import assert_array_almost_equal
4+
5+ class IdentityBilinearMixin (BilinearMixin ):
6+ """A simple Identity bilinear mixin that returns an identity matrix
7+ M as learned. Can change M for a random matrix calling random_M.
8+ Class for testing purposes.
9+ """
10+ def __init__ (self , preprocessor = None ):
11+ super ().__init__ (preprocessor = preprocessor )
12+
13+ def fit (self , X , y ):
14+ X , y = self ._prepare_inputs (X , y , ensure_min_samples = 2 )
15+ self .d = np .shape (X [0 ])[- 1 ]
16+ self .components_ = np .identity (self .d )
17+ return self
18+
19+ def random_M (self ):
20+ self .components_ = np .random .rand (self .d , self .d )
21+
22+ def test_same_similarity_with_two_methods ():
23+ d = 100
24+ u = np .random .rand (d )
25+ v = np .random .rand (d )
26+ mixin = IdentityBilinearMixin ()
27+ mixin .fit ([u , v ], [0 , 0 ]) # Dummy fit
28+ mixin .random_M ()
29+
30+ # The distances must match, whether calc with get_metric() or score_pairs()
31+ dist1 = mixin .score_pairs ([[u , v ], [v , u ]])
32+ dist2 = [mixin .get_metric ()(u , v ), mixin .get_metric ()(v , u )]
33+
34+ assert_array_almost_equal (dist1 , dist2 )
35+
36+ def test_check_correctness_similarity ():
37+ d = 100
38+ u = np .random .rand (d )
39+ v = np .random .rand (d )
40+ mixin = IdentityBilinearMixin ()
41+ mixin .fit ([u , v ], [0 , 0 ]) # Dummy fit
42+ dist1 = mixin .score_pairs ([[u , v ], [v , u ]])
43+ u_v = np .dot (np .dot (u .T , np .identity (d )), v )
44+ v_u = np .dot (np .dot (v .T , np .identity (d )), u )
45+ desired = [u_v , v_u ]
46+ assert_array_almost_equal (dist1 , desired )
47+
48+ def test_check_handmade_example ():
49+ u = np .array ([0 , 1 , 2 ])
50+ v = np .array ([3 , 4 , 5 ])
51+ mixin = IdentityBilinearMixin ()
52+ mixin .fit ([u , v ], [0 , 0 ])
53+ c = np .array ([[2 , 4 , 6 ], [6 , 4 , 2 ], [1 , 2 , 3 ]])
54+ mixin .components_ = c # Force a components_
55+ dists = mixin .score_pairs ([[u , v ], [v , u ]])
56+ assert_array_almost_equal (dists , [96 , 120 ])
57+
58+ def test_check_handmade_symmetric_example ():
59+ u = np .array ([0 , 1 , 2 ])
60+ v = np .array ([3 , 4 , 5 ])
61+ mixin = IdentityBilinearMixin ()
62+ mixin .fit ([u , v ], [0 , 0 ])
63+ dists = mixin .score_pairs ([[u , v ], [v , u ]])
64+ assert_array_almost_equal (dists , [14 , 14 ])
0 commit comments