Skip to content

Commit fcf4cd5

Browse files
committed
implement progressive call results
1 parent ed24e39 commit fcf4cd5

File tree

8 files changed

+294
-52
lines changed

8 files changed

+294
-52
lines changed

.github/workflows/main.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ jobs:
2626
run: |
2727
git clone https://github.com/xconnio/xconn-aat-setup.git
2828
cd xconn-aat-setup
29-
make build-docker-xconn
29+
make build-docker-nxt
3030
make build-docker-crossbar
3131
docker compose up -d
3232
sudo snap install wick --classic
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import asyncio
2+
3+
from xconn import run
4+
from xconn.types import Result, Invocation
5+
from xconn.async_client import connect_anonymous
6+
7+
8+
async def invocation_handler(invocation: Invocation) -> Result:
9+
file_size = 100
10+
for i in range(0, file_size + 1, 10):
11+
progress = i * 100 // file_size
12+
try:
13+
await invocation.send_progress([progress], {})
14+
except Exception as err:
15+
return Result(["wamp.error.canceled", str(err)])
16+
await asyncio.sleep(0.5)
17+
18+
return Result(["Download complete!"])
19+
20+
21+
async def main() -> None:
22+
test_procedure_progress_download = "io.xconn.progress.download"
23+
24+
# create and connect a callee client to server
25+
callee = await connect_anonymous("ws://localhost:8080/ws", "realm1")
26+
27+
await callee.register(test_procedure_progress_download, invocation_handler)
28+
print(f"Registered procedure '{test_procedure_progress_download}'")
29+
30+
31+
if __name__ == "__main__":
32+
run(main())
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from xconn import run
2+
from xconn.types import Result
3+
from xconn.async_client import connect_anonymous
4+
5+
6+
async def progress_handler(res: Result) -> None:
7+
progress = res.args[0]
8+
print(f"Download progress: {progress}%")
9+
10+
11+
async def main() -> None:
12+
test_procedure_progress_download = "io.xconn.progress.download"
13+
14+
# create and connect a callee client to server
15+
caller = await connect_anonymous("ws://localhost:8080/ws", "realm1")
16+
17+
await caller.call_progress(test_procedure_progress_download, progress_handler)
18+
19+
20+
if __name__ == "__main__":
21+
run(main())
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import sys
2+
import time
3+
import signal
4+
5+
from xconn.client import connect_anonymous
6+
from xconn.types import Result, Invocation
7+
8+
9+
def invocation_handler(invocation: Invocation) -> Result:
10+
file_size = 100
11+
for i in range(0, file_size + 1, 10):
12+
progress = i * 100 // file_size
13+
try:
14+
invocation.send_progress([progress], {})
15+
except Exception as err:
16+
return Result(["wamp.error.canceled", str(err)])
17+
time.sleep(0.5)
18+
19+
return Result(["Download complete!"])
20+
21+
22+
if __name__ == "__main__":
23+
test_procedure_progress_download = "io.xconn.progress.download"
24+
25+
# create and connect a callee client to server
26+
callee = connect_anonymous("ws://localhost:8080/ws", "realm1")
27+
28+
download_progress_registration = callee.register(test_procedure_progress_download, invocation_handler)
29+
print(f"Registered procedure '{test_procedure_progress_download}'")
30+
31+
def handle_sigint(signum, frame):
32+
print("SIGINT received. Cleaning up...")
33+
34+
# unregister procedure "io.xconn.progress.download"
35+
download_progress_registration.unregister()
36+
37+
# close connection to the server
38+
callee.leave()
39+
40+
sys.exit(0)
41+
42+
43+
# register signal handler
44+
signal.signal(signal.SIGINT, handle_sigint)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from xconn.types import Result
2+
from xconn.client import connect_anonymous
3+
4+
5+
def progress_handler(res: Result) -> None:
6+
progress = res.args[0]
7+
print(f"Download progress: {progress}%")
8+
9+
10+
if __name__ == "__main__":
11+
test_procedure_progress_download = "io.xconn.progress.download"
12+
13+
# create and connect a callee client to server
14+
caller = connect_anonymous("ws://localhost:8080/ws", "realm1")
15+
16+
caller.call_progress(test_procedure_progress_download, progress_handler)
17+
18+
caller.leave()

xconn/async_session.py

Lines changed: 88 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import inspect
55
from dataclasses import dataclass
66
from asyncio import Future, get_event_loop
7-
from typing import Callable, Union, Awaitable, Any
7+
from typing import Callable, Awaitable, Any
88

99
from wampproto import messages, idgen, session
1010

@@ -68,17 +68,15 @@ def __init__(self, base_session: types.IAsyncBaseSession):
6868
# RPC data structures
6969
self.call_requests: dict[int, Future[types.Result]] = {}
7070
self.register_requests: dict[int, RegisterRequest] = {}
71-
self.registrations: dict[
72-
int,
73-
Union[Callable[[types.Invocation], types.Result], Callable[[types.Invocation], Awaitable[types.Result]]],
74-
] = {}
71+
self.registrations: dict[int, Callable[[types.Invocation], Awaitable[types.Result]]] = {}
7572
self.unregister_requests: dict[int, types.UnregisterRequest] = {}
7673

7774
# PubSub data structures
7875
self.publish_requests: dict[int, Future[None]] = {}
7976
self.subscribe_requests: dict[int, SubscribeRequest] = {}
8077
self.subscriptions: dict[int, Callable[[types.Event], Awaitable[None]]] = {}
8178
self.unsubscribe_requests: dict[int, types.UnsubscribeRequest] = {}
79+
self.progress_handlers: dict[int, Callable[[types.Result], Awaitable[None]]] = {}
8280

8381
self.goodbye_request = Future()
8482

