Skip to content

Commit d0f328d

Browse files
add default inference endpoint for Elastic Inference Service rerank (elastic#129681)
* add Elastic Inference Service rerank default inference endpoint * [CI] Auto commit changes from spotless * fix integ tests * update mock Elastic Inference Service authorization response * fix rerank service test --------- Co-authored-by: elasticsearchmachine <[email protected]> (cherry picked from commit cef717c) # Conflicts: # x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java # x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java
1 parent 09e312c commit d0f328d

File tree

8 files changed

+74
-8
lines changed

8 files changed

+74
-8
lines changed

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ public void testGetDefaultEndpoints() throws IOException {
3333
var allModels = getAllModels();
3434
var chatCompletionModels = getModels("_all", TaskType.CHAT_COMPLETION);
3535

36-
assertThat(allModels, hasSize(5));
36+
assertThat(allModels, hasSize(6));
3737
assertThat(chatCompletionModels, hasSize(1));
3838

3939
for (var model : chatCompletionModels) {
@@ -42,6 +42,7 @@ public void testGetDefaultEndpoints() throws IOException {
4242

4343
assertInferenceIdTaskType(allModels, ".rainbow-sprinkles-elastic", TaskType.CHAT_COMPLETION);
4444
assertInferenceIdTaskType(allModels, ".elser-v2-elastic", TaskType.SPARSE_EMBEDDING);
45+
assertInferenceIdTaskType(allModels, ".rerank-v1-elastic", TaskType.RERANK);
4546
}
4647

4748
private static void assertInferenceIdTaskType(List<Map<String, Object>> models, String inferenceId, TaskType taskType) {

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.assertStatusOkOrCreated;
2222
import static org.hamcrest.Matchers.containsInAnyOrder;
23+
import static org.hamcrest.Matchers.equalTo;
2324

2425
public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
2526

@@ -107,6 +108,11 @@ private Iterable<String> providersFor(TaskType taskType) throws IOException {
107108
}
108109

109110
public void testGetServicesWithRerankTaskType() throws IOException {
111+
List<Object> services = getServices(TaskType.RERANK);
112+
assertThat(services.size(), equalTo(10));
113+
114+
var providers = providers(services);
115+
110116
assertThat(
111117
providersFor(TaskType.RERANK),
112118
containsInAnyOrder(
@@ -120,7 +126,9 @@ public void testGetServicesWithRerankTaskType() throws IOException {
120126
"test_reranking_service",
121127
"voyageai",
122128
"hugging_face",
123-
"amazon_sagemaker"
129+
"amazon_sagemaker",
130+
"hugging_face",
131+
"elastic"
124132
).toArray()
125133
)
126134
);

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ public void enqueueAuthorizeAllModelsResponse() {
4141
{
4242
"model_name": "elser-v2",
4343
"task_types": ["embed/text/sparse"]
44+
},
45+
{
46+
"model_name": "rerank-v1",
47+
"task_types": ["rerank/text/text-similarity"]
4448
}
4549
]
4650
}

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,10 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
198198
{
199199
"model_name": "elser-v2",
200200
"task_types": ["embed/text/sparse"]
201+
},
202+
{
203+
"model_name": "rerank-v1",
204+
"task_types": ["rerank/text/text-similarity"]
201205
}
202206
]
203207
}
@@ -222,16 +226,25 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
222226
".rainbow-sprinkles-elastic",
223227
MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME),
224228
service
229+
),
230+
new InferenceService.DefaultConfigId(
231+
".rerank-v1-elastic",
232+
MinimalServiceSettings.rerank(ElasticInferenceService.NAME),
233+
service
225234
)
226235
)
227236
)
228237
);
229-
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING)));
238+
assertThat(
239+
service.supportedTaskTypes(),
240+
is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.RERANK))
241+
);
230242

231243
PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
232244
service.defaultConfigs(listener);
233245
assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".elser-v2-elastic"));
234246
assertThat(listener.actionGet(TIMEOUT).get(1).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
247+
assertThat(listener.actionGet(TIMEOUT).get(2).getConfigurations().getInferenceEntityId(), is(".rerank-v1-elastic"));
235248

236249
var getModelListener = new PlainActionFuture<UnparsedModel>();
237250
// persists the default endpoints
@@ -249,6 +262,10 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
249262
{
250263
"model_name": "elser-v2",
251264
"task_types": ["embed/text/sparse"]
265+
},
266+
{
267+
"model_name": "rerank-v1",
268+
"task_types": ["rerank/text/text-similarity"]
252269
}
253270
]
254271
}
@@ -268,11 +285,16 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
268285
".elser-v2-elastic",
269286
MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME),
270287
service
288+
),
289+
new InferenceService.DefaultConfigId(
290+
".rerank-v1-elastic",
291+
MinimalServiceSettings.rerank(ElasticInferenceService.NAME),
292+
service
271293
)
272294
)
273295
)
274296
);
275-
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING)));
297+
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.RERANK)));
276298

