Skip to content

Commit 7f7ba0c

Browse files
authored
Metrics for stuck async requests (#471)
1 parent 37af55c commit 7f7ba0c

File tree

8 files changed

+80
-31
lines changed

8 files changed

+80
-31
lines changed
Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,19 @@
11
from typing import Sequence
22

3-
from .app import TaskVisibility, celery_app, get_all_db_indexes, get_redis_host_port, inspect_app
3+
from .app import (
4+
DEFAULT_TASK_VISIBILITY_SECONDS,
5+
TaskVisibility,
6+
celery_app,
7+
get_all_db_indexes,
8+
get_redis_host_port,
9+
inspect_app,
10+
)
411

512
__all__: Sequence[str] = (
613
"celery_app",
714
"get_all_db_indexes",
815
"get_redis_host_port",
916
"inspect_app",
1017
"TaskVisibility",
18+
"DEFAULT_TASK_VISIBILITY_SECONDS",
1119
)

model-engine/model_engine_server/core/celery/app.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
] = "model_engine_server.core.celery.abs:AzureBlockBlobBackend"
3232

3333

34+
DEFAULT_TASK_VISIBILITY_SECONDS = 86400
35+
36+
3437
@unique
3538
class TaskVisibility(IntEnum):
3639
"""

model-engine/model_engine_server/domain/gateways/monitoring_metrics_gateway.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,12 @@ class MonitoringMetricsGateway(ABC):
2424
def emit_attempted_build_metric(self):
2525
"""
2626
Service builder attempted metric
27-
2827
"""
2928

3029
@abstractmethod
3130
def emit_successful_build_metric(self):
3231
"""
3332
Service builder succeeded metric
34-
3533
"""
3634

3735
@abstractmethod
@@ -44,49 +42,42 @@ def emit_build_time_metric(self, duration_seconds: float):
4442
def emit_image_build_cache_hit_metric(self, image_type: str):
4543
"""
4644
Service builder image build cache hit metric
47-
4845
"""
4946

5047
@abstractmethod
5148
def emit_image_build_cache_miss_metric(self, image_type: str):
5249
"""
5350
Service builder image build cache miss metric
54-
5551
"""
5652

5753
@abstractmethod
5854
def emit_docker_failed_build_metric(self):
5955
"""
6056
Service builder docker build failed metric
61-
6257
"""
6358

6459
@abstractmethod
6560
def emit_database_cache_hit_metric(self):
6661
"""
6762
Successful database cache metric
68-
6963
"""
7064

7165
@abstractmethod
7266
def emit_database_cache_miss_metric(self):
7367
"""
7468
Missed database cache metric
75-
7669
"""
7770

7871
@abstractmethod
7972
def emit_route_call_metric(self, route: str, metadata: MetricMetadata):
8073
"""
8174
Route call metric
82-
8375
"""
8476
pass
8577

8678
@abstractmethod
8779
def emit_token_count_metrics(self, token_usage: TokenUsage, metadata: MetricMetadata):
8880
"""
8981
Token count metrics
90-
9182
"""
9283
pass

model-engine/model_engine_server/inference/domain/gateways/inference_monitoring_metrics_gateway.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,23 @@ def emit_successful_post_inference_hook(self, hook: str):
3030
Args:
3131
hook: The name of the hook
3232
"""
33+
34+
@abstractmethod
35+
def emit_async_task_received_metric(self, queue_name: str):
36+
"""
37+
Async task received metric
38+
39+
Args:
40+
queue_name: The name of the Celery queue
41+
"""
42+
pass
43+
44+
@abstractmethod
45+
def emit_async_task_stuck_metric(self, queue_name: str):
46+
"""
47+
Async task stuck metric
48+
49+
Args:
50+
queue_name: The name of the Celery queue
51+
"""
52+
pass

model-engine/model_engine_server/inference/forwarding/celery_forwarder.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
import argparse
22
import json
3+
from datetime import datetime, timedelta
34
from typing import Any, Dict, Optional, TypedDict, Union
45

