|
47 | 47 | import org.elasticsearch.inference.UnparsedModel; |
48 | 48 | import org.elasticsearch.license.LicenseUtils; |
49 | 49 | import org.elasticsearch.license.XPackLicenseState; |
| 50 | +import org.elasticsearch.logging.LogManager; |
| 51 | +import org.elasticsearch.logging.Logger; |
50 | 52 | import org.elasticsearch.rest.RestStatus; |
51 | 53 | import org.elasticsearch.tasks.Task; |
52 | 54 | import org.elasticsearch.xcontent.XContent; |
|
88 | 90 | * |
89 | 91 | */ |
90 | 92 | public class ShardBulkInferenceActionFilter implements MappedActionFilter { |
| 93 | + private static final Logger logger = LogManager.getLogger(ShardBulkInferenceActionFilter.class); |
| 94 | + |
91 | 95 | private static final ByteSizeValue DEFAULT_BATCH_SIZE = ByteSizeValue.ofMb(1); |
92 | 96 |
|
93 | 97 | /** |
@@ -317,119 +321,124 @@ private void executeChunkedInferenceAsync( |
317 | 321 | final Releasable onFinish |
318 | 322 | ) { |
319 | 323 | if (inferenceProvider == null) { |
320 | | - ActionListener<UnparsedModel> modelLoadingListener = new ActionListener<>() { |
321 | | - @Override |
322 | | - public void onResponse(UnparsedModel unparsedModel) { |
323 | | - var service = inferenceServiceRegistry.getService(unparsedModel.service()); |
324 | | - if (service.isEmpty() == false) { |
325 | | - var provider = new InferenceProvider( |
326 | | - service.get(), |
327 | | - service.get() |
328 | | - .parsePersistedConfigWithSecrets( |
329 | | - inferenceId, |
330 | | - unparsedModel.taskType(), |
331 | | - unparsedModel.settings(), |
332 | | - unparsedModel.secrets() |
333 | | - ) |
334 | | - ); |
335 | | - executeChunkedInferenceAsync(inferenceId, provider, requests, onFinish); |
336 | | - } else { |
337 | | - try (onFinish) { |
338 | | - for (FieldInferenceRequest request : requests) { |
339 | | - inferenceResults.get(request.bulkItemIndex).failures.add( |
340 | | - new ResourceNotFoundException( |
341 | | - "Inference service [{}] not found for field [{}]", |
342 | | - unparsedModel.service(), |
343 | | - request.field |
344 | | - ) |
345 | | - ); |
346 | | - } |
347 | | - } |
348 | | - } |
349 | | - } |
350 | | - |
351 | | - @Override |
352 | | - public void onFailure(Exception exc) { |
| 324 | + ActionListener<UnparsedModel> modelLoadingListener = ActionListener.wrap(unparsedModel -> { |
| 325 | + var service = inferenceServiceRegistry.getService(unparsedModel.service()); |
| 326 | + if (service.isEmpty() == false) { |
| 327 | + var provider = new InferenceProvider( |
| 328 | + service.get(), |
| 329 | + service.get() |
| 330 | + .parsePersistedConfigWithSecrets( |
| 331 | + inferenceId, |
| 332 | + unparsedModel.taskType(), |
| 333 | + unparsedModel.settings(), |
| 334 | + unparsedModel.secrets() |
| 335 | + ) |
| 336 | + ); |
| 337 | + executeChunkedInferenceAsync(inferenceId, provider, requests, onFinish); |
| 338 | + } else { |
353 | 339 | try (onFinish) { |
354 | 340 | for (FieldInferenceRequest request : requests) { |
355 | | - Exception failure; |
356 | | - if (ExceptionsHelper.unwrap(exc, ResourceNotFoundException.class) instanceof ResourceNotFoundException) { |
357 | | - failure = new ResourceNotFoundException( |
358 | | - "Inference id [{}] not found for field [{}]", |
359 | | - inferenceId, |
| 341 | + inferenceResults.get(request.bulkItemIndex).failures.add( |
| 342 | + new ResourceNotFoundException( |
| 343 | + "Inference service [{}] not found for field [{}]", |
| 344 | + unparsedModel.service(), |
360 | 345 | request.field |
361 | | - ); |
362 | | - } else { |
363 | | - failure = new InferenceException( |
364 | | - "Error loading inference for inference id [{}] on field [{}]", |
365 | | - exc, |
366 | | - inferenceId, |
367 | | - request.field |
368 | | - ); |
369 | | - } |
370 | | - inferenceResults.get(request.bulkItemIndex).failures.add(failure); |
| 346 | + ) |
| 347 | + ); |
371 | 348 | } |
372 | 349 | } |
373 | 350 | } |
374 | | - }; |
375 | | - modelRegistry.getModelWithSecrets(inferenceId, modelLoadingListener); |
376 | | - return; |
377 | | - } |
378 | | - final List<ChunkInferenceInput> inputs = requests.stream() |
379 | | - .map(r -> new ChunkInferenceInput(r.input, r.chunkingSettings)) |
380 | | - .collect(Collectors.toList()); |
381 | | - |
382 | | - ActionListener<List<ChunkedInference>> completionListener = new ActionListener<>() { |
383 | | - @Override |
384 | | - public void onResponse(List<ChunkedInference> results) { |
| 351 | + }, exc -> { |
385 | 352 | try (onFinish) { |
386 | | - var requestsIterator = requests.iterator(); |
387 | | - for (ChunkedInference result : results) { |
388 | | - var request = requestsIterator.next(); |
389 | | - var acc = inferenceResults.get(request.bulkItemIndex); |
390 | | - if (result instanceof ChunkedInferenceError error) { |
391 | | - acc.addFailure( |
392 | | - new InferenceException( |
393 | | - "Exception when running inference id [{}] on field [{}]", |
394 | | - error.exception(), |
395 | | - inferenceProvider.model.getInferenceEntityId(), |
396 | | - request.field |
397 | | - ) |
| 353 | + for (FieldInferenceRequest request : requests) { |
| 354 | + Exception failure; |
| 355 | + if (ExceptionsHelper.unwrap(exc, ResourceNotFoundException.class) instanceof ResourceNotFoundException) { |
| 356 | + failure = new ResourceNotFoundException( |
| 357 | + "Inference id [{}] not found for field [{}]", |
| 358 | + inferenceId, |
| 359 | + request.field |
398 | 360 | ); |
399 | 361 | } else { |
400 | | - acc.addOrUpdateResponse( |
401 | | - new FieldInferenceResponse( |
402 | | - request.field(), |
403 | | - request.sourceField(), |
404 | | - useLegacyFormat ? request.input() : null, |
405 | | - request.inputOrder(), |
406 | | - request.offsetAdjustment(), |
407 | | - inferenceProvider.model, |
408 | | - result |
409 | | - ) |
| 362 | + failure = new InferenceException( |
| 363 | + "Error loading inference for inference id [{}] on field [{}]", |
| 364 | + exc, |
| 365 | + inferenceId, |
| 366 | + request.field |
410 | 367 | ); |
411 | 368 | } |
| 369 | + inferenceResults.get(request.bulkItemIndex).failures.add(failure); |
| 370 | + } |
| 371 | + |
| 372 | + if (ExceptionsHelper.status(exc).getStatus() >= 500) { |
| 373 | + List<String> fields = requests.stream().map(FieldInferenceRequest::field).distinct().toList(); |
| 374 | + logger.error("Error loading inference for inference id [" + inferenceId + "] on fields " + fields, exc); |
412 | 375 | } |
413 | 376 | } |
414 | | - } |
| 377 | + }); |
| 378 | + modelRegistry.getModelWithSecrets(inferenceId, modelLoadingListener); |
| 379 | + return; |
| 380 | + } |
| 381 | + final List<ChunkInferenceInput> inputs = requests.stream() |
| 382 | + .map(r -> new ChunkInferenceInput(r.input, r.chunkingSettings)) |
| 383 | + .collect(Collectors.toList()); |
415 | 384 |
|
416 | | - @Override |
417 | | - public void onFailure(Exception exc) { |
418 | | - try (onFinish) { |
419 | | - for (FieldInferenceRequest request : requests) { |
420 | | - addInferenceResponseFailure( |
421 | | - request.bulkItemIndex, |
| 385 | + ActionListener<List<ChunkedInference>> completionListener = ActionListener.wrap(results -> { |
| 386 | + try (onFinish) { |
| 387 | + var requestsIterator = requests.iterator(); |
| 388 | + for (ChunkedInference result : results) { |
| 389 | + var request = requestsIterator.next(); |
| 390 | + var acc = inferenceResults.get(request.bulkItemIndex); |
| 391 | + if (result instanceof ChunkedInferenceError error) { |
| 392 | + acc.addFailure( |
422 | 393 | new InferenceException( |
423 | 394 | "Exception when running inference id [{}] on field [{}]", |
424 | | - exc, |
| 395 | + error.exception(), |
425 | 396 | inferenceProvider.model.getInferenceEntityId(), |
426 | 397 | request.field |
427 | 398 | ) |
428 | 399 | ); |
| 400 | + } else { |
| 401 | + acc.addOrUpdateResponse( |
| 402 | + new FieldInferenceResponse( |
| 403 | + request.field(), |
| 404 | + request.sourceField(), |
| 405 | + useLegacyFormat ? request.input() : null, |
| 406 | + request.inputOrder(), |
| 407 | + request.offsetAdjustment(), |
| 408 | + inferenceProvider.model, |
| 409 | + result |
| 410 | + ) |
| 411 | + ); |
429 | 412 | } |
430 | 413 | } |
431 | 414 | } |
432 | | - }; |
| 415 | + }, exc -> { |
| 416 | + try (onFinish) { |
| 417 | + for (FieldInferenceRequest request : requests) { |
| 418 | + addInferenceResponseFailure( |
| 419 | + request.bulkItemIndex, |
| 420 | + new InferenceException( |
| 421 | + "Exception when running inference id [{}] on field [{}]", |
| 422 | + exc, |
| 423 | + inferenceProvider.model.getInferenceEntityId(), |
| 424 | + request.field |
| 425 | + ) |
| 426 | + ); |
| 427 | + } |
| 428 | + |
| 429 | + if (ExceptionsHelper.status(exc).getStatus() >= 500) { |
| 430 | + List<String> fields = requests.stream().map(FieldInferenceRequest::field).distinct().toList(); |
| 431 | + logger.error( |
| 432 | + "Exception when running inference id [" |
| 433 | + + inferenceProvider.model.getInferenceEntityId() |
| 434 | + + "] on fields " |
| 435 | + + fields, |
| 436 | + exc |
| 437 | + ); |
| 438 | + } |
| 439 | + } |
| 440 | + }); |
| 441 | + |
433 | 442 | inferenceProvider.service() |
434 | 443 | .chunkedInfer( |
435 | 444 | inferenceProvider.model(), |
|
0 commit comments