Skip to content

feat: Add xgb detectors #42

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions detectors/Dockerfile.xgb
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
FROM registry.access.redhat.com/ubi9/ubi-minimal as base
RUN microdnf update -y && \
microdnf install -y --nodocs \
python-pip python-devel && \
pip install --upgrade --no-cache-dir pip wheel && \
microdnf clean all
RUN pip install --no-cache-dir torch

# FROM icr.io/fm-stack/ubi9-minimal-py39-torch as builder
FROM base as builder

COPY ./common/requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

COPY ./xgb/requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

FROM builder


WORKDIR /app
ARG CACHEBUST=1
RUN echo "$CACHEBUST"
COPY xgb/build/model_artifacts /app/model_artifacts
COPY ./common /common

COPY ./xgb/build/scheme.py /app
COPY ./xgb/build/app.py /app
COPY ./xgb/build/detector.py /app

EXPOSE 8000
CMD ["uvicorn", "app:app", "--workers", "4", "--host", "0.0.0.0", "--port", "8000", "--log-config", "/common/log_conf.yaml"]
27 changes: 27 additions & 0 deletions detectors/xgb/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# XGB Classification Detector

## Setup
1. Train XGB model and save trained model
```
cd guardrails-detectors/detectors/xgb/build
make all
```

2. Build image
```
cd guardrails-detectors
podman build --file=Dockerfile.xgb -t xgb_detector:latest
```

## Detector API
## `/api/v1/text/contents`
*

## Testing Locally
```
podman run -p 8001:8000 --platform=linux/amd64 quay.io/christinaexyou/xgb_detector:latest
```

Wait for the server to start
```
```
4 changes: 4 additions & 0 deletions detectors/xgb/build/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
train_pipeline:
python3 train.py
all:
- train_pipeline
50 changes: 50 additions & 0 deletions detectors/xgb/build/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import os
import sys
from contextlib import asynccontextmanager
from typing import Annotated

from fastapi import Header
from prometheus_fastapi_instrumentator import Instrumentator
sys.path.insert(0, os.path.abspath(".."))

from common.app import DetectorBaseAPI as FastAPI
from detector import Detector
from detectors.common.scheme import (
ContentAnalysisHttpRequest,
ContentsAnalysisResponse,
Error,
)

detector_objects = {}


@asynccontextmanager
async def lifespan(app: FastAPI):
app.set_detector(Detector())
yield
# Clean up the ML models and release the resources
detector: Detector = app.get_detector()
if detector and hasattr(detector, 'close'):
detector.close()
app.cleanup_detector()


app = FastAPI(lifespan=lifespan, dependencies=[])
Instrumentator().instrument(app).expose(app)


@app.post(
"/api/v1/text/contents",
response_model=ContentsAnalysisResponse,
description="""Detectors that work on content text, be it user prompt or generated text. \
Generally classification type detectors qualify for this. <br>""",
responses={
404: {"model": Error, "description": "Resource Not Found"},
422: {"model": Error, "description": "Validation Error"},
},
)
async def detector_unary_handler(
request: ContentAnalysisHttpRequest,
detector_id: Annotated[str, Header(example="en_syntax_slate.38m.hap")],
):
return ContentsAnalysisResponse(root=detector_objects["detector"].run(request))
63 changes: 63 additions & 0 deletions detectors/xgb/build/detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import os
import sys

sys.path.insert(0, os.path.abspath(".."))
import pathlib
import torch
import xgboost as xgb
from detectors.common.scheme import (
ContentAnalysisHttpRequest,
ContentAnalysisResponse,
)
import pickle as pkl
from base_detector_registry import BaseDetectorRegistry

try:
from common.app import logger
except ImportError:
sys.path.insert(0, os.path.join(pathlib.Path(__file__).parent.parent.resolve()))
from common.app import logger

class Detector:
def __init__(self):
# initialize the detector
model_files_path = os.path.abspath(os.path.join(os.sep, "app", "model_artifacts"))
if not os.path.exists(model_files_path):
model_files_path = os.path.join("build", "model_artifacts")
logger.info(model_files_path)

self.model = pkl.load(open(os.path.join(model_files_path, 'model.pkl'), 'rb'))
self.vectorizer = pkl.load(open(os.path.join(model_files_path, 'vectorizer.pkl'), 'rb'))

if torch.cuda.is_available():
self.cuda_device = torch.device("cuda")
torch.cuda.empty_cache()
self.model.to(self.cuda_device)
logger.info("cuda_device".upper() + " " + str(self.cuda_device))
self.batch_size = 1
else:
self.batch_size = 8
logger.info("Detector initialized.")

def run(self, request: ContentAnalysisHttpRequest) -> ContentAnalysisResponse:
if hasattr(request, "detection_type") and request.detection_type != "spamCheck":
logger.warning(f"Unsupported detection type: {request.detection_type}")

