Skip to content

Commit d05094f

Browse files
committed
🚧 added means of specifiying the safe labels via env variables
1 parent ceed982 commit d05094f

File tree

1 file changed

+28
-5
lines changed

1 file changed

+28
-5
lines changed

detectors/huggingface/detector.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import sys
33

44
sys.path.insert(0, os.path.abspath(".."))
5+
import json
56
import math
67
import torch
78
from transformers import (
@@ -18,6 +19,22 @@
1819
)
1920
import gc
2021

22+
def _parse_safe_labels_env():
23+
if os.environ.get("SAFE_LABELS"):
24+
try:
25+
parsed = json.loads(os.environ.get("SAFE_LABELS"))
26+
if isinstance(parsed, (int, str)):
27+
logger.info(f"SAFE_LABELS env var: {parsed}")
28+
return [parsed]
29+
if isinstance(parsed, list) and all(isinstance(x, (int, str)) for x in parsed):
30+
logger.info(f"SAFE_LABELS env var: {parsed}")
31+
return parsed
32+
except Exception as e:
33+
logger.warning(f"Could not parse SAFE_LABELS env var: {e}. Defaulting to [0].")
34+
return [0]
35+
logger.info("SAFE_LABELS env var not set: defaulting to [0].")
36+
return [0]
37+
2138
class Detector:
2239
risk_names = [
2340
"harm",
@@ -37,6 +54,7 @@ def __init__(self):
3754
self.model = None
3855
self.cuda_device = None
3956
self.model_name = "unknown"
57+
self.safe_labels = _parse_safe_labels_env()
4058

4159
model_files_path = os.environ.get("MODEL_DIR")
4260
if not model_files_path:
@@ -206,7 +224,9 @@ def process_sequence_classification(self, text, detector_params=None, threshold=
206224
detector_params = detector_params or {}
207225
if threshold is None:
208226
threshold = detector_params.get("threshold", 0.5)
209-
non_trigger_labels = set(detector_params.get("non_trigger_labels", []))
227+
# Merge safe_labels from env and request
228+
request_safe_labels = set(detector_params.get("safe_labels", []))
229+
all_safe_labels = set(self.safe_labels) | request_safe_labels
210230
content_analyses = []
211231
tokenized = self.tokenizer(
212232
text,
@@ -223,7 +243,12 @@ def process_sequence_classification(self, text, detector_params=None, threshold=
223243
probabilities = torch.softmax(logits, dim=1).detach().cpu().numpy()[0]
224244
for idx, prob in enumerate(probabilities):
225245
label = self.model.config.id2label[idx]
226-
if prob >= threshold and label not in non_trigger_labels:
246+
# Exclude by index or label name
247+
if (
248+
prob >= threshold
249+
and idx not in all_safe_labels
250+
and label not in all_safe_labels
251+
):
227252
content_analyses.append(
228253
ContentAnalysisResponse(
229254
start=0,
@@ -260,10 +285,8 @@ def run(self, input: ContentAnalysisHttpRequest) -> ContentsAnalysisResponse:
260285
contents_analyses.append(analyses)
261286
return contents_analyses
262287

263-
264288
def close(self) -> None:
265289
"""Clean up model and tokenizer resources."""
266-
267290
if self.model:
268291
if hasattr(self.model, 'to') and hasattr(self.model, 'device') and self.model.device.type != "cpu":
269292
self.model = self.model.to(torch.device("cpu"))
@@ -275,4 +298,4 @@ def close(self) -> None:
275298
gc.collect()
276299

277300
if torch.cuda.is_available():
278-
torch.cuda.empty_cache()
301+
torch.cuda.empty_cache()

0 commit comments

Comments
 (0)