diff --git a/AGENTS.md b/AGENTS.md index 98a86eb4..92359381 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -26,7 +26,7 @@ Project structure expectations Coding standards - Add typing annotations to all functions and classes (including return types). -- Add or update docstrings (PEP 257) for functions and classes. +- Add or update docstrings for all files, classes and methods, including private methods. Method docstrings must be in NumPy format. - Preserve existing comments and keep imports at the top of files. - Follow existing repository style; run `pre-commit` and `ruff` where available. diff --git a/custom_components/opnsense/pyopnsense/__init__.py b/custom_components/opnsense/pyopnsense/__init__.py index a1159173..0684c4aa 100644 --- a/custom_components/opnsense/pyopnsense/__init__.py +++ b/custom_components/opnsense/pyopnsense/__init__.py @@ -626,6 +626,52 @@ async def _safe_list_post( result = await self._post(path=path, payload=payload) return result if isinstance(result, list) else [] + async def _get_check(self, path: str) -> bool: + """Check if the given API path is accessible. + + This method intentionally bypasses the request queue used by _get() and + _post() to provide fast, lightweight endpoint availability checks. This + is appropriate for plugin/service detection and initialization checks + where the 0.3s queue delay would harm user experience. + + Parameters + ---------- + path : str + The API path to check for accessibility. + + Returns + ------- + bool + True if the path is accessible (HTTP 2xx success), False otherwise. + + """ + # /api////[/[/...]] + self._rest_api_query_count += 1 + url: str = f"{self._url}{path}" + _LOGGER.debug("[get_check] url: %s", url) + try: + async with self._session.get( + url, + auth=aiohttp.BasicAuth(self._username, self._password), + timeout=aiohttp.ClientTimeout(total=DEFAULT_TIMEOUT), + ssl=self._verify_ssl, + ) as response: + _LOGGER.debug("[get_check] Response %s: %s", response.status, response.reason) + if response.ok: + return True + if response.status == 403: + _LOGGER.error( + "Permission Error in get_check. Path: %s. Ensure the OPNsense user connected to HA has appropriate access. Recommend full admin access", + url, + ) + return False + except aiohttp.ClientError as e: + _LOGGER.error("Client error. %s: %s", type(e).__name__, e) + if self._initial: + raise + + return False + @_log_errors async def _filter_configure(self) -> None: script: str = r""" @@ -1140,7 +1186,17 @@ async def _get_dnsmasq_leases(self) -> list: return leases async def _get_isc_dhcpv4_leases(self) -> list: - """Return IPv4 DHCP Leases by ISC.""" + """Return IPv4 DHCP Leases by ISC. + + Returns + ------- + list + A list of dictionaries representing IPv4 DHCP leases. + + """ + if not await self._get_check("/api/dhcpv4/service/status"): + _LOGGER.debug("ISC DHCPv4 plugin/service not available, skipping lease retrieval") + return [] if self._use_snake_case: response = await self._safe_dict_get("/api/dhcpv4/leases/search_lease") else: @@ -1184,7 +1240,17 @@ async def _get_isc_dhcpv4_leases(self) -> list: return leases async def _get_isc_dhcpv6_leases(self) -> list: - """Return IPv6 DHCP Leases by ISC.""" + """Return IPv6 DHCP Leases by ISC. + + Returns + ------- + list + A list of dictionaries representing IPv6 DHCP leases. + + """ + if not await self._get_check("/api/dhcpv6/service/status"): + _LOGGER.debug("ISC DHCPv6 plugin/service not available, skipping lease retrieval") + return [] if self._use_snake_case: response = await self._safe_dict_get("/api/dhcpv6/leases/search_lease") else: diff --git a/tests/test_pyopnsense.py b/tests/test_pyopnsense.py index 7fb7369b..1e79a804 100644 --- a/tests/test_pyopnsense.py +++ b/tests/test_pyopnsense.py @@ -146,6 +146,73 @@ async def test_safe_dict_post_and_list_post(monkeypatch, make_client) -> None: assert result_empty_list == [] +@pytest.mark.asyncio +async def test_get_check(make_client) -> None: + """Test _get_check method returns True for ok responses, False otherwise.""" + session = MagicMock(spec=aiohttp.ClientSession) + + # Fake response class for testing + class FakeResp: + def __init__(self, status=500, ok=False): + self.status = status + self.reason = "Test" + self.ok = ok + self.request_info = MagicMock() + self.history = [] + self.headers = {} + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + # Test successful response (ok=True) + session.get = lambda *a, **k: FakeResp(status=200, ok=True) + client = make_client(session=session) + result = await client._get_check("/api/test") + assert result is True + + # Test failed response (ok=False) + session.get = lambda *a, **k: FakeResp(status=404, ok=False) + client = make_client(session=session) + result = await client._get_check("/api/test") + assert result is False + + # Test 403 response specifically + session.get = lambda *a, **k: FakeResp(status=403, ok=False) + client = make_client(session=session) + result = await client._get_check("/api/test") + assert result is False + + +@pytest.mark.asyncio +@pytest.mark.parametrize("initial,should_raise", [(False, False), (True, True)]) +async def test_get_check_handles_client_error(make_client, initial, should_raise) -> None: + """Ensure _get_check handles aiohttp.ClientError correctly. + + When client is not in initialization mode, the method should swallow the + ClientError and return False. When the client is in initialization mode + (used during setup), it should re-raise the exception. + """ + session = MagicMock(spec=aiohttp.ClientSession) + + def _raise(*a, **k): + raise aiohttp.ClientError("boom") + + session.get = _raise + client = make_client(session=session) + # simulate the initialization flag behavior + client._initial = initial + + if should_raise: + with pytest.raises(aiohttp.ClientError): + await client._get_check("/api/test") + else: + result = await client._get_check("/api/test") + assert result is False + + @pytest.mark.asyncio async def test_get_ip_key_sorting(make_client) -> None: """Sort IP-like items using get_ip_key ordering.""" @@ -358,6 +425,26 @@ async def test_dhcp_leases_and_keep_latest_and_dnsmasq(make_client) -> None: await client.async_close() +@pytest.mark.asyncio +async def test_isc_dhcp_service_not_running(make_client) -> None: + """Test ISC DHCP lease methods return empty list when service is not running.""" + session = MagicMock(spec=aiohttp.ClientSession) + client = make_client(session=session) + try: + # Mock _get_check to return False (service not running) + client._get_check = AsyncMock(return_value=False) + + # Test DHCPv4 + leases_v4 = await client._get_isc_dhcpv4_leases() + assert leases_v4 == [] + + # Test DHCPv6 + leases_v6 = await client._get_isc_dhcpv6_leases() + assert leases_v6 == [] + finally: + await client.async_close() + + @pytest.mark.asyncio async def test_carp_and_reboot_and_wol(make_client) -> None: """Verify CARP interface discovery and system control endpoints (reboot/halt/WOL).""" @@ -1067,7 +1154,7 @@ def mkarg(pname: str): for name, func in _inspect.getmembers( pyopnsense.OPNsenseClient, predicate=_inspect.iscoroutinefunction ) - if name not in ("__init__",) + if name != "__init__" ] for name, _func in coros: @@ -2560,6 +2647,7 @@ async def test_get_isc_dhcpv4_and_v6_parsing() -> None: # v4: ends present and in future future_dt = (datetime.now() + timedelta(hours=1)).strftime("%Y/%m/%d %H:%M:%S") client._use_snake_case = False + client._get_check = AsyncMock(return_value=True) client._safe_dict_get = AsyncMock( side_effect=[ { @@ -2583,6 +2671,7 @@ async def test_get_isc_dhcpv4_and_v6_parsing() -> None: # v6: ends missing -> field passed through client._use_snake_case = True + client._get_check = AsyncMock(return_value=True) client._safe_dict_get = AsyncMock( return_value={ "rows": [