From 1741a6e9c8bd88526aaee4c8fb5cc2a0dad65e52 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Wed, 30 Jul 2025 08:41:46 -0700 Subject: [PATCH 1/7] POC for replayer configuration from existing plugins --- temporalio/worker/_replayer.py | 27 +++++++++++++++++---- tests/test_plugins.py | 43 ++++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 4 deletions(-) diff --git a/temporalio/worker/_replayer.py b/temporalio/worker/_replayer.py index 6e9761b58..b17aed847 100644 --- a/temporalio/worker/_replayer.py +++ b/temporalio/worker/_replayer.py @@ -7,7 +7,7 @@ import logging from contextlib import asynccontextmanager from dataclasses import dataclass -from typing import AsyncIterator, Dict, Mapping, Optional, Sequence, Type +from typing import AsyncIterator, Dict, Mapping, Optional, Sequence, Type, Union from typing_extensions import TypedDict @@ -19,9 +19,11 @@ import temporalio.runtime import temporalio.workflow + from ..common import HeaderCodecBehavior from ._interceptor import Interceptor -from ._worker import load_default_build_id +from ._worker import load_default_build_id, WorkerConfig +from temporalio.client import ClientConfig from ._workflow import _WorkflowWorker from ._workflow_instance import UnsandboxedWorkflowRunner, WorkflowRunner from .workflow_sandbox import SandboxedWorkflowRunner @@ -42,6 +44,7 @@ def __init__( namespace: str = "ReplayNamespace", data_converter: temporalio.converter.DataConverter = temporalio.converter.DataConverter.default, interceptors: Sequence[Interceptor] = [], + plugins: Sequence[Union[temporalio.worker.Plugin, temporalio.client.Plugin]] = [], build_id: Optional[str] = None, identity: Optional[str] = None, workflow_failure_exception_types: Sequence[Type[BaseException]] = [], @@ -62,8 +65,6 @@ def __init__( will be shared across all replay calls and never explicitly shut down. Users are encouraged to provide their own if needing more control. """ - if not workflows: - raise ValueError("At least one workflow must be specified") self._config = ReplayerConfig( workflows=list(workflows), workflow_task_executor=( @@ -82,6 +83,24 @@ def __init__( disable_safe_workflow_eviction=disable_safe_workflow_eviction, header_codec_behavior=header_codec_behavior, ) + root_worker_plugin: temporalio.worker.Plugin = temporalio.worker._worker._RootPlugin() + root_client_plugin: temporalio.client.Plugin = temporalio.client._RootPlugin() + for plugin in reversed(plugins): + root_worker_plugin = plugin.init_worker_plugin(root_worker_plugin) + root_client_plugin = plugin.init_client_plugin(root_client_plugin) + + # Allow plugins to configure shared configurations with worker + worker_config = WorkerConfig(**{k: v for k, v in self._config.items() if k in WorkerConfig.__annotations__}) + worker_config = root_worker_plugin.configure_worker(worker_config) + self._config.update({k: v for k, v in worker_config.items() if k in ReplayerConfig.__annotations__}) + + # Allow plugins to configure shared configurations with client + client_config = ClientConfig(**{k: v for k, v in self._config.items() if k in ClientConfig.__annotations__}) + client_config = root_client_plugin.configure_client(client_config) + self._config.update({k: v for k, v in client_config.items() if k in ReplayerConfig.__annotations__}) + + if not self._config["workflows"]: + raise ValueError("At least one workflow must be specified") def config(self) -> ReplayerConfig: """Config, as a dictionary, used to create this replayer. diff --git a/tests/test_plugins.py b/tests/test_plugins.py index 4a60bba4d..42c221f26 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -1,4 +1,5 @@ import dataclasses +import uuid import warnings from typing import cast @@ -6,11 +7,15 @@ import temporalio.client import temporalio.worker +from temporalio import workflow from temporalio.client import Client, ClientConfig, OutboundInterceptor +from temporalio.contrib.pydantic import pydantic_data_converter from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker, WorkerConfig from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner from tests.worker.test_worker import never_run_activity +from temporalio.worker import Replayer +from tests.helpers import new_worker class TestClientInterceptor(temporalio.client.Interceptor): @@ -136,3 +141,41 @@ async def test_worker_sandbox_restrictions(client: Client) -> None: SandboxedWorkflowRunner, worker.config().get("workflow_runner") ).restrictions.passthrough_modules ) + +class ReplayCheckPlugin(temporalio.client.Plugin, temporalio.worker.Plugin): + def configure_worker(self, config: WorkerConfig) -> WorkerConfig: + config["workflows"] = list(config["workflows"]) + [HelloWorkflow] + return super().configure_worker(config) + + def configure_client(self, config: ClientConfig) -> ClientConfig: + config["data_converter"] = pydantic_data_converter + return super().configure_client(config) + +@workflow.defn +class HelloWorkflow: + @workflow.run + async def run(self, name: str) -> str: + return f"Hello, {name}!" + +async def test_replay(client: Client) -> None: + plugin = ReplayCheckPlugin() + new_config = client.config() + new_config["plugins"] = [plugin] + client = Client(**new_config) + + async with new_worker(client) as worker: + handle = await client.start_workflow( + HelloWorkflow.run, + "Tim", + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + await handle.result() + replayer = Replayer( + workflows=[], + plugins=[plugin] + ) + assert len(replayer.config()["workflows"])==1 + assert replayer.config()["data_converter"] == pydantic_data_converter + + await replayer.replay_workflow(await handle.fetch_history()) From 3c92c40b4b416b0ab4809312083522fb08ba691a Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Wed, 30 Jul 2025 09:22:25 -0700 Subject: [PATCH 2/7] Handle non-combined cases --- temporalio/worker/_replayer.py | 11 +++++++---- tests/test_plugins.py | 17 +++++++++++++++++ 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/temporalio/worker/_replayer.py b/temporalio/worker/_replayer.py index b17aed847..bd9ac06d1 100644 --- a/temporalio/worker/_replayer.py +++ b/temporalio/worker/_replayer.py @@ -83,18 +83,21 @@ def __init__( disable_safe_workflow_eviction=disable_safe_workflow_eviction, header_codec_behavior=header_codec_behavior, ) + + # Allow plugins to configure shared configurations with worker root_worker_plugin: temporalio.worker.Plugin = temporalio.worker._worker._RootPlugin() - root_client_plugin: temporalio.client.Plugin = temporalio.client._RootPlugin() - for plugin in reversed(plugins): + for plugin in reversed([plugin for plugin in plugins if isinstance(plugin, temporalio.worker.Plugin)]): root_worker_plugin = plugin.init_worker_plugin(root_worker_plugin) - root_client_plugin = plugin.init_client_plugin(root_client_plugin) - # Allow plugins to configure shared configurations with worker worker_config = WorkerConfig(**{k: v for k, v in self._config.items() if k in WorkerConfig.__annotations__}) worker_config = root_worker_plugin.configure_worker(worker_config) self._config.update({k: v for k, v in worker_config.items() if k in ReplayerConfig.__annotations__}) # Allow plugins to configure shared configurations with client + root_client_plugin: temporalio.client.Plugin = temporalio.client._RootPlugin() + for plugin in reversed([plugin for plugin in plugins if isinstance(plugin, temporalio.client.Plugin)]): + root_client_plugin = plugin.init_client_plugin(root_client_plugin) + client_config = ClientConfig(**{k: v for k, v in self._config.items() if k in ClientConfig.__annotations__}) client_config = root_client_plugin.configure_client(client_config) self._config.update({k: v for k, v in client_config.items() if k in ReplayerConfig.__annotations__}) diff --git a/tests/test_plugins.py b/tests/test_plugins.py index 42c221f26..0d62b15b6 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -179,3 +179,20 @@ async def test_replay(client: Client) -> None: assert replayer.config()["data_converter"] == pydantic_data_converter await replayer.replay_workflow(await handle.fetch_history()) + + replayer = Replayer( + workflows=[HelloWorkflow], + plugins=[MyClientPlugin()] + ) + replayer = Replayer( + workflows=[HelloWorkflow], + plugins=[MyWorkerPlugin()] + ) + replayer = Replayer( + workflows=[HelloWorkflow], + plugins=[MyClientPlugin(), MyWorkerPlugin()] + ) + replayer = Replayer( + workflows=[HelloWorkflow], + plugins=[MyWorkerPlugin(), MyClientPlugin(), MyCombinedPlugin()] + ) \ No newline at end of file From 76cacf666350cbb5bf90d002fbb4e49bdc21a209 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Wed, 30 Jul 2025 09:58:12 -0700 Subject: [PATCH 3/7] Fixing type checking --- temporalio/worker/_replayer.py | 65 +++++++++++++++++++++++++++------- tests/test_plugins.py | 34 +++++++----------- 2 files changed, 65 insertions(+), 34 deletions(-) diff --git a/temporalio/worker/_replayer.py b/temporalio/worker/_replayer.py index bd9ac06d1..0e2045e04 100644 --- a/temporalio/worker/_replayer.py +++ b/temporalio/worker/_replayer.py @@ -7,7 +7,7 @@ import logging from contextlib import asynccontextmanager from dataclasses import dataclass -from typing import AsyncIterator, Dict, Mapping, Optional, Sequence, Type, Union +from typing import AsyncIterator, Dict, Mapping, Optional, Sequence, Type, Union, cast from typing_extensions import TypedDict @@ -18,12 +18,11 @@ import temporalio.converter import temporalio.runtime import temporalio.workflow - +from temporalio.client import ClientConfig from ..common import HeaderCodecBehavior from ._interceptor import Interceptor -from ._worker import load_default_build_id, WorkerConfig -from temporalio.client import ClientConfig +from ._worker import WorkerConfig, load_default_build_id from ._workflow import _WorkflowWorker from ._workflow_instance import UnsandboxedWorkflowRunner, WorkflowRunner from .workflow_sandbox import SandboxedWorkflowRunner @@ -44,7 +43,9 @@ def __init__( namespace: str = "ReplayNamespace", data_converter: temporalio.converter.DataConverter = temporalio.converter.DataConverter.default, interceptors: Sequence[Interceptor] = [], - plugins: Sequence[Union[temporalio.worker.Plugin, temporalio.client.Plugin]] = [], + plugins: Sequence[ + Union[temporalio.worker.Plugin, temporalio.client.Plugin] + ] = [], build_id: Optional[str] = None, identity: Optional[str] = None, workflow_failure_exception_types: Sequence[Type[BaseException]] = [], @@ -86,21 +87,59 @@ def __init__( # Allow plugins to configure shared configurations with worker root_worker_plugin: temporalio.worker.Plugin = temporalio.worker._worker._RootPlugin() - for plugin in reversed([plugin for plugin in plugins if isinstance(plugin, temporalio.worker.Plugin)]): + for plugin in reversed( + [ + plugin + for plugin in plugins + if isinstance(plugin, temporalio.worker.Plugin) + ] + ): root_worker_plugin = plugin.init_worker_plugin(root_worker_plugin) - worker_config = WorkerConfig(**{k: v for k, v in self._config.items() if k in WorkerConfig.__annotations__}) + worker_config = cast( + WorkerConfig, + { + k: v + for k, v in self._config.items() + if k in WorkerConfig.__annotations__ + }, + ) + worker_config = root_worker_plugin.configure_worker(worker_config) - self._config.update({k: v for k, v in worker_config.items() if k in ReplayerConfig.__annotations__}) + self._config.update( + cast(ReplayerConfig, { + k: v + for k, v in worker_config.items() + if k in ReplayerConfig.__annotations__ + }) + ) # Allow plugins to configure shared configurations with client root_client_plugin: temporalio.client.Plugin = temporalio.client._RootPlugin() - for plugin in reversed([plugin for plugin in plugins if isinstance(plugin, temporalio.client.Plugin)]): - root_client_plugin = plugin.init_client_plugin(root_client_plugin) - - client_config = ClientConfig(**{k: v for k, v in self._config.items() if k in ClientConfig.__annotations__}) + for client_plugin in reversed( + [ + plugin + for plugin in plugins + if isinstance(plugin, temporalio.client.Plugin) + ] + ): + root_client_plugin = client_plugin.init_client_plugin(root_client_plugin) + + client_config = cast(ClientConfig, + { + k: v + for k, v in self._config.items() + if k in ClientConfig.__annotations__ + } + ) client_config = root_client_plugin.configure_client(client_config) - self._config.update({k: v for k, v in client_config.items() if k in ReplayerConfig.__annotations__}) + self._config.update( + cast(ReplayerConfig, { + k: v + for k, v in client_config.items() + if k in ReplayerConfig.__annotations__ + }) + ) if not self._config["workflows"]: raise ValueError("At least one workflow must be specified") diff --git a/tests/test_plugins.py b/tests/test_plugins.py index 0d62b15b6..a4ccd9244 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -11,11 +11,10 @@ from temporalio.client import Client, ClientConfig, OutboundInterceptor from temporalio.contrib.pydantic import pydantic_data_converter from temporalio.testing import WorkflowEnvironment -from temporalio.worker import Worker, WorkerConfig +from temporalio.worker import Replayer, Worker, WorkerConfig from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner -from tests.worker.test_worker import never_run_activity -from temporalio.worker import Replayer from tests.helpers import new_worker +from tests.worker.test_worker import never_run_activity class TestClientInterceptor(temporalio.client.Interceptor): @@ -142,21 +141,24 @@ async def test_worker_sandbox_restrictions(client: Client) -> None: ).restrictions.passthrough_modules ) + class ReplayCheckPlugin(temporalio.client.Plugin, temporalio.worker.Plugin): def configure_worker(self, config: WorkerConfig) -> WorkerConfig: - config["workflows"] = list(config["workflows"]) + [HelloWorkflow] + config["workflows"] = list(config.get("workflows") or []) + [HelloWorkflow] return super().configure_worker(config) def configure_client(self, config: ClientConfig) -> ClientConfig: config["data_converter"] = pydantic_data_converter return super().configure_client(config) + @workflow.defn class HelloWorkflow: @workflow.run async def run(self, name: str) -> str: return f"Hello, {name}!" + async def test_replay(client: Client) -> None: plugin = ReplayCheckPlugin() new_config = client.config() @@ -171,28 +173,18 @@ async def test_replay(client: Client) -> None: task_queue=worker.task_queue, ) await handle.result() - replayer = Replayer( - workflows=[], - plugins=[plugin] - ) - assert len(replayer.config()["workflows"])==1 - assert replayer.config()["data_converter"] == pydantic_data_converter + replayer = Replayer(workflows=[], plugins=[plugin]) + assert len(replayer.config().get("workflows") or []) == 1 + assert replayer.config().get("data_converter") == pydantic_data_converter await replayer.replay_workflow(await handle.fetch_history()) + replayer = Replayer(workflows=[HelloWorkflow], plugins=[MyClientPlugin()]) + replayer = Replayer(workflows=[HelloWorkflow], plugins=[MyWorkerPlugin()]) replayer = Replayer( - workflows=[HelloWorkflow], - plugins=[MyClientPlugin()] - ) - replayer = Replayer( - workflows=[HelloWorkflow], - plugins=[MyWorkerPlugin()] + workflows=[HelloWorkflow], plugins=[MyClientPlugin(), MyWorkerPlugin()] ) replayer = Replayer( workflows=[HelloWorkflow], - plugins=[MyClientPlugin(), MyWorkerPlugin()] + plugins=[MyWorkerPlugin(), MyClientPlugin(), MyCombinedPlugin()], ) - replayer = Replayer( - workflows=[HelloWorkflow], - plugins=[MyWorkerPlugin(), MyClientPlugin(), MyCombinedPlugin()] - ) \ No newline at end of file From f61b40082f537d259f9f1e536a9e32485ab4ab80 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Wed, 30 Jul 2025 12:33:30 -0700 Subject: [PATCH 4/7] Move shared configuration into plugin definition --- temporalio/worker/__init__.py | 2 + temporalio/worker/_replayer.py | 142 ++++++++++++++++++++++----------- tests/test_plugins.py | 14 ++-- 3 files changed, 108 insertions(+), 50 deletions(-) diff --git a/temporalio/worker/__init__.py b/temporalio/worker/__init__.py index 6e062afcc..4c1138fd5 100644 --- a/temporalio/worker/__init__.py +++ b/temporalio/worker/__init__.py @@ -24,6 +24,7 @@ from ._replayer import ( Replayer, ReplayerConfig, + ReplayerPlugin, WorkflowReplayResult, WorkflowReplayResults, ) @@ -68,6 +69,7 @@ "WorkerDeploymentVersion", "Replayer", "ReplayerConfig", + "ReplayerPlugin", "WorkflowReplayResult", "WorkflowReplayResults", "PollerBehavior", diff --git a/temporalio/worker/_replayer.py b/temporalio/worker/_replayer.py index 0e2045e04..e4cd1b3ad 100644 --- a/temporalio/worker/_replayer.py +++ b/temporalio/worker/_replayer.py @@ -7,7 +7,7 @@ import logging from contextlib import asynccontextmanager from dataclasses import dataclass -from typing import AsyncIterator, Dict, Mapping, Optional, Sequence, Type, Union, cast +from typing import AsyncIterator, Dict, Mapping, Optional, Sequence, Type, cast from typing_extensions import TypedDict @@ -18,7 +18,6 @@ import temporalio.converter import temporalio.runtime import temporalio.workflow -from temporalio.client import ClientConfig from ..common import HeaderCodecBehavior from ._interceptor import Interceptor @@ -30,6 +29,88 @@ logger = logging.getLogger(__name__) +class ReplayerPlugin: + """Base class for replayer plugins that can modify replayer configuration.""" + + def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: + """Configure the replayer. + + Default implementation applies shared configuration from worker and client plugins. + + Args: + config: The replayer configuration to modify. + + Returns: + The modified replayer configuration. + """ + # If this plugin is also a worker plugin, apply shared worker config + if isinstance(self, temporalio.worker.Plugin): + # Create a minimal worker config with shared fields + worker_config = cast( + WorkerConfig, + { + "workflows": config["workflows"], + "workflow_task_executor": config["workflow_task_executor"], + "workflow_runner": config["workflow_runner"], + "unsandboxed_workflow_runner": config[ + "unsandboxed_workflow_runner" + ], + "interceptors": config["interceptors"], + "build_id": config["build_id"], + "identity": config["identity"], + "workflow_failure_exception_types": config[ + "workflow_failure_exception_types" + ], + "debug_mode": config["debug_mode"], + "disable_safe_workflow_eviction": config[ + "disable_safe_workflow_eviction" + ], + }, + ) + + modified_worker_config = self.configure_worker(worker_config) + config["workflows"] = modified_worker_config["workflows"] + config["workflow_task_executor"] = modified_worker_config[ + "workflow_task_executor" + ] + config["workflow_runner"] = modified_worker_config["workflow_runner"] + config["unsandboxed_workflow_runner"] = modified_worker_config[ + "unsandboxed_workflow_runner" + ] + config["interceptors"] = modified_worker_config["interceptors"] + config["build_id"] = modified_worker_config["build_id"] + config["identity"] = modified_worker_config["identity"] + config["workflow_failure_exception_types"] = modified_worker_config[ + "workflow_failure_exception_types" + ] + config["debug_mode"] = modified_worker_config["debug_mode"] + config["disable_safe_workflow_eviction"] = modified_worker_config[ + "disable_safe_workflow_eviction" + ] + + # If this plugin is also a client plugin, apply shared client config + if isinstance(self, temporalio.client.Plugin): + # Only include fields that exist in both ReplayerConfig and ClientConfig + # Note: interceptors are different types between client and worker, so excluded + client_config = cast( + temporalio.client.ClientConfig, + { + "namespace": config["namespace"], + "data_converter": config["data_converter"], + "header_codec_behavior": config["header_codec_behavior"], + }, + ) + + modified_client_config = self.configure_client(client_config) + config["namespace"] = modified_client_config["namespace"] + config["data_converter"] = modified_client_config["data_converter"] + config["header_codec_behavior"] = modified_client_config[ + "header_codec_behavior" + ] + + return config + + class Replayer: """Replayer to replay workflows from history.""" @@ -43,9 +124,7 @@ def __init__( namespace: str = "ReplayNamespace", data_converter: temporalio.converter.DataConverter = temporalio.converter.DataConverter.default, interceptors: Sequence[Interceptor] = [], - plugins: Sequence[ - Union[temporalio.worker.Plugin, temporalio.client.Plugin] - ] = [], + plugins: Sequence[ReplayerPlugin] = [], build_id: Optional[str] = None, identity: Optional[str] = None, workflow_failure_exception_types: Sequence[Type[BaseException]] = [], @@ -85,62 +164,35 @@ def __init__( header_codec_behavior=header_codec_behavior, ) - # Allow plugins to configure shared configurations with worker - root_worker_plugin: temporalio.worker.Plugin = temporalio.worker._worker._RootPlugin() - for plugin in reversed( + # Initialize all worker plugins + root_worker_plugin: temporalio.worker.Plugin = ( + temporalio.worker._worker._RootPlugin() + ) + for worker_plugin in reversed( [ - plugin + cast(temporalio.worker.Plugin, plugin) for plugin in plugins if isinstance(plugin, temporalio.worker.Plugin) ] ): - root_worker_plugin = plugin.init_worker_plugin(root_worker_plugin) - - worker_config = cast( - WorkerConfig, - { - k: v - for k, v in self._config.items() - if k in WorkerConfig.__annotations__ - }, - ) - - worker_config = root_worker_plugin.configure_worker(worker_config) - self._config.update( - cast(ReplayerConfig, { - k: v - for k, v in worker_config.items() - if k in ReplayerConfig.__annotations__ - }) - ) + root_worker_plugin = worker_plugin.init_worker_plugin(root_worker_plugin) - # Allow plugins to configure shared configurations with client + # Initialize all client plugins root_client_plugin: temporalio.client.Plugin = temporalio.client._RootPlugin() for client_plugin in reversed( [ - plugin + cast(temporalio.client.Plugin, plugin) for plugin in plugins if isinstance(plugin, temporalio.client.Plugin) ] ): root_client_plugin = client_plugin.init_client_plugin(root_client_plugin) - client_config = cast(ClientConfig, - { - k: v - for k, v in self._config.items() - if k in ClientConfig.__annotations__ - } - ) - client_config = root_client_plugin.configure_client(client_config) - self._config.update( - cast(ReplayerConfig, { - k: v - for k, v in client_config.items() - if k in ReplayerConfig.__annotations__ - }) - ) + # Apply plugin configuration + for plugin in plugins: + self._config = plugin.configure_replayer(self._config) + # Validate workflows after plugin configuration if not self._config["workflows"]: raise ValueError("At least one workflow must be specified") diff --git a/tests/test_plugins.py b/tests/test_plugins.py index a4ccd9244..49ddc23fa 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -11,7 +11,7 @@ from temporalio.client import Client, ClientConfig, OutboundInterceptor from temporalio.contrib.pydantic import pydantic_data_converter from temporalio.testing import WorkflowEnvironment -from temporalio.worker import Replayer, Worker, WorkerConfig +from temporalio.worker import Replayer, ReplayerConfig, Worker, WorkerConfig from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner from tests.helpers import new_worker from tests.worker.test_worker import never_run_activity @@ -26,7 +26,7 @@ def intercept_client(self, next: OutboundInterceptor) -> OutboundInterceptor: return super().intercept_client(next) -class MyClientPlugin(temporalio.client.Plugin): +class MyClientPlugin(temporalio.worker.ReplayerPlugin, temporalio.client.Plugin): def __init__(self): self.interceptor = TestClientInterceptor() @@ -62,13 +62,15 @@ async def test_client_plugin(client: Client, env: WorkflowEnvironment): assert new_client.service_client.config.api_key == "replaced key" -class MyCombinedPlugin(temporalio.client.Plugin, temporalio.worker.Plugin): +class MyCombinedPlugin( + temporalio.worker.ReplayerPlugin, temporalio.client.Plugin, temporalio.worker.Plugin +): def configure_worker(self, config: WorkerConfig) -> WorkerConfig: config["task_queue"] = "combined" return super().configure_worker(config) -class MyWorkerPlugin(temporalio.worker.Plugin): +class MyWorkerPlugin(temporalio.worker.ReplayerPlugin, temporalio.worker.Plugin): def configure_worker(self, config: WorkerConfig) -> WorkerConfig: config["task_queue"] = "replaced_queue" runner = config.get("workflow_runner") @@ -142,7 +144,9 @@ async def test_worker_sandbox_restrictions(client: Client) -> None: ) -class ReplayCheckPlugin(temporalio.client.Plugin, temporalio.worker.Plugin): +class ReplayCheckPlugin( + temporalio.worker.ReplayerPlugin, temporalio.client.Plugin, temporalio.worker.Plugin +): def configure_worker(self, config: WorkerConfig) -> WorkerConfig: config["workflows"] = list(config.get("workflows") or []) + [HelloWorkflow] return super().configure_worker(config) From 1007d02c5a3a623c41c632bba7b98fd69c7d4f84 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Fri, 1 Aug 2025 16:59:39 -0700 Subject: [PATCH 5/7] Moving replay configuration to worker plugin, add replay execution hook --- .../openai_agents/_temporal_openai_agents.py | 30 +++- temporalio/worker/__init__.py | 2 - temporalio/worker/_replayer.py | 167 ++++-------------- temporalio/worker/_worker.py | 29 +++ .../worker/workflow_sandbox/_importer.py | 6 + .../openai_agents/test_openai_replay.py | 41 ++--- tests/test_plugins.py | 31 ++-- 7 files changed, 131 insertions(+), 175 deletions(-) diff --git a/temporalio/contrib/openai_agents/_temporal_openai_agents.py b/temporalio/contrib/openai_agents/_temporal_openai_agents.py index a1f71db71..6a90e5113 100644 --- a/temporalio/contrib/openai_agents/_temporal_openai_agents.py +++ b/temporalio/contrib/openai_agents/_temporal_openai_agents.py @@ -1,6 +1,6 @@ """Initialize Temporal OpenAI Agents overrides.""" -from contextlib import contextmanager +from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager from datetime import timedelta from typing import AsyncIterator, Callable, Optional, Union @@ -41,7 +41,13 @@ from temporalio.converter import ( DataConverter, ) -from temporalio.worker import Worker, WorkerConfig +from temporalio.worker import ( + Replayer, + ReplayerConfig, + Worker, + WorkerConfig, + WorkflowReplayResult, +) @contextmanager @@ -282,3 +288,23 @@ async def run_worker(self, worker: Worker) -> None: """ with set_open_ai_agent_temporal_overrides(self._model_params): await super().run_worker(worker) + + def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: + """Configure the replayer for OpenAI Agents.""" + config["interceptors"] = list(config.get("interceptors") or []) + [ + OpenAIAgentsTracingInterceptor() + ] + config["data_converter"] = DataConverter( + payload_converter_class=_OpenAIPayloadConverter + ) + return config + + @asynccontextmanager + async def workflow_replay( + self, + replayer: Replayer, + histories: AsyncIterator[temporalio.client.WorkflowHistory], + ) -> AsyncIterator[AsyncIterator[WorkflowReplayResult]]: + with set_open_ai_agent_temporal_overrides(self._model_params): + async with super().workflow_replay(replayer, histories) as results: + yield results diff --git a/temporalio/worker/__init__.py b/temporalio/worker/__init__.py index 4c1138fd5..6e062afcc 100644 --- a/temporalio/worker/__init__.py +++ b/temporalio/worker/__init__.py @@ -24,7 +24,6 @@ from ._replayer import ( Replayer, ReplayerConfig, - ReplayerPlugin, WorkflowReplayResult, WorkflowReplayResults, ) @@ -69,7 +68,6 @@ "WorkerDeploymentVersion", "Replayer", "ReplayerConfig", - "ReplayerPlugin", "WorkflowReplayResult", "WorkflowReplayResults", "PollerBehavior", diff --git a/temporalio/worker/_replayer.py b/temporalio/worker/_replayer.py index e4cd1b3ad..c30b5fb16 100644 --- a/temporalio/worker/_replayer.py +++ b/temporalio/worker/_replayer.py @@ -5,9 +5,10 @@ import asyncio import concurrent.futures import logging -from contextlib import asynccontextmanager +import typing +from contextlib import AbstractAsyncContextManager, asynccontextmanager from dataclasses import dataclass -from typing import AsyncIterator, Dict, Mapping, Optional, Sequence, Type, cast +from typing import AsyncIterator, Dict, Mapping, Optional, Sequence, Type from typing_extensions import TypedDict @@ -21,7 +22,6 @@ from ..common import HeaderCodecBehavior from ._interceptor import Interceptor -from ._worker import WorkerConfig, load_default_build_id from ._workflow import _WorkflowWorker from ._workflow_instance import UnsandboxedWorkflowRunner, WorkflowRunner from .workflow_sandbox import SandboxedWorkflowRunner @@ -29,86 +29,23 @@ logger = logging.getLogger(__name__) -class ReplayerPlugin: - """Base class for replayer plugins that can modify replayer configuration.""" - - def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: - """Configure the replayer. - - Default implementation applies shared configuration from worker and client plugins. - - Args: - config: The replayer configuration to modify. - - Returns: - The modified replayer configuration. - """ - # If this plugin is also a worker plugin, apply shared worker config - if isinstance(self, temporalio.worker.Plugin): - # Create a minimal worker config with shared fields - worker_config = cast( - WorkerConfig, - { - "workflows": config["workflows"], - "workflow_task_executor": config["workflow_task_executor"], - "workflow_runner": config["workflow_runner"], - "unsandboxed_workflow_runner": config[ - "unsandboxed_workflow_runner" - ], - "interceptors": config["interceptors"], - "build_id": config["build_id"], - "identity": config["identity"], - "workflow_failure_exception_types": config[ - "workflow_failure_exception_types" - ], - "debug_mode": config["debug_mode"], - "disable_safe_workflow_eviction": config[ - "disable_safe_workflow_eviction" - ], - }, - ) - - modified_worker_config = self.configure_worker(worker_config) - config["workflows"] = modified_worker_config["workflows"] - config["workflow_task_executor"] = modified_worker_config[ - "workflow_task_executor" - ] - config["workflow_runner"] = modified_worker_config["workflow_runner"] - config["unsandboxed_workflow_runner"] = modified_worker_config[ - "unsandboxed_workflow_runner" - ] - config["interceptors"] = modified_worker_config["interceptors"] - config["build_id"] = modified_worker_config["build_id"] - config["identity"] = modified_worker_config["identity"] - config["workflow_failure_exception_types"] = modified_worker_config[ - "workflow_failure_exception_types" - ] - config["debug_mode"] = modified_worker_config["debug_mode"] - config["disable_safe_workflow_eviction"] = modified_worker_config[ - "disable_safe_workflow_eviction" - ] - - # If this plugin is also a client plugin, apply shared client config - if isinstance(self, temporalio.client.Plugin): - # Only include fields that exist in both ReplayerConfig and ClientConfig - # Note: interceptors are different types between client and worker, so excluded - client_config = cast( - temporalio.client.ClientConfig, - { - "namespace": config["namespace"], - "data_converter": config["data_converter"], - "header_codec_behavior": config["header_codec_behavior"], - }, - ) - - modified_client_config = self.configure_client(client_config) - config["namespace"] = modified_client_config["namespace"] - config["data_converter"] = modified_client_config["data_converter"] - config["header_codec_behavior"] = modified_client_config[ - "header_codec_behavior" - ] +class ReplayerConfig(TypedDict, total=False): + """TypedDict of config originally passed to :py:class:`Replayer`.""" - return config + workflows: Sequence[Type] + workflow_task_executor: Optional[concurrent.futures.ThreadPoolExecutor] + workflow_runner: WorkflowRunner + unsandboxed_workflow_runner: WorkflowRunner + namespace: str + data_converter: temporalio.converter.DataConverter + interceptors: Sequence[Interceptor] + build_id: Optional[str] + identity: Optional[str] + workflow_failure_exception_types: Sequence[Type[BaseException]] + debug_mode: bool + runtime: Optional[temporalio.runtime.Runtime] + disable_safe_workflow_eviction: bool + header_codec_behavior: HeaderCodecBehavior class Replayer: @@ -124,7 +61,7 @@ def __init__( namespace: str = "ReplayNamespace", data_converter: temporalio.converter.DataConverter = temporalio.converter.DataConverter.default, interceptors: Sequence[Interceptor] = [], - plugins: Sequence[ReplayerPlugin] = [], + plugins: Sequence[temporalio.worker.Plugin] = [], build_id: Optional[str] = None, identity: Optional[str] = None, workflow_failure_exception_types: Sequence[Type[BaseException]] = [], @@ -164,32 +101,16 @@ def __init__( header_codec_behavior=header_codec_behavior, ) - # Initialize all worker plugins - root_worker_plugin: temporalio.worker.Plugin = ( - temporalio.worker._worker._RootPlugin() - ) - for worker_plugin in reversed( - [ - cast(temporalio.worker.Plugin, plugin) - for plugin in plugins - if isinstance(plugin, temporalio.worker.Plugin) - ] - ): - root_worker_plugin = worker_plugin.init_worker_plugin(root_worker_plugin) - - # Initialize all client plugins - root_client_plugin: temporalio.client.Plugin = temporalio.client._RootPlugin() - for client_plugin in reversed( - [ - cast(temporalio.client.Plugin, plugin) - for plugin in plugins - if isinstance(plugin, temporalio.client.Plugin) - ] - ): - root_client_plugin = client_plugin.init_client_plugin(root_client_plugin) + from ._worker import _RootPlugin + + root_plugin: temporalio.worker.Plugin = _RootPlugin() + for plugin in reversed(plugins): + root_plugin = plugin.init_worker_plugin(root_plugin) + self._config = root_plugin.configure_replayer(self._config) + self._plugin = root_plugin # Apply plugin configuration - for plugin in plugins: + for plugin in reversed(plugins): self._config = plugin.configure_replayer(self._config) # Validate workflows after plugin configuration @@ -262,10 +183,9 @@ async def replay_workflows( replay_failures[result.history.run_id] = result.replay_failure return WorkflowReplayResults(replay_failures=replay_failures) - @asynccontextmanager - async def workflow_replay_iterator( + def workflow_replay_iterator( self, histories: AsyncIterator[temporalio.client.WorkflowHistory] - ) -> AsyncIterator[AsyncIterator[WorkflowReplayResult]]: + ) -> AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]]: """Replay workflows for the given histories. This is a context manager for use via ``async with``. The value is an @@ -278,6 +198,12 @@ async def workflow_replay_iterator( An async iterator that returns replayed workflow results as they are replayed. """ + return self._plugin.workflow_replay(self, histories) + + @asynccontextmanager + async def _workflow_replay_iterator( + self, histories: AsyncIterator[temporalio.client.WorkflowHistory] + ) -> AsyncIterator[AsyncIterator[WorkflowReplayResult]]: try: last_replay_failure: Optional[Exception] last_replay_complete = asyncio.Event() @@ -337,6 +263,8 @@ def on_eviction_hook( != HeaderCodecBehavior.NO_CODEC, ) # Create bridge worker + from ._worker import load_default_build_id + bridge_worker, pusher = temporalio.bridge.worker.Worker.for_replay( runtime._core_runtime, temporalio.bridge.worker.WorkerConfig( @@ -440,25 +368,6 @@ async def replay_iterator() -> AsyncIterator[WorkflowReplayResult]: logger.warning("Failed to finalize shutdown", exc_info=True) -class ReplayerConfig(TypedDict, total=False): - """TypedDict of config originally passed to :py:class:`Replayer`.""" - - workflows: Sequence[Type] - workflow_task_executor: Optional[concurrent.futures.ThreadPoolExecutor] - workflow_runner: WorkflowRunner - unsandboxed_workflow_runner: WorkflowRunner - namespace: str - data_converter: temporalio.converter.DataConverter - interceptors: Sequence[Interceptor] - build_id: Optional[str] - identity: Optional[str] - workflow_failure_exception_types: Sequence[Type[BaseException]] - debug_mode: bool - runtime: Optional[temporalio.runtime.Runtime] - disable_safe_workflow_eviction: bool - header_codec_behavior: HeaderCodecBehavior - - @dataclass(frozen=True) class WorkflowReplayResult: """Single workflow replay result.""" diff --git a/temporalio/worker/_worker.py b/temporalio/worker/_worker.py index 58f881c04..f96f4199b 100644 --- a/temporalio/worker/_worker.py +++ b/temporalio/worker/_worker.py @@ -9,10 +9,12 @@ import logging import sys import warnings +from contextlib import AbstractAsyncContextManager, asynccontextmanager from dataclasses import dataclass from datetime import timedelta from typing import ( Any, + AsyncIterator, Awaitable, Callable, List, @@ -36,6 +38,7 @@ WorkerDeploymentVersion, ) +from . import Replayer, ReplayerConfig, WorkflowReplayResult from ._activity import SharedStateManager, _ActivityWorker from ._interceptor import Interceptor from ._nexus import _NexusWorker @@ -149,6 +152,25 @@ async def run_worker(self, worker: Worker) -> None: """ await self.next_worker_plugin.run_worker(worker) + def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: + """Hook called when creating a replayer to allow modification of configuration. + + This should be used to configure anything in ReplayerConfig needed to make execution match + the original. This could include interceptors, DataConverter, workflows, and more. + + Uniquely does not rely on a chain, and is instead called sequentially on the plugins + because the replayer cannot instantiate the worker/client component. + """ + return config + + def workflow_replay( + self, + replayer: Replayer, + histories: AsyncIterator[temporalio.client.WorkflowHistory], + ) -> AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]]: + """Hook called when running a replayer to allow interception of execution.""" + return self.next_worker_plugin.workflow_replay(replayer, histories) + class _RootPlugin(Plugin): def configure_worker(self, config: WorkerConfig) -> WorkerConfig: @@ -157,6 +179,13 @@ def configure_worker(self, config: WorkerConfig) -> WorkerConfig: async def run_worker(self, worker: Worker) -> None: await worker._run() + def workflow_replay( + self, + replayer: Replayer, + histories: AsyncIterator[temporalio.client.WorkflowHistory], + ) -> AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]]: + return replayer._workflow_replay_iterator(histories) + class Worker: """Worker to process workflows and/or activities. diff --git a/temporalio/worker/workflow_sandbox/_importer.py b/temporalio/worker/workflow_sandbox/_importer.py index 462bd44c2..944f2929a 100644 --- a/temporalio/worker/workflow_sandbox/_importer.py +++ b/temporalio/worker/workflow_sandbox/_importer.py @@ -285,6 +285,12 @@ def module_configured_passthrough(self, name: str) -> bool: def _maybe_passthrough_module(self, name: str) -> Optional[types.ModuleType]: # If imports not passed through and all modules are not passed through # and name not in passthrough modules, check parents + logger.debug( + "Check passthrough module: %s - %s", + name, + temporalio.workflow.unsafe.is_imports_passed_through() + or self.module_configured_passthrough(name), + ) if ( not temporalio.workflow.unsafe.is_imports_passed_through() and not self.module_configured_passthrough(name) diff --git a/tests/contrib/openai_agents/test_openai_replay.py b/tests/contrib/openai_agents/test_openai_replay.py index d3ac92c5e..c6ac1ea68 100644 --- a/tests/contrib/openai_agents/test_openai_replay.py +++ b/tests/contrib/openai_agents/test_openai_replay.py @@ -1,16 +1,15 @@ +from contextlib import ( + AbstractAsyncContextManager, + AbstractContextManager, + asynccontextmanager, +) from pathlib import Path +from typing import AsyncGenerator import pytest from temporalio.client import WorkflowHistory -from temporalio.contrib.openai_agents import ModelActivityParameters -from temporalio.contrib.openai_agents._temporal_openai_agents import ( - set_open_ai_agent_temporal_overrides, -) -from temporalio.contrib.openai_agents._trace_interceptor import ( - OpenAIAgentsTracingInterceptor, -) -from temporalio.contrib.pydantic import pydantic_data_converter +from temporalio.contrib.openai_agents import ModelActivityParameters, OpenAIAgentsPlugin from temporalio.worker import Replayer from tests.contrib.openai_agents.test_openai import ( AgentsAsToolsWorkflow, @@ -39,17 +38,15 @@ async def test_replay(file_name: str) -> None: with (Path(__file__).with_name("histories") / file_name).open("r") as f: history_json = f.read() - with set_open_ai_agent_temporal_overrides(ModelActivityParameters()): - await Replayer( - workflows=[ - ResearchWorkflow, - ToolsWorkflow, - CustomerServiceWorkflow, - AgentsAsToolsWorkflow, - HelloWorldAgent, - InputGuardrailWorkflow, - OutputGuardrailWorkflow, - ], - data_converter=pydantic_data_converter, - interceptors=[OpenAIAgentsTracingInterceptor()], - ).replay_workflow(WorkflowHistory.from_json("fake", history_json)) + await Replayer( + workflows=[ + ResearchWorkflow, + ToolsWorkflow, + CustomerServiceWorkflow, + AgentsAsToolsWorkflow, + HelloWorldAgent, + InputGuardrailWorkflow, + OutputGuardrailWorkflow, + ], + plugins=[OpenAIAgentsPlugin()], + ).replay_workflow(WorkflowHistory.from_json("fake", history_json)) diff --git a/tests/test_plugins.py b/tests/test_plugins.py index 49ddc23fa..ced9ba668 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -26,7 +26,7 @@ def intercept_client(self, next: OutboundInterceptor) -> OutboundInterceptor: return super().intercept_client(next) -class MyClientPlugin(temporalio.worker.ReplayerPlugin, temporalio.client.Plugin): +class MyClientPlugin(temporalio.client.Plugin): def __init__(self): self.interceptor = TestClientInterceptor() @@ -62,15 +62,13 @@ async def test_client_plugin(client: Client, env: WorkflowEnvironment): assert new_client.service_client.config.api_key == "replaced key" -class MyCombinedPlugin( - temporalio.worker.ReplayerPlugin, temporalio.client.Plugin, temporalio.worker.Plugin -): +class MyCombinedPlugin(temporalio.client.Plugin, temporalio.worker.Plugin): def configure_worker(self, config: WorkerConfig) -> WorkerConfig: config["task_queue"] = "combined" return super().configure_worker(config) -class MyWorkerPlugin(temporalio.worker.ReplayerPlugin, temporalio.worker.Plugin): +class MyWorkerPlugin(temporalio.worker.Plugin): def configure_worker(self, config: WorkerConfig) -> WorkerConfig: config["task_queue"] = "replaced_queue" runner = config.get("workflow_runner") @@ -144,16 +142,19 @@ async def test_worker_sandbox_restrictions(client: Client) -> None: ) -class ReplayCheckPlugin( - temporalio.worker.ReplayerPlugin, temporalio.client.Plugin, temporalio.worker.Plugin -): +class ReplayCheckPlugin(temporalio.client.Plugin, temporalio.worker.Plugin): + def configure_client(self, config: ClientConfig) -> ClientConfig: + config["data_converter"] = pydantic_data_converter + return super().configure_client(config) + def configure_worker(self, config: WorkerConfig) -> WorkerConfig: config["workflows"] = list(config.get("workflows") or []) + [HelloWorkflow] return super().configure_worker(config) - def configure_client(self, config: ClientConfig) -> ClientConfig: + def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: config["data_converter"] = pydantic_data_converter - return super().configure_client(config) + config["workflows"] = list(config.get("workflows") or []) + [HelloWorkflow] + return config @workflow.defn @@ -182,13 +183,3 @@ async def test_replay(client: Client) -> None: assert replayer.config().get("data_converter") == pydantic_data_converter await replayer.replay_workflow(await handle.fetch_history()) - - replayer = Replayer(workflows=[HelloWorkflow], plugins=[MyClientPlugin()]) - replayer = Replayer(workflows=[HelloWorkflow], plugins=[MyWorkerPlugin()]) - replayer = Replayer( - workflows=[HelloWorkflow], plugins=[MyClientPlugin(), MyWorkerPlugin()] - ) - replayer = Replayer( - workflows=[HelloWorkflow], - plugins=[MyWorkerPlugin(), MyClientPlugin(), MyCombinedPlugin()], - ) From cbc5fa34411cc05d8aea9538da3594b822234c52 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Fri, 1 Aug 2025 17:31:02 -0700 Subject: [PATCH 6/7] POC for moving plugin run_worker to a context --- .../openai_agents/_temporal_openai_agents.py | 6 ++- temporalio/worker/_worker.py | 15 ++++--- .../worker/workflow_sandbox/_importer.py | 6 --- tests/test_plugins.py | 42 +++++++++++++++++-- 4 files changed, 50 insertions(+), 19 deletions(-) diff --git a/temporalio/contrib/openai_agents/_temporal_openai_agents.py b/temporalio/contrib/openai_agents/_temporal_openai_agents.py index 6a90e5113..0da4ccb32 100644 --- a/temporalio/contrib/openai_agents/_temporal_openai_agents.py +++ b/temporalio/contrib/openai_agents/_temporal_openai_agents.py @@ -276,7 +276,8 @@ def configure_worker(self, config: WorkerConfig) -> WorkerConfig: ] return super().configure_worker(config) - async def run_worker(self, worker: Worker) -> None: + @asynccontextmanager + async def run_worker(self) -> AsyncIterator[None]: """Run the worker with OpenAI agents temporal overrides. This method sets up the necessary runtime overrides for OpenAI agents @@ -287,7 +288,8 @@ async def run_worker(self, worker: Worker) -> None: worker: The worker instance to run. """ with set_open_ai_agent_temporal_overrides(self._model_params): - await super().run_worker(worker) + async with super().run_worker(): + yield def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: """Configure the replayer for OpenAI Agents.""" diff --git a/temporalio/worker/_worker.py b/temporalio/worker/_worker.py index f96f4199b..76a0cf260 100644 --- a/temporalio/worker/_worker.py +++ b/temporalio/worker/_worker.py @@ -140,17 +140,14 @@ def configure_worker(self, config: WorkerConfig) -> WorkerConfig: """ return self.next_worker_plugin.configure_worker(config) - async def run_worker(self, worker: Worker) -> None: + def run_worker(self) -> AbstractAsyncContextManager[None]: """Hook called when running a worker to allow interception of execution. This method is called when the worker is started and allows plugins to intercept or wrap the worker execution. Plugins can add monitoring, custom lifecycle management, or other execution-time behavior. - - Args: - worker: The worker instance to run. """ - await self.next_worker_plugin.run_worker(worker) + return self.next_worker_plugin.run_worker() def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: """Hook called when creating a replayer to allow modification of configuration. @@ -176,8 +173,9 @@ class _RootPlugin(Plugin): def configure_worker(self, config: WorkerConfig) -> WorkerConfig: return config - async def run_worker(self, worker: Worker) -> None: - await worker._run() + @asynccontextmanager + async def run_worker(self) -> AsyncIterator[None]: + yield def workflow_replay( self, @@ -794,7 +792,8 @@ async def run(self) -> None: also cancel the shutdown process. Therefore users are encouraged to use explicit shutdown instead. """ - await self._plugin.run_worker(self) + async with self._plugin.run_worker(): + await self._run() async def _run(self): # Eagerly validate which will do a namespace check in Core diff --git a/temporalio/worker/workflow_sandbox/_importer.py b/temporalio/worker/workflow_sandbox/_importer.py index 944f2929a..462bd44c2 100644 --- a/temporalio/worker/workflow_sandbox/_importer.py +++ b/temporalio/worker/workflow_sandbox/_importer.py @@ -285,12 +285,6 @@ def module_configured_passthrough(self, name: str) -> bool: def _maybe_passthrough_module(self, name: str) -> Optional[types.ModuleType]: # If imports not passed through and all modules are not passed through # and name not in passthrough modules, check parents - logger.debug( - "Check passthrough module: %s - %s", - name, - temporalio.workflow.unsafe.is_imports_passed_through() - or self.module_configured_passthrough(name), - ) if ( not temporalio.workflow.unsafe.is_imports_passed_through() and not self.module_configured_passthrough(name) diff --git a/tests/test_plugins.py b/tests/test_plugins.py index ced9ba668..868ff4115 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -1,7 +1,9 @@ import dataclasses import uuid import warnings -from typing import cast +from contextlib import asynccontextmanager +from datetime import timedelta +from typing import AsyncIterator, cast import pytest @@ -68,6 +70,9 @@ def configure_worker(self, config: WorkerConfig) -> WorkerConfig: return super().configure_worker(config) +IN_CONTEXT: bool = False + + class MyWorkerPlugin(temporalio.worker.Plugin): def configure_worker(self, config: WorkerConfig) -> WorkerConfig: config["task_queue"] = "replaced_queue" @@ -79,8 +84,15 @@ def configure_worker(self, config: WorkerConfig) -> WorkerConfig: ) return super().configure_worker(config) - async def run_worker(self, worker: Worker) -> None: - await super().run_worker(worker) + @asynccontextmanager + async def run_worker(self) -> AsyncIterator[None]: + global IN_CONTEXT + try: + IN_CONTEXT = True + async with super().run_worker(): + yield + finally: + IN_CONTEXT = False async def test_worker_plugin_basic_config(client: Client) -> None: @@ -109,6 +121,30 @@ async def test_worker_plugin_basic_config(client: Client) -> None: assert worker.config().get("task_queue") == "replaced_queue" +@workflow.defn(sandboxed=False) +class CheckContextWorkflow: + @workflow.run + async def run(self) -> bool: + return IN_CONTEXT + + +async def test_worker_plugin_run_context(client: Client) -> None: + async with Worker( + client, + task_queue=str(uuid.uuid4()), + workflows=[CheckContextWorkflow], + activities=[never_run_activity], + plugins=[MyWorkerPlugin()], + ) as worker: + result = await client.execute_workflow( + CheckContextWorkflow.run, + task_queue=worker.task_queue, + id=f"workflow-{uuid.uuid4()}", + execution_timeout=timedelta(seconds=1), + ) + assert result + + async def test_worker_duplicated_plugin(client: Client) -> None: new_config = client.config() new_config["plugins"] = [MyCombinedPlugin()] From 2cf213c1088e7917f173b3b31a9597eb548a055d Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Fri, 1 Aug 2025 17:35:44 -0700 Subject: [PATCH 7/7] Some cleanup --- .../openai_agents/_temporal_openai_agents.py | 10 - temporalio/worker/_replayer.py | 356 +++++++++--------- temporalio/worker/_worker.py | 15 - 3 files changed, 179 insertions(+), 202 deletions(-) diff --git a/temporalio/contrib/openai_agents/_temporal_openai_agents.py b/temporalio/contrib/openai_agents/_temporal_openai_agents.py index 0da4ccb32..2375528c6 100644 --- a/temporalio/contrib/openai_agents/_temporal_openai_agents.py +++ b/temporalio/contrib/openai_agents/_temporal_openai_agents.py @@ -300,13 +300,3 @@ def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: payload_converter_class=_OpenAIPayloadConverter ) return config - - @asynccontextmanager - async def workflow_replay( - self, - replayer: Replayer, - histories: AsyncIterator[temporalio.client.WorkflowHistory], - ) -> AsyncIterator[AsyncIterator[WorkflowReplayResult]]: - with set_open_ai_agent_temporal_overrides(self._model_params): - async with super().workflow_replay(replayer, histories) as results: - yield results diff --git a/temporalio/worker/_replayer.py b/temporalio/worker/_replayer.py index c30b5fb16..9bffa4164 100644 --- a/temporalio/worker/_replayer.py +++ b/temporalio/worker/_replayer.py @@ -29,25 +29,6 @@ logger = logging.getLogger(__name__) -class ReplayerConfig(TypedDict, total=False): - """TypedDict of config originally passed to :py:class:`Replayer`.""" - - workflows: Sequence[Type] - workflow_task_executor: Optional[concurrent.futures.ThreadPoolExecutor] - workflow_runner: WorkflowRunner - unsandboxed_workflow_runner: WorkflowRunner - namespace: str - data_converter: temporalio.converter.DataConverter - interceptors: Sequence[Interceptor] - build_id: Optional[str] - identity: Optional[str] - workflow_failure_exception_types: Sequence[Type[BaseException]] - debug_mode: bool - runtime: Optional[temporalio.runtime.Runtime] - disable_safe_workflow_eviction: bool - header_codec_behavior: HeaderCodecBehavior - - class Replayer: """Replayer to replay workflows from history.""" @@ -183,9 +164,10 @@ async def replay_workflows( replay_failures[result.history.run_id] = result.replay_failure return WorkflowReplayResults(replay_failures=replay_failures) - def workflow_replay_iterator( + @asynccontextmanager + async def workflow_replay_iterator( self, histories: AsyncIterator[temporalio.client.WorkflowHistory] - ) -> AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]]: + ) -> AsyncIterator[AsyncIterator[WorkflowReplayResult]]: """Replay workflows for the given histories. This is a context manager for use via ``async with``. The value is an @@ -198,174 +180,194 @@ def workflow_replay_iterator( An async iterator that returns replayed workflow results as they are replayed. """ - return self._plugin.workflow_replay(self, histories) - - @asynccontextmanager - async def _workflow_replay_iterator( - self, histories: AsyncIterator[temporalio.client.WorkflowHistory] - ) -> AsyncIterator[AsyncIterator[WorkflowReplayResult]]: - try: - last_replay_failure: Optional[Exception] - last_replay_complete = asyncio.Event() - - # Create eviction hook - def on_eviction_hook( - run_id: str, - remove_job: temporalio.bridge.proto.workflow_activation.RemoveFromCache, - ) -> None: - nonlocal last_replay_failure - if ( - remove_job.reason - == temporalio.bridge.proto.workflow_activation.RemoveFromCache.EvictionReason.NONDETERMINISM - ): - last_replay_failure = temporalio.workflow.NondeterminismError( - remove_job.message - ) - elif ( - remove_job.reason - != temporalio.bridge.proto.workflow_activation.RemoveFromCache.EvictionReason.CACHE_FULL - and remove_job.reason - != temporalio.bridge.proto.workflow_activation.RemoveFromCache.EvictionReason.LANG_REQUESTED - ): - last_replay_failure = RuntimeError( - f"{remove_job.reason}: {remove_job.message}" - ) - else: - last_replay_failure = None - last_replay_complete.set() - - # Create worker referencing bridge worker - bridge_worker: temporalio.bridge.worker.Worker - task_queue = f"replay-{self._config['build_id']}" - runtime = self._config["runtime"] or temporalio.runtime.Runtime.default() - workflow_worker = _WorkflowWorker( - bridge_worker=lambda: bridge_worker, - namespace=self._config["namespace"], - task_queue=task_queue, - workflows=self._config["workflows"], - workflow_task_executor=self._config["workflow_task_executor"], - max_concurrent_workflow_tasks=5, - workflow_runner=self._config["workflow_runner"], - unsandboxed_workflow_runner=self._config["unsandboxed_workflow_runner"], - data_converter=self._config["data_converter"], - interceptors=self._config["interceptors"], - workflow_failure_exception_types=self._config[ - "workflow_failure_exception_types" - ], - debug_mode=self._config["debug_mode"], - metric_meter=runtime.metric_meter, - on_eviction_hook=on_eviction_hook, - disable_eager_activity_execution=False, - disable_safe_eviction=self._config["disable_safe_workflow_eviction"], - should_enforce_versioning_behavior=False, - assert_local_activity_valid=lambda a: None, - encode_headers=self._config["header_codec_behavior"] - != HeaderCodecBehavior.NO_CODEC, - ) - # Create bridge worker - from ._worker import load_default_build_id - - bridge_worker, pusher = temporalio.bridge.worker.Worker.for_replay( - runtime._core_runtime, - temporalio.bridge.worker.WorkerConfig( + async with self._plugin.run_worker(): + try: + last_replay_failure: Optional[Exception] + last_replay_complete = asyncio.Event() + + # Create eviction hook + def on_eviction_hook( + run_id: str, + remove_job: temporalio.bridge.proto.workflow_activation.RemoveFromCache, + ) -> None: + nonlocal last_replay_failure + if ( + remove_job.reason + == temporalio.bridge.proto.workflow_activation.RemoveFromCache.EvictionReason.NONDETERMINISM + ): + last_replay_failure = temporalio.workflow.NondeterminismError( + remove_job.message + ) + elif ( + remove_job.reason + != temporalio.bridge.proto.workflow_activation.RemoveFromCache.EvictionReason.CACHE_FULL + and remove_job.reason + != temporalio.bridge.proto.workflow_activation.RemoveFromCache.EvictionReason.LANG_REQUESTED + ): + last_replay_failure = RuntimeError( + f"{remove_job.reason}: {remove_job.message}" + ) + else: + last_replay_failure = None + last_replay_complete.set() + + # Create worker referencing bridge worker + bridge_worker: temporalio.bridge.worker.Worker + task_queue = f"replay-{self._config['build_id']}" + runtime = ( + self._config["runtime"] or temporalio.runtime.Runtime.default() + ) + workflow_worker = _WorkflowWorker( + bridge_worker=lambda: bridge_worker, namespace=self._config["namespace"], task_queue=task_queue, - identity_override=self._config["identity"], - # Need to tell core whether we want to consider all - # non-determinism exceptions as workflow fail, and whether we do - # per workflow type - nondeterminism_as_workflow_fail=workflow_worker.nondeterminism_as_workflow_fail(), - nondeterminism_as_workflow_fail_for_types=workflow_worker.nondeterminism_as_workflow_fail_for_types(), - # All values below are ignored but required by Core - max_cached_workflows=2, - tuner=temporalio.bridge.worker.TunerHolder( - workflow_slot_supplier=temporalio.bridge.worker.FixedSizeSlotSupplier( + workflows=self._config["workflows"], + workflow_task_executor=self._config["workflow_task_executor"], + max_concurrent_workflow_tasks=5, + workflow_runner=self._config["workflow_runner"], + unsandboxed_workflow_runner=self._config[ + "unsandboxed_workflow_runner" + ], + data_converter=self._config["data_converter"], + interceptors=self._config["interceptors"], + workflow_failure_exception_types=self._config[ + "workflow_failure_exception_types" + ], + debug_mode=self._config["debug_mode"], + metric_meter=runtime.metric_meter, + on_eviction_hook=on_eviction_hook, + disable_eager_activity_execution=False, + disable_safe_eviction=self._config[ + "disable_safe_workflow_eviction" + ], + should_enforce_versioning_behavior=False, + assert_local_activity_valid=lambda a: None, + encode_headers=self._config["header_codec_behavior"] + != HeaderCodecBehavior.NO_CODEC, + ) + # Create bridge worker + from ._worker import load_default_build_id + + bridge_worker, pusher = temporalio.bridge.worker.Worker.for_replay( + runtime._core_runtime, + temporalio.bridge.worker.WorkerConfig( + namespace=self._config["namespace"], + task_queue=task_queue, + identity_override=self._config["identity"], + # Need to tell core whether we want to consider all + # non-determinism exceptions as workflow fail, and whether we do + # per workflow type + nondeterminism_as_workflow_fail=workflow_worker.nondeterminism_as_workflow_fail(), + nondeterminism_as_workflow_fail_for_types=workflow_worker.nondeterminism_as_workflow_fail_for_types(), + # All values below are ignored but required by Core + max_cached_workflows=2, + tuner=temporalio.bridge.worker.TunerHolder( + workflow_slot_supplier=temporalio.bridge.worker.FixedSizeSlotSupplier( + 2 + ), + activity_slot_supplier=temporalio.bridge.worker.FixedSizeSlotSupplier( + 1 + ), + local_activity_slot_supplier=temporalio.bridge.worker.FixedSizeSlotSupplier( + 1 + ), + ), + nonsticky_to_sticky_poll_ratio=1, + no_remote_activities=True, + sticky_queue_schedule_to_start_timeout_millis=1000, + max_heartbeat_throttle_interval_millis=1000, + default_heartbeat_throttle_interval_millis=1000, + max_activities_per_second=None, + max_task_queue_activities_per_second=None, + graceful_shutdown_period_millis=0, + versioning_strategy=temporalio.bridge.worker.WorkerVersioningStrategyNone( + build_id_no_versioning=self._config["build_id"] + or load_default_build_id(), + ), + workflow_task_poller_behavior=temporalio.bridge.worker.PollerBehaviorSimpleMaximum( 2 ), - activity_slot_supplier=temporalio.bridge.worker.FixedSizeSlotSupplier( + activity_task_poller_behavior=temporalio.bridge.worker.PollerBehaviorSimpleMaximum( 1 ), - local_activity_slot_supplier=temporalio.bridge.worker.FixedSizeSlotSupplier( + nexus_task_poller_behavior=temporalio.bridge.worker.PollerBehaviorSimpleMaximum( 1 ), ), - nonsticky_to_sticky_poll_ratio=1, - no_remote_activities=True, - sticky_queue_schedule_to_start_timeout_millis=1000, - max_heartbeat_throttle_interval_millis=1000, - default_heartbeat_throttle_interval_millis=1000, - max_activities_per_second=None, - max_task_queue_activities_per_second=None, - graceful_shutdown_period_millis=0, - versioning_strategy=temporalio.bridge.worker.WorkerVersioningStrategyNone( - build_id_no_versioning=self._config["build_id"] - or load_default_build_id(), - ), - workflow_task_poller_behavior=temporalio.bridge.worker.PollerBehaviorSimpleMaximum( - 2 - ), - activity_task_poller_behavior=temporalio.bridge.worker.PollerBehaviorSimpleMaximum( - 1 - ), - nexus_task_poller_behavior=temporalio.bridge.worker.PollerBehaviorSimpleMaximum( - 1 - ), - ), - ) - # Start worker - workflow_worker_task = asyncio.create_task(workflow_worker.run()) - - # Yield iterator - async def replay_iterator() -> AsyncIterator[WorkflowReplayResult]: - async for history in histories: - # Clear last complete and push history - last_replay_complete.clear() - await pusher.push_history( - history.workflow_id, - temporalio.api.history.v1.History( - events=history.events - ).SerializeToString(), - ) - - # Wait for worker error or last replay to complete. This - # should never take more than a few seconds due to deadlock - # detector but we cannot add timeout just in case debug mode - # is enabled. - await asyncio.wait( # type: ignore - [ - workflow_worker_task, - asyncio.create_task(last_replay_complete.wait()), - ], - return_when=asyncio.FIRST_COMPLETED, - ) - # If worker task complete, wait on it so it'll throw - if workflow_worker_task.done(): - await workflow_worker_task - # Should always be set if workflow worker didn't throw - assert last_replay_complete.is_set() - - yield WorkflowReplayResult( - history=history, - replay_failure=last_replay_failure, - ) - - yield replay_iterator() - finally: - # Close the pusher - pusher.close() - # If the workflow worker task is not done, wait for it - try: - if not workflow_worker_task.done(): - await workflow_worker_task - except Exception: - logger.warning("Failed to shutdown worker", exc_info=True) + ) + # Start worker + workflow_worker_task = asyncio.create_task(workflow_worker.run()) + + # Yield iterator + async def replay_iterator() -> AsyncIterator[WorkflowReplayResult]: + async for history in histories: + # Clear last complete and push history + last_replay_complete.clear() + await pusher.push_history( + history.workflow_id, + temporalio.api.history.v1.History( + events=history.events + ).SerializeToString(), + ) + + # Wait for worker error or last replay to complete. This + # should never take more than a few seconds due to deadlock + # detector but we cannot add timeout just in case debug mode + # is enabled. + await asyncio.wait( # type: ignore + [ + workflow_worker_task, + asyncio.create_task(last_replay_complete.wait()), + ], + return_when=asyncio.FIRST_COMPLETED, + ) + # If worker task complete, wait on it so it'll throw + if workflow_worker_task.done(): + await workflow_worker_task + # Should always be set if workflow worker didn't throw + assert last_replay_complete.is_set() + + yield WorkflowReplayResult( + history=history, + replay_failure=last_replay_failure, + ) + + yield replay_iterator() finally: - # We must shutdown here + # Close the pusher + pusher.close() + # If the workflow worker task is not done, wait for it try: - bridge_worker.initiate_shutdown() - await bridge_worker.finalize_shutdown() + if not workflow_worker_task.done(): + await workflow_worker_task except Exception: - logger.warning("Failed to finalize shutdown", exc_info=True) + logger.warning("Failed to shutdown worker", exc_info=True) + finally: + # We must shutdown here + try: + bridge_worker.initiate_shutdown() + await bridge_worker.finalize_shutdown() + except Exception: + logger.warning("Failed to finalize shutdown", exc_info=True) + + +class ReplayerConfig(TypedDict, total=False): + """TypedDict of config originally passed to :py:class:`Replayer`.""" + + workflows: Sequence[Type] + workflow_task_executor: Optional[concurrent.futures.ThreadPoolExecutor] + workflow_runner: WorkflowRunner + unsandboxed_workflow_runner: WorkflowRunner + namespace: str + data_converter: temporalio.converter.DataConverter + interceptors: Sequence[Interceptor] + build_id: Optional[str] + identity: Optional[str] + workflow_failure_exception_types: Sequence[Type[BaseException]] + debug_mode: bool + runtime: Optional[temporalio.runtime.Runtime] + disable_safe_workflow_eviction: bool + header_codec_behavior: HeaderCodecBehavior @dataclass(frozen=True) diff --git a/temporalio/worker/_worker.py b/temporalio/worker/_worker.py index 76a0cf260..b66637e08 100644 --- a/temporalio/worker/_worker.py +++ b/temporalio/worker/_worker.py @@ -160,14 +160,6 @@ def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: """ return config - def workflow_replay( - self, - replayer: Replayer, - histories: AsyncIterator[temporalio.client.WorkflowHistory], - ) -> AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]]: - """Hook called when running a replayer to allow interception of execution.""" - return self.next_worker_plugin.workflow_replay(replayer, histories) - class _RootPlugin(Plugin): def configure_worker(self, config: WorkerConfig) -> WorkerConfig: @@ -177,13 +169,6 @@ def configure_worker(self, config: WorkerConfig) -> WorkerConfig: async def run_worker(self) -> AsyncIterator[None]: yield - def workflow_replay( - self, - replayer: Replayer, - histories: AsyncIterator[temporalio.client.WorkflowHistory], - ) -> AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]]: - return replayer._workflow_replay_iterator(histories) - class Worker: """Worker to process workflows and/or activities.