4040import org .apache .lucene .store .Directory ;
4141import org .apache .lucene .store .FSDirectory ;
4242import org .apache .lucene .store .MMapDirectory ;
43+ import org .elasticsearch .common .io .Channels ;
4344import org .elasticsearch .core .PathUtils ;
4445import org .elasticsearch .index .mapper .vectors .DenseVectorFieldMapper ;
4546import org .elasticsearch .search .profile .query .QueryProfiler ;
@@ -87,7 +88,7 @@ class KnnSearcher {
8788 private final int efSearch ;
8889 private final int nProbe ;
8990 private final KnnIndexTester .IndexType indexType ;
90- private final int dim ;
91+ private int dim ;
9192 private final VectorSimilarityFunction similarityFunction ;
9293 private final VectorEncoding vectorEncoding ;
9394 private final float overSamplingFactor ;
@@ -117,6 +118,7 @@ void runSearch(KnnIndexTester.Results finalResults) throws IOException {
117118 TopDocs [] results = new TopDocs [numQueryVectors ];
118119 int [][] resultIds = new int [numQueryVectors ][];
119120 long elapsed , totalCpuTimeMS , totalVisited = 0 ;
121+ int offsetByteSize = 0 ;
120122 try (
121123 FileChannel input = FileChannel .open (queryPath );
122124 ExecutorService executorService = Executors .newFixedThreadPool (searchThreads , r -> new Thread (r , "KnnSearcher-Thread" ))
@@ -128,7 +130,19 @@ void runSearch(KnnIndexTester.Results finalResults) throws IOException {
128130 + " bytes, assuming vector count is "
129131 + (queryPathSizeInBytes / ((long ) dim * vectorEncoding .byteSize ))
130132 );
131- KnnIndexer .VectorReader targetReader = KnnIndexer .VectorReader .create (input , dim , vectorEncoding );
133+ if (dim == -1 ) {
134+ offsetByteSize = 4 ;
135+ ByteBuffer preamble = ByteBuffer .allocate (4 ).order (ByteOrder .LITTLE_ENDIAN );
136+ int bytesRead = Channels .readFromFileChannel (input , 0 , preamble );
137+ if (bytesRead < 4 ) {
138+ throw new IllegalArgumentException ("queryPath \" " + queryPath + "\" does not contain a valid dims?" );
139+ }
140+ dim = preamble .getInt (0 );
141+ if (dim <= 0 ) {
142+ throw new IllegalArgumentException ("queryPath \" " + queryPath + "\" has invalid dimension: " + dim );
143+ }
144+ }
145+ KnnIndexer .VectorReader targetReader = KnnIndexer .VectorReader .create (input , dim , vectorEncoding , offsetByteSize );
132146 long startNS ;
133147 try (MMapDirectory dir = new MMapDirectory (indexPath )) {
134148 try (DirectoryReader reader = DirectoryReader .open (dir )) {
@@ -191,7 +205,7 @@ void runSearch(KnnIndexTester.Results finalResults) throws IOException {
191205 }
192206 }
193207 logger .info ("checking results" );
194- int [][] nn = getOrCalculateExactNN ();
208+ int [][] nn = getOrCalculateExactNN (offsetByteSize );
195209 finalResults .avgRecall = checkResults (resultIds , nn , topK );
196210 finalResults .qps = (1000f * numQueryVectors ) / elapsed ;
197211 finalResults .avgLatency = (float ) elapsed / numQueryVectors ;
@@ -200,7 +214,7 @@ void runSearch(KnnIndexTester.Results finalResults) throws IOException {
200214 finalResults .avgCpuCount = (double ) totalCpuTimeMS / elapsed ;
201215 }
202216
203- private int [][] getOrCalculateExactNN () throws IOException {
217+ private int [][] getOrCalculateExactNN (int vectorFileOffsetBytes ) throws IOException {
204218 // look in working directory for cached nn file
205219 String hash = Integer .toString (
206220 Objects .hash (
@@ -228,9 +242,9 @@ private int[][] getOrCalculateExactNN() throws IOException {
228242 // checking low-precision recall
229243 int [][] nn ;
230244 if (vectorEncoding .equals (VectorEncoding .BYTE )) {
231- nn = computeExactNNByte (queryPath );
245+ nn = computeExactNNByte (queryPath , vectorFileOffsetBytes );
232246 } else {
233- nn = computeExactNN (queryPath );
247+ nn = computeExactNN (queryPath , vectorFileOffsetBytes );
234248 }
235249 writeExactNN (nn , nnPath );
236250 long elapsedMS = TimeUnit .NANOSECONDS .toMillis (System .nanoTime () - startNS ); // ns -> ms
@@ -356,12 +370,17 @@ private void writeExactNN(int[][] nn, Path nnPath) throws IOException {
356370 }
357371 }
358372
359- private int [][] computeExactNN (Path queryPath ) throws IOException {
373+ private int [][] computeExactNN (Path queryPath , int vectorFileOffsetBytes ) throws IOException {
360374 int [][] result = new int [numQueryVectors ][];
361375 try (Directory dir = FSDirectory .open (indexPath ); DirectoryReader reader = DirectoryReader .open (dir )) {
362376 List <Callable <Void >> tasks = new ArrayList <>();
363377 try (FileChannel qIn = FileChannel .open (queryPath )) {
364- KnnIndexer .VectorReader queryReader = KnnIndexer .VectorReader .create (qIn , dim , VectorEncoding .FLOAT32 );
378+ KnnIndexer .VectorReader queryReader = KnnIndexer .VectorReader .create (
379+ qIn ,
380+ dim ,
381+ VectorEncoding .FLOAT32 ,
382+ vectorFileOffsetBytes
383+ );
365384 for (int i = 0 ; i < numQueryVectors ; i ++) {
366385 float [] queryVector = new float [dim ];
367386 queryReader .next (queryVector );
@@ -373,12 +392,12 @@ private int[][] computeExactNN(Path queryPath) throws IOException {
373392 }
374393 }
375394
376- private int [][] computeExactNNByte (Path queryPath ) throws IOException {
395+ private int [][] computeExactNNByte (Path queryPath , int vectorFileOffsetBytes ) throws IOException {
377396 int [][] result = new int [numQueryVectors ][];
378397 try (Directory dir = FSDirectory .open (indexPath ); DirectoryReader reader = DirectoryReader .open (dir )) {
379398 List <Callable <Void >> tasks = new ArrayList <>();
380399 try (FileChannel qIn = FileChannel .open (queryPath )) {
381- KnnIndexer .VectorReader queryReader = KnnIndexer .VectorReader .create (qIn , dim , VectorEncoding .BYTE );
400+ KnnIndexer .VectorReader queryReader = KnnIndexer .VectorReader .create (qIn , dim , VectorEncoding .BYTE , vectorFileOffsetBytes );
382401 for (int i = 0 ; i < numQueryVectors ; i ++) {
383402 byte [] queryVector = new byte [dim ];
384403 queryReader .next (queryVector );
0 commit comments