Skip to content

Commit 33bc6e8

Browse files
author
Kazmer, Nagy-Betegh
committed
feat(agent): make structured output part of the agent loop
1 parent 1c7257b commit 33bc6e8

File tree

2 files changed

+193
-52
lines changed

2 files changed

+193
-52
lines changed

src/strands/agent/agent.py

Lines changed: 175 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,21 @@
1010
"""
1111

1212
import asyncio
13+
import copy
1314
import json
1415
import logging
1516
import random
1617
from concurrent.futures import ThreadPoolExecutor
18+
from contextlib import suppress
1719
from typing import Any, AsyncGenerator, AsyncIterator, Callable, Mapping, Optional, Type, TypeVar, Union, cast
1820

1921
from opentelemetry import trace as trace_api
2022
from pydantic import BaseModel
2123

24+
from strands.tools.decorator import tool
25+
2226
from ..event_loop.event_loop import event_loop_cycle, run_tool
27+
from ..experimental.hooks import AfterToolInvocationEvent
2328
from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler
2429
from ..hooks import (
2530
AfterInvocationEvent,
@@ -400,7 +405,12 @@ async def invoke_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A
400405

401406
return cast(AgentResult, event["result"])
402407

403-
def structured_output(self, output_model: Type[T], prompt: Optional[Union[str, list[ContentBlock]]] = None) -> T:
408+
def structured_output(
409+
self,
410+
output_model: Type[T],
411+
prompt: Optional[Union[str, list[ContentBlock]]] = None,
412+
preserve_conversation: bool = False,
413+
) -> T:
404414
"""This method allows you to get structured output from the agent.
405415
406416
If you pass in a prompt, it will be used temporarily without adding it to the conversation history.
@@ -413,20 +423,33 @@ def structured_output(self, output_model: Type[T], prompt: Optional[Union[str, l
413423
output_model: The output model (a JSON schema written as a Pydantic BaseModel)
414424
that the agent will use when responding.
415425
prompt: The prompt to use for the agent (will not be added to conversation history).
426+
preserve_conversation: If False (default), restores original conversation state after execution.
427+
If True, allows structured output execution to modify conversation history.
416428
417429
Raises:
418430
ValueError: If no conversation history or prompt is provided.
419431
"""
420432

421433
def execute() -> T:
422-
return asyncio.run(self.structured_output_async(output_model, prompt))
434+
return asyncio.run(self.structured_output_async(output_model, prompt, preserve_conversation))
423435

424436
with ThreadPoolExecutor() as executor:
425437
future = executor.submit(execute)
426438
return future.result()
427439

440+
def _register_structured_output_tool(self, output_model: type[BaseModel]) -> Any:
441+
@tool
442+
def _structured_output(input: output_model) -> output_model: # type: ignore[valid-type]
443+
"""If this tool is present it MUST be used to return structured data for the user."""
444+
return input
445+
446+
return _structured_output
447+
428448
async def structured_output_async(
429-
self, output_model: Type[T], prompt: Optional[Union[str, list[ContentBlock]]] = None
449+
self,
450+
output_model: Type[T],
451+
prompt: Optional[Union[str, list[ContentBlock]]] = None,
452+
preserve_conversation: bool = False,
430453
) -> T:
431454
"""This method allows you to get structured output from the agent.
432455
@@ -440,53 +463,163 @@ async def structured_output_async(
440463
output_model: The output model (a JSON schema written as a Pydantic BaseModel)
441464
that the agent will use when responding.
442465
prompt: The prompt to use for the agent (will not be added to conversation history).
466+
preserve_conversation: If False (default), restores original conversation state after execution.
467+
If True, allows structured output execution to modify conversation history.
443468
444469
Raises:
445470
ValueError: If no conversation history or prompt is provided.
446471
"""
447472
self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self))
448-
with self.tracer.tracer.start_as_current_span(
449-
"execute_structured_output", kind=trace_api.SpanKind.CLIENT
450-
) as structured_output_span:
451-
try:
452-
if not self.messages and not prompt:
453-
raise ValueError("No conversation history or prompt provided")
454-
# Create temporary messages array if prompt is provided
455-
if prompt:
456-
content: list[ContentBlock] = [{"text": prompt}] if isinstance(prompt, str) else prompt
457-
temp_messages = self.messages + [{"role": "user", "content": content}]
458-
else:
459-
temp_messages = self.messages
460-
461-
structured_output_span.set_attributes(
462-
{
463-
"gen_ai.system": "strands-agents",
464-
"gen_ai.agent.name": self.name,
465-
"gen_ai.agent.id": self.agent_id,
466-
"gen_ai.operation.name": "execute_structured_output",
467-
}
468-
)
469-
for message in temp_messages:
470-
structured_output_span.add_event(
471-
f"gen_ai.{message['role']}.message",
472-
attributes={"role": message["role"], "content": serialize(message["content"])},
473+
474+
# Save original state for restoration - avoid deep copying callbacks that contain async generators
475+
import copy
476+
from contextlib import suppress
477+
from ..experimental.hooks import AfterToolInvocationEvent
478+
479+
# Store references to what we'll add temporarily
480+
added_tool_name = None
481+
added_callback = None
482+
483+
# Save original messages if we need to restore them later
484+
original_messages = copy.deepcopy(self.messages) if not preserve_conversation else None
485+
486+
# Create and add the structured output tool
487+
structured_output_tool = self._register_structured_output_tool(output_model)
488+
self.tool_registry.register_tool(structured_output_tool)
489+
added_tool_name = structured_output_tool.tool_name
490+
491+
# Variable to capture the structured result
492+
captured_result = None
493+
494+
# Hook to capture structured output tool invocation
495+
def capture_structured_output_hook(event: AfterToolInvocationEvent) -> None:
496+
nonlocal captured_result
497+
498+
if (
499+
event.selected_tool
500+
and hasattr(event.selected_tool, "tool_name")
501+
and event.selected_tool.tool_name == "_structured_output"
502+
and event.result
503+
and event.result.get("status") == "success"
504+
):
505+
# Parse the validated Pydantic model from the tool result
506+
with suppress(Exception):
507+
content = event.result.get("content", [])
508+
if content and isinstance(content[0], dict) and "text" in content[0]:
509+
# The tool returns the model instance as string, but we need the actual instance
510+
# Since our tool returns the input directly, we can reconstruct it
511+
tool_input = event.tool_use.get("input", {}).get("input")
512+
if tool_input:
513+
captured_result = output_model(**tool_input)
514+
515+
# Add the callback temporarily (use add_callback, not add_hook)
516+
self.hooks.add_callback(AfterToolInvocationEvent, capture_structured_output_hook)
517+
added_callback = capture_structured_output_hook
518+
519+
try:
520+
with self.tracer.tracer.start_as_current_span(
521+
"execute_structured_output", kind=trace_api.SpanKind.CLIENT
522+
) as structured_output_span:
523+
try:
524+
if not self.messages and not prompt:
525+
raise ValueError("No conversation history or prompt provided")
526+
527+
# Create temporary messages array if prompt is provided
528+
message: Message
529+
if prompt:
530+
content: list[ContentBlock] = [{"text": prompt}] if isinstance(prompt, str) else prompt
531+
message = {"role": "user", "content": content}
532+
else:
533+
# Use existing conversation history
534+
message = {
535+
"role": "user",
536+
"content": [
537+
{
538+
"text": "Please provide the information from our conversation in the requested "
539+
"structured format."
540+
}
541+
],
542+
}
543+
544+
structured_output_span.set_attributes(
545+
{
546+
"gen_ai.system": "strands-agents",
547+
"gen_ai.agent.name": self.name,
548+
"gen_ai.agent.id": self.agent_id,
549+
"gen_ai.operation.name": "execute_structured_output",
550+
}
473551
)
474-
if self.system_prompt:
552+
553+
# Add tracing for messages
554+
messages_to_trace = self.messages if not prompt else self.messages + [message]
555+
for msg in messages_to_trace:
556+
structured_output_span.add_event(
557+
f"gen_ai.{msg['role']}.message",
558+
attributes={"role": msg["role"], "content": serialize(msg["content"])},
559+
)
560+
561+
if self.system_prompt:
562+
structured_output_span.add_event(
563+
"gen_ai.system.message",
564+
attributes={"role": "system", "content": serialize([{"text": self.system_prompt}])},
565+
)
566+
567+
invocation_state = {
568+
"structured_output_mode": True,
569+
"structured_output_model": output_model,
570+
}
571+
572+
# Run the event loop
573+
async for event in self._run_loop(message=message, invocation_state=invocation_state):
574+
if "stop" in event:
575+
break
576+
577+
# Return the captured structured result if we got it from the tool
578+
if captured_result:
579+
structured_output_span.add_event(
580+
"gen_ai.choice", attributes={"message": serialize(captured_result.model_dump())}
581+
)
582+
return captured_result
583+
584+
# Fallback: Use the original model.structured_output approach
585+
# This maintains backward compatibility with existing tests and implementations
586+
# Use original_messages to get clean message state, or self.messages if preserve_conversation=True
587+
base_messages = original_messages if original_messages is not None else self.messages
588+
temp_messages = base_messages if not prompt else base_messages + [message]
589+
590+
events = self.model.structured_output(output_model, temp_messages, system_prompt=self.system_prompt)
591+
async for event in events:
592+
if "callback" in event:
593+
self.callback_handler(**cast(dict, event["callback"]))
594+
475595
structured_output_span.add_event(
476-
"gen_ai.system.message",
477-
attributes={"role": "system", "content": serialize([{"text": self.system_prompt}])},
596+
"gen_ai.choice", attributes={"message": serialize(event["output"].model_dump())}
478597
)
479-
events = self.model.structured_output(output_model, temp_messages, system_prompt=self.system_prompt)
480-
async for event in events:
481-
if "callback" in event:
482-
self.callback_handler(**cast(dict, event["callback"]))
483-
structured_output_span.add_event(
484-
"gen_ai.choice", attributes={"message": serialize(event["output"].model_dump())}
485-
)
486-
return event["output"]
487-
488-
finally:
489-
self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self))
598+
return cast(T, event["output"])
599+
600+
except Exception as e:
601+
structured_output_span.record_exception(e)
602+
raise
603+
604+
finally:
605+
# Clean up what we added - remove the callback
606+
if added_callback is not None and AfterToolInvocationEvent in self.hooks._registered_callbacks:
607+
callbacks = self.hooks._registered_callbacks[AfterToolInvocationEvent]
608+
if added_callback in callbacks:
609+
callbacks.remove(added_callback)
610+
611+
# Remove the tool we added
612+
if added_tool_name:
613+
if added_tool_name in self.tool_registry.registry:
614+
del self.tool_registry.registry[added_tool_name]
615+
if added_tool_name in self.tool_registry.dynamic_tools:
616+
del self.tool_registry.dynamic_tools[added_tool_name]
617+
618+
# Restore original messages if preserve_conversation is False
619+
if original_messages is not None:
620+
self.messages = original_messages
621+
622+
self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self))
490623

491624
async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: Any) -> AsyncIterator[Any]:
492625
"""Process a natural language prompt and yield events as an async iterator.

tests/strands/agent/test_agent_hooks.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -263,30 +263,38 @@ def test_agent_structured_output_hooks(agent, hook_provider, user, agenerator):
263263
"""Verify that the correct hook events are emitted as part of structured_output."""
264264

265265
agent.model.structured_output = Mock(return_value=agenerator([{"output": user}]))
266-
agent.structured_output(type(user), "example prompt")
266+
agent.structured_output(type(user), "example prompt", preserve_conversation=True)
267267

268268
length, events = hook_provider.get_events()
269+
events_list = list(events)
269270

270-
assert length == 2
271+
# With the new tool-based implementation, we get more events from the event loop
272+
# but we should still have BeforeInvocationEvent first and AfterInvocationEvent last
273+
assert length > 2 # More events due to event loop execution
271274

272-
assert next(events) == BeforeInvocationEvent(agent=agent)
273-
assert next(events) == AfterInvocationEvent(agent=agent)
275+
assert events_list[0] == BeforeInvocationEvent(agent=agent)
276+
assert events_list[-1] == AfterInvocationEvent(agent=agent)
274277

275-
assert len(agent.messages) == 0 # no new messages added
278+
# With the new tool-based implementation, messages are added during the structured output process
279+
assert len(agent.messages) > 0 # messages added during structured output execution
276280

277281

278282
@pytest.mark.asyncio
279283
async def test_agent_structured_async_output_hooks(agent, hook_provider, user, agenerator):
280284
"""Verify that the correct hook events are emitted as part of structured_output_async."""
281285

282286
agent.model.structured_output = Mock(return_value=agenerator([{"output": user}]))
283-
await agent.structured_output_async(type(user), "example prompt")
287+
await agent.structured_output_async(type(user), "example prompt", preserve_conversation=True)
284288

285289
length, events = hook_provider.get_events()
290+
events_list = list(events)
286291

287-
assert length == 2
292+
# With the new tool-based implementation, we get more events from the event loop
293+
# but we should still have BeforeInvocationEvent first and AfterInvocationEvent last
294+
assert length > 2 # More events due to event loop execution
288295

289-
assert next(events) == BeforeInvocationEvent(agent=agent)
290-
assert next(events) == AfterInvocationEvent(agent=agent)
296+
assert events_list[0] == BeforeInvocationEvent(agent=agent)
297+
assert events_list[-1] == AfterInvocationEvent(agent=agent)
291298

292-
assert len(agent.messages) == 0 # no new messages added
299+
# With the new tool-based implementation, messages are added during the structured output process
300+
assert len(agent.messages) > 0 # messages added during structured output execution

0 commit comments

Comments
 (0)