From defe0bbfcbb5c1fc6400439f3ba20f49d14b9e0f Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Wed, 9 Jul 2025 15:40:44 -0700 Subject: [PATCH 01/11] Initial rough framework for plugins --- temporalio/client.py | 46 +++++++++++++++++++++++++++++-- temporalio/worker/__init__.py | 2 ++ temporalio/worker/_worker.py | 39 +++++++++++++++++++++++++- tests/test_client.py | 29 +++++++++++++++++++ tests/worker/test_worker.py | 52 +++++++++++++++++++++++++++++++++-- 5 files changed, 163 insertions(+), 5 deletions(-) diff --git a/temporalio/client.py b/temporalio/client.py index 0aab85465..2042c9b4d 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -120,6 +120,7 @@ async def connect( runtime: Optional[temporalio.runtime.Runtime] = None, http_connect_proxy_config: Optional[HttpConnectProxyConfig] = None, header_codec_behavior: HeaderCodecBehavior = HeaderCodecBehavior.NO_CODEC, + plugins: Sequence[Plugin] = [], ) -> Client: """Connect to a Temporal server. @@ -178,13 +179,21 @@ async def connect( runtime=runtime, http_connect_proxy_config=http_connect_proxy_config, ) + + root_plugin: Plugin = _RootPlugin() + for plugin in reversed(list(plugins)): + root_plugin = plugin.init_client_plugin(root_plugin) + + service_client = await root_plugin.connect_service_client(connect_config) + return Client( - await temporalio.service.ServiceClient.connect(connect_config), + service_client, namespace=namespace, data_converter=data_converter, interceptors=interceptors, default_workflow_query_reject_condition=default_workflow_query_reject_condition, header_codec_behavior=header_codec_behavior, + plugins=plugins, ) def __init__( @@ -198,6 +207,7 @@ def __init__( temporalio.common.QueryRejectCondition ] = None, header_codec_behavior: HeaderCodecBehavior = HeaderCodecBehavior.NO_CODEC, + plugins: Sequence[Plugin] = [], ): """Create a Temporal client from a service client. @@ -209,15 +219,22 @@ def __init__( self._impl = interceptor.intercept_client(self._impl) # Store the config for tracking - self._config = ClientConfig( + config = ClientConfig( service_client=service_client, namespace=namespace, data_converter=data_converter, interceptors=interceptors, default_workflow_query_reject_condition=default_workflow_query_reject_condition, header_codec_behavior=header_codec_behavior, + plugins=plugins, ) + root_plugin: Plugin = _RootPlugin() + for plugin in reversed(list(plugins)): + root_plugin = plugin.init_client_plugin(root_plugin) + + self._config = root_plugin.on_create_client(config) + def config(self) -> ClientConfig: """Config, as a dictionary, used to create this client. @@ -1510,6 +1527,7 @@ class ClientConfig(TypedDict, total=False): Optional[temporalio.common.QueryRejectCondition] ] header_codec_behavior: Required[HeaderCodecBehavior] + plugins: Required[Sequence[Plugin]] class WorkflowHistoryEventFilterType(IntEnum): @@ -7367,3 +7385,27 @@ async def _decode_user_metadata( if not metadata.HasField("details") else (await converter.decode([metadata.details]))[0], ) + + +class Plugin: + def init_client_plugin(self, next: Plugin) -> Plugin: + self.next_client_plugin = next + return self + + def on_create_client(self, config: ClientConfig) -> ClientConfig: + return self.next_client_plugin.on_create_client(config) + + async def connect_service_client( + self, config: temporalio.service.ConnectConfig + ) -> temporalio.service.ServiceClient: + return await self.next_client_plugin.connect_service_client(config) + + +class _RootPlugin(Plugin): + def on_create_client(self, config: ClientConfig) -> ClientConfig: + return config + + async def connect_service_client( + self, config: temporalio.service.ConnectConfig + ) -> temporalio.service.ServiceClient: + return await temporalio.service.ServiceClient.connect(config) diff --git a/temporalio/worker/__init__.py b/temporalio/worker/__init__.py index 83a66e67f..6e062afcc 100644 --- a/temporalio/worker/__init__.py +++ b/temporalio/worker/__init__.py @@ -44,6 +44,7 @@ WorkflowSlotInfo, ) from ._worker import ( + Plugin, PollerBehavior, PollerBehaviorAutoscaling, PollerBehaviorSimpleMaximum, @@ -78,6 +79,7 @@ "ActivityOutboundInterceptor", "WorkflowInboundInterceptor", "WorkflowOutboundInterceptor", + "Plugin", # Interceptor input "ContinueAsNewInput", "ExecuteActivityInput", diff --git a/temporalio/worker/_worker.py b/temporalio/worker/_worker.py index 4d77e111e..3cf1d62bc 100644 --- a/temporalio/worker/_worker.py +++ b/temporalio/worker/_worker.py @@ -96,6 +96,26 @@ def _to_bridge(self) -> temporalio.bridge.worker.PollerBehavior: ] +class Plugin: + def init_worker_plugin(self, next: Plugin) -> Plugin: + self.next_worker_plugin = next + return self + + def on_create_worker(self, config: WorkerConfig) -> WorkerConfig: + return self.next_worker_plugin.on_create_worker(config) + + async def run_worker(self, worker: Worker) -> None: + await self.next_worker_plugin.run_worker(worker) + + +class _RootPlugin(Plugin): + def on_create_worker(self, config: WorkerConfig) -> WorkerConfig: + return config + + async def run_worker(self, worker: Worker) -> None: + await worker._run() + + class Worker: """Worker to process workflows and/or activities. @@ -153,6 +173,7 @@ def __init__( nexus_task_poller_behavior: PollerBehavior = PollerBehaviorSimpleMaximum( maximum=5 ), + plugins: Sequence[Plugin] = [], ) -> None: """Create a worker to process workflows and/or activities. @@ -343,11 +364,17 @@ def __init__( ) interceptors = interceptors_from_client + list(interceptors) + plugins_from_client = cast( + List[Plugin], [p for p in client_config["plugins"] if isinstance(p, Plugin)] + ) + plugins = plugins_from_client + list(plugins) + print(plugins) + # Extract the bridge service client bridge_client = _extract_bridge_client_for_worker(client) # Store the config for tracking - self._config = WorkerConfig( + config = WorkerConfig( client=client, task_queue=task_queue, activities=activities, @@ -382,6 +409,13 @@ def __init__( use_worker_versioning=use_worker_versioning, disable_safe_workflow_eviction=disable_safe_workflow_eviction, ) + + root_plugin: Plugin = _RootPlugin() + for plugin in reversed(list(plugins)): + root_plugin = plugin.init_worker_plugin(root_plugin) + self._config = root_plugin.on_create_worker(config) + self._plugin = root_plugin + self._started = False self._shutdown_event = asyncio.Event() self._shutdown_complete_event = asyncio.Event() @@ -646,6 +680,9 @@ async def run(self) -> None: also cancel the shutdown process. Therefore users are encouraged to use explicit shutdown instead. """ + await self._plugin.run_worker(self) + + async def _run(self): # Eagerly validate which will do a namespace check in Core await self._bridge_worker.validate() diff --git a/tests/test_client.py b/tests/test_client.py index 418d9ff53..0e718e6be 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -41,9 +41,11 @@ BuildIdOpPromoteSetByBuildId, CancelWorkflowInput, Client, + ClientConfig, CloudOperationsClient, Interceptor, OutboundInterceptor, + Plugin, QueryWorkflowInput, RPCError, RPCStatusCode, @@ -1499,3 +1501,30 @@ async def test_cloud_client_simple(): GetNamespaceRequest(namespace=os.environ["TEMPORAL_CLIENT_CLOUD_NAMESPACE"]) ) assert os.environ["TEMPORAL_CLIENT_CLOUD_NAMESPACE"] == result.namespace.namespace + + +class MyPlugin(Plugin): + def on_create_client(self, config: ClientConfig) -> ClientConfig: + config["namespace"] = "replaced_namespace" + return super().on_create_client(config) + + async def connect_service_client( + self, config: temporalio.service.ConnectConfig + ) -> temporalio.service.ServiceClient: + config.api_key = "replaced key" + return await super().connect_service_client(config) + + +async def test_client_plugin(client: Client, env: WorkflowEnvironment): + if env.supports_time_skipping: + pytest.skip("Client connect is only designed for local") + + config = client.config() + config["plugins"] = [MyPlugin()] + new_client = Client(**config) + assert new_client.namespace == "replaced_namespace" + + new_client = await Client.connect( + client.service_client.config.target_host, plugins=[MyPlugin()] + ) + assert new_client.service_client.config.api_key == "replaced key" diff --git a/tests/worker/test_worker.py b/tests/worker/test_worker.py index f1be74b4d..7a72729c7 100644 --- a/tests/worker/test_worker.py +++ b/tests/worker/test_worker.py @@ -2,13 +2,13 @@ import asyncio import concurrent.futures -import sys import uuid from datetime import timedelta from typing import Any, Awaitable, Callable, Optional, Sequence from urllib.request import urlopen import temporalio.api.enums.v1 +import temporalio.client import temporalio.worker._worker from temporalio import activity, workflow from temporalio.api.workflowservice.v1 import ( @@ -19,7 +19,11 @@ SetWorkerDeploymentRampingVersionRequest, SetWorkerDeploymentRampingVersionResponse, ) -from temporalio.client import BuildIdOpAddNewDefault, Client, TaskReachabilityType +from temporalio.client import ( + BuildIdOpAddNewDefault, + Client, + TaskReachabilityType, +) from temporalio.common import PinnedVersioningOverride, RawValue, VersioningBehavior from temporalio.runtime import PrometheusConfig, Runtime, TelemetryConfig from temporalio.service import RPCError @@ -38,6 +42,7 @@ SlotReleaseContext, SlotReserveContext, Worker, + WorkerConfig, WorkerDeploymentConfig, WorkerDeploymentVersion, WorkerTuner, @@ -1184,3 +1189,46 @@ def shutdown(self) -> None: if self.next_exception_task: self.next_exception_task.cancel() setattr(self.worker._bridge_worker, self.attr, self.orig_poll_call) + + +class MyCombinedPlugin(temporalio.client.Plugin, temporalio.worker.Plugin): + def on_create_worker(self, config: WorkerConfig) -> WorkerConfig: + print("Create worker combined plugin") + config["task_queue"] = "combined" + return super().on_create_worker(config) + + +class MyWorkerPlugin(temporalio.worker.Plugin): + def on_create_worker(self, config: WorkerConfig) -> WorkerConfig: + print("Create worker worker plugin") + config["task_queue"] = "replaced_queue" + return super().on_create_worker(config) + + async def run_worker(self, worker: Worker) -> None: + await super().run_worker(worker) + + +async def test_worker_plugin(client: Client) -> None: + worker = Worker( + client, + task_queue="queue", + activities=[never_run_activity], + plugins=[MyWorkerPlugin()], + ) + assert worker.config().get("task_queue") == "replaced_queue" + + # Test client plugin propagation to worker plugins + new_config = client.config() + new_config["plugins"] = [MyCombinedPlugin()] + client = Client(**new_config) + worker = Worker(client, task_queue="queue", activities=[never_run_activity]) + assert worker.config().get("task_queue") == "combined" + + # Test both. Client propagated plugins are called first, so the worker plugin overrides in this case + worker = Worker( + client, + task_queue="queue", + activities=[never_run_activity], + plugins=[MyWorkerPlugin()], + ) + assert worker.config().get("task_queue") == "replaced_queue" From 66b031c3d29f4bf60475f1ca119b72a8729b5a38 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Thu, 10 Jul 2025 08:53:21 -0700 Subject: [PATCH 02/11] Ensure plugin modification happen before any other initialization --- temporalio/client.py | 14 +-- temporalio/worker/_worker.py | 200 +++++++++++++++++++---------------- 2 files changed, 114 insertions(+), 100 deletions(-) diff --git a/temporalio/client.py b/temporalio/client.py index 2042c9b4d..145537af7 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -107,6 +107,7 @@ async def connect( namespace: str = "default", api_key: Optional[str] = None, data_converter: temporalio.converter.DataConverter = temporalio.converter.DataConverter.default, + plugins: Sequence[Plugin] = [], interceptors: Sequence[Interceptor] = [], default_workflow_query_reject_condition: Optional[ temporalio.common.QueryRejectCondition @@ -120,7 +121,6 @@ async def connect( runtime: Optional[temporalio.runtime.Runtime] = None, http_connect_proxy_config: Optional[HttpConnectProxyConfig] = None, header_codec_behavior: HeaderCodecBehavior = HeaderCodecBehavior.NO_CODEC, - plugins: Sequence[Plugin] = [], ) -> Client: """Connect to a Temporal server. @@ -202,22 +202,17 @@ def __init__( *, namespace: str = "default", data_converter: temporalio.converter.DataConverter = temporalio.converter.DataConverter.default, + plugins: Sequence[Plugin] = [], interceptors: Sequence[Interceptor] = [], default_workflow_query_reject_condition: Optional[ temporalio.common.QueryRejectCondition ] = None, header_codec_behavior: HeaderCodecBehavior = HeaderCodecBehavior.NO_CODEC, - plugins: Sequence[Plugin] = [], ): """Create a Temporal client from a service client. See :py:meth:`connect` for details on the parameters. """ - # Iterate over interceptors in reverse building the impl - self._impl: OutboundInterceptor = _ClientImpl(self) - for interceptor in reversed(list(interceptors)): - self._impl = interceptor.intercept_client(self._impl) - # Store the config for tracking config = ClientConfig( service_client=service_client, @@ -235,6 +230,11 @@ def __init__( self._config = root_plugin.on_create_client(config) + # Iterate over interceptors in reverse building the impl + self._impl: OutboundInterceptor = _ClientImpl(self) + for interceptor in reversed(list(interceptors)): + self._impl = interceptor.intercept_client(self._impl) + def config(self) -> ClientConfig: """Config, as a dictionary, used to create this client. diff --git a/temporalio/worker/_worker.py b/temporalio/worker/_worker.py index 3cf1d62bc..bc9d07353 100644 --- a/temporalio/worker/_worker.py +++ b/temporalio/worker/_worker.py @@ -138,6 +138,7 @@ def __init__( nexus_task_executor: Optional[concurrent.futures.Executor] = None, workflow_runner: WorkflowRunner = SandboxedWorkflowRunner(), unsandboxed_workflow_runner: WorkflowRunner = UnsandboxedWorkflowRunner(), + plugins: Sequence[Plugin] = [], interceptors: Sequence[Interceptor] = [], build_id: Optional[str] = None, identity: Optional[str] = None, @@ -173,7 +174,6 @@ def __init__( nexus_task_poller_behavior: PollerBehavior = PollerBehaviorSimpleMaximum( maximum=5 ), - plugins: Sequence[Plugin] = [], ) -> None: """Create a worker to process workflows and/or activities. @@ -342,42 +342,11 @@ def __init__( nexus_task_poller_behavior: Specify the behavior of Nexus task polling. Defaults to a 5-poller maximum. """ - # TODO(nexus-preview): max_concurrent_nexus_tasks / tuner support - if not (activities or nexus_service_handlers or workflows): - raise ValueError( - "At least one activity, Nexus service, or workflow must be specified" - ) - if use_worker_versioning and not build_id: - raise ValueError( - "build_id must be specified when use_worker_versioning is True" - ) - if deployment_config and (build_id or use_worker_versioning): - raise ValueError( - "deployment_config cannot be used with build_id or use_worker_versioning" - ) - - # Prepend applicable client interceptors to the given ones - client_config = client.config() - interceptors_from_client = cast( - List[Interceptor], - [i for i in client_config["interceptors"] if isinstance(i, Interceptor)], - ) - interceptors = interceptors_from_client + list(interceptors) - - plugins_from_client = cast( - List[Plugin], [p for p in client_config["plugins"] if isinstance(p, Plugin)] - ) - plugins = plugins_from_client + list(plugins) - print(plugins) - - # Extract the bridge service client - bridge_client = _extract_bridge_client_for_worker(client) - - # Store the config for tracking config = WorkerConfig( client=client, task_queue=task_queue, activities=activities, + nexus_service_handlers=nexus_service_handlers, workflows=workflows, activity_executor=activity_executor, workflow_task_executor=workflow_task_executor, @@ -391,6 +360,7 @@ def __init__( max_concurrent_workflow_tasks=max_concurrent_workflow_tasks, max_concurrent_activities=max_concurrent_activities, max_concurrent_local_activities=max_concurrent_local_activities, + tuner=tuner, max_concurrent_workflow_task_polls=max_concurrent_workflow_task_polls, nonsticky_to_sticky_poll_ratio=nonsticky_to_sticky_poll_ratio, max_concurrent_activity_task_polls=max_concurrent_activity_task_polls, @@ -408,14 +378,54 @@ def __init__( on_fatal_error=on_fatal_error, use_worker_versioning=use_worker_versioning, disable_safe_workflow_eviction=disable_safe_workflow_eviction, + deployment_config=deployment_config, + workflow_task_poller_behavior=workflow_task_poller_behavior, + activity_task_poller_behavior=activity_task_poller_behavior, + nexus_task_poller_behavior=nexus_task_poller_behavior, ) + plugins_from_client = cast( + List[Plugin], [p for p in client.config()["plugins"] if isinstance(p, Plugin)] + ) + plugins = plugins_from_client + list(plugins) + root_plugin: Plugin = _RootPlugin() for plugin in reversed(list(plugins)): root_plugin = plugin.init_worker_plugin(root_plugin) - self._config = root_plugin.on_create_worker(config) + config = root_plugin.on_create_worker(config) self._plugin = root_plugin + self._init_from_config(config) + + def _init_from_config(self, config: WorkerConfig): + """Handles post plugin initialization to ensure original arguments are not used""" + self._config = config + + # TODO(nexus-preview): max_concurrent_nexus_tasks / tuner support + if not (config["activities"] or config["nexus_service_handlers"] or config["workflows"]): + raise ValueError( + "At least one activity, Nexus service, or workflow must be specified" + ) + if config["use_worker_versioning"] and not config["build_id"]: + raise ValueError( + "build_id must be specified when use_worker_versioning is True" + ) + if config["deployment_config"] and (config["build_id"] or config["use_worker_versioning"]): + raise ValueError( + "deployment_config cannot be used with build_id or use_worker_versioning" + ) + + # Prepend applicable client interceptors to the given ones + client_config = config["client"].config() + interceptors_from_client = cast( + List[Interceptor], + [i for i in client_config["interceptors"] if isinstance(i, Interceptor)], + ) + interceptors = interceptors_from_client + list(config["interceptors"]) + + # Extract the bridge service client + bridge_client = _extract_bridge_client_for_worker(config["client"]) + self._started = False self._shutdown_event = asyncio.Event() self._shutdown_complete_event = asyncio.Event() @@ -427,14 +437,14 @@ def __init__( self._runtime = ( bridge_client.config.runtime or temporalio.runtime.Runtime.default() ) - if activities: + if config["activities"]: # Issue warning here if executor max_workers is lower than max # concurrent activities. We do this here instead of in # _ActivityWorker so the stack level is predictable. - max_workers = getattr(activity_executor, "_max_workers", None) - concurrent_activities = max_concurrent_activities - if tuner and tuner._get_activities_max(): - concurrent_activities = tuner._get_activities_max() + max_workers = getattr(config["activity_executor"], "_max_workers", None) + concurrent_activities = config["max_concurrent_activities"] + if config["tuner"] and config["tuner"]._get_activities_max(): + concurrent_activities = config["tuner"]._get_activities_max() if isinstance(max_workers, int) and max_workers < ( concurrent_activities or 0 ): @@ -446,10 +456,10 @@ def __init__( self._activity_worker = _ActivityWorker( bridge_worker=lambda: self._bridge_worker, - task_queue=task_queue, - activities=activities, - activity_executor=activity_executor, - shared_state_manager=shared_state_manager, + task_queue=config["task_queue"], + activities=config["activities"], + activity_executor=config["activity_executor"], + shared_state_manager=config["shared_state_manager"], data_converter=client_config["data_converter"], interceptors=interceptors, metric_meter=self._runtime.metric_meter, @@ -457,23 +467,23 @@ def __init__( == HeaderCodecBehavior.CODEC, ) self._nexus_worker: Optional[_NexusWorker] = None - if nexus_service_handlers: + if config["nexus_service_handlers"]: self._nexus_worker = _NexusWorker( bridge_worker=lambda: self._bridge_worker, - client=client, - task_queue=task_queue, - service_handlers=nexus_service_handlers, + client=config["client"], + task_queue=config["task_queue"], + service_handlers=config["nexus_service_handlers"], data_converter=client_config["data_converter"], interceptors=interceptors, metric_meter=self._runtime.metric_meter, - executor=nexus_task_executor, + executor=config["nexus_task_executor"], ) self._workflow_worker: Optional[_WorkflowWorker] = None - if workflows: + if config["workflows"]: should_enforce_versioning_behavior = ( - deployment_config is not None - and deployment_config.use_worker_versioning - and deployment_config.default_versioning_behavior + config["deployment_config"] is not None + and config["deployment_config"].use_worker_versioning + and config["deployment_config"].default_versioning_behavior == temporalio.common.VersioningBehavior.UNSPECIFIED ) @@ -487,32 +497,33 @@ def check_activity(activity): self._workflow_worker = _WorkflowWorker( bridge_worker=lambda: self._bridge_worker, - namespace=client.namespace, - task_queue=task_queue, - workflows=workflows, - workflow_task_executor=workflow_task_executor, - max_concurrent_workflow_tasks=max_concurrent_workflow_tasks, - workflow_runner=workflow_runner, - unsandboxed_workflow_runner=unsandboxed_workflow_runner, + namespace=config["client"].namespace, + task_queue=config["task_queue"], + workflows=config["workflows"], + workflow_task_executor=config["workflow_task_executor"], + max_concurrent_workflow_tasks=config["max_concurrent_workflow_tasks"], + workflow_runner=config["workflow_runner"], + unsandboxed_workflow_runner=config["unsandboxed_workflow_runner"], data_converter=client_config["data_converter"], interceptors=interceptors, - workflow_failure_exception_types=workflow_failure_exception_types, - debug_mode=debug_mode, - disable_eager_activity_execution=disable_eager_activity_execution, + workflow_failure_exception_types=config["workflow_failure_exception_types"], + debug_mode=config["debug_mode"], + disable_eager_activity_execution=config["disable_eager_activity_execution"], metric_meter=self._runtime.metric_meter, on_eviction_hook=None, - disable_safe_eviction=disable_safe_workflow_eviction, + disable_safe_eviction=config["disable_safe_workflow_eviction"], should_enforce_versioning_behavior=should_enforce_versioning_behavior, assert_local_activity_valid=check_activity, encode_headers=client_config["header_codec_behavior"] != HeaderCodecBehavior.NO_CODEC, ) - if tuner is not None: + tuner = config["tuner"] + if config["tuner"] is not None: if ( - max_concurrent_workflow_tasks - or max_concurrent_activities - or max_concurrent_local_activities + config["max_concurrent_workflow_tasks"] + or config["max_concurrent_activities"] + or config["max_concurrent_local_activities"] ): raise ValueError( "Cannot specify max_concurrent_workflow_tasks, max_concurrent_activities, " @@ -520,38 +531,40 @@ def check_activity(activity): ) else: tuner = WorkerTuner.create_fixed( - workflow_slots=max_concurrent_workflow_tasks, - activity_slots=max_concurrent_activities, - local_activity_slots=max_concurrent_local_activities, + workflow_slots=config["max_concurrent_workflow_tasks"], + activity_slots=config["max_concurrent_activities"], + local_activity_slots=config["max_concurrent_local_activities"], ) bridge_tuner = tuner._to_bridge_tuner() versioning_strategy: temporalio.bridge.worker.WorkerVersioningStrategy - if deployment_config: + if config["deployment_config"]: versioning_strategy = ( - deployment_config._to_bridge_worker_deployment_options() + config["deployment_config"]._to_bridge_worker_deployment_options() ) - elif use_worker_versioning: - build_id = build_id or load_default_build_id() + elif config["use_worker_versioning"]: + build_id = config["build_id"] or load_default_build_id() versioning_strategy = ( temporalio.bridge.worker.WorkerVersioningStrategyLegacyBuildIdBased( build_id_with_versioning=build_id ) ) else: - build_id = build_id or load_default_build_id() + build_id = config["build_id"] or load_default_build_id() versioning_strategy = temporalio.bridge.worker.WorkerVersioningStrategyNone( build_id_no_versioning=build_id ) - if max_concurrent_workflow_task_polls: + workflow_task_poller_behavior = config["workflow_task_poller_behavior"] + if config["max_concurrent_workflow_task_polls"]: workflow_task_poller_behavior = PollerBehaviorSimpleMaximum( - maximum=max_concurrent_workflow_task_polls + maximum=config["max_concurrent_workflow_task_polls"] ) - if max_concurrent_activity_task_polls: + activity_task_poller_behavior = config["activity_task_poller_behavior"] + if config["max_concurrent_activity_task_polls"]: activity_task_poller_behavior = PollerBehaviorSimpleMaximum( - maximum=max_concurrent_activity_task_polls + maximum=config["max_concurrent_activity_task_polls"] ) # Create bridge worker last. We have empirically observed that if it is @@ -564,29 +577,29 @@ def check_activity(activity): self._bridge_worker = temporalio.bridge.worker.Worker.create( bridge_client._bridge_client, temporalio.bridge.worker.WorkerConfig( - namespace=client.namespace, - task_queue=task_queue, - identity_override=identity, - max_cached_workflows=max_cached_workflows, + namespace=config["client"].namespace, + task_queue=config["task_queue"], + identity_override=config["identity"], + max_cached_workflows=config["max_cached_workflows"], tuner=bridge_tuner, - nonsticky_to_sticky_poll_ratio=nonsticky_to_sticky_poll_ratio, + nonsticky_to_sticky_poll_ratio=config["nonsticky_to_sticky_poll_ratio"], # We have to disable remote activities if a user asks _or_ if we # are not running an activity worker at all. Otherwise shutdown # will not proceed properly. - no_remote_activities=no_remote_activities or not activities, + no_remote_activities=config["no_remote_activities"] or not config["activities"], sticky_queue_schedule_to_start_timeout_millis=int( - 1000 * sticky_queue_schedule_to_start_timeout.total_seconds() + 1000 * config["sticky_queue_schedule_to_start_timeout"].total_seconds() ), max_heartbeat_throttle_interval_millis=int( - 1000 * max_heartbeat_throttle_interval.total_seconds() + 1000 * config["max_heartbeat_throttle_interval"].total_seconds() ), default_heartbeat_throttle_interval_millis=int( - 1000 * default_heartbeat_throttle_interval.total_seconds() + 1000 * config["default_heartbeat_throttle_interval"].total_seconds() ), - max_activities_per_second=max_activities_per_second, - max_task_queue_activities_per_second=max_task_queue_activities_per_second, + max_activities_per_second=config["max_activities_per_second"], + max_task_queue_activities_per_second=config["max_task_queue_activities_per_second"], graceful_shutdown_period_millis=int( - 1000 * graceful_shutdown_timeout.total_seconds() + 1000 * config["graceful_shutdown_timeout"].total_seconds() ), # Need to tell core whether we want to consider all # non-determinism exceptions as workflow fail, and whether we do @@ -601,7 +614,7 @@ def check_activity(activity): versioning_strategy=versioning_strategy, workflow_task_poller_behavior=workflow_task_poller_behavior._to_bridge(), activity_task_poller_behavior=activity_task_poller_behavior._to_bridge(), - nexus_task_poller_behavior=nexus_task_poller_behavior._to_bridge(), + nexus_task_poller_behavior=config["nexus_task_poller_behavior"]._to_bridge(), ), ) @@ -852,6 +865,7 @@ class WorkerConfig(TypedDict, total=False): client: temporalio.client.Client task_queue: str activities: Sequence[Callable] + nexus_service_handlers: Sequence[Any] workflows: Sequence[Type] activity_executor: Optional[concurrent.futures.Executor] workflow_task_executor: Optional[concurrent.futures.ThreadPoolExecutor] @@ -886,7 +900,7 @@ class WorkerConfig(TypedDict, total=False): deployment_config: Optional[WorkerDeploymentConfig] workflow_task_poller_behavior: PollerBehavior activity_task_poller_behavior: PollerBehavior - + nexus_task_poller_behavior: PollerBehavior @dataclass class WorkerDeploymentConfig: From ab89a650d293febf746d8c6e52e6d1c291ad1380 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Thu, 10 Jul 2025 09:36:01 -0700 Subject: [PATCH 03/11] Openai Agents plugin PoC --- temporalio/contrib/openai_agents/__init__.py | 2 + .../openai_agents/temporal_openai_agents.py | 47 +- temporalio/worker/_worker.py | 44 +- tests/contrib/openai_agents/test_openai.py | 734 +++++++++--------- 4 files changed, 432 insertions(+), 395 deletions(-) diff --git a/temporalio/contrib/openai_agents/__init__.py b/temporalio/contrib/openai_agents/__init__.py index 027ad44ad..2c20effc7 100644 --- a/temporalio/contrib/openai_agents/__init__.py +++ b/temporalio/contrib/openai_agents/__init__.py @@ -14,6 +14,7 @@ OpenAIAgentsTracingInterceptor, ) from temporalio.contrib.openai_agents.temporal_openai_agents import ( + Plugin, TestModel, TestModelProvider, set_open_ai_agent_temporal_overrides, @@ -21,6 +22,7 @@ ) __all__ = [ + "Plugin", "ModelActivity", "ModelActivityParameters", "workflow", diff --git a/temporalio/contrib/openai_agents/temporal_openai_agents.py b/temporalio/contrib/openai_agents/temporal_openai_agents.py index b9bf57499..9b574d708 100644 --- a/temporalio/contrib/openai_agents/temporal_openai_agents.py +++ b/temporalio/contrib/openai_agents/temporal_openai_agents.py @@ -3,10 +3,9 @@ import json from contextlib import contextmanager from datetime import timedelta -from typing import Any, AsyncIterator, Callable, Optional, Union, overload +from typing import Any, AsyncIterator, Callable, Optional, Union from agents import ( - Agent, AgentOutputSchemaBase, Handoff, Model, @@ -19,31 +18,34 @@ TResponseInputItem, set_trace_provider, ) -from agents.function_schema import DocstringStyle, function_schema +from agents.function_schema import function_schema from agents.items import TResponseStreamEvent from agents.run import get_default_agent_runner, set_default_agent_runner from agents.tool import ( FunctionTool, - ToolErrorFunction, - ToolFunction, - ToolParams, - default_tool_error_function, - function_tool, ) from agents.tracing import get_trace_provider from agents.tracing.provider import DefaultTraceProvider -from agents.util._types import MaybeAwaitable from openai.types.responses import ResponsePromptParam +import temporalio.client +import temporalio.worker from temporalio import activity from temporalio import workflow as temporal_workflow +from temporalio.client import ClientConfig from temporalio.common import Priority, RetryPolicy +from temporalio.contrib.openai_agents import ( + ModelActivity, + OpenAIAgentsTracingInterceptor, +) from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters from temporalio.contrib.openai_agents._openai_runner import TemporalOpenAIRunner from temporalio.contrib.openai_agents._temporal_trace_provider import ( TemporalTraceProvider, ) +from temporalio.contrib.pydantic import pydantic_data_converter from temporalio.exceptions import ApplicationError, TemporalError +from temporalio.worker import Worker, WorkerConfig from temporalio.workflow import ActivityCancellationType, VersioningIntent @@ -154,6 +156,33 @@ def stream_response( raise NotImplementedError() +class Plugin(temporalio.client.Plugin, temporalio.worker.Plugin): + def __init__( + self, + model_params: Optional[ModelActivityParameters] = None, + model_provider: Optional[ModelProvider] = None, + ) -> None: + self._model_params = model_params + self._model_provider = model_provider + + def on_create_client(self, config: ClientConfig) -> ClientConfig: + config["data_converter"] = pydantic_data_converter + return super().on_create_client(config) + + def on_create_worker(self, config: WorkerConfig) -> WorkerConfig: + config["interceptors"] = list(config.get("interceptors") or []) + [ + OpenAIAgentsTracingInterceptor() + ] + config["activities"] = list(config.get("activities") or []) + [ + ModelActivity(self._model_provider).invoke_model_activity + ] + return super().on_create_worker(config) + + async def run_worker(self, worker: Worker) -> None: + with set_open_ai_agent_temporal_overrides(self._model_params): + await super().run_worker(worker) + + class ToolSerializationError(TemporalError): """Error that occurs when a tool output could not be serialized.""" diff --git a/temporalio/worker/_worker.py b/temporalio/worker/_worker.py index bc9d07353..fca8a35a4 100644 --- a/temporalio/worker/_worker.py +++ b/temporalio/worker/_worker.py @@ -385,7 +385,8 @@ def __init__( ) plugins_from_client = cast( - List[Plugin], [p for p in client.config()["plugins"] if isinstance(p, Plugin)] + List[Plugin], + [p for p in client.config()["plugins"] if isinstance(p, Plugin)], ) plugins = plugins_from_client + list(plugins) @@ -402,7 +403,11 @@ def _init_from_config(self, config: WorkerConfig): self._config = config # TODO(nexus-preview): max_concurrent_nexus_tasks / tuner support - if not (config["activities"] or config["nexus_service_handlers"] or config["workflows"]): + if not ( + config["activities"] + or config["nexus_service_handlers"] + or config["workflows"] + ): raise ValueError( "At least one activity, Nexus service, or workflow must be specified" ) @@ -410,7 +415,9 @@ def _init_from_config(self, config: WorkerConfig): raise ValueError( "build_id must be specified when use_worker_versioning is True" ) - if config["deployment_config"] and (config["build_id"] or config["use_worker_versioning"]): + if config["deployment_config"] and ( + config["build_id"] or config["use_worker_versioning"] + ): raise ValueError( "deployment_config cannot be used with build_id or use_worker_versioning" ) @@ -506,9 +513,13 @@ def check_activity(activity): unsandboxed_workflow_runner=config["unsandboxed_workflow_runner"], data_converter=client_config["data_converter"], interceptors=interceptors, - workflow_failure_exception_types=config["workflow_failure_exception_types"], + workflow_failure_exception_types=config[ + "workflow_failure_exception_types" + ], debug_mode=config["debug_mode"], - disable_eager_activity_execution=config["disable_eager_activity_execution"], + disable_eager_activity_execution=config[ + "disable_eager_activity_execution" + ], metric_meter=self._runtime.metric_meter, on_eviction_hook=None, disable_safe_eviction=config["disable_safe_workflow_eviction"], @@ -519,7 +530,7 @@ def check_activity(activity): ) tuner = config["tuner"] - if config["tuner"] is not None: + if tuner is not None: if ( config["max_concurrent_workflow_tasks"] or config["max_concurrent_activities"] @@ -540,9 +551,9 @@ def check_activity(activity): versioning_strategy: temporalio.bridge.worker.WorkerVersioningStrategy if config["deployment_config"]: - versioning_strategy = ( - config["deployment_config"]._to_bridge_worker_deployment_options() - ) + versioning_strategy = config[ + "deployment_config" + ]._to_bridge_worker_deployment_options() elif config["use_worker_versioning"]: build_id = config["build_id"] or load_default_build_id() versioning_strategy = ( @@ -586,9 +597,11 @@ def check_activity(activity): # We have to disable remote activities if a user asks _or_ if we # are not running an activity worker at all. Otherwise shutdown # will not proceed properly. - no_remote_activities=config["no_remote_activities"] or not config["activities"], + no_remote_activities=config["no_remote_activities"] + or not config["activities"], sticky_queue_schedule_to_start_timeout_millis=int( - 1000 * config["sticky_queue_schedule_to_start_timeout"].total_seconds() + 1000 + * config["sticky_queue_schedule_to_start_timeout"].total_seconds() ), max_heartbeat_throttle_interval_millis=int( 1000 * config["max_heartbeat_throttle_interval"].total_seconds() @@ -597,7 +610,9 @@ def check_activity(activity): 1000 * config["default_heartbeat_throttle_interval"].total_seconds() ), max_activities_per_second=config["max_activities_per_second"], - max_task_queue_activities_per_second=config["max_task_queue_activities_per_second"], + max_task_queue_activities_per_second=config[ + "max_task_queue_activities_per_second" + ], graceful_shutdown_period_millis=int( 1000 * config["graceful_shutdown_timeout"].total_seconds() ), @@ -614,7 +629,9 @@ def check_activity(activity): versioning_strategy=versioning_strategy, workflow_task_poller_behavior=workflow_task_poller_behavior._to_bridge(), activity_task_poller_behavior=activity_task_poller_behavior._to_bridge(), - nexus_task_poller_behavior=config["nexus_task_poller_behavior"]._to_bridge(), + nexus_task_poller_behavior=config[ + "nexus_task_poller_behavior" + ]._to_bridge(), ), ) @@ -902,6 +919,7 @@ class WorkerConfig(TypedDict, total=False): activity_task_poller_behavior: PollerBehavior nexus_task_poller_behavior: PollerBehavior + @dataclass class WorkerDeploymentConfig: """Options for configuring the Worker Versioning feature. diff --git a/tests/contrib/openai_agents/test_openai.py b/tests/contrib/openai_agents/test_openai.py index cfc74eb6b..d2224761c 100644 --- a/tests/contrib/openai_agents/test_openai.py +++ b/tests/contrib/openai_agents/test_openai.py @@ -124,26 +124,28 @@ async def test_hello_world_agent(client: Client, use_local_model: bool): if not use_local_model and not os.environ.get("OPENAI_API_KEY"): pytest.skip("No openai API key") new_config = client.config() - new_config["data_converter"] = pydantic_data_converter + new_config["plugins"] = [ + openai_agents.Plugin( + model_params=ModelActivityParameters( + start_to_close_timeout=timedelta(seconds=30) + ), + model_provider=TestModelProvider(TestHelloModel()) + if use_local_model + else None, + ) + ] client = Client(**new_config) - model_params = ModelActivityParameters(start_to_close_timeout=timedelta(seconds=30)) - with set_open_ai_agent_temporal_overrides(model_params): - model_activity = ModelActivity( - TestModelProvider(TestHelloModel()) if use_local_model else None + async with new_worker(client, HelloWorldAgent) as worker: + result = await client.execute_workflow( + HelloWorldAgent.run, + "Tell me about recursion in programming.", + id=f"hello-workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=5), ) - async with new_worker( - client, HelloWorldAgent, activities=[model_activity.invoke_model_activity] - ) as worker: - result = await client.execute_workflow( - HelloWorldAgent.run, - "Tell me about recursion in programming.", - id=f"hello-workflow-{uuid.uuid4()}", - task_queue=worker.task_queue, - execution_timeout=timedelta(seconds=5), - ) - if use_local_model: - assert result == "test" + if use_local_model: + assert result == "test" @dataclass @@ -305,103 +307,100 @@ async def test_tool_workflow(client: Client, use_local_model: bool): if not use_local_model and not os.environ.get("OPENAI_API_KEY"): pytest.skip("No openai API key") new_config = client.config() - new_config["data_converter"] = pydantic_data_converter + new_config["plugins"] = [ + openai_agents.Plugin( + model_params=ModelActivityParameters( + start_to_close_timeout=timedelta(seconds=30) + ), + model_provider=TestModelProvider(TestWeatherModel()) + if use_local_model + else None, + ) + ] client = Client(**new_config) - model_params = ModelActivityParameters(start_to_close_timeout=timedelta(seconds=30)) - with set_open_ai_agent_temporal_overrides(model_params): - model_activity = ModelActivity( - TestModelProvider( - TestWeatherModel( # type: ignore - ) - ) - if use_local_model - else None + async with new_worker( + client, + ToolsWorkflow, + activities=[ + get_weather, + get_weather_object, + get_weather_country, + get_weather_context, + ], + ) as worker: + workflow_handle = await client.start_workflow( + ToolsWorkflow.run, + "What is the weather in Tokio?", + id=f"tools-workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=30), ) - async with new_worker( - client, - ToolsWorkflow, - activities=[ - model_activity.invoke_model_activity, - get_weather, - get_weather_object, - get_weather_country, - get_weather_context, - ], - interceptors=[OpenAIAgentsTracingInterceptor()], - ) as worker: - workflow_handle = await client.start_workflow( - ToolsWorkflow.run, - "What is the weather in Tokio?", - id=f"tools-workflow-{uuid.uuid4()}", - task_queue=worker.task_queue, - execution_timeout=timedelta(seconds=30), + result = await workflow_handle.result() + + if use_local_model: + assert result == "Test weather result" + + events = [] + async for e in workflow_handle.fetch_history_events(): + if e.HasField("activity_task_completed_event_attributes"): + events.append(e) + + assert len(events) == 9 + assert ( + "function_call" + in events[0] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "Sunny with wind" + in events[1] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "function_call" + in events[2] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "Sunny with wind" + in events[3] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "function_call" + in events[4] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "Sunny with wind" + in events[5] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "function_call" + in events[6] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "Stormy" + in events[7] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "Test weather result" + in events[8] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() ) - result = await workflow_handle.result() - - if use_local_model: - assert result == "Test weather result" - - events = [] - async for e in workflow_handle.fetch_history_events(): - if e.HasField("activity_task_completed_event_attributes"): - events.append(e) - - assert len(events) == 9 - assert ( - "function_call" - in events[0] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "Sunny with wind" - in events[1] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "function_call" - in events[2] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "Sunny with wind" - in events[3] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "function_call" - in events[4] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "Sunny with wind" - in events[5] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "function_call" - in events[6] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "Stormy" - in events[7] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "Test weather result" - in events[8] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) @no_type_check @@ -491,63 +490,60 @@ async def test_research_workflow(client: Client, use_local_model: bool): if not use_local_model and not os.environ.get("OPENAI_API_KEY"): pytest.skip("No openai API key") new_config = client.config() - new_config["data_converter"] = pydantic_data_converter + new_config["plugins"] = [ + openai_agents.Plugin( + model_params=ModelActivityParameters( + start_to_close_timeout=timedelta(seconds=30) + ), + model_provider=TestModelProvider(TestResearchModel()) + if use_local_model + else None, + ) + ] client = Client(**new_config) - global response_index - response_index = 0 - - model_params = ModelActivityParameters( - start_to_close_timeout=timedelta(seconds=120) - ) - with set_open_ai_agent_temporal_overrides(model_params): - model_activity = ModelActivity( - TestModelProvider(TestResearchModel()) if use_local_model else None + async with new_worker( + client, + ResearchWorkflow, + ) as worker: + workflow_handle = await client.start_workflow( + ResearchWorkflow.run, + "Caribbean vacation spots in April, optimizing for surfing, hiking and water sports", + id=f"research-workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=120), ) - async with new_worker( - client, - ResearchWorkflow, - activities=[model_activity.invoke_model_activity, get_weather], - interceptors=[OpenAIAgentsTracingInterceptor()], - ) as worker: - workflow_handle = await client.start_workflow( - ResearchWorkflow.run, - "Caribbean vacation spots in April, optimizing for surfing, hiking and water sports", - id=f"research-workflow-{uuid.uuid4()}", - task_queue=worker.task_queue, - execution_timeout=timedelta(seconds=120), + result = await workflow_handle.result() + + if use_local_model: + assert result == "report" + + events = [] + async for e in workflow_handle.fetch_history_events(): + if e.HasField("activity_task_completed_event_attributes"): + events.append(e) + + assert len(events) == 12 + assert ( + '"type":"output_text"' + in events[0] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() ) - result = await workflow_handle.result() - - if use_local_model: - assert result == "report" - - events = [] - async for e in workflow_handle.fetch_history_events(): - if e.HasField("activity_task_completed_event_attributes"): - events.append(e) - - assert len(events) == 12 + for i in range(1, 11): assert ( - '"type":"output_text"' - in events[0] + "web_search_call" + in events[i] .activity_task_completed_event_attributes.result.payloads[0] .data.decode() ) - for i in range(1, 11): - assert ( - "web_search_call" - in events[i] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - '"type":"output_text"' - in events[11] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) + assert ( + '"type":"output_text"' + in events[11] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) def orchestrator_agent() -> Agent: @@ -708,67 +704,64 @@ async def test_agents_as_tools_workflow(client: Client, use_local_model: bool): if not use_local_model and not os.environ.get("OPENAI_API_KEY"): pytest.skip("No openai API key") new_config = client.config() - new_config["data_converter"] = pydantic_data_converter + new_config["plugins"] = [ + openai_agents.Plugin( + model_params=ModelActivityParameters( + start_to_close_timeout=timedelta(seconds=30) + ), + model_provider=TestModelProvider(AgentAsToolsModel()) + if use_local_model + else None, + ) + ] client = Client(**new_config) - model_params = ModelActivityParameters(start_to_close_timeout=timedelta(seconds=30)) - with set_open_ai_agent_temporal_overrides(model_params): - model_activity = ModelActivity( - TestModelProvider( - AgentAsToolsModel( # type: ignore - ) - ) - if use_local_model - else None + async with new_worker( + client, + AgentsAsToolsWorkflow, + ) as worker: + workflow_handle = await client.start_workflow( + AgentsAsToolsWorkflow.run, + "Translate to Spanish: 'I am full'", + id=f"agents-as-tools-workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=30), ) - async with new_worker( - client, - AgentsAsToolsWorkflow, - activities=[model_activity.invoke_model_activity], - interceptors=[OpenAIAgentsTracingInterceptor()], - ) as worker: - workflow_handle = await client.start_workflow( - AgentsAsToolsWorkflow.run, - "Translate to Spanish: 'I am full'", - id=f"agents-as-tools-workflow-{uuid.uuid4()}", - task_queue=worker.task_queue, - execution_timeout=timedelta(seconds=30), + result = await workflow_handle.result() + + if use_local_model: + assert result == 'The translation to Spanish is: "Estoy lleno."' + + events = [] + async for e in workflow_handle.fetch_history_events(): + if e.HasField("activity_task_completed_event_attributes"): + events.append(e) + + assert len(events) == 4 + assert ( + "function_call" + in events[0] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "Estoy lleno" + in events[1] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "The translation to Spanish is:" + in events[2] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "The translation to Spanish is:" + in events[3] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() ) - result = await workflow_handle.result() - - if use_local_model: - assert result == 'The translation to Spanish is: "Estoy lleno."' - - events = [] - async for e in workflow_handle.fetch_history_events(): - if e.HasField("activity_task_completed_event_attributes"): - events.append(e) - - assert len(events) == 4 - assert ( - "function_call" - in events[0] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "Estoy lleno" - in events[1] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "The translation to Spanish is:" - in events[2] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "The translation to Spanish is:" - in events[3] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) class AirlineAgentContext(BaseModel): @@ -1063,97 +1056,94 @@ async def test_customer_service_workflow(client: Client, use_local_model: bool): if not use_local_model and not os.environ.get("OPENAI_API_KEY"): pytest.skip("No openai API key") new_config = client.config() - new_config["data_converter"] = pydantic_data_converter + new_config["plugins"] = [ + openai_agents.Plugin( + model_params=ModelActivityParameters( + start_to_close_timeout=timedelta(seconds=30) + ), + model_provider=TestModelProvider(CustomerServiceModel()) + if use_local_model + else None, + ) + ] client = Client(**new_config) questions = ["Hello", "Book me a flight to PDX", "11111", "Any window seat"] - model_params = ModelActivityParameters(start_to_close_timeout=timedelta(seconds=30)) - with set_open_ai_agent_temporal_overrides(model_params): - model_activity = ModelActivity( - TestModelProvider( - CustomerServiceModel( # type: ignore - ) - ) - if use_local_model - else None + async with new_worker( + client, + CustomerServiceWorkflow, + ) as worker: + workflow_handle = await client.start_workflow( + CustomerServiceWorkflow.run, + id=f"customer-service-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=30), ) - async with new_worker( - client, - CustomerServiceWorkflow, - activities=[model_activity.invoke_model_activity], - interceptors=[OpenAIAgentsTracingInterceptor()], - ) as worker: - workflow_handle = await client.start_workflow( - CustomerServiceWorkflow.run, - id=f"customer-service-{uuid.uuid4()}", - task_queue=worker.task_queue, - execution_timeout=timedelta(seconds=30), + history: list[Any] = [] + for q in questions: + message_input = ProcessUserMessageInput( + user_input=q, chat_length=len(history) + ) + new_history = await workflow_handle.execute_update( + CustomerServiceWorkflow.process_user_message, message_input + ) + history.extend(new_history) + print(*new_history, sep="\n") + + await workflow_handle.cancel() + + with pytest.raises(WorkflowFailureError) as err: + await workflow_handle.result() + assert isinstance(err.value.cause, CancelledError) + + if use_local_model: + events = [] + async for e in WorkflowHandle( + client, + workflow_handle.id, + run_id=workflow_handle._first_execution_run_id, + ).fetch_history_events(): + if e.HasField("activity_task_completed_event_attributes"): + events.append(e) + + assert len(events) == 6 + assert ( + "Hi there! How can I assist you today?" + in events[0] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "transfer_to_seat_booking_agent" + in events[1] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "Could you please provide your confirmation number?" + in events[2] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "Thanks! What seat number would you like to change to?" + in events[3] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "update_seat" + in events[4] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "Your seat has been updated to a window seat. If there's anything else you need, feel free to let me know!" + in events[5] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() ) - history: list[Any] = [] - for q in questions: - message_input = ProcessUserMessageInput( - user_input=q, chat_length=len(history) - ) - new_history = await workflow_handle.execute_update( - CustomerServiceWorkflow.process_user_message, message_input - ) - history.extend(new_history) - print(*new_history, sep="\n") - - await workflow_handle.cancel() - - with pytest.raises(WorkflowFailureError) as err: - await workflow_handle.result() - assert isinstance(err.value.cause, CancelledError) - - if use_local_model: - events = [] - async for e in WorkflowHandle( - client, - workflow_handle.id, - run_id=workflow_handle._first_execution_run_id, - ).fetch_history_events(): - if e.HasField("activity_task_completed_event_attributes"): - events.append(e) - - assert len(events) == 6 - assert ( - "Hi there! How can I assist you today?" - in events[0] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "transfer_to_seat_booking_agent" - in events[1] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "Could you please provide your confirmation number?" - in events[2] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "Thanks! What seat number would you like to change to?" - in events[3] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "update_seat" - in events[4] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "Your seat has been updated to a window seat. If there's anything else you need, feel free to let me know!" - in events[5] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) guardrail_response_index: int = 0 @@ -1356,42 +1346,40 @@ async def test_input_guardrail(client: Client, use_local_model: bool): if not use_local_model and not os.environ.get("OPENAI_API_KEY"): pytest.skip("No openai API key") new_config = client.config() - new_config["data_converter"] = pydantic_data_converter - client = Client(**new_config) - - model_params = ModelActivityParameters(start_to_close_timeout=timedelta(seconds=30)) - with set_open_ai_agent_temporal_overrides(model_params): - model_activity = ModelActivity( - TestModelProvider( - InputGuardrailModel( # type: ignore - "", openai_client=AsyncOpenAI(api_key="Fake key") - ) + new_config["plugins"] = [ + openai_agents.Plugin( + model_params=ModelActivityParameters( + start_to_close_timeout=timedelta(seconds=30) + ), + model_provider=TestModelProvider( + InputGuardrailModel("", openai_client=AsyncOpenAI(api_key="Fake key")) ) if use_local_model - else None + else None, ) - async with new_worker( - client, - InputGuardrailWorkflow, - activities=[model_activity.invoke_model_activity], - interceptors=[OpenAIAgentsTracingInterceptor()], - ) as worker: - workflow_handle = await client.start_workflow( - InputGuardrailWorkflow.run, - [ - "What's the capital of California?", - "Can you help me solve for x: 2x + 5 = 11", - ], - id=f"input-guardrail-{uuid.uuid4()}", - task_queue=worker.task_queue, - execution_timeout=timedelta(seconds=10), - ) - result = await workflow_handle.result() + ] + client = Client(**new_config) - if use_local_model: - assert len(result) == 2 - assert result[0] == "The capital of California is Sacramento." - assert result[1] == "Sorry, I can't help you with your math homework." + async with new_worker( + client, + InputGuardrailWorkflow, + ) as worker: + workflow_handle = await client.start_workflow( + InputGuardrailWorkflow.run, + [ + "What's the capital of California?", + "Can you help me solve for x: 2x + 5 = 11", + ], + id=f"input-guardrail-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=10), + ) + result = await workflow_handle.result() + + if use_local_model: + assert len(result) == 2 + assert result[0] == "The capital of California is Sacramento." + assert result[1] == "Sorry, I can't help you with your math homework." class OutputGuardrailModel(StaticTestModel): @@ -1473,35 +1461,32 @@ async def test_output_guardrail(client: Client, use_local_model: bool): if not use_local_model and not os.environ.get("OPENAI_API_KEY"): pytest.skip("No openai API key") new_config = client.config() - new_config["data_converter"] = pydantic_data_converter + new_config["plugins"] = [ + openai_agents.Plugin( + model_params=ModelActivityParameters( + start_to_close_timeout=timedelta(seconds=30) + ), + model_provider=TestModelProvider(OutputGuardrailModel()) + if use_local_model + else None, + ) + ] client = Client(**new_config) - model_params = ModelActivityParameters(start_to_close_timeout=timedelta(seconds=30)) - with set_open_ai_agent_temporal_overrides(model_params): - model_activity = ModelActivity( - TestModelProvider( - OutputGuardrailModel( # type: ignore - ) - ) - if use_local_model - else None + async with new_worker( + client, + OutputGuardrailWorkflow, + ) as worker: + workflow_handle = await client.start_workflow( + OutputGuardrailWorkflow.run, + id=f"output-guardrail-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=10), ) - async with new_worker( - client, - OutputGuardrailWorkflow, - activities=[model_activity.invoke_model_activity], - interceptors=[OpenAIAgentsTracingInterceptor()], - ) as worker: - workflow_handle = await client.start_workflow( - OutputGuardrailWorkflow.run, - id=f"output-guardrail-{uuid.uuid4()}", - task_queue=worker.task_queue, - execution_timeout=timedelta(seconds=10), - ) - result = await workflow_handle.result() + result = await workflow_handle.result() - if use_local_model: - assert not result + if use_local_model: + assert not result class WorkflowToolModel(StaticTestModel): @@ -1564,21 +1549,24 @@ async def run_tool(self): async def test_workflow_method_tools(client: Client): new_config = client.config() - new_config["data_converter"] = pydantic_data_converter + new_config["plugins"] = [ + openai_agents.Plugin( + model_params=ModelActivityParameters( + start_to_close_timeout=timedelta(seconds=30) + ), + model_provider=TestModelProvider(WorkflowToolModel()), + ) + ] client = Client(**new_config) - with set_open_ai_agent_temporal_overrides(): - model_activity = ModelActivity(TestModelProvider(WorkflowToolModel())) - async with new_worker( - client, - WorkflowToolWorkflow, - activities=[model_activity.invoke_model_activity], - interceptors=[OpenAIAgentsTracingInterceptor()], - ) as worker: - workflow_handle = await client.start_workflow( - WorkflowToolWorkflow.run, - id=f"workflow-tool-{uuid.uuid4()}", - task_queue=worker.task_queue, - execution_timeout=timedelta(seconds=10), - ) - await workflow_handle.result() + async with new_worker( + client, + WorkflowToolWorkflow, + ) as worker: + workflow_handle = await client.start_workflow( + WorkflowToolWorkflow.run, + id=f"workflow-tool-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=10), + ) + await workflow_handle.result() From 25e98213d7deb9385acfaf7f08614f585a738a43 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Thu, 10 Jul 2025 09:38:50 -0700 Subject: [PATCH 04/11] Remove extra import/exports --- temporalio/contrib/openai_agents/__init__.py | 8 -------- .../contrib/openai_agents/temporal_openai_agents.py | 8 ++++---- tests/contrib/openai_agents/test_openai.py | 4 ---- 3 files changed, 4 insertions(+), 16 deletions(-) diff --git a/temporalio/contrib/openai_agents/__init__.py b/temporalio/contrib/openai_agents/__init__.py index 2c20effc7..43636fa17 100644 --- a/temporalio/contrib/openai_agents/__init__.py +++ b/temporalio/contrib/openai_agents/__init__.py @@ -8,26 +8,18 @@ Use with caution in production environments. """ -from temporalio.contrib.openai_agents._invoke_model_activity import ModelActivity from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters -from temporalio.contrib.openai_agents._trace_interceptor import ( - OpenAIAgentsTracingInterceptor, -) from temporalio.contrib.openai_agents.temporal_openai_agents import ( Plugin, TestModel, TestModelProvider, - set_open_ai_agent_temporal_overrides, workflow, ) __all__ = [ "Plugin", - "ModelActivity", "ModelActivityParameters", "workflow", - "set_open_ai_agent_temporal_overrides", - "OpenAIAgentsTracingInterceptor", "TestModel", "TestModelProvider", ] diff --git a/temporalio/contrib/openai_agents/temporal_openai_agents.py b/temporalio/contrib/openai_agents/temporal_openai_agents.py index 9b574d708..04d68a5af 100644 --- a/temporalio/contrib/openai_agents/temporal_openai_agents.py +++ b/temporalio/contrib/openai_agents/temporal_openai_agents.py @@ -34,15 +34,15 @@ from temporalio import workflow as temporal_workflow from temporalio.client import ClientConfig from temporalio.common import Priority, RetryPolicy -from temporalio.contrib.openai_agents import ( - ModelActivity, - OpenAIAgentsTracingInterceptor, -) +from temporalio.contrib.openai_agents._invoke_model_activity import ModelActivity from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters from temporalio.contrib.openai_agents._openai_runner import TemporalOpenAIRunner from temporalio.contrib.openai_agents._temporal_trace_provider import ( TemporalTraceProvider, ) +from temporalio.contrib.openai_agents._trace_interceptor import ( + OpenAIAgentsTracingInterceptor, +) from temporalio.contrib.pydantic import pydantic_data_converter from temporalio.exceptions import ApplicationError, TemporalError from temporalio.worker import Worker, WorkerConfig diff --git a/tests/contrib/openai_agents/test_openai.py b/tests/contrib/openai_agents/test_openai.py index d2224761c..3fc283e1c 100644 --- a/tests/contrib/openai_agents/test_openai.py +++ b/tests/contrib/openai_agents/test_openai.py @@ -50,14 +50,10 @@ from temporalio.client import Client, WorkflowFailureError, WorkflowHandle from temporalio.contrib import openai_agents from temporalio.contrib.openai_agents import ( - ModelActivity, ModelActivityParameters, - OpenAIAgentsTracingInterceptor, TestModel, TestModelProvider, - set_open_ai_agent_temporal_overrides, ) -from temporalio.contrib.pydantic import pydantic_data_converter from temporalio.exceptions import CancelledError from tests.contrib.openai_agents.research_agents.research_manager import ( ResearchManager, From 381e0cc651aee340560f225d2ca2817bf6503ae8 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Thu, 10 Jul 2025 09:41:00 -0700 Subject: [PATCH 05/11] Format --- temporalio/worker/_worker.py | 42 +++++++++++++++++++++++++----------- 1 file changed, 30 insertions(+), 12 deletions(-) diff --git a/temporalio/worker/_worker.py b/temporalio/worker/_worker.py index bc9d07353..5b81e9186 100644 --- a/temporalio/worker/_worker.py +++ b/temporalio/worker/_worker.py @@ -385,7 +385,8 @@ def __init__( ) plugins_from_client = cast( - List[Plugin], [p for p in client.config()["plugins"] if isinstance(p, Plugin)] + List[Plugin], + [p for p in client.config()["plugins"] if isinstance(p, Plugin)], ) plugins = plugins_from_client + list(plugins) @@ -402,7 +403,11 @@ def _init_from_config(self, config: WorkerConfig): self._config = config # TODO(nexus-preview): max_concurrent_nexus_tasks / tuner support - if not (config["activities"] or config["nexus_service_handlers"] or config["workflows"]): + if not ( + config["activities"] + or config["nexus_service_handlers"] + or config["workflows"] + ): raise ValueError( "At least one activity, Nexus service, or workflow must be specified" ) @@ -410,7 +415,9 @@ def _init_from_config(self, config: WorkerConfig): raise ValueError( "build_id must be specified when use_worker_versioning is True" ) - if config["deployment_config"] and (config["build_id"] or config["use_worker_versioning"]): + if config["deployment_config"] and ( + config["build_id"] or config["use_worker_versioning"] + ): raise ValueError( "deployment_config cannot be used with build_id or use_worker_versioning" ) @@ -506,9 +513,13 @@ def check_activity(activity): unsandboxed_workflow_runner=config["unsandboxed_workflow_runner"], data_converter=client_config["data_converter"], interceptors=interceptors, - workflow_failure_exception_types=config["workflow_failure_exception_types"], + workflow_failure_exception_types=config[ + "workflow_failure_exception_types" + ], debug_mode=config["debug_mode"], - disable_eager_activity_execution=config["disable_eager_activity_execution"], + disable_eager_activity_execution=config[ + "disable_eager_activity_execution" + ], metric_meter=self._runtime.metric_meter, on_eviction_hook=None, disable_safe_eviction=config["disable_safe_workflow_eviction"], @@ -540,9 +551,9 @@ def check_activity(activity): versioning_strategy: temporalio.bridge.worker.WorkerVersioningStrategy if config["deployment_config"]: - versioning_strategy = ( - config["deployment_config"]._to_bridge_worker_deployment_options() - ) + versioning_strategy = config[ + "deployment_config" + ]._to_bridge_worker_deployment_options() elif config["use_worker_versioning"]: build_id = config["build_id"] or load_default_build_id() versioning_strategy = ( @@ -586,9 +597,11 @@ def check_activity(activity): # We have to disable remote activities if a user asks _or_ if we # are not running an activity worker at all. Otherwise shutdown # will not proceed properly. - no_remote_activities=config["no_remote_activities"] or not config["activities"], + no_remote_activities=config["no_remote_activities"] + or not config["activities"], sticky_queue_schedule_to_start_timeout_millis=int( - 1000 * config["sticky_queue_schedule_to_start_timeout"].total_seconds() + 1000 + * config["sticky_queue_schedule_to_start_timeout"].total_seconds() ), max_heartbeat_throttle_interval_millis=int( 1000 * config["max_heartbeat_throttle_interval"].total_seconds() @@ -597,7 +610,9 @@ def check_activity(activity): 1000 * config["default_heartbeat_throttle_interval"].total_seconds() ), max_activities_per_second=config["max_activities_per_second"], - max_task_queue_activities_per_second=config["max_task_queue_activities_per_second"], + max_task_queue_activities_per_second=config[ + "max_task_queue_activities_per_second" + ], graceful_shutdown_period_millis=int( 1000 * config["graceful_shutdown_timeout"].total_seconds() ), @@ -614,7 +629,9 @@ def check_activity(activity): versioning_strategy=versioning_strategy, workflow_task_poller_behavior=workflow_task_poller_behavior._to_bridge(), activity_task_poller_behavior=activity_task_poller_behavior._to_bridge(), - nexus_task_poller_behavior=config["nexus_task_poller_behavior"]._to_bridge(), + nexus_task_poller_behavior=config[ + "nexus_task_poller_behavior" + ]._to_bridge(), ), ) @@ -902,6 +919,7 @@ class WorkerConfig(TypedDict, total=False): activity_task_poller_behavior: PollerBehavior nexus_task_poller_behavior: PollerBehavior + @dataclass class WorkerDeploymentConfig: """Options for configuring the Worker Versioning feature. From 0db686d27a3075b2987e95589f5fe2e5dde972a4 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Thu, 10 Jul 2025 09:47:12 -0700 Subject: [PATCH 06/11] Use local tuner for type inference --- temporalio/worker/_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/temporalio/worker/_worker.py b/temporalio/worker/_worker.py index 5b81e9186..fca8a35a4 100644 --- a/temporalio/worker/_worker.py +++ b/temporalio/worker/_worker.py @@ -530,7 +530,7 @@ def check_activity(activity): ) tuner = config["tuner"] - if config["tuner"] is not None: + if tuner is not None: if ( config["max_concurrent_workflow_tasks"] or config["max_concurrent_activities"] From 0ace676abc775244ab5349288e654431abd6499c Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Mon, 21 Jul 2025 09:53:43 -0700 Subject: [PATCH 07/11] Add docstrings and fix merge issues --- .../openai_agents/temporal_openai_agents.py | 96 ++++++++- temporalio/contrib/openai_agents/workflow.py | 24 ++- tests/contrib/openai_agents/test_openai.py | 107 +++++----- .../openai_agents/test_openai_tracing.py | 197 ++++++++---------- tests/test_client.py | 4 +- tests/worker/test_worker.py | 8 +- 6 files changed, 267 insertions(+), 169 deletions(-) diff --git a/temporalio/contrib/openai_agents/temporal_openai_agents.py b/temporalio/contrib/openai_agents/temporal_openai_agents.py index 2003e22d6..76129c79e 100644 --- a/temporalio/contrib/openai_agents/temporal_openai_agents.py +++ b/temporalio/contrib/openai_agents/temporal_openai_agents.py @@ -36,6 +36,7 @@ from temporalio.contrib.pydantic import pydantic_data_converter from temporalio.worker import Worker, WorkerConfig + @contextmanager def set_open_ai_agent_temporal_overrides( model_params: Optional[ModelActivityParameters] = None, @@ -142,20 +143,104 @@ def stream_response( """Get a streamed response from the model. Unimplemented.""" raise NotImplementedError() + class Plugin(temporalio.client.Plugin, temporalio.worker.Plugin): + """Temporal plugin for integrating OpenAI agents with Temporal workflows. + + .. warning:: + This class is experimental and may change in future versions. + Use with caution in production environments. + + This plugin provides seamless integration between the OpenAI Agents SDK and + Temporal workflows. It automatically configures the necessary interceptors, + activities, and data converters to enable OpenAI agents to run within + Temporal workflows with proper tracing and model execution. + + The plugin: + 1. Configures the Pydantic data converter for type-safe serialization + 2. Sets up tracing interceptors for OpenAI agent interactions + 3. Registers model execution activities + 4. Manages the OpenAI agent runtime overrides during worker execution + + Args: + model_params: Configuration parameters for Temporal activity execution + of model calls. If None, default parameters will be used. + model_provider: Optional model provider for custom model implementations. + Useful for testing or custom model integrations. + + Example: + >>> from temporalio.client import Client + >>> from temporalio.worker import Worker + >>> from temporalio.contrib.openai_agents import Plugin, ModelActivityParameters + >>> from datetime import timedelta + >>> + >>> # Configure model parameters + >>> model_params = ModelActivityParameters( + ... start_to_close_timeout=timedelta(seconds=30), + ... retry_policy=RetryPolicy(maximum_attempts=3) + ... ) + >>> + >>> # Create plugin + >>> plugin = Plugin(model_params=model_params) + >>> + >>> # Use with client and worker + >>> client = await Client.connect( + ... "localhost:7233", + ... plugins=[plugin] + ... ) + >>> worker = Worker( + ... client, + ... task_queue="my-task-queue", + ... workflows=[MyWorkflow], + ... plugins=[plugin] + ... ) + """ + def __init__( self, model_params: Optional[ModelActivityParameters] = None, model_provider: Optional[ModelProvider] = None, ) -> None: + """Initialize the OpenAI agents plugin. + + Args: + model_params: Configuration parameters for Temporal activity execution + of model calls. If None, default parameters will be used. + model_provider: Optional model provider for custom model implementations. + Useful for testing or custom model integrations. + """ self._model_params = model_params self._model_provider = model_provider def configure_client(self, config: ClientConfig) -> ClientConfig: + """Configure the Temporal client for OpenAI agents integration. + + This method sets up the Pydantic data converter to enable proper + serialization of OpenAI agent objects and responses. + + Args: + config: The client configuration to modify. + + Returns: + The modified client configuration. + """ config["data_converter"] = pydantic_data_converter return super().configure_client(config) def configure_worker(self, config: WorkerConfig) -> WorkerConfig: + """Configure the Temporal worker for OpenAI agents integration. + + This method adds the necessary interceptors and activities for OpenAI + agent execution: + - Adds tracing interceptors for OpenAI agent interactions + - Registers model execution activities + + Args: + config: The worker configuration to modify. + + Returns: + The modified worker configuration. + """ config["interceptors"] = list(config.get("interceptors") or []) + [ OpenAIAgentsTracingInterceptor() ] @@ -165,5 +250,14 @@ def configure_worker(self, config: WorkerConfig) -> WorkerConfig: return super().configure_worker(config) async def run_worker(self, worker: Worker) -> None: + """Run the worker with OpenAI agents temporal overrides. + + This method sets up the necessary runtime overrides for OpenAI agents + to work within the Temporal worker context, including custom runners + and trace providers. + + Args: + worker: The worker instance to run. + """ with set_open_ai_agent_temporal_overrides(self._model_params): - await super().run_worker(worker) \ No newline at end of file + await super().run_worker(worker) diff --git a/temporalio/contrib/openai_agents/workflow.py b/temporalio/contrib/openai_agents/workflow.py index 50eba0b9e..35d7c0311 100644 --- a/temporalio/contrib/openai_agents/workflow.py +++ b/temporalio/contrib/openai_agents/workflow.py @@ -240,4 +240,26 @@ async def run_operation(ctx: RunContextWrapper[Any], input: str) -> Any: class ToolSerializationError(TemporalError): - """Error that occurs when a tool output could not be serialized.""" + """Error that occurs when a tool output could not be serialized. + + .. warning:: + This exception is experimental and may change in future versions. + Use with caution in production environments. + + This exception is raised when a tool (created from an activity or Nexus operation) + returns a value that cannot be properly serialized for use by the OpenAI agent. + All tool outputs must be convertible to strings for the agent to process them. + + The error typically occurs when: + - A tool returns a complex object that doesn't have a meaningful string representation + - The returned object cannot be converted using str() + - Custom serialization is needed but not implemented + + Example: + >>> @activity.defn + >>> def problematic_tool() -> ComplexObject: + ... return ComplexObject() # This might cause ToolSerializationError + + To fix this error, ensure your tool returns string-convertible values or + modify the tool to return a string representation of the result. + """ diff --git a/tests/contrib/openai_agents/test_openai.py b/tests/contrib/openai_agents/test_openai.py index 1282e20a3..3b0808d76 100644 --- a/tests/contrib/openai_agents/test_openai.py +++ b/tests/contrib/openai_agents/test_openai.py @@ -530,68 +530,63 @@ async def test_nexus_tool_workflow( pytest.skip("Nexus tests don't work with time-skipping server") new_config = client.config() - new_config["data_converter"] = pydantic_data_converter + new_config["plugins"] = [ + openai_agents.Plugin( + model_params=ModelActivityParameters( + start_to_close_timeout=timedelta(seconds=30) + ), + model_provider=TestModelProvider(TestResearchModel()) + if use_local_model + else None, + ) + ] client = Client(**new_config) - model_params = ModelActivityParameters(start_to_close_timeout=timedelta(seconds=30)) - with set_open_ai_agent_temporal_overrides(model_params): - model_activity = ModelActivity( - TestModelProvider( - TestNexusWeatherModel( # type: ignore - ) - ) - if use_local_model - else None + async with new_worker( + client, + NexusToolsWorkflow, + nexus_service_handlers=[WeatherServiceHandler()], + ) as worker: + await create_nexus_endpoint(worker.task_queue, client) + + workflow_handle = await client.start_workflow( + NexusToolsWorkflow.run, + "What is the weather in Tokio?", + id=f"nexus-tools-workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=30), ) - async with new_worker( - client, - NexusToolsWorkflow, - activities=[ - model_activity.invoke_model_activity, - ], - nexus_service_handlers=[WeatherServiceHandler()], - interceptors=[OpenAIAgentsTracingInterceptor()], - ) as worker: - await create_nexus_endpoint(worker.task_queue, client) - - workflow_handle = await client.start_workflow( - NexusToolsWorkflow.run, - "What is the weather in Tokio?", - id=f"nexus-tools-workflow-{uuid.uuid4()}", - task_queue=worker.task_queue, - execution_timeout=timedelta(seconds=30), - ) - result = await workflow_handle.result() + result = await workflow_handle.result() - if use_local_model: - assert result == "Test nexus weather result" + if use_local_model: + assert result == "Test nexus weather result" - events = [] - async for e in workflow_handle.fetch_history_events(): - if e.HasField( - "activity_task_completed_event_attributes" - ) or e.HasField("nexus_operation_completed_event_attributes"): - events.append(e) + events = [] + async for e in workflow_handle.fetch_history_events(): + if e.HasField("activity_task_completed_event_attributes") or e.HasField( + "nexus_operation_completed_event_attributes" + ): + events.append(e) - assert len(events) == 3 - assert ( - "function_call" - in events[0] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) - assert ( - "Sunny with wind" - in events[ - 1 - ].nexus_operation_completed_event_attributes.result.data.decode() - ) - assert ( - "Test nexus weather result" - in events[2] - .activity_task_completed_event_attributes.result.payloads[0] - .data.decode() - ) + assert len(events) == 3 + assert ( + "function_call" + in events[0] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) + assert ( + "Sunny with wind" + in events[ + 1 + ].nexus_operation_completed_event_attributes.result.data.decode() + ) + assert ( + "Test nexus weather result" + in events[2] + .activity_task_completed_event_attributes.result.payloads[0] + .data.decode() + ) @no_type_check diff --git a/tests/contrib/openai_agents/test_openai_tracing.py b/tests/contrib/openai_agents/test_openai_tracing.py index 5a7d03785..ff5619820 100644 --- a/tests/contrib/openai_agents/test_openai_tracing.py +++ b/tests/contrib/openai_agents/test_openai_tracing.py @@ -1,19 +1,15 @@ -import datetime import uuid from datetime import timedelta -from typing import Any, Optional +from typing import Any from agents import Span, Trace, TracingProcessor from agents.tracing import get_trace_provider from temporalio.client import Client +from temporalio.contrib import openai_agents from temporalio.contrib.openai_agents import ( - ModelActivity, - OpenAIAgentsTracingInterceptor, TestModelProvider, - set_open_ai_agent_temporal_overrides, ) -from temporalio.contrib.pydantic import pydantic_data_converter from tests.contrib.openai_agents.test_openai import ResearchWorkflow, TestResearchModel from tests.helpers import new_worker @@ -44,108 +40,99 @@ def force_flush(self) -> None: async def test_tracing(client: Client): new_config = client.config() - new_config["data_converter"] = pydantic_data_converter + new_config["plugins"] = [ + openai_agents.Plugin(model_provider=TestModelProvider(TestResearchModel())) + ] client = Client(**new_config) - with set_open_ai_agent_temporal_overrides(): - provider = get_trace_provider() - - processor = MemoryTracingProcessor() - provider.set_processors([processor]) - - model_activity = ModelActivity(TestModelProvider(TestResearchModel())) - async with new_worker( - client, - ResearchWorkflow, - activities=[model_activity.invoke_model_activity], - interceptors=[OpenAIAgentsTracingInterceptor()], - ) as worker: - workflow_handle = await client.start_workflow( - ResearchWorkflow.run, - "Caribbean vacation spots in April, optimizing for surfing, hiking and water sports", - id=f"research-workflow-{uuid.uuid4()}", - task_queue=worker.task_queue, - execution_timeout=timedelta(seconds=120), - ) - result = await workflow_handle.result() - - # There is one closed root trace - assert len(processor.trace_events) == 2 - assert ( - processor.trace_events[0][0].trace_id - == processor.trace_events[1][0].trace_id - ) - assert processor.trace_events[0][1] - assert not processor.trace_events[1][1] - - def paired_span(a: tuple[Span[Any], bool], b: tuple[Span[Any], bool]) -> None: - assert a[0].trace_id == b[0].trace_id - assert a[1] - assert not b[1] - - # Initial planner spans - There are only 3 because we don't make an actual model call - paired_span(processor.span_events[0], processor.span_events[5]) - assert ( - processor.span_events[0][0].span_data.export().get("name") == "PlannerAgent" - ) - - paired_span(processor.span_events[1], processor.span_events[4]) - assert ( - processor.span_events[1][0].span_data.export().get("name") - == "temporal:startActivity" - ) - - paired_span(processor.span_events[2], processor.span_events[3]) - assert ( - processor.span_events[2][0].span_data.export().get("name") - == "temporal:executeActivity" + provider = get_trace_provider() + + processor = MemoryTracingProcessor() + provider.set_processors([processor]) + + async with new_worker( + client, + ResearchWorkflow, + ) as worker: + workflow_handle = await client.start_workflow( + ResearchWorkflow.run, + "Caribbean vacation spots in April, optimizing for surfing, hiking and water sports", + id=f"research-workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=120), ) + result = await workflow_handle.result() + + # There is one closed root trace + assert len(processor.trace_events) == 2 + assert ( + processor.trace_events[0][0].trace_id == processor.trace_events[1][0].trace_id + ) + assert processor.trace_events[0][1] + assert not processor.trace_events[1][1] + + def paired_span(a: tuple[Span[Any], bool], b: tuple[Span[Any], bool]) -> None: + assert a[0].trace_id == b[0].trace_id + assert a[1] + assert not b[1] + + # Initial planner spans - There are only 3 because we don't make an actual model call + paired_span(processor.span_events[0], processor.span_events[5]) + assert processor.span_events[0][0].span_data.export().get("name") == "PlannerAgent" + + paired_span(processor.span_events[1], processor.span_events[4]) + assert ( + processor.span_events[1][0].span_data.export().get("name") + == "temporal:startActivity" + ) + + paired_span(processor.span_events[2], processor.span_events[3]) + assert ( + processor.span_events[2][0].span_data.export().get("name") + == "temporal:executeActivity" + ) + + for span, start in processor.span_events[6:-6]: + span_data = span.span_data.export() + + # All spans should be closed + if start: + assert any( + span.span_id == s.span_id and not s_start + for (s, s_start) in processor.span_events + ) - for span, start in processor.span_events[6:-6]: - span_data = span.span_data.export() - - # All spans should be closed - if start: - assert any( - span.span_id == s.span_id and not s_start - for (s, s_start) in processor.span_events - ) - - # Start activity is always parented to an agent - if span_data.get("name") == "temporal:startActivity": - parents = [ - s for (s, _) in processor.span_events if s.span_id == span.parent_id - ] - assert ( - len(parents) == 2 - and parents[0].span_data.export()["type"] == "agent" - ) - - # Execute is parented to start - if span_data.get("name") == "temporal:executeActivity": - parents = [ - s for (s, _) in processor.span_events if s.span_id == span.parent_id - ] - assert ( - len(parents) == 2 - and parents[0].span_data.export()["name"] - == "temporal:startActivity" - ) - - # Final writer spans - There are only 3 because we don't make an actual model call - paired_span(processor.span_events[-6], processor.span_events[-1]) - assert ( - processor.span_events[-6][0].span_data.export().get("name") == "WriterAgent" - ) + # Start activity is always parented to an agent + if span_data.get("name") == "temporal:startActivity": + parents = [ + s for (s, _) in processor.span_events if s.span_id == span.parent_id + ] + assert ( + len(parents) == 2 and parents[0].span_data.export()["type"] == "agent" + ) - paired_span(processor.span_events[-5], processor.span_events[-2]) - assert ( - processor.span_events[-5][0].span_data.export().get("name") - == "temporal:startActivity" - ) + # Execute is parented to start + if span_data.get("name") == "temporal:executeActivity": + parents = [ + s for (s, _) in processor.span_events if s.span_id == span.parent_id + ] + assert ( + len(parents) == 2 + and parents[0].span_data.export()["name"] == "temporal:startActivity" + ) - paired_span(processor.span_events[-4], processor.span_events[-3]) - assert ( - processor.span_events[-4][0].span_data.export().get("name") - == "temporal:executeActivity" - ) + # Final writer spans - There are only 3 because we don't make an actual model call + paired_span(processor.span_events[-6], processor.span_events[-1]) + assert processor.span_events[-6][0].span_data.export().get("name") == "WriterAgent" + + paired_span(processor.span_events[-5], processor.span_events[-2]) + assert ( + processor.span_events[-5][0].span_data.export().get("name") + == "temporal:startActivity" + ) + + paired_span(processor.span_events[-4], processor.span_events[-3]) + assert ( + processor.span_events[-4][0].span_data.export().get("name") + == "temporal:executeActivity" + ) diff --git a/tests/test_client.py b/tests/test_client.py index ef86eb950..a43b41a73 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1504,9 +1504,9 @@ async def test_cloud_client_simple(): class MyPlugin(Plugin): - def on_create_client(self, config: ClientConfig) -> ClientConfig: + def configure_client(self, config: ClientConfig) -> ClientConfig: config["namespace"] = "replaced_namespace" - return super().on_create_client(config) + return super().configure_client(config) async def connect_service_client( self, config: temporalio.service.ConnectConfig diff --git a/tests/worker/test_worker.py b/tests/worker/test_worker.py index 7a72729c7..3e3d0090e 100644 --- a/tests/worker/test_worker.py +++ b/tests/worker/test_worker.py @@ -1192,17 +1192,17 @@ def shutdown(self) -> None: class MyCombinedPlugin(temporalio.client.Plugin, temporalio.worker.Plugin): - def on_create_worker(self, config: WorkerConfig) -> WorkerConfig: + def configure_worker(self, config: WorkerConfig) -> WorkerConfig: print("Create worker combined plugin") config["task_queue"] = "combined" - return super().on_create_worker(config) + return super().configure_worker(config) class MyWorkerPlugin(temporalio.worker.Plugin): - def on_create_worker(self, config: WorkerConfig) -> WorkerConfig: + def configure_worker(self, config: WorkerConfig) -> WorkerConfig: print("Create worker worker plugin") config["task_queue"] = "replaced_queue" - return super().on_create_worker(config) + return super().configure_worker(config) async def run_worker(self, worker: Worker) -> None: await super().run_worker(worker) From 9b6a08b1f3875dd2dbf98b5bcbc0acd9a6504dce Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Mon, 21 Jul 2025 09:55:18 -0700 Subject: [PATCH 08/11] Remove tests duplicated by merge --- tests/test_client.py | 15 ------------- tests/worker/test_worker.py | 43 ------------------------------------- 2 files changed, 58 deletions(-) diff --git a/tests/test_client.py b/tests/test_client.py index a43b41a73..427aff459 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1513,18 +1513,3 @@ async def connect_service_client( ) -> temporalio.service.ServiceClient: config.api_key = "replaced key" return await super().connect_service_client(config) - - -async def test_client_plugin(client: Client, env: WorkflowEnvironment): - if env.supports_time_skipping: - pytest.skip("Client connect is only designed for local") - - config = client.config() - config["plugins"] = [MyPlugin()] - new_client = Client(**config) - assert new_client.namespace == "replaced_namespace" - - new_client = await Client.connect( - client.service_client.config.target_host, plugins=[MyPlugin()] - ) - assert new_client.service_client.config.api_key == "replaced key" diff --git a/tests/worker/test_worker.py b/tests/worker/test_worker.py index 3e3d0090e..7c08cfa37 100644 --- a/tests/worker/test_worker.py +++ b/tests/worker/test_worker.py @@ -1189,46 +1189,3 @@ def shutdown(self) -> None: if self.next_exception_task: self.next_exception_task.cancel() setattr(self.worker._bridge_worker, self.attr, self.orig_poll_call) - - -class MyCombinedPlugin(temporalio.client.Plugin, temporalio.worker.Plugin): - def configure_worker(self, config: WorkerConfig) -> WorkerConfig: - print("Create worker combined plugin") - config["task_queue"] = "combined" - return super().configure_worker(config) - - -class MyWorkerPlugin(temporalio.worker.Plugin): - def configure_worker(self, config: WorkerConfig) -> WorkerConfig: - print("Create worker worker plugin") - config["task_queue"] = "replaced_queue" - return super().configure_worker(config) - - async def run_worker(self, worker: Worker) -> None: - await super().run_worker(worker) - - -async def test_worker_plugin(client: Client) -> None: - worker = Worker( - client, - task_queue="queue", - activities=[never_run_activity], - plugins=[MyWorkerPlugin()], - ) - assert worker.config().get("task_queue") == "replaced_queue" - - # Test client plugin propagation to worker plugins - new_config = client.config() - new_config["plugins"] = [MyCombinedPlugin()] - client = Client(**new_config) - worker = Worker(client, task_queue="queue", activities=[never_run_activity]) - assert worker.config().get("task_queue") == "combined" - - # Test both. Client propagated plugins are called first, so the worker plugin overrides in this case - worker = Worker( - client, - task_queue="queue", - activities=[never_run_activity], - plugins=[MyWorkerPlugin()], - ) - assert worker.config().get("task_queue") == "replaced_queue" From e116ada42bf305d60a176ff980e162c08c5e02d8 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Mon, 21 Jul 2025 11:37:16 -0700 Subject: [PATCH 09/11] Fix tests --- tests/contrib/openai_agents/test_openai.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/contrib/openai_agents/test_openai.py b/tests/contrib/openai_agents/test_openai.py index 3b0808d76..ceb347bec 100644 --- a/tests/contrib/openai_agents/test_openai.py +++ b/tests/contrib/openai_agents/test_openai.py @@ -431,6 +431,7 @@ async def test_tool_workflow(client: Client, use_local_model: bool): get_weather_object, get_weather_country, get_weather_context, + ActivityWeatherService().get_weather_method, ], ) as worker: workflow_handle = await client.start_workflow( @@ -535,7 +536,7 @@ async def test_nexus_tool_workflow( model_params=ModelActivityParameters( start_to_close_timeout=timedelta(seconds=30) ), - model_provider=TestModelProvider(TestResearchModel()) + model_provider=TestModelProvider(TestNexusWeatherModel()) if use_local_model else None, ) From b0e10f00818b04f42a6f95ce4970c6471e499670 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Tue, 22 Jul 2025 07:53:03 -0700 Subject: [PATCH 10/11] Rename plugin + a few fixes --- temporalio/contrib/openai_agents/__init__.py | 4 ++-- .../openai_agents/temporal_openai_agents.py | 7 +++---- tests/contrib/openai_agents/test_openai.py | 18 +++++++++--------- .../openai_agents/test_openai_tracing.py | 2 +- tests/test_client.py | 12 ------------ 5 files changed, 15 insertions(+), 28 deletions(-) diff --git a/temporalio/contrib/openai_agents/__init__.py b/temporalio/contrib/openai_agents/__init__.py index 378bd6da0..c9da59497 100644 --- a/temporalio/contrib/openai_agents/__init__.py +++ b/temporalio/contrib/openai_agents/__init__.py @@ -10,7 +10,7 @@ from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters from temporalio.contrib.openai_agents.temporal_openai_agents import ( - Plugin, + OpenAIAgentsPlugin, TestModel, TestModelProvider, ) @@ -18,7 +18,7 @@ from . import workflow __all__ = [ - "Plugin", + "OpenAIAgentsPlugin", "ModelActivityParameters", "workflow", "TestModel", diff --git a/temporalio/contrib/openai_agents/temporal_openai_agents.py b/temporalio/contrib/openai_agents/temporal_openai_agents.py index 76129c79e..0d7c5f968 100644 --- a/temporalio/contrib/openai_agents/temporal_openai_agents.py +++ b/temporalio/contrib/openai_agents/temporal_openai_agents.py @@ -144,7 +144,7 @@ def stream_response( raise NotImplementedError() -class Plugin(temporalio.client.Plugin, temporalio.worker.Plugin): +class OpenAIAgentsPlugin(temporalio.client.Plugin, temporalio.worker.Plugin): """Temporal plugin for integrating OpenAI agents with Temporal workflows. .. warning:: @@ -171,7 +171,7 @@ class Plugin(temporalio.client.Plugin, temporalio.worker.Plugin): Example: >>> from temporalio.client import Client >>> from temporalio.worker import Worker - >>> from temporalio.contrib.openai_agents import Plugin, ModelActivityParameters + >>> from temporalio.contrib.openai_agents import OpenAIAgentsPlugin, ModelActivityParameters >>> from datetime import timedelta >>> >>> # Configure model parameters @@ -181,7 +181,7 @@ class Plugin(temporalio.client.Plugin, temporalio.worker.Plugin): ... ) >>> >>> # Create plugin - >>> plugin = Plugin(model_params=model_params) + >>> plugin = OpenAIAgentsPlugin(model_params=model_params) >>> >>> # Use with client and worker >>> client = await Client.connect( @@ -192,7 +192,6 @@ class Plugin(temporalio.client.Plugin, temporalio.worker.Plugin): ... client, ... task_queue="my-task-queue", ... workflows=[MyWorkflow], - ... plugins=[plugin] ... ) """ diff --git a/tests/contrib/openai_agents/test_openai.py b/tests/contrib/openai_agents/test_openai.py index ceb347bec..018fcd791 100644 --- a/tests/contrib/openai_agents/test_openai.py +++ b/tests/contrib/openai_agents/test_openai.py @@ -124,7 +124,7 @@ async def test_hello_world_agent(client: Client, use_local_model: bool): pytest.skip("No openai API key") new_config = client.config() new_config["plugins"] = [ - openai_agents.Plugin( + openai_agents.OpenAIAgentsPlugin( model_params=ModelActivityParameters( start_to_close_timeout=timedelta(seconds=30) ), @@ -412,7 +412,7 @@ async def test_tool_workflow(client: Client, use_local_model: bool): pytest.skip("No openai API key") new_config = client.config() new_config["plugins"] = [ - openai_agents.Plugin( + openai_agents.OpenAIAgentsPlugin( model_params=ModelActivityParameters( start_to_close_timeout=timedelta(seconds=30) ), @@ -532,7 +532,7 @@ async def test_nexus_tool_workflow( new_config = client.config() new_config["plugins"] = [ - openai_agents.Plugin( + openai_agents.OpenAIAgentsPlugin( model_params=ModelActivityParameters( start_to_close_timeout=timedelta(seconds=30) ), @@ -678,7 +678,7 @@ async def test_research_workflow(client: Client, use_local_model: bool): pytest.skip("No openai API key") new_config = client.config() new_config["plugins"] = [ - openai_agents.Plugin( + openai_agents.OpenAIAgentsPlugin( model_params=ModelActivityParameters( start_to_close_timeout=timedelta(seconds=30) ), @@ -892,7 +892,7 @@ async def test_agents_as_tools_workflow(client: Client, use_local_model: bool): pytest.skip("No openai API key") new_config = client.config() new_config["plugins"] = [ - openai_agents.Plugin( + openai_agents.OpenAIAgentsPlugin( model_params=ModelActivityParameters( start_to_close_timeout=timedelta(seconds=30) ), @@ -1244,7 +1244,7 @@ async def test_customer_service_workflow(client: Client, use_local_model: bool): pytest.skip("No openai API key") new_config = client.config() new_config["plugins"] = [ - openai_agents.Plugin( + openai_agents.OpenAIAgentsPlugin( model_params=ModelActivityParameters( start_to_close_timeout=timedelta(seconds=30) ), @@ -1534,7 +1534,7 @@ async def test_input_guardrail(client: Client, use_local_model: bool): pytest.skip("No openai API key") new_config = client.config() new_config["plugins"] = [ - openai_agents.Plugin( + openai_agents.OpenAIAgentsPlugin( model_params=ModelActivityParameters( start_to_close_timeout=timedelta(seconds=30) ), @@ -1649,7 +1649,7 @@ async def test_output_guardrail(client: Client, use_local_model: bool): pytest.skip("No openai API key") new_config = client.config() new_config["plugins"] = [ - openai_agents.Plugin( + openai_agents.OpenAIAgentsPlugin( model_params=ModelActivityParameters( start_to_close_timeout=timedelta(seconds=30) ), @@ -1737,7 +1737,7 @@ async def run_tool(self): async def test_workflow_method_tools(client: Client): new_config = client.config() new_config["plugins"] = [ - openai_agents.Plugin( + openai_agents.OpenAIAgentsPlugin( model_params=ModelActivityParameters( start_to_close_timeout=timedelta(seconds=30) ), diff --git a/tests/contrib/openai_agents/test_openai_tracing.py b/tests/contrib/openai_agents/test_openai_tracing.py index ff5619820..22fea5086 100644 --- a/tests/contrib/openai_agents/test_openai_tracing.py +++ b/tests/contrib/openai_agents/test_openai_tracing.py @@ -41,7 +41,7 @@ def force_flush(self) -> None: async def test_tracing(client: Client): new_config = client.config() new_config["plugins"] = [ - openai_agents.Plugin(model_provider=TestModelProvider(TestResearchModel())) + openai_agents.OpenAIAgentsPlugin(model_provider=TestModelProvider(TestResearchModel())) ] client = Client(**new_config) diff --git a/tests/test_client.py b/tests/test_client.py index 427aff459..9c33e9e1c 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1501,15 +1501,3 @@ async def test_cloud_client_simple(): GetNamespaceRequest(namespace=os.environ["TEMPORAL_CLIENT_CLOUD_NAMESPACE"]) ) assert os.environ["TEMPORAL_CLIENT_CLOUD_NAMESPACE"] == result.namespace.namespace - - -class MyPlugin(Plugin): - def configure_client(self, config: ClientConfig) -> ClientConfig: - config["namespace"] = "replaced_namespace" - return super().configure_client(config) - - async def connect_service_client( - self, config: temporalio.service.ConnectConfig - ) -> temporalio.service.ServiceClient: - config.api_key = "replaced key" - return await super().connect_service_client(config) From 3e7217fb1f831b33935dbf82b2dcf222c4228448 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Tue, 22 Jul 2025 08:21:46 -0700 Subject: [PATCH 11/11] Lint --- tests/contrib/openai_agents/test_openai_tracing.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/contrib/openai_agents/test_openai_tracing.py b/tests/contrib/openai_agents/test_openai_tracing.py index 22fea5086..c8ad366e6 100644 --- a/tests/contrib/openai_agents/test_openai_tracing.py +++ b/tests/contrib/openai_agents/test_openai_tracing.py @@ -41,7 +41,9 @@ def force_flush(self) -> None: async def test_tracing(client: Client): new_config = client.config() new_config["plugins"] = [ - openai_agents.OpenAIAgentsPlugin(model_provider=TestModelProvider(TestResearchModel())) + openai_agents.OpenAIAgentsPlugin( + model_provider=TestModelProvider(TestResearchModel()) + ) ] client = Client(**new_config)