Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions temporalio/contrib/openai_agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters
from temporalio.contrib.openai_agents._temporal_openai_agents import (
OpenAIAgentsPlugin,
OpenAIPayloadConverter,
TestModel,
TestModelProvider,
)
Expand All @@ -23,9 +24,10 @@

__all__ = [
"AgentsWorkflowError",
"OpenAIAgentsPlugin",
"ModelActivityParameters",
"workflow",
"OpenAIAgentsPlugin",
"OpenAIPayloadConverter",
"TestModel",
"TestModelProvider",
"workflow",
]
29 changes: 22 additions & 7 deletions temporalio/contrib/openai_agents/_temporal_openai_agents.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Initialize Temporal OpenAI Agents overrides."""

import dataclasses
from contextlib import asynccontextmanager, contextmanager
from datetime import timedelta
from typing import AsyncIterator, Callable, Optional, Union
Expand Down Expand Up @@ -43,6 +44,7 @@
)
from temporalio.converter import (
DataConverter,
DefaultPayloadConverter,
)
from temporalio.worker import (
Replayer,
Expand Down Expand Up @@ -148,8 +150,11 @@ def stream_response(
raise NotImplementedError()


class _OpenAIPayloadConverter(PydanticPayloadConverter):
class OpenAIPayloadConverter(PydanticPayloadConverter):
"""PayloadConverter for OpenAI agents."""

def __init__(self) -> None:
"""Initialize a payload converter."""
super().__init__(ToJsonOptions(exclude_unset=True))


Expand Down Expand Up @@ -250,6 +255,20 @@ def init_worker_plugin(self, next: temporalio.worker.Plugin) -> None:
"""Set the next worker plugin"""
self.next_worker_plugin = next

@staticmethod
def _data_converter(converter: Optional[DataConverter]) -> DataConverter:
if converter is None:
return DataConverter(payload_converter_class=OpenAIPayloadConverter)
elif isinstance(converter.payload_converter, DefaultPayloadConverter):
return dataclasses.replace(
converter, payload_converter_class=OpenAIPayloadConverter
)
elif not isinstance(converter.payload_converter, OpenAIPayloadConverter):
raise ValueError(
"The payload converter must be of type OpenAIPayloadConverter."
)
return converter

def configure_client(self, config: ClientConfig) -> ClientConfig:
"""Configure the Temporal client for OpenAI agents integration.

Expand All @@ -262,9 +281,7 @@ def configure_client(self, config: ClientConfig) -> ClientConfig:
Returns:
The modified client configuration.
"""
config["data_converter"] = DataConverter(
payload_converter_class=_OpenAIPayloadConverter
)
config["data_converter"] = self._data_converter(config["data_converter"])
return self.next_client_plugin.configure_client(config)

def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
Expand Down Expand Up @@ -310,9 +327,7 @@ def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig:
config["interceptors"] = list(config.get("interceptors") or []) + [
OpenAIAgentsTracingInterceptor()
]
config["data_converter"] = DataConverter(
payload_converter_class=_OpenAIPayloadConverter
)
config["data_converter"] = self._data_converter(config.get("data_converter"))
return self.next_worker_plugin.configure_replayer(config)

@asynccontextmanager
Expand Down
Loading