Skip to content

Commit c778e84

Browse files
committed
perf: set configurable thread for PyTorch operations
1 parent d9fa4b7 commit c778e84

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ curl -G 'http://localhost:8000/predict' --data-urlencode 'text=test text'
105105
- `TOXICITY_THRESHOLD` - the level below which the text will be considered toxic. Default: `0` - the argmax function is used. This is a float value, example: `-0.2`, `-0.05`, `1`.
106106
- `WEB_CONCURRENCY` - Number of worker processes. Defaults to the value of this environment variable if set, otherwise 1. Note: Not compatible with `--reload` option.
107107
- `METRICS_PREFIX` - Prefix for Prometheus metrics names. Default: `toxicity_detector`. Allows customization of metric names to avoid conflicts in a shared Prometheus setup.
108+
- `TORCH_THREADS` - Number of threads to use for PyTorch operations. Defaults to the value of this environment variable if set, otherwise the number of CPU cores.
108109

109110
# Prometheus Metrics
110111
This project exposes several Prometheus metrics for monitoring the toxicity detector's performance and behavior. All metric names are prefixed with the value of the `METRICS_PREFIX` environment variable (default: `toxicity_detector`). Below is a list of available metrics and what they collect:

app/model.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,21 @@
1111

1212
from .utils import clear_text, measure_time
1313

14-
loop = asyncio.get_running_loop()
15-
loop.set_default_executor(ThreadPoolExecutor())
14+
cpu_cores = cpu_count()
1615

1716
# Environment
1817
load_dotenv()
1918

2019
model_path = environ.get("MODEL_PATH", "./model")
2120
threshold = float(environ.get("TOXICITY_THRESHOLD", 0))
2221
metrics_prefix = environ.get("METRICS_PREFIX", "toxicity_detector")
22+
num_threads = int(environ.get("TORCH_THREADS", cpu_cores or 1))
2323

24+
# Configuring Thread Settings
25+
torch.set_num_threads(num_threads)
26+
27+
loop = asyncio.get_running_loop()
28+
loop.set_default_executor(ThreadPoolExecutor())
2429

2530
# Initialize Prometheus metrics
2631
MODEL_ERRORS = Counter(
@@ -96,7 +101,6 @@ def wrapper(*args, **kwargs):
96101
# Log PyTorch backends and devices information
97102
logger.info("CUDA available: %s", torch.cuda.is_available())
98103
logger.info("Current device: %s", device)
99-
cpu_cores = cpu_count()
100104
logger.info(
101105
"Number of CPU cores: %s", cpu_cores if cpu_cores is not None else "Unknown"
102106
)

0 commit comments

Comments
 (0)