Skip to content

Commit cf37b62

Browse files
committed
feat(agent): make structured output part of the agent loop
1 parent 8c63d75 commit cf37b62

File tree

2 files changed

+188
-53
lines changed

2 files changed

+188
-53
lines changed

src/strands/agent/agent.py

Lines changed: 170 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,20 @@
1010
"""
1111

1212
import asyncio
13+
import copy
1314
import json
1415
import logging
1516
import random
1617
from concurrent.futures import ThreadPoolExecutor
18+
from contextlib import suppress
1719
from typing import Any, AsyncGenerator, AsyncIterator, Callable, Mapping, Optional, Type, TypeVar, Union, cast
1820

1921
from opentelemetry import trace as trace_api
2022
from pydantic import BaseModel
2123

2224
from .. import _identifier
2325
from ..event_loop.event_loop import event_loop_cycle, run_tool
26+
from ..experimental.hooks import AfterToolInvocationEvent
2427
from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler
2528
from ..hooks import (
2629
AfterInvocationEvent,
@@ -34,7 +37,8 @@
3437
from ..models.model import Model
3538
from ..session.session_manager import SessionManager
3639
from ..telemetry.metrics import EventLoopMetrics
37-
from ..telemetry.tracer import get_tracer, serialize
40+
from ..telemetry.tracer import get_tracer
41+
from ..tools.decorator import tool
3842
from ..tools.registry import ToolRegistry
3943
from ..tools.watcher import ToolWatcher
4044
from ..types.content import ContentBlock, Message, Messages
@@ -404,7 +408,12 @@ async def invoke_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A
404408

405409
return cast(AgentResult, event["result"])
406410

407-
def structured_output(self, output_model: Type[T], prompt: Optional[Union[str, list[ContentBlock]]] = None) -> T:
411+
def structured_output(
412+
self,
413+
output_model: Type[T],
414+
prompt: Optional[Union[str, list[ContentBlock]]] = None,
415+
preserve_conversation: bool = False,
416+
) -> T:
408417
"""This method allows you to get structured output from the agent.
409418
410419
If you pass in a prompt, it will be used temporarily without adding it to the conversation history.
@@ -417,20 +426,33 @@ def structured_output(self, output_model: Type[T], prompt: Optional[Union[str, l
417426
output_model: The output model (a JSON schema written as a Pydantic BaseModel)
418427
that the agent will use when responding.
419428
prompt: The prompt to use for the agent (will not be added to conversation history).
429+
preserve_conversation: If False (default), restores original conversation state after execution.
430+
If True, allows structured output execution to modify conversation history.
420431
421432
Raises:
422433
ValueError: If no conversation history or prompt is provided.
423434
"""
424435

425436
def execute() -> T:
426-
return asyncio.run(self.structured_output_async(output_model, prompt))
437+
return asyncio.run(self.structured_output_async(output_model, prompt, preserve_conversation))
427438

428439
with ThreadPoolExecutor() as executor:
429440
future = executor.submit(execute)
430441
return future.result()
431442

443+
def _register_structured_output_tool(self, output_model: type[BaseModel]) -> Any:
444+
@tool
445+
def _structured_output(input: output_model) -> output_model: # type: ignore[valid-type]
446+
"""If this tool is present it MUST be used to return structured data for the user."""
447+
return input
448+
449+
return _structured_output
450+
432451
async def structured_output_async(
433-
self, output_model: Type[T], prompt: Optional[Union[str, list[ContentBlock]]] = None
452+
self,
453+
output_model: Type[T],
454+
prompt: Optional[Union[str, list[ContentBlock]]] = None,
455+
preserve_conversation: bool = False,
434456
) -> T:
435457
"""This method allows you to get structured output from the agent.
436458
@@ -444,53 +466,158 @@ async def structured_output_async(
444466
output_model: The output model (a JSON schema written as a Pydantic BaseModel)
445467
that the agent will use when responding.
446468
prompt: The prompt to use for the agent (will not be added to conversation history).
469+
preserve_conversation: If False (default), restores original conversation state after execution.
470+
If True, allows structured output execution to modify conversation history.
447471
448472
Raises:
449473
ValueError: If no conversation history or prompt is provided.
450474
"""
451475
self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self))
452-
with self.tracer.tracer.start_as_current_span(
453-
"execute_structured_output", kind=trace_api.SpanKind.CLIENT
454-
) as structured_output_span:
455-
try:
456-
if not self.messages and not prompt:
457-
raise ValueError("No conversation history or prompt provided")
458-
# Create temporary messages array if prompt is provided
459-
if prompt:
460-
content: list[ContentBlock] = [{"text": prompt}] if isinstance(prompt, str) else prompt
461-
temp_messages = self.messages + [{"role": "user", "content": content}]
462-
else:
463-
temp_messages = self.messages
464-
465-
structured_output_span.set_attributes(
466-
{
467-
"gen_ai.system": "strands-agents",
468-
"gen_ai.agent.name": self.name,
469-
"gen_ai.agent.id": self.agent_id,
470-
"gen_ai.operation.name": "execute_structured_output",
471-
}
472-
)
473-
for message in temp_messages:
474-
structured_output_span.add_event(
475-
f"gen_ai.{message['role']}.message",
476-
attributes={"role": message["role"], "content": serialize(message["content"])},
476+
477+
# Store references to what we'll add temporarily
478+
added_tool_name = None
479+
added_callback = None
480+
481+
# Save original messages if we need to restore them later
482+
original_messages = copy.deepcopy(self.messages) if not preserve_conversation else None
483+
484+
# Create and add the structured output tool
485+
structured_output_tool = self._register_structured_output_tool(output_model)
486+
self.tool_registry.register_tool(structured_output_tool)
487+
added_tool_name = structured_output_tool.tool_name
488+
489+
# Variable to capture the structured result
490+
captured_result = None
491+
492+
# Hook to capture structured output tool invocation
493+
def capture_structured_output_hook(event: AfterToolInvocationEvent) -> None:
494+
nonlocal captured_result
495+
496+
if (
497+
event.selected_tool
498+
and hasattr(event.selected_tool, "tool_name")
499+
and event.selected_tool.tool_name == "_structured_output"
500+
and event.result
501+
and event.result.get("status") == "success"
502+
):
503+
# Parse the validated Pydantic model from the tool result
504+
with suppress(Exception):
505+
content = event.result.get("content", [])
506+
if content and isinstance(content[0], dict) and "text" in content[0]:
507+
# The tool returns the model instance as string, but we need the actual instance
508+
# Since our tool returns the input directly, we can reconstruct it
509+
tool_input = event.tool_use.get("input", {}).get("input")
510+
if tool_input:
511+
captured_result = output_model(**tool_input)
512+
513+
# Add the callback temporarily (use add_callback, not add_hook)
514+
self.hooks.add_callback(AfterToolInvocationEvent, capture_structured_output_hook)
515+
added_callback = capture_structured_output_hook
516+
517+
try:
518+
with self.tracer.tracer.start_as_current_span(
519+
"execute_structured_output", kind=trace_api.SpanKind.CLIENT
520+
) as structured_output_span:
521+
try:
522+
if not self.messages and not prompt:
523+
raise ValueError("No conversation history or prompt provided")
524+
525+
# Create temporary messages array if prompt is provided
526+
message: Message
527+
if prompt:
528+
content: list[ContentBlock] = [{"text": prompt}] if isinstance(prompt, str) else prompt
529+
message = {"role": "user", "content": content}
530+
else:
531+
# Use existing conversation history
532+
message = {
533+
"role": "user",
534+
"content": [
535+
{
536+
"text": "Please provide the information from our conversation in the requested "
537+
"structured format."
538+
}
539+
],
540+
}
541+
542+
structured_output_span.set_attributes(
543+
{
544+
"gen_ai.system": "strands-agents",
545+
"gen_ai.agent.name": self.name,
546+
"gen_ai.agent.id": self.agent_id,
547+
"gen_ai.operation.name": "execute_structured_output",
548+
}
477549
)
478-
if self.system_prompt:
550+
551+
# Add tracing for messages
552+
messages_to_trace = self.messages if not prompt else self.messages + [message]
553+
for msg in messages_to_trace:
554+
structured_output_span.add_event(
555+
f"gen_ai.{msg['role']}.message",
556+
attributes={"role": msg["role"], "content": serialize(msg["content"])},
557+
)
558+
559+
if self.system_prompt:
560+
structured_output_span.add_event(
561+
"gen_ai.system.message",
562+
attributes={"role": "system", "content": serialize([{"text": self.system_prompt}])},
563+
)
564+
565+
invocation_state = {
566+
"structured_output_mode": True,
567+
"structured_output_model": output_model,
568+
}
569+
570+
# Run the event loop
571+
async for event in self._run_loop(message=message, invocation_state=invocation_state):
572+
if "stop" in event:
573+
break
574+
575+
# Return the captured structured result if we got it from the tool
576+
if captured_result:
577+
structured_output_span.add_event(
578+
"gen_ai.choice", attributes={"message": serialize(captured_result.model_dump())}
579+
)
580+
return captured_result
581+
582+
# Fallback: Use the original model.structured_output approach
583+
# This maintains backward compatibility with existing tests and implementations
584+
# Use original_messages to get clean message state, or self.messages if preserve_conversation=True
585+
base_messages = original_messages if original_messages is not None else self.messages
586+
temp_messages = base_messages if not prompt else base_messages + [message]
587+
588+
events = self.model.structured_output(output_model, temp_messages, system_prompt=self.system_prompt)
589+
async for event in events:
590+
if "callback" in event:
591+
self.callback_handler(**cast(dict, event["callback"]))
592+
479593
structured_output_span.add_event(
480-
"gen_ai.system.message",
481-
attributes={"role": "system", "content": serialize([{"text": self.system_prompt}])},
594+
"gen_ai.choice", attributes={"message": serialize(event["output"].model_dump())}
482595
)
483-
events = self.model.structured_output(output_model, temp_messages, system_prompt=self.system_prompt)
484-
async for event in events:
485-
if "callback" in event:
486-
self.callback_handler(**cast(dict, event["callback"]))
487-
structured_output_span.add_event(
488-
"gen_ai.choice", attributes={"message": serialize(event["output"].model_dump())}
489-
)
490-
return event["output"]
491-
492-
finally:
493-
self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self))
596+
return cast(T, event["output"])
597+
598+
except Exception as e:
599+
structured_output_span.record_exception(e)
600+
raise
601+
602+
finally:
603+
# Clean up what we added - remove the callback
604+
if added_callback is not None and AfterToolInvocationEvent in self.hooks._registered_callbacks:
605+
callbacks = self.hooks._registered_callbacks[AfterToolInvocationEvent]
606+
if added_callback in callbacks:
607+
callbacks.remove(added_callback)
608+
609+
# Remove the tool we added
610+
if added_tool_name:
611+
if added_tool_name in self.tool_registry.registry:
612+
del self.tool_registry.registry[added_tool_name]
613+
if added_tool_name in self.tool_registry.dynamic_tools:
614+
del self.tool_registry.dynamic_tools[added_tool_name]
615+
616+
# Restore original messages if preserve_conversation is False
617+
if original_messages is not None:
618+
self.messages = original_messages
619+
620+
self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self))
494621

495622
async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: Any) -> AsyncIterator[Any]:
496623
"""Process a natural language prompt and yield events as an async iterator.

tests/strands/agent/test_agent_hooks.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -263,30 +263,38 @@ def test_agent_structured_output_hooks(agent, hook_provider, user, agenerator):
263263
"""Verify that the correct hook events are emitted as part of structured_output."""
264264

265265
agent.model.structured_output = Mock(return_value=agenerator([{"output": user}]))
266-
agent.structured_output(type(user), "example prompt")
266+
agent.structured_output(type(user), "example prompt", preserve_conversation=True)
267267

268268
length, events = hook_provider.get_events()
269+
events_list = list(events)
269270

270-
assert length == 2
271+
# With the new tool-based implementation, we get more events from the event loop
272+
# but we should still have BeforeInvocationEvent first and AfterInvocationEvent last
273+
assert length > 2 # More events due to event loop execution
271274

272-
assert next(events) == BeforeInvocationEvent(agent=agent)
273-
assert next(events) == AfterInvocationEvent(agent=agent)
275+
assert events_list[0] == BeforeInvocationEvent(agent=agent)
276+
assert events_list[-1] == AfterInvocationEvent(agent=agent)
274277

275-
assert len(agent.messages) == 0 # no new messages added
278+
# With the new tool-based implementation, messages are added during the structured output process
279+
assert len(agent.messages) > 0 # messages added during structured output execution
276280

277281

278282
@pytest.mark.asyncio
279283
async def test_agent_structured_async_output_hooks(agent, hook_provider, user, agenerator):
280284
"""Verify that the correct hook events are emitted as part of structured_output_async."""
281285

282286
agent.model.structured_output = Mock(return_value=agenerator([{"output": user}]))
283-
await agent.structured_output_async(type(user), "example prompt")
287+
await agent.structured_output_async(type(user), "example prompt", preserve_conversation=True)
284288

285289
length, events = hook_provider.get_events()
290+
events_list = list(events)
286291

287-
assert length == 2
292+
# With the new tool-based implementation, we get more events from the event loop
293+
# but we should still have BeforeInvocationEvent first and AfterInvocationEvent last
294+
assert length > 2 # More events due to event loop execution
288295

289-
assert next(events) == BeforeInvocationEvent(agent=agent)
290-
assert next(events) == AfterInvocationEvent(agent=agent)
296+
assert events_list[0] == BeforeInvocationEvent(agent=agent)
297+
assert events_list[-1] == AfterInvocationEvent(agent=agent)
291298

292-
assert len(agent.messages) == 0 # no new messages added
299+
# With the new tool-based implementation, messages are added during the structured output process
300+
assert len(agent.messages) > 0 # messages added during structured output execution

0 commit comments

Comments
 (0)