Skip to content

Commit e5248ad

Browse files
committed
Adding initial tests
Moving dimensions to ServiceSettings
1 parent 19b2c0c commit e5248ad

File tree

10 files changed

+2973
-34
lines changed

10 files changed

+2973
-34
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIEmbeddingsRequestEntity.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
5656
builder.field(TRUNCATION_FIELD, taskSettings.getTruncation());
5757
}
5858

59-
if(taskSettings.getOutputDimension() != null) {
60-
builder.field(OUTPUT_DIMENSION, taskSettings.getOutputDimension());
59+
if(serviceSettings.dimensions() != null) {
60+
builder.field(OUTPUT_DIMENSION, serviceSettings.dimensions());
6161
}
6262

6363
if(serviceSettings.getEmbeddingType() != null) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/embeddings/VoyageAIEmbeddingsTaskSettings.java

Lines changed: 8 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
public class VoyageAIEmbeddingsTaskSettings implements TaskSettings {
4242

4343
public static final String NAME = "voyageai_embeddings_task_settings";
44-
public static final VoyageAIEmbeddingsTaskSettings EMPTY_SETTINGS = new VoyageAIEmbeddingsTaskSettings(null, null, null);
44+
public static final VoyageAIEmbeddingsTaskSettings EMPTY_SETTINGS = new VoyageAIEmbeddingsTaskSettings(null, null);
4545
static final String INPUT_TYPE = "input_type";
4646
static final EnumSet<InputType> VALID_REQUEST_VALUES = EnumSet.of(
4747
InputType.INGEST,
@@ -79,7 +79,7 @@ public static VoyageAIEmbeddingsTaskSettings fromMap(Map<String, Object> map) {
7979
throw validationException;
8080
}
8181

82-
return new VoyageAIEmbeddingsTaskSettings(inputType, truncation, outputDimension);
82+
return new VoyageAIEmbeddingsTaskSettings(inputType, truncation);
8383
}
8484

8585
/**
@@ -101,9 +101,8 @@ public static VoyageAIEmbeddingsTaskSettings of(
101101
) {
102102
var inputTypeToUse = getValidInputType(originalSettings, requestTaskSettings, requestInputType);
103103
var truncationToUse = getValidTruncation(originalSettings, requestTaskSettings);
104-
var outputDimension = getValidOutputDimension(originalSettings, requestTaskSettings);
105104

106-
return new VoyageAIEmbeddingsTaskSettings(inputTypeToUse, truncationToUse, outputDimension);
105+
return new VoyageAIEmbeddingsTaskSettings(inputTypeToUse, truncationToUse);
107106
}
108107

109108
private static InputType getValidInputType(
@@ -129,36 +128,23 @@ private static Boolean getValidTruncation(
129128
return requestTaskSettings.getTruncation() == null ? originalSettings.truncation : requestTaskSettings.getTruncation();
130129
}
131130

132-
private static Integer getValidOutputDimension(
133-
VoyageAIEmbeddingsTaskSettings originalSettings,
134-
VoyageAIEmbeddingsTaskSettings requestTaskSettings
135-
) {
136-
return requestTaskSettings.getOutputDimension() == null
137-
? originalSettings.outputDimension
138-
: requestTaskSettings.getOutputDimension();
139-
}
140-
141131
private final InputType inputType;
142132
private final Boolean truncation;
143-
private final Integer outputDimension;
144133

145134
public VoyageAIEmbeddingsTaskSettings(StreamInput in) throws IOException {
146135
this(
147136
in.readOptionalEnum(InputType.class),
148-
in.readOptionalBoolean(),
149-
in.readOptionalInt()
137+
in.readOptionalBoolean()
150138
);
151139
}
152140

153141
public VoyageAIEmbeddingsTaskSettings(
154142
@Nullable InputType inputType,
155-
@Nullable Boolean truncation,
156-
@Nullable Integer outputDimension
143+
@Nullable Boolean truncation
157144
) {
158145
validateInputType(inputType);
159146
this.inputType = inputType;
160147
this.truncation = truncation;
161-
this.outputDimension = outputDimension;
162148
}
163149

164150
private static void validateInputType(InputType inputType) {
@@ -171,7 +157,7 @@ private static void validateInputType(InputType inputType) {
171157

172158
@Override
173159
public boolean isEmpty() {
174-
return inputType == null && truncation == null && outputDimension == null;
160+
return inputType == null && truncation == null;
175161
}
176162

177163
@Override
@@ -185,10 +171,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
185171
builder.field(TRUNCATION, truncation);
186172
}
187173

188-
if (outputDimension != null) {
189-
builder.field(OUTPUT_DIMENSION, outputDimension);
190-
}
191-
192174
builder.endObject();
193175
return builder;
194176
}
@@ -201,10 +183,6 @@ public Boolean getTruncation() {
201183
return truncation;
202184
}
203185

204-
public Integer getOutputDimension() {
205-
return outputDimension;
206-
}
207-
208186
@Override
209187
public String getWriteableName() {
210188
return NAME;
@@ -219,7 +197,6 @@ public TransportVersion getMinimalSupportedVersion() {
219197
public void writeTo(StreamOutput out) throws IOException {
220198
out.writeOptionalEnum(inputType);
221199
out.writeOptionalBoolean(truncation);
222-
out.writeOptionalInt(outputDimension);
223200
}
224201

225202
@Override
@@ -228,13 +205,12 @@ public boolean equals(Object o) {
228205
if (o == null || getClass() != o.getClass()) return false;
229206
VoyageAIEmbeddingsTaskSettings that = (VoyageAIEmbeddingsTaskSettings) o;
230207
return Objects.equals(inputType, that.inputType) &&
231-
Objects.equals(truncation, that.truncation) &&
232-
Objects.equals(outputDimension, that.outputDimension);
208+
Objects.equals(truncation, that.truncation);
233209
}
234210

235211
@Override
236212
public int hashCode() {
237-
return Objects.hash(inputType, truncation, outputDimension);
213+
return Objects.hash(inputType, truncation);
238214
}
239215

240216
public static String invalidInputTypeMessage(InputType inputType) {
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.voyageai;
9+
10+
import org.elasticsearch.common.Strings;
11+
import org.elasticsearch.common.ValidationException;
12+
import org.elasticsearch.common.io.stream.Writeable;
13+
import org.elasticsearch.core.Nullable;
14+
import org.elasticsearch.test.AbstractWireSerializingTestCase;
15+
import org.elasticsearch.xcontent.XContentBuilder;
16+
import org.elasticsearch.xcontent.XContentFactory;
17+
import org.elasticsearch.xcontent.XContentType;
18+
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
19+
import org.elasticsearch.xpack.inference.services.ServiceFields;
20+
import org.elasticsearch.xpack.inference.services.ServiceUtils;
21+
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings;
22+
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
23+
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
24+
import org.hamcrest.MatcherAssert;
25+
26+
import java.io.IOException;
27+
import java.util.HashMap;
28+
import java.util.Map;
29+
30+
import static org.hamcrest.Matchers.containsString;
31+
import static org.hamcrest.Matchers.is;
32+
33+
public class VoyageAIServiceSettingsTests extends AbstractWireSerializingTestCase<VoyageAIServiceSettings> {
34+
35+
public static VoyageAIServiceSettings createRandomWithNonNullUrl() {
36+
return createRandom(randomAlphaOfLength(15));
37+
}
38+
39+
/**
40+
* The created settings can have a url set to null.
41+
*/
42+
public static VoyageAIServiceSettings createRandom() {
43+
var url = randomBoolean() ? randomAlphaOfLength(15) : null;
44+
return createRandom(url);
45+
}
46+
47+
private static VoyageAIServiceSettings createRandom(String url) {
48+
var model = randomAlphaOfLength(15);
49+
50+
return new VoyageAIServiceSettings(ServiceUtils.createOptionalUri(url), model, RateLimitSettingsTests.createRandom());
51+
}
52+
53+
public void testFromMap() {
54+
var url = "https://www.abc.com";
55+
var model = "model";
56+
var serviceSettings = VoyageAIServiceSettings.fromMap(
57+
new HashMap<>(Map.of(ServiceFields.URL, url, VoyageAIServiceSettings.MODEL_ID, model)),
58+
ConfigurationParseContext.REQUEST
59+
);
60+
61+
MatcherAssert.assertThat(serviceSettings, is(new VoyageAIServiceSettings(ServiceUtils.createUri(url), model, null)));
62+
}
63+
64+
public void testFromMap_WithRateLimit() {
65+
var url = "https://www.abc.com";
66+
var model = "model";
67+
var serviceSettings = VoyageAIServiceSettings.fromMap(
68+
new HashMap<>(
69+
Map.of(
70+
ServiceFields.URL,
71+
url,
72+
VoyageAIServiceSettings.MODEL_ID,
73+
model,
74+
RateLimitSettings.FIELD_NAME,
75+
new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, 3))
76+
)
77+
),
78+
ConfigurationParseContext.REQUEST
79+
);
80+
81+
MatcherAssert.assertThat(
82+
serviceSettings,
83+
is(new VoyageAIServiceSettings(ServiceUtils.createUri(url), model, new RateLimitSettings(3)))
84+
);
85+
}
86+
87+
public void testFromMap_WhenUsingModelId() {
88+
var url = "https://www.abc.com";
89+
var model = "model";
90+
var serviceSettings = VoyageAIServiceSettings.fromMap(
91+
new HashMap<>(Map.of(ServiceFields.URL, url, VoyageAIServiceSettings.MODEL_ID, model)),
92+
ConfigurationParseContext.PERSISTENT
93+
);
94+
95+
MatcherAssert.assertThat(serviceSettings, is(new VoyageAIServiceSettings(ServiceUtils.createUri(url), model, null)));
96+
}
97+
98+
public void testFromMap_MissingUrl_DoesNotThrowException() {
99+
var serviceSettings = VoyageAIServiceSettings.fromMap(
100+
new HashMap<>(Map.of(VoyageAIServiceSettings.MODEL_ID, "model")),
101+
ConfigurationParseContext.PERSISTENT
102+
);
103+
assertNull(serviceSettings.uri());
104+
}
105+
106+
public void testFromMap_EmptyUrl_ThrowsError() {
107+
var thrownException = expectThrows(
108+
ValidationException.class,
109+
() -> VoyageAIServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, "")), ConfigurationParseContext.PERSISTENT)
110+
);
111+
112+
MatcherAssert.assertThat(
113+
thrownException.getMessage(),
114+
containsString(
115+
Strings.format(
116+
"Validation Failed: 1: [service_settings] Invalid value empty string. [%s] must be a non-empty string;",
117+
ServiceFields.URL
118+
)
119+
)
120+
);
121+
}
122+
123+
public void testFromMap_InvalidUrl_ThrowsError() {
124+
var url = "https://www.abc^.com";
125+
var thrownException = expectThrows(
126+
ValidationException.class,
127+
() -> VoyageAIServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, url)), ConfigurationParseContext.PERSISTENT)
128+
);
129+
130+
MatcherAssert.assertThat(
131+
thrownException.getMessage(),
132+
containsString(
133+
Strings.format("Validation Failed: 1: [service_settings] Invalid url [%s] received for field [%s]", url, ServiceFields.URL)
134+
)
135+
);
136+
}
137+
138+
public void testXContent_WritesModelId() throws IOException {
139+
var entity = new VoyageAIServiceSettings((String) null, "model", new RateLimitSettings(1));
140+
141+
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
142+
entity.toXContent(builder, null);
143+
String xContentResult = Strings.toString(builder);
144+
145+
assertThat(xContentResult, is("""
146+
{"model_id":"model","rate_limit":{"requests_per_minute":1}}"""));
147+
}
148+
149+
@Override
150+
protected Writeable.Reader<VoyageAIServiceSettings> instanceReader() {
151+
return VoyageAIServiceSettings::new;
152+
}
153+
154+
@Override
155+
protected VoyageAIServiceSettings createTestInstance() {
156+
return createRandomWithNonNullUrl();
157+
}
158+
159+
@Override
160+
protected VoyageAIServiceSettings mutateInstance(VoyageAIServiceSettings instance) throws IOException {
161+
return randomValueOtherThan(instance, VoyageAIServiceSettingsTests::createRandom);
162+
}
163+
164+
public static Map<String, Object> getServiceSettingsMap(@Nullable String url, String model) {
165+
var map = new HashMap<String, Object>();
166+
167+
if (url != null) {
168+
map.put(ServiceFields.URL, url);
169+
}
170+
171+
map.put(VoyageAIServiceSettings.MODEL_ID, model);
172+
173+
return map;
174+
}
175+
}

0 commit comments

Comments
 (0)