3
3
4
4
sys .path .insert (0 , os .path .abspath (".." ))
5
5
# from common.scheme import TextDetectionHttpRequest, TextDetectionResponse
6
-
6
+ import math
7
7
import torch .nn
8
8
from common .app import logger
9
9
from scheme import (
13
13
)
14
14
15
15
# Detector imports
16
- from transformers import (AutoConfig , AutoTokenizer , AutoModelForTokenClassification , AutoModelForSequenceClassification )
16
+ from transformers import (
17
+ AutoConfig ,
18
+ AutoTokenizer ,
19
+ AutoModelForTokenClassification ,
20
+ AutoModelForSequenceClassification ,
21
+ AutoModelForCausalLM ,
22
+ )
17
23
18
24
19
25
class Detector :
26
+ risk_names = [
27
+ "harm" ,
28
+ "social_bias" ,
29
+ "jailbreak" ,
30
+ "profanity" ,
31
+ "unethical_behavior" ,
32
+ "sexual_content" ,
33
+ "violence" ,
34
+ ]
35
+
20
36
def __init__ (self ):
21
37
# initialize the detector
22
38
model_files_path = os .environ .get ("MODEL_DIR" )
@@ -26,15 +42,27 @@ def __init__(self):
26
42
config = AutoConfig .from_pretrained (model_files_path )
27
43
logger .info ("Config: {}" .format (config ))
28
44
29
-
30
45
self .is_token_classifier = False
31
-
32
- if self .is_token_classifier :
33
- pass
34
- else :
46
+ self .is_causal_lm = False
47
+ self .is_sequence_classifier = False
48
+
49
+ if any ("ForTokenClassification" in arch for arch in config .architectures ):
50
+ self .is_token_classifier = True
51
+ logger .error ("Token classification models are not supported." )
52
+ raise ValueError ("Token classification models are not supported." )
53
+ elif any ("GraniteForCausalLM" in arch for arch in config .architectures ):
54
+ self .is_causal_lm = True
55
+ self .model = AutoModelForCausalLM .from_pretrained (
56
+ pretrained_model_name_or_path = model_files_path ,
57
+ )
58
+ elif any ("ForSequenceClassification" in arch for arch in config .architectures ):
59
+ self .is_sequence_classifier = True
35
60
self .model = AutoModelForSequenceClassification .from_pretrained (
36
61
pretrained_model_name_or_path = model_files_path ,
37
62
)
63
+ else :
64
+ logger .error ("Unsupported model architecture." )
65
+ raise ValueError ("Unsupported model architecture." )
38
66
39
67
logger .info ("torch.cuda?" .upper () + " " + str (torch .cuda .is_available ()))
40
68
@@ -43,6 +71,8 @@ def __init__(self):
43
71
except Exception :
44
72
if self .is_token_classifier :
45
73
self .model_name = "token_classifier"
74
+ elif self .is_causal_lm :
75
+ self .model_name = "causal_lm"
46
76
else :
47
77
self .model_name = "sequence_classifier"
48
78
self .cuda_device = None
@@ -57,56 +87,146 @@ def __init__(self):
57
87
os .environ ["PYTORCH_CUDA_ALLOC_CONF" ] = "max_split_size_mb:512"
58
88
logger .info ("cuda_device" .upper () + " " + str (self .cuda_device ))
59
89
90
+ def parse_output (self , output , input_len , nlogprobs , safe_token , unsafe_token ):
91
+ label , prob_of_risk = None , None
92
+ if nlogprobs > 0 :
93
+ list_index_logprobs_i = [
94
+ torch .topk (token_i , k = nlogprobs , largest = True , sorted = True )
95
+ for token_i in list (output .scores )[:- 1 ]
96
+ ]
97
+ if list_index_logprobs_i is not None :
98
+ prob = self .get_probabilities (
99
+ list_index_logprobs_i , safe_token , unsafe_token
100
+ )
101
+ prob_of_risk = prob [1 ]
102
+ res = self .tokenizer .decode (
103
+ output .sequences [:, input_len :][0 ], skip_special_tokens = True
104
+ ).strip ()
105
+ if unsafe_token .lower () == res .lower ():
106
+ label = unsafe_token
107
+ elif safe_token .lower () == res .lower ():
108
+ label = safe_token
109
+ else :
110
+ label = "failed"
111
+ return label , prob_of_risk .item ()
112
+
113
+ def get_probabilities (self , logprobs , safe_token , unsafe_token ):
114
+ safe_token_prob = 1e-50
115
+ unsafe_token_prob = 1e-50
116
+ for gen_token_i in logprobs :
117
+ for logprob , index in zip (
118
+ gen_token_i .values .tolist ()[0 ], gen_token_i .indices .tolist ()[0 ]
119
+ ):
120
+ decoded_token = self .tokenizer .convert_ids_to_tokens (index )
121
+ if decoded_token .strip ().lower () == safe_token .lower ():
122
+ safe_token_prob += math .exp (logprob )
123
+ if decoded_token .strip ().lower () == unsafe_token .lower ():
124
+ unsafe_token_prob += math .exp (logprob )
125
+ probabilities = torch .softmax (
126
+ torch .tensor ([math .log (safe_token_prob ), math .log (unsafe_token_prob )]),
127
+ dim = 0 ,
128
+ )
129
+ return probabilities
130
+
60
131
def run (self , input : ContentAnalysisHttpRequest ) -> ContentsAnalysisResponse :
61
132
# run the classification for each entry on contents array
62
133
# logger.info(tokenizer_parameters)
63
134
contents_analyses = []
64
135
for text in input .contents :
65
136
content_analyses = []
66
- tokenized = self .tokenizer (
67
- text ,
68
- max_length = len (text ),
69
- return_tensors = "pt" ,
70
- truncation = True ,
71
- padding = True ,
72
- )
73
- if self .cuda_device :
74
- logger .info ("adding tokenized to CUDA" )
75
- # If we are using a GPU, the tokens need to be there.
76
- tokenized = tokenized .to (self .cuda_device )
77
- # print (tokenized)
78
-
79
- # A BatchEncoding includes 'data', 'encodings', 'is_fast', and 'n_sequences'.
80
- with torch .no_grad ():
81
- logger .info ("tokens: {}" .format (tokenized ))
82
- model_out = self .model (** tokenized )
83
- # logger.info(model_out)
84
- # return logits
85
- logits = model_out .logits
86
-
87
- if self .is_token_classifier :
88
- pass
89
- else :
90
- # Get the class with the highest probability, and use the model’s id2label mapping to convert it to a text label list
91
- prediction = torch .argmax (logits , dim = 1 ).detach ().numpy ().tolist ()[0 ]
92
- prediction_labels = self .model .config .id2label [prediction ]
93
- probability = torch .softmax (logits , dim = 1 ).detach ().numpy ()[:,1 ].tolist ()[0 ]
94
-
137
+ if self .is_causal_lm :
138
+ messages = [{"role" : "user" , "content" : text }]
139
+ for risk_name in self .risk_names :
140
+ guardian_config = {"risk_name" : risk_name }
141
+ input_ids = self .tokenizer .apply_chat_template (
142
+ messages ,
143
+ guardian_config = guardian_config ,
144
+ add_generation_prompt = True ,
145
+ return_tensors = "pt" ,
146
+ ).to (self .model .device )
147
+ input_len = input_ids .shape [1 ]
148
+ with torch .no_grad ():
149
+ output = self .model .generate (
150
+ input_ids ,
151
+ do_sample = False ,
152
+ max_new_tokens = 20 ,
153
+ return_dict_in_generate = True ,
154
+ output_scores = True ,
155
+ )
156
+ nlogprobs = 20
157
+ safe_token = "No"
158
+ unsafe_token = "Yes"
159
+ label , prob_of_risk = self .parse_output (
160
+ output , input_len , nlogprobs , safe_token , unsafe_token
161
+ )
95
162
content_analyses .append (
96
163
ContentAnalysisResponse (
97
164
start = 0 ,
98
165
end = len (text ),
99
166
detection = self .model_name ,
100
- detection_type = "sequence_classification " ,
101
- score = probability ,
102
- sequence_classification = prediction_labels ,
103
- sequence_probability = probability ,
167
+ detection_type = "causal_lm " ,
168
+ score = prob_of_risk ,
169
+ sequence_classification = risk_name ,
170
+ sequence_probability = prob_of_risk ,
104
171
token_classifications = None ,
105
172
token_probabilities = None ,
106
173
text = text ,
107
174
evidences = [],
108
175
)
109
176
)
177
+
178
+ else :
179
+ tokenized = self .tokenizer (
180
+ text ,
181
+ max_length = len (text ),
182
+ return_tensors = "pt" ,
183
+ truncation = True ,
184
+ padding = True ,
185
+ )
186
+ if self .cuda_device :
187
+ logger .info ("adding tokenized to CUDA" )
188
+ # If we are using a GPU, the tokens need to be there.
189
+ tokenized = tokenized .to (self .cuda_device )
190
+ # print (tokenized)
191
+
192
+ # A BatchEncoding includes 'data', 'encodings', 'is_fast', and 'n_sequences'.
193
+ with torch .no_grad ():
194
+ logger .info ("tokens: {}" .format (tokenized ))
195
+ model_out = self .model (** tokenized )
196
+ # logger.info(model_out)
197
+ # return logits
198
+ logits = model_out .logits
199
+
200
+ if self .is_token_classifier :
201
+ pass
202
+ else :
203
+ # Get the class with the highest probability, and use the model’s id2label mapping to convert it to a text label list
204
+ prediction = (
205
+ torch .argmax (logits , dim = 1 ).detach ().numpy ().tolist ()[0 ]
206
+ )
207
+ prediction_labels = self .model .config .id2label [prediction ]
208
+ probability = (
209
+ torch .softmax (logits , dim = 1 )
210
+ .detach ()
211
+ .numpy ()[:, 1 ]
212
+ .tolist ()[0 ]
213
+ )
214
+
215
+ content_analyses .append (
216
+ ContentAnalysisResponse (
217
+ start = 0 ,
218
+ end = len (text ),
219
+ detection = self .model_name ,
220
+ detection_type = "sequence_classification" ,
221
+ score = probability ,
222
+ sequence_classification = prediction_labels ,
223
+ sequence_probability = probability ,
224
+ token_classifications = None ,
225
+ token_probabilities = None ,
226
+ text = text ,
227
+ evidences = [],
228
+ )
229
+ )
110
230
contents_analyses .append (content_analyses )
111
231
112
232
return contents_analyses
0 commit comments