Skip to content

Commit c7598db

Browse files
committed
🚧 extended the Detector class to enable some compatibility with the granite-guardian CausalLM model
1 parent d1aeb2c commit c7598db

File tree

2 files changed

+161
-41
lines changed

2 files changed

+161
-41
lines changed

detectors/huggingface/detector.py

Lines changed: 160 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
sys.path.insert(0, os.path.abspath(".."))
55
# from common.scheme import TextDetectionHttpRequest, TextDetectionResponse
6-
6+
import math
77
import torch.nn
88
from common.app import logger
99
from scheme import (
@@ -13,10 +13,26 @@
1313
)
1414

1515
# Detector imports
16-
from transformers import (AutoConfig, AutoTokenizer, AutoModelForTokenClassification, AutoModelForSequenceClassification)
16+
from transformers import (
17+
AutoConfig,
18+
AutoTokenizer,
19+
AutoModelForTokenClassification,
20+
AutoModelForSequenceClassification,
21+
AutoModelForCausalLM,
22+
)
1723

1824

1925
class Detector:
26+
risk_names = [
27+
"harm",
28+
"social_bias",
29+
"jailbreak",
30+
"profanity",
31+
"unethical_behavior",
32+
"sexual_content",
33+
"violence",
34+
]
35+
2036
def __init__(self):
2137
# initialize the detector
2238
model_files_path = os.environ.get("MODEL_DIR")
@@ -26,15 +42,27 @@ def __init__(self):
2642
config = AutoConfig.from_pretrained(model_files_path)
2743
logger.info("Config: {}".format(config))
2844

29-
3045
self.is_token_classifier = False
31-
32-
if self.is_token_classifier:
33-
pass
34-
else:
46+
self.is_causal_lm = False
47+
self.is_sequence_classifier = False
48+
49+
if any("ForTokenClassification" in arch for arch in config.architectures):
50+
self.is_token_classifier = True
51+
logger.error("Token classification models are not supported.")
52+
raise ValueError("Token classification models are not supported.")
53+
elif any("GraniteForCausalLM" in arch for arch in config.architectures):
54+
self.is_causal_lm = True
55+
self.model = AutoModelForCausalLM.from_pretrained(
56+
pretrained_model_name_or_path=model_files_path,
57+
)
58+
elif any("ForSequenceClassification" in arch for arch in config.architectures):
59+
self.is_sequence_classifier = True
3560
self.model = AutoModelForSequenceClassification.from_pretrained(
3661
pretrained_model_name_or_path=model_files_path,
3762
)
63+
else:
64+
logger.error("Unsupported model architecture.")
65+
raise ValueError("Unsupported model architecture.")
3866

3967
logger.info("torch.cuda?".upper() + " " + str(torch.cuda.is_available()))
4068

@@ -43,6 +71,8 @@ def __init__(self):
4371
except Exception:
4472
if self.is_token_classifier:
4573
self.model_name = "token_classifier"
74+
elif self.is_causal_lm:
75+
self.model_name = "causal_lm"
4676
else:
4777
self.model_name = "sequence_classifier"
4878
self.cuda_device = None
@@ -57,56 +87,146 @@ def __init__(self):
5787
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
5888
logger.info("cuda_device".upper() + " " + str(self.cuda_device))
5989

