Skip to content

Commit 5149815

Browse files
committed
Updates to allow using Cohere binary embedding response in semantic search queries. (elastic#121827)
* wip * wip * [CI] Auto commit changes from spotless * updating tests * [CI] Auto commit changes from spotless * Update docs/changelog/121827.yaml * Updates after the refactor * [CI] Auto commit changes from spotless * Updating error message --------- Co-authored-by: elasticsearchmachine <[email protected]> (cherry picked from commit e843849) # Conflicts: # x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingBitResults.java
1 parent 223d50f commit 5149815

File tree

9 files changed

+140
-9
lines changed

9 files changed

+140
-9
lines changed

docs/changelog/121827.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 121827
2+
summary: Updates to allow using Cohere binary embedding response in semantic search
3+
queries
4+
area: Machine Learning
5+
type: bug
6+
issues: []

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingBitResults.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ public int getFirstEmbeddingSize() {
5555
if (embeddings.isEmpty()) {
5656
throw new IllegalStateException("Embeddings list is empty");
5757
}
58-
return embeddings.get(0).values().length;
58+
// bit embeddings are encoded as bytes so convert this to bits
59+
return Byte.SIZE * embeddings.getFirst().values().length;
5960
}
6061

6162
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -711,9 +711,12 @@ yield new SparseVectorQueryBuilder(
711711

712712
MlTextEmbeddingResults textEmbeddingResults = (MlTextEmbeddingResults) inferenceResults;
713713
float[] inference = textEmbeddingResults.getInferenceAsFloat();
714-
if (inference.length != modelSettings.dimensions()) {
714+
var inferenceLength = modelSettings.elementType() == DenseVectorFieldMapper.ElementType.BIT
715+
? inference.length * Byte.SIZE
716+
: inference.length;
717+
if (inferenceLength != modelSettings.dimensions()) {
715718
throw new IllegalArgumentException(
716-
generateDimensionCountMismatchMessage(inference.length, modelSettings.dimensions())
719+
generateDimensionCountMismatchMessage(inferenceLength, modelSettings.dimensions())
717720
);
718721
}
719722

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import org.elasticsearch.xpack.inference.services.ServiceComponents;
4040
import org.elasticsearch.xpack.inference.services.ServiceUtils;
4141
import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModel;
42+
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType;
4243
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel;
4344
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsServiceSettings;
4445
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankModel;
@@ -313,7 +314,7 @@ public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
313314
if (model instanceof CohereEmbeddingsModel embeddingsModel) {
314315
var serviceSettings = embeddingsModel.getServiceSettings();
315316
var similarityFromModel = serviceSettings.similarity();
316-
var similarityToUse = similarityFromModel == null ? defaultSimilarity() : similarityFromModel;
317+
var similarityToUse = similarityFromModel == null ? defaultSimilarity(serviceSettings.getEmbeddingType()) : similarityFromModel;
317318

318319
var updatedServiceSettings = new CohereEmbeddingsServiceSettings(
319320
new CohereServiceSettings(
@@ -341,7 +342,11 @@ public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
341342
*
342343
* @return The default similarity.
343344
*/
344-
static SimilarityMeasure defaultSimilarity() {
345+
static SimilarityMeasure defaultSimilarity(CohereEmbeddingType embeddingType) {
346+
if (embeddingType == CohereEmbeddingType.BIT || embeddingType == CohereEmbeddingType.BINARY) {
347+
return SimilarityMeasure.L2_NORM;
348+
}
349+
345350
return SimilarityMeasure.DOT_PRODUCT;
346351
}
347352

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
1414
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
1515
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
16+
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults;
1617
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
1718
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
1819
import org.elasticsearch.xpack.core.ml.search.WeightedToken;
@@ -377,6 +378,78 @@ public void testMergingListener_Byte() {
377378
}
378379
}
379380

381+
public void testMergingListener_Bit() {
382+
int batchSize = 5;
383+
int chunkSize = 20;
384+
int overlap = 0;
385+
// passage will be chunked into batchSize + 1 parts
386+
// and spread over 2 batch requests
387+
int numberOfWordsInPassage = (chunkSize * batchSize) + 5;
388+
389+
var passageBuilder = new StringBuilder();
390+
for (int i = 0; i < numberOfWordsInPassage; i++) {
391+
passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace
392+
}
393+
List<String> inputs = List.of("1st small", passageBuilder.toString(), "2nd small", "3rd small");
394+
395+
var finalListener = testListener();
396+
var batches = new EmbeddingRequestChunker(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener);
397+
assertThat(batches, hasSize(2));
398+
399+
// 4 inputs in 2 batches
400+
{
401+
var embeddings = new ArrayList<TextEmbeddingByteResults.Embedding>();
402+
for (int i = 0; i < batchSize; i++) {
403+
embeddings.add(new TextEmbeddingByteResults.Embedding(new byte[] { randomByte() }));
404+
}
405+
batches.get(0).listener().onResponse(new TextEmbeddingBitResults(embeddings));
406+
}
407+
{
408+
var embeddings = new ArrayList<TextEmbeddingByteResults.Embedding>();
409+
for (int i = 0; i < 4; i++) { // 4 requests in the 2nd batch
410+
embeddings.add(new TextEmbeddingByteResults.Embedding(new byte[] { randomByte() }));
411+
}
412+
batches.get(1).listener().onResponse(new TextEmbeddingBitResults(embeddings));
413+
}
414+
415+
assertNotNull(finalListener.results);
416+
assertThat(finalListener.results, hasSize(4));
417+
{
418+
var chunkedResult = finalListener.results.get(0);
419+
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
420+
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
421+
assertThat(chunkedByteResult.chunks(), hasSize(1));
422+
assertEquals("1st small", chunkedByteResult.chunks().get(0).matchedText());
423+
}
424+
{
425+
// this is the large input split in multiple chunks
426+
var chunkedResult = finalListener.results.get(1);
427+
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
428+
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
429+
assertThat(chunkedByteResult.chunks(), hasSize(6));
430+
assertThat(chunkedByteResult.chunks().get(0).matchedText(), startsWith("passage_input0 "));
431+
assertThat(chunkedByteResult.chunks().get(1).matchedText(), startsWith(" passage_input20 "));
432+
assertThat(chunkedByteResult.chunks().get(2).matchedText(), startsWith(" passage_input40 "));
433+
assertThat(chunkedByteResult.chunks().get(3).matchedText(), startsWith(" passage_input60 "));
434+
assertThat(chunkedByteResult.chunks().get(4).matchedText(), startsWith(" passage_input80 "));
435+
assertThat(chunkedByteResult.chunks().get(5).matchedText(), startsWith(" passage_input100 "));
436+
}
437+
{
438+
var chunkedResult = finalListener.results.get(2);
439+
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
440+
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
441+
assertThat(chunkedByteResult.chunks(), hasSize(1));
442+
assertEquals("2nd small", chunkedByteResult.chunks().get(0).matchedText());
443+
}
444+
{
445+
var chunkedResult = finalListener.results.get(3);
446+
assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class));
447+
var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult;
448+
assertThat(chunkedByteResult.chunks(), hasSize(1));
449+
assertEquals("3rd small", chunkedByteResult.chunks().get(0).matchedText());
450+
}
451+
}
452+
380453
public void testMergingListener_Sparse() {
381454
int batchSize = 4;
382455
int chunkSize = 10;

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingBitResultsTests.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,17 @@ public void testTransformToCoordinationFormat() {
105105
);
106106
}
107107

108+
public void testGetFirstEmbeddingSize() {
109+
var firstEmbeddingSize = new TextEmbeddingBitResults(
110+
List.of(
111+
new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 23, (byte) 24 }),
112+
new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 25, (byte) 26 })
113+
)
114+
).getFirstEmbeddingSize();
115+
116+
assertThat(firstEmbeddingSize, is(16));
117+
}
118+
108119
@Override
109120
protected Writeable.Reader<TextEmbeddingBitResults> instanceReader() {
110121
return TextEmbeddingBitResults::new;

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingByteResultsTests.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,17 @@ public void testTransformToCoordinationFormat() {
104104
);
105105
}
106106

