File tree Expand file tree Collapse file tree 2 files changed +12
-4
lines changed
Expand file tree Collapse file tree 2 files changed +12
-4
lines changed Original file line number Diff line number Diff line change 1+ import asyncio
12import logging
3+ from concurrent .futures import ThreadPoolExecutor
24from os import cpu_count , environ
35
46import torch
911
1012from .utils import clear_text , measure_time
1113
14+ loop = asyncio .get_running_loop ()
15+ loop .set_default_executor (ThreadPoolExecutor ())
16+
1217# Environment
1318load_dotenv ()
1419
@@ -142,7 +147,7 @@ def call_model(text: str) -> Tensor:
142147 return outputs .logits
143148
144149
145- def predict (text : str ) -> bool :
150+ def sync_predict (text : str ) -> bool :
146151 text = clear_text (text ).lower ()
147152 if not text :
148153 return False
@@ -157,3 +162,7 @@ def predict(text: str) -> bool:
157162
158163 log_prediction (text , logits , result , execution_time )
159164 return result
165+
166+
167+ async def async_predict (text : str ) -> bool :
168+ return await loop .run_in_executor (None , sync_predict , text )
Original file line number Diff line number Diff line change 1- import asyncio
21import time
32
43from prometheus_client import (
1413from starlette .responses import PlainTextResponse , Response
1514from starlette .routing import Route , Router
1615
16+ from .model import async_predict as model_predict
1717from .model import metrics_prefix
18- from .model import predict as call_model
1918
2019# Initialize Prometheus metrics
2120disable_created_metrics ()
@@ -44,7 +43,7 @@ async def predict(request: Request):
4443
4544 try :
4645 text = request .query_params .get ("text" )
47- result = await asyncio . to_thread ( call_model , text ) if text else False
46+ result = await model_predict ( text ) if text else False
4847
4948 label = "toxic" if result else "non_toxic"
5049 REQUEST_COUNT .labels (endpoint = endpoint , result = label ).inc ()
You can’t perform that action at this time.
0 commit comments