|
3 | 3 | from datetime import datetime, timedelta |
4 | 4 | from typing import Any, Dict, Optional, TypedDict, Union |
5 | 5 |
|
| 6 | +from aiohttp import ClientConnectionError |
6 | 7 | from celery import Celery, Task, states |
| 8 | +from gevent import monkey |
7 | 9 | from model_engine_server.common.constants import DEFAULT_CELERY_TASK_NAME, LIRA_CELERY_TASK_NAME |
8 | 10 | from model_engine_server.common.dtos.model_endpoints import BrokerType |
9 | 11 | from model_engine_server.common.dtos.tasks import EndpointPredictV1Request |
|
23 | 25 | from model_engine_server.inference.infra.gateways.datadog_inference_monitoring_metrics_gateway import ( |
24 | 26 | DatadogInferenceMonitoringMetricsGateway, |
25 | 27 | ) |
26 | | -from requests import ConnectionError |
| 28 | + |
| 29 | +monkey.patch_all() |
27 | 30 |
|
28 | 31 | logger = make_logger(logger_name()) |
29 | 32 |
|
@@ -132,7 +135,7 @@ def after_return( |
132 | 135 | base=ErrorHandlingTask, |
133 | 136 | name=LIRA_CELERY_TASK_NAME, |
134 | 137 | track_started=True, |
135 | | - autoretry_for=(ConnectionError,), |
| 138 | + autoretry_for=(ClientConnectionError,), |
136 | 139 | ) |
137 | 140 | def exec_func(payload, arrival_timestamp, *ignored_args, **ignored_kwargs): |
138 | 141 | if len(ignored_args) > 0: |
@@ -177,12 +180,7 @@ def start_celery_service( |
177 | 180 | concurrency=concurrency, |
178 | 181 | loglevel="INFO", |
179 | 182 | optimization="fair", |
180 | | - # Don't use pool="solo" so we can send multiple concurrent requests over |
181 | | - # Historically, pool="solo" argument fixes the known issues of celery and some of the libraries. |
182 | | - # Particularly asyncio and torchvision transformers. This isn't relevant since celery-forwarder |
183 | | - # is quite lightweight |
184 | | - # TODO: we should probably use eventlet or gevent for the pool, since |
185 | | - # the forwarder is nearly the most extreme example of IO bound. |
| 183 | + pool="gevent", |
186 | 184 | ) |
187 | 185 | worker.start() |
188 | 186 |
|
|
0 commit comments