2
2
import sys
3
3
4
4
sys .path .insert (0 , os .path .abspath (".." ))
5
+ import json
5
6
import math
6
7
import torch
7
8
from transformers import (
18
19
)
19
20
import gc
20
21
22
+ def _parse_safe_labels_env ():
23
+ if os .environ .get ("SAFE_LABELS" ):
24
+ try :
25
+ parsed = json .loads (os .environ .get ("SAFE_LABELS" ))
26
+ if isinstance (parsed , (int , str )):
27
+ logger .info (f"SAFE_LABELS env var: { parsed } " )
28
+ return [parsed ]
29
+ if isinstance (parsed , list ) and all (isinstance (x , (int , str )) for x in parsed ):
30
+ logger .info (f"SAFE_LABELS env var: { parsed } " )
31
+ return parsed
32
+ except Exception as e :
33
+ logger .warning (f"Could not parse SAFE_LABELS env var: { e } . Defaulting to [0]." )
34
+ return [0 ]
35
+ logger .info ("SAFE_LABELS env var not set: defaulting to [0]." )
36
+ return [0 ]
37
+
21
38
class Detector :
22
39
risk_names = [
23
40
"harm" ,
@@ -37,6 +54,7 @@ def __init__(self):
37
54
self .model = None
38
55
self .cuda_device = None
39
56
self .model_name = "unknown"
57
+ self .safe_labels = _parse_safe_labels_env ()
40
58
41
59
model_files_path = os .environ .get ("MODEL_DIR" )
42
60
if not model_files_path :
@@ -206,7 +224,9 @@ def process_sequence_classification(self, text, detector_params=None, threshold=
206
224
detector_params = detector_params or {}
207
225
if threshold is None :
208
226
threshold = detector_params .get ("threshold" , 0.5 )
209
- non_trigger_labels = set (detector_params .get ("non_trigger_labels" , []))
227
+ # Merge safe_labels from env and request
228
+ request_safe_labels = set (detector_params .get ("safe_labels" , []))
229
+ all_safe_labels = set (self .safe_labels ) | request_safe_labels
210
230
content_analyses = []
211
231
tokenized = self .tokenizer (
212
232
text ,
@@ -223,7 +243,12 @@ def process_sequence_classification(self, text, detector_params=None, threshold=
223
243
probabilities = torch .softmax (logits , dim = 1 ).detach ().cpu ().numpy ()[0 ]
224
244
for idx , prob in enumerate (probabilities ):
225
245
label = self .model .config .id2label [idx ]
226
- if prob >= threshold and label not in non_trigger_labels :
246
+ # Exclude by index or label name
247
+ if (
248
+ prob >= threshold
249
+ and idx not in all_safe_labels
250
+ and label not in all_safe_labels
251
+ ):
227
252
content_analyses .append (
228
253
ContentAnalysisResponse (
229
254
start = 0 ,
@@ -260,10 +285,8 @@ def run(self, input: ContentAnalysisHttpRequest) -> ContentsAnalysisResponse:
260
285
contents_analyses .append (analyses )
261
286
return contents_analyses
262
287
263
-
264
288
def close (self ) -> None :
265
289
"""Clean up model and tokenizer resources."""
266
-
267
290
if self .model :
268
291
if hasattr (self .model , 'to' ) and hasattr (self .model , 'device' ) and self .model .device .type != "cpu" :
269
292
self .model = self .model .to (torch .device ("cpu" ))
@@ -275,4 +298,4 @@ def close(self) -> None:
275
298
gc .collect ()
276
299
277
300
if torch .cuda .is_available ():
278
- torch .cuda .empty_cache ()
301
+ torch .cuda .empty_cache ()
0 commit comments