Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions .github/workflows/build-binaries.yml
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,3 @@ jobs:
with:
name: packages-${{ matrix.package-suffix }}
path: dist

- name: Deliberately fail to prevent releasing nexus-rpc w/ GitHub link in pyproject.toml
run: |
echo "This is a deliberate failure to prevent releasing nexus-rpc with a GitHub link in pyproject.toml"
exit 1
5 changes: 1 addition & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ keywords = [
"workflow",
]
dependencies = [
"nexus-rpc>=1.1.0",
"nexus-rpc==1.1.0",
"protobuf>=3.20,<6",
"python-dateutil>=2.8.2,<3 ; python_version < '3.11'",
"types-protobuf>=3.20",
Expand Down Expand Up @@ -231,6 +231,3 @@ exclude = [
[tool.uv]
# Prevent uv commands from building the package by default
package = false

[tool.uv.sources]
nexus-rpc = { git = "https://github.com/nexus-rpc/sdk-python.git", rev = "35f574c711193a6e2560d3e6665732a5bb7ae92c" }
2 changes: 1 addition & 1 deletion temporalio/nexus/_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ async def _start(
return WorkflowRunOperationHandler(_start, input_type, output_type)

method_name = get_callable_name(start)
nexusrpc.set_operation(
nexusrpc.set_operation_definition(
operation_handler_factory,
nexusrpc.Operation(
name=name or method_name,
Expand Down
32 changes: 32 additions & 0 deletions temporalio/nexus/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
TypeVar,
)

import nexusrpc
from nexusrpc import (
InputT,
OutputT,
Expand Down Expand Up @@ -78,10 +79,19 @@ def _get_start_method_input_and_output_type_annotations(
try:
type_annotations = typing.get_type_hints(start)
except TypeError:
warnings.warn(
f"Expected decorated start method {start} to have type annotations"
)
return None, None
output_type = type_annotations.pop("return", None)

if len(type_annotations) != 2:
suffix = f": {type_annotations}" if type_annotations else ""
warnings.warn(
f"Expected decorated start method {start} to have exactly 2 "
f"type-annotated parameters (ctx and input), but it has {len(type_annotations)}"
f"{suffix}."
)
input_type = None
else:
ctx_type, input_type = type_annotations.values()
Expand All @@ -108,6 +118,28 @@ def get_callable_name(fn: Callable[..., Any]) -> str:
return method_name


# TODO(nexus-preview) Copied from nexusrpc
def get_operation_factory(
obj: Any,
) -> tuple[
Optional[Callable[[Any], Any]],
Optional[nexusrpc.Operation[Any, Any]],
]:
"""Return the :py:class:`Operation` for the object along with the factory function.

``obj`` should be a decorated operation start method.
"""
op_defn = nexusrpc.get_operation_definition(obj)
if op_defn:
factory = obj
else:
if factory := getattr(obj, "__nexus_operation_factory__", None):
op_defn = nexusrpc.get_operation_definition(factory)
if not isinstance(op_defn, nexusrpc.Operation):
return None, None
return factory, op_defn


# TODO(nexus-preview) Copied from nexusrpc
def set_operation_factory(
obj: Any,
Expand Down
8 changes: 5 additions & 3 deletions temporalio/worker/_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,14 +299,15 @@ class StartNexusOperationInput(Generic[InputT, OutputT]):
input: InputT
schedule_to_close_timeout: Optional[timedelta]
headers: Optional[Mapping[str, str]]
output_type: Optional[type[OutputT]] = None
output_type: Optional[Type[OutputT]] = None

def __post_init__(self) -> None:
"""Initialize operation-specific attributes after dataclass creation."""
if isinstance(self.operation, nexusrpc.Operation):
self.output_type = self.operation.output_type
elif callable(self.operation):
if op := nexusrpc.get_operation(self.operation):
_, op = temporalio.nexus._util.get_operation_factory(self.operation)
if isinstance(op, nexusrpc.Operation):
self.output_type = op.output_type
else:
raise ValueError(
Expand All @@ -325,7 +326,8 @@ def operation_name(self) -> str:
elif isinstance(self.operation, str):
return self.operation
elif callable(self.operation):
if op := nexusrpc.get_operation(self.operation):
_, op = temporalio.nexus._util.get_operation_factory(self.operation)
if isinstance(op, nexusrpc.Operation):
return op.name
else:
raise ValueError(
Expand Down
36 changes: 18 additions & 18 deletions temporalio/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5145,7 +5145,7 @@ async def start_operation(
operation: nexusrpc.Operation[InputT, OutputT],
input: InputT,
*,
output_type: Optional[type[OutputT]] = None,
output_type: Optional[Type[OutputT]] = None,
schedule_to_close_timeout: Optional[timedelta] = None,
headers: Optional[Mapping[str, str]] = None,
) -> NexusOperationHandle[OutputT]: ...
Expand All @@ -5158,7 +5158,7 @@ async def start_operation(
operation: str,
input: Any,
*,
output_type: Optional[type[OutputT]] = None,
output_type: Optional[Type[OutputT]] = None,
schedule_to_close_timeout: Optional[timedelta] = None,
headers: Optional[Mapping[str, str]] = None,
) -> NexusOperationHandle[OutputT]: ...
Expand All @@ -5174,7 +5174,7 @@ async def start_operation(
],
input: InputT,
*,
output_type: Optional[type[OutputT]] = None,
output_type: Optional[Type[OutputT]] = None,
schedule_to_close_timeout: Optional[timedelta] = None,
headers: Optional[Mapping[str, str]] = None,
) -> NexusOperationHandle[OutputT]: ...
Expand All @@ -5190,7 +5190,7 @@ async def start_operation(
],
input: InputT,
*,
output_type: Optional[type[OutputT]] = None,
output_type: Optional[Type[OutputT]] = None,
schedule_to_close_timeout: Optional[timedelta] = None,
headers: Optional[Mapping[str, str]] = None,
) -> NexusOperationHandle[OutputT]: ...
Expand All @@ -5206,7 +5206,7 @@ async def start_operation(
],
input: InputT,
*,
output_type: Optional[type[OutputT]] = None,
output_type: Optional[Type[OutputT]] = None,
schedule_to_close_timeout: Optional[timedelta] = None,
headers: Optional[Mapping[str, str]] = None,
) -> NexusOperationHandle[OutputT]: ...
Expand All @@ -5217,7 +5217,7 @@ async def start_operation(
operation: Any,
input: Any,
*,
output_type: Optional[type[OutputT]] = None,
output_type: Optional[Type[OutputT]] = None,
schedule_to_close_timeout: Optional[timedelta] = None,
headers: Optional[Mapping[str, str]] = None,
) -> Any:
Expand Down Expand Up @@ -5246,7 +5246,7 @@ async def execute_operation(
operation: nexusrpc.Operation[InputT, OutputT],
input: InputT,
*,
output_type: Optional[type[OutputT]] = None,
output_type: Optional[Type[OutputT]] = None,
schedule_to_close_timeout: Optional[timedelta] = None,
headers: Optional[Mapping[str, str]] = None,
) -> OutputT: ...
Expand All @@ -5259,7 +5259,7 @@ async def execute_operation(
operation: str,
input: Any,
*,
output_type: Optional[type[OutputT]] = None,
output_type: Optional[Type[OutputT]] = None,
schedule_to_close_timeout: Optional[timedelta] = None,
headers: Optional[Mapping[str, str]] = None,
) -> OutputT: ...
Expand All @@ -5275,7 +5275,7 @@ async def execute_operation(
],
input: InputT,
*,
output_type: Optional[type[OutputT]] = None,
output_type: Optional[Type[OutputT]] = None,
schedule_to_close_timeout: Optional[timedelta] = None,
headers: Optional[Mapping[str, str]] = None,
) -> OutputT: ...
Expand All @@ -5294,7 +5294,7 @@ async def execute_operation(
],
input: InputT,
*,
output_type: Optional[type[OutputT]] = None,
output_type: Optional[Type[OutputT]] = None,
schedule_to_close_timeout: Optional[timedelta] = None,
headers: Optional[Mapping[str, str]] = None,
) -> OutputT: ...
Expand All @@ -5310,7 +5310,7 @@ async def execute_operation(
],
input: InputT,
*,
output_type: Optional[type[OutputT]] = None,
output_type: Optional[Type[OutputT]] = None,
schedule_to_close_timeout: Optional[timedelta] = None,
headers: Optional[Mapping[str, str]] = None,
) -> OutputT: ...
Expand All @@ -5321,7 +5321,7 @@ async def execute_operation(
operation: Any,
input: Any,
*,
output_type: Optional[type[OutputT]] = None,
output_type: Optional[Type[OutputT]] = None,
schedule_to_close_timeout: Optional[timedelta] = None,
headers: Optional[Mapping[str, str]] = None,
) -> Any:
Expand All @@ -5345,7 +5345,7 @@ def __init__(
self,
*,
endpoint: str,
service: Union[type[ServiceT], str],
service: Union[Type[ServiceT], str],
) -> None:
"""Create a Nexus client.

Expand All @@ -5372,7 +5372,7 @@ async def start_operation(
operation: Any,
input: Any,
*,
output_type: Optional[type] = None,
output_type: Optional[Type] = None,
schedule_to_close_timeout: Optional[timedelta] = None,
headers: Optional[Mapping[str, str]] = None,
) -> Any:
Expand All @@ -5393,7 +5393,7 @@ async def execute_operation(
operation: Any,
input: Any,
*,
output_type: Optional[type] = None,
output_type: Optional[Type] = None,
schedule_to_close_timeout: Optional[timedelta] = None,
headers: Optional[Mapping[str, str]] = None,
) -> Any:
Expand All @@ -5410,7 +5410,7 @@ async def execute_operation(
@overload
def create_nexus_client(
*,
service: type[ServiceT],
service: Type[ServiceT],
endpoint: str,
) -> NexusClient[ServiceT]: ...

Expand All @@ -5425,9 +5425,9 @@ def create_nexus_client(

def create_nexus_client(
*,
service: Union[type[ServiceT], str],
service: Union[Type[ServiceT], str],
endpoint: str,
) -> NexusClient[Any]:
) -> NexusClient[ServiceT]:
"""Create a Nexus client.

.. warning::
Expand Down
73 changes: 71 additions & 2 deletions tests/nexus/test_dynamic_creation_of_user_handler_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import httpx
import nexusrpc.handler
import pytest
from nexusrpc.handler import sync_operation

from temporalio import nexus, workflow
from temporalio.client import Client
from temporalio.nexus._util import get_operation_factory
from temporalio.testing import WorkflowEnvironment
from temporalio.worker import Worker
from tests.helpers.nexus import ServiceClient, create_nexus_endpoint
Expand Down Expand Up @@ -76,8 +78,8 @@ async def test_run_nexus_service_from_programmatically_created_service_handler(
service_handler = nexusrpc.handler._core.ServiceHandler(
service=nexusrpc.ServiceDefinition(
name="MyService",
operation_definitions={
"increment": nexusrpc.OperationDefinition[int, int](
operations={
"increment": nexusrpc.Operation[int, int](
name="increment",
method_name="increment",
input_type=int,
Expand Down Expand Up @@ -105,3 +107,70 @@ async def test_run_nexus_service_from_programmatically_created_service_handler(
json=1,
)
assert response.status_code == 201


def make_incrementer_user_service_definition_and_service_handler_classes(
op_names: list[str],
) -> tuple[type, type]:
#
# service contract
#

ops = {name: nexusrpc.Operation[int, int] for name in op_names}
service_cls: type = nexusrpc.service(type("ServiceContract", (), ops))

#
# service handler
#
@sync_operation
async def _increment_op(
self,
ctx: nexusrpc.handler.StartOperationContext,
input: int,
) -> int:
return input + 1

op_handler_factories = {}
for name in op_names:
op_handler_factory, _ = get_operation_factory(_increment_op)
assert op_handler_factory
op_handler_factories[name] = op_handler_factory

handler_cls: type = nexusrpc.handler.service_handler(service=service_cls)(
type("ServiceImpl", (), op_handler_factories)
)

return service_cls, handler_cls


@pytest.mark.skip(
reason="Dynamic creation of service contract using type() is not supported"
)
async def test_dynamic_creation_of_user_handler_classes(
client: Client, env: WorkflowEnvironment
):
task_queue = str(uuid.uuid4())

service_cls, handler_cls = (
make_incrementer_user_service_definition_and_service_handler_classes(
["increment"]
)
)

assert (service_defn := nexusrpc.get_service_definition(service_cls))
service_name = service_defn.name

endpoint = (await create_nexus_endpoint(task_queue, client)).endpoint.id
async with Worker(
client,
task_queue=task_queue,
nexus_service_handlers=[handler_cls()],
):
server_address = ServiceClient.default_server_address(env)
async with httpx.AsyncClient() as http_client:
response = await http_client.post(
f"http://{server_address}/nexus/endpoints/{endpoint}/services/{service_name}/increment",
json=1,
)
assert response.status_code == 200
assert response.json() == 2
Loading