Skip to content

Commit 623ea6e

Browse files
committed
Merge branch 'main' into nexus-metric-meter
2 parents a878297 + 90dda94 commit 623ea6e

File tree

8 files changed

+572
-101
lines changed

8 files changed

+572
-101
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ jobs:
3232
- os: ubuntu-latest
3333
python: "3.10"
3434
protoCheckTarget: true
35+
- python: "3.10"
36+
pytestExtraArgs: "--reruns 3 --only-rerun \"RuntimeError: Failed validating workflow\""
3537
- os: ubuntu-arm
3638
runsOn: ubuntu-24.04-arm64-2-core
3739
- os: macos-intel

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ dev = [
5959
"openai-agents>=0.3,<0.5; python_version >= '3.14'",
6060
"openai-agents[litellm]>=0.3,<0.4; python_version < '3.14'",
6161
"googleapis-common-protos==1.70.0",
62+
"pytest-rerunfailures>=16.1",
6263
]
6364

6465
[tool.poe.tasks]

temporalio/nexus/_link_conversion.py

Lines changed: 107 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -24,26 +24,31 @@
2424
)
2525
LINK_EVENT_ID_PARAM_NAME = "eventID"
2626
LINK_EVENT_TYPE_PARAM_NAME = "eventType"
27+
LINK_REQUEST_ID_PARAM_NAME = "requestID"
28+
LINK_REFERENCE_TYPE_PARAM_NAME = "referenceType"
29+
30+
EVENT_REFERENCE_TYPE = "EventReference"
31+
REQUEST_ID_REFERENCE_TYPE = "RequestIdReference"
2732

2833

