11
11
AutoModelForCausalLM ,
12
12
)
13
13
from common .app import logger
14
- from scheme import (
14
+ from common . scheme import (
15
15
ContentAnalysisHttpRequest ,
16
16
ContentAnalysisResponse ,
17
17
ContentsAnalysisResponse ,
@@ -169,15 +169,6 @@ def get_probabilities(self, logprobs, safe_token, unsafe_token):
169
169
return probabilities
170
170
171
171
def process_causal_lm (self , text ):
172
- """
173
- Process text using a causal language model.
174
-
175
- Args:
176
- text (str): Input text.
177
-
178
- Returns:
179
- List[ContentAnalysisResponse]: List of content analysis results.
180
- """
181
172
messages = [{"role" : "user" , "content" : text }]
182
173
content_analyses = []
183
174
for risk_name in self .risk_names :
@@ -205,26 +196,17 @@ def process_causal_lm(self, text):
205
196
detection = self .model_name ,
206
197
detection_type = "causal_lm" ,
207
198
score = prob_of_risk ,
208
- sequence_classification = risk_name ,
209
- sequence_probability = prob_of_risk ,
210
- token_classifications = None ,
211
- token_probabilities = None ,
212
199
text = text ,
213
200
evidences = [],
214
201
)
215
202
)
216
203
return content_analyses
217
204
218
- def process_sequence_classification (self , text ):
219
- """
220
- Process text using a sequence classification model.
221
-
222
- Args:
223
- text (str): Input text.
224
-
225
- Returns:
226
- List[ContentAnalysisResponse]: List of content analysis results.
227
- """
205
+ def process_sequence_classification (self , text , detector_params = None , threshold = None ):
206
+ detector_params = detector_params or {}
207
+ if threshold is None :
208
+ threshold = detector_params .get ("threshold" , 0.5 )
209
+ non_trigger_labels = set (detector_params .get ("non_trigger_labels" , []))
228
210
content_analyses = []
229
211
tokenized = self .tokenizer (
230
212
text ,
@@ -238,26 +220,21 @@ def process_sequence_classification(self, text):
238
220
239
221
with torch .no_grad ():
240
222
logits = self .model (** tokenized ).logits
241
- prediction = torch .argmax (logits , dim = 1 ).detach ().cpu ().numpy ().tolist ()[0 ]
242
- prediction_labels = self .model .config .id2label [prediction ]
243
- probability = (
244
- torch .softmax (logits , dim = 1 ).detach ().cpu ().numpy ()[:, 1 ].tolist ()[0 ]
245
- )
246
- content_analyses .append (
247
- ContentAnalysisResponse (
248
- start = 0 ,
249
- end = len (text ),
250
- detection = self .model_name ,
251
- detection_type = "sequence_classification" ,
252
- score = probability ,
253
- sequence_classification = prediction_labels ,
254
- sequence_probability = probability ,
255
- token_classifications = None ,
256
- token_probabilities = None ,
257
- text = text ,
258
- evidences = [],
259
- )
260
- )
223
+ probabilities = torch .softmax (logits , dim = 1 ).detach ().cpu ().numpy ()[0 ]
224
+ for idx , prob in enumerate (probabilities ):
225
+ label = self .model .config .id2label [idx ]
226
+ if prob >= threshold and label not in non_trigger_labels :
227
+ content_analyses .append (
228
+ ContentAnalysisResponse (
229
+ start = 0 ,
230
+ end = len (text ),
231
+ detection = getattr (self .model .config , "problem_type" , "sequence_classification" ),
232
+ detection_type = label ,
233
+ score = prob ,
234
+ text = text ,
235
+ evidences = [],
236
+ )
237
+ )
261
238
return content_analyses
262
239
263
240
def run (self , input : ContentAnalysisHttpRequest ) -> ContentsAnalysisResponse :
@@ -275,7 +252,9 @@ def run(self, input: ContentAnalysisHttpRequest) -> ContentsAnalysisResponse:
275
252
if self .is_causal_lm :
276
253
analyses = self .process_causal_lm (text )
277
254
elif self .is_sequence_classifier :
278
- analyses = self .process_sequence_classification (text )
255
+ analyses = self .process_sequence_classification (
256
+ text , detector_params = getattr (input , "detector_params" , None )
257
+ )
279
258
else :
280
259
raise ValueError ("Unsupported model type for analysis." )
281
260
contents_analyses .append (analyses )
0 commit comments