Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 49 additions & 7 deletions temporalio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ async def connect(
namespace: str = "default",
api_key: Optional[str] = None,
data_converter: temporalio.converter.DataConverter = temporalio.converter.DataConverter.default,
plugins: Sequence[Plugin] = [],
interceptors: Sequence[Interceptor] = [],
default_workflow_query_reject_condition: Optional[
temporalio.common.QueryRejectCondition
Expand Down Expand Up @@ -178,13 +179,21 @@ async def connect(
runtime=runtime,
http_connect_proxy_config=http_connect_proxy_config,
)

root_plugin: Plugin = _RootPlugin()
for plugin in reversed(list(plugins)):
root_plugin = plugin.init_client_plugin(root_plugin)

service_client = await root_plugin.connect_service_client(connect_config)

return Client(
await temporalio.service.ServiceClient.connect(connect_config),
service_client,
namespace=namespace,
data_converter=data_converter,
interceptors=interceptors,
default_workflow_query_reject_condition=default_workflow_query_reject_condition,
header_codec_behavior=header_codec_behavior,
plugins=plugins,
)

def __init__(
Expand All @@ -193,6 +202,7 @@ def __init__(
*,
namespace: str = "default",
data_converter: temporalio.converter.DataConverter = temporalio.converter.DataConverter.default,
plugins: Sequence[Plugin] = [],
interceptors: Sequence[Interceptor] = [],
default_workflow_query_reject_condition: Optional[
temporalio.common.QueryRejectCondition
Expand All @@ -203,21 +213,28 @@ def __init__(

See :py:meth:`connect` for details on the parameters.
"""
# Iterate over interceptors in reverse building the impl
self._impl: OutboundInterceptor = _ClientImpl(self)
for interceptor in reversed(list(interceptors)):
self._impl = interceptor.intercept_client(self._impl)

# Store the config for tracking
self._config = ClientConfig(
config = ClientConfig(
service_client=service_client,
namespace=namespace,
data_converter=data_converter,
interceptors=interceptors,
default_workflow_query_reject_condition=default_workflow_query_reject_condition,
header_codec_behavior=header_codec_behavior,
plugins=plugins,
)

root_plugin: Plugin = _RootPlugin()
for plugin in reversed(list(plugins)):
root_plugin = plugin.init_client_plugin(root_plugin)

self._config = root_plugin.on_create_client(config)

# Iterate over interceptors in reverse building the impl
self._impl: OutboundInterceptor = _ClientImpl(self)
for interceptor in reversed(list(interceptors)):
self._impl = interceptor.intercept_client(self._impl)

def config(self) -> ClientConfig:
"""Config, as a dictionary, used to create this client.

Expand Down Expand Up @@ -1510,6 +1527,7 @@ class ClientConfig(TypedDict, total=False):
Optional[temporalio.common.QueryRejectCondition]
]
header_codec_behavior: Required[HeaderCodecBehavior]
plugins: Required[Sequence[Plugin]]


class WorkflowHistoryEventFilterType(IntEnum):
Expand Down Expand Up @@ -7367,3 +7385,27 @@ async def _decode_user_metadata(
if not metadata.HasField("details")
else (await converter.decode([metadata.details]))[0],
)


class Plugin:
def init_client_plugin(self, next: Plugin) -> Plugin:
self.next_client_plugin = next
return self

def on_create_client(self, config: ClientConfig) -> ClientConfig:
return self.next_client_plugin.on_create_client(config)

async def connect_service_client(
self, config: temporalio.service.ConnectConfig
) -> temporalio.service.ServiceClient:
return await self.next_client_plugin.connect_service_client(config)


class _RootPlugin(Plugin):
def on_create_client(self, config: ClientConfig) -> ClientConfig:
return config

async def connect_service_client(
self, config: temporalio.service.ConnectConfig
) -> temporalio.service.ServiceClient:
return await temporalio.service.ServiceClient.connect(config)
10 changes: 2 additions & 8 deletions temporalio/contrib/openai_agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,18 @@
Use with caution in production environments.
"""

from temporalio.contrib.openai_agents._invoke_model_activity import ModelActivity
from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters
from temporalio.contrib.openai_agents._trace_interceptor import (
OpenAIAgentsTracingInterceptor,
)
from temporalio.contrib.openai_agents.temporal_openai_agents import (
Plugin,
TestModel,
TestModelProvider,
set_open_ai_agent_temporal_overrides,
workflow,
)

__all__ = [
"ModelActivity",
"Plugin",
"ModelActivityParameters",
"workflow",
"set_open_ai_agent_temporal_overrides",
"OpenAIAgentsTracingInterceptor",
"TestModel",
"TestModelProvider",
]
47 changes: 38 additions & 9 deletions temporalio/contrib/openai_agents/temporal_openai_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
import json
from contextlib import contextmanager
from datetime import timedelta
from typing import Any, AsyncIterator, Callable, Optional, Union, overload
from typing import Any, AsyncIterator, Callable, Optional, Union

from agents import (
Agent,
AgentOutputSchemaBase,
Handoff,
Model,
Expand All @@ -19,31 +18,34 @@
TResponseInputItem,
set_trace_provider,
)
from agents.function_schema import DocstringStyle, function_schema
from agents.function_schema import function_schema
from agents.items import TResponseStreamEvent
from agents.run import get_default_agent_runner, set_default_agent_runner
from agents.tool import (
FunctionTool,
ToolErrorFunction,
ToolFunction,
ToolParams,
default_tool_error_function,
function_tool,
)
from agents.tracing import get_trace_provider
from agents.tracing.provider import DefaultTraceProvider
from agents.util._types import MaybeAwaitable
from openai.types.responses import ResponsePromptParam

import temporalio.client
import temporalio.worker
from temporalio import activity
from temporalio import workflow as temporal_workflow
from temporalio.client import ClientConfig
from temporalio.common import Priority, RetryPolicy
from temporalio.contrib.openai_agents._invoke_model_activity import ModelActivity
from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters
from temporalio.contrib.openai_agents._openai_runner import TemporalOpenAIRunner
from temporalio.contrib.openai_agents._temporal_trace_provider import (
TemporalTraceProvider,
)
from temporalio.contrib.openai_agents._trace_interceptor import (
OpenAIAgentsTracingInterceptor,
)
from temporalio.contrib.pydantic import pydantic_data_converter
from temporalio.exceptions import ApplicationError, TemporalError
from temporalio.worker import Worker, WorkerConfig
from temporalio.workflow import ActivityCancellationType, VersioningIntent


Expand Down Expand Up @@ -154,6 +156,33 @@ def stream_response(
raise NotImplementedError()


class Plugin(temporalio.client.Plugin, temporalio.worker.Plugin):
def __init__(
self,
model_params: Optional[ModelActivityParameters] = None,
model_provider: Optional[ModelProvider] = None,
) -> None:
self._model_params = model_params
self._model_provider = model_provider

def on_create_client(self, config: ClientConfig) -> ClientConfig:
config["data_converter"] = pydantic_data_converter
return super().on_create_client(config)

def on_create_worker(self, config: WorkerConfig) -> WorkerConfig:
config["interceptors"] = list(config.get("interceptors") or []) + [
OpenAIAgentsTracingInterceptor()
]
config["activities"] = list(config.get("activities") or []) + [
ModelActivity(self._model_provider).invoke_model_activity
]
return super().on_create_worker(config)

async def run_worker(self, worker: Worker) -> None:
with set_open_ai_agent_temporal_overrides(self._model_params):
await super().run_worker(worker)


class ToolSerializationError(TemporalError):
"""Error that occurs when a tool output could not be serialized."""

Expand Down
2 changes: 2 additions & 0 deletions temporalio/worker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
WorkflowSlotInfo,
)
from ._worker import (
Plugin,
PollerBehavior,
PollerBehaviorAutoscaling,
PollerBehaviorSimpleMaximum,
Expand Down Expand Up @@ -78,6 +79,7 @@
"ActivityOutboundInterceptor",
"WorkflowInboundInterceptor",
"WorkflowOutboundInterceptor",
"Plugin",
# Interceptor input
"ContinueAsNewInput",
"ExecuteActivityInput",
Expand Down
Loading
Loading