diff --git a/detectors/Dockerfile.hf b/detectors/Dockerfile.hf index 440f988..04424f3 100644 --- a/detectors/Dockerfile.hf +++ b/detectors/Dockerfile.hf @@ -96,11 +96,12 @@ FROM builder WORKDIR /app ARG CACHEBUST=1 RUN echo "$CACHEBUST" -COPY ./common /common +COPY ./common /app/detectors/common +COPY ./huggingface/detector.py /app/detectors/huggingface/ +RUN mkdir /common; cp /app/detectors/common/log_conf.yaml /common/ COPY ./huggingface/app.py /app -COPY ./huggingface/detector.py /app EXPOSE 8000 -CMD ["uvicorn", "app:app", "--workers", "4", "--host", "0.0.0.0", "--port", "8000", "--log-config", "/common/log_conf.yaml"] +CMD ["uvicorn", "app:app", "--workers", "1", "--host", "0.0.0.0", "--port", "8000", "--log-config", "/common/log_conf.yaml"] # gunicorn main:app --workers 4 --worker-class uvicorn.workers.UvicornWorker --bind 0.0.0.0:8000 \ No newline at end of file diff --git a/detectors/built_in/base_detector_registry.py b/detectors/built_in/base_detector_registry.py index 6bc9f24..170e0d7 100644 --- a/detectors/built_in/base_detector_registry.py +++ b/detectors/built_in/base_detector_registry.py @@ -1,19 +1,16 @@ -import contextlib import logging from abc import ABC, abstractmethod -import time -from http.client import HTTPException +from fastapi import HTTPException from typing import List +from detectors.common.instrumented_detector import InstrumentedDetector from detectors.common.scheme import ContentAnalysisResponse -class BaseDetectorRegistry(ABC): +class BaseDetectorRegistry(InstrumentedDetector, ABC): def __init__(self, registry_name): + super().__init__(registry_name) self.registry = None - self.registry_name = registry_name - # prometheus - self.instruments = {} @abstractmethod def handle_request(self, content: str, detector_params: dict, headers: dict, **kwargs) -> List[ContentAnalysisResponse]: @@ -22,35 +19,6 @@ def handle_request(self, content: str, detector_params: dict, headers: dict, **k def get_registry(self): return self.registry - def add_instruments(self, gauges): - self.instruments = gauges - - def increment_detector_instruments(self, function_name: str, is_detection: bool): - """Increment the detection and request counters, automatically update rates""" - if self.instruments.get("requests"): - self.instruments["requests"].labels(self.registry_name, function_name).inc() - - # The labels() function will initialize the counters if not already created. - # This prevents the counters not existing until they are first incremented - # If the counters have already been created, this is just a cheap dict.get() call - if self.instruments.get("errors"): - _ = self.instruments["errors"].labels(self.registry_name, function_name) - if self.instruments.get("runtime"): - _ = self.instruments["runtime"].labels(self.registry_name, function_name) - - # create and/or increment the detection counter - if self.instruments.get("detections"): - detection_counter = self.instruments["detections"].labels(self.registry_name, function_name) - if is_detection: - detection_counter.inc() - - - def increment_error_instruments(self, function_name: str): - """Increment the error counter, update the rate gauges""" - if self.instruments.get("errors"): - self.instruments["errors"].labels(self.registry_name, function_name).inc() - - def throw_internal_detector_error(self, function_name: str, logger: logging.Logger, exception: Exception, increment_requests: bool): """consistent handling of internal errors within a detection function""" if increment_requests and self.instruments.get("requests"): @@ -60,16 +28,6 @@ def throw_internal_detector_error(self, function_name: str, logger: logging.Logg raise HTTPException(status_code=500, detail="Detection error, check detector logs") - @contextlib.contextmanager - def instrument_runtime(self, function_name: str): - try: - start_time = time.time() - yield - if self.instruments.get("runtime"): - self.instruments["runtime"].labels(self.registry_name, function_name).inc(time.time() - start_time) - finally: - pass - def get_detection_functions_from_params(self, params: dict): """Parse the request parameters to extract and normalize detection functions as iterable list""" if self.registry_name in params and isinstance(params[self.registry_name], (list, str)): diff --git a/detectors/common/app.py b/detectors/common/app.py index 5387716..5705d41 100644 --- a/detectors/common/app.py +++ b/detectors/common/app.py @@ -7,15 +7,12 @@ import yaml from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse -from prometheus_client import Gauge, Counter - -sys.path.insert(0, os.path.abspath("..")) +from prometheus_client import Counter import logging from fastapi import FastAPI, status from starlette.exceptions import HTTPException as StarletteHTTPException -from prometheus_fastapi_instrumentator import Instrumentator logger = logging.getLogger(__name__) uvicorn_error_logger = logging.getLogger("uvicorn.error") @@ -39,22 +36,22 @@ def __init__(self, *args, **kwargs): self.state.instruments = { "detections": Counter( "trustyai_guardrails_detections", - "Number of detections per built-in detector function", + "Number of detections per detector function", ["detector_kind", "detector_name"] ), "requests": Counter( "trustyai_guardrails_requests", - "Number of requests per built-in detector function", + "Number of requests per detector function", ["detector_kind", "detector_name"] ), "errors": Counter( "trustyai_guardrails_errors", - "Number of errors per built-in detector function", + "Number of errors per detector function", ["detector_kind", "detector_name"] ), "runtime": Counter( "trustyai_guardrails_runtime", - "Total runtime of a built-in detector function- this is the induced latency of this guardrail", + "Total runtime of a detector function- this is the induced latency of this guardrail", ["detector_kind", "detector_name"] ) } diff --git a/detectors/common/instrumented_detector.py b/detectors/common/instrumented_detector.py new file mode 100644 index 0000000..72f0061 --- /dev/null +++ b/detectors/common/instrumented_detector.py @@ -0,0 +1,45 @@ +import contextlib +import time + + +class InstrumentedDetector: + def __init__(self, registry_name: str = "default"): + self.registry_name = registry_name + self.instruments = {} + + @contextlib.contextmanager + def instrument_runtime(self, function_name: str): + try: + start_time = time.time() + yield + if self.instruments.get("runtime"): + self.instruments["runtime"].labels(self.registry_name, function_name).inc(time.time() - start_time) + finally: + pass + + def add_instruments(self, gauges): + self.instruments = gauges + + def increment_detector_instruments(self, function_name: str, is_detection: bool): + """Increment the detection and request counters, automatically update rates""" + if self.instruments.get("requests"): + self.instruments["requests"].labels(self.registry_name, function_name).inc() + + # The labels() function will initialize the counters if not already created. + # This prevents the counters not existing until they are first incremented + # If the counters have already been created, this is just a cheap dict.get() call + if self.instruments.get("errors"): + _ = self.instruments["errors"].labels(self.registry_name, function_name) + if self.instruments.get("runtime"): + _ = self.instruments["runtime"].labels(self.registry_name, function_name) + + # create and/or increment the detection counter + if self.instruments.get("detections"): + detection_counter = self.instruments["detections"].labels(self.registry_name, function_name) + if is_detection: + detection_counter.inc() + + def increment_error_instruments(self, function_name: str): + """Increment the error counter, update the rate gauges""" + if self.instruments.get("errors"): + self.instruments["errors"].labels(self.registry_name, function_name).inc() diff --git a/detectors/common/requirements-dev.txt b/detectors/common/requirements-dev.txt index 243bd3e..a50e66c 100644 --- a/detectors/common/requirements-dev.txt +++ b/detectors/common/requirements-dev.txt @@ -3,3 +3,4 @@ locust==2.31.1 pre-commit==3.8.0 pytest==8.3.2 tls-test-tools +protobuf==6.33.0 diff --git a/detectors/huggingface/app.py b/detectors/huggingface/app.py index f65f28a..bc9f0e1 100644 --- a/detectors/huggingface/app.py +++ b/detectors/huggingface/app.py @@ -1,15 +1,11 @@ -import os -import sys from contextlib import asynccontextmanager -from typing import Annotated +from typing import List -from fastapi import Header from prometheus_fastapi_instrumentator import Instrumentator -sys.path.insert(0, os.path.abspath("..")) - -from common.app import DetectorBaseAPI as FastAPI -from detector import Detector -from common.scheme import ( +from starlette.concurrency import run_in_threadpool +from detectors.common.app import DetectorBaseAPI as FastAPI +from detectors.huggingface.detector import Detector +from detectors.common.scheme import ( ContentAnalysisHttpRequest, ContentsAnalysisResponse, Error, @@ -18,7 +14,9 @@ @asynccontextmanager async def lifespan(app: FastAPI): - app.set_detector(Detector()) + detector = Detector() + app.set_detector(detector, detector.model_name) + detector.add_instruments(app.state.instruments) yield # Clean up the ML models and release the resources detector: Detector = app.get_detector() @@ -42,10 +40,11 @@ async def lifespan(app: FastAPI): }, ) async def detector_unary_handler( - request: ContentAnalysisHttpRequest, - detector_id: Annotated[str, Header(example="en_syntax_slate.38m.hap")], + request: ContentAnalysisHttpRequest, ): - detector: Detector = app.get_detector() - if not detector: + detectors: List[Detector] = list(app.get_all_detectors().values()) + if not len(detectors) or not detectors[0]: raise RuntimeError("Detector is not initialized") - return ContentsAnalysisResponse(root=detector.run(request)) + result = await run_in_threadpool(detectors[0].run, request) + return ContentsAnalysisResponse(root=result) + diff --git a/detectors/huggingface/deploy/servingruntime.yaml b/detectors/huggingface/deploy/servingruntime.yaml index a71c067..b723694 100644 --- a/detectors/huggingface/deploy/servingruntime.yaml +++ b/detectors/huggingface/deploy/servingruntime.yaml @@ -9,7 +9,7 @@ metadata: opendatahub.io/dashboard: 'true' spec: annotations: - prometheus.io/port: '8080' + prometheus.io/port: '8000' prometheus.io/path: '/metrics' multiModel: false supportedModelFormats: @@ -35,6 +35,10 @@ spec: value: /mnt/models - name: HF_HOME value: /tmp/hf_home + - name: DETECTOR_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name ports: - containerPort: 8000 protocol: TCP diff --git a/detectors/huggingface/detector.py b/detectors/huggingface/detector.py index 44e6ee6..aab4536 100644 --- a/detectors/huggingface/detector.py +++ b/detectors/huggingface/detector.py @@ -1,7 +1,7 @@ import os -import sys -sys.path.insert(0, os.path.abspath("..")) +from detectors.common.instrumented_detector import InstrumentedDetector + import json import math import torch @@ -11,14 +11,15 @@ AutoModelForSequenceClassification, AutoModelForCausalLM, ) -from common.app import logger -from common.scheme import ( +from detectors.common.app import logger +from detectors.common.scheme import ( ContentAnalysisHttpRequest, ContentAnalysisResponse, ContentsAnalysisResponse, ) import gc + def _parse_safe_labels_env(): if os.environ.get("SAFE_LABELS"): try: @@ -35,7 +36,8 @@ def _parse_safe_labels_env(): logger.info("SAFE_LABELS env var not set: defaulting to [0].") return [0] -class Detector: + +class Detector(InstrumentedDetector): risk_names = [ "harm", "social_bias", @@ -50,6 +52,7 @@ def __init__(self): """ Initialize the Detector class by setting up the model, tokenizer, and device. """ + super().__init__() self.tokenizer = None self.model = None self.cuda_device = None @@ -102,6 +105,18 @@ def initialize_model(self, model_files_path): else: self.model_name = "unknown" + self.registry_name = self.model_name + + # set by k8s to be the pod name + if os.environ.get("DETECTOR_NAME"): + pod_name = os.environ.get("DETECTOR_NAME") + if "-predictor" in pod_name: + # recover the original ISVC name as specified by the user + pod_name = pod_name.split("-predictor")[0] + self.function_name = pod_name + else: + self.function_name = os.path.basename(model_files_path) + logger.info(f"Model type detected: {self.model_name}") def initialize_device(self): @@ -173,7 +188,7 @@ def get_probabilities(self, logprobs, safe_token, unsafe_token): unsafe_token_prob = 1e-50 for gen_token_i in logprobs: for logprob, index in zip( - gen_token_i.values.tolist()[0], gen_token_i.indices.tolist()[0] + gen_token_i.values.tolist()[0], gen_token_i.indices.tolist()[0] ): decoded_token = self.tokenizer.convert_ids_to_tokens(index) if decoded_token.strip().lower() == safe_token.lower(): @@ -226,7 +241,7 @@ def process_sequence_classification(self, text, detector_params=None, threshold= threshold = detector_params.get("threshold", 0.5) # Merge safe_labels from env and request request_safe_labels = set(detector_params.get("safe_labels", [])) - all_safe_labels = set(self.safe_labels) | request_safe_labels + all_safe_labels = set(self.safe_labels) | request_safe_labels content_analyses = [] tokenized = self.tokenizer( text, @@ -245,9 +260,9 @@ def process_sequence_classification(self, text, detector_params=None, threshold= label = self.model.config.id2label[idx] # Exclude by index or label name if ( - prob >= threshold - and idx not in all_safe_labels - and label not in all_safe_labels + prob >= threshold + and idx not in all_safe_labels + and label not in all_safe_labels ): detection_value = getattr(self.model.config, "problem_type", None) content_analyses.append( @@ -274,16 +289,19 @@ def run(self, input: ContentAnalysisHttpRequest) -> ContentsAnalysisResponse: ContentsAnalysisResponse: The aggregated response for all input texts. """ contents_analyses = [] - for text in input.contents: - if self.is_causal_lm: - analyses = self.process_causal_lm(text) - elif self.is_sequence_classifier: - analyses = self.process_sequence_classification( - text, detector_params=getattr(input, "detector_params", None) - ) - else: - raise ValueError("Unsupported model type for analysis.") - contents_analyses.append(analyses) + with self.instrument_runtime(self.function_name): + for text in input.contents: + if self.is_causal_lm: + analyses = self.process_causal_lm(text) + elif self.is_sequence_classifier: + analyses = self.process_sequence_classification( + text, detector_params=getattr(input, "detector_params", None) + ) + else: + raise ValueError("Unsupported model type for analysis.") + contents_analyses.append(analyses) + is_detection = any(len(analyses) > 0 for analyses in contents_analyses) + self.increment_detector_instruments(self.function_name, is_detection=is_detection) return contents_analyses def close(self) -> None: @@ -299,4 +317,4 @@ def close(self) -> None: gc.collect() if torch.cuda.is_available(): - torch.cuda.empty_cache() \ No newline at end of file + torch.cuda.empty_cache() diff --git a/detectors/huggingface/locustfile.py b/detectors/huggingface/locustfile.py deleted file mode 100644 index 011e62d..0000000 --- a/detectors/huggingface/locustfile.py +++ /dev/null @@ -1,35 +0,0 @@ -""" -Content Warning: Contains potentially offensive text dealing with racism, misogyny, and violence. Examples of input prompts provided purely for the purposes of testing HAP (Hate, Abuse and Profanity) models. -""" - -from locust import HttpUser, between, task - - -class WebsiteUser(HttpUser): - wait_time = between(1, 5) - - # def on_start(self): - # self.client.post("/login", { - # "username": "test_user", - # "password": "" - # }) - - @task - def docs(self): - self.client.get("/docs") - - @task - def api(self): - self.client.get("/openapi.json") - - @task - def pii(self): - self.client.post( - "/api/v1/text/contents?pii_transformer", - json={ - "contents": [ - "My name is John Doe and my social security number is 123-45-6789." - ] - }, - headers={"detector-id": "pii", "Content-Type": "application/json"}, - ) diff --git a/detectors/huggingface/requirements.txt b/detectors/huggingface/requirements.txt index 20d5940..9f929de 100644 --- a/detectors/huggingface/requirements.txt +++ b/detectors/huggingface/requirements.txt @@ -1 +1 @@ -transformers==4.50.0 +transformers==4.50.0 \ No newline at end of file diff --git a/tests/detectors/builtIn/test_filetype.py b/tests/detectors/builtIn/test_filetype.py index f9cafa7..f65785e 100644 --- a/tests/detectors/builtIn/test_filetype.py +++ b/tests/detectors/builtIn/test_filetype.py @@ -55,7 +55,6 @@ def test_detect_content_invalid_json(self, client: TestClient): "detector_params": {"file_type": ["json"]} } resp = client.post("/api/v1/text/contents", json=payload) - print(resp.content) assert resp.status_code == 200 detections = resp.json()[0] assert detections[0]["detection"] == "invalid_json" diff --git a/tests/detectors/huggingface/test_metrics.py b/tests/detectors/huggingface/test_metrics.py new file mode 100644 index 0000000..f3b4885 --- /dev/null +++ b/tests/detectors/huggingface/test_metrics.py @@ -0,0 +1,104 @@ +import random +import os +import time +from collections import namedtuple +from unittest import mock +from unittest.mock import Mock, MagicMock + +import pytest +import torch +from starlette.testclient import TestClient + +from detectors.huggingface.detector import Detector +from detectors.huggingface.app import app + + +def send_request(client: TestClient, detect: bool, slow: bool = False): + payload = { + "contents": ["this message is too long and should induce a detection from the model" if detect else "fine"], + "detector_params": {"regex": []} + } + if slow: + payload["contents"][0] = " ".join(payload["contents"]*1000) + + expected_status_code = 200 + response = client.post("/api/v1/text/contents", json=payload) + if response.status_code != expected_status_code: + print(response.text) + assert response.status_code == expected_status_code + + +def get_metric_dict(client: TestClient): + metrics = client.get("/metrics") + metrics = metrics.content.decode().split("\n") + metric_dict = {} + + for m in metrics: + if "trustyai" in m and "{" in m: + key, value = m.split(" ") + metric_dict[key] = float(value) + + return metric_dict + +class TestMetrics: + @pytest.fixture + def client(self): + current_dir = os.path.dirname(__file__) + parent_dir = os.path.dirname(os.path.dirname(current_dir)) + os.environ["MODEL_DIR"] = os.path.join(parent_dir, "dummy_models", "bert/BertForSequenceClassification") + + detector = Detector() + + # patch the model to allow for control over detections - long messages will flag + def detection_fn(*args, **kwargs): + output = Mock() + if kwargs["input_ids"].shape[-1] > 10: + output.logits = torch.tensor([[0.0, 1.0]]) + else: + output.logits = torch.tensor([[1.0, 0.0]]) + + if kwargs["input_ids"].shape[-1] > 100: + time.sleep(.25) + return output + + class ModelMock: + def __init__(self): + self.config = Mock() + self.config.id2label = detector.model.config.id2label + self.config.problem_type = detector.model.config.problem_type + def __call__(self, *args, **kwargs): + return detection_fn(*args, **kwargs) + + detector.model = ModelMock() + app.set_detector(detector, detector.registry_name) + detector.add_instruments(app.state.instruments) + return TestClient(app) + + + + def test_prometheus(self, client: TestClient): + for i in range(20): + send_request(client=client, detect=i%3==0) + + expected_results = { + 'trustyai_guardrails_detections_total{detector_kind="sequence_classifier",detector_name="BertForSequenceClassification"}': 7.0, + 'trustyai_guardrails_errors_total{detector_kind="sequence_classifier",detector_name="BertForSequenceClassification"}': 0.0, + 'trustyai_guardrails_requests_total{detector_kind="sequence_classifier",detector_name="BertForSequenceClassification"}': 20.0, + } + + metric_dict = get_metric_dict(client) + + for expected_key, expected_val in expected_results.items(): + assert expected_key in metric_dict, f"expected key {expected_key} not found in metric dict" + assert metric_dict[expected_key] == expected_val, f"metric {expected_key} value={metric_dict[expected_key]} did not match expected value {expected_val}" + + + def test_runtime_metrics(self, client: TestClient): + # 8 calls of this function should induce ~ 2 seconds of latency + for _ in range(8): + send_request(client=client, detect=False, slow=True) + metric_dict = get_metric_dict(client) + + func_runtime = metric_dict['trustyai_guardrails_runtime_total{detector_kind="sequence_classifier",detector_name="BertForSequenceClassification"}'] + assert func_runtime > 1.8 + assert func_runtime < 2.2 \ No newline at end of file