Skip to content

Commit ceed982

Browse files
committed
✨ changed scheme.py import to be from common and added a better way to handle mulitclass classifiers
1 parent 5a66b50 commit ceed982

File tree

4 files changed

+25
-114
lines changed

4 files changed

+25
-114
lines changed

detectors/Dockerfile.hf

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ RUN echo "$CACHEBUST"
2323
COPY ./common /common
2424
COPY ./huggingface/app.py /app
2525
COPY ./huggingface/detector.py /app
26-
COPY ./huggingface/scheme.py /app
2726

2827
EXPOSE 8000
2928
CMD ["uvicorn", "app:app", "--workers", "4", "--host", "0.0.0.0", "--port", "8000", "--log-config", "/common/log_conf.yaml"]

detectors/huggingface/app.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from common.app import DetectorBaseAPI as FastAPI
1111
from detector import Detector
12-
from scheme import (
12+
from common.scheme import (
1313
ContentAnalysisHttpRequest,
1414
ContentsAnalysisResponse,
1515
Error,

detectors/huggingface/detector.py

Lines changed: 24 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
AutoModelForCausalLM,
1212
)
1313
from common.app import logger
14-
from scheme import (
14+
from common.scheme import (
1515
ContentAnalysisHttpRequest,
1616
ContentAnalysisResponse,
1717
ContentsAnalysisResponse,
@@ -169,15 +169,6 @@ def get_probabilities(self, logprobs, safe_token, unsafe_token):
169169
return probabilities
170170

171171
def process_causal_lm(self, text):
172-
"""
173-
Process text using a causal language model.
174-
175-
Args:
176-
text (str): Input text.
177-
178-
Returns:
179-
List[ContentAnalysisResponse]: List of content analysis results.
180-
"""
181172
messages = [{"role": "user", "content": text}]
182173
content_analyses = []
183174
for risk_name in self.risk_names:
@@ -205,26 +196,17 @@ def process_causal_lm(self, text):
205196
detection=self.model_name,
206197
detection_type="causal_lm",
207198
score=prob_of_risk,
208-
sequence_classification=risk_name,
209-
sequence_probability=prob_of_risk,
210-
token_classifications=None,
211-
token_probabilities=None,
212199
text=text,
213200
evidences=[],
214201
)
215202
)
216203
return content_analyses
217204

218-
def process_sequence_classification(self, text):
219-
"""
220-
Process text using a sequence classification model.
221-
222-
Args:
223-
text (str): Input text.
224-
225-
Returns:
226-
List[ContentAnalysisResponse]: List of content analysis results.
227-
"""
205+
def process_sequence_classification(self, text, detector_params=None, threshold=None):
206+
detector_params = detector_params or {}
207+
if threshold is None:
208+
threshold = detector_params.get("threshold", 0.5)
209+
non_trigger_labels = set(detector_params.get("non_trigger_labels", []))
228210
content_analyses = []
229211
tokenized = self.tokenizer(
230212
text,
@@ -238,26 +220,21 @@ def process_sequence_classification(self, text):
238220

239221
with torch.no_grad():
240222
logits = self.model(**tokenized).logits
241-
prediction = torch.argmax(logits, dim=1).detach().cpu().numpy().tolist()[0]
242-
prediction_labels = self.model.config.id2label[prediction]
243-
probability = (
244-
torch.softmax(logits, dim=1).detach().cpu().numpy()[:, 1].tolist()[0]
245-
)
246-
content_analyses.append(
247-
ContentAnalysisResponse(
248-
start=0,
249-
end=len(text),
250-
detection=self.model_name,
251-
detection_type="sequence_classification",
252-
score=probability,
253-
sequence_classification=prediction_labels,
254-
sequence_probability=probability,
255-
token_classifications=None,
256-
token_probabilities=None,
257-
text=text,
258-
evidences=[],
259-
)
260-
)
223+
probabilities = torch.softmax(logits, dim=1).detach().cpu().numpy()[0]
224+
for idx, prob in enumerate(probabilities):
225+
label = self.model.config.id2label[idx]
226+
if prob >= threshold and label not in non_trigger_labels:
227+
content_analyses.append(
228+
ContentAnalysisResponse(
229+
start=0,
230+
end=len(text),
231+
detection=getattr(self.model.config, "problem_type", "sequence_classification"),
232+
detection_type=label,
233+
score=prob,
234+
text=text,
235+
evidences=[],
236+
)
237+
)
261238
return content_analyses
262239

263240
def run(self, input: ContentAnalysisHttpRequest) -> ContentsAnalysisResponse:
@@ -275,7 +252,9 @@ def run(self, input: ContentAnalysisHttpRequest) -> ContentsAnalysisResponse:
275252
if self.is_causal_lm:
276253
analyses = self.process_causal_lm(text)
277254
elif self.is_sequence_classifier:
278-
analyses = self.process_sequence_classification(text)
255+
analyses = self.process_sequence_classification(
256+
text, detector_params=getattr(input, "detector_params", None)
257+
)
279258
else:
280259
raise ValueError("Unsupported model type for analysis.")
281260
contents_analyses.append(analyses)

detectors/huggingface/scheme.py

Lines changed: 0 additions & 67 deletions
This file was deleted.

0 commit comments

Comments
 (0)