diff --git a/temporalio/contrib/openai_agents/_temporal_openai_agents.py b/temporalio/contrib/openai_agents/_temporal_openai_agents.py index a1f71db71..2375528c6 100644 --- a/temporalio/contrib/openai_agents/_temporal_openai_agents.py +++ b/temporalio/contrib/openai_agents/_temporal_openai_agents.py @@ -1,6 +1,6 @@ """Initialize Temporal OpenAI Agents overrides.""" -from contextlib import contextmanager +from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager from datetime import timedelta from typing import AsyncIterator, Callable, Optional, Union @@ -41,7 +41,13 @@ from temporalio.converter import ( DataConverter, ) -from temporalio.worker import Worker, WorkerConfig +from temporalio.worker import ( + Replayer, + ReplayerConfig, + Worker, + WorkerConfig, + WorkflowReplayResult, +) @contextmanager @@ -270,7 +276,8 @@ def configure_worker(self, config: WorkerConfig) -> WorkerConfig: ] return super().configure_worker(config) - async def run_worker(self, worker: Worker) -> None: + @asynccontextmanager + async def run_worker(self) -> AsyncIterator[None]: """Run the worker with OpenAI agents temporal overrides. This method sets up the necessary runtime overrides for OpenAI agents @@ -281,4 +288,15 @@ async def run_worker(self, worker: Worker) -> None: worker: The worker instance to run. """ with set_open_ai_agent_temporal_overrides(self._model_params): - await super().run_worker(worker) + async with super().run_worker(): + yield + + def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: + """Configure the replayer for OpenAI Agents.""" + config["interceptors"] = list(config.get("interceptors") or []) + [ + OpenAIAgentsTracingInterceptor() + ] + config["data_converter"] = DataConverter( + payload_converter_class=_OpenAIPayloadConverter + ) + return config diff --git a/temporalio/worker/_replayer.py b/temporalio/worker/_replayer.py index 6e9761b58..9bffa4164 100644 --- a/temporalio/worker/_replayer.py +++ b/temporalio/worker/_replayer.py @@ -5,7 +5,8 @@ import asyncio import concurrent.futures import logging -from contextlib import asynccontextmanager +import typing +from contextlib import AbstractAsyncContextManager, asynccontextmanager from dataclasses import dataclass from typing import AsyncIterator, Dict, Mapping, Optional, Sequence, Type @@ -21,7 +22,6 @@ from ..common import HeaderCodecBehavior from ._interceptor import Interceptor -from ._worker import load_default_build_id from ._workflow import _WorkflowWorker from ._workflow_instance import UnsandboxedWorkflowRunner, WorkflowRunner from .workflow_sandbox import SandboxedWorkflowRunner @@ -42,6 +42,7 @@ def __init__( namespace: str = "ReplayNamespace", data_converter: temporalio.converter.DataConverter = temporalio.converter.DataConverter.default, interceptors: Sequence[Interceptor] = [], + plugins: Sequence[temporalio.worker.Plugin] = [], build_id: Optional[str] = None, identity: Optional[str] = None, workflow_failure_exception_types: Sequence[Type[BaseException]] = [], @@ -62,8 +63,6 @@ def __init__( will be shared across all replay calls and never explicitly shut down. Users are encouraged to provide their own if needing more control. """ - if not workflows: - raise ValueError("At least one workflow must be specified") self._config = ReplayerConfig( workflows=list(workflows), workflow_task_executor=( @@ -83,6 +82,22 @@ def __init__( header_codec_behavior=header_codec_behavior, ) + from ._worker import _RootPlugin + + root_plugin: temporalio.worker.Plugin = _RootPlugin() + for plugin in reversed(plugins): + root_plugin = plugin.init_worker_plugin(root_plugin) + self._config = root_plugin.configure_replayer(self._config) + self._plugin = root_plugin + + # Apply plugin configuration + for plugin in reversed(plugins): + self._config = plugin.configure_replayer(self._config) + + # Validate workflows after plugin configuration + if not self._config["workflows"]: + raise ValueError("At least one workflow must be specified") + def config(self) -> ReplayerConfig: """Config, as a dictionary, used to create this replayer. @@ -165,166 +180,175 @@ async def workflow_replay_iterator( An async iterator that returns replayed workflow results as they are replayed. """ - try: - last_replay_failure: Optional[Exception] - last_replay_complete = asyncio.Event() - - # Create eviction hook - def on_eviction_hook( - run_id: str, - remove_job: temporalio.bridge.proto.workflow_activation.RemoveFromCache, - ) -> None: - nonlocal last_replay_failure - if ( - remove_job.reason - == temporalio.bridge.proto.workflow_activation.RemoveFromCache.EvictionReason.NONDETERMINISM - ): - last_replay_failure = temporalio.workflow.NondeterminismError( - remove_job.message - ) - elif ( - remove_job.reason - != temporalio.bridge.proto.workflow_activation.RemoveFromCache.EvictionReason.CACHE_FULL - and remove_job.reason - != temporalio.bridge.proto.workflow_activation.RemoveFromCache.EvictionReason.LANG_REQUESTED - ): - last_replay_failure = RuntimeError( - f"{remove_job.reason}: {remove_job.message}" - ) - else: - last_replay_failure = None - last_replay_complete.set() - - # Create worker referencing bridge worker - bridge_worker: temporalio.bridge.worker.Worker - task_queue = f"replay-{self._config['build_id']}" - runtime = self._config["runtime"] or temporalio.runtime.Runtime.default() - workflow_worker = _WorkflowWorker( - bridge_worker=lambda: bridge_worker, - namespace=self._config["namespace"], - task_queue=task_queue, - workflows=self._config["workflows"], - workflow_task_executor=self._config["workflow_task_executor"], - max_concurrent_workflow_tasks=5, - workflow_runner=self._config["workflow_runner"], - unsandboxed_workflow_runner=self._config["unsandboxed_workflow_runner"], - data_converter=self._config["data_converter"], - interceptors=self._config["interceptors"], - workflow_failure_exception_types=self._config[ - "workflow_failure_exception_types" - ], - debug_mode=self._config["debug_mode"], - metric_meter=runtime.metric_meter, - on_eviction_hook=on_eviction_hook, - disable_eager_activity_execution=False, - disable_safe_eviction=self._config["disable_safe_workflow_eviction"], - should_enforce_versioning_behavior=False, - assert_local_activity_valid=lambda a: None, - encode_headers=self._config["header_codec_behavior"] - != HeaderCodecBehavior.NO_CODEC, - ) - # Create bridge worker - bridge_worker, pusher = temporalio.bridge.worker.Worker.for_replay( - runtime._core_runtime, - temporalio.bridge.worker.WorkerConfig( + async with self._plugin.run_worker(): + try: + last_replay_failure: Optional[Exception] + last_replay_complete = asyncio.Event() + + # Create eviction hook + def on_eviction_hook( + run_id: str, + remove_job: temporalio.bridge.proto.workflow_activation.RemoveFromCache, + ) -> None: + nonlocal last_replay_failure + if ( + remove_job.reason + == temporalio.bridge.proto.workflow_activation.RemoveFromCache.EvictionReason.NONDETERMINISM + ): + last_replay_failure = temporalio.workflow.NondeterminismError( + remove_job.message + ) + elif ( + remove_job.reason + != temporalio.bridge.proto.workflow_activation.RemoveFromCache.EvictionReason.CACHE_FULL + and remove_job.reason + != temporalio.bridge.proto.workflow_activation.RemoveFromCache.EvictionReason.LANG_REQUESTED + ): + last_replay_failure = RuntimeError( + f"{remove_job.reason}: {remove_job.message}" + ) + else: + last_replay_failure = None + last_replay_complete.set() + + # Create worker referencing bridge worker + bridge_worker: temporalio.bridge.worker.Worker + task_queue = f"replay-{self._config['build_id']}" + runtime = ( + self._config["runtime"] or temporalio.runtime.Runtime.default() + ) + workflow_worker = _WorkflowWorker( + bridge_worker=lambda: bridge_worker, namespace=self._config["namespace"], task_queue=task_queue, - identity_override=self._config["identity"], - # Need to tell core whether we want to consider all - # non-determinism exceptions as workflow fail, and whether we do - # per workflow type - nondeterminism_as_workflow_fail=workflow_worker.nondeterminism_as_workflow_fail(), - nondeterminism_as_workflow_fail_for_types=workflow_worker.nondeterminism_as_workflow_fail_for_types(), - # All values below are ignored but required by Core - max_cached_workflows=2, - tuner=temporalio.bridge.worker.TunerHolder( - workflow_slot_supplier=temporalio.bridge.worker.FixedSizeSlotSupplier( + workflows=self._config["workflows"], + workflow_task_executor=self._config["workflow_task_executor"], + max_concurrent_workflow_tasks=5, + workflow_runner=self._config["workflow_runner"], + unsandboxed_workflow_runner=self._config[ + "unsandboxed_workflow_runner" + ], + data_converter=self._config["data_converter"], + interceptors=self._config["interceptors"], + workflow_failure_exception_types=self._config[ + "workflow_failure_exception_types" + ], + debug_mode=self._config["debug_mode"], + metric_meter=runtime.metric_meter, + on_eviction_hook=on_eviction_hook, + disable_eager_activity_execution=False, + disable_safe_eviction=self._config[ + "disable_safe_workflow_eviction" + ], + should_enforce_versioning_behavior=False, + assert_local_activity_valid=lambda a: None, + encode_headers=self._config["header_codec_behavior"] + != HeaderCodecBehavior.NO_CODEC, + ) + # Create bridge worker + from ._worker import load_default_build_id + + bridge_worker, pusher = temporalio.bridge.worker.Worker.for_replay( + runtime._core_runtime, + temporalio.bridge.worker.WorkerConfig( + namespace=self._config["namespace"], + task_queue=task_queue, + identity_override=self._config["identity"], + # Need to tell core whether we want to consider all + # non-determinism exceptions as workflow fail, and whether we do + # per workflow type + nondeterminism_as_workflow_fail=workflow_worker.nondeterminism_as_workflow_fail(), + nondeterminism_as_workflow_fail_for_types=workflow_worker.nondeterminism_as_workflow_fail_for_types(), + # All values below are ignored but required by Core + max_cached_workflows=2, + tuner=temporalio.bridge.worker.TunerHolder( + workflow_slot_supplier=temporalio.bridge.worker.FixedSizeSlotSupplier( + 2 + ), + activity_slot_supplier=temporalio.bridge.worker.FixedSizeSlotSupplier( + 1 + ), + local_activity_slot_supplier=temporalio.bridge.worker.FixedSizeSlotSupplier( + 1 + ), + ), + nonsticky_to_sticky_poll_ratio=1, + no_remote_activities=True, + sticky_queue_schedule_to_start_timeout_millis=1000, + max_heartbeat_throttle_interval_millis=1000, + default_heartbeat_throttle_interval_millis=1000, + max_activities_per_second=None, + max_task_queue_activities_per_second=None, + graceful_shutdown_period_millis=0, + versioning_strategy=temporalio.bridge.worker.WorkerVersioningStrategyNone( + build_id_no_versioning=self._config["build_id"] + or load_default_build_id(), + ), + workflow_task_poller_behavior=temporalio.bridge.worker.PollerBehaviorSimpleMaximum( 2 ), - activity_slot_supplier=temporalio.bridge.worker.FixedSizeSlotSupplier( + activity_task_poller_behavior=temporalio.bridge.worker.PollerBehaviorSimpleMaximum( 1 ), - local_activity_slot_supplier=temporalio.bridge.worker.FixedSizeSlotSupplier( + nexus_task_poller_behavior=temporalio.bridge.worker.PollerBehaviorSimpleMaximum( 1 ), ), - nonsticky_to_sticky_poll_ratio=1, - no_remote_activities=True, - sticky_queue_schedule_to_start_timeout_millis=1000, - max_heartbeat_throttle_interval_millis=1000, - default_heartbeat_throttle_interval_millis=1000, - max_activities_per_second=None, - max_task_queue_activities_per_second=None, - graceful_shutdown_period_millis=0, - versioning_strategy=temporalio.bridge.worker.WorkerVersioningStrategyNone( - build_id_no_versioning=self._config["build_id"] - or load_default_build_id(), - ), - workflow_task_poller_behavior=temporalio.bridge.worker.PollerBehaviorSimpleMaximum( - 2 - ), - activity_task_poller_behavior=temporalio.bridge.worker.PollerBehaviorSimpleMaximum( - 1 - ), - nexus_task_poller_behavior=temporalio.bridge.worker.PollerBehaviorSimpleMaximum( - 1 - ), - ), - ) - # Start worker - workflow_worker_task = asyncio.create_task(workflow_worker.run()) - - # Yield iterator - async def replay_iterator() -> AsyncIterator[WorkflowReplayResult]: - async for history in histories: - # Clear last complete and push history - last_replay_complete.clear() - await pusher.push_history( - history.workflow_id, - temporalio.api.history.v1.History( - events=history.events - ).SerializeToString(), - ) - - # Wait for worker error or last replay to complete. This - # should never take more than a few seconds due to deadlock - # detector but we cannot add timeout just in case debug mode - # is enabled. - await asyncio.wait( # type: ignore - [ - workflow_worker_task, - asyncio.create_task(last_replay_complete.wait()), - ], - return_when=asyncio.FIRST_COMPLETED, - ) - # If worker task complete, wait on it so it'll throw - if workflow_worker_task.done(): - await workflow_worker_task - # Should always be set if workflow worker didn't throw - assert last_replay_complete.is_set() - - yield WorkflowReplayResult( - history=history, - replay_failure=last_replay_failure, - ) - - yield replay_iterator() - finally: - # Close the pusher - pusher.close() - # If the workflow worker task is not done, wait for it - try: - if not workflow_worker_task.done(): - await workflow_worker_task - except Exception: - logger.warning("Failed to shutdown worker", exc_info=True) + ) + # Start worker + workflow_worker_task = asyncio.create_task(workflow_worker.run()) + + # Yield iterator + async def replay_iterator() -> AsyncIterator[WorkflowReplayResult]: + async for history in histories: + # Clear last complete and push history + last_replay_complete.clear() + await pusher.push_history( + history.workflow_id, + temporalio.api.history.v1.History( + events=history.events + ).SerializeToString(), + ) + + # Wait for worker error or last replay to complete. This + # should never take more than a few seconds due to deadlock + # detector but we cannot add timeout just in case debug mode + # is enabled. + await asyncio.wait( # type: ignore + [ + workflow_worker_task, + asyncio.create_task(last_replay_complete.wait()), + ], + return_when=asyncio.FIRST_COMPLETED, + ) + # If worker task complete, wait on it so it'll throw + if workflow_worker_task.done(): + await workflow_worker_task + # Should always be set if workflow worker didn't throw + assert last_replay_complete.is_set() + + yield WorkflowReplayResult( + history=history, + replay_failure=last_replay_failure, + ) + + yield replay_iterator() finally: - # We must shutdown here + # Close the pusher + pusher.close() + # If the workflow worker task is not done, wait for it try: - bridge_worker.initiate_shutdown() - await bridge_worker.finalize_shutdown() + if not workflow_worker_task.done(): + await workflow_worker_task except Exception: - logger.warning("Failed to finalize shutdown", exc_info=True) + logger.warning("Failed to shutdown worker", exc_info=True) + finally: + # We must shutdown here + try: + bridge_worker.initiate_shutdown() + await bridge_worker.finalize_shutdown() + except Exception: + logger.warning("Failed to finalize shutdown", exc_info=True) class ReplayerConfig(TypedDict, total=False): diff --git a/temporalio/worker/_worker.py b/temporalio/worker/_worker.py index 58f881c04..b66637e08 100644 --- a/temporalio/worker/_worker.py +++ b/temporalio/worker/_worker.py @@ -9,10 +9,12 @@ import logging import sys import warnings +from contextlib import AbstractAsyncContextManager, asynccontextmanager from dataclasses import dataclass from datetime import timedelta from typing import ( Any, + AsyncIterator, Awaitable, Callable, List, @@ -36,6 +38,7 @@ WorkerDeploymentVersion, ) +from . import Replayer, ReplayerConfig, WorkflowReplayResult from ._activity import SharedStateManager, _ActivityWorker from ._interceptor import Interceptor from ._nexus import _NexusWorker @@ -137,25 +140,34 @@ def configure_worker(self, config: WorkerConfig) -> WorkerConfig: """ return self.next_worker_plugin.configure_worker(config) - async def run_worker(self, worker: Worker) -> None: + def run_worker(self) -> AbstractAsyncContextManager[None]: """Hook called when running a worker to allow interception of execution. This method is called when the worker is started and allows plugins to intercept or wrap the worker execution. Plugins can add monitoring, custom lifecycle management, or other execution-time behavior. + """ + return self.next_worker_plugin.run_worker() - Args: - worker: The worker instance to run. + def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: + """Hook called when creating a replayer to allow modification of configuration. + + This should be used to configure anything in ReplayerConfig needed to make execution match + the original. This could include interceptors, DataConverter, workflows, and more. + + Uniquely does not rely on a chain, and is instead called sequentially on the plugins + because the replayer cannot instantiate the worker/client component. """ - await self.next_worker_plugin.run_worker(worker) + return config class _RootPlugin(Plugin): def configure_worker(self, config: WorkerConfig) -> WorkerConfig: return config - async def run_worker(self, worker: Worker) -> None: - await worker._run() + @asynccontextmanager + async def run_worker(self) -> AsyncIterator[None]: + yield class Worker: @@ -765,7 +777,8 @@ async def run(self) -> None: also cancel the shutdown process. Therefore users are encouraged to use explicit shutdown instead. """ - await self._plugin.run_worker(self) + async with self._plugin.run_worker(): + await self._run() async def _run(self): # Eagerly validate which will do a namespace check in Core diff --git a/tests/contrib/openai_agents/test_openai_replay.py b/tests/contrib/openai_agents/test_openai_replay.py index d3ac92c5e..c6ac1ea68 100644 --- a/tests/contrib/openai_agents/test_openai_replay.py +++ b/tests/contrib/openai_agents/test_openai_replay.py @@ -1,16 +1,15 @@ +from contextlib import ( + AbstractAsyncContextManager, + AbstractContextManager, + asynccontextmanager, +) from pathlib import Path +from typing import AsyncGenerator import pytest from temporalio.client import WorkflowHistory -from temporalio.contrib.openai_agents import ModelActivityParameters -from temporalio.contrib.openai_agents._temporal_openai_agents import ( - set_open_ai_agent_temporal_overrides, -) -from temporalio.contrib.openai_agents._trace_interceptor import ( - OpenAIAgentsTracingInterceptor, -) -from temporalio.contrib.pydantic import pydantic_data_converter +from temporalio.contrib.openai_agents import ModelActivityParameters, OpenAIAgentsPlugin from temporalio.worker import Replayer from tests.contrib.openai_agents.test_openai import ( AgentsAsToolsWorkflow, @@ -39,17 +38,15 @@ async def test_replay(file_name: str) -> None: with (Path(__file__).with_name("histories") / file_name).open("r") as f: history_json = f.read() - with set_open_ai_agent_temporal_overrides(ModelActivityParameters()): - await Replayer( - workflows=[ - ResearchWorkflow, - ToolsWorkflow, - CustomerServiceWorkflow, - AgentsAsToolsWorkflow, - HelloWorldAgent, - InputGuardrailWorkflow, - OutputGuardrailWorkflow, - ], - data_converter=pydantic_data_converter, - interceptors=[OpenAIAgentsTracingInterceptor()], - ).replay_workflow(WorkflowHistory.from_json("fake", history_json)) + await Replayer( + workflows=[ + ResearchWorkflow, + ToolsWorkflow, + CustomerServiceWorkflow, + AgentsAsToolsWorkflow, + HelloWorldAgent, + InputGuardrailWorkflow, + OutputGuardrailWorkflow, + ], + plugins=[OpenAIAgentsPlugin()], + ).replay_workflow(WorkflowHistory.from_json("fake", history_json)) diff --git a/tests/test_plugins.py b/tests/test_plugins.py index 4a60bba4d..868ff4115 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -1,15 +1,21 @@ import dataclasses +import uuid import warnings -from typing import cast +from contextlib import asynccontextmanager +from datetime import timedelta +from typing import AsyncIterator, cast import pytest import temporalio.client import temporalio.worker +from temporalio import workflow from temporalio.client import Client, ClientConfig, OutboundInterceptor +from temporalio.contrib.pydantic import pydantic_data_converter from temporalio.testing import WorkflowEnvironment -from temporalio.worker import Worker, WorkerConfig +from temporalio.worker import Replayer, ReplayerConfig, Worker, WorkerConfig from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner +from tests.helpers import new_worker from tests.worker.test_worker import never_run_activity @@ -64,6 +70,9 @@ def configure_worker(self, config: WorkerConfig) -> WorkerConfig: return super().configure_worker(config) +IN_CONTEXT: bool = False + + class MyWorkerPlugin(temporalio.worker.Plugin): def configure_worker(self, config: WorkerConfig) -> WorkerConfig: config["task_queue"] = "replaced_queue" @@ -75,8 +84,15 @@ def configure_worker(self, config: WorkerConfig) -> WorkerConfig: ) return super().configure_worker(config) - async def run_worker(self, worker: Worker) -> None: - await super().run_worker(worker) + @asynccontextmanager + async def run_worker(self) -> AsyncIterator[None]: + global IN_CONTEXT + try: + IN_CONTEXT = True + async with super().run_worker(): + yield + finally: + IN_CONTEXT = False async def test_worker_plugin_basic_config(client: Client) -> None: @@ -105,6 +121,30 @@ async def test_worker_plugin_basic_config(client: Client) -> None: assert worker.config().get("task_queue") == "replaced_queue" +@workflow.defn(sandboxed=False) +class CheckContextWorkflow: + @workflow.run + async def run(self) -> bool: + return IN_CONTEXT + + +async def test_worker_plugin_run_context(client: Client) -> None: + async with Worker( + client, + task_queue=str(uuid.uuid4()), + workflows=[CheckContextWorkflow], + activities=[never_run_activity], + plugins=[MyWorkerPlugin()], + ) as worker: + result = await client.execute_workflow( + CheckContextWorkflow.run, + task_queue=worker.task_queue, + id=f"workflow-{uuid.uuid4()}", + execution_timeout=timedelta(seconds=1), + ) + assert result + + async def test_worker_duplicated_plugin(client: Client) -> None: new_config = client.config() new_config["plugins"] = [MyCombinedPlugin()] @@ -136,3 +176,46 @@ async def test_worker_sandbox_restrictions(client: Client) -> None: SandboxedWorkflowRunner, worker.config().get("workflow_runner") ).restrictions.passthrough_modules ) + + +class ReplayCheckPlugin(temporalio.client.Plugin, temporalio.worker.Plugin): + def configure_client(self, config: ClientConfig) -> ClientConfig: + config["data_converter"] = pydantic_data_converter + return super().configure_client(config) + + def configure_worker(self, config: WorkerConfig) -> WorkerConfig: + config["workflows"] = list(config.get("workflows") or []) + [HelloWorkflow] + return super().configure_worker(config) + + def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: + config["data_converter"] = pydantic_data_converter + config["workflows"] = list(config.get("workflows") or []) + [HelloWorkflow] + return config + + +@workflow.defn +class HelloWorkflow: + @workflow.run + async def run(self, name: str) -> str: + return f"Hello, {name}!" + + +async def test_replay(client: Client) -> None: + plugin = ReplayCheckPlugin() + new_config = client.config() + new_config["plugins"] = [plugin] + client = Client(**new_config) + + async with new_worker(client) as worker: + handle = await client.start_workflow( + HelloWorkflow.run, + "Tim", + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + await handle.result() + replayer = Replayer(workflows=[], plugins=[plugin]) + assert len(replayer.config().get("workflows") or []) == 1 + assert replayer.config().get("data_converter") == pydantic_data_converter + + await replayer.replay_workflow(await handle.fetch_history())