Skip to content

Commit b60c0f1

Browse files
committed
On start callback, test interceptor
1 parent f4f446e commit b60c0f1

File tree

2 files changed

+95
-11
lines changed

2 files changed

+95
-11
lines changed

temporalio/client.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -992,16 +992,22 @@ async def _start_update_with_start(
992992
update_workflow_input=update_input,
993993
)
994994

995+
def on_start(
996+
start_response: temporalio.api.workflowservice.v1.StartWorkflowExecutionResponse,
997+
):
998+
start_workflow_operation._workflow_handle.set_result(
999+
WorkflowHandle(
1000+
self,
1001+
start_workflow_operation._start_workflow_input.id,
1002+
first_execution_run_id=start_response.run_id,
1003+
)
1004+
)
1005+
1006+
setattr(input, "_on_start_callback", on_start)
1007+
9951008
update_handle = await self._impl.start_workflow_update_with_start(input)
9961009
# TODO https://github.com/temporalio/sdk-python/issues/682
9971010
assert update_handle.workflow_run_id, "In Client.start_workflow_update why don't we use the run ID from the update response?"
998-
start_workflow_operation._workflow_handle.set_result(
999-
WorkflowHandle(
1000-
self,
1001-
update_handle.workflow_id,
1002-
first_execution_run_id=update_handle.workflow_run_id,
1003-
)
1004-
)
10051011
return update_handle
10061012

10071013
def list_workflows(
@@ -5878,6 +5884,14 @@ async def start_workflow_update_with_start(
58785884
),
58795885
],
58805886
)
5887+
5888+
on_start = getattr(input, "_on_start_callback", None)
5889+
if not on_start:
5890+
raise RuntimeError(
5891+
"Missing on_start callback. Please report this as a bug at https://github.com/temporalio/sdk-python/issues."
5892+
)
5893+
seen_start = False
5894+
58815895
# Repeatedly try to invoke ExecuteMultiOperation until the update
58825896
# reaches user-provided wait stage or is at least ACCEPTED (as of the
58835897
# time of this writing, the user cannot specify sooner than ACCEPTED)
@@ -5943,6 +5957,9 @@ async def start_workflow_update_with_start(
59435957
raise RuntimeError("Invalid ExecuteMultiOperationResponse")
59445958
[start_response] = start_responses
59455959
[update_response] = update_responses
5960+
if not seen_start:
5961+
on_start(start_response)
5962+
seen_start = True
59465963
known_outcome = (
59475964
update_response.outcome if update_response.HasField("outcome") else None
59485965
)

tests/worker/test_update_with_start.py

Lines changed: 71 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,24 @@
44
from contextlib import contextmanager
55
from datetime import timedelta
66
from enum import Enum
7-
from typing import Iterator, NoReturn
7+
from typing import Any, Iterator, NoReturn
88
from unittest.mock import patch
99

1010
import pytest
1111

12+
import temporalio.api.common.v1
1213
import temporalio.api.errordetails.v1
1314
import temporalio.worker
1415
from temporalio import activity, workflow
1516
from temporalio.client import (
1617
Client,
18+
Interceptor,
19+
OutboundInterceptor,
1720
RPCError,
21+
StartWorkflowUpdateWithStartInput,
1822
WithStartWorkflowOperation,
1923
WorkflowUpdateFailedError,
24+
WorkflowUpdateHandle,
2025
WorkflowUpdateStage,
2126
)
2227
from temporalio.common import (
@@ -430,8 +435,8 @@ async def test_workflow_update_poll_loop(client: Client):
430435
)
431436
with patch.object(
432437
client.workflow_service,
433-
"update_workflow_execution",
434-
wraps=client.workflow_service.update_workflow_execution,
438+
"execute_multi_operation",
439+
wraps=client.workflow_service.execute_multi_operation,
435440
) as workflow_service_method:
436441
try:
437442
await client.execute_update_with_start(
@@ -440,6 +445,68 @@ async def test_workflow_update_poll_loop(client: Client):
440445
)
441446
except:
442447
print(
443-
f"update_workflow_execution was called {workflow_service_method.call_count} times"
448+
f"execute_multi_operation was called {workflow_service_method.call_count} times"
444449
)
445450
raise
451+
452+
453+
class SimpleClientInterceptor(Interceptor):
454+
def intercept_client(self, next: OutboundInterceptor) -> OutboundInterceptor:
455+
return SimpleClientOutboundInterceptor(super().intercept_client(next))
456+
457+
458+
class SimpleClientOutboundInterceptor(OutboundInterceptor):
459+
def __init__(self, next: OutboundInterceptor) -> None:
460+
super().__init__(next)
461+
462+
async def start_workflow_update_with_start(
463+
self, input: StartWorkflowUpdateWithStartInput
464+
) -> WorkflowUpdateHandle[Any]:
465+
input.start_workflow_input.args = ["intercepted-workflow-arg"]
466+
input.update_workflow_input.args = ["intercepted-update-arg"]
467+
return await super().start_workflow_update_with_start(input)
468+
469+
470+
@workflow.defn
471+
class UpdateWithStartInterceptorWorkflow:
472+
def __init__(self) -> None:
473+
self.received_update = False
474+
475+
@workflow.run
476+
async def run(self, arg: str) -> str:
477+
await workflow.wait_condition(lambda: self.received_update)
478+
return arg
479+
480+
@workflow.update
481+
async def my_update(self, arg: str) -> str:
482+
self.received_update = True
483+
await workflow.wait_condition(lambda: self.received_update)
484+
return arg
485+
486+
487+
async def test_update_with_start_client_outbound_interceptor(
488+
client: Client,
489+
):
490+
interceptor = SimpleClientInterceptor()
491+
client = Client(**{**client.config(), "interceptors": [interceptor]})
492+
493+
async with new_worker(
494+
client,
495+
UpdateWithStartInterceptorWorkflow,
496+
) as worker:
497+
start_op = WithStartWorkflowOperation(
498+
UpdateWithStartInterceptorWorkflow.run,
499+
"original-workflow-arg",
500+
id=f"wf-{uuid.uuid4()}",
501+
task_queue=worker.task_queue,
502+
id_conflict_policy=WorkflowIDConflictPolicy.FAIL,
503+
)
504+
update_result = await client.execute_update_with_start(
505+
UpdateWithStartInterceptorWorkflow.my_update,
506+
"original-update-arg",
507+
start_workflow_operation=start_op,
508+
)
509+
assert update_result == "intercepted-update-arg"
510+
511+
wf_handle = await start_op.workflow_handle()
512+
assert await wf_handle.result() == "intercepted-workflow-arg"

0 commit comments

Comments
 (0)