Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions config.yaml.full
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ tool:
computer_sandbox:
url: #mcp sse/streamable-http url
api_key: #mcp api key
# [optional] for Volcengine LLM Firewall https://www.volcengine.com/product/LLM-FW
llm_firewall:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

llm_firewall-》llm_shield,

app_id:


observability:
Expand Down
280 changes: 280 additions & 0 deletions veadk/tools/builtin_tools/llm_firewall.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, List
from volcenginesdkllmshield import ClientV2, ModerateV2Request, MessageV2, ContentTypeV2

from google.adk.plugins import BasePlugin
from google.adk.agents.callback_context import CallbackContext
from google.adk.models import LlmRequest, LlmResponse
from google.genai import types

from veadk.config import getenv
from veadk.utils.logger import get_logger

logger = get_logger(__name__)


class LLMFirewallPlugin(BasePlugin):
"""
LLM Firewall Plugin for content moderation and security filtering.

This plugin integrates with Volcengine's LLM Firewall service to provide real-time
content moderation for LLM requests. It analyzes user inputs for various risks
including prompt injection, sensitive information, and policy violations before
allowing requests to reach the language model.

Examples:
Basic usage with default settings:
```python
governance = LLMFirewallPlugin()
agent = Agent(
before_model_callback=governance.before_model_callback
)
```
"""

def __init__(
self, max_history: int = 5, region: str = "cn-beijing", timeout: int = 50
) -> None:
"""
Initialize the LLM Firewall Plugin.

Sets up the plugin with Volcengine LLM Firewall service configuration
and initializes the moderation client.

Args:
max_history (int, optional): Maximum number of conversation turns
to include in moderation context. Defaults to 5.
region (str, optional): Volcengine service region.
Defaults to "cn-beijing".
timeout (int, optional): Request timeout in seconds.
Defaults to 50.

Raises:
ValueError: If required environment variables are missing
Copy link

Copilot AI Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring states that ValueError is raised if required environment variables are missing, but the init method doesn't actually raise this exception. The getenv() function will raise ValueError if the environment variable is not found (based on the veadk.utils.misc implementation), but this happens inside getenv, not explicitly in init. The documentation should clarify that the ValueError comes from the getenv() calls for VOLCENGINE_ACCESS_KEY, VOLCENGINE_SECRET_KEY, and TOOL_LLM_FIREWALL_APP_ID.

Suggested change
ValueError: If required environment variables are missing
ValueError: If required environment variables are missing. This exception is raised by the
`getenv` calls for `VOLCENGINE_ACCESS_KEY`, `VOLCENGINE_SECRET_KEY`, and
`TOOL_LLM_FIREWALL_APP_ID`.

Copilot uses AI. Check for mistakes.
"""
self.name = "LLMFirewallPlugin"
super().__init__(name=self.name)

self.ak = getenv("VOLCENGINE_ACCESS_KEY")
self.sk = getenv("VOLCENGINE_SECRET_KEY")
self.appid = getenv("TOOL_LLM_FIREWALL_APP_ID")
self.region = region
self.llm_fw_url = (
f"https://{self.region}.sdk.access.llm-shield.omini-shield.com"
)
self.timeout = timeout
self.max_history = max_history

self.client = ClientV2(self.llm_fw_url, self.ak, self.sk, region, self.timeout)

self.category_map = {
101: "Model Misuse",
103: "Sensitive Information",
104: "Prompt Injection",
106: "General Topic Control",
107: "Computational Resource Consumption",
}

def _get_system_instruction(self, llm_request: LlmRequest) -> str:
"""
Extract system instruction from LLM request.

Retrieves the system instruction from the request configuration
to include in moderation context for better risk assessment.

Args:
llm_request (LlmRequest): The incoming LLM request object

Returns:
str: System instruction text, empty string if not found
"""
config = getattr(llm_request, "config", None)
if config:
return getattr(config, "system_instruction", "")
return ""

def _build_history_from_contents(self, llm_request: LlmRequest) -> List[MessageV2]:
"""
Build conversation history from LLM request contents.

Constructs a structured conversation history for moderation context,
including system instructions and recent user-assistant exchanges.
This helps the firewall understand conversation context for better
risk assessment.

Args:
llm_request (LlmRequest): The incoming LLM request containing
conversation contents

Returns:
List[MessageV2]: Structured conversation history with messages
formatted for LLM Firewall service. Limited to max_history
recent exchanges plus system instruction if present.
"""
history = []

# Add system instruction as the first message if available
system_instruction = self._get_system_instruction(llm_request)
if system_instruction:
history.append(
MessageV2(
role="system",
content=system_instruction,
content_type=ContentTypeV2.TEXT,
)
)

contents = getattr(llm_request, "contents", [])
if not contents:
return history

# Add recent conversation history (excluding current user message)
recent_contents = contents[:-1]
if len(recent_contents) > self.max_history:
recent_contents = recent_contents[-self.max_history :]

for content in recent_contents:
parts = getattr(content, "parts", [])
if parts and hasattr(parts[0], "text"):
Copy link

