From 5a9268c1c6535f9e0f857a91613a1350cfda4a67 Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Mon, 10 Nov 2025 14:04:19 -0800 Subject: [PATCH 01/11] Raise an assertion error when a Runtime is used by client/worker creation/usage. --- scripts/gen_bridge_client.py | 1 + temporalio/bridge/src/client.rs | 1 + temporalio/bridge/src/client_rpc_generated.rs | 5 ++ temporalio/bridge/src/runtime.rs | 16 +++- temporalio/bridge/src/worker.rs | 8 ++ tests/conftest.py | 17 ++++- tests/helpers/fork.py | 75 +++++++++++++++++++ tests/test_client.py | 41 +++++++++- tests/worker/test_worker.py | 43 +++++++++++ 9 files changed, 204 insertions(+), 3 deletions(-) create mode 100644 tests/helpers/fork.py diff --git a/scripts/gen_bridge_client.py b/scripts/gen_bridge_client.py index 89ae54ec1..f0a4457dc 100644 --- a/scripts/gen_bridge_client.py +++ b/scripts/gen_bridge_client.py @@ -171,6 +171,7 @@ def generate_rust_service_call(service_descriptor: ServiceDescriptor) -> str: py: Python<'p>, call: RpcCall, ) -> PyResult> { + self.runtime.assert_same_process("use client")?; use temporal_client::${descriptor_name}; let mut retry_client = self.retry_client.clone(); self.runtime.future_into_py(py, async move { diff --git a/temporalio/bridge/src/client.rs b/temporalio/bridge/src/client.rs index dfbd432a1..abe4a2354 100644 --- a/temporalio/bridge/src/client.rs +++ b/temporalio/bridge/src/client.rs @@ -92,6 +92,7 @@ pub fn connect_client<'a>( config: ClientConfig, ) -> PyResult> { let opts: ClientOptions = config.try_into()?; + runtime_ref.runtime.assert_same_process("create client")?; let runtime = runtime_ref.runtime.clone(); runtime_ref.runtime.future_into_py(py, async move { Ok(ClientRef { diff --git a/temporalio/bridge/src/client_rpc_generated.rs b/temporalio/bridge/src/client_rpc_generated.rs index 659f5d8cf..0b2d2ffa8 100644 --- a/temporalio/bridge/src/client_rpc_generated.rs +++ b/temporalio/bridge/src/client_rpc_generated.rs @@ -15,6 +15,7 @@ impl ClientRef { py: Python<'p>, call: RpcCall, ) -> PyResult> { + self.runtime.assert_same_process("use client")?; use temporal_client::WorkflowService; let mut retry_client = self.retry_client.clone(); self.runtime.future_into_py(py, async move { @@ -566,6 +567,7 @@ impl ClientRef { py: Python<'p>, call: RpcCall, ) -> PyResult> { + self.runtime.assert_same_process("use client")?; use temporal_client::OperatorService; let mut retry_client = self.retry_client.clone(); self.runtime.future_into_py(py, async move { @@ -628,6 +630,7 @@ impl ClientRef { } fn call_cloud_service<'p>(&self, py: Python<'p>, call: RpcCall) -> PyResult> { + self.runtime.assert_same_process("use client")?; use temporal_client::CloudService; let mut retry_client = self.retry_client.clone(); self.runtime.future_into_py(py, async move { @@ -842,6 +845,7 @@ impl ClientRef { } fn call_test_service<'p>(&self, py: Python<'p>, call: RpcCall) -> PyResult> { + self.runtime.assert_same_process("use client")?; use temporal_client::TestService; let mut retry_client = self.retry_client.clone(); self.runtime.future_into_py(py, async move { @@ -881,6 +885,7 @@ impl ClientRef { } fn call_health_service<'p>(&self, py: Python<'p>, call: RpcCall) -> PyResult> { + self.runtime.assert_same_process("use client")?; use temporal_client::HealthService; let mut retry_client = self.retry_client.clone(); self.runtime.future_into_py(py, async move { diff --git a/temporalio/bridge/src/runtime.rs b/temporalio/bridge/src/runtime.rs index 72cc905ae..a75aeb3e3 100644 --- a/temporalio/bridge/src/runtime.rs +++ b/temporalio/bridge/src/runtime.rs @@ -1,5 +1,5 @@ use futures::channel::mpsc::Receiver; -use pyo3::exceptions::{PyRuntimeError, PyValueError}; +use pyo3::exceptions::{PyAssertionError, PyRuntimeError, PyValueError}; use pyo3::prelude::*; use pythonize::pythonize; use std::collections::HashMap; @@ -33,6 +33,7 @@ pub struct RuntimeRef { #[derive(Clone)] pub(crate) struct Runtime { + pub(crate) pid: u32, pub(crate) core: Arc, metrics_call_buffer: Option>>, log_forwarder_handle: Option>>, @@ -173,6 +174,7 @@ pub fn init_runtime(telemetry_config: TelemetryConfig) -> PyResult { Ok(RuntimeRef { runtime: Runtime { + pid: std::process::id(), core: Arc::new(core), metrics_call_buffer, log_forwarder_handle, @@ -197,6 +199,18 @@ impl Runtime { let _guard = self.core.tokio_handle().enter(); pyo3_async_runtimes::generic::future_into_py::(py, fut) } + + pub(crate) fn assert_same_process(&self, action: &'static str) -> PyResult<()> { + let current_pid = std::process::id(); + if self.pid != current_pid { + Err(PyAssertionError::new_err(format!( + "Cannot {} across forks (original runtime PID is {}, current is {})", + action, self.pid, current_pid, + ))) + } else { + Ok(()) + } + } } impl Drop for Runtime { diff --git a/temporalio/bridge/src/worker.rs b/temporalio/bridge/src/worker.rs index 92b43f356..549f4268f 100644 --- a/temporalio/bridge/src/worker.rs +++ b/temporalio/bridge/src/worker.rs @@ -474,6 +474,7 @@ pub fn new_worker( config: WorkerConfig, ) -> PyResult { enter_sync!(runtime_ref.runtime); + runtime_ref.runtime.assert_same_process("create worker")?; let event_loop_task_locals = Arc::new(OnceLock::new()); let config = convert_worker_config(config, event_loop_task_locals.clone())?; let worker = temporal_sdk_core::init_worker( @@ -495,6 +496,9 @@ pub fn new_replay_worker<'a>( config: WorkerConfig, ) -> PyResult> { enter_sync!(runtime_ref.runtime); + runtime_ref + .runtime + .assert_same_process("create replay worker")?; let event_loop_task_locals = Arc::new(OnceLock::new()); let config = convert_worker_config(config, event_loop_task_locals.clone())?; let (history_pusher, stream) = HistoryPusher::new(runtime_ref.runtime.clone()); @@ -519,6 +523,7 @@ pub fn new_replay_worker<'a>( #[pymethods] impl WorkerRef { fn validate<'p>(&self, py: Python<'p>) -> PyResult> { + self.runtime.assert_same_process("use worker")?; let worker = self.worker.as_ref().unwrap().clone(); // Set custom slot supplier task locals so they can run futures. // Event loop is assumed to be running at this point. @@ -538,6 +543,7 @@ impl WorkerRef { } fn poll_workflow_activation<'p>(&self, py: Python<'p>) -> PyResult> { + self.runtime.assert_same_process("use worker")?; let worker = self.worker.as_ref().unwrap().clone(); self.runtime.future_into_py(py, async move { let bytes = match worker.poll_workflow_activation().await { @@ -550,6 +556,7 @@ impl WorkerRef { } fn poll_activity_task<'p>(&self, py: Python<'p>) -> PyResult> { + self.runtime.assert_same_process("use worker")?; let worker = self.worker.as_ref().unwrap().clone(); self.runtime.future_into_py(py, async move { let bytes = match worker.poll_activity_task().await { @@ -562,6 +569,7 @@ impl WorkerRef { } fn poll_nexus_task<'p>(&self, py: Python<'p>) -> PyResult> { + self.runtime.assert_same_process("use worker")?; let worker = self.worker.as_ref().unwrap().clone(); self.runtime.future_into_py(py, async move { let bytes = match worker.poll_nexus_task().await { diff --git a/tests/conftest.py b/tests/conftest.py index 8ffd3a456..5bff5cd46 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,8 @@ import asyncio +import multiprocessing.context import os import sys -from typing import AsyncGenerator +from typing import AsyncGenerator, Generator import pytest import pytest_asyncio @@ -133,6 +134,20 @@ async def env(env_type: str) -> AsyncGenerator[WorkflowEnvironment, None]: await env.shutdown() +@pytest.fixture(scope="session") +def mp_fork_ctx() -> Generator[multiprocessing.context.ForkContext | None]: + # ForkContext is not available on Windows + if sys.platform == "win32": + yield None + return + + mp_ctx = multiprocessing.get_context("fork") + yield mp_ctx + for p in mp_ctx.active_children(): + p.terminate() + p.join() + + @pytest_asyncio.fixture async def client(env: WorkflowEnvironment) -> Client: return env.client diff --git a/tests/helpers/fork.py b/tests/helpers/fork.py new file mode 100644 index 000000000..7f275b7ab --- /dev/null +++ b/tests/helpers/fork.py @@ -0,0 +1,75 @@ +import asyncio +import multiprocessing +import multiprocessing.context +from dataclasses import dataclass +from typing import Any, Self + +import pytest + + +@dataclass +class _ForkTestResult: + status: str + err_name: str | None + err_msg: str | None + + def __eq__(self, value: object) -> bool: + if not isinstance(value, _ForkTestResult): + return False + + valid_err_msg = False + + if self.err_msg and value.err_msg: + valid_err_msg = ( + self.err_msg in value.err_msg or value.err_msg in self.err_msg + ) + + return ( + value.status == self.status + and value.err_name == value.err_name + and valid_err_msg + ) + + @classmethod + def assertion_error(cls, message: str) -> Self: + return cls(status="error", err_name="AssertionError", err_msg=message) + + +class _TestFork: + _expected: _ForkTestResult + + async def coro(self) -> Any: + raise NotImplementedError() + + def entry(self): + event_loop = asyncio.new_event_loop() + asyncio.set_event_loop(event_loop) + try: + event_loop.run_until_complete(self.coro()) + payload = _ForkTestResult(status="ok", err_name=None, err_msg=None) + except BaseException as err: + payload = _ForkTestResult( + status="error", err_name=err.__class__.__name__, err_msg=str(err) + ) + + self._child_conn.send(payload) + self._child_conn.close() + + def run(self, mp_fork_context: multiprocessing.context.ForkContext | None): + if not mp_fork_context: + pytest.skip("fork context not available") + + self._parent_conn, self._child_conn = mp_fork_context.Pipe(duplex=False) + # start fork + child_process = mp_fork_context.Process( + target=self.entry, args=(), daemon=False + ) + child_process.start() + # close parent's handle on child_conn + self._child_conn.close() + + # get run info from pipe + payload = self._parent_conn.recv() + self._parent_conn.close() + + assert payload == self._expected diff --git a/tests/test_client.py b/tests/test_client.py index 63dec2810..a27c7925e 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,9 +1,11 @@ import dataclasses import json +import multiprocessing +import multiprocessing.context import os import uuid from datetime import datetime, timedelta, timezone -from typing import Any, List, Mapping, Optional, Tuple, Union, cast +from typing import Any, List, Mapping, Optional, Self, Tuple, Union, cast from unittest import mock import google.protobuf.any_pb2 @@ -90,6 +92,7 @@ new_worker, worker_versioning_enabled, ) +from tests.helpers.fork import _ForkTestResult, _TestFork from tests.helpers.worker import ( ExternalWorker, KSAction, @@ -1541,3 +1544,39 @@ async def get_schedule_result() -> Tuple[int, Optional[str]]: ) await handle.delete() + + +class TestForkCreateClient(_TestFork): + async def coro(self): + await Client.connect( + self._env.client.config()["service_client"].config.target_host + ) + + def test_fork_create_client( + self, + env: WorkflowEnvironment, + mp_fork_ctx: multiprocessing.context.ForkContext | None, + ): + self._expected = _ForkTestResult.assertion_error( + "Cannot create client across forks" + ) + self._env = env + self.run(mp_fork_ctx) + + +class TestForkUseClient(_TestFork): + async def coro(self): + await self._client.start_workflow( + "some-workflow", + id=f"workflow-{uuid.uuid4()}", + task_queue=f"tq-{uuid.uuid4()}", + ) + + def test_fork_use_client( + self, client: Client, mp_fork_ctx: multiprocessing.context.ForkContext | None + ): + self._expected = _ForkTestResult.assertion_error( + "Cannot use client across forks" + ) + self._client = client + self.run(mp_fork_ctx) diff --git a/tests/worker/test_worker.py b/tests/worker/test_worker.py index 32f27f631..c60206ace 100644 --- a/tests/worker/test_worker.py +++ b/tests/worker/test_worker.py @@ -2,6 +2,8 @@ import asyncio import concurrent.futures +import multiprocessing +import multiprocessing.context import uuid from datetime import timedelta from typing import Any, Awaitable, Callable, Optional, Sequence @@ -58,6 +60,7 @@ new_worker, worker_versioning_enabled, ) +from tests.helpers.fork import _ForkTestResult, _TestFork from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name @@ -1271,3 +1274,43 @@ def shutdown(self) -> None: if self.next_exception_task: self.next_exception_task.cancel() setattr(self.worker._bridge_worker, self.attr, self.orig_poll_call) + + +class TestForkCreateWorker(_TestFork): + async def coro(self): + self._worker = Worker( + self._client, + task_queue=f"task-queue-{uuid.uuid4()}", + activities=[never_run_activity], + workflows=[], + nexus_service_handlers=[], + ) + + def test_fork_create_worker( + self, client: Client, mp_fork_ctx: multiprocessing.context.ForkContext | None + ): + self._expected = _ForkTestResult.assertion_error( + "Cannot create worker across forks" + ) + self._client = client + self.run(mp_fork_ctx) + + +class TestForkUseWorker(_TestFork): + async def coro(self): + await self._pre_fork_worker.run() + + def test_fork_use_worker( + self, client: Client, mp_fork_ctx: multiprocessing.context.ForkContext | None + ): + self._expected = _ForkTestResult.assertion_error( + "Cannot use worker across forks" + ) + self._pre_fork_worker = Worker( + client, + task_queue=f"task-queue-{uuid.uuid4()}", + activities=[never_run_activity], + workflows=[], + nexus_service_handlers=[], + ) + self.run(mp_fork_ctx) From 41213ef2c7fc81f6d46bb9d087d6fbb808bd5632 Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Mon, 10 Nov 2025 15:07:52 -0800 Subject: [PATCH 02/11] Add _RuntimeRef to encapsulate default runtime creation. Add Runtime.prevent_default to allow users to more easily enforce that a default runtime is never automatically created --- temporalio/runtime.py | 74 ++++++++++++++++++++++++++++++++++++------- tests/test_runtime.py | 34 ++++++++++++++++++++ 2 files changed, 96 insertions(+), 12 deletions(-) diff --git a/temporalio/runtime.py b/temporalio/runtime.py index 64fa12192..7ed47751a 100644 --- a/temporalio/runtime.py +++ b/temporalio/runtime.py @@ -24,22 +24,62 @@ import temporalio.bridge.runtime import temporalio.common -_default_runtime: Optional[Runtime] = None + +class _RuntimeRef: + def __init__( + self, + ) -> None: + self._default_runtime: Runtime | None = None + self._prevent_default = False + self._default_created = False + + def default(self) -> Runtime: + if not self._default_runtime: + if self._prevent_default: + raise RuntimeError( + "Cannot create default Runtime after Runtime.prevent_default has been called" + ) + self._default_runtime = Runtime(telemetry=TelemetryConfig()) + self._default_created = True + return self._default_runtime + + def prevent_default(self): + if self._default_created: + raise RuntimeError( + "Runtime.prevent_default called after default runtime has been created" + ) + self._prevent_default = True + + def set_default( + self, runtime: Runtime, *, error_if_already_set: bool = True + ) -> None: + if self._default_runtime and error_if_already_set: + raise RuntimeError("Runtime default already set") + + self._default_runtime = runtime + + +_runtime_ref: _RuntimeRef = _RuntimeRef() class Runtime: """Runtime for Temporal Python SDK. - Users are encouraged to use :py:meth:`default`. It can be set with + Most users are encouraged to use :py:meth:`default`. It can be set with :py:meth:`set_default`. Every time a new runtime is created, a new internal thread pool is created. - Runtimes do not work across forks. + Runtimes do not work across forks. Advanced users should consider using + :py:meth:`prevent_default` and `:py:meth`set_default` to ensure each + fork creates it's own runtime. + """ @classmethod def default(cls) -> Runtime: - """Get the default runtime, creating if not already created. + """Get the default runtime, creating if not already created. If :py:meth:`prevent_default` + is called before this method it will raise a RuntimeError instead of creating a default + runtime. If the default runtime needs to be different, it should be done with :py:meth:`set_default` before this is called or ever used. @@ -47,10 +87,21 @@ def default(cls) -> Runtime: Returns: The default runtime. """ - global _default_runtime - if not _default_runtime: - _default_runtime = cls(telemetry=TelemetryConfig()) - return _default_runtime + + global _runtime_ref + return _runtime_ref.default() + + @classmethod + def prevent_default(cls): + """Prevent :py:meth:`default` from lazily creating a :py:class:`Runtime`. + + Raises a RuntimeError if a default :py:class:`Runtime` has already been created. + + Explicitly setting a default runtime with :py:meth:`set_default` bypasses this setting and + future calls to :py:meth:`default` will return provided runtime. + """ + global _runtime_ref + _runtime_ref.prevent_default() @staticmethod def set_default(runtime: Runtime, *, error_if_already_set: bool = True) -> None: @@ -65,10 +116,9 @@ def set_default(runtime: Runtime, *, error_if_already_set: bool = True) -> None: error_if_already_set: If True and default is already set, this will raise a RuntimeError. """ - global _default_runtime - if _default_runtime and error_if_already_set: - raise RuntimeError("Runtime default already set") - _default_runtime = runtime + global _runtime_ref + _runtime_ref.set_default(runtime, error_if_already_set=error_if_already_set) + return def __init__(self, *, telemetry: TelemetryConfig) -> None: """Create a default runtime with the given telemetry config. diff --git a/tests/test_runtime.py b/tests/test_runtime.py index 4505ebfcf..8e190094e 100644 --- a/tests/test_runtime.py +++ b/tests/test_runtime.py @@ -7,6 +7,8 @@ from typing import List, cast from urllib.request import urlopen +import pytest + from temporalio import workflow from temporalio.client import Client from temporalio.runtime import ( @@ -16,6 +18,7 @@ Runtime, TelemetryConfig, TelemetryFilter, + _RuntimeRef, ) from temporalio.worker import Worker from tests.helpers import assert_eq_eventually, assert_eventually, find_free_port @@ -254,3 +257,34 @@ async def check_metrics() -> None: # Wait for metrics to appear and match the expected buckets await assert_eventually(check_metrics) + + +def test_runtime_ref_creates_default(): + ref = _RuntimeRef() + assert not ref._default_created + ref.default() + assert ref._default_created + + +def test_runtime_ref_prevents_default(): + ref = _RuntimeRef() + ref.prevent_default() + with pytest.raises(RuntimeError): + ref.default() + + +def test_runtime_ref_prevent_default_errors_after_default(): + ref = _RuntimeRef() + ref.default() + with pytest.raises(RuntimeError): + ref.prevent_default() + + +def test_runtime_ref_set_default_allowed(): + ref = _RuntimeRef() + ref.prevent_default() + explicit_runtime = Runtime(telemetry=TelemetryConfig()) + ref.set_default(explicit_runtime) + + new_default = ref.default() + assert new_default is explicit_runtime From 59156b241db66234f6ad38501198fa5d23c237c0 Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Mon, 10 Nov 2025 15:12:46 -0800 Subject: [PATCH 03/11] remove blank line to fix linter error --- temporalio/runtime.py | 1 - 1 file changed, 1 deletion(-) diff --git a/temporalio/runtime.py b/temporalio/runtime.py index 7ed47751a..c2c7c96f9 100644 --- a/temporalio/runtime.py +++ b/temporalio/runtime.py @@ -87,7 +87,6 @@ def default(cls) -> Runtime: Returns: The default runtime. """ - global _runtime_ref return _runtime_ref.default() From 5f93b8e6271688ae85c1a63447be4e28f1bf0ab1 Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Mon, 10 Nov 2025 16:25:40 -0800 Subject: [PATCH 04/11] fix use of Self since it's not avaiable in typing until 3.11 --- tests/helpers/fork.py | 12 ++++++++---- tests/test_client.py | 2 +- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/tests/helpers/fork.py b/tests/helpers/fork.py index 7f275b7ab..d9b8ab4d0 100644 --- a/tests/helpers/fork.py +++ b/tests/helpers/fork.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import asyncio import multiprocessing import multiprocessing.context from dataclasses import dataclass -from typing import Any, Self +from typing import Any import pytest @@ -30,9 +32,11 @@ def __eq__(self, value: object) -> bool: and valid_err_msg ) - @classmethod - def assertion_error(cls, message: str) -> Self: - return cls(status="error", err_name="AssertionError", err_msg=message) + @staticmethod + def assertion_error(message: str) -> _ForkTestResult: + return _ForkTestResult( + status="error", err_name="AssertionError", err_msg=message + ) class _TestFork: diff --git a/tests/test_client.py b/tests/test_client.py index a27c7925e..e302629f9 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -5,7 +5,7 @@ import os import uuid from datetime import datetime, timedelta, timezone -from typing import Any, List, Mapping, Optional, Self, Tuple, Union, cast +from typing import Any, List, Mapping, Optional, Tuple, Union, cast from unittest import mock import google.protobuf.any_pb2 From 811aa6458c27b9876207dbc662ace4febfd5a14f Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Mon, 10 Nov 2025 16:41:28 -0800 Subject: [PATCH 05/11] remove references to ForkContext to avoid exploding in Windows --- tests/conftest.py | 25 ++++++++++++++----------- tests/helpers/fork.py | 11 ++++++----- tests/test_client.py | 4 ++-- tests/worker/test_worker.py | 4 ++-- 4 files changed, 24 insertions(+), 20 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 5bff5cd46..16fe5e4dd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -135,17 +135,20 @@ async def env(env_type: str) -> AsyncGenerator[WorkflowEnvironment, None]: @pytest.fixture(scope="session") -def mp_fork_ctx() -> Generator[multiprocessing.context.ForkContext | None]: - # ForkContext is not available on Windows - if sys.platform == "win32": - yield None - return - - mp_ctx = multiprocessing.get_context("fork") - yield mp_ctx - for p in mp_ctx.active_children(): - p.terminate() - p.join() +def mp_fork_ctx() -> Generator[multiprocessing.context.BaseContext | None]: + mp_ctx = None + try: + mp_ctx = multiprocessing.get_context("fork") + except KeyError: + pass + + try: + yield mp_ctx + finally: + if mp_ctx: + for p in mp_ctx.active_children(): + p.terminate() + p.join() @pytest_asyncio.fixture diff --git a/tests/helpers/fork.py b/tests/helpers/fork.py index d9b8ab4d0..e6d84652f 100644 --- a/tests/helpers/fork.py +++ b/tests/helpers/fork.py @@ -3,6 +3,7 @@ import asyncio import multiprocessing import multiprocessing.context +import sys from dataclasses import dataclass from typing import Any @@ -59,15 +60,15 @@ def entry(self): self._child_conn.send(payload) self._child_conn.close() - def run(self, mp_fork_context: multiprocessing.context.ForkContext | None): - if not mp_fork_context: + def run(self, mp_fork_context: multiprocessing.context.BaseContext | None): + process_factory = getattr(mp_fork_context, "Process", None) + + if not mp_fork_context or not process_factory: pytest.skip("fork context not available") self._parent_conn, self._child_conn = mp_fork_context.Pipe(duplex=False) # start fork - child_process = mp_fork_context.Process( - target=self.entry, args=(), daemon=False - ) + child_process = process_factory(target=self.entry, args=(), daemon=False) child_process.start() # close parent's handle on child_conn self._child_conn.close() diff --git a/tests/test_client.py b/tests/test_client.py index e302629f9..458492ff6 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1555,7 +1555,7 @@ async def coro(self): def test_fork_create_client( self, env: WorkflowEnvironment, - mp_fork_ctx: multiprocessing.context.ForkContext | None, + mp_fork_ctx: multiprocessing.context.BaseContext | None, ): self._expected = _ForkTestResult.assertion_error( "Cannot create client across forks" @@ -1573,7 +1573,7 @@ async def coro(self): ) def test_fork_use_client( - self, client: Client, mp_fork_ctx: multiprocessing.context.ForkContext | None + self, client: Client, mp_fork_ctx: multiprocessing.context.BaseContext | None ): self._expected = _ForkTestResult.assertion_error( "Cannot use client across forks" diff --git a/tests/worker/test_worker.py b/tests/worker/test_worker.py index c60206ace..9ad4be3af 100644 --- a/tests/worker/test_worker.py +++ b/tests/worker/test_worker.py @@ -1287,7 +1287,7 @@ async def coro(self): ) def test_fork_create_worker( - self, client: Client, mp_fork_ctx: multiprocessing.context.ForkContext | None + self, client: Client, mp_fork_ctx: multiprocessing.context.BaseContext | None ): self._expected = _ForkTestResult.assertion_error( "Cannot create worker across forks" @@ -1301,7 +1301,7 @@ async def coro(self): await self._pre_fork_worker.run() def test_fork_use_worker( - self, client: Client, mp_fork_ctx: multiprocessing.context.ForkContext | None + self, client: Client, mp_fork_ctx: multiprocessing.context.BaseContext | None ): self._expected = _ForkTestResult.assertion_error( "Cannot use worker across forks" From 2461cdda8e6c0c521c342a4741e681a6b7465962 Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Mon, 10 Nov 2025 16:50:12 -0800 Subject: [PATCH 06/11] switch type of fixture to Iterator instead of Generator --- tests/conftest.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 16fe5e4dd..5953fee12 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,7 @@ import multiprocessing.context import os import sys -from typing import AsyncGenerator, Generator +from typing import AsyncGenerator, Iterator import pytest import pytest_asyncio @@ -12,25 +12,25 @@ # If there is an integration test environment variable set, we must remove the # first path from the sys.path so we can import the wheel instead if os.getenv("TEMPORAL_INTEGRATION_TEST"): - assert ( - sys.path[0] == os.getcwd() - ), "Expected first sys.path to be the current working dir" + assert sys.path[0] == os.getcwd(), ( + "Expected first sys.path to be the current working dir" + ) sys.path.pop(0) # Import temporalio and confirm it is prefixed with virtual env import temporalio - assert temporalio.__file__.startswith( - sys.prefix - ), f"Expected {temporalio.__file__} to be in {sys.prefix}" + assert temporalio.__file__.startswith(sys.prefix), ( + f"Expected {temporalio.__file__} to be in {sys.prefix}" + ) # Unless specifically overridden, we expect tests to run under protobuf 4.x/5.x lib import google.protobuf protobuf_version = google.protobuf.__version__ if os.getenv("TEMPORAL_TEST_PROTO3"): - assert protobuf_version.startswith( - "3." - ), f"Expected protobuf 3.x, got {protobuf_version}" + assert protobuf_version.startswith("3."), ( + f"Expected protobuf 3.x, got {protobuf_version}" + ) else: assert ( protobuf_version.startswith("4.") @@ -135,7 +135,7 @@ async def env(env_type: str) -> AsyncGenerator[WorkflowEnvironment, None]: @pytest.fixture(scope="session") -def mp_fork_ctx() -> Generator[multiprocessing.context.BaseContext | None]: +def mp_fork_ctx() -> Iterator[multiprocessing.context.BaseContext | None]: mp_ctx = None try: mp_ctx = multiprocessing.get_context("fork") From ba9a8a81dcf38033a635fb79429602fc96198fef Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Mon, 10 Nov 2025 16:51:32 -0800 Subject: [PATCH 07/11] run formatter --- tests/conftest.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 5953fee12..1c2d38d7c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,25 +12,25 @@ # If there is an integration test environment variable set, we must remove the # first path from the sys.path so we can import the wheel instead if os.getenv("TEMPORAL_INTEGRATION_TEST"): - assert sys.path[0] == os.getcwd(), ( - "Expected first sys.path to be the current working dir" - ) + assert ( + sys.path[0] == os.getcwd() + ), "Expected first sys.path to be the current working dir" sys.path.pop(0) # Import temporalio and confirm it is prefixed with virtual env import temporalio - assert temporalio.__file__.startswith(sys.prefix), ( - f"Expected {temporalio.__file__} to be in {sys.prefix}" - ) + assert temporalio.__file__.startswith( + sys.prefix + ), f"Expected {temporalio.__file__} to be in {sys.prefix}" # Unless specifically overridden, we expect tests to run under protobuf 4.x/5.x lib import google.protobuf protobuf_version = google.protobuf.__version__ if os.getenv("TEMPORAL_TEST_PROTO3"): - assert protobuf_version.startswith("3."), ( - f"Expected protobuf 3.x, got {protobuf_version}" - ) + assert protobuf_version.startswith( + "3." + ), f"Expected protobuf 3.x, got {protobuf_version}" else: assert ( protobuf_version.startswith("4.") From 9301c5107cedaf06088eae90e8eace34a9d0d5d8 Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Tue, 11 Nov 2025 11:01:14 -0800 Subject: [PATCH 08/11] except the correct error type to prevent breakage on windows --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 1c2d38d7c..c0f8bc5e0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -139,7 +139,7 @@ def mp_fork_ctx() -> Iterator[multiprocessing.context.BaseContext | None]: mp_ctx = None try: mp_ctx = multiprocessing.get_context("fork") - except KeyError: + except ValueError: pass try: From c350ed58f956e5c3b2d837ac6f91f533a59cdd25 Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Wed, 12 Nov 2025 13:11:39 -0800 Subject: [PATCH 09/11] Update tests to match error info. Update prevent_default test to demonstrate that you can call prevent_default and then set_default to allow future calls to default. add tests for set_default --- tests/test_runtime.py | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/tests/test_runtime.py b/tests/test_runtime.py index 8e190094e..24b339457 100644 --- a/tests/test_runtime.py +++ b/tests/test_runtime.py @@ -269,22 +269,41 @@ def test_runtime_ref_creates_default(): def test_runtime_ref_prevents_default(): ref = _RuntimeRef() ref.prevent_default() - with pytest.raises(RuntimeError): + with pytest.raises(RuntimeError) as exc_info: ref.default() + assert exc_info.match( + "Cannot create default Runtime after Runtime.prevent_default has been called" + ) + + # explicitly setting a default runtime will allow future calls to `default()`` + explicit_runtime = Runtime(telemetry=TelemetryConfig()) + ref.set_default(explicit_runtime) + + assert ref.default() is explicit_runtime def test_runtime_ref_prevent_default_errors_after_default(): ref = _RuntimeRef() ref.default() - with pytest.raises(RuntimeError): + with pytest.raises(RuntimeError) as exc_info: ref.prevent_default() + assert exc_info.match( + "Runtime.prevent_default called after default runtime has been created" + ) + -def test_runtime_ref_set_default_allowed(): +def test_runtime_ref_set_default(): ref = _RuntimeRef() - ref.prevent_default() explicit_runtime = Runtime(telemetry=TelemetryConfig()) ref.set_default(explicit_runtime) + assert ref.default() is explicit_runtime + + new_runtime = Runtime(telemetry=TelemetryConfig()) + + with pytest.raises(RuntimeError) as exc_info: + ref.set_default(new_runtime) + assert exc_info.match("Runtime default already set") - new_default = ref.default() - assert new_default is explicit_runtime + ref.set_default(new_runtime, error_if_already_set=False) + assert ref.default() is new_runtime From db1ede034ed74947b8d771d44f109d14672363d7 Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Wed, 12 Nov 2025 13:18:44 -0800 Subject: [PATCH 10/11] fix typo in docstring --- temporalio/runtime.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/temporalio/runtime.py b/temporalio/runtime.py index c2c7c96f9..3b1903979 100644 --- a/temporalio/runtime.py +++ b/temporalio/runtime.py @@ -97,7 +97,7 @@ def prevent_default(cls): Raises a RuntimeError if a default :py:class:`Runtime` has already been created. Explicitly setting a default runtime with :py:meth:`set_default` bypasses this setting and - future calls to :py:meth:`default` will return provided runtime. + future calls to :py:meth:`default` will return the provided runtime. """ global _runtime_ref _runtime_ref.prevent_default() From 0f9bc70243b58e544506b1613df7baa06bebe9ec Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Fri, 14 Nov 2025 09:17:03 -0800 Subject: [PATCH 11/11] remove empty return in Runtime.set_default. Remove _default_created flag in favor of using the optional nature of _default_runtime in _RuntimeRef --- temporalio/runtime.py | 6 ++---- tests/test_runtime.py | 4 ++-- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/temporalio/runtime.py b/temporalio/runtime.py index 3b1903979..345d7ca77 100644 --- a/temporalio/runtime.py +++ b/temporalio/runtime.py @@ -31,7 +31,6 @@ def __init__( ) -> None: self._default_runtime: Runtime | None = None self._prevent_default = False - self._default_created = False def default(self) -> Runtime: if not self._default_runtime: @@ -44,9 +43,9 @@ def default(self) -> Runtime: return self._default_runtime def prevent_default(self): - if self._default_created: + if self._default_runtime: raise RuntimeError( - "Runtime.prevent_default called after default runtime has been created" + "Runtime.prevent_default called after default runtime has been created or set" ) self._prevent_default = True @@ -117,7 +116,6 @@ def set_default(runtime: Runtime, *, error_if_already_set: bool = True) -> None: """ global _runtime_ref _runtime_ref.set_default(runtime, error_if_already_set=error_if_already_set) - return def __init__(self, *, telemetry: TelemetryConfig) -> None: """Create a default runtime with the given telemetry config. diff --git a/tests/test_runtime.py b/tests/test_runtime.py index 24b339457..9b318bbb7 100644 --- a/tests/test_runtime.py +++ b/tests/test_runtime.py @@ -261,9 +261,9 @@ async def check_metrics() -> None: def test_runtime_ref_creates_default(): ref = _RuntimeRef() - assert not ref._default_created + assert not ref._default_runtime ref.default() - assert ref._default_created + assert ref._default_runtime def test_runtime_ref_prevents_default():