diff --git a/temporalio/envconfig.py b/temporalio/envconfig.py index 7d19881a1..95a73740d 100644 --- a/temporalio/envconfig.py +++ b/temporalio/envconfig.py @@ -8,7 +8,7 @@ from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Literal, Mapping, Optional, Union, cast +from typing import Any, Dict, Literal, Mapping, Optional, Union, cast from typing_extensions import TypeAlias, TypedDict @@ -172,11 +172,11 @@ class ClientConnectConfig(TypedDict, total=False): Experimental API. """ - target_host: Optional[str] - namespace: Optional[str] - api_key: Optional[str] - tls: Optional[Union[bool, temporalio.service.TLSConfig]] - rpc_metadata: Optional[Mapping[str, str]] + target_host: str + namespace: str + api_key: str + tls: Union[bool, temporalio.service.TLSConfig] + rpc_metadata: Mapping[str, str] @dataclass(frozen=True) @@ -230,18 +230,26 @@ def to_dict(self) -> ClientConfigProfileDict: def to_client_connect_config(self) -> ClientConnectConfig: """Create a `ClientConnectConfig` from this profile.""" - config: ClientConnectConfig = {} - if self.address: - config["target_host"] = self.address - if self.namespace: + if not self.address: + raise ValueError( + "Configuration profile must contain an 'address' to be used for " + "client connection" + ) + + # Only include non-None values + config: Dict[str, Any] = {} + config["target_host"] = self.address + if self.namespace is not None: config["namespace"] = self.namespace - if self.api_key: + if self.api_key is not None: config["api_key"] = self.api_key - if self.tls: + if self.tls is not None: config["tls"] = self.tls.to_connect_tls_config() if self.grpc_meta: config["rpc_metadata"] = self.grpc_meta - return config + + # Cast to ClientConnectConfig - this is safe because we've only included non-None values + return cast(ClientConnectConfig, config) @staticmethod def load( diff --git a/tests/test_envconfig.py b/tests/test_envconfig.py index dc675dc77..33e2433f6 100644 --- a/tests/test_envconfig.py +++ b/tests/test_envconfig.py @@ -449,7 +449,7 @@ async def test_load_client_connect_config(client: Client, tmp_path: Path): config = ClientConfig.load_client_connect_config(config_file=str(config_file)) assert config.get("target_host") == target_host assert config.get("namespace") == namespace - new_client = await Client.connect(**config) # type: ignore + new_client = await Client.connect(**config) assert new_client.service_client.config.target_host == target_host assert new_client.namespace == namespace @@ -462,7 +462,7 @@ async def test_load_client_connect_config(client: Client, tmp_path: Path): rpc_metadata = config.get("rpc_metadata") assert rpc_metadata assert "custom-header" in rpc_metadata - new_client = await Client.connect(**config) # type: ignore + new_client = await Client.connect(**config) assert new_client.service_client.config.target_host == target_host assert new_client.namespace == "custom-namespace" assert ( @@ -476,7 +476,7 @@ async def test_load_client_connect_config(client: Client, tmp_path: Path): ) assert config.get("target_host") == target_host assert config.get("namespace") == "env-namespace-override" - new_client = await Client.connect(**config) # type: ignore + new_client = await Client.connect(**config) assert new_client.namespace == "env-namespace-override" # Test with env overrides disabled @@ -487,7 +487,7 @@ async def test_load_client_connect_config(client: Client, tmp_path: Path): ) assert config.get("target_host") == target_host assert config.get("namespace") == namespace - new_client = await Client.connect(**config) # type: ignore + new_client = await Client.connect(**config) assert new_client.namespace == namespace # Test with file loading disabled (so only env is used) @@ -500,11 +500,18 @@ async def test_load_client_connect_config(client: Client, tmp_path: Path): ) assert config.get("target_host") == target_host assert config.get("namespace") == "env-only-namespace" - new_client = await Client.connect(**config) # type: ignore + new_client = await Client.connect(**config) assert new_client.service_client.config.target_host == target_host assert new_client.namespace == "env-only-namespace" +def test_to_client_connect_config_missing_address_fails(): + """Test that to_client_connect_config raises a ValueError if address is missing.""" + profile = ClientConfigProfile() + with pytest.raises(ValueError, match="must contain an 'address'"): + profile.to_client_connect_config() + + def test_disables_raise_error(): """Test that providing both disable_file and disable_env raises an error.""" with pytest.raises(RuntimeError, match="Cannot disable both"):