Skip to content

Commit ecf8183

Browse files
jimczivaleriy42
authored andcommitted
Increment inference stats counter for shard bulk inference calls (elastic#129140)
This change updates the inference stats counter to include chunked inference calls performed by the shard bulk inference filter on all semantic text fields. It ensures that usage of inference on semantic text fields is properly recorded in the stats.
1 parent 411ce03 commit ecf8183

File tree

4 files changed

+115
-21
lines changed

4 files changed

+115
-21
lines changed

docs/changelog/129140.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 129140
2+
summary: Increment inference stats counter for shard bulk inference calls
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -344,22 +344,24 @@ public Collection<?> createComponents(PluginServices services) {
344344
}
345345
inferenceServiceRegistry.set(serviceRegistry);
346346

347+
var meterRegistry = services.telemetryProvider().getMeterRegistry();
348+
var inferenceStats = InferenceStats.create(meterRegistry);
349+
var inferenceStatsBinding = new PluginComponentBinding<>(InferenceStats.class, inferenceStats);
350+
347351
var actionFilter = new ShardBulkInferenceActionFilter(
348352
services.clusterService(),
349353
serviceRegistry,
350354
modelRegistry.get(),
351355
getLicenseState(),
352-
services.indexingPressure()
356+
services.indexingPressure(),
357+
inferenceStats
353358
);
354359
shardBulkInferenceActionFilter.set(actionFilter);
355360

356-
var meterRegistry = services.telemetryProvider().getMeterRegistry();
357-
var inferenceStats = new PluginComponentBinding<>(InferenceStats.class, InferenceStats.create(meterRegistry));
358-
359361
components.add(serviceRegistry);
360362
components.add(modelRegistry.get());
361363
components.add(httpClientManager);
362-
components.add(inferenceStats);
364+
components.add(inferenceStatsBinding);
363365

364366
// Only add InferenceServiceNodeLocalRateLimitCalculator (which is a ClusterStateListener) for cluster aware rate limiting,
365367
// if the rate limiting feature flags are enabled, otherwise provide noop implementation

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
6464
import org.elasticsearch.xpack.inference.mapper.SemanticTextUtils;
6565
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
66+
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
6667

6768
import java.io.IOException;
6869
import java.util.ArrayList;
@@ -78,6 +79,8 @@
7879
import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE;
7980
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunks;
8081
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunksLegacy;
82+
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes;
83+
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes;
8184

8285
/**
8386
* A {@link MappedActionFilter} that intercepts {@link BulkShardRequest} to apply inference on fields specified
@@ -112,20 +115,23 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
112115
private final ModelRegistry modelRegistry;
113116
private final XPackLicenseState licenseState;
114117
private final IndexingPressure indexingPressure;
118+
private final InferenceStats inferenceStats;
115119
private volatile long batchSizeInBytes;
116120

117121
public ShardBulkInferenceActionFilter(
118122
ClusterService clusterService,
119123
InferenceServiceRegistry inferenceServiceRegistry,
120124
ModelRegistry modelRegistry,
121125
XPackLicenseState licenseState,
122-
IndexingPressure indexingPressure
126+
IndexingPressure indexingPressure,
127+
InferenceStats inferenceStats
123128
) {
124129
this.clusterService = clusterService;
125130
this.inferenceServiceRegistry = inferenceServiceRegistry;
126131
this.modelRegistry = modelRegistry;
127132
this.licenseState = licenseState;
128133
this.indexingPressure = indexingPressure;
134+
this.inferenceStats = inferenceStats;
129135
this.batchSizeInBytes = INDICES_INFERENCE_BATCH_SIZE.get(clusterService.getSettings()).getBytes();
130136
clusterService.getClusterSettings().addSettingsUpdateConsumer(INDICES_INFERENCE_BATCH_SIZE, this::setBatchSize);
131137
}
@@ -386,10 +392,12 @@ public void onFailure(Exception exc) {
386392
public void onResponse(List<ChunkedInference> results) {
387393
try (onFinish) {
388394
var requestsIterator = requests.iterator();
395+
int success = 0;
389396
for (ChunkedInference result : results) {
390397
var request = requestsIterator.next();
391398
var acc = inferenceResults.get(request.bulkItemIndex);
392399
if (result instanceof ChunkedInferenceError error) {
400+
recordRequestCountMetrics(inferenceProvider.model, 1, error.exception());
393401
acc.addFailure(
394402
new InferenceException(
395403
"Exception when running inference id [{}] on field [{}]",
@@ -399,6 +407,7 @@ public void onResponse(List<ChunkedInference> results) {
399407
)
400408
);
401409
} else {
410+
success++;
402411
acc.addOrUpdateResponse(
403412
new FieldInferenceResponse(
404413
request.field(),
@@ -412,12 +421,16 @@ public void onResponse(List<ChunkedInference> results) {
412421
);
413422
}
414423
}
424+
if (success > 0) {
425+
recordRequestCountMetrics(inferenceProvider.model, success, null);
426+
}
415427
}
416428
}
417429

418430
@Override
419431
public void onFailure(Exception exc) {
420432
try (onFinish) {
433+
recordRequestCountMetrics(inferenceProvider.model, requests.size(), exc);
421434
for (FieldInferenceRequest request : requests) {
422435
addInferenceResponseFailure(
423436
request.bulkItemIndex,
@@ -444,6 +457,14 @@ public void onFailure(Exception exc) {
444457
);
445458
}
446459

460+
private void recordRequestCountMetrics(Model model, int incrementBy, Throwable throwable) {
461+
Map<String, Object> requestCountAttributes = new HashMap<>();
462+
requestCountAttributes.putAll(modelAttributes(model));
463+
requestCountAttributes.putAll(responseAttributes(throwable));
464+
requestCountAttributes.put("inference_source", "semantic_text_bulk");
465+
inferenceStats.requestCount().incrementBy(incrementBy, requestCountAttributes);
466+
}
467+
447468
/**
448469
* Adds all inference requests associated with their respective inference IDs to the given {@code requestsMap}
449470
* for the specified {@code item}.

0 commit comments

Comments
 (0)