content_analyses = []
for batch_idx in range(0, len(request.contents), self.batch_size):
text = request.contents[batch_idx:batch_idx + self.batch_size]
vectorized_text = self.vectorizer.transform(text)
predictions = self.model.predict(vectorized_text)
detections = any([True for p in predictions if p == 1])

content_analyses.append(
ContentAnalysisResponse(
start=0,
end=len(text),
detection=detections,
detection_type="spamCheck",
text=text,
evidences=[],
)
)
return content_analyses
80 changes: 80 additions & 0 deletions detectors/xgb/build/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import argparse
import os
import pathlib
import pickle
import re

from datasets import load_dataset
import pandas as pd
import xgboost as xgb
from nltk.corpus import stopwords
from nltk.stem import PorterStemmer
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import GridSearchCV


def load_data(dataset_name, **dataset_kwargs):
return load_dataset(dataset_name, **dataset_kwargs)

def generate_training_df(data):
df = pd.DataFrame(data).rename(columns={"sms": "text"})
return df

def preprocess_text(X):
stemmer = PorterStemmer()
stop_words = stopwords.words('english')
X['text'] = X['text'].apply(lambda x: " ".join([stemmer.stem(i) for i in re.sub("[^a-zA-Z]", " ", x).split() if i not in stop_words]).lower())
return X

# ==================================================================================================
# === MAIN =========================================================================================
# ==================================================================================================
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='sms_spam')
parser.add_argument('--hf_token', type=str, default=os.getenv('HF_TOKEN', ''))

args = parser.parse_args()
artifact_path = os.path.join(pathlib.Path(__file__).parent.resolve(), "model_artifacts")
os.makedirs(artifact_path, exist_ok=True)

if args.dataset.lower() == 'sms_spam':
print("Loading SMS spam dataset...")
data = load_data("ucirvine/sms_spam", token=args.hf_token, split="train")
train_df = generate_training_df(data)

print("Preprocessing data...")
X = train_df.drop(columns=['label'])
X = preprocess_text(X)
vectorizer = TfidfVectorizer()
X_vec = vectorizer.fit_transform(X['text'])

y = train_df['label']

print("Training XGBoost model...")
param_grid = {
'max_depth': [3, 5, 7],
'learning_rate': [0.1, 0.01, 0.001],
'subsample': [0.5, 0.7, 1]
}
grid_search = GridSearchCV(
xgb.XGBClassifier(random_state=42),
param_grid,
cv=5,
scoring='accuracy'
)
grid_search.fit(X_vec, y)
clf = xgb.XGBClassifier(
max_depth=grid_search.best_params_['max_depth'],
learning_rate=grid_search.best_params_['learning_rate'],
subsample=grid_search.best_params_['subsample'],
random_state=42
)
clf.fit(X_vec, y)

print(f"Saving training artifacts to {artifact_path}...")
pickle.dump(vectorizer, open(f'{artifact_path}/vectorizer.pkl', 'wb'))
pickle.dump(clf, open(f'{artifact_path}/model.pkl', 'wb'))

else:
raise NotImplementedError(f"Dataset {args.dataset} not yet supported")
7 changes: 7 additions & 0 deletions detectors/xgb/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
xgboost
torch==2.4.0
pandas==2.2.2
numpy==1.26.4
datasets
nltk==3.9.1
scikit-learn
29 changes: 29 additions & 0 deletions detectors/xgb/test_xgb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import pytest
from fastapi.testclient import TestClient

class TestXGBDetectors:
@pytest.fixture
def client(self):
from detectors.xgb.build.app import app
from detectors.xgb.build.detector import Detector

app.set_detector(Detector(), "detector")
return TestClient(app)

@pytest.mark.parametrize(
"content,expected",
[
(["Congratulations! You've won a $1000 Walmart gift card. Click here to claim now."], True),
(["Don't forget to bring your notebook to class tomorrow."], False),
]
)

def test_xgb_detectors(self, client, content, expected):
payload = {
"content": [content],
}
resp = client.post("api/v1/text/contexts", json=payload)
assert resp.status_code == 200
assert len(resp.json()[0]) > 0
assert resp.json()[0][0]['spam_check'] == expected

29 changes: 29 additions & 0 deletions tests/detectors/xgb/test_xgb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import pytest
from fastapi.testclient import TestClient

class TestXGBDetectors:
@pytest.fixture
def client(self):
from detectors.xgb.build.app import app
from detectors.xgb.build.detector import Detector

app.set_detector(Detector(), "detector")
return TestClient(app)

@pytest.mark.parametrize(
"content,expected",
[
(["Congratulations! You've won a $1000 Walmart gift card. Click here to claim now."], True),
(["Don't forget to bring your notebook to class tomorrow."], False),
]
)

def test_xgb_detectors(self, client, content, expected):
payload = {
"content": [content],
}
resp = client.post("api/v1/text/contexts", json=payload)
assert resp.status_code == 200
assert len(resp.json()[0]) > 0
assert resp.json()[0][0]['spam_check'] == expected

Loading