Skip to content

Commit 29f934f

Browse files
committed
initial explorations into warnings and errors for dynamic imports and non-passthroughs
1 parent 5994a45 commit 29f934f

File tree

6 files changed

+219
-3
lines changed

6 files changed

+219
-3
lines changed

temporalio/worker/workflow_sandbox/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,13 @@
5858
RestrictedWorkflowAccessError,
5959
SandboxMatcher,
6060
SandboxRestrictions,
61+
DisallowedUnintentionalPassthroughError,
6162
)
6263
from ._runner import SandboxedWorkflowRunner
6364

6465
__all__ = [
6566
"RestrictedWorkflowAccessError",
67+
"DisallowedUnintentionalPassthroughError",
6668
"SandboxedWorkflowRunner",
6769
"SandboxMatcher",
6870
"SandboxRestrictions",

temporalio/worker/workflow_sandbox/_importer.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from ._restrictions import (
4141
RestrictedModule,
4242
RestrictedWorkflowAccessError,
43+
DisallowedUnintentionalPassthroughError,
4344
RestrictionContext,
4445
SandboxRestrictions,
4546
)
@@ -200,6 +201,17 @@ def _import(
200201

201202
# Check module restrictions and passthrough modules
202203
if full_name not in sys.modules:
204+
# Issue a warning if appropriate
205+
if (
206+
self.restriction_context.in_activation
207+
and self._is_import_policy_applied(
208+
temporalio.workflow.SandboxImportPolicy.WARN_ON_DYNAMIC_IMPORT
209+
)
210+
):
211+
warnings.warn(
212+
f"Module {name} was imported after initial workflow load."
213+
)
214+
203215
# Make sure not an entirely invalid module
204216
self._assert_valid_module(full_name)
205217

@@ -282,13 +294,35 @@ def module_configured_passthrough(self, name: str) -> bool:
282294
break
283295
return True
284296

297+
def _is_import_policy_applied(
298+
self, policy: temporalio.workflow.SandboxImportPolicy
299+
) -> bool:
300+
applied_policy = temporalio.workflow.unsafe.current_sandbox_import_policy()
301+
if applied_policy != temporalio.workflow.SandboxImportPolicy.UNSET:
302+
return policy in applied_policy
303+
304+
return policy in self.restrictions.import_policy
305+
285306
def _maybe_passthrough_module(self, name: str) -> Optional[types.ModuleType]:
286307
# If imports not passed through and all modules are not passed through
287308
# and name not in passthrough modules, check parents
288309
if (
289310
not temporalio.workflow.unsafe.is_imports_passed_through()
290311
and not self.module_configured_passthrough(name)
291312
):
313+
if self._is_import_policy_applied(
314+
temporalio.workflow.SandboxImportPolicy.RAISE_ON_NON_PASSTHROUGH
315+
):
316+
# TODO(amazzeo): this is not an appropriate error type
317+
raise DisallowedUnintentionalPassthroughError(name)
318+
319+
if self._is_import_policy_applied(
320+
temporalio.workflow.SandboxImportPolicy.WARN_ON_NON_PASSTHROUGH
321+
):
322+
warnings.warn(
323+
f"Module {name} was not intentionally passed through to the sandbox."
324+
)
325+
292326
return None
293327
# Do the pass through
294328
with self._unapplied():

temporalio/worker/workflow_sandbox/_restrictions.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,23 @@ def default_message(qualified_name: str) -> str:
8282
)
8383

8484

85+
# TODO(amazzeo): is NondeterminisimError appropriate as a subclass?
86+
class DisallowedUnintentionalPassthroughError(temporalio.workflow.NondeterminismError):
87+
def __init__(
88+
self, qualified_name: str, *, override_message: Optional[str] = None
89+
) -> None:
90+
super().__init__(
91+
override_message
92+
or DisallowedUnintentionalPassthroughError.default_message(qualified_name)
93+
)
94+
self.qualified_name = qualified_name
95+
96+
@staticmethod
97+
def default_message(qualified_name: str) -> str:
98+
"""Get default message for restricted access."""
99+
return f"Module {qualified_name} was not intentionally passed through to the sandbox."
100+
101+
85102
@dataclass(frozen=True)
86103
class SandboxRestrictions:
87104
"""Set of restrictions that can be applied to a sandbox."""
@@ -110,6 +127,15 @@ class methods (including __init__, etc). The check compares the against the
110127
fully qualified path to the item.
111128
"""
112129

130+
import_policy: temporalio.workflow.SandboxImportPolicy = (
131+
temporalio.workflow.SandboxImportPolicy.WARN_ON_DYNAMIC_IMPORT
132+
)
133+
134+
import_policy_all_warnings: ClassVar[temporalio.workflow.SandboxImportPolicy]
135+
import_policy_disallow_unintentional_passthrough: ClassVar[
136+
temporalio.workflow.SandboxImportPolicy
137+
]
138+
113139
passthrough_all_modules: bool = False
114140
"""
115141
Pass through all modules, do not sandbox any modules. This is the equivalent
@@ -285,6 +311,7 @@ def access_matcher(
285311
Returns:
286312
The matcher if matched.
287313
"""
314+
288315
# We prefer to avoid recursion
289316
matcher = self
290317
for v in child_path:
@@ -305,10 +332,12 @@ def access_matcher(
305332
if not child_matcher:
306333
return None
307334
matcher = child_matcher
335+
308336
if not context.is_runtime and matcher.only_runtime:
309337
return None
310338
if not matcher.match_self:
311339
return None
340+
312341
return matcher
313342

314343
def match_access(
@@ -496,6 +525,15 @@ def with_child_unrestricted(self, *child_path: str) -> SandboxMatcher:
496525
}
497526
)
498527

528+
SandboxRestrictions.import_policy_all_warnings = (
529+
temporalio.workflow.SandboxImportPolicy.WARN_ON_DYNAMIC_IMPORT
530+
| temporalio.workflow.SandboxImportPolicy.WARN_ON_NON_PASSTHROUGH
531+
)
532+
SandboxRestrictions.import_policy_disallow_unintentional_passthrough = (
533+
temporalio.workflow.SandboxImportPolicy.WARN_ON_DYNAMIC_IMPORT
534+
| temporalio.workflow.SandboxImportPolicy.RAISE_ON_NON_PASSTHROUGH
535+
)
536+
499537
# sys.stdlib_module_names is only available on 3.10+, so we hardcode here. A
500538
# test will fail if this list doesn't match the latest Python version it was
501539
# generated against, spitting out the expected list. This is a string instead
@@ -819,6 +857,7 @@ def unwrap_if_proxied(v: Any) -> Any:
819857
def __init__(self) -> None:
820858
"""Create a restriction context."""
821859
self.is_runtime = False
860+
self.in_activation = False
822861

823862

824863
@dataclass

temporalio/worker/workflow_sandbox/_runner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ def activate(
159159
self, act: temporalio.bridge.proto.workflow_activation.WorkflowActivation
160160
) -> temporalio.bridge.proto.workflow_completion.WorkflowActivationCompletion:
161161
self.importer.restriction_context.is_runtime = True
162+
self.importer.restriction_context.in_activation = True
162163
try:
163164
self._run_code(
164165
"with __temporal_importer.applied():\n"
@@ -169,6 +170,7 @@ def activate(
169170
return self.globals_and_locals.pop("__temporal_completion") # type: ignore
170171
finally:
171172
self.importer.restriction_context.is_runtime = False
173+
self.importer.restriction_context.in_activation = False
172174

173175
def _run_code(self, code: str, **extra_globals: Any) -> None:
174176
for k, v in extra_globals.items():

temporalio/workflow.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from contextlib import contextmanager
1414
from dataclasses import dataclass
1515
from datetime import datetime, timedelta, timezone
16-
from enum import Enum, IntEnum
16+
from enum import Enum, Flag, IntEnum, auto
1717
from functools import partial
1818
from random import Random
1919
from typing import (
@@ -1412,6 +1412,24 @@ async def wait_condition(
14121412
_imports_passed_through = threading.local()
14131413

14141414

1415+
# TODO(amazzeo): this name is confusing with existing import passthrough phrasing.
1416+
class SandboxImportPolicy(Flag):
1417+
# TODO(amazzeo): Need to establish consistent phrasing about the time this is applied
1418+
"""Defines the behavior taken when modules are imported into the sandbox after the workflow is initially loaded."""
1419+
1420+
UNSET = auto()
1421+
"""All imports that do not violate sandbox restrictions are allowed and no warning is generated."""
1422+
WARN_ON_DYNAMIC_IMPORT = auto()
1423+
"""Issue a warning when an import is triggered in the sandbox after initial workflow load."""
1424+
WARN_ON_NON_PASSTHROUGH = auto()
1425+
"""Issue a warning when an import is triggered in the sandbox that was not passed through."""
1426+
RAISE_ON_NON_PASSTHROUGH = auto()
1427+
"""Raise an error when an import is triggered in the sandbox that was not passed through."""
1428+
1429+
1430+
_sandbox_import_policy = threading.local()
1431+
1432+
14151433
class unsafe:
14161434
"""Contains static methods that should not normally be called during
14171435
workflow execution except in advanced cases.
@@ -1498,6 +1516,26 @@ def imports_passed_through() -> Iterator[None]:
14981516
finally:
14991517
_imports_passed_through.value = False
15001518

1519+
@staticmethod
1520+
def current_sandbox_import_policy() -> SandboxImportPolicy:
1521+
applied_policy = getattr(
1522+
_sandbox_import_policy, "value", SandboxImportPolicy.UNSET
1523+
)
1524+
return applied_policy
1525+
1526+
@staticmethod
1527+
@contextmanager
1528+
def sandbox_import_policy(policy: SandboxImportPolicy) -> Iterator[None]:
1529+
# TODO(amazzeo): the default behavior here seems inappropriate
1530+
original_policy = _sandbox_import_policy.value = getattr(
1531+
_sandbox_import_policy, "value", SandboxImportPolicy.UNSET
1532+
)
1533+
_sandbox_import_policy.value = policy
1534+
try:
1535+
yield None
1536+
finally:
1537+
_sandbox_import_policy.value = original_policy
1538+
15011539

15021540
class LoggerAdapter(logging.LoggerAdapter):
15031541
"""Adapter that adds details to the log about the running workflow.

tests/worker/workflow_sandbox/test_runner.py

Lines changed: 103 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import asyncio
44
import dataclasses
55
import functools
6+
import importlib
67
import inspect
78
import os
89
import sys
@@ -11,19 +12,25 @@
1112
from dataclasses import dataclass
1213
from datetime import date, datetime, timedelta
1314
from enum import IntEnum
14-
from typing import Callable, Dict, List, Optional, Sequence, Set, Type
15+
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Type
1516

1617
import pytest
1718

1819
from temporalio import activity, workflow
1920
from temporalio.client import Client, WorkflowFailureError, WorkflowHandle
2021
from temporalio.exceptions import ApplicationError
21-
from temporalio.worker import Worker
22+
from temporalio.worker import Worker, WorkflowInboundInterceptor
23+
from temporalio.worker._interceptor import (
24+
ExecuteWorkflowInput,
25+
Interceptor,
26+
WorkflowInterceptorClassInput,
27+
)
2228
from temporalio.worker.workflow_sandbox import (
2329
RestrictedWorkflowAccessError,
2430
SandboxedWorkflowRunner,
2531
SandboxMatcher,
2632
SandboxRestrictions,
33+
DisallowedUnintentionalPassthroughError,
2734
)
2835
from tests.helpers import assert_eq_eventually
2936
from tests.worker.workflow_sandbox.testmodules import stateful_module
@@ -483,3 +490,97 @@ def new_worker(
483490
activities=activities,
484491
workflow_runner=SandboxedWorkflowRunner(restrictions=restrictions),
485492
)
493+
494+
495+
class _TestWorkflowInboundInterceptor(WorkflowInboundInterceptor):
496+
async def execute_workflow(self, input: ExecuteWorkflowInput) -> Any:
497+
# import in the interceptor to show it will be captured
498+
# applying this policy should squelch the "after initial workload" warning
499+
with workflow.unsafe.sandbox_import_policy(
500+
workflow.SandboxImportPolicy.WARN_ON_NON_PASSTHROUGH
501+
):
502+
import opentelemetry
503+
504+
return await super().execute_workflow(input)
505+
506+
507+
class _TestInterceptor(Interceptor):
508+
def workflow_interceptor_class(
509+
self, input: WorkflowInterceptorClassInput
510+
) -> Type[_TestWorkflowInboundInterceptor]:
511+
return _TestWorkflowInboundInterceptor
512+
513+
514+
@workflow.defn
515+
class LazyImportWorkflow:
516+
@workflow.run
517+
async def run(self) -> None:
518+
try:
519+
import opentelemetry.version
520+
except DisallowedUnintentionalPassthroughError as err:
521+
raise ApplicationError(
522+
str(err), type="DisallowedUnintentionalPassthroughError"
523+
) from err
524+
525+
526+
async def test_workflow_sandbox_import_warnings(client: Client):
527+
restrictions = dataclasses.replace(
528+
SandboxRestrictions.default,
529+
import_policy=SandboxRestrictions.import_policy_all_warnings,
530+
# passthrough this test module to avoid a ton of noisy warnings
531+
passthrough_modules=SandboxRestrictions.passthrough_modules_default
532+
| {"tests.worker.workflow_sandbox.test_runner"},
533+
)
534+
535+
async with Worker(
536+
client,
537+
task_queue=str(uuid.uuid4()),
538+
workflows=[LazyImportWorkflow],
539+
interceptors=[_TestInterceptor()],
540+
workflow_runner=SandboxedWorkflowRunner(restrictions),
541+
) as worker:
542+
with pytest.warns() as records:
543+
await client.execute_workflow(
544+
LazyImportWorkflow.run,
545+
id=f"workflow-{uuid.uuid4()}",
546+
task_queue=worker.task_queue,
547+
)
548+
549+
expected_warnings = {
550+
"Module opentelemetry was not intentionally passed through to the sandbox.",
551+
"Module opentelemetry.version was imported after initial workflow load.",
552+
"Module opentelemetry.version was not intentionally passed through to the sandbox.",
553+
}
554+
actual_warnings = {str(w.message) for w in records}
555+
556+
assert expected_warnings <= actual_warnings
557+
558+
559+
async def test_workflow_sandbox_import_errors(client: Client):
560+
restrictions = dataclasses.replace(
561+
SandboxRestrictions.default,
562+
import_policy=SandboxRestrictions.import_policy_disallow_unintentional_passthrough,
563+
# passthrough this test module to avoid a ton of noisy warnings
564+
passthrough_modules=SandboxRestrictions.passthrough_modules_default
565+
| {"tests.worker.workflow_sandbox.test_runner"},
566+
)
567+
568+
async with Worker(
569+
client,
570+
task_queue=str(uuid.uuid4()),
571+
workflows=[LazyImportWorkflow],
572+
workflow_runner=SandboxedWorkflowRunner(restrictions),
573+
) as worker:
574+
with pytest.raises(WorkflowFailureError) as err:
575+
await client.execute_workflow(
576+
LazyImportWorkflow.run,
577+
id=f"workflow-{uuid.uuid4()}",
578+
task_queue=worker.task_queue,
579+
)
580+
581+
assert isinstance(err.value.cause, ApplicationError)
582+
assert err.value.cause.type == "DisallowedUnintentionalPassthroughError"
583+
assert (
584+
"Module opentelemetry.version was not intentionally passed through to the sandbox."
585+
== err.value.cause.message
586+
)

0 commit comments

Comments
 (0)