277299
var getModelListener = new PlainActionFuture<UnparsedModel>();
278300
modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
5656
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
5757
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel;
58+
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankServiceSettings;
5859
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
5960
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
6061

@@ -98,6 +99,10 @@ public class ElasticInferenceService extends SenderService {
9899
static final String DEFAULT_ELSER_MODEL_ID_V2 = "elser-v2";
99100
static final String DEFAULT_ELSER_ENDPOINT_ID_V2 = defaultEndpointId(DEFAULT_ELSER_MODEL_ID_V2);
100101

102+
// rerank-v1
103+
static final String DEFAULT_RERANK_MODEL_ID_V1 = "rerank-v1";
104+
static final String DEFAULT_RERANK_ENDPOINT_ID_V1 = defaultEndpointId(DEFAULT_RERANK_MODEL_ID_V1);
105+
101106
/**
102107
* The task types that the {@link InferenceAction.Request} can accept.
103108
*/
@@ -163,6 +168,19 @@ private static Map<String, DefaultModelConfig> initDefaultEndpoints(
163168
ChunkingSettingsBuilder.DEFAULT_SETTINGS
164169
),
165170
MinimalServiceSettings.sparseEmbedding(NAME)
171+
),
172+
DEFAULT_RERANK_MODEL_ID_V1,
173+
new DefaultModelConfig(
174+
new ElasticInferenceServiceRerankModel(
175+
DEFAULT_RERANK_ENDPOINT_ID_V1,
176+
TaskType.RERANK,
177+
NAME,
178+
new ElasticInferenceServiceRerankServiceSettings(DEFAULT_RERANK_MODEL_ID_V1, null),
179+
EmptyTaskSettings.INSTANCE,
180+
EmptySecretSettings.INSTANCE,
181+
elasticInferenceServiceComponents
182+
),
183+
MinimalServiceSettings.rerank(NAME)
166184
)
167185
);
168186
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankModel.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ public URI uri() {
8787
private URI createUri() throws ElasticsearchStatusException {
8888
try {
8989
// TODO, consider transforming the base URL into a URI for better error handling.
90-
return new URI(elasticInferenceServiceComponents().elasticInferenceServiceUrl() + "/api/v1/rerank");
90+
return new URI(elasticInferenceServiceComponents().elasticInferenceServiceUrl() + "/api/v1/rerank/text/text-similarity");
9191
} catch (URISyntaxException e) {
9292
throw new ElasticsearchStatusException(
9393
"Failed to create URI for service ["

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntity.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ public class ElasticInferenceServiceAuthorizationResponseEntity implements Infer
4343
"embed/text/sparse",
4444
TaskType.SPARSE_EMBEDDING,
4545
"chat",
46-
TaskType.CHAT_COMPLETION
46+
TaskType.CHAT_COMPLETION,
47+
"rerank/text/text-similarity",
48+
TaskType.RERANK
4749
);
4850

4951
@SuppressWarnings("unchecked")

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1291,6 +1291,10 @@ public void testDefaultConfigs_Returns_DefaultEndpoints_WhenTaskTypeIsCorrect()
12911291
{
12921292
"model_name": "elser-v2",
12931293
"task_types": ["embed/text/sparse"]
1294+
},
1295+
{
1296+
"model_name": "rerank-v1",
1297+
"task_types": ["rerank/text/text-similarity"]
12941298
}
12951299
]
12961300
}
@@ -1316,18 +1320,25 @@ public void testDefaultConfigs_Returns_DefaultEndpoints_WhenTaskTypeIsCorrect()
13161320
".rainbow-sprinkles-elastic",
13171321
MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME),
13181322
service
1323+
),
1324+
new InferenceService.DefaultConfigId(
1325+
".rerank-v1-elastic",
1326+
MinimalServiceSettings.rerank(ElasticInferenceService.NAME),
1327+
service
13191328
)
13201329
)
13211330
)
13221331
);
1323-
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING)));
1332+
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.RERANK)));
13241333

13251334
PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
13261335
service.defaultConfigs(listener);
13271336
var models = listener.actionGet(TIMEOUT);
1328-
assertThat(models.size(), is(2));
1337+
assertThat(models.size(), is(3));
13291338
assertThat(models.get(0).getConfigurations().getInferenceEntityId(), is(".elser-v2-elastic"));
13301339
assertThat(models.get(1).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
1340+
assertThat(models.get(2).getConfigurations().getInferenceEntityId(), is(".rerank-v1-elastic"));
1341+
13311342
}
13321343
}
13331344

0 commit comments

Comments
 (0)