Skip to content

Commit 9cf89a6

Browse files
[ML] Inference API removing _unified and using _stream instead (elastic#121804) (elastic#122045)
* Adding proxy action * [CI] Auto commit changes from spotless * Incrementing reference count for body content and fixing tests * [CI] Auto commit changes from spotless * Refactoring * Update x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceActionProxy.java Co-authored-by: David Kyle <[email protected]> * Addressing feedback --------- Co-authored-by: elasticsearchmachine <[email protected]> Co-authored-by: David Kyle <[email protected]> (cherry picked from commit ab48235) # Conflicts: # x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java
1 parent dbb052d commit 9cf89a6

File tree

22 files changed

+559
-245
lines changed

22 files changed

+559
-245
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
public class InferenceAction extends ActionType<InferenceAction.Response> {
4848

4949
public static final InferenceAction INSTANCE = new InferenceAction();
50-
public static final String NAME = "cluster:monitor/xpack/inference";
50+
public static final String NAME = "cluster:internal/xpack/inference";
5151

5252
public InferenceAction() {
5353
super(NAME);
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.core.inference.action;
9+
10+
import org.elasticsearch.action.ActionRequest;
11+
import org.elasticsearch.action.ActionRequestValidationException;
12+
import org.elasticsearch.action.ActionType;
13+
import org.elasticsearch.common.bytes.BytesReference;
14+
import org.elasticsearch.common.io.stream.StreamInput;
15+
import org.elasticsearch.common.io.stream.StreamOutput;
16+
import org.elasticsearch.common.xcontent.XContentHelper;
17+
import org.elasticsearch.core.TimeValue;
18+
import org.elasticsearch.inference.TaskType;
19+
import org.elasticsearch.xcontent.XContentType;
20+
21+
import java.io.IOException;
22+
import java.util.Objects;
23+
24+
/**
25+
* This action is used when making a REST request to the inference API. The transport handler
26+
* will then look at the task type in the params (or retrieve it from the persisted model if it wasn't
27+
* included in the params) to determine where this request should be routed. If the task type is chat completion
28+
* then it will be routed to the unified chat completion handler by creating the {@link UnifiedCompletionAction}.
29+
* If not, it will be passed along to {@link InferenceAction}.
30+
*/
31+
public class InferenceActionProxy extends ActionType<InferenceAction.Response> {
32+
public static final InferenceActionProxy INSTANCE = new InferenceActionProxy();
33+
public static final String NAME = "cluster:monitor/xpack/inference/post";
34+
35+
public InferenceActionProxy() {
36+
super(NAME);
37+
}
38+
39+
public static class Request extends ActionRequest {
40+
41+
private final TaskType taskType;
42+
private final String inferenceEntityId;
43+
private final BytesReference content;
44+
private final XContentType contentType;
45+
private final TimeValue timeout;
46+
private final boolean stream;
47+
48+
public Request(
49+
TaskType taskType,
50+
String inferenceEntityId,
51+
BytesReference content,
52+
XContentType contentType,
53+
TimeValue timeout,
54+
boolean stream
55+
) {
56+
this.taskType = taskType;
57+
this.inferenceEntityId = inferenceEntityId;
58+
this.content = content;
59+
this.contentType = contentType;
60+
this.timeout = timeout;
61+
this.stream = stream;
62+
}
63+
64+
public Request(StreamInput in) throws IOException {
65+
super(in);
66+
this.taskType = TaskType.fromStream(in);
67+
this.inferenceEntityId = in.readString();
68+
this.content = in.readBytesReference();
69+
this.contentType = in.readEnum(XContentType.class);
70+
this.timeout = in.readTimeValue();
71+
72+
// streaming is not supported yet for transport traffic
73+
this.stream = false;
74+
}
75+
76+
public TaskType getTaskType() {
77+
return taskType;
78+
}
79+
80+
public String getInferenceEntityId() {
81+
return inferenceEntityId;
82+
}
83+
84+
public BytesReference getContent() {
85+
return content;
86+
}
87+
88+
public XContentType getContentType() {
89+
return contentType;
90+
}
91+
92+
public TimeValue getTimeout() {
93+
return timeout;
94+
}
95+
96+
public boolean isStreaming() {
97+
return stream;
98+
}
99+
100+
@Override
101+
public ActionRequestValidationException validate() {
102+
return null;
103+
}
104+
105+
@Override
106+
public void writeTo(StreamOutput out) throws IOException {
107+
super.writeTo(out);
108+
out.writeString(inferenceEntityId);
109+
taskType.writeTo(out);
110+
out.writeBytesReference(content);
111+
XContentHelper.writeTo(out, contentType);
112+
out.writeTimeValue(timeout);
113+
}
114+
115+
@Override
116+
public boolean equals(Object o) {
117+
if (this == o) return true;
118+
if (o == null || getClass() != o.getClass()) return false;
119+
Request request = (Request) o;
120+
return taskType == request.taskType
121+
&& Objects.equals(inferenceEntityId, request.inferenceEntityId)
122+
&& Objects.equals(content, request.content)
123+
&& contentType == request.contentType
124+
&& timeout == request.timeout
125+
&& stream == request.stream;
126+
}
127+
128+
@Override
129+
public int hashCode() {
130+
return Objects.hash(taskType, inferenceEntityId, content, contentType, timeout, stream);
131+
}
132+
}
133+
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionAction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
public class UnifiedCompletionAction extends ActionType<InferenceAction.Response> {
2323
public static final UnifiedCompletionAction INSTANCE = new UnifiedCompletionAction();
24-
public static final String NAME = "cluster:monitor/xpack/inference/unified";
24+
public static final String NAME = "cluster:internal/xpack/inference/unified";
2525

2626
public UnifiedCompletionAction() {
2727
super(NAME);

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/store/ReservedRolesStoreTests.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4219,7 +4219,7 @@ public void testInferenceAdminRole() {
42194219
assertThat(roleDescriptor.getMetadata(), hasEntry("_reserved", true));
42204220

42214221
Role role = Role.buildFromRoleDescriptor(roleDescriptor, new FieldPermissionsCache(Settings.EMPTY), RESTRICTED_INDICES);
4222-
assertTrue(role.cluster().check("cluster:monitor/xpack/inference", request, authentication));
4222+
assertTrue(role.cluster().check("cluster:monitor/xpack/inference/post", request, authentication));
42234223
assertTrue(role.cluster().check("cluster:monitor/xpack/inference/get", request, authentication));
42244224
assertTrue(role.cluster().check("cluster:admin/xpack/inference/put", request, authentication));
42254225
assertTrue(role.cluster().check("cluster:admin/xpack/inference/delete", request, authentication));
@@ -4239,10 +4239,9 @@ public void testInferenceUserRole() {
42394239
assertThat(roleDescriptor.getMetadata(), hasEntry("_reserved", true));
42404240

42414241
Role role = Role.buildFromRoleDescriptor(roleDescriptor, new FieldPermissionsCache(Settings.EMPTY), RESTRICTED_INDICES);
4242-
assertTrue(role.cluster().check("cluster:monitor/xpack/inference", request, authentication));
4242+
assertTrue(role.cluster().check("cluster:monitor/xpack/inference/post", request, authentication));
42434243
assertTrue(role.cluster().check("cluster:monitor/xpack/inference/get", request, authentication));
42444244
assertFalse(role.cluster().check("cluster:admin/xpack/inference/put", request, authentication));
4245-
assertTrue(role.cluster().check("cluster:monitor/xpack/inference/unified", request, authentication));
42464245
assertFalse(role.cluster().check("cluster:admin/xpack/inference/delete", request, authentication));
42474246
assertTrue(role.cluster().check("cluster:monitor/xpack/ml/trained_models/deployment/infer", request, authentication));
42484247
assertFalse(role.cluster().check("cluster:admin/xpack/ml/trained_models/deployment/start", request, authentication));

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -360,8 +360,7 @@ protected Deque<ServerSentEvent> unifiedCompletionInferOnMockService(
360360
List<String> input,
361361
@Nullable Consumer<Response> responseConsumerCallback
362362
) throws Exception {
363-
var route = randomBoolean() ? "_stream" : "_unified"; // TODO remove unified route
364-
var endpoint = Strings.format("_inference/%s/%s/%s", taskType, modelId, route);
363+
var endpoint = Strings.format("_inference/%s/%s/_stream", taskType, modelId);
365364
return callAsyncUnified(endpoint, input, "user", responseConsumerCallback);
366365
}
367366

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
5959
import org.elasticsearch.xpack.core.inference.action.GetInferenceServicesAction;
6060
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
61+
import org.elasticsearch.xpack.core.inference.action.InferenceActionProxy;
6162
import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction;
6263
import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction;
6364
import org.elasticsearch.xpack.core.inference.action.UpdateInferenceModelAction;
@@ -67,6 +68,7 @@
6768
import org.elasticsearch.xpack.inference.action.TransportGetInferenceModelAction;
6869
import org.elasticsearch.xpack.inference.action.TransportGetInferenceServicesAction;
6970
import org.elasticsearch.xpack.inference.action.TransportInferenceAction;
71+
import org.elasticsearch.xpack.inference.action.TransportInferenceActionProxy;
7072
import org.elasticsearch.xpack.inference.action.TransportInferenceUsageAction;
7173
import org.elasticsearch.xpack.inference.action.TransportPutInferenceModelAction;
7274
import org.elasticsearch.xpack.inference.action.TransportUnifiedCompletionInferenceAction;
@@ -104,7 +106,6 @@
104106
import org.elasticsearch.xpack.inference.rest.RestInferenceAction;
105107
import org.elasticsearch.xpack.inference.rest.RestPutInferenceModelAction;
106108
import org.elasticsearch.xpack.inference.rest.RestStreamInferenceAction;
107-
import org.elasticsearch.xpack.inference.rest.RestUnifiedCompletionInferenceAction;
108109
import org.elasticsearch.xpack.inference.rest.RestUpdateInferenceModelAction;
109110
import org.elasticsearch.xpack.inference.services.ServiceComponents;
110111
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchService;
@@ -195,6 +196,7 @@ public InferencePlugin(Settings settings) {
195196
public List<ActionHandler<? extends ActionRequest, ? extends ActionResponse>> getActions() {
196197
return List.of(
197198
new ActionHandler<>(InferenceAction.INSTANCE, TransportInferenceAction.class),
199+
new ActionHandler<>(InferenceActionProxy.INSTANCE, TransportInferenceActionProxy.class),
198200
new ActionHandler<>(GetInferenceModelAction.INSTANCE, TransportGetInferenceModelAction.class),
199201
new ActionHandler<>(PutInferenceModelAction.INSTANCE, TransportPutInferenceModelAction.class),
200202
new ActionHandler<>(UpdateInferenceModelAction.INSTANCE, TransportUpdateInferenceModelAction.class),
@@ -226,8 +228,7 @@ public List<RestHandler> getRestHandlers(
226228
new RestUpdateInferenceModelAction(),
227229
new RestDeleteInferenceEndpointAction(),
228230
new RestGetInferenceDiagnosticsAction(),
229-
new RestGetInferenceServicesAction(),
230-
new RestUnifiedCompletionInferenceAction(threadPoolSetOnce)
231+
new RestGetInferenceServicesAction()
231232
);
232233
}
233234

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.action;
9+
10+
import org.elasticsearch.ElasticsearchStatusException;
11+
import org.elasticsearch.action.ActionListener;
12+
import org.elasticsearch.action.support.ActionFilters;
13+
import org.elasticsearch.action.support.HandledTransportAction;
14+
import org.elasticsearch.client.internal.Client;
15+
import org.elasticsearch.common.util.concurrent.EsExecutors;
16+
import org.elasticsearch.common.xcontent.XContentHelper;
17+
import org.elasticsearch.inference.TaskType;
18+
import org.elasticsearch.inference.UnparsedModel;
19+
import org.elasticsearch.injection.guice.Inject;
20+
import org.elasticsearch.rest.RestStatus;
21+
import org.elasticsearch.tasks.Task;
22+
import org.elasticsearch.transport.TransportService;
23+
import org.elasticsearch.xcontent.XContentParserConfiguration;
24+
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
25+
import org.elasticsearch.xpack.core.inference.action.InferenceActionProxy;
26+
import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction;
27+
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
28+
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
29+
30+
import java.io.IOException;
31+
32+
import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN;
33+
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
34+
35+
public class TransportInferenceActionProxy extends HandledTransportAction<InferenceActionProxy.Request, InferenceAction.Response> {
36+
private final ModelRegistry modelRegistry;
37+
private final Client client;
38+
39+
@Inject
40+
public TransportInferenceActionProxy(
41+
TransportService transportService,
42+
ActionFilters actionFilters,
43+
ModelRegistry modelRegistry,
44+
Client client
45+
) {
46+
super(
47+
InferenceActionProxy.NAME,
48+
transportService,
49+
actionFilters,
50+
InferenceActionProxy.Request::new,
51+
EsExecutors.DIRECT_EXECUTOR_SERVICE
52+
);
53+
54+
this.modelRegistry = modelRegistry;
55+
this.client = client;
56+
}
57+
58+
@Override
59+
protected void doExecute(Task task, InferenceActionProxy.Request request, ActionListener<InferenceAction.Response> listener) {
60+
try {
61+
ActionListener<UnparsedModel> getModelListener = listener.delegateFailureAndWrap((l, unparsedModel) -> {
62+
if (unparsedModel.taskType() == TaskType.CHAT_COMPLETION) {
63+
sendUnifiedCompletionRequest(request, l);
64+
} else {
65+
sendInferenceActionRequest(request, l);
66+
}
67+
});
68+
69+
if (request.getTaskType() == TaskType.ANY) {
70+
modelRegistry.getModelWithSecrets(request.getInferenceEntityId(), getModelListener);
71+
} else if (request.getTaskType() == TaskType.CHAT_COMPLETION) {
72+
sendUnifiedCompletionRequest(request, listener);
73+
} else {
74+
sendInferenceActionRequest(request, listener);
75+
}
76+
} catch (Exception e) {
77+
listener.onFailure(e);
78+
}
79+
}
80+
81+
private void sendUnifiedCompletionRequest(InferenceActionProxy.Request request, ActionListener<InferenceAction.Response> listener) {
82+
// format any validation exceptions from the rest -> transport path as UnifiedChatCompletionException
83+
var unifiedErrorFormatListener = listener.delegateResponse((l, e) -> l.onFailure(UnifiedChatCompletionException.fromThrowable(e)));
84+
85+
try {
86+
if (request.isStreaming() == false) {
87+
throw new ElasticsearchStatusException(
88+
"The [chat_completion] task type only supports streaming, please try again with the _stream API",
89+
RestStatus.BAD_REQUEST
90+
);
91+
}
92+
93+
UnifiedCompletionAction.Request unifiedRequest;
94+
try (
95+
var parser = XContentHelper.createParser(XContentParserConfiguration.EMPTY, request.getContent(), request.getContentType())
96+
) {
97+
unifiedRequest = UnifiedCompletionAction.Request.parseRequest(
98+
request.getInferenceEntityId(),
99+
request.getTaskType(),
100+
request.getTimeout(),
101+
parser
102+
);
103+
}
104+
105+
executeAsyncWithOrigin(client, INFERENCE_ORIGIN, UnifiedCompletionAction.INSTANCE, unifiedRequest, unifiedErrorFormatListener);
106+
} catch (Exception e) {
107+
unifiedErrorFormatListener.onFailure(e);
108+
}
109+
}
110+
111+
private void sendInferenceActionRequest(InferenceActionProxy.Request request, ActionListener<InferenceAction.Response> listener)
112+
throws IOException {
113+
InferenceAction.Request.Builder inferenceActionRequestBuilder;
114+
try (var parser = XContentHelper.createParser(XContentParserConfiguration.EMPTY, request.getContent(), request.getContentType())) {
115+
inferenceActionRequestBuilder = InferenceAction.Request.parseRequest(
116+
request.getInferenceEntityId(),
117+
request.getTaskType(),
118+
parser
119+
);
120+
inferenceActionRequestBuilder.setInferenceTimeout(request.getTimeout()).setStream(request.isStreaming());
121+
}
122+
123+
executeAsyncWithOrigin(client, INFERENCE_ORIGIN, InferenceAction.INSTANCE, inferenceActionRequestBuilder.build(), listener);
124+
}
125+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/BaseInferenceAction.java

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import org.elasticsearch.rest.RestChannel;
1616
import org.elasticsearch.rest.RestRequest;
1717
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
18+
import org.elasticsearch.xpack.core.inference.action.InferenceActionProxy;
1819

1920
import java.io.IOException;
2021

@@ -41,21 +42,22 @@ static TimeValue parseTimeout(RestRequest restRequest) {
4142
@Override
4243
protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException {
4344
var params = parseParams(restRequest);
45+
var content = restRequest.requiredContent();
46+
var inferTimeout = parseTimeout(restRequest);
4447

45-
InferenceAction.Request.Builder requestBuilder;
46-
try (var parser = restRequest.contentParser()) {
47-
requestBuilder = InferenceAction.Request.parseRequest(params.inferenceEntityId(), params.taskType(), parser);
48-
}
48+
var request = new InferenceActionProxy.Request(
49+
params.taskType(),
50+
params.inferenceEntityId(),
51+
content,
52+
restRequest.getXContentType(),
53+
inferTimeout,
54+
shouldStream()
55+
);
4956

50-
var inferTimeout = parseTimeout(restRequest);
51-
requestBuilder.setInferenceTimeout(inferTimeout);
52-
var request = prepareInferenceRequest(requestBuilder);
53-
return channel -> client.execute(InferenceAction.INSTANCE, request, listener(channel));
57+
return channel -> client.execute(InferenceActionProxy.INSTANCE, request, ActionListener.withRef(listener(channel), content));
5458
}
5559

56-
protected InferenceAction.Request prepareInferenceRequest(InferenceAction.Request.Builder builder) {
57-
return builder.build();
58-
}
60+
protected abstract boolean shouldStream();
5961

6062
protected abstract ActionListener<InferenceAction.Response> listener(RestChannel channel);
6163
}

0 commit comments

Comments
 (0)