56
from celery import Celery, Task, states
67
from model_engine_server.common.constants import DEFAULT_CELERY_TASK_NAME, LIRA_CELERY_TASK_NAME
78
from model_engine_server.common.dtos.model_endpoints import BrokerType
89
from model_engine_server.common.dtos.tasks import EndpointPredictV1Request
9-
from model_engine_server.core.celery import TaskVisibility, celery_app
10+
from model_engine_server.core.celery import (
11+
DEFAULT_TASK_VISIBILITY_SECONDS,
12+
TaskVisibility,
13+
celery_app,
14+
)
1015
from model_engine_server.core.config import infra_config
1116
from model_engine_server.core.loggers import logger_name, make_logger
1217
from model_engine_server.core.utils.format import format_stacktrace
@@ -15,6 +20,9 @@
1520
LoadForwarder,
1621
load_named_config,
1722
)
23+
from model_engine_server.inference.infra.gateways.datadog_inference_monitoring_metrics_gateway import (
24+
DatadogInferenceMonitoringMetricsGateway,
25+
)
1826

1927
logger = make_logger(logger_name())
2028

@@ -68,6 +76,8 @@ def create_celery_service(
6876
backend_protocol=backend_protocol,
6977
)
7078

79+
monitoring_metrics_gateway = DatadogInferenceMonitoringMetricsGateway()
80+
7181
class ErrorHandlingTask(Task):
7282
"""Sets a 'custom' field with error in the Task response for FAILURE.
7383
@@ -112,13 +122,18 @@ def after_return(
112122
# See documentation for options:
113123
# https://docs.celeryproject.org/en/stable/userguide/tasks.html#list-of-options
114124
@app.task(base=ErrorHandlingTask, name=LIRA_CELERY_TASK_NAME, track_started=True)
115-
def exec_func(payload, *ignored_args, **ignored_kwargs):
125+
def exec_func(payload, arrival_timestamp, *ignored_args, **ignored_kwargs):
116126
if len(ignored_args) > 0:
117127
logger.warning(f"Ignoring {len(ignored_args)} positional arguments: {ignored_args=}")
118128
if len(ignored_kwargs) > 0:
119129
logger.warning(f"Ignoring {len(ignored_kwargs)} keyword arguments: {ignored_kwargs=}")
120130
try:
121-
return forwarder(payload)
131+
monitoring_metrics_gateway.emit_async_task_received_metric(queue_name)
132+
result = forwarder(payload)
133+
request_duration = datetime.now() - arrival_timestamp
134+
if request_duration > timedelta(seconds=DEFAULT_TASK_VISIBILITY_SECONDS):
135+
monitoring_metrics_gateway.emit_async_task_stuck_metric(queue_name)
136+
return result
122137
except Exception:
123138
logger.exception("Celery service failed to respond to request.")
124139
raise
@@ -131,8 +146,8 @@ def exec_func(payload, *ignored_args, **ignored_kwargs):
131146
name=DEFAULT_CELERY_TASK_NAME,
132147
track_started=True,
133148
)
134-
def exec_func_pre_lira(payload, *ignored_args, **ignored_kwargs):
135-
return exec_func(payload, *ignored_args, **ignored_kwargs)
149+
def exec_func_pre_lira(payload, arrival_timestamp, *ignored_args, **ignored_kwargs):
150+
return exec_func(payload, arrival_timestamp, *ignored_args, **ignored_kwargs)
136151

137152
return app
138153

model-engine/model_engine_server/inference/infra/gateways/datadog_inference_monitoring_metrics_gateway.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,11 @@ def emit_attempted_post_inference_hook(self, hook: str):
1010

1111
def emit_successful_post_inference_hook(self, hook: str):
1212
statsd.increment(f"scale_launch.post_inference_hook.{hook}.success")
13+
14+
def emit_async_task_received_metric(self, queue_name: str):
15+
statsd.increment(
16+
"scale_launch.async_task.received.count", tags=[f"queue_name:{queue_name}"]
17+
) # pragma: no cover
18+
19+
def emit_async_task_stuck_metric(self, queue_name: str):
20+
statsd.increment("scale_launch.async_task.stuck.count", tags=[f"queue_name:{queue_name}"])

model-engine/model_engine_server/infra/gateways/live_async_model_endpoint_inference_gateway.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
from datetime import datetime
23

34
from model_engine_server.common.constants import DEFAULT_CELERY_TASK_NAME
45
from model_engine_server.common.dtos.tasks import (
@@ -37,7 +38,7 @@ def create_task(
3738
send_task_response = self.task_queue_gateway.send_task(
3839
task_name=task_name,
3940
queue_name=topic,
40-
args=[predict_args, predict_request.return_pickled],
41+
args=[predict_args, datetime.now(), predict_request.return_pickled],
4142
expires=task_timeout_seconds,
4243
)
4344
return CreateAsyncTaskV1Response(task_id=send_task_response.task_id)

model-engine/tests/unit/infra/gateways/test_live_async_model_inference_gateway.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
from datetime import datetime, timedelta
23
from typing import Any
34

45
import pytest
@@ -22,10 +23,11 @@ def test_task_create_get_url(
2223
task_id = create_response.task_id
2324
task_queue_gateway: Any = fake_live_async_model_inference_gateway.task_queue_gateway
2425
assert len(task_queue_gateway.queue) == 1
25-
assert task_queue_gateway.queue[task_id]["args"] == [
26-
endpoint_predict_request_1[0].dict(),
27-
endpoint_predict_request_1[0].return_pickled,
28-
]
26+
assert task_queue_gateway.queue[task_id]["args"][0] == endpoint_predict_request_1[0].dict()
27+
assert (datetime.now() - task_queue_gateway.queue[task_id]["args"][1]) < timedelta(seconds=1)
28+
assert (
29+
task_queue_gateway.queue[task_id]["args"][2] == endpoint_predict_request_1[0].return_pickled
30+
)
2931

3032
get_response_1 = fake_live_async_model_inference_gateway.get_task(task_id)
3133
assert get_response_1 == GetAsyncTaskV1Response(task_id=task_id, status=TaskStatus.PENDING)
@@ -49,17 +51,18 @@ def test_task_create_get_args_callback(
4951
task_id = create_response.task_id
5052
task_queue_gateway: Any = fake_live_async_model_inference_gateway.task_queue_gateway
5153
assert len(task_queue_gateway.queue) == 1
52-
assert task_queue_gateway.queue[task_id]["args"] == [
53-
{
54-
"args": endpoint_predict_request_2[0].args.__root__,
55-
"url": None,
56-
"cloudpickle": None,
57-
"callback_auth": json.loads(endpoint_predict_request_2[0].callback_auth.json()),
58-
"callback_url": endpoint_predict_request_2[0].callback_url,
59-
"return_pickled": endpoint_predict_request_2[0].return_pickled,
60-
},
61-
endpoint_predict_request_2[0].return_pickled,
62-
]
54+
assert task_queue_gateway.queue[task_id]["args"][0] == {
55+
"args": endpoint_predict_request_2[0].args.__root__,
56+
"url": None,
57+
"cloudpickle": None,
58+
"callback_auth": json.loads(endpoint_predict_request_2[0].callback_auth.json()),
59+
"callback_url": endpoint_predict_request_2[0].callback_url,
60+
"return_pickled": endpoint_predict_request_2[0].return_pickled,
61+
}
62+
assert (datetime.now() - task_queue_gateway.queue[task_id]["args"][1]) < timedelta(seconds=1)
63+
assert (
64+
task_queue_gateway.queue[task_id]["args"][2] == endpoint_predict_request_2[0].return_pickled
65+
)
6366

6467
get_response_1 = fake_live_async_model_inference_gateway.get_task(task_id)
6568
assert get_response_1 == GetAsyncTaskV1Response(task_id=task_id, status=TaskStatus.PENDING)

0 commit comments

Comments
 (0)