Skip to content

Commit 1007d02

Browse files
committed
Moving replay configuration to worker plugin, add replay execution hook
1 parent f61b400 commit 1007d02

File tree

7 files changed

+131
-175
lines changed

7 files changed

+131
-175
lines changed

temporalio/contrib/openai_agents/_temporal_openai_agents.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Initialize Temporal OpenAI Agents overrides."""
22

3-
from contextlib import contextmanager
3+
from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager
44
from datetime import timedelta
55
from typing import AsyncIterator, Callable, Optional, Union
66

@@ -41,7 +41,13 @@
4141
from temporalio.converter import (
4242
DataConverter,
4343
)
44-
from temporalio.worker import Worker, WorkerConfig
44+
from temporalio.worker import (
45+
Replayer,
46+
ReplayerConfig,
47+
Worker,
48+
WorkerConfig,
49+
WorkflowReplayResult,
50+
)
4551

4652

4753
@contextmanager
@@ -282,3 +288,23 @@ async def run_worker(self, worker: Worker) -> None:
282288
"""
283289
with set_open_ai_agent_temporal_overrides(self._model_params):
284290
await super().run_worker(worker)
291+
292+
def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig:
293+
"""Configure the replayer for OpenAI Agents."""
294+
config["interceptors"] = list(config.get("interceptors") or []) + [
295+
OpenAIAgentsTracingInterceptor()
296+
]
297+
config["data_converter"] = DataConverter(
298+
payload_converter_class=_OpenAIPayloadConverter
299+
)
300+
return config
301+
302+
@asynccontextmanager
303+
async def workflow_replay(
304+
self,
305+
replayer: Replayer,
306+
histories: AsyncIterator[temporalio.client.WorkflowHistory],
307+
) -> AsyncIterator[AsyncIterator[WorkflowReplayResult]]:
308+
with set_open_ai_agent_temporal_overrides(self._model_params):
309+
async with super().workflow_replay(replayer, histories) as results:
310+
yield results

temporalio/worker/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from ._replayer import (
2525
Replayer,
2626
ReplayerConfig,
27-
ReplayerPlugin,
2827
WorkflowReplayResult,
2928
WorkflowReplayResults,
3029
)
@@ -69,7 +68,6 @@
6968
"WorkerDeploymentVersion",
7069
"Replayer",
7170
"ReplayerConfig",
72-
"ReplayerPlugin",
7371
"WorkflowReplayResult",
7472
"WorkflowReplayResults",
7573
"PollerBehavior",

temporalio/worker/_replayer.py

Lines changed: 38 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
import asyncio
66
import concurrent.futures
77
import logging
8-
from contextlib import asynccontextmanager
8+
import typing
9+
from contextlib import AbstractAsyncContextManager, asynccontextmanager
910
from dataclasses import dataclass
10-
from typing import AsyncIterator, Dict, Mapping, Optional, Sequence, Type, cast
11+
from typing import AsyncIterator, Dict, Mapping, Optional, Sequence, Type
1112

1213
from typing_extensions import TypedDict
1314

@@ -21,94 +22,30 @@
2122

2223
from ..common import HeaderCodecBehavior
2324
from ._interceptor import Interceptor
24-
from ._worker import WorkerConfig, load_default_build_id
2525
from ._workflow import _WorkflowWorker
2626
from ._workflow_instance import UnsandboxedWorkflowRunner, WorkflowRunner
2727
from .workflow_sandbox import SandboxedWorkflowRunner
2828

2929
logger = logging.getLogger(__name__)
3030

3131

