Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Project structure expectations

Coding standards
- Add typing annotations to all functions and classes (including return types).
- Add or update docstrings (PEP 257) for functions and classes.
- Add or update docstrings for all files, classes and methods, including private methods. Method docstrings must be in NumPy format.
- Preserve existing comments and keep imports at the top of files.
- Follow existing repository style; run `pre-commit` and `ruff` where available.

Expand Down
70 changes: 68 additions & 2 deletions custom_components/opnsense/pyopnsense/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,52 @@ async def _safe_list_post(
result = await self._post(path=path, payload=payload)
return result if isinstance(result, list) else []

async def _get_check(self, path: str) -> bool:
"""Check if the given API path is accessible.

This method intentionally bypasses the request queue used by _get() and
_post() to provide fast, lightweight endpoint availability checks. This
is appropriate for plugin/service detection and initialization checks
where the 0.3s queue delay would harm user experience.

Parameters
----------
path : str
The API path to check for accessibility.

Returns
-------
bool
True if the path is accessible (HTTP 2xx success), False otherwise.

"""
# /api/<module>/<controller>/<command>/[<param1>/[<param2>/...]]
self._rest_api_query_count += 1
url: str = f"{self._url}{path}"
_LOGGER.debug("[get_check] url: %s", url)
try:
async with self._session.get(
url,
auth=aiohttp.BasicAuth(self._username, self._password),
timeout=aiohttp.ClientTimeout(total=DEFAULT_TIMEOUT),
ssl=self._verify_ssl,
) as response:
_LOGGER.debug("[get_check] Response %s: %s", response.status, response.reason)
if response.ok:
return True
if response.status == 403:
_LOGGER.error(
"Permission Error in get_check. Path: %s. Ensure the OPNsense user connected to HA has appropriate access. Recommend full admin access",
url,
)
return False
except aiohttp.ClientError as e:
_LOGGER.error("Client error. %s: %s", type(e).__name__, e)
if self._initial:
raise

return False

@_log_errors
async def _filter_configure(self) -> None:
script: str = r"""
Expand Down Expand Up @@ -1140,7 +1186,17 @@ async def _get_dnsmasq_leases(self) -> list:
return leases

async def _get_isc_dhcpv4_leases(self) -> list:
"""Return IPv4 DHCP Leases by ISC."""
"""Return IPv4 DHCP Leases by ISC.

Returns
-------
list
A list of dictionaries representing IPv4 DHCP leases.

"""
if not await self._get_check("/api/dhcpv4/service/status"):
_LOGGER.debug("ISC DHCPv4 plugin/service not available, skipping lease retrieval")
return []
if self._use_snake_case:
response = await self._safe_dict_get("/api/dhcpv4/leases/search_lease")
else:
Expand Down Expand Up @@ -1184,7 +1240,17 @@ async def _get_isc_dhcpv4_leases(self) -> list:
return leases

async def _get_isc_dhcpv6_leases(self) -> list:
"""Return IPv6 DHCP Leases by ISC."""
"""Return IPv6 DHCP Leases by ISC.

Returns
-------
list
A list of dictionaries representing IPv6 DHCP leases.

"""
if not await self._get_check("/api/dhcpv6/service/status"):
_LOGGER.debug("ISC DHCPv6 plugin/service not available, skipping lease retrieval")
return []
if self._use_snake_case:
response = await self._safe_dict_get("/api/dhcpv6/leases/search_lease")
else:
Expand Down
91 changes: 90 additions & 1 deletion tests/test_pyopnsense.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,73 @@ async def test_safe_dict_post_and_list_post(monkeypatch, make_client) -> None:
assert result_empty_list == []


@pytest.mark.asyncio
async def test_get_check(make_client) -> None:
"""Test _get_check method returns True for ok responses, False otherwise."""
session = MagicMock(spec=aiohttp.ClientSession)

# Fake response class for testing
class FakeResp:
def __init__(self, status=500, ok=False):
self.status = status
self.reason = "Test"
self.ok = ok
self.request_info = MagicMock()
self.history = []
self.headers = {}

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc, tb):
return False

# Test successful response (ok=True)
session.get = lambda *a, **k: FakeResp(status=200, ok=True)
client = make_client(session=session)
result = await client._get_check("/api/test")
assert result is True

# Test failed response (ok=False)
session.get = lambda *a, **k: FakeResp(status=404, ok=False)
client = make_client(session=session)
result = await client._get_check("/api/test")
assert result is False

# Test 403 response specifically
session.get = lambda *a, **k: FakeResp(status=403, ok=False)
client = make_client(session=session)
result = await client._get_check("/api/test")
assert result is False


@pytest.mark.asyncio
@pytest.mark.parametrize("initial,should_raise", [(False, False), (True, True)])
async def test_get_check_handles_client_error(make_client, initial, should_raise) -> None:
"""Ensure _get_check handles aiohttp.ClientError correctly.

When client is not in initialization mode, the method should swallow the
ClientError and return False. When the client is in initialization mode
(used during setup), it should re-raise the exception.
"""
session = MagicMock(spec=aiohttp.ClientSession)

def _raise(*a, **k):
raise aiohttp.ClientError("boom")

session.get = _raise
client = make_client(session=session)
# simulate the initialization flag behavior
client._initial = initial

if should_raise:
with pytest.raises(aiohttp.ClientError):
await client._get_check("/api/test")
else:
result = await client._get_check("/api/test")
assert result is False


@pytest.mark.asyncio
async def test_get_ip_key_sorting(make_client) -> None:
"""Sort IP-like items using get_ip_key ordering."""
Expand Down Expand Up @@ -358,6 +425,26 @@ async def test_dhcp_leases_and_keep_latest_and_dnsmasq(make_client) -> None:
await client.async_close()


@pytest.mark.asyncio
async def test_isc_dhcp_service_not_running(make_client) -> None:
"""Test ISC DHCP lease methods return empty list when service is not running."""
session = MagicMock(spec=aiohttp.ClientSession)
client = make_client(session=session)
try:
# Mock _get_check to return False (service not running)
client._get_check = AsyncMock(return_value=False)

# Test DHCPv4
leases_v4 = await client._get_isc_dhcpv4_leases()
assert leases_v4 == []

# Test DHCPv6
leases_v6 = await client._get_isc_dhcpv6_leases()
assert leases_v6 == []
finally:
await client.async_close()


@pytest.mark.asyncio
async def test_carp_and_reboot_and_wol(make_client) -> None:
"""Verify CARP interface discovery and system control endpoints (reboot/halt/WOL)."""
Expand Down Expand Up @@ -1067,7 +1154,7 @@ def mkarg(pname: str):
for name, func in _inspect.getmembers(
pyopnsense.OPNsenseClient, predicate=_inspect.iscoroutinefunction
)
if name not in ("__init__",)
if name != "__init__"
]

for name, _func in coros:
Expand Down Expand Up @@ -2560,6 +2647,7 @@ async def test_get_isc_dhcpv4_and_v6_parsing() -> None:
# v4: ends present and in future
future_dt = (datetime.now() + timedelta(hours=1)).strftime("%Y/%m/%d %H:%M:%S")
client._use_snake_case = False
client._get_check = AsyncMock(return_value=True)
client._safe_dict_get = AsyncMock(
side_effect=[
{
Expand All @@ -2583,6 +2671,7 @@ async def test_get_isc_dhcpv4_and_v6_parsing() -> None:

# v6: ends missing -> field passed through
client._use_snake_case = True
client._get_check = AsyncMock(return_value=True)
client._safe_dict_get = AsyncMock(
return_value={
"rows": [
Expand Down
Loading