From 31bf89401d5512dc98bb28e845a545d248031149 Mon Sep 17 00:00:00 2001 From: Christina Xu <77992688+christinaexyou@users.noreply.github.com> Date: Tue, 22 Jul 2025 13:11:20 -0400 Subject: [PATCH 1/2] feat: Add XGB detectors --- detectors/xgb/README.md | 27 +++++++++++ detectors/xgb/build/Makefile | 4 ++ detectors/xgb/build/app.py | 52 +++++++++++++++++++++ detectors/xgb/build/detector.py | 63 ++++++++++++++++++++++++++ detectors/xgb/build/train.py | 80 +++++++++++++++++++++++++++++++++ detectors/xgb/requirements.txt | 7 +++ detectors/xgb/test_xgb.py | 29 ++++++++++++ 7 files changed, 262 insertions(+) create mode 100644 detectors/xgb/README.md create mode 100644 detectors/xgb/build/Makefile create mode 100644 detectors/xgb/build/app.py create mode 100644 detectors/xgb/build/detector.py create mode 100644 detectors/xgb/build/train.py create mode 100644 detectors/xgb/requirements.txt create mode 100644 detectors/xgb/test_xgb.py diff --git a/detectors/xgb/README.md b/detectors/xgb/README.md new file mode 100644 index 0000000..c0c284a --- /dev/null +++ b/detectors/xgb/README.md @@ -0,0 +1,27 @@ +# XGB Classification Detector + +## Setup +1. Train XGB model and save trained model + ``` + cd guardrails-detectors/detectors/xgb/build + make all + ``` + +2. Build image + ``` + cd guardrails-detectors + podman build --file=Dockerfile.xgb -t xgb_detector:latest + ``` + +## Detector API +## `/api/v1/text/contents` +* + +## Testing Locally +``` +podman run -p 8001:8000 --platform=linux/amd64 quay.io/christinaexyou/xgb_detector:latest +``` + +Wait for the server to start +``` +``` \ No newline at end of file diff --git a/detectors/xgb/build/Makefile b/detectors/xgb/build/Makefile new file mode 100644 index 0000000..9dc2877 --- /dev/null +++ b/detectors/xgb/build/Makefile @@ -0,0 +1,4 @@ +train_pipeline: + python3 train.py +all: + - train_pipeline \ No newline at end of file diff --git a/detectors/xgb/build/app.py b/detectors/xgb/build/app.py new file mode 100644 index 0000000..df3b84b --- /dev/null +++ b/detectors/xgb/build/app.py @@ -0,0 +1,52 @@ +import os +import sys +from contextlib import asynccontextmanager +from typing import Annotated + +from fastapi import Header +from prometheus_fastapi_instrumentator import Instrumentator + +sys.path.insert(0, os.path.abspath("..")) + +from detector import Detector + +from detectors.common.app import DetectorBaseAPI as FastAPI +from detectors.common.scheme import ( + ContentAnalysisHttpRequest, + ContentsAnalysisResponse, + Error, +) + +detector_objects = {} + + +@asynccontextmanager +async def lifespan(app: FastAPI): + app.set_detector(Detector()) + yield + # Clean up the ML models and release the resources + detector: Detector = app.get_detector() + if detector and hasattr(detector, 'close'): + detector.close() + app.cleanup_detector() + + +app = FastAPI(lifespan=lifespan, dependencies=[]) +Instrumentator().instrument(app).expose(app) + + +@app.post( + "/api/v1/text/contents", + response_model=ContentsAnalysisResponse, + description="""Detectors that work on content text, be it user prompt or generated text. \ + Generally classification type detectors qualify for this.
""", + responses={ + 404: {"model": Error, "description": "Resource Not Found"}, + 422: {"model": Error, "description": "Validation Error"}, + }, +) +async def detector_unary_handler( + request: ContentAnalysisHttpRequest, + detector_id: Annotated[str, Header(example="en_syntax_slate.38m.hap")], +): + return ContentsAnalysisResponse(root=detector_objects["detector"].run(request)) \ No newline at end of file diff --git a/detectors/xgb/build/detector.py b/detectors/xgb/build/detector.py new file mode 100644 index 0000000..9a3aa23 --- /dev/null +++ b/detectors/xgb/build/detector.py @@ -0,0 +1,63 @@ +import os +import sys + +sys.path.insert(0, os.path.abspath("..")) +import pathlib +import torch +import xgboost as xgb +from detectors.common.scheme import ( + ContentAnalysisHttpRequest, + ContentAnalysisResponse, + ContentAnalysisHttpResponse, +) +import pickle as pkl + +try: + from common.app import logger +except ImportError: + sys.path.insert(0, os.path.join(pathlib.Path(__file__).parent.parent.resolve())) + from common.app import logger + +class Detector: + def __init__(self): + # initialize the detector + model_files_path = os.path.abspath(s.path.join(os.sep, "app", "model_artifacts")) + if not os.path.exists(model_files_path): + model_files_path = os.path.join("build", "model_artifacts") + logger.info(model_files_path) + + self.model = pkl.load(open(os.path.join(model_files_path, 'model.pkl'), 'rb')) + self.vectorizer = pkl.load(open(os.path.join(model_files_path, 'vectorizer.pkl'), 'rb')) + + if torch.cuda.is_available(): + self.cuda_device = torch.device("cuda") + torch.cuda.empty_cache() + self.model.to(self.cuda_device) + logger.info("cuda_device".upper() + " " + str(self.cuda_device)) + self.batch_size = 1 + else: + self.batch_size = 8 + logger.info("Detector initialized.") + + def run(self, request: ContentAnalysisHttpRequest) -> ContentsAnalysisHttpResponse: + if hasattr(request, "detection_type") and request.detection_type != "spamCheck": + logger.warning(f"Unsupported detection type: {request.detection_type}") + + content_analyses = [] + for batch_idx in range(0, len(request.contents), self.batch_size): + text = request.contents[batch_idx:batch_idx + self.batch_size] + vectorized_text = self.vectorizer.transform(text) + predictions = self.model.predict(vectorized_text) + detections = any([True for p in predictions if p == 1]) + + content_analyses.append( + ContentAnalysisResponse( + start=0, + end=len(text), + detection=detections, + detection_type="spamCheck", + text=text, + evidences=[], + ) + ) + return content_analyses diff --git a/detectors/xgb/build/train.py b/detectors/xgb/build/train.py new file mode 100644 index 0000000..d09f494 --- /dev/null +++ b/detectors/xgb/build/train.py @@ -0,0 +1,80 @@ +import argparse +import os +import pathlib +import pickle +import re + +from datasets import load_dataset +import pandas as pd +import xgboost as xgb +from nltk.corpus import stopwords +from nltk.stem import PorterStemmer +from sklearn.feature_extraction.text import TfidfVectorizer +from sklearn.model_selection import GridSearchCV + + +def load_data(dataset_name, **dataset_kwargs): + return load_dataset(dataset_name, **dataset_kwargs) + +def generate_training_df(data): + df = pd.DataFrame(data).rename(columns={"sms": "text"}) + return df + +def preprocess_text(X): + stemmer = PorterStemmer() + stop_words = stopwords.words('english') + X['text'] = X['text'].apply(lambda x: " ".join([stemmer.stem(i) for i in re.sub("[^a-zA-Z]", " ", x).split() if i not in stop_words]).lower()) + return X + +# ================================================================================================== +# === MAIN ========================================================================================= +# ================================================================================================== +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--dataset', type=str, default='sms_spam') + parser.add_argument('--hf_token', type=str, default=os.getenv('HF_TOKEN', '')) + + args = parser.parse_args() + artifact_path = os.path.join(pathlib.Path(__file__).parent.resolve(), "model_artifacts") + os.makedirs(artifact_path, exist_ok=True) + + if args.dataset.lower() == 'sms_spam': + print("Loading SMS spam dataset...") + data = load_data("ucirvine/sms_spam", token=args.hf_token, split="train") + train_df = generate_training_df(data) + + print("Preprocessing data...") + X = train_df.drop(columns=['label']) + X = preprocess_text(X) + vectorizer = TfidfVectorizer() + X_vec = vectorizer.fit_transform(X['text']) + + y = train_df['label'] + + print("Training XGBoost model...") + param_grid = { + 'max_depth': [3, 5, 7], + 'learning_rate': [0.1, 0.01, 0.001], + 'subsample': [0.5, 0.7, 1] + } + grid_search = GridSearchCV( + xgb.XGBClassifier(random_state=42), + param_grid, + cv=5, + scoring='accuracy' + ) + grid_search.fit(X_vec, y) + clf = xgb.XGBClassifier( + max_depth=grid_search.best_params_['max_depth'], + learning_rate=grid_search.best_params_['learning_rate'], + subsample=grid_search.best_params_['subsample'], + random_state=42 + ) + clf.fit(X, y) + + print(f"Saving training artifacts to {artifact_path}...") + pickle.dump(vectorizer, open(f'{artifact_path}/vectorizer.pkl', 'wb')) + pickle.dump(clf, open(f'{artifact_path}/model.pkl', 'wb')) + + else: + raise NotImplementedError(f"Dataset {args.dataset} not yet supported") diff --git a/detectors/xgb/requirements.txt b/detectors/xgb/requirements.txt new file mode 100644 index 0000000..eef693c --- /dev/null +++ b/detectors/xgb/requirements.txt @@ -0,0 +1,7 @@ +xgboost==3.0.2 +torch==2.4.0 +pandas==2.2.2 +numpy==1.26.4 +datasets +nltk==3.9.1 +scikit-learn==1.7.0 diff --git a/detectors/xgb/test_xgb.py b/detectors/xgb/test_xgb.py new file mode 100644 index 0000000..fee1a00 --- /dev/null +++ b/detectors/xgb/test_xgb.py @@ -0,0 +1,29 @@ +import pytest +from fastapi.testclient import TestClient + +class TestXGBDetectors: + @pytest.fixture + def client(self): + from detectors.xgb.build.app import app + from detectors.xgb.build.detector import Detector + + app.set_detector(Detector(), "detector") + return TestClient(app) + + @pytest.mark.parametrize( + "content,expected", + [ + (["Congratulations! You've won a $1000 Walmart gift card. Click here to claim now."], True), + (["Don't forget to bring your notebook to class tomorrow."], False), + ] + ) + + def test_xgb_detectors(self, client, content, expected): + payload = { + "content": [content], + } + resp = client.post("api/v1/text/contexts", json=payload) + assert resp.status_code == 200 + assert len(resp.json()[0]) > 0 + assert resp.json()[0][0]['spam_check'] == expected + From 761ac8cd7224964e26b8d3dd94e4d95aeb97306b Mon Sep 17 00:00:00 2001 From: Christina Xu Date: Tue, 22 Jul 2025 16:40:24 -0400 Subject: [PATCH 2/2] Add Dockerfile.xgb --- detectors/Dockerfile.xgb | 32 ++++++++++++++++++++++++++++++++ detectors/xgb/build/app.py | 4 +--- detectors/xgb/build/detector.py | 6 +++--- detectors/xgb/build/train.py | 2 +- detectors/xgb/requirements.txt | 4 ++-- tests/detectors/xgb/test_xgb.py | 29 +++++++++++++++++++++++++++++ 6 files changed, 68 insertions(+), 9 deletions(-) create mode 100644 detectors/Dockerfile.xgb create mode 100644 tests/detectors/xgb/test_xgb.py diff --git a/detectors/Dockerfile.xgb b/detectors/Dockerfile.xgb new file mode 100644 index 0000000..2e99b5a --- /dev/null +++ b/detectors/Dockerfile.xgb @@ -0,0 +1,32 @@ +FROM registry.access.redhat.com/ubi9/ubi-minimal as base +RUN microdnf update -y && \ + microdnf install -y --nodocs \ + python-pip python-devel && \ + pip install --upgrade --no-cache-dir pip wheel && \ + microdnf clean all +RUN pip install --no-cache-dir torch + +# FROM icr.io/fm-stack/ubi9-minimal-py39-torch as builder +FROM base as builder + +COPY ./common/requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +COPY ./xgb/requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +FROM builder + + +WORKDIR /app +ARG CACHEBUST=1 +RUN echo "$CACHEBUST" +COPY xgb/build/model_artifacts /app/model_artifacts +COPY ./common /common + +COPY ./xgb/build/scheme.py /app +COPY ./xgb/build/app.py /app +COPY ./xgb/build/detector.py /app + +EXPOSE 8000 +CMD ["uvicorn", "app:app", "--workers", "4", "--host", "0.0.0.0", "--port", "8000", "--log-config", "/common/log_conf.yaml"] \ No newline at end of file diff --git a/detectors/xgb/build/app.py b/detectors/xgb/build/app.py index df3b84b..41331f3 100644 --- a/detectors/xgb/build/app.py +++ b/detectors/xgb/build/app.py @@ -5,12 +5,10 @@ from fastapi import Header from prometheus_fastapi_instrumentator import Instrumentator - sys.path.insert(0, os.path.abspath("..")) +from common.app import DetectorBaseAPI as FastAPI from detector import Detector - -from detectors.common.app import DetectorBaseAPI as FastAPI from detectors.common.scheme import ( ContentAnalysisHttpRequest, ContentsAnalysisResponse, diff --git a/detectors/xgb/build/detector.py b/detectors/xgb/build/detector.py index 9a3aa23..cf01985 100644 --- a/detectors/xgb/build/detector.py +++ b/detectors/xgb/build/detector.py @@ -8,9 +8,9 @@ from detectors.common.scheme import ( ContentAnalysisHttpRequest, ContentAnalysisResponse, - ContentAnalysisHttpResponse, ) import pickle as pkl +from base_detector_registry import BaseDetectorRegistry try: from common.app import logger @@ -21,7 +21,7 @@ class Detector: def __init__(self): # initialize the detector - model_files_path = os.path.abspath(s.path.join(os.sep, "app", "model_artifacts")) + model_files_path = os.path.abspath(os.path.join(os.sep, "app", "model_artifacts")) if not os.path.exists(model_files_path): model_files_path = os.path.join("build", "model_artifacts") logger.info(model_files_path) @@ -39,7 +39,7 @@ def __init__(self): self.batch_size = 8 logger.info("Detector initialized.") - def run(self, request: ContentAnalysisHttpRequest) -> ContentsAnalysisHttpResponse: + def run(self, request: ContentAnalysisHttpRequest) -> ContentAnalysisResponse: if hasattr(request, "detection_type") and request.detection_type != "spamCheck": logger.warning(f"Unsupported detection type: {request.detection_type}") diff --git a/detectors/xgb/build/train.py b/detectors/xgb/build/train.py index d09f494..ee3bace 100644 --- a/detectors/xgb/build/train.py +++ b/detectors/xgb/build/train.py @@ -70,7 +70,7 @@ def preprocess_text(X): subsample=grid_search.best_params_['subsample'], random_state=42 ) - clf.fit(X, y) + clf.fit(X_vec, y) print(f"Saving training artifacts to {artifact_path}...") pickle.dump(vectorizer, open(f'{artifact_path}/vectorizer.pkl', 'wb')) diff --git a/detectors/xgb/requirements.txt b/detectors/xgb/requirements.txt index eef693c..5721445 100644 --- a/detectors/xgb/requirements.txt +++ b/detectors/xgb/requirements.txt @@ -1,7 +1,7 @@ -xgboost==3.0.2 +xgboost torch==2.4.0 pandas==2.2.2 numpy==1.26.4 datasets nltk==3.9.1 -scikit-learn==1.7.0 +scikit-learn diff --git a/tests/detectors/xgb/test_xgb.py b/tests/detectors/xgb/test_xgb.py new file mode 100644 index 0000000..fee1a00 --- /dev/null +++ b/tests/detectors/xgb/test_xgb.py @@ -0,0 +1,29 @@ +import pytest +from fastapi.testclient import TestClient + +class TestXGBDetectors: + @pytest.fixture + def client(self): + from detectors.xgb.build.app import app + from detectors.xgb.build.detector import Detector + + app.set_detector(Detector(), "detector") + return TestClient(app) + + @pytest.mark.parametrize( + "content,expected", + [ + (["Congratulations! You've won a $1000 Walmart gift card. Click here to claim now."], True), + (["Don't forget to bring your notebook to class tomorrow."], False), + ] + ) + + def test_xgb_detectors(self, client, content, expected): + payload = { + "content": [content], + } + resp = client.post("api/v1/text/contexts", json=payload) + assert resp.status_code == 200 + assert len(resp.json()[0]) > 0 + assert resp.json()[0][0]['spam_check'] == expected +