Skip to content

Commit e067465

Browse files
committed
Changes due to the comments
1 parent f3f43ed commit e067465

File tree

12 files changed

+256
-285
lines changed

12 files changed

+256
-285
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreator.java

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,21 +37,15 @@ public VoyageAIActionCreator(Sender sender, ServiceComponents serviceComponents)
3737
@Override
3838
public ExecutableAction create(VoyageAIEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType) {
3939
var overriddenModel = VoyageAIEmbeddingsModel.of(model, taskSettings, inputType);
40-
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(
41-
overriddenModel.getServiceSettings().getCommonSettings().uri(),
42-
"VoyageAI embeddings"
43-
);
40+
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(overriddenModel.uri(), "VoyageAI embeddings");
4441
var requestCreator = VoyageAIEmbeddingsRequestManager.of(overriddenModel, serviceComponents.threadPool());
4542
return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage);
4643
}
4744

4845
@Override
4946
public ExecutableAction create(VoyageAIRerankModel model, Map<String, Object> taskSettings) {
5047
var overriddenModel = VoyageAIRerankModel.of(model, taskSettings);
51-
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(
52-
overriddenModel.getServiceSettings().getCommonSettings().uri(),
53-
"VoyageAI rerank"
54-
);
48+
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(overriddenModel.uri(), "VoyageAI rerank");
5549
var requestCreator = VoyageAIRerankRequestManager.of(overriddenModel, serviceComponents.threadPool());
5650
return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage);
5751
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/VoyageAIRequestManager.java

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,33 @@
1010
import org.elasticsearch.threadpool.ThreadPool;
1111
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIModel;
1212

13+
import java.util.Map;
1314
import java.util.Objects;
1415

