Skip to content

Commit 9780f2b

Browse files
committed
Respond to upstream: do not use get_operation_factory
1 parent c4a85b7 commit 9780f2b

File tree

4 files changed

+4
-99
lines changed

4 files changed

+4
-99
lines changed

temporalio/nexus/_util.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
TypeVar,
1414
)
1515

16-
import nexusrpc
1716
from nexusrpc import (
1817
InputT,
1918
OutputT,
@@ -118,28 +117,6 @@ def get_callable_name(fn: Callable[..., Any]) -> str:
118117
return method_name
119118

120119

121-
# TODO(nexus-preview) Copied from nexusrpc
122-
def get_operation_factory(
123-
obj: Any,
124-
) -> tuple[
125-
Optional[Callable[[Any], Any]],
126-
Optional[nexusrpc.Operation[Any, Any]],
127-
]:
128-
"""Return the :py:class:`Operation` for the object along with the factory function.
129-
130-
``obj`` should be a decorated operation start method.
131-
"""
132-
op = nexusrpc.get_operation(obj)
133-
if op:
134-
factory = obj
135-
else:
136-
if factory := getattr(obj, "__nexus_operation_factory__", None):
137-
op = nexusrpc.get_operation(factory)
138-
if not isinstance(op, nexusrpc.Operation):
139-
return None, None
140-
return factory, op
141-
142-
143120
# TODO(nexus-preview) Copied from nexusrpc
144121
def set_operation_factory(
145122
obj: Any,

temporalio/worker/_interceptor.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -299,15 +299,14 @@ class StartNexusOperationInput(Generic[InputT, OutputT]):
299299
input: InputT
300300
schedule_to_close_timeout: Optional[timedelta]
301301
headers: Optional[Mapping[str, str]]
302-
output_type: Optional[Type[OutputT]] = None
302+
output_type: Optional[type[OutputT]] = None
303303

304304
def __post_init__(self) -> None:
305305
"""Initialize operation-specific attributes after dataclass creation."""
306306
if isinstance(self.operation, nexusrpc.Operation):
307307
self.output_type = self.operation.output_type
308308
elif callable(self.operation):
309-
_, op = temporalio.nexus._util.get_operation_factory(self.operation)
310-
if isinstance(op, nexusrpc.Operation):
309+
if op := nexusrpc.get_operation(self.operation):
311310
self.output_type = op.output_type
312311
else:
313312
raise ValueError(
@@ -326,8 +325,7 @@ def operation_name(self) -> str:
326325
elif isinstance(self.operation, str):
327326
return self.operation
328327
elif callable(self.operation):
329-
_, op = temporalio.nexus._util.get_operation_factory(self.operation)
330-
if isinstance(op, nexusrpc.Operation):
328+
if op := nexusrpc.get_operation(self.operation):
331329
return op.name
332330
else:
333331
raise ValueError(

tests/nexus/test_dynamic_creation_of_user_handler_classes.py

Lines changed: 0 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,9 @@
33
import httpx
44
import nexusrpc.handler
55
import pytest
6-
from nexusrpc.handler import sync_operation
76

87
from temporalio import nexus, workflow
98
from temporalio.client import Client
10-
from temporalio.nexus._util import get_operation_factory
119
from temporalio.testing import WorkflowEnvironment
1210
from temporalio.worker import Worker
1311
from tests.helpers.nexus import ServiceClient, create_nexus_endpoint
@@ -107,70 +105,3 @@ async def test_run_nexus_service_from_programmatically_created_service_handler(
107105
json=1,
108106
)
109107
assert response.status_code == 201
110-
111-
112-
def make_incrementer_user_service_definition_and_service_handler_classes(
113-
op_names: list[str],
114-
) -> tuple[type, type]:
115-
#
116-
# service contract
117-
#
118-
119-
ops = {name: nexusrpc.Operation[int, int] for name in op_names}
120-
service_cls: type = nexusrpc.service(type("ServiceContract", (), ops))
121-
122-
#
123-
# service handler
124-
#
125-
@sync_operation
126-
async def _increment_op(
127-
self,
128-
ctx: nexusrpc.handler.StartOperationContext,
129-
input: int,
130-
) -> int:
131-
return input + 1
132-
133-
op_handler_factories = {}
134-
for name in op_names:
135-
op_handler_factory, _ = get_operation_factory(_increment_op)
136-
assert op_handler_factory
137-
op_handler_factories[name] = op_handler_factory
138-
139-
handler_cls: type = nexusrpc.handler.service_handler(service=service_cls)(
140-
type("ServiceImpl", (), op_handler_factories)
141-
)
142-
143-
return service_cls, handler_cls
144-
145-
146-
@pytest.mark.skip(
147-
reason="Dynamic creation of service contract using type() is not supported"
148-
)
149-
async def test_dynamic_creation_of_user_handler_classes(
150-
client: Client, env: WorkflowEnvironment
151-
):
152-
task_queue = str(uuid.uuid4())
153-
154-
service_cls, handler_cls = (
155-
make_incrementer_user_service_definition_and_service_handler_classes(
156-
["increment"]
157-
)
158-
)
159-
160-
assert (service_defn := nexusrpc.get_service_definition(service_cls))
161-
service_name = service_defn.name
162-
163-
endpoint = (await create_nexus_endpoint(task_queue, client)).endpoint.id
164-
async with Worker(
165-
client,
166-
task_queue=task_queue,
167-
nexus_service_handlers=[handler_cls()],
168-
):
169-
server_address = ServiceClient.default_server_address(env)
170-
async with httpx.AsyncClient() as http_client:
171-
response = await http_client.post(
172-
f"http://{server_address}/nexus/endpoints/{endpoint}/services/{service_name}/increment",
173-
json=1,
174-
)
175-
assert response.status_code == 200
176-
assert response.json() == 2

tests/nexus/test_handler_operation_definitions.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
from temporalio import nexus
1313
from temporalio.nexus import WorkflowRunOperationContext, workflow_run_operation
14-
from temporalio.nexus._util import get_operation_factory
1514

1615

1716
@dataclass
@@ -96,7 +95,7 @@ async def test_collected_operation_names(
9695
assert isinstance(service_defn, nexusrpc.ServiceDefinition)
9796
assert service_defn.name == "Service"
9897
for method_name, expected_op in test_case.expected_operations.items():
99-
_, actual_op = get_operation_factory(getattr(test_case.Service, method_name))
98+
actual_op = nexusrpc.get_operation(getattr(test_case.Service, method_name))
10099
assert isinstance(actual_op, nexusrpc.Operation)
101100
assert actual_op.name == expected_op.name
102101
assert actual_op.input_type == expected_op.input_type

0 commit comments

Comments
 (0)