Skip to content

Commit b8e3941

Browse files
authored
Implement a retry on disconnect during transaction (#75)
1 parent 6344786 commit b8e3941

File tree

8 files changed

+125
-59
lines changed

8 files changed

+125
-59
lines changed

switchbot/adv_parsers/bulb.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
def process_color_bulb(data: bytes, mfr_data: bytes | None) -> dict[str, bool | int]:
66
"""Process WoBulb services data."""
7+
assert mfr_data is not None
78
return {
89
"sequence_number": mfr_data[6],
910
"isOn": bool(mfr_data[7] & 0b10000000),

switchbot/adv_parsers/meter.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
"""Meter parser."""
22
from __future__ import annotations
33

4+
from typing import Any
45

5-
def process_wosensorth(data: bytes, mfr_data: bytes | None) -> dict[str, object]:
6+
7+
def process_wosensorth(data: bytes, mfr_data: bytes | None) -> dict[str, Any]:
68
"""Process woSensorTH/Temp sensor services data."""
79

810
_temp_sign = 1 if data[4] & 0b10000000 else -1

switchbot/adv_parsers/plug.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
def process_woplugmini(data: bytes, mfr_data: bytes | None) -> dict[str, bool | int]:
66
"""Process plug mini."""
7+
assert mfr_data is not None
78
return {
89
"switchMode": True,
910
"isOn": mfr_data[7] == 0x80,

switchbot/devices/bot.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@ async def turn_on(self) -> bool:
3838
return True
3939

4040
if result[0] == 5:
41-
_LOGGER.debug("Bot is in press mode and doesn't have on state")
41+
_LOGGER.debug(
42+
"%s: Bot is in press mode and doesn't have on state", self.name
43+
)
4244
return True
4345

4446
return False
@@ -50,7 +52,9 @@ async def turn_off(self) -> bool:
5052
return True
5153

5254
if result[0] == 5:
53-
_LOGGER.debug("Bot is in press mode and doesn't have off state")
55+
_LOGGER.debug(
56+
"%s: Bot is in press mode and doesn't have off state", self.name
57+
)
5458
return True
5559

5660
return False
@@ -62,7 +66,7 @@ async def hand_up(self) -> bool:
6266
return True
6367

6468
if result[0] == 5:
65-
_LOGGER.debug("Bot is in press mode")
69+
_LOGGER.debug("%s: Bot is in press mode", self.name)
6670
return True
6771

6872
return False
@@ -74,7 +78,7 @@ async def hand_down(self) -> bool:
7478
return True
7579

7680
if result[0] == 5:
77-
_LOGGER.debug("Bot is in press mode")
81+
_LOGGER.debug("%s: Bot is in press mode", self.name)
7882
return True
7983

8084
return False
@@ -86,7 +90,7 @@ async def press(self) -> bool:
8690
return True
8791

8892
if result[0] == 5:
89-
_LOGGER.debug("Bot is in switch mode")
93+
_LOGGER.debug("%s: Bot is in switch mode", self.name)
9094
return True
9195

9296
return False

switchbot/devices/curtain.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ async def get_extended_info_summary(self) -> dict[str, Any] | None:
112112
)
113113

114114
if _data in (b"\x07", b"\x00"):
115-
_LOGGER.error("Unsuccessfull, please try again")
115+
_LOGGER.error("%s: Unsuccessful, please try again", self.name)
116116
return None
117117

118118
self.ext_info_sum["device0"] = {
@@ -145,7 +145,7 @@ async def get_extended_info_adv(self) -> dict[str, Any] | None:
145145
)
146146

147147
if _data in (b"\x07", b"\x00"):
148-
_LOGGER.error("Unsuccessfull, please try again")
148+
_LOGGER.error("%s: Unsuccessful, please try again", self.name)
149149
return None
150150

151151
_state_of_charge = [

switchbot/devices/device.py

Lines changed: 105 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44
import asyncio
55
import binascii
66
import logging
7-
from typing import Any
7+
from ctypes import cast
8+
from typing import Any, Callable, TypeVar
89
from uuid import UUID
910

1011
import async_timeout
1112
import bleak
13+
from bleak import BleakError
1214
from bleak.backends.device import BLEDevice
13-
from bleak.backends.service import BleakGATTServiceCollection
15+
from bleak.backends.service import BleakGATTCharacteristic, BleakGATTServiceCollection
1416
from bleak_retry_connector import (
1517
BleakClientWithServiceCache,
1618
ble_device_has_changed,
@@ -31,6 +33,13 @@
3133
# Base key when encryption is set
3234
KEY_PASSWORD_PREFIX = "571"
3335

36+
BLEAK_EXCEPTIONS = (AttributeError, BleakError, asyncio.exceptions.TimeoutError)
37+
38+
# How long to hold the connection
39+
# to wait for additional commands for
40+
# disconnecting the device.
41+
DISCONNECT_DELAY = 59
42+
3443

3544
def _sb_uuid(comms_type: str = "service") -> UUID | str:
3645
"""Return Switchbot UUID."""
@@ -60,13 +69,19 @@ def __init__(
6069
self._scan_timeout: int = kwargs.pop("scan_timeout", DEFAULT_SCAN_TIMEOUT)
6170
self._retry_count: int = kwargs.pop("retry_count", DEFAULT_RETRY_COUNT)
6271
self._connect_lock = asyncio.Lock()
72+
self._operation_lock = asyncio.Lock()
6373
if password is None or password == "":
6474
self._password_encoded = None
6575
else:
6676
self._password_encoded = "%08x" % (
6777
binascii.crc32(password.encode("ascii")) & 0xFFFFFFFF
6878
)
79+
self._client: BleakClientWithServiceCache | None = None
6980
self._cached_services: BleakGATTServiceCollection | None = None
81+
self._read_char: BleakGATTCharacteristic | None = None
82+
self._write_char: BleakGATTCharacteristic | None = None
83+
self._disconnect_timer: asyncio.TimerHandle | None = None
84+
self.loop = asyncio.get_event_loop()
7085

7186
def _commandkey(self, key: str) -> str:
7287
"""Add password to key if set."""
@@ -79,21 +94,30 @@ def _commandkey(self, key: str) -> str:
7994
async def _sendcommand(self, key: str, retry: int) -> bytes:
8095
"""Send command to device and read response."""
8196
command = bytearray.fromhex(self._commandkey(key))
82-
_LOGGER.debug("Sending command to switchbot %s", command)
97+
_LOGGER.debug("%s: Sending command %s", self.name, command)
98+
if self._operation_lock.locked():
99+
_LOGGER.debug(
100+
"%s: Operation already in progress, waiting for it to complete.",
101+
self.name,
102+
)
103+
83104
max_attempts = retry + 1
84-
async with self._connect_lock:
105+
async with self._operation_lock:
85106
for attempt in range(max_attempts):
86107
try:
87108
return await self._send_command_locked(key, command)
88-
except (bleak.BleakError, asyncio.exceptions.TimeoutError):
109+
except BLEAK_EXCEPTIONS:
89110
if attempt == retry:
90111
_LOGGER.error(
91-
"Switchbot communication failed. Stopping trying",
112+
"%s: communication failed. Stopping trying",
113+
self.name,
92114
exc_info=True,
93115
)
94116
return b"\x00"
95117

96-
_LOGGER.debug("Switchbot communication failed with:", exc_info=True)
118+
_LOGGER.debug(
119+
"%s: communication failed with:", self.name, exc_info=True
120+
)
97121

98122
raise RuntimeError("Unreachable")
99123

@@ -102,49 +126,91 @@ def name(self) -> str:
102126
"""Return device name."""
103127
return f"{self._device.name} ({self._device.address})"
104128

105-
async def _send_command_locked(self, key: str, command: bytes) -> bytes:
106-
"""Send command to device and read response."""
107-
client: BleakClientWithServiceCache | None = None
108-
try:
109-
_LOGGER.debug("%s: Connnecting to switchbot", self.name)
129+
async def _ensure_connected(self):
130+
"""Ensure connection to device is established."""
131+
if self._connect_lock.locked():
132+
_LOGGER.debug(
133+
"%s: Connection already in progress, waiting for it to complete.",
134+
self.name,
135+
)
136+
if self._client and self._client.is_connected:
137+
self._reset_disconnect_timer()
138+
return
139+
async with self._connect_lock:
140+
# Check again while holding the lock
141+
if self._client and self._client.is_connected:
142+
self._reset_disconnect_timer()
143+
return
110144
client = await establish_connection(
111145
BleakClientWithServiceCache,
112146
self._device,
113147
self.name,
114-
max_attempts=1,
115148
cached_services=self._cached_services,
116149
)
117150
self._cached_services = client.services
118-
_LOGGER.debug(
119-
"%s: Connnected to switchbot: %s", self.name, client.is_connected
120-
)
121-
read_char = client.services.get_characteristic(_sb_uuid(comms_type="rx"))
122-
write_char = client.services.get_characteristic(_sb_uuid(comms_type="tx"))
123-
future: asyncio.Future[bytearray] = asyncio.Future()
151+
_LOGGER.debug("%s: Connected", self.name)
152+
services = client.services
153+
self._read_char = services.get_characteristic(_sb_uuid(comms_type="rx"))
154+
self._write_char = services.get_characteristic(_sb_uuid(comms_type="tx"))
155+
self._client = client
156+
self._reset_disconnect_timer()
157+
158+
def _reset_disconnect_timer(self):
159+
"""Reset disconnect timer."""
160+
if self._disconnect_timer:
161+
self._disconnect_timer.cancel()
162+
self._disconnect_timer = self.loop.call_later(
163+
DISCONNECT_DELAY, self._disconnect
164+
)
124165

125-
def _notification_handler(_sender: int, data: bytearray) -> None:
126-
"""Handle notification responses."""
127-
if future.done():
128-
_LOGGER.debug("%s: Notification handler already done", self.name)
129-
return
130-
future.set_result(data)
166+
def _disconnect(self):
167+
"""Disconnect from device."""
168+
self._disconnect_timer = None
169+
asyncio.create_task(self._execute_disconnect())
170+
171+
async def _execute_disconnect(self):
172+
"""Execute disconnection."""
173+
_LOGGER.debug(
174+
"%s: Disconnecting after timeout of %s",
175+
self.name,
176+
DISCONNECT_DELAY,
177+
)
178+
async with self._connect_lock:
179+
if not self._client or not self._client.is_connected:
180+
return
181+
await self._client.disconnect()
182+
self._client = None
183+
self._read_char = None
184+
self._write_char = None
185+
186+
async def _send_command_locked(self, key: str, command: bytes) -> bytes:
187+
"""Send command to device and read response."""
188+
await self._ensure_connected()
189+
assert self._client is not None
190+
assert self._read_char is not None
191+
assert self._write_char is not None
192+
future: asyncio.Future[bytearray] = asyncio.Future()
193+
client = self._client
131194

132-
_LOGGER.debug("%s: Subscribe to notifications", self.name)
133-
await client.start_notify(read_char, _notification_handler)
195+
def _notification_handler(_sender: int, data: bytearray) -> None:
196+
"""Handle notification responses."""
197+
if future.done():
198+
_LOGGER.debug("%s: Notification handler already done", self.name)
199+
return
200+
future.set_result(data)
134201

135-
_LOGGER.debug("%s: Sending command, %s", self.name, key)
136-
await client.write_gatt_char(write_char, command, False)
202+
_LOGGER.debug("%s: Subscribe to notifications", self.name)
203+
await client.start_notify(self._read_char, _notification_handler)
137204

138-
async with async_timeout.timeout(5):
139-
notify_msg = await future
140-
_LOGGER.info("%s: Notification received: %s", self.name, notify_msg)
205+
_LOGGER.debug("%s: Sending command: %s", self.name, key)
206+
await client.write_gatt_char(self._write_char, command, False)
141207

142-
_LOGGER.debug("%s: UnSubscribe to notifications", self.name)
143-
await client.stop_notify(read_char)
208+
async with async_timeout.timeout(5):
209+
notify_msg = await future
210+
_LOGGER.debug("%s: Notification received: %s", self.name, notify_msg)
144211

145-
finally:
146-
if client:
147-
await client.disconnect()
212+
_LOGGER.debug("%s: UnSubscribe to notifications", self.name)
213+
await client.stop_notify(self._read_char)
148214

149215
if notify_msg == b"\x07":
150216
_LOGGER.error("Password required")
@@ -175,7 +241,7 @@ def update_from_advertisement(self, advertisement: SwitchBotAdvertisement) -> No
175241

176242
async def get_device_data(
177243
self, retry: int = DEFAULT_RETRY_COUNT, interface: int | None = None
178-
) -> dict | None:
244+
) -> SwitchBotAdvertisement | None:
179245
"""Find switchbot devices and their advertisement data."""
180246
if interface:
181247
_interface: int = interface
@@ -191,7 +257,7 @@ async def get_device_data(
191257

192258
return self._sb_adv_data
193259

194-
async def _get_basic_info(self) -> dict | None:
260+
async def _get_basic_info(self) -> bytes | None:
195261
"""Return basic info of device."""
196262
_data = await self._sendcommand(
197263
key=DEVICE_GET_BASIC_SETTINGS_KEY, retry=self._retry_count

switchbot/devices/plug.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,6 @@
1313
class SwitchbotPlugMini(SwitchbotDevice):
1414
"""Representation of a Switchbot plug mini."""
1515

16-
def __init__(self, *args: Any, **kwargs: Any) -> None:
17-
"""Switchbot plug mini constructor."""
18-
super().__init__(*args, **kwargs)
19-
self._settings: dict[str, Any] = {}
20-
2116
async def update(self, interface: int | None = None) -> None:
2217
"""Update state of device."""
2318
await self.get_device_data(retry=self._retry_count, interface=interface)
@@ -35,7 +30,4 @@ async def turn_off(self) -> bool:
3530
def is_on(self) -> Any:
3631
"""Return switch state from cache."""
3732
# To get actual position call update() first.
38-
value = self._get_adv_value("isOn")
39-
if value is None:
40-
return None
41-
return value
33+
return self._get_adv_value("isOn")

switchbot/discovery.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,8 @@ async def get_device_data(
109109
await self.discover()
110110

111111
return {
112-
device: data
113-
for device, data in self._adv_data.items()
112+
device: adv
113+
for device, adv in self._adv_data.items()
114114
# MacOS uses UUIDs instead of MAC addresses
115-
if data.get("address") == address
115+
if adv.data.get("address") == address
116116
}

0 commit comments

Comments
 (0)