Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions detectors/Dockerfile.hf
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,12 @@ FROM builder
WORKDIR /app
ARG CACHEBUST=1
RUN echo "$CACHEBUST"
COPY ./common /common
COPY ./common /app/detectors/common
COPY ./huggingface/detector.py /app/detectors/huggingface/
RUN mkdir /common; cp /app/detectors/common/log_conf.yaml /common/
COPY ./huggingface/app.py /app
COPY ./huggingface/detector.py /app

EXPOSE 8000
CMD ["uvicorn", "app:app", "--workers", "4", "--host", "0.0.0.0", "--port", "8000", "--log-config", "/common/log_conf.yaml"]
CMD ["uvicorn", "app:app", "--workers", "1", "--host", "0.0.0.0", "--port", "8000", "--log-config", "/common/log_conf.yaml"]

# gunicorn main:app --workers 4 --worker-class uvicorn.workers.UvicornWorker --bind 0.0.0.0:8000
50 changes: 4 additions & 46 deletions detectors/built_in/base_detector_registry.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
import contextlib
import logging
from abc import ABC, abstractmethod
import time
from http.client import HTTPException
from fastapi import HTTPException
from typing import List

from detectors.common.instrumented_detector import InstrumentedDetector
from detectors.common.scheme import ContentAnalysisResponse

class BaseDetectorRegistry(ABC):
class BaseDetectorRegistry(InstrumentedDetector, ABC):
def __init__(self, registry_name):
super().__init__(registry_name)
self.registry = None
self.registry_name = registry_name

# prometheus
self.instruments = {}

