|  | 
| 3 | 3 | import asyncio | 
| 4 | 4 | import dataclasses | 
| 5 | 5 | import functools | 
|  | 6 | +import importlib | 
| 6 | 7 | import inspect | 
| 7 | 8 | import os | 
| 8 | 9 | import sys | 
|  | 
| 11 | 12 | from dataclasses import dataclass | 
| 12 | 13 | from datetime import date, datetime, timedelta | 
| 13 | 14 | from enum import IntEnum | 
| 14 |  | -from typing import Callable, Dict, List, Optional, Sequence, Set, Type | 
|  | 15 | +from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Type | 
| 15 | 16 | 
 | 
| 16 | 17 | import pytest | 
| 17 | 18 | 
 | 
| 18 | 19 | from temporalio import activity, workflow | 
| 19 | 20 | from temporalio.client import Client, WorkflowFailureError, WorkflowHandle | 
| 20 | 21 | from temporalio.exceptions import ApplicationError | 
| 21 |  | -from temporalio.worker import Worker | 
|  | 22 | +from temporalio.worker import Worker, WorkflowInboundInterceptor | 
|  | 23 | +from temporalio.worker._interceptor import ( | 
|  | 24 | +    ExecuteWorkflowInput, | 
|  | 25 | +    Interceptor, | 
|  | 26 | +    WorkflowInterceptorClassInput, | 
|  | 27 | +) | 
| 22 | 28 | from temporalio.worker.workflow_sandbox import ( | 
| 23 | 29 |     RestrictedWorkflowAccessError, | 
| 24 | 30 |     SandboxedWorkflowRunner, | 
| 25 | 31 |     SandboxMatcher, | 
| 26 | 32 |     SandboxRestrictions, | 
|  | 33 | +    DisallowedUnintentionalPassthroughError, | 
| 27 | 34 | ) | 
| 28 | 35 | from tests.helpers import assert_eq_eventually | 
| 29 | 36 | from tests.worker.workflow_sandbox.testmodules import stateful_module | 
| @@ -483,3 +490,97 @@ def new_worker( | 
| 483 | 490 |         activities=activities, | 
| 484 | 491 |         workflow_runner=SandboxedWorkflowRunner(restrictions=restrictions), | 
| 485 | 492 |     ) | 
|  | 493 | + | 
|  | 494 | + | 
|  | 495 | +class _TestWorkflowInboundInterceptor(WorkflowInboundInterceptor): | 
|  | 496 | +    async def execute_workflow(self, input: ExecuteWorkflowInput) -> Any: | 
|  | 497 | +        # import in the interceptor to show it will be captured | 
|  | 498 | +        # applying this policy should squelch the "after initial workload" warning | 
|  | 499 | +        with workflow.unsafe.sandbox_import_policy( | 
|  | 500 | +            workflow.SandboxImportPolicy.WARN_ON_NON_PASSTHROUGH | 
|  | 501 | +        ): | 
|  | 502 | +            import opentelemetry | 
|  | 503 | + | 
|  | 504 | +        return await super().execute_workflow(input) | 
|  | 505 | + | 
|  | 506 | + | 
|  | 507 | +class _TestInterceptor(Interceptor): | 
|  | 508 | +    def workflow_interceptor_class( | 
|  | 509 | +        self, input: WorkflowInterceptorClassInput | 
|  | 510 | +    ) -> Type[_TestWorkflowInboundInterceptor]: | 
|  | 511 | +        return _TestWorkflowInboundInterceptor | 
|  | 512 | + | 
|  | 513 | + | 
|  | 514 | +@workflow.defn | 
|  | 515 | +class LazyImportWorkflow: | 
|  | 516 | +    @workflow.run | 
|  | 517 | +    async def run(self) -> None: | 
|  | 518 | +        try: | 
|  | 519 | +            import opentelemetry.version | 
|  | 520 | +        except DisallowedUnintentionalPassthroughError as err: | 
|  | 521 | +            raise ApplicationError( | 
|  | 522 | +                str(err), type="DisallowedUnintentionalPassthroughError" | 
|  | 523 | +            ) from err | 
|  | 524 | + | 
|  | 525 | + | 
|  | 526 | +async def test_workflow_sandbox_import_warnings(client: Client): | 
|  | 527 | +    restrictions = dataclasses.replace( | 
|  | 528 | +        SandboxRestrictions.default, | 
|  | 529 | +        import_policy=SandboxRestrictions.import_policy_all_warnings, | 
|  | 530 | +        # passthrough this test module to avoid a ton of noisy warnings | 
|  | 531 | +        passthrough_modules=SandboxRestrictions.passthrough_modules_default | 
|  | 532 | +        | {"tests.worker.workflow_sandbox.test_runner"}, | 
|  | 533 | +    ) | 
|  | 534 | + | 
|  | 535 | +    async with Worker( | 
|  | 536 | +        client, | 
|  | 537 | +        task_queue=str(uuid.uuid4()), | 
|  | 538 | +        workflows=[LazyImportWorkflow], | 
|  | 539 | +        interceptors=[_TestInterceptor()], | 
|  | 540 | +        workflow_runner=SandboxedWorkflowRunner(restrictions), | 
|  | 541 | +    ) as worker: | 
|  | 542 | +        with pytest.warns() as records: | 
|  | 543 | +            await client.execute_workflow( | 
|  | 544 | +                LazyImportWorkflow.run, | 
|  | 545 | +                id=f"workflow-{uuid.uuid4()}", | 
|  | 546 | +                task_queue=worker.task_queue, | 
|  | 547 | +            ) | 
|  | 548 | + | 
|  | 549 | +            expected_warnings = { | 
|  | 550 | +                "Module opentelemetry was not intentionally passed through to the sandbox.", | 
|  | 551 | +                "Module opentelemetry.version was imported after initial workflow load.", | 
|  | 552 | +                "Module opentelemetry.version was not intentionally passed through to the sandbox.", | 
|  | 553 | +            } | 
|  | 554 | +            actual_warnings = {str(w.message) for w in records} | 
|  | 555 | + | 
|  | 556 | +            assert expected_warnings <= actual_warnings | 
|  | 557 | + | 
|  | 558 | + | 
|  | 559 | +async def test_workflow_sandbox_import_errors(client: Client): | 
|  | 560 | +    restrictions = dataclasses.replace( | 
|  | 561 | +        SandboxRestrictions.default, | 
|  | 562 | +        import_policy=SandboxRestrictions.import_policy_disallow_unintentional_passthrough, | 
|  | 563 | +        # passthrough this test module to avoid a ton of noisy warnings | 
|  | 564 | +        passthrough_modules=SandboxRestrictions.passthrough_modules_default | 
|  | 565 | +        | {"tests.worker.workflow_sandbox.test_runner"}, | 
|  | 566 | +    ) | 
|  | 567 | + | 
|  | 568 | +    async with Worker( | 
|  | 569 | +        client, | 
|  | 570 | +        task_queue=str(uuid.uuid4()), | 
|  | 571 | +        workflows=[LazyImportWorkflow], | 
|  | 572 | +        workflow_runner=SandboxedWorkflowRunner(restrictions), | 
|  | 573 | +    ) as worker: | 
|  | 574 | +        with pytest.raises(WorkflowFailureError) as err: | 
|  | 575 | +            await client.execute_workflow( | 
|  | 576 | +                LazyImportWorkflow.run, | 
|  | 577 | +                id=f"workflow-{uuid.uuid4()}", | 
|  | 578 | +                task_queue=worker.task_queue, | 
|  | 579 | +            ) | 
|  | 580 | + | 
|  | 581 | +        assert isinstance(err.value.cause, ApplicationError) | 
|  | 582 | +        assert err.value.cause.type == "DisallowedUnintentionalPassthroughError" | 
|  | 583 | +        assert ( | 
|  | 584 | +            "Module opentelemetry.version was not intentionally passed through to the sandbox." | 
|  | 585 | +            == err.value.cause.message | 
|  | 586 | +        ) | 
0 commit comments