Skip to content

Commit f6f7b67

Browse files
committed
Adding support for specifying embedding type to Jina AI service settings (elastic#121548)
* Adding embeddings type to Jina AI service settings * Update docs/changelog/121548.yaml * Setting default similarity to L2 norm for binary embedding type (cherry picked from commit 6b2e566) # Conflicts: # server/src/main/java/org/elasticsearch/TransportVersions.java
1 parent 5c34bfb commit f6f7b67

File tree

14 files changed

+906
-140
lines changed

14 files changed

+906
-140
lines changed

docs/changelog/121548.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 121548
2+
summary: Adding support for specifying embedding type to Jina AI service settings
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ static TransportVersion def(int id) {
189189
public static final TransportVersion ESQL_RETRY_ON_SHARD_LEVEL_FAILURE_BACKPORT_8_19 = def(8_841_0_03);
190190
public static final TransportVersion ESQL_SUPPORT_PARTIAL_RESULTS_BACKPORT_8_19 = def(8_841_0_04);
191191
public static final TransportVersion VOYAGE_AI_INTEGRATION_ADDED_BACKPORT_8_X = def(8_841_0_05);
192+
public static final TransportVersion JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED_BACKPORT_8_19 = def(8_841_0_06);
192193

193194
/*
194195
* STOP! READ THIS FIRST! No, really,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIEmbeddingsRequest.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import org.elasticsearch.xpack.inference.external.jinaai.JinaAIAccount;
1515
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
1616
import org.elasticsearch.xpack.inference.external.request.Request;
17+
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType;
1718
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModel;
1819
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings;
1920

@@ -30,6 +31,7 @@ public class JinaAIEmbeddingsRequest extends JinaAIRequest {
3031
private final JinaAIEmbeddingsTaskSettings taskSettings;
3132
private final String model;
3233
private final String inferenceEntityId;
34+
private final JinaAIEmbeddingType embeddingType;
3335

3436
public JinaAIEmbeddingsRequest(List<String> input, JinaAIEmbeddingsModel embeddingsModel) {
3537
Objects.requireNonNull(embeddingsModel);
@@ -38,6 +40,7 @@ public JinaAIEmbeddingsRequest(List<String> input, JinaAIEmbeddingsModel embeddi
3840
this.input = Objects.requireNonNull(input);
3941
taskSettings = embeddingsModel.getTaskSettings();
4042
model = embeddingsModel.getServiceSettings().getCommonSettings().modelId();
43+
embeddingType = embeddingsModel.getServiceSettings().getEmbeddingType();
4144
inferenceEntityId = embeddingsModel.getInferenceEntityId();
4245
}
4346

@@ -46,7 +49,7 @@ public HttpRequest createHttpRequest() {
4649
HttpPost httpPost = new HttpPost(account.uri());
4750

4851
ByteArrayEntity byteEntity = new ByteArrayEntity(
49-
Strings.toString(new JinaAIEmbeddingsRequestEntity(input, taskSettings, model)).getBytes(StandardCharsets.UTF_8)
52+
Strings.toString(new JinaAIEmbeddingsRequestEntity(input, taskSettings, model, embeddingType)).getBytes(StandardCharsets.UTF_8)
5053
);
5154
httpPost.setEntity(byteEntity);
5255

@@ -75,6 +78,10 @@ public boolean[] getTruncationInfo() {
7578
return null;
7679
}
7780

81+
public JinaAIEmbeddingType getEmbeddingType() {
82+
return embeddingType;
83+
}
84+
7885
public static URI buildDefaultUri() throws URISyntaxException {
7986
return new URIBuilder().setScheme("https")
8087
.setHost(JinaAIUtils.HOST)

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIEmbeddingsRequestEntity.java

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.elasticsearch.inference.InputType;
1212
import org.elasticsearch.xcontent.ToXContentObject;
1313
import org.elasticsearch.xcontent.XContentBuilder;
14+
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType;
1415
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings;
1516

1617
import java.io.IOException;
@@ -19,9 +20,12 @@
1920

2021
import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings.invalidInputTypeMessage;
2122

22-
public record JinaAIEmbeddingsRequestEntity(List<String> input, JinaAIEmbeddingsTaskSettings taskSettings, @Nullable String model)
23-
implements
24-
ToXContentObject {
23+
public record JinaAIEmbeddingsRequestEntity(
24+
List<String> input,
25+
JinaAIEmbeddingsTaskSettings taskSettings,
26+
@Nullable String model,
27+
@Nullable JinaAIEmbeddingType embeddingType
28+
) implements ToXContentObject {
2529

2630
private static final String SEARCH_DOCUMENT = "retrieval.passage";
2731
private static final String SEARCH_QUERY = "retrieval.query";
@@ -30,6 +34,7 @@ public record JinaAIEmbeddingsRequestEntity(List<String> input, JinaAIEmbeddings
3034
private static final String INPUT_FIELD = "input";
3135
private static final String MODEL_FIELD = "model";
3236
public static final String TASK_TYPE_FIELD = "task";
37+
static final String EMBEDDING_TYPE_FIELD = "embedding_type";
3338

3439
public JinaAIEmbeddingsRequestEntity {
3540
Objects.requireNonNull(input);
@@ -43,6 +48,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
4348
builder.field(INPUT_FIELD, input);
4449
builder.field(MODEL_FIELD, model);
4550

51+
if (embeddingType != null) {
52+
builder.field(EMBEDDING_TYPE_FIELD, embeddingType.toRequestString());
53+
}
54+
4655
if (taskSettings.getInputType() != null) {
4756
builder.field(TASK_TYPE_FIELD, convertToString(taskSettings.getInputType()));
4857
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/jinaai/JinaAIEmbeddingsResponseEntity.java

Lines changed: 88 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,28 +9,54 @@
99

1010
package org.elasticsearch.xpack.inference.external.response.jinaai;
1111

12+
import org.elasticsearch.common.Strings;
1213
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
14+
import org.elasticsearch.core.CheckedFunction;
15+
import org.elasticsearch.inference.InferenceServiceResults;
1316
import org.elasticsearch.xcontent.XContentFactory;
1417
import org.elasticsearch.xcontent.XContentParser;
1518
import org.elasticsearch.xcontent.XContentParserConfiguration;
1619
import org.elasticsearch.xcontent.XContentType;
20+
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults;
21+
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
1722
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
1823
import org.elasticsearch.xpack.inference.external.http.HttpResult;
1924
import org.elasticsearch.xpack.inference.external.request.Request;
25+
import org.elasticsearch.xpack.inference.external.request.jinaai.JinaAIEmbeddingsRequest;
2026
import org.elasticsearch.xpack.inference.external.response.XContentUtils;
27+
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType;
2128

2229
import java.io.IOException;
30+
import java.util.Arrays;
2331
import java.util.List;
32+
import java.util.Map;
2433

2534
import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
2635
import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList;
2736
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.consumeUntilObjectEnd;
2837
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;
2938
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField;
39+
import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType.toLowerCase;
3040

3141
public class JinaAIEmbeddingsResponseEntity {
3242
private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in JinaAI embeddings response";
3343

44+
private static final Map<String, CheckedFunction<XContentParser, InferenceServiceResults, IOException>> EMBEDDING_PARSERS = Map.of(
45+
toLowerCase(JinaAIEmbeddingType.FLOAT),
46+
JinaAIEmbeddingsResponseEntity::parseFloatDataObject,
47+
toLowerCase(JinaAIEmbeddingType.BIT),
48+
JinaAIEmbeddingsResponseEntity::parseBitDataObject,
49+
toLowerCase(JinaAIEmbeddingType.BINARY),
50+
JinaAIEmbeddingsResponseEntity::parseBitDataObject
51+
);
52+
private static final String VALID_EMBEDDING_TYPES_STRING = supportedEmbeddingTypes();
53+
54+
private static String supportedEmbeddingTypes() {
55+
var validTypes = EMBEDDING_PARSERS.keySet().toArray(String[]::new);
56+
Arrays.sort(validTypes);
57+
return String.join(", ", validTypes);
58+
}
59+
3460
/**
3561
* Parses the JinaAI json response.
3662
* For a request like:
@@ -73,8 +99,21 @@ public class JinaAIEmbeddingsResponseEntity {
7399
* </code>
74100
* </pre>
75101
*/
76-
public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException {
102+
public static InferenceServiceResults fromResponse(Request request, HttpResult response) throws IOException {
103+
// embeddings type is not specified anywhere in the response so grab it from the request
104+
JinaAIEmbeddingsRequest embeddingsRequest = (JinaAIEmbeddingsRequest) request;
105+
var embeddingType = embeddingsRequest.getEmbeddingType().toString();
77106
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
107+
var embeddingValueParser = EMBEDDING_PARSERS.get(embeddingType);
108+
109+
if (embeddingValueParser == null) {
110+
throw new IllegalStateException(
111+
Strings.format(
112+
"Failed to find a supported embedding type for in the Jina AI embeddings response. Supported types are [%s]",
113+
VALID_EMBEDDING_TYPES_STRING
114+
)
115+
);
116+
}
78117

79118
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
80119
moveToFirstToken(jsonParser);
@@ -84,26 +123,66 @@ public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult
84123

85124
positionParserAtTokenAfterField(jsonParser, "data", FAILED_TO_FIND_FIELD_TEMPLATE);
86125

87-
List<TextEmbeddingFloatResults.Embedding> embeddingList = parseList(
88-
jsonParser,
89-
JinaAIEmbeddingsResponseEntity::parseEmbeddingObject
90-
);
91-
92-
return new TextEmbeddingFloatResults(embeddingList);
126+
return embeddingValueParser.apply(jsonParser);
93127
}
94128
}
95129

96-
private static TextEmbeddingFloatResults.Embedding parseEmbeddingObject(XContentParser parser) throws IOException {
130+
private static InferenceServiceResults parseFloatDataObject(XContentParser jsonParser) throws IOException {
131+
List<TextEmbeddingFloatResults.Embedding> embeddingList = parseList(
132+
jsonParser,
133+
JinaAIEmbeddingsResponseEntity::parseFloatEmbeddingObject
134+
);
135+
136+
return new TextEmbeddingFloatResults(embeddingList);
137+
}
138+
139+
private static TextEmbeddingFloatResults.Embedding parseFloatEmbeddingObject(XContentParser parser) throws IOException {
97140
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
98141

99142
positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE);
100143

101-
List<Float> embeddingValuesList = parseList(parser, XContentUtils::parseFloat);
144+
var embeddingValuesList = parseList(parser, XContentUtils::parseFloat);
102145
// parse and discard the rest of the object
103146
consumeUntilObjectEnd(parser);
104147

105148
return TextEmbeddingFloatResults.Embedding.of(embeddingValuesList);
106149
}
107150

151+
private static InferenceServiceResults parseBitDataObject(XContentParser jsonParser) throws IOException {
152+
List<TextEmbeddingByteResults.Embedding> embeddingList = parseList(
153+
jsonParser,
154+
JinaAIEmbeddingsResponseEntity::parseBitEmbeddingObject
155+
);
156+
157+
return new TextEmbeddingBitResults(embeddingList);
158+
}
159+
160+
private static TextEmbeddingByteResults.Embedding parseBitEmbeddingObject(XContentParser parser) throws IOException {
161+
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
162+
163+
positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE);
164+
165+
var embeddingList = parseList(parser, JinaAIEmbeddingsResponseEntity::parseEmbeddingInt8Entry);
166+
// parse and discard the rest of the object
167+
consumeUntilObjectEnd(parser);
168+
169+
return TextEmbeddingByteResults.Embedding.of(embeddingList);
170+
}
171+
172+
private static Byte parseEmbeddingInt8Entry(XContentParser parser) throws IOException {
173+
XContentParser.Token token = parser.currentToken();
174+
ensureExpectedToken(XContentParser.Token.VALUE_NUMBER, token, parser);
175+
var parsedByte = parser.shortValue();
176+
checkByteBounds(parsedByte);
177+
178+
return (byte) parsedByte;
179+
}
180+
181+
private static void checkByteBounds(short value) {
182+
if (value < Byte.MIN_VALUE || value > Byte.MAX_VALUE) {
183+
throw new IllegalArgumentException("Value [" + value + "] is out of range for a byte");
184+
}
185+
}
186+
108187
private JinaAIEmbeddingsResponseEntity() {}
109188
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import org.elasticsearch.xpack.inference.services.SenderService;
3939
import org.elasticsearch.xpack.inference.services.ServiceComponents;
4040
import org.elasticsearch.xpack.inference.services.ServiceUtils;
41+
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType;
4142
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModel;
4243
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsServiceSettings;
4344
import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankModel;
@@ -294,7 +295,7 @@ public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
294295
if (model instanceof JinaAIEmbeddingsModel embeddingsModel) {
295296
var serviceSettings = embeddingsModel.getServiceSettings();
296297
var similarityFromModel = serviceSettings.similarity();
297-
var similarityToUse = similarityFromModel == null ? defaultSimilarity() : similarityFromModel;
298+
var similarityToUse = similarityFromModel == null ? defaultSimilarity(serviceSettings.getEmbeddingType()) : similarityFromModel;
298299
var maxInputTokens = serviceSettings.maxInputTokens();
299300

300301
var updatedServiceSettings = new JinaAIEmbeddingsServiceSettings(
@@ -305,7 +306,8 @@ public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
305306
),
306307
similarityToUse,
307308
embeddingSize,
308-
maxInputTokens
309+
maxInputTokens,
310+
serviceSettings.getEmbeddingType()
309311
);
310312

311313
return new JinaAIEmbeddingsModel(embeddingsModel, updatedServiceSettings);
@@ -322,7 +324,10 @@ public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
322324
*
323325
* @return The default similarity.
324326
*/
325-
static SimilarityMeasure defaultSimilarity() {
327+
static SimilarityMeasure defaultSimilarity(JinaAIEmbeddingType embeddingType) {
328+
if (embeddingType == JinaAIEmbeddingType.BINARY || embeddingType == JinaAIEmbeddingType.BIT) {
329+
return SimilarityMeasure.L2_NORM;
330+
}
326331
return SimilarityMeasure.DOT_PRODUCT;
327332
}
328333

0 commit comments

Comments
 (0)