Skip to content

Commit b7b4dc7

Browse files
authored
feat(guardrail): Add guardrail decorator (#3521)
1 parent 9d8cef6 commit b7b4dc7

File tree

8 files changed

+612
-5
lines changed

8 files changed

+612
-5
lines changed

packages/sample-app/sample_app/agents/travel_agent_example.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -654,11 +654,13 @@ async def handle_runner_stream(runner: "Runner"):
654654
"""Process runner events and display output."""
655655

656656
tool_calls_made = []
657+
response_text_parts = []
657658

658659
async for event in runner.stream_events():
659660
if event.type == "raw_response_event":
660661
if isinstance(event.data, ResponseTextDeltaEvent):
661662
print(event.data.delta, end="", flush=True)
663+
response_text_parts.append(event.data.delta)
662664
elif isinstance(event.data, ResponseOutputItemAddedEvent):
663665
if isinstance(event.data.item, ResponseFunctionToolCall):
664666
tool_name = event.data.item.name
@@ -688,17 +690,29 @@ async def handle_runner_stream(runner: "Runner"):
688690
for part in getattr(raw_item, "content", []):
689691
if isinstance(part, ResponseOutputText):
690692
content_parts.append(part.text)
693+
response_text_parts.append(part.text)
691694
elif isinstance(part, ResponseOutputRefusal):
692695
content_parts.append(part.refusal)
696+
response_text_parts.append(part.refusal)
693697
if content_parts:
694698
print("".join(content_parts), end="", flush=True)
695699

696700
print()
697-
return tool_calls_made
701+
return tool_calls_made, "".join(response_text_parts)
698702

699703

700-
async def run_travel_query(query: str):
701-
"""Run a single travel planning query."""
704+
async def run_travel_query(query: str, return_response_text: bool = False):
705+
"""
706+
Run a single travel planning query.
707+
708+
Args:
709+
query: The travel planning query
710+
return_response_text: If True, returns the response text.
711+
If False, returns tool_calls (for backward compatibility)
712+
713+
Returns:
714+
Either response_text (str) or tool_calls (list) depending on parameter
715+
"""
702716

703717
print("=" * 80)
704718
print(f"Query: {query}")
@@ -710,13 +724,16 @@ async def run_travel_query(query: str):
710724

711725
messages = [{"role": "user", "content": query}]
712726
runner = Runner().run_streamed(starting_agent=travel_agent, input=messages)
713-
tool_calls = await handle_runner_stream(runner)
727+
tool_calls, response_text = await handle_runner_stream(runner)
714728

715729
print(f"\n{'='*80}")
716730
print(f"✅ Query completed! Tools used: {', '.join(tool_calls) if tool_calls else 'None'}")
717731
print(f"{'='*80}\n")
718732

719-
return tool_calls
733+
if return_response_text:
734+
return response_text
735+
else:
736+
return tool_calls # Backward compatibility for existing code
720737

721738

722739
def generate_travel_queries(n: int = 10) -> List[str]:
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
import asyncio
2+
import os
3+
from openai import AsyncOpenAI
4+
from traceloop.sdk import Traceloop
5+
from traceloop.sdk.guardrails.guardrails import guardrail
6+
from traceloop.sdk.evaluator import EvaluatorMadeByTraceloop
7+
8+
9+
Traceloop.init(
10+
app_name="medical-chat-example"
11+
)
12+
13+
api_key = os.getenv("OPENAI_API_KEY")
14+
if not api_key:
15+
raise ValueError("OPENAI_API_KEY environment variable is required. Please set it before running this example.")
16+
17+
client = AsyncOpenAI(api_key=api_key)
18+
19+
20+
# Custom callback function to handle evaluation results
21+
def handle_medical_evaluation(evaluator_result, original_result):
22+
"""
23+
Custom handler for medical advice evaluation.
24+
25+
Args:
26+
evaluator_result: The evaluation result with 'success' and 'reason' fields
27+
original_result: The original AI response dict (e.g., {"text": "..."})
28+
29+
Returns:
30+
Either the original result dict or a modified version
31+
"""
32+
if not evaluator_result.success:
33+
# Return a modified dict with error message
34+
print(f"handle_medical_evaluation was activated - evaluator_result: {evaluator_result}")
35+
return {
36+
"text": "There is an issue with the request. Please try again."
37+
}
38+
return original_result
39+
40+
41+
@guardrail(
42+
evaluator=EvaluatorMadeByTraceloop.pii_detector(probability_threshold=0.8),
43+
on_evaluation_complete=handle_medical_evaluation
44+
)
45+
async def get_doctor_response_with_pii_check(patient_message: str) -> dict:
46+
"""Get a doctor's response with PII detection guardrail and custom callback."""
47+
48+
system_prompt = """You are a medical AI assistant. Provide helpful,
49+
general medical information and advice while being clear about your limitations.
50+
Always recommend consulting with qualified healthcare providers for proper diagnosis and treatment.
51+
Be empathetic and professional in your responses."""
52+
# This is the system prompt for the personal information case
53+
personal_info_system_prompt = """You are a medical AI assistant that provides helpful, general medical information # noqa: F841
54+
tailored to the individual user.
55+
56+
When personal information is available (such as age, sex, symptoms, medical history,
57+
lifestyle, medications, or concerns), actively incorporate it into your responses
58+
to make guidance more relevant and personalized.
59+
60+
Adapt explanations, examples, and recommendations to the user’s context whenever possible.
61+
If key personal details are missing, ask concise and relevant follow-up questions
62+
before giving advice.
63+
64+
Be clear about your limitations as an AI and do not provide diagnoses or definitive
65+
treatment plans. Always encourage consultation with qualified healthcare professionals
66+
for diagnosis, treatment, or urgent concerns.
67+
68+
Maintain a professional, empathetic, and supportive tone.
69+
Avoid assumptions, respect privacy, and clearly distinguish general information
70+
from personalized considerations."""
71+
72+
response = await client.chat.completions.create(
73+
model="gpt-4o",
74+
messages=[
75+
{"role": "system", "content": system_prompt},
76+
{"role": "user", "content": patient_message}
77+
],
78+
max_tokens=500,
79+
temperature=0
80+
)
81+
82+
return {
83+
"text": response.choices[0].message.content
84+
}
85+
86+
87+
# Main function using the simple example
88+
@guardrail(evaluator="medicaladvicegiven")
89+
async def get_doctor_response(patient_message: str) -> dict:
90+
"""Get a doctor's response to patient input using GPT-4o."""
91+
92+
system_prompt = """You are a medical AI assistant. Provide helpful,
93+
general medical information and advice while being clear about your limitations.
94+
Always recommend consulting with qualified healthcare providers for proper diagnosis and treatment.
95+
Be empathetic and professional in your responses."""
96+
97+
response = await client.chat.completions.create(
98+
model="gpt-4o",
99+
messages=[
100+
{"role": "system", "content": system_prompt},
101+
{"role": "user", "content": patient_message}
102+
],
103+
max_tokens=500,
104+
temperature=0
105+
)
106+
107+
# Return dict with 'text' field
108+
return {
109+
"text": response.choices[0].message.content
110+
}
111+
112+
113+
async def medical_chat_session():
114+
"""Run an interactive medical chat session."""
115+
print("🏥 Welcome to the Medical Chat")
116+
print("=" * 50)
117+
print("This example simulates a conversation between a patient and a doctor.")
118+
print("The doctor's responses are processed through guardrails to ensure safety.")
119+
print("Type 'quit' to exit the chat.\n")
120+
121+
while True:
122+
try:
123+
patient_input = input("Patient: ").strip()
124+
125+
if patient_input.lower() in ['quit', 'exit', 'q']:
126+
print("\n👋 Thank you for using the medical chat. Take care!")
127+
break
128+
129+
if not patient_input:
130+
print("Please enter your symptoms or medical concern.")
131+
continue
132+
133+
print("\n🤖 Processing your request through the medical AI system...\n")
134+
135+
# Get the doctor's response with guardrails applied
136+
doctor_response = await get_doctor_response_with_pii_check(patient_input)
137+
138+
# Extract text from the response dict
139+
response_text = doctor_response.get("text", str(doctor_response))
140+
print(f"👨‍⚕️ Doctor response: {response_text}")
141+
142+
print("-" * 50)
143+
144+
except KeyboardInterrupt:
145+
print("\n\n👋 Chat session interrupted. Goodbye!")
146+
break
147+
except Exception as e:
148+
print(f"\n❌ An error occurred: {e}")
149+
print("Please try again or type 'quit' to exit.")
150+
151+
152+
async def main():
153+
"""Main function to run the medical chat example."""
154+
await medical_chat_session()
155+
156+
157+
if __name__ == "__main__":
158+
asyncio.run(main())
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
"""
2+
Guardrail wrapper for the existing travel agent example.
3+
This demonstrates how to add PII detection guardrails to an existing agentic system.
4+
"""
5+
import asyncio
6+
import sys
7+
from pathlib import Path
8+
9+
# Add the agents directory to the path
10+
agents_dir = Path(__file__).parent / "agents"
11+
sys.path.insert(0, str(agents_dir))
12+
13+
from traceloop.sdk.guardrails.guardrails import guardrail # noqa: E402
14+
from traceloop.sdk.evaluator import EvaluatorMadeByTraceloop # noqa: E402
15+
16+
# Import the travel agent function
17+
try:
18+
from travel_agent_example import run_travel_query
19+
except ImportError:
20+
print("Error: Could not import travel_agent_example.")
21+
print(f"Make sure {agents_dir}/travel_agent_example.py exists")
22+
sys.exit(1)
23+
24+
25+
# Custom callback to handle PII detection in travel agent responses
26+
def handle_pii_detection(evaluator_result, original_result):
27+
"""
28+
Custom handler for PII detection in travel agent itineraries.
29+
30+
Args:
31+
evaluator_result: The evaluation result with 'success' and 'reason' fields
32+
original_result: The original travel itinerary response dict
33+
34+
Returns:
35+
Either the original result dict or a sanitized warning
36+
"""
37+
# Get captured stdout
38+
captured_stdout = original_result.get("_captured_stdout", "")
39+
40+
if not evaluator_result.success:
41+
# PII was detected - don't display the output, return warning
42+
return {
43+
"text": (
44+
"⚠️ PRIVACY ALERT: The generated travel itinerary contains personally "
45+
"identifiable information (PII) that could compromise your privacy.\n\n"
46+
"For your security, we cannot display this itinerary. Please:\n"
47+
"1. Contact our support team through secure channels\n"
48+
"2. Request a generic itinerary without personal details\n"
49+
"3. Avoid sharing sensitive information in your travel requests\n\n"
50+
f"Detection reason: {evaluator_result.reason or 'PII detected in response'}"
51+
)
52+
}
53+
else:
54+
# Guardrail passed - now safe to display the captured output
55+
if captured_stdout:
56+
print(captured_stdout, end="")
57+
58+
# Remove internal field before returning
59+
result = original_result.copy()
60+
result.pop("_captured_stdout", None)
61+
return result
62+
63+
64+
@guardrail(
65+
evaluator=EvaluatorMadeByTraceloop.pii_detector(probability_threshold=0.7),
66+
on_evaluation_complete=handle_pii_detection
67+
)
68+
async def guarded_travel_agent(query: str) -> dict:
69+
"""
70+
Wrapper around the travel agent that adds PII detection guardrails.
71+
72+
This function:
73+
1. Runs the full travel agent flow (tools, API calls, itinerary generation)
74+
2. Gets the final response text from the agent
75+
3. Runs PII detection on the complete output
76+
4. Returns sanitized response if PII is detected
77+
78+
Args:
79+
query: User's travel planning request
80+
81+
Returns:
82+
Dict with 'text' field containing the travel itinerary or privacy warning
83+
"""
84+
import io
85+
86+
# Capture stdout to prevent streaming output before guardrail check
87+
old_stdout = sys.stdout
88+
sys.stdout = captured_output = io.StringIO()
89+
90+
try:
91+
# Run the travel agent and get the response text
92+
# return_response_text=True makes it return the agent's text instead of tool_calls
93+
response_text = await run_travel_query(query, return_response_text=True)
94+
95+
# Return dict with 'text' field as required by pii_detector
96+
return {"text": response_text, "_captured_stdout": captured_output.getvalue()}
97+
98+
finally:
99+
# Restore stdout
100+
sys.stdout = old_stdout
101+
102+
103+
async def main():
104+
"""
105+
Interactive travel agent with PII detection guardrails.
106+
"""
107+
print("=" * 80)
108+
print("🛡️ Travel Agent with PII Detection Guardrails")
109+
print("=" * 80)
110+
print("This travel agent uses PII detection to protect your privacy.")
111+
print("The agent's output is hidden until the guardrail check completes.")
112+
print("Type 'quit' or 'exit' to stop.\n")
113+
print("💡 Example queries:")
114+
print(" - Plan a 5-day trip to Paris for couples interested in food")
115+
print(" - I want to visit Tokyo for 7 days with a moderate budget")
116+
print(" - Create an itinerary for a family trip to Barcelona")
117+
print("=" * 80)
118+
print()
119+
120+
while True:
121+
try:
122+
# Get user input
123+
user_query = input("\n✈️ Your travel request: ").strip()
124+
125+
if not user_query:
126+
print("Please enter a travel planning request.")
127+
continue
128+
129+
if user_query.lower() in ['quit', 'exit', 'q']:
130+
print("\n👋 Thank you for using the Travel Agent. Safe travels!")
131+
break
132+
133+
print("\n🔒 Running travel agent with PII guardrail check...")
134+
print("(Agent output will appear after guardrail validation)\n")
135+
136+
# Run the guarded travel agent
137+
result = await guarded_travel_agent(user_query)
138+
139+
# Display the result
140+
print("\n" + "=" * 80)
141+
print("📋 FINAL RESPONSE (after PII guardrail check):")
142+
print("=" * 80)
143+
response_text = result.get("text", "")
144+
print(response_text)
145+
print("=" * 80)
146+
147+
# Check if this was a warning (PII detected)
148+
if "PRIVACY ALERT" in response_text:
149+
print("❌ Response blocked due to PII detection")
150+
else:
151+
print("✅ Response approved by guardrail")
152+
153+
print("=" * 80)
154+
155+
except KeyboardInterrupt:
156+
print("\n\n👋 Session interrupted. Goodbye!")
157+
break
158+
except Exception as e:
159+
print(f"\n❌ Error: {e}")
160+
import traceback
161+
traceback.print_exc()
162+
print("\nPlease try again or type 'quit' to exit.")
163+
164+
165+
if __name__ == "__main__":
166+
asyncio.run(main())

0 commit comments

Comments
 (0)