diff --git a/temporalio/contrib/openai_agents/__init__.py b/temporalio/contrib/openai_agents/__init__.py index 998cf61eb..2d9777aa8 100644 --- a/temporalio/contrib/openai_agents/__init__.py +++ b/temporalio/contrib/openai_agents/__init__.py @@ -11,6 +11,7 @@ from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters from temporalio.contrib.openai_agents._temporal_openai_agents import ( OpenAIAgentsPlugin, + OpenAIPayloadConverter, TestModel, TestModelProvider, ) @@ -23,9 +24,10 @@ __all__ = [ "AgentsWorkflowError", - "OpenAIAgentsPlugin", "ModelActivityParameters", - "workflow", + "OpenAIAgentsPlugin", + "OpenAIPayloadConverter", "TestModel", "TestModelProvider", + "workflow", ] diff --git a/temporalio/contrib/openai_agents/_temporal_openai_agents.py b/temporalio/contrib/openai_agents/_temporal_openai_agents.py index f553725cf..f8b8999a2 100644 --- a/temporalio/contrib/openai_agents/_temporal_openai_agents.py +++ b/temporalio/contrib/openai_agents/_temporal_openai_agents.py @@ -1,5 +1,6 @@ """Initialize Temporal OpenAI Agents overrides.""" +import dataclasses from contextlib import asynccontextmanager, contextmanager from datetime import timedelta from typing import AsyncIterator, Callable, Optional, Union @@ -43,6 +44,7 @@ ) from temporalio.converter import ( DataConverter, + DefaultPayloadConverter, ) from temporalio.worker import ( Replayer, @@ -148,8 +150,11 @@ def stream_response( raise NotImplementedError() -class _OpenAIPayloadConverter(PydanticPayloadConverter): +class OpenAIPayloadConverter(PydanticPayloadConverter): + """PayloadConverter for OpenAI agents.""" + def __init__(self) -> None: + """Initialize a payload converter.""" super().__init__(ToJsonOptions(exclude_unset=True)) @@ -250,6 +255,20 @@ def init_worker_plugin(self, next: temporalio.worker.Plugin) -> None: """Set the next worker plugin""" self.next_worker_plugin = next + @staticmethod + def _data_converter(converter: Optional[DataConverter]) -> DataConverter: + if converter is None: + return DataConverter(payload_converter_class=OpenAIPayloadConverter) + elif converter.payload_converter_class is DefaultPayloadConverter: + return dataclasses.replace( + converter, payload_converter_class=OpenAIPayloadConverter + ) + elif not isinstance(converter.payload_converter, OpenAIPayloadConverter): + raise ValueError( + "The payload converter must be of type OpenAIPayloadConverter." + ) + return converter + def configure_client(self, config: ClientConfig) -> ClientConfig: """Configure the Temporal client for OpenAI agents integration. @@ -262,9 +281,7 @@ def configure_client(self, config: ClientConfig) -> ClientConfig: Returns: The modified client configuration. """ - config["data_converter"] = DataConverter( - payload_converter_class=_OpenAIPayloadConverter - ) + config["data_converter"] = self._data_converter(config["data_converter"]) return self.next_client_plugin.configure_client(config) def configure_worker(self, config: WorkerConfig) -> WorkerConfig: @@ -310,9 +327,7 @@ def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: config["interceptors"] = list(config.get("interceptors") or []) + [ OpenAIAgentsTracingInterceptor() ] - config["data_converter"] = DataConverter( - payload_converter_class=_OpenAIPayloadConverter - ) + config["data_converter"] = self._data_converter(config.get("data_converter")) return self.next_worker_plugin.configure_replayer(config) @asynccontextmanager