Skip to content

Commit cbc5fa3

Browse files
committed
POC for moving plugin run_worker to a context
1 parent 1007d02 commit cbc5fa3

File tree

4 files changed

+50
-19
lines changed

4 files changed

+50
-19
lines changed

temporalio/contrib/openai_agents/_temporal_openai_agents.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,8 @@ def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
276276
]
277277
return super().configure_worker(config)
278278

279-
async def run_worker(self, worker: Worker) -> None:
279+
@asynccontextmanager
280+
async def run_worker(self) -> AsyncIterator[None]:
280281
"""Run the worker with OpenAI agents temporal overrides.
281282
282283
This method sets up the necessary runtime overrides for OpenAI agents
@@ -287,7 +288,8 @@ async def run_worker(self, worker: Worker) -> None:
287288
worker: The worker instance to run.
288289
"""
289290
with set_open_ai_agent_temporal_overrides(self._model_params):
290-
await super().run_worker(worker)
291+
async with super().run_worker():
292+
yield
291293

292294
def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig:
293295
"""Configure the replayer for OpenAI Agents."""

temporalio/worker/_worker.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -140,17 +140,14 @@ def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
140140
"""
141141
return self.next_worker_plugin.configure_worker(config)
142142

143-
async def run_worker(self, worker: Worker) -> None:
143+
def run_worker(self) -> AbstractAsyncContextManager[None]:
144144
"""Hook called when running a worker to allow interception of execution.
145145
146146
This method is called when the worker is started and allows plugins to
147147
intercept or wrap the worker execution. Plugins can add monitoring,
148148
custom lifecycle management, or other execution-time behavior.
149-
150-
Args:
151-
worker: The worker instance to run.
152149
"""
153-
await self.next_worker_plugin.run_worker(worker)
150+
return self.next_worker_plugin.run_worker()
154151

155152
def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig:
156153
"""Hook called when creating a replayer to allow modification of configuration.
@@ -176,8 +173,9 @@ class _RootPlugin(Plugin):
176173
def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
177174
return config
178175

179-
async def run_worker(self, worker: Worker) -> None:
180-
await worker._run()
176+
@asynccontextmanager
177+
async def run_worker(self) -> AsyncIterator[None]:
178+
yield
181179

182180
def workflow_replay(
183181
self,
@@ -794,7 +792,8 @@ async def run(self) -> None:
794792
also cancel the shutdown process. Therefore users are encouraged to use
795793
explicit shutdown instead.
796794
"""
797-
await self._plugin.run_worker(self)
795+
async with self._plugin.run_worker():
796+
await self._run()
798797

799798
async def _run(self):
800799
# Eagerly validate which will do a namespace check in Core

temporalio/worker/workflow_sandbox/_importer.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -285,12 +285,6 @@ def module_configured_passthrough(self, name: str) -> bool:
285285
def _maybe_passthrough_module(self, name: str) -> Optional[types.ModuleType]:
286286
# If imports not passed through and all modules are not passed through
287287
# and name not in passthrough modules, check parents
288-
logger.debug(
289-
"Check passthrough module: %s - %s",
290-
name,
291-
temporalio.workflow.unsafe.is_imports_passed_through()
292-
or self.module_configured_passthrough(name),
293-
)
294288
if (
295289
not temporalio.workflow.unsafe.is_imports_passed_through()
296290
and not self.module_configured_passthrough(name)

tests/test_plugins.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import dataclasses
22
import uuid
33
import warnings
4-
from typing import cast
4+
from contextlib import asynccontextmanager
5+
from datetime import timedelta
6+
from typing import AsyncIterator, cast
57

68
import pytest
79

@@ -68,6 +70,9 @@ def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
6870
return super().configure_worker(config)
6971

7072

73+
IN_CONTEXT: bool = False
74+
75+
7176
class MyWorkerPlugin(temporalio.worker.Plugin):
7277
def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
7378
config["task_queue"] = "replaced_queue"
@@ -79,8 +84,15 @@ def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
7984
)
8085
return super().configure_worker(config)
8186

82-
async def run_worker(self, worker: Worker) -> None:
83-
await super().run_worker(worker)
87+
@asynccontextmanager
88+
async def run_worker(self) -> AsyncIterator[None]:
89+
global IN_CONTEXT
90+
try:
91+
IN_CONTEXT = True
92+
async with super().run_worker():
93+
yield
94+
finally:
95+
IN_CONTEXT = False
8496

8597

8698
async def test_worker_plugin_basic_config(client: Client) -> None:
@@ -109,6 +121,30 @@ async def test_worker_plugin_basic_config(client: Client) -> None:
109121
assert worker.config().get("task_queue") == "replaced_queue"
110122

111123

124+
@workflow.defn(sandboxed=False)
125+
class CheckContextWorkflow:
126+
@workflow.run
127+
async def run(self) -> bool:
128+
return IN_CONTEXT
129+
130+
131+
async def test_worker_plugin_run_context(client: Client) -> None:
132+
async with Worker(
133+
client,
134+
task_queue=str(uuid.uuid4()),
135+
workflows=[CheckContextWorkflow],
136+
activities=[never_run_activity],
137+
plugins=[MyWorkerPlugin()],
138+
) as worker:
139+
result = await client.execute_workflow(
140+
CheckContextWorkflow.run,
141+
task_queue=worker.task_queue,
142+
id=f"workflow-{uuid.uuid4()}",
143+
execution_timeout=timedelta(seconds=1),
144+
)
145+
assert result
146+
147+
112148
async def test_worker_duplicated_plugin(client: Client) -> None:
113149
new_config = client.config()
114150
new_config["plugins"] = [MyCombinedPlugin()]

0 commit comments

Comments
 (0)