Skip to content

Commit 4427d92

Browse files
benwtrentywangd
authored andcommitted
Adds new unexposed and experimental IVF format (elastic#127528)
1 parent 3c1dacb commit 4427d92

File tree

23 files changed

+2582
-3
lines changed

23 files changed

+2582
-3
lines changed

benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OSQScorerBenchmark.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import org.apache.lucene.util.VectorUtil;
1818
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
1919
import org.elasticsearch.common.logging.LogConfigurator;
20-
import org.elasticsearch.simdvec.internal.vectorization.ES91OSQVectorsScorer;
20+
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
2121
import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider;
2222
import org.openjdk.jmh.annotations.Benchmark;
2323
import org.openjdk.jmh.annotations.BenchmarkMode;
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
* your election, the "Elastic License 2.0", the "GNU Affero General Public
77
* License v3.0 only", or the "Server Side Public License, v 1".
88
*/
9-
package org.elasticsearch.simdvec.internal.vectorization;
9+
package org.elasticsearch.simdvec;
1010

1111
import org.apache.lucene.index.VectorSimilarityFunction;
1212
import org.apache.lucene.store.IndexInput;

libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@
99

1010
package org.elasticsearch.simdvec;
1111

12+
import org.apache.lucene.store.IndexInput;
1213
import org.apache.lucene.util.BitUtil;
1314
import org.apache.lucene.util.Constants;
1415
import org.elasticsearch.simdvec.internal.vectorization.ESVectorUtilSupport;
1516
import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider;
1617

18+
import java.io.IOException;
1719
import java.lang.invoke.MethodHandle;
1820
import java.lang.invoke.MethodHandles;
1921
import java.lang.invoke.MethodType;
@@ -41,6 +43,10 @@ public class ESVectorUtil {
4143

4244
private static final ESVectorUtilSupport IMPL = ESVectorizationProvider.getInstance().getVectorUtilSupport();
4345

46+
public static ES91OSQVectorsScorer getES91OSQVectorsScorer(IndexInput input, int dimension) throws IOException {
47+
return ESVectorizationProvider.getInstance().newES91OSQVectorsScorer(input, dimension);
48+
}
49+
4450
public static long ipByteBinByte(byte[] q, byte[] d) {
4551
if (q.length != d.length * B_QUERY) {
4652
throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + B_QUERY + " x " + d.length);
@@ -211,4 +217,40 @@ public static void centerAndCalculateOSQStatsDp(float[] target, float[] centroid
211217
assert stats.length == 6;
212218
IMPL.centerAndCalculateOSQStatsDp(target, centroid, centered, stats);
213219
}
220+
221+
/**
222+
* Calculates the difference between two vectors and stores the result in a third vector.
223+
* @param v1 the first vector
224+
* @param v2 the second vector
225+
* @param result the result vector, must be the same length as the input vectors
226+
*/
227+
public static void subtract(float[] v1, float[] v2, float[] result) {
228+
if (v1.length != v2.length) {
229+
throw new IllegalArgumentException("vector dimensions differ: " + v1.length + "!=" + v2.length);
230+
}
231+
if (result.length != v1.length) {
232+
throw new IllegalArgumentException("vector dimensions differ: " + result.length + "!=" + v1.length);
233+
}
234+
for (int i = 0; i < v1.length; i++) {
235+
result[i] = v1[i] - v2[i];
236+
}
237+
}
238+
239+
/**
240+
* calculates the spill-over score for a vector and a centroid, given its residual with
241+
* its actually nearest centroid
242+
* @param v1 the vector
243+
* @param centroid the centroid
244+
* @param originalResidual the residual with the actually nearest centroid
245+
* @return the spill-over score (soar)
246+
*/
247+
public static float soarResidual(float[] v1, float[] centroid, float[] originalResidual) {
248+
if (v1.length != centroid.length) {
249+
throw new IllegalArgumentException("vector dimensions differ: " + v1.length + "!=" + centroid.length);
250+
}
251+
if (originalResidual.length != v1.length) {
252+
throw new IllegalArgumentException("vector dimensions differ: " + originalResidual.length + "!=" + v1.length);
253+
}
254+
return IMPL.soarResidual(v1, centroid, originalResidual);
255+
}
214256
}

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,18 @@ public void centerAndCalculateOSQStatsDp(float[] target, float[] centroid, float
138138
stats[5] = centroidDot;
139139
}
140140

141+
@Override
142+
public float soarResidual(float[] v1, float[] centroid, float[] originalResidual) {
143+
assert v1.length == centroid.length;
144+
assert v1.length == originalResidual.length;
145+
float proj = 0;
146+
for (int i = 0; i < v1.length; i++) {
147+
float djk = v1[i] - centroid[i];
148+
proj = fma(djk, originalResidual[i], proj);
149+
}
150+
return proj;
151+
}
152+
141153
public static int ipByteBitImpl(byte[] q, byte[] d) {
142154
return ipByteBitImpl(q, d, 0);
143155
}

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorizationProvider.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
package org.elasticsearch.simdvec.internal.vectorization;
1111

1212
import org.apache.lucene.store.IndexInput;
13+
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
1314

1415
import java.io.IOException;
1516

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,7 @@ public interface ESVectorUtilSupport {
2828
void centerAndCalculateOSQStatsEuclidean(float[] target, float[] centroid, float[] centered, float[] stats);
2929

3030
void centerAndCalculateOSQStatsDp(float[] target, float[] centroid, float[] centered, float[] stats);
31+
32+
float soarResidual(float[] v1, float[] centroid, float[] originalResidual);
33+
3134
}

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
package org.elasticsearch.simdvec.internal.vectorization;
1111

1212
import org.apache.lucene.store.IndexInput;
13+
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
1314

1415
import java.io.IOException;
1516
import java.util.Objects;

libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.apache.lucene.util.Constants;
1414
import org.elasticsearch.logging.LogManager;
1515
import org.elasticsearch.logging.Logger;
16+
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
1617

1718
import java.io.IOException;
1819
import java.util.Locale;

libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import org.apache.lucene.store.IndexInput;
2121
import org.apache.lucene.util.VectorUtil;
2222
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
23+
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
2324

2425
import java.io.IOException;
2526
import java.lang.foreign.MemorySegment;

libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,49 @@ public float calculateOSQLoss(float[] target, float[] interval, float step, floa
367367
return (1f - lambda) * xe * xe / norm2 + lambda * e;
368368
}
369369

370+
@Override
371+
public float soarResidual(float[] v1, float[] centroid, float[] originalResidual) {
372+
assert v1.length == centroid.length;
373+
assert v1.length == originalResidual.length;
374+
float proj = 0;
375+
int i = 0;
376+
if (v1.length > 2 * FLOAT_SPECIES.length()) {
377+
FloatVector projVec1 = FloatVector.zero(FLOAT_SPECIES);
378+
FloatVector projVec2 = FloatVector.zero(FLOAT_SPECIES);
379+
int unrolledLimit = FLOAT_SPECIES.loopBound(v1.length) - FLOAT_SPECIES.length();
380+
for (; i < unrolledLimit; i += 2 * FLOAT_SPECIES.length()) {
381+
// one
382+
FloatVector v1Vec0 = FloatVector.fromArray(FLOAT_SPECIES, v1, i);
383+
FloatVector centroidVec0 = FloatVector.fromArray(FLOAT_SPECIES, centroid, i);
384+
FloatVector originalResidualVec0 = FloatVector.fromArray(FLOAT_SPECIES, originalResidual, i);
385+
FloatVector djkVec0 = v1Vec0.sub(centroidVec0);
386+
projVec1 = fma(djkVec0, originalResidualVec0, projVec1);
387+
388+
// two
389+
FloatVector v1Vec1 = FloatVector.fromArray(FLOAT_SPECIES, v1, i + FLOAT_SPECIES.length());
390+
FloatVector centroidVec1 = FloatVector.fromArray(FLOAT_SPECIES, centroid, i + FLOAT_SPECIES.length());
391+
FloatVector originalResidualVec1 = FloatVector.fromArray(FLOAT_SPECIES, originalResidual, i + FLOAT_SPECIES.length());
392+
FloatVector djkVec1 = v1Vec1.sub(centroidVec1);
393+
projVec2 = fma(djkVec1, originalResidualVec1, projVec2);
394+
}
395+
// vector tail
396+
for (; i < FLOAT_SPECIES.loopBound(v1.length); i += FLOAT_SPECIES.length()) {
397+
FloatVector v1Vec = FloatVector.fromArray(FLOAT_SPECIES, v1, i);
398+
FloatVector centroidVec = FloatVector.fromArray(FLOAT_SPECIES, centroid, i);
399+
FloatVector originalResidualVec = FloatVector.fromArray(FLOAT_SPECIES, originalResidual, i);
400+
FloatVector djkVec = v1Vec.sub(centroidVec);
401+
projVec1 = fma(djkVec, originalResidualVec, projVec1);
402+
}
403+
proj += projVec1.add(projVec2).reduceLanes(ADD);
404+
}
405+
// tail
406+
for (; i < v1.length; i++) {
407+
float djk = v1[i] - centroid[i];
408+
proj = fma(djk, originalResidual[i], proj);
409+
}
410+
return proj;
411+
}
412+
370413
private static final VectorSpecies<Byte> BYTE_SPECIES_128 = ByteVector.SPECIES_128;
371414
private static final VectorSpecies<Byte> BYTE_SPECIES_256 = ByteVector.SPECIES_256;
372415

0 commit comments

Comments
 (0)