22import numpy as np
33from numpy .testing import assert_array_almost_equal
44
5+
56class IdentityBilinearMixin (BilinearMixin ):
67 """A simple Identity bilinear mixin that returns an identity matrix
78 M as learned. Can change M for a random matrix calling random_M.
@@ -15,50 +16,54 @@ def fit(self, X, y):
1516 self .d = np .shape (X [0 ])[- 1 ]
1617 self .components_ = np .identity (self .d )
1718 return self
18-
19+
1920 def random_M (self ):
2021 self .components_ = np .random .rand (self .d , self .d )
2122
23+
2224def test_same_similarity_with_two_methods ():
2325 d = 100
2426 u = np .random .rand (d )
2527 v = np .random .rand (d )
2628 mixin = IdentityBilinearMixin ()
27- mixin .fit ([u , v ], [0 , 0 ]) # Dummy fit
28- mixin .random_M ()
29+ mixin .fit ([u , v ], [0 , 0 ])
30+ mixin .random_M () # Dummy fit
2931
3032 # The distances must match, whether calc with get_metric() or score_pairs()
3133 dist1 = mixin .score_pairs ([[u , v ], [v , u ]])
3234 dist2 = [mixin .get_metric ()(u , v ), mixin .get_metric ()(v , u )]
3335
3436 assert_array_almost_equal (dist1 , dist2 )
3537
38+
3639def test_check_correctness_similarity ():
3740 d = 100
3841 u = np .random .rand (d )
3942 v = np .random .rand (d )
4043 mixin = IdentityBilinearMixin ()
41- mixin .fit ([u , v ], [0 , 0 ]) # Dummy fit
44+ mixin .fit ([u , v ], [0 , 0 ]) # Identity fit
4245 dist1 = mixin .score_pairs ([[u , v ], [v , u ]])
4346 u_v = np .dot (np .dot (u .T , np .identity (d )), v )
4447 v_u = np .dot (np .dot (v .T , np .identity (d )), u )
4548 desired = [u_v , v_u ]
4649 assert_array_almost_equal (dist1 , desired )
4750
51+
4852def test_check_handmade_example ():
4953 u = np .array ([0 , 1 , 2 ])
5054 v = np .array ([3 , 4 , 5 ])
5155 mixin = IdentityBilinearMixin ()
52- mixin .fit ([u , v ], [0 , 0 ])
56+ mixin .fit ([u , v ], [0 , 0 ]) # Identity fit
5357 c = np .array ([[2 , 4 , 6 ], [6 , 4 , 2 ], [1 , 2 , 3 ]])
54- mixin .components_ = c # Force a components_
58+ mixin .components_ = c # Force components_
5559 dists = mixin .score_pairs ([[u , v ], [v , u ]])
5660 assert_array_almost_equal (dists , [96 , 120 ])
5761
62+
5863def test_check_handmade_symmetric_example ():
5964 u = np .array ([0 , 1 , 2 ])
6065 v = np .array ([3 , 4 , 5 ])
6166 mixin = IdentityBilinearMixin ()
62- mixin .fit ([u , v ], [0 , 0 ])
67+ mixin .fit ([u , v ], [0 , 0 ]) # Identity fit
6368 dists = mixin .score_pairs ([[u , v ], [v , u ]])
64- assert_array_almost_equal (dists , [14 , 14 ])
69+ assert_array_almost_equal (dists , [14 , 14 ])
0 commit comments