Skip to content

Commit fef1d0a

Browse files
[8.x] feat: VoyageAI integration (elastic#122134) (elastic#123768)
* feat: VoyageAI integration (elastic#122134) * VoyageAI embeddings and rerank: - embeddings works, tested - initial rerank code What's missing: - unit and integration tests - rerank request/response mapping and verification * VoyageAI embeddings and rerank: - embeddings works, tested - rerank works, tested (https://www.elastic.co/search-labs/blog/elasticsearch-cohere-rerank) What's missing: - unit and integration tests * VoyageAI embeddings and rerank: - embeddings works, tested - rerank works, tested (https://www.elastic.co/search-labs/blog/elasticsearch-cohere-rerank) What's missing: - unit and integration tests * VoyageAI embeddings and rerank: - embeddings works, tested - rerank works, tested (https://www.elastic.co/search-labs/blog/elasticsearch-cohere-rerank) What's missing: - unit and integration tests * Adding initial tests Moving dimensions to ServiceSettings * Correcting the TransportVersions.java * Correcting due to comments * Adding BIT support * Initial tests * More tests * More tests/corrections * Removing warnings * Further tests * Transport version correction * Adding changelog and correcting TransportVersions * Spotless tests * Changes due to the comments * Changes due to the comments * Correcting QA tests * Correcting QA tests --------- Co-authored-by: Jonathan Buttner <[email protected]> Co-authored-by: Jonathan Buttner <[email protected]> (cherry picked from commit 521f855) # Conflicts: # server/src/main/java/org/elasticsearch/TransportVersions.java # x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java * Using correct transport version and fixing errors * Fixing errors and test failures --------- Co-authored-by: fzowl <[email protected]>
1 parent 122df00 commit fef1d0a

File tree

53 files changed

+8137
-4
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+8137
-4
lines changed

docs/changelog/122134.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 122134
2+
summary: Adding integration for VoyageAI embeddings and rerank models
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ static TransportVersion def(int id) {
188188
public static final TransportVersion REMOVE_ALL_APPLICABLE_SELECTOR_BACKPORT_8_19 = def(8_841_0_02);
189189
public static final TransportVersion ESQL_RETRY_ON_SHARD_LEVEL_FAILURE_BACKPORT_8_19 = def(8_841_0_03);
190190
public static final TransportVersion ESQL_SUPPORT_PARTIAL_RESULTS_BACKPORT_8_19 = def(8_841_0_04);
191+
public static final TransportVersion VOYAGE_AI_INTEGRATION_ADDED_BACKPORT_8_X = def(8_841_0_05);
191192

192193
/*
193194
* STOP! READ THIS FIRST! No, really,

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
2626
@SuppressWarnings("unchecked")
2727
public void testGetServicesWithoutTaskType() throws IOException {
2828
List<Object> services = getAllServices();
29-
assertThat(services.size(), equalTo(19));
29+
assertThat(services.size(), equalTo(20));
3030

3131
String[] providers = new String[services.size()];
3232
for (int i = 0; i < services.size(); i++) {
@@ -54,6 +54,7 @@ public void testGetServicesWithoutTaskType() throws IOException {
5454
"test_reranking_service",
5555
"test_service",
5656
"text_embedding_test_service",
57+
"voyageai",
5758
"watsonxai"
5859
).toArray(),
5960
providers
@@ -63,7 +64,7 @@ public void testGetServicesWithoutTaskType() throws IOException {
6364
@SuppressWarnings("unchecked")
6465
public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
6566
List<Object> services = getServices(TaskType.TEXT_EMBEDDING);
66-
assertThat(services.size(), equalTo(14));
67+
assertThat(services.size(), equalTo(15));
6768

6869
String[] providers = new String[services.size()];
6970
for (int i = 0; i < services.size(); i++) {
@@ -86,6 +87,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
8687
"mistral",
8788
"openai",
8889
"text_embedding_test_service",
90+
"voyageai",
8991
"watsonxai"
9092
).toArray(),
9193
providers
@@ -95,7 +97,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
9597
@SuppressWarnings("unchecked")
9698
public void testGetServicesWithRerankTaskType() throws IOException {
9799
List<Object> services = getServices(TaskType.RERANK);
98-
assertThat(services.size(), equalTo(6));
100+
assertThat(services.size(), equalTo(7));
99101

100102
String[] providers = new String[services.size()];
101103
for (int i = 0; i < services.size(); i++) {
@@ -104,7 +106,8 @@ public void testGetServicesWithRerankTaskType() throws IOException {
104106
}
105107

106108
assertArrayEquals(
107-
List.of("alibabacloud-ai-search", "cohere", "elasticsearch", "googlevertexai", "jinaai", "test_reranking_service").toArray(),
109+
List.of("alibabacloud-ai-search", "cohere", "elasticsearch", "googlevertexai", "jinaai", "test_reranking_service", "voyageai")
110+
.toArray(),
108111
providers
109112
);
110113
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,11 @@
9191
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettings;
9292
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsTaskSettings;
9393
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
94+
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings;
95+
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings;
96+
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings;
97+
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankServiceSettings;
98+
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankTaskSettings;
9499

95100
import java.util.ArrayList;
96101
import java.util.List;
@@ -143,6 +148,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
143148
addEisNamedWriteables(namedWriteables);
144149
addAlibabaCloudSearchNamedWriteables(namedWriteables);
145150
addJinaAINamedWriteables(namedWriteables);
151+
addVoyageAINamedWriteables(namedWriteables);
146152

147153
addUnifiedNamedWriteables(namedWriteables);
148154

@@ -619,6 +625,28 @@ private static void addJinaAINamedWriteables(List<NamedWriteableRegistry.Entry>
619625
);
620626
}
621627

628+
private static void addVoyageAINamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
629+
namedWriteables.add(
630+
new NamedWriteableRegistry.Entry(ServiceSettings.class, VoyageAIServiceSettings.NAME, VoyageAIServiceSettings::new)
631+
);
632+
namedWriteables.add(
633+
new NamedWriteableRegistry.Entry(
634+
ServiceSettings.class,
635+
VoyageAIEmbeddingsServiceSettings.NAME,
636+
VoyageAIEmbeddingsServiceSettings::new
637+
)
638+
);
639+
namedWriteables.add(
640+
new NamedWriteableRegistry.Entry(TaskSettings.class, VoyageAIEmbeddingsTaskSettings.NAME, VoyageAIEmbeddingsTaskSettings::new)
641+
);
642+
namedWriteables.add(
643+
new NamedWriteableRegistry.Entry(ServiceSettings.class, VoyageAIRerankServiceSettings.NAME, VoyageAIRerankServiceSettings::new)
644+
);
645+
namedWriteables.add(
646+
new NamedWriteableRegistry.Entry(TaskSettings.class, VoyageAIRerankTaskSettings.NAME, VoyageAIRerankTaskSettings::new)
647+
);
648+
}
649+
622650
private static void addEisNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
623651
namedWriteables.add(
624652
new NamedWriteableRegistry.Entry(

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@
128128
import org.elasticsearch.xpack.inference.services.jinaai.JinaAIService;
129129
import org.elasticsearch.xpack.inference.services.mistral.MistralService;
130130
import org.elasticsearch.xpack.inference.services.openai.OpenAiService;
131+
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIService;
131132
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
132133

133134
import java.util.ArrayList;
@@ -359,6 +360,7 @@ public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
359360
context -> new AlibabaCloudSearchService(httpFactory.get(), serviceComponents.get()),
360361
context -> new IbmWatsonxService(httpFactory.get(), serviceComponents.get()),
361362
context -> new JinaAIService(httpFactory.get(), serviceComponents.get()),
363+
context -> new VoyageAIService(httpFactory.get(), serviceComponents.get()),
362364
ElasticsearchInternalService::new
363365
);
364366
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.action.voyageai;
9+
10+
import org.elasticsearch.inference.InputType;
11+
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
12+
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
13+
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
14+
import org.elasticsearch.xpack.inference.external.http.sender.VoyageAIEmbeddingsRequestManager;
15+
import org.elasticsearch.xpack.inference.external.http.sender.VoyageAIRerankRequestManager;
16+
import org.elasticsearch.xpack.inference.services.ServiceComponents;
17+
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel;
18+
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModel;
19+
20+
import java.util.Map;
21+
import java.util.Objects;
22+
23+
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
24+
25+
/**
26+
* Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the voyageai model type.
27+
*/
28+
public class VoyageAIActionCreator implements VoyageAIActionVisitor {
29+
private final Sender sender;
30+
private final ServiceComponents serviceComponents;
31+
32+
public VoyageAIActionCreator(Sender sender, ServiceComponents serviceComponents) {
33+
this.sender = Objects.requireNonNull(sender);
34+
this.serviceComponents = Objects.requireNonNull(serviceComponents);
35+
}
36+
37+
@Override
38+
public ExecutableAction create(VoyageAIEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType) {
39+
var overriddenModel = VoyageAIEmbeddingsModel.of(model, taskSettings, inputType);
40+
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("VoyageAI embeddings");
41+
var requestCreator = VoyageAIEmbeddingsRequestManager.of(overriddenModel, serviceComponents.threadPool());
42+
return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage);
43+
}
44+
45+
@Override
46+
public ExecutableAction create(VoyageAIRerankModel model, Map<String, Object> taskSettings) {
47+
var overriddenModel = VoyageAIRerankModel.of(model, taskSettings);
48+
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("VoyageAI rerank");
49+
var requestCreator = VoyageAIRerankRequestManager.of(overriddenModel, serviceComponents.threadPool());
50+
return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage);
51+
}
52+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.action.voyageai;
9+
10+
import org.elasticsearch.inference.InputType;
11+
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
12+
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel;
13+
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModel;
14+
15+
import java.util.Map;
16+
17+
public interface VoyageAIActionVisitor {
18+
ExecutableAction create(VoyageAIEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType);
19+
20+
ExecutableAction create(VoyageAIRerankModel model, Map<String, Object> taskSettings);
21+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.http.sender;
9+
10+
import org.apache.logging.log4j.LogManager;
11+
import org.apache.logging.log4j.Logger;
12+
import org.elasticsearch.action.ActionListener;
13+
import org.elasticsearch.inference.InferenceServiceResults;
14+
import org.elasticsearch.threadpool.ThreadPool;
15+
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
16+
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
17+
import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIEmbeddingsRequest;
18+
import org.elasticsearch.xpack.inference.external.response.voyageai.VoyageAIEmbeddingsResponseEntity;
19+
import org.elasticsearch.xpack.inference.external.voyageai.VoyageAIResponseHandler;
20+
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel;
21+
22+
import java.util.List;
23+
import java.util.Objects;
24+
import java.util.function.Supplier;
25+
26+
public class VoyageAIEmbeddingsRequestManager extends VoyageAIRequestManager {
27+
private static final Logger logger = LogManager.getLogger(VoyageAIEmbeddingsRequestManager.class);
28+
private static final ResponseHandler HANDLER = createEmbeddingsHandler();
29+
30+
private static ResponseHandler createEmbeddingsHandler() {
31+
return new VoyageAIResponseHandler("voyageai text embedding", VoyageAIEmbeddingsResponseEntity::fromResponse);
32+
}
33+
34+
public static VoyageAIEmbeddingsRequestManager of(VoyageAIEmbeddingsModel model, ThreadPool threadPool) {
35+
return new VoyageAIEmbeddingsRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool));
36+
}
37+
38+
private final VoyageAIEmbeddingsModel model;
39+
40+
private VoyageAIEmbeddingsRequestManager(VoyageAIEmbeddingsModel model, ThreadPool threadPool) {
41+
super(threadPool, model);
42+
this.model = Objects.requireNonNull(model);
43+
}
44+
45+
@Override
46+
public void execute(
47+
InferenceInputs inferenceInputs,
48+
RequestSender requestSender,
49+
Supplier<Boolean> hasRequestCompletedFunction,
50+
ActionListener<InferenceServiceResults> listener
51+
) {
52+
List<String> docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs();
53+
VoyageAIEmbeddingsRequest request = new VoyageAIEmbeddingsRequest(docsInput, model);
54+
55+
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
56+
}
57+
}
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.http.sender;
9+
10+
import org.elasticsearch.threadpool.ThreadPool;
11+
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIModel;
12+
13+
import java.util.Map;
14+
import java.util.Objects;
15+
16+
abstract class VoyageAIRequestManager extends BaseRequestManager {
17+
private static final String DEFAULT_MODEL_FAMILY = "default_model_family";
18+
private static final Map<String, String> MODEL_TO_MODEL_FAMILY = Map.of(
19+
"voyage-multimodal-3",
20+
"embed_multimodal",
21+
"voyage-3-large",
22+
"embed_large",
23+
"voyage-code-3",
24+
"embed_large",
25+
"voyage-3",
26+
"embed_medium",
27+
"voyage-3-lite",
28+
"embed_small",
29+
"voyage-finance-2",
30+
"embed_large",
31+
"voyage-law-2",
32+
"embed_large",
33+
"voyage-code-2",
34+
"embed_large",
35+
"rerank-2",
36+
"rerank_large",
37+
"rerank-2-lite",
38+
"rerank_small"
39+
);
40+
41+
protected VoyageAIRequestManager(ThreadPool threadPool, VoyageAIModel model) {
42+
super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings());
43+
}
44+
45+
record RateLimitGrouping(int apiKeyHash) {
46+
public static RateLimitGrouping of(VoyageAIModel model) {
47+
Objects.requireNonNull(model);
48+
String modelId = model.getServiceSettings().modelId();
49+
String modelFamily = MODEL_TO_MODEL_FAMILY.getOrDefault(modelId, DEFAULT_MODEL_FAMILY);
50+
51+
return new RateLimitGrouping(modelFamily.hashCode());
52+
}
53+
}
54+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.http.sender;
9+
10+
import org.apache.logging.log4j.LogManager;
11+
import org.apache.logging.log4j.Logger;
12+
import org.elasticsearch.action.ActionListener;
13+
import org.elasticsearch.inference.InferenceServiceResults;
14+
import org.elasticsearch.threadpool.ThreadPool;
15+
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
16+
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
17+
import org.elasticsearch.xpack.inference.external.request.voyageai.VoyageAIRerankRequest;
18+
import org.elasticsearch.xpack.inference.external.response.voyageai.VoyageAIRerankResponseEntity;
19+
import org.elasticsearch.xpack.inference.external.voyageai.VoyageAIResponseHandler;
20+
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModel;
21+
22+
import java.util.Objects;
23+
import java.util.function.Supplier;
24+
25+
public class VoyageAIRerankRequestManager extends VoyageAIRequestManager {
26+
private static final Logger logger = LogManager.getLogger(VoyageAIRerankRequestManager.class);
27+
private static final ResponseHandler HANDLER = createVoyageAIResponseHandler();
28+
29+
private static ResponseHandler createVoyageAIResponseHandler() {
30+
return new VoyageAIResponseHandler("voyageai rerank", (request, response) -> VoyageAIRerankResponseEntity.fromResponse(response));
31+
}
32+
33+
public static VoyageAIRerankRequestManager of(VoyageAIRerankModel model, ThreadPool threadPool) {
34+
return new VoyageAIRerankRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool));
35+
}
36+
37+
private final VoyageAIRerankModel model;
38+
39+
private VoyageAIRerankRequestManager(VoyageAIRerankModel model, ThreadPool threadPool) {
40+
super(threadPool, model);
41+
this.model = model;
42+
}
43+
44+
@Override
45+
public void execute(
46+
InferenceInputs inferenceInputs,
47+
RequestSender requestSender,
48+
Supplier<Boolean> hasRequestCompletedFunction,
49+
ActionListener<InferenceServiceResults> listener
50+
) {
51+
var rerankInput = QueryAndDocsInputs.of(inferenceInputs);
52+
VoyageAIRerankRequest request = new VoyageAIRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model);
53+
54+
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
55+
}
56+
}

0 commit comments

Comments
 (0)