Skip to content

Commit 1694f3f

Browse files
committed
Make it easier to modify sandbox restrictions
1 parent b485c77 commit 1694f3f

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

temporalio/worker/workflow_sandbox/_runner.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from __future__ import annotations
88

99
import threading
10+
from dataclasses import dataclass
1011
from datetime import datetime, timedelta, timezone
1112
from typing import Any, Optional, Sequence, Type
1213

@@ -52,9 +53,12 @@
5253
)
5354

5455

56+
@dataclass
5557
class SandboxedWorkflowRunner(WorkflowRunner):
5658
"""Runner for workflows in a sandbox."""
5759

60+
restrictions: SandboxRestrictions = SandboxRestrictions.default
61+
5862
def __init__(
5963
self,
6064
*,
@@ -70,8 +74,8 @@ def __init__(
7074
re-imported and instantiated for *each* workflow run.
7175
"""
7276
super().__init__()
77+
self.restrictions = restrictions
7378
self._runner_class = runner_class
74-
self._restrictions = restrictions
7579
self._worker_level_failure_exception_types: Sequence[type[BaseException]] = []
7680

7781
def prepare_workflow(self, defn: temporalio.workflow._Definition) -> None:
@@ -94,7 +98,7 @@ def prepare_workflow(self, defn: temporalio.workflow._Definition) -> None:
9498

9599
def create_instance(self, det: WorkflowInstanceDetails) -> WorkflowInstance:
96100
"""Implements :py:meth:`WorkflowRunner.create_instance`."""
97-
return _Instance(det, self._runner_class, self._restrictions)
101+
return _Instance(det, self._runner_class, self.restrictions)
98102

99103
def set_worker_level_failure_exception_types(
100104
self, types: Sequence[type[BaseException]]

tests/test_plugins.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import warnings
2+
from typing import cast
23

34
import pytest
45

@@ -7,6 +8,7 @@
78
from temporalio.client import Client, ClientConfig, OutboundInterceptor
89
from temporalio.testing import WorkflowEnvironment
910
from temporalio.worker import Worker, WorkerConfig
11+
from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner
1012
from tests.worker.test_worker import never_run_activity
1113

1214

@@ -64,6 +66,9 @@ def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
6466
class MyWorkerPlugin(temporalio.worker.Plugin):
6567
def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
6668
config["task_queue"] = "replaced_queue"
69+
runner = config.get("workflow_runner")
70+
if isinstance(runner, SandboxedWorkflowRunner):
71+
runner.restrictions.passthrough_modules.add("my_module")
6772
return super().configure_worker(config)
6873

6974
async def run_worker(self, worker: Worker) -> None:
@@ -111,3 +116,19 @@ async def test_worker_duplicated_plugin(client: Client) -> None:
111116

112117
assert len(warning_list) == 1
113118
assert "The same plugin type" in str(warning_list[0].message)
119+
120+
121+
async def test_worker_sandbox_restrictions(client: Client) -> None:
122+
with warnings.catch_warnings(record=True) as warning_list:
123+
worker = Worker(
124+
client,
125+
task_queue="queue",
126+
activities=[never_run_activity],
127+
plugins=[MyWorkerPlugin()],
128+
)
129+
assert (
130+
"my_module"
131+
in cast(
132+
SandboxedWorkflowRunner, worker.config().get("workflow_runner")
133+
).restrictions.passthrough_modules
134+
)

0 commit comments

Comments
 (0)