44import asyncio
55import socket
66import struct
7- import time
87import typing
8+ from contextlib import suppress
99
1010import requests
1111from pytoniq_core import HashMap , Builder
@@ -68,6 +68,7 @@ def __init__(self,
6868 """########### init ###########"""
6969 self .tasks = {}
7070 self .inited = False
71+ self ._closing = False
7172 self .logger = logging .getLogger (self .__class__ .__name__ )
7273 self .timeout = timeout
7374
@@ -141,43 +142,62 @@ async def send_and_encrypt(self, data: bytes, qid: str) -> asyncio.Future:
141142
142143 async def receive (self , data_len : int ) -> bytes :
143144 try :
144- data = await self .reader .readexactly (data_len )
145- except ConnectionError :
145+ return await self .reader .readexactly (data_len )
146+ except ( ConnectionError , asyncio . IncompleteReadError ) :
146147 await self .close ()
147148 raise
148- return data
149149
150150 async def receive_and_decrypt (self , data_len : int ) -> bytes :
151151 data = self .decrypt (await self .receive (data_len ))
152152 return data
153153
154154 async def listen (self ) -> None :
155- while True :
156- while not self .tasks :
157- await asyncio .sleep (self .delta )
155+ try :
156+ while True :
157+ while not self .tasks :
158+ await asyncio .sleep (self .delta )
158159
159- data_len_encrypted = await self .receive (4 )
160- data_len = int (self .decrypt (data_len_encrypted )[::- 1 ].hex (), 16 )
160+ data_len_encrypted = await self .receive (4 )
161+ data_len = int (self .decrypt (data_len_encrypted )[::- 1 ].hex (), 16 )
161162
162- self .logger .debug (msg = f'received { data_len // 8 } bytes of data' )
163+ self .logger .debug (msg = f'received { data_len // 8 } bytes of data' )
163164
164- data_encrypted = await self .receive (data_len )
165- data_decrypted = self .decrypt (data_encrypted )
166- # check hashsum
167- assert hashlib .sha256 (data_decrypted [:- 32 ]).digest () == data_decrypted [- 32 :], 'incorrect checksum'
168- result = self .deserialize_adnl_query (data_decrypted [:- 32 ])
165+ data_encrypted = await self .receive (data_len )
166+ data_decrypted = self .decrypt (data_encrypted )
167+ # check hashsum
168+ assert hashlib .sha256 (data_decrypted [:- 32 ]).digest () == data_decrypted [- 32 :], 'incorrect checksum'
169+ result = self .deserialize_adnl_query (data_decrypted [:- 32 ])
169170
170- if not result :
171- # for handshake
172- result = {}
171+ if not result :
172+ # for handshake
173+ result = {}
173174
174- qid = result .get ('query_id' , result .get ('random_id' )) # return query_id for ordinary requests, random_id for ping-pong requests, None for handshake
175+ qid = result .get ('query_id' , result .get ('random_id' )) # return query_id for ordinary requests, random_id for ping-pong requests, None for handshake
175176
176- request : asyncio .Future = self .tasks .pop (qid )
177+ request : asyncio .Future = self .tasks .pop (qid )
178+
179+ result = result .get ('answer' , {})
180+ if not request .done ():
181+ request .set_result (result )
182+ except asyncio .CancelledError :
183+ pass
184+ except (ConnectionResetError , asyncio .IncompleteReadError , ConnectionAbortedError , TimeoutError ):
185+ return
186+ except Exception as e :
187+ self .logger .exception ('listener crashed' )
188+ asyncio .create_task (self .close ())
189+ return
190+ finally :
191+ self ._cancel_all_tasks ()
177192
178- result = result .get ('answer' , {})
179- if not request .done ():
180- request .set_result (result )
193+ def _cancel_all_tasks (self ):
194+ if self .tasks :
195+ for fut in list (self .tasks .values ()):
196+ if fut and not fut .done ():
197+ fut : asyncio .Future
198+ # fut.cancel()
199+ fut .set_exception (LiteClientError ('Connection is closed' ))
200+ self .tasks .clear ()
181201
182202 async def connect (self ) -> None :
183203 if self .inited :
@@ -190,10 +210,10 @@ async def connect(self) -> None:
190210 )
191211 future = await asyncio .wait_for (self .send (handshake , None ), self .timeout )
192212 self .listener = asyncio .create_task (self .listen ())
213+ await asyncio .wait_for (future , self .timeout )
193214 await self .update_last_blocks ()
194215 self .pinger = asyncio .create_task (self .ping ())
195216 self .updater = asyncio .create_task (self .block_updater ())
196- await future
197217 self .inited = True
198218
199219 async def reconnect (self , max_retries : int = 5 , retry_delay : int = 2 ) -> None :
@@ -210,29 +230,41 @@ async def reconnect(self, max_retries: int = 5, retry_delay: int = 2) -> None:
210230 self .logger .info ("Successfully reconnected" )
211231 return # exit if connection succeeds
212232 except Exception as e :
213- self .logger .error (f"Reconnection attempt { attempt + 1 } /{ max_retries } failed: { str (e )} " )
233+ self .logger .error (f"Reconnection attempt { attempt + 1 } /{ max_retries } failed: { type ( e ) } : { str (e )} " )
214234 await asyncio .sleep (retry_delay )
215235 raise LiteClientError ('Failed to reconnect after several attempts' )
216236
217237 async def close (self ) -> None :
218- for task in [self .pinger , self .updater , self .listener ]:
219- if task is not None and not task .done ():
220- task .cancel ()
221- if task is not None :
222- try :
223- await task
224- except (asyncio .CancelledError , Exception ):
225- pass
226- self .inited = False
227- self .tasks = {}
228- self .reader = None
229- if self .writer :
230- self .writer .close ()
231- try :
232- await self .writer .wait_closed ()
233- except ConnectionError :
234- pass
235- self .writer = None
238+ current = asyncio .current_task ()
239+ if self ._closing :
240+ return
241+ self ._closing = True
242+ self ._cancel_all_tasks ()
243+ try :
244+ for task_name in ("pinger" , "updater" , "listener" ):
245+ task = getattr (self , task_name , None )
246+ if task is None :
247+ continue
248+ if task is not current and not task .done ():
249+ task .cancel ()
250+ if task is not current :
251+ try :
252+ await asyncio .wait_for (task , timeout = 1.0 )
253+ except (asyncio .CancelledError , Exception ):
254+ pass
255+ setattr (self , task_name , None )
256+
257+ self .inited = False
258+ self .tasks .clear ()
259+ self .reader = None
260+ w = self .writer
261+ self .writer = None
262+ with suppress (Exception ):
263+ w .close ()
264+ with suppress (Exception ):
265+ await asyncio .wait_for (w .wait_closed (), timeout = 1.0 )
266+ finally :
267+ self ._closing = False
236268 self .logger .info ('client has been closed' )
237269
238270 def handshake (self ) -> bytes :
0 commit comments