|
67 | 67 | from temporalio.testing import WorkflowEnvironment |
68 | 68 | from temporalio.worker import ( |
69 | 69 | UnsandboxedWorkflowRunner, |
| 70 | + Worker, |
70 | 71 | WorkflowInstance, |
71 | 72 | WorkflowInstanceDetails, |
72 | 73 | WorkflowRunner, |
@@ -2069,6 +2070,87 @@ async def query_result(handle: WorkflowHandle) -> str: |
2069 | 2070 | # await query_result(patch_handle) |
2070 | 2071 |
|
2071 | 2072 |
|
| 2073 | +@workflow.defn(name="patch-memoized") |
| 2074 | +class PatchMemoizedWorkflowUnpatched: |
| 2075 | + def __init__(self, *, should_patch: bool = False) -> None: |
| 2076 | + self.should_patch = should_patch |
| 2077 | + self._waiting_signal = True |
| 2078 | + |
| 2079 | + @workflow.run |
| 2080 | + async def run(self) -> List[str]: |
| 2081 | + results: List[str] = [] |
| 2082 | + if self.should_patch and workflow.patched("some-patch"): |
| 2083 | + results.append("pre-patch") |
| 2084 | + self._waiting_signal = True |
| 2085 | + await workflow.wait_condition(lambda: not self._waiting_signal) |
| 2086 | + results.append("some-value") |
| 2087 | + if self.should_patch and workflow.patched("some-patch"): |
| 2088 | + results.append("post-patch") |
| 2089 | + return results |
| 2090 | + |
| 2091 | + @workflow.signal |
| 2092 | + def signal(self) -> None: |
| 2093 | + self._waiting_signal = False |
| 2094 | + |
| 2095 | + @workflow.query |
| 2096 | + def waiting_signal(self) -> bool: |
| 2097 | + return self._waiting_signal |
| 2098 | + |
| 2099 | + |
| 2100 | +@workflow.defn(name="patch-memoized") |
| 2101 | +class PatchMemoizedWorkflowPatched(PatchMemoizedWorkflowUnpatched): |
| 2102 | + def __init__(self) -> None: |
| 2103 | + super().__init__(should_patch=True) |
| 2104 | + |
| 2105 | + @workflow.run |
| 2106 | + async def run(self) -> List[str]: |
| 2107 | + return await super().run() |
| 2108 | + |
| 2109 | + |
| 2110 | +async def test_workflow_patch_memoized(client: Client): |
| 2111 | + # Start a worker with the workflow unpatched and wait until halfway through |
| 2112 | + task_queue = f"tq-{uuid.uuid4()}" |
| 2113 | + async with Worker( |
| 2114 | + client, task_queue=task_queue, workflows=[PatchMemoizedWorkflowUnpatched] |
| 2115 | + ): |
| 2116 | + pre_patch_handle = await client.start_workflow( |
| 2117 | + PatchMemoizedWorkflowUnpatched.run, |
| 2118 | + id=f"workflow-{uuid.uuid4()}", |
| 2119 | + task_queue=task_queue, |
| 2120 | + ) |
| 2121 | + |
| 2122 | + # Need to wait until it has gotten halfway through |
| 2123 | + async def waiting_signal() -> bool: |
| 2124 | + return await pre_patch_handle.query( |
| 2125 | + PatchMemoizedWorkflowUnpatched.waiting_signal |
| 2126 | + ) |
| 2127 | + |
| 2128 | + await assert_eq_eventually(True, waiting_signal) |
| 2129 | + |
| 2130 | + # Now start the worker again, but this time with a patched workflow |
| 2131 | + async with Worker( |
| 2132 | + client, task_queue=task_queue, workflows=[PatchMemoizedWorkflowPatched] |
| 2133 | + ): |
| 2134 | + # Start a new workflow post patch |
| 2135 | + post_patch_handle = await client.start_workflow( |
| 2136 | + PatchMemoizedWorkflowPatched.run, |
| 2137 | + id=f"workflow-{uuid.uuid4()}", |
| 2138 | + task_queue=task_queue, |
| 2139 | + ) |
| 2140 | + |
| 2141 | + # Send signal to both and check results |
| 2142 | + await pre_patch_handle.signal(PatchMemoizedWorkflowPatched.signal) |
| 2143 | + await post_patch_handle.signal(PatchMemoizedWorkflowPatched.signal) |
| 2144 | + |
| 2145 | + # Confirm expected values |
| 2146 | + assert ["some-value"] == await pre_patch_handle.result() |
| 2147 | + assert [ |
| 2148 | + "pre-patch", |
| 2149 | + "some-value", |
| 2150 | + "post-patch", |
| 2151 | + ] == await post_patch_handle.result() |
| 2152 | + |
| 2153 | + |
2072 | 2154 | @workflow.defn |
2073 | 2155 | class UUIDWorkflow: |
2074 | 2156 | def __init__(self) -> None: |
|
0 commit comments