Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion detectors/Dockerfile.builtIn
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ WORKDIR /app
ARG CACHEBUST=1
RUN echo "$CACHEBUST"
COPY ./common /app/detectors/common
COPY ./built_in/* /app
COPY ./built_in/ /app

EXPOSE 8080

Expand Down
21 changes: 14 additions & 7 deletions detectors/built_in/app.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,47 @@
import logging

from fastapi import HTTPException
from contextlib import asynccontextmanager
from base_detector_registry import BaseDetectorRegistry
from regex_detectors import RegexDetectorRegistry
from custom_detectors_wrapper import CustomDetectorRegistry
from file_type_detectors import FileTypeDetectorRegistry

from prometheus_fastapi_instrumentator import Instrumentator
from detectors.common.scheme import ContentAnalysisHttpRequest, ContentsAnalysisResponse
from detectors.common.app import DetectorBaseAPI as FastAPI


@asynccontextmanager
async def lifespan(app: FastAPI):
app.set_detector(RegexDetectorRegistry(), "regex")
app.set_detector(FileTypeDetectorRegistry(), "file_type")
app.set_detector(CustomDetectorRegistry(), "custom")
yield

app.cleanup_detector()


app = FastAPI(lifespan=lifespan)
Instrumentator().instrument(app).expose(app)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


# registry : dict[str, BaseDetectorRegistry] = {
# "regex": RegexDetectorRegistry(),
# "file_type": FileTypeDetectorRegistry(),
# }

@app.post("/api/v1/text/contents", response_model=ContentsAnalysisResponse)
def detect_content(request: ContentAnalysisHttpRequest):
logger.info(f"Request for {request.detector_params}")

detections = []
for content in request.contents:
message_detections = []
for detector_kind, detector_registry in app.get_all_detectors().items():
for detector_kind in request.detector_params:
detector_registry = app.get_all_detectors().get(detector_kind)
if detector_registry is None:
raise HTTPException(status_code=400, detail=f"Detector {detector_kind} not found")
if not isinstance(detector_registry, BaseDetectorRegistry):
raise TypeError(f"Detector {detector_kind} is not a valid BaseDetectorRegistry")
if detector_kind in request.detector_params:
else:
try:
message_detections += detector_registry.handle_request(content, request.detector_params)
except HTTPException as e:
Expand Down
Empty file.
6 changes: 6 additions & 0 deletions detectors/built_in/custom_detectors/custom_detectors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@

def over_100_characters(text: str) -> bool:
return len(text)>100

def contains_word(text: str) -> bool:
return "apple" in text.lower()
118 changes: 118 additions & 0 deletions detectors/built_in/custom_detectors_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import ast
import os

from fastapi import HTTPException
import inspect
import logging
from typing import List, Optional, Callable


from base_detector_registry import BaseDetectorRegistry
from detectors.common.scheme import ContentAnalysisResponse

logger = logging.getLogger(__name__)

def custom_func_wrapper(func: Callable, func_name: str, s: str) -> Optional[ContentAnalysisResponse]:
"""Convert a some f(text)->bool into a Detector response"""
try:
result = func(s)
except Exception as e:
logging.error(f"Error when computing custom detector function {func_name}: {e}")
raise e
if result:
if isinstance(result, bool):
return ContentAnalysisResponse(
start=0,
end=len(s),
text=s,
detection_type=func_name,
detection=func_name,
score=1.0)
elif isinstance(result, dict):
try:
return ContentAnalysisResponse(**result)
except Exception as e:
logging.error(f"Error when trying to build ContentAnalysisResponse from {func_name} response: {e}")
raise e
else:
msg = f"Unsupported result type for custom detector function {func_name}, must be bool or ContentAnalysisResponse, got: {type(result)}"
logging.error(msg)
raise TypeError(msg)
else:
return None


def static_code_analysis(module_path, forbidden_imports=None, forbidden_calls=None):
"""
Perform static code analysis on a Python module to check for forbidden imports and function calls.
Returns a list of issues found.
"""
if forbidden_imports is None:
forbidden_imports = {"os", "subprocess", "sys", "shutil"}
if forbidden_calls is None:
forbidden_calls = {"eval", "exec", "open", "compile", "input"}

issues = []
with open(module_path, "r") as f:
source = f.read()
try:
tree = ast.parse(source, filename=module_path)
except Exception as e:
issues.append(f"Failed to parse {module_path}: {e}")
return issues

for node in ast.walk(tree):
# Check for forbidden imports
if isinstance(node, ast.Import):
for alias in node.names:
if alias.name.split(".")[0] in forbidden_imports:
issues.append(f"- Forbidden import: {alias.name} (line {node.lineno})")
if isinstance(node, ast.ImportFrom):
if node.module and node.module.split(".")[0] in forbidden_imports:
issues.append(f"- Forbidden import: {node.module} (line {node.lineno})")
# Check for forbidden function calls
if isinstance(node, ast.Call):
func_name = ""
if isinstance(node.func, ast.Name):
func_name = node.func.id
elif isinstance(node.func, ast.Attribute):
func_name = f"{getattr(node.func.value, 'id', '')}.{node.func.attr}"
if func_name in forbidden_calls:
issues.append(f"- Forbidden function call: {func_name} (line {node.lineno})")
return issues


class CustomDetectorRegistry(BaseDetectorRegistry):
def __init__(self):
super().__init__()

issues = static_code_analysis(module_path = os.path.join(os.path.dirname(__file__), "custom_detectors", "custom_detectors.py"))
if issues:
logging.error(f"Detected {len(issues)} potential security issues inside the custom_detectors file: {issues}")
raise ImportError(f"Unsafe code detected in custom_detectors:\n" + "\n".join(issues))

import custom_detectors.custom_detectors as custom_detectors

self.registry = {name: obj for name, obj
in inspect.getmembers(custom_detectors, inspect.isfunction)
if not name.startswith("_")}
logger.info(f"Registered the following custom detectors: {self.registry.keys()}")

def handle_request(self, content: str, detector_params: dict) -> List[ContentAnalysisResponse]:
detections = []
if "custom" in detector_params and isinstance(detector_params["custom"], (list, str)):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: Type checking for 'custom' parameter may miss edge cases with unexpected types.

Unexpected types for 'custom' are ignored without notice. Logging or handling these cases would improve debuggability.

custom_functions = detector_params["custom"]
custom_functions = [custom_functions] if isinstance(custom_functions, str) else custom_functions
for custom_function in custom_functions:
if self.registry.get(custom_function):
try:
result = custom_func_wrapper(self.registry[custom_function], custom_function, content)
if result is not None:
detections.append(result)
except Exception as e:
logger.error(e)
raise HTTPException(status_code=400, detail="Detection error, check detector logs")
else:
raise HTTPException(status_code=400, detail=f"Unrecognized custom function: {custom_function}")
return detections

3 changes: 2 additions & 1 deletion detectors/built_in/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
markdown==3.8.2
jsonschema==4.24.0
xmlschema==4.1.0
xmlschema==4.1.0
requests==2.32.5
99 changes: 99 additions & 0 deletions tests/detectors/builtIn/test_custom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import importlib
import sys
from http.client import HTTPException

import pytest
import os
from fastapi.testclient import TestClient


CUSTOM_DETECTORS_PATH = os.path.join(
os.path.dirname(__file__),
"../../../detectors/built_in/custom_detectors/custom_detectors.py"
)

SAFE_CODE = """
def over_100_characters(text: str) -> bool:
return len(text)>100
def contains_word(text: str) -> bool:
return "apple" in text.lower()
"""

UNSAFE_CODE = '''
import os
def evil(text: str) -> bool:
os.system("echo haha gottem")
return True
'''


def write_code_to_custom_detectors(code: str):
with open(CUSTOM_DETECTORS_PATH, "w") as f:
f.write(code)

def restore_safe_code():
write_code_to_custom_detectors(SAFE_CODE)


class TestCustomDetectors:
@pytest.fixture
def client(self):
from detectors.built_in.app import app
from detectors.built_in.custom_detectors_wrapper import CustomDetectorRegistry
app.set_detector(CustomDetectorRegistry(), "custom")
return TestClient(app)

@pytest.fixture(autouse=True)
def cleanup_custom_detectors(self):
# Always restore safe code after test
yield
restore_safe_code()

def test_missing_detector_type(self, client):
payload = {
"contents": ["What is an apple?"],
"detector_params": {"custom1": ["contains_word"]}
}
resp = client.post("/api/v1/text/contents", json=payload)
assert resp.status_code == 400 and "Detector custom1 not found" in resp.text


def test_custom_detectors(self, client):
payload = {
"contents": ["What is an apple?"],
"detector_params": {"custom": ["contains_word"]}
}
resp = client.post("/api/v1/text/contents", json=payload)
assert resp.status_code == 200
texts = [d["text"] for d in resp.json()[0]]
assert "What is an apple?" in texts
Comment on lines +62 to +70
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Missing negative and error case tests for custom detectors.

Add tests for scenarios where the custom detector does not match any content and for invalid detector names to ensure proper error handling and coverage of negative cases.

Suggested change
def test_custom_detectors(self, client):
payload = {
"contents": ["What is an apple?"],
"detector_params": {"custom": ["contains_word"]}
}
resp = client.post("/api/v1/text/contents", json=payload)
assert resp.status_code == 200
texts = [d["text"] for d in resp.json()[0]]
assert "What is an apple?" in texts
def test_custom_detectors(self, client):
payload = {
"contents": ["What is an apple?"],
"detector_params": {"custom": ["contains_word"]}
}
resp = client.post("/api/v1/text/contents", json=payload)
assert resp.status_code == 200
texts = [d["text"] for d in resp.json()[0]]
assert "What is an apple?" in texts
def test_custom_detector_no_match(self, client):
payload = {
"contents": ["Bananas are yellow."],
"detector_params": {"custom": ["contains_word"]}
}
resp = client.post("/api/v1/text/contents", json=payload)
assert resp.status_code == 200
texts = [d["text"] for d in resp.json()[0]]
assert "Bananas are yellow." not in texts
def test_custom_detector_invalid_name(self, client):
payload = {
"contents": ["What is an apple?"],
"detector_params": {"custom": ["non_existent_detector"]}
}
resp = client.post("/api/v1/text/contents", json=payload)
assert resp.status_code == 400 or resp.status_code == 422
# Optionally check for error message in response
# assert "invalid detector" in resp.text.lower()


def test_custom_detectors_not_match(self, client):
msg = "What is an banana?"
payload = {
"contents": [msg],
"detector_params": {"custom": ["contains_word"]}
}
resp = client.post("/api/v1/text/contents", json=payload)
assert resp.status_code == 200
texts = [d["text"] for d in resp.json()[0]]
assert msg not in texts

def test_unsafe_code(self, client):
write_code_to_custom_detectors(UNSAFE_CODE)
from detectors.built_in.custom_detectors_wrapper import CustomDetectorRegistry
with pytest.raises(ImportError) as excinfo:
CustomDetectorRegistry()
assert "Unsafe code detected" in str(excinfo.value)
assert "Forbidden import: os" in str(excinfo.value) or "os.system" in str(excinfo.value)


def test_custom_detectors_func_doesnt_exist(self, client):
payload = {
"contents": ["What is an apple?"],
"detector_params": {"custom": ["abc"]}
}
resp = client.post("/api/v1/text/contents", json=payload)
assert resp.status_code == 400 and "Unrecognized custom function: abc" in resp.text

Loading