Skip to content

Commit 0728de3

Browse files
committed
Update celery forwarder to use greenlets instead of processes
1 parent a33d5c2 commit 0728de3

File tree

1 file changed

+8
-10
lines changed

1 file changed

+8
-10
lines changed

model-engine/model_engine_server/inference/forwarding/celery_forwarder.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
from datetime import datetime, timedelta
44
from typing import Any, Dict, Optional, TypedDict, Union
55

6+
from aiohttp import ClientConnectionError
67
from celery import Celery, Task, states
8+
from gevent import monkey
79
from model_engine_server.common.constants import DEFAULT_CELERY_TASK_NAME, LIRA_CELERY_TASK_NAME
810
from model_engine_server.common.dtos.model_endpoints import BrokerType
911
from model_engine_server.common.dtos.tasks import EndpointPredictV1Request
@@ -23,7 +25,8 @@
2325
from model_engine_server.inference.infra.gateways.datadog_inference_monitoring_metrics_gateway import (
2426
DatadogInferenceMonitoringMetricsGateway,
2527
)
26-
from requests import ConnectionError
28+
29+
monkey.patch_all()
2730

2831
logger = make_logger(logger_name())
2932

@@ -132,9 +135,9 @@ def after_return(
132135
base=ErrorHandlingTask,
133136
name=LIRA_CELERY_TASK_NAME,
134137
track_started=True,
135-
autoretry_for=(ConnectionError,),
138+
autoretry_for=(ClientConnectionError,),
136139
)
137-
def exec_func(payload, arrival_timestamp, *ignored_args, **ignored_kwargs):
140+
async def exec_func(payload, arrival_timestamp, *ignored_args, **ignored_kwargs):
138141
if len(ignored_args) > 0:
139142
logger.warning(f"Ignoring {len(ignored_args)} positional arguments: {ignored_args=}")
140143
if len(ignored_kwargs) > 0:
@@ -144,7 +147,7 @@ def exec_func(payload, arrival_timestamp, *ignored_args, **ignored_kwargs):
144147
# Don't fail the celery task even if there's a status code
145148
# (otherwise we can't really control what gets put in the result attribute)
146149
# in the task (https://docs.celeryq.dev/en/stable/reference/celery.result.html#celery.result.AsyncResult.status)
147-
result = forwarder(payload)
150+
result = await forwarder.forward(payload)
148151
request_duration = datetime.now() - arrival_timestamp
149152
if request_duration > timedelta(seconds=DEFAULT_TASK_VISIBILITY_SECONDS):
150153
monitoring_metrics_gateway.emit_async_task_stuck_metric(queue_name)
@@ -177,12 +180,7 @@ def start_celery_service(
177180
concurrency=concurrency,
178181
loglevel="INFO",
179182
optimization="fair",
180-
# Don't use pool="solo" so we can send multiple concurrent requests over
181-
# Historically, pool="solo" argument fixes the known issues of celery and some of the libraries.
182-
# Particularly asyncio and torchvision transformers. This isn't relevant since celery-forwarder
183-
# is quite lightweight
184-
# TODO: we should probably use eventlet or gevent for the pool, since
185-
# the forwarder is nearly the most extreme example of IO bound.
183+
pool="gevent",
186184
)
187185
worker.start()
188186

0 commit comments

Comments
 (0)