Skip to content

Commit eeeb627

Browse files
authored
Merge branch 'main' into pyright_exclusions
2 parents 00d0970 + 90dda94 commit eeeb627

File tree

14 files changed

+633
-107
lines changed

14 files changed

+633
-107
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ jobs:
3232
- os: ubuntu-latest
3333
python: "3.10"
3434
protoCheckTarget: true
35+
- python: "3.10"
36+
pytestExtraArgs: "--reruns 3 --only-rerun \"RuntimeError: Failed validating workflow\""
3537
- os: ubuntu-arm
3638
runsOn: ubuntu-24.04-arm64-2-core
3739
- os: macos-intel

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ dev = [
5959
"openai-agents>=0.3,<0.5; python_version >= '3.14'",
6060
"openai-agents[litellm]>=0.3,<0.4; python_version < '3.14'",
6161
"googleapis-common-protos==1.70.0",
62+
"pytest-rerunfailures>=16.1",
6263
]
6364

6465
[tool.poe.tasks]

temporalio/client.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ async def connect(
125125
default_workflow_query_reject_condition: Optional[
126126
temporalio.common.QueryRejectCondition
127127
] = None,
128-
tls: Union[bool, TLSConfig] = False,
128+
tls: Union[bool, TLSConfig, None] = None,
129129
retry_config: Optional[RetryConfig] = None,
130130
keep_alive_config: Optional[KeepAliveConfig] = KeepAliveConfig.default,
131131
rpc_metadata: Mapping[str, Union[str, bytes]] = {},
@@ -166,9 +166,11 @@ async def connect(
166166
condition for workflow queries if not set during query. See
167167
:py:meth:`WorkflowHandle.query` for details on the rejection
168168
condition.
169-
tls: If false, the default, do not use TLS. If true, use system
170-
default TLS configuration. If TLS configuration present, that
171-
TLS configuration will be used.
169+
tls: If ``None``, the default, TLS will be enabled automatically
170+
when ``api_key`` is provided, otherwise TLS is disabled. If
171+
``False``, do not use TLS. If ``True``, use system default TLS
172+
configuration. If TLS configuration present, that TLS
173+
configuration will be used.
172174
retry_config: Retry configuration for direct service calls (when
173175
opted in) or all high-level calls made by this client (which all
174176
opt-in to retries by default). If unset, a default retry

temporalio/nexus/_link_conversion.py

Lines changed: 107 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -24,26 +24,31 @@
2424
)
2525
LINK_EVENT_ID_PARAM_NAME = "eventID"
2626
LINK_EVENT_TYPE_PARAM_NAME = "eventType"
27+
LINK_REQUEST_ID_PARAM_NAME = "requestID"
28+
LINK_REFERENCE_TYPE_PARAM_NAME = "referenceType"
29+
30+
EVENT_REFERENCE_TYPE = "EventReference"
31+
REQUEST_ID_REFERENCE_TYPE = "RequestIdReference"
2732

2833

2934
def workflow_execution_started_event_link_from_workflow_handle(
30-
handle: temporalio.client.WorkflowHandle[Any, Any],
35+
handle: temporalio.client.WorkflowHandle[Any, Any], request_id: str
3136
) -> temporalio.api.common.v1.Link.WorkflowEvent:
3237
"""Create a WorkflowEvent link corresponding to a started workflow"""
3338
if handle.first_execution_run_id is None:
3439
raise ValueError(
3540
f"Workflow handle {handle} has no first execution run ID. "
3641
f"Cannot create WorkflowExecutionStarted event link."
3742
)
43+
3844
return temporalio.api.common.v1.Link.WorkflowEvent(
3945
namespace=handle._client.namespace,
4046
workflow_id=handle.id,
4147
run_id=handle.first_execution_run_id,
42-
event_ref=temporalio.api.common.v1.Link.WorkflowEvent.EventReference(
43-
event_id=1,
48+
request_id_ref=temporalio.api.common.v1.Link.WorkflowEvent.RequestIdReference(
49+
request_id=request_id,
4450
event_type=temporalio.api.enums.v1.EventType.EVENT_TYPE_WORKFLOW_EXECUTION_STARTED,
4551
),
46-
# TODO(nexus-preview): RequestIdReference
4752
)
4853

4954

@@ -60,9 +65,21 @@ def workflow_event_to_nexus_link(
6065
workflow_id = urllib.parse.quote(workflow_event.workflow_id)
6166
run_id = urllib.parse.quote(workflow_event.run_id)
6267
path = f"/namespaces/{namespace}/workflows/{workflow_id}/{run_id}/history"
63-
query_params = _event_reference_to_query_params(workflow_event.event_ref)
68+
69+
query_params = None
70+
match workflow_event.WhichOneof("reference"):
71+
case "event_ref":
72+
query_params = _event_reference_to_query_params(workflow_event.event_ref)
73+
case "request_id_ref":
74+
query_params = _request_id_reference_to_query_params(
75+
workflow_event.request_id_ref
76+
)
77+
78+
# urllib will omit '//' from the url if netloc is empty so we add the scheme manually
79+
url = f"{scheme}://{urllib.parse.urlunparse(('', '', path, '', query_params, ''))}"
80+
6481
return nexusrpc.Link(
65-
url=urllib.parse.urlunparse((scheme, "", path, "", query_params, "")),
82+
url=url,
6683
type=workflow_event.DESCRIPTOR.full_name,
6784
)
6885

@@ -83,7 +100,20 @@ def nexus_link_to_workflow_event(
83100
)
84101
return None
85102
try:
86-
event_ref = _query_params_to_event_reference(url.query)
103+
query_params = urllib.parse.parse_qs(url.query)
104+
105+
request_id_ref = None
106+
event_ref = None
107+
match query_params.get(LINK_REFERENCE_TYPE_PARAM_NAME):
108+
case ["EventReference"]:
109+
event_ref = _query_params_to_event_reference(query_params)
110+
case ["RequestIdReference"]:
111+
request_id_ref = _query_params_to_request_id_reference(query_params)
112+
case _:
113+
raise ValueError(
114+
f"Invalid Nexus link: {link}. Expected {LINK_REFERENCE_TYPE_PARAM_NAME} to be '{EVENT_REFERENCE_TYPE}' or '{REQUEST_ID_REFERENCE_TYPE}'"
115+
)
116+
87117
except ValueError as err:
88118
logger.warning(
89119
f"Failed to parse event reference from Nexus link URL query parameters: {link} ({err})"
@@ -96,6 +126,7 @@ def nexus_link_to_workflow_event(
96126
workflow_id=urllib.parse.unquote(groups["workflow_id"]),
97127
run_id=urllib.parse.unquote(groups["run_id"]),
98128
event_ref=event_ref,
129+
request_id_ref=request_id_ref,
99130
)
100131

101132

@@ -109,36 +140,58 @@ def _event_reference_to_query_params(
109140
)
110141
return urllib.parse.urlencode(
111142
{
112-
"eventID": event_ref.event_id,
113-
"eventType": event_type_name,
114-
"referenceType": "EventReference",
143+
LINK_EVENT_ID_PARAM_NAME: event_ref.event_id,
144+
LINK_EVENT_TYPE_PARAM_NAME: event_type_name,
145+
LINK_REFERENCE_TYPE_PARAM_NAME: EVENT_REFERENCE_TYPE,
115146
}
116147
)
117148

118149

150+
def _request_id_reference_to_query_params(
151+
request_id_ref: temporalio.api.common.v1.Link.WorkflowEvent.RequestIdReference,
152+
) -> str:
153+
params = {
154+
LINK_REFERENCE_TYPE_PARAM_NAME: REQUEST_ID_REFERENCE_TYPE,
155+
}
156+
157+
if request_id_ref.request_id:
158+
params[LINK_REQUEST_ID_PARAM_NAME] = request_id_ref.request_id
159+
160+
event_type_name = temporalio.api.enums.v1.EventType.Name(request_id_ref.event_type)
161+
if event_type_name.startswith("EVENT_TYPE_"):
162+
event_type_name = _event_type_constant_case_to_pascal_case(
163+
event_type_name.removeprefix("EVENT_TYPE_")
164+
)
165+
params[LINK_EVENT_TYPE_PARAM_NAME] = event_type_name
166+
167+
return urllib.parse.urlencode(params)
168+
169+
119170
def _query_params_to_event_reference(
120-
raw_query_params: str,
171+
query_params: dict[str, list[str]],
121172
) -> temporalio.api.common.v1.Link.WorkflowEvent.EventReference:
122173
"""Return an EventReference from the query params or raise ValueError."""
123-
query_params = urllib.parse.parse_qs(raw_query_params)
124-
125-
[reference_type] = query_params.get("referenceType") or [""]
126-
if reference_type != "EventReference":
174+
[reference_type] = query_params.get(LINK_REFERENCE_TYPE_PARAM_NAME) or [""]
175+
if reference_type != EVENT_REFERENCE_TYPE:
127176
raise ValueError(
128177
f"Expected Nexus link URL query parameter referenceType to be EventReference but got: {reference_type}"
129178
)
179+
130180
# event type
131-
[raw_event_type_name] = query_params.get(LINK_EVENT_TYPE_PARAM_NAME) or [""]
132-
if not raw_event_type_name:
133-
raise ValueError(f"query params do not contain event type: {query_params}")
134-
if raw_event_type_name.startswith("EVENT_TYPE_"):
135-
event_type_name = raw_event_type_name
136-
elif re.match("[A-Z][a-z]", raw_event_type_name):
137-
event_type_name = "EVENT_TYPE_" + _event_type_pascal_case_to_constant_case(
138-
raw_event_type_name
139-
)
140-
else:
141-
raise ValueError(f"Invalid event type name: {raw_event_type_name}")
181+
match query_params.get(LINK_EVENT_TYPE_PARAM_NAME):
182+
case None:
183+
raise ValueError(f"query params do not contain event type: {query_params}")
184+
185+
case [raw_event_type_name] if raw_event_type_name.startswith("EVENT_TYPE_"):
186+
event_type_name = raw_event_type_name
187+
188+
case [raw_event_type_name] if re.match("[A-Z][a-z]", raw_event_type_name):
189+
event_type_name = "EVENT_TYPE_" + _event_type_pascal_case_to_constant_case(
190+
raw_event_type_name
191+
)
192+
193+
case raw_event_type_name:
194+
raise ValueError(f"Invalid event type name: {raw_event_type_name}")
142195

143196
# event id
144197
event_id = 0
@@ -155,6 +208,34 @@ def _query_params_to_event_reference(
155208
)
156209

157210

211+
def _query_params_to_request_id_reference(
212+
query_params: dict[str, list[str]],
213+
) -> temporalio.api.common.v1.Link.WorkflowEvent.RequestIdReference:
214+
"""Return an EventReference from the query params or raise ValueError."""
215+
# event type
216+
match query_params.get(LINK_EVENT_TYPE_PARAM_NAME):
217+
case None:
218+
raise ValueError(f"query params do not contain event type: {query_params}")
219+
220+
case [raw_event_type_name] if raw_event_type_name.startswith("EVENT_TYPE_"):
221+
event_type_name = raw_event_type_name
222+
223+
case [raw_event_type_name] if re.match("[A-Z][a-z]", raw_event_type_name):
224+
event_type_name = "EVENT_TYPE_" + _event_type_pascal_case_to_constant_case(
225+
raw_event_type_name
226+
)
227+
228+
case raw_event_type_name:
229+
raise ValueError(f"Invalid event type name: {raw_event_type_name}")
230+
231+
[request_id] = query_params.get(LINK_REQUEST_ID_PARAM_NAME, [""])
232+
233+
return temporalio.api.common.v1.Link.WorkflowEvent.RequestIdReference(
234+
request_id=request_id,
235+
event_type=temporalio.api.enums.v1.EventType.Value(event_type_name),
236+
)
237+
238+
158239
def _event_type_constant_case_to_pascal_case(s: str) -> str:
159240
"""Convert a CONSTANT_CASE string to PascalCase.
160241

temporalio/nexus/_operation_context.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,8 @@ def _add_outbound_links(
183183
if not wf_event_links:
184184
wf_event_links = [
185185
_link_conversion.workflow_execution_started_event_link_from_workflow_handle(
186-
workflow_handle
186+
workflow_handle,
187+
self.nexus_context.request_id,
187188
)
188189
]
189190
self.nexus_context.outbound_links.extend(

temporalio/service.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ class ConnectConfig:
136136

137137
target_host: str
138138
api_key: Optional[str] = None
139-
tls: Union[bool, TLSConfig] = False
139+
tls: Union[bool, TLSConfig, None] = None
140140
retry_config: Optional[RetryConfig] = None
141141
keep_alive_config: Optional[KeepAliveConfig] = KeepAliveConfig.default
142142
rpc_metadata: Mapping[str, Union[str, bytes]] = field(default_factory=dict)
@@ -172,6 +172,10 @@ def _to_bridge_config(self) -> temporalio.bridge.client.ClientConfig:
172172
elif self.tls:
173173
target_url = f"https://{self.target_host}"
174174
tls_config = TLSConfig()._to_bridge_config()
175+
# Enable TLS by default when API key is provided and tls not explicitly set
176+
elif self.tls is None and self.api_key is not None:
177+
target_url = f"https://{self.target_host}"
178+
tls_config = TLSConfig()._to_bridge_config()
175179
else:
176180
target_url = f"http://{self.target_host}"
177181
tls_config = None

tests/api/test_grpc_stub.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ async def test_grpc_metadata():
129129
f"localhost:{port}",
130130
api_key="my-api-key",
131131
rpc_metadata={"my-meta-key": "my-meta-val"},
132+
tls=False,
132133
)
133134
workflow_server.assert_last_metadata(
134135
{

tests/conftest.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@
77
import pytest
88
import pytest_asyncio
99

10+
from temporalio.client import Client
11+
from temporalio.testing import WorkflowEnvironment
12+
from temporalio.worker import SharedStateManager
13+
from tests.helpers.worker import ExternalPythonWorker, ExternalWorker
14+
1015
from . import DEV_SERVER_DOWNLOAD_VERSION
1116

1217
# If there is an integration test environment variable set, we must remove the
@@ -38,10 +43,6 @@
3843
or protobuf_version.startswith("6.")
3944
), f"Expected protobuf 4.x/5.x/6.x, got {protobuf_version}"
4045

41-
from temporalio.client import Client
42-
from temporalio.testing import WorkflowEnvironment
43-
from tests.helpers.worker import ExternalPythonWorker, ExternalWorker
44-
4546

4647
def pytest_runtest_setup(item):
4748
"""Print a newline so that custom printed output starts on new line."""
@@ -134,6 +135,17 @@ async def env(env_type: str) -> AsyncGenerator[WorkflowEnvironment, None]:
134135
await env.shutdown()
135136

136137

138+
@pytest.fixture(scope="session")
139+
def shared_state_manager() -> Iterator[SharedStateManager]:
140+
mp_mgr = multiprocessing.Manager()
141+
mgr = SharedStateManager.create_from_multiprocessing(mp_mgr)
142+
143+
try:
144+
yield mgr
145+
finally:
146+
mp_mgr.shutdown()
147+
148+
137149
@pytest.fixture(scope="session")
138150
def mp_fork_ctx() -> Iterator[multiprocessing.context.BaseContext | None]:
139151
mp_ctx = None

0 commit comments

Comments
 (0)