diff --git a/.github/workflows/build-binaries.yml b/.github/workflows/build-binaries.yml index be818370f..3c7834b5b 100644 --- a/.github/workflows/build-binaries.yml +++ b/.github/workflows/build-binaries.yml @@ -74,8 +74,3 @@ jobs: with: name: packages-${{ matrix.package-suffix }} path: dist - - - name: Deliberately fail to prevent releasing nexus-rpc w/ GitHub link in pyproject.toml - run: | - echo "This is a deliberate failure to prevent releasing nexus-rpc with a GitHub link in pyproject.toml" - exit 1 \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index a7e7826ea..2d3f96068 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ keywords = [ "workflow", ] dependencies = [ - "nexus-rpc>=1.1.0", + "nexus-rpc==1.1.0", "protobuf>=3.20,<6", "python-dateutil>=2.8.2,<3 ; python_version < '3.11'", "types-protobuf>=3.20", @@ -231,6 +231,3 @@ exclude = [ [tool.uv] # Prevent uv commands from building the package by default package = false - -[tool.uv.sources] -nexus-rpc = { git = "https://github.com/nexus-rpc/sdk-python.git", rev = "35f574c711193a6e2560d3e6665732a5bb7ae92c" } diff --git a/temporalio/nexus/_decorators.py b/temporalio/nexus/_decorators.py index d8675afdb..1266fd29e 100644 --- a/temporalio/nexus/_decorators.py +++ b/temporalio/nexus/_decorators.py @@ -123,7 +123,7 @@ async def _start( return WorkflowRunOperationHandler(_start, input_type, output_type) method_name = get_callable_name(start) - nexusrpc.set_operation( + nexusrpc.set_operation_definition( operation_handler_factory, nexusrpc.Operation( name=name or method_name, diff --git a/temporalio/nexus/_util.py b/temporalio/nexus/_util.py index 70793be64..ef005d0c4 100644 --- a/temporalio/nexus/_util.py +++ b/temporalio/nexus/_util.py @@ -13,6 +13,7 @@ TypeVar, ) +import nexusrpc from nexusrpc import ( InputT, OutputT, @@ -78,10 +79,19 @@ def _get_start_method_input_and_output_type_annotations( try: type_annotations = typing.get_type_hints(start) except TypeError: + warnings.warn( + f"Expected decorated start method {start} to have type annotations" + ) return None, None output_type = type_annotations.pop("return", None) if len(type_annotations) != 2: + suffix = f": {type_annotations}" if type_annotations else "" + warnings.warn( + f"Expected decorated start method {start} to have exactly 2 " + f"type-annotated parameters (ctx and input), but it has {len(type_annotations)}" + f"{suffix}." + ) input_type = None else: ctx_type, input_type = type_annotations.values() @@ -108,6 +118,28 @@ def get_callable_name(fn: Callable[..., Any]) -> str: return method_name +# TODO(nexus-preview) Copied from nexusrpc +def get_operation_factory( + obj: Any, +) -> tuple[ + Optional[Callable[[Any], Any]], + Optional[nexusrpc.Operation[Any, Any]], +]: + """Return the :py:class:`Operation` for the object along with the factory function. + + ``obj`` should be a decorated operation start method. + """ + op_defn = nexusrpc.get_operation_definition(obj) + if op_defn: + factory = obj + else: + if factory := getattr(obj, "__nexus_operation_factory__", None): + op_defn = nexusrpc.get_operation_definition(factory) + if not isinstance(op_defn, nexusrpc.Operation): + return None, None + return factory, op_defn + + # TODO(nexus-preview) Copied from nexusrpc def set_operation_factory( obj: Any, diff --git a/temporalio/worker/_interceptor.py b/temporalio/worker/_interceptor.py index 32ce66e0b..1b412cb7f 100644 --- a/temporalio/worker/_interceptor.py +++ b/temporalio/worker/_interceptor.py @@ -299,14 +299,15 @@ class StartNexusOperationInput(Generic[InputT, OutputT]): input: InputT schedule_to_close_timeout: Optional[timedelta] headers: Optional[Mapping[str, str]] - output_type: Optional[type[OutputT]] = None + output_type: Optional[Type[OutputT]] = None def __post_init__(self) -> None: """Initialize operation-specific attributes after dataclass creation.""" if isinstance(self.operation, nexusrpc.Operation): self.output_type = self.operation.output_type elif callable(self.operation): - if op := nexusrpc.get_operation(self.operation): + _, op = temporalio.nexus._util.get_operation_factory(self.operation) + if isinstance(op, nexusrpc.Operation): self.output_type = op.output_type else: raise ValueError( @@ -325,7 +326,8 @@ def operation_name(self) -> str: elif isinstance(self.operation, str): return self.operation elif callable(self.operation): - if op := nexusrpc.get_operation(self.operation): + _, op = temporalio.nexus._util.get_operation_factory(self.operation) + if isinstance(op, nexusrpc.Operation): return op.name else: raise ValueError( diff --git a/temporalio/workflow.py b/temporalio/workflow.py index 75cedd18c..f3a514e26 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -5145,7 +5145,7 @@ async def start_operation( operation: nexusrpc.Operation[InputT, OutputT], input: InputT, *, - output_type: Optional[type[OutputT]] = None, + output_type: Optional[Type[OutputT]] = None, schedule_to_close_timeout: Optional[timedelta] = None, headers: Optional[Mapping[str, str]] = None, ) -> NexusOperationHandle[OutputT]: ... @@ -5158,7 +5158,7 @@ async def start_operation( operation: str, input: Any, *, - output_type: Optional[type[OutputT]] = None, + output_type: Optional[Type[OutputT]] = None, schedule_to_close_timeout: Optional[timedelta] = None, headers: Optional[Mapping[str, str]] = None, ) -> NexusOperationHandle[OutputT]: ... @@ -5174,7 +5174,7 @@ async def start_operation( ], input: InputT, *, - output_type: Optional[type[OutputT]] = None, + output_type: Optional[Type[OutputT]] = None, schedule_to_close_timeout: Optional[timedelta] = None, headers: Optional[Mapping[str, str]] = None, ) -> NexusOperationHandle[OutputT]: ... @@ -5190,7 +5190,7 @@ async def start_operation( ], input: InputT, *, - output_type: Optional[type[OutputT]] = None, + output_type: Optional[Type[OutputT]] = None, schedule_to_close_timeout: Optional[timedelta] = None, headers: Optional[Mapping[str, str]] = None, ) -> NexusOperationHandle[OutputT]: ... @@ -5206,7 +5206,7 @@ async def start_operation( ], input: InputT, *, - output_type: Optional[type[OutputT]] = None, + output_type: Optional[Type[OutputT]] = None, schedule_to_close_timeout: Optional[timedelta] = None, headers: Optional[Mapping[str, str]] = None, ) -> NexusOperationHandle[OutputT]: ... @@ -5217,7 +5217,7 @@ async def start_operation( operation: Any, input: Any, *, - output_type: Optional[type[OutputT]] = None, + output_type: Optional[Type[OutputT]] = None, schedule_to_close_timeout: Optional[timedelta] = None, headers: Optional[Mapping[str, str]] = None, ) -> Any: @@ -5246,7 +5246,7 @@ async def execute_operation( operation: nexusrpc.Operation[InputT, OutputT], input: InputT, *, - output_type: Optional[type[OutputT]] = None, + output_type: Optional[Type[OutputT]] = None, schedule_to_close_timeout: Optional[timedelta] = None, headers: Optional[Mapping[str, str]] = None, ) -> OutputT: ... @@ -5259,7 +5259,7 @@ async def execute_operation( operation: str, input: Any, *, - output_type: Optional[type[OutputT]] = None, + output_type: Optional[Type[OutputT]] = None, schedule_to_close_timeout: Optional[timedelta] = None, headers: Optional[Mapping[str, str]] = None, ) -> OutputT: ... @@ -5275,7 +5275,7 @@ async def execute_operation( ], input: InputT, *, - output_type: Optional[type[OutputT]] = None, + output_type: Optional[Type[OutputT]] = None, schedule_to_close_timeout: Optional[timedelta] = None, headers: Optional[Mapping[str, str]] = None, ) -> OutputT: ... @@ -5294,7 +5294,7 @@ async def execute_operation( ], input: InputT, *, - output_type: Optional[type[OutputT]] = None, + output_type: Optional[Type[OutputT]] = None, schedule_to_close_timeout: Optional[timedelta] = None, headers: Optional[Mapping[str, str]] = None, ) -> OutputT: ... @@ -5310,7 +5310,7 @@ async def execute_operation( ], input: InputT, *, - output_type: Optional[type[OutputT]] = None, + output_type: Optional[Type[OutputT]] = None, schedule_to_close_timeout: Optional[timedelta] = None, headers: Optional[Mapping[str, str]] = None, ) -> OutputT: ... @@ -5321,7 +5321,7 @@ async def execute_operation( operation: Any, input: Any, *, - output_type: Optional[type[OutputT]] = None, + output_type: Optional[Type[OutputT]] = None, schedule_to_close_timeout: Optional[timedelta] = None, headers: Optional[Mapping[str, str]] = None, ) -> Any: @@ -5345,7 +5345,7 @@ def __init__( self, *, endpoint: str, - service: Union[type[ServiceT], str], + service: Union[Type[ServiceT], str], ) -> None: """Create a Nexus client. @@ -5372,7 +5372,7 @@ async def start_operation( operation: Any, input: Any, *, - output_type: Optional[type] = None, + output_type: Optional[Type] = None, schedule_to_close_timeout: Optional[timedelta] = None, headers: Optional[Mapping[str, str]] = None, ) -> Any: @@ -5393,7 +5393,7 @@ async def execute_operation( operation: Any, input: Any, *, - output_type: Optional[type] = None, + output_type: Optional[Type] = None, schedule_to_close_timeout: Optional[timedelta] = None, headers: Optional[Mapping[str, str]] = None, ) -> Any: @@ -5410,7 +5410,7 @@ async def execute_operation( @overload def create_nexus_client( *, - service: type[ServiceT], + service: Type[ServiceT], endpoint: str, ) -> NexusClient[ServiceT]: ... @@ -5425,9 +5425,9 @@ def create_nexus_client( def create_nexus_client( *, - service: Union[type[ServiceT], str], + service: Union[Type[ServiceT], str], endpoint: str, -) -> NexusClient[Any]: +) -> NexusClient[ServiceT]: """Create a Nexus client. .. warning:: diff --git a/tests/nexus/test_dynamic_creation_of_user_handler_classes.py b/tests/nexus/test_dynamic_creation_of_user_handler_classes.py index 3df085d01..0eef14b84 100644 --- a/tests/nexus/test_dynamic_creation_of_user_handler_classes.py +++ b/tests/nexus/test_dynamic_creation_of_user_handler_classes.py @@ -3,9 +3,11 @@ import httpx import nexusrpc.handler import pytest +from nexusrpc.handler import sync_operation from temporalio import nexus, workflow from temporalio.client import Client +from temporalio.nexus._util import get_operation_factory from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker from tests.helpers.nexus import ServiceClient, create_nexus_endpoint @@ -76,8 +78,8 @@ async def test_run_nexus_service_from_programmatically_created_service_handler( service_handler = nexusrpc.handler._core.ServiceHandler( service=nexusrpc.ServiceDefinition( name="MyService", - operation_definitions={ - "increment": nexusrpc.OperationDefinition[int, int]( + operations={ + "increment": nexusrpc.Operation[int, int]( name="increment", method_name="increment", input_type=int, @@ -105,3 +107,70 @@ async def test_run_nexus_service_from_programmatically_created_service_handler( json=1, ) assert response.status_code == 201 + + +def make_incrementer_user_service_definition_and_service_handler_classes( + op_names: list[str], +) -> tuple[type, type]: + # + # service contract + # + + ops = {name: nexusrpc.Operation[int, int] for name in op_names} + service_cls: type = nexusrpc.service(type("ServiceContract", (), ops)) + + # + # service handler + # + @sync_operation + async def _increment_op( + self, + ctx: nexusrpc.handler.StartOperationContext, + input: int, + ) -> int: + return input + 1 + + op_handler_factories = {} + for name in op_names: + op_handler_factory, _ = get_operation_factory(_increment_op) + assert op_handler_factory + op_handler_factories[name] = op_handler_factory + + handler_cls: type = nexusrpc.handler.service_handler(service=service_cls)( + type("ServiceImpl", (), op_handler_factories) + ) + + return service_cls, handler_cls + + +@pytest.mark.skip( + reason="Dynamic creation of service contract using type() is not supported" +) +async def test_dynamic_creation_of_user_handler_classes( + client: Client, env: WorkflowEnvironment +): + task_queue = str(uuid.uuid4()) + + service_cls, handler_cls = ( + make_incrementer_user_service_definition_and_service_handler_classes( + ["increment"] + ) + ) + + assert (service_defn := nexusrpc.get_service_definition(service_cls)) + service_name = service_defn.name + + endpoint = (await create_nexus_endpoint(task_queue, client)).endpoint.id + async with Worker( + client, + task_queue=task_queue, + nexus_service_handlers=[handler_cls()], + ): + server_address = ServiceClient.default_server_address(env) + async with httpx.AsyncClient() as http_client: + response = await http_client.post( + f"http://{server_address}/nexus/endpoints/{endpoint}/services/{service_name}/increment", + json=1, + ) + assert response.status_code == 200 + assert response.json() == 2 diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index d95db5731..c805a967c 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -45,7 +45,6 @@ sync_operation, ) from nexusrpc.handler._decorators import operation_handler -from typing_extensions import dataclass_transform from temporalio import nexus, workflow from temporalio.client import Client @@ -329,17 +328,12 @@ class UnsuccessfulResponse: headers: Mapping[str, str] = UNSUCCESSFUL_RESPONSE_HEADERS -@dataclass_transform() -class _BaseTestCase: - pass - - -class _TestCase(_BaseTestCase): +class _TestCase: operation: str - expected: SuccessfulResponse service_defn: str = "MyService" input: Input = Input("") headers: dict[str, str] = {} + expected: SuccessfulResponse expected_without_service_definition: Optional[SuccessfulResponse] = None skip = "" @@ -779,7 +773,10 @@ async def test_start_operation_without_type_annotations( def test_operation_without_type_annotations_without_service_definition_raises_validation_error(): - with pytest.raises(ValueError, match=r"has no input type"): + with pytest.raises( + ValueError, + match=r"has no input type.+has no output type", + ): service_handler(MyServiceHandlerWithOperationsWithoutTypeAnnotations) diff --git a/tests/nexus/test_handler_operation_definitions.py b/tests/nexus/test_handler_operation_definitions.py index 82a0682fb..8e41c1efa 100644 --- a/tests/nexus/test_handler_operation_definitions.py +++ b/tests/nexus/test_handler_operation_definitions.py @@ -11,6 +11,7 @@ from temporalio import nexus from temporalio.nexus import WorkflowRunOperationContext, workflow_run_operation +from temporalio.nexus._util import get_operation_factory @dataclass @@ -95,7 +96,7 @@ async def test_collected_operation_names( assert isinstance(service_defn, nexusrpc.ServiceDefinition) assert service_defn.name == "Service" for method_name, expected_op in test_case.expected_operations.items(): - actual_op = nexusrpc.get_operation(getattr(test_case.Service, method_name)) + _, actual_op = get_operation_factory(getattr(test_case.Service, method_name)) assert isinstance(actual_op, nexusrpc.Operation) assert actual_op.name == expected_op.name assert actual_op.input_type == expected_op.input_type diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 085febb78..c9417ef58 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -253,12 +253,11 @@ def __init__( request_cancel: bool, task_queue: str, ) -> None: - service: type[Any] = { - CallerReference.IMPL_WITH_INTERFACE: ServiceImpl, - CallerReference.INTERFACE: ServiceInterface, - }[input.op_input.caller_reference] self.nexus_client = workflow.create_nexus_client( - service=service, + service={ + CallerReference.IMPL_WITH_INTERFACE: ServiceImpl, + CallerReference.INTERFACE: ServiceInterface, + }[input.op_input.caller_reference], endpoint=make_nexus_endpoint_name(task_queue), ) self._nexus_operation_started = False @@ -884,7 +883,7 @@ async def run( task_queue: str, ) -> ServiceClassNameOutput: C, N = CallerReference, NameOverride - service_cls: type[Any] + service_cls: type if (caller_reference, name_override) == (C.INTERFACE, N.YES): service_cls = ServiceInterfaceWithNameOverride elif (caller_reference, name_override) == (C.INTERFACE, N.NO): diff --git a/uv.lock b/uv.lock index 07b4ac4bb..13cb7bed3 100644 --- a/uv.lock +++ b/uv.lock @@ -1602,10 +1602,14 @@ wheels = [ [[package]] name = "nexus-rpc" version = "1.1.0" -source = { git = "https://github.com/nexus-rpc/sdk-python.git?rev=35f574c711193a6e2560d3e6665732a5bb7ae92c#35f574c711193a6e2560d3e6665732a5bb7ae92c" } +source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "typing-extensions" }, ] +sdist = { url = "https://files.pythonhosted.org/packages/ef/66/540687556bd28cf1ec370cc6881456203dfddb9dab047b8979c6865b5984/nexus_rpc-1.1.0.tar.gz", hash = "sha256:d65ad6a2f54f14e53ebe39ee30555eaeb894102437125733fb13034a04a44553", size = 77383, upload-time = "2025-07-07T19:03:58.368Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bf/2f/9e9d0dcaa4c6ffa22b7aa31069a8a264c753ff8027b36af602cce038c92f/nexus_rpc-1.1.0-py3-none-any.whl", hash = "sha256:d1b007af2aba186a27e736f8eaae39c03aed05b488084ff6c3d1785c9ba2ad38", size = 27743, upload-time = "2025-07-07T19:03:57.556Z" }, +] [[package]] name = "nh3" @@ -2754,7 +2758,7 @@ dev = [ requires-dist = [ { name = "eval-type-backport", marker = "python_full_version < '3.10' and extra == 'openai-agents'", specifier = ">=0.2.2" }, { name = "grpcio", marker = "extra == 'grpc'", specifier = ">=1.48.2,<2" }, - { name = "nexus-rpc", git = "https://github.com/nexus-rpc/sdk-python.git?rev=35f574c711193a6e2560d3e6665732a5bb7ae92c" }, + { name = "nexus-rpc", specifier = "==1.1.0" }, { name = "openai-agents", marker = "extra == 'openai-agents'", specifier = ">=0.2.3,<0.3" }, { name = "opentelemetry-api", marker = "extra == 'opentelemetry'", specifier = ">=1.11.1,<2" }, { name = "opentelemetry-sdk", marker = "extra == 'opentelemetry'", specifier = ">=1.11.1,<2" },