Skip to content

Commit c65f94d

Browse files
authored
Fix score computation in ES91Int4VectorsScorer (elastic#131905)
1 parent d2854f5 commit c65f94d

File tree

3 files changed

+74
-40
lines changed

3 files changed

+74
-40
lines changed

libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES91Int4VectorsScorer.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,7 @@ public float applyCorrections(
161161
float qcDist
162162
) {
163163
float ax = lowerInterval;
164-
// Here we assume `lx` is simply bit vectors, so the scaling isn't necessary
165-
float lx = upperInterval - ax;
164+
float lx = (upperInterval - ax) * FOUR_BIT_SCALE;
166165
float ay = queryLowerInterval;
167166
float ly = (queryUpperInterval - ay) * FOUR_BIT_SCALE;
168167
float y1 = queryComponentSum;

libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91Int4VectorsScorer.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ private void applyCorrectionsBulk(
354354
memorySegment,
355355
offset + 4 * BULK_SIZE + i * Float.BYTES,
356356
ByteOrder.LITTLE_ENDIAN
357-
).sub(ax);
357+
).sub(ax).mul(FOUR_BIT_SCALE);
358358
var targetComponentSums = ShortVector.fromMemorySegment(
359359
SHORT_SPECIES,
360360
memorySegment,

libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91Int4VectorScorerTests.java

Lines changed: 72 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
import org.elasticsearch.simdvec.ES91Int4VectorsScorer;
2121
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
2222

23-
import static org.hamcrest.Matchers.lessThan;
23+
import java.io.IOException;
24+
25+
import static org.hamcrest.Matchers.greaterThan;
2426

2527
public class ES91Int4VectorScorerTests extends BaseVectorizationTests {
2628

@@ -130,31 +132,59 @@ public void testInt4ScoreBulk() throws Exception {
130132
// only even dimensions are supported
131133
final int dimensions = random().nextInt(1, 1000) * 2;
132134
final int numVectors = random().nextInt(1, 10) * ES91Int4VectorsScorer.BULK_SIZE;
133-
final byte[] vector = new byte[ES91Int4VectorsScorer.BULK_SIZE * dimensions];
134-
final byte[] corrections = new byte[ES91Int4VectorsScorer.BULK_SIZE * 14];
135+
final float[][] vectors = new float[numVectors][dimensions];
136+
final int[] quantizedScratch = new int[dimensions];
137+
final byte[] quantizeVector = new byte[dimensions];
138+
final float[] centroid = new float[dimensions];
139+
VectorSimilarityFunction similarityFunction = randomFrom(VectorSimilarityFunction.values());
140+
for (int i = 0; i < dimensions; i++) {
141+
centroid[i] = random().nextFloat();
142+
}
143+
if (similarityFunction != VectorSimilarityFunction.EUCLIDEAN) {
144+
VectorUtil.l2normalize(centroid);
145+
}
146+
147+
OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(similarityFunction);
135148
try (Directory dir = new MMapDirectory(createTempDir())) {
136149
try (IndexOutput out = dir.createOutput("tests.bin", IOContext.DEFAULT)) {
150+
OptimizedScalarQuantizer.QuantizationResult[] results =
151+
new OptimizedScalarQuantizer.QuantizationResult[ES91Int4VectorsScorer.BULK_SIZE];
137152
for (int i = 0; i < numVectors; i += ES91Int4VectorsScorer.BULK_SIZE) {
138-
for (int j = 0; j < ES91Int4VectorsScorer.BULK_SIZE * dimensions; j++) {
139-
vector[j] = (byte) random().nextInt(16); // 4-bit quantization
153+
for (int j = 0; j < ES91Int4VectorsScorer.BULK_SIZE; j++) {
154+
for (int k = 0; k < dimensions; k++) {
155+
vectors[i + j][k] = random().nextFloat();
156+
}
157+
if (similarityFunction != VectorSimilarityFunction.EUCLIDEAN) {
158+
VectorUtil.l2normalize(vectors[i + j]);
159+
}
160+
results[j] = quantizer.scalarQuantize(vectors[i + j].clone(), quantizedScratch, (byte) 4, centroid);
161+
for (int k = 0; k < dimensions; k++) {
162+
quantizeVector[k] = (byte) quantizedScratch[k];
163+
}
164+
out.writeBytes(quantizeVector, 0, dimensions);
140165
}
141-
out.writeBytes(vector, 0, vector.length);
142-
random().nextBytes(corrections);
143-
out.writeBytes(corrections, 0, corrections.length);
166+
writeCorrections(results, out);
144167
}
145168
}
146-
final byte[] query = new byte[dimensions];
169+
final float[] query = new float[dimensions];
170+
final byte[] quantizeQuery = new byte[dimensions];
147171
for (int j = 0; j < dimensions; j++) {
148-
query[j] = (byte) random().nextInt(16); // 4-bit quantization
172+
query[j] = random().nextFloat();
149173
}
150-
OptimizedScalarQuantizer.QuantizationResult queryCorrections = new OptimizedScalarQuantizer.QuantizationResult(
151-
random().nextFloat(),
152-
random().nextFloat(),
153-
random().nextFloat(),
154-
Short.toUnsignedInt((short) random().nextInt())
174+
if (similarityFunction != VectorSimilarityFunction.EUCLIDEAN) {
175+
VectorUtil.l2normalize(query);
176+
}
177+
OptimizedScalarQuantizer.QuantizationResult queryCorrections = quantizer.scalarQuantize(
178+
query.clone(),
179+
quantizedScratch,
180+
(byte) 4,
181+
centroid
155182
);
156-
float centroidDp = random().nextFloat();
157-
VectorSimilarityFunction similarityFunction = randomFrom(VectorSimilarityFunction.values());
183+
for (int j = 0; j < dimensions; j++) {
184+
quantizeQuery[j] = (byte) quantizedScratch[j];
185+
}
186+
float centroidDp = VectorUtil.dotProduct(centroid, centroid);
187+
158188
try (IndexInput in = dir.openInput("tests.bin", IOContext.DEFAULT)) {
159189
// Work on a slice that has just the right number of bytes to make the test fail with an
160190
// index-out-of-bounds in case the implementation reads more than the allowed number of
@@ -166,7 +196,7 @@ public void testInt4ScoreBulk() throws Exception {
166196
float[] scoresPanama = new float[ES91Int4VectorsScorer.BULK_SIZE];
167197
for (int i = 0; i < numVectors; i += ES91Int4VectorsScorer.BULK_SIZE) {
168198
defaultScorer.scoreBulk(
169-
query,
199+
quantizeQuery,
170200
queryCorrections.lowerInterval(),
171201
queryCorrections.upperInterval(),
172202
queryCorrections.quantizedComponentSum(),
@@ -176,7 +206,7 @@ public void testInt4ScoreBulk() throws Exception {
176206
scoresDefault
177207
);
178208
panamaScorer.scoreBulk(
179-
query,
209+
quantizeQuery,
180210
queryCorrections.lowerInterval(),
181211
queryCorrections.upperInterval(),
182212
queryCorrections.quantizedComponentSum(),
@@ -186,29 +216,34 @@ public void testInt4ScoreBulk() throws Exception {
186216
scoresPanama
187217
);
188218
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
189-
if (scoresDefault[j] == scoresPanama[j]) {
190-
continue;
191-
}
192-
if (scoresDefault[j] > (1000 * Byte.MAX_VALUE)) {
193-
float diff = Math.abs(scoresDefault[j] - scoresPanama[j]);
194-
assertThat(
195-
"defaultScores: " + scoresDefault[j] + " bulkScores: " + scoresPanama[j],
196-
diff / scoresDefault[j],
197-
lessThan(1e-5f)
198-
);
199-
assertThat(
200-
"defaultScores: " + scoresDefault[j] + " bulkScores: " + scoresPanama[j],
201-
diff / scoresPanama[j],
202-
lessThan(1e-5f)
203-
);
204-
} else {
205-
assertEquals(scoresDefault[j], scoresPanama[j], 1e-2f);
206-
}
219+
assertEquals(scoresDefault[j], scoresPanama[j], 1e-2f);
220+
float realSimilarity = similarityFunction.compare(vectors[i + j], query);
221+
float accuracy = realSimilarity > scoresDefault[j]
222+
? scoresDefault[j] / realSimilarity
223+
: realSimilarity / scoresDefault[j];
224+
assertThat(accuracy, greaterThan(0.90f));
207225
}
208226
assertEquals(in.getFilePointer(), slice.getFilePointer());
209227
}
210228
assertEquals((long) (dimensions + 14) * numVectors, in.getFilePointer());
211229
}
212230
}
213231
}
232+
233+
private static void writeCorrections(OptimizedScalarQuantizer.QuantizationResult[] corrections, IndexOutput out) throws IOException {
234+
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
235+
out.writeInt(Float.floatToIntBits(correction.lowerInterval()));
236+
}
237+
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
238+
out.writeInt(Float.floatToIntBits(correction.upperInterval()));
239+
}
240+
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
241+
int targetComponentSum = correction.quantizedComponentSum();
242+
assert targetComponentSum >= 0 && targetComponentSum <= 0xffff;
243+
out.writeShort((short) targetComponentSum);
244+
}
245+
for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
246+
out.writeInt(Float.floatToIntBits(correction.additionalCorrection()));
247+
}
248+
}
214249
}

0 commit comments

Comments
 (0)