Skip to content

Commit 99cae02

Browse files
committed
🎉 initial commit of the pii detector based on a hf sequence classification model
1 parent 0bdaca8 commit 99cae02

File tree

7 files changed

+270
-0
lines changed

7 files changed

+270
-0
lines changed

detectors/Dockerfile.pii-transformer

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
FROM registry.access.redhat.com/ubi9/ubi-minimal as base
2+
RUN microdnf update -y && \
3+
microdnf install -y --nodocs \
4+
python-pip python-devel && \
5+
pip install --upgrade --no-cache-dir pip wheel && \
6+
microdnf clean all
7+
RUN pip install --no-cache-dir torch
8+
9+
# FROM icr.io/fm-stack/ubi9-minimal-py39-torch as builder
10+
FROM base as builder
11+
12+
COPY ./common/requirements.txt .
13+
RUN pip install --no-cache-dir -r requirements.txt
14+
15+
COPY ./pii_transformer/requirements.txt .
16+
RUN pip install --no-cache-dir -r requirements.txt
17+
18+
FROM builder
19+
20+
WORKDIR /app
21+
22+
COPY ./common /common
23+
COPY ./pii_transformer/app.py /app
24+
COPY ./pii_transformer/detector.py /app
25+
COPY ./pii_transformer/scheme.py /app
26+
27+
ENV PII_MODEL_PATH "h2oai/deberta_finetuned_pii"
28+
29+
EXPOSE 8000
30+
CMD ["uvicorn", "app:app", "--workers", "4", "--host", "0.0.0.0", "--port", "8000", "--log-config", "/common/log_conf.yaml"]
31+
32+
# gunicorn main:app --workers 4 --worker-class uvicorn.workers.UvicornWorker --bind 0.0.0.0:8000

detectors/pii_transformer/__init__.py

Whitespace-only changes.

detectors/pii_transformer/app.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import os
2+
import sys
3+
from contextlib import asynccontextmanager
4+
from typing import Annotated
5+
6+
from fastapi import Header
7+
8+
sys.path.insert(0, os.path.abspath(".."))
9+
10+
from common.app import DetectorBaseAPI as FastAPI
11+
from detector import Detector
12+
from scheme import (
13+
ContentAnalysisHttpRequest,
14+
ContentsAnalysisResponse,
15+
Error,
16+
)
17+
18+
detector_objects = {}
19+
20+
21+
@asynccontextmanager
22+
async def lifespan(app: FastAPI):
23+
detector_objects["detector"] = Detector()
24+
yield
25+
# Clean up the ML models and release the resources
26+
detector_objects.clear()
27+
28+
29+
app = FastAPI(lifespan=lifespan, dependencies=[])
30+
31+
32+
@app.post(
33+
"/api/v1/text/contents",
34+
response_model=ContentsAnalysisResponse,
35+
description="""Detectors that work on content text, be it user prompt or generated text. \
36+
Generally classification type detectors qualify for this. <br>""",
37+
responses={
38+
404: {"model": Error, "description": "Resource Not Found"},
39+
422: {"model": Error, "description": "Validation Error"},
40+
},
41+
)
42+
async def detector_unary_handler(
43+
request: ContentAnalysisHttpRequest,
44+
detector_id: Annotated[str, Header(example="en_syntax_slate.38m.hap")],
45+
):
46+
return ContentsAnalysisResponse(root=detector_objects["detector"].run(request))