90+
def parse_output(self, output, input_len, nlogprobs, safe_token, unsafe_token):
91+
label, prob_of_risk = None, None
92+
if nlogprobs > 0:
93+
list_index_logprobs_i = [
94+
torch.topk(token_i, k=nlogprobs, largest=True, sorted=True)
95+
for token_i in list(output.scores)[:-1]
96+
]
97+
if list_index_logprobs_i is not None:
98+
prob = self.get_probabilities(
99+
list_index_logprobs_i, safe_token, unsafe_token
100+
)
101+
prob_of_risk = prob[1]
102+
res = self.tokenizer.decode(
103+
output.sequences[:, input_len:][0], skip_special_tokens=True
104+
).strip()
105+
if unsafe_token.lower() == res.lower():
106+
label = unsafe_token
107+
elif safe_token.lower() == res.lower():
108+
label = safe_token
109+
else:
110+
label = "failed"
111+
return label, prob_of_risk.item()
112+
113+
def get_probabilities(self, logprobs, safe_token, unsafe_token):
114+
safe_token_prob = 1e-50
115+
unsafe_token_prob = 1e-50
116+
for gen_token_i in logprobs:
117+
for logprob, index in zip(
118+
gen_token_i.values.tolist()[0], gen_token_i.indices.tolist()[0]
119+
):
120+
decoded_token = self.tokenizer.convert_ids_to_tokens(index)
121+
if decoded_token.strip().lower() == safe_token.lower():
122+
safe_token_prob += math.exp(logprob)
123+
if decoded_token.strip().lower() == unsafe_token.lower():
124+
unsafe_token_prob += math.exp(logprob)
125+
probabilities = torch.softmax(
126+
torch.tensor([math.log(safe_token_prob), math.log(unsafe_token_prob)]),
127+
dim=0,
128+
)
129+
return probabilities
130+
60131
def run(self, input: ContentAnalysisHttpRequest) -> ContentsAnalysisResponse:
61132
# run the classification for each entry on contents array
62133
# logger.info(tokenizer_parameters)
63134
contents_analyses = []
64135
for text in input.contents:
65136
content_analyses = []
66-
tokenized = self.tokenizer(
67-
text,
68-
max_length=len(text),
69-
return_tensors="pt",
70-
truncation=True,
71-
padding=True,
72-
)
73-
if self.cuda_device:
74-
logger.info("adding tokenized to CUDA")
75-
# If we are using a GPU, the tokens need to be there.
76-
tokenized = tokenized.to(self.cuda_device)
77-
# print (tokenized)
78-
79-
# A BatchEncoding includes 'data', 'encodings', 'is_fast', and 'n_sequences'.
80-
with torch.no_grad():
81-
logger.info("tokens: {}".format(tokenized))
82-
model_out = self.model(**tokenized)
83-
# logger.info(model_out)
84-
# return logits
85-
logits = model_out.logits
86-
87-
if self.is_token_classifier:
88-
pass
89-
else:
90-
# Get the class with the highest probability, and use the model’s id2label mapping to convert it to a text label list
91-
prediction = torch.argmax(logits, dim=1).detach().numpy().tolist()[0]
92-
prediction_labels = self.model.config.id2label[prediction]
93-
probability = torch.softmax(logits, dim=1).detach().numpy()[:,1].tolist()[0]
94-
137+
if self.is_causal_lm:
138+
messages = [{"role": "user", "content": text}]
139+
for risk_name in self.risk_names:
140+
guardian_config = {"risk_name": risk_name}
141+
input_ids = self.tokenizer.apply_chat_template(
142+
messages,
143+
guardian_config=guardian_config,
144+
add_generation_prompt=True,
145+
return_tensors="pt",
146+
).to(self.model.device)
147+
input_len = input_ids.shape[1]
148+
with torch.no_grad():
149+
output = self.model.generate(
150+
input_ids,
151+
do_sample=False,
152+
max_new_tokens=20,
153+
return_dict_in_generate=True,
154+
output_scores=True,
155+
)
156+
nlogprobs = 20
157+
safe_token = "No"
158+
unsafe_token = "Yes"
159+
label, prob_of_risk = self.parse_output(
160+
output, input_len, nlogprobs, safe_token, unsafe_token
161+
)
95162
content_analyses.append(
96163
ContentAnalysisResponse(
97164
start=0,
98165
end=len(text),
99166
detection=self.model_name,
100-
detection_type="sequence_classification",
101-
score=probability,
102-
sequence_classification=prediction_labels,
103-
sequence_probability=probability,
167+
detection_type="causal_lm",
168+
score=prob_of_risk,
169+
sequence_classification=risk_name,
170+
sequence_probability=prob_of_risk,
104171
token_classifications=None,
105172
token_probabilities=None,
106173
text=text,
107174
evidences=[],
108175
)
109176
)
177+
178+
else:
179+
tokenized = self.tokenizer(
180+
text,
181+
max_length=len(text),
182+
return_tensors="pt",
183+
truncation=True,
184+
padding=True,
185+
)
186+
if self.cuda_device:
187+
logger.info("adding tokenized to CUDA")
188+
# If we are using a GPU, the tokens need to be there.
189+
tokenized = tokenized.to(self.cuda_device)
190+
# print (tokenized)
191+
192+
# A BatchEncoding includes 'data', 'encodings', 'is_fast', and 'n_sequences'.
193+
with torch.no_grad():
194+
logger.info("tokens: {}".format(tokenized))
195+
model_out = self.model(**tokenized)
196+
# logger.info(model_out)
197+
# return logits
198+
logits = model_out.logits
199+
200+
if self.is_token_classifier:
201+
pass
202+
else:
203+
# Get the class with the highest probability, and use the model’s id2label mapping to convert it to a text label list
204+
prediction = (
205+
torch.argmax(logits, dim=1).detach().numpy().tolist()[0]
206+
)
207+
prediction_labels = self.model.config.id2label[prediction]
208+
probability = (
209+
torch.softmax(logits, dim=1)
210+
.detach()
211+
.numpy()[:, 1]
212+
.tolist()[0]
213+
)
214+
215+
content_analyses.append(
216+
ContentAnalysisResponse(
217+
start=0,
218+
end=len(text),
219+
detection=self.model_name,
220+
detection_type="sequence_classification",
221+
score=probability,
222+
sequence_classification=prediction_labels,
223+
sequence_probability=probability,
224+
token_classifications=None,
225+
token_probabilities=None,
226+
text=text,
227+
evidences=[],
228+
)
229+
)
110230
contents_analyses.append(content_analyses)
111231

112232
return contents_analyses
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
transformers==4.43.4
1+
transformers==4.47.0

0 commit comments

Comments
 (0)