44from contextlib import contextmanager
55from datetime import timedelta
66from enum import Enum
7- from typing import Iterator , NoReturn
7+ from typing import Any , Iterator , NoReturn
88from unittest .mock import patch
99
1010import pytest
1111
12+ import temporalio .api .common .v1
1213import temporalio .api .errordetails .v1
1314import temporalio .worker
1415from temporalio import activity , workflow
1516from temporalio .client import (
1617 Client ,
18+ Interceptor ,
19+ OutboundInterceptor ,
1720 RPCError ,
21+ StartWorkflowUpdateWithStartInput ,
1822 WithStartWorkflowOperation ,
1923 WorkflowUpdateFailedError ,
24+ WorkflowUpdateHandle ,
2025 WorkflowUpdateStage ,
2126)
2227from temporalio .common import (
@@ -430,8 +435,8 @@ async def test_workflow_update_poll_loop(client: Client):
430435 )
431436 with patch .object (
432437 client .workflow_service ,
433- "update_workflow_execution " ,
434- wraps = client .workflow_service .update_workflow_execution ,
438+ "execute_multi_operation " ,
439+ wraps = client .workflow_service .execute_multi_operation ,
435440 ) as workflow_service_method :
436441 try :
437442 await client .execute_update_with_start (
@@ -440,6 +445,68 @@ async def test_workflow_update_poll_loop(client: Client):
440445 )
441446 except :
442447 print (
443- f"update_workflow_execution was called { workflow_service_method .call_count } times"
448+ f"execute_multi_operation was called { workflow_service_method .call_count } times"
444449 )
445450 raise
451+
452+
453+ class SimpleClientInterceptor (Interceptor ):
454+ def intercept_client (self , next : OutboundInterceptor ) -> OutboundInterceptor :
455+ return SimpleClientOutboundInterceptor (super ().intercept_client (next ))
456+
457+
458+ class SimpleClientOutboundInterceptor (OutboundInterceptor ):
459+ def __init__ (self , next : OutboundInterceptor ) -> None :
460+ super ().__init__ (next )
461+
462+ async def start_workflow_update_with_start (
463+ self , input : StartWorkflowUpdateWithStartInput
464+ ) -> WorkflowUpdateHandle [Any ]:
465+ input .start_workflow_input .args = ["intercepted-workflow-arg" ]
466+ input .update_workflow_input .args = ["intercepted-update-arg" ]
467+ return await super ().start_workflow_update_with_start (input )
468+
469+
470+ @workflow .defn
471+ class UpdateWithStartInterceptorWorkflow :
472+ def __init__ (self ) -> None :
473+ self .received_update = False
474+
475+ @workflow .run
476+ async def run (self , arg : str ) -> str :
477+ await workflow .wait_condition (lambda : self .received_update )
478+ return arg
479+
480+ @workflow .update
481+ async def my_update (self , arg : str ) -> str :
482+ self .received_update = True
483+ await workflow .wait_condition (lambda : self .received_update )
484+ return arg
485+
486+
487+ async def test_update_with_start_client_outbound_interceptor (
488+ client : Client ,
489+ ):
490+ interceptor = SimpleClientInterceptor ()
491+ client = Client (** {** client .config (), "interceptors" : [interceptor ]})
492+
493+ async with new_worker (
494+ client ,
495+ UpdateWithStartInterceptorWorkflow ,
496+ ) as worker :
497+ start_op = WithStartWorkflowOperation (
498+ UpdateWithStartInterceptorWorkflow .run ,
499+ "original-workflow-arg" ,
500+ id = f"wf-{ uuid .uuid4 ()} " ,
501+ task_queue = worker .task_queue ,
502+ id_conflict_policy = WorkflowIDConflictPolicy .FAIL ,
503+ )
504+ update_result = await client .execute_update_with_start (
505+ UpdateWithStartInterceptorWorkflow .my_update ,
506+ "original-update-arg" ,
507+ start_workflow_operation = start_op ,
508+ )
509+ assert update_result == "intercepted-update-arg"
510+
511+ wf_handle = await start_op .workflow_handle ()
512+ assert await wf_handle .result () == "intercepted-workflow-arg"
0 commit comments