44import asyncio
55import binascii
66import logging
7- from typing import Any
7+ from ctypes import cast
8+ from typing import Any , Callable , TypeVar
89from uuid import UUID
910
1011import async_timeout
1112import bleak
13+ from bleak import BleakError
1214from bleak .backends .device import BLEDevice
13- from bleak .backends .service import BleakGATTServiceCollection
15+ from bleak .backends .service import BleakGATTCharacteristic , BleakGATTServiceCollection
1416from bleak_retry_connector import (
1517 BleakClientWithServiceCache ,
1618 ble_device_has_changed ,
3133# Base key when encryption is set
3234KEY_PASSWORD_PREFIX = "571"
3335
36+ BLEAK_EXCEPTIONS = (AttributeError , BleakError , asyncio .exceptions .TimeoutError )
37+
38+ # How long to hold the connection
39+ # to wait for additional commands for
40+ # disconnecting the device.
41+ DISCONNECT_DELAY = 59
42+
3443
3544def _sb_uuid (comms_type : str = "service" ) -> UUID | str :
3645 """Return Switchbot UUID."""
@@ -60,13 +69,19 @@ def __init__(
6069 self ._scan_timeout : int = kwargs .pop ("scan_timeout" , DEFAULT_SCAN_TIMEOUT )
6170 self ._retry_count : int = kwargs .pop ("retry_count" , DEFAULT_RETRY_COUNT )
6271 self ._connect_lock = asyncio .Lock ()
72+ self ._operation_lock = asyncio .Lock ()
6373 if password is None or password == "" :
6474 self ._password_encoded = None
6575 else :
6676 self ._password_encoded = "%08x" % (
6777 binascii .crc32 (password .encode ("ascii" )) & 0xFFFFFFFF
6878 )
79+ self ._client : BleakClientWithServiceCache | None = None
6980 self ._cached_services : BleakGATTServiceCollection | None = None
81+ self ._read_char : BleakGATTCharacteristic | None = None
82+ self ._write_char : BleakGATTCharacteristic | None = None
83+ self ._disconnect_timer : asyncio .TimerHandle | None = None
84+ self .loop = asyncio .get_event_loop ()
7085
7186 def _commandkey (self , key : str ) -> str :
7287 """Add password to key if set."""
@@ -79,21 +94,30 @@ def _commandkey(self, key: str) -> str:
7994 async def _sendcommand (self , key : str , retry : int ) -> bytes :
8095 """Send command to device and read response."""
8196 command = bytearray .fromhex (self ._commandkey (key ))
82- _LOGGER .debug ("Sending command to switchbot %s" , command )
97+ _LOGGER .debug ("%s: Sending command %s" , self .name , command )
98+ if self ._operation_lock .locked ():
99+ _LOGGER .debug (
100+ "%s: Operation already in progress, waiting for it to complete." ,
101+ self .name ,
102+ )
103+
83104 max_attempts = retry + 1
84- async with self ._connect_lock :
105+ async with self ._operation_lock :
85106 for attempt in range (max_attempts ):
86107 try :
87108 return await self ._send_command_locked (key , command )
88- except ( bleak . BleakError , asyncio . exceptions . TimeoutError ) :
109+ except BLEAK_EXCEPTIONS :
89110 if attempt == retry :
90111 _LOGGER .error (
91- "Switchbot communication failed. Stopping trying" ,
112+ "%s: communication failed. Stopping trying" ,
113+ self .name ,
92114 exc_info = True ,
93115 )
94116 return b"\x00 "
95117
96- _LOGGER .debug ("Switchbot communication failed with:" , exc_info = True )
118+ _LOGGER .debug (
119+ "%s: communication failed with:" , self .name , exc_info = True
120+ )
97121
98122 raise RuntimeError ("Unreachable" )
99123
@@ -102,49 +126,91 @@ def name(self) -> str:
102126 """Return device name."""
103127 return f"{ self ._device .name } ({ self ._device .address } )"
104128
105- async def _send_command_locked (self , key : str , command : bytes ) -> bytes :
106- """Send command to device and read response."""
107- client : BleakClientWithServiceCache | None = None
108- try :
109- _LOGGER .debug ("%s: Connnecting to switchbot" , self .name )
129+ async def _ensure_connected (self ):
130+ """Ensure connection to device is established."""
131+ if self ._connect_lock .locked ():
132+ _LOGGER .debug (
133+ "%s: Connection already in progress, waiting for it to complete." ,
134+ self .name ,
135+ )
136+ if self ._client and self ._client .is_connected :
137+ self ._reset_disconnect_timer ()
138+ return
139+ async with self ._connect_lock :
140+ # Check again while holding the lock
141+ if self ._client and self ._client .is_connected :
142+ self ._reset_disconnect_timer ()
143+ return
110144 client = await establish_connection (
111145 BleakClientWithServiceCache ,
112146 self ._device ,
113147 self .name ,
114- max_attempts = 1 ,
115148 cached_services = self ._cached_services ,
116149 )
117150 self ._cached_services = client .services
118- _LOGGER .debug (
119- "%s: Connnected to switchbot: %s" , self .name , client .is_connected
120- )
121- read_char = client .services .get_characteristic (_sb_uuid (comms_type = "rx" ))
122- write_char = client .services .get_characteristic (_sb_uuid (comms_type = "tx" ))
123- future : asyncio .Future [bytearray ] = asyncio .Future ()
151+ _LOGGER .debug ("%s: Connected" , self .name )
152+ services = client .services
153+ self ._read_char = services .get_characteristic (_sb_uuid (comms_type = "rx" ))
154+ self ._write_char = services .get_characteristic (_sb_uuid (comms_type = "tx" ))
155+ self ._client = client
156+ self ._reset_disconnect_timer ()
157+
158+ def _reset_disconnect_timer (self ):
159+ """Reset disconnect timer."""
160+ if self ._disconnect_timer :
161+ self ._disconnect_timer .cancel ()
162+ self ._disconnect_timer = self .loop .call_later (
163+ DISCONNECT_DELAY , self ._disconnect
164+ )
124165
125- def _notification_handler (_sender : int , data : bytearray ) -> None :
126- """Handle notification responses."""
127- if future .done ():
128- _LOGGER .debug ("%s: Notification handler already done" , self .name )
129- return
130- future .set_result (data )
166+ def _disconnect (self ):
167+ """Disconnect from device."""
168+ self ._disconnect_timer = None
169+ asyncio .create_task (self ._execute_disconnect ())
170+
171+ async def _execute_disconnect (self ):
172+ """Execute disconnection."""
173+ _LOGGER .debug (
174+ "%s: Disconnecting after timeout of %s" ,
175+ self .name ,
176+ DISCONNECT_DELAY ,
177+ )
178+ async with self ._connect_lock :
179+ if not self ._client or not self ._client .is_connected :
180+ return
181+ await self ._client .disconnect ()
182+ self ._client = None
183+ self ._read_char = None
184+ self ._write_char = None
185+
186+ async def _send_command_locked (self , key : str , command : bytes ) -> bytes :
187+ """Send command to device and read response."""
188+ await self ._ensure_connected ()
189+ assert self ._client is not None
190+ assert self ._read_char is not None
191+ assert self ._write_char is not None
192+ future : asyncio .Future [bytearray ] = asyncio .Future ()
193+ client = self ._client
131194
132- _LOGGER .debug ("%s: Subscribe to notifications" , self .name )
133- await client .start_notify (read_char , _notification_handler )
195+ def _notification_handler (_sender : int , data : bytearray ) -> None :
196+ """Handle notification responses."""
197+ if future .done ():
198+ _LOGGER .debug ("%s: Notification handler already done" , self .name )
199+ return
200+ future .set_result (data )
134201
135- _LOGGER .debug ("%s: Sending command, %s " , self .name , key )
136- await client .write_gatt_char ( write_char , command , False )
202+ _LOGGER .debug ("%s: Subscribe to notifications " , self .name )
203+ await client .start_notify ( self . _read_char , _notification_handler )
137204
138- async with async_timeout .timeout (5 ):
139- notify_msg = await future
140- _LOGGER .info ("%s: Notification received: %s" , self .name , notify_msg )
205+ _LOGGER .debug ("%s: Sending command: %s" , self .name , key )
206+ await client .write_gatt_char (self ._write_char , command , False )
141207
142- _LOGGER .debug ("%s: UnSubscribe to notifications" , self .name )
143- await client .stop_notify (read_char )
208+ async with async_timeout .timeout (5 ):
209+ notify_msg = await future
210+ _LOGGER .debug ("%s: Notification received: %s" , self .name , notify_msg )
144211
145- finally :
146- if client :
147- await client .disconnect ()
212+ _LOGGER .debug ("%s: UnSubscribe to notifications" , self .name )
213+ await client .stop_notify (self ._read_char )
148214
149215 if notify_msg == b"\x07 " :
150216 _LOGGER .error ("Password required" )
@@ -175,7 +241,7 @@ def update_from_advertisement(self, advertisement: SwitchBotAdvertisement) -> No
175241
176242 async def get_device_data (
177243 self , retry : int = DEFAULT_RETRY_COUNT , interface : int | None = None
178- ) -> dict | None :
244+ ) -> SwitchBotAdvertisement | None :
179245 """Find switchbot devices and their advertisement data."""
180246 if interface :
181247 _interface : int = interface
@@ -191,7 +257,7 @@ async def get_device_data(
191257
192258 return self ._sb_adv_data
193259
194- async def _get_basic_info (self ) -> dict | None :
260+ async def _get_basic_info (self ) -> bytes | None :
195261 """Return basic info of device."""
196262 _data = await self ._sendcommand (
197263 key = DEVICE_GET_BASIC_SETTINGS_KEY , retry = self ._retry_count
0 commit comments