33from datetime import datetime , timedelta
44from typing import Any , Dict , Optional , TypedDict , Union
55
6+ from aiohttp import ClientConnectionError
67from celery import Celery , Task , states
8+ from gevent import monkey
79from model_engine_server .common .constants import DEFAULT_CELERY_TASK_NAME , LIRA_CELERY_TASK_NAME
810from model_engine_server .common .dtos .model_endpoints import BrokerType
911from model_engine_server .common .dtos .tasks import EndpointPredictV1Request
2325from model_engine_server .inference .infra .gateways .datadog_inference_monitoring_metrics_gateway import (
2426 DatadogInferenceMonitoringMetricsGateway ,
2527)
26- from requests import ConnectionError
28+
29+ monkey .patch_all ()
2730
2831logger = make_logger (logger_name ())
2932
@@ -132,9 +135,9 @@ def after_return(
132135 base = ErrorHandlingTask ,
133136 name = LIRA_CELERY_TASK_NAME ,
134137 track_started = True ,
135- autoretry_for = (ConnectionError ,),
138+ autoretry_for = (ClientConnectionError ,),
136139 )
137- def exec_func (payload , arrival_timestamp , * ignored_args , ** ignored_kwargs ):
140+ async def exec_func (payload , arrival_timestamp , * ignored_args , ** ignored_kwargs ):
138141 if len (ignored_args ) > 0 :
139142 logger .warning (f"Ignoring { len (ignored_args )} positional arguments: { ignored_args = } " )
140143 if len (ignored_kwargs ) > 0 :
@@ -144,7 +147,7 @@ def exec_func(payload, arrival_timestamp, *ignored_args, **ignored_kwargs):
144147 # Don't fail the celery task even if there's a status code
145148 # (otherwise we can't really control what gets put in the result attribute)
146149 # in the task (https://docs.celeryq.dev/en/stable/reference/celery.result.html#celery.result.AsyncResult.status)
147- result = forwarder (payload )
150+ result = await forwarder . forward (payload )
148151 request_duration = datetime .now () - arrival_timestamp
149152 if request_duration > timedelta (seconds = DEFAULT_TASK_VISIBILITY_SECONDS ):
150153 monitoring_metrics_gateway .emit_async_task_stuck_metric (queue_name )
@@ -177,12 +180,7 @@ def start_celery_service(
177180 concurrency = concurrency ,
178181 loglevel = "INFO" ,
179182 optimization = "fair" ,
180- # Don't use pool="solo" so we can send multiple concurrent requests over
181- # Historically, pool="solo" argument fixes the known issues of celery and some of the libraries.
182- # Particularly asyncio and torchvision transformers. This isn't relevant since celery-forwarder
183- # is quite lightweight
184- # TODO: we should probably use eventlet or gevent for the pool, since
185- # the forwarder is nearly the most extreme example of IO bound.
183+ pool = "gevent" ,
186184 )
187185 worker .start ()
188186
0 commit comments