Skip to content

Commit 6f1d62b

Browse files
fix typing when passing loaded config into Client.connect, raise error if 'address' not provided by config (#998)
Co-authored-by: tconley1428 <[email protected]>
1 parent c8bc329 commit 6f1d62b

File tree

2 files changed

+33
-18
lines changed

2 files changed

+33
-18
lines changed

temporalio/envconfig.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from dataclasses import dataclass, field
1010
from pathlib import Path
11-
from typing import Any, Literal, Mapping, Optional, Union, cast
11+
from typing import Any, Dict, Literal, Mapping, Optional, Union, cast
1212

1313
from typing_extensions import TypeAlias, TypedDict
1414

@@ -172,11 +172,11 @@ class ClientConnectConfig(TypedDict, total=False):
172172
Experimental API.
173173
"""
174174

175-
target_host: Optional[str]
176-
namespace: Optional[str]
177-
api_key: Optional[str]
178-
tls: Optional[Union[bool, temporalio.service.TLSConfig]]
179-
rpc_metadata: Optional[Mapping[str, str]]
175+
target_host: str
176+
namespace: str
177+
api_key: str
178+
tls: Union[bool, temporalio.service.TLSConfig]
179+
rpc_metadata: Mapping[str, str]
180180

181181

182182
@dataclass(frozen=True)
@@ -230,18 +230,26 @@ def to_dict(self) -> ClientConfigProfileDict:
230230

231231
def to_client_connect_config(self) -> ClientConnectConfig:
232232
"""Create a `ClientConnectConfig` from this profile."""
233-
config: ClientConnectConfig = {}
234-
if self.address:
235-
config["target_host"] = self.address
236-
if self.namespace:
233+
if not self.address:
234+
raise ValueError(
235+
"Configuration profile must contain an 'address' to be used for "
236+
"client connection"
237+
)
238+
239+
# Only include non-None values
240+
config: Dict[str, Any] = {}
241+
config["target_host"] = self.address
242+
if self.namespace is not None:
237243
config["namespace"] = self.namespace
238-
if self.api_key:
244+
if self.api_key is not None:
239245
config["api_key"] = self.api_key
240-
if self.tls:
246+
if self.tls is not None:
241247
config["tls"] = self.tls.to_connect_tls_config()
242248
if self.grpc_meta:
243249
config["rpc_metadata"] = self.grpc_meta
244-
return config
250+
251+
# Cast to ClientConnectConfig - this is safe because we've only included non-None values
252+
return cast(ClientConnectConfig, config)
245253

246254
@staticmethod
247255
def load(

tests/test_envconfig.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ async def test_load_client_connect_config(client: Client, tmp_path: Path):
449449
config = ClientConfig.load_client_connect_config(config_file=str(config_file))
450450
assert config.get("target_host") == target_host
451451
assert config.get("namespace") == namespace
452-
new_client = await Client.connect(**config) # type: ignore
452+
new_client = await Client.connect(**config)
453453
assert new_client.service_client.config.target_host == target_host
454454
assert new_client.namespace == namespace
455455

@@ -462,7 +462,7 @@ async def test_load_client_connect_config(client: Client, tmp_path: Path):
462462
rpc_metadata = config.get("rpc_metadata")
463463
assert rpc_metadata
464464
assert "custom-header" in rpc_metadata
465-
new_client = await Client.connect(**config) # type: ignore
465+
new_client = await Client.connect(**config)
466466
assert new_client.service_client.config.target_host == target_host
467467
assert new_client.namespace == "custom-namespace"
468468
assert (
@@ -476,7 +476,7 @@ async def test_load_client_connect_config(client: Client, tmp_path: Path):
476476
)
477477
assert config.get("target_host") == target_host
478478
assert config.get("namespace") == "env-namespace-override"
479-
new_client = await Client.connect(**config) # type: ignore
479+
new_client = await Client.connect(**config)
480480
assert new_client.namespace == "env-namespace-override"
481481

482482
# Test with env overrides disabled
@@ -487,7 +487,7 @@ async def test_load_client_connect_config(client: Client, tmp_path: Path):
487487
)
488488
assert config.get("target_host") == target_host
489489
assert config.get("namespace") == namespace
490-
new_client = await Client.connect(**config) # type: ignore
490+
new_client = await Client.connect(**config)
491491
assert new_client.namespace == namespace
492492

493493
# Test with file loading disabled (so only env is used)
@@ -500,11 +500,18 @@ async def test_load_client_connect_config(client: Client, tmp_path: Path):
500500
)
501501
assert config.get("target_host") == target_host
502502
assert config.get("namespace") == "env-only-namespace"
503-
new_client = await Client.connect(**config) # type: ignore
503+
new_client = await Client.connect(**config)
504504
assert new_client.service_client.config.target_host == target_host
505505
assert new_client.namespace == "env-only-namespace"
506506

507507

508+
def test_to_client_connect_config_missing_address_fails():
509+
"""Test that to_client_connect_config raises a ValueError if address is missing."""
510+
profile = ClientConfigProfile()
511+
with pytest.raises(ValueError, match="must contain an 'address'"):
512+
profile.to_client_connect_config()
513+
514+
508515
def test_disables_raise_error():
509516
"""Test that providing both disable_file and disable_env raises an error."""
510517
with pytest.raises(RuntimeError, match="Cannot disable both"):

0 commit comments

Comments
 (0)