Skip to content

Commit aac1d48

Browse files
committed
Changes due to the comments
1 parent e067465 commit aac1d48

File tree

3 files changed

+20
-51
lines changed

3 files changed

+20
-51
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIEmbeddingsResponseEntity.java

Lines changed: 12 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
package org.elasticsearch.xpack.inference.external.response.voyageai;
1111

1212
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
13-
import org.elasticsearch.core.Nullable;
1413
import org.elasticsearch.inference.InferenceServiceResults;
1514
import org.elasticsearch.xcontent.ConstructingObjectParser;
1615
import org.elasticsearch.xcontent.ParseField;
@@ -32,7 +31,6 @@
3231
import java.util.List;
3332

3433
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
35-
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
3634
import static org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingType.toLowerCase;
3735

3836
public class VoyageAIEmbeddingsResponseEntity {
@@ -47,31 +45,29 @@ private static String supportedEmbeddingTypes() {
4745
return String.join(", ", validTypes);
4846
}
4947

50-
record EmbeddingInt8Result(List<EmbeddingInt8ResultEntry> entries, String model, String object, @Nullable Usage usage) {
48+
record EmbeddingInt8Result(List<EmbeddingInt8ResultEntry> entries) {
5149
@SuppressWarnings("unchecked")
5250
public static final ConstructingObjectParser<EmbeddingInt8Result, Void> PARSER = new ConstructingObjectParser<>(
5351
EmbeddingInt8Result.class.getSimpleName(),
54-
args -> new EmbeddingInt8Result((List<EmbeddingInt8ResultEntry>) args[0], (String) args[1], (String) args[2], (Usage) args[3])
52+
true,
53+
args -> new EmbeddingInt8Result((List<EmbeddingInt8ResultEntry>) args[0])
5554
);
5655

5756
static {
5857
PARSER.declareObjectArray(constructorArg(), EmbeddingInt8ResultEntry.PARSER::apply, new ParseField("data"));
59-
PARSER.declareString(constructorArg(), new ParseField("model"));
60-
PARSER.declareString(constructorArg(), new ParseField("object"));
61-
PARSER.declareObject(optionalConstructorArg(), Usage.PARSER::apply, new ParseField("usage"));
6258
}
6359
}
6460

65-
record EmbeddingInt8ResultEntry(String object, Integer index, List<Integer> embedding) {
61+
record EmbeddingInt8ResultEntry(Integer index, List<Integer> embedding) {
6662

6763
@SuppressWarnings("unchecked")
6864
public static final ConstructingObjectParser<EmbeddingInt8ResultEntry, Void> PARSER = new ConstructingObjectParser<>(
6965
EmbeddingInt8ResultEntry.class.getSimpleName(),
70-
args -> new EmbeddingInt8ResultEntry((String) args[0], (Integer) args[1], (List<Integer>) args[2])
66+
true,
67+
args -> new EmbeddingInt8ResultEntry((Integer) args[0], (List<Integer>) args[1])
7168
);
7269

7370
static {
74-
PARSER.declareString(constructorArg(), new ParseField("object"));
7571
PARSER.declareInt(constructorArg(), new ParseField("index"));
7672
PARSER.declareIntArray(constructorArg(), new ParseField("embedding"));
7773
}
@@ -88,31 +84,29 @@ public InferenceByteEmbedding toInferenceByteEmbedding() {
8884
}
8985
}
9086

91-
record EmbeddingFloatResult(List<EmbeddingFloatResultEntry> entries, String model, String object, @Nullable Usage usage) {
87+
record EmbeddingFloatResult(List<EmbeddingFloatResultEntry> entries) {
9288
@SuppressWarnings("unchecked")
9389
public static final ConstructingObjectParser<EmbeddingFloatResult, Void> PARSER = new ConstructingObjectParser<>(
9490
EmbeddingFloatResult.class.getSimpleName(),
95-
args -> new EmbeddingFloatResult((List<EmbeddingFloatResultEntry>) args[0], (String) args[1], (String) args[2], (Usage) args[3])
91+
true,
92+
args -> new EmbeddingFloatResult((List<EmbeddingFloatResultEntry>) args[0])
9693
);
9794

9895
static {
9996
PARSER.declareObjectArray(constructorArg(), EmbeddingFloatResultEntry.PARSER::apply, new ParseField("data"));
100-
PARSER.declareString(constructorArg(), new ParseField("model"));
101-
PARSER.declareString(constructorArg(), new ParseField("object"));
102-
PARSER.declareObject(optionalConstructorArg(), Usage.PARSER::apply, new ParseField("usage"));
10397
}
10498
}
10599

106-
record EmbeddingFloatResultEntry(String object, Integer index, List<Float> embedding) {
100+
record EmbeddingFloatResultEntry(Integer index, List<Float> embedding) {
107101

108102
@SuppressWarnings("unchecked")
109103
public static final ConstructingObjectParser<EmbeddingFloatResultEntry, Void> PARSER = new ConstructingObjectParser<>(
110104
EmbeddingFloatResultEntry.class.getSimpleName(),
111-
args -> new EmbeddingFloatResultEntry((String) args[0], (Integer) args[1], (List<Float>) args[2])
105+
true,
106+
args -> new EmbeddingFloatResultEntry((Integer) args[0], (List<Float>) args[1])
112107
);
113108

114109
static {
115-
PARSER.declareString(constructorArg(), new ParseField("object"));
116110
PARSER.declareInt(constructorArg(), new ParseField("index"));
117111
PARSER.declareFloatArray(constructorArg(), new ParseField("embedding"));
118112
}
@@ -122,18 +116,6 @@ public InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding toInferenceFlo
122116
}
123117
}
124118

125-
record Usage(Integer totalTokens) {
126-
127-
public static final ConstructingObjectParser<Usage, Void> PARSER = new ConstructingObjectParser<>(
128-
Usage.class.getSimpleName(),
129-
args -> new Usage((Integer) args[0])
130-
);
131-
132-
static {
133-
PARSER.declareInt(constructorArg(), new ParseField("total_tokens"));
134-
}
135-
}
136-
137119
/**
138120
* Parses the VoyageAI json response.
139121
* For a request like:

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIRerankResponseEntity.java

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -33,19 +33,17 @@ public class VoyageAIRerankResponseEntity {
3333

3434
private static final Logger logger = LogManager.getLogger(VoyageAIRerankResponseEntity.class);
3535

36-
record RerankResult(List<RerankResultEntry> entries, String model, String object, @Nullable Usage usage) {
36+
record RerankResult(List<RerankResultEntry> entries) {
3737

3838
@SuppressWarnings("unchecked")
3939
public static final ConstructingObjectParser<RerankResult, Void> PARSER = new ConstructingObjectParser<>(
4040
RerankResult.class.getSimpleName(),
41-
args -> new RerankResult((List<RerankResultEntry>) args[0], (String) args[1], (String) args[2], (Usage) args[3])
41+
true,
42+
args -> new RerankResult((List<RerankResultEntry>) args[0])
4243
);
4344

4445
static {
4546
PARSER.declareObjectArray(constructorArg(), RerankResultEntry.PARSER::apply, new ParseField("data"));
46-
PARSER.declareString(constructorArg(), new ParseField("model"));
47-
PARSER.declareString(constructorArg(), new ParseField("object"));
48-
PARSER.declareObject(optionalConstructorArg(), Usage.PARSER::apply, new ParseField("usage"));
4947
}
5048
}
5149

@@ -67,18 +65,6 @@ public RankedDocsResults.RankedDoc toRankedDoc() {
6765
}
6866
}
6967

70-
record Usage(Integer totalTokens) {
71-
72-
public static final ConstructingObjectParser<Usage, Void> PARSER = new ConstructingObjectParser<>(
73-
Usage.class.getSimpleName(),
74-
args -> new Usage((Integer) args[0])
75-
);
76-
77-
static {
78-
PARSER.declareInt(constructorArg(), new ParseField("total_tokens"));
79-
}
80-
}
81-
8268
/**
8369
* Parses the VoyageAI ranked response.
8470
* For a request like:

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/voyageai/VoyageAIEmbeddingsResponseEntityTests.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.util.List;
2121

2222
import static org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModelTests.createModel;
23+
import static org.hamcrest.Matchers.containsString;
2324
import static org.hamcrest.Matchers.instanceOf;
2425
import static org.hamcrest.Matchers.is;
2526
import static org.mockito.Mockito.mock;
@@ -139,14 +140,14 @@ public void testFromResponse_FailsWhenDataFieldIsNotPresent() {
139140
);
140141

141142
var thrownException = expectThrows(
142-
XContentParseException.class,
143+
java.lang.IllegalArgumentException.class,
143144
() -> VoyageAIEmbeddingsResponseEntity.fromResponse(
144145
request,
145146
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
146147
)
147148
);
148149

149-
assertThat(thrownException.getMessage(), is("[3:3] [EmbeddingFloatResult] unknown field [not_data]"));
150+
assertThat(thrownException.getMessage(), is("Required [data]"));
150151
}
151152

152153
public void testFromResponse_FailsWhenDataFieldNotAnArray() {
@@ -183,7 +184,7 @@ public void testFromResponse_FailsWhenDataFieldNotAnArray() {
183184
)
184185
);
185186

186-
assertThat(thrownException.getMessage(), is("[4:15] [EmbeddingFloatResult] failed to parse field [data]"));
187+
assertThat(thrownException.getMessage(), containsString("[EmbeddingFloatResult] failed to parse field [data]"));
187188
}
188189

189190
public void testFromResponse_FailsWhenEmbeddingsDoesNotExist() {
@@ -220,7 +221,7 @@ public void testFromResponse_FailsWhenEmbeddingsDoesNotExist() {
220221
)
221222
);
222223

223-
assertThat(thrownException.getMessage(), is("[7:27] [EmbeddingFloatResult] failed to parse field [data]"));
224+
assertThat(thrownException.getMessage(), containsString("[EmbeddingFloatResult] failed to parse field [data]"));
224225
}
225226

226227
public void testFromResponse_FailsWhenEmbeddingValueIsAString() {

0 commit comments

Comments
 (0)