@abstractmethod
def handle_request(self, content: str, detector_params: dict, headers: dict, **kwargs) -> List[ContentAnalysisResponse]:
Expand All @@ -22,35 +19,6 @@ def handle_request(self, content: str, detector_params: dict, headers: dict, **k
def get_registry(self):
return self.registry

def add_instruments(self, gauges):
self.instruments = gauges

def increment_detector_instruments(self, function_name: str, is_detection: bool):
"""Increment the detection and request counters, automatically update rates"""
if self.instruments.get("requests"):
self.instruments["requests"].labels(self.registry_name, function_name).inc()

# The labels() function will initialize the counters if not already created.
# This prevents the counters not existing until they are first incremented
# If the counters have already been created, this is just a cheap dict.get() call
if self.instruments.get("errors"):
_ = self.instruments["errors"].labels(self.registry_name, function_name)
if self.instruments.get("runtime"):
_ = self.instruments["runtime"].labels(self.registry_name, function_name)

# create and/or increment the detection counter
if self.instruments.get("detections"):
detection_counter = self.instruments["detections"].labels(self.registry_name, function_name)
if is_detection:
detection_counter.inc()


def increment_error_instruments(self, function_name: str):
"""Increment the error counter, update the rate gauges"""
if self.instruments.get("errors"):
self.instruments["errors"].labels(self.registry_name, function_name).inc()


def throw_internal_detector_error(self, function_name: str, logger: logging.Logger, exception: Exception, increment_requests: bool):
"""consistent handling of internal errors within a detection function"""
if increment_requests and self.instruments.get("requests"):
Expand All @@ -60,16 +28,6 @@ def throw_internal_detector_error(self, function_name: str, logger: logging.Logg
raise HTTPException(status_code=500, detail="Detection error, check detector logs")


@contextlib.contextmanager
def instrument_runtime(self, function_name: str):
try:
start_time = time.time()
yield
if self.instruments.get("runtime"):
self.instruments["runtime"].labels(self.registry_name, function_name).inc(time.time() - start_time)
finally:
pass

def get_detection_functions_from_params(self, params: dict):
"""Parse the request parameters to extract and normalize detection functions as iterable list"""
if self.registry_name in params and isinstance(params[self.registry_name], (list, str)):
Expand Down
13 changes: 5 additions & 8 deletions detectors/common/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,12 @@
import yaml
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from prometheus_client import Gauge, Counter

sys.path.insert(0, os.path.abspath(".."))
from prometheus_client import Counter

import logging

from fastapi import FastAPI, status
from starlette.exceptions import HTTPException as StarletteHTTPException
from prometheus_fastapi_instrumentator import Instrumentator

logger = logging.getLogger(__name__)
uvicorn_error_logger = logging.getLogger("uvicorn.error")
Expand All @@ -39,22 +36,22 @@ def __init__(self, *args, **kwargs):
self.state.instruments = {
"detections": Counter(
"trustyai_guardrails_detections",
"Number of detections per built-in detector function",
"Number of detections per detector function",
["detector_kind", "detector_name"]
),
"requests": Counter(
"trustyai_guardrails_requests",
"Number of requests per built-in detector function",
"Number of requests per detector function",
["detector_kind", "detector_name"]
),
"errors": Counter(
"trustyai_guardrails_errors",
"Number of errors per built-in detector function",
"Number of errors per detector function",
["detector_kind", "detector_name"]
),
"runtime": Counter(
"trustyai_guardrails_runtime",
"Total runtime of a built-in detector function- this is the induced latency of this guardrail",
"Total runtime of a detector function- this is the induced latency of this guardrail",
["detector_kind", "detector_name"]
)
}
Expand Down
45 changes: 45 additions & 0 deletions detectors/common/instrumented_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import contextlib
import time


class InstrumentedDetector:
def __init__(self, registry_name: str = "default"):
self.registry_name = registry_name
self.instruments = {}

@contextlib.contextmanager
def instrument_runtime(self, function_name: str):
try:
start_time = time.time()
yield
if self.instruments.get("runtime"):
self.instruments["runtime"].labels(self.registry_name, function_name).inc(time.time() - start_time)
finally:
pass

def add_instruments(self, gauges):
self.instruments = gauges

def increment_detector_instruments(self, function_name: str, is_detection: bool):
"""Increment the detection and request counters, automatically update rates"""
if self.instruments.get("requests"):
self.instruments["requests"].labels(self.registry_name, function_name).inc()

# The labels() function will initialize the counters if not already created.
# This prevents the counters not existing until they are first incremented
# If the counters have already been created, this is just a cheap dict.get() call
if self.instruments.get("errors"):
_ = self.instruments["errors"].labels(self.registry_name, function_name)
if self.instruments.get("runtime"):
_ = self.instruments["runtime"].labels(self.registry_name, function_name)

# create and/or increment the detection counter
if self.instruments.get("detections"):
detection_counter = self.instruments["detections"].labels(self.registry_name, function_name)
if is_detection:
detection_counter.inc()

def increment_error_instruments(self, function_name: str):
"""Increment the error counter, update the rate gauges"""
if self.instruments.get("errors"):
self.instruments["errors"].labels(self.registry_name, function_name).inc()
1 change: 1 addition & 0 deletions detectors/common/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ locust==2.31.1
pre-commit==3.8.0
pytest==8.3.2
tls-test-tools
protobuf==6.33.0
29 changes: 14 additions & 15 deletions detectors/huggingface/app.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
import os
import sys
from contextlib import asynccontextmanager
from typing import Annotated
from typing import List

from fastapi import Header
from prometheus_fastapi_instrumentator import Instrumentator
sys.path.insert(0, os.path.abspath(".."))

from common.app import DetectorBaseAPI as FastAPI
from detector import Detector
from common.scheme import (
from starlette.concurrency import run_in_threadpool
from detectors.common.app import DetectorBaseAPI as FastAPI
from detectors.huggingface.detector import Detector
from detectors.common.scheme import (
ContentAnalysisHttpRequest,
ContentsAnalysisResponse,
Error,
Expand All @@ -18,7 +14,9 @@

@asynccontextmanager
async def lifespan(app: FastAPI):
app.set_detector(Detector())
detector = Detector()
app.set_detector(detector, detector.model_name)
detector.add_instruments(app.state.instruments)
yield
# Clean up the ML models and release the resources
detector: Detector = app.get_detector()
Expand All @@ -42,10 +40,11 @@ async def lifespan(app: FastAPI):
},
)
async def detector_unary_handler(
request: ContentAnalysisHttpRequest,
detector_id: Annotated[str, Header(example="en_syntax_slate.38m.hap")],
request: ContentAnalysisHttpRequest,
):
detector: Detector = app.get_detector()
if not detector:
detectors: List[Detector] = list(app.get_all_detectors().values())
if not len(detectors) or not detectors[0]:
raise RuntimeError("Detector is not initialized")
return ContentsAnalysisResponse(root=detector.run(request))
result = await run_in_threadpool(detectors[0].run, request)
return ContentsAnalysisResponse(root=result)

6 changes: 5 additions & 1 deletion detectors/huggingface/deploy/servingruntime.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ metadata:
opendatahub.io/dashboard: 'true'
spec:
annotations:
prometheus.io/port: '8080'
prometheus.io/port: '8000'
prometheus.io/path: '/metrics'
multiModel: false
supportedModelFormats:
Expand All @@ -35,6 +35,10 @@ spec:
value: /mnt/models
- name: HF_HOME
value: /tmp/hf_home
- name: DETECTOR_NAME
valueFrom:
fieldRef:
fieldPath: metadata.name
ports:
- containerPort: 8000
protocol: TCP
Expand Down
60 changes: 39 additions & 21 deletions detectors/huggingface/detector.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import sys

sys.path.insert(0, os.path.abspath(".."))
from detectors.common.instrumented_detector import InstrumentedDetector

import json
import math
import torch
Expand All @@ -11,14 +11,15 @@
AutoModelForSequenceClassification,
AutoModelForCausalLM,
)
from common.app import logger
from common.scheme import (
from detectors.common.app import logger
from detectors.common.scheme import (
ContentAnalysisHttpRequest,
ContentAnalysisResponse,
ContentsAnalysisResponse,
)
import gc


def _parse_safe_labels_env():
if os.environ.get("SAFE_LABELS"):
try:
Expand All @@ -35,7 +36,8 @@ def _parse_safe_labels_env():
logger.info("SAFE_LABELS env var not set: defaulting to [0].")
return [0]

class Detector:

class Detector(InstrumentedDetector):
risk_names = [
"harm",
"social_bias",
Expand All @@ -50,6 +52,7 @@ def __init__(self):
"""
Initialize the Detector class by setting up the model, tokenizer, and device.
"""
super().__init__()
self.tokenizer = None
self.model = None
self.cuda_device = None
Expand Down Expand Up @@ -102,6 +105,18 @@ def initialize_model(self, model_files_path):
else:
self.model_name = "unknown"

self.registry_name = self.model_name

# set by k8s to be the pod name
if os.environ.get("DETECTOR_NAME"):
pod_name = os.environ.get("DETECTOR_NAME")
if "-predictor" in pod_name:
# recover the original ISVC name as specified by the user
pod_name = pod_name.split("-predictor")[0]
self.function_name = pod_name
else:
self.function_name = os.path.basename(model_files_path)

logger.info(f"Model type detected: {self.model_name}")

def initialize_device(self):
Expand Down Expand Up @@ -173,7 +188,7 @@ def get_probabilities(self, logprobs, safe_token, unsafe_token):
unsafe_token_prob = 1e-50
for gen_token_i in logprobs:
for logprob, index in zip(
gen_token_i.values.tolist()[0], gen_token_i.indices.tolist()[0]
gen_token_i.values.tolist()[0], gen_token_i.indices.tolist()[0]
):
decoded_token = self.tokenizer.convert_ids_to_tokens(index)
if decoded_token.strip().lower() == safe_token.lower():
Expand Down Expand Up @@ -226,7 +241,7 @@ def process_sequence_classification(self, text, detector_params=None, threshold=
threshold = detector_params.get("threshold", 0.5)
# Merge safe_labels from env and request
request_safe_labels = set(detector_params.get("safe_labels", []))
all_safe_labels = set(self.safe_labels) | request_safe_labels
all_safe_labels = set(self.safe_labels) | request_safe_labels
content_analyses = []
tokenized = self.tokenizer(
text,
Expand All @@ -245,9 +260,9 @@ def process_sequence_classification(self, text, detector_params=None, threshold=
label = self.model.config.id2label[idx]
# Exclude by index or label name
if (
prob >= threshold
and idx not in all_safe_labels
and label not in all_safe_labels
prob >= threshold
and idx not in all_safe_labels
and label not in all_safe_labels
):
detection_value = getattr(self.model.config, "problem_type", None)
content_analyses.append(
Expand All @@ -274,16 +289,19 @@ def run(self, input: ContentAnalysisHttpRequest) -> ContentsAnalysisResponse:
ContentsAnalysisResponse: The aggregated response for all input texts.
"""
contents_analyses = []
for text in input.contents:
if self.is_causal_lm:
analyses = self.process_causal_lm(text)
elif self.is_sequence_classifier:
analyses = self.process_sequence_classification(
text, detector_params=getattr(input, "detector_params", None)
)
else:
raise ValueError("Unsupported model type for analysis.")
contents_analyses.append(analyses)
with self.instrument_runtime(self.function_name):
for text in input.contents:
if self.is_causal_lm:
analyses = self.process_causal_lm(text)
elif self.is_sequence_classifier:
analyses = self.process_sequence_classification(
text, detector_params=getattr(input, "detector_params", None)
)
else:
raise ValueError("Unsupported model type for analysis.")
contents_analyses.append(analyses)
is_detection = any(len(analyses) > 0 for analyses in contents_analyses)
self.increment_detector_instruments(self.function_name, is_detection=is_detection)
return contents_analyses

def close(self) -> None:
Expand All @@ -299,4 +317,4 @@ def close(self) -> None:
gc.collect()

if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.empty_cache()
Loading