Skip to content

Commit 8bf015f

Browse files
author
Kazmer, Nagy-Betegh
committed
feat(agent): make structured output part of the agent loop
1 parent 1c7257b commit 8bf015f

File tree

2 files changed

+188
-52
lines changed

2 files changed

+188
-52
lines changed

src/strands/agent/agent.py

Lines changed: 170 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from opentelemetry import trace as trace_api
2020
from pydantic import BaseModel
2121

22+
from strands.tools.decorator import tool
23+
2224
from ..event_loop.event_loop import event_loop_cycle, run_tool
2325
from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler
2426
from ..hooks import (
@@ -400,7 +402,8 @@ async def invoke_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A
400402

401403
return cast(AgentResult, event["result"])
402404

403-
def structured_output(self, output_model: Type[T], prompt: Optional[Union[str, list[ContentBlock]]] = None) -> T:
405+
def structured_output(self, output_model: Type[T], prompt: Optional[Union[str, list[ContentBlock]]] = None,
406+
preserve_conversation: bool = False) -> T:
404407
"""This method allows you to get structured output from the agent.
405408
406409
If you pass in a prompt, it will be used temporarily without adding it to the conversation history.
@@ -413,20 +416,35 @@ def structured_output(self, output_model: Type[T], prompt: Optional[Union[str, l
413416
output_model: The output model (a JSON schema written as a Pydantic BaseModel)
414417
that the agent will use when responding.
415418
prompt: The prompt to use for the agent (will not be added to conversation history).
419+
preserve_conversation: If False (default), restores original conversation state after execution.
420+
If True, allows structured output execution to modify conversation history.
416421
417422
Raises:
418423
ValueError: If no conversation history or prompt is provided.
419424
"""
420425

421426
def execute() -> T:
422-
return asyncio.run(self.structured_output_async(output_model, prompt))
427+
return asyncio.run(self.structured_output_async(output_model, prompt, preserve_conversation))
423428

424429
with ThreadPoolExecutor() as executor:
425430
future = executor.submit(execute)
426431
return future.result()
427432

433+
def _register_structured_output_tool(self, output_model: type[BaseModel]):
434+
@tool
435+
def _structured_output(input: output_model) -> output_model:
436+
"""If this tool is present it MUST be used to return structured data for the user."""
437+
return input
438+
439+
return _structured_output
440+
441+
def _get_structured_output_tool(self, output_model: Type[T]):
442+
"""Get or create the structured output tool for the given model."""
443+
return self._register_structured_output_tool(output_model)
444+
428445
async def structured_output_async(
429-
self, output_model: Type[T], prompt: Optional[Union[str, list[ContentBlock]]] = None
446+
self, output_model: Type[T], prompt: Optional[Union[str, list[ContentBlock]]] = None,
447+
preserve_conversation: bool = False
430448
) -> T:
431449
"""This method allows you to get structured output from the agent.
432450
@@ -440,53 +458,163 @@ async def structured_output_async(
440458
output_model: The output model (a JSON schema written as a Pydantic BaseModel)
441459
that the agent will use when responding.
442460
prompt: The prompt to use for the agent (will not be added to conversation history).
461+
preserve_conversation: If False (default), restores original conversation state after execution.
462+
If True, allows structured output execution to modify conversation history.
443463
444464
Raises:
445465
ValueError: If no conversation history or prompt is provided.
446466
"""
447467
self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self))
448-
with self.tracer.tracer.start_as_current_span(
449-
"execute_structured_output", kind=trace_api.SpanKind.CLIENT
450-
) as structured_output_span:
451-
try:
452-
if not self.messages and not prompt:
453-
raise ValueError("No conversation history or prompt provided")
454-
# Create temporary messages array if prompt is provided
455-
if prompt:
456-
content: list[ContentBlock] = [{"text": prompt}] if isinstance(prompt, str) else prompt
457-
temp_messages = self.messages + [{"role": "user", "content": content}]
458-
else:
459-
temp_messages = self.messages
460-
461-
structured_output_span.set_attributes(
462-
{
463-
"gen_ai.system": "strands-agents",
464-
"gen_ai.agent.name": self.name,
465-
"gen_ai.agent.id": self.agent_id,
466-
"gen_ai.operation.name": "execute_structured_output",
467-
}
468-
)
469-
for message in temp_messages:
470-
structured_output_span.add_event(
471-
f"gen_ai.{message['role']}.message",
472-
attributes={"role": message["role"], "content": serialize(message["content"])},
468+
469+
# Save original state for restoration - avoid deep copying callbacks that contain async generators
470+
import copy
471+
from contextlib import suppress
472+
from ..experimental.hooks import AfterToolInvocationEvent
473+
474+
# Store references to what we'll add temporarily
475+
added_tool_name = None
476+
added_callback = None
477+
478+
# Save original messages if we need to restore them later
479+
original_messages = copy.deepcopy(self.messages) if not preserve_conversation else None
480+
481+
# Create and add the structured output tool
482+
structured_output_tool = self._register_structured_output_tool(output_model)
483+
self.tool_registry.register_tool(structured_output_tool)
484+
added_tool_name = structured_output_tool.tool_name
485+
486+
# Variable to capture the structured result
487+
captured_result = None
488+
489+
# Hook to capture structured output tool invocation
490+
def capture_structured_output_hook(event: AfterToolInvocationEvent) -> AfterToolInvocationEvent:
491+
nonlocal captured_result
492+
493+
if (
494+
event.selected_tool
495+
and hasattr(event.selected_tool, "tool_name")
496+
and event.selected_tool.tool_name == "_structured_output"
497+
and event.result
498+
and event.result.get("status") == "success"
499+
):
500+
# Parse the validated Pydantic model from the tool result
501+
with suppress(Exception):
502+
content = event.result.get("content", [])
503+
if content and isinstance(content[0], dict) and "text" in content[0]:
504+
# The tool returns the model instance as string, but we need the actual instance
505+
# Since our tool returns the input directly, we can reconstruct it
506+
tool_input = event.tool_use.get("input", {}).get("input")
507+
if tool_input:
508+
captured_result = output_model(**tool_input)
509+
510+
return event
511+
512+
# Add the callback temporarily (use add_callback, not add_hook)
513+
self.hooks.add_callback(AfterToolInvocationEvent, capture_structured_output_hook)
514+
added_callback = capture_structured_output_hook
515+
516+
try:
517+
with self.tracer.tracer.start_as_current_span(
518+
"execute_structured_output", kind=trace_api.SpanKind.CLIENT
519+
) as structured_output_span:
520+
try:
521+
if not self.messages and not prompt:
522+
raise ValueError("No conversation history or prompt provided")
523+
524+
# Create temporary messages array if prompt is provided
525+
if prompt:
526+
content: list[ContentBlock] = [{"text": prompt}] if isinstance(prompt, str) else prompt
527+
message = {"role": "user", "content": content}
528+
else:
529+
# Use existing conversation history
530+
message = {
531+
"role": "user",
532+
"content": [
533+
{
534+
"text": "Please provide the information from our conversation in the requested structured format."
535+
}
536+
],
537+
}
538+
539+
structured_output_span.set_attributes(
540+
{
541+
"gen_ai.system": "strands-agents",
542+
"gen_ai.agent.name": self.name,
543+
"gen_ai.agent.id": self.agent_id,
544+
"gen_ai.operation.name": "execute_structured_output",
545+
}
473546
)
474-
if self.system_prompt:
547+
548+
# Add tracing for messages
549+
messages_to_trace = self.messages if not prompt else self.messages + [message]
550+
for msg in messages_to_trace:
551+
structured_output_span.add_event(
552+
f"gen_ai.{msg['role']}.message",
553+
attributes={"role": msg["role"], "content": serialize(msg["content"])},
554+
)
555+
556+
if self.system_prompt:
557+
structured_output_span.add_event(
558+
"gen_ai.system.message",
559+
attributes={"role": "system", "content": serialize([{"text": self.system_prompt}])},
560+
)
561+
562+
invocation_state = {
563+
"structured_output_mode": True,
564+
"structured_output_model": output_model,
565+
}
566+
567+
# Run the event loop
568+
async for event in self._run_loop(message=message, invocation_state=invocation_state):
569+
if "stop" in event:
570+
break
571+
572+
# Return the captured structured result if we got it from the tool
573+
if captured_result:
574+
structured_output_span.add_event(
575+
"gen_ai.choice", attributes={"message": serialize(captured_result.model_dump())}
576+
)
577+
return captured_result
578+
579+
# Fallback: Use the original model.structured_output approach
580+
# This maintains backward compatibility with existing tests and implementations
581+
# Use original_messages to get clean message state, or self.messages if preserve_conversation=True
582+
base_messages = original_messages if original_messages is not None else self.messages
583+
temp_messages = base_messages if not prompt else base_messages + [message]
584+
585+
events = self.model.structured_output(output_model, temp_messages, system_prompt=self.system_prompt)
586+
async for event in events:
587+
if "callback" in event:
588+
self.callback_handler(**cast(dict, event["callback"]))
589+
475590
structured_output_span.add_event(
476-
"gen_ai.system.message",
477-
attributes={"role": "system", "content": serialize([{"text": self.system_prompt}])},
591+
"gen_ai.choice", attributes={"message": serialize(event["output"].model_dump())}
478592
)
479-
events = self.model.structured_output(output_model, temp_messages, system_prompt=self.system_prompt)
480-
async for event in events:
481-
if "callback" in event:
482-
self.callback_handler(**cast(dict, event["callback"]))
483-
structured_output_span.add_event(
484-
"gen_ai.choice", attributes={"message": serialize(event["output"].model_dump())}
485-
)
486-
return event["output"]
487-
488-
finally:
489-
self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self))
593+
return event["output"]
594+
595+
except Exception as e:
596+
structured_output_span.record_exception(e)
597+
raise
598+
599+
finally:
600+
# Clean up what we added - remove the callback
601+
if added_callback and AfterToolInvocationEvent in self.hooks._registered_callbacks:
602+
callbacks = self.hooks._registered_callbacks[AfterToolInvocationEvent]
603+
if added_callback in callbacks:
604+
callbacks.remove(added_callback)
605+
606+
# Remove the tool we added
607+
if added_tool_name:
608+
if added_tool_name in self.tool_registry.registry:
609+
del self.tool_registry.registry[added_tool_name]
610+
if added_tool_name in self.tool_registry.dynamic_tools:
611+
del self.tool_registry.dynamic_tools[added_tool_name]
612+
613+
# Restore original messages if preserve_conversation is False
614+
if original_messages is not None:
615+
self.messages = original_messages
616+
617+
self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self))
490618