Copilot AI Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential IndexError if parts list is empty. While line 152 checks if parts, it assumes parts[0] exists when checking hasattr(parts[0], \"text\"). If parts is a truthy but empty container, this will raise an IndexError. The condition should be if parts and len(parts) > 0 and hasattr(parts[0], \"text\") or simplified to check the length explicitly.

Suggested change
if parts and hasattr(parts[0], "text"):
if parts and len(parts) > 0 and hasattr(parts[0], "text"):

Copilot uses AI. Check for mistakes.
role = getattr(content, "role", "")
role = "user" if role == "user" else "assistant"
text = getattr(parts[0], "text", "")
if text:
history.append(
MessageV2(
role=role, content=text, content_type=ContentTypeV2.TEXT
)
)

return history

def before_model_callback(
self, callback_context: CallbackContext, llm_request: LlmRequest, **kwargs
) -> Optional[LlmResponse]:
"""
Callback executed before sending request to the model.

This is the main entry point for content moderation. It extracts the
user's message, builds conversation context, sends it to LLM Firewall
for analysis, and either blocks harmful content or allows safe content
to proceed to the model.

The moderation process:
1. Extracts the latest user message from request
2. Builds conversation history for context
3. Sends moderation request to LLM Firewall service
4. Analyzes response for risk categories
5. Blocks request with informative message if risks detected
6. Allows request to proceed if content is safe

Args:
callback_context (CallbackContext): Callback context
llm_request (LlmRequest): The incoming LLM request to moderate
**kwargs: Additional keyword arguments

Returns:
Optional[LlmResponse]:
- LlmResponse with blocking message if content violates policies
- None if content is safe and request should proceed to model
"""
# Extract the last user message for moderation
last_user_message = None
contents = getattr(llm_request, "contents", [])

if contents:
last_content = contents[-1]
last_role = getattr(last_content, "role", "")
last_parts = getattr(last_content, "parts", [])

if last_role == "user" and last_parts:
last_user_message = getattr(last_parts[0], "text", "")

# Skip moderation if message is empty
if not last_user_message:
return None

# Build conversation history for context
history = self._build_history_from_contents(llm_request)

# Create moderation request
moderation_request = ModerateV2Request(
scene=self.appid,
message=MessageV2(
role="user", content=last_user_message, content_type=ContentTypeV2.TEXT
),
history=history,
)

try:
# Send request to LLM Firewall service
response = self.client.Moderate(moderation_request)
except Exception as e:
logger.error(f"LLM Firewall request failed: {e}")
return None
Copy link

Copilot AI Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The broad Exception catch at line 224 makes it difficult to diagnose and handle different failure scenarios. Consider catching more specific exceptions (e.g., network errors, timeout errors, API errors) and logging them with appropriate severity levels. This would also help distinguish between transient failures and configuration issues.

Copilot uses AI. Check for mistakes.

# Check for API errors in response
response_metadata = getattr(response, "response_metadata", None)
if response_metadata:
error_info = getattr(response_metadata, "error", None)
if error_info:
error_code = getattr(error_info, "code", "Unknown")
if error_code:
error_message = getattr(error_info, "message", "Unknown error")
logger.error(
f"LLM Firewall API error: {error_code} - {error_message}"
)
return None

# Process risk detection results
result = getattr(response, "result", None)
if result:
decision = getattr(result, "decision", None)
decision_type = getattr(decision, "decision_type", None)
risk_info = getattr(result, "risk_info", None)
if int(decision_type) == 2 and risk_info:
Copy link

Copilot AI Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential exception if decision_type is None. The code calls int(decision_type) without checking if decision_type is None first. If decision_type is None (which can happen when getattr returns None), this will raise a TypeError. Consider adding a null check: if decision_type is not None and int(decision_type) == 2 and risk_info:

Suggested change
if int(decision_type) == 2 and risk_info:
if decision_type is not None and int(decision_type) == 2 and risk_info:

Copilot uses AI. Check for mistakes.
risks = getattr(risk_info, "risks", [])
if risks:
# Extract risk categories for user-friendly error message
risk_reasons = set()
for risk in risks:
category = getattr(risk, "category", None)
if category:
category_name = self.category_map.get(
int(category), f"Category {category}"
)
risk_reasons.add(category_name)

risk_reasons_list = list(risk_reasons)

# Generate blocking response
reason_text = (
", ".join(risk_reasons_list)
if risk_reasons_list
Copy link

Copilot AI Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The conversion of set to list at line 260 is unnecessary since the set is only used once for joining. You can directly join the set: reason_text = \", \".join(risk_reasons) if risk_reasons else \"security policy violation\". This simplifies the code and avoids the intermediate variable.

Suggested change
risk_reasons_list = list(risk_reasons)
# Generate blocking response
reason_text = (
", ".join(risk_reasons_list)
if risk_reasons_list
# Generate blocking response
reason_text = (
", ".join(risk_reasons)
if risk_reasons

Copilot uses AI. Check for mistakes.
else "security policy violation"
)
response_text = (
f"Your request has been blocked due to: {reason_text}. "
f"Please modify your input and try again."
)

return LlmResponse(
content=types.Content(
role="model",
parts=[types.Part(text=response_text)],
)
)

return None