2
2
import sys
3
3
4
4
sys .path .insert (0 , os .path .abspath (".." ))
5
+ import json
5
6
import math
6
7
import torch
7
8
from transformers import (
11
12
AutoModelForCausalLM ,
12
13
)
13
14
from common .app import logger
14
- from scheme import (
15
+ from common . scheme import (
15
16
ContentAnalysisHttpRequest ,
16
17
ContentAnalysisResponse ,
17
18
ContentsAnalysisResponse ,
18
19
)
19
20
import gc
20
21
22
+ def _parse_safe_labels_env ():
23
+ if os .environ .get ("SAFE_LABELS" ):
24
+ try :
25
+ parsed = json .loads (os .environ .get ("SAFE_LABELS" ))
26
+ if isinstance (parsed , (int , str )):
27
+ logger .info (f"SAFE_LABELS env var: { parsed } " )
28
+ return [parsed ]
29
+ if isinstance (parsed , list ) and all (isinstance (x , (int , str )) for x in parsed ):
30
+ logger .info (f"SAFE_LABELS env var: { parsed } " )
31
+ return parsed
32
+ except Exception as e :
33
+ logger .warning (f"Could not parse SAFE_LABELS env var: { e } . Defaulting to [0]." )
34
+ return [0 ]
35
+ logger .info ("SAFE_LABELS env var not set: defaulting to [0]." )
36
+ return [0 ]
37
+
21
38
class Detector :
22
39
risk_names = [
23
40
"harm" ,
@@ -37,6 +54,7 @@ def __init__(self):
37
54
self .model = None
38
55
self .cuda_device = None
39
56
self .model_name = "unknown"
57
+ self .safe_labels = _parse_safe_labels_env ()
40
58
41
59
model_files_path = os .environ .get ("MODEL_DIR" )
42
60
if not model_files_path :
@@ -169,15 +187,6 @@ def get_probabilities(self, logprobs, safe_token, unsafe_token):
169
187
return probabilities
170
188
171
189
def process_causal_lm (self , text ):
172
- """
173
- Process text using a causal language model.
174
-
175
- Args:
176
- text (str): Input text.
177
-
178
- Returns:
179
- List[ContentAnalysisResponse]: List of content analysis results.
180
- """
181
190
messages = [{"role" : "user" , "content" : text }]
182
191
content_analyses = []
183
192
for risk_name in self .risk_names :
@@ -205,26 +214,19 @@ def process_causal_lm(self, text):
205
214
detection = self .model_name ,
206
215
detection_type = "causal_lm" ,
207
216
score = prob_of_risk ,
208
- sequence_classification = risk_name ,
209
- sequence_probability = prob_of_risk ,
210
- token_classifications = None ,
211
- token_probabilities = None ,
212
217
text = text ,
213
218
evidences = [],
214
219
)
215
220
)
216
221
return content_analyses
217
222
218
- def process_sequence_classification (self , text ):
219
- """
220
- Process text using a sequence classification model.
221
-
222
- Args:
223
- text (str): Input text.
224
-
225
- Returns:
226
- List[ContentAnalysisResponse]: List of content analysis results.
227
- """
223
+ def process_sequence_classification (self , text , detector_params = None , threshold = None ):
224
+ detector_params = detector_params or {}
225
+ if threshold is None :
226
+ threshold = detector_params .get ("threshold" , 0.5 )
227
+ # Merge safe_labels from env and request
228
+ request_safe_labels = set (detector_params .get ("safe_labels" , []))
229
+ all_safe_labels = set (self .safe_labels ) | request_safe_labels
228
230
content_analyses = []
229
231
tokenized = self .tokenizer (
230
232
text ,
@@ -238,26 +240,26 @@ def process_sequence_classification(self, text):
238
240
239
241
with torch .no_grad ():
240
242
logits = self .model (** tokenized ).logits
241
- prediction = torch .argmax (logits , dim = 1 ).detach ().cpu ().numpy (). tolist ()[0 ]
242
- prediction_labels = self . model . config . id2label [ prediction ]
243
- probability = (
244
- torch . softmax ( logits , dim = 1 ). detach (). cpu (). numpy ()[:, 1 ]. tolist ()[ 0 ]
245
- )
246
- content_analyses . append (
247
- ContentAnalysisResponse (
248
- start = 0 ,
249
- end = len ( text ),
250
- detection = self . model_name ,
251
- detection_type = "sequence_classification" ,
252
- score = probability ,
253
- sequence_classification = prediction_labels ,
254
- sequence_probability = probability ,
255
- token_classifications = None ,
256
- token_probabilities = None ,
257
- text = text ,
258
- evidences = [],
259
- )
260
- )
243
+ probabilities = torch .softmax (logits , dim = 1 ).detach ().cpu ().numpy ()[0 ]
244
+ for idx , prob in enumerate ( probabilities ):
245
+ label = self . model . config . id2label [ idx ]
246
+ # Exclude by index or label name
247
+ if (
248
+ prob >= threshold
249
+ and idx not in all_safe_labels
250
+ and label not in all_safe_labels
251
+ ):
252
+ content_analyses . append (
253
+ ContentAnalysisResponse (
254
+ start = 0 ,
255
+ end = len ( text ) ,
256
+ detection = getattr ( self . model . config , "problem_type" , "sequence_classification" ) ,
257
+ detection_type = label ,
258
+ score = prob ,
259
+ text = text ,
260
+ evidences = [],
261
+ )
262
+ )
261
263
return content_analyses
262
264
263
265
def run (self , input : ContentAnalysisHttpRequest ) -> ContentsAnalysisResponse :
@@ -275,16 +277,16 @@ def run(self, input: ContentAnalysisHttpRequest) -> ContentsAnalysisResponse:
275
277
if self .is_causal_lm :
276
278
analyses = self .process_causal_lm (text )
277
279
elif self .is_sequence_classifier :
278
- analyses = self .process_sequence_classification (text )
280
+ analyses = self .process_sequence_classification (
281
+ text , detector_params = getattr (input , "detector_params" , None )
282
+ )
279
283
else :
280
284
raise ValueError ("Unsupported model type for analysis." )
281
285
contents_analyses .append (analyses )
282
286
return contents_analyses
283
287
284
-
285
288
def close (self ) -> None :
286
289
"""Clean up model and tokenizer resources."""
287
-
288
290
if self .model :
289
291
if hasattr (self .model , 'to' ) and hasattr (self .model , 'device' ) and self .model .device .type != "cpu" :
290
292
self .model = self .model .to (torch .device ("cpu" ))
@@ -296,4 +298,4 @@ def close(self) -> None:
296
298
gc .collect ()
297
299
298
300
if torch .cuda .is_available ():
299
- torch .cuda .empty_cache ()
301
+ torch .cuda .empty_cache ()
0 commit comments