Skip to content

Commit cd4dc9e

Browse files
committed
Implement feature
1 parent 83382c8 commit cd4dc9e

File tree

3 files changed

+57
-46
lines changed

3 files changed

+57
-46
lines changed

temporalio/client.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5877,6 +5877,12 @@ async def _build_start_workflow_execution_request(
58775877
)
58785878
# Links are duplicated on request for compatibility with older server versions.
58795879
req.links.extend(links)
5880+
5881+
if temporalio.nexus._operation_context._in_nexus_backing_workflow_start_context():
5882+
req.on_conflict_options.attach_request_id = True
5883+
req.on_conflict_options.attach_completion_callbacks = True
5884+
req.on_conflict_options.attach_links = True
5885+
58805886
return req
58815887

58825888
async def _build_signal_with_start_workflow_execution_request(
@@ -5932,6 +5938,7 @@ async def _populate_start_workflow_execution_request(
59325938
"temporalio.api.enums.v1.WorkflowIdConflictPolicy.ValueType",
59335939
int(input.id_conflict_policy),
59345940
)
5941+
59355942
if input.retry_policy is not None:
59365943
input.retry_policy.apply_to_proto(req.retry_policy)
59375944
req.cron_schedule = input.cron_schedule

temporalio/nexus/_operation_context.py

Lines changed: 50 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
import dataclasses
44
import logging
55
from collections.abc import Awaitable, Mapping, MutableMapping, Sequence
6+
from contextlib import contextmanager
67
from contextvars import ContextVar
78
from dataclasses import dataclass
89
from datetime import timedelta
910
from typing import (
1011
TYPE_CHECKING,
1112
Any,
1213
Callable,
14+
Generator,
1315
Optional,
1416
Union,
1517
overload,
@@ -47,6 +49,10 @@
4749
ContextVar("temporal-cancel-operation-context")
4850
)
4951

52+
_temporal_nexus_backing_workflow_start_context: ContextVar[bool] = ContextVar(
53+
"temporal-nexus-backing-workflow-start-context"
54+
)
55+
5056

5157
@dataclass(frozen=True)
5258
class Info:
@@ -96,6 +102,19 @@ def _try_temporal_context() -> (
96102
return start_ctx or cancel_ctx
97103

98104

105+
@contextmanager
106+
def _nexus_backing_workflow_start_context() -> Generator[None, None, None]:
107+
token = _temporal_nexus_backing_workflow_start_context.set(True)
108+
try:
109+
yield
110+
finally:
111+
_temporal_nexus_backing_workflow_start_context.reset(token)
112+
113+
114+
def _in_nexus_backing_workflow_start_context() -> bool:
115+
return _temporal_nexus_backing_workflow_start_context.get(False)
116+
117+
99118
@dataclass
100119
class _TemporalStartOperationContext:
101120
"""Context for a Nexus start operation being handled by a Temporal Nexus Worker."""
@@ -396,48 +415,40 @@ async def start_workflow(
396415
Nexus caller is itself a workflow, this means that the workflow in the caller
397416
namespace web UI will contain links to the started workflow, and vice versa.
398417
"""
399-
# TODO(nexus-preview): When sdk-python supports on_conflict_options, Typescript does this:
400-
# if (workflowOptions.workflowIdConflictPolicy === 'USE_EXISTING') {
401-
# internalOptions.onConflictOptions = {
402-
# attachLinks: true,
403-
# attachCompletionCallbacks: true,
404-
# attachRequestId: true,
405-
# };
406-
# }
407-
408418
# We must pass nexus_completion_callbacks, workflow_event_links, and request_id,
409419
# but these are deliberately not exposed in overloads, hence the type-check
410420
# violation.
411-
wf_handle = await self._temporal_context.client.start_workflow( # type: ignore
412-
workflow=workflow,
413-
arg=arg,
414-
args=args,
415-
id=id,
416-
task_queue=task_queue or self._temporal_context.info().task_queue,
417-
result_type=result_type,
418-
execution_timeout=execution_timeout,
419-
run_timeout=run_timeout,
420-
task_timeout=task_timeout,
421-
id_reuse_policy=id_reuse_policy,
422-
id_conflict_policy=id_conflict_policy,
423-
retry_policy=retry_policy,
424-
cron_schedule=cron_schedule,
425-
memo=memo,
426-
search_attributes=search_attributes,
427-
static_summary=static_summary,
428-
static_details=static_details,
429-
start_delay=start_delay,
430-
start_signal=start_signal,
431-
start_signal_args=start_signal_args,
432-
rpc_metadata=rpc_metadata,
433-
rpc_timeout=rpc_timeout,
434-
request_eager_start=request_eager_start,
435-
priority=priority,
436-
versioning_override=versioning_override,
437-
callbacks=self._temporal_context._get_callbacks(),
438-
workflow_event_links=self._temporal_context._get_workflow_event_links(),
439-
request_id=self._temporal_context.nexus_context.request_id,
440-
)
421+
with _nexus_backing_workflow_start_context():
422+
wf_handle = await self._temporal_context.client.start_workflow( # type: ignore
423+
workflow=workflow,
424+
arg=arg,
425+
args=args,
426+
id=id,
427+
task_queue=task_queue or self._temporal_context.info().task_queue,
428+
result_type=result_type,
429+
execution_timeout=execution_timeout,
430+
run_timeout=run_timeout,
431+
task_timeout=task_timeout,
432+
id_reuse_policy=id_reuse_policy,
433+
id_conflict_policy=id_conflict_policy,
434+
retry_policy=retry_policy,
435+
cron_schedule=cron_schedule,
436+
memo=memo,
437+
search_attributes=search_attributes,
438+
static_summary=static_summary,
439+
static_details=static_details,
440+
start_delay=start_delay,
441+
start_signal=start_signal,
442+
start_signal_args=start_signal_args,
443+
rpc_metadata=rpc_metadata,
444+
rpc_timeout=rpc_timeout,
445+
request_eager_start=request_eager_start,
446+
priority=priority,
447+
versioning_override=versioning_override,
448+
callbacks=self._temporal_context._get_callbacks(),
449+
workflow_event_links=self._temporal_context._get_workflow_event_links(),
450+
request_id=self._temporal_context.nexus_context.request_id,
451+
)
441452

442453
self._temporal_context._add_outbound_links(wf_handle)
443454

temporalio/worker/_workflow_instance.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3029,13 +3029,6 @@ def operation_token(self) -> Optional[str]:
30293029
def __await__(self) -> Generator[Any, Any, OutputT]:
30303030
return self._task.__await__()
30313031

3032-
def __repr__(self) -> str:
3033-
return (
3034-
f"{self._start_fut} "
3035-
f"{self._result_fut} "
3036-
f"Task[{self._task._state}] fut_waiter = {self._task._fut_waiter}) ({self._task._must_cancel})" # type: ignore
3037-
)
3038-
30393032
def cancel(self) -> bool:
30403033
return self._task.cancel()
30413034

0 commit comments

Comments
 (0)