diff --git a/.gitignore b/.gitignore index ec8155a..151a015 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,9 @@ # Type checker .DS_Store +# IDEA +**/.idea/* + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] @@ -12,4 +15,8 @@ __pycache__/ .python-version # secrets -.env \ No newline at end of file +.env + +# artifacts +**/.pkl +**/model_artifacts/* \ No newline at end of file diff --git a/detectors/Dockerfile.embedding-classifier b/detectors/Dockerfile.embedding-classifier new file mode 100644 index 0000000..5a595db --- /dev/null +++ b/detectors/Dockerfile.embedding-classifier @@ -0,0 +1,32 @@ +FROM registry.access.redhat.com/ubi9/ubi-minimal as base +RUN microdnf update -y && \ + microdnf install -y --nodocs \ + python-pip python-devel && \ + pip install --upgrade --no-cache-dir pip wheel && \ + microdnf clean all +RUN pip install --no-cache-dir torch + +# FROM icr.io/fm-stack/ubi9-minimal-py39-torch as builder +FROM base as builder + +COPY ./common/requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +COPY embedding_classification/requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +FROM builder + +WORKDIR /app + +COPY embedding_classification/build/model_artifacts /app/model_artifacts +COPY ./common /common +COPY embedding_classification/app.py /app +COPY embedding_classification/detector.py /app +COPY embedding_classification/scheme.py /app + + +EXPOSE 8000 +CMD ["uvicorn", "app:app", "--workers", "1", "--host", "0.0.0.0", "--port", "8000", "--log-config", "/common/log_conf.yaml"] + +# gunicorn main:app --workers 4 --worker-class uvicorn.workers.UvicornWorker --bind 0.0.0.0:8000 \ No newline at end of file diff --git a/detectors/__init__.py b/detectors/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/detectors/embedding_classification/README.md b/detectors/embedding_classification/README.md new file mode 100644 index 0000000..1ca647c --- /dev/null +++ b/detectors/embedding_classification/README.md @@ -0,0 +1,37 @@ +# Embedding Classification Detector + +# Setup +1) Fetch prerequisite models, train pipeline, save training artifacts + ```bash + cd guardrails-detectors/detectors/embedding_classification/build + make all + ``` +2) Build image (this can take a while and use a lot of VM storage during the build, beware): + ```bash + cd guardrails-detectors + podman build --file=Dockerfile.embedding-classifier -t mmlu_detector:latest + ``` + +## Detector API +### `/api/v1/text/contents` +* `contents`: List of texts to classify +* `allowList`: Allowed list of subjects: all inbound texts must belong to _at least one_ of these subjects to avoid flagging the detector +* `blockList`: Blocked list of subjects: all inbounds texts must not belong to _any_ of these subjects to avoid flagging the detector. +* `threshold`: Defines the maximum distance a body of text can be from the subject centroid and still be classified into that subject. The default value is 0.75, while a threshold of >10 will classify every document into every subject. As such, values 0""", + responses={ + 404: {"model": Error, "description": "Resource Not Found"}, + 422: {"model": Error, "description": "Validation Error"}, + }, +) +async def detector_unary_handler( + request: ContentAnalysisHttpRequest, + detector_id: Annotated[str, Header(example="en_syntax_slate.38m.hap")], +): + return ContentsAnalysisResponse(root=detector_objects["detector"].run(request)) diff --git a/detectors/embedding_classification/build/Makefile b/detectors/embedding_classification/build/Makefile new file mode 100644 index 0000000..89805f2 --- /dev/null +++ b/detectors/embedding_classification/build/Makefile @@ -0,0 +1,9 @@ +train_pipeline: + python3 train.py + +download_embedding_model: + huggingface-cli download dunzhang/stella_en_1.5B_v5 --local-dir model_artifacts/$(basename dunzhang/stella_en_1.5B_v5) --revision 7816d43c4efd2fead216afbb7522d2093b44b16b + +all: + download_embedding_model + train_pipeline \ No newline at end of file diff --git a/detectors/embedding_classification/build/dataset_configs/__init__.py b/detectors/embedding_classification/build/dataset_configs/__init__.py new file mode 100644 index 0000000..5eb4bb1 --- /dev/null +++ b/detectors/embedding_classification/build/dataset_configs/__init__.py @@ -0,0 +1 @@ +from .mmlu_dataset_config import MMLUDatasetConfig \ No newline at end of file diff --git a/detectors/embedding_classification/build/dataset_configs/base_dataset_config.py b/detectors/embedding_classification/build/dataset_configs/base_dataset_config.py new file mode 100644 index 0000000..709d3fa --- /dev/null +++ b/detectors/embedding_classification/build/dataset_configs/base_dataset_config.py @@ -0,0 +1,13 @@ +class BaseDatasetConfig(): + """Base Config for defining text and label pairs from a Huggingface text dataset""" + + def __init__(self): + pass + + def get_text(self, docs): + """Define a function to extract the "text" from each row of the dataset.""" + raise NotImplementedError + + def get_label(self, docs): + """Define a function to extract the label from each row of the dataset.""" + raise NotImplementedError \ No newline at end of file diff --git a/detectors/embedding_classification/build/dataset_configs/mmlu_dataset_config.py b/detectors/embedding_classification/build/dataset_configs/mmlu_dataset_config.py new file mode 100644 index 0000000..07c4626 --- /dev/null +++ b/detectors/embedding_classification/build/dataset_configs/mmlu_dataset_config.py @@ -0,0 +1,19 @@ +from .base_dataset_config import BaseDatasetConfig + + +class MMLUDatasetConfig(BaseDatasetConfig): + """Config for defining text and label pairs from MMLU""" + + def __init__(self): + super().__init__() + + def get_text(self, docs): + """Define a function to extract the "text" from each row of the dataset.""" + qs = docs['question'] + ans = docs['answer'] + cs = docs['choices'] + return ["{}\n\n{}".format(qs[i], cs[i][ans[i]]) for i in range(len(docs))] + + def get_label(self, docs): + """Define a function to extract the label from each row of the dataset.""" + return docs['subject'] \ No newline at end of file diff --git a/detectors/embedding_classification/build/requirements.txt b/detectors/embedding_classification/build/requirements.txt new file mode 100644 index 0000000..f98eca2 --- /dev/null +++ b/detectors/embedding_classification/build/requirements.txt @@ -0,0 +1,8 @@ +transformers==4.43.4 +datasets==3.0.0 +pandas==2.2.2 +sentence-transformers==3.3.1 +numpy==1.26.4 +tqdm==4.66.5 +torch==2.4.0 +umap-learn==0.5.7 \ No newline at end of file diff --git a/detectors/embedding_classification/build/train.py b/detectors/embedding_classification/build/train.py new file mode 100644 index 0000000..5ea0f5c --- /dev/null +++ b/detectors/embedding_classification/build/train.py @@ -0,0 +1,127 @@ +import argparse +import datasets +import dataset_configs +import matplotlib.pyplot as plt +import numpy as np +import os +import pandas as pd +import pathlib +import pickle +from sentence_transformers import SentenceTransformer +import torch +from tqdm.autonotebook import tqdm +import umap + +from detectors.embedding_classification.build.dataset_configs.base_dataset_config import BaseDatasetConfig + + +plt.style.use('https://raw.githubusercontent.com/RobGeada/stylelibs/main/material_rh.mplstyle') + + +# === DATA LOADING ================================================================================= +def load_data(dataset_name, **dataset_kwargs): + return datasets.load_dataset(dataset_name, **dataset_kwargs) + + +def generate_training_df(data, dataset_config: BaseDatasetConfig): + df = pd.DataFrame() + df['text'] = dataset_config.get_text(data) + df['label'] = dataset_config.get_label(data) + return df + + +# === EMBEDDING ==================================================================================== +def get_torch_device(): + cuda_available = torch.cuda.is_available() + mps_available = torch.backends.mps.is_available() + if cuda_available: + device = "cuda" + elif mps_available: + device = "mps" + else: + device = "cpu" + print("Using {} backend for sentence transformer.".format(device)) + return torch.device(device) + + +def get_embedding_model(): + device = get_torch_device() + return SentenceTransformer(os.path.join("model_artifacts","dunzhang","stella_en_1"), trust_remote_code=True).to(device) + + +def get_embeddings(train_df, batch_size, model): + query_prompt_name = "s2p_query" + + nrows = len(train_df) + embeddings = np.zeros([nrows, 1024]) + for idx in tqdm(range(0, nrows, batch_size)): + text = train_df['text'].iloc[idx: idx+batch_size] + embeddings[idx:idx+batch_size] = model.encode(text, prompt_name=query_prompt_name) + return embeddings + + +def generate_embedding_df(train_df, reduced_embedding): + embedding_df = pd.DataFrame(reduced_embedding) + embedding_df.columns = [str(i) for i in range(reduced_embedding.shape[1])] + embedding_df['Label'] = train_df['label'] + return embedding_df + + +# === CENTROIDS ==================================================================================== +def get_centroids(embedding_df, reduced_embedding): + return embedding_df.groupby("Label").agg({str(d): "mean" for d in range(reduced_embedding.shape[1])}) + + +# ================================================================================================== +# === MAIN ========================================================================================= +# ================================================================================================== +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--dataset', type=str, default='mmlu') + + args = parser.parse_args() + artifact_path = os.path.join(pathlib.Path(__file__).parent.resolve(), "model_artifacts") + os.makedirs(artifact_path, exist_ok=True) + + if args.dataset.lower() == 'mmlu': + # load data + print("Loading MMLU dataset...") + data = load_data("cais/mmlu", name='all') + train_df = generate_training_df(data['test'], dataset_configs.MMLUDatasetConfig()) + + # get embeddings + embedding_artifact_path = os.path.join(artifact_path, args.dataset.lower()+"_embeddings.npy") + + if not os.path.exists(embedding_artifact_path): + print("Loading embedding model...") + embedding_model = get_embedding_model() + + print("Generating embeddings for MMLU") + embeddings = get_embeddings(train_df, batch_size=4, model=embedding_model) + np.save(embedding_artifact_path, embeddings) + else: + print("Loading pre-trained embeddings...") + embeddings = np.load(embedding_artifact_path) + + # get dimensionality reduction + print("Fitting dimensionality reduction...") + reducer = umap.UMAP(n_components=3) + reduced_embedding = reducer.fit_transform(embeddings) + embedding_df = generate_embedding_df(train_df, reduced_embedding) + + # centroids + print("Generating centroids...") + centroids = get_centroids(embedding_df, reduced_embedding) + + # save artifacts + print("Saving training artifacts to {}...".format(artifact_path)) + pickle.dump(reducer, open(os.path.join(artifact_path, "umap.pkl"), "wb")) + centroids.to_pickle(os.path.join(artifact_path, "centroids.pkl")) + print("Training completed successfully!") + else: + raise NotImplementedError("Dataset {} not yet supported".format(args.dataset)) + + + + + diff --git a/detectors/embedding_classification/detector.py b/detectors/embedding_classification/detector.py new file mode 100644 index 0000000..056b075 --- /dev/null +++ b/detectors/embedding_classification/detector.py @@ -0,0 +1,118 @@ +import os +import sys + +import pandas as pd + +sys.path.insert(0, os.path.abspath("")) +# from common.scheme import TextDetectionHttpRequest, TextDetectionResponse +import os +import pathlib +import pickle as pkl +import numpy as np +import torch +import torch.nn + + +from scheme import ( + ContentAnalysisHttpRequest, + ContentAnalysisResponse, + ContentsAnalysisResponse, + EvidenceObj +) + +# Detector imports +from sentence_transformers import SentenceTransformer + +try: + from common.app import logger +except ImportError: + sys.path.insert(0, os.path.join(pathlib.Path(__file__).parent.parent.resolve())) + from common.app import logger + + +class Detector: + def __init__(self): + # initialize the detector + model_files_path = os.path.abspath(os.path.join(os.sep, "app", "model_artifacts")) + if not os.path.exists(model_files_path): + model_files_path = os.path.join("build", "model_artifacts") + logger.info(model_files_path) + + self.model = SentenceTransformer(os.path.join(model_files_path, "dunzhang", "stella_en_1"), trust_remote_code=True) + self.reducer = pkl.load(open(os.path.join(model_files_path, "umap.pkl"), "rb")) + self.centroids = pd.read_pickle(os.path.join(model_files_path, "centroids.pkl")) + + if torch.cuda.is_available(): + # transparently taking a cuda gpu for an actor + self.cuda_device = torch.device("cuda") + torch.cuda.empty_cache() + self.model.to(self.cuda_device) + # self.tokenizer.to(self.cuda_device) + # AttributeError: 'RobertaTokenizerFast' object has no attribute 'to' + os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512" + logger.info("cuda_device".upper() + " " + str(self.cuda_device)) + self.batch_size = 1 + else: + self.batch_size = 8 + + logger.info("Detector initialized.") + + def get_distance_to_centroids(self, point): + dists = self.centroids.apply(lambda x: np.linalg.norm(point-x), 1).sort_values().iloc[:10] + return {k:np.round(v/1, 2) for k,v in dists.to_dict().items()} + + + def run(self, request: ContentAnalysisHttpRequest) -> ContentsAnalysisResponse: + # run the classification for each entry on contents array + # logger.info(tokenizer_parameters) + contents_analyses = [] + for batch_idx in range(0, len(request.contents), self.batch_size): + texts = request.contents[batch_idx:batch_idx+self.batch_size] + embedding = self.model.encode(texts, prompt_name="s2p_query") + umapped = self.reducer.transform(embedding) + + for idx in range(len(umapped)): + topics = self.get_distance_to_centroids(umapped[idx]) + + logger.debug(topics) + matched_topics = {k:v for k,v in topics.items() if v < request.threshold} + + isDetection = False + violationDescription = [] + if request.allowList: + isDetection = not any([allowedTopic in matched_topics.keys() for allowedTopic in request.allowList]) + if isDetection: + violationDescription.append("Text matched none of the allowed topics: {}".format(request.allowList)) + + if request.blockList: + blockMatches = [blockedTopic for blockedTopic in request.blockList if blockedTopic in matched_topics.keys()] + if blockMatches: + isDetection = True + violationDescription.append("Text matched the following blocked topic(s): {}".format(blockMatches)) + + contents_analyses.append( + ContentAnalysisResponse( + start=0, + end=len(texts[idx]), + detection="mmluTopicMatch", + detection_type="mmluTopicMatch", + topics=matched_topics, + violation=isDetection, + violationDescription=violationDescription, + text=texts[idx], + evidences=[], + ) + ) + return contents_analyses + +### local testing +if __name__ == "__main__": + detector = Detector() + request = ContentAnalysisHttpRequest( + contents=["How far away is the Sun from the center of the Milky Way?", "What is the healthiest vegetable?", "What is the square root of 256?"], + allowList=['astronomy'], + blockList=['nutrition'], + ) + + analyses = detector.run(request) + print(analyses) \ No newline at end of file diff --git a/detectors/embedding_classification/requirements.txt b/detectors/embedding_classification/requirements.txt new file mode 100644 index 0000000..8c89656 --- /dev/null +++ b/detectors/embedding_classification/requirements.txt @@ -0,0 +1,5 @@ +sentence-transformers==3.3.1 +umap-learn==0.5.7 +torch==2.4.0 +pandas==2.2.2 +numpy==1.26.4 \ No newline at end of file diff --git a/detectors/embedding_classification/scheme.py b/detectors/embedding_classification/scheme.py new file mode 100644 index 0000000..7a59126 --- /dev/null +++ b/detectors/embedding_classification/scheme.py @@ -0,0 +1,79 @@ +from enum import Enum +from typing import List, Optional, Dict + +from pydantic import BaseModel, Field, RootModel + + +class Evidence(BaseModel): + source: str = Field( + title="Source", + example="https://en.wikipedia.org/wiki/IBM", + description="Source of the evidence, it can be url of the evidence etc", + ) + + +class EvidenceType(str, Enum): + url = "url" + title = "title" + + +class EvidenceObj(BaseModel): + type: EvidenceType = Field( + title="EvidenceType", + example="url", + description="Type field signifying the type of evidence provided. Example url, title etc", + ) + evidence: Evidence = Field( + description="Evidence object, currently only containing source, but in future can contain other optional arguments like id, etc", + ) + + +class ContentAnalysisHttpRequest(BaseModel): + contents: List[str] = Field( + min_length=1, + title="Contents", + description="Field allowing users to provide list of texts for analysis. Note, results of this endpoint will contain analysis / detection of each of the provided text in the order they are present in the contents object.", + example=[ + "Martians are like crocodiles; the more you give them meat, the more they want" + ], + ) + threshold: float = Field( + default=.75, + description="Determines the maximum distance between a prompt embedding and the topic centroid to still be considered part of that topic. E.g., all prompts within $threshold of the `astronomy` centroid are given the 'astronomy' label. " + ) + allowList: List[str]= Field( + description="If an allowList is provided, only prompts that match at least one of the provided topics will be allowed.", + example=['anatomy', 'global_facts', 'high_school_mathematics'] + ) + blockList: List[str]= Field( + description="If a blockList is provided, only prompts that match none of the provided topics will be allowed.", + example=['anatomy', 'global_facts', 'high_school_mathematics'] + ) + + +class ContentAnalysisResponse(BaseModel): + start: int = Field(example=14) + end: int = Field(example=26) + detection: str = Field(example="mmluTopicMatch") + detection_type: str = Field(example="mmluTopicMatch") + text: str = Field(example="My favourite dish is pierogi") + topics: Dict[str, float] = Field(example=['nutrition']) + violation: bool = Field(example=True) + violationDescription: List[str] = Field( + description="Description of any possible violations for allowed or blocked topics present in the text.", + example=["Text matched none of the allowed topics: ['astronomy', 'anatomy']", "Text matched blocked topic(s): ['nutrition']"]) + evidences: Optional[List[EvidenceObj]] = Field( + description="Optional field providing evidences for the provided detection", + default=[], + ) + + +class ContentsAnalysisResponse(RootModel): + root: List[List[ContentAnalysisResponse]] = Field( + title="Response Text Content Analysis Unary Handler Api V1 Text Content Post" + ) + + +class Error(BaseModel): + code: int + message: str