491619
async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: Any) -> AsyncIterator[Any]:
492620
"""Process a natural language prompt and yield events as an async iterator.

tests/strands/agent/test_agent_hooks.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -263,30 +263,38 @@ def test_agent_structured_output_hooks(agent, hook_provider, user, agenerator):
263263
"""Verify that the correct hook events are emitted as part of structured_output."""
264264

265265
agent.model.structured_output = Mock(return_value=agenerator([{"output": user}]))
266-
agent.structured_output(type(user), "example prompt")
266+
agent.structured_output(type(user), "example prompt", preserve_conversation=True)
267267

268268
length, events = hook_provider.get_events()
269+
events_list = list(events)
269270

270-
assert length == 2
271+
# With the new tool-based implementation, we get more events from the event loop
272+
# but we should still have BeforeInvocationEvent first and AfterInvocationEvent last
273+
assert length > 2 # More events due to event loop execution
271274

272-
assert next(events) == BeforeInvocationEvent(agent=agent)
273-
assert next(events) == AfterInvocationEvent(agent=agent)
275+
assert events_list[0] == BeforeInvocationEvent(agent=agent)
276+
assert events_list[-1] == AfterInvocationEvent(agent=agent)
274277

275-
assert len(agent.messages) == 0 # no new messages added
278+
# With the new tool-based implementation, messages are added during the structured output process
279+
assert len(agent.messages) > 0 # messages added during structured output execution
276280

277281

278282
@pytest.mark.asyncio
279283
async def test_agent_structured_async_output_hooks(agent, hook_provider, user, agenerator):
280284
"""Verify that the correct hook events are emitted as part of structured_output_async."""
281285

282286
agent.model.structured_output = Mock(return_value=agenerator([{"output": user}]))
283-
await agent.structured_output_async(type(user), "example prompt")
287+
await agent.structured_output_async(type(user), "example prompt", preserve_conversation=True)
284288

285289
length, events = hook_provider.get_events()
290+
events_list = list(events)
286291

287-
assert length == 2
292+
# With the new tool-based implementation, we get more events from the event loop
293+
# but we should still have BeforeInvocationEvent first and AfterInvocationEvent last
294+
assert length > 2 # More events due to event loop execution
288295

289-
assert next(events) == BeforeInvocationEvent(agent=agent)
290-
assert next(events) == AfterInvocationEvent(agent=agent)
296+
assert events_list[0] == BeforeInvocationEvent(agent=agent)
297+
assert events_list[-1] == AfterInvocationEvent(agent=agent)
291298

292-
assert len(agent.messages) == 0 # no new messages added
299+
# With the new tool-based implementation, messages are added during the structured output process
300+
assert len(agent.messages) > 0 # messages added during structured output execution

0 commit comments

Comments
 (0)