Skip to content

Commit a08250a

Browse files
skarzitconley1428
andauthored
refactor: use classmethod in factory methods (#1179)
* refactor: use classmethod in Client.connect factory method * refactor: use classmethod in Runtime.default factory method * refactor: use classmethod in WorkerTuner factory methods * refactor: use classmethod in WorkflowEnvironment.start_local factory method * refactor: use classmethod in SandboxMatcher.nested_child factory method * refactor: use classmethod in configurations factory methods * refactor: use classmethod in the remaining WorkflowEnvironment factory methods * chore(dev): apply fixers on new changes * fix: resolve typing issues with Self return type * fix: use default_factory for MappingProxyType instances defaults --------- Co-authored-by: tconley1428 <[email protected]>
1 parent 1b70f07 commit a08250a

File tree

7 files changed

+50
-37
lines changed

7 files changed

+50
-37
lines changed

temporalio/client.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,9 @@ class Client:
112112
Clients do not work across forks since runtimes do not work across forks.
113113
"""
114114

115-
@staticmethod
115+
@classmethod
116116
async def connect(
117+
cls,
117118
target_host: str,
118119
*,
119120
namespace: str = "default",
@@ -133,7 +134,7 @@ async def connect(
133134
runtime: Optional[temporalio.runtime.Runtime] = None,
134135
http_connect_proxy_config: Optional[HttpConnectProxyConfig] = None,
135136
header_codec_behavior: HeaderCodecBehavior = HeaderCodecBehavior.NO_CODEC,
136-
) -> Client:
137+
) -> Self:
137138
"""Connect to a Temporal server.
138139
139140
Args:
@@ -209,7 +210,7 @@ def make_lambda(plugin, next):
209210

210211
service_client = await next_function(connect_config)
211212

212-
return Client(
213+
return cls(
213214
service_client,
214215
namespace=namespace,
215216
data_converter=data_converter,

temporalio/envconfig.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from pathlib import Path
1111
from typing import Any, Dict, Literal, Mapping, Optional, Union, cast
1212

13-
from typing_extensions import TypeAlias, TypedDict
13+
from typing_extensions import Self, TypeAlias, TypedDict
1414

1515
import temporalio.service
1616
from temporalio.bridge.temporal_sdk_bridge import envconfig as _bridge_envconfig
@@ -148,12 +148,12 @@ def to_connect_tls_config(self) -> Union[bool, temporalio.service.TLSConfig]:
148148
client_private_key=_read_source(self.client_private_key),
149149
)
150150

151-
@staticmethod
152-
def from_dict(d: Optional[ClientConfigTLSDict]) -> Optional[ClientConfigTLS]:
151+
@classmethod
152+
def from_dict(cls, d: Optional[ClientConfigTLSDict]) -> Optional[Self]:
153153
"""Create a ClientConfigTLS from a dictionary."""
154154
if not d:
155155
return None
156-
return ClientConfigTLS(
156+
return cls(
157157
disabled=d.get("disabled"),
158158
server_name=d.get("server_name"),
159159
# Note: Bridge uses snake_case, but TOML uses kebab-case which is
@@ -202,10 +202,10 @@ class ClientConfigProfile:
202202
grpc_meta: Mapping[str, str] = field(default_factory=dict)
203203
"""gRPC metadata."""
204204

205-
@staticmethod
206-
def from_dict(d: ClientConfigProfileDict) -> ClientConfigProfile:
205+
@classmethod
206+
def from_dict(cls, d: ClientConfigProfileDict) -> Self:
207207
"""Create a ClientConfigProfile from a dictionary."""
208-
return ClientConfigProfile(
208+
return cls(
209209
address=d.get("address"),
210210
namespace=d.get("namespace"),
211211
api_key=d.get("api_key"),
@@ -318,14 +318,15 @@ def to_dict(self) -> Mapping[str, ClientConfigProfileDict]:
318318
"""Convert to a dictionary that can be used for TOML serialization."""
319319
return {k: v.to_dict() for k, v in self.profiles.items()}
320320

321-
@staticmethod
321+
@classmethod
322322
def from_dict(
323+
cls,
323324
d: Mapping[str, Mapping[str, Any]],
324-
) -> ClientConfig:
325+
) -> Self:
325326
"""Create a ClientConfig from a dictionary."""
326327
# We must cast the inner dictionary because the source is often a plain
327328
# Mapping[str, Any] from the bridge or other sources.
328-
return ClientConfig(
329+
return cls(
329330
profiles={
330331
k: ClientConfigProfile.from_dict(cast(ClientConfigProfileDict, v))
331332
for k, v in d.items()

temporalio/runtime.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ class Runtime:
3737
Runtimes do not work across forks.
3838
"""
3939

40-
@staticmethod
41-
def default() -> Runtime:
40+
@classmethod
41+
def default(cls) -> Runtime:
4242
"""Get the default runtime, creating if not already created.
4343
4444
If the default runtime needs to be different, it should be done with
@@ -49,7 +49,7 @@ def default() -> Runtime:
4949
"""
5050
global _default_runtime
5151
if not _default_runtime:
52-
_default_runtime = Runtime(telemetry=TelemetryConfig())
52+
_default_runtime = cls(telemetry=TelemetryConfig())
5353
return _default_runtime
5454

5555
@staticmethod

temporalio/testing/_workflow.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
)
2121

2222
import google.protobuf.empty_pb2
23+
from typing_extensions import Self
2324

2425
import temporalio.api.testservice.v1
2526
import temporalio.bridge.testing
@@ -54,8 +55,8 @@ class WorkflowEnvironment:
5455
to have ``assert`` failures fail the workflow with the assertion error.
5556
"""
5657

57-
@staticmethod
58-
def from_client(client: temporalio.client.Client) -> WorkflowEnvironment:
58+
@classmethod
59+
def from_client(cls, client: temporalio.client.Client) -> Self:
5960
"""Create a workflow environment from the given client.
6061
6162
:py:attr:`supports_time_skipping` will always return ``False`` for this
@@ -69,12 +70,11 @@ def from_client(client: temporalio.client.Client) -> WorkflowEnvironment:
6970
The workflow environment that runs against the given client.
7071
"""
7172
# Add the assertion interceptor
72-
return WorkflowEnvironment(
73-
_client_with_interceptors(client, _AssertionErrorInterceptor())
74-
)
73+
return cls(_client_with_interceptors(client, _AssertionErrorInterceptor()))
7574

76-
@staticmethod
75+
@classmethod
7776
async def start_local(
77+
cls,
7878
*,
7979
namespace: str = "default",
8080
data_converter: temporalio.converter.DataConverter = temporalio.converter.DataConverter.default,
@@ -234,8 +234,9 @@ async def start_local(
234234
)
235235
raise
236236

237-
@staticmethod
237+
@classmethod
238238
async def start_time_skipping(
239+
cls,
239240
*,
240241
data_converter: temporalio.converter.DataConverter = temporalio.converter.DataConverter.default,
241242
interceptors: Sequence[temporalio.client.Interceptor] = [],
@@ -357,7 +358,8 @@ async def start_time_skipping(
357358
def __init__(self, client: temporalio.client.Client) -> None:
358359
"""Create a workflow environment from a client.
359360
360-
Most users would use a static method instead.
361+
Most users would use a factory methods instead.
362+
361363
"""
362364
self._client = client
363365

temporalio/worker/_tuning.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from datetime import timedelta
88
from typing import Any, Callable, Literal, Optional, Protocol, Union, runtime_checkable
99

10-
from typing_extensions import TypeAlias
10+
from typing_extensions import Self, TypeAlias
1111

1212
import temporalio.bridge.worker
1313
from temporalio.common import WorkerDeploymentVersion
@@ -310,8 +310,9 @@ def _to_bridge_slot_supplier(
310310
class WorkerTuner(ABC):
311311
"""WorkerTuners allow for the dynamic customization of some aspects of worker configuration"""
312312

313-
@staticmethod
313+
@classmethod
314314
def create_resource_based(
315+
cls,
315316
*,
316317
target_memory_usage: float,
317318
target_cpu_usage: float,
@@ -341,8 +342,9 @@ def create_resource_based(
341342
nexus,
342343
)
343344

344-
@staticmethod
345+
@classmethod
345346
def create_fixed(
347+
cls,
346348
*,
347349
workflow_slots: Optional[int] = None,
348350
activity_slots: Optional[int] = None,
@@ -362,8 +364,9 @@ def create_fixed(
362364
FixedSizeSlotSupplier(nexus_slots if nexus_slots else 100),
363365
)
364366

365-
@staticmethod
367+
@classmethod
366368
def create_composite(
369+
cls,
367370
*,
368371
workflow_supplier: SlotSupplier,
369372
activity_supplier: SlotSupplier,

temporalio/worker/workflow_sandbox/_restrictions.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
cast,
3535
)
3636

37+
from typing_extensions import Self
38+
3739
try:
3840
import pydantic
3941
import pydantic_core
@@ -211,8 +213,8 @@ class SandboxMatcher:
211213
instances.
212214
"""
213215

214-
@staticmethod
215-
def nested_child(path: Sequence[str], child: SandboxMatcher) -> SandboxMatcher:
216+
@classmethod
217+
def nested_child(cls, path: Sequence[str], child: SandboxMatcher) -> SandboxMatcher:
216218
"""Create a matcher where the given child is put at the given path.
217219
218220
Args:
@@ -224,12 +226,12 @@ def nested_child(path: Sequence[str], child: SandboxMatcher) -> SandboxMatcher:
224226
"""
225227
ret = child
226228
for key in reversed(path):
227-
ret = SandboxMatcher(children={key: ret})
229+
ret = cls(children={key: ret})
228230
return ret
229231

230232
access: Set[str] = frozenset() # type: ignore
231233
"""Immutable set of names to match access.
232-
234+
233235
This is often only used for pass through checks and not member restrictions.
234236
If this is used for member restrictions, even importing/accessing the value
235237
will fail as opposed to :py:attr:`use` which is for when it is used.
@@ -239,7 +241,7 @@ def nested_child(path: Sequence[str], child: SandboxMatcher) -> SandboxMatcher:
239241

240242
use: Set[str] = frozenset() # type: ignore
241243
"""Immutable set of names to match use.
242-
244+
243245
This is best used for member restrictions on functions/classes because the
244246
restriction will not apply to referencing/importing the item, just when it
245247
is used.
@@ -275,7 +277,7 @@ def nested_child(path: Sequence[str], child: SandboxMatcher) -> SandboxMatcher:
275277

276278
exclude: Set[str] = frozenset() # type: ignore
277279
"""Immutable set of names to exclude.
278-
280+
279281
These override anything that may have been matched elsewhere.
280282
"""
281283

tests/nexus/test_handler.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import uuid
2121
from collections.abc import Mapping
2222
from concurrent.futures.thread import ThreadPoolExecutor
23-
from dataclasses import dataclass
23+
from dataclasses import dataclass, field
2424
from types import MappingProxyType
2525
from typing import Any, Callable, Optional, Union
2626

@@ -313,7 +313,9 @@ async def non_serializable_output(
313313
class SuccessfulResponse:
314314
status_code: int
315315
body_json: Optional[Union[dict[str, Any], Callable[[dict[str, Any]], bool]]] = None
316-
headers: Mapping[str, str] = SUCCESSFUL_RESPONSE_HEADERS
316+
headers: Mapping[str, str] = field(
317+
default_factory=lambda: SUCCESSFUL_RESPONSE_HEADERS
318+
)
317319

318320

319321
@dataclass
@@ -325,7 +327,9 @@ class UnsuccessfulResponse:
325327
# Expected value of inverse of non_retryable attribute of exception.
326328
retryable_exception: bool = True
327329
body_json: Optional[Callable[[dict[str, Any]], bool]] = None
328-
headers: Mapping[str, str] = UNSUCCESSFUL_RESPONSE_HEADERS
330+
headers: Mapping[str, str] = field(
331+
default_factory=lambda: UNSUCCESSFUL_RESPONSE_HEADERS
332+
)
329333

330334

331335
class _TestCase:

0 commit comments

Comments
 (0)