Skip to content

Commit 242e5b3

Browse files
authored
Use Wrapped Action Listeners in ShardBulkInferenceActionFilter (elastic#138505) (elastic#138534)
1 parent 43ba33e commit 242e5b3

File tree

1 file changed

+99
-90
lines changed

1 file changed

+99
-90
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java

Lines changed: 99 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
import org.elasticsearch.inference.UnparsedModel;
4848
import org.elasticsearch.license.LicenseUtils;
4949
import org.elasticsearch.license.XPackLicenseState;
50+
import org.elasticsearch.logging.LogManager;
51+
import org.elasticsearch.logging.Logger;
5052
import org.elasticsearch.rest.RestStatus;
5153
import org.elasticsearch.tasks.Task;
5254
import org.elasticsearch.xcontent.XContent;
@@ -88,6 +90,8 @@
8890
*
8991
*/
9092
public class ShardBulkInferenceActionFilter implements MappedActionFilter {
93+
private static final Logger logger = LogManager.getLogger(ShardBulkInferenceActionFilter.class);
94+
9195
private static final ByteSizeValue DEFAULT_BATCH_SIZE = ByteSizeValue.ofMb(1);
9296

9397
/**
@@ -317,119 +321,124 @@ private void executeChunkedInferenceAsync(
317321
final Releasable onFinish
318322
) {
319323
if (inferenceProvider == null) {
320-
ActionListener<UnparsedModel> modelLoadingListener = new ActionListener<>() {
321-
@Override
322-
public void onResponse(UnparsedModel unparsedModel) {
323-
var service = inferenceServiceRegistry.getService(unparsedModel.service());
324-
if (service.isEmpty() == false) {
325-
var provider = new InferenceProvider(
326-
service.get(),
327-
service.get()
328-
.parsePersistedConfigWithSecrets(
329-
inferenceId,
330-
unparsedModel.taskType(),
331-
unparsedModel.settings(),
332-
unparsedModel.secrets()
333-
)
334-
);
335-
executeChunkedInferenceAsync(inferenceId, provider, requests, onFinish);
336-
} else {
337-
try (onFinish) {
338-
for (FieldInferenceRequest request : requests) {
339-
inferenceResults.get(request.bulkItemIndex).failures.add(
340-
new ResourceNotFoundException(
341-
"Inference service [{}] not found for field [{}]",
342-
unparsedModel.service(),
343-
request.field
344-
)
345-
);
346-
}
347-
}
348-
}
349-
}
350-
351-
@Override
352-
public void onFailure(Exception exc) {
324+
ActionListener<UnparsedModel> modelLoadingListener = ActionListener.wrap(unparsedModel -> {
325+
var service = inferenceServiceRegistry.getService(unparsedModel.service());
326+
if (service.isEmpty() == false) {
327+
var provider = new InferenceProvider(
328+
service.get(),
329+
service.get()
330+
.parsePersistedConfigWithSecrets(
331+
inferenceId,
332+
unparsedModel.taskType(),
333+
unparsedModel.settings(),
334+
unparsedModel.secrets()
335+
)
336+
);
337+
executeChunkedInferenceAsync(inferenceId, provider, requests, onFinish);
338+
} else {
353339
try (onFinish) {
354340
for (FieldInferenceRequest request : requests) {
355-
Exception failure;
356-
if (ExceptionsHelper.unwrap(exc, ResourceNotFoundException.class) instanceof ResourceNotFoundException) {
357-
failure = new ResourceNotFoundException(
358-
"Inference id [{}] not found for field [{}]",
359-
inferenceId,
341+
inferenceResults.get(request.bulkItemIndex).failures.add(
342+
new ResourceNotFoundException(
343+
"Inference service [{}] not found for field [{}]",
344+
unparsedModel.service(),
360345
request.field
361-
);
362-
} else {
363-
failure = new InferenceException(
364-
"Error loading inference for inference id [{}] on field [{}]",
365-
exc,
366-
inferenceId,
367-
request.field
368-
);
369-
}
370-
inferenceResults.get(request.bulkItemIndex).failures.add(failure);
346+
)
347+
);
371348
}
372349
}
373350
}
374-
};
375-
modelRegistry.getModelWithSecrets(inferenceId, modelLoadingListener);
376-
return;
377-
}
378-
final List<ChunkInferenceInput> inputs = requests.stream()
379-
.map(r -> new ChunkInferenceInput(r.input, r.chunkingSettings))
380-
.collect(Collectors.toList());
381-
382-
ActionListener<List<ChunkedInference>> completionListener = new ActionListener<>() {
383-
@Override
384-
public void onResponse(List<ChunkedInference> results) {
351+
}, exc -> {
385352
try (onFinish) {
386-
var requestsIterator = requests.iterator();
387-
for (ChunkedInference result : results) {
388-
var request = requestsIterator.next();
389-
var acc = inferenceResults.get(request.bulkItemIndex);
390-
if (result instanceof ChunkedInferenceError error) {
391-
acc.addFailure(
392-
new InferenceException(
393-
"Exception when running inference id [{}] on field [{}]",
394-
error.exception(),
395-
inferenceProvider.model.getInferenceEntityId(),
396-
request.field
397-
)
353+
for (FieldInferenceRequest request : requests) {
354+
Exception failure;
355+
if (ExceptionsHelper.unwrap(exc, ResourceNotFoundException.class) instanceof ResourceNotFoundException) {
356+
failure = new ResourceNotFoundException(
357+
"Inference id [{}] not found for field [{}]",
358+
inferenceId,
359+
request.field
398360
);
399361
} else {
400-
acc.addOrUpdateResponse(
401-
new FieldInferenceResponse(
402-
request.field(),
403-
request.sourceField(),
404-
useLegacyFormat ? request.input() : null,
405-
request.inputOrder(),
406-
request.offsetAdjustment(),
407-
inferenceProvider.model,
408-
result
409-
)
362+
failure = new InferenceException(
363+
"Error loading inference for inference id [{}] on field [{}]",
364+
exc,
365+
inferenceId,
366+
request.field
410367
);
411368
}
369+
inferenceResults.get(request.bulkItemIndex).failures.add(failure);
370+
}
371+
372+
if (ExceptionsHelper.status(exc).getStatus() >= 500) {
373+
List<String> fields = requests.stream().map(FieldInferenceRequest::field).distinct().toList();
374+
logger.error("Error loading inference for inference id [" + inferenceId + "] on fields " + fields, exc);
412375
}
413376
}
414-
}
377+
});
378+
modelRegistry.getModelWithSecrets(inferenceId, modelLoadingListener);
379+
return;
380+
}
381+
final List<ChunkInferenceInput> inputs = requests.stream()
382+
.map(r -> new ChunkInferenceInput(r.input, r.chunkingSettings))
383+
.collect(Collectors.toList());
415384

416-
@Override
417-
public void onFailure(Exception exc) {
418-
try (onFinish) {
419-
for (FieldInferenceRequest request : requests) {
420-
addInferenceResponseFailure(
421-
request.bulkItemIndex,
385+
ActionListener<List<ChunkedInference>> completionListener = ActionListener.wrap(results -> {
386+
try (onFinish) {
387+
var requestsIterator = requests.iterator();
388+
for (ChunkedInference result : results) {
389+
var request = requestsIterator.next();
390+
var acc = inferenceResults.get(request.bulkItemIndex);
391+
if (result instanceof ChunkedInferenceError error) {
392+
acc.addFailure(
422393
new InferenceException(
423394
"Exception when running inference id [{}] on field [{}]",
424-
exc,
395+
error.exception(),
425396
inferenceProvider.model.getInferenceEntityId(),
426397
request.field
427398
)
428399
);
400+
} else {
401+
acc.addOrUpdateResponse(
402+
new FieldInferenceResponse(
403+
request.field(),
404+
request.sourceField(),
405+
useLegacyFormat ? request.input() : null,
406+
request.inputOrder(),
407+
request.offsetAdjustment(),
408+
inferenceProvider.model,
409+
result
410+
)
411+
);
429412
}
430413
}
431414
}
432-
};
415+
}, exc -> {
416+
try (onFinish) {
417+
for (FieldInferenceRequest request : requests) {
418+
addInferenceResponseFailure(
419+
request.bulkItemIndex,
420+
new InferenceException(
421+
"Exception when running inference id [{}] on field [{}]",
422+
exc,
423+
inferenceProvider.model.getInferenceEntityId(),
424+
request.field
425+
)
426+
);
427+
}
428+
429+
if (ExceptionsHelper.status(exc).getStatus() >= 500) {
430+
List<String> fields = requests.stream().map(FieldInferenceRequest::field).distinct().toList();
431+
logger.error(
432+
"Exception when running inference id ["
433+
+ inferenceProvider.model.getInferenceEntityId()
434+
+ "] on fields "
435+
+ fields,
436+
exc
437+
);
438+
}
439+
}
440+
});
441+
433442
inferenceProvider.service()
434443
.chunkedInfer(
435444
inferenceProvider.model(),

0 commit comments

Comments
 (0)