2020import  org .elasticsearch .simdvec .ES91Int4VectorsScorer ;
2121import  org .elasticsearch .simdvec .ES91OSQVectorsScorer ;
2222
23- import  static  org .hamcrest .Matchers .lessThan ;
23+ import  java .io .IOException ;
24+ 
25+ import  static  org .hamcrest .Matchers .greaterThan ;
2426
2527public  class  ES91Int4VectorScorerTests  extends  BaseVectorizationTests  {
2628
@@ -130,31 +132,59 @@ public void testInt4ScoreBulk() throws Exception {
130132        // only even dimensions are supported 
131133        final  int  dimensions  = random ().nextInt (1 , 1000 ) * 2 ;
132134        final  int  numVectors  = random ().nextInt (1 , 10 ) * ES91Int4VectorsScorer .BULK_SIZE ;
133-         final  byte [] vector  = new  byte [ES91Int4VectorsScorer .BULK_SIZE  * dimensions ];
134-         final  byte [] corrections  = new  byte [ES91Int4VectorsScorer .BULK_SIZE  * 14 ];
135+         final  float [][] vectors  = new  float [numVectors ][dimensions ];
136+         final  int [] quantizedScratch  = new  int [dimensions ];
137+         final  byte [] quantizeVector  = new  byte [dimensions ];
138+         final  float [] centroid  = new  float [dimensions ];
139+         VectorSimilarityFunction  similarityFunction  = randomFrom (VectorSimilarityFunction .values ());
140+         for  (int  i  = 0 ; i  < dimensions ; i ++) {
141+             centroid [i ] = random ().nextFloat ();
142+         }
143+         if  (similarityFunction  != VectorSimilarityFunction .EUCLIDEAN ) {
144+             VectorUtil .l2normalize (centroid );
145+         }
146+ 
147+         OptimizedScalarQuantizer  quantizer  = new  OptimizedScalarQuantizer (similarityFunction );
135148        try  (Directory  dir  = new  MMapDirectory (createTempDir ())) {
136149            try  (IndexOutput  out  = dir .createOutput ("tests.bin" , IOContext .DEFAULT )) {
150+                 OptimizedScalarQuantizer .QuantizationResult [] results  =
151+                     new  OptimizedScalarQuantizer .QuantizationResult [ES91Int4VectorsScorer .BULK_SIZE ];
137152                for  (int  i  = 0 ; i  < numVectors ; i  += ES91Int4VectorsScorer .BULK_SIZE ) {
138-                     for  (int  j  = 0 ; j  < ES91Int4VectorsScorer .BULK_SIZE  * dimensions ; j ++) {
139-                         vector [j ] = (byte ) random ().nextInt (16 ); // 4-bit quantization 
153+                     for  (int  j  = 0 ; j  < ES91Int4VectorsScorer .BULK_SIZE ; j ++) {
154+                         for  (int  k  = 0 ; k  < dimensions ; k ++) {
155+                             vectors [i  + j ][k ] = random ().nextFloat ();
156+                         }
157+                         if  (similarityFunction  != VectorSimilarityFunction .EUCLIDEAN ) {
158+                             VectorUtil .l2normalize (vectors [i  + j ]);
159+                         }
160+                         results [j ] = quantizer .scalarQuantize (vectors [i  + j ].clone (), quantizedScratch , (byte ) 4 , centroid );
161+                         for  (int  k  = 0 ; k  < dimensions ; k ++) {
162+                             quantizeVector [k ] = (byte ) quantizedScratch [k ];
163+                         }
164+                         out .writeBytes (quantizeVector , 0 , dimensions );
140165                    }
141-                     out .writeBytes (vector , 0 , vector .length );
142-                     random ().nextBytes (corrections );
143-                     out .writeBytes (corrections , 0 , corrections .length );
166+                     writeCorrections (results , out );
144167                }
145168            }
146-             final  byte [] query  = new  byte [dimensions ];
169+             final  float [] query  = new  float [dimensions ];
170+             final  byte [] quantizeQuery  = new  byte [dimensions ];
147171            for  (int  j  = 0 ; j  < dimensions ; j ++) {
148-                 query [j ] = ( byte )  random ().nextInt ( 16 );  // 4-bit quantization 
172+                 query [j ] = random ().nextFloat (); 
149173            }
150-             OptimizedScalarQuantizer .QuantizationResult  queryCorrections  = new  OptimizedScalarQuantizer .QuantizationResult (
151-                 random ().nextFloat (),
152-                 random ().nextFloat (),
153-                 random ().nextFloat (),
154-                 Short .toUnsignedInt ((short ) random ().nextInt ())
174+             if  (similarityFunction  != VectorSimilarityFunction .EUCLIDEAN ) {
175+                 VectorUtil .l2normalize (query );
176+             }
177+             OptimizedScalarQuantizer .QuantizationResult  queryCorrections  = quantizer .scalarQuantize (
178+                 query .clone (),
179+                 quantizedScratch ,
180+                 (byte ) 4 ,
181+                 centroid 
155182            );
156-             float  centroidDp  = random ().nextFloat ();
157-             VectorSimilarityFunction  similarityFunction  = randomFrom (VectorSimilarityFunction .values ());
183+             for  (int  j  = 0 ; j  < dimensions ; j ++) {
184+                 quantizeQuery [j ] = (byte ) quantizedScratch [j ];
185+             }
186+             float  centroidDp  = VectorUtil .dotProduct (centroid , centroid );
187+ 
158188            try  (IndexInput  in  = dir .openInput ("tests.bin" , IOContext .DEFAULT )) {
159189                // Work on a slice that has just the right number of bytes to make the test fail with an 
160190                // index-out-of-bounds in case the implementation reads more than the allowed number of 
@@ -166,7 +196,7 @@ public void testInt4ScoreBulk() throws Exception {
166196                float [] scoresPanama  = new  float [ES91Int4VectorsScorer .BULK_SIZE ];
167197                for  (int  i  = 0 ; i  < numVectors ; i  += ES91Int4VectorsScorer .BULK_SIZE ) {
168198                    defaultScorer .scoreBulk (
169-                         query ,
199+                         quantizeQuery ,
170200                        queryCorrections .lowerInterval (),
171201                        queryCorrections .upperInterval (),
172202                        queryCorrections .quantizedComponentSum (),
@@ -176,7 +206,7 @@ public void testInt4ScoreBulk() throws Exception {
176206                        scoresDefault 
177207                    );
178208                    panamaScorer .scoreBulk (
179-                         query ,
209+                         quantizeQuery ,
180210                        queryCorrections .lowerInterval (),
181211                        queryCorrections .upperInterval (),
182212                        queryCorrections .quantizedComponentSum (),
@@ -186,29 +216,34 @@ public void testInt4ScoreBulk() throws Exception {
186216                        scoresPanama 
187217                    );
188218                    for  (int  j  = 0 ; j  < ES91OSQVectorsScorer .BULK_SIZE ; j ++) {
189-                         if  (scoresDefault [j ] == scoresPanama [j ]) {
190-                             continue ;
191-                         }
192-                         if  (scoresDefault [j ] > (1000  * Byte .MAX_VALUE )) {
193-                             float  diff  = Math .abs (scoresDefault [j ] - scoresPanama [j ]);
194-                             assertThat (
195-                                 "defaultScores: "  + scoresDefault [j ] + " bulkScores: "  + scoresPanama [j ],
196-                                 diff  / scoresDefault [j ],
197-                                 lessThan (1e-5f )
198-                             );
199-                             assertThat (
200-                                 "defaultScores: "  + scoresDefault [j ] + " bulkScores: "  + scoresPanama [j ],
201-                                 diff  / scoresPanama [j ],
202-                                 lessThan (1e-5f )
203-                             );
204-                         } else  {
205-                             assertEquals (scoresDefault [j ], scoresPanama [j ], 1e-2f );
206-                         }
219+                         assertEquals (scoresDefault [j ], scoresPanama [j ], 1e-2f );
220+                         float  realSimilarity  = similarityFunction .compare (vectors [i  + j ], query );
221+                         float  accuracy  = realSimilarity  > scoresDefault [j ]
222+                             ? scoresDefault [j ] / realSimilarity 
223+                             : realSimilarity  / scoresDefault [j ];
224+                         assertThat (accuracy , greaterThan (0.90f ));
207225                    }
208226                    assertEquals (in .getFilePointer (), slice .getFilePointer ());
209227                }
210228                assertEquals ((long ) (dimensions  + 14 ) * numVectors , in .getFilePointer ());
211229            }
212230        }
213231    }
232+ 
233+     private  static  void  writeCorrections (OptimizedScalarQuantizer .QuantizationResult [] corrections , IndexOutput  out ) throws  IOException  {
234+         for  (OptimizedScalarQuantizer .QuantizationResult  correction  : corrections ) {
235+             out .writeInt (Float .floatToIntBits (correction .lowerInterval ()));
236+         }
237+         for  (OptimizedScalarQuantizer .QuantizationResult  correction  : corrections ) {
238+             out .writeInt (Float .floatToIntBits (correction .upperInterval ()));
239+         }
240+         for  (OptimizedScalarQuantizer .QuantizationResult  correction  : corrections ) {
241+             int  targetComponentSum  = correction .quantizedComponentSum ();
242+             assert  targetComponentSum  >= 0  && targetComponentSum  <= 0xffff ;
243+             out .writeShort ((short ) targetComponentSum );
244+         }
245+         for  (OptimizedScalarQuantizer .QuantizationResult  correction  : corrections ) {
246+             out .writeInt (Float .floatToIntBits (correction .additionalCorrection ()));
247+         }
248+     }
214249}
0 commit comments