Skip to content

Commit 7f9ba0f

Browse files
Adding chunking settings parser fix and tests (elastic#135726)
1 parent fb723a3 commit 7f9ba0f

File tree

2 files changed

+106
-4
lines changed

2 files changed

+106
-4
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,12 @@ public void parseRequestConfig(
9696
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
9797
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
9898

99-
var chunkingSettings = extractChunkingSettings(config, taskType);
99+
ChunkingSettings chunkingSettings = null;
100+
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
101+
chunkingSettings = ChunkingSettingsBuilder.fromMap(
102+
removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)
103+
);
104+
}
100105

101106
CustomModel model = createModel(
102107
inferenceEntityId,
@@ -147,7 +152,14 @@ private static RequestParameters createParameters(CustomModel model) {
147152
};
148153
}
149154

150-
private static ChunkingSettings extractChunkingSettings(Map<String, Object> config, TaskType taskType) {
155+
private static ChunkingSettings extractPersistentChunkingSettings(Map<String, Object> config, TaskType taskType) {
156+
/*
157+
* There's a sutle difference between how the chunking settings are parsed for the request context vs the persistent context.
158+
* For persistent context, to support backwards compatibility, if the chunking settings are not present, removeFromMap will
159+
* return null which results in the older word boundary chunking settings being used as the default.
160+
* For request context, removeFromMapOrDefaultEmpty returns an empty map which results in the newer sentence boundary chunking
161+
* settings being used as the default.
162+
*/
151163
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
152164
return ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
153165
}
@@ -220,7 +232,7 @@ public CustomModel parsePersistedConfigWithSecrets(
220232
Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
221233
Map<String, Object> secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS);
222234

223-
var chunkingSettings = extractChunkingSettings(config, taskType);
235+
var chunkingSettings = extractPersistentChunkingSettings(config, taskType);
224236

225237
return createModelWithoutLoggingDeprecations(
226238
inferenceEntityId,
@@ -237,7 +249,7 @@ public CustomModel parsePersistedConfig(String inferenceEntityId, TaskType taskT
237249
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
238250
Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
239251

240-
var chunkingSettings = extractChunkingSettings(config, taskType);
252+
var chunkingSettings = extractPersistentChunkingSettings(config, taskType);
241253

242254
return createModelWithoutLoggingDeprecations(
243255
inferenceEntityId,

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
3232
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
3333
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
34+
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
3435
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests;
3536
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
3637
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
@@ -53,7 +54,9 @@
5354
import java.util.Map;
5455

5556
import static org.elasticsearch.xpack.inference.Utils.TIMEOUT;
57+
import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
5658
import static org.elasticsearch.xpack.inference.Utils.getRequestConfigMap;
59+
import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap;
5760
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
5861
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
5962
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
@@ -312,6 +315,93 @@ private static CustomServiceSettings.TextEmbeddingSettings getDefaultTextEmbeddi
312315
: CustomServiceSettings.TextEmbeddingSettings.NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS;
313316
}
314317

318+
public void testParseRequestConfig_CreatesAnEmbeddingsModel_WhenChunkingSettingsProvided() throws Exception {
319+
var chunkingSettingsMap = createRandomChunkingSettingsMap();
320+
321+
try (var service = createService(threadPool, clientManager)) {
322+
var config = getRequestConfigMap(
323+
createServiceSettingsMap(TaskType.TEXT_EMBEDDING),
324+
createTaskSettingsMap(),
325+
chunkingSettingsMap,
326+
createSecretSettingsMap()
327+
);
328+
329+
var listener = new PlainActionFuture<Model>();
330+
service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, listener);
331+
var model = listener.actionGet(TIMEOUT);
332+
333+
assertModel(model, TaskType.TEXT_EMBEDDING);
334+
335+
var expectedChunkingSettings = ChunkingSettingsBuilder.fromMap(chunkingSettingsMap);
336+
assertThat(model.getConfigurations().getChunkingSettings(), is(expectedChunkingSettings));
337+
}
338+
}
339+
340+
public void testParseRequestConfig_CreatesAnEmbeddingsModel_WhenChunkingSettingsNotProvided() throws Exception {
341+
try (var service = createService(threadPool, clientManager)) {
342+
var config = getRequestConfigMap(
343+
createServiceSettingsMap(TaskType.TEXT_EMBEDDING),
344+
createTaskSettingsMap(),
345+
createSecretSettingsMap()
346+
);
347+
348+
var listener = new PlainActionFuture<Model>();
349+
service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, listener);
350+
var model = listener.actionGet(TIMEOUT);
351+
352+
assertModel(model, TaskType.TEXT_EMBEDDING);
353+
354+
var expectedChunkingSettings = ChunkingSettingsBuilder.fromMap(Map.of());
355+
assertThat(model.getConfigurations().getChunkingSettings(), is(expectedChunkingSettings));
356+
}
357+
}
358+
359+
public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModel_WhenChunkingSettingsProvided() throws Exception {
360+
var chunkingSettingsMap = createRandomChunkingSettingsMap();
361+
362+
try (var service = createService(threadPool, clientManager)) {
363+
var persistedConfigMap = getPersistedConfigMap(
364+
createServiceSettingsMap(TaskType.TEXT_EMBEDDING),
365+
createTaskSettingsMap(),
366+
chunkingSettingsMap,
367+
createSecretSettingsMap()
368+
);
369+
370+
var model = service.parsePersistedConfigWithSecrets(
371+
"id",
372+
TaskType.TEXT_EMBEDDING,
373+
persistedConfigMap.config(),
374+
persistedConfigMap.secrets()
375+
);
376+
377+
assertModel(model, TaskType.TEXT_EMBEDDING);
378+
379+
var expectedChunkingSettings = ChunkingSettingsBuilder.fromMap(chunkingSettingsMap);
380+
assertThat(model.getConfigurations().getChunkingSettings(), is(expectedChunkingSettings));
381+
}
382+
}
383+
384+
public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModel_WhenChunkingSettingsNotProvided() throws Exception {
385+
try (var service = createService(threadPool, clientManager)) {
386+
var persistedConfigMap = getPersistedConfigMap(
387+
createServiceSettingsMap(TaskType.TEXT_EMBEDDING),
388+
createTaskSettingsMap(),
389+
createSecretSettingsMap()
390+
);
391+
392+
var model = service.parsePersistedConfigWithSecrets(
393+
"id",
394+
TaskType.TEXT_EMBEDDING,
395+
persistedConfigMap.config(),
396+
persistedConfigMap.secrets()
397+
);
398+
assertModel(model, TaskType.TEXT_EMBEDDING);
399+
400+
var expectedChunkingSettings = ChunkingSettingsBuilder.fromMap(null);
401+
assertThat(model.getConfigurations().getChunkingSettings(), is(expectedChunkingSettings));
402+
}
403+
}
404+
315405
public void testInfer_ReturnsAnError_WithoutParsingTheResponseBody() throws IOException {
316406
try (var service = createService(threadPool, clientManager)) {
317407
String responseJson = "error";

0 commit comments

Comments
 (0)