detectors/pii_transformer/detector.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import os
2+
import sys
3+
4+
sys.path.insert(0, os.path.abspath(".."))
5+
# from common.scheme import TextDetectionHttpRequest, TextDetectionResponse
6+
7+
import torch.nn
8+
from common.app import logger
9+
from scheme import (
10+
ContentAnalysisHttpRequest,
11+
ContentAnalysisResponse,
12+
ContentsAnalysisResponse,
13+
)
14+
15+
# Detector imports
16+
from transformers import AutoTokenizer, AutoModelForTokenClassification
17+
18+
19+
class Detector:
20+
def __init__(self):
21+
# initialize the detector
22+
model_files_path = os.environ.get("PII_MODEL_PATH")
23+
logger.info(model_files_path)
24+
# The tokenizer is going to be using the data on the CPU
25+
self.tokenizer = AutoTokenizer.from_pretrained(model_files_path, use_fast=True)
26+
self.model = AutoModelForTokenClassification.from_pretrained(
27+
pretrained_model_name_or_path=model_files_path,
28+
)
29+
30+
logger.info("torch.cuda".upper() + " " + str(torch.cuda.is_available()))
31+
32+
self.cuda_device = None
33+
34+
if torch.cuda.is_available():
35+
# transparently taking a cuda gpu for an actor
36+
self.cuda_device = torch.device("cuda")
37+
torch.cuda.empty_cache()
38+
self.model.to(self.cuda_device)
39+
# self.tokenizer.to(self.cuda_device)
40+
# AttributeError: 'RobertaTokenizerFast' object has no attribute 'to'
41+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
42+
logger.info("cuda_device".upper() + " " + str(self.cuda_device))
43+
44+
def run(self, input: ContentAnalysisHttpRequest) -> ContentsAnalysisResponse:
45+
# run the classification for each entry on contents array
46+
# logger.info(tokenizer_parameters)
47+
contents_analyses = []
48+
for text in input.contents:
49+
content_analyses = []
50+
tokenized = self.tokenizer(
51+
text,
52+
max_length=len(text),
53+
return_tensors="pt",
54+
truncation=True,
55+
padding=True,
56+
)
57+
if self.cuda_device:
58+
logger.info("adding tokenized to CUDA")
59+
# If we are using a GPU, the tokens need to be there.
60+
tokenized = tokenized.to(self.cuda_device)
61+
# print (tokenized)
62+
63+
# A BatchEncoding includes 'data', 'encodings', 'is_fast', and 'n_sequences'.
64+
model_out = self.model(**tokenized)
65+
# logger.info(model_out)
66+
# return logits
67+
logits = model_out.logits
68+
# Get the class with the highest probability, and use the model’s id2label mapping to convert it to a text label list
69+
predictions = torch.argmax(logits, dim=2)
70+
predicted_token_class = [
71+
self.model.config.id2label[p] for p in predictions[0].tolist()
72+
]
73+
# check if predicted token class list contains elements other than 'O', if yes, then it is a PII
74+
pii_indicator = any([True for p in predicted_token_class if p != "O"])
75+
76+
# # A List[float] seems like a sensible way to return this
77+
# if hap_score >= input.parameters["threshold"]:
78+
content_analyses.append(
79+
ContentAnalysisResponse(
80+
start=0,
81+
end=len(text),
82+
detection="has_pii",
83+
detection_type="pii",
84+
pii_check=pii_indicator,
85+
text=text,
86+
predicted_token_class=predicted_token_class,
87+
evidences=[],
88+
)
89+
)
90+
contents_analyses.append(content_analyses)
91+
92+
return contents_analyses
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""
2+
Content Warning: Contains potentially offensive text dealing with racism, misogyny, and violence. Examples of input prompts provided purely for the purposes of testing HAP (Hate, Abuse and Profanity) models.
3+
"""
4+
5+
from locust import HttpUser, between, task
6+
7+
8+
class WebsiteUser(HttpUser):
9+
wait_time = between(1, 5)
10+
11+
# def on_start(self):
12+
# self.client.post("/login", {
13+
# "username": "test_user",
14+
# "password": ""
15+
# })
16+
17+
@task
18+
def docs(self):
19+
self.client.get("/docs")
20+
21+
@task
22+
def api(self):
23+
self.client.get("/openapi.json")
24+
25+
@task
26+
def pii(self):
27+
self.client.post(
28+
"/api/v1/text/contents?pii_transformer",
29+
json={
30+
"contents": [
31+
"My name is John Doe and my social security number is 123-45-6789."
32+
]
33+
},
34+
headers={"detector-id": "pii", "Content-Type": "application/json"},
35+
)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
transformers==4.43.4

detectors/pii_transformer/scheme.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from enum import Enum
2+
from typing import List, Optional
3+
4+
from pydantic import BaseModel, Field, RootModel
5+
6+
7+
class Evidence(BaseModel):
8+
source: str = Field(
9+
title="Source",
10+
example="https://en.wikipedia.org/wiki/IBM",
11+
description="Source of the evidence, it can be url of the evidence etc",
12+
)
13+
14+
15+
class EvidenceType(str, Enum):
16+
url = "url"
17+
title = "title"
18+
19+
20+
class EvidenceObj(BaseModel):
21+
type: EvidenceType = Field(
22+
title="EvidenceType",
23+
example="url",
24+
description="Type field signifying the type of evidence provided. Example url, title etc",
25+
)
26+
evidence: Evidence = Field(
27+
description="Evidence object, currently only containing source, but in future can contain other optional arguments like id, etc",
28+
)
29+
30+
31+
class ContentAnalysisHttpRequest(BaseModel):
32+
contents: List[str] = Field(
33+
min_length=1,
34+
title="Contents",
35+
description="Field allowing users to provide list of texts for analysis. Note, results of this endpoint will contain analysis / detection of each of the provided text in the order they are present in the contents object.",
36+
example=[
37+
"Martians are like crocodiles; the more you give them meat, the more they want"
38+
],
39+
)
40+
41+
42+
class ContentAnalysisResponse(BaseModel):
43+
start: int = Field(example=14)
44+
end: int = Field(example=26)
45+
detection: str = Field(example="has_pii")
46+
detection_type: str = Field(example="pii")
47+
pii_check: bool = Field(example=True)
48+
text: str = Field(example="My favourite dish is pierogi")
49+
predicted_token_class: List[str] = Field(examples=["O", "O", "O", "O", "O"])
50+
evidences: Optional[List[EvidenceObj]] = Field(
51+
description="Optional field providing evidences for the provided detection",
52+
default=[],
53+
)
54+
55+
56+
class ContentsAnalysisResponse(RootModel):
57+
root: List[List[ContentAnalysisResponse]] = Field(
58+
title="Response Text Content Analysis Unary Handler Api V1 Text Content Post"
59+
)
60+
61+
62+
class Error(BaseModel):
63+
code: int
64+
message: str

0 commit comments

Comments
 (0)