Skip to content

Commit f8ec5b6

Browse files
committed
[Inference API] Auto-propagate product origin for every subclass of ElasticInferenceServiceRequest (elastic#123141)
(cherry picked from commit 08aa668)
1 parent 494ec9f commit f8ec5b6

11 files changed

+125
-19
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ElasticInferenceServiceSparseEmbeddingsRequestManager.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.elasticsearch.action.ActionListener;
1313
import org.elasticsearch.inference.InferenceServiceResults;
1414
import org.elasticsearch.inference.InputType;
15+
import org.elasticsearch.tasks.Task;
1516
import org.elasticsearch.xpack.inference.common.Truncator;
1617
import org.elasticsearch.xpack.inference.external.elastic.ElasticInferenceServiceResponseHandler;
1718
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
@@ -43,6 +44,8 @@ public class ElasticInferenceServiceSparseEmbeddingsRequestManager extends Elast
4344

4445
private final InputType inputType;
4546

47+
private final String productOrigin;
48+
4649
private static ResponseHandler createSparseEmbeddingsHandler() {
4750
return new ElasticInferenceServiceResponseHandler(
4851
String.format(Locale.ROOT, "%s sparse embeddings", ELASTIC_INFERENCE_SERVICE_IDENTIFIER),
@@ -60,6 +63,7 @@ public ElasticInferenceServiceSparseEmbeddingsRequestManager(
6063
this.model = model;
6164
this.truncator = serviceComponents.truncator();
6265
this.traceContext = traceContext;
66+
this.productOrigin = serviceComponents.threadPool().getThreadContext().getHeader(Task.X_ELASTIC_PRODUCT_ORIGIN_HTTP_HEADER);
6367
this.inputType = inputType;
6468
}
6569

@@ -78,6 +82,7 @@ public void execute(
7882
truncatedInput,
7983
model,
8084
traceContext,
85+
productOrigin,
8186
inputType
8287
);
8388
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ElasticInferenceServiceUnifiedCompletionRequestManager.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.apache.logging.log4j.Logger;
1212
import org.elasticsearch.action.ActionListener;
1313
import org.elasticsearch.inference.InferenceServiceResults;
14+
import org.elasticsearch.tasks.Task;
1415
import org.elasticsearch.threadpool.ThreadPool;
1516
import org.elasticsearch.xpack.inference.external.elastic.ElasticInferenceServiceUnifiedChatCompletionResponseHandler;
1617
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
@@ -43,6 +44,7 @@ public static ElasticInferenceServiceUnifiedCompletionRequestManager of(
4344

4445
private final ElasticInferenceServiceCompletionModel model;
4546
private final TraceContext traceContext;
47+
private final String productOrigin;
4648

4749
private ElasticInferenceServiceUnifiedCompletionRequestManager(
4850
ElasticInferenceServiceCompletionModel model,
@@ -52,6 +54,7 @@ private ElasticInferenceServiceUnifiedCompletionRequestManager(
5254
super(threadPool, model);
5355
this.model = model;
5456
this.traceContext = traceContext;
57+
this.productOrigin = threadPool.getThreadContext().getHeader(Task.X_ELASTIC_PRODUCT_ORIGIN_HTTP_HEADER);
5558
}
5659

5760
@Override
@@ -65,7 +68,8 @@ public void execute(
6568
ElasticInferenceServiceUnifiedChatCompletionRequest request = new ElasticInferenceServiceUnifiedChatCompletionRequest(
6669
inferenceInputs.castTo(UnifiedChatInput.class),
6770
model,
68-
traceContext
71+
traceContext,
72+
productOrigin
6973
);
7074

7175
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceAuthorizationRequest.java

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
package org.elasticsearch.xpack.inference.external.request.elastic;
99

1010
import org.apache.http.client.methods.HttpGet;
11+
import org.apache.http.client.methods.HttpRequestBase;
1112
import org.elasticsearch.ElasticsearchStatusException;
1213
import org.elasticsearch.rest.RestStatus;
13-
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
1414
import org.elasticsearch.xpack.inference.external.request.Request;
1515
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
1616
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
@@ -20,12 +20,13 @@
2020
import java.net.URISyntaxException;
2121
import java.util.Objects;
2222

23-
public class ElasticInferenceServiceAuthorizationRequest implements ElasticInferenceServiceRequest {
23+
public class ElasticInferenceServiceAuthorizationRequest extends ElasticInferenceServiceRequest {
2424

2525
private final URI uri;
2626
private final TraceContextHandler traceContextHandler;
2727

28-
public ElasticInferenceServiceAuthorizationRequest(String url, TraceContext traceContext) {
28+
public ElasticInferenceServiceAuthorizationRequest(String url, TraceContext traceContext, String productOrigin) {
29+
super(productOrigin);
2930
this.uri = createUri(Objects.requireNonNull(url));
3031
this.traceContextHandler = new TraceContextHandler(traceContext);
3132
}
@@ -44,11 +45,11 @@ private URI createUri(String url) throws ElasticsearchStatusException {
4445
}
4546

4647
@Override
47-
public HttpRequest createHttpRequest() {
48+
public HttpRequestBase createHttpRequestBase() {
4849
var httpGet = new HttpGet(uri);
4950
traceContextHandler.propagateTraceContext(httpGet);
5051

51-
return new HttpRequest(httpGet, getInferenceEntityId());
52+
return httpGet;
5253
}
5354

5455
public TraceContext getTraceContext() {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRequest.java

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,30 @@
77

88
package org.elasticsearch.xpack.inference.external.request.elastic;
99

10+
import org.apache.http.client.methods.HttpRequestBase;
11+
import org.elasticsearch.tasks.Task;
12+
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
1013
import org.elasticsearch.xpack.inference.external.request.Request;
1114

12-
public interface ElasticInferenceServiceRequest extends Request {}
15+
public abstract class ElasticInferenceServiceRequest implements Request {
16+
17+
private final String productOrigin;
18+
19+
public ElasticInferenceServiceRequest(String productOrigin) {
20+
this.productOrigin = productOrigin;
21+
}
22+
23+
public String getProductOrigin() {
24+
return productOrigin;
25+
}
26+
27+
@Override
28+
public final HttpRequest createHttpRequest() {
29+
HttpRequestBase request = createHttpRequestBase();
30+
// TODO: consider moving tracing here, too
31+
request.setHeader(Task.X_ELASTIC_PRODUCT_ORIGIN_HTTP_HEADER, productOrigin);
32+
return new HttpRequest(request, getInferenceEntityId());
33+
}
34+
35+
protected abstract HttpRequestBase createHttpRequestBase();
36+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequest.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@
99

1010
import org.apache.http.HttpHeaders;
1111
import org.apache.http.client.methods.HttpPost;
12+
import org.apache.http.client.methods.HttpRequestBase;
1213
import org.apache.http.entity.ByteArrayEntity;
1314
import org.apache.http.message.BasicHeader;
1415
import org.elasticsearch.common.Strings;
1516
import org.elasticsearch.inference.InputType;
1617
import org.elasticsearch.xcontent.XContentType;
1718
import org.elasticsearch.xpack.inference.common.Truncator;
18-
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
1919
import org.elasticsearch.xpack.inference.external.request.Request;
2020
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModel;
2121
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceUsageContext;
@@ -26,7 +26,7 @@
2626
import java.nio.charset.StandardCharsets;
2727
import java.util.Objects;
2828

29-
public class ElasticInferenceServiceSparseEmbeddingsRequest implements ElasticInferenceServiceRequest {
29+
public class ElasticInferenceServiceSparseEmbeddingsRequest extends ElasticInferenceServiceRequest {
3030

3131
private final URI uri;
3232
private final ElasticInferenceServiceSparseEmbeddingsModel model;
@@ -40,8 +40,10 @@ public ElasticInferenceServiceSparseEmbeddingsRequest(
4040
Truncator.TruncationResult truncationResult,
4141
ElasticInferenceServiceSparseEmbeddingsModel model,
4242
TraceContext traceContext,
43+
String productOrigin,
4344
InputType inputType
4445
) {
46+
super(productOrigin);
4547
this.truncator = truncator;
4648
this.truncationResult = truncationResult;
4749
this.model = Objects.requireNonNull(model);
@@ -51,7 +53,7 @@ public ElasticInferenceServiceSparseEmbeddingsRequest(
5153
}
5254

5355
@Override
54-
public HttpRequest createHttpRequest() {
56+
public HttpRequestBase createHttpRequestBase() {
5557
var httpPost = new HttpPost(uri);
5658
var usageContext = inputTypeToUsageContext(inputType);
5759
var requestEntity = Strings.toString(
@@ -68,7 +70,7 @@ public HttpRequest createHttpRequest() {
6870
traceContextHandler.propagateTraceContext(httpPost);
6971
httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()));
7072

71-
return new HttpRequest(httpPost, getInferenceEntityId());
73+
return httpPost;
7274
}
7375

7476
public TraceContext getTraceContext() {
@@ -93,6 +95,7 @@ public Request truncate() {
9395
truncatedInput,
9496
model,
9597
traceContextHandler.traceContext(),
98+
getProductOrigin(),
9699
inputType
97100
);
98101
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequest.java

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99

1010
import org.apache.http.HttpHeaders;
1111
import org.apache.http.client.methods.HttpPost;
12+
import org.apache.http.client.methods.HttpRequestBase;
1213
import org.apache.http.entity.ByteArrayEntity;
1314
import org.apache.http.message.BasicHeader;
1415
import org.elasticsearch.common.Strings;
1516
import org.elasticsearch.xcontent.XContentType;
1617
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
17-
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
1818
import org.elasticsearch.xpack.inference.external.request.Request;
1919
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
2020
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
@@ -24,7 +24,7 @@
2424
import java.nio.charset.StandardCharsets;
2525
import java.util.Objects;
2626

27-
public class ElasticInferenceServiceUnifiedChatCompletionRequest implements Request {
27+
public class ElasticInferenceServiceUnifiedChatCompletionRequest extends ElasticInferenceServiceRequest {
2828

2929
private final ElasticInferenceServiceCompletionModel model;
3030
private final UnifiedChatInput unifiedChatInput;
@@ -33,15 +33,17 @@ public class ElasticInferenceServiceUnifiedChatCompletionRequest implements Requ
3333
public ElasticInferenceServiceUnifiedChatCompletionRequest(
3434
UnifiedChatInput unifiedChatInput,
3535
ElasticInferenceServiceCompletionModel model,
36-
TraceContext traceContext
36+
TraceContext traceContext,
37+
String productOrigin
3738
) {
39+
super(productOrigin);
3840
this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput);
3941
this.model = Objects.requireNonNull(model);
4042
this.traceContextHandler = new TraceContextHandler(traceContext);
4143
}
4244

4345
@Override
44-
public HttpRequest createHttpRequest() {
46+
public HttpRequestBase createHttpRequestBase() {
4547
var httpPost = new HttpPost(model.uri());
4648
var requestEntity = Strings.toString(
4749
new ElasticInferenceServiceUnifiedChatCompletionRequestEntity(unifiedChatInput, model.getServiceSettings().modelId())
@@ -53,7 +55,7 @@ public HttpRequest createHttpRequest() {
5355
traceContextHandler.propagateTraceContext(httpPost);
5456
httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()));
5557

56-
return new HttpRequest(httpPost, getInferenceEntityId());
58+
return httpPost;
5759
}
5860

5961
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,8 @@ public void getAuthorization(ActionListener<ElasticInferenceServiceAuthorization
108108
requestCompleteLatch.countDown();
109109
});
110110

111-
var request = new ElasticInferenceServiceAuthorizationRequest(baseUrl, getCurrentTraceInfo());
111+
var productOrigin = threadPool.getThreadContext().getHeader(Task.X_ELASTIC_PRODUCT_ORIGIN_HTTP_HEADER);
112+
var request = new ElasticInferenceServiceAuthorizationRequest(baseUrl, getCurrentTraceInfo(), productOrigin);
112113

113114
sender.sendWithoutQueuing(logger, request, AUTH_RESPONSE_HANDLER, DEFAULT_AUTH_TIMEOUT, newListener);
114115
} catch (Exception e) {

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,11 @@ public void testSendWithoutQueuing_SendsRequestAndReceivesResponse() throws Exce
158158
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
159159

160160
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
161-
var request = new ElasticInferenceServiceAuthorizationRequest(getUrl(webServer), new TraceContext("", ""));
161+
var request = new ElasticInferenceServiceAuthorizationRequest(
162+
getUrl(webServer),
163+
new TraceContext("", ""),
164+
randomAlphaOfLength(10)
165+
);
162166
var responseHandler = new ElasticInferenceServiceResponseHandler(
163167
String.format(Locale.ROOT, "%s sparse embeddings", ELASTIC_INFERENCE_SERVICE_IDENTIFIER),
164168
ElasticInferenceServiceAuthorizationResponseEntity::fromResponse

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceAuthorizationRequestTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ public void testCreateUriThrowsForInvalidBaseUrl() {
3030

3131
ElasticsearchStatusException exception = assertThrows(
3232
ElasticsearchStatusException.class,
33-
() -> new ElasticInferenceServiceAuthorizationRequest(invalidUrl, traceContext)
33+
() -> new ElasticInferenceServiceAuthorizationRequest(invalidUrl, traceContext, randomAlphaOfLength(10))
3434
);
3535

3636
assertThat(exception.status(), is(RestStatus.BAD_REQUEST));
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.request.elastic;
9+
10+
import org.apache.http.client.methods.HttpGet;
11+
import org.apache.http.client.methods.HttpRequestBase;
12+
import org.elasticsearch.tasks.Task;
13+
import org.elasticsearch.test.ESTestCase;
14+
import org.elasticsearch.xpack.inference.external.request.Request;
15+
16+
import java.net.URI;
17+
18+
import static org.hamcrest.Matchers.equalTo;
19+
20+
public class ElasticInferenceServiceRequestTests extends ESTestCase {
21+
22+
public void testElasticInferenceServiceRequestSubclasses_Decorate_HttpRequest_WithProductOrigin() {
23+
var productOrigin = "elastic";
24+
var elasticInferenceServiceRequestWrapper = getDummyElasticInferenceServiceRequest(productOrigin);
25+
var httpRequest = elasticInferenceServiceRequestWrapper.createHttpRequest();
26+
var productOriginHeader = httpRequest.httpRequestBase().getFirstHeader(Task.X_ELASTIC_PRODUCT_ORIGIN_HTTP_HEADER);
27+
28+
// Make sure this header only exists once
29+
assertThat(httpRequest.httpRequestBase().getHeaders(Task.X_ELASTIC_PRODUCT_ORIGIN_HTTP_HEADER).length, equalTo(1));
30+
assertThat(productOriginHeader.getValue(), equalTo(productOrigin));
31+
}
32+
33+
private static ElasticInferenceServiceRequest getDummyElasticInferenceServiceRequest(String productOrigin) {
34+
return new ElasticInferenceServiceRequest(productOrigin) {
35+
@Override
36+
protected HttpRequestBase createHttpRequestBase() {
37+
return new HttpGet("http://localhost:8080");
38+
}
39+
40+
@Override
41+
public URI getURI() {
42+
return null;
43+
}
44+
45+
@Override
46+
public Request truncate() {
47+
return null;
48+
}
49+
50+
@Override
51+
public boolean[] getTruncationInfo() {
52+
return new boolean[0];
53+
}
54+
55+
@Override
56+
public String getInferenceEntityId() {
57+
return "";
58+
}
59+
};
60+
}
61+
}

0 commit comments

Comments
 (0)