-
Couldn't load subscription status.
- Fork 18
Feat: Add custom detectors to the built in detector #51
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
|
|
||
| def over_100_characters(text: str) -> bool: | ||
| return len(text)>100 | ||
|
|
||
| def contains_word(text: str) -> bool: | ||
| return "apple" in text.lower() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,118 @@ | ||
| import ast | ||
| import os | ||
|
|
||
| from fastapi import HTTPException | ||
| import inspect | ||
| import logging | ||
| from typing import List, Optional, Callable | ||
|
|
||
|
|
||
| from base_detector_registry import BaseDetectorRegistry | ||
| from detectors.common.scheme import ContentAnalysisResponse | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| def custom_func_wrapper(func: Callable, func_name: str, s: str) -> Optional[ContentAnalysisResponse]: | ||
| """Convert a some f(text)->bool into a Detector response""" | ||
| try: | ||
| result = func(s) | ||
| except Exception as e: | ||
| logging.error(f"Error when computing custom detector function {func_name}: {e}") | ||
| raise e | ||
| if result: | ||
| if isinstance(result, bool): | ||
| return ContentAnalysisResponse( | ||
| start=0, | ||
| end=len(s), | ||
| text=s, | ||
| detection_type=func_name, | ||
| detection=func_name, | ||
| score=1.0) | ||
| elif isinstance(result, dict): | ||
| try: | ||
| return ContentAnalysisResponse(**result) | ||
| except Exception as e: | ||
| logging.error(f"Error when trying to build ContentAnalysisResponse from {func_name} response: {e}") | ||
| raise e | ||
| else: | ||
| msg = f"Unsupported result type for custom detector function {func_name}, must be bool or ContentAnalysisResponse, got: {type(result)}" | ||
| logging.error(msg) | ||
| raise TypeError(msg) | ||
| else: | ||
| return None | ||
|
|
||
|
|
||
| def static_code_analysis(module_path, forbidden_imports=None, forbidden_calls=None): | ||
| """ | ||
| Perform static code analysis on a Python module to check for forbidden imports and function calls. | ||
| Returns a list of issues found. | ||
| """ | ||
| if forbidden_imports is None: | ||
| forbidden_imports = {"os", "subprocess", "sys", "shutil"} | ||
| if forbidden_calls is None: | ||
| forbidden_calls = {"eval", "exec", "open", "compile", "input"} | ||
|
|
||
| issues = [] | ||
| with open(module_path, "r") as f: | ||
| source = f.read() | ||
| try: | ||
| tree = ast.parse(source, filename=module_path) | ||
| except Exception as e: | ||
| issues.append(f"Failed to parse {module_path}: {e}") | ||
| return issues | ||
|
|
||
| for node in ast.walk(tree): | ||
| # Check for forbidden imports | ||
| if isinstance(node, ast.Import): | ||
| for alias in node.names: | ||
| if alias.name.split(".")[0] in forbidden_imports: | ||
| issues.append(f"- Forbidden import: {alias.name} (line {node.lineno})") | ||
| if isinstance(node, ast.ImportFrom): | ||
| if node.module and node.module.split(".")[0] in forbidden_imports: | ||
| issues.append(f"- Forbidden import: {node.module} (line {node.lineno})") | ||
| # Check for forbidden function calls | ||
| if isinstance(node, ast.Call): | ||
| func_name = "" | ||
| if isinstance(node.func, ast.Name): | ||
| func_name = node.func.id | ||
| elif isinstance(node.func, ast.Attribute): | ||
| func_name = f"{getattr(node.func.value, 'id', '')}.{node.func.attr}" | ||
| if func_name in forbidden_calls: | ||
| issues.append(f"- Forbidden function call: {func_name} (line {node.lineno})") | ||
| return issues | ||
|
|
||
|
|
||
| class CustomDetectorRegistry(BaseDetectorRegistry): | ||
| def __init__(self): | ||
| super().__init__() | ||
|
|
||
| issues = static_code_analysis(module_path = os.path.join(os.path.dirname(__file__), "custom_detectors", "custom_detectors.py")) | ||
| if issues: | ||
| logging.error(f"Detected {len(issues)} potential security issues inside the custom_detectors file: {issues}") | ||
| raise ImportError(f"Unsafe code detected in custom_detectors:\n" + "\n".join(issues)) | ||
|
|
||
| import custom_detectors.custom_detectors as custom_detectors | ||
|
|
||
| self.registry = {name: obj for name, obj | ||
| in inspect.getmembers(custom_detectors, inspect.isfunction) | ||
| if not name.startswith("_")} | ||
| logger.info(f"Registered the following custom detectors: {self.registry.keys()}") | ||
|
|
||
| def handle_request(self, content: str, detector_params: dict) -> List[ContentAnalysisResponse]: | ||
| detections = [] | ||
| if "custom" in detector_params and isinstance(detector_params["custom"], (list, str)): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion: Type checking for 'custom' parameter may miss edge cases with unexpected types. Unexpected types for 'custom' are ignored without notice. Logging or handling these cases would improve debuggability. |
||
| custom_functions = detector_params["custom"] | ||
| custom_functions = [custom_functions] if isinstance(custom_functions, str) else custom_functions | ||
| for custom_function in custom_functions: | ||
| if self.registry.get(custom_function): | ||
| try: | ||
| result = custom_func_wrapper(self.registry[custom_function], custom_function, content) | ||
| if result is not None: | ||
| detections.append(result) | ||
| except Exception as e: | ||
| logger.error(e) | ||
| raise HTTPException(status_code=400, detail="Detection error, check detector logs") | ||
| else: | ||
| raise HTTPException(status_code=400, detail=f"Unrecognized custom function: {custom_function}") | ||
| return detections | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,4 @@ | ||
| markdown==3.8.2 | ||
| jsonschema==4.24.0 | ||
| xmlschema==4.1.0 | ||
| xmlschema==4.1.0 | ||
| requests==2.32.5 |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,99 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import importlib | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import sys | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from http.client import HTTPException | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import pytest | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import os | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from fastapi.testclient import TestClient | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| CUSTOM_DETECTORS_PATH = os.path.join( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| os.path.dirname(__file__), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "../../../detectors/built_in/custom_detectors/custom_detectors.py" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| SAFE_CODE = """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def over_100_characters(text: str) -> bool: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return len(text)>100 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def contains_word(text: str) -> bool: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return "apple" in text.lower() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| UNSAFE_CODE = ''' | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import os | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def evil(text: str) -> bool: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| os.system("echo haha gottem") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return True | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ''' | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def write_code_to_custom_detectors(code: str): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| with open(CUSTOM_DETECTORS_PATH, "w") as f: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| f.write(code) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def restore_safe_code(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| write_code_to_custom_detectors(SAFE_CODE) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class TestCustomDetectors: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @pytest.fixture | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def client(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from detectors.built_in.app import app | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from detectors.built_in.custom_detectors_wrapper import CustomDetectorRegistry | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| app.set_detector(CustomDetectorRegistry(), "custom") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return TestClient(app) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @pytest.fixture(autouse=True) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def cleanup_custom_detectors(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Always restore safe code after test | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| yield | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| restore_safe_code() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def test_missing_detector_type(self, client): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| payload = { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "contents": ["What is an apple?"], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "detector_params": {"custom1": ["contains_word"]} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| resp = client.post("/api/v1/text/contents", json=payload) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert resp.status_code == 400 and "Detector custom1 not found" in resp.text | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def test_custom_detectors(self, client): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| payload = { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "contents": ["What is an apple?"], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "detector_params": {"custom": ["contains_word"]} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| resp = client.post("/api/v1/text/contents", json=payload) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert resp.status_code == 200 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| texts = [d["text"] for d in resp.json()[0]] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert "What is an apple?" in texts | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+62
to
+70
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion (testing): Missing negative and error case tests for custom detectors. Add tests for scenarios where the custom detector does not match any content and for invalid detector names to ensure proper error handling and coverage of negative cases.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def test_custom_detectors_not_match(self, client): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| msg = "What is an banana?" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| payload = { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "contents": [msg], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "detector_params": {"custom": ["contains_word"]} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| resp = client.post("/api/v1/text/contents", json=payload) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert resp.status_code == 200 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| texts = [d["text"] for d in resp.json()[0]] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert msg not in texts | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def test_unsafe_code(self, client): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| write_code_to_custom_detectors(UNSAFE_CODE) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from detectors.built_in.custom_detectors_wrapper import CustomDetectorRegistry | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| with pytest.raises(ImportError) as excinfo: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| CustomDetectorRegistry() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert "Unsafe code detected" in str(excinfo.value) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert "Forbidden import: os" in str(excinfo.value) or "os.system" in str(excinfo.value) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def test_custom_detectors_func_doesnt_exist(self, client): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| payload = { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "contents": ["What is an apple?"], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "detector_params": {"custom": ["abc"]} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| resp = client.post("/api/v1/text/contents", json=payload) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert resp.status_code == 400 and "Unrecognized custom function: abc" in resp.text | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Uh oh!
There was an error while loading. Please reload this page.