Skip to content

Commit 72fd63a

Browse files
authored
Merge pull request #44 from m-misiura/hf-enable-multiclass
feat: better handle the multilabel classifiers by tweaking the response format
2 parents 5a66b50 + d05094f commit 72fd63a

File tree

4 files changed

+51
-117
lines changed

4 files changed

+51
-117
lines changed

detectors/Dockerfile.hf

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ RUN echo "$CACHEBUST"
2323
COPY ./common /common
2424
COPY ./huggingface/app.py /app
2525
COPY ./huggingface/detector.py /app
26-
COPY ./huggingface/scheme.py /app
2726

2827
EXPOSE 8000
2928
CMD ["uvicorn", "app:app", "--workers", "4", "--host", "0.0.0.0", "--port", "8000", "--log-config", "/common/log_conf.yaml"]

detectors/huggingface/app.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from common.app import DetectorBaseAPI as FastAPI
1111
from detector import Detector
12-
from scheme import (
12+
from common.scheme import (
1313
ContentAnalysisHttpRequest,
1414
ContentsAnalysisResponse,
1515
Error,

detectors/huggingface/detector.py

Lines changed: 50 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import sys
33

44
sys.path.insert(0, os.path.abspath(".."))
5+
import json
56
import math
67
import torch
78
from transformers import (
@@ -11,13 +12,29 @@
1112
AutoModelForCausalLM,
1213
)
1314
from common.app import logger
14-
from scheme import (
15+
from common.scheme import (
1516
ContentAnalysisHttpRequest,
1617
ContentAnalysisResponse,
1718
ContentsAnalysisResponse,
1819
)
1920
import gc
2021

22+
def _parse_safe_labels_env():
23+
if os.environ.get("SAFE_LABELS"):
24+
try:
25+
parsed = json.loads(os.environ.get("SAFE_LABELS"))
26+
if isinstance(parsed, (int, str)):
27+
logger.info(f"SAFE_LABELS env var: {parsed}")
28+
return [parsed]
29+
if isinstance(parsed, list) and all(isinstance(x, (int, str)) for x in parsed):
30+
logger.info(f"SAFE_LABELS env var: {parsed}")
31+
return parsed
32+
except Exception as e:
33+
logger.warning(f"Could not parse SAFE_LABELS env var: {e}. Defaulting to [0].")
34+
return [0]
35+
logger.info("SAFE_LABELS env var not set: defaulting to [0].")
36+
return [0]
37+
2138
class Detector:
2239
risk_names = [
2340
"harm",
@@ -37,6 +54,7 @@ def __init__(self):
3754
self.model = None
3855
self.cuda_device = None
3956
self.model_name = "unknown"
57+
self.safe_labels = _parse_safe_labels_env()
4058

4159
model_files_path = os.environ.get("MODEL_DIR")
4260
if not model_files_path:
@@ -169,15 +187,6 @@ def get_probabilities(self, logprobs, safe_token, unsafe_token):
169187
return probabilities
170188

171189
def process_causal_lm(self, text):
172-
"""
173-
Process text using a causal language model.
174-
175-
Args:
176-
text (str): Input text.
177-
178-
Returns:
179-
List[ContentAnalysisResponse]: List of content analysis results.
180-
"""
181190
messages = [{"role": "user", "content": text}]
182191
content_analyses = []
183192
for risk_name in self.risk_names:
@@ -205,26 +214,19 @@ def process_causal_lm(self, text):
205214
detection=self.model_name,
206215
detection_type="causal_lm",
207216
score=prob_of_risk,
208-
sequence_classification=risk_name,
209-
sequence_probability=prob_of_risk,
210-
token_classifications=None,
211-
token_probabilities=None,
212217
text=text,
213218
evidences=[],
214219
)
215220
)
216221
return content_analyses
217222

218-
def process_sequence_classification(self, text):
219-
"""
220-
Process text using a sequence classification model.
221-
222-
Args:
223-
text (str): Input text.
224-
225-
Returns:
226-
List[ContentAnalysisResponse]: List of content analysis results.
227-
"""
223+
def process_sequence_classification(self, text, detector_params=None, threshold=None):
224+
detector_params = detector_params or {}
225+
if threshold is None:
226+
threshold = detector_params.get("threshold", 0.5)
227+
# Merge safe_labels from env and request
228+
request_safe_labels = set(detector_params.get("safe_labels", []))
229+
all_safe_labels = set(self.safe_labels) | request_safe_labels
228230
content_analyses = []
229231
tokenized = self.tokenizer(
230232
text,
@@ -238,26 +240,26 @@ def process_sequence_classification(self, text):
238240

239241
with torch.no_grad():
240242
logits = self.model(**tokenized).logits
241-
prediction = torch.argmax(logits, dim=1).detach().cpu().numpy().tolist()[0]
242-
prediction_labels = self.model.config.id2label[prediction]
243-
probability = (
244-
torch.softmax(logits, dim=1).detach().cpu().numpy()[:, 1].tolist()[0]
245-
)
246-
content_analyses.append(
247-
ContentAnalysisResponse(
248-
start=0,
249-
end=len(text),
250-
detection=self.model_name,
251-
detection_type="sequence_classification",
252-
score=probability,
253-
sequence_classification=prediction_labels,
254-
sequence_probability=probability,
255-
token_classifications=None,
256-
token_probabilities=None,
257-
text=text,
258-
evidences=[],
259-
)
260-
)
243+
probabilities = torch.softmax(logits, dim=1).detach().cpu().numpy()[0]
244+
for idx, prob in enumerate(probabilities):
245+
label = self.model.config.id2label[idx]
246+
# Exclude by index or label name
247+
if (
248+
prob >= threshold
249+
and idx not in all_safe_labels
250+
and label not in all_safe_labels
251+
):
252+
content_analyses.append(
253+
ContentAnalysisResponse(
254+
start=0,
255+
end=len(text),
256+
detection=getattr(self.model.config, "problem_type", "sequence_classification"),
257+
detection_type=label,
258+
score=prob,
259+
text=text,
260+
evidences=[],
261+
)
262+
)
261263
return content_analyses
262264

263265
def run(self, input: ContentAnalysisHttpRequest) -> ContentsAnalysisResponse:
@@ -275,16 +277,16 @@ def run(self, input: ContentAnalysisHttpRequest) -> ContentsAnalysisResponse:
275277
if self.is_causal_lm:
276278
analyses = self.process_causal_lm(text)
277279
elif self.is_sequence_classifier:
278-
analyses = self.process_sequence_classification(text)
280+
analyses = self.process_sequence_classification(
281+
text, detector_params=getattr(input, "detector_params", None)
282+
)
279283
else:
280284
raise ValueError("Unsupported model type for analysis.")
281285
contents_analyses.append(analyses)
282286
return contents_analyses
283287

284-
285288
def close(self) -> None:
286289
"""Clean up model and tokenizer resources."""
287-
288290
if self.model:
289291
if hasattr(self.model, 'to') and hasattr(self.model, 'device') and self.model.device.type != "cpu":
290292
self.model = self.model.to(torch.device("cpu"))
@@ -296,4 +298,4 @@ def close(self) -> None:
296298
gc.collect()
297299

298300
if torch.cuda.is_available():
299-
torch.cuda.empty_cache()
301+
torch.cuda.empty_cache()

detectors/huggingface/scheme.py

Lines changed: 0 additions & 67 deletions
This file was deleted.

0 commit comments

Comments
 (0)