10
10
"""
11
11
12
12
import asyncio
13
+ import copy
13
14
import json
14
15
import logging
15
16
import random
16
17
from concurrent .futures import ThreadPoolExecutor
18
+ from contextlib import suppress
17
19
from typing import Any , AsyncGenerator , AsyncIterator , Callable , Mapping , Optional , Type , TypeVar , Union , cast
18
20
19
21
from opentelemetry import trace as trace_api
20
22
from pydantic import BaseModel
21
23
24
+ from strands .tools .decorator import tool
25
+
22
26
from ..event_loop .event_loop import event_loop_cycle , run_tool
27
+ from ..experimental .hooks import AfterToolInvocationEvent
23
28
from ..handlers .callback_handler import PrintingCallbackHandler , null_callback_handler
24
29
from ..hooks import (
25
30
AfterInvocationEvent ,
@@ -400,7 +405,12 @@ async def invoke_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A
400
405
401
406
return cast (AgentResult , event ["result" ])
402
407
403
- def structured_output (self , output_model : Type [T ], prompt : Optional [Union [str , list [ContentBlock ]]] = None ) -> T :
408
+ def structured_output (
409
+ self ,
410
+ output_model : Type [T ],
411
+ prompt : Optional [Union [str , list [ContentBlock ]]] = None ,
412
+ preserve_conversation : bool = False ,
413
+ ) -> T :
404
414
"""This method allows you to get structured output from the agent.
405
415
406
416
If you pass in a prompt, it will be used temporarily without adding it to the conversation history.
@@ -413,20 +423,33 @@ def structured_output(self, output_model: Type[T], prompt: Optional[Union[str, l
413
423
output_model: The output model (a JSON schema written as a Pydantic BaseModel)
414
424
that the agent will use when responding.
415
425
prompt: The prompt to use for the agent (will not be added to conversation history).
426
+ preserve_conversation: If False (default), restores original conversation state after execution.
427
+ If True, allows structured output execution to modify conversation history.
416
428
417
429
Raises:
418
430
ValueError: If no conversation history or prompt is provided.
419
431
"""
420
432
421
433
def execute () -> T :
422
- return asyncio .run (self .structured_output_async (output_model , prompt ))
434
+ return asyncio .run (self .structured_output_async (output_model , prompt , preserve_conversation ))
423
435
424
436
with ThreadPoolExecutor () as executor :
425
437
future = executor .submit (execute )
426
438
return future .result ()
427
439
440
+ def _register_structured_output_tool (self , output_model : type [BaseModel ]) -> Any :
441
+ @tool
442
+ def _structured_output (input : output_model ) -> output_model : # type: ignore[valid-type]
443
+ """If this tool is present it MUST be used to return structured data for the user."""
444
+ return input
445
+
446
+ return _structured_output
447
+
428
448
async def structured_output_async (
429
- self , output_model : Type [T ], prompt : Optional [Union [str , list [ContentBlock ]]] = None
449
+ self ,
450
+ output_model : Type [T ],
451
+ prompt : Optional [Union [str , list [ContentBlock ]]] = None ,
452
+ preserve_conversation : bool = False ,
430
453
) -> T :
431
454
"""This method allows you to get structured output from the agent.
432
455
@@ -440,53 +463,163 @@ async def structured_output_async(
440
463
output_model: The output model (a JSON schema written as a Pydantic BaseModel)
441
464
that the agent will use when responding.
442
465
prompt: The prompt to use for the agent (will not be added to conversation history).
466
+ preserve_conversation: If False (default), restores original conversation state after execution.
467
+ If True, allows structured output execution to modify conversation history.
443
468
444
469
Raises:
445
470
ValueError: If no conversation history or prompt is provided.
446
471
"""
447
472
self .hooks .invoke_callbacks (BeforeInvocationEvent (agent = self ))
448
- with self .tracer .tracer .start_as_current_span (
449
- "execute_structured_output" , kind = trace_api .SpanKind .CLIENT
450
- ) as structured_output_span :
451
- try :
452
- if not self .messages and not prompt :
453
- raise ValueError ("No conversation history or prompt provided" )
454
- # Create temporary messages array if prompt is provided
455
- if prompt :
456
- content : list [ContentBlock ] = [{"text" : prompt }] if isinstance (prompt , str ) else prompt
457
- temp_messages = self .messages + [{"role" : "user" , "content" : content }]
458
- else :
459
- temp_messages = self .messages
460
-
461
- structured_output_span .set_attributes (
462
- {
463
- "gen_ai.system" : "strands-agents" ,
464
- "gen_ai.agent.name" : self .name ,
465
- "gen_ai.agent.id" : self .agent_id ,
466
- "gen_ai.operation.name" : "execute_structured_output" ,
467
- }
468
- )
469
- for message in temp_messages :
470
- structured_output_span .add_event (
471
- f"gen_ai.{ message ['role' ]} .message" ,
472
- attributes = {"role" : message ["role" ], "content" : serialize (message ["content" ])},
473
+
474
+ # Save original state for restoration - avoid deep copying callbacks that contain async generators
475
+ import copy
476
+ from contextlib import suppress
477
+ from ..experimental .hooks import AfterToolInvocationEvent
478
+
479
+ # Store references to what we'll add temporarily
480
+ added_tool_name = None
481
+ added_callback = None
482
+
483
+ # Save original messages if we need to restore them later
484
+ original_messages = copy .deepcopy (self .messages ) if not preserve_conversation else None
485
+
486
+ # Create and add the structured output tool
487
+ structured_output_tool = self ._register_structured_output_tool (output_model )
488
+ self .tool_registry .register_tool (structured_output_tool )
489
+ added_tool_name = structured_output_tool .tool_name
490
+
491
+ # Variable to capture the structured result
492
+ captured_result = None
493
+
494
+ # Hook to capture structured output tool invocation
495
+ def capture_structured_output_hook (event : AfterToolInvocationEvent ) -> None :
496
+ nonlocal captured_result
497
+
498
+ if (
499
+ event .selected_tool
500
+ and hasattr (event .selected_tool , "tool_name" )
501
+ and event .selected_tool .tool_name == "_structured_output"
502
+ and event .result
503
+ and event .result .get ("status" ) == "success"
504
+ ):
505
+ # Parse the validated Pydantic model from the tool result
506
+ with suppress (Exception ):
507
+ content = event .result .get ("content" , [])
508
+ if content and isinstance (content [0 ], dict ) and "text" in content [0 ]:
509
+ # The tool returns the model instance as string, but we need the actual instance
510
+ # Since our tool returns the input directly, we can reconstruct it
511
+ tool_input = event .tool_use .get ("input" , {}).get ("input" )
512
+ if tool_input :
513
+ captured_result = output_model (** tool_input )
514
+
515
+ # Add the callback temporarily (use add_callback, not add_hook)
516
+ self .hooks .add_callback (AfterToolInvocationEvent , capture_structured_output_hook )
517
+ added_callback = capture_structured_output_hook
518
+
519
+ try :
520
+ with self .tracer .tracer .start_as_current_span (
521
+ "execute_structured_output" , kind = trace_api .SpanKind .CLIENT
522
+ ) as structured_output_span :
523
+ try :
524
+ if not self .messages and not prompt :
525
+ raise ValueError ("No conversation history or prompt provided" )
526
+
527
+ # Create temporary messages array if prompt is provided
528
+ message : Message
529
+ if prompt :
530
+ content : list [ContentBlock ] = [{"text" : prompt }] if isinstance (prompt , str ) else prompt
531
+ message = {"role" : "user" , "content" : content }
532
+ else :
533
+ # Use existing conversation history
534
+ message = {
535
+ "role" : "user" ,
536
+ "content" : [
537
+ {
538
+ "text" : "Please provide the information from our conversation in the requested "
539
+ "structured format."
540
+ }
541
+ ],
542
+ }
543
+
544
+ structured_output_span .set_attributes (
545
+ {
546
+ "gen_ai.system" : "strands-agents" ,
547
+ "gen_ai.agent.name" : self .name ,
548
+ "gen_ai.agent.id" : self .agent_id ,
549
+ "gen_ai.operation.name" : "execute_structured_output" ,
550
+ }
473
551
)
474
- if self .system_prompt :
552
+
553
+ # Add tracing for messages
554
+ messages_to_trace = self .messages if not prompt else self .messages + [message ]
555
+ for msg in messages_to_trace :
556
+ structured_output_span .add_event (
557
+ f"gen_ai.{ msg ['role' ]} .message" ,
558
+ attributes = {"role" : msg ["role" ], "content" : serialize (msg ["content" ])},
559
+ )
560
+
561
+ if self .system_prompt :
562
+ structured_output_span .add_event (
563
+ "gen_ai.system.message" ,
564
+ attributes = {"role" : "system" , "content" : serialize ([{"text" : self .system_prompt }])},
565
+ )
566
+
567
+ invocation_state = {
568
+ "structured_output_mode" : True ,
569
+ "structured_output_model" : output_model ,
570
+ }
571
+
572
+ # Run the event loop
573
+ async for event in self ._run_loop (message = message , invocation_state = invocation_state ):
574
+ if "stop" in event :
575
+ break
576
+
577
+ # Return the captured structured result if we got it from the tool
578
+ if captured_result :
579
+ structured_output_span .add_event (
580
+ "gen_ai.choice" , attributes = {"message" : serialize (captured_result .model_dump ())}
581
+ )
582
+ return captured_result
583
+
584
+ # Fallback: Use the original model.structured_output approach
585
+ # This maintains backward compatibility with existing tests and implementations
586
+ # Use original_messages to get clean message state, or self.messages if preserve_conversation=True
587
+ base_messages = original_messages if original_messages is not None else self .messages
588
+ temp_messages = base_messages if not prompt else base_messages + [message ]
589
+
590
+ events = self .model .structured_output (output_model , temp_messages , system_prompt = self .system_prompt )
591
+ async for event in events :
592
+ if "callback" in event :
593
+ self .callback_handler (** cast (dict , event ["callback" ]))
594
+
475
595
structured_output_span .add_event (
476
- "gen_ai.system.message" ,
477
- attributes = {"role" : "system" , "content" : serialize ([{"text" : self .system_prompt }])},
596
+ "gen_ai.choice" , attributes = {"message" : serialize (event ["output" ].model_dump ())}
478
597
)
479
- events = self .model .structured_output (output_model , temp_messages , system_prompt = self .system_prompt )
480
- async for event in events :
481
- if "callback" in event :
482
- self .callback_handler (** cast (dict , event ["callback" ]))
483
- structured_output_span .add_event (
484
- "gen_ai.choice" , attributes = {"message" : serialize (event ["output" ].model_dump ())}
485
- )
486
- return event ["output" ]
487
-
488
- finally :
489
- self .hooks .invoke_callbacks (AfterInvocationEvent (agent = self ))
598
+ return cast (T , event ["output" ])
599
+
600
+ except Exception as e :
601
+ structured_output_span .record_exception (e )
602
+ raise
603
+
604
+ finally :
605
+ # Clean up what we added - remove the callback
606
+ if added_callback is not None and AfterToolInvocationEvent in self .hooks ._registered_callbacks :
607
+ callbacks = self .hooks ._registered_callbacks [AfterToolInvocationEvent ]
608
+ if added_callback in callbacks :
609
+ callbacks .remove (added_callback )
610
+
611
+ # Remove the tool we added
612
+ if added_tool_name :
613
+ if added_tool_name in self .tool_registry .registry :
614
+ del self .tool_registry .registry [added_tool_name ]
615
+ if added_tool_name in self .tool_registry .dynamic_tools :
616
+ del self .tool_registry .dynamic_tools [added_tool_name ]
617
+
618
+ # Restore original messages if preserve_conversation is False
619
+ if original_messages is not None :
620
+ self .messages = original_messages
621
+
622
+ self .hooks .invoke_callbacks (AfterInvocationEvent (agent = self ))
490
623
491
624
async def stream_async (self , prompt : Union [str , list [ContentBlock ]], ** kwargs : Any ) -> AsyncIterator [Any ]:
492
625
"""Process a natural language prompt and yield events as an async iterator.
0 commit comments