1516
abstract class VoyageAIRequestManager extends BaseRequestManager {
17+
private static final String DEFAULT_MODEL_FAMILY = "default_model_family";
18+
private static final Map<String, String> MODEL_TO_MODEL_FAMILY = Map.of(
19+
"voyage-multimodal-3",
20+
"embed_multimodal",
21+
"voyage-3-large",
22+
"embed_large",
23+
"voyage-code-3",
24+
"embed_large",
25+
"voyage-3",
26+
"embed_medium",
27+
"voyage-3-lite",
28+
"embed_small",
29+
"voyage-finance-2",
30+
"embed_large",
31+
"voyage-law-2",
32+
"embed_large",
33+
"voyage-code-2",
34+
"embed_large",
35+
"rerank-2",
36+
"rerank_large",
37+
"rerank-2-lite",
38+
"rerank_small"
39+
);
1640

1741
protected VoyageAIRequestManager(ThreadPool threadPool, VoyageAIModel model) {
1842
super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings());
@@ -21,8 +45,10 @@ protected VoyageAIRequestManager(ThreadPool threadPool, VoyageAIModel model) {
2145
record RateLimitGrouping(int apiKeyHash) {
2246
public static RateLimitGrouping of(VoyageAIModel model) {
2347
Objects.requireNonNull(model);
48+
String modelId = model.getServiceSettings().modelId();
49+
String modelFamily = MODEL_TO_MODEL_FAMILY.getOrDefault(modelId, DEFAULT_MODEL_FAMILY);
2450

25-
return new RateLimitGrouping(model.apiKey().hashCode());
51+
return new RateLimitGrouping(modelFamily.hashCode());
2652
}
2753
}
2854
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequest.java

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
package org.elasticsearch.xpack.inference.external.request.voyageai;
99

1010
import org.apache.http.client.methods.HttpPost;
11-
import org.apache.http.client.utils.URIBuilder;
1211
import org.apache.http.entity.ByteArrayEntity;
1312
import org.elasticsearch.common.Strings;
1413
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
@@ -19,7 +18,6 @@
1918
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings;
2019

2120
import java.net.URI;
22-
import java.net.URISyntaxException;
2321
import java.nio.charset.StandardCharsets;
2422
import java.util.List;
2523
import java.util.Objects;
@@ -36,7 +34,7 @@ public class VoyageAIEmbeddingsRequest extends VoyageAIRequest {
3634
public VoyageAIEmbeddingsRequest(List<String> input, VoyageAIEmbeddingsModel embeddingsModel) {
3735
Objects.requireNonNull(embeddingsModel);
3836

39-
account = VoyageAIAccount.of(embeddingsModel, VoyageAIEmbeddingsRequest::buildDefaultUri);
37+
account = VoyageAIAccount.of(embeddingsModel);
4038
this.input = Objects.requireNonNull(input);
4139
serviceSettings = embeddingsModel.getServiceSettings();
4240
taskSettings = embeddingsModel.getTaskSettings();
@@ -54,7 +52,7 @@ public HttpRequest createHttpRequest() {
5452
);
5553
httpPost.setEntity(byteEntity);
5654

57-
decorateWithAuthHeader(httpPost, account);
55+
decorateWithHeaders(httpPost, account);
5856

5957
return new HttpRequest(httpPost, getInferenceEntityId());
6058
}
@@ -86,11 +84,4 @@ public VoyageAIEmbeddingsTaskSettings getTaskSettings() {
8684
public VoyageAIEmbeddingsServiceSettings getServiceSettings() {
8785
return serviceSettings;
8886
}
89-
90-
public static URI buildDefaultUri() throws URISyntaxException {
91-
return new URIBuilder().setScheme("https")
92-
.setHost(VoyageAIUtils.HOST)
93-
.setPathSegments(VoyageAIUtils.VERSION_1, VoyageAIUtils.EMBEDDINGS_PATH)
94-
.build();
95-
}
9687
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRequest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
public abstract class VoyageAIRequest implements Request {
1919

20-
public static void decorateWithAuthHeader(HttpPost request, VoyageAIAccount account) {
20+
public static void decorateWithHeaders(HttpPost request, VoyageAIAccount account) {
2121
request.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());
2222
request.setHeader(createAuthBearerHeader(account.apiKey()));
2323
request.setHeader(VoyageAIUtils.createRequestSourceHeader());

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequest.java

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
package org.elasticsearch.xpack.inference.external.request.voyageai;
99

1010
import org.apache.http.client.methods.HttpPost;
11-
import org.apache.http.client.utils.URIBuilder;
1211
import org.apache.http.entity.ByteArrayEntity;
1312
import org.elasticsearch.common.Strings;
1413
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
@@ -18,7 +17,6 @@
1817
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankTaskSettings;
1918

2019
import java.net.URI;
21-
import java.net.URISyntaxException;
2220
import java.nio.charset.StandardCharsets;
2321
import java.util.List;
2422
import java.util.Objects;
@@ -35,7 +33,7 @@ public class VoyageAIRerankRequest extends VoyageAIRequest {
3533
public VoyageAIRerankRequest(String query, List<String> input, VoyageAIRerankModel model) {
3634
Objects.requireNonNull(model);
3735

38-
this.account = VoyageAIAccount.of(model, VoyageAIRerankRequest::buildDefaultUri);
36+
this.account = VoyageAIAccount.of(model);
3937
this.input = Objects.requireNonNull(input);
4038
this.query = Objects.requireNonNull(query);
4139
taskSettings = model.getTaskSettings();
@@ -52,7 +50,7 @@ public HttpRequest createHttpRequest() {
5250
);
5351
httpPost.setEntity(byteEntity);
5452

55-
decorateWithAuthHeader(httpPost, account);
53+
decorateWithHeaders(httpPost, account);
5654

5755
return new HttpRequest(httpPost, getInferenceEntityId());
5856
}
@@ -76,11 +74,4 @@ public Request truncate() {
7674
public boolean[] getTruncationInfo() {
7775
return null;
7876
}
79-
80-
public static URI buildDefaultUri() throws URISyntaxException {
81-
return new URIBuilder().setScheme("https")
82-
.setHost(VoyageAIUtils.HOST)
83-
.setPathSegments(VoyageAIUtils.VERSION_1, VoyageAIUtils.RERANK_PATH)
84-
.build();
85-
}
8677
}

0 commit comments

Comments
 (0)