32-
class ReplayerPlugin:
33-
"""Base class for replayer plugins that can modify replayer configuration."""
34-
35-
def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig:
36-
"""Configure the replayer.
37-
38-
Default implementation applies shared configuration from worker and client plugins.
39-
40-
Args:
41-
config: The replayer configuration to modify.
42-
43-
Returns:
44-
The modified replayer configuration.
45-
"""
46-
# If this plugin is also a worker plugin, apply shared worker config
47-
if isinstance(self, temporalio.worker.Plugin):
48-
# Create a minimal worker config with shared fields
49-
worker_config = cast(
50-
WorkerConfig,
51-
{
52-
"workflows": config["workflows"],
53-
"workflow_task_executor": config["workflow_task_executor"],
54-
"workflow_runner": config["workflow_runner"],
55-
"unsandboxed_workflow_runner": config[
56-
"unsandboxed_workflow_runner"
57-
],
58-
"interceptors": config["interceptors"],
59-
"build_id": config["build_id"],
60-
"identity": config["identity"],
61-
"workflow_failure_exception_types": config[
62-
"workflow_failure_exception_types"
63-
],
64-
"debug_mode": config["debug_mode"],
65-
"disable_safe_workflow_eviction": config[
66-
"disable_safe_workflow_eviction"
67-
],
68-
},
69-
)
70-
71-
modified_worker_config = self.configure_worker(worker_config)
72-
config["workflows"] = modified_worker_config["workflows"]
73-
config["workflow_task_executor"] = modified_worker_config[
74-
"workflow_task_executor"
75-
]
76-
config["workflow_runner"] = modified_worker_config["workflow_runner"]
77-
config["unsandboxed_workflow_runner"] = modified_worker_config[
78-
"unsandboxed_workflow_runner"
79-
]
80-
config["interceptors"] = modified_worker_config["interceptors"]
81-
config["build_id"] = modified_worker_config["build_id"]
82-
config["identity"] = modified_worker_config["identity"]
83-
config["workflow_failure_exception_types"] = modified_worker_config[
84-
"workflow_failure_exception_types"
85-
]
86-
config["debug_mode"] = modified_worker_config["debug_mode"]
87-
config["disable_safe_workflow_eviction"] = modified_worker_config[
88-
"disable_safe_workflow_eviction"
89-
]
90-
91-
# If this plugin is also a client plugin, apply shared client config
92-
if isinstance(self, temporalio.client.Plugin):
93-
# Only include fields that exist in both ReplayerConfig and ClientConfig
94-
# Note: interceptors are different types between client and worker, so excluded
95-
client_config = cast(
96-
temporalio.client.ClientConfig,
97-
{
98-
"namespace": config["namespace"],
99-
"data_converter": config["data_converter"],
100-
"header_codec_behavior": config["header_codec_behavior"],
101-
},
102-
)
103-
104-
modified_client_config = self.configure_client(client_config)
105-
config["namespace"] = modified_client_config["namespace"]
106-
config["data_converter"] = modified_client_config["data_converter"]
107-
config["header_codec_behavior"] = modified_client_config[
108-
"header_codec_behavior"
109-
]
32+
class ReplayerConfig(TypedDict, total=False):
33+
"""TypedDict of config originally passed to :py:class:`Replayer`."""
11034

111-
return config
35+
workflows: Sequence[Type]
36+
workflow_task_executor: Optional[concurrent.futures.ThreadPoolExecutor]
37+
workflow_runner: WorkflowRunner
38+
unsandboxed_workflow_runner: WorkflowRunner
39+
namespace: str
40+
data_converter: temporalio.converter.DataConverter
41+
interceptors: Sequence[Interceptor]
42+
build_id: Optional[str]
43+
identity: Optional[str]
44+
workflow_failure_exception_types: Sequence[Type[BaseException]]
45+
debug_mode: bool
46+
runtime: Optional[temporalio.runtime.Runtime]
47+
disable_safe_workflow_eviction: bool
48+
header_codec_behavior: HeaderCodecBehavior
11249

11350

