@@ -21,6 +21,132 @@ public static void setUp() {
2121 Library .setKhivaBackend (Library .Backend .KHIVA_BACKEND_CPU );
2222 }
2323
24+ private double getSingleValueDouble (Array arr , long dim0 , long dim1 , long dim2 , long dim3 ) {
25+ double [] data = arr .getData ();
26+
27+ long [] dims4 = arr .getDims ();
28+ long offset = (dims4 [0 ] * dims4 [1 ] * dims4 [2 ]) * dim3 ;
29+ offset += (dims4 [0 ] * dims4 [1 ]) * dim2 ;
30+ offset += dims4 [0 ] * dim1 ;
31+ offset += dim0 ;
32+
33+ return data [(int ) offset ];
34+ }
35+
36+ private int getSingleValueInt (Array arr , long dim0 , long dim1 , long dim2 , long dim3 ) {
37+ int [] data = arr .getData ();
38+
39+ long [] dims4 = arr .getDims ();
40+ long offset = (dims4 [0 ] * dims4 [1 ] * dims4 [2 ]) * dim3 ;
41+ offset += (dims4 [0 ] * dims4 [1 ]) * dim2 ;
42+ offset += dims4 [0 ] * dim1 ;
43+ offset += dim0 ;
44+
45+ return data [(int ) offset ];
46+ }
47+
48+ @ Test
49+ public void testMass () throws Exception {
50+ double [] tss = {10 , 10 , 10 , 11 , 12 , 11 , 10 , 10 , 11 , 12 , 11 , 14 , 10 , 10 };
51+ long [] dimsTss = {14 , 1 , 1 , 1 };
52+
53+ double [] query = {4 , 3 , 8 };
54+ long [] dimsQuery = {3 , 1 , 1 , 1 };
55+
56+ try (
57+ Array t = new Array (tss , dimsTss );
58+ Array q = new Array (query , dimsQuery )
59+ ) {
60+
61+ double [] expectedDistance = {1.732051 , 0.328954 , 1.210135 , 3.150851 , 3.245858 , 2.822044 ,
62+ 0.328954 , 1.210135 , 3.150851 , 0.248097 , 3.30187 , 2.82205 };
63+ Array result = Matrix .mass (q , t );
64+ double [] distances = result .getData ();
65+
66+ Assert .assertArrayEquals (expectedDistance , distances , 1e-3 );
67+
68+ result .close ();
69+ }
70+
71+ }
72+
73+ @ Test
74+ public void testMassMultiple () throws Exception {
75+ double [] tss = {10 , 10 , 10 , 11 , 12 , 11 , 10 , 10 , 11 , 12 , 11 , 14 , 10 , 10 };
76+ long [] dimsTss = {7 , 2 , 1 , 1 };
77+
78+ double [] query = {10 , 10 , 11 , 11 , 10 , 11 , 10 , 10 };
79+ long [] dimsQuery = {4 , 2 , 1 , 1 };
80+
81+ try (
82+ Array t = new Array (tss , dimsTss );
83+ Array q = new Array (query , dimsQuery )
84+ ) {
85+
86+ double [] expectedDistance = {1.8388 , 0.8739 , 1.5307 , 3.6955 , 3.2660 , 3.4897 , 2.8284 , 1.2116 , 1.5307 ,
87+ 2.1758 , 2.5783 , 3.7550 , 2.8284 , 2.8284 , 3.2159 , 0.5020 };
88+ Array result = Matrix .mass (q , t );
89+ double [] distances = result .getData ();
90+
91+ Assert .assertArrayEquals (expectedDistance , distances , 1e-3 );
92+
93+ result .close ();
94+ }
95+
96+ }
97+
98+ @ Test
99+ public void testFindBestNOccurrences () throws Exception {
100+ double [] tss = {10 , 10 , 11 , 11 , 12 , 11 , 10 , 10 , 11 , 12 , 11 , 10 , 10 , 11 , 10 , 10 , 11 ,
101+ 11 , 12 , 11 , 10 , 10 , 11 , 12 , 11 , 10 , 10 , 11 };
102+ long [] dimsTss = {28 , 1 , 1 , 1 };
103+
104+ double [] query = {10 , 11 , 12 };
105+ long [] dimsQuery = {3 , 1 , 1 , 1 };
106+
107+ try (
108+ Array t = new Array (tss , dimsTss );
109+ Array q = new Array (query , dimsQuery )
110+ ) {
111+ Array [] result = Matrix .findBestNOccurrences (q , t , 1 );
112+ double [] distances = result [0 ].getData ();
113+ int [] indexes = result [1 ].getData ();
114+
115+ Assert .assertEquals (distances [0 ], 0 , DELTA );
116+ Assert .assertEquals (indexes [0 ], 7 );
117+
118+ result [0 ].close ();
119+ result [1 ].close ();
120+ }
121+
122+ }
123+
124+ @ Test
125+ public void testFindBestNOccurrencesMultipleQueries () throws Exception {
126+ double [] tss = {10 , 10 , 11 , 11 , 10 , 11 , 10 , 10 , 11 , 11 , 10 , 11 , 10 , 10 ,
127+ 11 , 10 , 10 , 11 , 10 , 11 , 11 , 10 , 11 , 11 , 14 , 10 , 11 , 10 };
128+ long [] dimsTss = {14 , 2 , 1 , 1 };
129+
130+ double [] query = {11 , 11 , 10 , 11 , 10 , 11 , 11 , 12 };
131+ long [] dimsQuery = {4 , 2 , 1 , 1 };
132+
133+ try (
134+ Array t = new Array (tss , dimsTss );
135+ Array q = new Array (query , dimsQuery )
136+ ) {
137+ Array [] result = Matrix .findBestNOccurrences (q , t , 4 );
138+
139+ double distance = getSingleValueDouble (result [0 ], 2 , 0 , 1 , 0 );
140+ Assert .assertEquals (distance , 1.83880 , 1e-3 );
141+
142+ int index = getSingleValueInt (result [1 ], 3 , 1 , 0 , 0 );
143+ Assert .assertEquals (index , 2 );
144+
145+ result [0 ].close ();
146+ result [1 ].close ();
147+ }
148+
149+ }
24150
25151 @ Test
26152 public void testStompSelfJoin () throws Exception {
0 commit comments