Skip to content

Commit 4876ecd

Browse files
committed
feat: add builtin tool of llm firewall
1 parent 43239a4 commit 4876ecd

File tree

2 files changed

+269
-0
lines changed

2 files changed

+269
-0
lines changed

config.yaml.full

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ tool:
6767
computer_sandbox:
6868
url: #mcp sse/streamable-http url
6969
api_key: #mcp api key
70+
# [optional] for Volcengine LLM Firewall https://www.volcengine.com/product/LLM-FW
71+
llm_firewall:
72+
app_id:
7073

7174

7275
observability:
Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
from typing import Optional, List
2+
from volcenginesdkllmshield import ClientV2, ModerateV2Request, MessageV2, ContentTypeV2
3+
4+
from google.adk.plugins import BasePlugin
5+
from google.adk.agents.callback_context import CallbackContext
6+
from google.adk.models import LlmRequest, LlmResponse
7+
from google.genai import types
8+
9+
from veadk.config import getenv
10+
from veadk.utils.logger import get_logger
11+
12+
logger = get_logger(__name__)
13+
14+
15+
class LLMFirewallPlugin(BasePlugin):
16+
"""
17+
LLM Firewall Plugin for content moderation and security filtering.
18+
19+
This plugin integrates with Volcengine's LLM Firewall service to provide real-time
20+
content moderation for LLM requests. It analyzes user inputs for various risks
21+
including prompt injection, sensitive information, and policy violations before
22+
allowing requests to reach the language model.
23+
24+
Examples:
25+
Basic usage with default settings:
26+
```python
27+
governance = LLMFirewallPlugin()
28+
agent = Agent(
29+
before_model_callback=governance.before_model_callback
30+
)
31+
```
32+
"""
33+
34+
def __init__(
35+
self, max_history: int = 5, region: str = "cn-beijing", timeout: int = 50
36+
) -> None:
37+
"""
38+
Initialize the LLM Firewall Plugin.
39+
40+
Sets up the plugin with Volcengine LLM Firewall service configuration
41+
and initializes the moderation client.
42+
43+
Args:
44+
max_history (int, optional): Maximum number of conversation turns
45+
to include in moderation context. Defaults to 5.
46+
region (str, optional): Volcengine service region.
47+
Defaults to "cn-beijing".
48+
timeout (int, optional): Request timeout in seconds.
49+
Defaults to 50.
50+
51+
Raises:
52+
ValueError: If required environment variables are missing
53+
"""
54+
self.name = "LLMFirewallPlugin"
55+
super().__init__(name=self.name)
56+
57+
self.ak = getenv("VOLCENGINE_ACCESS_KEY")
58+
self.sk = getenv("VOLCENGINE_SECRET_KEY")
59+
self.appid = getenv("TOOL_LLM_FIREWALL_APP_ID")
60+
self.region = region
61+
self.llm_fw_url = (
62+
f"https://{self.region}.sdk.access.llm-shield.omini-shield.com"
63+
)
64+
self.timeout = timeout
65+
self.max_history = max_history
66+
67+
self.client = ClientV2(self.llm_fw_url, self.ak, self.sk, region, self.timeout)
68+
69+
self.category_map = {
70+
101: "Model Misuse",
71+
103: "Sensitive Information",
72+
104: "Prompt Injection",
73+
106: "General Topic Control",
74+
107: "Computational Resource Consumption",
75+
}
76+
77+
def _get_system_instruction(self, llm_request: LlmRequest) -> str:
78+
"""
79+
Extract system instruction from LLM request.
80+
81+
Retrieves the system instruction from the request configuration
82+
to include in moderation context for better risk assessment.
83+
84+
Args:
85+
llm_request (LlmRequest): The incoming LLM request object
86+
87+
Returns:
88+
str: System instruction text, empty string if not found
89+
"""
90+
config = getattr(llm_request, "config", None)
91+
if config:
92+
return getattr(config, "system_instruction", "")
93+
return ""
94+
95+
def _build_history_from_contents(self, llm_request: LlmRequest) -> List[MessageV2]:
96+
"""
97+
Build conversation history from LLM request contents.
98+
99+
Constructs a structured conversation history for moderation context,
100+
including system instructions and recent user-assistant exchanges.
101+
This helps the firewall understand conversation context for better
102+
risk assessment.
103+
104+
Args:
105+
llm_request (LlmRequest): The incoming LLM request containing
106+
conversation contents
107+
108+
Returns:
109+
List[MessageV2]: Structured conversation history with messages
110+
formatted for LLM Firewall service. Limited to max_history
111+
recent exchanges plus system instruction if present.
112+
"""
113+
history = []
114+
115+
# Add system instruction as the first message if available
116+
system_instruction = self._get_system_instruction(llm_request)
117+
if system_instruction:
118+
history.append(
119+
MessageV2(
120+
role="system",
121+
content=system_instruction,
122+
content_type=ContentTypeV2.TEXT,
123+
)
124+
)
125+
126+
contents = getattr(llm_request, "contents", [])
127+
if not contents:
128+
return history
129+
130+
# Add recent conversation history (excluding current user message)
131+
recent_contents = contents[:-1]
132+
if len(recent_contents) > self.max_history:
133+
recent_contents = recent_contents[-self.max_history :]
134+
135+
for content in recent_contents:
136+
parts = getattr(content, "parts", [])
137+
if parts and hasattr(parts[0], "text"):
138+
role = getattr(content, "role", "")
139+
role = "user" if role == "user" else "assistant"
140+
text = getattr(parts[0], "text", "")
141+
if text:
142+
history.append(
143+
MessageV2(
144+
role=role, content=text, content_type=ContentTypeV2.TEXT
145+
)
146+
)
147+
148+
return history
149+
150+
def before_model_callback(
151+
self, callback_context: CallbackContext, llm_request: LlmRequest, **kwargs
152+
) -> Optional[LlmResponse]:
153+
"""
154+
Callback executed before sending request to the model.
155+
156+
This is the main entry point for content moderation. It extracts the
157+
user's message, builds conversation context, sends it to LLM Firewall
158+
for analysis, and either blocks harmful content or allows safe content
159+
to proceed to the model.
160+
161+
The moderation process:
162+
1. Extracts the latest user message from request
163+
2. Builds conversation history for context
164+
3. Sends moderation request to LLM Firewall service
165+
4. Analyzes response for risk categories
166+
5. Blocks request with informative message if risks detected
167+
6. Allows request to proceed if content is safe
168+
169+
Args:
170+
callback_context (CallbackContext): Callback context
171+
llm_request (LlmRequest): The incoming LLM request to moderate
172+
**kwargs: Additional keyword arguments
173+
174+
Returns:
175+
Optional[LlmResponse]:
176+
- LlmResponse with blocking message if content violates policies
177+
- None if content is safe and request should proceed to model
178+
"""
179+
# Extract the last user message for moderation
180+
last_user_message = None
181+
contents = getattr(llm_request, "contents", [])
182+
183+
if contents:
184+
last_content = contents[-1]
185+
last_role = getattr(last_content, "role", "")
186+
last_parts = getattr(last_content, "parts", [])
187+
188+
if last_role == "user" and last_parts:
189+
last_user_message = getattr(last_parts[0], "text", "")
190+
191+
# Skip moderation if message is empty
192+
if not last_user_message:
193+
return None
194+
195+
# Build conversation history for context
196+
history = self._build_history_from_contents(llm_request)
197+
198+
# Create moderation request
199+
moderation_request = ModerateV2Request(
200+
scene=self.appid,
201+
message=MessageV2(
202+
role="user", content=last_user_message, content_type=ContentTypeV2.TEXT
203+
),
204+
history=history,
205+
)
206+
207+
try:
208+
# Send request to LLM Firewall service
209+
response = self.client.Moderate(moderation_request)
210+
except Exception as e:
211+
logger.error(f"LLM Firewall request failed: {e}")
212+
return None
213+
214+
# Check for API errors in response
215+
response_metadata = getattr(response, "response_metadata", None)
216+
if response_metadata:
217+
error_info = getattr(response_metadata, "error", None)
218+
if error_info:
219+
error_code = getattr(error_info, "code", "Unknown")
220+
if error_code:
221+
error_message = getattr(error_info, "message", "Unknown error")
222+
logger.error(
223+
f"LLM Firewall API error: {error_code} - {error_message}"
224+
)
225+
return None
226+
227+
# Process risk detection results
228+
result = getattr(response, "result", None)
229+
if result:
230+
decision = getattr(result, "decision", None)
231+
decision_type = getattr(decision, "decision_type", None)
232+
risk_info = getattr(result, "risk_info", None)
233+
if int(decision_type) == 2 and risk_info:
234+
risks = getattr(risk_info, "risks", [])
235+
if risks:
236+
# Extract risk categories for user-friendly error message
237+
risk_reasons = set()
238+
for risk in risks:
239+
category = getattr(risk, "category", None)
240+
if category:
241+
category_name = self.category_map.get(
242+
int(category), f"Category {category}"
243+
)
244+
risk_reasons.add(category_name)
245+
246+
risk_reasons_list = list(risk_reasons)
247+
248+
# Generate blocking response
249+
reason_text = (
250+
", ".join(risk_reasons_list)
251+
if risk_reasons_list
252+
else "security policy violation"
253+
)
254+
response_text = (
255+
f"Your request has been blocked due to: {reason_text}. "
256+
f"Please modify your input and try again."
257+
)
258+
259+
return LlmResponse(
260+
content=types.Content(
261+
role="model",
262+
parts=[types.Part(text=response_text)],
263+
)
264+
)
265+
266+
return None

0 commit comments

Comments
 (0)