2934
def workflow_execution_started_event_link_from_workflow_handle(
30-
handle: temporalio.client.WorkflowHandle[Any, Any],
35+
handle: temporalio.client.WorkflowHandle[Any, Any], request_id: str
3136
) -> temporalio.api.common.v1.Link.WorkflowEvent:
3237
"""Create a WorkflowEvent link corresponding to a started workflow"""
3338
if handle.first_execution_run_id is None:
3439
raise ValueError(
3540
f"Workflow handle {handle} has no first execution run ID. "
3641
f"Cannot create WorkflowExecutionStarted event link."
3742
)
43+
3844
return temporalio.api.common.v1.Link.WorkflowEvent(
3945
namespace=handle._client.namespace,
4046
workflow_id=handle.id,
4147
run_id=handle.first_execution_run_id,
42-
event_ref=temporalio.api.common.v1.Link.WorkflowEvent.EventReference(
43-
event_id=1,
48+
request_id_ref=temporalio.api.common.v1.Link.WorkflowEvent.RequestIdReference(
49+
request_id=request_id,
4450
event_type=temporalio.api.enums.v1.EventType.EVENT_TYPE_WORKFLOW_EXECUTION_STARTED,
4551
),
46-
# TODO(nexus-preview): RequestIdReference
4752
)
4853

4954

@@ -60,9 +65,21 @@ def workflow_event_to_nexus_link(
6065
workflow_id = urllib.parse.quote(workflow_event.workflow_id)
6166
run_id = urllib.parse.quote(workflow_event.run_id)
6267
path = f"/namespaces/{namespace}/workflows/{workflow_id}/{run_id}/history"
63-
query_params = _event_reference_to_query_params(workflow_event.event_ref)
68+
69+
query_params = None
70+
match workflow_event.WhichOneof("reference"):
71+
case "event_ref":
72+
query_params = _event_reference_to_query_params(workflow_event.event_ref)
73+
case "request_id_ref":
74+
query_params = _request_id_reference_to_query_params(
75+
workflow_event.request_id_ref
76+
)
77+
78+
# urllib will omit '//' from the url if netloc is empty so we add the scheme manually
79+
url = f"{scheme}://{urllib.parse.urlunparse(('', '', path, '', query_params, ''))}"
80+
6481
return nexusrpc.Link(
65-
url=urllib.parse.urlunparse((scheme, "", path, "", query_params, "")),
82+
url=url,
6683
type=workflow_event.DESCRIPTOR.full_name,
6784
)
6885

@@ -83,7 +100,20 @@ def nexus_link_to_workflow_event(
83100
)
84101
return None
85102
try:
86-
event_ref = _query_params_to_event_reference(url.query)
103+
query_params = urllib.parse.parse_qs(url.query)
104+
105+
request_id_ref = None
106+
event_ref = None
107+
match query_params.get(LINK_REFERENCE_TYPE_PARAM_NAME):
108+
case ["EventReference"]:
109+
event_ref = _query_params_to_event_reference(query_params)
110+
case ["RequestIdReference"]:
111+
request_id_ref = _query_params_to_request_id_reference(query_params)
112+
case _:
113+
raise ValueError(
114+
f"Invalid Nexus link: {link}. Expected {LINK_REFERENCE_TYPE_PARAM_NAME} to be '{EVENT_REFERENCE_TYPE}' or '{REQUEST_ID_REFERENCE_TYPE}'"
115+
)
116+
87117
except ValueError as err:
88118
logger.warning(
89119
f"Failed to parse event reference from Nexus link URL query parameters: {link} ({err})"
@@ -96,6 +126,7 @@ def nexus_link_to_workflow_event(
96126
workflow_id=urllib.parse.unquote(groups["workflow_id"]),
97127
run_id=urllib.parse.unquote(groups["run_id"]),
98128
event_ref=event_ref,
129+
request_id_ref=request_id_ref,
99130
)
100131

101132

@@ -109,36 +140,58 @@ def _event_reference_to_query_params(
109140
)
110141
return urllib.parse.urlencode(
111142
{
112-
"eventID": event_ref.event_id,
113-
"eventType": event_type_name,
114-
"referenceType": "EventReference",
143+
LINK_EVENT_ID_PARAM_NAME: event_ref.event_id,
144+
LINK_EVENT_TYPE_PARAM_NAME: event_type_name,
145+
LINK_REFERENCE_TYPE_PARAM_NAME: EVENT_REFERENCE_TYPE,
115146
}
116147
)
117148

118149

150+
def _request_id_reference_to_query_params(
151+
request_id_ref: temporalio.api.common.v1.Link.WorkflowEvent.RequestIdReference,
152+
) -> str:
153+
params = {
154+
LINK_REFERENCE_TYPE_PARAM_NAME: REQUEST_ID_REFERENCE_TYPE,
155+
}
156+
157+
if request_id_ref.request_id:
158+
params[LINK_REQUEST_ID_PARAM_NAME] = request_id_ref.request_id
159+
160+
event_type_name = temporalio.api.enums.v1.EventType.Name(request_id_ref.event_type)
161+
if event_type_name.startswith("EVENT_TYPE_"):
162+
event_type_name = _event_type_constant_case_to_pascal_case(
163+
event_type_name.removeprefix("EVENT_TYPE_")
164+
)
165+
params[LINK_EVENT_TYPE_PARAM_NAME] = event_type_name
166+
167+
return urllib.parse.urlencode(params)
168+
169+
119170
def _query_params_to_event_reference(
120-
raw_query_params: str,
171+
query_params: dict[str, list[str]],
121172
) -> temporalio.api.common.v1.Link.WorkflowEvent.EventReference:
122173
"""Return an EventReference from the query params or raise ValueError."""
123-
query_params = urllib.parse.parse_qs(raw_query_params)
124-
125-
[reference_type] = query_params.get("referenceType") or [""]
126-
if reference_type != "EventReference":
174+
[reference_type] = query_params.get(LINK_REFERENCE_TYPE_PARAM_NAME) or [""]
175+
if reference_type != EVENT_REFERENCE_TYPE:
127176
raise ValueError(
128177
f"Expected Nexus link URL query parameter referenceType to be EventReference but got: {reference_type}"
129178
)
179+
130180
# event type
131-
[raw_event_type_name] = query_params.get(LINK_EVENT_TYPE_PARAM_NAME) or [""]
132-
if not raw_event_type_name:
133-
raise ValueError(f"query params do not contain event type: {query_params}")
134-
if raw_event_type_name.startswith("EVENT_TYPE_"):
135-
event_type_name = raw_event_type_name
136-
elif re.match("[A-Z][a-z]", raw_event_type_name):
137-
event_type_name = "EVENT_TYPE_" + _event_type_pascal_case_to_constant_case(
138-
raw_event_type_name
139-
)
140-
else:
141-
raise ValueError(f"Invalid event type name: {raw_event_type_name}")
181+
match query_params.get(LINK_EVENT_TYPE_PARAM_NAME):
182+
case None:
183+
raise ValueError(f"query params do not contain event type: {query_params}")
184+
185+
case [raw_event_type_name] if raw_event_type_name.startswith("EVENT_TYPE_"):
186+
event_type_name = raw_event_type_name
187+
188+
case [raw_event_type_name] if re.match("[A-Z][a-z]", raw_event_type_name):
189+
event_type_name = "EVENT_TYPE_" + _event_type_pascal_case_to_constant_case(
190+
raw_event_type_name
191+
)
192+
193+
case raw_event_type_name:
194+
raise ValueError(f"Invalid event type name: {raw_event_type_name}")
142195

143196
# event id
144197
event_id = 0
@@ -155,6 +208,34 @@ def _query_params_to_event_reference(
155208
)
156209

157210

211+
def _query_params_to_request_id_reference(
212+
query_params: dict[str, list[str]],
213+
) -> temporalio.api.common.v1.Link.WorkflowEvent.RequestIdReference:
214+
"""Return an EventReference from the query params or raise ValueError."""
215+
# event type
216+
match query_params.get(LINK_EVENT_TYPE_PARAM_NAME):
217+
case None:
218+
raise ValueError(f"query params do not contain event type: {query_params}")
219+
220+
case [raw_event_type_name] if raw_event_type_name.startswith("EVENT_TYPE_"):
221+
event_type_name = raw_event_type_name
222+
223+
case [raw_event_type_name] if re.match("[A-Z][a-z]", raw_event_type_name):
224+
event_type_name = "EVENT_TYPE_" + _event_type_pascal_case_to_constant_case(
225+
raw_event_type_name
226+
)
227+
228+
case raw_event_type_name:
229+
raise ValueError(f"Invalid event type name: {raw_event_type_name}")
230+
231+
[request_id] = query_params.get(LINK_REQUEST_ID_PARAM_NAME, [""])
232+
233+
return temporalio.api.common.v1.Link.WorkflowEvent.RequestIdReference(
234+
request_id=request_id,
235+
event_type=temporalio.api.enums.v1.EventType.Value(event_type_name),
236+
)
237+
238+
158239
def _event_type_constant_case_to_pascal_case(s: str) -> str:
159240
"""Convert a CONSTANT_CASE string to PascalCase.
160241

temporalio/nexus/_operation_context.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,8 @@ def _add_outbound_links(
215215
if not wf_event_links:
216216
wf_event_links = [
217217
_link_conversion.workflow_execution_started_event_link_from_workflow_handle(
218-
workflow_handle
218+
workflow_handle,
219+
self.nexus_context.request_id,
219220
)
220221
]
221222
self.nexus_context.outbound_links.extend(

tests/conftest.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@
77
import pytest
88
import pytest_asyncio
99

10+
from temporalio.client import Client
11+
from temporalio.testing import WorkflowEnvironment
12+
from temporalio.worker import SharedStateManager
13+
from tests.helpers.worker import ExternalPythonWorker, ExternalWorker
14+
1015
from . import DEV_SERVER_DOWNLOAD_VERSION
1116

1217
# If there is an integration test environment variable set, we must remove the
@@ -38,10 +43,6 @@
3843
or protobuf_version.startswith("6.")
3944
), f"Expected protobuf 4.x/5.x/6.x, got {protobuf_version}"
4045

41-
from temporalio.client import Client
42-
from temporalio.testing import WorkflowEnvironment
43-
from tests.helpers.worker import ExternalPythonWorker, ExternalWorker
44-
4546

4647
def pytest_runtest_setup(item):
4748
"""Print a newline so that custom printed output starts on new line."""
@@ -134,6 +135,17 @@ async def env(env_type: str) -> AsyncGenerator[WorkflowEnvironment, None]:
134135
await env.shutdown()
135136

136137

138+
@pytest.fixture(scope="session")
139+
def shared_state_manager() -> Iterator[SharedStateManager]:
140+
mp_mgr = multiprocessing.Manager()
141+
mgr = SharedStateManager.create_from_multiprocessing(mp_mgr)
142+
143+
try:
144+
yield mgr
145+
finally:
146+
mp_mgr.shutdown()
147+
148+
137149
@pytest.fixture(scope="session")
138150
def mp_fork_ctx() -> Iterator[multiprocessing.context.BaseContext | None]:
139151
mp_ctx = None

0 commit comments

Comments
 (0)