11451
class Replayer:
@@ -124,7 +61,7 @@ def __init__(
12461
namespace: str = "ReplayNamespace",
12562
data_converter: temporalio.converter.DataConverter = temporalio.converter.DataConverter.default,
12663
interceptors: Sequence[Interceptor] = [],
127-
plugins: Sequence[ReplayerPlugin] = [],
64+
plugins: Sequence[temporalio.worker.Plugin] = [],
12865
build_id: Optional[str] = None,
12966
identity: Optional[str] = None,
13067
workflow_failure_exception_types: Sequence[Type[BaseException]] = [],
@@ -164,32 +101,16 @@ def __init__(
164101
header_codec_behavior=header_codec_behavior,
165102
)
166103

167-
# Initialize all worker plugins
168-
root_worker_plugin: temporalio.worker.Plugin = (
169-
temporalio.worker._worker._RootPlugin()
170-
)
171-
for worker_plugin in reversed(
172-
[
173-
cast(temporalio.worker.Plugin, plugin)
174-
for plugin in plugins
175-
if isinstance(plugin, temporalio.worker.Plugin)
176-
]
177-
):
178-
root_worker_plugin = worker_plugin.init_worker_plugin(root_worker_plugin)
179-
180-
# Initialize all client plugins
181-
root_client_plugin: temporalio.client.Plugin = temporalio.client._RootPlugin()
182-
for client_plugin in reversed(
183-
[
184-
cast(temporalio.client.Plugin, plugin)
185-
for plugin in plugins
186-
if isinstance(plugin, temporalio.client.Plugin)
187-
]
188-
):
189-
root_client_plugin = client_plugin.init_client_plugin(root_client_plugin)
104+
from ._worker import _RootPlugin
105+
106+
root_plugin: temporalio.worker.Plugin = _RootPlugin()
107+
for plugin in reversed(plugins):
108+
root_plugin = plugin.init_worker_plugin(root_plugin)
109+
self._config = root_plugin.configure_replayer(self._config)
110+
self._plugin = root_plugin
190111

191112
# Apply plugin configuration
192-
for plugin in plugins:
113+
for plugin in reversed(plugins):
193114
self._config = plugin.configure_replayer(self._config)
194115

195116
# Validate workflows after plugin configuration
@@ -262,10 +183,9 @@ async def replay_workflows(
262183
replay_failures[result.history.run_id] = result.replay_failure
263184
return WorkflowReplayResults(replay_failures=replay_failures)
264185

265-
@asynccontextmanager
266-
async def workflow_replay_iterator(
186+
def workflow_replay_iterator(
267187
self, histories: AsyncIterator[temporalio.client.WorkflowHistory]
268-
) -> AsyncIterator[AsyncIterator[WorkflowReplayResult]]:
188+
) -> AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]]:
269189
"""Replay workflows for the given histories.
270190
271191
This is a context manager for use via ``async with``. The value is an
@@ -278,6 +198,12 @@ async def workflow_replay_iterator(
278198
An async iterator that returns replayed workflow results as they are
279199
replayed.
280200
"""
201+
return self._plugin.workflow_replay(self, histories)
202+
203+
@asynccontextmanager
204+
async def _workflow_replay_iterator(
205+
self, histories: AsyncIterator[temporalio.client.WorkflowHistory]
206+
) -> AsyncIterator[AsyncIterator[WorkflowReplayResult]]:
281207
try:
282208
last_replay_failure: Optional[Exception]
283209
last_replay_complete = asyncio.Event()
@@ -337,6 +263,8 @@ def on_eviction_hook(
337263
!= HeaderCodecBehavior.NO_CODEC,
338264
)
339265
# Create bridge worker
266+
from ._worker import load_default_build_id
267+
340268
bridge_worker, pusher = temporalio.bridge.worker.Worker.for_replay(
341269
runtime._core_runtime,
342270
temporalio.bridge.worker.WorkerConfig(
@@ -440,25 +368,6 @@ async def replay_iterator() -> AsyncIterator[WorkflowReplayResult]:
440368
logger.warning("Failed to finalize shutdown", exc_info=True)
441369

442370

443-
class ReplayerConfig(TypedDict, total=False):
444-
"""TypedDict of config originally passed to :py:class:`Replayer`."""
445-
446-
workflows: Sequence[Type]
447-
workflow_task_executor: Optional[concurrent.futures.ThreadPoolExecutor]
448-
workflow_runner: WorkflowRunner
449-
unsandboxed_workflow_runner: WorkflowRunner
450-
namespace: str
451-
data_converter: temporalio.converter.DataConverter
452-
interceptors: Sequence[Interceptor]
453-
build_id: Optional[str]
454-
identity: Optional[str]
455-
workflow_failure_exception_types: Sequence[Type[BaseException]]
456-
debug_mode: bool
457-
runtime: Optional[temporalio.runtime.Runtime]
458-
disable_safe_workflow_eviction: bool
459-
header_codec_behavior: HeaderCodecBehavior
460-
461-
462371
@dataclass(frozen=True)
463372
class WorkflowReplayResult:
464373
"""Single workflow replay result."""

temporalio/worker/_worker.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99
import logging
1010
import sys
1111
import warnings
12+
from contextlib import AbstractAsyncContextManager, asynccontextmanager
1213
from dataclasses import dataclass
1314
from datetime import timedelta
1415
from typing import (
1516
Any,
17+
AsyncIterator,
1618
Awaitable,
1719
Callable,
1820
List,
@@ -36,6 +38,7 @@
3638
WorkerDeploymentVersion,
3739
)
3840

41+
from . import Replayer, ReplayerConfig, WorkflowReplayResult
3942
from ._activity import SharedStateManager, _ActivityWorker
4043
from ._interceptor import Interceptor
4144
from ._nexus import _NexusWorker
@@ -149,6 +152,25 @@ async def run_worker(self, worker: Worker) -> None:
149152
"""
150153
await self.next_worker_plugin.run_worker(worker)
151154

155+
def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig:
156+
"""Hook called when creating a replayer to allow modification of configuration.
157+
158+
This should be used to configure anything in ReplayerConfig needed to make execution match
159+
the original. This could include interceptors, DataConverter, workflows, and more.
160+
161+
Uniquely does not rely on a chain, and is instead called sequentially on the plugins
162+
because the replayer cannot instantiate the worker/client component.
163+
"""
164+
return config
165+
166+
def workflow_replay(
167+
self,
168+
replayer: Replayer,
169+
histories: AsyncIterator[temporalio.client.WorkflowHistory],
170+
) -> AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]]:
171+
"""Hook called when running a replayer to allow interception of execution."""
172+
return self.next_worker_plugin.workflow_replay(replayer, histories)
173+
152174

153175
class _RootPlugin(Plugin):
154176
def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
@@ -157,6 +179,13 @@ def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
157179
async def run_worker(self, worker: Worker) -> None:
158180
await worker._run()
159181

182+
def workflow_replay(
183+
self,
184+
replayer: Replayer,
185+
histories: AsyncIterator[temporalio.client.WorkflowHistory],
186+
) -> AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]]:
187+
return replayer._workflow_replay_iterator(histories)
188+
160189

161190
class Worker:
162191
"""Worker to process workflows and/or activities.

temporalio/worker/workflow_sandbox/_importer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,12 @@ def module_configured_passthrough(self, name: str) -> bool:
285285
def _maybe_passthrough_module(self, name: str) -> Optional[types.ModuleType]:
286286
# If imports not passed through and all modules are not passed through
287287
# and name not in passthrough modules, check parents
288+
logger.debug(
289+
"Check passthrough module: %s - %s",
290+
name,
291+
temporalio.workflow.unsafe.is_imports_passed_through()
292+
or self.module_configured_passthrough(name),
293+
)
288294
if (
289295
not temporalio.workflow.unsafe.is_imports_passed_through()
290296
and not self.module_configured_passthrough(name)

0 commit comments

Comments
 (0)