Skip to content

Commit 0b66b8b

Browse files
benwtrentsmalyshev
authored andcommitted
Allow reading vectors where dim is in the file (elastic#130138)
This allows configuration to have a `-1` dim to read files that have the `dim` in the file. Additionally, allows setting numQuerys to `0` to skip the search phase easily.
1 parent e072a49 commit 0b66b8b

File tree

4 files changed

+66
-27
lines changed

4 files changed

+66
-27
lines changed

qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,8 +262,10 @@ public CmdLineArgs build() {
262262
if (docVectors == null) {
263263
throw new IllegalArgumentException("Document vectors path must be provided");
264264
}
265-
if (dimensions <= 0) {
266-
throw new IllegalArgumentException("dimensions must be a positive integer");
265+
if (dimensions <= 0 && dimensions != -1) {
266+
throw new IllegalArgumentException(
267+
"dimensions must be a positive integer or -1 for when dimension is available in the vector file"
268+
);
267269
}
268270
return new CmdLineArgs(
269271
docVectors,

qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ public static void main(String[] args) throws Exception {
200200
knnIndexer.numSegments(result);
201201
}
202202
}
203-
if (cmdLineArgs.queryVectors() != null) {
203+
if (cmdLineArgs.queryVectors() != null && cmdLineArgs.numQueries() > 0) {
204204
KnnSearcher knnSearcher = new KnnSearcher(indexPath, cmdLineArgs);
205205
knnSearcher.runSearch(result);
206206
}

qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexer.java

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ class KnnIndexer {
6464
private final Path docsPath;
6565
private final Path indexPath;
6666
private final VectorEncoding vectorEncoding;
67-
private final int dim;
67+
private int dim;
6868
private final VectorSimilarityFunction similarityFunction;
6969
private final Codec codec;
7070
private final int numDocs;
@@ -106,10 +106,6 @@ void createIndex(KnnIndexTester.Results result) throws IOException, InterruptedE
106106

107107
iwc.setMaxFullFlushMergeWaitMillis(0);
108108

109-
FieldType fieldType = switch (vectorEncoding) {
110-
case BYTE -> KnnByteVectorField.createFieldType(dim, similarityFunction);
111-
case FLOAT32 -> KnnFloatVectorField.createFieldType(dim, similarityFunction);
112-
};
113109
iwc.setInfoStream(new PrintStreamInfoStream(System.out) {
114110
@Override
115111
public boolean isEnabled(String component) {
@@ -137,7 +133,26 @@ public boolean isEnabled(String component) {
137133
FileChannel in = FileChannel.open(docsPath)
138134
) {
139135
long docsPathSizeInBytes = in.size();
140-
if (docsPathSizeInBytes % ((long) dim * vectorEncoding.byteSize) != 0) {
136+
int offsetByteSize = 0;
137+
if (dim == -1) {
138+
offsetByteSize = 4;
139+
ByteBuffer preamble = ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN);
140+
int bytesRead = Channels.readFromFileChannel(in, 0, preamble);
141+
if (bytesRead < 4) {
142+
throw new IllegalArgumentException(
143+
"docsPath \"" + docsPath + "\" does not contain a valid dims? size=" + docsPathSizeInBytes
144+
);
145+
}
146+
dim = preamble.getInt(0);
147+
if (dim <= 0) {
148+
throw new IllegalArgumentException("docsPath \"" + docsPath + "\" has invalid dimension: " + dim);
149+
}
150+
}
151+
FieldType fieldType = switch (vectorEncoding) {
152+
case BYTE -> KnnByteVectorField.createFieldType(dim, similarityFunction);
153+
case FLOAT32 -> KnnFloatVectorField.createFieldType(dim, similarityFunction);
154+
};
155+
if (docsPathSizeInBytes % (((long) dim * vectorEncoding.byteSize + offsetByteSize)) != 0) {
141156
throw new IllegalArgumentException(
142157
"docsPath \"" + docsPath + "\" does not contain a whole number of vectors? size=" + docsPathSizeInBytes
143158
);
@@ -150,7 +165,7 @@ public boolean isEnabled(String component) {
150165
vectorEncoding.byteSize
151166
);
152167

153-
VectorReader inReader = VectorReader.create(in, dim, vectorEncoding);
168+
VectorReader inReader = VectorReader.create(in, dim, vectorEncoding, offsetByteSize);
154169
try (ExecutorService exec = Executors.newFixedThreadPool(numIndexThreads, r -> new Thread(r, "KnnIndexer-Thread"))) {
155170
AtomicInteger numDocsIndexed = new AtomicInteger();
156171
List<Future<?>> threads = new ArrayList<>();
@@ -271,36 +286,39 @@ private void _run() throws IOException {
271286

272287
static class VectorReader {
273288
final float[] target;
289+
final int offsetByteSize;
274290
final ByteBuffer bytes;
275291
final FileChannel input;
276292
long position;
277293

278-
static VectorReader create(FileChannel input, int dim, VectorEncoding vectorEncoding) throws IOException {
294+
static VectorReader create(FileChannel input, int dim, VectorEncoding vectorEncoding, int offsetByteSize) throws IOException {
295+
// check if dim is set as preamble in the file:
279296
int bufferSize = dim * vectorEncoding.byteSize;
280-
if (input.size() % ((long) dim * vectorEncoding.byteSize) != 0) {
297+
if (input.size() % ((long) dim * vectorEncoding.byteSize + offsetByteSize) != 0) {
281298
throw new IllegalArgumentException(
282299
"vectors file \"" + input + "\" does not contain a whole number of vectors? size=" + input.size()
283300
);
284301
}
285-
return new VectorReader(input, dim, bufferSize);
302+
return new VectorReader(input, dim, bufferSize, offsetByteSize);
286303
}
287304

288-
VectorReader(FileChannel input, int dim, int bufferSize) throws IOException {
305+
VectorReader(FileChannel input, int dim, int bufferSize, int offsetByteSize) throws IOException {
306+
this.offsetByteSize = offsetByteSize;
289307
this.bytes = ByteBuffer.wrap(new byte[bufferSize]).order(ByteOrder.LITTLE_ENDIAN);
290308
this.input = input;
291309
this.target = new float[dim];
292310
reset();
293311
}
294312

295313
void reset() throws IOException {
296-
position = 0;
314+
position = offsetByteSize;
297315
input.position(position);
298316
}
299317

300318
private void readNext() throws IOException {
301319
int bytesRead = Channels.readFromFileChannel(this.input, position, bytes);
302320
if (bytesRead < bytes.capacity()) {
303-
position = 0;
321+
position = offsetByteSize;
304322
bytes.position(0);
305323
// wrap around back to the start of the file if we hit the end:
306324
logger.warn("VectorReader hit EOF when reading " + this.input + "; now wrapping around to start of file again");
@@ -312,7 +330,7 @@ private void readNext() throws IOException {
312330
);
313331
}
314332
}
315-
position += bytesRead;
333+
position += bytesRead + offsetByteSize;
316334
bytes.position(0);
317335
}
318336

qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
import org.apache.lucene.store.Directory;
4141
import org.apache.lucene.store.FSDirectory;
4242
import org.apache.lucene.store.MMapDirectory;
43+
import org.elasticsearch.common.io.Channels;
4344
import org.elasticsearch.core.PathUtils;
4445
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
4546
import org.elasticsearch.search.profile.query.QueryProfiler;
@@ -87,7 +88,7 @@ class KnnSearcher {
8788
private final int efSearch;
8889
private final int nProbe;
8990
private final KnnIndexTester.IndexType indexType;
90-
private final int dim;
91+
private int dim;
9192
private final VectorSimilarityFunction similarityFunction;
9293
private final VectorEncoding vectorEncoding;
9394
private final float overSamplingFactor;
@@ -117,6 +118,7 @@ void runSearch(KnnIndexTester.Results finalResults) throws IOException {
117118
TopDocs[] results = new TopDocs[numQueryVectors];
118119
int[][] resultIds = new int[numQueryVectors][];
119120
long elapsed, totalCpuTimeMS, totalVisited = 0;
121+
int offsetByteSize = 0;
120122
try (
121123
FileChannel input = FileChannel.open(queryPath);
122124
ExecutorService executorService = Executors.newFixedThreadPool(searchThreads, r -> new Thread(r, "KnnSearcher-Thread"))
@@ -128,7 +130,19 @@ void runSearch(KnnIndexTester.Results finalResults) throws IOException {
128130
+ " bytes, assuming vector count is "
129131
+ (queryPathSizeInBytes / ((long) dim * vectorEncoding.byteSize))
130132
);
131-
KnnIndexer.VectorReader targetReader = KnnIndexer.VectorReader.create(input, dim, vectorEncoding);
133+
if (dim == -1) {
134+
offsetByteSize = 4;
135+
ByteBuffer preamble = ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN);
136+
int bytesRead = Channels.readFromFileChannel(input, 0, preamble);
137+
if (bytesRead < 4) {
138+
throw new IllegalArgumentException("queryPath \"" + queryPath + "\" does not contain a valid dims?");
139+
}
140+
dim = preamble.getInt(0);
141+
if (dim <= 0) {
142+
throw new IllegalArgumentException("queryPath \"" + queryPath + "\" has invalid dimension: " + dim);
143+
}
144+
}
145+
KnnIndexer.VectorReader targetReader = KnnIndexer.VectorReader.create(input, dim, vectorEncoding, offsetByteSize);
132146
long startNS;
133147
try (MMapDirectory dir = new MMapDirectory(indexPath)) {
134148
try (DirectoryReader reader = DirectoryReader.open(dir)) {
@@ -191,7 +205,7 @@ void runSearch(KnnIndexTester.Results finalResults) throws IOException {
191205
}
192206
}
193207
logger.info("checking results");
194-
int[][] nn = getOrCalculateExactNN();
208+
int[][] nn = getOrCalculateExactNN(offsetByteSize);
195209
finalResults.avgRecall = checkResults(resultIds, nn, topK);
196210
finalResults.qps = (1000f * numQueryVectors) / elapsed;
197211
finalResults.avgLatency = (float) elapsed / numQueryVectors;
@@ -200,7 +214,7 @@ void runSearch(KnnIndexTester.Results finalResults) throws IOException {
200214
finalResults.avgCpuCount = (double) totalCpuTimeMS / elapsed;
201215
}
202216

203-
private int[][] getOrCalculateExactNN() throws IOException {
217+
private int[][] getOrCalculateExactNN(int vectorFileOffsetBytes) throws IOException {
204218
// look in working directory for cached nn file
205219
String hash = Integer.toString(
206220
Objects.hash(
@@ -228,9 +242,9 @@ private int[][] getOrCalculateExactNN() throws IOException {
228242
// checking low-precision recall
229243
int[][] nn;
230244
if (vectorEncoding.equals(VectorEncoding.BYTE)) {
231-
nn = computeExactNNByte(queryPath);
245+
nn = computeExactNNByte(queryPath, vectorFileOffsetBytes);
232246
} else {
233-
nn = computeExactNN(queryPath);
247+
nn = computeExactNN(queryPath, vectorFileOffsetBytes);
234248
}
235249
writeExactNN(nn, nnPath);
236250
long elapsedMS = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startNS); // ns -> ms
@@ -356,12 +370,17 @@ private void writeExactNN(int[][] nn, Path nnPath) throws IOException {
356370
}
357371
}
358372

359-
private int[][] computeExactNN(Path queryPath) throws IOException {
373+
private int[][] computeExactNN(Path queryPath, int vectorFileOffsetBytes) throws IOException {
360374
int[][] result = new int[numQueryVectors][];
361375
try (Directory dir = FSDirectory.open(indexPath); DirectoryReader reader = DirectoryReader.open(dir)) {
362376
List<Callable<Void>> tasks = new ArrayList<>();
363377
try (FileChannel qIn = FileChannel.open(queryPath)) {
364-
KnnIndexer.VectorReader queryReader = KnnIndexer.VectorReader.create(qIn, dim, VectorEncoding.FLOAT32);
378+
KnnIndexer.VectorReader queryReader = KnnIndexer.VectorReader.create(
379+
qIn,
380+
dim,
381+
VectorEncoding.FLOAT32,
382+
vectorFileOffsetBytes
383+
);
365384
for (int i = 0; i < numQueryVectors; i++) {
366385
float[] queryVector = new float[dim];
367386
queryReader.next(queryVector);
@@ -373,12 +392,12 @@ private int[][] computeExactNN(Path queryPath) throws IOException {
373392
}
374393
}
375394

376-
private int[][] computeExactNNByte(Path queryPath) throws IOException {
395+
private int[][] computeExactNNByte(Path queryPath, int vectorFileOffsetBytes) throws IOException {
377396
int[][] result = new int[numQueryVectors][];
378397
try (Directory dir = FSDirectory.open(indexPath); DirectoryReader reader = DirectoryReader.open(dir)) {
379398
List<Callable<Void>> tasks = new ArrayList<>();
380399
try (FileChannel qIn = FileChannel.open(queryPath)) {
381-
KnnIndexer.VectorReader queryReader = KnnIndexer.VectorReader.create(qIn, dim, VectorEncoding.BYTE);
400+
KnnIndexer.VectorReader queryReader = KnnIndexer.VectorReader.create(qIn, dim, VectorEncoding.BYTE, vectorFileOffsetBytes);
382401
for (int i = 0; i < numQueryVectors; i++) {
383402
byte[] queryVector = new byte[dim];
384403
queryReader.next(queryVector);

0 commit comments

Comments
 (0)