10
10
"""
11
11
12
12
import asyncio
13
+ import copy
13
14
import json
14
15
import logging
15
16
import random
16
17
from concurrent .futures import ThreadPoolExecutor
18
+ from contextlib import suppress
17
19
from typing import Any , AsyncGenerator , AsyncIterator , Callable , Mapping , Optional , Type , TypeVar , Union , cast
18
20
19
21
from opentelemetry import trace as trace_api
20
22
from pydantic import BaseModel
21
23
22
24
from .. import _identifier
23
25
from ..event_loop .event_loop import event_loop_cycle , run_tool
26
+ from ..experimental .hooks import AfterToolInvocationEvent
24
27
from ..handlers .callback_handler import PrintingCallbackHandler , null_callback_handler
25
28
from ..hooks import (
26
29
AfterInvocationEvent ,
34
37
from ..models .model import Model
35
38
from ..session .session_manager import SessionManager
36
39
from ..telemetry .metrics import EventLoopMetrics
37
- from ..telemetry .tracer import get_tracer , serialize
40
+ from ..telemetry .tracer import get_tracer
41
+ from ..tools .decorator import tool
38
42
from ..tools .registry import ToolRegistry
39
43
from ..tools .watcher import ToolWatcher
40
44
from ..types .content import ContentBlock , Message , Messages
@@ -404,7 +408,12 @@ async def invoke_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A
404
408
405
409
return cast (AgentResult , event ["result" ])
406
410
407
- def structured_output (self , output_model : Type [T ], prompt : Optional [Union [str , list [ContentBlock ]]] = None ) -> T :
411
+ def structured_output (
412
+ self ,
413
+ output_model : Type [T ],
414
+ prompt : Optional [Union [str , list [ContentBlock ]]] = None ,
415
+ preserve_conversation : bool = False ,
416
+ ) -> T :
408
417
"""This method allows you to get structured output from the agent.
409
418
410
419
If you pass in a prompt, it will be used temporarily without adding it to the conversation history.
@@ -417,20 +426,33 @@ def structured_output(self, output_model: Type[T], prompt: Optional[Union[str, l
417
426
output_model: The output model (a JSON schema written as a Pydantic BaseModel)
418
427
that the agent will use when responding.
419
428
prompt: The prompt to use for the agent (will not be added to conversation history).
429
+ preserve_conversation: If False (default), restores original conversation state after execution.
430
+ If True, allows structured output execution to modify conversation history.
420
431
421
432
Raises:
422
433
ValueError: If no conversation history or prompt is provided.
423
434
"""
424
435
425
436
def execute () -> T :
426
- return asyncio .run (self .structured_output_async (output_model , prompt ))
437
+ return asyncio .run (self .structured_output_async (output_model , prompt , preserve_conversation ))
427
438
428
439
with ThreadPoolExecutor () as executor :
429
440
future = executor .submit (execute )
430
441
return future .result ()
431
442
443
+ def _register_structured_output_tool (self , output_model : type [BaseModel ]) -> Any :
444
+ @tool
445
+ def _structured_output (input : output_model ) -> output_model : # type: ignore[valid-type]
446
+ """If this tool is present it MUST be used to return structured data for the user."""
447
+ return input
448
+
449
+ return _structured_output
450
+
432
451
async def structured_output_async (
433
- self , output_model : Type [T ], prompt : Optional [Union [str , list [ContentBlock ]]] = None
452
+ self ,
453
+ output_model : Type [T ],
454
+ prompt : Optional [Union [str , list [ContentBlock ]]] = None ,
455
+ preserve_conversation : bool = False ,
434
456
) -> T :
435
457
"""This method allows you to get structured output from the agent.
436
458
@@ -444,53 +466,158 @@ async def structured_output_async(
444
466
output_model: The output model (a JSON schema written as a Pydantic BaseModel)
445
467
that the agent will use when responding.
446
468
prompt: The prompt to use for the agent (will not be added to conversation history).
469
+ preserve_conversation: If False (default), restores original conversation state after execution.
470
+ If True, allows structured output execution to modify conversation history.
447
471
448
472
Raises:
449
473
ValueError: If no conversation history or prompt is provided.
450
474
"""
451
475
self .hooks .invoke_callbacks (BeforeInvocationEvent (agent = self ))
452
- with self .tracer .tracer .start_as_current_span (
453
- "execute_structured_output" , kind = trace_api .SpanKind .CLIENT
454
- ) as structured_output_span :
455
- try :
456
- if not self .messages and not prompt :
457
- raise ValueError ("No conversation history or prompt provided" )
458
- # Create temporary messages array if prompt is provided
459
- if prompt :
460
- content : list [ContentBlock ] = [{"text" : prompt }] if isinstance (prompt , str ) else prompt
461
- temp_messages = self .messages + [{"role" : "user" , "content" : content }]
462
- else :
463
- temp_messages = self .messages
464
-
465
- structured_output_span .set_attributes (
466
- {
467
- "gen_ai.system" : "strands-agents" ,
468
- "gen_ai.agent.name" : self .name ,
469
- "gen_ai.agent.id" : self .agent_id ,
470
- "gen_ai.operation.name" : "execute_structured_output" ,
471
- }
472
- )
473
- for message in temp_messages :
474
- structured_output_span .add_event (
475
- f"gen_ai.{ message ['role' ]} .message" ,
476
- attributes = {"role" : message ["role" ], "content" : serialize (message ["content" ])},
476
+
477
+ # Store references to what we'll add temporarily
478
+ added_tool_name = None
479
+ added_callback = None
480
+
481
+ # Save original messages if we need to restore them later
482
+ original_messages = copy .deepcopy (self .messages ) if not preserve_conversation else None
483
+
484
+ # Create and add the structured output tool
485
+ structured_output_tool = self ._register_structured_output_tool (output_model )
486
+ self .tool_registry .register_tool (structured_output_tool )
487
+ added_tool_name = structured_output_tool .tool_name
488
+
489
+ # Variable to capture the structured result
490
+ captured_result = None
491
+
492
+ # Hook to capture structured output tool invocation
493
+ def capture_structured_output_hook (event : AfterToolInvocationEvent ) -> None :
494
+ nonlocal captured_result
495
+
496
+ if (
497
+ event .selected_tool
498
+ and hasattr (event .selected_tool , "tool_name" )
499
+ and event .selected_tool .tool_name == "_structured_output"
500
+ and event .result
501
+ and event .result .get ("status" ) == "success"
502
+ ):
503
+ # Parse the validated Pydantic model from the tool result
504
+ with suppress (Exception ):
505
+ content = event .result .get ("content" , [])
506
+ if content and isinstance (content [0 ], dict ) and "text" in content [0 ]:
507
+ # The tool returns the model instance as string, but we need the actual instance
508
+ # Since our tool returns the input directly, we can reconstruct it
509
+ tool_input = event .tool_use .get ("input" , {}).get ("input" )
510
+ if tool_input :
511
+ captured_result = output_model (** tool_input )
512
+
513
+ # Add the callback temporarily (use add_callback, not add_hook)
514
+ self .hooks .add_callback (AfterToolInvocationEvent , capture_structured_output_hook )
515
+ added_callback = capture_structured_output_hook
516
+
517
+ try :
518
+ with self .tracer .tracer .start_as_current_span (
519
+ "execute_structured_output" , kind = trace_api .SpanKind .CLIENT
520
+ ) as structured_output_span :
521
+ try :
522
+ if not self .messages and not prompt :
523
+ raise ValueError ("No conversation history or prompt provided" )
524
+
525
+ # Create temporary messages array if prompt is provided
526
+ message : Message
527
+ if prompt :
528
+ content : list [ContentBlock ] = [{"text" : prompt }] if isinstance (prompt , str ) else prompt
529
+ message = {"role" : "user" , "content" : content }
530
+ else :
531
+ # Use existing conversation history
532
+ message = {
533
+ "role" : "user" ,
534
+ "content" : [
535
+ {
536
+ "text" : "Please provide the information from our conversation in the requested "
537
+ "structured format."
538
+ }
539
+ ],
540
+ }
541
+
542
+ structured_output_span .set_attributes (
543
+ {
544
+ "gen_ai.system" : "strands-agents" ,
545
+ "gen_ai.agent.name" : self .name ,
546
+ "gen_ai.agent.id" : self .agent_id ,
547
+ "gen_ai.operation.name" : "execute_structured_output" ,
548
+ }
477
549
)
478
- if self .system_prompt :
550
+
551
+ # Add tracing for messages
552
+ messages_to_trace = self .messages if not prompt else self .messages + [message ]
553
+ for msg in messages_to_trace :
554
+ structured_output_span .add_event (
555
+ f"gen_ai.{ msg ['role' ]} .message" ,
556
+ attributes = {"role" : msg ["role" ], "content" : serialize (msg ["content" ])},
557
+ )
558
+
559
+ if self .system_prompt :
560
+ structured_output_span .add_event (
561
+ "gen_ai.system.message" ,
562
+ attributes = {"role" : "system" , "content" : serialize ([{"text" : self .system_prompt }])},
563
+ )
564
+
565
+ invocation_state = {
566
+ "structured_output_mode" : True ,
567
+ "structured_output_model" : output_model ,
568
+ }
569
+
570
+ # Run the event loop
571
+ async for event in self ._run_loop (message = message , invocation_state = invocation_state ):
572
+ if "stop" in event :
573
+ break
574
+
575
+ # Return the captured structured result if we got it from the tool
576
+ if captured_result :
577
+ structured_output_span .add_event (
578
+ "gen_ai.choice" , attributes = {"message" : serialize (captured_result .model_dump ())}
579
+ )
580
+ return captured_result
581
+
582
+ # Fallback: Use the original model.structured_output approach
583
+ # This maintains backward compatibility with existing tests and implementations
584
+ # Use original_messages to get clean message state, or self.messages if preserve_conversation=True
585
+ base_messages = original_messages if original_messages is not None else self .messages
586
+ temp_messages = base_messages if not prompt else base_messages + [message ]
587
+
588
+ events = self .model .structured_output (output_model , temp_messages , system_prompt = self .system_prompt )
589
+ async for event in events :
590
+ if "callback" in event :
591
+ self .callback_handler (** cast (dict , event ["callback" ]))
592
+
479
593
structured_output_span .add_event (
480
- "gen_ai.system.message" ,
481
- attributes = {"role" : "system" , "content" : serialize ([{"text" : self .system_prompt }])},
594
+ "gen_ai.choice" , attributes = {"message" : serialize (event ["output" ].model_dump ())}
482
595
)
483
- events = self .model .structured_output (output_model , temp_messages , system_prompt = self .system_prompt )
484
- async for event in events :
485
- if "callback" in event :
486
- self .callback_handler (** cast (dict , event ["callback" ]))
487
- structured_output_span .add_event (
488
- "gen_ai.choice" , attributes = {"message" : serialize (event ["output" ].model_dump ())}
489
- )
490
- return event ["output" ]
491
-
492
- finally :
493
- self .hooks .invoke_callbacks (AfterInvocationEvent (agent = self ))
596
+ return cast (T , event ["output" ])
597
+
598
+ except Exception as e :
599
+ structured_output_span .record_exception (e )
600
+ raise
601
+
602
+ finally :
603
+ # Clean up what we added - remove the callback
604
+ if added_callback is not None and AfterToolInvocationEvent in self .hooks ._registered_callbacks :
605
+ callbacks = self .hooks ._registered_callbacks [AfterToolInvocationEvent ]
606
+ if added_callback in callbacks :
607
+ callbacks .remove (added_callback )
608
+
609
+ # Remove the tool we added
610
+ if added_tool_name :
611
+ if added_tool_name in self .tool_registry .registry :
612
+ del self .tool_registry .registry [added_tool_name ]
613
+ if added_tool_name in self .tool_registry .dynamic_tools :
614
+ del self .tool_registry .dynamic_tools [added_tool_name ]
615
+
616
+ # Restore original messages if preserve_conversation is False
617
+ if original_messages is not None :
618
+ self .messages = original_messages
619
+
620
+ self .hooks .invoke_callbacks (AfterInvocationEvent (agent = self ))
494
621
495
622
async def stream_async (self , prompt : Union [str , list [ContentBlock ]], ** kwargs : Any ) -> AsyncIterator [Any ]:
496
623
"""Process a natural language prompt and yield events as an async iterator.
0 commit comments