Skip to content

Commit 0291862

Browse files
authored
Merge pull request #58 from yungwine/safe_close
Safe close
2 parents ad7cfa1 + 6b8e915 commit 0291862

File tree

1 file changed

+75
-43
lines changed

1 file changed

+75
-43
lines changed

pytoniq/liteclient/client.py

Lines changed: 75 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import asyncio
55
import socket
66
import struct
7-
import time
87
import typing
8+
from contextlib import suppress
99

1010
import requests
1111
from pytoniq_core import HashMap, Builder
@@ -68,6 +68,7 @@ def __init__(self,
6868
"""########### init ###########"""
6969
self.tasks = {}
7070
self.inited = False
71+
self._closing = False
7172
self.logger = logging.getLogger(self.__class__.__name__)
7273
self.timeout = timeout
7374

@@ -141,43 +142,62 @@ async def send_and_encrypt(self, data: bytes, qid: str) -> asyncio.Future:
141142

142143
async def receive(self, data_len: int) -> bytes:
143144
try:
144-
data = await self.reader.readexactly(data_len)
145-
except ConnectionError:
145+
return await self.reader.readexactly(data_len)
146+
except (ConnectionError, asyncio.IncompleteReadError):
146147
await self.close()
147148
raise
148-
return data
149149

150150
async def receive_and_decrypt(self, data_len: int) -> bytes:
151151
data = self.decrypt(await self.receive(data_len))
152152
return data
153153

154154
async def listen(self) -> None:
155-
while True:
156-
while not self.tasks:
157-
await asyncio.sleep(self.delta)
155+
try:
156+
while True:
157+
while not self.tasks:
158+
await asyncio.sleep(self.delta)
158159

159-
data_len_encrypted = await self.receive(4)
160-
data_len = int(self.decrypt(data_len_encrypted)[::-1].hex(), 16)
160+
data_len_encrypted = await self.receive(4)
161+
data_len = int(self.decrypt(data_len_encrypted)[::-1].hex(), 16)
161162

162-
self.logger.debug(msg=f'received {data_len // 8} bytes of data')
163+
self.logger.debug(msg=f'received {data_len // 8} bytes of data')
163164

164-
data_encrypted = await self.receive(data_len)
165-
data_decrypted = self.decrypt(data_encrypted)
166-
# check hashsum
167-
assert hashlib.sha256(data_decrypted[:-32]).digest() == data_decrypted[-32:], 'incorrect checksum'
168-
result = self.deserialize_adnl_query(data_decrypted[:-32])
165+
data_encrypted = await self.receive(data_len)
166+
data_decrypted = self.decrypt(data_encrypted)
167+
# check hashsum
168+
assert hashlib.sha256(data_decrypted[:-32]).digest() == data_decrypted[-32:], 'incorrect checksum'
169+
result = self.deserialize_adnl_query(data_decrypted[:-32])
169170

170-
if not result:
171-
# for handshake
172-
result = {}
171+
if not result:
172+
# for handshake
173+
result = {}
173174

174-
qid = result.get('query_id', result.get('random_id')) # return query_id for ordinary requests, random_id for ping-pong requests, None for handshake
175+
qid = result.get('query_id', result.get('random_id')) # return query_id for ordinary requests, random_id for ping-pong requests, None for handshake
175176

176-
request: asyncio.Future = self.tasks.pop(qid)
177+
request: asyncio.Future = self.tasks.pop(qid)
178+
179+
result = result.get('answer', {})
180+
if not request.done():
181+
request.set_result(result)
182+
except asyncio.CancelledError:
183+
pass
184+
except (ConnectionResetError, asyncio.IncompleteReadError, ConnectionAbortedError, TimeoutError):
185+
return
186+
except Exception as e:
187+
self.logger.exception('listener crashed')
188+
asyncio.create_task(self.close())
189+
return
190+
finally:
191+
self._cancel_all_tasks()
177192

178-
result = result.get('answer', {})
179-
if not request.done():
180-
request.set_result(result)
193+
def _cancel_all_tasks(self):
194+
if self.tasks:
195+
for fut in list(self.tasks.values()):
196+
if fut and not fut.done():
197+
fut: asyncio.Future
198+
# fut.cancel()
199+
fut.set_exception(LiteClientError('Connection is closed'))
200+
self.tasks.clear()
181201

182202
async def connect(self) -> None:
183203
if self.inited:
@@ -190,10 +210,10 @@ async def connect(self) -> None:
190210
)
191211
future = await asyncio.wait_for(self.send(handshake, None), self.timeout)
192212
self.listener = asyncio.create_task(self.listen())
213+
await asyncio.wait_for(future, self.timeout)
193214
await self.update_last_blocks()
194215
self.pinger = asyncio.create_task(self.ping())
195216
self.updater = asyncio.create_task(self.block_updater())
196-
await future
197217
self.inited = True
198218

199219
async def reconnect(self, max_retries: int = 5, retry_delay: int = 2) -> None:
@@ -210,29 +230,41 @@ async def reconnect(self, max_retries: int = 5, retry_delay: int = 2) -> None:
210230
self.logger.info("Successfully reconnected")
211231
return # exit if connection succeeds
212232
except Exception as e:
213-
self.logger.error(f"Reconnection attempt {attempt + 1}/{max_retries} failed: {str(e)}")
233+
self.logger.error(f"Reconnection attempt {attempt + 1}/{max_retries} failed: {type(e)}: {str(e)}")
214234
await asyncio.sleep(retry_delay)
215235
raise LiteClientError('Failed to reconnect after several attempts')
216236

217237
async def close(self) -> None:
218-
for task in [self.pinger, self.updater, self.listener]:
219-
if task is not None and not task.done():
220-
task.cancel()
221-
if task is not None:
222-
try:
223-
await task
224-
except (asyncio.CancelledError, Exception):
225-
pass
226-
self.inited = False
227-
self.tasks = {}
228-
self.reader = None
229-
if self.writer:
230-
self.writer.close()
231-
try:
232-
await self.writer.wait_closed()
233-
except ConnectionError:
234-
pass
235-
self.writer = None
238+
current = asyncio.current_task()
239+
if self._closing:
240+
return
241+
self._closing = True
242+
self._cancel_all_tasks()
243+
try:
244+
for task_name in ("pinger", "updater", "listener"):
245+
task = getattr(self, task_name, None)
246+
if task is None:
247+
continue
248+
if task is not current and not task.done():
249+
task.cancel()
250+
if task is not current:
251+
try:
252+
await asyncio.wait_for(task, timeout=1.0)
253+
except (asyncio.CancelledError, Exception):
254+
pass
255+
setattr(self, task_name, None)
256+
257+
self.inited = False
258+
self.tasks.clear()
259+
self.reader = None
260+
w = self.writer
261+
self.writer = None
262+
with suppress(Exception):
263+
w.close()
264+
with suppress(Exception):
265+
await asyncio.wait_for(w.wait_closed(), timeout=1.0)
266+
finally:
267+
self._closing = False
236268
self.logger.info('client has been closed')
237269

238270
def handshake(self) -> bytes:

0 commit comments

Comments
 (0)