@@ -118,29 +116,68 @@ async def process_incoming_message(self, msg: messages.Message):
118116
del self.registrations[request.registration_id]
119117
request.future.set_result(None)
120118
elif isinstance(msg, messages.Result):
121-
request = self.call_requests.pop(msg.request_id)
122-
request.set_result(types.Result(msg.args, msg.kwargs, msg.options))
119+
progress = msg.options.get("progress", False)
120+
if progress:
121+
progress_handler = self.progress_handlers.get(msg.request_id, None)
122+
if progress_handler is not None:
123+
try:
124+
await progress_handler(types.Result(msg.args, msg.kwargs, msg.options))
125+
except Exception as e:
126+
# TODO: implement call canceling
127+
print(e)
128+
else:
129+
request = self.call_requests.pop(msg.request_id, None)
130+
if request is not None:
131+
request.set_result(types.Result(msg.args, msg.kwargs, msg.options))
132+
self.progress_handlers.pop(msg.request_id)
123133
elif isinstance(msg, messages.Invocation):
124134
try:
125135
endpoint = self.registrations[msg.registration_id]
126-
result = await endpoint(types.Invocation(msg.args, msg.kwargs, msg.details))
127-
128-
if result is None:
129-
data = self.session.send_message(messages.Yield(messages.YieldFields(msg.request_id)))
130-
elif isinstance(result, types.Result):
131-
data = self.session.send_message(
132-
messages.Yield(messages.YieldFields(msg.request_id, result.args, result.kwargs, result.details))
133-
)
134-
else:
135-
message = "Endpoint returned invalid result type. Expected types.Result or None, got: " + str(
136-
type(result)
137-
)
138-
msg_to_send = messages.Error(
139-
messages.ErrorFields(msg.TYPE, msg.request_id, xconn_uris.ERROR_INTERNAL_ERROR, [message])
140-
)
141-
data = self.session.send_message(msg_to_send)
142-
143-
await self.base_session.send(data)
136+
invocation = types.Invocation(msg.args, msg.kwargs, msg.details)
137+
receive_progress = msg.details.get("receive_progress", False)
138+
if receive_progress:
139+
140+
async def _progress_func(args: list[Any] | None, kwargs: dict[str, Any] | None):
141+
yield_msg = messages.Yield(
142+
messages.YieldFields(msg.request_id, args, kwargs, {"progress": True})
143+
)
144+
data = self.session.send_message(yield_msg)
145+
await self.base_session.send(data)
146+
147+
invocation.send_progress = _progress_func
148+
149+
async def handle_endpoint_invocation():
150+
try:
151+
result = await endpoint(invocation)
152+
if result is None:
153+
data = self.session.send_message(messages.Yield(messages.YieldFields(msg.request_id)))
154+
elif isinstance(result, types.Result):
155+
data = self.session.send_message(
156+
messages.Yield(
157+
messages.YieldFields(msg.request_id, result.args, result.kwargs, result.details)
158+
)
159+
)
160+
else:
161+
message = (
162+
"Endpoint returned invalid result type. Expected types.Result or None, got: "
163+
+ str(type(result))
164+
)
165+
msg_to_send = messages.Error(
166+
messages.ErrorFields(
167+
msg.TYPE, msg.request_id, xconn_uris.ERROR_INTERNAL_ERROR, [message]
168+
)
169+
)
170+
data = self.session.send_message(msg_to_send)
171+
except Exception as e:
172+
message = f"unexpected error calling endpoint {endpoint.__name__}, error is: {e}"
173+
msg_to_send = messages.Error(
174+
messages.ErrorFields(msg.TYPE, msg.request_id, xconn_uris.ERROR_INTERNAL_ERROR, [message])
175+
)
176+
data = self.session.send_message(msg_to_send)
177+
await self.base_session.send(data)
178+
179+
current_loop = get_event_loop()
180+
current_loop.create_task(handle_endpoint_invocation())
144181
except ApplicationError as e:
145182
msg_to_send = messages.Error(messages.ErrorFields(msg.TYPE, msg.request_id, e.message, e.args))
146183
data = self.session.send_message(msg_to_send)
@@ -215,6 +252,15 @@ async def register(
215252

216253
return await f
217254

255+
async def _call(self, call_msg: messages.Call) -> types.Result:
256+
f = Future()
257+
self.call_requests[call_msg.request_id] = f
258+
259+
data = self.session.send_message(call_msg)
260+
await self.base_session.send(data)
261+
262+
return await f
263+
218264
async def call(self, procedure: str, *args, **kwargs) -> types.Result:
219265
options = kwargs.pop("options", None)
220266
call = messages.Call(messages.CallFields(self.idgen.next(), procedure, args, kwargs, options=options))
@@ -227,6 +273,23 @@ async def call(self, procedure: str, *args, **kwargs) -> types.Result:
227273

228274
return await f
229275

276+
async def call_progress(
277+
self,
278+
procedure: str,
279+
progress_handler: Callable[[types.Result], Awaitable[None]],
280+
args: list[Any] | None = None,
281+
kwargs: dict[str, Any] | None = None,
282+
options: dict[str, Any] | None = None,
283+
) -> types.Result:
284+
if options is None:
285+
options = {}
286+
287+
options["receive_progress"] = True
288+
call_msg = messages.Call(messages.CallFields(self.idgen.next(), procedure, args, kwargs, options))
289+
self.progress_handlers[call_msg.request_id] = progress_handler
290+
291+
return await self._call(call_msg)
292+
230293
async def subscribe(
231294
self, topic: str, event_handler: Callable[[types.Event], Awaitable[None]], options: dict | None = None
232295
) -> Subscription:

0 commit comments

Comments
 (0)