diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 5150060c6..52aea2683 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -354,14 +354,21 @@ def tool_names(self) -> list[str]: all_tools = self.tool_registry.get_all_tools_config() return list(all_tools.keys()) - def __call__(self, prompt: Union[str, list[ContentBlock]], **kwargs: Any) -> AgentResult: + def __call__(self, prompt: str | list[ContentBlock] | Messages | None = None, **kwargs: Any) -> AgentResult: """Process a natural language prompt through the agent's event loop. - This method implements the conversational interface (e.g., `agent("hello!")`). It adds the user's prompt to - the conversation history, processes it through the model, executes any tool calls, and returns the final result. + This method implements the conversational interface with multiple input patterns: + - String input: `agent("hello!")` + - ContentBlock list: `agent([{"text": "hello"}, {"image": {...}}])` + - Message list: `agent([{"role": "user", "content": [{"text": "hello"}]}])` + - No input: `agent()` - uses existing conversation history Args: - prompt: User input as text or list of ContentBlock objects for multi-modal content. + prompt: User input in various formats: + - str: Simple text input + - list[ContentBlock]: Multi-modal content blocks + - list[Message]: Complete messages with roles + - None: Use existing conversation history **kwargs: Additional parameters to pass through the event loop. Returns: @@ -380,14 +387,23 @@ def execute() -> AgentResult: future = executor.submit(execute) return future.result() - async def invoke_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: Any) -> AgentResult: + async def invoke_async( + self, prompt: str | list[ContentBlock] | Messages | None = None, **kwargs: Any + ) -> AgentResult: """Process a natural language prompt through the agent's event loop. - This method implements the conversational interface (e.g., `agent("hello!")`). It adds the user's prompt to - the conversation history, processes it through the model, executes any tool calls, and returns the final result. + This method implements the conversational interface with multiple input patterns: + - String input: Simple text input + - ContentBlock list: Multi-modal content blocks + - Message list: Complete messages with roles + - No input: Use existing conversation history Args: - prompt: User input as text or list of ContentBlock objects for multi-modal content. + prompt: User input in various formats: + - str: Simple text input + - list[ContentBlock]: Multi-modal content blocks + - list[Message]: Complete messages with roles + - None: Use existing conversation history **kwargs: Additional parameters to pass through the event loop. Returns: @@ -404,7 +420,7 @@ async def invoke_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A return cast(AgentResult, event["result"]) - def structured_output(self, output_model: Type[T], prompt: Optional[Union[str, list[ContentBlock]]] = None) -> T: + def structured_output(self, output_model: Type[T], prompt: str | list[ContentBlock] | Messages | None = None) -> T: """This method allows you to get structured output from the agent. If you pass in a prompt, it will be used temporarily without adding it to the conversation history. @@ -416,7 +432,11 @@ def structured_output(self, output_model: Type[T], prompt: Optional[Union[str, l Args: output_model: The output model (a JSON schema written as a Pydantic BaseModel) that the agent will use when responding. - prompt: The prompt to use for the agent (will not be added to conversation history). + prompt: The prompt to use for the agent in various formats: + - str: Simple text input + - list[ContentBlock]: Multi-modal content blocks + - list[Message]: Complete messages with roles + - None: Use existing conversation history Raises: ValueError: If no conversation history or prompt is provided. @@ -430,7 +450,7 @@ def execute() -> T: return future.result() async def structured_output_async( - self, output_model: Type[T], prompt: Optional[Union[str, list[ContentBlock]]] = None + self, output_model: Type[T], prompt: str | list[ContentBlock] | Messages | None = None ) -> T: """This method allows you to get structured output from the agent. @@ -455,12 +475,8 @@ async def structured_output_async( try: if not self.messages and not prompt: raise ValueError("No conversation history or prompt provided") - # Create temporary messages array if prompt is provided - if prompt: - content: list[ContentBlock] = [{"text": prompt}] if isinstance(prompt, str) else prompt - temp_messages = self.messages + [{"role": "user", "content": content}] - else: - temp_messages = self.messages + + temp_messages: Messages = self.messages + self._convert_prompt_to_messages(prompt) structured_output_span.set_attributes( { @@ -492,16 +508,25 @@ async def structured_output_async( finally: self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self)) - async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: Any) -> AsyncIterator[Any]: + async def stream_async( + self, + prompt: str | list[ContentBlock] | Messages | None = None, + **kwargs: Any, + ) -> AsyncIterator[Any]: """Process a natural language prompt and yield events as an async iterator. - This method provides an asynchronous interface for streaming agent events, allowing - consumers to process stream events programmatically through an async iterator pattern - rather than callback functions. This is particularly useful for web servers and other - async environments. + This method provides an asynchronous interface for streaming agent events with multiple input patterns: + - String input: Simple text input + - ContentBlock list: Multi-modal content blocks + - Message list: Complete messages with roles + - No input: Use existing conversation history Args: - prompt: User input as text or list of ContentBlock objects for multi-modal content. + prompt: User input in various formats: + - str: Simple text input + - list[ContentBlock]: Multi-modal content blocks + - list[Message]: Complete messages with roles + - None: Use existing conversation history **kwargs: Additional parameters to pass to the event loop. Yields: @@ -525,13 +550,15 @@ async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A """ callback_handler = kwargs.get("callback_handler", self.callback_handler) - content: list[ContentBlock] = [{"text": prompt}] if isinstance(prompt, str) else prompt - message: Message = {"role": "user", "content": content} + # Process input and get message to add (if any) + messages = self._convert_prompt_to_messages(prompt) + + self.trace_span = self._start_agent_trace_span(messages) - self.trace_span = self._start_agent_trace_span(message) with trace_api.use_span(self.trace_span): try: - events = self._run_loop(message, invocation_state=kwargs) + events = self._run_loop(messages, invocation_state=kwargs) + async for event in events: if "callback" in event: callback_handler(**event["callback"]) @@ -548,12 +575,12 @@ async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A raise async def _run_loop( - self, message: Message, invocation_state: dict[str, Any] + self, messages: Messages, invocation_state: dict[str, Any] ) -> AsyncGenerator[dict[str, Any], None]: """Execute the agent's event loop with the given message and parameters. Args: - message: The user message to add to the conversation. + messages: The input messages to add to the conversation. invocation_state: Additional parameters to pass to the event loop. Yields: @@ -564,7 +591,8 @@ async def _run_loop( try: yield {"callback": {"init_event_loop": True, **invocation_state}} - self._append_message(message) + for message in messages: + self._append_message(message) # Execute the event loop cycle with retry logic for context limits events = self._execute_event_loop_cycle(invocation_state) @@ -622,6 +650,34 @@ async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> A async for event in events: yield event + def _convert_prompt_to_messages(self, prompt: str | list[ContentBlock] | Messages | None) -> Messages: + messages: Messages | None = None + if prompt is not None: + if isinstance(prompt, str): + # String input - convert to user message + messages = [{"role": "user", "content": [{"text": prompt}]}] + elif isinstance(prompt, list): + if len(prompt) == 0: + # Empty list + messages = [] + # Check if all item in input list are dictionaries + elif all(isinstance(item, dict) for item in prompt): + # Check if all items are messages + if all(all(key in item for key in Message.__annotations__.keys()) for item in prompt): + # Messages input - add all messages to conversation + messages = cast(Messages, prompt) + + # Check if all items are content blocks + elif all(any(key in ContentBlock.__annotations__.keys() for key in item) for item in prompt): + # Treat as List[ContentBlock] input - convert to user message + # This allows invalid structures to be passed through to the model + messages = [{"role": "user", "content": cast(list[ContentBlock], prompt)}] + else: + messages = [] + if messages is None: + raise ValueError("Input prompt must be of type: `str | list[Contentblock] | Messages | None`.") + return messages + def _record_tool_execution( self, tool: ToolUse, @@ -687,15 +743,15 @@ def _record_tool_execution( self._append_message(tool_result_msg) self._append_message(assistant_msg) - def _start_agent_trace_span(self, message: Message) -> trace_api.Span: + def _start_agent_trace_span(self, messages: Messages) -> trace_api.Span: """Starts a trace span for the agent. Args: - message: The user message. + messages: The input messages. """ model_id = self.model.config.get("model_id") if hasattr(self.model, "config") else None return self.tracer.start_agent_span( - message=message, + messages=messages, agent_name=self.name, model_id=model_id, tools=self.tool_names, diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index 802865189..6b429393d 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -408,7 +408,7 @@ def end_event_loop_cycle_span( def start_agent_span( self, - message: Message, + messages: Messages, agent_name: str, model_id: Optional[str] = None, tools: Optional[list] = None, @@ -418,7 +418,7 @@ def start_agent_span( """Start a new span for an agent invocation. Args: - message: The user message being sent to the agent. + messages: List of messages being sent to the agent. agent_name: Name of the agent. model_id: Optional model identifier. tools: Optional list of tools being used. @@ -451,13 +451,12 @@ def start_agent_span( span = self._start_span( f"invoke_agent {agent_name}", attributes=attributes, span_kind=trace_api.SpanKind.CLIENT ) - self._add_event( - span, - "gen_ai.user.message", - event_attributes={ - "content": serialize(message["content"]), - }, - ) + for message in messages: + self._add_event( + span, + f"gen_ai.{message['role']}.message", + {"content": serialize(message["content"])}, + ) return span diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 7e769c6d7..ae951c196 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -1360,12 +1360,12 @@ def test_agent_call_creates_and_ends_span_on_success(mock_get_tracer, mock_model # Verify span was created mock_tracer.start_agent_span.assert_called_once_with( + messages=[{"content": [{"text": "test prompt"}], "role": "user"}], agent_name="Strands Agents", - custom_trace_attributes=agent.trace_attributes, - message={"content": [{"text": "test prompt"}], "role": "user"}, model_id=unittest.mock.ANY, - system_prompt=agent.system_prompt, tools=agent.tool_names, + system_prompt=agent.system_prompt, + custom_trace_attributes=agent.trace_attributes, ) # Verify span was ended with the result @@ -1394,12 +1394,12 @@ async def test_event_loop(*args, **kwargs): # Verify span was created mock_tracer.start_agent_span.assert_called_once_with( - custom_trace_attributes=agent.trace_attributes, + messages=[{"content": [{"text": "test prompt"}], "role": "user"}], agent_name="Strands Agents", - message={"content": [{"text": "test prompt"}], "role": "user"}, model_id=unittest.mock.ANY, - system_prompt=agent.system_prompt, tools=agent.tool_names, + system_prompt=agent.system_prompt, + custom_trace_attributes=agent.trace_attributes, ) expected_response = AgentResult( @@ -1432,12 +1432,12 @@ def test_agent_call_creates_and_ends_span_on_exception(mock_get_tracer, mock_mod # Verify span was created mock_tracer.start_agent_span.assert_called_once_with( - custom_trace_attributes=agent.trace_attributes, + messages=[{"content": [{"text": "test prompt"}], "role": "user"}], agent_name="Strands Agents", - message={"content": [{"text": "test prompt"}], "role": "user"}, model_id=unittest.mock.ANY, - system_prompt=agent.system_prompt, tools=agent.tool_names, + system_prompt=agent.system_prompt, + custom_trace_attributes=agent.trace_attributes, ) # Verify span was ended with the exception @@ -1468,12 +1468,12 @@ async def test_agent_stream_async_creates_and_ends_span_on_exception(mock_get_tr # Verify span was created mock_tracer.start_agent_span.assert_called_once_with( + messages=[{"content": [{"text": "test prompt"}], "role": "user"}], agent_name="Strands Agents", - custom_trace_attributes=agent.trace_attributes, - message={"content": [{"text": "test prompt"}], "role": "user"}, model_id=unittest.mock.ANY, - system_prompt=agent.system_prompt, tools=agent.tool_names, + system_prompt=agent.system_prompt, + custom_trace_attributes=agent.trace_attributes, ) # Verify span was ended with the exception @@ -1801,6 +1801,63 @@ def test_agent_tool_record_direct_tool_call_disabled_with_non_serializable(agent assert len(agent.messages) == 0 +def test_agent_empty_invoke(): + model = MockedModelProvider([{"role": "assistant", "content": [{"text": "hello!"}]}]) + agent = Agent(model=model, messages=[{"role": "user", "content": [{"text": "hello!"}]}]) + result = agent() + assert str(result) == "hello!\n" + assert len(agent.messages) == 2 + + +def test_agent_empty_list_invoke(): + model = MockedModelProvider([{"role": "assistant", "content": [{"text": "hello!"}]}]) + agent = Agent(model=model, messages=[{"role": "user", "content": [{"text": "hello!"}]}]) + result = agent([]) + assert str(result) == "hello!\n" + assert len(agent.messages) == 2 + + +def test_agent_with_assistant_role_message(): + model = MockedModelProvider([{"role": "assistant", "content": [{"text": "world!"}]}]) + agent = Agent(model=model) + assistant_message = [{"role": "assistant", "content": [{"text": "hello..."}]}] + result = agent(assistant_message) + assert str(result) == "world!\n" + assert len(agent.messages) == 2 + + +def test_agent_with_multiple_messages_on_invoke(): + model = MockedModelProvider([{"role": "assistant", "content": [{"text": "world!"}]}]) + agent = Agent(model=model) + input_messages = [ + {"role": "user", "content": [{"text": "hello"}]}, + {"role": "assistant", "content": [{"text": "..."}]}, + ] + result = agent(input_messages) + assert str(result) == "world!\n" + assert len(agent.messages) == 3 + + +def test_agent_with_invalid_input(): + model = MockedModelProvider([{"role": "assistant", "content": [{"text": "world!"}]}]) + agent = Agent(model=model) + with pytest.raises(ValueError, match="Input prompt must be of type: `str | list[Contentblock] | Messages | None`."): + agent({"invalid": "input"}) + + +def test_agent_with_invalid_input_list(): + model = MockedModelProvider([{"role": "assistant", "content": [{"text": "world!"}]}]) + agent = Agent(model=model) + with pytest.raises(ValueError, match="Input prompt must be of type: `str | list[Contentblock] | Messages | None`."): + agent([{"invalid": "input"}]) + + +def test_agent_with_list_of_message_and_content_block(): + model = MockedModelProvider([{"role": "assistant", "content": [{"text": "world!"}]}]) + agent = Agent(model=model) + with pytest.raises(ValueError, match="Input prompt must be of type: `str | list[Contentblock] | Messages | None`."): + agent([{"role": "user", "content": [{"text": "hello"}]}, {"text", "hello"}]) + def test_agent_tool_call_parameter_filtering_integration(mock_randint): """Test that tool calls properly filter parameters in message recording.""" mock_randint.return_value = 42 @@ -1832,3 +1889,4 @@ def test_tool(action: str) -> str: assert '"action": "test_value"' in tool_call_text assert '"agent"' not in tool_call_text assert '"extra_param"' not in tool_call_text + diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index dcfce1211..586911bef 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -369,7 +369,7 @@ def test_start_agent_span(mock_tracer): span = tracer.start_agent_span( custom_trace_attributes=custom_attrs, agent_name="WeatherAgent", - message={"content": content, "role": "user"}, + messages=[{"content": content, "role": "user"}], model_id=model_id, tools=tools, )