Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions temporalio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
import temporalio.converter
import temporalio.exceptions
import temporalio.nexus
import temporalio.nexus._operation_context
import temporalio.runtime
import temporalio.service
import temporalio.workflow
Expand Down Expand Up @@ -5877,6 +5878,12 @@ async def _build_start_workflow_execution_request(
)
# Links are duplicated on request for compatibility with older server versions.
req.links.extend(links)

if temporalio.nexus._operation_context._in_nexus_backing_workflow_start_context():
req.on_conflict_options.attach_request_id = True
req.on_conflict_options.attach_completion_callbacks = True
req.on_conflict_options.attach_links = True

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Nexus Signal-with-Start Missing Conflict Options

The _build_signal_with_start_workflow_execution_request method is missing the on_conflict_options logic for Nexus operations. This logic, which sets options for attaching request IDs, callbacks, and links, is present in _build_start_workflow_execution_request. Without it, Nexus operations using signal-with-start won't properly attach to existing workflows, causing inconsistent behavior.

Fix in Cursor Fix in Web

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a bad review comment, but we don't need to support signal-with-start in this context yet.

return req

async def _build_signal_with_start_workflow_execution_request(
Expand Down Expand Up @@ -5932,6 +5939,7 @@ async def _populate_start_workflow_execution_request(
"temporalio.api.enums.v1.WorkflowIdConflictPolicy.ValueType",
int(input.id_conflict_policy),
)

if input.retry_policy is not None:
input.retry_policy.apply_to_proto(req.retry_policy)
req.cron_schedule = input.cron_schedule
Expand Down
97 changes: 50 additions & 47 deletions temporalio/nexus/_operation_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
import dataclasses
import logging
from collections.abc import Awaitable, Mapping, MutableMapping, Sequence
from contextlib import contextmanager
from contextvars import ContextVar
from dataclasses import dataclass
from datetime import timedelta
from typing import (
TYPE_CHECKING,
Any,
Callable,
Generator,
Optional,
Union,
overload,
Expand Down Expand Up @@ -47,6 +49,10 @@
ContextVar("temporal-cancel-operation-context")
)

_temporal_nexus_backing_workflow_start_context: ContextVar[bool] = ContextVar(
"temporal-nexus-backing-workflow-start-context"
)


@dataclass(frozen=True)
class Info:
Expand Down Expand Up @@ -96,6 +102,19 @@ def _try_temporal_context() -> (
return start_ctx or cancel_ctx


@contextmanager
def _nexus_backing_workflow_start_context() -> Generator[None, None, None]:
token = _temporal_nexus_backing_workflow_start_context.set(True)
try:
yield
finally:
_temporal_nexus_backing_workflow_start_context.reset(token)


def _in_nexus_backing_workflow_start_context() -> bool:
return _temporal_nexus_backing_workflow_start_context.get(False)


@dataclass
class _TemporalStartOperationContext:
"""Context for a Nexus start operation being handled by a Temporal Nexus Worker."""
Expand Down Expand Up @@ -396,56 +415,40 @@ async def start_workflow(
Nexus caller is itself a workflow, this means that the workflow in the caller
namespace web UI will contain links to the started workflow, and vice versa.
"""
# TODO(nexus-preview): When sdk-python supports on_conflict_options, Typescript does this:
# if (workflowOptions.workflowIdConflictPolicy === 'USE_EXISTING') {
# internalOptions.onConflictOptions = {
# attachLinks: true,
# attachCompletionCallbacks: true,
# attachRequestId: true,
# };
# }
if (
id_conflict_policy
== temporalio.common.WorkflowIDConflictPolicy.USE_EXISTING
):
raise RuntimeError(
"WorkflowIDConflictPolicy.USE_EXISTING is not yet supported when starting a workflow "
"that backs a Nexus operation (Python SDK Nexus support is at Pre-release stage)."
)

# We must pass nexus_completion_callbacks, workflow_event_links, and request_id,
# but these are deliberately not exposed in overloads, hence the type-check
# violation.
wf_handle = await self._temporal_context.client.start_workflow( # type: ignore
workflow=workflow,
arg=arg,
args=args,
id=id,
task_queue=task_queue or self._temporal_context.info().task_queue,
result_type=result_type,
execution_timeout=execution_timeout,
run_timeout=run_timeout,
task_timeout=task_timeout,
id_reuse_policy=id_reuse_policy,
id_conflict_policy=id_conflict_policy,
retry_policy=retry_policy,
cron_schedule=cron_schedule,
memo=memo,
search_attributes=search_attributes,
static_summary=static_summary,
static_details=static_details,
start_delay=start_delay,
start_signal=start_signal,
start_signal_args=start_signal_args,
rpc_metadata=rpc_metadata,
rpc_timeout=rpc_timeout,
request_eager_start=request_eager_start,
priority=priority,
versioning_override=versioning_override,
callbacks=self._temporal_context._get_callbacks(),
workflow_event_links=self._temporal_context._get_workflow_event_links(),
request_id=self._temporal_context.nexus_context.request_id,
)
with _nexus_backing_workflow_start_context():
wf_handle = await self._temporal_context.client.start_workflow( # type: ignore
workflow=workflow,
arg=arg,
args=args,
id=id,
task_queue=task_queue or self._temporal_context.info().task_queue,
result_type=result_type,
execution_timeout=execution_timeout,
run_timeout=run_timeout,
task_timeout=task_timeout,
id_reuse_policy=id_reuse_policy,
id_conflict_policy=id_conflict_policy,
retry_policy=retry_policy,
cron_schedule=cron_schedule,
memo=memo,
search_attributes=search_attributes,
static_summary=static_summary,
static_details=static_details,
start_delay=start_delay,
start_signal=start_signal,
start_signal_args=start_signal_args,
rpc_metadata=rpc_metadata,
rpc_timeout=rpc_timeout,
request_eager_start=request_eager_start,
priority=priority,
versioning_override=versioning_override,
callbacks=self._temporal_context._get_callbacks(),
workflow_event_links=self._temporal_context._get_workflow_event_links(),
request_id=self._temporal_context.nexus_context.request_id,
)

self._temporal_context._add_outbound_links(wf_handle)

Expand Down
7 changes: 0 additions & 7 deletions temporalio/worker/_workflow_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3029,13 +3029,6 @@ def operation_token(self) -> Optional[str]:
def __await__(self) -> Generator[Any, Any, OutputT]:
return self._task.__await__()

def __repr__(self) -> str:
return (
f"{self._start_fut} "
f"{self._result_fut} "
f"Task[{self._task._state}] fut_waiter = {self._task._fut_waiter}) ({self._task._must_cancel})" # type: ignore
)

def cancel(self) -> bool:
return self._task.cancel()

Expand Down
124 changes: 124 additions & 0 deletions tests/nexus/test_use_existing_conflict_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from __future__ import annotations

import asyncio
import uuid
from dataclasses import dataclass
from typing import Optional

import pytest
from nexusrpc.handler import service_handler

from temporalio import nexus, workflow
from temporalio.client import Client
from temporalio.common import WorkflowIDConflictPolicy
from temporalio.testing import WorkflowEnvironment
from temporalio.worker import Worker
from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name


@dataclass
class OpInput:
workflow_id: str
conflict_policy: WorkflowIDConflictPolicy


@workflow.defn
class HandlerWorkflow:
def __init__(self) -> None:
self.result: Optional[str] = None

@workflow.run
async def run(self) -> str:
await workflow.wait_condition(lambda: self.result is not None)
assert self.result
return self.result

@workflow.signal
def complete(self, result: str) -> None:
self.result = result


@service_handler
class NexusService:
@nexus.workflow_run_operation
async def workflow_backed_operation(
self, ctx: nexus.WorkflowRunOperationContext, input: OpInput
) -> nexus.WorkflowHandle[str]:
return await ctx.start_workflow(
HandlerWorkflow.run,
id=input.workflow_id,
id_conflict_policy=input.conflict_policy,
)


@dataclass
class CallerWorkflowInput:
workflow_id: str
task_queue: str
num_operations: int


@workflow.defn
class CallerWorkflow:
def __init__(self) -> None:
self._nexus_operations_have_started = asyncio.Event()

@workflow.run
async def run(self, input: CallerWorkflowInput) -> list[str]:
nexus_client = workflow.create_nexus_client(
service=NexusService, endpoint=make_nexus_endpoint_name(input.task_queue)
)

op_input = OpInput(
workflow_id=input.workflow_id,
conflict_policy=WorkflowIDConflictPolicy.USE_EXISTING,
)

handles = []
for _ in range(input.num_operations):
handles.append(
await nexus_client.start_operation(
NexusService.workflow_backed_operation, op_input
)
)
self._nexus_operations_have_started.set()
return await asyncio.gather(*handles)

@workflow.update
async def nexus_operations_have_started(self) -> None:
await self._nexus_operations_have_started.wait()


async def test_multiple_operation_invocations_can_connect_to_same_handler_workflow(
client: Client, env: WorkflowEnvironment
):
if env.supports_time_skipping:
pytest.skip("Nexus tests don't work with time-skipping server")

task_queue = str(uuid.uuid4())
workflow_id = str(uuid.uuid4())

async with Worker(
client,
nexus_service_handlers=[NexusService()],
workflows=[CallerWorkflow, HandlerWorkflow],
task_queue=task_queue,
):
await create_nexus_endpoint(task_queue, client)
caller_handle = await client.start_workflow(
CallerWorkflow.run,
args=[
CallerWorkflowInput(
workflow_id=workflow_id,
task_queue=task_queue,
num_operations=5,
)
],
id=str(uuid.uuid4()),
task_queue=task_queue,
)
await caller_handle.execute_update(CallerWorkflow.nexus_operations_have_started)
await client.get_workflow_handle(workflow_id).signal(
HandlerWorkflow.complete, "test-result"
)
assert await caller_handle.result() == ["test-result"] * 5
Loading