1212import traceback
1313import warnings
1414from abc import ABC , abstractmethod
15+ from contextlib import contextmanager
1516from dataclasses import dataclass
1617from datetime import timedelta
1718from typing import (
2122 Deque ,
2223 Dict ,
2324 Generator ,
25+ Iterator ,
2426 List ,
2527 Mapping ,
2628 MutableMapping ,
@@ -193,6 +195,7 @@ def __init__(self, det: WorkflowInstanceDetails) -> None:
193195 self ._object : Any = None
194196 self ._is_replaying : bool = False
195197 self ._random = random .Random (det .randomness_seed )
198+ self ._read_only = False
196199
197200 # Patches we have been notified of and memoized patch responses
198201 self ._patches_notified : Set [str ] = set ()
@@ -421,36 +424,39 @@ async def run_query() -> None:
421424 command = self ._add_command ()
422425 command .respond_to_query .query_id = job .query_id
423426 try :
424- # Named query or dynamic
425- defn = self ._queries .get (job .query_type ) or self ._queries .get (None )
426- if not defn :
427- known_queries = sorted ([k for k in self ._queries .keys () if k ])
428- raise RuntimeError (
429- f"Query handler for '{ job .query_type } ' expected but not found, "
430- f"known queries: [{ ' ' .join (known_queries )} ]"
427+ with self ._as_read_only ():
428+ # Named query or dynamic
429+ defn = self ._queries .get (job .query_type ) or self ._queries .get (None )
430+ if not defn :
431+ known_queries = sorted ([k for k in self ._queries .keys () if k ])
432+ raise RuntimeError (
433+ f"Query handler for '{ job .query_type } ' expected but not found, "
434+ f"known queries: [{ ' ' .join (known_queries )} ]"
435+ )
436+
437+ # Create input
438+ args = self ._process_handler_args (
439+ job .query_type ,
440+ job .arguments ,
441+ defn .name ,
442+ defn .arg_types ,
443+ defn .dynamic_vararg ,
431444 )
432-
433- # Create input
434- args = self ._process_handler_args (
435- job .query_type ,
436- job .arguments ,
437- defn .name ,
438- defn .arg_types ,
439- defn .dynamic_vararg ,
440- )
441- input = HandleQueryInput (
442- id = job .query_id ,
443- query = job .query_type ,
444- args = args ,
445- headers = job .headers ,
446- )
447- success = await self ._inbound .handle_query (input )
448- result_payloads = self ._payload_converter .to_payloads ([success ])
449- if len (result_payloads ) != 1 :
450- raise ValueError (
451- f"Expected 1 result payload, got { len (result_payloads )} "
445+ input = HandleQueryInput (
446+ id = job .query_id ,
447+ query = job .query_type ,
448+ args = args ,
449+ headers = job .headers ,
450+ )
451+ success = await self ._inbound .handle_query (input )
452+ result_payloads = self ._payload_converter .to_payloads ([success ])
453+ if len (result_payloads ) != 1 :
454+ raise ValueError (
455+ f"Expected 1 result payload, got { len (result_payloads )} "
456+ )
457+ command .respond_to_query .succeeded .response .CopyFrom (
458+ result_payloads [0 ]
452459 )
453- command .respond_to_query .succeeded .response .CopyFrom (result_payloads [0 ])
454460 except Exception as err :
455461 try :
456462 self ._failure_converter .to_failure (
@@ -695,6 +701,7 @@ def workflow_continue_as_new(
695701 search_attributes : Optional [temporalio .common .SearchAttributes ],
696702 versioning_intent : Optional [temporalio .workflow .VersioningIntent ],
697703 ) -> NoReturn :
704+ self ._assert_not_read_only ("continue as new" )
698705 # Use definition if callable
699706 name : Optional [str ] = None
700707 arg_types : Optional [List [Type ]] = None
@@ -795,12 +802,20 @@ def workflow_payload_converter(self) -> temporalio.converter.PayloadConverter:
795802 return self ._payload_converter
796803
797804 def workflow_random (self ) -> random .Random :
805+ self ._assert_not_read_only ("random" )
798806 return self ._random
799807
800808 def workflow_set_query_handler (
801809 self , name : Optional [str ], handler : Optional [Callable ]
802810 ) -> None :
811+ self ._assert_not_read_only ("set query handler" )
803812 if handler :
813+ if inspect .iscoroutinefunction (handler ):
814+ warnings .warn (
815+ "Queries as async def functions are deprecated" ,
816+ DeprecationWarning ,
817+ stacklevel = 3 ,
818+ )
804819 defn = temporalio .workflow ._QueryDefinition (
805820 name = name , fn = handler , is_method = False
806821 )
@@ -817,6 +832,7 @@ def workflow_set_query_handler(
817832 def workflow_set_signal_handler (
818833 self , name : Optional [str ], handler : Optional [Callable ]
819834 ) -> None :
835+ self ._assert_not_read_only ("set signal handler" )
820836 if handler :
821837 defn = temporalio .workflow ._SignalDefinition (
822838 name = name , fn = handler , is_method = False
@@ -855,6 +871,7 @@ def workflow_start_activity(
855871 activity_id : Optional [str ],
856872 versioning_intent : Optional [temporalio .workflow .VersioningIntent ],
857873 ) -> temporalio .workflow .ActivityHandle [Any ]:
874+ self ._assert_not_read_only ("start activity" )
858875 # Get activity definition if it's callable
859876 name : str
860877 arg_types : Optional [List [Type ]] = None
@@ -1012,6 +1029,7 @@ def workflow_upsert_search_attributes(
10121029 async def workflow_wait_condition (
10131030 self , fn : Callable [[], bool ], * , timeout : Optional [float ] = None
10141031 ) -> None :
1032+ self ._assert_not_read_only ("wait condition" )
10151033 fut = self .create_future ()
10161034 self ._conditions .append ((fn , fut ))
10171035 await asyncio .wait_for (fut , timeout )
@@ -1153,8 +1171,24 @@ async def run_child() -> Any:
11531171 # These are in alphabetical order.
11541172
11551173 def _add_command (self ) -> temporalio .bridge .proto .workflow_commands .WorkflowCommand :
1174+ self ._assert_not_read_only ("add command" )
11561175 return self ._current_completion .successful .commands .add ()
11571176
1177+ @contextmanager
1178+ def _as_read_only (self ) -> Iterator [None ]:
1179+ prev_val = self ._read_only
1180+ self ._read_only = True
1181+ try :
1182+ yield None
1183+ finally :
1184+ self ._read_only = prev_val
1185+
1186+ def _assert_not_read_only (self , action_attempted : str ) -> None :
1187+ if self ._read_only :
1188+ raise temporalio .workflow .ReadOnlyContextError (
1189+ f"While in read-only function, action attempted: { action_attempted } "
1190+ )
1191+
11581192 async def _cancel_external_workflow (
11591193 self ,
11601194 # Should not have seq set
@@ -1258,6 +1292,7 @@ def _register_task(
12581292 * ,
12591293 name : Optional [str ],
12601294 ) -> None :
1295+ self ._assert_not_read_only ("create task" )
12611296 # Name not supported on older Python versions
12621297 if sys .version_info >= (3 , 8 ):
12631298 # Put the workflow info at the end of the task name
@@ -1423,6 +1458,7 @@ def call_soon(
14231458 * args : Any ,
14241459 context : Optional [contextvars .Context ] = None ,
14251460 ) -> asyncio .Handle :
1461+ self ._assert_not_read_only ("schedule task" )
14261462 handle = asyncio .Handle (callback , args , self , context )
14271463 self ._ready .append (handle )
14281464 return handle
@@ -1434,6 +1470,7 @@ def call_later(
14341470 * args : Any ,
14351471 context : Optional [contextvars .Context ] = None ,
14361472 ) -> asyncio .TimerHandle :
1473+ self ._assert_not_read_only ("schedule timer" )
14371474 # Delay must be positive
14381475 if delay < 0 :
14391476 raise RuntimeError ("Attempting to schedule timer with negative delay" )
@@ -1675,6 +1712,7 @@ def __init__(
16751712 instance ._register_task (self , name = f"activity: { input .activity } " )
16761713
16771714 def cancel (self , msg : Optional [Any ] = None ) -> bool :
1715+ self ._instance ._assert_not_read_only ("cancel activity handle" )
16781716 # We override this because if it's not yet started and not done, we need
16791717 # to send a cancel command because the async function won't run to trap
16801718 # the cancel (i.e. cancelled before started)
@@ -1821,6 +1859,7 @@ async def signal(
18211859 * ,
18221860 args : Sequence [Any ] = [],
18231861 ) -> None :
1862+ self ._instance ._assert_not_read_only ("signal child handle" )
18241863 await self ._instance ._outbound .signal_child_workflow (
18251864 SignalChildWorkflowInput (
18261865 signal = temporalio .workflow ._SignalDefinition .must_name_from_fn_or_str (
@@ -1935,6 +1974,7 @@ async def signal(
19351974 * ,
19361975 args : Sequence [Any ] = [],
19371976 ) -> None :
1977+ self ._instance ._assert_not_read_only ("signal external handle" )
19381978 await self ._instance ._outbound .signal_external_workflow (
19391979 SignalExternalWorkflowInput (
19401980 signal = temporalio .workflow ._SignalDefinition .must_name_from_fn_or_str (
@@ -1949,6 +1989,7 @@ async def signal(
19491989 )
19501990
19511991 async def cancel (self ) -> None :
1992+ self ._instance ._assert_not_read_only ("cancel external handle" )
19521993 command = self ._instance ._add_command ()
19531994 v = command .request_cancel_external_workflow_execution
19541995 v .workflow_execution .namespace = self ._instance ._info .namespace
0 commit comments