From f994f01ae224b7df6f88aef2d0edbc990da6ce48 Mon Sep 17 00:00:00 2001 From: Snuffy2 Date: Sun, 1 Feb 2026 17:44:14 -0500 Subject: [PATCH 01/15] Move existing firewall to legacy --- custom_components/opnsense/coordinator.py | 2 +- .../opnsense/pyopnsense/__init__.py | 190 ++++++++++++++- custom_components/opnsense/switch.py | 217 ++++++++++-------- tests/test_pyopnsense.py | 10 +- tests/test_switch.py | 8 +- 5 files changed, 310 insertions(+), 117 deletions(-) diff --git a/custom_components/opnsense/coordinator.py b/custom_components/opnsense/coordinator.py index 61a2710e..61652c7f 100644 --- a/custom_components/opnsense/coordinator.py +++ b/custom_components/opnsense/coordinator.py @@ -150,7 +150,7 @@ def _build_categories(self) -> list[MutableMapping[str, str]]: if config.get(CONF_SYNC_NOTICES, DEFAULT_SYNC_OPTION_VALUE): categories.append({"function": "get_notices", "state_key": "notices"}) if config.get(CONF_SYNC_FILTERS_AND_NAT, DEFAULT_SYNC_OPTION_VALUE): - categories.append({"function": "get_config", "state_key": "config"}) + categories.append({"function": "get_firewall", "state_key": "firewall"}) if config.get(CONF_SYNC_UNBOUND, DEFAULT_SYNC_OPTION_VALUE): categories.append( { diff --git a/custom_components/opnsense/pyopnsense/__init__.py b/custom_components/opnsense/pyopnsense/__init__.py index a1159173..208a3027 100644 --- a/custom_components/opnsense/pyopnsense/__init__.py +++ b/custom_components/opnsense/pyopnsense/__init__.py @@ -367,6 +367,15 @@ async def _get(self, path: str) -> MutableMapping[str, Any] | list | None: await self._request_queue.put(("get", path, None, future, caller)) return await future + async def _get_raw(self, path: str) -> str | None: + try: + caller = inspect.stack()[1].function + except (IndexError, AttributeError): + caller = "Unknown" + future = self._loop.create_future() + await self._request_queue.put(("get_raw", path, None, future, caller)) + return await future + async def _post( self, path: str, payload: MutableMapping[str, Any] | None = None ) -> MutableMapping[str, Any] | list | None: @@ -383,15 +392,17 @@ async def _process_queue(self) -> None: method, path, payload, future, caller = await self._request_queue.get() try: if method == "get_from_stream": - result: MutableMapping[str, Any] | list | None = await self._do_get_from_stream( - path, caller - ) + result: Any = await self._do_get_from_stream(path, caller) if future is not None and not future.done(): future.set_result(result) elif method == "get": result = await self._do_get(path, caller) if future is not None and not future.done(): future.set_result(result) + elif method == "get_raw": + result = await self._do_get_raw(path, caller) + if future is not None and not future.done(): + future.set_result(result) elif method == "post": result = await self._do_post(path, payload, caller) if future is not None and not future.done(): @@ -552,6 +563,50 @@ async def _do_get( return None + async def _do_get_raw(self, path: str, caller: str = "Unknown") -> str | None: + # /api////[/[/...]] + self._rest_api_query_count += 1 + url: str = f"{self._url}{path}" + _LOGGER.debug("[get_raw] url: %s", url) + try: + async with self._session.get( + url, + auth=aiohttp.BasicAuth(self._username, self._password), + timeout=aiohttp.ClientTimeout(total=DEFAULT_TIMEOUT), + ssl=self._verify_ssl, + ) as response: + _LOGGER.debug("[get_raw] Response %s: %s", response.status, response.reason) + if response.ok: + return await response.text() + if response.status == 403: + _LOGGER.error( + "Permission Error in do_get_raw (called by %s). Path: %s. Ensure the OPNsense user connected to HA has appropriate access. Recommend full admin access", + caller, + url, + ) + else: + _LOGGER.error( + "Error in do_get_raw (called by %s). Path: %s. Response %s: %s", + caller, + url, + response.status, + response.reason, + ) + if self._initial: + raise aiohttp.ClientResponseError( + request_info=response.request_info, + history=response.history, + status=response.status, + message=f"HTTP Status Error: {response.status} {response.reason}", + headers=response.headers, + ) + except aiohttp.ClientError as e: + _LOGGER.error("Client error. %s: %s", type(e).__name__, e) + if self._initial: + raise + + return None + async def _safe_dict_get(self, path: str) -> MutableMapping[str, Any]: """Fetch data from the given path, ensuring the result is a dict.""" result = await self._get(path=path) @@ -808,7 +863,7 @@ async def disable_filter_rule_by_created_time(self, created_time: str) -> None: # use created_time as a unique_id since none other exists @_log_errors - async def enable_nat_port_forward_rule_by_created_time(self, created_time: str) -> None: + async def enable_nat_port_forward_rule_by_created_time_legacy(self, created_time: str) -> None: """Enable a NAT Port Forward rule.""" config: MutableMapping[str, Any] = await self.get_config() for rule in config.get("nat", {}).get("rule", []): @@ -826,7 +881,7 @@ async def enable_nat_port_forward_rule_by_created_time(self, created_time: str) # use created_time as a unique_id since none other exists @_log_errors - async def disable_nat_port_forward_rule_by_created_time(self, created_time: str) -> None: + async def disable_nat_port_forward_rule_by_created_time_legacy(self, created_time: str) -> None: """Disable a NAT Port Forward rule.""" config: MutableMapping[str, Any] = await self.get_config() for rule in config.get("nat", {}).get("rule", []): @@ -844,7 +899,7 @@ async def disable_nat_port_forward_rule_by_created_time(self, created_time: str) # use created_time as a unique_id since none other exists @_log_errors - async def enable_nat_outbound_rule_by_created_time(self, created_time: str) -> None: + async def enable_nat_outbound_rule_by_created_time_legacy(self, created_time: str) -> None: """Enable NAT Outbound rule.""" config: MutableMapping[str, Any] = await self.get_config() for rule in config.get("nat", {}).get("outbound", {}).get("rule", []): @@ -862,7 +917,7 @@ async def enable_nat_outbound_rule_by_created_time(self, created_time: str) -> N # use created_time as a unique_id since none other exists @_log_errors - async def disable_nat_outbound_rule_by_created_time(self, created_time: str) -> None: + async def disable_nat_outbound_rule_by_created_time_legacy(self, created_time: str) -> None: """Disable NAT Outbound Rule.""" config: MutableMapping[str, Any] = await self.get_config() for rule in config.get("nat", {}).get("outbound", {}).get("rule", []): @@ -874,6 +929,127 @@ async def disable_nat_outbound_rule_by_created_time(self, created_time: str) -> await self._restore_config_section("nat", config["nat"]) await self._filter_configure() + ##################### + @_log_errors + async def get_firewall(self) -> dict[str, Any]: + """Retrieve all firewall and NAT rules from OPNsense. + + Returns + ------- + dict + A dictionary representing the firewall and NAT rules. + + """ + if self._firmware_version is None: + await self.get_host_firmware_version() + + try: + if awesomeversion.AwesomeVersion( + self._firmware_version + ) < awesomeversion.AwesomeVersion("26.1"): + _LOGGER.debug("Using legacy plugin for firewall filters for OPNsense < 26.1") + return {"config": await self.get_config()} + except awesomeversion.exceptions.AwesomeVersionCompareException: + _LOGGER.warning("Error comparing firmware version. Skipping get_firewall.") + return {} + firewall: dict[str, Any] = {"nat": {}} + if await self.is_plugin_installed(): + firewall["config"] = await self.get_config() + firewall["rules"] = await self.get_firewall_rules() + firewall["nat"]["destination_rules"] = await self.get_firewall_nat_destination_rules() + firewall["nat"]["one_to_one_rules"] = await self.get_firewall_nat_one_to_one_rules() + firewall["nat"]["source_rules"] = await self.get_firewall_nat_source_rules() + _LOGGER.debug("[get_firewall] firewall: %s", firewall) + return firewall + + @_log_errors + async def get_firewall_rules(self) -> list: + """Retrieve firewall rules from OPNsense. + + Returns + ------- + list of dict + A list of dictionaries, each representing a firewall rule parsed + from CSV format. Dictionary keys correspond to CSV headers such + as '@uuid', 'enabled', 'action', etc. + + """ + response = await self._get_raw("/api/firewall/filter/download_rules") + _LOGGER.debug("[get_firewall_rules] response: %s", response) + if not response or not isinstance(response, str): + return [] + lines = response.strip().split("\n") + if len(lines) < 2: + return [] + headers = lines[0].split(",") + rules = [] + for line in lines[1:]: + if line.strip(): + values = line.split(",") + rule = dict(zip(headers, values, strict=True)) + rules.append(rule) + _LOGGER.debug("[get_firewall_rules] rules: %s", rules) + return rules + + @_log_errors + async def get_firewall_nat_destination_rules(self) -> list: + """Retrieve NAT destination rules from OPNsense. + + Returns + ------- + list of dict + A list of dictionaries representing NAT destination rules. + + """ + request_body: MutableMapping[str, Any] = {"current": 1, "sort": {}} + response = await self._safe_dict_post( + "/api/firewall/d_nat/search_rule", payload=request_body + ) + _LOGGER.debug("[get_firewall_nat_destination_rules] response: %s", response) + rules: list = response.get("rows", []) + _LOGGER.debug("[get_firewall_nat_destination_rules] rules: %s", rules) + return rules + + @_log_errors + async def get_firewall_nat_one_to_one_rules(self) -> list: + """Retrieve NAT one-to-one rules from OPNsense. + + Returns + ------- + list of dict + A list of dictionaries representing NAT one-to-one rules. + + """ + request_body: MutableMapping[str, Any] = {"current": 1, "sort": {}} + response = await self._safe_dict_post( + "/api/firewall/one_to_one/search_rule", payload=request_body + ) + _LOGGER.debug("[get_firewall_nat_one_to_one_rules] response: %s", response) + rules: list = response.get("rows", []) + _LOGGER.debug("[get_firewall_nat_one_to_one_rules] rules: %s", rules) + return rules + + @_log_errors + async def get_firewall_nat_source_rules(self) -> list: + """Retrieve NAT source rules from OPNsense. + + Returns + ------- + list of dict + A list of dictionaries representing NAT source rules. + + """ + request_body: MutableMapping[str, Any] = {"current": 1, "sort": {}} + response = await self._safe_dict_post( + "/api/firewall/source_nat/search_rule", payload=request_body + ) + _LOGGER.debug("[get_firewall_nat_source_rules] response: %s", response) + rules: list = response.get("rows", []) + _LOGGER.debug("[get_firewall_nat_source_rules] rules: %s", rules) + return rules + + ##################### + @_log_errors async def get_arp_table(self, resolve_hostnames: bool = False) -> list: """Return the active ARP table.""" diff --git a/custom_components/opnsense/switch.py b/custom_components/opnsense/switch.py index fbe0b547..a8b968b8 100644 --- a/custom_components/opnsense/switch.py +++ b/custom_components/opnsense/switch.py @@ -30,128 +30,135 @@ _LOGGER: logging.Logger = logging.getLogger(__name__) -async def _compile_filter_switches( +async def _compile_filter_switches_legacy( config_entry: ConfigEntry, coordinator: OPNsenseDataUpdateCoordinator, state: MutableMapping[str, Any], ) -> list: - if not isinstance(state, MutableMapping) or not isinstance(state.get("config"), MutableMapping): + if not isinstance(state, MutableMapping) or not isinstance( + state.get("firewall", {}).get("config", {}).get("filter", {}).get("rule"), list + ): return [] entities: list = [] # filter rules - if "filter" in state.get("config", {}): - rules = dict_get(state, "config.filter.rule") - if isinstance(rules, list): - for rule in rules: - if not isinstance(rule, dict): - continue + rules = dict_get(state, "firewall.config.filter.rule") + if isinstance(rules, list): + for rule in rules: + if not isinstance(rule, dict): + continue - # do NOT add rules that are NAT rules - if "associated-rule-id" in rule: - continue + # do NOT add rules that are NAT rules + if "associated-rule-id" in rule: + continue - # not possible to disable these rules - if rule.get("descr", "") == "Anti-Lockout Rule": - continue + # not possible to disable these rules + if rule.get("descr", "") == "Anti-Lockout Rule": + continue - tracker = dict_get(rule, "created.time") - # we use tracker as the unique id - if tracker is None or len(tracker) < 1: - continue + tracker = dict_get(rule, "created.time") + # we use tracker as the unique id + if tracker is None or len(tracker) < 1: + continue - entities.append( - OPNsenseFilterSwitch( - config_entry=config_entry, - coordinator=coordinator, - entity_description=SwitchEntityDescription( - key=f"filter.{tracker}", - name=f"Filter Rule {tracker} ({rule.get('descr', '')})", - icon="mdi:play-network-outline", - # entity_category=entity_category, - device_class=SwitchDeviceClass.SWITCH, - entity_registry_enabled_default=False, - ), - ) + entities.append( + OPNsenseFilterSwitchLegacy( + config_entry=config_entry, + coordinator=coordinator, + entity_description=SwitchEntityDescription( + key=f"filter.{tracker}", + name=f"Filter Rule {tracker} ({rule.get('descr', '')})", + icon="mdi:play-network-outline", + # entity_category=entity_category, + device_class=SwitchDeviceClass.SWITCH, + entity_registry_enabled_default=False, + ), ) + ) + _LOGGER.debug("[compile_filter_switches_legacy] entities: %s", len(entities)) return entities -async def _compile_port_forward_switches( +async def _compile_port_forward_switches_legacy( config_entry: ConfigEntry, coordinator: OPNsenseDataUpdateCoordinator, state: MutableMapping[str, Any], ) -> list: - if not isinstance(state, MutableMapping) or not isinstance(state.get("config"), MutableMapping): + if not isinstance(state, MutableMapping) or not isinstance( + state.get("firewall", {}).get("config", {}).get("nat", {}).get("rule"), list + ): return [] entities: list = [] # nat port forward rules - if "nat" in state.get("config", {}): - rules = dict_get(state, "config.nat.rule") - if isinstance(rules, list): - for rule in rules: - if not isinstance(rule, dict): - continue + rules = dict_get(state, "firewall.config.nat.rule") + if isinstance(rules, list): + for rule in rules: + if not isinstance(rule, dict): + continue - tracker = dict_get(rule, "created.time") - # we use tracker as the unique id - if tracker is None or len(tracker) < 1: - continue + tracker = dict_get(rule, "created.time") + # we use tracker as the unique id + if tracker is None or len(tracker) < 1: + continue - entity = OPNsenseNatSwitch( - config_entry=config_entry, - coordinator=coordinator, - entity_description=SwitchEntityDescription( - key=f"nat_port_forward.{tracker}", - name=f"NAT Port Forward Rule {tracker} ({rule.get('descr', '')})", - icon="mdi:network-outline", - # entity_category=ENTITY_CATEGORY_CONFIG, - device_class=SwitchDeviceClass.SWITCH, - entity_registry_enabled_default=False, - ), - ) - entities.append(entity) + entity = OPNsenseNatSwitchLegacy( + config_entry=config_entry, + coordinator=coordinator, + entity_description=SwitchEntityDescription( + key=f"nat_port_forward.{tracker}", + name=f"NAT Port Forward Rule {tracker} ({rule.get('descr', '')})", + icon="mdi:network-outline", + # entity_category=ENTITY_CATEGORY_CONFIG, + device_class=SwitchDeviceClass.SWITCH, + entity_registry_enabled_default=False, + ), + ) + entities.append(entity) + _LOGGER.debug("[compile_port_forward_switches_legacy] entities: %s", len(entities)) return entities -async def _compile_nat_outbound_switches( +async def _compile_nat_outbound_switches_legacy( config_entry: ConfigEntry, coordinator: OPNsenseDataUpdateCoordinator, state: MutableMapping[str, Any], ) -> list: - if not isinstance(state, MutableMapping) or not isinstance(state.get("config"), MutableMapping): + if not isinstance(state, MutableMapping) or not isinstance( + state.get("firewall", {}).get("config", {}).get("nat", {}).get("outbound", {}).get("rule"), + list, + ): return [] entities: list = [] # nat outbound rules - if "nat" in state.get("config", {}): - # to actually be applicable mode must by "hybrid" or "advanced" - rules = dict_get(state, "config.nat.outbound.rule") - if isinstance(rules, list): - for rule in rules: - if not isinstance(rule, dict): - continue + # to actually be applicable, mode must by "hybrid" or "advanced" + rules = dict_get(state, "firewall.config.nat.outbound.rule") + if isinstance(rules, list): + for rule in rules: + if not isinstance(rule, dict): + continue - tracker = dict_get(rule, "created.time") - # we use tracker as the unique id - if tracker is None or len(tracker) < 1: - continue + tracker = dict_get(rule, "created.time") + # we use tracker as the unique id + if tracker is None or len(tracker) < 1: + continue - if "Auto created rule" in rule.get("descr", ""): - continue + if "Auto created rule" in rule.get("descr", ""): + continue - entity = OPNsenseNatSwitch( - config_entry=config_entry, - coordinator=coordinator, - entity_description=SwitchEntityDescription( - key=f"nat_outbound.{tracker}", - name=f"NAT Outbound Rule {tracker} ({rule.get('descr', '')})", - icon="mdi:network-outline", - # entity_category=ENTITY_CATEGORY_CONFIG, - device_class=SwitchDeviceClass.SWITCH, - entity_registry_enabled_default=False, - ), - ) - entities.append(entity) + entity = OPNsenseNatSwitchLegacy( + config_entry=config_entry, + coordinator=coordinator, + entity_description=SwitchEntityDescription( + key=f"nat_outbound.{tracker}", + name=f"NAT Outbound Rule {tracker} ({rule.get('descr', '')})", + icon="mdi:network-outline", + # entity_category=ENTITY_CATEGORY_CONFIG, + device_class=SwitchDeviceClass.SWITCH, + entity_registry_enabled_default=False, + ), + ) + entities.append(entity) + _LOGGER.debug("[compile_nat_outbound_switches_legacy] entities: %s", len(entities)) return entities @@ -286,9 +293,13 @@ async def async_setup_entry( entities: list = [] if config.get(CONF_SYNC_FILTERS_AND_NAT, DEFAULT_SYNC_OPTION_VALUE): - entities.extend(await _compile_filter_switches(config_entry, coordinator, state)) - entities.extend(await _compile_port_forward_switches(config_entry, coordinator, state)) - entities.extend(await _compile_nat_outbound_switches(config_entry, coordinator, state)) + entities.extend(await _compile_filter_switches_legacy(config_entry, coordinator, state)) + entities.extend( + await _compile_port_forward_switches_legacy(config_entry, coordinator, state) + ) + entities.extend( + await _compile_nat_outbound_switches_legacy(config_entry, coordinator, state) + ) if config.get(CONF_SYNC_SERVICES, DEFAULT_SYNC_OPTION_VALUE): entities.extend(await _compile_service_switches(config_entry, coordinator, state)) if config.get(CONF_SYNC_VPN, DEFAULT_SYNC_OPTION_VALUE): @@ -384,7 +395,7 @@ def _clear(_: Any) -> None: ) -class OPNsenseFilterSwitch(OPNsenseSwitch): +class OPNsenseFilterSwitchLegacy(OPNsenseSwitch): """Class for OPNsense Filter Switch entities.""" def __init__( @@ -401,7 +412,7 @@ def __init__( ) self._tracker: str = self._opnsense_get_tracker() self._rule: MutableMapping[str, Any] | None = None - # _LOGGER.debug(f"[OPNsenseFilterSwitch init] Name: {self.name}, tracker: {self._tracker}") + # _LOGGER.debug(f"[OPNsenseFilterSwitchLegacy init] Name: {self.name}, tracker: {self._tracker}") def _opnsense_get_tracker(self) -> str: parts = self.entity_description.key.split(".") @@ -413,7 +424,7 @@ def _opnsense_get_rule(self) -> MutableMapping[str, Any] | None: tracker: str = self._opnsense_get_tracker() if not isinstance(state, MutableMapping): return None - for rule in state.get("config", {}).get("filter", {}).get("rule", {}): + for rule in state.get("firewall", {}).get("config", {}).get("filter", {}).get("rule", {}): if dict_get(rule, "created.time") == tracker: return rule return None @@ -438,7 +449,7 @@ def _handle_coordinator_update(self) -> None: return self._available = True self.async_write_ha_state() - # _LOGGER.debug(f"[OPNsenseFilterSwitch handle_coordinator_update] Name: {self.name}, available: {self.available}, is_on: {self.is_on}, extra_state_attributes: {self.extra_state_attributes}") + # _LOGGER.debug(f"[OPNsenseFilterSwitchLegacy handle_coordinator_update] Name: {self.name}, available: {self.available}, is_on: {self.is_on}, extra_state_attributes: {self.extra_state_attributes}") async def async_turn_on(self, **kwargs: Any) -> None: """Turn the entity on.""" @@ -468,7 +479,7 @@ def icon(self) -> str | None: return super().icon -class OPNsenseNatSwitch(OPNsenseSwitch): +class OPNsenseNatSwitchLegacy(OPNsenseSwitch): """Class for OPNsense NAT Switch entities.""" def __init__( @@ -486,7 +497,7 @@ def __init__( self._rule_type: str = self._opnsense_get_rule_type() self._tracker: str = self._opnsense_get_tracker() self._rule: MutableMapping[str, Any] | None = None - # _LOGGER.debug(f"[OPNsenseNatSwitch init] Name: {self.name}, tracker: {self._tracker}, rule_type: {self._rule_type}") + # _LOGGER.debug(f"[OPNsenseNatSwitchLegacy init] Name: {self.name}, tracker: {self._tracker}, rule_type: {self._rule_type}") def _opnsense_get_rule_type(self) -> str: return self.entity_description.key.split(".")[0] @@ -502,9 +513,15 @@ def _opnsense_get_rule(self) -> MutableMapping[str, Any] | None: return None rules: list = [] if self._rule_type == ATTR_NAT_PORT_FORWARD: - rules = state.get("config", {}).get("nat", {}).get("rule", []) + rules = state.get("firewall", {}).get("config", {}).get("nat", {}).get("rule", []) if self._rule_type == ATTR_NAT_OUTBOUND: - rules = state.get("config", {}).get("nat", {}).get("outbound", {}).get("rule", []) + rules = ( + state.get("firewall", {}) + .get("config", {}) + .get("nat", {}) + .get("outbound", {}) + .get("rule", []) + ) for rule in rules: if dict_get(rule, "created.time") == self._tracker: @@ -527,16 +544,16 @@ def _handle_coordinator_update(self) -> None: return self._available = True self.async_write_ha_state() - # _LOGGER.debug(f"[OPNsenseNatSwitch handle_coordinator_update] Name: {self.name}, available: {self.available}, is_on: {self.is_on}, extra_state_attributes: {self.extra_state_attributes}") + # _LOGGER.debug(f"[OPNsenseNatSwitchLegacy handle_coordinator_update] Name: {self.name}, available: {self.available}, is_on: {self.is_on}, extra_state_attributes: {self.extra_state_attributes}") async def async_turn_on(self, **kwargs: Any) -> None: """Turn the entity on.""" if not isinstance(self._rule, MutableMapping) or not self._client: return if self._rule_type == ATTR_NAT_PORT_FORWARD: - method = self._client.enable_nat_port_forward_rule_by_created_time + method = self._client.enable_nat_port_forward_rule_by_created_time_legacy elif self._rule_type == ATTR_NAT_OUTBOUND: - method = self._client.enable_nat_outbound_rule_by_created_time + method = self._client.enable_nat_outbound_rule_by_created_time_legacy else: return await method(self._tracker) @@ -550,9 +567,9 @@ async def async_turn_off(self, **kwargs: Any) -> None: if not isinstance(self._rule, MutableMapping) or not self._client: return if self._rule_type == ATTR_NAT_PORT_FORWARD: - method = self._client.disable_nat_port_forward_rule_by_created_time + method = self._client.disable_nat_port_forward_rule_by_created_time_legacy elif self._rule_type == ATTR_NAT_OUTBOUND: - method = self._client.disable_nat_outbound_rule_by_created_time + method = self._client.disable_nat_outbound_rule_by_created_time_legacy else: return await method(self._tracker) diff --git a/tests/test_pyopnsense.py b/tests/test_pyopnsense.py index 7fb7369b..111ccd02 100644 --- a/tests/test_pyopnsense.py +++ b/tests/test_pyopnsense.py @@ -1331,12 +1331,12 @@ async def test_enable_and_disable_filter_rules_and_nat_port_forward(make_client) client._restore_config_section.assert_called() client._filter_configure.assert_awaited() - # enable_nat_port_forward_rule_by_created_time: similar flow under 'nat' section + # enable_nat_port_forward_rule_by_created_time_legacy: similar flow under 'nat' section cfg_nat = {"nat": {"rule": [{"created": {"time": "t-nat"}, "disabled": "1"}]}} client.get_config = AsyncMock(return_value=cfg_nat) client._restore_config_section = AsyncMock() client._filter_configure = AsyncMock() - await client.enable_nat_port_forward_rule_by_created_time("t-nat") + await client.enable_nat_port_forward_rule_by_created_time_legacy("t-nat") client._restore_config_section.assert_called() client._filter_configure.assert_awaited() @@ -2701,7 +2701,7 @@ async def test_get_config_and_rule_enable_disable_branches() -> None: # patch _restore_config_section and _filter_configure to be no-ops client._restore_config_section = AsyncMock() client._filter_configure = AsyncMock() - await client.disable_nat_port_forward_rule_by_created_time("n1") + await client.disable_nat_port_forward_rule_by_created_time_legacy("n1") await client.async_close() @@ -3053,7 +3053,7 @@ async def test_enable_disable_nat_outbound_rules(make_client) -> None: client.get_config = AsyncMock(return_value=cfg_enable) client._restore_config_section = AsyncMock() client._filter_configure = AsyncMock() - await client.enable_nat_outbound_rule_by_created_time("t1") + await client.enable_nat_outbound_rule_by_created_time_legacy("t1") client._restore_config_section.assert_awaited() client._filter_configure.assert_awaited() @@ -3062,7 +3062,7 @@ async def test_enable_disable_nat_outbound_rules(make_client) -> None: client.get_config = AsyncMock(return_value=cfg_disable) client._restore_config_section = AsyncMock() client._filter_configure = AsyncMock() - await client.disable_nat_outbound_rule_by_created_time("t2") + await client.disable_nat_outbound_rule_by_created_time_legacy("t2") client._restore_config_section.assert_awaited() client._filter_configure.assert_awaited() await client.async_close() diff --git a/tests/test_switch.py b/tests/test_switch.py index 43093fc0..bbf99706 100644 --- a/tests/test_switch.py +++ b/tests/test_switch.py @@ -60,8 +60,8 @@ def make_coord(data): _compile_port_forward_switches, {"config": {"nat": {"rule": [{"descr": "PF", "created": {"time": "p1"}}]}}}, ( - "enable_nat_port_forward_rule_by_created_time", - "disable_nat_port_forward_rule_by_created_time", + "enable_nat_port_forward_rule_by_created_time_legacy", + "disable_nat_port_forward_rule_by_created_time_legacy", ), ), ( @@ -72,8 +72,8 @@ def make_coord(data): } }, ( - "enable_nat_outbound_rule_by_created_time", - "disable_nat_outbound_rule_by_created_time", + "enable_nat_outbound_rule_by_created_time_legacy", + "disable_nat_outbound_rule_by_created_time_legacy", ), ), ( From 67bf4fb3618056c468b015e89c08167c1c723890 Mon Sep 17 00:00:00 2001 From: Snuffy2 Date: Sun, 1 Feb 2026 20:27:19 -0500 Subject: [PATCH 02/15] Create new firewall switches --- custom_components/opnsense/const.py | 8 +- custom_components/opnsense/coordinator.py | 4 +- .../opnsense/pyopnsense/__init__.py | 200 +++++++-- custom_components/opnsense/switch.py | 388 +++++++++++++++++- tests/test_coordinator.py | 4 +- tests/test_switch.py | 68 +-- 6 files changed, 591 insertions(+), 81 deletions(-) diff --git a/custom_components/opnsense/const.py b/custom_components/opnsense/const.py index 59a68a8c..7aa8ac62 100644 --- a/custom_components/opnsense/const.py +++ b/custom_components/opnsense/const.py @@ -55,7 +55,7 @@ CONF_SYNC_GATEWAYS = "sync_gateways" CONF_SYNC_SERVICES = "sync_services" CONF_SYNC_NOTICES = "sync_notices" -CONF_SYNC_FILTERS_AND_NAT = "sync_filters_and_nat" +CONF_SYNC_FIREWALL_AND_NAT = "sync_filters_and_nat" CONF_SYNC_UNBOUND = "sync_unbound" CONF_SYNC_INTERFACES = "sync_interfaces" CONF_SYNC_CERTIFICATES = "sync_certificates" @@ -63,7 +63,7 @@ DEFAULT_GRANULAR_SYNC_OPTIONS = False DEFAULT_SYNC_OPTION_VALUE = True -SYNC_ITEMS_REQUIRING_PLUGIN = (CONF_SYNC_FILTERS_AND_NAT,) +SYNC_ITEMS_REQUIRING_PLUGIN = (CONF_SYNC_FIREWALL_AND_NAT,) GRANULAR_SYNC_ITEMS = ( CONF_SYNC_TELEMETRY, CONF_SYNC_GATEWAYS, @@ -72,7 +72,7 @@ CONF_SYNC_NOTICES, CONF_SYNC_FIRMWARE_UPDATES, CONF_SYNC_CARP, - CONF_SYNC_FILTERS_AND_NAT, + CONF_SYNC_FIREWALL_AND_NAT, CONF_SYNC_SERVICES, CONF_SYNC_VPN, CONF_SYNC_CERTIFICATES, @@ -90,7 +90,7 @@ CONF_SYNC_UNBOUND: ["unbound"], CONF_SYNC_INTERFACES: ["interface"], CONF_SYNC_CERTIFICATES: ["certificates"], - CONF_SYNC_FILTERS_AND_NAT: ["filter", "nat"], + CONF_SYNC_FIREWALL_AND_NAT: ["filter", "nat"], } CONF_DEVICES = "devices" CONF_MANUAL_DEVICES = "manual_devices" diff --git a/custom_components/opnsense/coordinator.py b/custom_components/opnsense/coordinator.py index 61652c7f..c8936654 100644 --- a/custom_components/opnsense/coordinator.py +++ b/custom_components/opnsense/coordinator.py @@ -17,7 +17,7 @@ CONF_SYNC_CARP, CONF_SYNC_CERTIFICATES, CONF_SYNC_DHCP_LEASES, - CONF_SYNC_FILTERS_AND_NAT, + CONF_SYNC_FIREWALL_AND_NAT, CONF_SYNC_FIRMWARE_UPDATES, CONF_SYNC_GATEWAYS, CONF_SYNC_INTERFACES, @@ -149,7 +149,7 @@ def _build_categories(self) -> list[MutableMapping[str, str]]: categories.append({"function": "get_services", "state_key": "services"}) if config.get(CONF_SYNC_NOTICES, DEFAULT_SYNC_OPTION_VALUE): categories.append({"function": "get_notices", "state_key": "notices"}) - if config.get(CONF_SYNC_FILTERS_AND_NAT, DEFAULT_SYNC_OPTION_VALUE): + if config.get(CONF_SYNC_FIREWALL_AND_NAT, DEFAULT_SYNC_OPTION_VALUE): categories.append({"function": "get_firewall", "state_key": "firewall"}) if config.get(CONF_SYNC_UNBOUND, DEFAULT_SYNC_OPTION_VALUE): categories.append( diff --git a/custom_components/opnsense/pyopnsense/__init__.py b/custom_components/opnsense/pyopnsense/__init__.py index 208a3027..c1bc4987 100644 --- a/custom_components/opnsense/pyopnsense/__init__.py +++ b/custom_components/opnsense/pyopnsense/__init__.py @@ -827,7 +827,7 @@ async def get_config(self) -> MutableMapping[str, Any]: return ret_data @_log_errors - async def enable_filter_rule_by_created_time(self, created_time: str) -> None: + async def enable_filter_rule_by_created_time_legacy(self, created_time: str) -> None: """Enable a filter rule.""" config = await self.get_config() for rule in config["filter"]["rule"]: @@ -844,7 +844,7 @@ async def enable_filter_rule_by_created_time(self, created_time: str) -> None: await self._filter_configure() @_log_errors - async def disable_filter_rule_by_created_time(self, created_time: str) -> None: + async def disable_filter_rule_by_created_time_legacy(self, created_time: str) -> None: """Disable a filter rule.""" config: MutableMapping[str, Any] = await self.get_config() @@ -955,15 +955,40 @@ async def get_firewall(self) -> dict[str, Any]: firewall: dict[str, Any] = {"nat": {}} if await self.is_plugin_installed(): firewall["config"] = await self.get_config() - firewall["rules"] = await self.get_firewall_rules() - firewall["nat"]["destination_rules"] = await self.get_firewall_nat_destination_rules() - firewall["nat"]["one_to_one_rules"] = await self.get_firewall_nat_one_to_one_rules() - firewall["nat"]["source_rules"] = await self.get_firewall_nat_source_rules() - _LOGGER.debug("[get_firewall] firewall: %s", firewall) + interface_map = await self._get_interface_firewall_map() + firewall["rules"] = await self._get_firewall_rules(interface_map=interface_map) + firewall["nat"]["d_nat"] = await self._get_nat_destination_rules() + firewall["nat"]["one_to_one"] = await self._get_nat_one_to_one_rules() + firewall["nat"]["source_nat"] = await self._get_nat_source_rules() + firewall["nat"]["npt"] = await self._get_nat_npt_rules() + # _LOGGER.debug("[get_firewall] firewall: %s", firewall) return firewall @_log_errors - async def get_firewall_rules(self) -> list: + async def _get_interface_firewall_map(self) -> dict[str, Any]: + """Retrieve a mapping of interface names to firewall interface names. + + Returns + ------- + dict + A dictionary mapping interface names to firewall interface names. + + """ + interfaces = await self._safe_dict_get("/api/firewall/filter/get_interface_list") + interface_map: dict[str, Any] = {} + + if isinstance(interfaces, MutableMapping): + for section in interfaces.values(): + if isinstance(section, MutableMapping) and "items" in section: + for item in section["items"]: + if isinstance(item, MutableMapping) and "value" in item and "label" in item: + interface_map[item["value"]] = item["label"] + + _LOGGER.debug("[get_interface_firewall_map] interface_map: %s", interface_map) + return interface_map + + @_log_errors + async def _get_firewall_rules(self, interface_map: dict[str, Any]) -> dict[str, Any]: """Retrieve firewall rules from OPNsense. Returns @@ -975,12 +1000,12 @@ async def get_firewall_rules(self) -> list: """ response = await self._get_raw("/api/firewall/filter/download_rules") - _LOGGER.debug("[get_firewall_rules] response: %s", response) + # _LOGGER.debug("[get_firewall_rules] response: %s", response) if not response or not isinstance(response, str): - return [] + return {} lines = response.strip().split("\n") if len(lines) < 2: - return [] + return {} headers = lines[0].split(",") rules = [] for line in lines[1:]: @@ -988,11 +1013,20 @@ async def get_firewall_rules(self) -> list: values = line.split(",") rule = dict(zip(headers, values, strict=True)) rules.append(rule) - _LOGGER.debug("[get_firewall_rules] rules: %s", rules) - return rules + # _LOGGER.debug("[get_firewall_rules] rules: %s", rules) + rules_dict: dict[str, Any] = {} + for rule in rules: + new_rule = rule.copy() + new_rule["uuid"] = new_rule.pop("@uuid", "") + new_rule["%interface"] = interface_map.get( + new_rule.get("interface", ""), new_rule.get("interface", "") + ) + rules_dict[new_rule["uuid"]] = new_rule + _LOGGER.debug("[get_firewall_rules] rules_dict: %s", rules_dict) + return rules_dict @_log_errors - async def get_firewall_nat_destination_rules(self) -> list: + async def _get_nat_destination_rules(self) -> dict[str, Any]: """Retrieve NAT destination rules from OPNsense. Returns @@ -1005,13 +1039,22 @@ async def get_firewall_nat_destination_rules(self) -> list: response = await self._safe_dict_post( "/api/firewall/d_nat/search_rule", payload=request_body ) - _LOGGER.debug("[get_firewall_nat_destination_rules] response: %s", response) + # _LOGGER.debug("[get_nat_destination_rules] response: %s", response) rules: list = response.get("rows", []) - _LOGGER.debug("[get_firewall_nat_destination_rules] rules: %s", rules) - return rules + # _LOGGER.debug("[get_nat_destination_rules] rules: %s", rules) + rules_dict: dict[str, Any] = {} + for rule in rules: + if not rule.get("uuid") or "lockout" in rule.get("uuid"): + continue # skip lockout rules + new_rule = rule.copy() + new_rule["description"] = new_rule.pop("descr", "") + new_rule["enabled"] = "1" if new_rule.pop("disabled", "0") == "0" else "0" + rules_dict[new_rule["uuid"]] = new_rule + _LOGGER.debug("[get_nat_destination_rules] rules_dict: %s", rules_dict) + return rules_dict @_log_errors - async def get_firewall_nat_one_to_one_rules(self) -> list: + async def _get_nat_one_to_one_rules(self) -> dict[str, Any]: """Retrieve NAT one-to-one rules from OPNsense. Returns @@ -1024,13 +1067,21 @@ async def get_firewall_nat_one_to_one_rules(self) -> list: response = await self._safe_dict_post( "/api/firewall/one_to_one/search_rule", payload=request_body ) - _LOGGER.debug("[get_firewall_nat_one_to_one_rules] response: %s", response) + # _LOGGER.debug("[get_nat_one_to_one_rules] response: %s", response) rules: list = response.get("rows", []) - _LOGGER.debug("[get_firewall_nat_one_to_one_rules] rules: %s", rules) - return rules + _LOGGER.debug("[get_nat_one_to_one_rules] rules: %s", rules) + rules_dict: dict[str, Any] = {} + for rule in rules: + if not rule.get("uuid") or "lockout" in rule.get("uuid"): + continue + new_rule = rule.copy() + # Add any transforms here + rules_dict[new_rule["uuid"]] = new_rule + _LOGGER.debug("[get_nat_one_to_one_rules] rules_dict: %s", rules_dict) + return rules_dict @_log_errors - async def get_firewall_nat_source_rules(self) -> list: + async def _get_nat_source_rules(self) -> dict[str, Any]: """Retrieve NAT source rules from OPNsense. Returns @@ -1043,10 +1094,107 @@ async def get_firewall_nat_source_rules(self) -> list: response = await self._safe_dict_post( "/api/firewall/source_nat/search_rule", payload=request_body ) - _LOGGER.debug("[get_firewall_nat_source_rules] response: %s", response) + # _LOGGER.debug("[get_nat_source_rules] response: %s", response) rules: list = response.get("rows", []) - _LOGGER.debug("[get_firewall_nat_source_rules] rules: %s", rules) - return rules + # _LOGGER.debug("[get_nat_source_rules] rules: %s", rules) + rules_dict: dict[str, Any] = {} + for rule in rules: + if not rule.get("uuid") or "lockout" in rule.get("uuid"): + continue + new_rule = rule.copy() + # Add any transforms here + rules_dict[new_rule["uuid"]] = new_rule + _LOGGER.debug("[get_nat_source_rules] rules_dict: %s", rules_dict) + return rules_dict + + @_log_errors + async def _get_nat_npt_rules(self) -> dict[str, Any]: + """Retrieve NAT NPT rules from OPNsense. + + Returns + ------- + list of dict + A list of dictionaries representing NAT NPT rules. + + """ + request_body: MutableMapping[str, Any] = {"current": 1, "sort": {}} + response = await self._safe_dict_post("/api/firewall/npt/search_rule", payload=request_body) + # _LOGGER.debug("[get_nat_npt_rules] response: %s", response) + rules: list = response.get("rows", []) + # _LOGGER.debug("[get_nat_npt_rules] rules: %s", rules) + rules_dict: dict[str, Any] = {} + for rule in rules: + if not rule.get("uuid") or "lockout" in rule.get("uuid"): + continue + new_rule = rule.copy() + # Add any transforms here + rules_dict[new_rule["uuid"]] = new_rule + _LOGGER.debug("[get_nat_npt_rules] rules_dict: %s", rules_dict) + return rules_dict + + async def toggle_firewall_rule(self, uuid: str, toggle_on_off: str | None = None) -> bool: + """Toggle Firewall Rule on and off.""" + payload: MutableMapping[str, Any] = {} + url = f"/api/firewall/filter/toggle_rule/{uuid}" + if toggle_on_off == "on": + url = f"{url}/1" + elif toggle_on_off == "off": + url = f"{url}/0" + response = await self._safe_dict_post( + url, + payload=payload, + ) + _LOGGER.debug( + "[toggle_firewall_rule] uuid: %s, action: %s, url: %s, response: %s", + uuid, + toggle_on_off, + url, + response, + ) + if response.get("result") == "failed": + return False + + apply_resp = await self._safe_dict_post("/api/firewall/filter/apply") + if apply_resp.get("status") != "OK\n\n": + return False + + return True + + async def toggle_nat_rule( + self, nat_rule_type: str, uuid: str, toggle_on_off: str | None = None + ) -> bool: + """Toggle NAT Rule on and off.""" + payload: MutableMapping[str, Any] = {} + url = f"/api/firewall/{nat_rule_type}/toggle_rule/{uuid}" + # d_nat uses opposite logic for on/off + if nat_rule_type == "d_nat" and toggle_on_off is not None: + if toggle_on_off == "on": + url = f"{url}/0" + elif toggle_on_off == "off": + url = f"{url}/1" + elif toggle_on_off == "on": + url = f"{url}/1" + elif toggle_on_off == "off": + url = f"{url}/0" + response = await self._safe_dict_post( + url, + payload=payload, + ) + _LOGGER.debug( + "[toggle_nat_rule] uuid: %s, action: %s, url: %s, response: %s", + uuid, + toggle_on_off, + url, + response, + ) + if response.get("result") == "failed": + return False + + apply_resp = await self._safe_dict_post(f"/api/firewall/{nat_rule_type}/apply") + if apply_resp.get("status") != "OK\n\n": + return False + + return True ##################### @@ -2527,7 +2675,7 @@ async def kill_states(self, ip_addr: str) -> MutableMapping[str, Any]: "dropped_states": response.get("dropped_states", 0), } - async def toggle_alias(self, alias: str, toggle_on_off: str) -> bool: + async def toggle_alias(self, alias: str, toggle_on_off: str | None = None) -> bool: """Toggle alias on and off.""" if self._use_snake_case: alias_list_resp = await self._safe_dict_get("/api/firewall/alias/search_item") diff --git a/custom_components/opnsense/switch.py b/custom_components/opnsense/switch.py index a8b968b8..9736ab9d 100644 --- a/custom_components/opnsense/switch.py +++ b/custom_components/opnsense/switch.py @@ -16,7 +16,7 @@ ATTR_NAT_OUTBOUND, ATTR_NAT_PORT_FORWARD, ATTR_UNBOUND_BLOCKLIST, - CONF_SYNC_FILTERS_AND_NAT, + CONF_SYNC_FIREWALL_AND_NAT, CONF_SYNC_SERVICES, CONF_SYNC_UNBOUND, CONF_SYNC_VPN, @@ -52,7 +52,7 @@ async def _compile_filter_switches_legacy( continue # not possible to disable these rules - if rule.get("descr", "") == "Anti-Lockout Rule": + if rule.get("description", "") == "Anti-Lockout Rule": continue tracker = dict_get(rule, "created.time") @@ -142,7 +142,7 @@ async def _compile_nat_outbound_switches_legacy( if tracker is None or len(tracker) < 1: continue - if "Auto created rule" in rule.get("descr", ""): + if "Auto created rule" in rule.get("description", ""): continue entity = OPNsenseNatSwitchLegacy( @@ -277,6 +277,141 @@ async def _compile_unbound_switches( return entities +async def _compile_firewall_rules_switches( + config_entry: ConfigEntry, + coordinator: OPNsenseDataUpdateCoordinator, + state: MutableMapping[str, Any], +) -> list: + if not isinstance(state, MutableMapping) or not isinstance( + state.get("firewall", {}).get("rules"), dict + ): + return [] + + entities: list = [] + for rule in state.get("firewall", {}).get("rules", {}).values(): + entity = OPNsenseFirewallRuleSwitch( + config_entry=config_entry, + coordinator=coordinator, + entity_description=SwitchEntityDescription( + key=f"firewall.rule.{rule.get('uuid', 'unknown')}", + name=f"Firewall Rule: {rule.get('%interface', '')}: {rule.get('description', 'unknown')}", + icon="mdi:play-network-outline", + device_class=SwitchDeviceClass.SWITCH, + entity_registry_enabled_default=False, + ), + ) + entities.append(entity) + return entities + + +async def _compile_nat_source_rules_switches( + config_entry: ConfigEntry, + coordinator: OPNsenseDataUpdateCoordinator, + state: MutableMapping[str, Any], +) -> list: + if not isinstance(state, MutableMapping) or not isinstance( + state.get("firewall", {}).get("nat", {}).get("source_nat"), dict + ): + return [] + + entities: list = [] + for rule in state.get("firewall", {}).get("nat", {}).get("source_nat", {}).values(): + entity = OPNsenseNATRuleSwitch( + config_entry=config_entry, + coordinator=coordinator, + entity_description=SwitchEntityDescription( + key=f"firewall.nat.source_nat.{rule.get('uuid', 'unknown')}", + name=f"NAT Source Rule: {rule.get('%interface', '')}: {rule.get('description', 'unknown')}", + icon="mdi:network-outline", + device_class=SwitchDeviceClass.SWITCH, + entity_registry_enabled_default=False, + ), + ) + entities.append(entity) + return entities + + +async def _compile_nat_destination_rules_switches( + config_entry: ConfigEntry, + coordinator: OPNsenseDataUpdateCoordinator, + state: MutableMapping[str, Any], +) -> list: + if not isinstance(state, MutableMapping) or not isinstance( + state.get("firewall", {}).get("nat", {}).get("d_nat"), dict + ): + return [] + + entities: list = [] + for rule in state.get("firewall", {}).get("nat", {}).get("d_nat", {}).values(): + entity = OPNsenseNATRuleSwitch( + config_entry=config_entry, + coordinator=coordinator, + entity_description=SwitchEntityDescription( + key=f"firewall.nat.d_nat.{rule.get('uuid', 'unknown')}", + name=f"NAT Destination Rule: {rule.get('%interface', '')}: {rule.get('description', 'unknown')}", + icon="mdi:network-outline", + device_class=SwitchDeviceClass.SWITCH, + entity_registry_enabled_default=False, + ), + ) + entities.append(entity) + return entities + + +async def _compile_nat_one_to_one_rules_switches( + config_entry: ConfigEntry, + coordinator: OPNsenseDataUpdateCoordinator, + state: MutableMapping[str, Any], +) -> list: + if not isinstance(state, MutableMapping) or not isinstance( + state.get("firewall", {}).get("nat", {}).get("one_to_one"), dict + ): + return [] + + entities: list = [] + for rule in state.get("firewall", {}).get("nat", {}).get("one_to_one", {}).values(): + entity = OPNsenseNATRuleSwitch( + config_entry=config_entry, + coordinator=coordinator, + entity_description=SwitchEntityDescription( + key=f"firewall.nat.one_to_one.{rule.get('uuid', 'unknown')}", + name=f"NAT One to One Rule: {rule.get('%interface', '')}: {rule.get('description', 'unknown')}", + icon="mdi:network-outline", + device_class=SwitchDeviceClass.SWITCH, + entity_registry_enabled_default=False, + ), + ) + entities.append(entity) + return entities + + +async def _compile_nat_npt_rules_switches( + config_entry: ConfigEntry, + coordinator: OPNsenseDataUpdateCoordinator, + state: MutableMapping[str, Any], +) -> list: + if not isinstance(state, MutableMapping) or not isinstance( + state.get("firewall", {}).get("nat", {}).get("npt"), dict + ): + return [] + + entities: list = [] + for rule in state.get("firewall", {}).get("nat", {}).get("npt", {}).values(): + entity = OPNsenseNATRuleSwitch( + config_entry=config_entry, + coordinator=coordinator, + entity_description=SwitchEntityDescription( + key=f"firewall.nat.npt.{rule.get('uuid', 'unknown')}", + name=f"NAT NPT Rule: {rule.get('%interface', '')}: {rule.get('description', 'unknown')}", + icon="mdi:network-outline", + device_class=SwitchDeviceClass.SWITCH, + entity_registry_enabled_default=False, + ), + ) + entities.append(entity) + return entities + + async def async_setup_entry( hass: HomeAssistant, config_entry: ConfigEntry, @@ -292,14 +427,59 @@ async def async_setup_entry( entities: list = [] - if config.get(CONF_SYNC_FILTERS_AND_NAT, DEFAULT_SYNC_OPTION_VALUE): - entities.extend(await _compile_filter_switches_legacy(config_entry, coordinator, state)) - entities.extend( - await _compile_port_forward_switches_legacy(config_entry, coordinator, state) - ) - entities.extend( - await _compile_nat_outbound_switches_legacy(config_entry, coordinator, state) - ) + if config.get(CONF_SYNC_FIREWALL_AND_NAT, DEFAULT_SYNC_OPTION_VALUE): + firmware = state.get("host_firmware_version", None) + if firmware: + try: + if awesomeversion.AwesomeVersion(firmware) < awesomeversion.AwesomeVersion("26.1"): + entities.extend( + await _compile_filter_switches_legacy(config_entry, coordinator, state) + ) + entities.extend( + await _compile_port_forward_switches_legacy( + config_entry, coordinator, state + ) + ) + entities.extend( + await _compile_nat_outbound_switches_legacy( + config_entry, coordinator, state + ) + ) + else: + entities.extend( + await _compile_filter_switches_legacy(config_entry, coordinator, state) + ) + entities.extend( + await _compile_firewall_rules_switches(config_entry, coordinator, state) + ) + entities.extend( + await _compile_nat_source_rules_switches(config_entry, coordinator, state) + ) + entities.extend( + await _compile_nat_destination_rules_switches( + config_entry, coordinator, state + ) + ) + entities.extend( + await _compile_nat_one_to_one_rules_switches( + config_entry, coordinator, state + ) + ) + entities.extend( + await _compile_nat_npt_rules_switches(config_entry, coordinator, state) + ) + + except ( + awesomeversion.exceptions.AwesomeVersionCompareException, + TypeError, + ValueError, + ) as e: + _LOGGER.error( + "Error comparing firmware version %s when determining creating Unbound Blocklist switches. %s: %s", + firmware, + type(e).__name__, + e, + ) if config.get(CONF_SYNC_SERVICES, DEFAULT_SYNC_OPTION_VALUE): entities.extend(await _compile_service_switches(config_entry, coordinator, state)) if config.get(CONF_SYNC_VPN, DEFAULT_SYNC_OPTION_VALUE): @@ -395,6 +575,188 @@ def _clear(_: Any) -> None: ) +class OPNsenseFirewallRuleSwitch(OPNsenseSwitch): + """Class for OPNsense Firewall Rule Switch entities.""" + + def __init__( + self, + config_entry: ConfigEntry, + coordinator: OPNsenseDataUpdateCoordinator, + entity_description: SwitchEntityDescription, + ) -> None: + """Initialize switch entity.""" + super().__init__( + config_entry=config_entry, + coordinator=coordinator, + entity_description=entity_description, + ) + self._rule_id: str = self._opnsense_get_rule_id() + _LOGGER.debug( + "[OPNsenseFirewallRuleSwitch init] Name: %s, rule_id: %s", self.name, self._rule_id + ) + + def _opnsense_get_rule_id(self) -> str: + return self.entity_description.key.split(".")[-1] + + def _opnsense_get_rule(self) -> MutableMapping[str, Any] | None: + state: MutableMapping[str, Any] = self.coordinator.data + if not isinstance(state, MutableMapping): + return None + return state.get("firewall", {}).get("rules", {}).get(self._rule_id, None) + + @callback + def _handle_coordinator_update(self) -> None: + if self.delay_update: + _LOGGER.debug( + "Skipping coordinator update for firewall rule switch %s due to delay", self.name + ) + return + rule = self._opnsense_get_rule() + if not rule: + self._available = False + self.async_write_ha_state() + return + try: + self._attr_is_on = bool(rule.get("enabled", "1") == "1") + except (TypeError, KeyError, AttributeError): + self._available = False + self.async_write_ha_state() + return + self._available = True + self.async_write_ha_state() + _LOGGER.debug( + "[OPNsenseFirewallRuleSwitch handle_coordinator_update] Name: %s, available: %s, is_on: %s, extra_state_attributes: %s", + self.name, + self.available, + self.is_on, + self.extra_state_attributes, + ) + + async def async_turn_on(self, **kwargs: Any) -> None: + """Turn the entity on.""" + if self._rule_id is None or not self._client: + return + await self._client.toggle_firewall_rule(self._rule_id, "on") + _LOGGER.info("Turned on firewall rule: %s", self.name) + self._attr_is_on = True + self.async_write_ha_state() + self.delay_update = True + + async def async_turn_off(self, **kwargs: Any) -> None: + """Turn the entity off.""" + if self._rule_id is None or not self._client: + return + await self._client.toggle_firewall_rule(self._rule_id, "off") + _LOGGER.info("Turned off firewall rule: %s", self.name) + self._attr_is_on = False + self.async_write_ha_state() + self.delay_update = True + + @property + def icon(self) -> str | None: + """Return the icon for the entity.""" + if self.available and self.is_on: + return "mdi:play-network" + return super().icon + + +class OPNsenseNATRuleSwitch(OPNsenseSwitch): + """Class for OPNsense NAT Rule Switch entities.""" + + def __init__( + self, + config_entry: ConfigEntry, + coordinator: OPNsenseDataUpdateCoordinator, + entity_description: SwitchEntityDescription, + ) -> None: + """Initialize switch entity.""" + super().__init__( + config_entry=config_entry, + coordinator=coordinator, + entity_description=entity_description, + ) + self._rule_id: str = self._opnsense_get_rule_id() + self._nat_rule_type: str = self._get_nat_rule_type() + _LOGGER.debug( + "[OPNsenseNATRuleSwitch init] Name: %s, key: %s, rule_id: %s, rule_type: %s", + self.name, + self.entity_description.key, + self._rule_id, + self._nat_rule_type, + ) + + def _get_nat_rule_type(self) -> str: + return self.entity_description.key.split(".")[2] + + def _opnsense_get_rule_id(self) -> str: + return self.entity_description.key.split(".")[-1] + + def _opnsense_get_rule(self) -> MutableMapping[str, Any] | None: + state: MutableMapping[str, Any] = self.coordinator.data + if not isinstance(state, MutableMapping): + return None + return ( + state.get("firewall", {}) + .get("nat", {}) + .get(self._nat_rule_type, {}) + .get(self._rule_id, None) + ) + + @callback + def _handle_coordinator_update(self) -> None: + if self.delay_update: + _LOGGER.debug("Skipping coordinator update for NAT switch %s due to delay", self.name) + return + rule = self._opnsense_get_rule() + _LOGGER.debug("[OPNsenseNATRuleSwitch handle_coordinator_update] fetched rule: %s", rule) + if not rule: + self._available = False + self.async_write_ha_state() + return + try: + self._attr_is_on = bool(rule.get("enabled", "1") == "1") + except (TypeError, KeyError, AttributeError): + self._available = False + self.async_write_ha_state() + return + self._available = True + self.async_write_ha_state() + _LOGGER.debug( + "[OPNsenseNATRuleSwitch handle_coordinator_update] Name: %s, available: %s, is_on: %s, extra_state_attributes: %s", + self.name, + self.available, + self.is_on, + self.extra_state_attributes, + ) + + async def async_turn_on(self, **kwargs: Any) -> None: + """Turn the entity on.""" + if self._rule_id is None or not self._client: + return + await self._client.toggle_nat_rule(self._nat_rule_type, self._rule_id, "on") + _LOGGER.info("Turned on NAT rule: %s", self.name) + self._attr_is_on = True + self.async_write_ha_state() + self.delay_update = True + + async def async_turn_off(self, **kwargs: Any) -> None: + """Turn the entity off.""" + if self._rule_id is None or not self._client: + return + await self._client.toggle_nat_rule(self._nat_rule_type, self._rule_id, "off") + _LOGGER.info("Turned off NAT rule: %s", self.name) + self._attr_is_on = False + self.async_write_ha_state() + self.delay_update = True + + @property + def icon(self) -> str | None: + """Return the icon for the entity.""" + if self.available and self.is_on: + return "mdi:network" + return super().icon + + class OPNsenseFilterSwitchLegacy(OPNsenseSwitch): """Class for OPNsense Filter Switch entities.""" @@ -455,7 +817,7 @@ async def async_turn_on(self, **kwargs: Any) -> None: """Turn the entity on.""" if self._rule is None or not self._client: return - await self._client.enable_filter_rule_by_created_time(self._tracker) + await self._client.enable_filter_rule_by_created_time_legacy(self._tracker) _LOGGER.info("Turned on filter rule: %s", self.name) self._attr_is_on = True self.async_write_ha_state() @@ -465,7 +827,7 @@ async def async_turn_off(self, **kwargs: Any) -> None: """Turn the entity off.""" if self._rule is None or not self._client: return - await self._client.disable_filter_rule_by_created_time(self._tracker) + await self._client.disable_filter_rule_by_created_time_legacy(self._tracker) _LOGGER.info("Turned off filter rule: %s", self.name) self._attr_is_on = False self.async_write_ha_state() diff --git a/tests/test_coordinator.py b/tests/test_coordinator.py index 5639fd79..a7292c25 100644 --- a/tests/test_coordinator.py +++ b/tests/test_coordinator.py @@ -18,7 +18,7 @@ CONF_SYNC_CARP, CONF_SYNC_CERTIFICATES, CONF_SYNC_DHCP_LEASES, - CONF_SYNC_FILTERS_AND_NAT, + CONF_SYNC_FIREWALL_AND_NAT, CONF_SYNC_FIRMWARE_UPDATES, CONF_SYNC_GATEWAYS, CONF_SYNC_INTERFACES, @@ -403,7 +403,7 @@ def test_build_categories_returns_empty_when_no_config(make_config_entry, fake_c (CONF_SYNC_GATEWAYS, ["gateways"]), (CONF_SYNC_SERVICES, ["services"]), (CONF_SYNC_NOTICES, ["notices"]), - (CONF_SYNC_FILTERS_AND_NAT, ["config"]), + (CONF_SYNC_FIREWALL_AND_NAT, ["config"]), (CONF_SYNC_UNBOUND, [ATTR_UNBOUND_BLOCKLIST]), (CONF_SYNC_INTERFACES, ["interfaces"]), (CONF_SYNC_CERTIFICATES, ["certificates"]), diff --git a/tests/test_switch.py b/tests/test_switch.py index bbf99706..2adf0927 100644 --- a/tests/test_switch.py +++ b/tests/test_switch.py @@ -16,7 +16,7 @@ ATTR_NAT_OUTBOUND, ATTR_NAT_PORT_FORWARD, CONF_DEVICE_UNIQUE_ID, - CONF_SYNC_FILTERS_AND_NAT, + CONF_SYNC_FIREWALL_AND_NAT, CONF_SYNC_SERVICES, CONF_SYNC_UNBOUND, CONF_SYNC_VPN, @@ -24,13 +24,13 @@ ) from custom_components.opnsense.coordinator import OPNsenseDataUpdateCoordinator from custom_components.opnsense.switch import ( - OPNsenseFilterSwitch, - OPNsenseNatSwitch, + OPNsenseFilterSwitchLegacy, + OPNsenseNatSwitchLegacy, OPNsenseServiceSwitch, OPNsenseVPNSwitch, - _compile_filter_switches, - _compile_nat_outbound_switches, - _compile_port_forward_switches, + _compile_filter_switches_legacy, + _compile_nat_outbound_switches_legacy, + _compile_port_forward_switches_legacy, _compile_service_switches, _compile_static_unbound_switch_legacy, _compile_unbound_switches, @@ -52,12 +52,12 @@ def make_coord(data): "compile_fn,state,client_methods", [ ( - _compile_filter_switches, + _compile_filter_switches_legacy, {"config": {"filter": {"rule": [{"descr": "Allow LAN", "created": {"time": "t1"}}]}}}, ("enable_filter_rule_by_created_time", "disable_filter_rule_by_created_time"), ), ( - _compile_port_forward_switches, + _compile_port_forward_switches_legacy, {"config": {"nat": {"rule": [{"descr": "PF", "created": {"time": "p1"}}]}}}, ( "enable_nat_port_forward_rule_by_created_time_legacy", @@ -65,7 +65,7 @@ def make_coord(data): ), ), ( - _compile_nat_outbound_switches, + _compile_nat_outbound_switches_legacy, { "config": { "nat": {"outbound": {"rule": [{"descr": "OB", "created": {"time": "o1"}}]}} @@ -159,7 +159,7 @@ async def test_compile_port_forward_skips_non_dict(coordinator, make_config_entr "config": {"nat": {"rule": ["not-a-dict", {"descr": "PF", "created": {"time": "p2"}}]}} } coordinator.data = state - ents = await _compile_port_forward_switches(config_entry, coordinator, state) + ents = await _compile_port_forward_switches_legacy(config_entry, coordinator, state) assert len(ents) == 1 assert ents[0].entity_description.key.endswith(".p2") @@ -192,7 +192,7 @@ def fake_add_entities(entities): config_entry = make_config_entry( data={ CONF_DEVICE_UNIQUE_ID: "dev1", - CONF_SYNC_FILTERS_AND_NAT: True, + CONF_SYNC_FIREWALL_AND_NAT: True, CONF_SYNC_SERVICES: True, CONF_SYNC_VPN: True, CONF_SYNC_UNBOUND: True, @@ -291,7 +291,7 @@ def add_entities(ents): for e in created if ( ( - not isinstance(e, OPNsenseFilterSwitch) + not isinstance(e, OPNsenseFilterSwitchLegacy) and getattr(e, "_attr_unique_id", "").endswith("unbound") ) or ( @@ -342,7 +342,7 @@ def test_delay_update_setter(monkeypatch, coordinator, make_config_entry): title="OPNsenseTest", ) setattr(config_entry.runtime_data, COORDINATOR, coordinator) - ent = OPNsenseFilterSwitch( + ent = OPNsenseFilterSwitchLegacy( config_entry=config_entry, coordinator=coordinator, entity_description=desc ) # synchronous test: use a plain hass-like object with a dedicated loop @@ -597,7 +597,7 @@ async def test_filter_disabled_and_missing(coordinator, ph_hass, make_config_ent # missing rules -> compile returns [] state = {"config": {"filter": {"rule": []}}} coordinator.data = state - entities = await _compile_filter_switches(config_entry, coordinator, state) + entities = await _compile_filter_switches_legacy(config_entry, coordinator, state) assert entities == [] # disabled rule -> is_on False @@ -605,7 +605,7 @@ async def test_filter_disabled_and_missing(coordinator, ph_hass, make_config_ent "config": {"filter": {"rule": [{"descr": "x", "created": {"time": "t2"}, "disabled": "1"}]}} } coordinator.data = state - entities = await _compile_filter_switches(config_entry, coordinator, state) + entities = await _compile_filter_switches_legacy(config_entry, coordinator, state) ent = entities[0] # use PHCC-provided hass fixture hass = ph_hass @@ -675,7 +675,7 @@ async def test_unbound_skips_update_when_delay_set(coordinator, ph_hass, make_co ), ( "filter", - _compile_filter_switches, + _compile_filter_switches_legacy, { "config": { "filter": { @@ -687,7 +687,7 @@ async def test_unbound_skips_update_when_delay_set(coordinator, ph_hass, make_co ), ( "nat", - _compile_port_forward_switches, + _compile_port_forward_switches_legacy, {"config": {"nat": {"rule": [{"descr": "PF", "created": {"time": "pdelay"}}]}}}, "first", ), @@ -758,9 +758,9 @@ async def test_compile_helpers_bad_input(coordinator, make_config_entry): config_entry = make_config_entry({CONF_DEVICE_UNIQUE_ID: "dev1"}) setattr(config_entry.runtime_data, COORDINATOR, coordinator) # non-mapping state - assert await _compile_filter_switches(config_entry, coordinator, None) == [] - assert await _compile_port_forward_switches(config_entry, coordinator, None) == [] - assert await _compile_nat_outbound_switches(config_entry, coordinator, None) == [] + assert await _compile_filter_switches_legacy(config_entry, coordinator, None) == [] + assert await _compile_port_forward_switches_legacy(config_entry, coordinator, None) == [] + assert await _compile_nat_outbound_switches_legacy(config_entry, coordinator, None) == [] @pytest.mark.asyncio @@ -801,7 +801,7 @@ async def test_switch_handle_error_sets_unavailable( setattr(config_entry.runtime_data, COORDINATOR, coordinator) state = {"config": {"filter": {"rule": [{"descr": "Good", "created": {"time": "t1"}}]}}} coordinator.data = state - ent = (await _compile_filter_switches(config_entry, coordinator, state))[0] + ent = (await _compile_filter_switches_legacy(config_entry, coordinator, state))[0] ent.hass = hass_local ent.coordinator = make_coord(state) ent.entity_id = f"switch.{ent._attr_unique_id}" @@ -815,7 +815,7 @@ def _fake_get_rule_filter() -> Any: desc = SwitchEntityDescription(key="nat_port_forward.abc", name="NAT") config_entry = make_config_entry({CONF_DEVICE_UNIQUE_ID: "dev1"}) setattr(config_entry.runtime_data, COORDINATOR, coordinator) - ent = OPNsenseNatSwitch( + ent = OPNsenseNatSwitchLegacy( config_entry=config_entry, coordinator=coordinator, entity_description=desc, @@ -864,7 +864,7 @@ def test_entity_icons(make_config_entry): f_desc = SwitchEntityDescription(key="filter.t1", name="Filter") config_entry = make_config_entry({CONF_DEVICE_UNIQUE_ID: "dev1"}) setattr(config_entry.runtime_data, COORDINATOR, None) - f_ent = OPNsenseFilterSwitch( + f_ent = OPNsenseFilterSwitchLegacy( config_entry=config_entry, coordinator=make_coord({}), entity_description=f_desc, @@ -877,7 +877,7 @@ def test_entity_icons(make_config_entry): n_desc = SwitchEntityDescription(key="nat_port_forward.t1", name="NAT") config_entry = make_config_entry({CONF_DEVICE_UNIQUE_ID: "dev1"}) setattr(config_entry.runtime_data, COORDINATOR, None) - n_ent = OPNsenseNatSwitch( + n_ent = OPNsenseNatSwitchLegacy( config_entry=config_entry, coordinator=make_coord({}), entity_description=n_desc, @@ -1028,7 +1028,7 @@ def test_reset_delay_calls_existing_remover(monkeypatch, make_config_entry): desc = SwitchEntityDescription(key="filter.t1", name="Filter") config_entry = make_config_entry({CONF_DEVICE_UNIQUE_ID: "dev1"}) setattr(config_entry.runtime_data, COORDINATOR, None) - ent = OPNsenseFilterSwitch( + ent = OPNsenseFilterSwitchLegacy( config_entry=config_entry, coordinator=make_coord({}), entity_description=desc, @@ -1073,7 +1073,7 @@ async def test_compile_filter_skip_and_invalid_rules(coordinator, make_config_en } } coordinator.data = state - entities = await _compile_filter_switches(config_entry, coordinator, state) + entities = await _compile_filter_switches_legacy(config_entry, coordinator, state) # only the valid rule should be compiled assert len(entities) == 1 @@ -1096,7 +1096,7 @@ async def test_compile_nat_outbound_skips_auto_created(coordinator, make_config_ } } coordinator.data = state - ents = await _compile_nat_outbound_switches(config_entry, coordinator, state) + ents = await _compile_nat_outbound_switches_legacy(config_entry, coordinator, state) assert len(ents) == 1 @@ -1131,7 +1131,7 @@ def fake_add_entities(entities): config_entry = make_config_entry( data={ CONF_DEVICE_UNIQUE_ID: "dev1", - CONF_SYNC_FILTERS_AND_NAT: False, + CONF_SYNC_FIREWALL_AND_NAT: False, CONF_SYNC_SERVICES: False, CONF_SYNC_VPN: False, CONF_SYNC_UNBOUND: True, @@ -1215,7 +1215,7 @@ async def test_nat_handle_missing_rule_returns_none(coordinator, ph_hass, make_c # create a nat switch with rule type that doesn't exist in state desc = SwitchEntityDescription(key="nat_outbound.missing", name="Missing") config_entry = make_config_entry({CONF_DEVICE_UNIQUE_ID: "dev1"}) - ent = OPNsenseNatSwitch( + ent = OPNsenseNatSwitchLegacy( config_entry=config_entry, coordinator=coordinator, entity_description=desc, @@ -1478,7 +1478,7 @@ def test_nat_rule_type_and_tracker_methods(coordinator, make_config_entry): config_entry = make_config_entry({CONF_DEVICE_UNIQUE_ID: "dev1"}) setattr(config_entry.runtime_data, COORDINATOR, coordinator) - pf = OPNsenseNatSwitch( + pf = OPNsenseNatSwitchLegacy( config_entry=config_entry, coordinator=coordinator, entity_description=desc_pf, @@ -1486,7 +1486,7 @@ def test_nat_rule_type_and_tracker_methods(coordinator, make_config_entry): # create a separate config_entry for the outbound test config_entry_ob = make_config_entry({CONF_DEVICE_UNIQUE_ID: "dev1"}) setattr(config_entry_ob.runtime_data, COORDINATOR, coordinator) - ob = OPNsenseNatSwitch( + ob = OPNsenseNatSwitchLegacy( config_entry=config_entry_ob, coordinator=coordinator, entity_description=desc_ob, @@ -1519,7 +1519,7 @@ async def test_compile_port_forward_with_missing_rules(coordinator, make_config_ config_entry = make_config_entry({CONF_DEVICE_UNIQUE_ID: "dev1"}) setattr(config_entry.runtime_data, COORDINATOR, coordinator) coordinator.data = {"config": {}} - res = await _compile_port_forward_switches(config_entry, coordinator, coordinator.data) + res = await _compile_port_forward_switches_legacy(config_entry, coordinator, coordinator.data) assert res == [] @@ -1548,7 +1548,7 @@ def test_filter_handle_exceptions_sets_unavailable( config_entry = make_config_entry({CONF_DEVICE_UNIQUE_ID: "dev1"}) setattr(config_entry.runtime_data, COORDINATOR, coordinator) - ent = OPNsenseFilterSwitch( + ent = OPNsenseFilterSwitchLegacy( config_entry=config_entry, coordinator=coordinator, entity_description=desc ) ent.hass = MagicMock(spec=HomeAssistant) @@ -1582,7 +1582,7 @@ def test_nat_handle_exceptions_sets_unavailable(exc_type, coordinator, make_conf config_entry = make_config_entry({CONF_DEVICE_UNIQUE_ID: "dev1"}) setattr(config_entry.runtime_data, COORDINATOR, coordinator) - ent = OPNsenseNatSwitch( + ent = OPNsenseNatSwitchLegacy( config_entry=config_entry, coordinator=coordinator, entity_description=desc ) ent.hass = MagicMock(spec=HomeAssistant) From 64d6b02970b96f9a5ea2a10dff4bbf12e774a8d1 Mon Sep 17 00:00:00 2001 From: Snuffy2 Date: Mon, 2 Feb 2026 21:11:04 -0500 Subject: [PATCH 03/15] Add docstrings and fix tests --- .../opnsense/pyopnsense/__init__.py | 64 ++- custom_components/opnsense/switch.py | 467 +++++++++++++++++- tests/test_coordinator.py | 2 +- tests/test_pyopnsense.py | 37 +- tests/test_switch.py | 140 ++++-- 5 files changed, 624 insertions(+), 86 deletions(-) diff --git a/custom_components/opnsense/pyopnsense/__init__.py b/custom_components/opnsense/pyopnsense/__init__.py index c1bc4987..1051989d 100644 --- a/custom_components/opnsense/pyopnsense/__init__.py +++ b/custom_components/opnsense/pyopnsense/__init__.py @@ -929,7 +929,6 @@ async def disable_nat_outbound_rule_by_created_time_legacy(self, created_time: s await self._restore_config_section("nat", config["nat"]) await self._filter_configure() - ##################### @_log_errors async def get_firewall(self) -> dict[str, Any]: """Retrieve all firewall and NAT rules from OPNsense. @@ -991,12 +990,17 @@ async def _get_interface_firewall_map(self) -> dict[str, Any]: async def _get_firewall_rules(self, interface_map: dict[str, Any]) -> dict[str, Any]: """Retrieve firewall rules from OPNsense. + Parameters + ---------- + interface_map : dict[str, Any] + A mapping of interface names to firewall interface names. + Returns ------- - list of dict - A list of dictionaries, each representing a firewall rule parsed + dict[str, Any] + A dictionary of firewall rules, keyed by UUID, each representing a firewall rule parsed from CSV format. Dictionary keys correspond to CSV headers such - as '@uuid', 'enabled', 'action', etc. + as 'uuid', 'enabled', 'action', etc. """ response = await self._get_raw("/api/firewall/filter/download_rules") @@ -1031,8 +1035,8 @@ async def _get_nat_destination_rules(self) -> dict[str, Any]: Returns ------- - list of dict - A list of dictionaries representing NAT destination rules. + dict[str, Any] + A dictionary of NAT destination rules, keyed by UUID. """ request_body: MutableMapping[str, Any] = {"current": 1, "sort": {}} @@ -1059,8 +1063,8 @@ async def _get_nat_one_to_one_rules(self) -> dict[str, Any]: Returns ------- - list of dict - A list of dictionaries representing NAT one-to-one rules. + dict[str, Any] + A dictionary of NAT one-to-one rules, keyed by UUID. """ request_body: MutableMapping[str, Any] = {"current": 1, "sort": {}} @@ -1086,8 +1090,8 @@ async def _get_nat_source_rules(self) -> dict[str, Any]: Returns ------- - list of dict - A list of dictionaries representing NAT source rules. + dict[str, Any] + A dictionary of NAT source rules, keyed by UUID. """ request_body: MutableMapping[str, Any] = {"current": 1, "sort": {}} @@ -1113,8 +1117,8 @@ async def _get_nat_npt_rules(self) -> dict[str, Any]: Returns ------- - list of dict - A list of dictionaries representing NAT NPT rules. + dict[str, Any] + A dictionary of NAT NPT rules, keyed by UUID. """ request_body: MutableMapping[str, Any] = {"current": 1, "sort": {}} @@ -1133,7 +1137,21 @@ async def _get_nat_npt_rules(self) -> dict[str, Any]: return rules_dict async def toggle_firewall_rule(self, uuid: str, toggle_on_off: str | None = None) -> bool: - """Toggle Firewall Rule on and off.""" + """Toggle Firewall Rule on and off. + + Parameters + ---------- + uuid : str + The UUID of the firewall rule to toggle. + toggle_on_off : str | None, optional + The action to perform: 'on' to enable, 'off' to disable, or None to toggle. + + Returns + ------- + bool + True if the operation was successful, False otherwise. + + """ payload: MutableMapping[str, Any] = {} url = f"/api/firewall/filter/toggle_rule/{uuid}" if toggle_on_off == "on": @@ -1163,7 +1181,23 @@ async def toggle_firewall_rule(self, uuid: str, toggle_on_off: str | None = None async def toggle_nat_rule( self, nat_rule_type: str, uuid: str, toggle_on_off: str | None = None ) -> bool: - """Toggle NAT Rule on and off.""" + """Toggle NAT Rule on and off. + + Parameters + ---------- + nat_rule_type : str + The type of NAT rule (e.g., 'd_nat', 'source_nat'). + uuid : str + The UUID of the NAT rule to toggle. + toggle_on_off : str | None, optional + The action to perform: 'on' to enable, 'off' to disable, or None to toggle. + + Returns + ------- + bool + True if the operation was successful, False otherwise. + + """ payload: MutableMapping[str, Any] = {} url = f"/api/firewall/{nat_rule_type}/toggle_rule/{uuid}" # d_nat uses opposite logic for on/off @@ -1196,8 +1230,6 @@ async def toggle_nat_rule( return True - ##################### - @_log_errors async def get_arp_table(self, resolve_hostnames: bool = False) -> list: """Return the active ARP table.""" diff --git a/custom_components/opnsense/switch.py b/custom_components/opnsense/switch.py index 9736ab9d..bb2f0d6f 100644 --- a/custom_components/opnsense/switch.py +++ b/custom_components/opnsense/switch.py @@ -35,6 +35,23 @@ async def _compile_filter_switches_legacy( coordinator: OPNsenseDataUpdateCoordinator, state: MutableMapping[str, Any], ) -> list: + """Compile legacy filter rule switches from OPNsense state. + + Parameters + ---------- + config_entry : ConfigEntry + The Home Assistant config entry. + coordinator : OPNsenseDataUpdateCoordinator + The data update coordinator. + state : MutableMapping[str, Any] + The current state data from OPNsense. + + Returns + ------- + list + A list of OPNsenseFilterSwitchLegacy entities. + + """ if not isinstance(state, MutableMapping) or not isinstance( state.get("firewall", {}).get("config", {}).get("filter", {}).get("rule"), list ): @@ -83,6 +100,23 @@ async def _compile_port_forward_switches_legacy( coordinator: OPNsenseDataUpdateCoordinator, state: MutableMapping[str, Any], ) -> list: + """Compile legacy NAT port forward rule switches from OPNsense state. + + Parameters + ---------- + config_entry : ConfigEntry + The Home Assistant config entry. + coordinator : OPNsenseDataUpdateCoordinator + The data update coordinator. + state : MutableMapping[str, Any] + The current state data from OPNsense. + + Returns + ------- + list + A list of OPNsenseNatSwitchLegacy entities for port forward rules. + + """ if not isinstance(state, MutableMapping) or not isinstance( state.get("firewall", {}).get("config", {}).get("nat", {}).get("rule"), list ): @@ -123,6 +157,23 @@ async def _compile_nat_outbound_switches_legacy( coordinator: OPNsenseDataUpdateCoordinator, state: MutableMapping[str, Any], ) -> list: + """Compile legacy NAT outbound rule switches from OPNsense state. + + Parameters + ---------- + config_entry : ConfigEntry + The Home Assistant config entry. + coordinator : OPNsenseDataUpdateCoordinator + The data update coordinator. + state : MutableMapping[str, Any] + The current state data from OPNsense. + + Returns + ------- + list + A list of OPNsenseNatSwitchLegacy entities for outbound rules. + + """ if not isinstance(state, MutableMapping) or not isinstance( state.get("firewall", {}).get("config", {}).get("nat", {}).get("outbound", {}).get("rule"), list, @@ -167,6 +218,23 @@ async def _compile_service_switches( coordinator: OPNsenseDataUpdateCoordinator, state: MutableMapping[str, Any], ) -> list: + """Compile service switches from OPNsense state. + + Parameters + ---------- + config_entry : ConfigEntry + The Home Assistant config entry. + coordinator : OPNsenseDataUpdateCoordinator + The data update coordinator. + state : MutableMapping[str, Any] + The current state data from OPNsense. + + Returns + ------- + list + A list of OPNsenseServiceSwitch entities. + + """ if not isinstance(state, MutableMapping) or not isinstance(state.get("services"), list): return [] @@ -197,6 +265,23 @@ async def _compile_vpn_switches( coordinator: OPNsenseDataUpdateCoordinator, state: MutableMapping[str, Any], ) -> list: + """Compile VPN switches from OPNsense state. + + Parameters + ---------- + config_entry : ConfigEntry + The Home Assistant config entry. + coordinator : OPNsenseDataUpdateCoordinator + The data update coordinator. + state : MutableMapping[str, Any] + The current state data from OPNsense. + + Returns + ------- + list + A list of OPNsenseVPNSwitch entities. + + """ entities: list = [] for vpn_type in ("openvpn", "wireguard"): for clients_servers in ("clients", "servers"): @@ -230,6 +315,23 @@ async def _compile_static_unbound_switch_legacy( coordinator: OPNsenseDataUpdateCoordinator, state: MutableMapping[str, Any], ) -> list: + """Compile legacy static Unbound blocklist switch from OPNsense state. + + Parameters + ---------- + config_entry : ConfigEntry + The Home Assistant config entry. + coordinator : OPNsenseDataUpdateCoordinator + The data update coordinator. + state : MutableMapping[str, Any] + The current state data from OPNsense. + + Returns + ------- + list + A list containing a single OPNsenseUnboundBlocklistSwitchLegacy entity. + + """ entities: list = [] entity = OPNsenseUnboundBlocklistSwitchLegacy( config_entry=config_entry, @@ -253,6 +355,23 @@ async def _compile_unbound_switches( coordinator: OPNsenseDataUpdateCoordinator, state: MutableMapping[str, Any], ) -> list: + """Compile Unbound blocklist switches from OPNsense state. + + Parameters + ---------- + config_entry : ConfigEntry + The Home Assistant config entry. + coordinator : OPNsenseDataUpdateCoordinator + The data update coordinator. + state : MutableMapping[str, Any] + The current state data from OPNsense. + + Returns + ------- + list + A list of OPNsenseUnboundBlocklistSwitch entities. + + """ if not isinstance(state, MutableMapping): return [] entities: list = [] @@ -282,6 +401,23 @@ async def _compile_firewall_rules_switches( coordinator: OPNsenseDataUpdateCoordinator, state: MutableMapping[str, Any], ) -> list: + """Compile firewall rule switches from OPNsense state. + + Parameters + ---------- + config_entry : ConfigEntry + The Home Assistant config entry. + coordinator : OPNsenseDataUpdateCoordinator + The data update coordinator. + state : MutableMapping[str, Any] + The current state data from OPNsense. + + Returns + ------- + list + A list of OPNsenseFirewallRuleSwitch entities. + + """ if not isinstance(state, MutableMapping) or not isinstance( state.get("firewall", {}).get("rules"), dict ): @@ -309,6 +445,23 @@ async def _compile_nat_source_rules_switches( coordinator: OPNsenseDataUpdateCoordinator, state: MutableMapping[str, Any], ) -> list: + """Compile NAT source rule switches from OPNsense state. + + Parameters + ---------- + config_entry : ConfigEntry + The Home Assistant config entry. + coordinator : OPNsenseDataUpdateCoordinator + The data update coordinator. + state : MutableMapping[str, Any] + The current state data from OPNsense. + + Returns + ------- + list + A list of OPNsenseNATRuleSwitch entities for source NAT rules. + + """ if not isinstance(state, MutableMapping) or not isinstance( state.get("firewall", {}).get("nat", {}).get("source_nat"), dict ): @@ -336,6 +489,23 @@ async def _compile_nat_destination_rules_switches( coordinator: OPNsenseDataUpdateCoordinator, state: MutableMapping[str, Any], ) -> list: + """Compile NAT destination rule switches from OPNsense state. + + Parameters + ---------- + config_entry : ConfigEntry + The Home Assistant config entry. + coordinator : OPNsenseDataUpdateCoordinator + The data update coordinator. + state : MutableMapping[str, Any] + The current state data from OPNsense. + + Returns + ------- + list + A list of OPNsenseNATRuleSwitch entities for destination NAT rules. + + """ if not isinstance(state, MutableMapping) or not isinstance( state.get("firewall", {}).get("nat", {}).get("d_nat"), dict ): @@ -363,6 +533,23 @@ async def _compile_nat_one_to_one_rules_switches( coordinator: OPNsenseDataUpdateCoordinator, state: MutableMapping[str, Any], ) -> list: + """Compile NAT one-to-one rule switches from OPNsense state. + + Parameters + ---------- + config_entry : ConfigEntry + The Home Assistant config entry. + coordinator : OPNsenseDataUpdateCoordinator + The data update coordinator. + state : MutableMapping[str, Any] + The current state data from OPNsense. + + Returns + ------- + list + A list of OPNsenseNATRuleSwitch entities for one-to-one NAT rules. + + """ if not isinstance(state, MutableMapping) or not isinstance( state.get("firewall", {}).get("nat", {}).get("one_to_one"), dict ): @@ -390,6 +577,23 @@ async def _compile_nat_npt_rules_switches( coordinator: OPNsenseDataUpdateCoordinator, state: MutableMapping[str, Any], ) -> list: + """Compile NAT NPT rule switches from OPNsense state. + + Parameters + ---------- + config_entry : ConfigEntry + The Home Assistant config entry. + coordinator : OPNsenseDataUpdateCoordinator + The data update coordinator. + state : MutableMapping[str, Any] + The current state data from OPNsense. + + Returns + ------- + list + A list of OPNsenseNATRuleSwitch entities for NPT NAT rules. + + """ if not isinstance(state, MutableMapping) or not isinstance( state.get("firewall", {}).get("nat", {}).get("npt"), dict ): @@ -417,7 +621,18 @@ async def async_setup_entry( config_entry: ConfigEntry, async_add_entities: AddEntitiesCallback, ) -> None: - """Set up the OPNsense switches.""" + """Set up the OPNsense switches. + + Parameters + ---------- + hass : HomeAssistant + The Home Assistant instance. + config_entry : ConfigEntry + The config entry for this integration. + async_add_entities : AddEntitiesCallback + Callback to add entities to Home Assistant. + + """ coordinator: OPNsenseDataUpdateCoordinator = getattr(config_entry.runtime_data, COORDINATOR) state: MutableMapping[str, Any] = coordinator.data if not isinstance(state, MutableMapping): @@ -527,7 +742,18 @@ def __init__( coordinator: OPNsenseDataUpdateCoordinator, entity_description: SwitchEntityDescription, ) -> None: - """Initialize OPNsense Switch entities.""" + """Initialize OPNsense Switch entities. + + Parameters + ---------- + config_entry : ConfigEntry + The Home Assistant config entry. + coordinator : OPNsenseDataUpdateCoordinator + The data update coordinator. + entity_description : SwitchEntityDescription + The entity description. + + """ name_suffix: str | None = ( entity_description.name if isinstance(entity_description.name, str) else None ) @@ -548,11 +774,26 @@ def __init__( @property def delay_update(self) -> bool: - """Return whether to process the coordinator update or not.""" + """Return whether to process the coordinator update or not. + + Returns + ------- + bool + True if updates should be delayed, False otherwise. + + """ return self._delay_update @delay_update.setter def delay_update(self, value: bool) -> None: + """Set whether to delay coordinator updates. + + Parameters + ---------- + value : bool + True to delay updates, False to allow them. + + """ if value and not self._delay_update: self._delay_update = True self._reset_delay() @@ -563,6 +804,7 @@ def delay_update(self, value: bool) -> None: self._delay_update_remove = None def _reset_delay(self) -> None: + """Reset the delay timer for coordinator updates.""" if self._delay_update_remove: self._delay_update_remove() @@ -584,7 +826,18 @@ def __init__( coordinator: OPNsenseDataUpdateCoordinator, entity_description: SwitchEntityDescription, ) -> None: - """Initialize switch entity.""" + """Initialize switch entity. + + Parameters + ---------- + config_entry : ConfigEntry + The Home Assistant config entry. + coordinator : OPNsenseDataUpdateCoordinator + The data update coordinator. + entity_description : SwitchEntityDescription + The entity description. + + """ super().__init__( config_entry=config_entry, coordinator=coordinator, @@ -596,9 +849,25 @@ def __init__( ) def _opnsense_get_rule_id(self) -> str: + """Get the rule ID from the entity description. + + Returns + ------- + str + The rule ID. + + """ return self.entity_description.key.split(".")[-1] def _opnsense_get_rule(self) -> MutableMapping[str, Any] | None: + """Get the firewall rule data from the coordinator. + + Returns + ------- + MutableMapping[str, Any] | None + The rule data if available, None otherwise. + + """ state: MutableMapping[str, Any] = self.coordinator.data if not isinstance(state, MutableMapping): return None @@ -606,6 +875,7 @@ def _opnsense_get_rule(self) -> MutableMapping[str, Any] | None: @callback def _handle_coordinator_update(self) -> None: + """Handle coordinator update for the firewall rule switch.""" if self.delay_update: _LOGGER.debug( "Skipping coordinator update for firewall rule switch %s due to delay", self.name @@ -669,7 +939,18 @@ def __init__( coordinator: OPNsenseDataUpdateCoordinator, entity_description: SwitchEntityDescription, ) -> None: - """Initialize switch entity.""" + """Initialize switch entity. + + Parameters + ---------- + config_entry : ConfigEntry + The Home Assistant config entry. + coordinator : OPNsenseDataUpdateCoordinator + The data update coordinator. + entity_description : SwitchEntityDescription + The entity description. + + """ super().__init__( config_entry=config_entry, coordinator=coordinator, @@ -686,12 +967,36 @@ def __init__( ) def _get_nat_rule_type(self) -> str: + """Get the NAT rule type from the entity description. + + Returns + ------- + str + The NAT rule type. + + """ return self.entity_description.key.split(".")[2] def _opnsense_get_rule_id(self) -> str: + """Get the rule ID from the entity description. + + Returns + ------- + str + The rule ID. + + """ return self.entity_description.key.split(".")[-1] def _opnsense_get_rule(self) -> MutableMapping[str, Any] | None: + """Get the NAT rule data from the coordinator. + + Returns + ------- + MutableMapping[str, Any] | None + The rule data if available, None otherwise. + + """ state: MutableMapping[str, Any] = self.coordinator.data if not isinstance(state, MutableMapping): return None @@ -704,6 +1009,7 @@ def _opnsense_get_rule(self) -> MutableMapping[str, Any] | None: @callback def _handle_coordinator_update(self) -> None: + """Handle coordinator update for the NAT rule switch.""" if self.delay_update: _LOGGER.debug("Skipping coordinator update for NAT switch %s due to delay", self.name) return @@ -766,7 +1072,18 @@ def __init__( coordinator: OPNsenseDataUpdateCoordinator, entity_description: SwitchEntityDescription, ) -> None: - """Initialize switch entity.""" + """Initialize switch entity. + + Parameters + ---------- + config_entry : ConfigEntry + The Home Assistant config entry. + coordinator : OPNsenseDataUpdateCoordinator + The data update coordinator. + entity_description : SwitchEntityDescription + The entity description. + + """ super().__init__( config_entry=config_entry, coordinator=coordinator, @@ -777,11 +1094,27 @@ def __init__( # _LOGGER.debug(f"[OPNsenseFilterSwitchLegacy init] Name: {self.name}, tracker: {self._tracker}") def _opnsense_get_tracker(self) -> str: + """Get the tracker from the entity description. + + Returns + ------- + str + The tracker string. + + """ parts = self.entity_description.key.split(".") parts.pop(0) return ".".join(parts) def _opnsense_get_rule(self) -> MutableMapping[str, Any] | None: + """Get the filter rule data from the coordinator. + + Returns + ------- + MutableMapping[str, Any] | None + The rule data if available, None otherwise. + + """ state: MutableMapping[str, Any] = self.coordinator.data tracker: str = self._opnsense_get_tracker() if not isinstance(state, MutableMapping): @@ -793,6 +1126,7 @@ def _opnsense_get_rule(self) -> MutableMapping[str, Any] | None: @callback def _handle_coordinator_update(self) -> None: + """Handle coordinator update for the filter switch.""" if self.delay_update: _LOGGER.debug( "Skipping coordinator update for filter switch %s due to delay", self.name @@ -850,7 +1184,18 @@ def __init__( coordinator: OPNsenseDataUpdateCoordinator, entity_description: SwitchEntityDescription, ) -> None: - """Initialize switch entity.""" + """Initialize switch entity. + + Parameters + ---------- + config_entry : ConfigEntry + The Home Assistant config entry. + coordinator : OPNsenseDataUpdateCoordinator + The data update coordinator. + entity_description : SwitchEntityDescription + The entity description. + + """ super().__init__( config_entry=config_entry, coordinator=coordinator, @@ -862,14 +1207,38 @@ def __init__( # _LOGGER.debug(f"[OPNsenseNatSwitchLegacy init] Name: {self.name}, tracker: {self._tracker}, rule_type: {self._rule_type}") def _opnsense_get_rule_type(self) -> str: + """Get the rule type from the entity description. + + Returns + ------- + str + The rule type. + + """ return self.entity_description.key.split(".")[0] def _opnsense_get_tracker(self) -> str: + """Get the tracker from the entity description. + + Returns + ------- + str + The tracker string. + + """ parts = self.entity_description.key.split(".") parts.pop(0) return ".".join(parts) def _opnsense_get_rule(self) -> MutableMapping[str, Any] | None: + """Get the NAT rule data from the coordinator. + + Returns + ------- + MutableMapping[str, Any] | None + The rule data if available, None otherwise. + + """ state: MutableMapping[str, Any] = self.coordinator.data if not isinstance(state, MutableMapping): return None @@ -892,6 +1261,7 @@ def _opnsense_get_rule(self) -> MutableMapping[str, Any] | None: @callback def _handle_coordinator_update(self) -> None: + """Handle coordinator update for the NAT switch.""" if self.delay_update: _LOGGER.debug("Skipping coordinator update for NAT switch %s due to delay", self.name) return @@ -957,7 +1327,18 @@ def __init__( coordinator: OPNsenseDataUpdateCoordinator, entity_description: SwitchEntityDescription, ) -> None: - """Initialize switch entity.""" + """Initialize switch entity. + + Parameters + ---------- + config_entry : ConfigEntry + The Home Assistant config entry. + coordinator : OPNsenseDataUpdateCoordinator + The data update coordinator. + entity_description : SwitchEntityDescription + The entity description. + + """ super().__init__( config_entry=config_entry, coordinator=coordinator, @@ -968,12 +1349,36 @@ def __init__( # _LOGGER.debug(f"[OPNsenseServiceSwitch init] Name: {self.name}, prop_name: {self._prop_name}") def _opnsense_get_property_name(self) -> str: + """Get the property name from the entity description. + + Returns + ------- + str + The property name. + + """ return self.entity_description.key.split(".")[2] def _opnsense_get_service_id(self) -> str: + """Get the service ID from the entity description. + + Returns + ------- + str + The service ID. + + """ return self.entity_description.key.split(".")[1] def _opnsense_get_service(self) -> MutableMapping[str, Any] | None: + """Get the service data from the coordinator. + + Returns + ------- + MutableMapping[str, Any] | None + The service data if available, None otherwise. + + """ state: MutableMapping[str, Any] = self.coordinator.data if not isinstance(state, MutableMapping): return None @@ -985,6 +1390,7 @@ def _opnsense_get_service(self) -> MutableMapping[str, Any] | None: @callback def _handle_coordinator_update(self) -> None: + """Handle coordinator update for the service switch.""" if self.delay_update: _LOGGER.debug( "Skipping coordinator update for service switch %s due to delay", self.name @@ -1114,7 +1520,18 @@ def __init__( coordinator: OPNsenseDataUpdateCoordinator, entity_description: SwitchEntityDescription, ) -> None: - """Initialize switch entity.""" + """Initialize switch entity. + + Parameters + ---------- + config_entry : ConfigEntry + The Home Assistant config entry. + coordinator : OPNsenseDataUpdateCoordinator + The data update coordinator. + entity_description : SwitchEntityDescription + The entity description. + + """ super().__init__( config_entry=config_entry, coordinator=coordinator, @@ -1194,7 +1611,18 @@ def __init__( coordinator: OPNsenseDataUpdateCoordinator, entity_description: SwitchEntityDescription, ) -> None: - """Initialize switch entity.""" + """Initialize switch entity. + + Parameters + ---------- + config_entry : ConfigEntry + The Home Assistant config entry. + coordinator : OPNsenseDataUpdateCoordinator + The data update coordinator. + entity_description : SwitchEntityDescription + The entity description. + + """ super().__init__( config_entry=config_entry, coordinator=coordinator, @@ -1207,6 +1635,7 @@ def __init__( @callback def _handle_coordinator_update(self) -> None: + """Handle coordinator update for the VPN switch.""" if self.delay_update: _LOGGER.debug("Skipping coordinator update for VPN switch %s due to delay", self.name) return @@ -1267,7 +1696,14 @@ def _handle_coordinator_update(self) -> None: # _LOGGER.debug(f"[OPNsenseVPNSwitch handle_coordinator_update] Name: {self.name}, available: {self.available}, is_on: {self.is_on}, extra_state_attributes: {self.extra_state_attributes}") async def async_turn_on(self, **kwargs: Any) -> None: - """Turn the entity on.""" + """Turn on the VPN switch. + + Parameters + ---------- + **kwargs : Any + Additional keyword arguments. + + """ if self.is_on or not self._client: return @@ -1284,7 +1720,14 @@ async def async_turn_on(self, **kwargs: Any) -> None: _LOGGER.error("Failed to turn on VPN: %s", self.name) async def async_turn_off(self, **kwargs: Any) -> None: - """Turn the entity off.""" + """Turn off the VPN switch. + + Parameters + ---------- + **kwargs : Any + Additional keyword arguments. + + """ if not self.is_on or not self._client: return diff --git a/tests/test_coordinator.py b/tests/test_coordinator.py index a7292c25..6268f7f0 100644 --- a/tests/test_coordinator.py +++ b/tests/test_coordinator.py @@ -403,7 +403,7 @@ def test_build_categories_returns_empty_when_no_config(make_config_entry, fake_c (CONF_SYNC_GATEWAYS, ["gateways"]), (CONF_SYNC_SERVICES, ["services"]), (CONF_SYNC_NOTICES, ["notices"]), - (CONF_SYNC_FIREWALL_AND_NAT, ["config"]), + (CONF_SYNC_FIREWALL_AND_NAT, ["firewall"]), (CONF_SYNC_UNBOUND, [ATTR_UNBOUND_BLOCKLIST]), (CONF_SYNC_INTERFACES, ["interfaces"]), (CONF_SYNC_CERTIFICATES, ["certificates"]), diff --git a/tests/test_pyopnsense.py b/tests/test_pyopnsense.py index 111ccd02..465a26ba 100644 --- a/tests/test_pyopnsense.py +++ b/tests/test_pyopnsense.py @@ -1313,21 +1313,21 @@ async def test_enable_and_disable_filter_rules_and_nat_port_forward(make_client) url="http://localhost", username="u", password="p", session=session ) - # enable_filter_rule_by_created_time: rule has 'disabled' -> should remove and call restore+configure + # enable_filter_rule_by_created_time_legacy: rule has 'disabled' -> should remove and call restore+configure cfg_enable = {"filter": {"rule": [{"created": {"time": "t-enable"}, "disabled": "1"}]}} client.get_config = AsyncMock(return_value=cfg_enable) client._restore_config_section = AsyncMock() client._filter_configure = AsyncMock() - await client.enable_filter_rule_by_created_time("t-enable") + await client.enable_filter_rule_by_created_time_legacy("t-enable") client._restore_config_section.assert_called() client._filter_configure.assert_awaited() - # disable_filter_rule_by_created_time: rule missing 'disabled' -> should add it and call restore+configure + # disable_filter_rule_by_created_time_legacy: rule missing 'disabled' -> should add it and call restore+configure cfg_disable = {"filter": {"rule": [{"created": {"time": "t-disable"}}]}} client.get_config = AsyncMock(return_value=cfg_disable) client._restore_config_section = AsyncMock() client._filter_configure = AsyncMock() - await client.disable_filter_rule_by_created_time("t-disable") + await client.disable_filter_rule_by_created_time_legacy("t-disable") client._restore_config_section.assert_called() client._filter_configure.assert_awaited() @@ -1500,15 +1500,22 @@ def add_entities(ents): # filter switches: validate via public switch setup path as a smoke test state2 = { - "config": { - "filter": { - "rule": [ - {"descr": "Anti-Lockout Rule", "created": {"time": "t1"}}, - {"descr": "Normal", "created": {"time": "t2"}, "associated-rule-id": "r1"}, - {"descr": "Ok", "created": {"time": "t3"}}, - ] + "host_firmware_version": "25.7.8", + "firewall": { + "config": { + "filter": { + "rule": [ + {"description": "Anti-Lockout Rule", "created": {"time": "t1"}}, + { + "description": "Normal", + "created": {"time": "t2"}, + "associated-rule-id": "r1", + }, + {"description": "Ok", "created": {"time": "t3"}}, + ] + } } - } + }, } # prepare a switch config entry with filter sync enabled switch_cfg = make_config_entry( @@ -2692,7 +2699,7 @@ async def test_get_config_and_rule_enable_disable_branches() -> None: client._exec_php = AsyncMock(return_value=fake_config) # calling enable should remove 'disabled' and call restore/configure (no exception) - await client.enable_filter_rule_by_created_time("t1") + await client.enable_filter_rule_by_created_time_legacy("t1") # disable_nat_port_forward: add a rule without 'disabled' and expect it to set 'disabled' client._exec_php = AsyncMock( @@ -2890,7 +2897,7 @@ async def test_reset_and_get_query_counts(): ), ], ) -async def test_enable_filter_rule_by_created_time( +async def test_enable_filter_rule_by_created_time_legacy( make_client, rules, created_time, should_call ) -> None: """Ensure enabling a filter rule removes 'disabled' and triggers restore/configure only when appropriate. @@ -2906,7 +2913,7 @@ async def test_enable_filter_rule_by_created_time( client._restore_config_section = AsyncMock() client._filter_configure = AsyncMock() - await client.enable_filter_rule_by_created_time(created_time) + await client.enable_filter_rule_by_created_time_legacy(created_time) if should_call: client._restore_config_section.assert_awaited() diff --git a/tests/test_switch.py b/tests/test_switch.py index 2adf0927..72a76b4b 100644 --- a/tests/test_switch.py +++ b/tests/test_switch.py @@ -53,12 +53,25 @@ def make_coord(data): [ ( _compile_filter_switches_legacy, - {"config": {"filter": {"rule": [{"descr": "Allow LAN", "created": {"time": "t1"}}]}}}, - ("enable_filter_rule_by_created_time", "disable_filter_rule_by_created_time"), + { + "firewall": { + "config": { + "filter": {"rule": [{"descr": "Allow LAN", "created": {"time": "t1"}}]} + } + } + }, + ( + "enable_filter_rule_by_created_time_legacy", + "disable_filter_rule_by_created_time_legacy", + ), ), ( _compile_port_forward_switches_legacy, - {"config": {"nat": {"rule": [{"descr": "PF", "created": {"time": "p1"}}]}}}, + { + "firewall": { + "config": {"nat": {"rule": [{"descr": "PF", "created": {"time": "p1"}}]}} + } + }, ( "enable_nat_port_forward_rule_by_created_time_legacy", "disable_nat_port_forward_rule_by_created_time_legacy", @@ -67,8 +80,10 @@ def make_coord(data): ( _compile_nat_outbound_switches_legacy, { - "config": { - "nat": {"outbound": {"rule": [{"descr": "OB", "created": {"time": "o1"}}]}} + "firewall": { + "config": { + "nat": {"outbound": {"rule": [{"descr": "OB", "created": {"time": "o1"}}]}} + } } }, ( @@ -151,12 +166,15 @@ async def test_compile_port_forward_skips_non_dict(coordinator, make_config_entr """Port forward compilation should skip non-dict rule entries.""" config_entry = make_config_entry( data={CONF_DEVICE_UNIQUE_ID: "dev1", "url": "http://example"}, + options={CONF_SYNC_FIREWALL_AND_NAT: True}, title="OPNsenseTest", ) setattr(config_entry.runtime_data, COORDINATOR, coordinator) # include a non-dict in nat.rule which should be skipped state = { - "config": {"nat": {"rule": ["not-a-dict", {"descr": "PF", "created": {"time": "p2"}}]}} + "firewall": { + "config": {"nat": {"rule": ["not-a-dict", {"descr": "PF", "created": {"time": "p2"}}]}} + } } coordinator.data = state ents = await _compile_port_forward_switches_legacy(config_entry, coordinator, state) @@ -174,11 +192,13 @@ def fake_add_entities(entities): # create a state that contains one of each entity type state = { - "config": { - "filter": {"rule": [{"descr": "Allow", "created": {"time": "f1"}}]}, - "nat": { - "rule": [{"descr": "PF", "created": {"time": "p1"}}], - "outbound": {"rule": [{"descr": "OB", "created": {"time": "o1"}}]}, + "firewall": { + "config": { + "filter": {"rule": [{"descr": "Allow", "created": {"time": "f1"}}]}, + "nat": { + "rule": [{"descr": "PF", "created": {"time": "p1"}}], + "outbound": {"rule": [{"descr": "OB", "created": {"time": "o1"}}]}, + }, }, }, "services": [{"id": "s1", "name": "svc", "locked": 0, "status": True}], @@ -206,7 +226,7 @@ def fake_add_entities(entities): # compute expected counts from coordinator.data to avoid brittle hard-coded value expected = 0 - cfg = coordinator.data.get("config", {}) + cfg = coordinator.data.get("firewall", {}).get("config", {}) # filter rules expected += len(cfg.get("filter", {}).get("rule", []) or []) # port forward rules @@ -591,18 +611,23 @@ async def test_filter_disabled_and_missing(coordinator, ph_hass, make_config_ent """Filter compilation handles missing and disabled rules correctly.""" config_entry = make_config_entry( data={CONF_DEVICE_UNIQUE_ID: "dev1", "url": "http://example"}, + options={CONF_SYNC_FIREWALL_AND_NAT: True}, title="OPNsenseTest", ) setattr(config_entry.runtime_data, COORDINATOR, coordinator) # missing rules -> compile returns [] - state = {"config": {"filter": {"rule": []}}} + state = {"firewall": {"config": {"filter": {"rule": []}}}} coordinator.data = state entities = await _compile_filter_switches_legacy(config_entry, coordinator, state) assert entities == [] # disabled rule -> is_on False state = { - "config": {"filter": {"rule": [{"descr": "x", "created": {"time": "t2"}, "disabled": "1"}]}} + "firewall": { + "config": { + "filter": {"rule": [{"descr": "x", "created": {"time": "t2"}, "disabled": "1"}]} + } + } } coordinator.data = state entities = await _compile_filter_switches_legacy(config_entry, coordinator, state) @@ -665,31 +690,42 @@ async def test_unbound_skips_update_when_delay_set(coordinator, ph_hass, make_co @pytest.mark.asyncio @pytest.mark.parametrize( - "kind,compile_fn,state,selector", + "kind,compile_fn,state,selector,options", [ ( "unbound", _compile_static_unbound_switch_legacy, {"unbound_blocklist": {"legacy": {"enabled": "1"}}}, "first", + {CONF_SYNC_UNBOUND: True}, ), ( "filter", _compile_filter_switches_legacy, { - "config": { - "filter": { - "rule": [{"descr": "Allow", "created": {"time": "fdelay"}, "disabled": "0"}] + "firewall": { + "config": { + "filter": { + "rule": [ + {"descr": "Allow", "created": {"time": "fdelay"}, "disabled": "0"} + ] + } } } }, "first", + {CONF_SYNC_FIREWALL_AND_NAT: True}, ), ( "nat", _compile_port_forward_switches_legacy, - {"config": {"nat": {"rule": [{"descr": "PF", "created": {"time": "pdelay"}}]}}}, + { + "firewall": { + "config": {"nat": {"rule": [{"descr": "PF", "created": {"time": "pdelay"}}]}} + } + }, "first", + {CONF_SYNC_FIREWALL_AND_NAT: True}, ), ( "service", @@ -706,6 +742,7 @@ async def test_unbound_skips_update_when_delay_set(coordinator, ph_hass, make_co ] }, "first", + {CONF_SYNC_SERVICES: True}, ), ( "vpn", @@ -715,14 +752,17 @@ async def test_unbound_skips_update_when_delay_set(coordinator, ph_hass, make_co "wireguard": {"clients": {}, "servers": {}}, }, "endswith:v1", + {CONF_SYNC_VPN: True}, ), ], ) async def test_delay_skips_update_parametrized( - kind, compile_fn, state, selector, coordinator, ph_hass, make_config_entry + kind, compile_fn, state, selector, options, coordinator, ph_hass, make_config_entry ): """Parametrized test asserting handlers return early when delay_update is set.""" - config_entry = make_config_entry({CONF_DEVICE_UNIQUE_ID: "dev1", "url": "http://example"}) + config_entry = make_config_entry( + data={CONF_DEVICE_UNIQUE_ID: "dev1", "url": "http://example"}, options=options + ) setattr(config_entry.runtime_data, COORDINATOR, coordinator) coordinator.data = state @@ -797,9 +837,15 @@ async def test_switch_handle_error_sets_unavailable( if kind == "filter": # compile one valid filter entity then monkeypatch to produce error - config_entry = make_config_entry({CONF_DEVICE_UNIQUE_ID: "dev1"}) + config_entry = make_config_entry( + data={CONF_DEVICE_UNIQUE_ID: "dev1"}, options={CONF_SYNC_FIREWALL_AND_NAT: True} + ) setattr(config_entry.runtime_data, COORDINATOR, coordinator) - state = {"config": {"filter": {"rule": [{"descr": "Good", "created": {"time": "t1"}}]}}} + state = { + "firewall": { + "config": {"filter": {"rule": [{"descr": "Good", "created": {"time": "t1"}}]}} + } + } coordinator.data = state ent = (await _compile_filter_switches_legacy(config_entry, coordinator, state))[0] ent.hass = hass_local @@ -1056,19 +1102,24 @@ def new_remover(): @pytest.mark.asyncio async def test_compile_filter_skip_and_invalid_rules(coordinator, make_config_entry): """Filter compilation skips invalid and non-dict rule entries.""" - config_entry = make_config_entry({CONF_DEVICE_UNIQUE_ID: "dev1", "url": "http://example"}) + config_entry = make_config_entry( + data={CONF_DEVICE_UNIQUE_ID: "dev1", "url": "http://example"}, + options={CONF_SYNC_FIREWALL_AND_NAT: True}, + ) setattr(config_entry.runtime_data, COORDINATOR, coordinator) # include various rules that should be skipped state = { - "config": { - "filter": { - "rule": [ - {"descr": "Anti-Lockout Rule", "created": {"time": "a1"}}, - {"associated-rule-id": "x", "created": {"time": "a2"}}, - {"descr": "No tracker"}, - ["not", "a", "dict"], - {"descr": "Good", "created": {"time": "g1"}}, - ] + "firewall": { + "config": { + "filter": { + "rule": [ + {"description": "Anti-Lockout Rule", "created": {"time": "a1"}}, + {"associated-rule-id": "x", "created": {"time": "a2"}}, + {"description": "No tracker"}, + ["not", "a", "dict"], + {"description": "Good", "created": {"time": "g1"}}, + ] + } } } } @@ -1081,16 +1132,21 @@ async def test_compile_filter_skip_and_invalid_rules(coordinator, make_config_en @pytest.mark.asyncio async def test_compile_nat_outbound_skips_auto_created(coordinator, make_config_entry): """Outbound NAT compilation ignores auto-created rules.""" - config_entry = make_config_entry({CONF_DEVICE_UNIQUE_ID: "dev1", "url": "http://example"}) + config_entry = make_config_entry( + data={CONF_DEVICE_UNIQUE_ID: "dev1", "url": "http://example"}, + options={CONF_SYNC_FIREWALL_AND_NAT: True}, + ) setattr(config_entry.runtime_data, COORDINATOR, coordinator) state = { - "config": { - "nat": { - "outbound": { - "rule": [ - {"descr": "Auto created rule", "created": {"time": "x1"}}, - {"descr": "Manual", "created": {"time": "x2"}}, - ] + "firewall": { + "config": { + "nat": { + "outbound": { + "rule": [ + {"description": "Auto created rule", "created": {"time": "x1"}}, + {"description": "Manual", "created": {"time": "x2"}}, + ] + } } } } @@ -1518,7 +1574,7 @@ async def test_compile_port_forward_with_missing_rules(coordinator, make_config_ # port forward compile should return [] when nat not present or rules missing config_entry = make_config_entry({CONF_DEVICE_UNIQUE_ID: "dev1"}) setattr(config_entry.runtime_data, COORDINATOR, coordinator) - coordinator.data = {"config": {}} + coordinator.data = {"firewall": {"config": {}}} res = await _compile_port_forward_switches_legacy(config_entry, coordinator, coordinator.data) assert res == [] From 14f68b54a6114df4cd8b49a2b99f0fc3a4589715 Mon Sep 17 00:00:00 2001 From: Snuffy2 Date: Mon, 2 Feb 2026 21:33:06 -0500 Subject: [PATCH 04/15] Improve test coverage --- .pre-commit-config.yaml | 4 +- tests/test_switch.py | 344 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 343 insertions(+), 5 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fc24c081..5cca521a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,7 +12,7 @@ repos: - id: check-toml - id: check-added-large-files - repo: https://github.com/rhysd/actionlint - rev: v1.7.9 + rev: v1.7.10 hooks: - id: actionlint # Note: shellcheck cannot directly parse YAML; actionlint extracts workflow @@ -31,7 +31,7 @@ repos: additional_dependencies: - tomli - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.14.8 + rev: v0.14.14 hooks: # Run the linter. - id: ruff-check diff --git a/tests/test_switch.py b/tests/test_switch.py index 72a76b4b..55d20a92 100644 --- a/tests/test_switch.py +++ b/tests/test_switch.py @@ -25,11 +25,18 @@ from custom_components.opnsense.coordinator import OPNsenseDataUpdateCoordinator from custom_components.opnsense.switch import ( OPNsenseFilterSwitchLegacy, + OPNsenseFirewallRuleSwitch, + OPNsenseNATRuleSwitch, OPNsenseNatSwitchLegacy, OPNsenseServiceSwitch, OPNsenseVPNSwitch, _compile_filter_switches_legacy, + _compile_firewall_rules_switches, + _compile_nat_destination_rules_switches, + _compile_nat_npt_rules_switches, + _compile_nat_one_to_one_rules_switches, _compile_nat_outbound_switches_legacy, + _compile_nat_source_rules_switches, _compile_port_forward_switches_legacy, _compile_service_switches, _compile_static_unbound_switch_legacy, @@ -111,6 +118,94 @@ def make_coord(data): {"unbound_blocklist": {"legacy": {"enabled": "1"}}}, ("enable_unbound_blocklist", "disable_unbound_blocklist"), ), + ( + _compile_firewall_rules_switches, + { + "firewall": { + "rules": { + "rule1": { + "uuid": "rule1", + "description": "Test Firewall Rule", + "%interface": "wan", + "enabled": "1", + } + } + } + }, + ("toggle_firewall_rule", "toggle_firewall_rule"), + ), + ( + _compile_nat_source_rules_switches, + { + "firewall": { + "nat": { + "source_nat": { + "nat1": { + "uuid": "nat1", + "description": "Source NAT Rule", + "%interface": "wan", + "enabled": "1", + } + } + } + } + }, + ("toggle_nat_rule", "toggle_nat_rule"), + ), + ( + _compile_nat_destination_rules_switches, + { + "firewall": { + "nat": { + "d_nat": { + "dnat1": { + "uuid": "dnat1", + "description": "Destination NAT Rule", + "%interface": "wan", + "enabled": "1", + } + } + } + } + }, + ("toggle_nat_rule", "toggle_nat_rule"), + ), + ( + _compile_nat_one_to_one_rules_switches, + { + "firewall": { + "nat": { + "one_to_one": { + "oto1": { + "uuid": "oto1", + "description": "One-to-One NAT Rule", + "%interface": "wan", + "enabled": "1", + } + } + } + } + }, + ("toggle_nat_rule", "toggle_nat_rule"), + ), + ( + _compile_nat_npt_rules_switches, + { + "firewall": { + "nat": { + "npt": { + "npt1": { + "uuid": "npt1", + "description": "NPT NAT Rule", + "%interface": "wan", + "enabled": "1", + } + } + } + } + }, + ("toggle_nat_rule", "toggle_nat_rule"), + ), ], ) async def test_switch_toggle_variants( @@ -151,15 +246,33 @@ async def test_switch_toggle_variants( # call turn_on/turn_off and assert client methods called await ent.async_turn_on() - # ensure the async client coroutine was actually awaited - getattr(ent._client, client_methods[0]).assert_awaited_once() # turning on should set delay_update for entities that perform delayed updates assert ent.delay_update is True + await ent.async_turn_off() - getattr(ent._client, client_methods[1]).assert_awaited_once() # turning off should also set delay_update assert ent.delay_update is True + # Check that the correct client methods were called + if client_methods[0] == client_methods[1]: + # Same method for on/off with different parameters + if "firewall_rule" in client_methods[0]: + # toggle_firewall_rule(rule_id, action) + getattr(ent._client, client_methods[0]).assert_any_call(ent._rule_id, "on") + getattr(ent._client, client_methods[1]).assert_any_call(ent._rule_id, "off") + else: + # toggle_nat_rule(rule_type, rule_id, action) + getattr(ent._client, client_methods[0]).assert_any_call( + ent._nat_rule_type, ent._rule_id, "on" + ) + getattr(ent._client, client_methods[1]).assert_any_call( + ent._nat_rule_type, ent._rule_id, "off" + ) + else: + # Different methods for on/off + getattr(ent._client, client_methods[0]).assert_awaited_once() + getattr(ent._client, client_methods[1]).assert_awaited_once() + @pytest.mark.asyncio async def test_compile_port_forward_skips_non_dict(coordinator, make_config_entry): @@ -247,6 +360,84 @@ def fake_add_entities(entities): assert calls.get("len") == expected +@pytest.mark.asyncio +async def test_async_setup_entry_new_firewall_api(coordinator, ph_hass, make_config_entry): + """Async setup should create entities for new firewall API (>= 26.1).""" + calls = {} + + def fake_add_entities(entities): + calls["len"] = len(entities) + + # create a state that contains new firewall API structure + state = { + "firewall": { + "rules": { + "rule1": { + "uuid": "rule1", + "description": "Test Firewall Rule", + "%interface": "wan", + "enabled": "1", + } + }, + "nat": { + "source_nat": { + "nat1": { + "uuid": "nat1", + "description": "Test Source NAT", + "%interface": "wan", + "enabled": "1", + } + }, + "d_nat": { + "dnat1": { + "uuid": "dnat1", + "description": "Test Destination NAT", + "%interface": "wan", + "enabled": "1", + } + }, + "one_to_one": { + "oto1": { + "uuid": "oto1", + "description": "Test One-to-One NAT", + "%interface": "wan", + "enabled": "1", + } + }, + "npt": { + "npt1": { + "uuid": "npt1", + "description": "Test NPT NAT", + "%interface": "wan", + "enabled": "1", + } + }, + }, + }, + "host_firmware_version": "26.1.0", + } + coordinator.data = state + + config_entry = make_config_entry( + data={ + CONF_DEVICE_UNIQUE_ID: "dev1", + CONF_SYNC_FIREWALL_AND_NAT: True, + CONF_SYNC_SERVICES: False, + CONF_SYNC_VPN: False, + CONF_SYNC_UNBOUND: False, + }, + title="OPNsenseTest", + ) + setattr(config_entry.runtime_data, COORDINATOR, coordinator) + + hass = ph_hass + await switch_mod.async_setup_entry(hass, config_entry, fake_add_entities) + + # Should create entities for each rule type in new API + expected = 5 # 1 firewall rule + 4 NAT rules + assert calls.get("len") == expected + + def test_vpn_icon_property(make_config_entry): """VPN switch exposes the expected icon when available and on.""" desc = SwitchEntityDescription(key="openvpn.clients.c1", name="VPNC") @@ -1795,3 +1986,150 @@ async def test_unbound_extended_switch_toggle_failures(coordinator, ph_hass, mak await ent.async_turn_off() assert ent.is_on is True # Should remain on assert ent.delay_update is False # Should not set delay + + +@pytest.mark.asyncio +async def test_compile_firewall_rules_switches(coordinator, make_config_entry): + """Test compilation of firewall rule switches for new API.""" + config_entry = make_config_entry({CONF_DEVICE_UNIQUE_ID: "dev1"}) + state = { + "firewall": { + "rules": { + "rule1": { + "uuid": "rule1", + "description": "Test Rule 1", + "%interface": "wan", + "enabled": "1", + }, + "rule2": { + "uuid": "rule2", + "description": "Test Rule 2", + "%interface": "lan", + "enabled": "0", + }, + } + } + } + ents = await _compile_firewall_rules_switches(config_entry, coordinator, state) + assert len(ents) == 2 + assert isinstance(ents[0], OPNsenseFirewallRuleSwitch) + assert isinstance(ents[1], OPNsenseFirewallRuleSwitch) + assert ents[0].entity_description.key == "firewall.rule.rule1" + assert ents[1].entity_description.key == "firewall.rule.rule2" + + +@pytest.mark.asyncio +async def test_compile_nat_source_rules_switches(coordinator, make_config_entry): + """Test compilation of NAT source rule switches.""" + config_entry = make_config_entry({CONF_DEVICE_UNIQUE_ID: "dev1"}) + state = { + "firewall": { + "nat": { + "source_nat": { + "nat1": { + "uuid": "nat1", + "description": "Source NAT Rule", + "%interface": "wan", + "enabled": "1", + } + } + } + } + } + ents = await _compile_nat_source_rules_switches(config_entry, coordinator, state) + assert len(ents) == 1 + assert isinstance(ents[0], OPNsenseNATRuleSwitch) + assert ents[0].entity_description.key == "firewall.nat.source_nat.nat1" + + +@pytest.mark.asyncio +async def test_compile_nat_destination_rules_switches(coordinator, make_config_entry): + """Test compilation of NAT destination rule switches.""" + config_entry = make_config_entry({CONF_DEVICE_UNIQUE_ID: "dev1"}) + state = { + "firewall": { + "nat": { + "d_nat": { + "dnat1": { + "uuid": "dnat1", + "description": "Destination NAT Rule", + "%interface": "wan", + "enabled": "1", + } + } + } + } + } + ents = await _compile_nat_destination_rules_switches(config_entry, coordinator, state) + assert len(ents) == 1 + assert isinstance(ents[0], OPNsenseNATRuleSwitch) + assert ents[0].entity_description.key == "firewall.nat.d_nat.dnat1" + + +@pytest.mark.asyncio +async def test_compile_nat_one_to_one_rules_switches(coordinator, make_config_entry): + """Test compilation of NAT one-to-one rule switches.""" + config_entry = make_config_entry({CONF_DEVICE_UNIQUE_ID: "dev1"}) + state = { + "firewall": { + "nat": { + "one_to_one": { + "oto1": { + "uuid": "oto1", + "description": "One-to-One NAT Rule", + "%interface": "wan", + "enabled": "1", + } + } + } + } + } + ents = await _compile_nat_one_to_one_rules_switches(config_entry, coordinator, state) + assert len(ents) == 1 + assert isinstance(ents[0], OPNsenseNATRuleSwitch) + assert ents[0].entity_description.key == "firewall.nat.one_to_one.oto1" + + +@pytest.mark.asyncio +async def test_compile_nat_npt_rules_switches(coordinator, make_config_entry): + """Test compilation of NAT NPT rule switches.""" + config_entry = make_config_entry({CONF_DEVICE_UNIQUE_ID: "dev1"}) + state = { + "firewall": { + "nat": { + "npt": { + "npt1": { + "uuid": "npt1", + "description": "NPT NAT Rule", + "%interface": "wan", + "enabled": "1", + } + } + } + } + } + ents = await _compile_nat_npt_rules_switches(config_entry, coordinator, state) + assert len(ents) == 1 + assert isinstance(ents[0], OPNsenseNATRuleSwitch) + assert ents[0].entity_description.key == "firewall.nat.npt.npt1" + + +@pytest.mark.asyncio +async def test_compile_new_api_empty_state(coordinator, make_config_entry): + """Test compilation functions handle empty/missing state gracefully.""" + config_entry = make_config_entry({CONF_DEVICE_UNIQUE_ID: "dev1"}) + + # Test empty state + ents = await _compile_firewall_rules_switches(config_entry, coordinator, {}) + assert ents == [] + + ents = await _compile_nat_source_rules_switches(config_entry, coordinator, {}) + assert ents == [] + + # Test state with firewall but no rules + state = {"firewall": {}} + ents = await _compile_firewall_rules_switches(config_entry, coordinator, state) + assert ents == [] + + ents = await _compile_nat_source_rules_switches(config_entry, coordinator, state) + assert ents == [] From 02f17997b390e918222dcae851258af93b7cced6 Mon Sep 17 00:00:00 2001 From: Snuffy2 Date: Mon, 2 Feb 2026 21:50:08 -0500 Subject: [PATCH 05/15] Refinements --- .../opnsense/pyopnsense/__init__.py | 22 ++++--- custom_components/opnsense/switch.py | 66 +++++++++++++------ tests/test_pyopnsense.py | 14 ++-- tests/test_switch.py | 20 ++++-- 4 files changed, 80 insertions(+), 42 deletions(-) diff --git a/custom_components/opnsense/pyopnsense/__init__.py b/custom_components/opnsense/pyopnsense/__init__.py index 1051989d..6dced9b1 100644 --- a/custom_components/opnsense/pyopnsense/__init__.py +++ b/custom_components/opnsense/pyopnsense/__init__.py @@ -3,9 +3,11 @@ from abc import ABC import asyncio from collections.abc import Callable, MutableMapping +import csv from datetime import datetime, timedelta, timezone from functools import partial import inspect +from io import StringIO import ipaddress import json import logging @@ -921,6 +923,8 @@ async def disable_nat_outbound_rule_by_created_time_legacy(self, created_time: s """Disable NAT Outbound Rule.""" config: MutableMapping[str, Any] = await self.get_config() for rule in config.get("nat", {}).get("outbound", {}).get("rule", []): + if "created" not in rule or "time" not in rule["created"]: + continue if rule["created"]["time"] != created_time: continue @@ -1007,16 +1011,16 @@ async def _get_firewall_rules(self, interface_map: dict[str, Any]) -> dict[str, # _LOGGER.debug("[get_firewall_rules] response: %s", response) if not response or not isinstance(response, str): return {} - lines = response.strip().split("\n") - if len(lines) < 2: + + try: + reader = csv.DictReader(StringIO(response)) + except (csv.Error, ValueError) as e: + _LOGGER.error("Failed to parse firewall rules CSV: %s", e) return {} - headers = lines[0].split(",") - rules = [] - for line in lines[1:]: - if line.strip(): - values = line.split(",") - rule = dict(zip(headers, values, strict=True)) - rules.append(rule) + if not reader.fieldnames: + return {} + rules = [row for row in reader if row] + # _LOGGER.debug("[get_firewall_rules] rules: %s", rules) rules_dict: dict[str, Any] = {} for rule in rules: diff --git a/custom_components/opnsense/switch.py b/custom_components/opnsense/switch.py index bb2f0d6f..f7329fbb 100644 --- a/custom_components/opnsense/switch.py +++ b/custom_components/opnsense/switch.py @@ -425,6 +425,8 @@ async def _compile_firewall_rules_switches( entities: list = [] for rule in state.get("firewall", {}).get("rules", {}).values(): + if not isinstance(rule, dict): + continue entity = OPNsenseFirewallRuleSwitch( config_entry=config_entry, coordinator=coordinator, @@ -469,6 +471,8 @@ async def _compile_nat_source_rules_switches( entities: list = [] for rule in state.get("firewall", {}).get("nat", {}).get("source_nat", {}).values(): + if not isinstance(rule, dict): + continue entity = OPNsenseNATRuleSwitch( config_entry=config_entry, coordinator=coordinator, @@ -513,6 +517,8 @@ async def _compile_nat_destination_rules_switches( entities: list = [] for rule in state.get("firewall", {}).get("nat", {}).get("d_nat", {}).values(): + if not isinstance(rule, dict): + continue entity = OPNsenseNATRuleSwitch( config_entry=config_entry, coordinator=coordinator, @@ -557,6 +563,8 @@ async def _compile_nat_one_to_one_rules_switches( entities: list = [] for rule in state.get("firewall", {}).get("nat", {}).get("one_to_one", {}).values(): + if not isinstance(rule, dict): + continue entity = OPNsenseNATRuleSwitch( config_entry=config_entry, coordinator=coordinator, @@ -601,6 +609,8 @@ async def _compile_nat_npt_rules_switches( entities: list = [] for rule in state.get("firewall", {}).get("nat", {}).get("npt", {}).values(): + if not isinstance(rule, dict): + continue entity = OPNsenseNATRuleSwitch( config_entry=config_entry, coordinator=coordinator, @@ -690,7 +700,7 @@ async def async_setup_entry( ValueError, ) as e: _LOGGER.error( - "Error comparing firmware version %s when determining creating Unbound Blocklist switches. %s: %s", + "Error comparing firewall/NAT firmware version %s: %s: %s", firmware, type(e).__name__, e, @@ -906,21 +916,27 @@ async def async_turn_on(self, **kwargs: Any) -> None: """Turn the entity on.""" if self._rule_id is None or not self._client: return - await self._client.toggle_firewall_rule(self._rule_id, "on") - _LOGGER.info("Turned on firewall rule: %s", self.name) - self._attr_is_on = True - self.async_write_ha_state() - self.delay_update = True + result = await self._client.toggle_firewall_rule(self._rule_id, "on") + if result: + _LOGGER.info("Turned on firewall rule: %s", self.name) + self._attr_is_on = True + self.async_write_ha_state() + self.delay_update = True + else: + _LOGGER.error("Failed to turn on firewall rule: %s", self.name) async def async_turn_off(self, **kwargs: Any) -> None: """Turn the entity off.""" if self._rule_id is None or not self._client: return - await self._client.toggle_firewall_rule(self._rule_id, "off") - _LOGGER.info("Turned off firewall rule: %s", self.name) - self._attr_is_on = False - self.async_write_ha_state() - self.delay_update = True + result = await self._client.toggle_firewall_rule(self._rule_id, "off") + if result: + _LOGGER.info("Turned off firewall rule: %s", self.name) + self._attr_is_on = False + self.async_write_ha_state() + self.delay_update = True + else: + _LOGGER.error("Failed to turn off firewall rule: %s", self.name) @property def icon(self) -> str | None: @@ -1039,21 +1055,27 @@ async def async_turn_on(self, **kwargs: Any) -> None: """Turn the entity on.""" if self._rule_id is None or not self._client: return - await self._client.toggle_nat_rule(self._nat_rule_type, self._rule_id, "on") - _LOGGER.info("Turned on NAT rule: %s", self.name) - self._attr_is_on = True - self.async_write_ha_state() - self.delay_update = True + result = await self._client.toggle_nat_rule(self._nat_rule_type, self._rule_id, "on") + if result: + _LOGGER.info("Turned on NAT rule: %s", self.name) + self._attr_is_on = True + self.async_write_ha_state() + self.delay_update = True + else: + _LOGGER.error("Failed to turn on NAT rule: %s", self.name) async def async_turn_off(self, **kwargs: Any) -> None: """Turn the entity off.""" if self._rule_id is None or not self._client: return - await self._client.toggle_nat_rule(self._nat_rule_type, self._rule_id, "off") - _LOGGER.info("Turned off NAT rule: %s", self.name) - self._attr_is_on = False - self.async_write_ha_state() - self.delay_update = True + result = await self._client.toggle_nat_rule(self._nat_rule_type, self._rule_id, "off") + if result: + _LOGGER.info("Turned off NAT rule: %s", self.name) + self._attr_is_on = False + self.async_write_ha_state() + self.delay_update = True + else: + _LOGGER.error("Failed to turn off NAT rule: %s", self.name) @property def icon(self) -> str | None: @@ -1267,6 +1289,8 @@ def _handle_coordinator_update(self) -> None: return self._rule = self._opnsense_get_rule() if not isinstance(self._rule, MutableMapping): + self._available = False + self.async_write_ha_state() return try: self._attr_is_on = "disabled" not in self._rule diff --git a/tests/test_pyopnsense.py b/tests/test_pyopnsense.py index 465a26ba..a71b7869 100644 --- a/tests/test_pyopnsense.py +++ b/tests/test_pyopnsense.py @@ -25,6 +25,7 @@ sensor as sensor_mod, switch as switch_mod, ) +from custom_components.opnsense.const import CONF_SYNC_FIREWALL_AND_NAT def test_human_friendly_duration() -> None: @@ -1319,7 +1320,7 @@ async def test_enable_and_disable_filter_rules_and_nat_port_forward(make_client) client._restore_config_section = AsyncMock() client._filter_configure = AsyncMock() await client.enable_filter_rule_by_created_time_legacy("t-enable") - client._restore_config_section.assert_called() + client._restore_config_section.assert_awaited() client._filter_configure.assert_awaited() # disable_filter_rule_by_created_time_legacy: rule missing 'disabled' -> should add it and call restore+configure @@ -1328,7 +1329,7 @@ async def test_enable_and_disable_filter_rules_and_nat_port_forward(make_client) client._restore_config_section = AsyncMock() client._filter_configure = AsyncMock() await client.disable_filter_rule_by_created_time_legacy("t-disable") - client._restore_config_section.assert_called() + client._restore_config_section.assert_awaited() client._filter_configure.assert_awaited() # enable_nat_port_forward_rule_by_created_time_legacy: similar flow under 'nat' section @@ -1337,7 +1338,7 @@ async def test_enable_and_disable_filter_rules_and_nat_port_forward(make_client) client._restore_config_section = AsyncMock() client._filter_configure = AsyncMock() await client.enable_nat_port_forward_rule_by_created_time_legacy("t-nat") - client._restore_config_section.assert_called() + client._restore_config_section.assert_awaited() client._filter_configure.assert_awaited() @@ -1521,7 +1522,7 @@ def add_entities(ents): switch_cfg = make_config_entry( data={ "device_unique_id": "dev1", - "sync_filters_and_nat": True, + CONF_SYNC_FIREWALL_AND_NAT: True, "sync_unbound": False, "sync_vpn": False, "sync_services": False, @@ -2697,9 +2698,12 @@ async def test_get_config_and_rule_enable_disable_branches() -> None: } client._exec_php = AsyncMock(return_value=fake_config) - + client._restore_config_section = AsyncMock() + client._filter_configure = AsyncMock() # calling enable should remove 'disabled' and call restore/configure (no exception) await client.enable_filter_rule_by_created_time_legacy("t1") + client._restore_config_section.assert_awaited() + client._filter_configure.assert_awaited() # disable_nat_port_forward: add a rule without 'disabled' and expect it to set 'disabled' client._exec_php = AsyncMock( diff --git a/tests/test_switch.py b/tests/test_switch.py index 55d20a92..e4ce9d9e 100644 --- a/tests/test_switch.py +++ b/tests/test_switch.py @@ -258,15 +258,21 @@ async def test_switch_toggle_variants( # Same method for on/off with different parameters if "firewall_rule" in client_methods[0]: # toggle_firewall_rule(rule_id, action) - getattr(ent._client, client_methods[0]).assert_any_call(ent._rule_id, "on") - getattr(ent._client, client_methods[1]).assert_any_call(ent._rule_id, "off") + getattr(ent._client, client_methods[0]).assert_has_awaits( + [ + ((ent._rule_id, "on"), {}), + ((ent._rule_id, "off"), {}), + ], + any_order=True, + ) else: # toggle_nat_rule(rule_type, rule_id, action) - getattr(ent._client, client_methods[0]).assert_any_call( - ent._nat_rule_type, ent._rule_id, "on" - ) - getattr(ent._client, client_methods[1]).assert_any_call( - ent._nat_rule_type, ent._rule_id, "off" + getattr(ent._client, client_methods[0]).assert_has_awaits( + [ + ((ent._nat_rule_type, ent._rule_id, "on"), {}), + ((ent._nat_rule_type, ent._rule_id, "off"), {}), + ], + any_order=True, ) else: # Different methods for on/off From 347b0d875344386a73b116ec9a24e8361ff8c9fc Mon Sep 17 00:00:00 2001 From: Snuffy2 Date: Mon, 2 Feb 2026 22:15:03 -0500 Subject: [PATCH 06/15] Improve pytest coverage --- tests/test_pyopnsense.py | 392 ++++++++++++++++++++++++++++++++++++++- tests/test_switch.py | 44 +++++ 2 files changed, 434 insertions(+), 2 deletions(-) diff --git a/tests/test_pyopnsense.py b/tests/test_pyopnsense.py index a71b7869..65c12b7d 100644 --- a/tests/test_pyopnsense.py +++ b/tests/test_pyopnsense.py @@ -16,6 +16,7 @@ from xmlrpc.client import Fault import aiohttp +import awesomeversion import pytest from yarl import URL @@ -210,6 +211,26 @@ async def test_get_host_firmware_set_use_snake_case_and_plugin_installed(make_cl await client.set_use_snake_case() assert client._use_snake_case is True + # set use snake case should detect <25.7 + client._firmware_version = "25.1.0" + await client.set_use_snake_case() + assert client._use_snake_case is False + + # test AwesomeVersionCompareException handling + original_compare = awesomeversion.AwesomeVersion.__lt__ + + def mock_compare(self, other): + raise awesomeversion.exceptions.AwesomeVersionCompareException("test exception") + + awesomeversion.AwesomeVersion.__lt__ = mock_compare + try: + client._firmware_version = "25.8.0" + await client.set_use_snake_case() + # Should default to True on exception + assert client._use_snake_case is True + finally: + awesomeversion.AwesomeVersion.__lt__ = original_compare + # invalid semver -> fallback to product_series client._safe_dict_get = AsyncMock( return_value={"product": {"product_version": "weird", "product_series": "seriesX"}} @@ -928,12 +949,13 @@ async def test_exec_php_error_paths(exc_factory, initial: bool, make_client) -> [ ("_do_get", "get", ("/api/x",), {"caller": "tst"}), ("_do_post", "post", ("/api/x",), {"payload": {}}), + ("_do_get_raw", "get", ("/api/x",), {"caller": "tst"}), ], ) -async def test_do_get_post_error_initial_behavior( +async def test_do_get_post_get_raw_error_initial_behavior( method_name, session_method, args, kwargs, make_client ) -> None: - """When client._initial is True, non-ok responses should raise ClientResponseError for _do_get/_do_post.""" + """When client._initial is True, non-ok responses should raise ClientResponseError for _do_get/_do_post/_do_get_raw.""" session = MagicMock(spec=aiohttp.ClientSession) # create a fake response context manager @@ -960,6 +982,9 @@ async def __aexit__(self, exc_type, exc, tb): async def json(self, content_type=None): return {"x": 1} + async def text(self): + return "raw response text" + @property def content(self): class C: @@ -982,6 +1007,36 @@ async def iter_chunked(self, n): await client.async_close() +@pytest.mark.asyncio +async def test_do_get_raw_client_error_initial_behavior(make_client) -> None: + """When client._initial is True, aiohttp.ClientError should be re-raised for _do_get_raw.""" + session = MagicMock(spec=aiohttp.ClientSession) + session.get.side_effect = aiohttp.ClientError("Connection failed") + + client = make_client(session=session) + client._initial = True + try: + with pytest.raises(aiohttp.ClientError): + await client._do_get_raw("/api/x", caller="tst") + finally: + await client.async_close() + + +@pytest.mark.asyncio +async def test_do_get_raw_client_error_non_initial_behavior(make_client) -> None: + """When client._initial is False, aiohttp.ClientError should be logged and None returned for _do_get_raw.""" + session = MagicMock(spec=aiohttp.ClientSession) + session.get.side_effect = aiohttp.ClientError("Connection failed") + + client = make_client(session=session) + client._initial = False + try: + result = await client._do_get_raw("/api/x", caller="tst") + assert result is None + finally: + await client.async_close() + + @pytest.mark.asyncio async def test_get_from_stream_parsing(make_client, fake_stream_response_factory) -> None: """Simulate SSE-like stream with two messages and assert parsing returns dict.""" @@ -3182,3 +3237,336 @@ async def test_get_device_unique_id_no_mac(make_client) -> None: client._safe_list_get = AsyncMock(return_value=[{"is_physical": False}]) assert await client.get_device_unique_id() is None await client.async_close() + + +@pytest.mark.asyncio +async def test_get_firewall_legacy_fallback(make_client) -> None: + """get_firewall falls back to legacy config for OPNsense < 26.1.""" + session = MagicMock(spec=aiohttp.ClientSession) + client = make_client(session=session) + client._firmware_version = "25.7.0" + + # Mock get_config for legacy fallback + client.get_config = AsyncMock(return_value={"filter": {"rule": []}}) + + result = await client.get_firewall() + assert result == {"config": {"filter": {"rule": []}}} + client.get_config.assert_awaited_once() + await client.async_close() + + +@pytest.mark.asyncio +async def test_get_firewall_new_api(make_client) -> None: + """get_firewall uses new API for OPNsense >= 26.1.""" + session = MagicMock(spec=aiohttp.ClientSession) + client = make_client(session=session) + client._firmware_version = "26.1.0" + + # Mock all the methods called in the new API path + client.is_plugin_installed = AsyncMock(return_value=True) + client.get_config = AsyncMock(return_value={"filter": {"rule": []}}) + client._get_interface_firewall_map = AsyncMock(return_value={"lan": "LAN"}) + client._get_firewall_rules = AsyncMock(return_value={"rule1": {"uuid": "rule1"}}) + client._get_nat_destination_rules = AsyncMock(return_value={"nat1": {"uuid": "nat1"}}) + client._get_nat_one_to_one_rules = AsyncMock(return_value={"one1": {"uuid": "one1"}}) + client._get_nat_source_rules = AsyncMock(return_value={"src1": {"uuid": "src1"}}) + client._get_nat_npt_rules = AsyncMock(return_value={"npt1": {"uuid": "npt1"}}) + + result = await client.get_firewall() + expected = { + "config": {"filter": {"rule": []}}, + "rules": {"rule1": {"uuid": "rule1"}}, + "nat": { + "d_nat": {"nat1": {"uuid": "nat1"}}, + "one_to_one": {"one1": {"uuid": "one1"}}, + "source_nat": {"src1": {"uuid": "src1"}}, + "npt": {"npt1": {"uuid": "npt1"}}, + }, + } + assert result == expected + client.is_plugin_installed.assert_awaited_once() + client.get_config.assert_awaited_once() + client._get_interface_firewall_map.assert_awaited_once() + client._get_firewall_rules.assert_awaited_once() + client._get_nat_destination_rules.assert_awaited_once() + client._get_nat_one_to_one_rules.assert_awaited_once() + client._get_nat_source_rules.assert_awaited_once() + client._get_nat_npt_rules.assert_awaited_once() + await client.async_close() + + +@pytest.mark.asyncio +async def test_get_firewall_version_compare_exception(make_client) -> None: + """get_firewall handles AwesomeVersionCompareException gracefully.""" + session = MagicMock(spec=aiohttp.ClientSession) + client = make_client(session=session) + client._firmware_version = "invalid" + + result = await client.get_firewall() + assert result == {} + await client.async_close() + + +@pytest.mark.asyncio +async def test_get_firewall_rules_successful_parsing(make_client) -> None: + """_get_firewall_rules successfully parses valid CSV data.""" + session = MagicMock(spec=aiohttp.ClientSession) + client = make_client(session=session) + + # Mock CSV response with valid firewall rules + csv_data = """@uuid,enabled,action,interface,descr +rule1,1,pass,lan,Allow HTTP +rule2,0,block,wan,Block traffic +""" + client._get_raw = AsyncMock(return_value=csv_data) + + interface_map = {"lan": "LAN", "wan": "WAN"} + result = await client._get_firewall_rules(interface_map) + + expected = { + "rule1": { + "uuid": "rule1", + "enabled": "1", + "action": "pass", + "interface": "lan", + "%interface": "LAN", + "descr": "Allow HTTP", + }, + "rule2": { + "uuid": "rule2", + "enabled": "0", + "action": "block", + "interface": "wan", + "%interface": "WAN", + "descr": "Block traffic", + }, + } + assert result == expected + client._get_raw.assert_awaited_once_with("/api/firewall/filter/download_rules") + await client.async_close() + + +@pytest.mark.asyncio +async def test_get_firewall_rules_none_response(make_client) -> None: + """_get_firewall_rules returns empty dict when response is None.""" + session = MagicMock(spec=aiohttp.ClientSession) + client = make_client(session=session) + + client._get_raw = AsyncMock(return_value=None) + + result = await client._get_firewall_rules({}) + assert result == {} + await client.async_close() + + +@pytest.mark.asyncio +async def test_get_firewall_rules_non_string_response(make_client) -> None: + """_get_firewall_rules returns empty dict when response is not a string.""" + session = MagicMock(spec=aiohttp.ClientSession) + client = make_client(session=session) + + client._get_raw = AsyncMock(return_value=123) # Non-string response + + result = await client._get_firewall_rules({}) + assert result == {} + await client.async_close() + + +@pytest.mark.asyncio +async def test_get_firewall_rules_csv_parsing_error(make_client) -> None: + """_get_firewall_rules returns empty dict when CSV parsing fails.""" + session = MagicMock(spec=aiohttp.ClientSession) + client = make_client(session=session) + + # Invalid CSV that will cause parsing error - unterminated quote + client._get_raw = AsyncMock(return_value='"unterminated,quote\nvalue1,value2') + + result = await client._get_firewall_rules({}) + assert result == {} + await client.async_close() + + +@pytest.mark.asyncio +async def test_get_firewall_rules_empty_csv(make_client) -> None: + """_get_firewall_rules returns empty dict when CSV has no data rows.""" + session = MagicMock(spec=aiohttp.ClientSession) + client = make_client(session=session) + + # CSV with only headers, no data rows + client._get_raw = AsyncMock(return_value="@uuid,enabled,action\n") + + result = await client._get_firewall_rules({}) + assert result == {} + await client.async_close() + + +@pytest.mark.asyncio +async def test_get_firewall_rules_interface_mapping(make_client) -> None: + """_get_firewall_rules handles interface mapping correctly.""" + session = MagicMock(spec=aiohttp.ClientSession) + client = make_client(session=session) + + # CSV with interfaces, some in map, some not + csv_data = """@uuid,enabled,interface,descr +rule1,1,lan,LAN rule +rule2,1,opt1,OPT1 rule +rule3,1,unknown,Unknown interface +""" + client._get_raw = AsyncMock(return_value=csv_data) + + interface_map = {"lan": "LAN", "opt1": "OPT1"} # opt1 in map, unknown not in map + result = await client._get_firewall_rules(interface_map) + + assert result["rule1"]["%interface"] == "LAN" # Mapped + assert result["rule2"]["%interface"] == "OPT1" # Mapped + assert result["rule3"]["%interface"] == "unknown" # Not mapped, uses original + await client.async_close() + + +@pytest.mark.asyncio +async def test_get_firewall_rules_missing_uuid(make_client) -> None: + """_get_firewall_rules handles rules without @uuid field.""" + session = MagicMock(spec=aiohttp.ClientSession) + client = make_client(session=session) + + # CSV with missing @uuid field + csv_data = """enabled,action,descr +1,pass,Rule without UUID +""" + client._get_raw = AsyncMock(return_value=csv_data) + + result = await client._get_firewall_rules({}) + + # Should have empty string as UUID + assert result[""]["uuid"] == "" + assert result[""]["enabled"] == "1" + assert result[""]["action"] == "pass" + await client.async_close() + assert result[""]["uuid"] == "" + assert result[""]["enabled"] == "1" + assert result[""]["action"] == "pass" + await client.async_close() + + +@pytest.mark.parametrize( + ("method_name", "api_endpoint", "has_transformations"), + [ + ("_get_nat_destination_rules", "/api/firewall/d_nat/search_rule", True), + ("_get_nat_one_to_one_rules", "/api/firewall/one_to_one/search_rule", False), + ("_get_nat_source_rules", "/api/firewall/source_nat/search_rule", False), + ("_get_nat_npt_rules", "/api/firewall/npt/search_rule", False), + ], +) +@pytest.mark.parametrize( + ("test_case", "mock_response", "expected_result"), + [ + ( + "successful_parsing", + { + "rows": [ + { + "uuid": "test-rule-1", + "descr": "Test rule 1", + "disabled": "0", + "interface": "wan", + "protocol": "tcp", + }, + { + "uuid": "test-rule-2", + "descr": "Test rule 2", + "disabled": "1", + "interface": "lan", + "protocol": "udp", + }, + ] + }, + { + "test-rule-1": { + "uuid": "test-rule-1", + "description": "Test rule 1", # transformed + "enabled": "1", # transformed + "interface": "wan", + "protocol": "tcp", + }, + "test-rule-2": { + "uuid": "test-rule-2", + "description": "Test rule 2", # transformed + "enabled": "0", # transformed + "interface": "lan", + "protocol": "udp", + }, + }, + ), + ( + "filters_lockout_rules", + { + "rows": [ + {"uuid": "normal-rule", "descr": "Normal rule", "disabled": "0"}, + { + "uuid": "lockout-rule", + "descr": "Lockout rule", + "disabled": "0", + }, # Should be filtered + { + "uuid": "another-lockout", + "descr": "Another lockout", + "disabled": "0", + }, # Should be filtered + {"uuid": None, "descr": "No UUID rule", "disabled": "0"}, # Should be filtered + ] + }, + { + "normal-rule": { + "uuid": "normal-rule", + "description": "Normal rule", # transformed + "enabled": "1", # transformed + } + }, + ), + ("empty_response", {}, {}), + ("response_without_rows", {"some_other_key": "value"}, {}), + ], +) +@pytest.mark.asyncio +async def test_nat_rules_parsing( + make_client, + method_name, + api_endpoint, + has_transformations, + test_case, + mock_response, + expected_result, +) -> None: + """Test NAT rules parsing for all NAT rule types.""" + session = MagicMock(spec=aiohttp.ClientSession) + client = make_client(session=session) + + client._safe_dict_post = AsyncMock(return_value=mock_response) + + # Call the appropriate method + method = getattr(client, method_name) + result = await method() + + # For non-transformed methods, adjust expected result + if not has_transformations and test_case == "successful_parsing": + # Remove transformations from expected result + for rule in expected_result.values(): + if "description" in rule: + rule["descr"] = rule.pop("description") + if "enabled" in rule: + rule.pop("enabled") + rule["disabled"] = "0" if rule.get("uuid") == "test-rule-1" else "1" + + if not has_transformations and test_case == "filters_lockout_rules": + # Remove transformations from expected result + for rule in expected_result.values(): + if "description" in rule: + rule["descr"] = rule.pop("description") + if "enabled" in rule: + rule.pop("enabled") + rule["disabled"] = "0" + + assert result == expected_result + + # Verify the correct API endpoint was called + client._safe_dict_post.assert_called_with(api_endpoint, payload={"current": 1, "sort": {}}) diff --git a/tests/test_switch.py b/tests/test_switch.py index e4ce9d9e..6dd6802a 100644 --- a/tests/test_switch.py +++ b/tests/test_switch.py @@ -280,6 +280,50 @@ async def test_switch_toggle_variants( getattr(ent._client, client_methods[1]).assert_awaited_once() +@pytest.mark.asyncio +async def test_compile_filter_switches_legacy_skip_conditions(make_config_entry) -> None: + """Test that _compile_filter_switches_legacy properly skips invalid rules.""" + + config_entry = make_config_entry( + data={CONF_DEVICE_UNIQUE_ID: "dev1", "url": "http://example"}, + title="OPNsenseTest", + ) + coordinator = make_coord({}) + + # Test with non-dict rules + state_with_mixed_rules = { + "firewall": { + "config": { + "filter": { + "rule": [ + {"descr": "Valid Rule", "created": {"time": "t1"}}, + "invalid_string_rule", # should be skipped + { + "descr": "NAT Rule", + "associated-rule-id": "nat1", + "created": {"time": "t2"}, + }, # should be skipped + { + "description": "Anti-Lockout Rule", + "created": {"time": "t3"}, + }, # should be skipped + {"descr": "No Tracker"}, # should be skipped + {"descr": "Empty Tracker", "created": {"time": ""}}, # should be skipped + ] + } + } + } + } + + entities = await _compile_filter_switches_legacy( + config_entry, coordinator, state_with_mixed_rules + ) + # Only the first valid rule should be included + assert len(entities) == 1 + assert entities[0].entity_description.key == "filter.t1" + assert "Valid Rule" in entities[0].entity_description.name + + @pytest.mark.asyncio async def test_compile_port_forward_skips_non_dict(coordinator, make_config_entry): """Port forward compilation should skip non-dict rule entries.""" From 9c15168f400560a3e8e08a41932f085bffecb9c9 Mon Sep 17 00:00:00 2001 From: Snuffy2 Date: Mon, 2 Feb 2026 22:51:52 -0500 Subject: [PATCH 07/15] Refinements --- .../opnsense/pyopnsense/__init__.py | 6 +- tests/test_pyopnsense.py | 152 ++++++++---------- 2 files changed, 74 insertions(+), 84 deletions(-) diff --git a/custom_components/opnsense/pyopnsense/__init__.py b/custom_components/opnsense/pyopnsense/__init__.py index 6dced9b1..603ff803 100644 --- a/custom_components/opnsense/pyopnsense/__init__.py +++ b/custom_components/opnsense/pyopnsense/__init__.py @@ -1026,6 +1026,8 @@ async def _get_firewall_rules(self, interface_map: dict[str, Any]) -> dict[str, for rule in rules: new_rule = rule.copy() new_rule["uuid"] = new_rule.pop("@uuid", "") + if not new_rule["uuid"]: + continue new_rule["%interface"] = interface_map.get( new_rule.get("interface", ""), new_rule.get("interface", "") ) @@ -1177,7 +1179,7 @@ async def toggle_firewall_rule(self, uuid: str, toggle_on_off: str | None = None return False apply_resp = await self._safe_dict_post("/api/firewall/filter/apply") - if apply_resp.get("status") != "OK\n\n": + if apply_resp.get("status", "").strip() != "OK": return False return True @@ -1229,7 +1231,7 @@ async def toggle_nat_rule( return False apply_resp = await self._safe_dict_post(f"/api/firewall/{nat_rule_type}/apply") - if apply_resp.get("status") != "OK\n\n": + if apply_resp.get("status", "").strip() != "OK": return False return True diff --git a/tests/test_pyopnsense.py b/tests/test_pyopnsense.py index 65c12b7d..96c5f8f5 100644 --- a/tests/test_pyopnsense.py +++ b/tests/test_pyopnsense.py @@ -7,10 +7,12 @@ import asyncio import contextlib +import copy from datetime import datetime, timedelta import inspect as _inspect import socket from ssl import SSLError +from typing import Any from unittest.mock import AsyncMock, MagicMock import xmlrpc.client as xc from xmlrpc.client import Fault @@ -195,7 +197,9 @@ async def test_opnsenseclient_async_close(make_client) -> None: @pytest.mark.asyncio -async def test_get_host_firmware_set_use_snake_case_and_plugin_installed(make_client) -> None: +async def test_get_host_firmware_set_use_snake_case_and_plugin_installed( + monkeypatch, make_client +) -> None: """Ensure firmware parsing, snake_case detection and plugin detection work.""" # create client/session for this test session = MagicMock(spec=aiohttp.ClientSession) @@ -217,19 +221,14 @@ async def test_get_host_firmware_set_use_snake_case_and_plugin_installed(make_cl assert client._use_snake_case is False # test AwesomeVersionCompareException handling - original_compare = awesomeversion.AwesomeVersion.__lt__ - def mock_compare(self, other): raise awesomeversion.exceptions.AwesomeVersionCompareException("test exception") - awesomeversion.AwesomeVersion.__lt__ = mock_compare - try: - client._firmware_version = "25.8.0" - await client.set_use_snake_case() - # Should default to True on exception - assert client._use_snake_case is True - finally: - awesomeversion.AwesomeVersion.__lt__ = original_compare + monkeypatch.setattr(awesomeversion.AwesomeVersion, "__lt__", mock_compare) + client._firmware_version = "25.8.0" + await client.set_use_snake_case() + # Should default to True on exception + assert client._use_snake_case is True # invalid semver -> fallback to product_series client._safe_dict_get = AsyncMock( @@ -3437,14 +3436,8 @@ async def test_get_firewall_rules_missing_uuid(make_client) -> None: result = await client._get_firewall_rules({}) - # Should have empty string as UUID - assert result[""]["uuid"] == "" - assert result[""]["enabled"] == "1" - assert result[""]["action"] == "pass" - await client.async_close() - assert result[""]["uuid"] == "" - assert result[""]["enabled"] == "1" - assert result[""]["action"] == "pass" + # Rules without @uuid are skipped, so result should be empty + assert result == {} await client.async_close() @@ -3458,40 +3451,22 @@ async def test_get_firewall_rules_missing_uuid(make_client) -> None: ], ) @pytest.mark.parametrize( - ("test_case", "mock_response", "expected_result"), + ("test_case", "expected_result"), [ ( "successful_parsing", - { - "rows": [ - { - "uuid": "test-rule-1", - "descr": "Test rule 1", - "disabled": "0", - "interface": "wan", - "protocol": "tcp", - }, - { - "uuid": "test-rule-2", - "descr": "Test rule 2", - "disabled": "1", - "interface": "lan", - "protocol": "udp", - }, - ] - }, { "test-rule-1": { "uuid": "test-rule-1", - "description": "Test rule 1", # transformed - "enabled": "1", # transformed + "description": "Test rule 1", + "enabled": "1", "interface": "wan", "protocol": "tcp", }, "test-rule-2": { "uuid": "test-rule-2", - "description": "Test rule 2", # transformed - "enabled": "0", # transformed + "description": "Test rule 2", + "enabled": "0", "interface": "lan", "protocol": "udp", }, @@ -3499,32 +3474,16 @@ async def test_get_firewall_rules_missing_uuid(make_client) -> None: ), ( "filters_lockout_rules", - { - "rows": [ - {"uuid": "normal-rule", "descr": "Normal rule", "disabled": "0"}, - { - "uuid": "lockout-rule", - "descr": "Lockout rule", - "disabled": "0", - }, # Should be filtered - { - "uuid": "another-lockout", - "descr": "Another lockout", - "disabled": "0", - }, # Should be filtered - {"uuid": None, "descr": "No UUID rule", "disabled": "0"}, # Should be filtered - ] - }, { "normal-rule": { "uuid": "normal-rule", - "description": "Normal rule", # transformed - "enabled": "1", # transformed + "description": "Normal rule", + "enabled": "1", } }, ), - ("empty_response", {}, {}), - ("response_without_rows", {"some_other_key": "value"}, {}), + ("empty_response", {}), + ("response_without_rows", {}), ], ) @pytest.mark.asyncio @@ -3534,39 +3493,68 @@ async def test_nat_rules_parsing( api_endpoint, has_transformations, test_case, - mock_response, expected_result, ) -> None: """Test NAT rules parsing for all NAT rule types.""" session = MagicMock(spec=aiohttp.ClientSession) client = make_client(session=session) + # Build API-style mock response depending on whether the endpoint uses + # transformations (d_nat-like endpoints use 'descr'/'disabled'). + mock_response: dict[str, Any] + if test_case == "empty_response": + mock_response = {} + elif test_case == "response_without_rows": + mock_response = {"some_other_key": "value"} + else: + normalized_rows: list[dict[str, Any]] = [] + extra_rows: list[dict[str, Any]] = [] + if test_case == "successful_parsing": + for uid, info in expected_result.items(): + row = {"uuid": uid} + row["description"] = info.get("description") + row["enabled"] = info.get("enabled") + if "interface" in info: + row["interface"] = info.get("interface") + if "protocol" in info: + row["protocol"] = info.get("protocol") + normalized_rows.append(row) + elif test_case == "filters_lockout_rules": + normalized_rows = [ + {"uuid": "normal-rule", "description": "Normal rule", "enabled": "1"} + ] + extra_rows = [ + {"uuid": "lockout-rule", "description": "Lockout rule", "enabled": "1"}, + {"uuid": "another-lockout", "description": "Another lockout", "enabled": "1"}, + {"uuid": None, "description": "No UUID rule", "enabled": "1"}, + ] + + api_rows: list[dict[str, Any]] = [] + for row in normalized_rows + extra_rows: + if has_transformations: + new_row = row.copy() + if "description" in new_row: + new_row["descr"] = new_row.pop("description") + if "enabled" in new_row: + new_row["disabled"] = "0" if new_row.pop("enabled") == "1" else "1" + api_rows.append(new_row) + else: + api_rows.append(row.copy()) + + mock_response = {"rows": api_rows} + client._safe_dict_post = AsyncMock(return_value=mock_response) # Call the appropriate method method = getattr(client, method_name) result = await method() - # For non-transformed methods, adjust expected result - if not has_transformations and test_case == "successful_parsing": - # Remove transformations from expected result - for rule in expected_result.values(): - if "description" in rule: - rule["descr"] = rule.pop("description") - if "enabled" in rule: - rule.pop("enabled") - rule["disabled"] = "0" if rule.get("uuid") == "test-rule-1" else "1" - - if not has_transformations and test_case == "filters_lockout_rules": - # Remove transformations from expected result - for rule in expected_result.values(): - if "description" in rule: - rule["descr"] = rule.pop("description") - if "enabled" in rule: - rule.pop("enabled") - rule["disabled"] = "0" + # Make a deep copy of expected_result so we don't mutate the shared fixture + expected = copy.deepcopy(expected_result) - assert result == expected_result + assert result == expected # Verify the correct API endpoint was called client._safe_dict_post.assert_called_with(api_endpoint, payload={"current": 1, "sort": {}}) + + await client.async_close() From 460c4c224c677be4972ae42f2f6b9675411b3123 Mon Sep 17 00:00:00 2001 From: Snuffy2 Date: Tue, 3 Feb 2026 17:41:36 -0500 Subject: [PATCH 08/15] Check for plugin --- custom_components/opnsense/coordinator.py | 1 + custom_components/opnsense/switch.py | 8 +++++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/custom_components/opnsense/coordinator.py b/custom_components/opnsense/coordinator.py index c8936654..8bacadde 100644 --- a/custom_components/opnsense/coordinator.py +++ b/custom_components/opnsense/coordinator.py @@ -114,6 +114,7 @@ def _build_categories(self) -> list[MutableMapping[str, str]]: "function": "get_host_firmware_version", "state_key": "host_firmware_version", }, + {"function": "is_plugin_installed", "state_key": "plugin_installed"}, ] if config.get(CONF_SYNC_TELEMETRY, DEFAULT_SYNC_OPTION_VALUE): diff --git a/custom_components/opnsense/switch.py b/custom_components/opnsense/switch.py index f7329fbb..12203203 100644 --- a/custom_components/opnsense/switch.py +++ b/custom_components/opnsense/switch.py @@ -671,9 +671,11 @@ async def async_setup_entry( ) ) else: - entities.extend( - await _compile_filter_switches_legacy(config_entry, coordinator, state) - ) + # TODO: Also disable once OPNsense 26.1.x drops support for the plugin + if state.get("plugin_installed", False) is True: + entities.extend( + await _compile_filter_switches_legacy(config_entry, coordinator, state) + ) entities.extend( await _compile_firewall_rules_switches(config_entry, coordinator, state) ) From 3cd6ea48876406e1e2df41e7669aed1c1ee53974 Mon Sep 17 00:00:00 2001 From: Snuffy2 Date: Tue, 3 Feb 2026 22:00:52 -0500 Subject: [PATCH 09/15] Implement plugin cleanup --- custom_components/opnsense/__init__.py | 71 ++++++++++++++++++- .../opnsense/translations/en.json | 10 ++- 2 files changed, 79 insertions(+), 2 deletions(-) diff --git a/custom_components/opnsense/__init__.py b/custom_components/opnsense/__init__.py index 9af50965..6cfe52b9 100644 --- a/custom_components/opnsense/__init__.py +++ b/custom_components/opnsense/__init__.py @@ -229,6 +229,14 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: except awesomeversion.exceptions.AwesomeVersionCompareException: _LOGGER.warning("Unable to confirm OPNsense Firmware version") + try: + if awesomeversion.AwesomeVersion(firmware) > awesomeversion.AwesomeVersion( + "25.10" + ) and awesomeversion.AwesomeVersion(firmware) < awesomeversion.AwesomeVersion("26.7"): + await _deprecated_plugin_cleanup_26_1(hass=hass, client=client, entry_id=entry.entry_id) + except awesomeversion.exceptions.AwesomeVersionCompareException: + _LOGGER.warning("Unable to confirm OPNsense Firmware version") + await coordinator.async_config_entry_first_refresh() platforms: list[Platform] = PLATFORMS.copy() @@ -272,6 +280,67 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: return True +async def _deprecated_plugin_cleanup_26_1( + hass: HomeAssistant, client: OPNsenseClient, entry_id: str +) -> bool: + _LOGGER.debug("Starting OPNsense 26.1 and Plugin cleanup") + entity_registry = er.async_get(hass) + plugin_installed: bool = await client.is_plugin_installed() + cleanup_started: bool = False + + for ent in er.async_entries_for_config_entry(entity_registry, entry_id): + platform = ent.entity_id.split(".")[0] + if platform != Platform.SWITCH: + continue + # _LOGGER.debug("[deprecated_plugin_cleanup] ent: %s", ent) + if ( + (not plugin_installed and "_filter_" in ent.unique_id) + or "_nat_port_forward_" in ent.unique_id + or "_nat_outbound_" in ent.unique_id + ): + cleanup_started = True + try: + entity_registry.async_remove(ent.entity_id) + _LOGGER.debug("[deprecated_plugin_cleanup] removed entity_id: %s", ent.entity_id) + except (KeyError, ValueError) as e: + _LOGGER.error( + "Error removing entity: %s. %s: %s", + ent.entity_id, + type(e).__name__, + e, + ) + if cleanup_started: + if plugin_installed: + _LOGGER.info( + "OPNsense 26.1 and Plugin cleanup partially completed. Plugin is still installed. NAT Outbound and NAT Port Forward rules removed. Firewall Filter rules will be removed once the plugin is removed." + ) + ir.async_create_issue( + hass, + DOMAIN, + "plugin_cleanup_partial", + is_fixable=False, + is_persistent=False, + issue_domain=DOMAIN, + severity=ir.IssueSeverity.WARNING, + translation_key="plugin_cleanup_partial", + ) + else: + _LOGGER.info( + "OPNsense 26.1 and Plugin cleanup completed. NAT Outbound, NAT Port Forward, and Firewall Filter rules removed." + ) + ir.async_create_issue( + hass, + DOMAIN, + "plugin_cleanup_done", + is_fixable=False, + is_persistent=False, + issue_domain=DOMAIN, + severity=ir.IssueSeverity.WARNING, + translation_key="plugin_cleanup_done", + ) + return True + + async def async_remove_config_entry_device( hass: HomeAssistant, config_entry: ConfigEntry, device_entry: dr.DeviceEntry ) -> bool: @@ -454,7 +523,7 @@ async def _migrate_3_to_4(hass: HomeAssistant, config_entry: ConfigEntry) -> boo for ent in er.async_entries_for_config_entry(entity_registry, config_entry.entry_id): platform = ent.entity_id.split(".")[0] if platform == Platform.SENSOR: - # _LOGGER.debug(f"[migrate_3_to_4] ent: {ent}") + # _LOGGER.debug("[migrate_3_to_4] ent: %s", ent) if "_telemetry_interface_" in ent.unique_id: new_unique_id: str | None = ent.unique_id.replace( "_telemetry_interface_", "_interface_" diff --git a/custom_components/opnsense/translations/en.json b/custom_components/opnsense/translations/en.json index 4f83ac67..229214b5 100644 --- a/custom_components/opnsense/translations/en.json +++ b/custom_components/opnsense/translations/en.json @@ -117,6 +117,14 @@ "device_id_mismatched": { "title": "OPNsense Hardware Has Changed", "description": "OPNsense Device ID has changed which indicates new or changed hardware. In order to accommodate this, hass-opnsense needs to be removed and reinstalled for this router. hass-opnsense is shutting down." + }, + "plugin_cleanup_partial": { + "title": "OPNsense 26.1 and Plugin cleanup partially completed", + "description": "OPNsense Plugin is still installed. NAT Outbound and NAT Port Forward rules removed. Firewall Filter rules will be removed once the plugin is removed." + }, + "plugin_cleanup_done": { + "title": "OPNsense 26.1 and Plugin cleanup completed", + "description": "NAT Outbound, NAT Port Forward, and Firewall Filter rules removed." } }, "selector": { @@ -315,7 +323,7 @@ }, "reload_interface": { "name": "Reload an Interface", - "description": "Reload or restart an OPNSense interface", + "description": "Reload or restart an OPNsense interface", "fields": { "interface": { "name": "Interface Name", From c7fd82f8a2114be63c87a8bf7240838561d134f1 Mon Sep 17 00:00:00 2001 From: Snuffy2 Date: Tue, 3 Feb 2026 22:11:09 -0500 Subject: [PATCH 10/15] Add tests and Update docstrings --- custom_components/opnsense/__init__.py | 205 +++++++++++++++++++++++-- tests/test_init.py | 115 ++++++++++++++ 2 files changed, 310 insertions(+), 10 deletions(-) diff --git a/custom_components/opnsense/__init__.py b/custom_components/opnsense/__init__.py index 6cfe52b9..3a48925d 100644 --- a/custom_components/opnsense/__init__.py +++ b/custom_components/opnsense/__init__.py @@ -1,4 +1,9 @@ -"""Support for OPNsense.""" +"""Home Assistant integration for OPNsense firewalls. + +This integration provides monitoring and control of OPNsense firewall devices, +including system information, network interfaces, firewall rules, DHCP leases, +and various other OPNsense features through the Home Assistant interface. +""" from collections.abc import Mapping, MutableMapping from datetime import timedelta @@ -59,7 +64,21 @@ async def _async_update_listener(hass: HomeAssistant, entry: ConfigEntry) -> None: - """Handle options update.""" + """Handle options update for the OPNsense integration. + + This function is called when the configuration entry options are updated. + It handles reloading the integration if necessary, removing entities that + are no longer enabled based on granular sync options, and cleaning up + device tracker devices if device tracking is disabled. + + Parameters + ---------- + hass : HomeAssistant + The Home Assistant instance. + entry : ConfigEntry + The configuration entry for the OPNsense integration. + + """ # _LOGGER.debug("[async_update_listener] entry: %s", entry.as_dict()) if getattr(entry.runtime_data, SHOULD_RELOAD, True): _LOGGER.info("[async_update_listener] Reloading") @@ -104,13 +123,55 @@ async def _async_update_listener(hass: HomeAssistant, entry: ConfigEntry) -> Non async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: - """Call the method to setup the integration-level services.""" + """Set up the OPNsense integration at the domain level. + + This function is called during Home Assistant startup to initialize + integration-level services for the OPNsense domain. + + Parameters + ---------- + hass : HomeAssistant + The Home Assistant instance. + config : ConfigType + The configuration dictionary (unused for config entry only integrations). + + Returns + ------- + bool + Always returns True to indicate successful setup. + + """ await async_setup_services(hass) return True async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: - """Set up OPNsense from a config entry.""" + """Set up the OPNsense integration from a configuration entry. + + This function initializes the OPNsense client, coordinators, and platforms + based on the provided configuration entry. It performs firmware version + checks, handles device ID validation, and sets up data coordinators for + state updates and device tracking. + + Parameters + ---------- + hass : HomeAssistant + The Home Assistant instance. + entry : ConfigEntry + The configuration entry containing OPNsense connection details. + + Returns + ------- + bool + True if setup was successful, False otherwise. + + Raises + ------ + Various exceptions may be raised during client initialization or firmware + checks, but they are handled internally with appropriate logging and issue + creation. + + """ config: Mapping[str, Any] = entry.data options: Mapping[str, Any] = entry.options # _LOGGER.debug("[async_setup_entry] entry: %s", entry.as_dict()) @@ -282,7 +343,24 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def _deprecated_plugin_cleanup_26_1( hass: HomeAssistant, client: OPNsenseClient, entry_id: str -) -> bool: +) -> None: + """Clean up deprecated entities for OPNsense 26.1 and plugin compatibility. + + This function removes switch entities that are no longer supported in + OPNsense 26.1, specifically firewall filter rules (when plugin not installed) + and NAT port forward/outbound rules. It creates appropriate issues to + inform the user about the cleanup. + + Parameters + ---------- + hass : HomeAssistant + The Home Assistant instance. + client : OPNsenseClient + The OPNsense client instance. + entry_id : str + The configuration entry ID. + + """ _LOGGER.debug("Starting OPNsense 26.1 and Plugin cleanup") entity_registry = er.async_get(hass) plugin_installed: bool = await client.is_plugin_installed() @@ -338,14 +416,31 @@ async def _deprecated_plugin_cleanup_26_1( severity=ir.IssueSeverity.WARNING, translation_key="plugin_cleanup_done", ) - return True async def async_remove_config_entry_device( hass: HomeAssistant, config_entry: ConfigEntry, device_entry: dr.DeviceEntry ) -> bool: - """Remove OPNsense Devices that aren't Device Tracker Devices and without any linked entities.""" - + """Remove OPNsense devices that are not device tracker devices and have no linked entities. + + This function checks if an OPNsense device can be safely removed. It prevents + removal of device tracker devices and devices that still have linked entities. + + Parameters + ---------- + hass : HomeAssistant + The Home Assistant instance. + config_entry : ConfigEntry + The configuration entry for the OPNsense integration. + device_entry : dr.DeviceEntry + The device entry to be removed. + + Returns + ------- + bool + True if the device can be removed, False otherwise. + + """ if device_entry.via_device_id: _LOGGER.error("Remove OPNsense Device Tracker Devices via the Integration Configuration") return False @@ -358,7 +453,24 @@ async def async_remove_config_entry_device( async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: - """Unload a config entry.""" + """Unload the OPNsense integration configuration entry. + + This function unloads all platforms associated with the configuration entry, + closes the OPNsense client connection, and cleans up the entry data. + + Parameters + ---------- + hass : HomeAssistant + The Home Assistant instance. + entry : ConfigEntry + The configuration entry to unload. + + Returns + ------- + bool + True if unloading was successful, False otherwise. + + """ _LOGGER.info("Unloading: %s", entry.as_dict()) platforms: list[Platform] = getattr(entry.runtime_data, LOADED_PLATFORMS) client: OPNsenseClient = getattr(entry.runtime_data, OPNSENSE_CLIENT) @@ -373,6 +485,24 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def _migrate_1_to_2(hass: HomeAssistant, config_entry: ConfigEntry) -> bool: + """Migrate configuration entry from version 1 to version 2. + + This migration replaces the deprecated 'tls_insecure' option with + 'verify_ssl' for SSL certificate verification. + + Parameters + ---------- + hass : HomeAssistant + The Home Assistant instance. + config_entry : ConfigEntry + The configuration entry to migrate. + + Returns + ------- + bool + Always returns True. + + """ tls_insecure = config_entry.data.get(CONF_TLS_INSECURE, DEFAULT_TLS_INSECURE) data: MutableMapping[str, Any] = dict(config_entry.data) @@ -389,6 +519,25 @@ async def _migrate_1_to_2(hass: HomeAssistant, config_entry: ConfigEntry) -> boo async def _migrate_2_to_3(hass: HomeAssistant, config_entry: ConfigEntry) -> bool: + """Migrate configuration entry from version 2 to version 3. + + This migration updates device unique IDs to use the lowest MAC address + and updates entity unique IDs accordingly. It also updates device + identifiers in the device registry. + + Parameters + ---------- + hass : HomeAssistant + The Home Assistant instance. + config_entry : ConfigEntry + The configuration entry to migrate. + + Returns + ------- + bool + True if migration was successful, False otherwise. + + """ _LOGGER.debug("[migrate_2_to_3] Initial Version: %s", config_entry.version) entity_registry = er.async_get(hass) device_registry = dr.async_get(hass) @@ -498,6 +647,25 @@ async def _migrate_2_to_3(hass: HomeAssistant, config_entry: ConfigEntry) -> boo async def _migrate_3_to_4(hass: HomeAssistant, config_entry: ConfigEntry) -> bool: + """Migrate configuration entry from version 3 to version 4. + + This migration moves telemetry-based entities (interfaces, gateways, openvpn) + out of the telemetry namespace and updates their unique IDs. It also removes + deprecated connected client count sensors. + + Parameters + ---------- + hass : HomeAssistant + The Home Assistant instance. + config_entry : ConfigEntry + The configuration entry to migrate. + + Returns + ------- + bool + True if migration was successful, False otherwise. + + """ _LOGGER.debug("[migrate_3_to_4] Initial Version: %s", config_entry.version) entity_registry = er.async_get(hass) @@ -601,7 +769,24 @@ async def _migrate_3_to_4(hass: HomeAssistant, config_entry: ConfigEntry) -> boo async def async_migrate_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool: - """Migrate an old config entry.""" + """Migrate an old configuration entry to the latest version. + + This function handles migration of configuration entries from older versions + to the current version by applying sequential migration steps. + + Parameters + ---------- + hass : HomeAssistant + The Home Assistant instance. + config_entry : ConfigEntry + The configuration entry to migrate. + + Returns + ------- + bool + True if migration was successful, False otherwise. + + """ version = config_entry.version if version > 4: diff --git a/tests/test_init.py b/tests/test_init.py index c7b604d6..be574de4 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -496,6 +496,121 @@ async def test_async_setup_entry_firmware_between_min_and_ltd( assert call_args[1].get("severity") == init_mod.ir.IssueSeverity.WARNING +@pytest.mark.asyncio +async def test_async_setup_entry_firmware_triggers_plugin_cleanup( + monkeypatch, ph_hass, coordinator_capture, fake_client, fake_coordinator, make_config_entry +): + """async_setup_entry calls _deprecated_plugin_cleanup_26_1 for firmware >25.10 and <26.7.""" + monkeypatch.setattr(init_mod, "OPNsenseClient", fake_client(firmware_version="26.0")) + monkeypatch.setattr( + init_mod, "OPNsenseDataUpdateCoordinator", coordinator_capture.factory(fake_coordinator) + ) + # Mock the cleanup function to track if it's called + cleanup_mock = AsyncMock(return_value=True) + monkeypatch.setattr(init_mod, "_deprecated_plugin_cleanup_26_1", cleanup_mock) + + entry = make_config_entry( + data={ + init_mod.CONF_URL: "http://1.2.3.4", + init_mod.CONF_USERNAME: "u", + init_mod.CONF_PASSWORD: "p", + init_mod.CONF_DEVICE_UNIQUE_ID: "dev1", + } + ) + hass = ph_hass + hass.config_entries.async_forward_entry_setups = AsyncMock(return_value=True) + hass.config_entries.async_reload = AsyncMock() + hass.data.setdefault("aiohttp_connector", {}) + + res = await init_mod.async_setup_entry(hass, entry) + assert res is True + # Verify cleanup was called with correct args + cleanup_mock.assert_called_once_with(hass=hass, client=ANY, entry_id=entry.entry_id) + + +@pytest.mark.asyncio +async def test_deprecated_plugin_cleanup_26_1_plugin_not_installed(monkeypatch): + """_deprecated_plugin_cleanup_26_1 removes filter entities when plugin not installed.""" + hass = MagicMock(spec=HomeAssistant) + client = MagicMock() + client.is_plugin_installed = AsyncMock(return_value=False) + entry_id = "test_entry_id" + + # Mock entity registry + entity_registry = MagicMock() + monkeypatch.setattr(init_mod.er, "async_get", lambda hass: entity_registry) + + # Mock entities: one filter entity, one normal switch + filter_entity = MagicMock() + filter_entity.entity_id = "switch.opnsense_filter_rule" + filter_entity.unique_id = "dev1_filter_rule1" + normal_entity = MagicMock() + normal_entity.entity_id = "switch.opnsense_normal_rule" + normal_entity.unique_id = "dev1_normal_rule1" + monkeypatch.setattr( + init_mod.er, + "async_entries_for_config_entry", + MagicMock(return_value=[filter_entity, normal_entity]), + ) + + # Mock issue registry + create_issue_mock = MagicMock() + monkeypatch.setattr(init_mod.ir, "async_create_issue", create_issue_mock) + + await init_mod._deprecated_plugin_cleanup_26_1(hass, client, entry_id) + + # Verify filter entity was removed, normal was not + entity_registry.async_remove.assert_called_once_with("switch.opnsense_filter_rule") + # Verify issue created for cleanup done + create_issue_mock.assert_called_once() + call_args = create_issue_mock.call_args + assert call_args[0][2] == "plugin_cleanup_done" + + +@pytest.mark.asyncio +async def test_deprecated_plugin_cleanup_26_1_plugin_installed(monkeypatch): + """_deprecated_plugin_cleanup_26_1 removes NAT entities when plugin is installed.""" + hass = MagicMock(spec=HomeAssistant) + client = MagicMock() + client.is_plugin_installed = AsyncMock(return_value=True) + entry_id = "test_entry_id" + + # Mock entity registry + entity_registry = MagicMock() + monkeypatch.setattr(init_mod.er, "async_get", lambda hass: entity_registry) + + # Mock entities: NAT port forward, NAT outbound, normal switch + nat_pf_entity = MagicMock() + nat_pf_entity.entity_id = "switch.opnsense_nat_port_forward_rule" + nat_pf_entity.unique_id = "dev1_nat_port_forward_rule1" + nat_out_entity = MagicMock() + nat_out_entity.entity_id = "switch.opnsense_nat_outbound_rule" + nat_out_entity.unique_id = "dev1_nat_outbound_rule1" + normal_entity = MagicMock() + normal_entity.entity_id = "switch.opnsense_normal_rule" + normal_entity.unique_id = "dev1_normal_rule1" + monkeypatch.setattr( + init_mod.er, + "async_entries_for_config_entry", + MagicMock(return_value=[nat_pf_entity, nat_out_entity, normal_entity]), + ) + + # Mock issue registry + create_issue_mock = MagicMock() + monkeypatch.setattr(init_mod.ir, "async_create_issue", create_issue_mock) + + await init_mod._deprecated_plugin_cleanup_26_1(hass, client, entry_id) + + # Verify NAT entities were removed, normal was not + assert entity_registry.async_remove.call_count == 2 + entity_registry.async_remove.assert_any_call("switch.opnsense_nat_port_forward_rule") + entity_registry.async_remove.assert_any_call("switch.opnsense_nat_outbound_rule") + # Verify issue created for partial cleanup + create_issue_mock.assert_called_once() + call_args = create_issue_mock.call_args + assert call_args[0][2] == "plugin_cleanup_partial" + + @pytest.mark.asyncio async def test_migrate_2_to_3_missing_device_id(monkeypatch, fake_client): """_migrate_2_to_3 returns False when the client provides no device id.""" From 31ffedbce245fcea577f66393cde82971b2d520f Mon Sep 17 00:00:00 2001 From: Snuffy2 Date: Tue, 3 Feb 2026 22:28:25 -0500 Subject: [PATCH 11/15] Remove plugin check for 26.1+ in config flow --- custom_components/opnsense/config_flow.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/custom_components/opnsense/config_flow.py b/custom_components/opnsense/config_flow.py index ae4d11d9..ef4a51d6 100644 --- a/custom_components/opnsense/config_flow.py +++ b/custom_components/opnsense/config_flow.py @@ -277,13 +277,20 @@ async def _handle_user_input( # Plugin check not required for config step of user. Otherwise, plugin check is required if # granular sync options is enabled - require_plugin_check = ( - config_step != "user" - and user_input.get(CONF_GRANULAR_SYNC_OPTIONS, DEFAULT_GRANULAR_SYNC_OPTIONS) - and any( - user_input.get(item, DEFAULT_SYNC_OPTION_VALUE) for item in SYNC_ITEMS_REQUIRING_PLUGIN + try: + require_plugin_check = ( + config_step != "user" + and user_input.get(CONF_GRANULAR_SYNC_OPTIONS, DEFAULT_GRANULAR_SYNC_OPTIONS) + and any( + user_input.get(item, DEFAULT_SYNC_OPTION_VALUE) + for item in SYNC_ITEMS_REQUIRING_PLUGIN + ) + and awesomeversion.AwesomeVersion(user_input[CONF_FIRMWARE_VERSION]) + < awesomeversion.AwesomeVersion("26.1") ) - ) + except awesomeversion.exceptions.AwesomeVersionCompareException as e: + raise UnknownFirmware from e + _LOGGER.debug( "[handle_user_input] config_step: %s, require_plugin_check: %s", config_step, From b5034109abd9653eb922dd314ef73f542ece1d36 Mon Sep 17 00:00:00 2001 From: Snuffy2 Date: Wed, 4 Feb 2026 20:51:54 -0500 Subject: [PATCH 12/15] Change to require 26.1.1 and cleanup code --- custom_components/opnsense/__init__.py | 8 +- custom_components/opnsense/config_flow.py | 2 +- .../opnsense/pyopnsense/__init__.py | 120 ++----------- custom_components/opnsense/switch.py | 4 +- tests/test_init.py | 22 +-- tests/test_pyopnsense.py | 167 ++++-------------- tests/test_switch.py | 4 +- 7 files changed, 69 insertions(+), 258 deletions(-) diff --git a/custom_components/opnsense/__init__.py b/custom_components/opnsense/__init__.py index 3a48925d..80545fed 100644 --- a/custom_components/opnsense/__init__.py +++ b/custom_components/opnsense/__init__.py @@ -292,9 +292,11 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: try: if awesomeversion.AwesomeVersion(firmware) > awesomeversion.AwesomeVersion( - "25.10" + "26.1" ) and awesomeversion.AwesomeVersion(firmware) < awesomeversion.AwesomeVersion("26.7"): - await _deprecated_plugin_cleanup_26_1(hass=hass, client=client, entry_id=entry.entry_id) + await _deprecated_plugin_cleanup_26_1_1( + hass=hass, client=client, entry_id=entry.entry_id + ) except awesomeversion.exceptions.AwesomeVersionCompareException: _LOGGER.warning("Unable to confirm OPNsense Firmware version") @@ -341,7 +343,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: return True -async def _deprecated_plugin_cleanup_26_1( +async def _deprecated_plugin_cleanup_26_1_1( hass: HomeAssistant, client: OPNsenseClient, entry_id: str ) -> None: """Clean up deprecated entities for OPNsense 26.1 and plugin compatibility. diff --git a/custom_components/opnsense/config_flow.py b/custom_components/opnsense/config_flow.py index ef4a51d6..eff63dda 100644 --- a/custom_components/opnsense/config_flow.py +++ b/custom_components/opnsense/config_flow.py @@ -286,7 +286,7 @@ async def _handle_user_input( for item in SYNC_ITEMS_REQUIRING_PLUGIN ) and awesomeversion.AwesomeVersion(user_input[CONF_FIRMWARE_VERSION]) - < awesomeversion.AwesomeVersion("26.1") + < awesomeversion.AwesomeVersion("26.1.1") ) except awesomeversion.exceptions.AwesomeVersionCompareException as e: raise UnknownFirmware from e diff --git a/custom_components/opnsense/pyopnsense/__init__.py b/custom_components/opnsense/pyopnsense/__init__.py index 603ff803..7036680b 100644 --- a/custom_components/opnsense/pyopnsense/__init__.py +++ b/custom_components/opnsense/pyopnsense/__init__.py @@ -3,11 +3,9 @@ from abc import ABC import asyncio from collections.abc import Callable, MutableMapping -import csv from datetime import datetime, timedelta, timezone from functools import partial import inspect -from io import StringIO import ipaddress import json import logging @@ -369,15 +367,6 @@ async def _get(self, path: str) -> MutableMapping[str, Any] | list | None: await self._request_queue.put(("get", path, None, future, caller)) return await future - async def _get_raw(self, path: str) -> str | None: - try: - caller = inspect.stack()[1].function - except (IndexError, AttributeError): - caller = "Unknown" - future = self._loop.create_future() - await self._request_queue.put(("get_raw", path, None, future, caller)) - return await future - async def _post( self, path: str, payload: MutableMapping[str, Any] | None = None ) -> MutableMapping[str, Any] | list | None: @@ -401,10 +390,6 @@ async def _process_queue(self) -> None: result = await self._do_get(path, caller) if future is not None and not future.done(): future.set_result(result) - elif method == "get_raw": - result = await self._do_get_raw(path, caller) - if future is not None and not future.done(): - future.set_result(result) elif method == "post": result = await self._do_post(path, payload, caller) if future is not None and not future.done(): @@ -565,50 +550,6 @@ async def _do_get( return None - async def _do_get_raw(self, path: str, caller: str = "Unknown") -> str | None: - # /api////[/[/...]] - self._rest_api_query_count += 1 - url: str = f"{self._url}{path}" - _LOGGER.debug("[get_raw] url: %s", url) - try: - async with self._session.get( - url, - auth=aiohttp.BasicAuth(self._username, self._password), - timeout=aiohttp.ClientTimeout(total=DEFAULT_TIMEOUT), - ssl=self._verify_ssl, - ) as response: - _LOGGER.debug("[get_raw] Response %s: %s", response.status, response.reason) - if response.ok: - return await response.text() - if response.status == 403: - _LOGGER.error( - "Permission Error in do_get_raw (called by %s). Path: %s. Ensure the OPNsense user connected to HA has appropriate access. Recommend full admin access", - caller, - url, - ) - else: - _LOGGER.error( - "Error in do_get_raw (called by %s). Path: %s. Response %s: %s", - caller, - url, - response.status, - response.reason, - ) - if self._initial: - raise aiohttp.ClientResponseError( - request_info=response.request_info, - history=response.history, - status=response.status, - message=f"HTTP Status Error: {response.status} {response.reason}", - headers=response.headers, - ) - except aiohttp.ClientError as e: - _LOGGER.error("Client error. %s: %s", type(e).__name__, e) - if self._initial: - raise - - return None - async def _safe_dict_get(self, path: str) -> MutableMapping[str, Any]: """Fetch data from the given path, ensuring the result is a dict.""" result = await self._get(path=path) @@ -949,8 +890,8 @@ async def get_firewall(self) -> dict[str, Any]: try: if awesomeversion.AwesomeVersion( self._firmware_version - ) < awesomeversion.AwesomeVersion("26.1"): - _LOGGER.debug("Using legacy plugin for firewall filters for OPNsense < 26.1") + ) < awesomeversion.AwesomeVersion("26.1.1"): + _LOGGER.debug("Using legacy plugin for firewall filters for OPNsense < 26.1.1") return {"config": await self.get_config()} except awesomeversion.exceptions.AwesomeVersionCompareException: _LOGGER.warning("Error comparing firmware version. Skipping get_firewall.") @@ -958,8 +899,7 @@ async def get_firewall(self) -> dict[str, Any]: firewall: dict[str, Any] = {"nat": {}} if await self.is_plugin_installed(): firewall["config"] = await self.get_config() - interface_map = await self._get_interface_firewall_map() - firewall["rules"] = await self._get_firewall_rules(interface_map=interface_map) + firewall["rules"] = await self._get_firewall_rules() firewall["nat"]["d_nat"] = await self._get_nat_destination_rules() firewall["nat"]["one_to_one"] = await self._get_nat_one_to_one_rules() firewall["nat"]["source_nat"] = await self._get_nat_source_rules() @@ -968,30 +908,7 @@ async def get_firewall(self) -> dict[str, Any]: return firewall @_log_errors - async def _get_interface_firewall_map(self) -> dict[str, Any]: - """Retrieve a mapping of interface names to firewall interface names. - - Returns - ------- - dict - A dictionary mapping interface names to firewall interface names. - - """ - interfaces = await self._safe_dict_get("/api/firewall/filter/get_interface_list") - interface_map: dict[str, Any] = {} - - if isinstance(interfaces, MutableMapping): - for section in interfaces.values(): - if isinstance(section, MutableMapping) and "items" in section: - for item in section["items"]: - if isinstance(item, MutableMapping) and "value" in item and "label" in item: - interface_map[item["value"]] = item["label"] - - _LOGGER.debug("[get_interface_firewall_map] interface_map: %s", interface_map) - return interface_map - - @_log_errors - async def _get_firewall_rules(self, interface_map: dict[str, Any]) -> dict[str, Any]: + async def _get_firewall_rules(self) -> dict[str, Any]: """Retrieve firewall rules from OPNsense. Parameters @@ -1007,30 +924,19 @@ async def _get_firewall_rules(self, interface_map: dict[str, Any]) -> dict[str, as 'uuid', 'enabled', 'action', etc. """ - response = await self._get_raw("/api/firewall/filter/download_rules") + request_body: MutableMapping[str, Any] = {"current": 1, "sort": {}} + response = await self._safe_dict_post( + "/api/firewall/filter/search_rule", payload=request_body + ) # _LOGGER.debug("[get_firewall_rules] response: %s", response) - if not response or not isinstance(response, str): - return {} - - try: - reader = csv.DictReader(StringIO(response)) - except (csv.Error, ValueError) as e: - _LOGGER.error("Failed to parse firewall rules CSV: %s", e) - return {} - if not reader.fieldnames: - return {} - rules = [row for row in reader if row] - - # _LOGGER.debug("[get_firewall_rules] rules: %s", rules) + rules: list = response.get("rows", []) + _LOGGER.debug("[get_firewall_rules] rules: %s", rules) rules_dict: dict[str, Any] = {} for rule in rules: - new_rule = rule.copy() - new_rule["uuid"] = new_rule.pop("@uuid", "") - if not new_rule["uuid"]: + if not rule.get("uuid") or "lockout" in rule.get("uuid"): continue - new_rule["%interface"] = interface_map.get( - new_rule.get("interface", ""), new_rule.get("interface", "") - ) + new_rule = rule.copy() + # Add any transforms here rules_dict[new_rule["uuid"]] = new_rule _LOGGER.debug("[get_firewall_rules] rules_dict: %s", rules_dict) return rules_dict diff --git a/custom_components/opnsense/switch.py b/custom_components/opnsense/switch.py index 12203203..6d2c7686 100644 --- a/custom_components/opnsense/switch.py +++ b/custom_components/opnsense/switch.py @@ -656,7 +656,9 @@ async def async_setup_entry( firmware = state.get("host_firmware_version", None) if firmware: try: - if awesomeversion.AwesomeVersion(firmware) < awesomeversion.AwesomeVersion("26.1"): + if awesomeversion.AwesomeVersion(firmware) < awesomeversion.AwesomeVersion( + "26.1.1" + ): entities.extend( await _compile_filter_switches_legacy(config_entry, coordinator, state) ) diff --git a/tests/test_init.py b/tests/test_init.py index be574de4..67dc84b9 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -500,14 +500,14 @@ async def test_async_setup_entry_firmware_between_min_and_ltd( async def test_async_setup_entry_firmware_triggers_plugin_cleanup( monkeypatch, ph_hass, coordinator_capture, fake_client, fake_coordinator, make_config_entry ): - """async_setup_entry calls _deprecated_plugin_cleanup_26_1 for firmware >25.10 and <26.7.""" - monkeypatch.setattr(init_mod, "OPNsenseClient", fake_client(firmware_version="26.0")) + """async_setup_entry calls _deprecated_plugin_cleanup_26_1_1 for firmware >25.10 and <26.7.""" + monkeypatch.setattr(init_mod, "OPNsenseClient", fake_client(firmware_version="26.2")) monkeypatch.setattr( init_mod, "OPNsenseDataUpdateCoordinator", coordinator_capture.factory(fake_coordinator) ) # Mock the cleanup function to track if it's called cleanup_mock = AsyncMock(return_value=True) - monkeypatch.setattr(init_mod, "_deprecated_plugin_cleanup_26_1", cleanup_mock) + monkeypatch.setattr(init_mod, "_deprecated_plugin_cleanup_26_1_1", cleanup_mock) entry = make_config_entry( data={ @@ -524,13 +524,13 @@ async def test_async_setup_entry_firmware_triggers_plugin_cleanup( res = await init_mod.async_setup_entry(hass, entry) assert res is True - # Verify cleanup was called with correct args - cleanup_mock.assert_called_once_with(hass=hass, client=ANY, entry_id=entry.entry_id) + # Verify cleanup was called with correct args and awaited + cleanup_mock.assert_awaited_once_with(hass=hass, client=ANY, entry_id=entry.entry_id) @pytest.mark.asyncio -async def test_deprecated_plugin_cleanup_26_1_plugin_not_installed(monkeypatch): - """_deprecated_plugin_cleanup_26_1 removes filter entities when plugin not installed.""" +async def test_deprecated_plugin_cleanup_26_1_1_plugin_not_installed(monkeypatch): + """_deprecated_plugin_cleanup_26_1_1 removes filter entities when plugin not installed.""" hass = MagicMock(spec=HomeAssistant) client = MagicMock() client.is_plugin_installed = AsyncMock(return_value=False) @@ -557,7 +557,7 @@ async def test_deprecated_plugin_cleanup_26_1_plugin_not_installed(monkeypatch): create_issue_mock = MagicMock() monkeypatch.setattr(init_mod.ir, "async_create_issue", create_issue_mock) - await init_mod._deprecated_plugin_cleanup_26_1(hass, client, entry_id) + await init_mod._deprecated_plugin_cleanup_26_1_1(hass, client, entry_id) # Verify filter entity was removed, normal was not entity_registry.async_remove.assert_called_once_with("switch.opnsense_filter_rule") @@ -568,8 +568,8 @@ async def test_deprecated_plugin_cleanup_26_1_plugin_not_installed(monkeypatch): @pytest.mark.asyncio -async def test_deprecated_plugin_cleanup_26_1_plugin_installed(monkeypatch): - """_deprecated_plugin_cleanup_26_1 removes NAT entities when plugin is installed.""" +async def test_deprecated_plugin_cleanup_26_1_1_plugin_installed(monkeypatch): + """_deprecated_plugin_cleanup_26_1_1 removes NAT entities when plugin is installed.""" hass = MagicMock(spec=HomeAssistant) client = MagicMock() client.is_plugin_installed = AsyncMock(return_value=True) @@ -599,7 +599,7 @@ async def test_deprecated_plugin_cleanup_26_1_plugin_installed(monkeypatch): create_issue_mock = MagicMock() monkeypatch.setattr(init_mod.ir, "async_create_issue", create_issue_mock) - await init_mod._deprecated_plugin_cleanup_26_1(hass, client, entry_id) + await init_mod._deprecated_plugin_cleanup_26_1_1(hass, client, entry_id) # Verify NAT entities were removed, normal was not assert entity_registry.async_remove.call_count == 2 diff --git a/tests/test_pyopnsense.py b/tests/test_pyopnsense.py index 96c5f8f5..0459ea8e 100644 --- a/tests/test_pyopnsense.py +++ b/tests/test_pyopnsense.py @@ -948,13 +948,12 @@ async def test_exec_php_error_paths(exc_factory, initial: bool, make_client) -> [ ("_do_get", "get", ("/api/x",), {"caller": "tst"}), ("_do_post", "post", ("/api/x",), {"payload": {}}), - ("_do_get_raw", "get", ("/api/x",), {"caller": "tst"}), ], ) -async def test_do_get_post_get_raw_error_initial_behavior( +async def test_do_get_post_error_initial_behavior( method_name, session_method, args, kwargs, make_client ) -> None: - """When client._initial is True, non-ok responses should raise ClientResponseError for _do_get/_do_post/_do_get_raw.""" + """When client._initial is True, non-ok responses should raise ClientResponseError for _do_get/_do_post.""" session = MagicMock(spec=aiohttp.ClientSession) # create a fake response context manager @@ -1006,36 +1005,6 @@ async def iter_chunked(self, n): await client.async_close() -@pytest.mark.asyncio -async def test_do_get_raw_client_error_initial_behavior(make_client) -> None: - """When client._initial is True, aiohttp.ClientError should be re-raised for _do_get_raw.""" - session = MagicMock(spec=aiohttp.ClientSession) - session.get.side_effect = aiohttp.ClientError("Connection failed") - - client = make_client(session=session) - client._initial = True - try: - with pytest.raises(aiohttp.ClientError): - await client._do_get_raw("/api/x", caller="tst") - finally: - await client.async_close() - - -@pytest.mark.asyncio -async def test_do_get_raw_client_error_non_initial_behavior(make_client) -> None: - """When client._initial is False, aiohttp.ClientError should be logged and None returned for _do_get_raw.""" - session = MagicMock(spec=aiohttp.ClientSession) - session.get.side_effect = aiohttp.ClientError("Connection failed") - - client = make_client(session=session) - client._initial = False - try: - result = await client._do_get_raw("/api/x", caller="tst") - assert result is None - finally: - await client.async_close() - - @pytest.mark.asyncio async def test_get_from_stream_parsing(make_client, fake_stream_response_factory) -> None: """Simulate SSE-like stream with two messages and assert parsing returns dict.""" @@ -3240,7 +3209,7 @@ async def test_get_device_unique_id_no_mac(make_client) -> None: @pytest.mark.asyncio async def test_get_firewall_legacy_fallback(make_client) -> None: - """get_firewall falls back to legacy config for OPNsense < 26.1.""" + """get_firewall falls back to legacy config for OPNsense < 26.1.1.""" session = MagicMock(spec=aiohttp.ClientSession) client = make_client(session=session) client._firmware_version = "25.7.0" @@ -3256,15 +3225,14 @@ async def test_get_firewall_legacy_fallback(make_client) -> None: @pytest.mark.asyncio async def test_get_firewall_new_api(make_client) -> None: - """get_firewall uses new API for OPNsense >= 26.1.""" + """get_firewall uses new API for OPNsense >= 26.1.1.""" session = MagicMock(spec=aiohttp.ClientSession) client = make_client(session=session) - client._firmware_version = "26.1.0" + client._firmware_version = "26.1.1" # Mock all the methods called in the new API path client.is_plugin_installed = AsyncMock(return_value=True) client.get_config = AsyncMock(return_value={"filter": {"rule": []}}) - client._get_interface_firewall_map = AsyncMock(return_value={"lan": "LAN"}) client._get_firewall_rules = AsyncMock(return_value={"rule1": {"uuid": "rule1"}}) client._get_nat_destination_rules = AsyncMock(return_value={"nat1": {"uuid": "nat1"}}) client._get_nat_one_to_one_rules = AsyncMock(return_value={"one1": {"uuid": "one1"}}) @@ -3285,7 +3253,6 @@ async def test_get_firewall_new_api(make_client) -> None: assert result == expected client.is_plugin_installed.assert_awaited_once() client.get_config.assert_awaited_once() - client._get_interface_firewall_map.assert_awaited_once() client._get_firewall_rules.assert_awaited_once() client._get_nat_destination_rules.assert_awaited_once() client._get_nat_one_to_one_rules.assert_awaited_once() @@ -3308,136 +3275,68 @@ async def test_get_firewall_version_compare_exception(make_client) -> None: @pytest.mark.asyncio async def test_get_firewall_rules_successful_parsing(make_client) -> None: - """_get_firewall_rules successfully parses valid CSV data.""" + """_get_firewall_rules successfully parses rows returned from the REST API.""" session = MagicMock(spec=aiohttp.ClientSession) client = make_client(session=session) - # Mock CSV response with valid firewall rules - csv_data = """@uuid,enabled,action,interface,descr -rule1,1,pass,lan,Allow HTTP -rule2,0,block,wan,Block traffic -""" - client._get_raw = AsyncMock(return_value=csv_data) - - interface_map = {"lan": "LAN", "wan": "WAN"} - result = await client._get_firewall_rules(interface_map) - - expected = { - "rule1": { + rows = [ + { "uuid": "rule1", "enabled": "1", "action": "pass", "interface": "lan", - "%interface": "LAN", "descr": "Allow HTTP", }, - "rule2": { + { "uuid": "rule2", "enabled": "0", "action": "block", "interface": "wan", - "%interface": "WAN", "descr": "Block traffic", }, - } - assert result == expected - client._get_raw.assert_awaited_once_with("/api/firewall/filter/download_rules") - await client.async_close() - - -@pytest.mark.asyncio -async def test_get_firewall_rules_none_response(make_client) -> None: - """_get_firewall_rules returns empty dict when response is None.""" - session = MagicMock(spec=aiohttp.ClientSession) - client = make_client(session=session) - - client._get_raw = AsyncMock(return_value=None) - - result = await client._get_firewall_rules({}) - assert result == {} - await client.async_close() - - -@pytest.mark.asyncio -async def test_get_firewall_rules_non_string_response(make_client) -> None: - """_get_firewall_rules returns empty dict when response is not a string.""" - session = MagicMock(spec=aiohttp.ClientSession) - client = make_client(session=session) - - client._get_raw = AsyncMock(return_value=123) # Non-string response - - result = await client._get_firewall_rules({}) - assert result == {} - await client.async_close() - + ] -@pytest.mark.asyncio -async def test_get_firewall_rules_csv_parsing_error(make_client) -> None: - """_get_firewall_rules returns empty dict when CSV parsing fails.""" - session = MagicMock(spec=aiohttp.ClientSession) - client = make_client(session=session) + client._safe_dict_post = AsyncMock(return_value={"rows": rows}) - # Invalid CSV that will cause parsing error - unterminated quote - client._get_raw = AsyncMock(return_value='"unterminated,quote\nvalue1,value2') + result = await client._get_firewall_rules() - result = await client._get_firewall_rules({}) - assert result == {} + expected = {r["uuid"]: r.copy() for r in rows} + assert result == expected + client._safe_dict_post.assert_awaited_once_with( + "/api/firewall/filter/search_rule", payload={"current": 1, "sort": {}} + ) await client.async_close() @pytest.mark.asyncio -async def test_get_firewall_rules_empty_csv(make_client) -> None: - """_get_firewall_rules returns empty dict when CSV has no data rows.""" +async def test_get_firewall_rules_empty_response(make_client) -> None: + """_get_firewall_rules returns empty dict when API response has no rows.""" session = MagicMock(spec=aiohttp.ClientSession) client = make_client(session=session) - # CSV with only headers, no data rows - client._get_raw = AsyncMock(return_value="@uuid,enabled,action\n") + client._safe_dict_post = AsyncMock(return_value={}) - result = await client._get_firewall_rules({}) + result = await client._get_firewall_rules() assert result == {} await client.async_close() @pytest.mark.asyncio -async def test_get_firewall_rules_interface_mapping(make_client) -> None: - """_get_firewall_rules handles interface mapping correctly.""" +async def test_get_firewall_rules_skips_invalid_rows(make_client) -> None: + """_get_firewall_rules skips rules without uuid and lockout rules.""" session = MagicMock(spec=aiohttp.ClientSession) client = make_client(session=session) - # CSV with interfaces, some in map, some not - csv_data = """@uuid,enabled,interface,descr -rule1,1,lan,LAN rule -rule2,1,opt1,OPT1 rule -rule3,1,unknown,Unknown interface -""" - client._get_raw = AsyncMock(return_value=csv_data) - - interface_map = {"lan": "LAN", "opt1": "OPT1"} # opt1 in map, unknown not in map - result = await client._get_firewall_rules(interface_map) - - assert result["rule1"]["%interface"] == "LAN" # Mapped - assert result["rule2"]["%interface"] == "OPT1" # Mapped - assert result["rule3"]["%interface"] == "unknown" # Not mapped, uses original - await client.async_close() - - -@pytest.mark.asyncio -async def test_get_firewall_rules_missing_uuid(make_client) -> None: - """_get_firewall_rules handles rules without @uuid field.""" - session = MagicMock(spec=aiohttp.ClientSession) - client = make_client(session=session) - - # CSV with missing @uuid field - csv_data = """enabled,action,descr -1,pass,Rule without UUID -""" - client._get_raw = AsyncMock(return_value=csv_data) + rows = [ + {"enabled": "1", "action": "pass"}, # missing uuid + {"uuid": "lockout-1", "enabled": "1"}, # lockout rule + {"uuid": "rule-ok", "enabled": "1"}, # valid + ] - result = await client._get_firewall_rules({}) + client._safe_dict_post = AsyncMock(return_value={"rows": rows}) - # Rules without @uuid are skipped, so result should be empty - assert result == {} + result = await client._get_firewall_rules() + assert list(result.keys()) == ["rule-ok"] await client.async_close() @@ -3555,6 +3454,8 @@ async def test_nat_rules_parsing( assert result == expected # Verify the correct API endpoint was called - client._safe_dict_post.assert_called_with(api_endpoint, payload={"current": 1, "sort": {}}) + client._safe_dict_post.assert_awaited_once_with( + api_endpoint, payload={"current": 1, "sort": {}} + ) await client.async_close() diff --git a/tests/test_switch.py b/tests/test_switch.py index 6dd6802a..7b75aa16 100644 --- a/tests/test_switch.py +++ b/tests/test_switch.py @@ -412,7 +412,7 @@ def fake_add_entities(entities): @pytest.mark.asyncio async def test_async_setup_entry_new_firewall_api(coordinator, ph_hass, make_config_entry): - """Async setup should create entities for new firewall API (>= 26.1).""" + """Async setup should create entities for new firewall API (>= 26.1.1).""" calls = {} def fake_add_entities(entities): @@ -464,7 +464,7 @@ def fake_add_entities(entities): }, }, }, - "host_firmware_version": "26.1.0", + "host_firmware_version": "26.1.1", } coordinator.data = state From 1113df34d42cfa92499d5597db583573c66a72d2 Mon Sep 17 00:00:00 2001 From: Snuffy2 Date: Fri, 6 Feb 2026 16:09:33 -0500 Subject: [PATCH 13/15] Refine fimrware checks and Cleanup issues --- custom_components/opnsense/__init__.py | 22 +++++++------- custom_components/opnsense/config_flow.py | 4 +-- .../opnsense/pyopnsense/__init__.py | 16 +++++++--- tests/test_init.py | 29 ++++++++++++------- 4 files changed, 44 insertions(+), 27 deletions(-) diff --git a/custom_components/opnsense/__init__.py b/custom_components/opnsense/__init__.py index 80545fed..fd835eca 100644 --- a/custom_components/opnsense/__init__.py +++ b/custom_components/opnsense/__init__.py @@ -244,7 +244,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: ir.async_create_issue( hass, DOMAIN, - f"opnsense_{firmware}_below_min_firmware_{OPNSENSE_MIN_FIRMWARE}", + f"{config_device_id}_opnsense_below_min_firmware_{OPNSENSE_MIN_FIRMWARE}", is_fixable=False, is_persistent=False, issue_domain=DOMAIN, @@ -264,7 +264,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: ir.async_create_issue( hass, DOMAIN, - f"opnsense_{firmware}_below_ltd_firmware_{OPNSENSE_LTD_FIRMWARE}", + f"{config_device_id}_opnsense_below_ltd_firmware_{OPNSENSE_LTD_FIRMWARE}", is_fixable=False, is_persistent=False, issue_domain=DOMAIN, @@ -280,14 +280,14 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: ir.async_delete_issue( hass, DOMAIN, - f"opnsense_{firmware}_below_min_firmware_{OPNSENSE_MIN_FIRMWARE}", + f"{config_device_id}_opnsense_below_min_firmware_{OPNSENSE_MIN_FIRMWARE}", ) ir.async_delete_issue( hass, DOMAIN, - f"opnsense_{firmware}_below_ltd_firmware_{OPNSENSE_LTD_FIRMWARE}", + f"{config_device_id}_opnsense_below_ltd_firmware_{OPNSENSE_LTD_FIRMWARE}", ) - except awesomeversion.exceptions.AwesomeVersionCompareException: + except (awesomeversion.exceptions.AwesomeVersionCompareException, TypeError, ValueError): _LOGGER.warning("Unable to confirm OPNsense Firmware version") try: @@ -295,9 +295,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: "26.1" ) and awesomeversion.AwesomeVersion(firmware) < awesomeversion.AwesomeVersion("26.7"): await _deprecated_plugin_cleanup_26_1_1( - hass=hass, client=client, entry_id=entry.entry_id + hass=hass, client=client, entry_id=entry.entry_id, config_device_id=config_device_id ) - except awesomeversion.exceptions.AwesomeVersionCompareException: + except (awesomeversion.exceptions.AwesomeVersionCompareException, TypeError, ValueError): _LOGGER.warning("Unable to confirm OPNsense Firmware version") await coordinator.async_config_entry_first_refresh() @@ -344,7 +344,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def _deprecated_plugin_cleanup_26_1_1( - hass: HomeAssistant, client: OPNsenseClient, entry_id: str + hass: HomeAssistant, client: OPNsenseClient, entry_id: str, config_device_id: str ) -> None: """Clean up deprecated entities for OPNsense 26.1 and plugin compatibility. @@ -361,6 +361,8 @@ async def _deprecated_plugin_cleanup_26_1_1( The OPNsense client instance. entry_id : str The configuration entry ID. + config_device_id : str + The device unique ID from the configuration. """ _LOGGER.debug("Starting OPNsense 26.1 and Plugin cleanup") @@ -397,7 +399,7 @@ async def _deprecated_plugin_cleanup_26_1_1( ir.async_create_issue( hass, DOMAIN, - "plugin_cleanup_partial", + f"{config_device_id}_plugin_cleanup_partial", is_fixable=False, is_persistent=False, issue_domain=DOMAIN, @@ -411,7 +413,7 @@ async def _deprecated_plugin_cleanup_26_1_1( ir.async_create_issue( hass, DOMAIN, - "plugin_cleanup_done", + f"{config_device_id}_plugin_cleanup_done", is_fixable=False, is_persistent=False, issue_domain=DOMAIN, diff --git a/custom_components/opnsense/config_flow.py b/custom_components/opnsense/config_flow.py index eff63dda..1edf3bcc 100644 --- a/custom_components/opnsense/config_flow.py +++ b/custom_components/opnsense/config_flow.py @@ -270,7 +270,7 @@ async def _handle_user_input( try: _validate_firmware_version(user_input[CONF_FIRMWARE_VERSION]) - except awesomeversion.exceptions.AwesomeVersionCompareException as e: + except (awesomeversion.exceptions.AwesomeVersionCompareException, TypeError, ValueError) as e: raise UnknownFirmware from e await client.set_use_snake_case(initial=True) @@ -288,7 +288,7 @@ async def _handle_user_input( and awesomeversion.AwesomeVersion(user_input[CONF_FIRMWARE_VERSION]) < awesomeversion.AwesomeVersion("26.1.1") ) - except awesomeversion.exceptions.AwesomeVersionCompareException as e: + except (awesomeversion.exceptions.AwesomeVersionCompareException, TypeError, ValueError) as e: raise UnknownFirmware from e _LOGGER.debug( diff --git a/custom_components/opnsense/pyopnsense/__init__.py b/custom_components/opnsense/pyopnsense/__init__.py index 7036680b..5f257d7a 100644 --- a/custom_components/opnsense/pyopnsense/__init__.py +++ b/custom_components/opnsense/pyopnsense/__init__.py @@ -331,7 +331,11 @@ async def set_use_snake_case(self, initial: bool = False) -> None: self._use_snake_case = False else: _LOGGER.debug("Using snake_case for OPNsense >= 25.7") - except awesomeversion.exceptions.AwesomeVersionCompareException as e: + except ( + awesomeversion.exceptions.AwesomeVersionCompareException, + TypeError, + ValueError, + ) as e: _LOGGER.error( "Error comparing firmware version %s. Using snake_case by default", self._firmware_version, @@ -696,7 +700,11 @@ async def get_firmware_update_info(self) -> MutableMapping[str, Any]: ): _LOGGER.debug("Update available but missing details") update_needs_info = True - except awesomeversion.exceptions.AwesomeVersionCompareException as e: + except ( + awesomeversion.exceptions.AwesomeVersionCompareException, + TypeError, + ValueError, + ) as e: _LOGGER.debug("Error checking firmware versions. %s: %s", type(e).__name__, e) update_needs_info = True @@ -893,7 +901,7 @@ async def get_firewall(self) -> dict[str, Any]: ) < awesomeversion.AwesomeVersion("26.1.1"): _LOGGER.debug("Using legacy plugin for firewall filters for OPNsense < 26.1.1") return {"config": await self.get_config()} - except awesomeversion.exceptions.AwesomeVersionCompareException: + except (awesomeversion.exceptions.AwesomeVersionCompareException, TypeError, ValueError): _LOGGER.warning("Error comparing firmware version. Skipping get_firewall.") return {} firewall: dict[str, Any] = {"nat": {}} @@ -1357,7 +1365,7 @@ async def _get_dnsmasq_leases(self) -> list: ) < awesomeversion.AwesomeVersion("25.1.7"): _LOGGER.debug("Skipping get_dnsmasq_leases for OPNsense < 25.1.7") return [] - except awesomeversion.exceptions.AwesomeVersionCompareException: + except (awesomeversion.exceptions.AwesomeVersionCompareException, TypeError, ValueError): pass response = await self._safe_dict_get("/api/dnsmasq/leases/search") diff --git a/tests/test_init.py b/tests/test_init.py index 67dc84b9..62ca8b08 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -491,7 +491,7 @@ async def test_async_setup_entry_firmware_between_min_and_ltd( call_args = create_issue_mock.call_args # args: (hass, domain, issue_id, ...) assert call_args[0][1] == init_mod.DOMAIN - expected_issue_id = f"opnsense_25.0_below_ltd_firmware_{init_mod.OPNSENSE_LTD_FIRMWARE}" + expected_issue_id = f"{entry.data[init_mod.CONF_DEVICE_UNIQUE_ID]}_opnsense_below_ltd_firmware_{init_mod.OPNSENSE_LTD_FIRMWARE}" assert call_args[0][2] == expected_issue_id assert call_args[1].get("severity") == init_mod.ir.IssueSeverity.WARNING @@ -525,7 +525,12 @@ async def test_async_setup_entry_firmware_triggers_plugin_cleanup( res = await init_mod.async_setup_entry(hass, entry) assert res is True # Verify cleanup was called with correct args and awaited - cleanup_mock.assert_awaited_once_with(hass=hass, client=ANY, entry_id=entry.entry_id) + cleanup_mock.assert_awaited_once_with( + hass=hass, + client=ANY, + entry_id=entry.entry_id, + config_device_id=entry.data[init_mod.CONF_DEVICE_UNIQUE_ID], + ) @pytest.mark.asyncio @@ -557,14 +562,15 @@ async def test_deprecated_plugin_cleanup_26_1_1_plugin_not_installed(monkeypatch create_issue_mock = MagicMock() monkeypatch.setattr(init_mod.ir, "async_create_issue", create_issue_mock) - await init_mod._deprecated_plugin_cleanup_26_1_1(hass, client, entry_id) + config_device_id = "dev1" + await init_mod._deprecated_plugin_cleanup_26_1_1(hass, client, entry_id, config_device_id) # Verify filter entity was removed, normal was not entity_registry.async_remove.assert_called_once_with("switch.opnsense_filter_rule") # Verify issue created for cleanup done create_issue_mock.assert_called_once() call_args = create_issue_mock.call_args - assert call_args[0][2] == "plugin_cleanup_done" + assert call_args[0][2] == f"{config_device_id}_plugin_cleanup_done" @pytest.mark.asyncio @@ -599,7 +605,8 @@ async def test_deprecated_plugin_cleanup_26_1_1_plugin_installed(monkeypatch): create_issue_mock = MagicMock() monkeypatch.setattr(init_mod.ir, "async_create_issue", create_issue_mock) - await init_mod._deprecated_plugin_cleanup_26_1_1(hass, client, entry_id) + config_device_id = "dev1" + await init_mod._deprecated_plugin_cleanup_26_1_1(hass, client, entry_id, config_device_id) # Verify NAT entities were removed, normal was not assert entity_registry.async_remove.call_count == 2 @@ -608,7 +615,7 @@ async def test_deprecated_plugin_cleanup_26_1_1_plugin_installed(monkeypatch): # Verify issue created for partial cleanup create_issue_mock.assert_called_once() call_args = create_issue_mock.call_args - assert call_args[0][2] == "plugin_cleanup_partial" + assert call_args[0][2] == f"{config_device_id}_plugin_cleanup_partial" @pytest.mark.asyncio @@ -1003,8 +1010,8 @@ async def test_async_setup_entry_firmware_at_or_above_ltd_deletes_previous_issue # Expect delete_issue to be called for the previous below-min and below-ltd issue ids assert delete_issue_mock.called, "async_delete_issue should have been called" called_issue_ids = [call[0][2] for call in delete_issue_mock.call_args_list if len(call[0]) > 2] - expected_min = f"opnsense_{init_mod.OPNSENSE_LTD_FIRMWARE}_below_min_firmware_{init_mod.OPNSENSE_MIN_FIRMWARE}" - expected_ltd = f"opnsense_{init_mod.OPNSENSE_LTD_FIRMWARE}_below_ltd_firmware_{init_mod.OPNSENSE_LTD_FIRMWARE}" + expected_min = f"{entry.data[init_mod.CONF_DEVICE_UNIQUE_ID]}_opnsense_below_min_firmware_{init_mod.OPNSENSE_MIN_FIRMWARE}" + expected_ltd = f"{entry.data[init_mod.CONF_DEVICE_UNIQUE_ID]}_opnsense_below_ltd_firmware_{init_mod.OPNSENSE_LTD_FIRMWARE}" assert expected_min in called_issue_ids assert expected_ltd in called_issue_ids @@ -1037,9 +1044,9 @@ async def test_async_setup_entry_delete_uses_actual_firmware_string( res = await init_mod.async_setup_entry(hass, entry) assert res is True - # Confirm delete_issue was called with the firmware returned by the client - expected_min = f"opnsense_{firmware_str}_below_min_firmware_{init_mod.OPNSENSE_MIN_FIRMWARE}" - expected_ltd = f"opnsense_{firmware_str}_below_ltd_firmware_{init_mod.OPNSENSE_LTD_FIRMWARE}" + # Confirm delete_issue was called for the expected issue ids + expected_min = f"{entry.data[init_mod.CONF_DEVICE_UNIQUE_ID]}_opnsense_below_min_firmware_{init_mod.OPNSENSE_MIN_FIRMWARE}" + expected_ltd = f"{entry.data[init_mod.CONF_DEVICE_UNIQUE_ID]}_opnsense_below_ltd_firmware_{init_mod.OPNSENSE_LTD_FIRMWARE}" assert calls.called, "async_delete_issue should have been called" issue_ids = [call[0][2] for call in calls.call_args_list if len(call[0]) > 2] assert expected_min in issue_ids From 40bd87bd55e40aa59490359f4845d030991eee46 Mon Sep 17 00:00:00 2001 From: Snuffy2 Date: Fri, 6 Feb 2026 16:49:51 -0500 Subject: [PATCH 14/15] ruff update and fixes --- .pre-commit-config.yaml | 2 +- tests/conftest.py | 2 +- tests/test_device_tracker.py | 8 ++++---- tests/test_integration.py | 8 ++++---- tests/test_pyopnsense.py | 2 +- tests/test_switch.py | 4 ++-- 6 files changed, 13 insertions(+), 13 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5cca521a..9660faab 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -31,7 +31,7 @@ repos: additional_dependencies: - tomli - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.14.14 + rev: v0.15.0 hooks: # Run the linter. - id: ruff-check diff --git a/tests/conftest.py b/tests/conftest.py index 71dadddd..97d3dc6b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -74,7 +74,7 @@ def _ensure_async_create_task_mock(real, side_effect): object.__setattr__( real, "async_create_task", - MagicMock(side_effect=lambda coro, *a, **k: orig(coro, *a, **k)), + MagicMock(side_effect=orig), ) diff --git a/tests/test_device_tracker.py b/tests/test_device_tracker.py index a377bdcc..b1d4b6d1 100644 --- a/tests/test_device_tracker.py +++ b/tests/test_device_tracker.py @@ -31,7 +31,7 @@ async def test_async_setup_entry_configured_devices( ) # attach coordinator into runtime_data under the expected attribute name setattr(entry.runtime_data, dt_mod.DEVICE_TRACKER_COORDINATOR, coordinator) - entry.add_update_listener = lambda f: (lambda: None) + entry.add_update_listener = lambda f: lambda: None entry.async_on_unload = lambda x: None hass = ph_hass hass.config_entries.async_update_entry = MagicMock() @@ -95,7 +95,7 @@ async def test_async_setup_entry_removes_nonmatching_tracked_macs( entry_id="eid_remove", ) setattr(entry.runtime_data, dt_mod.DEVICE_TRACKER_COORDINATOR, coordinator) - entry.add_update_listener = lambda f: (lambda: None) + entry.add_update_listener = lambda f: lambda: None entry.async_on_unload = lambda x: None hass = ph_hass @@ -307,7 +307,7 @@ async def test_async_setup_entry_state_not_mapping(ph_hass, coordinator, make_co hass.data = {} hass.config_entries.async_update_entry = MagicMock() - await dt_mod.async_setup_entry(hass, entry, lambda x: added.extend(x)) + await dt_mod.async_setup_entry(hass, entry, added.extend) assert len(added) == 0 assert not hass.config_entries.async_update_entry.called @@ -437,7 +437,7 @@ async def test_async_setup_entry_from_arp_entries( added = [] - await dt_mod.async_setup_entry(hass, entry, lambda ents: added.extend(ents)) + await dt_mod.async_setup_entry(hass, entry, added.extend) assert len(added) == 2 assert all(isinstance(e, dt_mod.OPNsenseScannerEntity) for e in added) assert {e.unique_id for e in added} == {"dev1_mac_m1", "dev1_mac_m2"} diff --git a/tests/test_integration.py b/tests/test_integration.py index c185ef59..f0049bb7 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -173,7 +173,7 @@ async def async_reload(self, entry_id): # pragma: no cover - reload path not as return None hass.config_entries = _Cfg() - hass.async_create_task = MagicMock(side_effect=lambda coro: asyncio.create_task(coro)) + hass.async_create_task = MagicMock(side_effect=asyncio.create_task) return hass @@ -223,7 +223,7 @@ async def _noop_unique_id(*a, **k): options={}, ) # Provide stubs expected by integration (update listener registration returns unsubscribe) - entry.add_update_listener = lambda f: (lambda: None) + entry.add_update_listener = lambda f: lambda: None entry.async_on_unload = lambda x: None hass.data = {} @@ -286,7 +286,7 @@ async def _noop_unique_id(*a, **k): # redefined for this test context entry_id="entry_gran", options={}, ) - entry.add_update_listener = lambda f: (lambda: None) + entry.add_update_listener = lambda f: lambda: None entry.async_on_unload = lambda x: None # Add to fake hass store so options flow update calls can mutate it @@ -396,7 +396,7 @@ async def _noop_unique_id(*a, **k): entry_id="entry_rel", options={}, ) - entry.add_update_listener = lambda f: (lambda: None) + entry.add_update_listener = lambda f: lambda: None entry.async_on_unload = lambda x: None # Setup diff --git a/tests/test_pyopnsense.py b/tests/test_pyopnsense.py index 0459ea8e..02035267 100644 --- a/tests/test_pyopnsense.py +++ b/tests/test_pyopnsense.py @@ -1091,7 +1091,7 @@ def mkarg(pname: str): for name, func in _inspect.getmembers( pyopnsense.OPNsenseClient, predicate=_inspect.iscoroutinefunction ) - if name not in ("__init__",) + if name != "__init__" ] for name, _func in coros: diff --git a/tests/test_switch.py b/tests/test_switch.py index 7b75aa16..84638937 100644 --- a/tests/test_switch.py +++ b/tests/test_switch.py @@ -741,7 +741,7 @@ async def test_vpn_turn_on_off_noops_when_preconditions_fail( # replace async_call_later to avoid scheduling monkeypatch.setattr( "custom_components.opnsense.switch.async_call_later", - lambda hass, delay, action: (lambda: None), + lambda hass, delay, action: lambda: None, ) desc = SwitchEntityDescription(key="openvpn.clients.c2", name="VPNC2") @@ -794,7 +794,7 @@ async def test_vpn_async_turn_off_variations( # avoid scheduling real async_call_later during tests monkeypatch.setattr( "custom_components.opnsense.switch.async_call_later", - lambda hass, delay, action: (lambda: None), + lambda hass, delay, action: lambda: None, ) desc = SwitchEntityDescription(key="openvpn.clients.cx", name="VPNCX") From 8d5cd7a3e029b1898fc9fd28851704f12a2e3f74 Mon Sep 17 00:00:00 2001 From: Snuffy2 Date: Fri, 6 Feb 2026 17:09:50 -0500 Subject: [PATCH 15/15] Refinements --- custom_components/opnsense/__init__.py | 8 ++-- .../opnsense/translations/en.json | 4 +- tests/test_init.py | 14 +++++-- tests/test_pyopnsense.py | 37 +++++++++++++++++++ 4 files changed, 54 insertions(+), 9 deletions(-) diff --git a/custom_components/opnsense/__init__.py b/custom_components/opnsense/__init__.py index fd835eca..fe373c89 100644 --- a/custom_components/opnsense/__init__.py +++ b/custom_components/opnsense/__init__.py @@ -346,7 +346,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def _deprecated_plugin_cleanup_26_1_1( hass: HomeAssistant, client: OPNsenseClient, entry_id: str, config_device_id: str ) -> None: - """Clean up deprecated entities for OPNsense 26.1 and plugin compatibility. + """Clean up deprecated entities for OPNsense 26.1.1 and plugin compatibility. This function removes switch entities that are no longer supported in OPNsense 26.1, specifically firewall filter rules (when plugin not installed) @@ -365,7 +365,7 @@ async def _deprecated_plugin_cleanup_26_1_1( The device unique ID from the configuration. """ - _LOGGER.debug("Starting OPNsense 26.1 and Plugin cleanup") + _LOGGER.debug("Starting OPNsense 26.1.1 and Plugin cleanup") entity_registry = er.async_get(hass) plugin_installed: bool = await client.is_plugin_installed() cleanup_started: bool = False @@ -394,7 +394,7 @@ async def _deprecated_plugin_cleanup_26_1_1( if cleanup_started: if plugin_installed: _LOGGER.info( - "OPNsense 26.1 and Plugin cleanup partially completed. Plugin is still installed. NAT Outbound and NAT Port Forward rules removed. Firewall Filter rules will be removed once the plugin is removed." + "OPNsense 26.1.1 and Plugin cleanup partially completed. Plugin is still installed. NAT Outbound and NAT Port Forward rules removed. Firewall Filter rules will be removed once the plugin is removed." ) ir.async_create_issue( hass, @@ -408,7 +408,7 @@ async def _deprecated_plugin_cleanup_26_1_1( ) else: _LOGGER.info( - "OPNsense 26.1 and Plugin cleanup completed. NAT Outbound, NAT Port Forward, and Firewall Filter rules removed." + "OPNsense 26.1.1 and Plugin cleanup completed. NAT Outbound, NAT Port Forward, and Firewall Filter rules removed." ) ir.async_create_issue( hass, diff --git a/custom_components/opnsense/translations/en.json b/custom_components/opnsense/translations/en.json index 229214b5..8218e09a 100644 --- a/custom_components/opnsense/translations/en.json +++ b/custom_components/opnsense/translations/en.json @@ -119,11 +119,11 @@ "description": "OPNsense Device ID has changed which indicates new or changed hardware. In order to accommodate this, hass-opnsense needs to be removed and reinstalled for this router. hass-opnsense is shutting down." }, "plugin_cleanup_partial": { - "title": "OPNsense 26.1 and Plugin cleanup partially completed", + "title": "OPNsense 26.1.1 and Plugin cleanup partially completed", "description": "OPNsense Plugin is still installed. NAT Outbound and NAT Port Forward rules removed. Firewall Filter rules will be removed once the plugin is removed." }, "plugin_cleanup_done": { - "title": "OPNsense 26.1 and Plugin cleanup completed", + "title": "OPNsense 26.1.1 and Plugin cleanup completed", "description": "NAT Outbound, NAT Port Forward, and Firewall Filter rules removed." } }, diff --git a/tests/test_init.py b/tests/test_init.py index 62ca8b08..4f860e33 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -585,20 +585,23 @@ async def test_deprecated_plugin_cleanup_26_1_1_plugin_installed(monkeypatch): entity_registry = MagicMock() monkeypatch.setattr(init_mod.er, "async_get", lambda hass: entity_registry) - # Mock entities: NAT port forward, NAT outbound, normal switch + # Mock entities: NAT port forward, NAT outbound, filter, normal switch nat_pf_entity = MagicMock() nat_pf_entity.entity_id = "switch.opnsense_nat_port_forward_rule" nat_pf_entity.unique_id = "dev1_nat_port_forward_rule1" nat_out_entity = MagicMock() nat_out_entity.entity_id = "switch.opnsense_nat_outbound_rule" nat_out_entity.unique_id = "dev1_nat_outbound_rule1" + filter_entity = MagicMock() + filter_entity.entity_id = "switch.opnsense_filter_rule" + filter_entity.unique_id = "dev1_filter_rule1" normal_entity = MagicMock() normal_entity.entity_id = "switch.opnsense_normal_rule" normal_entity.unique_id = "dev1_normal_rule1" monkeypatch.setattr( init_mod.er, "async_entries_for_config_entry", - MagicMock(return_value=[nat_pf_entity, nat_out_entity, normal_entity]), + MagicMock(return_value=[nat_pf_entity, nat_out_entity, filter_entity, normal_entity]), ) # Mock issue registry @@ -608,10 +611,15 @@ async def test_deprecated_plugin_cleanup_26_1_1_plugin_installed(monkeypatch): config_device_id = "dev1" await init_mod._deprecated_plugin_cleanup_26_1_1(hass, client, entry_id, config_device_id) - # Verify NAT entities were removed, normal was not + # Verify NAT entities were removed, filter and normal were not assert entity_registry.async_remove.call_count == 2 entity_registry.async_remove.assert_any_call("switch.opnsense_nat_port_forward_rule") entity_registry.async_remove.assert_any_call("switch.opnsense_nat_outbound_rule") + # Ensure filter entity was preserved when plugin is installed + assert not entity_registry.async_remove.mock_calls or all( + call.args[0] != "switch.opnsense_filter_rule" + for call in entity_registry.async_remove.mock_calls + ) # Verify issue created for partial cleanup create_issue_mock.assert_called_once() call_args = create_issue_mock.call_args diff --git a/tests/test_pyopnsense.py b/tests/test_pyopnsense.py index 02035267..25558b55 100644 --- a/tests/test_pyopnsense.py +++ b/tests/test_pyopnsense.py @@ -3261,6 +3261,43 @@ async def test_get_firewall_new_api(make_client) -> None: await client.async_close() +@pytest.mark.asyncio +async def test_get_firewall_new_api_plugin_not_installed(make_client) -> None: + """get_firewall uses new API for OPNsense >= 26.1.1 but when plugin not installed it should skip config.""" + session = MagicMock(spec=aiohttp.ClientSession) + client = make_client(session=session) + client._firmware_version = "26.1.1" + + # Plugin not installed: shouldn't call get_config + client.is_plugin_installed = AsyncMock(return_value=False) + client.get_config = AsyncMock(return_value={"filter": {"rule": []}}) + client._get_firewall_rules = AsyncMock(return_value={"rule1": {"uuid": "rule1"}}) + client._get_nat_destination_rules = AsyncMock(return_value={"nat1": {"uuid": "nat1"}}) + client._get_nat_one_to_one_rules = AsyncMock(return_value={"one1": {"uuid": "one1"}}) + client._get_nat_source_rules = AsyncMock(return_value={"src1": {"uuid": "src1"}}) + client._get_nat_npt_rules = AsyncMock(return_value={"npt1": {"uuid": "npt1"}}) + + result = await client.get_firewall() + expected = { + "rules": {"rule1": {"uuid": "rule1"}}, + "nat": { + "d_nat": {"nat1": {"uuid": "nat1"}}, + "one_to_one": {"one1": {"uuid": "one1"}}, + "source_nat": {"src1": {"uuid": "src1"}}, + "npt": {"npt1": {"uuid": "npt1"}}, + }, + } + assert result == expected + client.is_plugin_installed.assert_awaited_once() + client.get_config.assert_not_awaited() + client._get_firewall_rules.assert_awaited_once() + client._get_nat_destination_rules.assert_awaited_once() + client._get_nat_one_to_one_rules.assert_awaited_once() + client._get_nat_source_rules.assert_awaited_once() + client._get_nat_npt_rules.assert_awaited_once() + await client.async_close() + + @pytest.mark.asyncio async def test_get_firewall_version_compare_exception(make_client) -> None: """get_firewall handles AwesomeVersionCompareException gracefully."""