From 31bf89401d5512dc98bb28e845a545d248031149 Mon Sep 17 00:00:00 2001
From: Christina Xu <77992688+christinaexyou@users.noreply.github.com>
Date: Tue, 22 Jul 2025 13:11:20 -0400
Subject: [PATCH 1/2] feat: Add XGB detectors
---
detectors/xgb/README.md | 27 +++++++++++
detectors/xgb/build/Makefile | 4 ++
detectors/xgb/build/app.py | 52 +++++++++++++++++++++
detectors/xgb/build/detector.py | 63 ++++++++++++++++++++++++++
detectors/xgb/build/train.py | 80 +++++++++++++++++++++++++++++++++
detectors/xgb/requirements.txt | 7 +++
detectors/xgb/test_xgb.py | 29 ++++++++++++
7 files changed, 262 insertions(+)
create mode 100644 detectors/xgb/README.md
create mode 100644 detectors/xgb/build/Makefile
create mode 100644 detectors/xgb/build/app.py
create mode 100644 detectors/xgb/build/detector.py
create mode 100644 detectors/xgb/build/train.py
create mode 100644 detectors/xgb/requirements.txt
create mode 100644 detectors/xgb/test_xgb.py
diff --git a/detectors/xgb/README.md b/detectors/xgb/README.md
new file mode 100644
index 0000000..c0c284a
--- /dev/null
+++ b/detectors/xgb/README.md
@@ -0,0 +1,27 @@
+# XGB Classification Detector
+
+## Setup
+1. Train XGB model and save trained model
+ ```
+ cd guardrails-detectors/detectors/xgb/build
+ make all
+ ```
+
+2. Build image
+ ```
+ cd guardrails-detectors
+ podman build --file=Dockerfile.xgb -t xgb_detector:latest
+ ```
+
+## Detector API
+## `/api/v1/text/contents`
+*
+
+## Testing Locally
+```
+podman run -p 8001:8000 --platform=linux/amd64 quay.io/christinaexyou/xgb_detector:latest
+```
+
+Wait for the server to start
+```
+```
\ No newline at end of file
diff --git a/detectors/xgb/build/Makefile b/detectors/xgb/build/Makefile
new file mode 100644
index 0000000..9dc2877
--- /dev/null
+++ b/detectors/xgb/build/Makefile
@@ -0,0 +1,4 @@
+train_pipeline:
+ python3 train.py
+all:
+ - train_pipeline
\ No newline at end of file
diff --git a/detectors/xgb/build/app.py b/detectors/xgb/build/app.py
new file mode 100644
index 0000000..df3b84b
--- /dev/null
+++ b/detectors/xgb/build/app.py
@@ -0,0 +1,52 @@
+import os
+import sys
+from contextlib import asynccontextmanager
+from typing import Annotated
+
+from fastapi import Header
+from prometheus_fastapi_instrumentator import Instrumentator
+
+sys.path.insert(0, os.path.abspath(".."))
+
+from detector import Detector
+
+from detectors.common.app import DetectorBaseAPI as FastAPI
+from detectors.common.scheme import (
+ ContentAnalysisHttpRequest,
+ ContentsAnalysisResponse,
+ Error,
+)
+
+detector_objects = {}
+
+
+@asynccontextmanager
+async def lifespan(app: FastAPI):
+ app.set_detector(Detector())
+ yield
+ # Clean up the ML models and release the resources
+ detector: Detector = app.get_detector()
+ if detector and hasattr(detector, 'close'):
+ detector.close()
+ app.cleanup_detector()
+
+
+app = FastAPI(lifespan=lifespan, dependencies=[])
+Instrumentator().instrument(app).expose(app)
+
+
+@app.post(
+ "/api/v1/text/contents",
+ response_model=ContentsAnalysisResponse,
+ description="""Detectors that work on content text, be it user prompt or generated text. \
+ Generally classification type detectors qualify for this.
""",
+ responses={
+ 404: {"model": Error, "description": "Resource Not Found"},
+ 422: {"model": Error, "description": "Validation Error"},
+ },
+)
+async def detector_unary_handler(
+ request: ContentAnalysisHttpRequest,
+ detector_id: Annotated[str, Header(example="en_syntax_slate.38m.hap")],
+):
+ return ContentsAnalysisResponse(root=detector_objects["detector"].run(request))
\ No newline at end of file
diff --git a/detectors/xgb/build/detector.py b/detectors/xgb/build/detector.py
new file mode 100644
index 0000000..9a3aa23
--- /dev/null
+++ b/detectors/xgb/build/detector.py
@@ -0,0 +1,63 @@
+import os
+import sys
+
+sys.path.insert(0, os.path.abspath(".."))
+import pathlib
+import torch
+import xgboost as xgb
+from detectors.common.scheme import (
+ ContentAnalysisHttpRequest,
+ ContentAnalysisResponse,
+ ContentAnalysisHttpResponse,
+)
+import pickle as pkl
+
+try:
+ from common.app import logger
+except ImportError:
+ sys.path.insert(0, os.path.join(pathlib.Path(__file__).parent.parent.resolve()))
+ from common.app import logger
+
+class Detector:
+ def __init__(self):
+ # initialize the detector
+ model_files_path = os.path.abspath(s.path.join(os.sep, "app", "model_artifacts"))
+ if not os.path.exists(model_files_path):
+ model_files_path = os.path.join("build", "model_artifacts")
+ logger.info(model_files_path)
+
+ self.model = pkl.load(open(os.path.join(model_files_path, 'model.pkl'), 'rb'))
+ self.vectorizer = pkl.load(open(os.path.join(model_files_path, 'vectorizer.pkl'), 'rb'))
+
+ if torch.cuda.is_available():
+ self.cuda_device = torch.device("cuda")
+ torch.cuda.empty_cache()
+ self.model.to(self.cuda_device)
+ logger.info("cuda_device".upper() + " " + str(self.cuda_device))
+ self.batch_size = 1
+ else:
+ self.batch_size = 8
+ logger.info("Detector initialized.")
+
+ def run(self, request: ContentAnalysisHttpRequest) -> ContentsAnalysisHttpResponse:
+ if hasattr(request, "detection_type") and request.detection_type != "spamCheck":
+ logger.warning(f"Unsupported detection type: {request.detection_type}")
+
+ content_analyses = []
+ for batch_idx in range(0, len(request.contents), self.batch_size):
+ text = request.contents[batch_idx:batch_idx + self.batch_size]
+ vectorized_text = self.vectorizer.transform(text)
+ predictions = self.model.predict(vectorized_text)
+ detections = any([True for p in predictions if p == 1])
+
+ content_analyses.append(
+ ContentAnalysisResponse(
+ start=0,
+ end=len(text),
+ detection=detections,
+ detection_type="spamCheck",
+ text=text,
+ evidences=[],
+ )
+ )
+ return content_analyses
diff --git a/detectors/xgb/build/train.py b/detectors/xgb/build/train.py
new file mode 100644
index 0000000..d09f494
--- /dev/null
+++ b/detectors/xgb/build/train.py
@@ -0,0 +1,80 @@
+import argparse
+import os
+import pathlib
+import pickle
+import re
+
+from datasets import load_dataset
+import pandas as pd
+import xgboost as xgb
+from nltk.corpus import stopwords
+from nltk.stem import PorterStemmer
+from sklearn.feature_extraction.text import TfidfVectorizer
+from sklearn.model_selection import GridSearchCV
+
+
+def load_data(dataset_name, **dataset_kwargs):
+ return load_dataset(dataset_name, **dataset_kwargs)
+
+def generate_training_df(data):
+ df = pd.DataFrame(data).rename(columns={"sms": "text"})
+ return df
+
+def preprocess_text(X):
+ stemmer = PorterStemmer()
+ stop_words = stopwords.words('english')
+ X['text'] = X['text'].apply(lambda x: " ".join([stemmer.stem(i) for i in re.sub("[^a-zA-Z]", " ", x).split() if i not in stop_words]).lower())
+ return X
+
+# ==================================================================================================
+# === MAIN =========================================================================================
+# ==================================================================================================
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--dataset', type=str, default='sms_spam')
+ parser.add_argument('--hf_token', type=str, default=os.getenv('HF_TOKEN', ''))
+
+ args = parser.parse_args()
+ artifact_path = os.path.join(pathlib.Path(__file__).parent.resolve(), "model_artifacts")
+ os.makedirs(artifact_path, exist_ok=True)
+
+ if args.dataset.lower() == 'sms_spam':
+ print("Loading SMS spam dataset...")
+ data = load_data("ucirvine/sms_spam", token=args.hf_token, split="train")
+ train_df = generate_training_df(data)
+
+ print("Preprocessing data...")
+ X = train_df.drop(columns=['label'])
+ X = preprocess_text(X)
+ vectorizer = TfidfVectorizer()
+ X_vec = vectorizer.fit_transform(X['text'])
+
+ y = train_df['label']
+
+ print("Training XGBoost model...")
+ param_grid = {
+ 'max_depth': [3, 5, 7],
+ 'learning_rate': [0.1, 0.01, 0.001],
+ 'subsample': [0.5, 0.7, 1]
+ }
+ grid_search = GridSearchCV(
+ xgb.XGBClassifier(random_state=42),
+ param_grid,
+ cv=5,
+ scoring='accuracy'
+ )
+ grid_search.fit(X_vec, y)
+ clf = xgb.XGBClassifier(
+ max_depth=grid_search.best_params_['max_depth'],
+ learning_rate=grid_search.best_params_['learning_rate'],
+ subsample=grid_search.best_params_['subsample'],
+ random_state=42
+ )
+ clf.fit(X, y)
+
+ print(f"Saving training artifacts to {artifact_path}...")
+ pickle.dump(vectorizer, open(f'{artifact_path}/vectorizer.pkl', 'wb'))
+ pickle.dump(clf, open(f'{artifact_path}/model.pkl', 'wb'))
+
+ else:
+ raise NotImplementedError(f"Dataset {args.dataset} not yet supported")
diff --git a/detectors/xgb/requirements.txt b/detectors/xgb/requirements.txt
new file mode 100644
index 0000000..eef693c
--- /dev/null
+++ b/detectors/xgb/requirements.txt
@@ -0,0 +1,7 @@
+xgboost==3.0.2
+torch==2.4.0
+pandas==2.2.2
+numpy==1.26.4
+datasets
+nltk==3.9.1
+scikit-learn==1.7.0
diff --git a/detectors/xgb/test_xgb.py b/detectors/xgb/test_xgb.py
new file mode 100644
index 0000000..fee1a00
--- /dev/null
+++ b/detectors/xgb/test_xgb.py
@@ -0,0 +1,29 @@
+import pytest
+from fastapi.testclient import TestClient
+
+class TestXGBDetectors:
+ @pytest.fixture
+ def client(self):
+ from detectors.xgb.build.app import app
+ from detectors.xgb.build.detector import Detector
+
+ app.set_detector(Detector(), "detector")
+ return TestClient(app)
+
+ @pytest.mark.parametrize(
+ "content,expected",
+ [
+ (["Congratulations! You've won a $1000 Walmart gift card. Click here to claim now."], True),
+ (["Don't forget to bring your notebook to class tomorrow."], False),
+ ]
+ )
+
+ def test_xgb_detectors(self, client, content, expected):
+ payload = {
+ "content": [content],
+ }
+ resp = client.post("api/v1/text/contexts", json=payload)
+ assert resp.status_code == 200
+ assert len(resp.json()[0]) > 0
+ assert resp.json()[0][0]['spam_check'] == expected
+
From 761ac8cd7224964e26b8d3dd94e4d95aeb97306b Mon Sep 17 00:00:00 2001
From: Christina Xu
Date: Tue, 22 Jul 2025 16:40:24 -0400
Subject: [PATCH 2/2] Add Dockerfile.xgb
---
detectors/Dockerfile.xgb | 32 ++++++++++++++++++++++++++++++++
detectors/xgb/build/app.py | 4 +---
detectors/xgb/build/detector.py | 6 +++---
detectors/xgb/build/train.py | 2 +-
detectors/xgb/requirements.txt | 4 ++--
tests/detectors/xgb/test_xgb.py | 29 +++++++++++++++++++++++++++++
6 files changed, 68 insertions(+), 9 deletions(-)
create mode 100644 detectors/Dockerfile.xgb
create mode 100644 tests/detectors/xgb/test_xgb.py
diff --git a/detectors/Dockerfile.xgb b/detectors/Dockerfile.xgb
new file mode 100644
index 0000000..2e99b5a
--- /dev/null
+++ b/detectors/Dockerfile.xgb
@@ -0,0 +1,32 @@
+FROM registry.access.redhat.com/ubi9/ubi-minimal as base
+RUN microdnf update -y && \
+ microdnf install -y --nodocs \
+ python-pip python-devel && \
+ pip install --upgrade --no-cache-dir pip wheel && \
+ microdnf clean all
+RUN pip install --no-cache-dir torch
+
+# FROM icr.io/fm-stack/ubi9-minimal-py39-torch as builder
+FROM base as builder
+
+COPY ./common/requirements.txt .
+RUN pip install --no-cache-dir -r requirements.txt
+
+COPY ./xgb/requirements.txt .
+RUN pip install --no-cache-dir -r requirements.txt
+
+FROM builder
+
+
+WORKDIR /app
+ARG CACHEBUST=1
+RUN echo "$CACHEBUST"
+COPY xgb/build/model_artifacts /app/model_artifacts
+COPY ./common /common
+
+COPY ./xgb/build/scheme.py /app
+COPY ./xgb/build/app.py /app
+COPY ./xgb/build/detector.py /app
+
+EXPOSE 8000
+CMD ["uvicorn", "app:app", "--workers", "4", "--host", "0.0.0.0", "--port", "8000", "--log-config", "/common/log_conf.yaml"]
\ No newline at end of file
diff --git a/detectors/xgb/build/app.py b/detectors/xgb/build/app.py
index df3b84b..41331f3 100644
--- a/detectors/xgb/build/app.py
+++ b/detectors/xgb/build/app.py
@@ -5,12 +5,10 @@
from fastapi import Header
from prometheus_fastapi_instrumentator import Instrumentator
-
sys.path.insert(0, os.path.abspath(".."))
+from common.app import DetectorBaseAPI as FastAPI
from detector import Detector
-
-from detectors.common.app import DetectorBaseAPI as FastAPI
from detectors.common.scheme import (
ContentAnalysisHttpRequest,
ContentsAnalysisResponse,
diff --git a/detectors/xgb/build/detector.py b/detectors/xgb/build/detector.py
index 9a3aa23..cf01985 100644
--- a/detectors/xgb/build/detector.py
+++ b/detectors/xgb/build/detector.py
@@ -8,9 +8,9 @@
from detectors.common.scheme import (
ContentAnalysisHttpRequest,
ContentAnalysisResponse,
- ContentAnalysisHttpResponse,
)
import pickle as pkl
+from base_detector_registry import BaseDetectorRegistry
try:
from common.app import logger
@@ -21,7 +21,7 @@
class Detector:
def __init__(self):
# initialize the detector
- model_files_path = os.path.abspath(s.path.join(os.sep, "app", "model_artifacts"))
+ model_files_path = os.path.abspath(os.path.join(os.sep, "app", "model_artifacts"))
if not os.path.exists(model_files_path):
model_files_path = os.path.join("build", "model_artifacts")
logger.info(model_files_path)
@@ -39,7 +39,7 @@ def __init__(self):
self.batch_size = 8
logger.info("Detector initialized.")
- def run(self, request: ContentAnalysisHttpRequest) -> ContentsAnalysisHttpResponse:
+ def run(self, request: ContentAnalysisHttpRequest) -> ContentAnalysisResponse:
if hasattr(request, "detection_type") and request.detection_type != "spamCheck":
logger.warning(f"Unsupported detection type: {request.detection_type}")
diff --git a/detectors/xgb/build/train.py b/detectors/xgb/build/train.py
index d09f494..ee3bace 100644
--- a/detectors/xgb/build/train.py
+++ b/detectors/xgb/build/train.py
@@ -70,7 +70,7 @@ def preprocess_text(X):
subsample=grid_search.best_params_['subsample'],
random_state=42
)
- clf.fit(X, y)
+ clf.fit(X_vec, y)
print(f"Saving training artifacts to {artifact_path}...")
pickle.dump(vectorizer, open(f'{artifact_path}/vectorizer.pkl', 'wb'))
diff --git a/detectors/xgb/requirements.txt b/detectors/xgb/requirements.txt
index eef693c..5721445 100644
--- a/detectors/xgb/requirements.txt
+++ b/detectors/xgb/requirements.txt
@@ -1,7 +1,7 @@
-xgboost==3.0.2
+xgboost
torch==2.4.0
pandas==2.2.2
numpy==1.26.4
datasets
nltk==3.9.1
-scikit-learn==1.7.0
+scikit-learn
diff --git a/tests/detectors/xgb/test_xgb.py b/tests/detectors/xgb/test_xgb.py
new file mode 100644
index 0000000..fee1a00
--- /dev/null
+++ b/tests/detectors/xgb/test_xgb.py
@@ -0,0 +1,29 @@
+import pytest
+from fastapi.testclient import TestClient
+
+class TestXGBDetectors:
+ @pytest.fixture
+ def client(self):
+ from detectors.xgb.build.app import app
+ from detectors.xgb.build.detector import Detector
+
+ app.set_detector(Detector(), "detector")
+ return TestClient(app)
+
+ @pytest.mark.parametrize(
+ "content,expected",
+ [
+ (["Congratulations! You've won a $1000 Walmart gift card. Click here to claim now."], True),
+ (["Don't forget to bring your notebook to class tomorrow."], False),
+ ]
+ )
+
+ def test_xgb_detectors(self, client, content, expected):
+ payload = {
+ "content": [content],
+ }
+ resp = client.post("api/v1/text/contexts", json=payload)
+ assert resp.status_code == 200
+ assert len(resp.json()[0]) > 0
+ assert resp.json()[0][0]['spam_check'] == expected
+