10
10
"""
11
11
12
12
import asyncio
13
+ import copy
13
14
import json
14
15
import logging
15
16
import random
16
17
from concurrent .futures import ThreadPoolExecutor
18
+ from contextlib import suppress
17
19
from typing import Any , AsyncGenerator , AsyncIterator , Callable , Mapping , Optional , Type , TypeVar , Union , cast
18
20
19
21
from opentelemetry import trace as trace_api
20
22
from pydantic import BaseModel
21
23
24
+ from strands .tools .decorator import tool
25
+
22
26
from ..event_loop .event_loop import event_loop_cycle , run_tool
27
+ from ..experimental .hooks import AfterToolInvocationEvent
23
28
from ..handlers .callback_handler import PrintingCallbackHandler , null_callback_handler
24
29
from ..hooks import (
25
30
AfterInvocationEvent ,
@@ -400,7 +405,12 @@ async def invoke_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A
400
405
401
406
return cast (AgentResult , event ["result" ])
402
407
403
- def structured_output (self , output_model : Type [T ], prompt : Optional [Union [str , list [ContentBlock ]]] = None ) -> T :
408
+ def structured_output (
409
+ self ,
410
+ output_model : Type [T ],
411
+ prompt : Optional [Union [str , list [ContentBlock ]]] = None ,
412
+ preserve_conversation : bool = False ,
413
+ ) -> T :
404
414
"""This method allows you to get structured output from the agent.
405
415
406
416
If you pass in a prompt, it will be used temporarily without adding it to the conversation history.
@@ -413,20 +423,33 @@ def structured_output(self, output_model: Type[T], prompt: Optional[Union[str, l
413
423
output_model: The output model (a JSON schema written as a Pydantic BaseModel)
414
424
that the agent will use when responding.
415
425
prompt: The prompt to use for the agent (will not be added to conversation history).
426
+ preserve_conversation: If False (default), restores original conversation state after execution.
427
+ If True, allows structured output execution to modify conversation history.
416
428
417
429
Raises:
418
430
ValueError: If no conversation history or prompt is provided.
419
431
"""
420
432
421
433
def execute () -> T :
422
- return asyncio .run (self .structured_output_async (output_model , prompt ))
434
+ return asyncio .run (self .structured_output_async (output_model , prompt , preserve_conversation ))
423
435
424
436
with ThreadPoolExecutor () as executor :
425
437
future = executor .submit (execute )
426
438
return future .result ()
427
439
440
+ def _register_structured_output_tool (self , output_model : type [BaseModel ]) -> Any :
441
+ @tool
442
+ def _structured_output (input : output_model ) -> output_model : # type: ignore[valid-type]
443
+ """If this tool is present it MUST be used to return structured data for the user."""
444
+ return input
445
+
446
+ return _structured_output
447
+
428
448
async def structured_output_async (
429
- self , output_model : Type [T ], prompt : Optional [Union [str , list [ContentBlock ]]] = None
449
+ self ,
450
+ output_model : Type [T ],
451
+ prompt : Optional [Union [str , list [ContentBlock ]]] = None ,
452
+ preserve_conversation : bool = False ,
430
453
) -> T :
431
454
"""This method allows you to get structured output from the agent.
432
455
@@ -440,53 +463,158 @@ async def structured_output_async(
440
463
output_model: The output model (a JSON schema written as a Pydantic BaseModel)
441
464
that the agent will use when responding.
442
465
prompt: The prompt to use for the agent (will not be added to conversation history).
466
+ preserve_conversation: If False (default), restores original conversation state after execution.
467
+ If True, allows structured output execution to modify conversation history.
443
468
444
469
Raises:
445
470
ValueError: If no conversation history or prompt is provided.
446
471
"""
447
472
self .hooks .invoke_callbacks (BeforeInvocationEvent (agent = self ))
448
- with self .tracer .tracer .start_as_current_span (
449
- "execute_structured_output" , kind = trace_api .SpanKind .CLIENT
450
- ) as structured_output_span :
451
- try :
452
- if not self .messages and not prompt :
453
- raise ValueError ("No conversation history or prompt provided" )
454
- # Create temporary messages array if prompt is provided
455
- if prompt :
456
- content : list [ContentBlock ] = [{"text" : prompt }] if isinstance (prompt , str ) else prompt
457
- temp_messages = self .messages + [{"role" : "user" , "content" : content }]
458
- else :
459
- temp_messages = self .messages
460
-
461
- structured_output_span .set_attributes (
462
- {
463
- "gen_ai.system" : "strands-agents" ,
464
- "gen_ai.agent.name" : self .name ,
465
- "gen_ai.agent.id" : self .agent_id ,
466
- "gen_ai.operation.name" : "execute_structured_output" ,
467
- }
468
- )
469
- for message in temp_messages :
470
- structured_output_span .add_event (
471
- f"gen_ai.{ message ['role' ]} .message" ,
472
- attributes = {"role" : message ["role" ], "content" : serialize (message ["content" ])},
473
+
474
+ # Store references to what we'll add temporarily
475
+ added_tool_name = None
476
+ added_callback = None
477
+
478
+ # Save original messages if we need to restore them later
479
+ original_messages = copy .deepcopy (self .messages ) if not preserve_conversation else None
480
+
481
+ # Create and add the structured output tool
482
+ structured_output_tool = self ._register_structured_output_tool (output_model )
483
+ self .tool_registry .register_tool (structured_output_tool )
484
+ added_tool_name = structured_output_tool .tool_name
485
+
486
+ # Variable to capture the structured result
487
+ captured_result = None
488
+
489
+ # Hook to capture structured output tool invocation
490
+ def capture_structured_output_hook (event : AfterToolInvocationEvent ) -> None :
491
+ nonlocal captured_result
492
+
493
+ if (
494
+ event .selected_tool
495
+ and hasattr (event .selected_tool , "tool_name" )
496
+ and event .selected_tool .tool_name == "_structured_output"
497
+ and event .result
498
+ and event .result .get ("status" ) == "success"
499
+ ):
500
+ # Parse the validated Pydantic model from the tool result
501
+ with suppress (Exception ):
502
+ content = event .result .get ("content" , [])
503
+ if content and isinstance (content [0 ], dict ) and "text" in content [0 ]:
504
+ # The tool returns the model instance as string, but we need the actual instance
505
+ # Since our tool returns the input directly, we can reconstruct it
506
+ tool_input = event .tool_use .get ("input" , {}).get ("input" )
507
+ if tool_input :
508
+ captured_result = output_model (** tool_input )
509
+
510
+ # Add the callback temporarily (use add_callback, not add_hook)
511
+ self .hooks .add_callback (AfterToolInvocationEvent , capture_structured_output_hook )
512
+ added_callback = capture_structured_output_hook
513
+
514
+ try :
515
+ with self .tracer .tracer .start_as_current_span (
516
+ "execute_structured_output" , kind = trace_api .SpanKind .CLIENT
517
+ ) as structured_output_span :
518
+ try :
519
+ if not self .messages and not prompt :
520
+ raise ValueError ("No conversation history or prompt provided" )
521
+
522
+ # Create temporary messages array if prompt is provided
523
+ message : Message
524
+ if prompt :
525
+ content : list [ContentBlock ] = [{"text" : prompt }] if isinstance (prompt , str ) else prompt
526
+ message = {"role" : "user" , "content" : content }
527
+ else :
528
+ # Use existing conversation history
529
+ message = {
530
+ "role" : "user" ,
531
+ "content" : [
532
+ {
533
+ "text" : "Please provide the information from our conversation in the requested "
534
+ "structured format."
535
+ }
536
+ ],
537
+ }
538
+
539
+ structured_output_span .set_attributes (
540
+ {
541
+ "gen_ai.system" : "strands-agents" ,
542
+ "gen_ai.agent.name" : self .name ,
543
+ "gen_ai.agent.id" : self .agent_id ,
544
+ "gen_ai.operation.name" : "execute_structured_output" ,
545
+ }
473
546
)
474
- if self .system_prompt :
547
+
548
+ # Add tracing for messages
549
+ messages_to_trace = self .messages if not prompt else self .messages + [message ]
550
+ for msg in messages_to_trace :
551
+ structured_output_span .add_event (
552
+ f"gen_ai.{ msg ['role' ]} .message" ,
553
+ attributes = {"role" : msg ["role" ], "content" : serialize (msg ["content" ])},
554
+ )
555
+
556
+ if self .system_prompt :
557
+ structured_output_span .add_event (
558
+ "gen_ai.system.message" ,
559
+ attributes = {"role" : "system" , "content" : serialize ([{"text" : self .system_prompt }])},
560
+ )
561
+
562
+ invocation_state = {
563
+ "structured_output_mode" : True ,
564
+ "structured_output_model" : output_model ,
565
+ }
566
+
567
+ # Run the event loop
568
+ async for event in self ._run_loop (message = message , invocation_state = invocation_state ):
569
+ if "stop" in event :
570
+ break
571
+
572
+ # Return the captured structured result if we got it from the tool
573
+ if captured_result :
574
+ structured_output_span .add_event (
575
+ "gen_ai.choice" , attributes = {"message" : serialize (captured_result .model_dump ())}
576
+ )
577
+ return captured_result
578
+
579
+ # Fallback: Use the original model.structured_output approach
580
+ # This maintains backward compatibility with existing tests and implementations
581
+ # Use original_messages to get clean message state, or self.messages if preserve_conversation=True
582
+ base_messages = original_messages if original_messages is not None else self .messages
583
+ temp_messages = base_messages if not prompt else base_messages + [message ]
584
+
585
+ events = self .model .structured_output (output_model , temp_messages , system_prompt = self .system_prompt )
586
+ async for event in events :
587
+ if "callback" in event :
588
+ self .callback_handler (** cast (dict , event ["callback" ]))
589
+
475
590
structured_output_span .add_event (
476
- "gen_ai.system.message" ,
477
- attributes = {"role" : "system" , "content" : serialize ([{"text" : self .system_prompt }])},
591
+ "gen_ai.choice" , attributes = {"message" : serialize (event ["output" ].model_dump ())}
478
592
)
479
- events = self .model .structured_output (output_model , temp_messages , system_prompt = self .system_prompt )
480
- async for event in events :
481
- if "callback" in event :
482
- self .callback_handler (** cast (dict , event ["callback" ]))
483
- structured_output_span .add_event (
484
- "gen_ai.choice" , attributes = {"message" : serialize (event ["output" ].model_dump ())}
485
- )
486
- return event ["output" ]
487
-
488
- finally :
489
- self .hooks .invoke_callbacks (AfterInvocationEvent (agent = self ))
593
+ return cast (T , event ["output" ])
594
+
595
+ except Exception as e :
596
+ structured_output_span .record_exception (e )
597
+ raise
598
+
599
+ finally :
600
+ # Clean up what we added - remove the callback
601
+ if added_callback is not None and AfterToolInvocationEvent in self .hooks ._registered_callbacks :
602
+ callbacks = self .hooks ._registered_callbacks [AfterToolInvocationEvent ]
603
+ if added_callback in callbacks :
604
+ callbacks .remove (added_callback )
605
+
606
+ # Remove the tool we added
607
+ if added_tool_name :
608
+ if added_tool_name in self .tool_registry .registry :
609
+ del self .tool_registry .registry [added_tool_name ]
610
+ if added_tool_name in self .tool_registry .dynamic_tools :
611
+ del self .tool_registry .dynamic_tools [added_tool_name ]
612
+
613
+ # Restore original messages if preserve_conversation is False
614
+ if original_messages is not None :
615
+ self .messages = original_messages
616
+
617
+ self .hooks .invoke_callbacks (AfterInvocationEvent (agent = self ))
490
618
491
619
async def stream_async (self , prompt : Union [str , list [ContentBlock ]], ** kwargs : Any ) -> AsyncIterator [Any ]:
492
620
"""Process a natural language prompt and yield events as an async iterator.
0 commit comments