Skip to content

Commit 23c8cb2

Browse files
author
Kazmer, Nagy-Betegh
committed
feat(agent): make structured output part of the agent loop
1 parent 1c7257b commit 23c8cb2

File tree

2 files changed

+188
-52
lines changed

2 files changed

+188
-52
lines changed

src/strands/agent/agent.py

Lines changed: 170 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,21 @@
1010
"""
1111

1212
import asyncio
13+
import copy
1314
import json
1415
import logging
1516
import random
1617
from concurrent.futures import ThreadPoolExecutor
18+
from contextlib import suppress
1719
from typing import Any, AsyncGenerator, AsyncIterator, Callable, Mapping, Optional, Type, TypeVar, Union, cast
1820

1921
from opentelemetry import trace as trace_api
2022
from pydantic import BaseModel
2123

24+
from strands.tools.decorator import tool
25+
2226
from ..event_loop.event_loop import event_loop_cycle, run_tool
27+
from ..experimental.hooks import AfterToolInvocationEvent
2328
from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler
2429
from ..hooks import (
2530
AfterInvocationEvent,
@@ -400,7 +405,12 @@ async def invoke_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A
400405

401406
return cast(AgentResult, event["result"])
402407

403-
def structured_output(self, output_model: Type[T], prompt: Optional[Union[str, list[ContentBlock]]] = None) -> T:
408+
def structured_output(
409+
self,
410+
output_model: Type[T],
411+
prompt: Optional[Union[str, list[ContentBlock]]] = None,
412+
preserve_conversation: bool = False,
413+
) -> T:
404414
"""This method allows you to get structured output from the agent.
405415
406416
If you pass in a prompt, it will be used temporarily without adding it to the conversation history.
@@ -413,20 +423,33 @@ def structured_output(self, output_model: Type[T], prompt: Optional[Union[str, l
413423
output_model: The output model (a JSON schema written as a Pydantic BaseModel)
414424
that the agent will use when responding.
415425
prompt: The prompt to use for the agent (will not be added to conversation history).
426+
preserve_conversation: If False (default), restores original conversation state after execution.
427+
If True, allows structured output execution to modify conversation history.
416428
417429
Raises:
418430
ValueError: If no conversation history or prompt is provided.
419431
"""
420432

421433
def execute() -> T:
422-
return asyncio.run(self.structured_output_async(output_model, prompt))
434+
return asyncio.run(self.structured_output_async(output_model, prompt, preserve_conversation))
423435

424436
with ThreadPoolExecutor() as executor:
425437
future = executor.submit(execute)
426438
return future.result()
427439

440+
def _register_structured_output_tool(self, output_model: type[BaseModel]) -> Any:
441+
@tool
442+
def _structured_output(input: output_model) -> output_model: # type: ignore[valid-type]
443+
"""If this tool is present it MUST be used to return structured data for the user."""
444+
return input
445+
446+
return _structured_output
447+
428448
async def structured_output_async(
429-
self, output_model: Type[T], prompt: Optional[Union[str, list[ContentBlock]]] = None
449+
self,
450+
output_model: Type[T],
451+
prompt: Optional[Union[str, list[ContentBlock]]] = None,
452+
preserve_conversation: bool = False,
430453
) -> T:
431454
"""This method allows you to get structured output from the agent.
432455
@@ -440,53 +463,158 @@ async def structured_output_async(
440463
output_model: The output model (a JSON schema written as a Pydantic BaseModel)
441464
that the agent will use when responding.
442465
prompt: The prompt to use for the agent (will not be added to conversation history).
466+
preserve_conversation: If False (default), restores original conversation state after execution.
467+
If True, allows structured output execution to modify conversation history.
443468
444469
Raises:
445470
ValueError: If no conversation history or prompt is provided.
446471
"""
447472
self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self))
448-
with self.tracer.tracer.start_as_current_span(
449-
"execute_structured_output", kind=trace_api.SpanKind.CLIENT
450-
) as structured_output_span:
451-
try:
452-
if not self.messages and not prompt:
453-
raise ValueError("No conversation history or prompt provided")
454-
# Create temporary messages array if prompt is provided
455-
if prompt:
456-
content: list[ContentBlock] = [{"text": prompt}] if isinstance(prompt, str) else prompt
457-
temp_messages = self.messages + [{"role": "user", "content": content}]
458-
else:
459-
temp_messages = self.messages
460-
461-
structured_output_span.set_attributes(
462-
{
463-
"gen_ai.system": "strands-agents",
464-
"gen_ai.agent.name": self.name,
465-
"gen_ai.agent.id": self.agent_id,
466-
"gen_ai.operation.name": "execute_structured_output",
467-
}
468-
)
469-
for message in temp_messages:
470-
structured_output_span.add_event(
471-
f"gen_ai.{message['role']}.message",
472-
attributes={"role": message["role"], "content": serialize(message["content"])},
473+
474+
# Store references to what we'll add temporarily
475+
added_tool_name = None
476+
added_callback = None
477+
478+
# Save original messages if we need to restore them later
479+
original_messages = copy.deepcopy(self.messages) if not preserve_conversation else None
480+
481+
# Create and add the structured output tool
482+
structured_output_tool = self._register_structured_output_tool(output_model)
483+
self.tool_registry.register_tool(structured_output_tool)
484+
added_tool_name = structured_output_tool.tool_name
485+
486+
# Variable to capture the structured result
487+
captured_result = None
488+
489+
# Hook to capture structured output tool invocation
490+
def capture_structured_output_hook(event: AfterToolInvocationEvent) -> None:
491+
nonlocal captured_result
492+
493+
if (
494+
event.selected_tool
495+
and hasattr(event.selected_tool, "tool_name")
496+
and event.selected_tool.tool_name == "_structured_output"
497+
and event.result
498+
and event.result.get("status") == "success"
499+
):
500+
# Parse the validated Pydantic model from the tool result
501+
with suppress(Exception):
502+
content = event.result.get("content", [])
503+
if content and isinstance(content[0], dict) and "text" in content[0]:
504+
# The tool returns the model instance as string, but we need the actual instance
505+
# Since our tool returns the input directly, we can reconstruct it
506+
tool_input = event.tool_use.get("input", {}).get("input")
507+
if tool_input:
508+
captured_result = output_model(**tool_input)
509+
510+
# Add the callback temporarily (use add_callback, not add_hook)
511+
self.hooks.add_callback(AfterToolInvocationEvent, capture_structured_output_hook)
512+
added_callback = capture_structured_output_hook
513+
514+
try:
515+
with self.tracer.tracer.start_as_current_span(
516+
"execute_structured_output", kind=trace_api.SpanKind.CLIENT
517+
) as structured_output_span:
518+
try:
519+
if not self.messages and not prompt:
520+
raise ValueError("No conversation history or prompt provided")
521+
522+
# Create temporary messages array if prompt is provided
523+
message: Message
524+
if prompt:
525+
content: list[ContentBlock] = [{"text": prompt}] if isinstance(prompt, str) else prompt
526+
message = {"role": "user", "content": content}
527+
else:
528+
# Use existing conversation history
529+
message = {
530+
"role": "user",
531+
"content": [
532+
{
533+
"text": "Please provide the information from our conversation in the requested "
534+
"structured format."
535+
}
536+
],
537+
}
538+
539+
structured_output_span.set_attributes(
540+
{
541+
"gen_ai.system": "strands-agents",
542+
"gen_ai.agent.name": self.name,
543+
"gen_ai.agent.id": self.agent_id,
544+
"gen_ai.operation.name": "execute_structured_output",
545+
}
473546
)
474-
if self.system_prompt:
547+
548+
# Add tracing for messages
549+
messages_to_trace = self.messages if not prompt else self.messages + [message]
550+
for msg in messages_to_trace:
551+
structured_output_span.add_event(
552+
f"gen_ai.{msg['role']}.message",
553+
attributes={"role": msg["role"], "content": serialize(msg["content"])},
554+
)
555+
556+
if self.system_prompt:
557+
structured_output_span.add_event(
558+
"gen_ai.system.message",
559+
attributes={"role": "system", "content": serialize([{"text": self.system_prompt}])},
560+
)
561+
562+
invocation_state = {
563+
"structured_output_mode": True,
564+
"structured_output_model": output_model,
565+
}
566+
567+
# Run the event loop
568+
async for event in self._run_loop(message=message, invocation_state=invocation_state):
569+
if "stop" in event:
570+
break
571+
572+
# Return the captured structured result if we got it from the tool
573+
if captured_result:
574+
structured_output_span.add_event(
575+
"gen_ai.choice", attributes={"message": serialize(captured_result.model_dump())}
576+
)
577+
return captured_result
578+
579+
# Fallback: Use the original model.structured_output approach
580+
# This maintains backward compatibility with existing tests and implementations
581+
# Use original_messages to get clean message state, or self.messages if preserve_conversation=True
582+
base_messages = original_messages if original_messages is not None else self.messages
583+
temp_messages = base_messages if not prompt else base_messages + [message]
584+
585+
events = self.model.structured_output(output_model, temp_messages, system_prompt=self.system_prompt)
586+
async for event in events:
587+
if "callback" in event:
588+
self.callback_handler(**cast(dict, event["callback"]))
589+
475590
structured_output_span.add_event(
476-
"gen_ai.system.message",
477-
attributes={"role": "system", "content": serialize([{"text": self.system_prompt}])},
591+
"gen_ai.choice", attributes={"message": serialize(event["output"].model_dump())}
478592
)
479-
events = self.model.structured_output(output_model, temp_messages, system_prompt=self.system_prompt)
480-
async for event in events:
481-
if "callback" in event:
482-
self.callback_handler(**cast(dict, event["callback"]))
483-
structured_output_span.add_event(
484-
"gen_ai.choice", attributes={"message": serialize(event["output"].model_dump())}
485-
)
486-
return event["output"]
487-
488-
finally:
489-
self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self))
593+
return cast(T, event["output"])
594+
595+
except Exception as e:
596+
structured_output_span.record_exception(e)
597+
raise
598+
599+
finally:
600+
# Clean up what we added - remove the callback
601+
if added_callback is not None and AfterToolInvocationEvent in self.hooks._registered_callbacks:
602+
callbacks = self.hooks._registered_callbacks[AfterToolInvocationEvent]
603+
if added_callback in callbacks:
604+
callbacks.remove(added_callback)
605+
606+
# Remove the tool we added
607+
if added_tool_name:
608+
if added_tool_name in self.tool_registry.registry:
609+
del self.tool_registry.registry[added_tool_name]
610+
if added_tool_name in self.tool_registry.dynamic_tools:
611+
del self.tool_registry.dynamic_tools[added_tool_name]
612+
613+
# Restore original messages if preserve_conversation is False
614+
if original_messages is not None:
615+
self.messages = original_messages
616+
617+
self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self))
490618

491619
async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: Any) -> AsyncIterator[Any]:
492620
"""Process a natural language prompt and yield events as an async iterator.

tests/strands/agent/test_agent_hooks.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -263,30 +263,38 @@ def test_agent_structured_output_hooks(agent, hook_provider, user, agenerator):
263263
"""Verify that the correct hook events are emitted as part of structured_output."""
264264

265265
agent.model.structured_output = Mock(return_value=agenerator([{"output": user}]))
266-
agent.structured_output(type(user), "example prompt")
266+
agent.structured_output(type(user), "example prompt", preserve_conversation=True)
267267

268268
length, events = hook_provider.get_events()
269+
events_list = list(events)
269270

270-
assert length == 2
271+
# With the new tool-based implementation, we get more events from the event loop
272+
# but we should still have BeforeInvocationEvent first and AfterInvocationEvent last
273+
assert length > 2 # More events due to event loop execution
271274

272-
assert next(events) == BeforeInvocationEvent(agent=agent)
273-
assert next(events) == AfterInvocationEvent(agent=agent)
275+
assert events_list[0] == BeforeInvocationEvent(agent=agent)
276+
assert events_list[-1] == AfterInvocationEvent(agent=agent)
274277

275-
assert len(agent.messages) == 0 # no new messages added
278+
# With the new tool-based implementation, messages are added during the structured output process
279+
assert len(agent.messages) > 0 # messages added during structured output execution
276280

277281

278282
@pytest.mark.asyncio
279283
async def test_agent_structured_async_output_hooks(agent, hook_provider, user, agenerator):
280284
"""Verify that the correct hook events are emitted as part of structured_output_async."""
281285

282286
agent.model.structured_output = Mock(return_value=agenerator([{"output": user}]))
283-
await agent.structured_output_async(type(user), "example prompt")
287+
await agent.structured_output_async(type(user), "example prompt", preserve_conversation=True)
284288

285289
length, events = hook_provider.get_events()
290+
events_list = list(events)
286291

287-
assert length == 2
292+
# With the new tool-based implementation, we get more events from the event loop
293+
# but we should still have BeforeInvocationEvent first and AfterInvocationEvent last
294+
assert length > 2 # More events due to event loop execution
288295

289-
assert next(events) == BeforeInvocationEvent(agent=agent)
290-
assert next(events) == AfterInvocationEvent(agent=agent)
296+
assert events_list[0] == BeforeInvocationEvent(agent=agent)
297+
assert events_list[-1] == AfterInvocationEvent(agent=agent)
291298

292-
assert len(agent.messages) == 0 # no new messages added
299+
# With the new tool-based implementation, messages are added during the structured output process
300+
assert len(agent.messages) > 0 # messages added during structured output execution

0 commit comments

Comments
 (0)