Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ jobs:
with:
CODE_FOLDER: bellows
CACHE_VERSION: 2
PYTHON_VERSION_DEFAULT: 3.11.0
PRE_COMMIT_CACHE_PATH: ~/.cache/pre-commit
MINIMUM_COVERAGE_PERCENTAGE: 99
secrets:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
3 changes: 2 additions & 1 deletion bellows/ash.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,7 @@ def rstack_frame_received(self, frame: RStackFrame) -> None:

self._tx_seq = 0
self._rx_seq = 0
self._cancel_pending_data_frames(NcpFailure(code=frame.reset_code))
self._change_ack_timeout(T_RX_ACK_INIT)
self._ezsp_protocol.reset_received(frame.reset_code)

Expand All @@ -582,7 +583,7 @@ def error_frame_received(self, frame: ErrorFrame) -> None:
def _enter_failed_state(self, reset_code: t.NcpResetCode) -> None:
self._ncp_state = NcpState.FAILED
self._cancel_pending_data_frames(NcpFailure(code=reset_code))
self._ezsp_protocol.reset_received(reset_code)
self._ezsp_protocol.error_received(reset_code)

def _write_frame(
self,
Expand Down
24 changes: 19 additions & 5 deletions bellows/ezsp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@

from . import v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v16, v17

RESET_ATTEMPTS = 5

EZSP_LATEST = v17.EZSPv17.VERSION
LOGGER = logging.getLogger(__name__)
MTOR_MIN_INTERVAL = 60
Expand Down Expand Up @@ -130,12 +132,24 @@ async def connect(self, *, use_thread: bool = True) -> None:
assert self._gw is None
self._gw = await bellows.uart.connect(self._config, self, use_thread=use_thread)

try:
for attempt in range(RESET_ATTEMPTS):
self._protocol = v4.EZSPv4(self.handle_callback, self._gw)
await self.startup_reset()
except Exception:
await self.disconnect()
raise

try:
await self.startup_reset()
break
except Exception as exc:
if attempt + 1 < RESET_ATTEMPTS:
LOGGER.debug(
"EZSP startup/reset failed, retrying (%d/%d): %r",
attempt + 1,
RESET_ATTEMPTS,
exc,
)
continue

await self.disconnect()
raise

async def reset(self):
LOGGER.debug("Resetting EZSP")
Expand Down
13 changes: 7 additions & 6 deletions bellows/uart.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import bellows.types as t

LOGGER = logging.getLogger(__name__)
RESET_TIMEOUT = 5
RESET_TIMEOUT = 3


class Gateway(zigpy.serial.SerialProtocol):
Expand All @@ -33,21 +33,22 @@ def data_received(self, data):

def reset_received(self, code: t.NcpResetCode) -> None:
"""Reset acknowledgement frame receive handler"""
# not a reset we've requested. Signal api reset
if code is not t.NcpResetCode.RESET_SOFTWARE:
self._api.enter_failed_state(code)
return
LOGGER.debug("Received reset: %r", code)

if self._reset_future and not self._reset_future.done():
self._reset_future.set_result(True)
elif self._startup_reset_future and not self._startup_reset_future.done():
self._startup_reset_future.set_result(True)
else:
self._api.enter_failed_state(code)
LOGGER.warning("Received an unexpected reset: %r", code)

def error_received(self, code: t.NcpResetCode) -> None:
"""Error frame receive handler."""
self._api.enter_failed_state(code)
if self._reset_future is not None or self._startup_reset_future is not None:
LOGGER.debug("Ignoring spurious error during reset: %r", code)
else:
self._api.enter_failed_state(code)

async def wait_for_startup_reset(self) -> None:
"""Wait for the first reset frame on startup."""
Expand Down
37 changes: 37 additions & 0 deletions tests/test_ash.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,43 @@ async def test_ash_end_to_end(transport_cls: type[FakeTransport]) -> None:
await host.send_data(b"ncp NAKing until failure")


async def test_rstack_cancels_pending_frames() -> None:
"""Test that RSTACK frame cancels pending data frames."""
host_ezsp = MagicMock()
ncp_ezsp = MagicMock()

host = ash.AshProtocol(host_ezsp)
ncp = AshNcpProtocol(ncp_ezsp)

host_transport = FakeTransport(ncp)
ncp_transport = FakeTransport(host)

host.connection_made(host_transport)
ncp.connection_made(ncp_transport)

# Pause the NCP transport so ACKs can't be sent back, creating a pending frame
ncp_transport.paused = True

# Start sending data without awaiting - this will create a pending frame
send_task = asyncio.create_task(host.send_data(b"test data"))

# Give task time to start and create the pending frame
await asyncio.sleep(0.1)

# Verify we have a pending frame
assert len(host._pending_data_frames) == 1

# Trigger RSTACK frame to cancel the pending frame
rstack = ash.RStackFrame(version=2, reset_code=t.NcpResetCode.RESET_POWER_ON)
host.rstack_frame_received(rstack)

# Verify task was cancelled with NcpFailure containing the reset code
with pytest.raises(ash.NcpFailure) as exc_info:
await send_task

assert exc_info.value.code == t.NcpResetCode.RESET_POWER_ON


def test_ncp_failure_comparison() -> None:
exc1 = ash.NcpFailure(code=t.NcpResetCode.ERROR_EXCEEDED_MAXIMUM_ACK_TIMEOUT_COUNT)
exc2 = ash.NcpFailure(code=t.NcpResetCode.RESET_POWER_ON)
Expand Down
26 changes: 24 additions & 2 deletions tests/test_ezsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,11 +299,33 @@ async def test_ezsp_connect_failure(disconnect_mock, reset_mock, version_mock):
await ezsp.connect()

assert conn_mock.await_count == 1
assert reset_mock.await_count == 1
assert version_mock.await_count == 1
assert reset_mock.await_count == 5
assert version_mock.await_count == 5
assert disconnect_mock.call_count == 1


@pytest.mark.parametrize("failures_before_success", [1, 2, 3, 4])
@patch.object(EZSP, "disconnect", new_callable=AsyncMock)
async def test_ezsp_connect_retry_success(disconnect_mock, failures_before_success):
"""Test connection succeeding after N failures."""
call_count = 0

async def startup_reset_mock():
nonlocal call_count
call_count += 1
if call_count <= failures_before_success:
raise RuntimeError(f"Startup failed (attempt {call_count})")

with patch("bellows.uart.connect"):
ezsp = make_ezsp(version=4)

with patch.object(ezsp, "startup_reset", side_effect=startup_reset_mock):
await ezsp.connect()

assert call_count == failures_before_success + 1
assert disconnect_mock.call_count == 0


async def test_ezsp_newer_version(ezsp_f):
"""Test newer version of ezsp."""
with patch.object(
Expand Down
36 changes: 31 additions & 5 deletions tests/test_uart.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,18 @@ def on_transport_close():
assert len(threads) == 0


async def test_wait_for_startup_reset(gw):
@pytest.mark.parametrize(
"reset_code",
[
t.NcpResetCode.RESET_SOFTWARE,
t.NcpResetCode.RESET_POWER_ON,
t.NcpResetCode.RESET_WATCHDOG,
t.NcpResetCode.RESET_EXTERNAL,
],
)
async def test_wait_for_startup_reset(gw, reset_code):
loop = asyncio.get_running_loop()
loop.call_later(0.01, gw.reset_received, t.NcpResetCode.RESET_SOFTWARE)
loop.call_later(0.01, gw.reset_received, reset_code)

assert gw._startup_reset_future is None
await gw.wait_for_startup_reset()
Expand All @@ -239,8 +248,25 @@ async def test_callbacks(gw):
]


def test_reset_propagation(gw):
gw.reset_received(t.NcpResetCode.ERROR_EXCEEDED_MAXIMUM_ACK_TIMEOUT_COUNT)
async def test_error_received_during_reset_ignored(gw):
# Set up a reset future to simulate being in the middle of a reset
loop = asyncio.get_running_loop()
gw._reset_future = loop.create_future()

# Error should be ignored (not trigger failed state)
gw.error_received(t.NcpResetCode.ERROR_EXCEEDED_MAXIMUM_ACK_TIMEOUT_COUNT)
assert gw._api.enter_failed_state.call_count == 0

# Clean up
gw._reset_future.cancel()


def test_unexpected_reset_triggers_failed_state(gw):
# When no reset is expected, any reset should trigger failed state
assert gw._reset_future is None
assert gw._startup_reset_future is None

gw.reset_received(t.NcpResetCode.RESET_SOFTWARE)
assert gw._api.enter_failed_state.mock_calls == [
call(t.NcpResetCode.ERROR_EXCEEDED_MAXIMUM_ACK_TIMEOUT_COUNT)
call(t.NcpResetCode.RESET_SOFTWARE)
]
Loading