diff --git a/temporalio/worker/workflow_sandbox/__init__.py b/temporalio/worker/workflow_sandbox/__init__.py index 399e633cc..1bdc62954 100644 --- a/temporalio/worker/workflow_sandbox/__init__.py +++ b/temporalio/worker/workflow_sandbox/__init__.py @@ -58,11 +58,13 @@ RestrictedWorkflowAccessError, SandboxMatcher, SandboxRestrictions, + UnintentionalPassthroughError, ) from ._runner import SandboxedWorkflowRunner __all__ = [ "RestrictedWorkflowAccessError", + "UnintentionalPassthroughError", "SandboxedWorkflowRunner", "SandboxMatcher", "SandboxRestrictions", diff --git a/temporalio/worker/workflow_sandbox/_importer.py b/temporalio/worker/workflow_sandbox/_importer.py index 7fd4beaac..dd5bb050c 100644 --- a/temporalio/worker/workflow_sandbox/_importer.py +++ b/temporalio/worker/workflow_sandbox/_importer.py @@ -42,6 +42,7 @@ RestrictedWorkflowAccessError, RestrictionContext, SandboxRestrictions, + UnintentionalPassthroughError, ) logger = logging.getLogger(__name__) @@ -200,6 +201,17 @@ def _import( # Check module restrictions and passthrough modules if full_name not in sys.modules: + # Issue a warning if appropriate + if ( + self.restriction_context.in_activation + and self._is_import_notification_policy_applied( + temporalio.workflow.SandboxImportNotificationPolicy.WARN_ON_DYNAMIC_IMPORT + ) + ): + warnings.warn( + f"Module {full_name} was imported after initial workflow load." + ) + # Make sure not an entirely invalid module self._assert_valid_module(full_name) @@ -282,6 +294,17 @@ def module_configured_passthrough(self, name: str) -> bool: break return True + def _is_import_notification_policy_applied( + self, policy: temporalio.workflow.SandboxImportNotificationPolicy + ) -> bool: + override_policy = ( + temporalio.workflow.unsafe.current_import_notification_policy_override() + ) + if override_policy: + return policy in override_policy + + return policy in self.restrictions.import_notification_policy + def _maybe_passthrough_module(self, name: str) -> Optional[types.ModuleType]: # If imports not passed through and all modules are not passed through # and name not in passthrough modules, check parents @@ -289,6 +312,18 @@ def _maybe_passthrough_module(self, name: str) -> Optional[types.ModuleType]: not temporalio.workflow.unsafe.is_imports_passed_through() and not self.module_configured_passthrough(name) ): + if self._is_import_notification_policy_applied( + temporalio.workflow.SandboxImportNotificationPolicy.RAISE_ON_UNINTENTIONAL_PASSTHROUGH + ): + raise UnintentionalPassthroughError(name) + + if self._is_import_notification_policy_applied( + temporalio.workflow.SandboxImportNotificationPolicy.WARN_ON_UNINTENTIONAL_PASSTHROUGH + ): + warnings.warn( + f"Module {name} was not intentionally passed through to the sandbox." + ) + return None # Do the pass through with self._unapplied(): diff --git a/temporalio/worker/workflow_sandbox/_restrictions.py b/temporalio/worker/workflow_sandbox/_restrictions.py index baad22fcb..178331a56 100644 --- a/temporalio/worker/workflow_sandbox/_restrictions.py +++ b/temporalio/worker/workflow_sandbox/_restrictions.py @@ -42,6 +42,7 @@ except ImportError: HAVE_PYDANTIC = False +import temporalio.exceptions import temporalio.workflow logger = logging.getLogger(__name__) @@ -82,6 +83,21 @@ def default_message(qualified_name: str) -> str: ) +class UnintentionalPassthroughError(temporalio.exceptions.TemporalError): + """Error that occurs when a workflow unintentionally passes an import to the sandbox when + the import notification policy includes :py:attr:`temporalio.workflow.SandboxImportNotificationPolicy.RAISE_ON_NON_PASSTHROUGH`. + + Attributes: + qualified_name: Fully qualified name of what was passed through to the sandbox. + """ + + def __init__(self, qualified_name: str) -> None: + """Create an unintentional passthrough error.""" + super().__init__( + f"Module {qualified_name} was not intentionally passed through to the sandbox." + ) + + @dataclass(frozen=True) class SandboxRestrictions: """Set of restrictions that can be applied to a sandbox.""" @@ -110,6 +126,13 @@ class methods (including __init__, etc). The check compares the against the fully qualified path to the item. """ + import_notification_policy: temporalio.workflow.SandboxImportNotificationPolicy = ( + temporalio.workflow.SandboxImportNotificationPolicy.WARN_ON_DYNAMIC_IMPORT + ) + """ + The import notification policy to use when an import is triggered during workflow loading or execution. See :py:class:`temporalio.workflow.SandboxImportNotificationPolicy` for options. + """ + passthrough_all_modules: bool = False """ Pass through all modules, do not sandbox any modules. This is the equivalent @@ -170,6 +193,12 @@ def with_passthrough_all_modules(self) -> SandboxRestrictions: """ return dataclasses.replace(self, passthrough_all_modules=True) + def with_import_notification_policy( + self, policy: temporalio.workflow.SandboxImportNotificationPolicy + ) -> SandboxRestrictions: + """Create a new restriction set with the given import notification policy as the :py:attr:`import_policy`.""" + return dataclasses.replace(self, import_notification_policy=policy) + # We intentionally use specific fields instead of generic "matcher" callbacks # for optimization reasons. @@ -305,10 +334,12 @@ def access_matcher( if not child_matcher: return None matcher = child_matcher + if not context.is_runtime and matcher.only_runtime: return None if not matcher.match_self: return None + return matcher def match_access( @@ -819,6 +850,7 @@ def unwrap_if_proxied(v: Any) -> Any: def __init__(self) -> None: """Create a restriction context.""" self.is_runtime = False + self.in_activation = False @dataclass diff --git a/temporalio/worker/workflow_sandbox/_runner.py b/temporalio/worker/workflow_sandbox/_runner.py index e1a48871d..f87736f9d 100644 --- a/temporalio/worker/workflow_sandbox/_runner.py +++ b/temporalio/worker/workflow_sandbox/_runner.py @@ -159,6 +159,7 @@ def activate( self, act: temporalio.bridge.proto.workflow_activation.WorkflowActivation ) -> temporalio.bridge.proto.workflow_completion.WorkflowActivationCompletion: self.importer.restriction_context.is_runtime = True + self.importer.restriction_context.in_activation = True try: self._run_code( "with __temporal_importer.applied():\n" @@ -169,6 +170,7 @@ def activate( return self.globals_and_locals.pop("__temporal_completion") # type: ignore finally: self.importer.restriction_context.is_runtime = False + self.importer.restriction_context.in_activation = False def _run_code(self, code: str, **extra_globals: Any) -> None: for k, v in extra_globals.items(): diff --git a/temporalio/workflow.py b/temporalio/workflow.py index 9168d328c..1dc70c8b4 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -13,7 +13,7 @@ from contextlib import contextmanager from dataclasses import dataclass from datetime import datetime, timedelta, timezone -from enum import Enum, IntEnum +from enum import Enum, Flag, IntEnum, auto from functools import partial from random import Random from typing import ( @@ -1410,6 +1410,20 @@ async def wait_condition( _sandbox_unrestricted = threading.local() _in_sandbox = threading.local() _imports_passed_through = threading.local() +_sandbox_import_notification_policy_override = threading.local() + + +class SandboxImportNotificationPolicy(Flag): + """Defines the behavior taken when modules are imported into the sandbox after the workflow is initially loaded or unintentionally missing from the passthrough list.""" + + SILENT = auto() + """Allow imports that do not violate sandbox restrictions and no warnings are generated.""" + WARN_ON_DYNAMIC_IMPORT = auto() + """Allows dynamic imports that do not violate sandbox restrictions but issues a warning when an import is triggered in the sandbox after initial workflow load.""" + WARN_ON_UNINTENTIONAL_PASSTHROUGH = auto() + """Allows imports that do not violate sandbox restrictions but issues a warning when an import is triggered in the sandbox that was unintentionally passed through.""" + RAISE_ON_UNINTENTIONAL_PASSTHROUGH = auto() + """Raise an error when an import is triggered in the sandbox that was unintentionally passed through.""" class unsafe: @@ -1498,6 +1512,35 @@ def imports_passed_through() -> Iterator[None]: finally: _imports_passed_through.value = False + @staticmethod + def current_import_notification_policy_override() -> ( + Optional[SandboxImportNotificationPolicy] + ): + """Gets the current import notification policy override if one is set.""" + applied_policy = getattr( + _sandbox_import_notification_policy_override, + "value", + None, + ) + return applied_policy + + @staticmethod + @contextmanager + def sandbox_import_notification_policy( + policy: SandboxImportNotificationPolicy, + ) -> Iterator[None]: + """Context manager to apply the given import notification policy.""" + original_policy = _sandbox_import_notification_policy_override.value = getattr( + _sandbox_import_notification_policy_override, + "value", + None, + ) + _sandbox_import_notification_policy_override.value = policy + try: + yield None + finally: + _sandbox_import_notification_policy_override.value = original_policy + class LoggerAdapter(logging.LoggerAdapter): """Adapter that adds details to the log about the running workflow. diff --git a/tests/worker/workflow_sandbox/test_runner.py b/tests/worker/workflow_sandbox/test_runner.py index 73a49c420..fec42dd0c 100644 --- a/tests/worker/workflow_sandbox/test_runner.py +++ b/tests/worker/workflow_sandbox/test_runner.py @@ -11,20 +11,27 @@ from dataclasses import dataclass from datetime import date, datetime, timedelta from enum import IntEnum -from typing import Callable, Dict, List, Optional, Sequence, Set, Type +from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Type import pytest from temporalio import activity, workflow from temporalio.client import Client, WorkflowFailureError, WorkflowHandle from temporalio.exceptions import ApplicationError -from temporalio.worker import Worker +from temporalio.worker import Worker, WorkflowInboundInterceptor +from temporalio.worker._interceptor import ( + ExecuteWorkflowInput, + Interceptor, + WorkflowInterceptorClassInput, +) from temporalio.worker.workflow_sandbox import ( RestrictedWorkflowAccessError, SandboxedWorkflowRunner, SandboxMatcher, SandboxRestrictions, + UnintentionalPassthroughError, ) +from temporalio.workflow import SandboxImportNotificationPolicy from tests.helpers import assert_eq_eventually from tests.worker.workflow_sandbox.testmodules import stateful_module from tests.worker.workflow_sandbox.testmodules.proto import SomeMessage @@ -35,7 +42,7 @@ _ = os.name # This used to fail because our __init__ couldn't handle metaclass init -import zipfile +import zipfile # noqa: E402 class MyZipFile(zipfile.ZipFile): @@ -483,3 +490,179 @@ def new_worker( activities=activities, workflow_runner=SandboxedWorkflowRunner(restrictions=restrictions), ) + + +class _TestWorkflowInboundInterceptor(WorkflowInboundInterceptor): + async def execute_workflow(self, input: ExecuteWorkflowInput) -> Any: + # import in the interceptor to show it will be captured + # applying this policy should squelch the "after initial workload" warning + with workflow.unsafe.sandbox_import_notification_policy( + workflow.SandboxImportNotificationPolicy.WARN_ON_UNINTENTIONAL_PASSTHROUGH + ): + import tests.worker.workflow_sandbox.testmodules.lazy_module_interceptor # noqa: F401 + + return await super().execute_workflow(input) + + +class _TestInterceptor(Interceptor): + def workflow_interceptor_class( + self, input: WorkflowInterceptorClassInput + ) -> Type[_TestWorkflowInboundInterceptor]: + return _TestWorkflowInboundInterceptor + + +@workflow.defn +class LazyImportWorkflow: + @workflow.run + async def run(self) -> None: + try: + import tests.worker.workflow_sandbox.testmodules.lazy_module # noqa: F401 + except UnintentionalPassthroughError as err: + raise ApplicationError( + str(err), type="UnintentionalPassthroughError" + ) from err + + +async def test_workflow_sandbox_import_default_warnings(client: Client): + restrictions = dataclasses.replace( + SandboxRestrictions.default, + # passthrough this test module to avoid a ton of noisy warnings + passthrough_modules=SandboxRestrictions.passthrough_modules_default + | {"tests.worker.workflow_sandbox.test_runner"}, + ) + + async with Worker( + client, + task_queue=str(uuid.uuid4()), + workflows=[LazyImportWorkflow], + workflow_runner=SandboxedWorkflowRunner(restrictions), + ) as worker: + with pytest.warns() as recorder: + await client.execute_workflow( + LazyImportWorkflow.run, + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + _assert_expected_warnings( + recorder, + { + "Module tests.worker.workflow_sandbox.testmodules.lazy_module was imported after initial workflow load.", + }, + ) + + +async def test_workflow_sandbox_import_all_warnings(client: Client): + restrictions = dataclasses.replace( + SandboxRestrictions.default, + import_notification_policy=SandboxImportNotificationPolicy.WARN_ON_DYNAMIC_IMPORT + | SandboxImportNotificationPolicy.WARN_ON_UNINTENTIONAL_PASSTHROUGH, + # passthrough this test module to avoid a ton of noisy warnings + passthrough_modules=SandboxRestrictions.passthrough_modules_default + | {"tests.worker.workflow_sandbox.test_runner"}, + ) + + async with Worker( + client, + task_queue=str(uuid.uuid4()), + workflows=[LazyImportWorkflow], + interceptors=[_TestInterceptor()], + workflow_runner=SandboxedWorkflowRunner(restrictions), + ) as worker: + with pytest.warns() as recorder: + await client.execute_workflow( + LazyImportWorkflow.run, + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + _assert_expected_warnings( + recorder, + { + "Module tests.worker.workflow_sandbox.testmodules.lazy_module_interceptor was not intentionally passed through to the sandbox.", + "Module tests.worker.workflow_sandbox.testmodules.lazy_module was imported after initial workflow load.", + "Module tests.worker.workflow_sandbox.testmodules.lazy_module was not intentionally passed through to the sandbox.", + }, + ) + + +async def test_workflow_sandbox_import_errors(client: Client): + restrictions = dataclasses.replace( + SandboxRestrictions.default, + import_notification_policy=SandboxImportNotificationPolicy.WARN_ON_DYNAMIC_IMPORT + | SandboxImportNotificationPolicy.RAISE_ON_UNINTENTIONAL_PASSTHROUGH, + # passthrough this test module to avoid a ton of noisy warnings + passthrough_modules=SandboxRestrictions.passthrough_modules_default + | {"tests.worker.workflow_sandbox.test_runner"}, + ) + + async with Worker( + client, + task_queue=str(uuid.uuid4()), + workflows=[LazyImportWorkflow], + workflow_runner=SandboxedWorkflowRunner(restrictions), + ) as worker: + with pytest.warns() as recorder: + with pytest.raises(WorkflowFailureError) as err: + await client.execute_workflow( + LazyImportWorkflow.run, + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + assert isinstance(err.value.cause, ApplicationError) + assert err.value.cause.type == "UnintentionalPassthroughError" + assert ( + "Module tests.worker.workflow_sandbox.testmodules.lazy_module was not intentionally passed through to the sandbox." + == err.value.cause.message + ) + + _assert_expected_warnings( + recorder, + { + "Module tests.worker.workflow_sandbox.testmodules.lazy_module was imported after initial workflow load.", + }, + ) + + +@workflow.defn +class SupressWarningsLazyImportWorkflow: + @workflow.run + async def run(self) -> None: + with workflow.unsafe.sandbox_import_notification_policy( + SandboxImportNotificationPolicy.SILENT + ): + try: + import tests.worker.workflow_sandbox.testmodules.lazy_module # noqa: F401 + except UserWarning: + raise ApplicationError("No warnings were expected") + + +async def test_workflow_sandbox_import_suppress_warnings(client: Client): + restrictions = dataclasses.replace( + SandboxRestrictions.default, + # passthrough this test module to avoid a ton of noisy warnings + passthrough_modules=SandboxRestrictions.passthrough_modules_default + | {"tests.worker.workflow_sandbox.test_runner"}, + ) + + async with Worker( + client, + task_queue=str(uuid.uuid4()), + workflows=[SupressWarningsLazyImportWorkflow], + workflow_runner=SandboxedWorkflowRunner(restrictions), + ) as worker: + with pytest.warns(None) as recorder: # type:ignore + await client.execute_workflow( + SupressWarningsLazyImportWorkflow.run, + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + assert len(recorder) == 0, "Expected no warnings to be issued" + + +def _assert_expected_warnings( + recorder: pytest.WarningsRecorder, expected_warnings: Set[str] +): + actual_warnings = {str(w.message) for w in recorder} + assert expected_warnings <= actual_warnings diff --git a/tests/worker/workflow_sandbox/testmodules/lazy_module.py b/tests/worker/workflow_sandbox/testmodules/lazy_module.py new file mode 100644 index 000000000..3c378a328 --- /dev/null +++ b/tests/worker/workflow_sandbox/testmodules/lazy_module.py @@ -0,0 +1,2 @@ +# intentionally empty +# used during import warning tests diff --git a/tests/worker/workflow_sandbox/testmodules/lazy_module_interceptor.py b/tests/worker/workflow_sandbox/testmodules/lazy_module_interceptor.py new file mode 100644 index 000000000..3c378a328 --- /dev/null +++ b/tests/worker/workflow_sandbox/testmodules/lazy_module_interceptor.py @@ -0,0 +1,2 @@ +# intentionally empty +# used during import warning tests