diff --git a/src/agentex/lib/adk/providers/_modules/openai.py b/src/agentex/lib/adk/providers/_modules/openai.py index 7cf222ee..a0a59c04 100644 --- a/src/agentex/lib/adk/providers/_modules/openai.py +++ b/src/agentex/lib/adk/providers/_modules/openai.py @@ -88,6 +88,7 @@ async def run_agent( mcp_timeout_seconds: int | None = None, input_guardrails: list[InputGuardrail] | None = None, output_guardrails: list[OutputGuardrail] | None = None, + max_turns: int | None = None, ) -> SerializableRunResult | RunResult: """ Run an agent without streaming or TaskMessage creation. @@ -114,6 +115,7 @@ async def run_agent( mcp_timeout_seconds: Optional param to set the timeout threshold for the MCP servers. Defaults to 5 seconds. input_guardrails: Optional list of input guardrails to run on initial user input. output_guardrails: Optional list of output guardrails to run on final agent output. + max_turns: Maximum number of turns the agent can take. Uses Runner's default if None. Returns: Union[SerializableRunResult, RunResult]: SerializableRunResult when in Temporal, RunResult otherwise. @@ -136,6 +138,7 @@ async def run_agent( mcp_timeout_seconds=mcp_timeout_seconds, input_guardrails=input_guardrails, output_guardrails=output_guardrails, + max_turns=max_turns, ) return await ActivityHelpers.execute_activity( activity_name=OpenAIActivityName.RUN_AGENT, @@ -163,6 +166,7 @@ async def run_agent( mcp_timeout_seconds=mcp_timeout_seconds, input_guardrails=input_guardrails, output_guardrails=output_guardrails, + max_turns=max_turns, ) async def run_agent_auto_send( @@ -191,6 +195,7 @@ async def run_agent_auto_send( mcp_timeout_seconds: int | None = None, input_guardrails: list[InputGuardrail] | None = None, output_guardrails: list[OutputGuardrail] | None = None, + max_turns: int | None = None, ) -> SerializableRunResult | RunResult: """ Run an agent with automatic TaskMessage creation. @@ -216,6 +221,7 @@ async def run_agent_auto_send( mcp_timeout_seconds: Optional param to set the timeout threshold for the MCP servers. Defaults to 5 seconds. input_guardrails: Optional list of input guardrails to run on initial user input. output_guardrails: Optional list of output guardrails to run on final agent output. + max_turns: Maximum number of turns the agent can take. Uses Runner's default if None. Returns: Union[SerializableRunResult, RunResult]: SerializableRunResult when in Temporal, RunResult otherwise. @@ -239,6 +245,7 @@ async def run_agent_auto_send( mcp_timeout_seconds=mcp_timeout_seconds, input_guardrails=input_guardrails, output_guardrails=output_guardrails, + max_turns=max_turns, ) return await ActivityHelpers.execute_activity( activity_name=OpenAIActivityName.RUN_AGENT_AUTO_SEND, @@ -267,6 +274,7 @@ async def run_agent_auto_send( mcp_timeout_seconds=mcp_timeout_seconds, input_guardrails=input_guardrails, output_guardrails=output_guardrails, + max_turns=max_turns, ) async def run_agent_streamed( @@ -291,6 +299,7 @@ async def run_agent_streamed( mcp_timeout_seconds: int | None = None, input_guardrails: list[InputGuardrail] | None = None, output_guardrails: list[OutputGuardrail] | None = None, + max_turns: int | None = None, ) -> RunResultStreaming: """ Run an agent with streaming enabled but no TaskMessage creation. @@ -320,6 +329,7 @@ async def run_agent_streamed( mcp_timeout_seconds: Optional param to set the timeout threshold for the MCP servers. Defaults to 5 seconds. input_guardrails: Optional list of input guardrails to run on initial user input. output_guardrails: Optional list of output guardrails to run on final agent output. + max_turns: Maximum number of turns the agent can take. Uses Runner's default if None. Returns: RunResultStreaming: The result of the agent run with streaming. @@ -352,6 +362,7 @@ async def run_agent_streamed( mcp_timeout_seconds=mcp_timeout_seconds, input_guardrails=input_guardrails, output_guardrails=output_guardrails, + max_turns=max_turns, ) async def run_agent_streamed_auto_send( @@ -380,6 +391,7 @@ async def run_agent_streamed_auto_send( mcp_timeout_seconds: int | None = None, input_guardrails: list[InputGuardrail] | None = None, output_guardrails: list[OutputGuardrail] | None = None, + max_turns: int | None = None, ) -> SerializableRunResultStreaming | RunResultStreaming: """ Run an agent with streaming enabled and automatic TaskMessage creation. @@ -405,6 +417,7 @@ async def run_agent_streamed_auto_send( output_type: Optional output type. tool_use_behavior: Optional tool use behavior. mcp_timeout_seconds: Optional param to set the timeout threshold for the MCP servers. Defaults to 5 seconds. + max_turns: Maximum number of turns the agent can take. Uses Runner's default if None. Returns: Union[SerializableRunResultStreaming, RunResultStreaming]: SerializableRunResultStreaming when in Temporal, RunResultStreaming otherwise. @@ -428,6 +441,7 @@ async def run_agent_streamed_auto_send( mcp_timeout_seconds=mcp_timeout_seconds, input_guardrails=input_guardrails, output_guardrails=output_guardrails, + max_turns=max_turns ) return await ActivityHelpers.execute_activity( activity_name=OpenAIActivityName.RUN_AGENT_STREAMED_AUTO_SEND, @@ -456,4 +470,5 @@ async def run_agent_streamed_auto_send( mcp_timeout_seconds=mcp_timeout_seconds, input_guardrails=input_guardrails, output_guardrails=output_guardrails, + max_turns=max_turns, ) diff --git a/src/agentex/lib/core/services/adk/providers/openai.py b/src/agentex/lib/core/services/adk/providers/openai.py index 1c3c7e71..8ed6d9c2 100644 --- a/src/agentex/lib/core/services/adk/providers/openai.py +++ b/src/agentex/lib/core/services/adk/providers/openai.py @@ -1,5 +1,4 @@ # Standard library imports -import json from contextlib import AsyncExitStack, asynccontextmanager from typing import Any, Literal @@ -11,7 +10,8 @@ from mcp import StdioServerParameters from openai.types.responses import ( ResponseCompletedEvent, - ResponseFunctionToolCall, + ResponseFunctionWebSearch, + ResponseCodeInterpreterToolCall, ResponseOutputItemDoneEvent, ResponseTextDeltaEvent, ResponseReasoningSummaryTextDeltaEvent, @@ -85,6 +85,86 @@ def __init__( self.streaming_service = streaming_service self.tracer = tracer + def _extract_tool_call_info( + self, tool_call_item: Any + ) -> tuple[str, str, dict[str, Any]]: + """ + Extract call_id, tool_name, and tool_arguments from a tool call item. + + Args: + tool_call_item: The tool call item to process + + Returns: + A tuple of (call_id, tool_name, tool_arguments) + """ + # Generic handling for different tool call types + # Try 'call_id' first, then 'id', then generate placeholder + if hasattr(tool_call_item, 'call_id'): + call_id = tool_call_item.call_id + elif hasattr(tool_call_item, 'id'): + call_id = tool_call_item.id + else: + call_id = f"unknown_call_{id(tool_call_item)}" + logger.warning( + f"Warning: Tool call item {type(tool_call_item)} has " + f"neither 'call_id' nor 'id' attribute, using placeholder: " + f"{call_id}" + ) + + if isinstance(tool_call_item, ResponseFunctionWebSearch): + tool_name = "web_search" + tool_arguments = { + "action": tool_call_item.action.model_dump(), + "status": tool_call_item.status + } + elif isinstance(tool_call_item, ResponseCodeInterpreterToolCall): + tool_name = "code_interpreter" + tool_arguments = { + "code": tool_call_item.code, + "status": tool_call_item.status + } + else: + # Generic handling for any tool call type + tool_name = getattr(tool_call_item, 'name', type(tool_call_item).__name__) + tool_arguments = tool_call_item.model_dump() + + return call_id, tool_name, tool_arguments + + def _extract_tool_response_info( + self, tool_call_map: dict[str, Any], tool_output_item: Any + ) -> tuple[str, str, str]: + """ + Extract call_id, tool_name, and content from a tool output item. + + Args: + tool_call_map: Map of call_ids to tool_call items + tool_output_item: The tool output item to process + + Returns: + A tuple of (call_id, tool_name, content) + """ + # Extract call_id and content from the tool_output_item + # Handle both dictionary access and attribute access + if hasattr(tool_output_item, 'get') and callable(tool_output_item.get): + # Dictionary-like access + call_id = tool_output_item["call_id"] + content = tool_output_item["output"] + else: + # Attribute access for structured objects + call_id = getattr(tool_output_item, 'call_id', None) + content = getattr(tool_output_item, 'output', None) + + # Get the name from the tool call map using generic approach + tool_call = tool_call_map[call_id] + if hasattr(tool_call, "name"): + tool_name = getattr(tool_call, "name") + elif hasattr(tool_call, "type"): + tool_name = getattr(tool_call, "type") + else: + tool_name = type(tool_call).__name__ + + return call_id, tool_name, content + async def run_agent( self, input_list: list[dict[str, Any]], @@ -107,6 +187,7 @@ async def run_agent( mcp_timeout_seconds: int | None = None, input_guardrails: list[InputGuardrail] | None = None, output_guardrails: list[OutputGuardrail] | None = None, + max_turns: int | None = None, ) -> RunResult: """ Run an agent without streaming or TaskMessage creation. @@ -131,6 +212,8 @@ async def run_agent( initial user input. output_guardrails: Optional list of output guardrails to run on final agent output. + mcp_timeout_seconds: Optional param to set the timeout threshold for the MCP servers. Defaults to 5 seconds. + max_turns: Maximum number of turns the agent can take. Uses Runner's default if None. Returns: SerializableRunResult: The result of the agent run. """ @@ -152,6 +235,7 @@ async def run_agent( "tools": tools, "output_type": output_type, "tool_use_behavior": tool_use_behavior, + "max_turns": max_turns, }, ) as span: heartbeat_if_in_workflow("run agent") @@ -159,7 +243,9 @@ async def run_agent( async with mcp_server_context( mcp_server_params, mcp_timeout_seconds ) as servers: - tools = [tool.to_oai_function_tool() for tool in tools] if tools else [] + tools = [ + tool.to_oai_function_tool()for tool in tools + ] if tools else [] handoffs = ( [Agent(**handoff.model_dump()) for handoff in handoffs] if handoffs @@ -189,7 +275,10 @@ async def run_agent( agent = Agent(**agent_kwargs) # Run without streaming - result = await Runner.run(starting_agent=agent, input=input_list) + if max_turns is not None: + result = await Runner.run(starting_agent=agent, input=input_list, max_turns=max_turns) + else: + result = await Runner.run(starting_agent=agent, input=input_list) if span: span.output = { @@ -227,6 +316,7 @@ async def run_agent_auto_send( mcp_timeout_seconds: int | None = None, input_guardrails: list[InputGuardrail] | None = None, output_guardrails: list[OutputGuardrail] | None = None, + max_turns: int | None = None, ) -> RunResult: """ Run an agent with automatic TaskMessage creation. @@ -249,6 +339,7 @@ async def run_agent_auto_send( mcp_timeout_seconds: Optional param to set the timeout threshold for the MCP servers. Defaults to 5 seconds. input_guardrails: Optional list of input guardrails to run on initial user input. output_guardrails: Optional list of output guardrails to run on final agent output. + max_turns: Maximum number of turns the agent can take. Uses Runner's default if None. Returns: SerializableRunResult: The result of the agent run. """ @@ -276,6 +367,7 @@ async def run_agent_auto_send( "tools": tools, "output_type": output_type, "tool_use_behavior": tool_use_behavior, + "max_turns": max_turns, }, ) as span: heartbeat_if_in_workflow("run agent auto send") @@ -312,7 +404,10 @@ async def run_agent_auto_send( agent = Agent(**agent_kwargs) # Run without streaming - result = await Runner.run(starting_agent=agent, input=input_list) + if max_turns is not None: + result = await Runner.run(starting_agent=agent, input=input_list, max_turns=max_turns) + else: + result = await Runner.run(starting_agent=agent, input=input_list) if span: span.output = { @@ -325,7 +420,7 @@ async def run_agent_auto_send( "final_output": result.final_output, } - tool_call_map: dict[str, ResponseFunctionToolCall] = {} + tool_call_map: dict[str, Any] = {} for item in result.new_items: if item.type == "message_output_item": @@ -349,13 +444,17 @@ async def run_agent_auto_send( ) elif item.type == "tool_call_item": - tool_call_map[item.raw_item.call_id] = item.raw_item + tool_call_item = item.raw_item + + # Extract tool call information using the helper method + call_id, tool_name, tool_arguments = self._extract_tool_call_info(tool_call_item) + tool_call_map[call_id] = tool_call_item tool_request_content = ToolRequestContent( author="agent", - tool_call_id=item.raw_item.call_id, - name=item.raw_item.name, - arguments=json.loads(item.raw_item.arguments), + tool_call_id=call_id, + name=tool_name, + arguments=tool_arguments, ) # Create tool request using streaming context @@ -376,11 +475,16 @@ async def run_agent_auto_send( elif item.type == "tool_call_output_item": tool_output_item = item.raw_item + # Extract tool response information using the helper method + call_id, tool_name, content = self._extract_tool_response_info( + tool_call_map, tool_output_item + ) + tool_response_content = ToolResponseContent( author="agent", - tool_call_id=tool_output_item["call_id"], - name=tool_call_map[tool_output_item["call_id"]].name, - content=tool_output_item["output"], + tool_call_id=call_id, + name=tool_name, + content=content, ) # Create tool response using streaming context async with ( @@ -422,6 +526,7 @@ async def run_agent_streamed( mcp_timeout_seconds: int | None = None, input_guardrails: list[InputGuardrail] | None = None, output_guardrails: list[OutputGuardrail] | None = None, + max_turns: int | None = None, ) -> RunResultStreaming: """ Run an agent with streaming enabled but no TaskMessage creation. @@ -446,6 +551,8 @@ async def run_agent_streamed( initial user input. output_guardrails: Optional list of output guardrails to run on final agent output. + mcp_timeout_seconds: Optional param to set the timeout threshold for the MCP servers. Defaults to 5 seconds. + max_turns: Maximum number of turns the agent can take. Uses Runner's default if None. Returns: RunResultStreaming: The result of the agent run with streaming. """ @@ -467,6 +574,7 @@ async def run_agent_streamed( "tools": tools, "output_type": output_type, "tool_use_behavior": tool_use_behavior, + "max_turns": max_turns, }, ) as span: heartbeat_if_in_workflow("run agent streamed") @@ -503,7 +611,10 @@ async def run_agent_streamed( agent = Agent(**agent_kwargs) # Run with streaming (but no TaskMessage creation) - result = Runner.run_streamed(starting_agent=agent, input=input_list) + if max_turns is not None: + result = Runner.run_streamed(starting_agent=agent, input=input_list, max_turns=max_turns) + else: + result = Runner.run_streamed(starting_agent=agent, input=input_list) if span: span.output = { @@ -541,6 +652,7 @@ async def run_agent_streamed_auto_send( mcp_timeout_seconds: int | None = None, input_guardrails: list[InputGuardrail] | None = None, output_guardrails: list[OutputGuardrail] | None = None, + max_turns: int | None = None, ) -> RunResultStreaming: """ Run an agent with streaming enabled and automatic TaskMessage creation. @@ -566,6 +678,8 @@ async def run_agent_streamed_auto_send( initial user input. output_guardrails: Optional list of output guardrails to run on final agent output. + mcp_timeout_seconds: Optional param to set the timeout threshold for the MCP servers. Defaults to 5 seconds. + max_turns: Maximum number of turns the agent can take. Uses Runner's default if None. Returns: RunResultStreaming: The result of the agent run with streaming. @@ -575,7 +689,7 @@ async def run_agent_streamed_auto_send( if self.agentex_client is None: raise ValueError("Agentex client must be provided for auto_send methods") - tool_call_map: dict[str, ResponseFunctionToolCall] = {} + tool_call_map: dict[str, Any] = {} trace = self.tracer.trace(trace_id) redacted_params = redact_mcp_server_params(mcp_server_params) @@ -596,6 +710,7 @@ async def run_agent_streamed_auto_send( "tools": tools, "output_type": output_type, "tool_use_behavior": tool_use_behavior, + "max_turns": max_turns, }, ) as span: heartbeat_if_in_workflow("run agent streamed auto send") @@ -632,7 +747,10 @@ async def run_agent_streamed_auto_send( agent = Agent(**agent_kwargs) # Run with streaming - result = Runner.run_streamed(starting_agent=agent, input=input_list) + if max_turns is not None: + result = Runner.run_streamed(starting_agent=agent, input=input_list, max_turns=max_turns) + else: + result = Runner.run_streamed(starting_agent=agent, input=input_list) item_id_to_streaming_context: dict[ str, StreamingTaskMessageContext @@ -649,13 +767,16 @@ async def run_agent_streamed_auto_send( if event.type == "run_item_stream_event": if event.item.type == "tool_call_item": tool_call_item = event.item.raw_item - tool_call_map[tool_call_item.call_id] = tool_call_item + + # Extract tool call information using the helper method + call_id, tool_name, tool_arguments = self._extract_tool_call_info(tool_call_item) + tool_call_map[call_id] = tool_call_item tool_request_content = ToolRequestContent( author="agent", - tool_call_id=tool_call_item.call_id, - name=tool_call_item.name, - arguments=json.loads(tool_call_item.arguments), + tool_call_id=call_id, + name=tool_name, + arguments=tool_arguments, ) # Create tool request using streaming context (immediate completion) @@ -677,13 +798,16 @@ async def run_agent_streamed_auto_send( elif event.item.type == "tool_call_output_item": tool_output_item = event.item.raw_item + # Extract tool response information using the helper method + call_id, tool_name, content = self._extract_tool_response_info( + tool_call_map, tool_output_item + ) + tool_response_content = ToolResponseContent( author="agent", - tool_call_id=tool_output_item["call_id"], - name=tool_call_map[ - tool_output_item["call_id"] - ].name, - content=tool_output_item["output"], + tool_call_id=call_id, + name=tool_name, + content=content, ) # Create tool response using streaming context (immediate completion) diff --git a/src/agentex/lib/core/temporal/activities/adk/providers/openai_activities.py b/src/agentex/lib/core/temporal/activities/adk/providers/openai_activities.py index bcad6bfb..bdbc2c57 100644 --- a/src/agentex/lib/core/temporal/activities/adk/providers/openai_activities.py +++ b/src/agentex/lib/core/temporal/activities/adk/providers/openai_activities.py @@ -1,34 +1,42 @@ # Standard library imports import base64 -from collections.abc import Callable -from contextlib import AsyncExitStack, asynccontextmanager from enum import Enum from typing import Any, Literal, Optional - -from pydantic import Field, PrivateAttr +from contextlib import AsyncExitStack, asynccontextmanager +from collections.abc import Callable import cloudpickle -from agents import RunContextWrapper, RunResult, RunResultStreaming +from mcp import StdioServerParameters +from agents import RunResult, RunContextWrapper, RunResultStreaming +from pydantic import Field, PrivateAttr +from agents.mcp import MCPServerStdio, MCPServerStdioParams +from temporalio import activity +from agents.tool import ( + ComputerTool as OAIComputerTool, + FunctionTool as OAIFunctionTool, + WebSearchTool as OAIWebSearchTool, + FileSearchTool as OAIFileSearchTool, + LocalShellTool as OAILocalShellTool, + CodeInterpreterTool as OAICodeInterpreterTool, + ImageGenerationTool as OAIImageGenerationTool, +) from agents.guardrail import InputGuardrail, OutputGuardrail from agents.exceptions import InputGuardrailTripwireTriggered, OutputGuardrailTripwireTriggered -from agents.mcp import MCPServerStdio, MCPServerStdioParams from agents.model_settings import ModelSettings as OAIModelSettings -from agents.tool import FunctionTool as OAIFunctionTool -from mcp import StdioServerParameters -from openai.types.responses.response_includable import ResponseIncludable from openai.types.shared.reasoning import Reasoning -from temporalio import activity -from agentex.lib.core.services.adk.providers.openai import OpenAIService +from openai.types.responses.response_includable import ResponseIncludable + +from agentex.lib.utils import logging + +# Third-party imports +from agentex.lib.types.tracing import BaseModelWithTraceParams # Local imports from agentex.lib.types.agent_results import ( SerializableRunResult, SerializableRunResultStreaming, ) - -# Third-party imports -from agentex.lib.types.tracing import BaseModelWithTraceParams -from agentex.lib.utils import logging +from agentex.lib.core.services.adk.providers.openai import OpenAIService logger = logging.make_logger(__name__) @@ -42,6 +50,147 @@ class OpenAIActivityName(str, Enum): RUN_AGENT_STREAMED_AUTO_SEND = "run_agent_streamed_auto_send" +class WebSearchTool(BaseModelWithTraceParams): + """Temporal-compatible wrapper for WebSearchTool.""" + + user_location: Optional[dict[str, Any]] = None # UserLocation object + search_context_size: Optional[Literal["low", "medium", "high"]] = "medium" + + def to_oai_function_tool(self) -> OAIWebSearchTool: + kwargs = {} + if self.user_location is not None: + kwargs["user_location"] = self.user_location + if self.search_context_size is not None: + kwargs["search_context_size"] = self.search_context_size + return OAIWebSearchTool(**kwargs) + + +class FileSearchTool(BaseModelWithTraceParams): + """Temporal-compatible wrapper for FileSearchTool.""" + + vector_store_ids: list[str] + max_num_results: Optional[int] = None + include_search_results: bool = False + ranking_options: Optional[dict[str, Any]] = None + filters: Optional[dict[str, Any]] = None + + def to_oai_function_tool(self): + return OAIFileSearchTool( + vector_store_ids=self.vector_store_ids, + max_num_results=self.max_num_results, + include_search_results=self.include_search_results, + ranking_options=self.ranking_options, + filters=self.filters, + ) + + +class ComputerTool(BaseModelWithTraceParams): + """Temporal-compatible wrapper for ComputerTool.""" + + # We need to serialize the computer object and safety check function + computer_serialized: str = Field(default="", description="Serialized computer object") + on_safety_check_serialized: str = Field(default="", description="Serialized safety check function") + + _computer: Any = PrivateAttr() + _on_safety_check: Optional[Callable] = PrivateAttr() + + def __init__( + self, + *, + computer: Any = None, + on_safety_check: Optional[Callable] = None, + **data, + ): + super().__init__(**data) + if computer is not None: + self.computer_serialized = self._serialize_callable(computer) + self._computer = computer + elif self.computer_serialized: + self._computer = self._deserialize_callable(self.computer_serialized) + + if on_safety_check is not None: + self.on_safety_check_serialized = self._serialize_callable(on_safety_check) + self._on_safety_check = on_safety_check + elif self.on_safety_check_serialized: + self._on_safety_check = self._deserialize_callable(self.on_safety_check_serialized) + + @classmethod + def _deserialize_callable(cls, serialized: str) -> Any: + encoded = serialized.encode() + serialized_bytes = base64.b64decode(encoded) + return cloudpickle.loads(serialized_bytes) + + @classmethod + def _serialize_callable(cls, func: Any) -> str: + serialized_bytes = cloudpickle.dumps(func) + encoded = base64.b64encode(serialized_bytes) + return encoded.decode() + + def to_oai_function_tool(self): + return OAIComputerTool( + computer=self._computer, + on_safety_check=self._on_safety_check, + ) + + +class CodeInterpreterTool(BaseModelWithTraceParams): + """Temporal-compatible wrapper for CodeInterpreterTool.""" + + tool_config: dict[str, Any] = Field( + default_factory=lambda: {"type": "code_interpreter"}, description="Tool configuration dict" + ) + + def to_oai_function_tool(self): + return OAICodeInterpreterTool(tool_config=self.tool_config) + + +class ImageGenerationTool(BaseModelWithTraceParams): + """Temporal-compatible wrapper for ImageGenerationTool.""" + + tool_config: dict[str, Any] = Field( + default_factory=lambda: {"type": "image_generation"}, description="Tool configuration dict" + ) + + def to_oai_function_tool(self): + return OAIImageGenerationTool(tool_config=self.tool_config) + + +class LocalShellTool(BaseModelWithTraceParams): + """Temporal-compatible wrapper for LocalShellTool.""" + + executor_serialized: str = Field(default="", description="Serialized LocalShellExecutor object") + + _executor: Any = PrivateAttr() + + def __init__( + self, + *, + executor: Any = None, + **data, + ): + super().__init__(**data) + if executor is not None: + self.executor_serialized = self._serialize_callable(executor) + self._executor = executor + elif self.executor_serialized: + self._executor = self._deserialize_callable(self.executor_serialized) + + @classmethod + def _deserialize_callable(cls, serialized: str) -> Any: + encoded = serialized.encode() + serialized_bytes = base64.b64decode(encoded) + return cloudpickle.loads(serialized_bytes) + + @classmethod + def _serialize_callable(cls, func: Any) -> str: + serialized_bytes = cloudpickle.dumps(func) + encoded = base64.b64encode(serialized_bytes) + return encoded.decode() + + def to_oai_function_tool(self): + return OAILocalShellTool(executor=self._executor) + + class FunctionTool(BaseModelWithTraceParams): name: str description: str @@ -79,22 +228,16 @@ def __init__( super().__init__(**data) if not on_invoke_tool: if not self.on_invoke_tool_serialized: - raise ValueError( - "One of `on_invoke_tool` or `on_invoke_tool_serialized` should be set" - ) + raise ValueError("One of `on_invoke_tool` or `on_invoke_tool_serialized` should be set") else: - on_invoke_tool = self._deserialize_callable( - self.on_invoke_tool_serialized - ) + on_invoke_tool = self._deserialize_callable(self.on_invoke_tool_serialized) else: self.on_invoke_tool_serialized = self._serialize_callable(on_invoke_tool) self._on_invoke_tool = on_invoke_tool @classmethod - def _deserialize_callable( - cls, serialized: str - ) -> Callable[[RunContextWrapper, str], Any]: + def _deserialize_callable(cls, serialized: str) -> Callable[[RunContextWrapper, str], Any]: encoded = serialized.encode() serialized_bytes = base64.b64decode(encoded) return cloudpickle.loads(serialized_bytes) @@ -108,11 +251,9 @@ def _serialize_callable(cls, func: Callable) -> str: @property def on_invoke_tool(self) -> Callable[[RunContextWrapper, str], Any]: if self._on_invoke_tool is None and self.on_invoke_tool_serialized: - self._on_invoke_tool = self._deserialize_callable( - self.on_invoke_tool_serialized - ) + self._on_invoke_tool = self._deserialize_callable(self.on_invoke_tool_serialized) return self._on_invoke_tool - + @on_invoke_tool.setter def on_invoke_tool(self, value: Callable[[RunContextWrapper, str], Any]): self.on_invoke_tool_serialized = self._serialize_callable(value) @@ -135,8 +276,9 @@ def to_oai_function_tool(self) -> OAIFunctionTool: class TemporalInputGuardrail(BaseModelWithTraceParams): - """Temporal-compatible wrapper for InputGuardrail with function + """Temporal-compatible wrapper for InputGuardrail with function serialization.""" + name: str _guardrail_function: Callable = PrivateAttr() guardrail_function_serialized: str = Field( @@ -157,19 +299,12 @@ def __init__( super().__init__(**data) if not guardrail_function: if not self.guardrail_function_serialized: - raise ValueError( - "One of `guardrail_function` or " - "`guardrail_function_serialized` should be set" - ) + raise ValueError("One of `guardrail_function` or `guardrail_function_serialized` should be set") else: - guardrail_function = self._deserialize_callable( - self.guardrail_function_serialized - ) + guardrail_function = self._deserialize_callable(self.guardrail_function_serialized) else: - self.guardrail_function_serialized = self._serialize_callable( - guardrail_function - ) - + self.guardrail_function_serialized = self._serialize_callable(guardrail_function) + self._guardrail_function = guardrail_function @classmethod @@ -186,13 +321,10 @@ def _serialize_callable(cls, func: Callable) -> str: @property def guardrail_function(self) -> Callable: - if (self._guardrail_function is None and - self.guardrail_function_serialized): - self._guardrail_function = self._deserialize_callable( - self.guardrail_function_serialized - ) + if self._guardrail_function is None and self.guardrail_function_serialized: + self._guardrail_function = self._deserialize_callable(self.guardrail_function_serialized) return self._guardrail_function - + @guardrail_function.setter def guardrail_function(self, value: Callable): self.guardrail_function_serialized = self._serialize_callable(value) @@ -200,15 +332,13 @@ def guardrail_function(self, value: Callable): def to_oai_input_guardrail(self) -> InputGuardrail: """Convert to OpenAI InputGuardrail.""" - return InputGuardrail( - guardrail_function=self.guardrail_function, - name=self.name - ) + return InputGuardrail(guardrail_function=self.guardrail_function, name=self.name) class TemporalOutputGuardrail(BaseModelWithTraceParams): - """Temporal-compatible wrapper for OutputGuardrail with function + """Temporal-compatible wrapper for OutputGuardrail with function serialization.""" + name: str _guardrail_function: Callable = PrivateAttr() guardrail_function_serialized: str = Field( @@ -229,19 +359,12 @@ def __init__( super().__init__(**data) if not guardrail_function: if not self.guardrail_function_serialized: - raise ValueError( - "One of `guardrail_function` or " - "`guardrail_function_serialized` should be set" - ) + raise ValueError("One of `guardrail_function` or `guardrail_function_serialized` should be set") else: - guardrail_function = self._deserialize_callable( - self.guardrail_function_serialized - ) + guardrail_function = self._deserialize_callable(self.guardrail_function_serialized) else: - self.guardrail_function_serialized = self._serialize_callable( - guardrail_function - ) - + self.guardrail_function_serialized = self._serialize_callable(guardrail_function) + self._guardrail_function = guardrail_function @classmethod @@ -258,13 +381,10 @@ def _serialize_callable(cls, func: Callable) -> str: @property def guardrail_function(self) -> Callable: - if (self._guardrail_function is None and - self.guardrail_function_serialized): - self._guardrail_function = self._deserialize_callable( - self.guardrail_function_serialized - ) + if self._guardrail_function is None and self.guardrail_function_serialized: + self._guardrail_function = self._deserialize_callable(self.guardrail_function_serialized) return self._guardrail_function - + @guardrail_function.setter def guardrail_function(self, value: Callable): self.guardrail_function_serialized = self._serialize_callable(value) @@ -272,10 +392,7 @@ def guardrail_function(self, value: Callable): def to_oai_output_guardrail(self) -> OutputGuardrail: """Convert to OpenAI OutputGuardrail.""" - return OutputGuardrail( - guardrail_function=self.guardrail_function, - name=self.name - ) + return OutputGuardrail(guardrail_function=self.guardrail_function, name=self.name) class ModelSettings(BaseModelWithTraceParams): @@ -297,9 +414,7 @@ class ModelSettings(BaseModelWithTraceParams): extra_args: dict[str, Any] | None = None def to_oai_model_settings(self) -> OAIModelSettings: - return OAIModelSettings( - **self.model_dump(exclude=["trace_id", "parent_span_id"]) - ) + return OAIModelSettings(**self.model_dump(exclude=["trace_id", "parent_span_id"])) class RunAgentParams(BaseModelWithTraceParams): @@ -313,12 +428,24 @@ class RunAgentParams(BaseModelWithTraceParams): handoffs: list["RunAgentParams"] | None = None model: str | None = None model_settings: ModelSettings | None = None - tools: list[FunctionTool] | None = None + tools: ( + list[ + FunctionTool + | WebSearchTool + | FileSearchTool + | ComputerTool + | CodeInterpreterTool + | ImageGenerationTool + | LocalShellTool + ] + | None + ) = None output_type: Any = None tool_use_behavior: Literal["run_llm_again", "stop_on_first_tool"] = "run_llm_again" mcp_timeout_seconds: int | None = None input_guardrails: list[TemporalInputGuardrail] | None = None output_guardrails: list[TemporalOutputGuardrail] | None = None + max_turns: int | None = None class RunAgentAutoSendParams(RunAgentParams): @@ -365,11 +492,11 @@ async def run_agent(self, params: RunAgentParams) -> SerializableRunResult: input_guardrails = None if params.input_guardrails: input_guardrails = [g.to_oai_input_guardrail() for g in params.input_guardrails] - + output_guardrails = None if params.output_guardrails: output_guardrails = [g.to_oai_output_guardrail() for g in params.output_guardrails] - + result = await self._openai_service.run_agent( input_list=params.input_list, mcp_server_params=params.mcp_server_params, @@ -386,23 +513,23 @@ async def run_agent(self, params: RunAgentParams) -> SerializableRunResult: tool_use_behavior=params.tool_use_behavior, input_guardrails=input_guardrails, output_guardrails=output_guardrails, + mcp_timeout_seconds=params.mcp_timeout_seconds, + max_turns=params.max_turns, ) return self._to_serializable_run_result(result) @activity.defn(name=OpenAIActivityName.RUN_AGENT_AUTO_SEND) - async def run_agent_auto_send( - self, params: RunAgentAutoSendParams - ) -> SerializableRunResult: + async def run_agent_auto_send(self, params: RunAgentAutoSendParams) -> SerializableRunResult: """Run an agent with automatic TaskMessage creation.""" # Convert Temporal guardrails to OpenAI guardrails input_guardrails = None if params.input_guardrails: input_guardrails = [g.to_oai_input_guardrail() for g in params.input_guardrails] - + output_guardrails = None if params.output_guardrails: output_guardrails = [g.to_oai_output_guardrail() for g in params.output_guardrails] - + try: result = await self._openai_service.run_agent_auto_send( task_id=params.task_id, @@ -421,65 +548,60 @@ async def run_agent_auto_send( tool_use_behavior=params.tool_use_behavior, input_guardrails=input_guardrails, output_guardrails=output_guardrails, + mcp_timeout_seconds=params.mcp_timeout_seconds, + max_turns=params.max_turns, ) return self._to_serializable_run_result(result) except InputGuardrailTripwireTriggered as e: # Handle guardrail trigger gracefully - rejection_message = "I'm sorry, but I cannot process this request due to a guardrail. Please try a different question." - + rejection_message = ( + "I'm sorry, but I cannot process this request due to a guardrail. Please try a different question." + ) + # Try to extract rejection message from the guardrail result - if hasattr(e, 'guardrail_result') and hasattr(e.guardrail_result, 'output'): - output_info = getattr(e.guardrail_result.output, 'output_info', {}) - if isinstance(output_info, dict) and 'rejection_message' in output_info: - rejection_message = output_info['rejection_message'] - + if hasattr(e, "guardrail_result") and hasattr(e.guardrail_result, "output"): + output_info = getattr(e.guardrail_result.output, "output_info", {}) + if isinstance(output_info, dict) and "rejection_message" in output_info: + rejection_message = output_info["rejection_message"] + # Build the final input list with the rejection message final_input_list = list(params.input_list or []) - final_input_list.append({ - "role": "assistant", - "content": rejection_message - }) - - return SerializableRunResult( - final_output=rejection_message, - final_input_list=final_input_list - ) + final_input_list.append({"role": "assistant", "content": rejection_message}) + + return SerializableRunResult(final_output=rejection_message, final_input_list=final_input_list) except OutputGuardrailTripwireTriggered as e: # Handle output guardrail trigger gracefully - rejection_message = "I'm sorry, but I cannot provide this response due to a guardrail. Please try a different question." - + rejection_message = ( + "I'm sorry, but I cannot provide this response due to a guardrail. Please try a different question." + ) + # Try to extract rejection message from the guardrail result - if hasattr(e, 'guardrail_result') and hasattr(e.guardrail_result, 'output'): - output_info = getattr(e.guardrail_result.output, 'output_info', {}) - if isinstance(output_info, dict) and 'rejection_message' in output_info: - rejection_message = output_info['rejection_message'] - + if hasattr(e, "guardrail_result") and hasattr(e.guardrail_result, "output"): + output_info = getattr(e.guardrail_result.output, "output_info", {}) + if isinstance(output_info, dict) and "rejection_message" in output_info: + rejection_message = output_info["rejection_message"] + # Build the final input list with the rejection message final_input_list = list(params.input_list or []) - final_input_list.append({ - "role": "assistant", - "content": rejection_message - }) - - return SerializableRunResult( - final_output=rejection_message, - final_input_list=final_input_list - ) + final_input_list.append({"role": "assistant", "content": rejection_message}) + + return SerializableRunResult(final_output=rejection_message, final_input_list=final_input_list) @activity.defn(name=OpenAIActivityName.RUN_AGENT_STREAMED_AUTO_SEND) async def run_agent_streamed_auto_send( self, params: RunAgentStreamedAutoSendParams ) -> SerializableRunResultStreaming: """Run an agent with streaming and automatic TaskMessage creation.""" + # Convert Temporal guardrails to OpenAI guardrails input_guardrails = None if params.input_guardrails: input_guardrails = [g.to_oai_input_guardrail() for g in params.input_guardrails] - + output_guardrails = None if params.output_guardrails: output_guardrails = [g.to_oai_output_guardrail() for g in params.output_guardrails] - + try: result = await self._openai_service.run_agent_streamed_auto_send( task_id=params.task_id, @@ -498,50 +620,44 @@ async def run_agent_streamed_auto_send( tool_use_behavior=params.tool_use_behavior, input_guardrails=input_guardrails, output_guardrails=output_guardrails, + mcp_timeout_seconds=params.mcp_timeout_seconds, + max_turns=params.max_turns, ) return self._to_serializable_run_result_streaming(result) except InputGuardrailTripwireTriggered as e: # Handle guardrail trigger gracefully - rejection_message = "I'm sorry, but I cannot process this request due to a guardrail. Please try a different question." - + rejection_message = ( + "I'm sorry, but I cannot process this request due to a guardrail. Please try a different question." + ) + # Try to extract rejection message from the guardrail result - if hasattr(e, 'guardrail_result') and hasattr(e.guardrail_result, 'output'): - output_info = getattr(e.guardrail_result.output, 'output_info', {}) - if isinstance(output_info, dict) and 'rejection_message' in output_info: - rejection_message = output_info['rejection_message'] - + if hasattr(e, "guardrail_result") and hasattr(e.guardrail_result, "output"): + output_info = getattr(e.guardrail_result.output, "output_info", {}) + if isinstance(output_info, dict) and "rejection_message" in output_info: + rejection_message = output_info["rejection_message"] + # Build the final input list with the rejection message final_input_list = list(params.input_list or []) - final_input_list.append({ - "role": "assistant", - "content": rejection_message - }) - - return SerializableRunResultStreaming( - final_output=rejection_message, - final_input_list=final_input_list - ) + final_input_list.append({"role": "assistant", "content": rejection_message}) + + return SerializableRunResultStreaming(final_output=rejection_message, final_input_list=final_input_list) except OutputGuardrailTripwireTriggered as e: # Handle output guardrail trigger gracefully - rejection_message = "I'm sorry, but I cannot provide this response due to a guardrail. Please try a different question." - + rejection_message = ( + "I'm sorry, but I cannot provide this response due to a guardrail. Please try a different question." + ) + # Try to extract rejection message from the guardrail result - if hasattr(e, 'guardrail_result') and hasattr(e.guardrail_result, 'output'): - output_info = getattr(e.guardrail_result.output, 'output_info', {}) - if isinstance(output_info, dict) and 'rejection_message' in output_info: - rejection_message = output_info['rejection_message'] - + if hasattr(e, "guardrail_result") and hasattr(e.guardrail_result, "output"): + output_info = getattr(e.guardrail_result.output, "output_info", {}) + if isinstance(output_info, dict) and "rejection_message" in output_info: + rejection_message = output_info["rejection_message"] + # Build the final input list with the rejection message final_input_list = list(params.input_list or []) - final_input_list.append({ - "role": "assistant", - "content": rejection_message - }) - - return SerializableRunResultStreaming( - final_output=rejection_message, - final_input_list=final_input_list - ) + final_input_list.append({"role": "assistant", "content": rejection_message}) + + return SerializableRunResultStreaming(final_output=rejection_message, final_input_list=final_input_list) @staticmethod def _to_serializable_run_result(result: RunResult) -> SerializableRunResult: diff --git a/tests/lib/__init__.py b/tests/lib/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/lib/adk/__init__.py b/tests/lib/adk/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/lib/adk/providers/__init__.py b/tests/lib/adk/providers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/lib/adk/providers/test_openai_activities.py b/tests/lib/adk/providers/test_openai_activities.py new file mode 100644 index 00000000..830263d7 --- /dev/null +++ b/tests/lib/adk/providers/test_openai_activities.py @@ -0,0 +1,651 @@ +from unittest.mock import Mock, patch + +import pytest +from agents import RunResult, RunResultStreaming +from temporalio.testing import ActivityEnvironment +from openai.types.responses import ResponseCodeInterpreterToolCall + + +class TestOpenAIActivities: + @pytest.fixture + def sample_run_result(self): + """Create a sample RunResult for mocking.""" + mock_result = Mock(spec=RunResult) + mock_result.final_output = "Hello! How can I help you today?" + mock_result.to_input_list.return_value = [ + {"role": "user", "content": "Hello, world!"}, + {"role": "assistant", "content": "Hello! How can I help you today?"}, + ] + # Add new_items attribute that the OpenAIService expects + mock_result.new_items = [] + return mock_result + + @pytest.mark.parametrize( + "max_turns,should_be_passed", + [ + (None, False), + (7, True), # Test with non-default value (default is 10) + ], + ) + @patch("agents.Runner.run") + async def test_run_agent(self, mock_runner_run, max_turns, should_be_passed, sample_run_result): + """Comprehensive test for run_agent covering all major scenarios.""" + from agentex.lib.core.temporal.activities.adk.providers.openai_activities import RunAgentParams + + # Arrange + mock_runner_run.return_value = sample_run_result + mock_tracer = self._create_mock_tracer() + _, openai_activities, env = self._create_test_setup(mock_tracer) + + # Create params with or without max_turns + params = RunAgentParams( + input_list=[{"role": "user", "content": "Hello, world!"}], + mcp_server_params=[], + agent_name="test_agent", + agent_instructions="You are a helpful assistant", + max_turns=max_turns, + trace_id="test-trace-id", + parent_span_id="test-span-id", + ) + + # Act + result = await env.run(openai_activities.run_agent, params) + + # Assert - Result structure + self._assert_result_structure(result) + + # Assert - Runner call + mock_runner_run.assert_called_once() + call_args = mock_runner_run.call_args + + # Assert - Runner signature validation + self._assert_runner_call_signature(call_args) + + # Assert - Input parameter matches + assert call_args.kwargs["input"] == params.input_list + + # Assert - Starting agent parameters + starting_agent = call_args.kwargs["starting_agent"] + self._assert_starting_agent_params(starting_agent, params) + + # Assert - Max turns parameter handling + if should_be_passed: + assert "max_turns" in call_args.kwargs, f"max_turns should be passed when set to {max_turns}" + assert call_args.kwargs["max_turns"] == max_turns, f"max_turns value should be {max_turns}" + else: + assert "max_turns" not in call_args.kwargs, "max_turns should not be passed when None" + + @pytest.mark.parametrize( + "tools_case", + [ + "no_tools", + "function_tool", + "web_search_tool", + "file_search_tool", + "computer_tool", + "code_interpreter_tool", + "image_generation_tool", + "local_shell_tool", + "mixed_tools", + ], + ) + @patch("agents.Runner.run") + async def test_run_agent_tools_conversion(self, mock_runner_run, tools_case, sample_run_result): + """Test that tools are properly converted from Temporal to OpenAI agents format.""" + from agentex.lib.core.temporal.activities.adk.providers.openai_activities import ( + RunAgentParams, + ) + + # Arrange + mock_runner_run.return_value = sample_run_result + mock_tracer = self._create_mock_tracer() + _, openai_activities, env = self._create_test_setup(mock_tracer) + + # Create different tool configurations based on test case + tools = self._create_tools_for_case(tools_case) + + params = RunAgentParams( + input_list=[{"role": "user", "content": "Hello, world!"}], + mcp_server_params=[], + agent_name="test_agent", + agent_instructions="You are a helpful assistant", + tools=tools, + trace_id="test-trace-id", + parent_span_id="test-span-id", + ) + + # Act + result = await env.run(openai_activities.run_agent, params) + + # Assert - Result structure + self._assert_result_structure(result) + + # Assert - Runner call + mock_runner_run.assert_called_once() + call_args = mock_runner_run.call_args + + # Assert - Runner signature validation + self._assert_runner_call_signature(call_args) + + # Assert - Agent was created and tools were converted properly + starting_agent = call_args.kwargs["starting_agent"] + self._assert_tools_conversion(starting_agent, tools_case, tools) + + @patch("agents.Runner.run") + async def test_run_agent_auto_send_with_tool_responses(self, mock_runner_run): + """Test run_agent_auto_send with code interpreter tool responses.""" + from agentex.lib.core.temporal.activities.adk.providers.openai_activities import ( + CodeInterpreterTool, + RunAgentAutoSendParams, + ) + + # Arrange - Setup test environment + mock_tracer = self._create_mock_tracer() + openai_service, openai_activities, env = self._create_test_setup(mock_tracer) + mock_streaming_context = self._setup_streaming_service_mocks(openai_service) + + # Create tool call and response mocks using helpers + code_interpreter_call = self._create_code_interpreter_tool_call_mock() + mock_tool_call_item = self._create_tool_call_item_mock(code_interpreter_call) + mock_tool_output_item = self._create_tool_output_item_mock() + + # Create a mock result with tool calls that will be processed + mock_result_with_tools = Mock(spec=RunResult) + mock_result_with_tools.final_output = "Code executed successfully" + mock_result_with_tools.to_input_list.return_value = [ + {"role": "user", "content": "Run some Python code"}, + {"role": "assistant", "content": "Code executed successfully"}, + ] + mock_result_with_tools.new_items = [mock_tool_call_item, mock_tool_output_item] + mock_runner_run.return_value = mock_result_with_tools + + # Create test parameters + params = RunAgentAutoSendParams( + input_list=[{"role": "user", "content": "Run some Python code"}], + mcp_server_params=[], + agent_name="test_agent", + agent_instructions=("You are a helpful assistant with code interpreter"), + tools=[CodeInterpreterTool(tool_config={"type": "code_interpreter"})], + trace_id="test-trace-id", + parent_span_id="test-span-id", + task_id="test-task-id", + ) + + result = await env.run(openai_activities.run_agent_auto_send, params) + + assert result.final_output == "Code executed successfully" + + # Verify runner.run was called with expected signature + mock_runner_run.assert_called_once() + call_args = mock_runner_run.call_args + self._assert_runner_call_signature(call_args) + + # Verify starting agent parameters + starting_agent = call_args.kwargs["starting_agent"] + # Create a mock object with the expected attributes + expected_params = Mock() + expected_params.agent_name = "test_agent" + expected_params.agent_instructions = "You are a helpful assistant with code interpreter" + expected_params.tools = [CodeInterpreterTool(tool_config={"type": "code_interpreter"})] + self._assert_starting_agent_params(starting_agent, expected_params) + + # Verify streaming context received tool request and response updates + # Should have been called twice - once for tool request, once for response + assert mock_streaming_context.stream_update.call_count == 2 + + # First call should be tool request + first_call = mock_streaming_context.stream_update.call_args_list[0] + first_update = first_call[1]["update"] # keyword argument + assert hasattr(first_update, "content") + assert first_update.content.name == "code_interpreter" + assert first_update.content.tool_call_id == "code_interpreter_call_123" + + # Second call should be tool response + second_call = mock_streaming_context.stream_update.call_args_list[1] + second_update = second_call[1]["update"] # keyword argument + assert hasattr(second_update, "content") + assert second_update.content.name == "code_interpreter_call" + assert second_update.content.tool_call_id == "code_interpreter_call_123" + + @patch("agents.Runner.run_streamed") + async def test_run_agent_streamed_auto_send(self, mock_runner_run_streamed): + """Test run_agent_streamed_auto_send with streaming and tool responses.""" + from agentex.lib.core.temporal.activities.adk.providers.openai_activities import ( + CodeInterpreterTool, + RunAgentStreamedAutoSendParams, + ) + + # Create streaming result mock using helper + mock_streaming_result = self._create_streaming_result_mock() + + # Create mock streaming events + async def mock_stream_events(): + # Tool call event + tool_call_event = Mock() + tool_call_event.type = "run_item_stream_event" + tool_call_item = Mock() + tool_call_item.type = "tool_call_item" + tool_call_item.raw_item = self._create_code_interpreter_tool_call_mock() + tool_call_event.item = tool_call_item + yield tool_call_event + + # Tool response event + tool_response_event = Mock() + tool_response_event.type = "run_item_stream_event" + tool_response_item = Mock() + tool_response_item.type = "tool_call_output_item" + tool_response_item.raw_item = {"call_id": "code_interpreter_call_123", "output": "Hello from streaming"} + tool_response_event.item = tool_response_item + yield tool_response_event + + mock_streaming_result.stream_events = mock_stream_events + mock_runner_run_streamed.return_value = mock_streaming_result + + # Setup test environment + mock_tracer = self._create_mock_tracer() + openai_service, openai_activities, env = self._create_test_setup(mock_tracer) + mock_streaming_context = self._setup_streaming_service_mocks(openai_service) + + # Create test parameters + params = RunAgentStreamedAutoSendParams( + input_list=[{"role": "user", "content": "Run some Python code"}], + mcp_server_params=[], + agent_name="test_agent", + agent_instructions=("You are a helpful assistant with code interpreter"), + tools=[CodeInterpreterTool(tool_config={"type": "code_interpreter"})], + trace_id="test-trace-id", + parent_span_id="test-span-id", + task_id="test-task-id", + ) + + # Act + result = await env.run(openai_activities.run_agent_streamed_auto_send, params) + + # Assert - Result structure (expecting SerializableRunResultStreaming from activity) + from agentex.lib.types.agent_results import SerializableRunResultStreaming + + assert isinstance(result, SerializableRunResultStreaming) + assert result.final_output == "Code executed successfully" + + # Verify runner.run_streamed was called with expected signature + mock_runner_run_streamed.assert_called_once() + call_args = mock_runner_run_streamed.call_args + self._assert_runner_call_signature_streamed(call_args) + + # Verify starting agent parameters + starting_agent = call_args.kwargs["starting_agent"] + # Create a mock object with the expected attributes + expected_params = Mock() + expected_params.agent_name = "test_agent" + expected_params.agent_instructions = "You are a helpful assistant with code interpreter" + expected_params.tools = [CodeInterpreterTool(tool_config={"type": "code_interpreter"})] + self._assert_starting_agent_params(starting_agent, expected_params) + + # Verify streaming context received tool request and response updates + # Should have been called twice - once for tool request, once for response + assert mock_streaming_context.stream_update.call_count == 2 + + # First call should be tool request + first_call = mock_streaming_context.stream_update.call_args_list[0] + first_update = first_call[1]["update"] # keyword argument + assert hasattr(first_update, "content") + assert first_update.content.name == "code_interpreter" + assert first_update.content.tool_call_id == "code_interpreter_call_123" + + # Second call should be tool response + second_call = mock_streaming_context.stream_update.call_args_list[1] + second_update = second_call[1]["update"] # keyword argument + assert hasattr(second_update, "content") + assert second_update.content.name == "code_interpreter_call" + assert second_update.content.tool_call_id == "code_interpreter_call_123" + + def _create_mock_tracer(self): + """Helper method to create a properly mocked tracer with async context manager support.""" + mock_tracer = Mock() + mock_trace = Mock() + mock_span = Mock() + + # Setup the span context manager + async def mock_span_aenter(_): + return mock_span + + async def mock_span_aexit(_, _exc_type, _exc_val, _exc_tb): + return None + + mock_span.__aenter__ = mock_span_aenter + mock_span.__aexit__ = mock_span_aexit + mock_trace.span.return_value = mock_span + mock_tracer.trace.return_value = mock_trace + + return mock_tracer + + def _create_test_setup(self, mock_tracer): + """Helper method to create OpenAIService and OpenAIActivities instances.""" + # Import here to avoid circular imports + from agentex.lib.core.services.adk.providers.openai import OpenAIService + from agentex.lib.core.temporal.activities.adk.providers.openai_activities import OpenAIActivities + + openai_service = OpenAIService(tracer=mock_tracer) + openai_activities = OpenAIActivities(openai_service) + env = ActivityEnvironment() + + return openai_service, openai_activities, env + + def _assert_runner_call_signature(self, call_args): + """Helper method to validate Runner.run call signature.""" + actual_kwargs = set(call_args.kwargs.keys()) + + # Check that we only pass valid Runner.run parameters + valid_params = { + "starting_agent", + "input", + "context", + "max_turns", + "hooks", + "run_config", + "previous_response_id", + "session", + } + invalid_kwargs = actual_kwargs - valid_params + assert not invalid_kwargs, f"Invalid arguments passed to Runner.run: {invalid_kwargs}" + + # Verify required arguments are present + assert "starting_agent" in call_args.kwargs, "starting_agent is required for Runner.run" + assert "input" in call_args.kwargs, "input is required for Runner.run" + + # Verify starting_agent is not None (actual agent object created) + assert call_args.kwargs["starting_agent"] is not None, "starting_agent should not be None" + + def _assert_runner_call_signature_streamed(self, call_args): + """Helper method to validate Runner.run_streamed call signature.""" + actual_kwargs = set(call_args.kwargs.keys()) + + # Check that we only pass valid Runner.run_streamed parameters + valid_params = { + "starting_agent", + "input", + "context", + "max_turns", + "hooks", + "run_config", + "previous_response_id", + "session", + } + invalid_kwargs = actual_kwargs - valid_params + assert not invalid_kwargs, f"Invalid arguments passed to Runner.run_streamed: {invalid_kwargs}" + + # Verify required arguments are present + assert "starting_agent" in call_args.kwargs, "starting_agent is required for Runner.run_streamed" + assert "input" in call_args.kwargs, "input is required for Runner.run_streamed" + + # Verify starting_agent is not None (actual agent object created) + assert call_args.kwargs["starting_agent"] is not None, "starting_agent should not be None" + + def _assert_starting_agent_params(self, starting_agent, expected_params): + """Helper method to validate starting_agent parameters match expected values.""" + # Verify agent name and instructions match + assert starting_agent.name == expected_params.agent_name, f"Agent name should be {expected_params.agent_name}" + assert starting_agent.instructions == expected_params.agent_instructions, f"Agent instructions should match" + + # Note: Other agent parameters like tools, guardrails would be tested here + # but they require more complex inspection of the agent object + + def _assert_result_structure(self, result, expected_output="Hello! How can I help you today?"): + """Helper method to validate the result structure.""" + from agentex.lib.types.agent_results import SerializableRunResult + + assert isinstance(result, SerializableRunResult) + assert result.final_output == expected_output + assert len(result.final_input_list) == 2 + assert result.final_input_list[0]["role"] == "user" + assert result.final_input_list[1]["role"] == "assistant" + + def _create_tools_for_case(self, tools_case): + """Helper method to create tools based on test case.""" + from agentex.lib.core.temporal.activities.adk.providers.openai_activities import ( + ComputerTool, + FunctionTool, + WebSearchTool, + FileSearchTool, + LocalShellTool, + CodeInterpreterTool, + ImageGenerationTool, + ) + + def sample_tool_function(_context, args): + return f"Tool called with {args}" + + def sample_computer(): + return Mock() # Mock computer object + + def sample_safety_check(_data): + return True + + def sample_executor(): + return Mock() # Mock executor + + if tools_case == "no_tools": + return None + elif tools_case == "function_tool": + return [ + FunctionTool( + name="test_function", + description="A test function tool", + params_json_schema={"type": "object", "properties": {}}, + on_invoke_tool=sample_tool_function, + ) + ] + elif tools_case == "web_search_tool": + return [WebSearchTool()] + elif tools_case == "file_search_tool": + return [ + FileSearchTool(vector_store_ids=["store1", "store2"], max_num_results=10, include_search_results=True) + ] + elif tools_case == "computer_tool": + return [ComputerTool(computer=sample_computer(), on_safety_check=sample_safety_check)] + elif tools_case == "code_interpreter_tool": + return [ + CodeInterpreterTool( + tool_config={"type": "code_interpreter", "container": {"type": "static", "image": "python:3.11"}} + ) + ] + elif tools_case == "image_generation_tool": + return [ + ImageGenerationTool( + tool_config={ + "type": "image_generation", + "quality": "high", + "size": "1024x1024", + "output_format": "png", + } + ) + ] + elif tools_case == "local_shell_tool": + return [LocalShellTool(executor=sample_executor())] + elif tools_case == "mixed_tools": + return [ + FunctionTool( + name="calculator", + description="A calculator tool", + params_json_schema={"type": "object", "properties": {"expression": {"type": "string"}}}, + on_invoke_tool=sample_tool_function, + ), + WebSearchTool(), + FileSearchTool(vector_store_ids=["store1"], max_num_results=5), + ] + else: + raise ValueError(f"Unknown tools_case: {tools_case}") + + def _assert_tools_conversion(self, starting_agent, tools_case, _original_tools): + """Helper method to validate that tools were properly converted.""" + from agents.tool import ( + ComputerTool as OAIComputerTool, + FunctionTool as OAIFunctionTool, + WebSearchTool as OAIWebSearchTool, + FileSearchTool as OAIFileSearchTool, + LocalShellTool as OAILocalShellTool, + CodeInterpreterTool as OAICodeInterpreterTool, + ImageGenerationTool as OAIImageGenerationTool, + ) + + if tools_case == "no_tools": + # When no tools are provided, the agent should have an empty tools list + assert starting_agent.tools == [], "Agent should have empty tools list when no tools provided" + + elif tools_case == "function_tool": + assert len(starting_agent.tools) == 1, "Agent should have 1 tool" + agent_tool = starting_agent.tools[0] + assert isinstance(agent_tool, OAIFunctionTool), "Tool should be converted to OAIFunctionTool" + assert agent_tool.name == "test_function", "Tool name should be preserved" + assert agent_tool.description == "A test function tool", "Tool description should be preserved" + # Check that the schema contains our expected fields (may have additional fields) + assert "type" in agent_tool.params_json_schema, "Tool schema should have type field" + assert agent_tool.params_json_schema["type"] == "object", "Tool schema type should be object" + assert "properties" in agent_tool.params_json_schema, "Tool schema should have properties field" + assert callable(agent_tool.on_invoke_tool), "Tool function should be callable" + + elif tools_case == "web_search_tool": + assert len(starting_agent.tools) == 1, "Agent should have 1 tool" + agent_tool = starting_agent.tools[0] + assert isinstance(agent_tool, OAIWebSearchTool), "Tool should be converted to OAIWebSearchTool" + + elif tools_case == "file_search_tool": + assert len(starting_agent.tools) == 1, "Agent should have 1 tool" + agent_tool = starting_agent.tools[0] + assert isinstance(agent_tool, OAIFileSearchTool), "Tool should be converted to OAIFileSearchTool" + assert agent_tool.vector_store_ids == ["store1", "store2"], "Vector store IDs should be preserved" + assert agent_tool.max_num_results == 10, "Max results should be preserved" + assert agent_tool.include_search_results, "Include search results flag should be preserved" + + elif tools_case == "computer_tool": + assert len(starting_agent.tools) == 1, "Agent should have 1 tool" + agent_tool = starting_agent.tools[0] + assert isinstance(agent_tool, OAIComputerTool), "Tool should be converted to OAIComputerTool" + assert agent_tool.computer is not None, "Computer object should be present" + assert agent_tool.on_safety_check is not None, "Safety check function should be present" + + elif tools_case == "code_interpreter_tool": + assert len(starting_agent.tools) == 1, "Agent should have 1 tool" + agent_tool = starting_agent.tools[0] + assert isinstance(agent_tool, OAICodeInterpreterTool), "Tool should be converted to OAICodeInterpreterTool" + + elif tools_case == "image_generation_tool": + assert len(starting_agent.tools) == 1, "Agent should have 1 tool" + agent_tool = starting_agent.tools[0] + assert isinstance(agent_tool, OAIImageGenerationTool), "Tool should be converted to OAIImageGenerationTool" + + elif tools_case == "local_shell_tool": + assert len(starting_agent.tools) == 1, "Agent should have 1 tool" + agent_tool = starting_agent.tools[0] + assert isinstance(agent_tool, OAILocalShellTool), "Tool should be converted to OAILocalShellTool" + assert agent_tool.executor is not None, "Executor should be present" + + elif tools_case == "mixed_tools": + assert len(starting_agent.tools) == 3, "Agent should have 3 tools" + + # Check first tool (FunctionTool) + function_tool = starting_agent.tools[0] + assert isinstance(function_tool, OAIFunctionTool), "First tool should be OAIFunctionTool" + assert function_tool.name == "calculator", "Function tool name should be preserved" + + # Check second tool (WebSearchTool) + web_tool = starting_agent.tools[1] + assert isinstance(web_tool, OAIWebSearchTool), "Second tool should be OAIWebSearchTool" + + # Check third tool (FileSearchTool) + file_tool = starting_agent.tools[2] + assert isinstance(file_tool, OAIFileSearchTool), "Third tool should be OAIFileSearchTool" + + else: + raise ValueError(f"Unknown tools_case: {tools_case}") + + def _setup_streaming_service_mocks(self, openai_service): + """Helper method to setup streaming service mocks for run_agent_auto_send.""" + from unittest.mock import AsyncMock + + # Mock the streaming service and agentex client + mock_streaming_service = AsyncMock() + mock_agentex_client = AsyncMock() + + # Mock streaming context manager + mock_streaming_context = AsyncMock() + + # Create a proper TaskMessage mock that passes validation + from agentex.types.task_message import TaskMessage + + mock_task_message = Mock(spec=TaskMessage) + mock_task_message.id = "test-task-message-id" + mock_task_message.task_id = "test-task-id" + mock_task_message.content = {"type": "text", "content": "test"} + + mock_streaming_context.task_message = mock_task_message + mock_streaming_context.stream_update = AsyncMock() + + # Create a proper async context manager mock + from contextlib import asynccontextmanager + from unittest.mock import AsyncMock + + @asynccontextmanager + async def mock_streaming_context_manager(*_args, **_kwargs): + yield mock_streaming_context + + mock_streaming_service.streaming_task_message_context = mock_streaming_context_manager + + openai_service.streaming_service = mock_streaming_service + openai_service.agentex_client = mock_agentex_client + + return mock_streaming_context + + def _create_code_interpreter_tool_call_mock(self, call_id="code_interpreter_call_123"): + """Helper to create ResponseCodeInterpreterToolCall mock objects.""" + return ResponseCodeInterpreterToolCall( + id=call_id, + type="code_interpreter_call", + status="completed", + code="print('Hello from code interpreter')", + container_id="container_123", + outputs=[], + ) + + def _create_tool_call_item_mock(self, tool_call): + """Helper to create tool call item mock.""" + mock_tool_call_item = Mock() + mock_tool_call_item.type = "tool_call_item" + mock_tool_call_item.raw_item = tool_call + return mock_tool_call_item + + def _create_tool_output_item_mock(self, call_id="code_interpreter_call_123", output="Hello from code interpreter"): + """Helper to create tool output item mock.""" + mock_tool_output_item = Mock() + mock_tool_output_item.type = "tool_call_output_item" + mock_tool_output_item.raw_item = {"call_id": call_id, "output": output} + return mock_tool_output_item + + def _create_streaming_result_mock(self, final_output="Code executed successfully"): + """Helper to create streaming result mock with common setup.""" + mock_streaming_result = Mock(spec=RunResultStreaming) + mock_streaming_result.final_output = final_output + mock_streaming_result.new_items = [] + mock_streaming_result.final_input_list = [ + {"role": "user", "content": "Run some Python code"}, + {"role": "assistant", "content": final_output}, + ] + mock_streaming_result.to_input_list.return_value = [ + {"role": "user", "content": "Run some Python code"}, + {"role": "assistant", "content": final_output}, + ] + return mock_streaming_result + + def _create_common_agent_params(self, **overrides): + """Helper to create common agent parameters with defaults.""" + defaults = { + "input_list": [{"role": "user", "content": "Run some Python code"}], + "mcp_server_params": [], + "agent_name": "test_agent", + "agent_instructions": "You are a helpful assistant with code interpreter", + "trace_id": "test-trace-id", + "parent_span_id": "test-span-id", + "task_id": "test-task-id", + } + defaults.update(overrides) + return defaults