diff --git a/temporalio/contrib/openai_agents/__init__.py b/temporalio/contrib/openai_agents/__init__.py index aedf6eb85..c9da59497 100644 --- a/temporalio/contrib/openai_agents/__init__.py +++ b/temporalio/contrib/openai_agents/__init__.py @@ -8,25 +8,19 @@ Use with caution in production environments. """ -from temporalio.contrib.openai_agents._invoke_model_activity import ModelActivity from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters -from temporalio.contrib.openai_agents._trace_interceptor import ( - OpenAIAgentsTracingInterceptor, -) from temporalio.contrib.openai_agents.temporal_openai_agents import ( + OpenAIAgentsPlugin, TestModel, TestModelProvider, - set_open_ai_agent_temporal_overrides, ) from . import workflow __all__ = [ - "ModelActivity", + "OpenAIAgentsPlugin", "ModelActivityParameters", "workflow", - "set_open_ai_agent_temporal_overrides", - "OpenAIAgentsTracingInterceptor", "TestModel", "TestModelProvider", ] diff --git a/temporalio/contrib/openai_agents/temporal_openai_agents.py b/temporalio/contrib/openai_agents/temporal_openai_agents.py index 6137a3c9d..0d7c5f968 100644 --- a/temporalio/contrib/openai_agents/temporal_openai_agents.py +++ b/temporalio/contrib/openai_agents/temporal_openai_agents.py @@ -21,11 +21,20 @@ from agents.tracing.provider import DefaultTraceProvider from openai.types.responses import ResponsePromptParam +import temporalio.client +import temporalio.worker +from temporalio.client import ClientConfig +from temporalio.contrib.openai_agents._invoke_model_activity import ModelActivity from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters from temporalio.contrib.openai_agents._openai_runner import TemporalOpenAIRunner from temporalio.contrib.openai_agents._temporal_trace_provider import ( TemporalTraceProvider, ) +from temporalio.contrib.openai_agents._trace_interceptor import ( + OpenAIAgentsTracingInterceptor, +) +from temporalio.contrib.pydantic import pydantic_data_converter +from temporalio.worker import Worker, WorkerConfig @contextmanager @@ -133,3 +142,121 @@ def stream_response( ) -> AsyncIterator[TResponseStreamEvent]: """Get a streamed response from the model. Unimplemented.""" raise NotImplementedError() + + +class OpenAIAgentsPlugin(temporalio.client.Plugin, temporalio.worker.Plugin): + """Temporal plugin for integrating OpenAI agents with Temporal workflows. + + .. warning:: + This class is experimental and may change in future versions. + Use with caution in production environments. + + This plugin provides seamless integration between the OpenAI Agents SDK and + Temporal workflows. It automatically configures the necessary interceptors, + activities, and data converters to enable OpenAI agents to run within + Temporal workflows with proper tracing and model execution. + + The plugin: + 1. Configures the Pydantic data converter for type-safe serialization + 2. Sets up tracing interceptors for OpenAI agent interactions + 3. Registers model execution activities + 4. Manages the OpenAI agent runtime overrides during worker execution + + Args: + model_params: Configuration parameters for Temporal activity execution + of model calls. If None, default parameters will be used. + model_provider: Optional model provider for custom model implementations. + Useful for testing or custom model integrations. + + Example: + >>> from temporalio.client import Client + >>> from temporalio.worker import Worker + >>> from temporalio.contrib.openai_agents import OpenAIAgentsPlugin, ModelActivityParameters + >>> from datetime import timedelta + >>> + >>> # Configure model parameters + >>> model_params = ModelActivityParameters( + ... start_to_close_timeout=timedelta(seconds=30), + ... retry_policy=RetryPolicy(maximum_attempts=3) + ... ) + >>> + >>> # Create plugin + >>> plugin = OpenAIAgentsPlugin(model_params=model_params) + >>> + >>> # Use with client and worker + >>> client = await Client.connect( + ... "localhost:7233", + ... plugins=[plugin] + ... ) + >>> worker = Worker( + ... client, + ... task_queue="my-task-queue", + ... workflows=[MyWorkflow], + ... ) + """ + + def __init__( + self, + model_params: Optional[ModelActivityParameters] = None, + model_provider: Optional[ModelProvider] = None, + ) -> None: + """Initialize the OpenAI agents plugin. + + Args: + model_params: Configuration parameters for Temporal activity execution + of model calls. If None, default parameters will be used. + model_provider: Optional model provider for custom model implementations. + Useful for testing or custom model integrations. + """ + self._model_params = model_params + self._model_provider = model_provider + + def configure_client(self, config: ClientConfig) -> ClientConfig: + """Configure the Temporal client for OpenAI agents integration. + + This method sets up the Pydantic data converter to enable proper + serialization of OpenAI agent objects and responses. + + Args: + config: The client configuration to modify. + + Returns: + The modified client configuration. + """ + config["data_converter"] = pydantic_data_converter + return super().configure_client(config) + + def configure_worker(self, config: WorkerConfig) -> WorkerConfig: + """Configure the Temporal worker for OpenAI agents integration. + + This method adds the necessary interceptors and activities for OpenAI + agent execution: + - Adds tracing interceptors for OpenAI agent interactions + - Registers model execution activities + + Args: + config: The worker configuration to modify. + + Returns: + The modified worker configuration. + """ + config["interceptors"] = list(config.get("interceptors") or []) + [ + OpenAIAgentsTracingInterceptor() + ] + config["activities"] = list(config.get("activities") or []) + [ + ModelActivity(self._model_provider).invoke_model_activity + ] + return super().configure_worker(config) + + async def run_worker(self, worker: Worker) -> None: + """Run the worker with OpenAI agents temporal overrides. + + This method sets up the necessary runtime overrides for OpenAI agents + to work within the Temporal worker context, including custom runners + and trace providers. + + Args: + worker: The worker instance to run. + """ + with set_open_ai_agent_temporal_overrides(self._model_params): + await super().run_worker(worker) diff --git a/temporalio/contrib/openai_agents/workflow.py b/temporalio/contrib/openai_agents/workflow.py index 50eba0b9e..35d7c0311 100644 --- a/temporalio/contrib/openai_agents/workflow.py +++ b/temporalio/contrib/openai_agents/workflow.py @@ -240,4 +240,26 @@ async def run_operation(ctx: RunContextWrapper[Any], input: str) -> Any: class ToolSerializationError(TemporalError): - """Error that occurs when a tool output could not be serialized.""" + """Error that occurs when a tool output could not be serialized. + + .. warning:: + This exception is experimental and may change in future versions. + Use with caution in production environments. + + This exception is raised when a tool (created from an activity or Nexus operation) + returns a value that cannot be properly serialized for use by the OpenAI agent. + All tool outputs must be convertible to strings for the agent to process them. + + The error typically occurs when: + - A tool returns a complex object that doesn't have a meaningful string representation + - The returned object cannot be converted using str() + - Custom serialization is needed but not implemented + + Example: + >>> @activity.defn + >>> def problematic_tool() -> ComplexObject: + ... return ComplexObject() # This might cause ToolSerializationError + + To fix this error, ensure your tool returns string-convertible values or + modify the tool to return a string representation of the result. + """ diff --git a/tests/contrib/openai_agents/test_openai.py b/tests/contrib/openai_agents/test_openai.py index b6e5b3dbf..088618036 100644 --- a/tests/contrib/openai_agents/test_openai.py +++ b/tests/contrib/openai_agents/test_openai.py @@ -51,14 +51,10 @@ from temporalio.client import Client, WorkflowFailureError, WorkflowHandle from temporalio.contrib import openai_agents from temporalio.contrib.openai_agents import ( - ModelActivity, ModelActivityParameters, - OpenAIAgentsTracingInterceptor, TestModel, TestModelProvider, - set_open_ai_agent_temporal_overrides, ) -from temporalio.contrib.pydantic import pydantic_data_converter from temporalio.exceptions import CancelledError from temporalio.testing import WorkflowEnvironment from tests.contrib.openai_agents.research_agents.research_manager import ( @@ -127,26 +123,28 @@ async def test_hello_world_agent(client: Client, use_local_model: bool): if not use_local_model and not os.environ.get("OPENAI_API_KEY"): pytest.skip("No openai API key") new_config = client.config() - new_config["data_converter"] = pydantic_data_converter + new_config["plugins"] = [ + openai_agents.OpenAIAgentsPlugin( + model_params=ModelActivityParameters( + start_to_close_timeout=timedelta(seconds=30) + ), + model_provider=TestModelProvider(TestHelloModel()) + if use_local_model + else None, + ) + ] client = Client(**new_config) - model_params = ModelActivityParameters(start_to_close_timeout=timedelta(seconds=30)) - with set_open_ai_agent_temporal_overrides(model_params): - model_activity = ModelActivity( - TestModelProvider(TestHelloModel()) if use_local_model else None + async with new_worker(client, HelloWorldAgent) as worker: + result = await client.execute_workflow( + HelloWorldAgent.run, + "Tell me about recursion in programming.", + id=f"hello-workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=5), ) - async with new_worker( - client, HelloWorldAgent, activities=[model_activity.invoke_model_activity] - ) as worker: - result = await client.execute_workflow( - HelloWorldAgent.run, - "Tell me about recursion in programming.", - id=f"hello-workflow-{uuid.uuid4()}", - task_queue=worker.task_queue, - execution_timeout=timedelta(seconds=5), - ) - if use_local_model: - assert result == "test" + if use_local_model: + assert result == "test" @dataclass @@ -413,116 +411,113 @@ async def test_tool_workflow(client: Client, use_local_model: bool): if not use_local_model and not os.environ.get("OPENAI_API_KEY"): pytest.skip("No openai API key") new_config = client.config() - new_config["data_converter"] = pydantic_data_converter + new_config["plugins"] = [ + openai_agents.OpenAIAgentsPlugin( + model_params=ModelActivityParameters( + start_to_close_timeout=timedelta(seconds=30) + ), + model_provider=TestModelProvider(TestWeatherModel()) + if use_local_model + else None, + ) + ] client = Client(**new_config) - model_params = ModelActivityParameters(start_to_close_timeout=timedelta(seconds=30)) - with set_open_ai_agent_temporal_overrides(model_params): - model_activity = ModelActivity( - TestModelProvider( - TestWeatherModel( # type: ignore - ) - ) - if use_local_model - else None + async with new_worker( + client, + ToolsWorkflow, + activities=[ + get_weather, + get_weather_object, + get_weather_country, + get_weather_context, + ActivityWeatherService().get_weather_method, + ], + ) as worker: + workflow_handle = await client.start_workflow( + ToolsWorkflow.run, + "What is the weather in Tokio?", + id=f"tools-workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=30), ) - async with new_worker( - client, - ToolsWorkflow, - activities=[ - model_activity.invoke_model_activity, - get_weather, - get_weather_object, - get_weather_country, - get_weather_context, - ActivityWeatherService().get_weather_method, - ], - interceptors=[OpenAIAgentsTracingInterceptor()], - ) as worker: - workflow_handle = await client.start_workflow( - ToolsWorkflow.run, - "What is the weather in Tokio?", - id=f"tools-workflow-{uuid.uuid4()}", - task_queue=worker.task_queue, - execution_timeout=timedelta(seconds=30), + result = await workflow_handle.result() + + if use_local_model: + assert result == "Test weather result" + + events = [] + async for e in workflow_handle.fetch_history_events(): + if e.HasField("activity_task_completed_event_attributes"): + events.append(e) + + assert len(events) == 11 + assert ( + "function_call" + in events[0] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "Sunny with wind" + in events[1] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "function_call" + in events[2] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "Sunny with wind" + in events[3] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "function_call" + in events[4] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "Sunny with wind" + in events[5] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "function_call" + in events[6] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "Stormy" + in events[7] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "function_call" + in events[8] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "Sunny with wind" + in events[9] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "Test weather result" + in events[10] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() ) - result = await workflow_handle.result() - - if use_local_model: - assert result == "Test weather result" - - events = [] - async for e in workflow_handle.fetch_history_events(): - if e.HasField("activity_task_completed_event_attributes"): - events.append(e) - - assert len(events) == 11 - assert ( - "function_call" - in events[0] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "Sunny with wind" - in events[1] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "function_call" - in events[2] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "Sunny with wind" - in events[3] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "function_call" - in events[4] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "Sunny with wind" - in events[5] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "function_call" - in events[6] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "Stormy" - in events[7] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "function_call" - in events[8] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "Sunny with wind" - in events[9] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "Test weather result" - in events[10] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) @pytest.mark.parametrize("use_local_model", [True, False]) @@ -536,68 +531,63 @@ async def test_nexus_tool_workflow( pytest.skip("Nexus tests don't work with time-skipping server") new_config = client.config() - new_config["data_converter"] = pydantic_data_converter + new_config["plugins"] = [ + openai_agents.OpenAIAgentsPlugin( + model_params=ModelActivityParameters( + start_to_close_timeout=timedelta(seconds=30) + ), + model_provider=TestModelProvider(TestNexusWeatherModel()) + if use_local_model + else None, + ) + ] client = Client(**new_config) - model_params = ModelActivityParameters(start_to_close_timeout=timedelta(seconds=30)) - with set_open_ai_agent_temporal_overrides(model_params): - model_activity = ModelActivity( - TestModelProvider( - TestNexusWeatherModel( # type: ignore - ) - ) - if use_local_model - else None + async with new_worker( + client, + NexusToolsWorkflow, + nexus_service_handlers=[WeatherServiceHandler()], + ) as worker: + await create_nexus_endpoint(worker.task_queue, client) + + workflow_handle = await client.start_workflow( + NexusToolsWorkflow.run, + "What is the weather in Tokio?", + id=f"nexus-tools-workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=30), ) - async with new_worker( - client, - NexusToolsWorkflow, - activities=[ - model_activity.invoke_model_activity, - ], - nexus_service_handlers=[WeatherServiceHandler()], - interceptors=[OpenAIAgentsTracingInterceptor()], - ) as worker: - await create_nexus_endpoint(worker.task_queue, client) - - workflow_handle = await client.start_workflow( - NexusToolsWorkflow.run, - "What is the weather in Tokio?", - id=f"nexus-tools-workflow-{uuid.uuid4()}", - task_queue=worker.task_queue, - execution_timeout=timedelta(seconds=30), + result = await workflow_handle.result() + + if use_local_model: + assert result == "Test nexus weather result" + + events = [] + async for e in workflow_handle.fetch_history_events(): + if e.HasField("activity_task_completed_event_attributes") or e.HasField( + "nexus_operation_completed_event_attributes" + ): + events.append(e) + + assert len(events) == 3 + assert ( + "function_call" + in events[0] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "Sunny with wind" + in events[ + 1 + ].nexus_operation_completed_event_attributes.result.data.decode() + ) + assert ( + "Test nexus weather result" + in events[2] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() ) - result = await workflow_handle.result() - - if use_local_model: - assert result == "Test nexus weather result" - - events = [] - async for e in workflow_handle.fetch_history_events(): - if e.HasField( - "activity_task_completed_event_attributes" - ) or e.HasField("nexus_operation_completed_event_attributes"): - events.append(e) - - assert len(events) == 3 - assert ( - "function_call" - in events[0] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "Sunny with wind" - in events[ - 1 - ].nexus_operation_completed_event_attributes.result.data.decode() - ) - assert ( - "Test nexus weather result" - in events[2] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) @no_type_check @@ -687,63 +677,60 @@ async def test_research_workflow(client: Client, use_local_model: bool): if not use_local_model and not os.environ.get("OPENAI_API_KEY"): pytest.skip("No openai API key") new_config = client.config() - new_config["data_converter"] = pydantic_data_converter + new_config["plugins"] = [ + openai_agents.OpenAIAgentsPlugin( + model_params=ModelActivityParameters( + start_to_close_timeout=timedelta(seconds=30) + ), + model_provider=TestModelProvider(TestResearchModel()) + if use_local_model + else None, + ) + ] client = Client(**new_config) - global response_index - response_index = 0 - - model_params = ModelActivityParameters( - start_to_close_timeout=timedelta(seconds=120) - ) - with set_open_ai_agent_temporal_overrides(model_params): - model_activity = ModelActivity( - TestModelProvider(TestResearchModel()) if use_local_model else None + async with new_worker( + client, + ResearchWorkflow, + ) as worker: + workflow_handle = await client.start_workflow( + ResearchWorkflow.run, + "Caribbean vacation spots in April, optimizing for surfing, hiking and water sports", + id=f"research-workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=120), ) - async with new_worker( - client, - ResearchWorkflow, - activities=[model_activity.invoke_model_activity, get_weather], - interceptors=[OpenAIAgentsTracingInterceptor()], - ) as worker: - workflow_handle = await client.start_workflow( - ResearchWorkflow.run, - "Caribbean vacation spots in April, optimizing for surfing, hiking and water sports", - id=f"research-workflow-{uuid.uuid4()}", - task_queue=worker.task_queue, - execution_timeout=timedelta(seconds=120), + result = await workflow_handle.result() + + if use_local_model: + assert result == "report" + + events = [] + async for e in workflow_handle.fetch_history_events(): + if e.HasField("activity_task_completed_event_attributes"): + events.append(e) + + assert len(events) == 12 + assert ( + '"type":"output_text"' + in events[0] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() ) - result = await workflow_handle.result() - - if use_local_model: - assert result == "report" - - events = [] - async for e in workflow_handle.fetch_history_events(): - if e.HasField("activity_task_completed_event_attributes"): - events.append(e) - - assert len(events) == 12 + for i in range(1, 11): assert ( - '"type":"output_text"' - in events[0] + "web_search_call" + in events[i] .activity_task_completed_event_attributes.result.payloads[0] .data.decode() ) - for i in range(1, 11): - assert ( - "web_search_call" - in events[i] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - '"type":"output_text"' - in events[11] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) + assert ( + '"type":"output_text"' + in events[11] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) def orchestrator_agent() -> Agent: @@ -904,67 +891,64 @@ async def test_agents_as_tools_workflow(client: Client, use_local_model: bool): if not use_local_model and not os.environ.get("OPENAI_API_KEY"): pytest.skip("No openai API key") new_config = client.config() - new_config["data_converter"] = pydantic_data_converter + new_config["plugins"] = [ + openai_agents.OpenAIAgentsPlugin( + model_params=ModelActivityParameters( + start_to_close_timeout=timedelta(seconds=30) + ), + model_provider=TestModelProvider(AgentAsToolsModel()) + if use_local_model + else None, + ) + ] client = Client(**new_config) - model_params = ModelActivityParameters(start_to_close_timeout=timedelta(seconds=30)) - with set_open_ai_agent_temporal_overrides(model_params): - model_activity = ModelActivity( - TestModelProvider( - AgentAsToolsModel( # type: ignore - ) - ) - if use_local_model - else None + async with new_worker( + client, + AgentsAsToolsWorkflow, + ) as worker: + workflow_handle = await client.start_workflow( + AgentsAsToolsWorkflow.run, + "Translate to Spanish: 'I am full'", + id=f"agents-as-tools-workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=30), ) - async with new_worker( - client, - AgentsAsToolsWorkflow, - activities=[model_activity.invoke_model_activity], - interceptors=[OpenAIAgentsTracingInterceptor()], - ) as worker: - workflow_handle = await client.start_workflow( - AgentsAsToolsWorkflow.run, - "Translate to Spanish: 'I am full'", - id=f"agents-as-tools-workflow-{uuid.uuid4()}", - task_queue=worker.task_queue, - execution_timeout=timedelta(seconds=30), + result = await workflow_handle.result() + + if use_local_model: + assert result == 'The translation to Spanish is: "Estoy lleno."' + + events = [] + async for e in workflow_handle.fetch_history_events(): + if e.HasField("activity_task_completed_event_attributes"): + events.append(e) + + assert len(events) == 4 + assert ( + "function_call" + in events[0] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "Estoy lleno" + in events[1] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "The translation to Spanish is:" + in events[2] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "The translation to Spanish is:" + in events[3] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() ) - result = await workflow_handle.result() - - if use_local_model: - assert result == 'The translation to Spanish is: "Estoy lleno."' - - events = [] - async for e in workflow_handle.fetch_history_events(): - if e.HasField("activity_task_completed_event_attributes"): - events.append(e) - - assert len(events) == 4 - assert ( - "function_call" - in events[0] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "Estoy lleno" - in events[1] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "The translation to Spanish is:" - in events[2] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "The translation to Spanish is:" - in events[3] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) class AirlineAgentContext(BaseModel): @@ -1259,97 +1243,94 @@ async def test_customer_service_workflow(client: Client, use_local_model: bool): if not use_local_model and not os.environ.get("OPENAI_API_KEY"): pytest.skip("No openai API key") new_config = client.config() - new_config["data_converter"] = pydantic_data_converter + new_config["plugins"] = [ + openai_agents.OpenAIAgentsPlugin( + model_params=ModelActivityParameters( + start_to_close_timeout=timedelta(seconds=30) + ), + model_provider=TestModelProvider(CustomerServiceModel()) + if use_local_model + else None, + ) + ] client = Client(**new_config) questions = ["Hello", "Book me a flight to PDX", "11111", "Any window seat"] - model_params = ModelActivityParameters(start_to_close_timeout=timedelta(seconds=30)) - with set_open_ai_agent_temporal_overrides(model_params): - model_activity = ModelActivity( - TestModelProvider( - CustomerServiceModel( # type: ignore - ) - ) - if use_local_model - else None + async with new_worker( + client, + CustomerServiceWorkflow, + ) as worker: + workflow_handle = await client.start_workflow( + CustomerServiceWorkflow.run, + id=f"customer-service-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=30), ) - async with new_worker( - client, - CustomerServiceWorkflow, - activities=[model_activity.invoke_model_activity], - interceptors=[OpenAIAgentsTracingInterceptor()], - ) as worker: - workflow_handle = await client.start_workflow( - CustomerServiceWorkflow.run, - id=f"customer-service-{uuid.uuid4()}", - task_queue=worker.task_queue, - execution_timeout=timedelta(seconds=30), + history: list[Any] = [] + for q in questions: + message_input = ProcessUserMessageInput( + user_input=q, chat_length=len(history) + ) + new_history = await workflow_handle.execute_update( + CustomerServiceWorkflow.process_user_message, message_input + ) + history.extend(new_history) + print(*new_history, sep="\n") + + await workflow_handle.cancel() + + with pytest.raises(WorkflowFailureError) as err: + await workflow_handle.result() + assert isinstance(err.value.cause, CancelledError) + + if use_local_model: + events = [] + async for e in WorkflowHandle( + client, + workflow_handle.id, + run_id=workflow_handle._first_execution_run_id, + ).fetch_history_events(): + if e.HasField("activity_task_completed_event_attributes"): + events.append(e) + + assert len(events) == 6 + assert ( + "Hi there! How can I assist you today?" + in events[0] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "transfer_to_seat_booking_agent" + in events[1] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "Could you please provide your confirmation number?" + in events[2] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "Thanks! What seat number would you like to change to?" + in events[3] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "update_seat" + in events[4] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "Your seat has been updated to a window seat. If there's anything else you need, feel free to let me know!" + in events[5] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() ) - history: list[Any] = [] - for q in questions: - message_input = ProcessUserMessageInput( - user_input=q, chat_length=len(history) - ) - new_history = await workflow_handle.execute_update( - CustomerServiceWorkflow.process_user_message, message_input - ) - history.extend(new_history) - print(*new_history, sep="\n") - - await workflow_handle.cancel() - - with pytest.raises(WorkflowFailureError) as err: - await workflow_handle.result() - assert isinstance(err.value.cause, CancelledError) - - if use_local_model: - events = [] - async for e in WorkflowHandle( - client, - workflow_handle.id, - run_id=workflow_handle._first_execution_run_id, - ).fetch_history_events(): - if e.HasField("activity_task_completed_event_attributes"): - events.append(e) - - assert len(events) == 6 - assert ( - "Hi there! How can I assist you today?" - in events[0] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "transfer_to_seat_booking_agent" - in events[1] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "Could you please provide your confirmation number?" - in events[2] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "Thanks! What seat number would you like to change to?" - in events[3] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "update_seat" - in events[4] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "Your seat has been updated to a window seat. If there's anything else you need, feel free to let me know!" - in events[5] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) guardrail_response_index: int = 0 @@ -1552,42 +1533,40 @@ async def test_input_guardrail(client: Client, use_local_model: bool): if not use_local_model and not os.environ.get("OPENAI_API_KEY"): pytest.skip("No openai API key") new_config = client.config() - new_config["data_converter"] = pydantic_data_converter - client = Client(**new_config) - - model_params = ModelActivityParameters(start_to_close_timeout=timedelta(seconds=30)) - with set_open_ai_agent_temporal_overrides(model_params): - model_activity = ModelActivity( - TestModelProvider( - InputGuardrailModel( # type: ignore - "", openai_client=AsyncOpenAI(api_key="Fake key") - ) + new_config["plugins"] = [ + openai_agents.OpenAIAgentsPlugin( + model_params=ModelActivityParameters( + start_to_close_timeout=timedelta(seconds=30) + ), + model_provider=TestModelProvider( + InputGuardrailModel("", openai_client=AsyncOpenAI(api_key="Fake key")) ) if use_local_model - else None + else None, ) - async with new_worker( - client, - InputGuardrailWorkflow, - activities=[model_activity.invoke_model_activity], - interceptors=[OpenAIAgentsTracingInterceptor()], - ) as worker: - workflow_handle = await client.start_workflow( - InputGuardrailWorkflow.run, - [ - "What's the capital of California?", - "Can you help me solve for x: 2x + 5 = 11", - ], - id=f"input-guardrail-{uuid.uuid4()}", - task_queue=worker.task_queue, - execution_timeout=timedelta(seconds=10), - ) - result = await workflow_handle.result() + ] + client = Client(**new_config) + + async with new_worker( + client, + InputGuardrailWorkflow, + ) as worker: + workflow_handle = await client.start_workflow( + InputGuardrailWorkflow.run, + [ + "What's the capital of California?", + "Can you help me solve for x: 2x + 5 = 11", + ], + id=f"input-guardrail-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=10), + ) + result = await workflow_handle.result() - if use_local_model: - assert len(result) == 2 - assert result[0] == "The capital of California is Sacramento." - assert result[1] == "Sorry, I can't help you with your math homework." + if use_local_model: + assert len(result) == 2 + assert result[0] == "The capital of California is Sacramento." + assert result[1] == "Sorry, I can't help you with your math homework." class OutputGuardrailModel(StaticTestModel): @@ -1669,35 +1648,32 @@ async def test_output_guardrail(client: Client, use_local_model: bool): if not use_local_model and not os.environ.get("OPENAI_API_KEY"): pytest.skip("No openai API key") new_config = client.config() - new_config["data_converter"] = pydantic_data_converter + new_config["plugins"] = [ + openai_agents.OpenAIAgentsPlugin( + model_params=ModelActivityParameters( + start_to_close_timeout=timedelta(seconds=30) + ), + model_provider=TestModelProvider(OutputGuardrailModel()) + if use_local_model + else None, + ) + ] client = Client(**new_config) - model_params = ModelActivityParameters(start_to_close_timeout=timedelta(seconds=30)) - with set_open_ai_agent_temporal_overrides(model_params): - model_activity = ModelActivity( - TestModelProvider( - OutputGuardrailModel( # type: ignore - ) - ) - if use_local_model - else None + async with new_worker( + client, + OutputGuardrailWorkflow, + ) as worker: + workflow_handle = await client.start_workflow( + OutputGuardrailWorkflow.run, + id=f"output-guardrail-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=10), ) - async with new_worker( - client, - OutputGuardrailWorkflow, - activities=[model_activity.invoke_model_activity], - interceptors=[OpenAIAgentsTracingInterceptor()], - ) as worker: - workflow_handle = await client.start_workflow( - OutputGuardrailWorkflow.run, - id=f"output-guardrail-{uuid.uuid4()}", - task_queue=worker.task_queue, - execution_timeout=timedelta(seconds=10), - ) - result = await workflow_handle.result() + result = await workflow_handle.result() - if use_local_model: - assert not result + if use_local_model: + assert not result class WorkflowToolModel(StaticTestModel): @@ -1760,21 +1736,24 @@ async def run_tool(self): async def test_workflow_method_tools(client: Client): new_config = client.config() - new_config["data_converter"] = pydantic_data_converter + new_config["plugins"] = [ + openai_agents.OpenAIAgentsPlugin( + model_params=ModelActivityParameters( + start_to_close_timeout=timedelta(seconds=30) + ), + model_provider=TestModelProvider(WorkflowToolModel()), + ) + ] client = Client(**new_config) - with set_open_ai_agent_temporal_overrides(): - model_activity = ModelActivity(TestModelProvider(WorkflowToolModel())) - async with new_worker( - client, - WorkflowToolWorkflow, - activities=[model_activity.invoke_model_activity], - interceptors=[OpenAIAgentsTracingInterceptor()], - ) as worker: - workflow_handle = await client.start_workflow( - WorkflowToolWorkflow.run, - id=f"workflow-tool-{uuid.uuid4()}", - task_queue=worker.task_queue, - execution_timeout=timedelta(seconds=10), - ) - await workflow_handle.result() + async with new_worker( + client, + WorkflowToolWorkflow, + ) as worker: + workflow_handle = await client.start_workflow( + WorkflowToolWorkflow.run, + id=f"workflow-tool-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=10), + ) + await workflow_handle.result() diff --git a/tests/contrib/openai_agents/test_openai_tracing.py b/tests/contrib/openai_agents/test_openai_tracing.py index 5a7d03785..c8ad366e6 100644 --- a/tests/contrib/openai_agents/test_openai_tracing.py +++ b/tests/contrib/openai_agents/test_openai_tracing.py @@ -1,19 +1,15 @@ -import datetime import uuid from datetime import timedelta -from typing import Any, Optional +from typing import Any from agents import Span, Trace, TracingProcessor from agents.tracing import get_trace_provider from temporalio.client import Client +from temporalio.contrib import openai_agents from temporalio.contrib.openai_agents import ( - ModelActivity, - OpenAIAgentsTracingInterceptor, TestModelProvider, - set_open_ai_agent_temporal_overrides, ) -from temporalio.contrib.pydantic import pydantic_data_converter from tests.contrib.openai_agents.test_openai import ResearchWorkflow, TestResearchModel from tests.helpers import new_worker @@ -44,108 +40,101 @@ def force_flush(self) -> None: async def test_tracing(client: Client): new_config = client.config() - new_config["data_converter"] = pydantic_data_converter - client = Client(**new_config) - - with set_open_ai_agent_temporal_overrides(): - provider = get_trace_provider() - - processor = MemoryTracingProcessor() - provider.set_processors([processor]) - - model_activity = ModelActivity(TestModelProvider(TestResearchModel())) - async with new_worker( - client, - ResearchWorkflow, - activities=[model_activity.invoke_model_activity], - interceptors=[OpenAIAgentsTracingInterceptor()], - ) as worker: - workflow_handle = await client.start_workflow( - ResearchWorkflow.run, - "Caribbean vacation spots in April, optimizing for surfing, hiking and water sports", - id=f"research-workflow-{uuid.uuid4()}", - task_queue=worker.task_queue, - execution_timeout=timedelta(seconds=120), - ) - result = await workflow_handle.result() - - # There is one closed root trace - assert len(processor.trace_events) == 2 - assert ( - processor.trace_events[0][0].trace_id - == processor.trace_events[1][0].trace_id - ) - assert processor.trace_events[0][1] - assert not processor.trace_events[1][1] - - def paired_span(a: tuple[Span[Any], bool], b: tuple[Span[Any], bool]) -> None: - assert a[0].trace_id == b[0].trace_id - assert a[1] - assert not b[1] - - # Initial planner spans - There are only 3 because we don't make an actual model call - paired_span(processor.span_events[0], processor.span_events[5]) - assert ( - processor.span_events[0][0].span_data.export().get("name") == "PlannerAgent" - ) - - paired_span(processor.span_events[1], processor.span_events[4]) - assert ( - processor.span_events[1][0].span_data.export().get("name") - == "temporal:startActivity" + new_config["plugins"] = [ + openai_agents.OpenAIAgentsPlugin( + model_provider=TestModelProvider(TestResearchModel()) ) + ] + client = Client(**new_config) - paired_span(processor.span_events[2], processor.span_events[3]) - assert ( - processor.span_events[2][0].span_data.export().get("name") - == "temporal:executeActivity" + provider = get_trace_provider() + + processor = MemoryTracingProcessor() + provider.set_processors([processor]) + + async with new_worker( + client, + ResearchWorkflow, + ) as worker: + workflow_handle = await client.start_workflow( + ResearchWorkflow.run, + "Caribbean vacation spots in April, optimizing for surfing, hiking and water sports", + id=f"research-workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=120), ) + result = await workflow_handle.result() + + # There is one closed root trace + assert len(processor.trace_events) == 2 + assert ( + processor.trace_events[0][0].trace_id == processor.trace_events[1][0].trace_id + ) + assert processor.trace_events[0][1] + assert not processor.trace_events[1][1] + + def paired_span(a: tuple[Span[Any], bool], b: tuple[Span[Any], bool]) -> None: + assert a[0].trace_id == b[0].trace_id + assert a[1] + assert not b[1] + + # Initial planner spans - There are only 3 because we don't make an actual model call + paired_span(processor.span_events[0], processor.span_events[5]) + assert processor.span_events[0][0].span_data.export().get("name") == "PlannerAgent" + + paired_span(processor.span_events[1], processor.span_events[4]) + assert ( + processor.span_events[1][0].span_data.export().get("name") + == "temporal:startActivity" + ) + + paired_span(processor.span_events[2], processor.span_events[3]) + assert ( + processor.span_events[2][0].span_data.export().get("name") + == "temporal:executeActivity" + ) + + for span, start in processor.span_events[6:-6]: + span_data = span.span_data.export() + + # All spans should be closed + if start: + assert any( + span.span_id == s.span_id and not s_start + for (s, s_start) in processor.span_events + ) - for span, start in processor.span_events[6:-6]: - span_data = span.span_data.export() - - # All spans should be closed - if start: - assert any( - span.span_id == s.span_id and not s_start - for (s, s_start) in processor.span_events - ) - - # Start activity is always parented to an agent - if span_data.get("name") == "temporal:startActivity": - parents = [ - s for (s, _) in processor.span_events if s.span_id == span.parent_id - ] - assert ( - len(parents) == 2 - and parents[0].span_data.export()["type"] == "agent" - ) - - # Execute is parented to start - if span_data.get("name") == "temporal:executeActivity": - parents = [ - s for (s, _) in processor.span_events if s.span_id == span.parent_id - ] - assert ( - len(parents) == 2 - and parents[0].span_data.export()["name"] - == "temporal:startActivity" - ) - - # Final writer spans - There are only 3 because we don't make an actual model call - paired_span(processor.span_events[-6], processor.span_events[-1]) - assert ( - processor.span_events[-6][0].span_data.export().get("name") == "WriterAgent" - ) + # Start activity is always parented to an agent + if span_data.get("name") == "temporal:startActivity": + parents = [ + s for (s, _) in processor.span_events if s.span_id == span.parent_id + ] + assert ( + len(parents) == 2 and parents[0].span_data.export()["type"] == "agent" + ) - paired_span(processor.span_events[-5], processor.span_events[-2]) - assert ( - processor.span_events[-5][0].span_data.export().get("name") - == "temporal:startActivity" - ) + # Execute is parented to start + if span_data.get("name") == "temporal:executeActivity": + parents = [ + s for (s, _) in processor.span_events if s.span_id == span.parent_id + ] + assert ( + len(parents) == 2 + and parents[0].span_data.export()["name"] == "temporal:startActivity" + ) - paired_span(processor.span_events[-4], processor.span_events[-3]) - assert ( - processor.span_events[-4][0].span_data.export().get("name") - == "temporal:executeActivity" - ) + # Final writer spans - There are only 3 because we don't make an actual model call + paired_span(processor.span_events[-6], processor.span_events[-1]) + assert processor.span_events[-6][0].span_data.export().get("name") == "WriterAgent" + + paired_span(processor.span_events[-5], processor.span_events[-2]) + assert ( + processor.span_events[-5][0].span_data.export().get("name") + == "temporal:startActivity" + ) + + paired_span(processor.span_events[-4], processor.span_events[-3]) + assert ( + processor.span_events[-4][0].span_data.export().get("name") + == "temporal:executeActivity" + )