Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions .github/workflows/build-binaries.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ on:
push:
branches:
- main
- dan-9999-pin-nexus
- "releases/*"

jobs:
Expand Down Expand Up @@ -74,8 +75,3 @@ jobs:
with:
name: packages-${{ matrix.package-suffix }}
path: dist

- name: Deliberately fail to prevent releasing nexus-rpc w/ GitHub link in pyproject.toml
run: |
echo "This is a deliberate failure to prevent releasing nexus-rpc with a GitHub link in pyproject.toml"
exit 1
5 changes: 1 addition & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ keywords = [
"workflow",
]
dependencies = [
"nexus-rpc>=1.1.0",
"nexus-rpc==1.1.0",
"protobuf>=3.20,<6",
"python-dateutil>=2.8.2,<3 ; python_version < '3.11'",
"types-protobuf>=3.20",
Expand Down Expand Up @@ -231,6 +231,3 @@ exclude = [
[tool.uv]
# Prevent uv commands from building the package by default
package = false

[tool.uv.sources]
nexus-rpc = { git = "https://github.com/nexus-rpc/sdk-python.git", rev = "35f574c711193a6e2560d3e6665732a5bb7ae92c" }
2 changes: 1 addition & 1 deletion temporalio/nexus/_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ async def _start(
return WorkflowRunOperationHandler(_start, input_type, output_type)

method_name = get_callable_name(start)
nexusrpc.set_operation(
nexusrpc.set_operation_definition(
operation_handler_factory,
nexusrpc.Operation(
name=name or method_name,
Expand Down
32 changes: 32 additions & 0 deletions temporalio/nexus/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
TypeVar,
)

import nexusrpc
from nexusrpc import (
InputT,
OutputT,
Expand Down Expand Up @@ -78,10 +79,19 @@ def _get_start_method_input_and_output_type_annotations(
try:
type_annotations = typing.get_type_hints(start)
except TypeError:
warnings.warn(
f"Expected decorated start method {start} to have type annotations"
)
return None, None
output_type = type_annotations.pop("return", None)

if len(type_annotations) != 2:
suffix = f": {type_annotations}" if type_annotations else ""
warnings.warn(
f"Expected decorated start method {start} to have exactly 2 "
f"type-annotated parameters (ctx and input), but it has {len(type_annotations)}"
f"{suffix}."
)
input_type = None
else:
ctx_type, input_type = type_annotations.values()
Expand All @@ -108,6 +118,28 @@ def get_callable_name(fn: Callable[..., Any]) -> str:
return method_name


# TODO(nexus-preview) Copied from nexusrpc
def get_operation_factory(
obj: Any,
) -> tuple[
Optional[Callable[[Any], Any]],
Optional[nexusrpc.Operation[Any, Any]],
]:
"""Return the :py:class:`Operation` for the object along with the factory function.

``obj`` should be a decorated operation start method.
"""
op_defn = nexusrpc.get_operation_definition(obj)
if op_defn:
factory = obj
else:
if factory := getattr(obj, "__nexus_operation_factory__", None):
op_defn = nexusrpc.get_operation_definition(factory)
if not isinstance(op_defn, nexusrpc.Operation):
return None, None
return factory, op_defn


# TODO(nexus-preview) Copied from nexusrpc
def set_operation_factory(
obj: Any,
Expand Down
8 changes: 5 additions & 3 deletions temporalio/worker/_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,14 +299,15 @@ class StartNexusOperationInput(Generic[InputT, OutputT]):
input: InputT
schedule_to_close_timeout: Optional[timedelta]
headers: Optional[Mapping[str, str]]
output_type: Optional[type[OutputT]] = None
output_type: Optional[Type[OutputT]] = None

def __post_init__(self) -> None:
"""Initialize operation-specific attributes after dataclass creation."""
if isinstance(self.operation, nexusrpc.Operation):
self.output_type = self.operation.output_type
elif callable(self.operation):
if op := nexusrpc.get_operation(self.operation):
_, op = temporalio.nexus._util.get_operation_factory(self.operation)
if isinstance(op, nexusrpc.Operation):
self.output_type = op.output_type
else:
raise ValueError(
Expand All @@ -325,7 +326,8 @@ def operation_name(self) -> str:
elif isinstance(self.operation, str):
return self.operation
elif callable(self.operation):
if op := nexusrpc.get_operation(self.operation):
_, op = temporalio.nexus._util.get_operation_factory(self.operation)
if isinstance(op, nexusrpc.Operation):
return op.name
else:
raise ValueError(
Expand Down
36 changes: 18 additions & 18 deletions temporalio/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5145,7 +5145,7 @@ async def start_operation(
operation: nexusrpc.Operation[InputT, OutputT],
input: InputT,
*,
output_type: Optional[type[OutputT]] = None,
output_type: Optional[Type[OutputT]] = None,
schedule_to_close_timeout: Optional[timedelta] = None,
headers: Optional[Mapping[str, str]] = None,
) -> NexusOperationHandle[OutputT]: ...
Expand All @@ -5158,7 +5158,7 @@ async def start_operation(
operation: str,
input: Any,
*,
output_type: Optional[type[OutputT]] = None,
output_type: Optional[Type[OutputT]] = None,
schedule_to_close_timeout: Optional[timedelta] = None,
headers: Optional[Mapping[str, str]] = None,
) -> NexusOperationHandle[OutputT]: ...
Expand All @@ -5174,7 +5174,7 @@ async def start_operation(
],
input: InputT,
*,
output_type: Optional[type[OutputT]] = None,
output_type: Optional[Type[OutputT]] = None,
schedule_to_close_timeout: Optional[timedelta] = None,
headers: Optional[Mapping[str, str]] = None,
) -> NexusOperationHandle[OutputT]: ...
Expand All @@ -5190,7 +5190,7 @@ async def start_operation(
],
input: InputT,
*,
output_type: Optional[type[OutputT]] = None,
output_type: Optional[Type[OutputT]] = None,
schedule_to_close_timeout: Optional[timedelta] = None,
headers: Optional[Mapping[str, str]] = None,
) -> NexusOperationHandle[OutputT]: ...
Expand All @@ -5206,7 +5206,7 @@ async def start_operation(
],
input: InputT,
*,
output_type: Optional[type[OutputT]] = None,
output_type: Optional[Type[OutputT]] = None,
schedule_to_close_timeout: Optional[timedelta] = None,
headers: Optional[Mapping[str, str]] = None,
) -> NexusOperationHandle[OutputT]: ...
Expand All @@ -5217,7 +5217,7 @@ async def start_operation(
operation: Any,
input: Any,
*,
output_type: Optional[type[OutputT]] = None,
output_type: Optional[Type[OutputT]] = None,
schedule_to_close_timeout: Optional[timedelta] = None,
headers: Optional[Mapping[str, str]] = None,
) -> Any:
Expand Down Expand Up @@ -5246,7 +5246,7 @@ async def execute_operation(
operation: nexusrpc.Operation[InputT, OutputT],
input: InputT,
*,
output_type: Optional[type[OutputT]] = None,
output_type: Optional[Type[OutputT]] = None,
schedule_to_close_timeout: Optional[timedelta] = None,
headers: Optional[Mapping[str, str]] = None,
) -> OutputT: ...
Expand All @@ -5259,7 +5259,7 @@ async def execute_operation(
operation: str,
input: Any,
*,
output_type: Optional[type[OutputT]] = None,
output_type: Optional[Type[OutputT]] = None,
schedule_to_close_timeout: Optional[timedelta] = None,
headers: Optional[Mapping[str, str]] = None,
) -> OutputT: ...
Expand All @@ -5275,7 +5275,7 @@ async def execute_operation(
],
input: InputT,
*,
output_type: Optional[type[OutputT]] = None,
output_type: Optional[Type[OutputT]] = None,
schedule_to_close_timeout: Optional[timedelta] = None,
headers: Optional[Mapping[str, str]] = None,
) -> OutputT: ...
Expand All @@ -5294,7 +5294,7 @@ async def execute_operation(
],
input: InputT,
*,
output_type: Optional[type[OutputT]] = None,
output_type: Optional[Type[OutputT]] = None,
schedule_to_close_timeout: Optional[timedelta] = None,
headers: Optional[Mapping[str, str]] = None,
) -> OutputT: ...
Expand All @@ -5310,7 +5310,7 @@ async def execute_operation(
],
input: InputT,
*,
output_type: Optional[type[OutputT]] = None,
output_type: Optional[Type[OutputT]] = None,
schedule_to_close_timeout: Optional[timedelta] = None,
headers: Optional[Mapping[str, str]] = None,
) -> OutputT: ...
Expand All @@ -5321,7 +5321,7 @@ async def execute_operation(
operation: Any,
input: Any,
*,
output_type: Optional[type[OutputT]] = None,
output_type: Optional[Type[OutputT]] = None,
schedule_to_close_timeout: Optional[timedelta] = None,
headers: Optional[Mapping[str, str]] = None,
) -> Any:
Expand All @@ -5345,7 +5345,7 @@ def __init__(
self,
*,
endpoint: str,
service: Union[type[ServiceT], str],
service: Union[Type[ServiceT], str],
) -> None:
"""Create a Nexus client.

Expand All @@ -5372,7 +5372,7 @@ async def start_operation(
operation: Any,
input: Any,
*,
output_type: Optional[type] = None,
output_type: Optional[Type] = None,
schedule_to_close_timeout: Optional[timedelta] = None,
headers: Optional[Mapping[str, str]] = None,
) -> Any:
Expand All @@ -5393,7 +5393,7 @@ async def execute_operation(
operation: Any,
input: Any,
*,
output_type: Optional[type] = None,
output_type: Optional[Type] = None,
schedule_to_close_timeout: Optional[timedelta] = None,
headers: Optional[Mapping[str, str]] = None,
) -> Any:
Expand All @@ -5410,7 +5410,7 @@ async def execute_operation(
@overload
def create_nexus_client(
*,
service: type[ServiceT],
service: Type[ServiceT],
endpoint: str,
) -> NexusClient[ServiceT]: ...

Expand All @@ -5425,9 +5425,9 @@ def create_nexus_client(

def create_nexus_client(
*,
service: Union[type[ServiceT], str],
service: Union[Type[ServiceT], str],
endpoint: str,
) -> NexusClient[Any]:
) -> NexusClient[ServiceT]:
"""Create a Nexus client.

.. warning::
Expand Down
73 changes: 71 additions & 2 deletions tests/nexus/test_dynamic_creation_of_user_handler_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import httpx
import nexusrpc.handler
import pytest
from nexusrpc.handler import sync_operation

from temporalio import nexus, workflow
from temporalio.client import Client
from temporalio.nexus._util import get_operation_factory
from temporalio.testing import WorkflowEnvironment
from temporalio.worker import Worker
from tests.helpers.nexus import ServiceClient, create_nexus_endpoint
Expand Down Expand Up @@ -76,8 +78,8 @@ async def test_run_nexus_service_from_programmatically_created_service_handler(
service_handler = nexusrpc.handler._core.ServiceHandler(
service=nexusrpc.ServiceDefinition(
name="MyService",
operation_definitions={
"increment": nexusrpc.OperationDefinition[int, int](
operations={
"increment": nexusrpc.Operation[int, int](
name="increment",
method_name="increment",
input_type=int,
Expand Down Expand Up @@ -105,3 +107,70 @@ async def test_run_nexus_service_from_programmatically_created_service_handler(
json=1,
)
assert response.status_code == 201


def make_incrementer_user_service_definition_and_service_handler_classes(
op_names: list[str],
) -> tuple[type, type]:
#
# service contract
#

ops = {name: nexusrpc.Operation[int, int] for name in op_names}
service_cls: type = nexusrpc.service(type("ServiceContract", (), ops))

#
# service handler
#
@sync_operation
async def _increment_op(
self,
ctx: nexusrpc.handler.StartOperationContext,
input: int,
) -> int:
return input + 1

op_handler_factories = {}
for name in op_names:
op_handler_factory, _ = get_operation_factory(_increment_op)
assert op_handler_factory
op_handler_factories[name] = op_handler_factory

handler_cls: type = nexusrpc.handler.service_handler(service=service_cls)(
type("ServiceImpl", (), op_handler_factories)
)

return service_cls, handler_cls


@pytest.mark.skip(
reason="Dynamic creation of service contract using type() is not supported"
)
async def test_dynamic_creation_of_user_handler_classes(
client: Client, env: WorkflowEnvironment
):
task_queue = str(uuid.uuid4())

service_cls, handler_cls = (
make_incrementer_user_service_definition_and_service_handler_classes(
["increment"]
)
)

assert (service_defn := nexusrpc.get_service_definition(service_cls))
service_name = service_defn.name

endpoint = (await create_nexus_endpoint(task_queue, client)).endpoint.id
async with Worker(
client,
task_queue=task_queue,
nexus_service_handlers=[handler_cls()],
):
server_address = ServiceClient.default_server_address(env)
async with httpx.AsyncClient() as http_client:
response = await http_client.post(
f"http://{server_address}/nexus/endpoints/{endpoint}/services/{service_name}/increment",
json=1,
)
assert response.status_code == 200
assert response.json() == 2
Loading
Loading