107+
public void testGetFirstEmbeddingSize() {
108+
var firstEmbeddingSize = new TextEmbeddingByteResults(
109+
List.of(
110+
new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 23, (byte) 24 }),
111+
new TextEmbeddingByteResults.Embedding(new byte[] { (byte) 25, (byte) 26 })
112+
)
113+
).getFirstEmbeddingSize();
114+
115+
assertThat(firstEmbeddingSize, is(2));
116+
}
117+
107118
@Override
108119
protected Writeable.Reader<TextEmbeddingByteResults> instanceReader() {
109120
return TextEmbeddingByteResults::new;

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,17 @@ public void testTransformToCoordinationFormat() {
105105
);
106106
}
107107

108+
public void testGetFirstEmbeddingSize() {
109+
var firstEmbeddingSize = new TextEmbeddingFloatResults(
110+
List.of(
111+
new TextEmbeddingFloatResults.Embedding(new float[] { 0.1F, 0.2F }),
112+
new TextEmbeddingFloatResults.Embedding(new float[] { 0.3F, 0.4F })
113+
)
114+
).getFirstEmbeddingSize();
115+
116+
assertThat(firstEmbeddingSize, is(2));
117+
}
118+
108119
@Override
109120
protected Writeable.Reader<TextEmbeddingFloatResults> instanceReader() {
110121
return TextEmbeddingFloatResults::new;

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1099,20 +1099,23 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si
10991099

11001100
try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) {
11011101
var embeddingSize = randomNonNegativeInt();
1102+
var embeddingType = randomFrom(CohereEmbeddingType.values());
11021103
var model = CohereEmbeddingsModelTests.createModel(
11031104
randomAlphaOfLength(10),
11041105
randomAlphaOfLength(10),
11051106
CohereEmbeddingsTaskSettings.EMPTY_SETTINGS,
11061107
randomNonNegativeInt(),
11071108
randomNonNegativeInt(),
11081109
randomAlphaOfLength(10),
1109-
randomFrom(CohereEmbeddingType.values()),
1110+
embeddingType,
11101111
similarityMeasure
11111112
);
11121113

11131114
Model updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize);
11141115

