Skip to content

Commit 713d681

Browse files
author
Kazmer, Nagy-Betegh
committed
feat(agent): make structured output part of the agent loop
1 parent 1c7257b commit 713d681

File tree

1 file changed

+136
-41
lines changed

1 file changed

+136
-41
lines changed

src/strands/agent/agent.py

Lines changed: 136 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from opentelemetry import trace as trace_api
2020
from pydantic import BaseModel
2121

22+
from strands.tools.decorator import tool
23+
2224
from ..event_loop.event_loop import event_loop_cycle, run_tool
2325
from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler
2426
from ..hooks import (
@@ -425,6 +427,18 @@ def execute() -> T:
425427
future = executor.submit(execute)
426428
return future.result()
427429

430+
def _register_structured_output_tool(self, output_model: type[BaseModel]):
431+
@tool
432+
def _structured_output(input: output_model) -> output_model:
433+
"""If this tool is present it MUST be used to return structured data for the user."""
434+
return input
435+
436+
return _structured_output
437+
438+
def _get_structured_output_tool(self, output_model: Type[T]):
439+
"""Get or create the structured output tool for the given model."""
440+
return self._register_structured_output_tool(output_model)
441+
428442
async def structured_output_async(
429443
self, output_model: Type[T], prompt: Optional[Union[str, list[ContentBlock]]] = None
430444
) -> T:
@@ -445,48 +459,129 @@ async def structured_output_async(
445459
ValueError: If no conversation history or prompt is provided.
446460
"""
447461
self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self))
448-
with self.tracer.tracer.start_as_current_span(
449-
"execute_structured_output", kind=trace_api.SpanKind.CLIENT
450-
) as structured_output_span:
451-
try:
452-
if not self.messages and not prompt:
453-
raise ValueError("No conversation history or prompt provided")
454-
# Create temporary messages array if prompt is provided
455-
if prompt:
456-
content: list[ContentBlock] = [{"text": prompt}] if isinstance(prompt, str) else prompt
457-
temp_messages = self.messages + [{"role": "user", "content": content}]
458-
else:
459-
temp_messages = self.messages
460-
461-
structured_output_span.set_attributes(
462-
{
463-
"gen_ai.system": "strands-agents",
464-
"gen_ai.agent.name": self.name,
465-
"gen_ai.agent.id": self.agent_id,
466-
"gen_ai.operation.name": "execute_structured_output",
467-
}
468-
)
469-
for message in temp_messages:
470-
structured_output_span.add_event(
471-
f"gen_ai.{message['role']}.message",
472-
attributes={"role": message["role"], "content": serialize(message["content"])},
473-
)
474-
if self.system_prompt:
475-
structured_output_span.add_event(
476-
"gen_ai.system.message",
477-
attributes={"role": "system", "content": serialize([{"text": self.system_prompt}])},
462+
463+
# Save original state for restoration BEFORE making any changes
464+
import copy
465+
466+
original_hooks_callbacks = copy.deepcopy(self.hooks._registered_callbacks)
467+
original_tool_registry = copy.deepcopy(self.tool_registry.registry)
468+
original_dynamic_tools = copy.deepcopy(self.tool_registry.dynamic_tools)
469+
470+
# Create and add the structured output tool
471+
structured_output_tool = self._register_structured_output_tool(output_model)
472+
self.tool_registry.register_tool(structured_output_tool)
473+
474+
# Variable to capture the structured result
475+
captured_result = None
476+
477+
# Import here to avoid circular imports
478+
from ..experimental.hooks import AfterToolInvocationEvent
479+
480+
# Hook to capture structured output tool invocation
481+
def capture_structured_output_hook(event: AfterToolInvocationEvent) -> AfterToolInvocationEvent:
482+
nonlocal captured_result
483+
484+
if (
485+
event.selected_tool
486+
and hasattr(event.selected_tool, "tool_name")
487+
and event.selected_tool.tool_name == "_structured_output"
488+
and event.result
489+
and event.result.get("status") == "success"
490+
):
491+
# Parse the validated Pydantic model from the tool result
492+
try:
493+
content = event.result.get("content", [])
494+
if content and isinstance(content[0], dict) and "text" in content[0]:
495+
# The tool returns the model instance as string, but we need the actual instance
496+
# Since our tool returns the input directly, we can reconstruct it
497+
tool_input = event.tool_use.get("input", {}).get("input")
498+
if tool_input:
499+
captured_result = output_model(**tool_input)
500+
except Exception:
501+
# Fallback: the tool should have returned the validated model
502+
pass
503+
504+
return event
505+
506+
# Add the callback temporarily (use add_callback, not add_hook)
507+
self.hooks.add_callback(AfterToolInvocationEvent, capture_structured_output_hook)
508+
509+
try:
510+
with self.tracer.tracer.start_as_current_span(
511+
"execute_structured_output", kind=trace_api.SpanKind.CLIENT
512+
) as structured_output_span:
513+
try:
514+
if not self.messages and not prompt:
515+
raise ValueError("No conversation history or prompt provided")
516+
517+
# Create temporary messages array if prompt is provided
518+
if prompt:
519+
content: list[ContentBlock] = [{"text": prompt}] if isinstance(prompt, str) else prompt
520+
message = {"role": "user", "content": content}
521+
else:
522+
# Use existing conversation history
523+
message = {
524+
"role": "user",
525+
"content": [
526+
{
527+
"text": "Please provide the information from our conversation in the requested structured format."
528+
}
529+
],
530+
}
531+
532+
structured_output_span.set_attributes(
533+
{
534+
"gen_ai.system": "strands-agents",
535+
"gen_ai.agent.name": self.name,
536+
"gen_ai.agent.id": self.agent_id,
537+
"gen_ai.operation.name": "execute_structured_output",
538+
}
478539
)
479-
events = self.model.structured_output(output_model, temp_messages, system_prompt=self.system_prompt)
480-
async for event in events:
481-
if "callback" in event:
482-
self.callback_handler(**cast(dict, event["callback"]))
483-
structured_output_span.add_event(
484-
"gen_ai.choice", attributes={"message": serialize(event["output"].model_dump())}
485-
)
486-
return event["output"]
487-
488-
finally:
489-
self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self))
540+
541+
# Add tracing for messages
542+
messages_to_trace = self.messages if not prompt else self.messages + [message]
543+
for msg in messages_to_trace:
544+
structured_output_span.add_event(
545+
f"gen_ai.{msg['role']}.message",
546+
attributes={"role": msg["role"], "content": serialize(msg["content"])},
547+
)
548+
549+
if self.system_prompt:
550+
structured_output_span.add_event(
551+
"gen_ai.system.message",
552+
attributes={"role": "system", "content": serialize([{"text": self.system_prompt}])},
553+
)
554+
555+
invocation_state = {
556+
"structured_output_mode": True,
557+
"structured_output_model": output_model,
558+
}
559+
560+
# Run the event loop
561+
async for event in self._run_loop(message=message, invocation_state=invocation_state):
562+
if "stop" in event:
563+
break
564+
565+
# Return the captured structured result
566+
if captured_result:
567+
structured_output_span.add_event(
568+
"gen_ai.choice", attributes={"message": serialize(captured_result.model_dump())}
569+
)
570+
return captured_result
571+
else:
572+
raise ValueError("Failed to capture structured output from agent")
573+
574+
except Exception as e:
575+
structured_output_span.record_exception(e)
576+
raise
577+
578+
finally:
579+
# Restore original state
580+
self.hooks._registered_callbacks = original_hooks_callbacks
581+
self.tool_registry.registry = original_tool_registry
582+
self.tool_registry.dynamic_tools = original_dynamic_tools
583+
584+
self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self))
490585

491586
async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: Any) -> AsyncIterator[Any]:
492587
"""Process a natural language prompt and yield events as an async iterator.

0 commit comments

Comments
 (0)