Skip to content

Add embedding-based detector #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
# Type checker
.DS_Store

# IDEA
**/.idea/*

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand All @@ -12,4 +15,8 @@ __pycache__/
.python-version

# secrets
.env
.env

# artifacts
**/.pkl
**/model_artifacts/*
32 changes: 32 additions & 0 deletions detectors/Dockerfile.embedding-classifier
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
FROM registry.access.redhat.com/ubi9/ubi-minimal as base
RUN microdnf update -y && \
microdnf install -y --nodocs \
python-pip python-devel && \
pip install --upgrade --no-cache-dir pip wheel && \
microdnf clean all
RUN pip install --no-cache-dir torch

# FROM icr.io/fm-stack/ubi9-minimal-py39-torch as builder
FROM base as builder

COPY ./common/requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

COPY embedding_classification/requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

FROM builder

WORKDIR /app

COPY embedding_classification/build/model_artifacts /app/model_artifacts
COPY ./common /common
COPY embedding_classification/app.py /app
COPY embedding_classification/detector.py /app
COPY embedding_classification/scheme.py /app


EXPOSE 8000
CMD ["uvicorn", "app:app", "--workers", "1", "--host", "0.0.0.0", "--port", "8000", "--log-config", "/common/log_conf.yaml"]

# gunicorn main:app --workers 4 --worker-class uvicorn.workers.UvicornWorker --bind 0.0.0.0:8000
Empty file added detectors/__init__.py
Empty file.
37 changes: 37 additions & 0 deletions detectors/embedding_classification/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Embedding Classification Detector

# Setup
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps it could be useful to state that local Python must match up Python in the Containerfile? At present, python 3.9 will be downloaded inside the container, which may warrant upgrading?

1) Fetch prerequisite models, train pipeline, save training artifacts
```bash
cd guardrails-detectors/detectors/embedding_classification/build
make all
```
2) Build image (this can take a while and use a lot of VM storage during the build, beware):
```bash
cd guardrails-detectors
podman build --file=Dockerfile.embedding-classifier -t mmlu_detector:latest
```

## Detector API
### `/api/v1/text/contents`
* `contents`: List of texts to classify
* `allowList`: Allowed list of subjects: all inbound texts must belong to _at least one_ of these subjects to avoid flagging the detector
* `blockList`: Blocked list of subjects: all inbounds texts must not belong to _any_ of these subjects to avoid flagging the detector.
* `threshold`: Defines the maximum distance a body of text can be from the subject centroid and still be classified into that subject. The default value is 0.75, while a threshold of >10 will classify every document into every subject. As such, values 0<threshold<1 are recommended.


## Testing Locally
```bash
podman run -p 8001:8000 --platform=linux/amd64 quay.io/rgeada/mmlu_detector:latest
```
wait for the server to start (you should see a log message like `Uvicorn running on http://0.0.0.0:8000`), then:
```bash
curl -X POST "localhost:8001/api/v1/text/contents" -H "Content-Type: application/json" \
-H "detector-id: mmluTopicMatch" \
-d '{
"contents": ["How far away is the Sun from the center of the Milky Way?", "What is the healthiest vegetable?", "The square root of 256 is 16."],
"allowList": ["astronomy"],
"blockList": ["nutrition"]
}' | jq
```

44 changes: 44 additions & 0 deletions detectors/embedding_classification/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import os
import sys
from contextlib import asynccontextmanager
from typing import Annotated

from fastapi import Header
sys.path.insert(0, os.path.abspath(".."))

from common.app import DetectorBaseAPI as FastAPI
from detector import Detector
from scheme import (
ContentAnalysisHttpRequest,
ContentsAnalysisResponse,
Error,
)

detector_objects = {}

@asynccontextmanager
async def lifespan(app: FastAPI):
detector_objects["detector"] = Detector()
yield
# Clean up the ML models and release the resources
detector_objects.clear()


app = FastAPI(lifespan=lifespan, dependencies=[])


@app.post(
"/api/v1/text/contents",
response_model=ContentsAnalysisResponse,
description="""Detectors that work on content text, be it user prompt or generated text. \
Generally classification type detectors qualify for this. <br>""",
responses={
404: {"model": Error, "description": "Resource Not Found"},
422: {"model": Error, "description": "Validation Error"},
},
)
async def detector_unary_handler(
request: ContentAnalysisHttpRequest,
detector_id: Annotated[str, Header(example="en_syntax_slate.38m.hap")],
):
return ContentsAnalysisResponse(root=detector_objects["detector"].run(request))
9 changes: 9 additions & 0 deletions detectors/embedding_classification/build/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
train_pipeline:
python3 train.py

download_embedding_model:
huggingface-cli download dunzhang/stella_en_1.5B_v5 --local-dir model_artifacts/$(basename dunzhang/stella_en_1.5B_v5) --revision 7816d43c4efd2fead216afbb7522d2093b44b16b

all:
download_embedding_model
train_pipeline
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .mmlu_dataset_config import MMLUDatasetConfig

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be moved to a shared utils folder so that other detectors that require training can use this class ?

Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
class BaseDatasetConfig():
"""Base Config for defining text and label pairs from a Huggingface text dataset"""

def __init__(self):
pass

def get_text(self, docs):
"""Define a function to extract the "text" from each row of the dataset."""
raise NotImplementedError

def get_label(self, docs):
"""Define a function to extract the label from each row of the dataset."""
raise NotImplementedError
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from .base_dataset_config import BaseDatasetConfig


class MMLUDatasetConfig(BaseDatasetConfig):
"""Config for defining text and label pairs from MMLU"""

def __init__(self):
super().__init__()

def get_text(self, docs):
"""Define a function to extract the "text" from each row of the dataset."""
qs = docs['question']
ans = docs['answer']
cs = docs['choices']
return ["{}\n\n{}".format(qs[i], cs[i][ans[i]]) for i in range(len(docs))]

def get_label(self, docs):
"""Define a function to extract the label from each row of the dataset."""
return docs['subject']
8 changes: 8 additions & 0 deletions detectors/embedding_classification/build/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
transformers==4.43.4
datasets==3.0.0
pandas==2.2.2
sentence-transformers==3.3.1
numpy==1.26.4
tqdm==4.66.5
torch==2.4.0
umap-learn==0.5.7
127 changes: 127 additions & 0 deletions detectors/embedding_classification/build/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import argparse
import datasets
import dataset_configs
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import pathlib
import pickle
from sentence_transformers import SentenceTransformer
import torch
from tqdm.autonotebook import tqdm
import umap

from detectors.embedding_classification.build.dataset_configs.base_dataset_config import BaseDatasetConfig


plt.style.use('https://raw.githubusercontent.com/RobGeada/stylelibs/main/material_rh.mplstyle')


# === DATA LOADING =================================================================================
def load_data(dataset_name, **dataset_kwargs):
return datasets.load_dataset(dataset_name, **dataset_kwargs)


def generate_training_df(data, dataset_config: BaseDatasetConfig):
df = pd.DataFrame()
df['text'] = dataset_config.get_text(data)
df['label'] = dataset_config.get_label(data)
return df


# === EMBEDDING ====================================================================================
def get_torch_device():
cuda_available = torch.cuda.is_available()
mps_available = torch.backends.mps.is_available()
if cuda_available:
device = "cuda"
elif mps_available:
device = "mps"
else:
device = "cpu"
print("Using {} backend for sentence transformer.".format(device))
return torch.device(device)


def get_embedding_model():
device = get_torch_device()
return SentenceTransformer(os.path.join("model_artifacts","dunzhang","stella_en_1"), trust_remote_code=True).to(device)


def get_embeddings(train_df, batch_size, model):
query_prompt_name = "s2p_query"

nrows = len(train_df)
embeddings = np.zeros([nrows, 1024])
for idx in tqdm(range(0, nrows, batch_size)):
text = train_df['text'].iloc[idx: idx+batch_size]
embeddings[idx:idx+batch_size] = model.encode(text, prompt_name=query_prompt_name)
return embeddings


def generate_embedding_df(train_df, reduced_embedding):
embedding_df = pd.DataFrame(reduced_embedding)
embedding_df.columns = [str(i) for i in range(reduced_embedding.shape[1])]
embedding_df['Label'] = train_df['label']
return embedding_df


# === CENTROIDS ====================================================================================
def get_centroids(embedding_df, reduced_embedding):
return embedding_df.groupby("Label").agg({str(d): "mean" for d in range(reduced_embedding.shape[1])})


# ==================================================================================================
# === MAIN =========================================================================================
# ==================================================================================================
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='mmlu')

args = parser.parse_args()
artifact_path = os.path.join(pathlib.Path(__file__).parent.resolve(), "model_artifacts")
os.makedirs(artifact_path, exist_ok=True)

if args.dataset.lower() == 'mmlu':
# load data
print("Loading MMLU dataset...")
data = load_data("cais/mmlu", name='all')
train_df = generate_training_df(data['test'], dataset_configs.MMLUDatasetConfig())

# get embeddings
embedding_artifact_path = os.path.join(artifact_path, args.dataset.lower()+"_embeddings.npy")

if not os.path.exists(embedding_artifact_path):
print("Loading embedding model...")
embedding_model = get_embedding_model()

print("Generating embeddings for MMLU")
embeddings = get_embeddings(train_df, batch_size=4, model=embedding_model)
np.save(embedding_artifact_path, embeddings)
else:
print("Loading pre-trained embeddings...")
embeddings = np.load(embedding_artifact_path)

# get dimensionality reduction
print("Fitting dimensionality reduction...")
reducer = umap.UMAP(n_components=3)
reduced_embedding = reducer.fit_transform(embeddings)
embedding_df = generate_embedding_df(train_df, reduced_embedding)

# centroids
print("Generating centroids...")
centroids = get_centroids(embedding_df, reduced_embedding)

# save artifacts
print("Saving training artifacts to {}...".format(artifact_path))
pickle.dump(reducer, open(os.path.join(artifact_path, "umap.pkl"), "wb"))
centroids.to_pickle(os.path.join(artifact_path, "centroids.pkl"))
print("Training completed successfully!")
else:
raise NotImplementedError("Dataset {} not yet supported".format(args.dataset))





Loading