1115-
SimilarityMeasure expectedSimilarityMeasure = similarityMeasure == null ? CohereService.defaultSimilarity() : similarityMeasure;
1116+
SimilarityMeasure expectedSimilarityMeasure = similarityMeasure == null
1117+
? CohereService.defaultSimilarity(embeddingType)
1118+
: similarityMeasure;
11161119
assertEquals(expectedSimilarityMeasure, updatedModel.getServiceSettings().similarity());
11171120
assertEquals(embeddingSize, updatedModel.getServiceSettings().dimensions().intValue());
11181121
}
@@ -1587,8 +1590,15 @@ public void testChunkedInfer_BatchesCalls_Bytes() throws IOException {
15871590
}
15881591
}
15891592

1590-
public void testDefaultSimilarity() {
1591-
assertEquals(SimilarityMeasure.DOT_PRODUCT, CohereService.defaultSimilarity());
1593+
public void testDefaultSimilarity_BinaryEmbedding() {
1594+
assertEquals(SimilarityMeasure.L2_NORM, CohereService.defaultSimilarity(CohereEmbeddingType.BINARY));
1595+
assertEquals(SimilarityMeasure.L2_NORM, CohereService.defaultSimilarity(CohereEmbeddingType.BIT));
1596+
}
1597+
1598+
public void testDefaultSimilarity_NotBinaryEmbedding() {
1599+
assertEquals(SimilarityMeasure.DOT_PRODUCT, CohereService.defaultSimilarity(CohereEmbeddingType.FLOAT));
1600+
assertEquals(SimilarityMeasure.DOT_PRODUCT, CohereService.defaultSimilarity(CohereEmbeddingType.BYTE));
1601+
assertEquals(SimilarityMeasure.DOT_PRODUCT, CohereService.defaultSimilarity(CohereEmbeddingType.INT8));
15921602
}
15931603

15941604
public void testInfer_StreamRequest() throws Exception {

0 commit comments

Comments
 (0)