Skip to content

Commit 80766bc

Browse files
committed
WIP pyupgrade
1 parent 90dda94 commit 80766bc

File tree

97 files changed

+3921
-4059
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

97 files changed

+3921
-4059
lines changed

scripts/gen_payload_visitor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def emit_loop(
4444

4545

4646
def emit_singular(
47-
field_name: str, access_expr: str, child_method: str, presence_word: Optional[str]
47+
field_name: str, access_expr: str, child_method: str, presence_word: str | None
4848
) -> str:
4949
# Helper to emit a singular field visit with presence check and optional headers guard
5050
if presence_word:
@@ -152,7 +152,7 @@ async def _visit_payload_container(self, fs, o):
152152
""",
153153
]
154154

155-
def check_repeated(self, child_desc, field, iter_expr) -> Optional[str]:
155+
def check_repeated(self, child_desc, field, iter_expr) -> str | None:
156156
# Special case for repeated payloads, handle them directly
157157
if child_desc.full_name == Payload.DESCRIPTOR.full_name:
158158
return emit_singular(field.name, iter_expr, "payload_container", None)

scripts/gen_protos.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
import subprocess
55
import sys
66
import tempfile
7+
from collections.abc import Mapping
78
from functools import partial
89
from pathlib import Path
9-
from typing import List, Mapping
10+
from typing import List
1011

1112
base_dir = Path(__file__).parent.parent
1213
proto_dir = (
@@ -64,7 +65,7 @@ def fix_generated_output(base_path: Path):
6465
- protoc doesn't generate the correct import paths
6566
(https://github.com/protocolbuffers/protobuf/issues/1491)
6667
"""
67-
imports: Mapping[str, List[str]] = collections.defaultdict(list)
68+
imports: Mapping[str, list[str]] = collections.defaultdict(list)
6869
for p in base_path.iterdir():
6970
if p.is_dir():
7071
fix_generated_output(p)

scripts/run_bench.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
import sys
66
import time
77
import uuid
8+
from collections.abc import AsyncIterator
89
from contextlib import asynccontextmanager
910
from datetime import timedelta
10-
from typing import AsyncIterator
1111

1212
from temporalio import activity, workflow
1313
from temporalio.testing import WorkflowEnvironment

temporalio/activity.py

Lines changed: 33 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,16 @@
1515
import inspect
1616
import logging
1717
import threading
18+
from collections.abc import Callable, Iterator, Mapping, MutableMapping, Sequence
1819
from contextlib import AbstractContextManager, contextmanager
1920
from dataclasses import dataclass
2021
from datetime import datetime, timedelta
2122
from typing import (
2223
TYPE_CHECKING,
2324
Any,
24-
Callable,
25-
Iterator,
2625
List,
27-
Mapping,
28-
MutableMapping,
2926
NoReturn,
3027
Optional,
31-
Sequence,
3228
Tuple,
3329
Type,
3430
Union,
@@ -53,7 +49,7 @@ def defn(fn: CallableType) -> CallableType: ...
5349

5450
@overload
5551
def defn(
56-
*, name: Optional[str] = None, no_thread_cancel_exception: bool = False
52+
*, name: str | None = None, no_thread_cancel_exception: bool = False
5753
) -> Callable[[CallableType], CallableType]: ...
5854

5955

@@ -64,9 +60,9 @@ def defn(
6460

6561

6662
def defn(
67-
fn: Optional[CallableType] = None,
63+
fn: CallableType | None = None, # type: ignore[reportInvalidTypeVarUse]
6864
*,
69-
name: Optional[str] = None,
65+
name: str | None = None,
7066
no_thread_cancel_exception: bool = False,
7167
dynamic: bool = False,
7268
):
@@ -111,11 +107,11 @@ class Info:
111107
attempt: int
112108
current_attempt_scheduled_time: datetime
113109
heartbeat_details: Sequence[Any]
114-
heartbeat_timeout: Optional[timedelta]
110+
heartbeat_timeout: timedelta | None
115111
is_local: bool
116-
schedule_to_close_timeout: Optional[timedelta]
112+
schedule_to_close_timeout: timedelta | None
117113
scheduled_time: datetime
118-
start_to_close_timeout: Optional[timedelta]
114+
start_to_close_timeout: timedelta | None
119115
started_time: datetime
120116
task_queue: str
121117
task_token: bytes
@@ -124,7 +120,7 @@ class Info:
124120
workflow_run_id: str
125121
workflow_type: str
126122
priority: temporalio.common.Priority
127-
retry_policy: Optional[temporalio.common.RetryPolicy]
123+
retry_policy: temporalio.common.RetryPolicy | None
128124
"""The retry policy of this activity.
129125
130126
Note that the server may have set a different policy than the one provided when scheduling the activity.
@@ -151,7 +147,7 @@ def _logger_details(self) -> Mapping[str, Any]:
151147

152148
@dataclass
153149
class _ActivityCancellationDetailsHolder:
154-
details: Optional[ActivityCancellationDetails] = None
150+
details: ActivityCancellationDetails | None = None
155151

156152

157153
@dataclass(frozen=True)
@@ -183,20 +179,20 @@ def _from_proto(
183179
class _Context:
184180
info: Callable[[], Info]
185181
# This is optional because during interceptor init it is not present
186-
heartbeat: Optional[Callable[..., None]]
182+
heartbeat: Callable[..., None] | None
187183
cancelled_event: _CompositeEvent
188184
worker_shutdown_event: _CompositeEvent
189-
shield_thread_cancel_exception: Optional[Callable[[], AbstractContextManager]]
190-
payload_converter_class_or_instance: Union[
191-
Type[temporalio.converter.PayloadConverter],
192-
temporalio.converter.PayloadConverter,
193-
]
194-
runtime_metric_meter: Optional[temporalio.common.MetricMeter]
195-
client: Optional[Client]
185+
shield_thread_cancel_exception: Callable[[], AbstractContextManager] | None
186+
payload_converter_class_or_instance: (
187+
type[temporalio.converter.PayloadConverter]
188+
| temporalio.converter.PayloadConverter
189+
)
190+
runtime_metric_meter: temporalio.common.MetricMeter | None
191+
client: Client | None
196192
cancellation_details: _ActivityCancellationDetailsHolder
197-
_logger_details: Optional[Mapping[str, Any]] = None
198-
_payload_converter: Optional[temporalio.converter.PayloadConverter] = None
199-
_metric_meter: Optional[temporalio.common.MetricMeter] = None
193+
_logger_details: Mapping[str, Any] | None = None
194+
_payload_converter: temporalio.converter.PayloadConverter | None = None
195+
_metric_meter: temporalio.common.MetricMeter | None = None
200196

201197
@staticmethod
202198
def current() -> _Context:
@@ -258,9 +254,9 @@ def metric_meter(self) -> temporalio.common.MetricMeter:
258254
@dataclass
259255
class _CompositeEvent:
260256
# This should always be present, but is sometimes lazily set internally
261-
thread_event: Optional[threading.Event]
257+
thread_event: threading.Event | None
262258
# Async event only for async activities
263-
async_event: Optional[asyncio.Event]
259+
async_event: asyncio.Event | None
264260

265261
def set(self) -> None:
266262
if not self.thread_event:
@@ -279,7 +275,7 @@ async def wait(self) -> None:
279275
raise RuntimeError("not in async activity")
280276
await self.async_event.wait()
281277

282-
def wait_sync(self, timeout: Optional[float] = None) -> None:
278+
def wait_sync(self, timeout: float | None = None) -> None:
283279
if not self.thread_event:
284280
raise RuntimeError("Missing event")
285281
self.thread_event.wait(timeout)
@@ -330,7 +326,7 @@ def info() -> Info:
330326
return _Context.current().info()
331327

332328

333-
def cancellation_details() -> Optional[ActivityCancellationDetails]:
329+
def cancellation_details() -> ActivityCancellationDetails | None:
334330
"""Cancellation details of the current activity, if any. Once set, cancellation details do not change."""
335331
return _Context.current().cancellation_details.details
336332

@@ -398,7 +394,7 @@ async def wait_for_cancelled() -> None:
398394
await _Context.current().cancelled_event.wait()
399395

400396

401-
def wait_for_cancelled_sync(timeout: Optional[Union[timedelta, float]] = None) -> None:
397+
def wait_for_cancelled_sync(timeout: timedelta | float | None = None) -> None:
402398
"""Synchronously block while waiting for a cancellation request on this
403399
activity.
404400
@@ -437,7 +433,7 @@ async def wait_for_worker_shutdown() -> None:
437433

438434

439435
def wait_for_worker_shutdown_sync(
440-
timeout: Optional[Union[timedelta, float]] = None,
436+
timeout: timedelta | float | None = None,
441437
) -> None:
442438
"""Synchronously block while waiting for shutdown to be called on the
443439
worker.
@@ -511,9 +507,7 @@ class LoggerAdapter(logging.LoggerAdapter):
511507
use by others. Default is False.
512508
"""
513509

514-
def __init__(
515-
self, logger: logging.Logger, extra: Optional[Mapping[str, Any]]
516-
) -> None:
510+
def __init__(self, logger: logging.Logger, extra: Mapping[str, Any] | None) -> None:
517511
"""Create the logger adapter."""
518512
super().__init__(logger, extra or {})
519513
self.activity_info_on_message = True
@@ -522,7 +516,7 @@ def __init__(
522516

523517
def process(
524518
self, msg: Any, kwargs: MutableMapping[str, Any]
525-
) -> Tuple[Any, MutableMapping[str, Any]]:
519+
) -> tuple[Any, MutableMapping[str, Any]]:
526520
"""Override to add activity details."""
527521
if (
528522
self.activity_info_on_message
@@ -559,16 +553,16 @@ def base_logger(self) -> logging.Logger:
559553

560554
@dataclass(frozen=True)
561555
class _Definition:
562-
name: Optional[str]
556+
name: str | None
563557
fn: Callable
564558
is_async: bool
565559
no_thread_cancel_exception: bool
566560
# Types loaded on post init if both are None
567-
arg_types: Optional[List[Type]] = None
568-
ret_type: Optional[Type] = None
561+
arg_types: list[type] | None = None
562+
ret_type: type | None = None
569563

570564
@staticmethod
571-
def from_callable(fn: Callable) -> Optional[_Definition]:
565+
def from_callable(fn: Callable) -> _Definition | None:
572566
defn = getattr(fn, "__temporal_activity_definition", None)
573567
if isinstance(defn, _Definition):
574568
# We have to replace the function with the given callable here
@@ -592,7 +586,7 @@ def must_from_callable(fn: Callable) -> _Definition:
592586
def _apply_to_callable(
593587
fn: Callable,
594588
*,
595-
activity_name: Optional[str],
589+
activity_name: str | None,
596590
no_thread_cancel_exception: bool = False,
597591
) -> None:
598592
# Validate the activity

temporalio/bridge/_visitor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# This file is generated by gen_payload_visitor.py. Changes should be made there.
22
import abc
3-
from typing import Any, MutableSequence
3+
from collections.abc import MutableSequence
4+
from typing import Any
45

56
from temporalio.api.common.v1.message_pb2 import Payload
67

temporalio/bridge/client.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55

66
from __future__ import annotations
77

8+
from collections.abc import Mapping
89
from dataclasses import dataclass
910
from datetime import timedelta
10-
from typing import Mapping, Optional, Tuple, Type, TypeVar, Union
11+
from typing import Optional, Tuple, Type, TypeVar, Union
1112

1213
import google.protobuf.message
1314

@@ -20,10 +21,10 @@
2021
class ClientTlsConfig:
2122
"""Python representation of the Rust struct for configuring TLS."""
2223

23-
server_root_ca_cert: Optional[bytes]
24-
domain: Optional[str]
25-
client_cert: Optional[bytes]
26-
client_private_key: Optional[bytes]
24+
server_root_ca_cert: bytes | None
25+
domain: str | None
26+
client_cert: bytes | None
27+
client_private_key: bytes | None
2728

2829

2930
@dataclass
@@ -34,7 +35,7 @@ class ClientRetryConfig:
3435
randomization_factor: float
3536
multiplier: float
3637
max_interval_millis: int
37-
max_elapsed_time_millis: Optional[int]
38+
max_elapsed_time_millis: int | None
3839
max_retries: int
3940

4041

@@ -51,23 +52,23 @@ class ClientHttpConnectProxyConfig:
5152
"""Python representation of the Rust struct for configuring HTTP proxy."""
5253

5354
target_host: str
54-
basic_auth: Optional[Tuple[str, str]]
55+
basic_auth: tuple[str, str] | None
5556

5657

5758
@dataclass
5859
class ClientConfig:
5960
"""Python representation of the Rust struct for configuring the client."""
6061

6162
target_url: str
62-
metadata: Mapping[str, Union[str, bytes]]
63-
api_key: Optional[str]
63+
metadata: Mapping[str, str | bytes]
64+
api_key: str | None
6465
identity: str
65-
tls_config: Optional[ClientTlsConfig]
66-
retry_config: Optional[ClientRetryConfig]
67-
keep_alive_config: Optional[ClientKeepAliveConfig]
66+
tls_config: ClientTlsConfig | None
67+
retry_config: ClientRetryConfig | None
68+
keep_alive_config: ClientKeepAliveConfig | None
6869
client_name: str
6970
client_version: str
70-
http_connect_proxy_config: Optional[ClientHttpConnectProxyConfig]
71+
http_connect_proxy_config: ClientHttpConnectProxyConfig | None
7172

7273

7374
@dataclass
@@ -77,8 +78,8 @@ class RpcCall:
7778
rpc: str
7879
req: bytes
7980
retry: bool
80-
metadata: Mapping[str, Union[str, bytes]]
81-
timeout_millis: Optional[int]
81+
metadata: Mapping[str, str | bytes]
82+
timeout_millis: int | None
8283

8384

8485
ProtoMessage = TypeVar("ProtoMessage", bound=google.protobuf.message.Message)
@@ -108,11 +109,11 @@ def __init__(
108109
self._runtime = runtime
109110
self._ref = ref
110111

111-
def update_metadata(self, metadata: Mapping[str, Union[str, bytes]]) -> None:
112+
def update_metadata(self, metadata: Mapping[str, str | bytes]) -> None:
112113
"""Update underlying metadata on Core client."""
113114
self._ref.update_metadata(metadata)
114115

115-
def update_api_key(self, api_key: Optional[str]) -> None:
116+
def update_api_key(self, api_key: str | None) -> None:
116117
"""Update underlying API key on Core client."""
117118
self._ref.update_api_key(api_key)
118119

@@ -122,10 +123,10 @@ async def call(
122123
service: str,
123124
rpc: str,
124125
req: google.protobuf.message.Message,
125-
resp_type: Type[ProtoMessage],
126+
resp_type: type[ProtoMessage],
126127
retry: bool,
127-
metadata: Mapping[str, Union[str, bytes]],
128-
timeout: Optional[timedelta],
128+
metadata: Mapping[str, str | bytes],
129+
timeout: timedelta | None,
129130
) -> ProtoMessage:
130131
"""Make RPC call using SDK Core."""
131132
# Prepare call

0 commit comments

Comments
 (0)