Skip to content

Commit d9fa4b7

Browse files
committed
perf: improve throughput with threaded model execution
Use threading to run multiple predictions in parallel, increasing request processing throughput.
1 parent 5c5abfa commit d9fa4b7

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

app/model.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import asyncio
12
import logging
3+
from concurrent.futures import ThreadPoolExecutor
24
from os import cpu_count, environ
35

46
import torch
@@ -9,6 +11,9 @@
911

1012
from .utils import clear_text, measure_time
1113

14+
loop = asyncio.get_running_loop()
15+
loop.set_default_executor(ThreadPoolExecutor())
16+
1217
# Environment
1318
load_dotenv()
1419

@@ -142,7 +147,7 @@ def call_model(text: str) -> Tensor:
142147
return outputs.logits
143148

144149

145-
def predict(text: str) -> bool:
150+
def sync_predict(text: str) -> bool:
146151
text = clear_text(text).lower()
147152
if not text:
148153
return False
@@ -157,3 +162,7 @@ def predict(text: str) -> bool:
157162

158163
log_prediction(text, logits, result, execution_time)
159164
return result
165+
166+
167+
async def async_predict(text: str) -> bool:
168+
return await loop.run_in_executor(None, sync_predict, text)

app/server.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import asyncio
21
import time
32

43
from prometheus_client import (
@@ -14,8 +13,8 @@
1413
from starlette.responses import PlainTextResponse, Response
1514
from starlette.routing import Route, Router
1615

16+
from .model import async_predict as model_predict
1717
from .model import metrics_prefix
18-
from .model import predict as call_model
1918

2019
# Initialize Prometheus metrics
2120
disable_created_metrics()
@@ -44,7 +43,7 @@ async def predict(request: Request):
4443

4544
try:
4645
text = request.query_params.get("text")
47-
result = await asyncio.to_thread(call_model, text) if text else False
46+
result = await model_predict(text) if text else False
4847

4948
label = "toxic" if result else "non_toxic"
5049
REQUEST_COUNT.labels(endpoint=endpoint, result=label).inc()

0 commit comments

Comments
 (0)