1414import threading
1515import warnings
1616from abc import ABC , abstractmethod
17+ from collections .abc import Iterator , Sequence
1718from contextlib import contextmanager
1819from dataclasses import dataclass , field
1920from datetime import datetime , timedelta , timezone
2021from typing import (
2122 Any ,
2223 Callable ,
23- Dict ,
24- Iterator ,
2524 NoReturn ,
2625 Optional ,
27- Sequence ,
28- Tuple ,
29- Type ,
3026 Union ,
3127)
3228
3329import google .protobuf .duration_pb2
3430import google .protobuf .timestamp_pb2
3531
3632import temporalio .activity
37- import temporalio .api .common .v1
38- import temporalio .bridge .client
39- import temporalio .bridge .proto
40- import temporalio .bridge .proto .activity_result
41- import temporalio .bridge .proto .activity_task
42- import temporalio .bridge .proto .common
4333import temporalio .bridge .runtime
4434import temporalio .bridge .worker
4535import temporalio .client
@@ -76,7 +66,7 @@ def __init__(
7666 self ._task_queue = task_queue
7767 self ._activity_executor = activity_executor
7868 self ._shared_state_manager = shared_state_manager
79- self ._running_activities : Dict [bytes , _RunningActivity ] = {}
69+ self ._running_activities : dict [bytes , _RunningActivity ] = {}
8070 self ._data_converter = data_converter
8171 self ._interceptors = interceptors
8272 self ._metric_meter = metric_meter
@@ -90,7 +80,7 @@ def __init__(
9080 self ._client = client
9181
9282 # Validate and build activity dict
93- self ._activities : Dict [str , temporalio .activity ._Definition ] = {}
83+ self ._activities : dict [str , temporalio .activity ._Definition ] = {}
9484 self ._dynamic_activity : Optional [temporalio .activity ._Definition ] = None
9585 for activity in activities :
9686 # Get definition
@@ -180,7 +170,7 @@ async def raise_from_exception_queue() -> NoReturn:
180170 self ._handle_cancel_activity_task (task .task_token , task .cancel )
181171 else :
182172 raise RuntimeError (f"Unrecognized activity task: { task } " )
183- except temporalio .bridge .worker .PollShutdownError :
173+ except temporalio .bridge .worker .PollShutdownError : # type: ignore[reportPrivateLocalImportUsage]
184174 exception_task .cancel ()
185175 return
186176 except Exception as err :
@@ -197,12 +187,12 @@ async def drain_poll_queue(self) -> None:
197187 try :
198188 # Just take all tasks and say we can't handle them
199189 task = await self ._bridge_worker ().poll_activity_task ()
200- completion = temporalio .bridge .proto .ActivityTaskCompletion (
190+ completion = temporalio .bridge .proto .ActivityTaskCompletion ( # type: ignore[reportAttributeAccessIssue]
201191 task_token = task .task_token
202192 )
203193 completion .result .failed .failure .message = "Worker shutting down"
204194 await self ._bridge_worker ().complete_activity_task (completion )
205- except temporalio .bridge .worker .PollShutdownError :
195+ except temporalio .bridge .worker .PollShutdownError : # type: ignore[reportPrivateLocalImportUsage]
206196 return
207197
208198 # Only call this after run()/drain_poll_queue() have returned. This will not
@@ -216,7 +206,9 @@ async def wait_all_completed(self) -> None:
216206 await asyncio .gather (* running_tasks , return_exceptions = False )
217207
218208 def _handle_cancel_activity_task (
219- self , task_token : bytes , cancel : temporalio .bridge .proto .activity_task .Cancel
209+ self ,
210+ task_token : bytes ,
211+ cancel : temporalio .bridge .proto .activity_task .Cancel , # type: ignore[reportAttributeAccessIssue]
220212 ) -> None :
221213 """Request cancellation of a running activity task."""
222214 activity = self ._running_activities .get (task_token )
@@ -264,7 +256,9 @@ async def _heartbeat_async(
264256
265257 # Perform the heartbeat
266258 try :
267- heartbeat = temporalio .bridge .proto .ActivityHeartbeat (task_token = task_token )
259+ heartbeat = temporalio .bridge .proto .ActivityHeartbeat ( # type: ignore[reportAttributeAccessIssue]
260+ task_token = task_token
261+ )
268262 if details :
269263 # Convert to core payloads
270264 heartbeat .details .extend (await self ._data_converter .encode (details ))
@@ -286,7 +280,7 @@ async def _heartbeat_async(
286280 async def _handle_start_activity_task (
287281 self ,
288282 task_token : bytes ,
289- start : temporalio .bridge .proto .activity_task .Start ,
283+ start : temporalio .bridge .proto .activity_task .Start , # type: ignore[reportAttributeAccessIssue]
290284 running_activity : _RunningActivity ,
291285 ) -> None :
292286 """Handle a start activity task.
@@ -298,7 +292,7 @@ async def _handle_start_activity_task(
298292 # We choose to surround interceptor creation and activity invocation in
299293 # a try block so we can mark the workflow as failed on any error instead
300294 # of having error handling in the interceptor
301- completion = temporalio .bridge .proto .ActivityTaskCompletion (
295+ completion = temporalio .bridge .proto .ActivityTaskCompletion ( # type: ignore[reportAttributeAccessIssue]
302296 task_token = task_token
303297 )
304298 try :
@@ -415,7 +409,7 @@ async def _handle_start_activity_task(
415409
416410 async def _execute_activity (
417411 self ,
418- start : temporalio .bridge .proto .activity_task .Start ,
412+ start : temporalio .bridge .proto .activity_task .Start , # type: ignore[reportAttributeAccessIssue]
419413 running_activity : _RunningActivity ,
420414 task_token : bytes ,
421415 ) -> Any :
@@ -651,14 +645,14 @@ class _ThreadExceptionRaiser:
651645 def __init__ (self ) -> None :
652646 self ._lock = threading .Lock ()
653647 self ._thread_id : Optional [int ] = None
654- self ._pending_exception : Optional [Type [Exception ]] = None
648+ self ._pending_exception : Optional [type [Exception ]] = None
655649 self ._shield_depth = 0
656650
657651 def set_thread_id (self , thread_id : int ) -> None :
658652 with self ._lock :
659653 self ._thread_id = thread_id
660654
661- def raise_in_thread (self , exc_type : Type [Exception ]) -> None :
655+ def raise_in_thread (self , exc_type : type [Exception ]) -> None :
662656 with self ._lock :
663657 self ._pending_exception = exc_type
664658 self ._raise_in_thread_if_pending_unlocked ()
@@ -814,7 +808,7 @@ def _execute_sync_activity(
814808 cancelled_event : threading .Event ,
815809 worker_shutdown_event : threading .Event ,
816810 payload_converter_class_or_instance : Union [
817- Type [temporalio .converter .PayloadConverter ],
811+ type [temporalio .converter .PayloadConverter ],
818812 temporalio .converter .PayloadConverter ,
819813 ],
820814 runtime_metric_meter : Optional [temporalio .common .MetricMeter ],
@@ -826,13 +820,10 @@ def _execute_sync_activity(
826820 thread_id = threading .current_thread ().ident
827821 if thread_id is not None :
828822 cancel_thread_raiser .set_thread_id (thread_id )
829- heartbeat_fn : Callable [..., None ]
830823 if isinstance (heartbeat , SharedHeartbeatSender ):
831- # To make mypy happy
832- heartbeat_sender = heartbeat
833- heartbeat_fn = lambda * details : heartbeat_sender .send_heartbeat (
834- info .task_token , * details
835- )
824+
825+ def heartbeat_fn (* details : Any ) -> None :
826+ heartbeat .send_heartbeat (info .task_token , * details )
836827 else :
837828 heartbeat_fn = heartbeat
838829 temporalio .activity ._Context .set (
@@ -942,11 +933,11 @@ def __init__(
942933 self ._mgr = mgr
943934 self ._queue_poller_executor = queue_poller_executor
944935 # 1000 in-flight heartbeats should be plenty
945- self ._heartbeat_queue : queue .Queue [Tuple [bytes , Sequence [Any ]]] = mgr .Queue (
936+ self ._heartbeat_queue : queue .Queue [tuple [bytes , Sequence [Any ]]] = mgr .Queue (
946937 1000
947938 )
948- self ._heartbeats : Dict [bytes , Callable [..., None ]] = {}
949- self ._heartbeat_completions : Dict [bytes , Callable ] = {}
939+ self ._heartbeats : dict [bytes , Callable [..., None ]] = {}
940+ self ._heartbeat_completions : dict [bytes , Callable ] = {}
950941
951942 def new_event (self ) -> threading .Event :
952943 return self ._mgr .Event ()
@@ -1004,7 +995,7 @@ def _heartbeat_processor(self) -> None:
1004995
1005996class _MultiprocessingSharedHeartbeatSender (SharedHeartbeatSender ):
1006997 def __init__ (
1007- self , heartbeat_queue : queue .Queue [Tuple [bytes , Sequence [Any ]]]
998+ self , heartbeat_queue : queue .Queue [tuple [bytes , Sequence [Any ]]]
1008999 ) -> None :
10091000 super ().__init__ ()
10101001 self ._heartbeat_queue = heartbeat_queue
0 commit comments