44import inspect
55from dataclasses import dataclass
66from asyncio import Future , get_event_loop
7- from typing import Callable , Union , Awaitable , Any
7+ from typing import Callable , Awaitable , Any
88
99from wampproto import messages , idgen , session
1010
@@ -68,17 +68,15 @@ def __init__(self, base_session: types.IAsyncBaseSession):
6868 # RPC data structures
6969 self .call_requests : dict [int , Future [types .Result ]] = {}
7070 self .register_requests : dict [int , RegisterRequest ] = {}
71- self .registrations : dict [
72- int ,
73- Union [Callable [[types .Invocation ], types .Result ], Callable [[types .Invocation ], Awaitable [types .Result ]]],
74- ] = {}
71+ self .registrations : dict [int , Callable [[types .Invocation ], Awaitable [types .Result ]]] = {}
7572 self .unregister_requests : dict [int , types .UnregisterRequest ] = {}
7673
7774 # PubSub data structures
7875 self .publish_requests : dict [int , Future [None ]] = {}
7976 self .subscribe_requests : dict [int , SubscribeRequest ] = {}
8077 self .subscriptions : dict [int , Callable [[types .Event ], Awaitable [None ]]] = {}
8178 self .unsubscribe_requests : dict [int , types .UnsubscribeRequest ] = {}
79+ self .progress_handlers : dict [int , Callable [[types .Result ], Awaitable [None ]]] = {}
8280
8381 self .goodbye_request = Future ()
8482
@@ -118,29 +116,68 @@ async def process_incoming_message(self, msg: messages.Message):
118116 del self .registrations [request .registration_id ]
119117 request .future .set_result (None )
120118 elif isinstance (msg , messages .Result ):
121- request = self .call_requests .pop (msg .request_id )
122- request .set_result (types .Result (msg .args , msg .kwargs , msg .options ))
119+ progress = msg .options .get ("progress" , False )
120+ if progress :
121+ progress_handler = self .progress_handlers .get (msg .request_id , None )
122+ if progress_handler is not None :
123+ try :
124+ await progress_handler (types .Result (msg .args , msg .kwargs , msg .options ))
125+ except Exception as e :
126+ # TODO: implement call canceling
127+ print (e )
128+ else :
129+ request = self .call_requests .pop (msg .request_id , None )
130+ if request is not None :
131+ request .set_result (types .Result (msg .args , msg .kwargs , msg .options ))
132+ self .progress_handlers .pop (msg .request_id )
123133 elif isinstance (msg , messages .Invocation ):
124134 try :
125135 endpoint = self .registrations [msg .registration_id ]
126- result = await endpoint (types .Invocation (msg .args , msg .kwargs , msg .details ))
127-
128- if result is None :
129- data = self .session .send_message (messages .Yield (messages .YieldFields (msg .request_id )))
130- elif isinstance (result , types .Result ):
131- data = self .session .send_message (
132- messages .Yield (messages .YieldFields (msg .request_id , result .args , result .kwargs , result .details ))
133- )
134- else :
135- message = "Endpoint returned invalid result type. Expected types.Result or None, got: " + str (
136- type (result )
137- )
138- msg_to_send = messages .Error (
139- messages .ErrorFields (msg .TYPE , msg .request_id , xconn_uris .ERROR_INTERNAL_ERROR , [message ])
140- )
141- data = self .session .send_message (msg_to_send )
142-
143- await self .base_session .send (data )
136+ invocation = types .Invocation (msg .args , msg .kwargs , msg .details )
137+ receive_progress = msg .details .get ("receive_progress" , False )
138+ if receive_progress :
139+
140+ async def _progress_func (args : list [Any ] | None , kwargs : dict [str , Any ] | None ):
141+ yield_msg = messages .Yield (
142+ messages .YieldFields (msg .request_id , args , kwargs , {"progress" : True })
143+ )
144+ data = self .session .send_message (yield_msg )
145+ await self .base_session .send (data )
146+
147+ invocation .send_progress = _progress_func
148+
149+ async def handle_endpoint_invocation ():
150+ try :
151+ result = await endpoint (invocation )
152+ if result is None :
153+ data = self .session .send_message (messages .Yield (messages .YieldFields (msg .request_id )))
154+ elif isinstance (result , types .Result ):
155+ data = self .session .send_message (
156+ messages .Yield (
157+ messages .YieldFields (msg .request_id , result .args , result .kwargs , result .details )
158+ )
159+ )
160+ else :
161+ message = (
162+ "Endpoint returned invalid result type. Expected types.Result or None, got: "
163+ + str (type (result ))
164+ )
165+ msg_to_send = messages .Error (
166+ messages .ErrorFields (
167+ msg .TYPE , msg .request_id , xconn_uris .ERROR_INTERNAL_ERROR , [message ]
168+ )
169+ )
170+ data = self .session .send_message (msg_to_send )
171+ except Exception as e :
172+ message = f"unexpected error calling endpoint { endpoint .__name__ } , error is: { e } "
173+ msg_to_send = messages .Error (
174+ messages .ErrorFields (msg .TYPE , msg .request_id , xconn_uris .ERROR_INTERNAL_ERROR , [message ])
175+ )
176+ data = self .session .send_message (msg_to_send )
177+ await self .base_session .send (data )
178+
179+ current_loop = get_event_loop ()
180+ current_loop .create_task (handle_endpoint_invocation ())
144181 except ApplicationError as e :
145182 msg_to_send = messages .Error (messages .ErrorFields (msg .TYPE , msg .request_id , e .message , e .args ))
146183 data = self .session .send_message (msg_to_send )
@@ -215,6 +252,15 @@ async def register(
215252
216253 return await f
217254
255+ async def _call (self , call_msg : messages .Call ) -> types .Result :
256+ f = Future ()
257+ self .call_requests [call_msg .request_id ] = f
258+
259+ data = self .session .send_message (call_msg )
260+ await self .base_session .send (data )
261+
262+ return await f
263+
218264 async def call (self , procedure : str , * args , ** kwargs ) -> types .Result :
219265 options = kwargs .pop ("options" , None )
220266 call = messages .Call (messages .CallFields (self .idgen .next (), procedure , args , kwargs , options = options ))
@@ -227,6 +273,23 @@ async def call(self, procedure: str, *args, **kwargs) -> types.Result:
227273
228274 return await f
229275
276+ async def call_progress (
277+ self ,
278+ procedure : str ,
279+ progress_handler : Callable [[types .Result ], Awaitable [None ]],
280+ args : list [Any ] | None = None ,
281+ kwargs : dict [str , Any ] | None = None ,
282+ options : dict [str , Any ] | None = None ,
283+ ) -> types .Result :
284+ if options is None :
285+ options = {}
286+
287+ options ["receive_progress" ] = True
288+ call_msg = messages .Call (messages .CallFields (self .idgen .next (), procedure , args , kwargs , options ))
289+ self .progress_handlers [call_msg .request_id ] = progress_handler
290+
291+ return await self ._call (call_msg )
292+
230293 async def subscribe (
231294 self , topic : str , event_handler : Callable [[types .Event ], Awaitable [None ]], options : dict | None = None
232295 ) -> Subscription :
0 commit comments