Skip to content

Commit fa5d9b7

Browse files
committed
Merge branch 'main' into nexus-metric-meter
2 parents 49cd4e7 + cc19379 commit fa5d9b7

File tree

8 files changed

+117
-31
lines changed

8 files changed

+117
-31
lines changed

temporalio/client.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ async def connect(
125125
default_workflow_query_reject_condition: Optional[
126126
temporalio.common.QueryRejectCondition
127127
] = None,
128-
tls: Union[bool, TLSConfig] = False,
128+
tls: Union[bool, TLSConfig, None] = None,
129129
retry_config: Optional[RetryConfig] = None,
130130
keep_alive_config: Optional[KeepAliveConfig] = KeepAliveConfig.default,
131131
rpc_metadata: Mapping[str, Union[str, bytes]] = {},
@@ -166,9 +166,11 @@ async def connect(
166166
condition for workflow queries if not set during query. See
167167
:py:meth:`WorkflowHandle.query` for details on the rejection
168168
condition.
169-
tls: If false, the default, do not use TLS. If true, use system
170-
default TLS configuration. If TLS configuration present, that
171-
TLS configuration will be used.
169+
tls: If ``None``, the default, TLS will be enabled automatically
170+
when ``api_key`` is provided, otherwise TLS is disabled. If
171+
``False``, do not use TLS. If ``True``, use system default TLS
172+
configuration. If TLS configuration present, that TLS
173+
configuration will be used.
172174
retry_config: Retry configuration for direct service calls (when
173175
opted in) or all high-level calls made by this client (which all
174176
opt-in to retries by default). If unset, a default retry
@@ -247,6 +249,7 @@ def __init__(
247249
default_workflow_query_reject_condition=default_workflow_query_reject_condition,
248250
header_codec_behavior=header_codec_behavior,
249251
)
252+
self._initial_config = config.copy()
250253

251254
for plugin in plugins:
252255
config = plugin.configure_client(config)
@@ -261,12 +264,16 @@ def _init_from_config(self, config: ClientConfig):
261264
for interceptor in reversed(list(self._config["interceptors"])):
262265
self._impl = interceptor.intercept_client(self._impl)
263266

264-
def config(self) -> ClientConfig:
267+
def config(self, *, active_config: bool = False) -> ClientConfig:
265268
"""Config, as a dictionary, used to create this client.
266269
270+
Args:
271+
active_config: If true, return the modified configuration in use rather than the initial one
272+
provided to the client.
273+
267274
This makes a shallow copy of the config each call.
268275
"""
269-
config = self._config.copy()
276+
config = self._config.copy() if active_config else self._initial_config.copy()
270277
config["interceptors"] = list(config["interceptors"])
271278
return config
272279

@@ -4290,7 +4297,8 @@ async def _to_proto(
42904297
await _apply_headers(
42914298
self.headers,
42924299
action.start_workflow.header.fields,
4293-
client.config()["header_codec_behavior"] == HeaderCodecBehavior.CODEC
4300+
client.config(active_config=True)["header_codec_behavior"]
4301+
== HeaderCodecBehavior.CODEC
42944302
and not self._from_raw,
42954303
client.data_converter.payload_codec,
42964304
)
@@ -6917,7 +6925,8 @@ async def _apply_headers(
69176925
await _apply_headers(
69186926
source,
69196927
dest,
6920-
self._client.config()["header_codec_behavior"] == HeaderCodecBehavior.CODEC,
6928+
self._client.config(active_config=True)["header_codec_behavior"]
6929+
== HeaderCodecBehavior.CODEC,
69216930
self._client.data_converter.payload_codec,
69226931
)
69236932

temporalio/service.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ class ConnectConfig:
136136

137137
target_host: str
138138
api_key: Optional[str] = None
139-
tls: Union[bool, TLSConfig] = False
139+
tls: Union[bool, TLSConfig, None] = None
140140
retry_config: Optional[RetryConfig] = None
141141
keep_alive_config: Optional[KeepAliveConfig] = KeepAliveConfig.default
142142
rpc_metadata: Mapping[str, Union[str, bytes]] = field(default_factory=dict)
@@ -172,6 +172,10 @@ def _to_bridge_config(self) -> temporalio.bridge.client.ClientConfig:
172172
elif self.tls:
173173
target_url = f"https://{self.target_host}"
174174
tls_config = TLSConfig()._to_bridge_config()
175+
# Enable TLS by default when API key is provided and tls not explicitly set
176+
elif self.tls is None and self.api_key is not None:
177+
target_url = f"https://{self.target_host}"
178+
tls_config = TLSConfig()._to_bridge_config()
175179
else:
176180
target_url = f"http://{self.target_host}"
177181
tls_config = None

temporalio/worker/_replayer.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def __init__(
8181
disable_safe_workflow_eviction=disable_safe_workflow_eviction,
8282
header_codec_behavior=header_codec_behavior,
8383
)
84+
self._initial_config = self._config.copy()
8485

8586
# Apply plugin configuration
8687
self.plugins = plugins
@@ -91,13 +92,17 @@ def __init__(
9192
if not self._config["workflows"]:
9293
raise ValueError("At least one workflow must be specified")
9394

94-
def config(self) -> ReplayerConfig:
95+
def config(self, *, active_config: bool = False) -> ReplayerConfig:
9596
"""Config, as a dictionary, used to create this replayer.
9697
98+
Args:
99+
active_config: If true, return the modified configuration in use rather than the initial one
100+
provided to the client.
101+
97102
Returns:
98103
Configuration, shallow-copied.
99104
"""
100-
config = self._config.copy()
105+
config = self._config.copy() if active_config else self._initial_config.copy()
101106
config["workflows"] = list(config["workflows"])
102107
return config
103108

temporalio/worker/_worker.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -375,8 +375,9 @@ def __init__(
375375
f"The same plugin type {type(client_plugin)} is present from both client and worker. It may run twice and may not be the intended behavior."
376376
)
377377
plugins = plugins_from_client + list(plugins)
378+
self._initial_config = config.copy()
378379

379-
self.plugins = plugins
380+
self._plugins = plugins
380381
for plugin in plugins:
381382
config = plugin.configure_worker(config)
382383

@@ -387,7 +388,6 @@ def _init_from_config(self, client: temporalio.client.Client, config: WorkerConf
387388
Client is safe to take separately since it can't be modified by worker plugins.
388389
"""
389390
self._config = config
390-
391391
if not (
392392
config["activities"]
393393
or config["nexus_service_handlers"]
@@ -408,7 +408,7 @@ def _init_from_config(self, client: temporalio.client.Client, config: WorkerConf
408408
)
409409

410410
# Prepend applicable client interceptors to the given ones
411-
client_config = config["client"].config()
411+
client_config = config["client"].config(active_config=True)
412412
interceptors_from_client = cast(
413413
List[Interceptor],
414414
[i for i in client_config["interceptors"] if isinstance(i, Interceptor)],
@@ -554,7 +554,7 @@ def check_activity(activity):
554554
maximum=config["max_concurrent_activity_task_polls"]
555555
)
556556

557-
deduped_plugin_names = list(set([plugin.name() for plugin in self.plugins]))
557+
deduped_plugin_names = list(set([plugin.name() for plugin in self._plugins]))
558558

559559
# Create bridge worker last. We have empirically observed that if it is
560560
# created before an error is raised from the activity worker
@@ -622,13 +622,17 @@ def check_activity(activity):
622622
),
623623
)
624624

625-
def config(self) -> WorkerConfig:
625+
def config(self, *, active_config: bool = False) -> WorkerConfig:
626626
"""Config, as a dictionary, used to create this worker.
627627
628+
Args:
629+
active_config: If true, return the modified configuration in use rather than the initial one
630+
provided to the worker.
631+
628632
Returns:
629633
Configuration, shallow-copied.
630634
"""
631-
config = self._config.copy()
635+
config = self._config.copy() if active_config else self._initial_config.copy()
632636
config["activities"] = list(config.get("activities", []))
633637
config["workflows"] = list(config.get("workflows", []))
634638
return config
@@ -702,7 +706,7 @@ def make_lambda(plugin, next):
702706
return lambda w: plugin.run_worker(w, next)
703707

704708
next_function = lambda w: w._run()
705-
for plugin in reversed(self.plugins):
709+
for plugin in reversed(self._plugins):
706710
next_function = make_lambda(plugin, next_function)
707711

708712
await next_function(self)

tests/api/test_grpc_stub.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ async def test_grpc_metadata():
129129
f"localhost:{port}",
130130
api_key="my-api-key",
131131
rpc_metadata={"my-meta-key": "my-meta-val"},
132+
tls=False,
132133
)
133134
workflow_server.assert_last_metadata(
134135
{

tests/test_envconfig.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1024,7 +1024,7 @@ async def test_e2e_multi_profile_different_client_connections(client: Client):
10241024
assert dev_client.service_client.config.target_host == target_host
10251025
assert dev_client.namespace == "dev"
10261026
assert dev_client.service_client.config.api_key is None
1027-
assert dev_client.service_client.config.tls is False
1027+
assert dev_client.service_client.config.tls is None
10281028

10291029
assert prod_client.service_client.config.target_host == target_host
10301030
assert prod_client.namespace == "prod"

tests/test_plugins.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ async def connect_service_client(
5656
next: Callable[[ConnectConfig], Awaitable[ServiceClient]],
5757
) -> ServiceClient:
5858
config.api_key = "replaced key"
59+
config.tls = False
5960
return await next(config)
6061

6162

@@ -150,7 +151,7 @@ async def test_worker_plugin_basic_config(client: Client) -> None:
150151
activities=[never_run_activity],
151152
plugins=[MyWorkerPlugin()],
152153
)
153-
task_queue = worker.config().get("task_queue")
154+
task_queue = worker.config(active_config=True).get("task_queue")
154155
assert task_queue is not None and task_queue.startswith("replaced_queue")
155156

156157
# Test client plugin propagation to worker plugins
@@ -160,7 +161,7 @@ async def test_worker_plugin_basic_config(client: Client) -> None:
160161
worker = Worker(
161162
client, task_queue="queue" + str(uuid.uuid4()), activities=[never_run_activity]
162163
)
163-
task_queue = worker.config().get("task_queue")
164+
task_queue = worker.config(active_config=True).get("task_queue")
164165
assert task_queue is not None and task_queue.startswith("combined")
165166

166167
# Test both. Client propagated plugins are called first, so the worker plugin overrides in this case
@@ -170,7 +171,7 @@ async def test_worker_plugin_basic_config(client: Client) -> None:
170171
activities=[never_run_activity],
171172
plugins=[MyWorkerPlugin()],
172173
)
173-
task_queue = worker.config().get("task_queue")
174+
task_queue = worker.config(active_config=True).get("task_queue")
174175
assert task_queue is not None and task_queue.startswith("replaced_queue")
175176

176177

@@ -202,7 +203,8 @@ async def test_worker_sandbox_restrictions(client: Client) -> None:
202203
assert (
203204
"my_module"
204205
in cast(
205-
SandboxedWorkflowRunner, worker.config().get("workflow_runner")
206+
SandboxedWorkflowRunner,
207+
worker.config(active_config=True).get("workflow_runner"),
206208
).restrictions.passthrough_modules
207209
)
208210

@@ -276,8 +278,11 @@ async def test_replay(client: Client) -> None:
276278
)
277279
await handle.result()
278280
replayer = Replayer(workflows=[], plugins=[plugin])
279-
assert len(replayer.config().get("workflows") or []) == 1
280-
assert replayer.config().get("data_converter") == pydantic_data_converter
281+
assert len(replayer.config(active_config=True).get("workflows") or []) == 1
282+
assert (
283+
replayer.config(active_config=True).get("data_converter")
284+
== pydantic_data_converter
285+
)
281286

282287
await replayer.replay_workflow(await handle.fetch_history())
283288

@@ -303,19 +308,28 @@ async def test_simple_plugins(client: Client) -> None:
303308
plugins=[plugin],
304309
)
305310
# On a sequence, a value is appended
306-
assert worker.config().get("workflows") == [HelloWorkflow, HelloWorkflow2]
311+
assert worker.config(active_config=True).get("workflows") == [
312+
HelloWorkflow,
313+
HelloWorkflow2,
314+
]
307315

308316
# Test with plugin registered in client
309317
worker = Worker(
310318
new_client,
311319
task_queue="queue" + str(uuid.uuid4()),
312320
activities=[never_run_activity],
313321
)
314-
assert worker.config().get("workflows") == [HelloWorkflow2]
322+
assert worker.config(active_config=True).get("workflows") == [HelloWorkflow2]
315323

316324
replayer = Replayer(workflows=[HelloWorkflow], plugins=[plugin])
317-
assert replayer.config().get("data_converter") == pydantic_data_converter
318-
assert replayer.config().get("workflows") == [HelloWorkflow, HelloWorkflow2]
325+
assert (
326+
replayer.config(active_config=True).get("data_converter")
327+
== pydantic_data_converter
328+
)
329+
assert replayer.config(active_config=True).get("workflows") == [
330+
HelloWorkflow,
331+
HelloWorkflow2,
332+
]
319333

320334

321335
async def test_simple_plugins_callables(client: Client) -> None:
@@ -350,7 +364,7 @@ def converter(old: Optional[DataConverter]):
350364
activities=[never_run_activity],
351365
plugins=[plugin],
352366
)
353-
assert worker.config().get("workflows") == []
367+
assert worker.config(active_config=True).get("workflows") == []
354368

355369

356370
class MediumPlugin(SimplePlugin):
@@ -371,5 +385,5 @@ async def test_medium_plugin(client: Client) -> None:
371385
plugins=[plugin],
372386
workflows=[HelloWorkflow],
373387
)
374-
task_queue = worker.config().get("task_queue")
388+
task_queue = worker.config(active_config=True).get("task_queue")
375389
assert task_queue is not None and task_queue.startswith("override")

tests/test_service.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,55 @@ async def test_grpc_status(client: Client, env: WorkflowEnvironment):
171171
)
172172

173173

174+
def test_connect_config_tls_enabled_by_default_when_api_key_provided():
175+
"""Test that TLS is enabled by default when API key is provided and tls is not configured."""
176+
config = temporalio.service.ConnectConfig(
177+
target_host="localhost:7233",
178+
api_key="test-api-key",
179+
)
180+
# TLS should be auto-enabled when api_key is provided and tls not explicitly set
181+
bridge_config = config._to_bridge_config()
182+
assert bridge_config.target_url == "https://localhost:7233"
183+
assert bridge_config.tls_config is not None
184+
185+
186+
def test_connect_config_tls_can_be_explicitly_disabled_even_when_api_key_provided():
187+
"""Test that TLS can be explicitly disabled even when API key is provided."""
188+
config = temporalio.service.ConnectConfig(
189+
target_host="localhost:7233",
190+
api_key="test-api-key",
191+
tls=False,
192+
)
193+
# TLS should remain disabled when explicitly set to False
194+
assert config.tls is False
195+
196+
197+
def test_connect_config_tls_disabled_by_default_when_no_api_key():
198+
"""Test that TLS is disabled by default when no API key is provided."""
199+
config = temporalio.service.ConnectConfig(
200+
target_host="localhost:7233",
201+
)
202+
# TLS should remain disabled when no api_key is provided
203+
bridge_config = config._to_bridge_config()
204+
assert bridge_config.target_url == "http://localhost:7233"
205+
assert bridge_config.tls_config is None
206+
207+
208+
def test_connect_config_tls_explicit_config_preserved():
209+
"""Test that explicit TLS configuration is preserved regardless of API key."""
210+
tls_config = temporalio.service.TLSConfig(
211+
server_root_ca_cert=b"test-cert",
212+
domain="test-domain",
213+
)
214+
config = temporalio.service.ConnectConfig(
215+
target_host="localhost:7233",
216+
api_key="test-api-key",
217+
tls=tls_config,
218+
)
219+
# Explicit TLS config should be preserved
220+
assert config.tls == tls_config
221+
222+
174223
async def test_rpc_execution_not_unknown(client: Client):
175224
"""
176225
Execute each rpc method and expect a failure, but ensure the failure is not that the rpc method is unknown

0 commit comments

Comments
 (0)