From a86f66a3cdf06e6a02821f68a8ec71886648f0d2 Mon Sep 17 00:00:00 2001 From: Rob Geada Date: Wed, 8 Oct 2025 10:04:14 +0100 Subject: [PATCH] Add custom detectors to the built in detector, with code scanning --- detectors/Dockerfile.builtIn | 2 +- detectors/built_in/app.py | 21 ++-- .../built_in/custom_detectors/__init__.py | 0 .../custom_detectors/custom_detectors.py | 6 + .../built_in/custom_detectors_wrapper.py | 118 ++++++++++++++++++ detectors/built_in/requirements.txt | 3 +- tests/detectors/builtIn/test_custom.py | 99 +++++++++++++++ 7 files changed, 240 insertions(+), 9 deletions(-) create mode 100644 detectors/built_in/custom_detectors/__init__.py create mode 100644 detectors/built_in/custom_detectors/custom_detectors.py create mode 100644 detectors/built_in/custom_detectors_wrapper.py create mode 100644 tests/detectors/builtIn/test_custom.py diff --git a/detectors/Dockerfile.builtIn b/detectors/Dockerfile.builtIn index 939bdeb..d1d502d 100644 --- a/detectors/Dockerfile.builtIn +++ b/detectors/Dockerfile.builtIn @@ -20,7 +20,7 @@ WORKDIR /app ARG CACHEBUST=1 RUN echo "$CACHEBUST" COPY ./common /app/detectors/common -COPY ./built_in/* /app +COPY ./built_in/ /app EXPOSE 8080 diff --git a/detectors/built_in/app.py b/detectors/built_in/app.py index 9ec7fca..73f595d 100644 --- a/detectors/built_in/app.py +++ b/detectors/built_in/app.py @@ -1,17 +1,22 @@ +import logging + from fastapi import HTTPException from contextlib import asynccontextmanager from base_detector_registry import BaseDetectorRegistry from regex_detectors import RegexDetectorRegistry +from custom_detectors_wrapper import CustomDetectorRegistry from file_type_detectors import FileTypeDetectorRegistry from prometheus_fastapi_instrumentator import Instrumentator from detectors.common.scheme import ContentAnalysisHttpRequest, ContentsAnalysisResponse from detectors.common.app import DetectorBaseAPI as FastAPI + @asynccontextmanager async def lifespan(app: FastAPI): app.set_detector(RegexDetectorRegistry(), "regex") app.set_detector(FileTypeDetectorRegistry(), "file_type") + app.set_detector(CustomDetectorRegistry(), "custom") yield app.cleanup_detector() @@ -19,22 +24,24 @@ async def lifespan(app: FastAPI): app = FastAPI(lifespan=lifespan) Instrumentator().instrument(app).expose(app) +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) -# registry : dict[str, BaseDetectorRegistry] = { -# "regex": RegexDetectorRegistry(), -# "file_type": FileTypeDetectorRegistry(), -# } - @app.post("/api/v1/text/contents", response_model=ContentsAnalysisResponse) def detect_content(request: ContentAnalysisHttpRequest): + logger.info(f"Request for {request.detector_params}") + detections = [] for content in request.contents: message_detections = [] - for detector_kind, detector_registry in app.get_all_detectors().items(): + for detector_kind in request.detector_params: + detector_registry = app.get_all_detectors().get(detector_kind) + if detector_registry is None: + raise HTTPException(status_code=400, detail=f"Detector {detector_kind} not found") if not isinstance(detector_registry, BaseDetectorRegistry): raise TypeError(f"Detector {detector_kind} is not a valid BaseDetectorRegistry") - if detector_kind in request.detector_params: + else: try: message_detections += detector_registry.handle_request(content, request.detector_params) except HTTPException as e: diff --git a/detectors/built_in/custom_detectors/__init__.py b/detectors/built_in/custom_detectors/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/detectors/built_in/custom_detectors/custom_detectors.py b/detectors/built_in/custom_detectors/custom_detectors.py new file mode 100644 index 0000000..84844eb --- /dev/null +++ b/detectors/built_in/custom_detectors/custom_detectors.py @@ -0,0 +1,6 @@ + +def over_100_characters(text: str) -> bool: + return len(text)>100 + +def contains_word(text: str) -> bool: + return "apple" in text.lower() diff --git a/detectors/built_in/custom_detectors_wrapper.py b/detectors/built_in/custom_detectors_wrapper.py new file mode 100644 index 0000000..6a00c3e --- /dev/null +++ b/detectors/built_in/custom_detectors_wrapper.py @@ -0,0 +1,118 @@ +import ast +import os + +from fastapi import HTTPException +import inspect +import logging +from typing import List, Optional, Callable + + +from base_detector_registry import BaseDetectorRegistry +from detectors.common.scheme import ContentAnalysisResponse + +logger = logging.getLogger(__name__) + +def custom_func_wrapper(func: Callable, func_name: str, s: str) -> Optional[ContentAnalysisResponse]: + """Convert a some f(text)->bool into a Detector response""" + try: + result = func(s) + except Exception as e: + logging.error(f"Error when computing custom detector function {func_name}: {e}") + raise e + if result: + if isinstance(result, bool): + return ContentAnalysisResponse( + start=0, + end=len(s), + text=s, + detection_type=func_name, + detection=func_name, + score=1.0) + elif isinstance(result, dict): + try: + return ContentAnalysisResponse(**result) + except Exception as e: + logging.error(f"Error when trying to build ContentAnalysisResponse from {func_name} response: {e}") + raise e + else: + msg = f"Unsupported result type for custom detector function {func_name}, must be bool or ContentAnalysisResponse, got: {type(result)}" + logging.error(msg) + raise TypeError(msg) + else: + return None + + +def static_code_analysis(module_path, forbidden_imports=None, forbidden_calls=None): + """ + Perform static code analysis on a Python module to check for forbidden imports and function calls. + Returns a list of issues found. + """ + if forbidden_imports is None: + forbidden_imports = {"os", "subprocess", "sys", "shutil"} + if forbidden_calls is None: + forbidden_calls = {"eval", "exec", "open", "compile", "input"} + + issues = [] + with open(module_path, "r") as f: + source = f.read() + try: + tree = ast.parse(source, filename=module_path) + except Exception as e: + issues.append(f"Failed to parse {module_path}: {e}") + return issues + + for node in ast.walk(tree): + # Check for forbidden imports + if isinstance(node, ast.Import): + for alias in node.names: + if alias.name.split(".")[0] in forbidden_imports: + issues.append(f"- Forbidden import: {alias.name} (line {node.lineno})") + if isinstance(node, ast.ImportFrom): + if node.module and node.module.split(".")[0] in forbidden_imports: + issues.append(f"- Forbidden import: {node.module} (line {node.lineno})") + # Check for forbidden function calls + if isinstance(node, ast.Call): + func_name = "" + if isinstance(node.func, ast.Name): + func_name = node.func.id + elif isinstance(node.func, ast.Attribute): + func_name = f"{getattr(node.func.value, 'id', '')}.{node.func.attr}" + if func_name in forbidden_calls: + issues.append(f"- Forbidden function call: {func_name} (line {node.lineno})") + return issues + + +class CustomDetectorRegistry(BaseDetectorRegistry): + def __init__(self): + super().__init__() + + issues = static_code_analysis(module_path = os.path.join(os.path.dirname(__file__), "custom_detectors", "custom_detectors.py")) + if issues: + logging.error(f"Detected {len(issues)} potential security issues inside the custom_detectors file: {issues}") + raise ImportError(f"Unsafe code detected in custom_detectors:\n" + "\n".join(issues)) + + import custom_detectors.custom_detectors as custom_detectors + + self.registry = {name: obj for name, obj + in inspect.getmembers(custom_detectors, inspect.isfunction) + if not name.startswith("_")} + logger.info(f"Registered the following custom detectors: {self.registry.keys()}") + + def handle_request(self, content: str, detector_params: dict) -> List[ContentAnalysisResponse]: + detections = [] + if "custom" in detector_params and isinstance(detector_params["custom"], (list, str)): + custom_functions = detector_params["custom"] + custom_functions = [custom_functions] if isinstance(custom_functions, str) else custom_functions + for custom_function in custom_functions: + if self.registry.get(custom_function): + try: + result = custom_func_wrapper(self.registry[custom_function], custom_function, content) + if result is not None: + detections.append(result) + except Exception as e: + logger.error(e) + raise HTTPException(status_code=400, detail="Detection error, check detector logs") + else: + raise HTTPException(status_code=400, detail=f"Unrecognized custom function: {custom_function}") + return detections + diff --git a/detectors/built_in/requirements.txt b/detectors/built_in/requirements.txt index 31144c9..985781e 100644 --- a/detectors/built_in/requirements.txt +++ b/detectors/built_in/requirements.txt @@ -1,3 +1,4 @@ markdown==3.8.2 jsonschema==4.24.0 -xmlschema==4.1.0 \ No newline at end of file +xmlschema==4.1.0 +requests==2.32.5 diff --git a/tests/detectors/builtIn/test_custom.py b/tests/detectors/builtIn/test_custom.py new file mode 100644 index 0000000..9d8b8b4 --- /dev/null +++ b/tests/detectors/builtIn/test_custom.py @@ -0,0 +1,99 @@ +import importlib +import sys +from http.client import HTTPException + +import pytest +import os +from fastapi.testclient import TestClient + + +CUSTOM_DETECTORS_PATH = os.path.join( + os.path.dirname(__file__), + "../../../detectors/built_in/custom_detectors/custom_detectors.py" +) + +SAFE_CODE = """ +def over_100_characters(text: str) -> bool: + return len(text)>100 + +def contains_word(text: str) -> bool: + return "apple" in text.lower() +""" + +UNSAFE_CODE = ''' +import os +def evil(text: str) -> bool: + os.system("echo haha gottem") + return True +''' + + +def write_code_to_custom_detectors(code: str): + with open(CUSTOM_DETECTORS_PATH, "w") as f: + f.write(code) + +def restore_safe_code(): + write_code_to_custom_detectors(SAFE_CODE) + + +class TestCustomDetectors: + @pytest.fixture + def client(self): + from detectors.built_in.app import app + from detectors.built_in.custom_detectors_wrapper import CustomDetectorRegistry + app.set_detector(CustomDetectorRegistry(), "custom") + return TestClient(app) + + @pytest.fixture(autouse=True) + def cleanup_custom_detectors(self): + # Always restore safe code after test + yield + restore_safe_code() + + def test_missing_detector_type(self, client): + payload = { + "contents": ["What is an apple?"], + "detector_params": {"custom1": ["contains_word"]} + } + resp = client.post("/api/v1/text/contents", json=payload) + assert resp.status_code == 400 and "Detector custom1 not found" in resp.text + + + def test_custom_detectors(self, client): + payload = { + "contents": ["What is an apple?"], + "detector_params": {"custom": ["contains_word"]} + } + resp = client.post("/api/v1/text/contents", json=payload) + assert resp.status_code == 200 + texts = [d["text"] for d in resp.json()[0]] + assert "What is an apple?" in texts + + def test_custom_detectors_not_match(self, client): + msg = "What is an banana?" + payload = { + "contents": [msg], + "detector_params": {"custom": ["contains_word"]} + } + resp = client.post("/api/v1/text/contents", json=payload) + assert resp.status_code == 200 + texts = [d["text"] for d in resp.json()[0]] + assert msg not in texts + + def test_unsafe_code(self, client): + write_code_to_custom_detectors(UNSAFE_CODE) + from detectors.built_in.custom_detectors_wrapper import CustomDetectorRegistry + with pytest.raises(ImportError) as excinfo: + CustomDetectorRegistry() + assert "Unsafe code detected" in str(excinfo.value) + assert "Forbidden import: os" in str(excinfo.value) or "os.system" in str(excinfo.value) + + + def test_custom_detectors_func_doesnt_exist(self, client): + payload = { + "contents": ["What is an apple?"], + "detector_params": {"custom": ["abc"]} + } + resp = client.post("/api/v1/text/contents", json=payload) + assert resp.status_code == 400 and "Unrecognized custom function: abc" in resp.text +