diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fc24c081..9660faab 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,7 +12,7 @@ repos: - id: check-toml - id: check-added-large-files - repo: https://github.com/rhysd/actionlint - rev: v1.7.9 + rev: v1.7.10 hooks: - id: actionlint # Note: shellcheck cannot directly parse YAML; actionlint extracts workflow @@ -31,7 +31,7 @@ repos: additional_dependencies: - tomli - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.14.8 + rev: v0.15.0 hooks: # Run the linter. - id: ruff-check diff --git a/custom_components/opnsense/__init__.py b/custom_components/opnsense/__init__.py index 9af50965..fe373c89 100644 --- a/custom_components/opnsense/__init__.py +++ b/custom_components/opnsense/__init__.py @@ -1,4 +1,9 @@ -"""Support for OPNsense.""" +"""Home Assistant integration for OPNsense firewalls. + +This integration provides monitoring and control of OPNsense firewall devices, +including system information, network interfaces, firewall rules, DHCP leases, +and various other OPNsense features through the Home Assistant interface. +""" from collections.abc import Mapping, MutableMapping from datetime import timedelta @@ -59,7 +64,21 @@ async def _async_update_listener(hass: HomeAssistant, entry: ConfigEntry) -> None: - """Handle options update.""" + """Handle options update for the OPNsense integration. + + This function is called when the configuration entry options are updated. + It handles reloading the integration if necessary, removing entities that + are no longer enabled based on granular sync options, and cleaning up + device tracker devices if device tracking is disabled. + + Parameters + ---------- + hass : HomeAssistant + The Home Assistant instance. + entry : ConfigEntry + The configuration entry for the OPNsense integration. + + """ # _LOGGER.debug("[async_update_listener] entry: %s", entry.as_dict()) if getattr(entry.runtime_data, SHOULD_RELOAD, True): _LOGGER.info("[async_update_listener] Reloading") @@ -104,13 +123,55 @@ async def _async_update_listener(hass: HomeAssistant, entry: ConfigEntry) -> Non async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: - """Call the method to setup the integration-level services.""" + """Set up the OPNsense integration at the domain level. + + This function is called during Home Assistant startup to initialize + integration-level services for the OPNsense domain. + + Parameters + ---------- + hass : HomeAssistant + The Home Assistant instance. + config : ConfigType + The configuration dictionary (unused for config entry only integrations). + + Returns + ------- + bool + Always returns True to indicate successful setup. + + """ await async_setup_services(hass) return True async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: - """Set up OPNsense from a config entry.""" + """Set up the OPNsense integration from a configuration entry. + + This function initializes the OPNsense client, coordinators, and platforms + based on the provided configuration entry. It performs firmware version + checks, handles device ID validation, and sets up data coordinators for + state updates and device tracking. + + Parameters + ---------- + hass : HomeAssistant + The Home Assistant instance. + entry : ConfigEntry + The configuration entry containing OPNsense connection details. + + Returns + ------- + bool + True if setup was successful, False otherwise. + + Raises + ------ + Various exceptions may be raised during client initialization or firmware + checks, but they are handled internally with appropriate logging and issue + creation. + + """ config: Mapping[str, Any] = entry.data options: Mapping[str, Any] = entry.options # _LOGGER.debug("[async_setup_entry] entry: %s", entry.as_dict()) @@ -183,7 +244,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: ir.async_create_issue( hass, DOMAIN, - f"opnsense_{firmware}_below_min_firmware_{OPNSENSE_MIN_FIRMWARE}", + f"{config_device_id}_opnsense_below_min_firmware_{OPNSENSE_MIN_FIRMWARE}", is_fixable=False, is_persistent=False, issue_domain=DOMAIN, @@ -203,7 +264,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: ir.async_create_issue( hass, DOMAIN, - f"opnsense_{firmware}_below_ltd_firmware_{OPNSENSE_LTD_FIRMWARE}", + f"{config_device_id}_opnsense_below_ltd_firmware_{OPNSENSE_LTD_FIRMWARE}", is_fixable=False, is_persistent=False, issue_domain=DOMAIN, @@ -219,14 +280,24 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: ir.async_delete_issue( hass, DOMAIN, - f"opnsense_{firmware}_below_min_firmware_{OPNSENSE_MIN_FIRMWARE}", + f"{config_device_id}_opnsense_below_min_firmware_{OPNSENSE_MIN_FIRMWARE}", ) ir.async_delete_issue( hass, DOMAIN, - f"opnsense_{firmware}_below_ltd_firmware_{OPNSENSE_LTD_FIRMWARE}", + f"{config_device_id}_opnsense_below_ltd_firmware_{OPNSENSE_LTD_FIRMWARE}", ) - except awesomeversion.exceptions.AwesomeVersionCompareException: + except (awesomeversion.exceptions.AwesomeVersionCompareException, TypeError, ValueError): + _LOGGER.warning("Unable to confirm OPNsense Firmware version") + + try: + if awesomeversion.AwesomeVersion(firmware) > awesomeversion.AwesomeVersion( + "26.1" + ) and awesomeversion.AwesomeVersion(firmware) < awesomeversion.AwesomeVersion("26.7"): + await _deprecated_plugin_cleanup_26_1_1( + hass=hass, client=client, entry_id=entry.entry_id, config_device_id=config_device_id + ) + except (awesomeversion.exceptions.AwesomeVersionCompareException, TypeError, ValueError): _LOGGER.warning("Unable to confirm OPNsense Firmware version") await coordinator.async_config_entry_first_refresh() @@ -272,11 +343,108 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: return True +async def _deprecated_plugin_cleanup_26_1_1( + hass: HomeAssistant, client: OPNsenseClient, entry_id: str, config_device_id: str +) -> None: + """Clean up deprecated entities for OPNsense 26.1.1 and plugin compatibility. + + This function removes switch entities that are no longer supported in + OPNsense 26.1, specifically firewall filter rules (when plugin not installed) + and NAT port forward/outbound rules. It creates appropriate issues to + inform the user about the cleanup. + + Parameters + ---------- + hass : HomeAssistant + The Home Assistant instance. + client : OPNsenseClient + The OPNsense client instance. + entry_id : str + The configuration entry ID. + config_device_id : str + The device unique ID from the configuration. + + """ + _LOGGER.debug("Starting OPNsense 26.1.1 and Plugin cleanup") + entity_registry = er.async_get(hass) + plugin_installed: bool = await client.is_plugin_installed() + cleanup_started: bool = False + + for ent in er.async_entries_for_config_entry(entity_registry, entry_id): + platform = ent.entity_id.split(".")[0] + if platform != Platform.SWITCH: + continue + # _LOGGER.debug("[deprecated_plugin_cleanup] ent: %s", ent) + if ( + (not plugin_installed and "_filter_" in ent.unique_id) + or "_nat_port_forward_" in ent.unique_id + or "_nat_outbound_" in ent.unique_id + ): + cleanup_started = True + try: + entity_registry.async_remove(ent.entity_id) + _LOGGER.debug("[deprecated_plugin_cleanup] removed entity_id: %s", ent.entity_id) + except (KeyError, ValueError) as e: + _LOGGER.error( + "Error removing entity: %s. %s: %s", + ent.entity_id, + type(e).__name__, + e, + ) + if cleanup_started: + if plugin_installed: + _LOGGER.info( + "OPNsense 26.1.1 and Plugin cleanup partially completed. Plugin is still installed. NAT Outbound and NAT Port Forward rules removed. Firewall Filter rules will be removed once the plugin is removed." + ) + ir.async_create_issue( + hass, + DOMAIN, + f"{config_device_id}_plugin_cleanup_partial", + is_fixable=False, + is_persistent=False, + issue_domain=DOMAIN, + severity=ir.IssueSeverity.WARNING, + translation_key="plugin_cleanup_partial", + ) + else: + _LOGGER.info( + "OPNsense 26.1.1 and Plugin cleanup completed. NAT Outbound, NAT Port Forward, and Firewall Filter rules removed." + ) + ir.async_create_issue( + hass, + DOMAIN, + f"{config_device_id}_plugin_cleanup_done", + is_fixable=False, + is_persistent=False, + issue_domain=DOMAIN, + severity=ir.IssueSeverity.WARNING, + translation_key="plugin_cleanup_done", + ) + + async def async_remove_config_entry_device( hass: HomeAssistant, config_entry: ConfigEntry, device_entry: dr.DeviceEntry ) -> bool: - """Remove OPNsense Devices that aren't Device Tracker Devices and without any linked entities.""" - + """Remove OPNsense devices that are not device tracker devices and have no linked entities. + + This function checks if an OPNsense device can be safely removed. It prevents + removal of device tracker devices and devices that still have linked entities. + + Parameters + ---------- + hass : HomeAssistant + The Home Assistant instance. + config_entry : ConfigEntry + The configuration entry for the OPNsense integration. + device_entry : dr.DeviceEntry + The device entry to be removed. + + Returns + ------- + bool + True if the device can be removed, False otherwise. + + """ if device_entry.via_device_id: _LOGGER.error("Remove OPNsense Device Tracker Devices via the Integration Configuration") return False @@ -289,7 +457,24 @@ async def async_remove_config_entry_device( async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: - """Unload a config entry.""" + """Unload the OPNsense integration configuration entry. + + This function unloads all platforms associated with the configuration entry, + closes the OPNsense client connection, and cleans up the entry data. + + Parameters + ---------- + hass : HomeAssistant + The Home Assistant instance. + entry : ConfigEntry + The configuration entry to unload. + + Returns + ------- + bool + True if unloading was successful, False otherwise. + + """ _LOGGER.info("Unloading: %s", entry.as_dict()) platforms: list[Platform] = getattr(entry.runtime_data, LOADED_PLATFORMS) client: OPNsenseClient = getattr(entry.runtime_data, OPNSENSE_CLIENT) @@ -304,6 +489,24 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def _migrate_1_to_2(hass: HomeAssistant, config_entry: ConfigEntry) -> bool: + """Migrate configuration entry from version 1 to version 2. + + This migration replaces the deprecated 'tls_insecure' option with + 'verify_ssl' for SSL certificate verification. + + Parameters + ---------- + hass : HomeAssistant + The Home Assistant instance. + config_entry : ConfigEntry + The configuration entry to migrate. + + Returns + ------- + bool + Always returns True. + + """ tls_insecure = config_entry.data.get(CONF_TLS_INSECURE, DEFAULT_TLS_INSECURE) data: MutableMapping[str, Any] = dict(config_entry.data) @@ -320,6 +523,25 @@ async def _migrate_1_to_2(hass: HomeAssistant, config_entry: ConfigEntry) -> boo async def _migrate_2_to_3(hass: HomeAssistant, config_entry: ConfigEntry) -> bool: + """Migrate configuration entry from version 2 to version 3. + + This migration updates device unique IDs to use the lowest MAC address + and updates entity unique IDs accordingly. It also updates device + identifiers in the device registry. + + Parameters + ---------- + hass : HomeAssistant + The Home Assistant instance. + config_entry : ConfigEntry + The configuration entry to migrate. + + Returns + ------- + bool + True if migration was successful, False otherwise. + + """ _LOGGER.debug("[migrate_2_to_3] Initial Version: %s", config_entry.version) entity_registry = er.async_get(hass) device_registry = dr.async_get(hass) @@ -429,6 +651,25 @@ async def _migrate_2_to_3(hass: HomeAssistant, config_entry: ConfigEntry) -> boo async def _migrate_3_to_4(hass: HomeAssistant, config_entry: ConfigEntry) -> bool: + """Migrate configuration entry from version 3 to version 4. + + This migration moves telemetry-based entities (interfaces, gateways, openvpn) + out of the telemetry namespace and updates their unique IDs. It also removes + deprecated connected client count sensors. + + Parameters + ---------- + hass : HomeAssistant + The Home Assistant instance. + config_entry : ConfigEntry + The configuration entry to migrate. + + Returns + ------- + bool + True if migration was successful, False otherwise. + + """ _LOGGER.debug("[migrate_3_to_4] Initial Version: %s", config_entry.version) entity_registry = er.async_get(hass) @@ -454,7 +695,7 @@ async def _migrate_3_to_4(hass: HomeAssistant, config_entry: ConfigEntry) -> boo for ent in er.async_entries_for_config_entry(entity_registry, config_entry.entry_id): platform = ent.entity_id.split(".")[0] if platform == Platform.SENSOR: - # _LOGGER.debug(f"[migrate_3_to_4] ent: {ent}") + # _LOGGER.debug("[migrate_3_to_4] ent: %s", ent) if "_telemetry_interface_" in ent.unique_id: new_unique_id: str | None = ent.unique_id.replace( "_telemetry_interface_", "_interface_" @@ -532,7 +773,24 @@ async def _migrate_3_to_4(hass: HomeAssistant, config_entry: ConfigEntry) -> boo async def async_migrate_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool: - """Migrate an old config entry.""" + """Migrate an old configuration entry to the latest version. + + This function handles migration of configuration entries from older versions + to the current version by applying sequential migration steps. + + Parameters + ---------- + hass : HomeAssistant + The Home Assistant instance. + config_entry : ConfigEntry + The configuration entry to migrate. + + Returns + ------- + bool + True if migration was successful, False otherwise. + + """ version = config_entry.version if version > 4: diff --git a/custom_components/opnsense/config_flow.py b/custom_components/opnsense/config_flow.py index ae4d11d9..1edf3bcc 100644 --- a/custom_components/opnsense/config_flow.py +++ b/custom_components/opnsense/config_flow.py @@ -270,20 +270,27 @@ async def _handle_user_input( try: _validate_firmware_version(user_input[CONF_FIRMWARE_VERSION]) - except awesomeversion.exceptions.AwesomeVersionCompareException as e: + except (awesomeversion.exceptions.AwesomeVersionCompareException, TypeError, ValueError) as e: raise UnknownFirmware from e await client.set_use_snake_case(initial=True) # Plugin check not required for config step of user. Otherwise, plugin check is required if # granular sync options is enabled - require_plugin_check = ( - config_step != "user" - and user_input.get(CONF_GRANULAR_SYNC_OPTIONS, DEFAULT_GRANULAR_SYNC_OPTIONS) - and any( - user_input.get(item, DEFAULT_SYNC_OPTION_VALUE) for item in SYNC_ITEMS_REQUIRING_PLUGIN + try: + require_plugin_check = ( + config_step != "user" + and user_input.get(CONF_GRANULAR_SYNC_OPTIONS, DEFAULT_GRANULAR_SYNC_OPTIONS) + and any( + user_input.get(item, DEFAULT_SYNC_OPTION_VALUE) + for item in SYNC_ITEMS_REQUIRING_PLUGIN + ) + and awesomeversion.AwesomeVersion(user_input[CONF_FIRMWARE_VERSION]) + < awesomeversion.AwesomeVersion("26.1.1") ) - ) + except (awesomeversion.exceptions.AwesomeVersionCompareException, TypeError, ValueError) as e: + raise UnknownFirmware from e + _LOGGER.debug( "[handle_user_input] config_step: %s, require_plugin_check: %s", config_step, diff --git a/custom_components/opnsense/const.py b/custom_components/opnsense/const.py index 59a68a8c..7aa8ac62 100644 --- a/custom_components/opnsense/const.py +++ b/custom_components/opnsense/const.py @@ -55,7 +55,7 @@ CONF_SYNC_GATEWAYS = "sync_gateways" CONF_SYNC_SERVICES = "sync_services" CONF_SYNC_NOTICES = "sync_notices" -CONF_SYNC_FILTERS_AND_NAT = "sync_filters_and_nat" +CONF_SYNC_FIREWALL_AND_NAT = "sync_filters_and_nat" CONF_SYNC_UNBOUND = "sync_unbound" CONF_SYNC_INTERFACES = "sync_interfaces" CONF_SYNC_CERTIFICATES = "sync_certificates" @@ -63,7 +63,7 @@ DEFAULT_GRANULAR_SYNC_OPTIONS = False DEFAULT_SYNC_OPTION_VALUE = True -SYNC_ITEMS_REQUIRING_PLUGIN = (CONF_SYNC_FILTERS_AND_NAT,) +SYNC_ITEMS_REQUIRING_PLUGIN = (CONF_SYNC_FIREWALL_AND_NAT,) GRANULAR_SYNC_ITEMS = ( CONF_SYNC_TELEMETRY, CONF_SYNC_GATEWAYS, @@ -72,7 +72,7 @@ CONF_SYNC_NOTICES, CONF_SYNC_FIRMWARE_UPDATES, CONF_SYNC_CARP, - CONF_SYNC_FILTERS_AND_NAT, + CONF_SYNC_FIREWALL_AND_NAT, CONF_SYNC_SERVICES, CONF_SYNC_VPN, CONF_SYNC_CERTIFICATES, @@ -90,7 +90,7 @@ CONF_SYNC_UNBOUND: ["unbound"], CONF_SYNC_INTERFACES: ["interface"], CONF_SYNC_CERTIFICATES: ["certificates"], - CONF_SYNC_FILTERS_AND_NAT: ["filter", "nat"], + CONF_SYNC_FIREWALL_AND_NAT: ["filter", "nat"], } CONF_DEVICES = "devices" CONF_MANUAL_DEVICES = "manual_devices" diff --git a/custom_components/opnsense/coordinator.py b/custom_components/opnsense/coordinator.py index 61a2710e..8bacadde 100644 --- a/custom_components/opnsense/coordinator.py +++ b/custom_components/opnsense/coordinator.py @@ -17,7 +17,7 @@ CONF_SYNC_CARP, CONF_SYNC_CERTIFICATES, CONF_SYNC_DHCP_LEASES, - CONF_SYNC_FILTERS_AND_NAT, + CONF_SYNC_FIREWALL_AND_NAT, CONF_SYNC_FIRMWARE_UPDATES, CONF_SYNC_GATEWAYS, CONF_SYNC_INTERFACES, @@ -114,6 +114,7 @@ def _build_categories(self) -> list[MutableMapping[str, str]]: "function": "get_host_firmware_version", "state_key": "host_firmware_version", }, + {"function": "is_plugin_installed", "state_key": "plugin_installed"}, ] if config.get(CONF_SYNC_TELEMETRY, DEFAULT_SYNC_OPTION_VALUE): @@ -149,8 +150,8 @@ def _build_categories(self) -> list[MutableMapping[str, str]]: categories.append({"function": "get_services", "state_key": "services"}) if config.get(CONF_SYNC_NOTICES, DEFAULT_SYNC_OPTION_VALUE): categories.append({"function": "get_notices", "state_key": "notices"}) - if config.get(CONF_SYNC_FILTERS_AND_NAT, DEFAULT_SYNC_OPTION_VALUE): - categories.append({"function": "get_config", "state_key": "config"}) + if config.get(CONF_SYNC_FIREWALL_AND_NAT, DEFAULT_SYNC_OPTION_VALUE): + categories.append({"function": "get_firewall", "state_key": "firewall"}) if config.get(CONF_SYNC_UNBOUND, DEFAULT_SYNC_OPTION_VALUE): categories.append( { diff --git a/custom_components/opnsense/pyopnsense/__init__.py b/custom_components/opnsense/pyopnsense/__init__.py index a1159173..5f257d7a 100644 --- a/custom_components/opnsense/pyopnsense/__init__.py +++ b/custom_components/opnsense/pyopnsense/__init__.py @@ -331,7 +331,11 @@ async def set_use_snake_case(self, initial: bool = False) -> None: self._use_snake_case = False else: _LOGGER.debug("Using snake_case for OPNsense >= 25.7") - except awesomeversion.exceptions.AwesomeVersionCompareException as e: + except ( + awesomeversion.exceptions.AwesomeVersionCompareException, + TypeError, + ValueError, + ) as e: _LOGGER.error( "Error comparing firmware version %s. Using snake_case by default", self._firmware_version, @@ -383,9 +387,7 @@ async def _process_queue(self) -> None: method, path, payload, future, caller = await self._request_queue.get() try: if method == "get_from_stream": - result: MutableMapping[str, Any] | list | None = await self._do_get_from_stream( - path, caller - ) + result: Any = await self._do_get_from_stream(path, caller) if future is not None and not future.done(): future.set_result(result) elif method == "get": @@ -698,7 +700,11 @@ async def get_firmware_update_info(self) -> MutableMapping[str, Any]: ): _LOGGER.debug("Update available but missing details") update_needs_info = True - except awesomeversion.exceptions.AwesomeVersionCompareException as e: + except ( + awesomeversion.exceptions.AwesomeVersionCompareException, + TypeError, + ValueError, + ) as e: _LOGGER.debug("Error checking firmware versions. %s: %s", type(e).__name__, e) update_needs_info = True @@ -772,7 +778,7 @@ async def get_config(self) -> MutableMapping[str, Any]: return ret_data @_log_errors - async def enable_filter_rule_by_created_time(self, created_time: str) -> None: + async def enable_filter_rule_by_created_time_legacy(self, created_time: str) -> None: """Enable a filter rule.""" config = await self.get_config() for rule in config["filter"]["rule"]: @@ -789,7 +795,7 @@ async def enable_filter_rule_by_created_time(self, created_time: str) -> None: await self._filter_configure() @_log_errors - async def disable_filter_rule_by_created_time(self, created_time: str) -> None: + async def disable_filter_rule_by_created_time_legacy(self, created_time: str) -> None: """Disable a filter rule.""" config: MutableMapping[str, Any] = await self.get_config() @@ -808,7 +814,7 @@ async def disable_filter_rule_by_created_time(self, created_time: str) -> None: # use created_time as a unique_id since none other exists @_log_errors - async def enable_nat_port_forward_rule_by_created_time(self, created_time: str) -> None: + async def enable_nat_port_forward_rule_by_created_time_legacy(self, created_time: str) -> None: """Enable a NAT Port Forward rule.""" config: MutableMapping[str, Any] = await self.get_config() for rule in config.get("nat", {}).get("rule", []): @@ -826,7 +832,7 @@ async def enable_nat_port_forward_rule_by_created_time(self, created_time: str) # use created_time as a unique_id since none other exists @_log_errors - async def disable_nat_port_forward_rule_by_created_time(self, created_time: str) -> None: + async def disable_nat_port_forward_rule_by_created_time_legacy(self, created_time: str) -> None: """Disable a NAT Port Forward rule.""" config: MutableMapping[str, Any] = await self.get_config() for rule in config.get("nat", {}).get("rule", []): @@ -844,7 +850,7 @@ async def disable_nat_port_forward_rule_by_created_time(self, created_time: str) # use created_time as a unique_id since none other exists @_log_errors - async def enable_nat_outbound_rule_by_created_time(self, created_time: str) -> None: + async def enable_nat_outbound_rule_by_created_time_legacy(self, created_time: str) -> None: """Enable NAT Outbound rule.""" config: MutableMapping[str, Any] = await self.get_config() for rule in config.get("nat", {}).get("outbound", {}).get("rule", []): @@ -862,10 +868,12 @@ async def enable_nat_outbound_rule_by_created_time(self, created_time: str) -> N # use created_time as a unique_id since none other exists @_log_errors - async def disable_nat_outbound_rule_by_created_time(self, created_time: str) -> None: + async def disable_nat_outbound_rule_by_created_time_legacy(self, created_time: str) -> None: """Disable NAT Outbound Rule.""" config: MutableMapping[str, Any] = await self.get_config() for rule in config.get("nat", {}).get("outbound", {}).get("rule", []): + if "created" not in rule or "time" not in rule["created"]: + continue if rule["created"]["time"] != created_time: continue @@ -874,6 +882,274 @@ async def disable_nat_outbound_rule_by_created_time(self, created_time: str) -> await self._restore_config_section("nat", config["nat"]) await self._filter_configure() + @_log_errors + async def get_firewall(self) -> dict[str, Any]: + """Retrieve all firewall and NAT rules from OPNsense. + + Returns + ------- + dict + A dictionary representing the firewall and NAT rules. + + """ + if self._firmware_version is None: + await self.get_host_firmware_version() + + try: + if awesomeversion.AwesomeVersion( + self._firmware_version + ) < awesomeversion.AwesomeVersion("26.1.1"): + _LOGGER.debug("Using legacy plugin for firewall filters for OPNsense < 26.1.1") + return {"config": await self.get_config()} + except (awesomeversion.exceptions.AwesomeVersionCompareException, TypeError, ValueError): + _LOGGER.warning("Error comparing firmware version. Skipping get_firewall.") + return {} + firewall: dict[str, Any] = {"nat": {}} + if await self.is_plugin_installed(): + firewall["config"] = await self.get_config() + firewall["rules"] = await self._get_firewall_rules() + firewall["nat"]["d_nat"] = await self._get_nat_destination_rules() + firewall["nat"]["one_to_one"] = await self._get_nat_one_to_one_rules() + firewall["nat"]["source_nat"] = await self._get_nat_source_rules() + firewall["nat"]["npt"] = await self._get_nat_npt_rules() + # _LOGGER.debug("[get_firewall] firewall: %s", firewall) + return firewall + + @_log_errors + async def _get_firewall_rules(self) -> dict[str, Any]: + """Retrieve firewall rules from OPNsense. + + Parameters + ---------- + interface_map : dict[str, Any] + A mapping of interface names to firewall interface names. + + Returns + ------- + dict[str, Any] + A dictionary of firewall rules, keyed by UUID, each representing a firewall rule parsed + from CSV format. Dictionary keys correspond to CSV headers such + as 'uuid', 'enabled', 'action', etc. + + """ + request_body: MutableMapping[str, Any] = {"current": 1, "sort": {}} + response = await self._safe_dict_post( + "/api/firewall/filter/search_rule", payload=request_body + ) + # _LOGGER.debug("[get_firewall_rules] response: %s", response) + rules: list = response.get("rows", []) + _LOGGER.debug("[get_firewall_rules] rules: %s", rules) + rules_dict: dict[str, Any] = {} + for rule in rules: + if not rule.get("uuid") or "lockout" in rule.get("uuid"): + continue + new_rule = rule.copy() + # Add any transforms here + rules_dict[new_rule["uuid"]] = new_rule + _LOGGER.debug("[get_firewall_rules] rules_dict: %s", rules_dict) + return rules_dict + + @_log_errors + async def _get_nat_destination_rules(self) -> dict[str, Any]: + """Retrieve NAT destination rules from OPNsense. + + Returns + ------- + dict[str, Any] + A dictionary of NAT destination rules, keyed by UUID. + + """ + request_body: MutableMapping[str, Any] = {"current": 1, "sort": {}} + response = await self._safe_dict_post( + "/api/firewall/d_nat/search_rule", payload=request_body + ) + # _LOGGER.debug("[get_nat_destination_rules] response: %s", response) + rules: list = response.get("rows", []) + # _LOGGER.debug("[get_nat_destination_rules] rules: %s", rules) + rules_dict: dict[str, Any] = {} + for rule in rules: + if not rule.get("uuid") or "lockout" in rule.get("uuid"): + continue # skip lockout rules + new_rule = rule.copy() + new_rule["description"] = new_rule.pop("descr", "") + new_rule["enabled"] = "1" if new_rule.pop("disabled", "0") == "0" else "0" + rules_dict[new_rule["uuid"]] = new_rule + _LOGGER.debug("[get_nat_destination_rules] rules_dict: %s", rules_dict) + return rules_dict + + @_log_errors + async def _get_nat_one_to_one_rules(self) -> dict[str, Any]: + """Retrieve NAT one-to-one rules from OPNsense. + + Returns + ------- + dict[str, Any] + A dictionary of NAT one-to-one rules, keyed by UUID. + + """ + request_body: MutableMapping[str, Any] = {"current": 1, "sort": {}} + response = await self._safe_dict_post( + "/api/firewall/one_to_one/search_rule", payload=request_body + ) + # _LOGGER.debug("[get_nat_one_to_one_rules] response: %s", response) + rules: list = response.get("rows", []) + _LOGGER.debug("[get_nat_one_to_one_rules] rules: %s", rules) + rules_dict: dict[str, Any] = {} + for rule in rules: + if not rule.get("uuid") or "lockout" in rule.get("uuid"): + continue + new_rule = rule.copy() + # Add any transforms here + rules_dict[new_rule["uuid"]] = new_rule + _LOGGER.debug("[get_nat_one_to_one_rules] rules_dict: %s", rules_dict) + return rules_dict + + @_log_errors + async def _get_nat_source_rules(self) -> dict[str, Any]: + """Retrieve NAT source rules from OPNsense. + + Returns + ------- + dict[str, Any] + A dictionary of NAT source rules, keyed by UUID. + + """ + request_body: MutableMapping[str, Any] = {"current": 1, "sort": {}} + response = await self._safe_dict_post( + "/api/firewall/source_nat/search_rule", payload=request_body + ) + # _LOGGER.debug("[get_nat_source_rules] response: %s", response) + rules: list = response.get("rows", []) + # _LOGGER.debug("[get_nat_source_rules] rules: %s", rules) + rules_dict: dict[str, Any] = {} + for rule in rules: + if not rule.get("uuid") or "lockout" in rule.get("uuid"): + continue + new_rule = rule.copy() + # Add any transforms here + rules_dict[new_rule["uuid"]] = new_rule + _LOGGER.debug("[get_nat_source_rules] rules_dict: %s", rules_dict) + return rules_dict + + @_log_errors + async def _get_nat_npt_rules(self) -> dict[str, Any]: + """Retrieve NAT NPT rules from OPNsense. + + Returns + ------- + dict[str, Any] + A dictionary of NAT NPT rules, keyed by UUID. + + """ + request_body: MutableMapping[str, Any] = {"current": 1, "sort": {}} + response = await self._safe_dict_post("/api/firewall/npt/search_rule", payload=request_body) + # _LOGGER.debug("[get_nat_npt_rules] response: %s", response) + rules: list = response.get("rows", []) + # _LOGGER.debug("[get_nat_npt_rules] rules: %s", rules) + rules_dict: dict[str, Any] = {} + for rule in rules: + if not rule.get("uuid") or "lockout" in rule.get("uuid"): + continue + new_rule = rule.copy() + # Add any transforms here + rules_dict[new_rule["uuid"]] = new_rule + _LOGGER.debug("[get_nat_npt_rules] rules_dict: %s", rules_dict) + return rules_dict + + async def toggle_firewall_rule(self, uuid: str, toggle_on_off: str | None = None) -> bool: + """Toggle Firewall Rule on and off. + + Parameters + ---------- + uuid : str + The UUID of the firewall rule to toggle. + toggle_on_off : str | None, optional + The action to perform: 'on' to enable, 'off' to disable, or None to toggle. + + Returns + ------- + bool + True if the operation was successful, False otherwise. + + """ + payload: MutableMapping[str, Any] = {} + url = f"/api/firewall/filter/toggle_rule/{uuid}" + if toggle_on_off == "on": + url = f"{url}/1" + elif toggle_on_off == "off": + url = f"{url}/0" + response = await self._safe_dict_post( + url, + payload=payload, + ) + _LOGGER.debug( + "[toggle_firewall_rule] uuid: %s, action: %s, url: %s, response: %s", + uuid, + toggle_on_off, + url, + response, + ) + if response.get("result") == "failed": + return False + + apply_resp = await self._safe_dict_post("/api/firewall/filter/apply") + if apply_resp.get("status", "").strip() != "OK": + return False + + return True + + async def toggle_nat_rule( + self, nat_rule_type: str, uuid: str, toggle_on_off: str | None = None + ) -> bool: + """Toggle NAT Rule on and off. + + Parameters + ---------- + nat_rule_type : str + The type of NAT rule (e.g., 'd_nat', 'source_nat'). + uuid : str + The UUID of the NAT rule to toggle. + toggle_on_off : str | None, optional + The action to perform: 'on' to enable, 'off' to disable, or None to toggle. + + Returns + ------- + bool + True if the operation was successful, False otherwise. + + """ + payload: MutableMapping[str, Any] = {} + url = f"/api/firewall/{nat_rule_type}/toggle_rule/{uuid}" + # d_nat uses opposite logic for on/off + if nat_rule_type == "d_nat" and toggle_on_off is not None: + if toggle_on_off == "on": + url = f"{url}/0" + elif toggle_on_off == "off": + url = f"{url}/1" + elif toggle_on_off == "on": + url = f"{url}/1" + elif toggle_on_off == "off": + url = f"{url}/0" + response = await self._safe_dict_post( + url, + payload=payload, + ) + _LOGGER.debug( + "[toggle_nat_rule] uuid: %s, action: %s, url: %s, response: %s", + uuid, + toggle_on_off, + url, + response, + ) + if response.get("result") == "failed": + return False + + apply_resp = await self._safe_dict_post(f"/api/firewall/{nat_rule_type}/apply") + if apply_resp.get("status", "").strip() != "OK": + return False + + return True + @_log_errors async def get_arp_table(self, resolve_hostnames: bool = False) -> list: """Return the active ARP table.""" @@ -1089,7 +1365,7 @@ async def _get_dnsmasq_leases(self) -> list: ) < awesomeversion.AwesomeVersion("25.1.7"): _LOGGER.debug("Skipping get_dnsmasq_leases for OPNsense < 25.1.7") return [] - except awesomeversion.exceptions.AwesomeVersionCompareException: + except (awesomeversion.exceptions.AwesomeVersionCompareException, TypeError, ValueError): pass response = await self._safe_dict_get("/api/dnsmasq/leases/search") @@ -2351,7 +2627,7 @@ async def kill_states(self, ip_addr: str) -> MutableMapping[str, Any]: "dropped_states": response.get("dropped_states", 0), } - async def toggle_alias(self, alias: str, toggle_on_off: str) -> bool: + async def toggle_alias(self, alias: str, toggle_on_off: str | None = None) -> bool: """Toggle alias on and off.""" if self._use_snake_case: alias_list_resp = await self._safe_dict_get("/api/firewall/alias/search_item") diff --git a/custom_components/opnsense/switch.py b/custom_components/opnsense/switch.py index fbe0b547..6d2c7686 100644 --- a/custom_components/opnsense/switch.py +++ b/custom_components/opnsense/switch.py @@ -16,7 +16,7 @@ ATTR_NAT_OUTBOUND, ATTR_NAT_PORT_FORWARD, ATTR_UNBOUND_BLOCKLIST, - CONF_SYNC_FILTERS_AND_NAT, + CONF_SYNC_FIREWALL_AND_NAT, CONF_SYNC_SERVICES, CONF_SYNC_UNBOUND, CONF_SYNC_VPN, @@ -30,128 +30,186 @@ _LOGGER: logging.Logger = logging.getLogger(__name__) -async def _compile_filter_switches( +async def _compile_filter_switches_legacy( config_entry: ConfigEntry, coordinator: OPNsenseDataUpdateCoordinator, state: MutableMapping[str, Any], ) -> list: - if not isinstance(state, MutableMapping) or not isinstance(state.get("config"), MutableMapping): + """Compile legacy filter rule switches from OPNsense state. + + Parameters + ---------- + config_entry : ConfigEntry + The Home Assistant config entry. + coordinator : OPNsenseDataUpdateCoordinator + The data update coordinator. + state : MutableMapping[str, Any] + The current state data from OPNsense. + + Returns + ------- + list + A list of OPNsenseFilterSwitchLegacy entities. + + """ + if not isinstance(state, MutableMapping) or not isinstance( + state.get("firewall", {}).get("config", {}).get("filter", {}).get("rule"), list + ): return [] entities: list = [] # filter rules - if "filter" in state.get("config", {}): - rules = dict_get(state, "config.filter.rule") - if isinstance(rules, list): - for rule in rules: - if not isinstance(rule, dict): - continue + rules = dict_get(state, "firewall.config.filter.rule") + if isinstance(rules, list): + for rule in rules: + if not isinstance(rule, dict): + continue - # do NOT add rules that are NAT rules - if "associated-rule-id" in rule: - continue + # do NOT add rules that are NAT rules + if "associated-rule-id" in rule: + continue - # not possible to disable these rules - if rule.get("descr", "") == "Anti-Lockout Rule": - continue + # not possible to disable these rules + if rule.get("description", "") == "Anti-Lockout Rule": + continue - tracker = dict_get(rule, "created.time") - # we use tracker as the unique id - if tracker is None or len(tracker) < 1: - continue + tracker = dict_get(rule, "created.time") + # we use tracker as the unique id + if tracker is None or len(tracker) < 1: + continue - entities.append( - OPNsenseFilterSwitch( - config_entry=config_entry, - coordinator=coordinator, - entity_description=SwitchEntityDescription( - key=f"filter.{tracker}", - name=f"Filter Rule {tracker} ({rule.get('descr', '')})", - icon="mdi:play-network-outline", - # entity_category=entity_category, - device_class=SwitchDeviceClass.SWITCH, - entity_registry_enabled_default=False, - ), - ) + entities.append( + OPNsenseFilterSwitchLegacy( + config_entry=config_entry, + coordinator=coordinator, + entity_description=SwitchEntityDescription( + key=f"filter.{tracker}", + name=f"Filter Rule {tracker} ({rule.get('descr', '')})", + icon="mdi:play-network-outline", + # entity_category=entity_category, + device_class=SwitchDeviceClass.SWITCH, + entity_registry_enabled_default=False, + ), ) + ) + _LOGGER.debug("[compile_filter_switches_legacy] entities: %s", len(entities)) return entities -async def _compile_port_forward_switches( +async def _compile_port_forward_switches_legacy( config_entry: ConfigEntry, coordinator: OPNsenseDataUpdateCoordinator, state: MutableMapping[str, Any], ) -> list: - if not isinstance(state, MutableMapping) or not isinstance(state.get("config"), MutableMapping): + """Compile legacy NAT port forward rule switches from OPNsense state. + + Parameters + ---------- + config_entry : ConfigEntry + The Home Assistant config entry. + coordinator : OPNsenseDataUpdateCoordinator + The data update coordinator. + state : MutableMapping[str, Any] + The current state data from OPNsense. + + Returns + ------- + list + A list of OPNsenseNatSwitchLegacy entities for port forward rules. + + """ + if not isinstance(state, MutableMapping) or not isinstance( + state.get("firewall", {}).get("config", {}).get("nat", {}).get("rule"), list + ): return [] entities: list = [] # nat port forward rules - if "nat" in state.get("config", {}): - rules = dict_get(state, "config.nat.rule") - if isinstance(rules, list): - for rule in rules: - if not isinstance(rule, dict): - continue + rules = dict_get(state, "firewall.config.nat.rule") + if isinstance(rules, list): + for rule in rules: + if not isinstance(rule, dict): + continue - tracker = dict_get(rule, "created.time") - # we use tracker as the unique id - if tracker is None or len(tracker) < 1: - continue + tracker = dict_get(rule, "created.time") + # we use tracker as the unique id + if tracker is None or len(tracker) < 1: + continue - entity = OPNsenseNatSwitch( - config_entry=config_entry, - coordinator=coordinator, - entity_description=SwitchEntityDescription( - key=f"nat_port_forward.{tracker}", - name=f"NAT Port Forward Rule {tracker} ({rule.get('descr', '')})", - icon="mdi:network-outline", - # entity_category=ENTITY_CATEGORY_CONFIG, - device_class=SwitchDeviceClass.SWITCH, - entity_registry_enabled_default=False, - ), - ) - entities.append(entity) + entity = OPNsenseNatSwitchLegacy( + config_entry=config_entry, + coordinator=coordinator, + entity_description=SwitchEntityDescription( + key=f"nat_port_forward.{tracker}", + name=f"NAT Port Forward Rule {tracker} ({rule.get('descr', '')})", + icon="mdi:network-outline", + # entity_category=ENTITY_CATEGORY_CONFIG, + device_class=SwitchDeviceClass.SWITCH, + entity_registry_enabled_default=False, + ), + ) + entities.append(entity) + _LOGGER.debug("[compile_port_forward_switches_legacy] entities: %s", len(entities)) return entities -async def _compile_nat_outbound_switches( +async def _compile_nat_outbound_switches_legacy( config_entry: ConfigEntry, coordinator: OPNsenseDataUpdateCoordinator, state: MutableMapping[str, Any], ) -> list: - if not isinstance(state, MutableMapping) or not isinstance(state.get("config"), MutableMapping): + """Compile legacy NAT outbound rule switches from OPNsense state. + + Parameters + ---------- + config_entry : ConfigEntry + The Home Assistant config entry. + coordinator : OPNsenseDataUpdateCoordinator + The data update coordinator. + state : MutableMapping[str, Any] + The current state data from OPNsense. + + Returns + ------- + list + A list of OPNsenseNatSwitchLegacy entities for outbound rules. + + """ + if not isinstance(state, MutableMapping) or not isinstance( + state.get("firewall", {}).get("config", {}).get("nat", {}).get("outbound", {}).get("rule"), + list, + ): return [] entities: list = [] # nat outbound rules - if "nat" in state.get("config", {}): - # to actually be applicable mode must by "hybrid" or "advanced" - rules = dict_get(state, "config.nat.outbound.rule") - if isinstance(rules, list): - for rule in rules: - if not isinstance(rule, dict): - continue + # to actually be applicable, mode must by "hybrid" or "advanced" + rules = dict_get(state, "firewall.config.nat.outbound.rule") + if isinstance(rules, list): + for rule in rules: + if not isinstance(rule, dict): + continue - tracker = dict_get(rule, "created.time") - # we use tracker as the unique id - if tracker is None or len(tracker) < 1: - continue + tracker = dict_get(rule, "created.time") + # we use tracker as the unique id + if tracker is None or len(tracker) < 1: + continue - if "Auto created rule" in rule.get("descr", ""): - continue + if "Auto created rule" in rule.get("description", ""): + continue - entity = OPNsenseNatSwitch( - config_entry=config_entry, - coordinator=coordinator, - entity_description=SwitchEntityDescription( - key=f"nat_outbound.{tracker}", - name=f"NAT Outbound Rule {tracker} ({rule.get('descr', '')})", - icon="mdi:network-outline", - # entity_category=ENTITY_CATEGORY_CONFIG, - device_class=SwitchDeviceClass.SWITCH, - entity_registry_enabled_default=False, - ), - ) - entities.append(entity) + entity = OPNsenseNatSwitchLegacy( + config_entry=config_entry, + coordinator=coordinator, + entity_description=SwitchEntityDescription( + key=f"nat_outbound.{tracker}", + name=f"NAT Outbound Rule {tracker} ({rule.get('descr', '')})", + icon="mdi:network-outline", + # entity_category=ENTITY_CATEGORY_CONFIG, + device_class=SwitchDeviceClass.SWITCH, + entity_registry_enabled_default=False, + ), + ) + entities.append(entity) + _LOGGER.debug("[compile_nat_outbound_switches_legacy] entities: %s", len(entities)) return entities @@ -160,6 +218,23 @@ async def _compile_service_switches( coordinator: OPNsenseDataUpdateCoordinator, state: MutableMapping[str, Any], ) -> list: + """Compile service switches from OPNsense state. + + Parameters + ---------- + config_entry : ConfigEntry + The Home Assistant config entry. + coordinator : OPNsenseDataUpdateCoordinator + The data update coordinator. + state : MutableMapping[str, Any] + The current state data from OPNsense. + + Returns + ------- + list + A list of OPNsenseServiceSwitch entities. + + """ if not isinstance(state, MutableMapping) or not isinstance(state.get("services"), list): return [] @@ -190,6 +265,23 @@ async def _compile_vpn_switches( coordinator: OPNsenseDataUpdateCoordinator, state: MutableMapping[str, Any], ) -> list: + """Compile VPN switches from OPNsense state. + + Parameters + ---------- + config_entry : ConfigEntry + The Home Assistant config entry. + coordinator : OPNsenseDataUpdateCoordinator + The data update coordinator. + state : MutableMapping[str, Any] + The current state data from OPNsense. + + Returns + ------- + list + A list of OPNsenseVPNSwitch entities. + + """ entities: list = [] for vpn_type in ("openvpn", "wireguard"): for clients_servers in ("clients", "servers"): @@ -223,6 +315,23 @@ async def _compile_static_unbound_switch_legacy( coordinator: OPNsenseDataUpdateCoordinator, state: MutableMapping[str, Any], ) -> list: + """Compile legacy static Unbound blocklist switch from OPNsense state. + + Parameters + ---------- + config_entry : ConfigEntry + The Home Assistant config entry. + coordinator : OPNsenseDataUpdateCoordinator + The data update coordinator. + state : MutableMapping[str, Any] + The current state data from OPNsense. + + Returns + ------- + list + A list containing a single OPNsenseUnboundBlocklistSwitchLegacy entity. + + """ entities: list = [] entity = OPNsenseUnboundBlocklistSwitchLegacy( config_entry=config_entry, @@ -246,6 +355,23 @@ async def _compile_unbound_switches( coordinator: OPNsenseDataUpdateCoordinator, state: MutableMapping[str, Any], ) -> list: + """Compile Unbound blocklist switches from OPNsense state. + + Parameters + ---------- + config_entry : ConfigEntry + The Home Assistant config entry. + coordinator : OPNsenseDataUpdateCoordinator + The data update coordinator. + state : MutableMapping[str, Any] + The current state data from OPNsense. + + Returns + ------- + list + A list of OPNsenseUnboundBlocklistSwitch entities. + + """ if not isinstance(state, MutableMapping): return [] entities: list = [] @@ -270,12 +396,253 @@ async def _compile_unbound_switches( return entities +async def _compile_firewall_rules_switches( + config_entry: ConfigEntry, + coordinator: OPNsenseDataUpdateCoordinator, + state: MutableMapping[str, Any], +) -> list: + """Compile firewall rule switches from OPNsense state. + + Parameters + ---------- + config_entry : ConfigEntry + The Home Assistant config entry. + coordinator : OPNsenseDataUpdateCoordinator + The data update coordinator. + state : MutableMapping[str, Any] + The current state data from OPNsense. + + Returns + ------- + list + A list of OPNsenseFirewallRuleSwitch entities. + + """ + if not isinstance(state, MutableMapping) or not isinstance( + state.get("firewall", {}).get("rules"), dict + ): + return [] + + entities: list = [] + for rule in state.get("firewall", {}).get("rules", {}).values(): + if not isinstance(rule, dict): + continue + entity = OPNsenseFirewallRuleSwitch( + config_entry=config_entry, + coordinator=coordinator, + entity_description=SwitchEntityDescription( + key=f"firewall.rule.{rule.get('uuid', 'unknown')}", + name=f"Firewall Rule: {rule.get('%interface', '')}: {rule.get('description', 'unknown')}", + icon="mdi:play-network-outline", + device_class=SwitchDeviceClass.SWITCH, + entity_registry_enabled_default=False, + ), + ) + entities.append(entity) + return entities + + +async def _compile_nat_source_rules_switches( + config_entry: ConfigEntry, + coordinator: OPNsenseDataUpdateCoordinator, + state: MutableMapping[str, Any], +) -> list: + """Compile NAT source rule switches from OPNsense state. + + Parameters + ---------- + config_entry : ConfigEntry + The Home Assistant config entry. + coordinator : OPNsenseDataUpdateCoordinator + The data update coordinator. + state : MutableMapping[str, Any] + The current state data from OPNsense. + + Returns + ------- + list + A list of OPNsenseNATRuleSwitch entities for source NAT rules. + + """ + if not isinstance(state, MutableMapping) or not isinstance( + state.get("firewall", {}).get("nat", {}).get("source_nat"), dict + ): + return [] + + entities: list = [] + for rule in state.get("firewall", {}).get("nat", {}).get("source_nat", {}).values(): + if not isinstance(rule, dict): + continue + entity = OPNsenseNATRuleSwitch( + config_entry=config_entry, + coordinator=coordinator, + entity_description=SwitchEntityDescription( + key=f"firewall.nat.source_nat.{rule.get('uuid', 'unknown')}", + name=f"NAT Source Rule: {rule.get('%interface', '')}: {rule.get('description', 'unknown')}", + icon="mdi:network-outline", + device_class=SwitchDeviceClass.SWITCH, + entity_registry_enabled_default=False, + ), + ) + entities.append(entity) + return entities + + +async def _compile_nat_destination_rules_switches( + config_entry: ConfigEntry, + coordinator: OPNsenseDataUpdateCoordinator, + state: MutableMapping[str, Any], +) -> list: + """Compile NAT destination rule switches from OPNsense state. + + Parameters + ---------- + config_entry : ConfigEntry + The Home Assistant config entry. + coordinator : OPNsenseDataUpdateCoordinator + The data update coordinator. + state : MutableMapping[str, Any] + The current state data from OPNsense. + + Returns + ------- + list + A list of OPNsenseNATRuleSwitch entities for destination NAT rules. + + """ + if not isinstance(state, MutableMapping) or not isinstance( + state.get("firewall", {}).get("nat", {}).get("d_nat"), dict + ): + return [] + + entities: list = [] + for rule in state.get("firewall", {}).get("nat", {}).get("d_nat", {}).values(): + if not isinstance(rule, dict): + continue + entity = OPNsenseNATRuleSwitch( + config_entry=config_entry, + coordinator=coordinator, + entity_description=SwitchEntityDescription( + key=f"firewall.nat.d_nat.{rule.get('uuid', 'unknown')}", + name=f"NAT Destination Rule: {rule.get('%interface', '')}: {rule.get('description', 'unknown')}", + icon="mdi:network-outline", + device_class=SwitchDeviceClass.SWITCH, + entity_registry_enabled_default=False, + ), + ) + entities.append(entity) + return entities + + +async def _compile_nat_one_to_one_rules_switches( + config_entry: ConfigEntry, + coordinator: OPNsenseDataUpdateCoordinator, + state: MutableMapping[str, Any], +) -> list: + """Compile NAT one-to-one rule switches from OPNsense state. + + Parameters + ---------- + config_entry : ConfigEntry + The Home Assistant config entry. + coordinator : OPNsenseDataUpdateCoordinator + The data update coordinator. + state : MutableMapping[str, Any] + The current state data from OPNsense. + + Returns + ------- + list + A list of OPNsenseNATRuleSwitch entities for one-to-one NAT rules. + + """ + if not isinstance(state, MutableMapping) or not isinstance( + state.get("firewall", {}).get("nat", {}).get("one_to_one"), dict + ): + return [] + + entities: list = [] + for rule in state.get("firewall", {}).get("nat", {}).get("one_to_one", {}).values(): + if not isinstance(rule, dict): + continue + entity = OPNsenseNATRuleSwitch( + config_entry=config_entry, + coordinator=coordinator, + entity_description=SwitchEntityDescription( + key=f"firewall.nat.one_to_one.{rule.get('uuid', 'unknown')}", + name=f"NAT One to One Rule: {rule.get('%interface', '')}: {rule.get('description', 'unknown')}", + icon="mdi:network-outline", + device_class=SwitchDeviceClass.SWITCH, + entity_registry_enabled_default=False, + ), + ) + entities.append(entity) + return entities + + +async def _compile_nat_npt_rules_switches( + config_entry: ConfigEntry, + coordinator: OPNsenseDataUpdateCoordinator, + state: MutableMapping[str, Any], +) -> list: + """Compile NAT NPT rule switches from OPNsense state. + + Parameters + ---------- + config_entry : ConfigEntry + The Home Assistant config entry. + coordinator : OPNsenseDataUpdateCoordinator + The data update coordinator. + state : MutableMapping[str, Any] + The current state data from OPNsense. + + Returns + ------- + list + A list of OPNsenseNATRuleSwitch entities for NPT NAT rules. + + """ + if not isinstance(state, MutableMapping) or not isinstance( + state.get("firewall", {}).get("nat", {}).get("npt"), dict + ): + return [] + + entities: list = [] + for rule in state.get("firewall", {}).get("nat", {}).get("npt", {}).values(): + if not isinstance(rule, dict): + continue + entity = OPNsenseNATRuleSwitch( + config_entry=config_entry, + coordinator=coordinator, + entity_description=SwitchEntityDescription( + key=f"firewall.nat.npt.{rule.get('uuid', 'unknown')}", + name=f"NAT NPT Rule: {rule.get('%interface', '')}: {rule.get('description', 'unknown')}", + icon="mdi:network-outline", + device_class=SwitchDeviceClass.SWITCH, + entity_registry_enabled_default=False, + ), + ) + entities.append(entity) + return entities + + async def async_setup_entry( hass: HomeAssistant, config_entry: ConfigEntry, async_add_entities: AddEntitiesCallback, ) -> None: - """Set up the OPNsense switches.""" + """Set up the OPNsense switches. + + Parameters + ---------- + hass : HomeAssistant + The Home Assistant instance. + config_entry : ConfigEntry + The config entry for this integration. + async_add_entities : AddEntitiesCallback + Callback to add entities to Home Assistant. + + """ coordinator: OPNsenseDataUpdateCoordinator = getattr(config_entry.runtime_data, COORDINATOR) state: MutableMapping[str, Any] = coordinator.data if not isinstance(state, MutableMapping): @@ -285,10 +652,63 @@ async def async_setup_entry( entities: list = [] - if config.get(CONF_SYNC_FILTERS_AND_NAT, DEFAULT_SYNC_OPTION_VALUE): - entities.extend(await _compile_filter_switches(config_entry, coordinator, state)) - entities.extend(await _compile_port_forward_switches(config_entry, coordinator, state)) - entities.extend(await _compile_nat_outbound_switches(config_entry, coordinator, state)) + if config.get(CONF_SYNC_FIREWALL_AND_NAT, DEFAULT_SYNC_OPTION_VALUE): + firmware = state.get("host_firmware_version", None) + if firmware: + try: + if awesomeversion.AwesomeVersion(firmware) < awesomeversion.AwesomeVersion( + "26.1.1" + ): + entities.extend( + await _compile_filter_switches_legacy(config_entry, coordinator, state) + ) + entities.extend( + await _compile_port_forward_switches_legacy( + config_entry, coordinator, state + ) + ) + entities.extend( + await _compile_nat_outbound_switches_legacy( + config_entry, coordinator, state + ) + ) + else: + # TODO: Also disable once OPNsense 26.1.x drops support for the plugin + if state.get("plugin_installed", False) is True: + entities.extend( + await _compile_filter_switches_legacy(config_entry, coordinator, state) + ) + entities.extend( + await _compile_firewall_rules_switches(config_entry, coordinator, state) + ) + entities.extend( + await _compile_nat_source_rules_switches(config_entry, coordinator, state) + ) + entities.extend( + await _compile_nat_destination_rules_switches( + config_entry, coordinator, state + ) + ) + entities.extend( + await _compile_nat_one_to_one_rules_switches( + config_entry, coordinator, state + ) + ) + entities.extend( + await _compile_nat_npt_rules_switches(config_entry, coordinator, state) + ) + + except ( + awesomeversion.exceptions.AwesomeVersionCompareException, + TypeError, + ValueError, + ) as e: + _LOGGER.error( + "Error comparing firewall/NAT firmware version %s: %s: %s", + firmware, + type(e).__name__, + e, + ) if config.get(CONF_SYNC_SERVICES, DEFAULT_SYNC_OPTION_VALUE): entities.extend(await _compile_service_switches(config_entry, coordinator, state)) if config.get(CONF_SYNC_VPN, DEFAULT_SYNC_OPTION_VALUE): @@ -336,7 +756,18 @@ def __init__( coordinator: OPNsenseDataUpdateCoordinator, entity_description: SwitchEntityDescription, ) -> None: - """Initialize OPNsense Switch entities.""" + """Initialize OPNsense Switch entities. + + Parameters + ---------- + config_entry : ConfigEntry + The Home Assistant config entry. + coordinator : OPNsenseDataUpdateCoordinator + The data update coordinator. + entity_description : SwitchEntityDescription + The entity description. + + """ name_suffix: str | None = ( entity_description.name if isinstance(entity_description.name, str) else None ) @@ -357,11 +788,26 @@ def __init__( @property def delay_update(self) -> bool: - """Return whether to process the coordinator update or not.""" + """Return whether to process the coordinator update or not. + + Returns + ------- + bool + True if updates should be delayed, False otherwise. + + """ return self._delay_update @delay_update.setter def delay_update(self, value: bool) -> None: + """Set whether to delay coordinator updates. + + Parameters + ---------- + value : bool + True to delay updates, False to allow them. + + """ if value and not self._delay_update: self._delay_update = True self._reset_delay() @@ -372,6 +818,7 @@ def delay_update(self, value: bool) -> None: self._delay_update_remove = None def _reset_delay(self) -> None: + """Reset the delay timer for coordinator updates.""" if self._delay_update_remove: self._delay_update_remove() @@ -384,7 +831,265 @@ def _clear(_: Any) -> None: ) -class OPNsenseFilterSwitch(OPNsenseSwitch): +class OPNsenseFirewallRuleSwitch(OPNsenseSwitch): + """Class for OPNsense Firewall Rule Switch entities.""" + + def __init__( + self, + config_entry: ConfigEntry, + coordinator: OPNsenseDataUpdateCoordinator, + entity_description: SwitchEntityDescription, + ) -> None: + """Initialize switch entity. + + Parameters + ---------- + config_entry : ConfigEntry + The Home Assistant config entry. + coordinator : OPNsenseDataUpdateCoordinator + The data update coordinator. + entity_description : SwitchEntityDescription + The entity description. + + """ + super().__init__( + config_entry=config_entry, + coordinator=coordinator, + entity_description=entity_description, + ) + self._rule_id: str = self._opnsense_get_rule_id() + _LOGGER.debug( + "[OPNsenseFirewallRuleSwitch init] Name: %s, rule_id: %s", self.name, self._rule_id + ) + + def _opnsense_get_rule_id(self) -> str: + """Get the rule ID from the entity description. + + Returns + ------- + str + The rule ID. + + """ + return self.entity_description.key.split(".")[-1] + + def _opnsense_get_rule(self) -> MutableMapping[str, Any] | None: + """Get the firewall rule data from the coordinator. + + Returns + ------- + MutableMapping[str, Any] | None + The rule data if available, None otherwise. + + """ + state: MutableMapping[str, Any] = self.coordinator.data + if not isinstance(state, MutableMapping): + return None + return state.get("firewall", {}).get("rules", {}).get(self._rule_id, None) + + @callback + def _handle_coordinator_update(self) -> None: + """Handle coordinator update for the firewall rule switch.""" + if self.delay_update: + _LOGGER.debug( + "Skipping coordinator update for firewall rule switch %s due to delay", self.name + ) + return + rule = self._opnsense_get_rule() + if not rule: + self._available = False + self.async_write_ha_state() + return + try: + self._attr_is_on = bool(rule.get("enabled", "1") == "1") + except (TypeError, KeyError, AttributeError): + self._available = False + self.async_write_ha_state() + return + self._available = True + self.async_write_ha_state() + _LOGGER.debug( + "[OPNsenseFirewallRuleSwitch handle_coordinator_update] Name: %s, available: %s, is_on: %s, extra_state_attributes: %s", + self.name, + self.available, + self.is_on, + self.extra_state_attributes, + ) + + async def async_turn_on(self, **kwargs: Any) -> None: + """Turn the entity on.""" + if self._rule_id is None or not self._client: + return + result = await self._client.toggle_firewall_rule(self._rule_id, "on") + if result: + _LOGGER.info("Turned on firewall rule: %s", self.name) + self._attr_is_on = True + self.async_write_ha_state() + self.delay_update = True + else: + _LOGGER.error("Failed to turn on firewall rule: %s", self.name) + + async def async_turn_off(self, **kwargs: Any) -> None: + """Turn the entity off.""" + if self._rule_id is None or not self._client: + return + result = await self._client.toggle_firewall_rule(self._rule_id, "off") + if result: + _LOGGER.info("Turned off firewall rule: %s", self.name) + self._attr_is_on = False + self.async_write_ha_state() + self.delay_update = True + else: + _LOGGER.error("Failed to turn off firewall rule: %s", self.name) + + @property + def icon(self) -> str | None: + """Return the icon for the entity.""" + if self.available and self.is_on: + return "mdi:play-network" + return super().icon + + +class OPNsenseNATRuleSwitch(OPNsenseSwitch): + """Class for OPNsense NAT Rule Switch entities.""" + + def __init__( + self, + config_entry: ConfigEntry, + coordinator: OPNsenseDataUpdateCoordinator, + entity_description: SwitchEntityDescription, + ) -> None: + """Initialize switch entity. + + Parameters + ---------- + config_entry : ConfigEntry + The Home Assistant config entry. + coordinator : OPNsenseDataUpdateCoordinator + The data update coordinator. + entity_description : SwitchEntityDescription + The entity description. + + """ + super().__init__( + config_entry=config_entry, + coordinator=coordinator, + entity_description=entity_description, + ) + self._rule_id: str = self._opnsense_get_rule_id() + self._nat_rule_type: str = self._get_nat_rule_type() + _LOGGER.debug( + "[OPNsenseNATRuleSwitch init] Name: %s, key: %s, rule_id: %s, rule_type: %s", + self.name, + self.entity_description.key, + self._rule_id, + self._nat_rule_type, + ) + + def _get_nat_rule_type(self) -> str: + """Get the NAT rule type from the entity description. + + Returns + ------- + str + The NAT rule type. + + """ + return self.entity_description.key.split(".")[2] + + def _opnsense_get_rule_id(self) -> str: + """Get the rule ID from the entity description. + + Returns + ------- + str + The rule ID. + + """ + return self.entity_description.key.split(".")[-1] + + def _opnsense_get_rule(self) -> MutableMapping[str, Any] | None: + """Get the NAT rule data from the coordinator. + + Returns + ------- + MutableMapping[str, Any] | None + The rule data if available, None otherwise. + + """ + state: MutableMapping[str, Any] = self.coordinator.data + if not isinstance(state, MutableMapping): + return None + return ( + state.get("firewall", {}) + .get("nat", {}) + .get(self._nat_rule_type, {}) + .get(self._rule_id, None) + ) + + @callback + def _handle_coordinator_update(self) -> None: + """Handle coordinator update for the NAT rule switch.""" + if self.delay_update: + _LOGGER.debug("Skipping coordinator update for NAT switch %s due to delay", self.name) + return + rule = self._opnsense_get_rule() + _LOGGER.debug("[OPNsenseNATRuleSwitch handle_coordinator_update] fetched rule: %s", rule) + if not rule: + self._available = False + self.async_write_ha_state() + return + try: + self._attr_is_on = bool(rule.get("enabled", "1") == "1") + except (TypeError, KeyError, AttributeError): + self._available = False + self.async_write_ha_state() + return + self._available = True + self.async_write_ha_state() + _LOGGER.debug( + "[OPNsenseNATRuleSwitch handle_coordinator_update] Name: %s, available: %s, is_on: %s, extra_state_attributes: %s", + self.name, + self.available, + self.is_on, + self.extra_state_attributes, + ) + + async def async_turn_on(self, **kwargs: Any) -> None: + """Turn the entity on.""" + if self._rule_id is None or not self._client: + return + result = await self._client.toggle_nat_rule(self._nat_rule_type, self._rule_id, "on") + if result: + _LOGGER.info("Turned on NAT rule: %s", self.name) + self._attr_is_on = True + self.async_write_ha_state() + self.delay_update = True + else: + _LOGGER.error("Failed to turn on NAT rule: %s", self.name) + + async def async_turn_off(self, **kwargs: Any) -> None: + """Turn the entity off.""" + if self._rule_id is None or not self._client: + return + result = await self._client.toggle_nat_rule(self._nat_rule_type, self._rule_id, "off") + if result: + _LOGGER.info("Turned off NAT rule: %s", self.name) + self._attr_is_on = False + self.async_write_ha_state() + self.delay_update = True + else: + _LOGGER.error("Failed to turn off NAT rule: %s", self.name) + + @property + def icon(self) -> str | None: + """Return the icon for the entity.""" + if self.available and self.is_on: + return "mdi:network" + return super().icon + + +class OPNsenseFilterSwitchLegacy(OPNsenseSwitch): """Class for OPNsense Filter Switch entities.""" def __init__( @@ -393,7 +1098,18 @@ def __init__( coordinator: OPNsenseDataUpdateCoordinator, entity_description: SwitchEntityDescription, ) -> None: - """Initialize switch entity.""" + """Initialize switch entity. + + Parameters + ---------- + config_entry : ConfigEntry + The Home Assistant config entry. + coordinator : OPNsenseDataUpdateCoordinator + The data update coordinator. + entity_description : SwitchEntityDescription + The entity description. + + """ super().__init__( config_entry=config_entry, coordinator=coordinator, @@ -401,25 +1117,42 @@ def __init__( ) self._tracker: str = self._opnsense_get_tracker() self._rule: MutableMapping[str, Any] | None = None - # _LOGGER.debug(f"[OPNsenseFilterSwitch init] Name: {self.name}, tracker: {self._tracker}") + # _LOGGER.debug(f"[OPNsenseFilterSwitchLegacy init] Name: {self.name}, tracker: {self._tracker}") def _opnsense_get_tracker(self) -> str: + """Get the tracker from the entity description. + + Returns + ------- + str + The tracker string. + + """ parts = self.entity_description.key.split(".") parts.pop(0) return ".".join(parts) def _opnsense_get_rule(self) -> MutableMapping[str, Any] | None: + """Get the filter rule data from the coordinator. + + Returns + ------- + MutableMapping[str, Any] | None + The rule data if available, None otherwise. + + """ state: MutableMapping[str, Any] = self.coordinator.data tracker: str = self._opnsense_get_tracker() if not isinstance(state, MutableMapping): return None - for rule in state.get("config", {}).get("filter", {}).get("rule", {}): + for rule in state.get("firewall", {}).get("config", {}).get("filter", {}).get("rule", {}): if dict_get(rule, "created.time") == tracker: return rule return None @callback def _handle_coordinator_update(self) -> None: + """Handle coordinator update for the filter switch.""" if self.delay_update: _LOGGER.debug( "Skipping coordinator update for filter switch %s due to delay", self.name @@ -438,13 +1171,13 @@ def _handle_coordinator_update(self) -> None: return self._available = True self.async_write_ha_state() - # _LOGGER.debug(f"[OPNsenseFilterSwitch handle_coordinator_update] Name: {self.name}, available: {self.available}, is_on: {self.is_on}, extra_state_attributes: {self.extra_state_attributes}") + # _LOGGER.debug(f"[OPNsenseFilterSwitchLegacy handle_coordinator_update] Name: {self.name}, available: {self.available}, is_on: {self.is_on}, extra_state_attributes: {self.extra_state_attributes}") async def async_turn_on(self, **kwargs: Any) -> None: """Turn the entity on.""" if self._rule is None or not self._client: return - await self._client.enable_filter_rule_by_created_time(self._tracker) + await self._client.enable_filter_rule_by_created_time_legacy(self._tracker) _LOGGER.info("Turned on filter rule: %s", self.name) self._attr_is_on = True self.async_write_ha_state() @@ -454,7 +1187,7 @@ async def async_turn_off(self, **kwargs: Any) -> None: """Turn the entity off.""" if self._rule is None or not self._client: return - await self._client.disable_filter_rule_by_created_time(self._tracker) + await self._client.disable_filter_rule_by_created_time_legacy(self._tracker) _LOGGER.info("Turned off filter rule: %s", self.name) self._attr_is_on = False self.async_write_ha_state() @@ -468,7 +1201,7 @@ def icon(self) -> str | None: return super().icon -class OPNsenseNatSwitch(OPNsenseSwitch): +class OPNsenseNatSwitchLegacy(OPNsenseSwitch): """Class for OPNsense NAT Switch entities.""" def __init__( @@ -477,7 +1210,18 @@ def __init__( coordinator: OPNsenseDataUpdateCoordinator, entity_description: SwitchEntityDescription, ) -> None: - """Initialize switch entity.""" + """Initialize switch entity. + + Parameters + ---------- + config_entry : ConfigEntry + The Home Assistant config entry. + coordinator : OPNsenseDataUpdateCoordinator + The data update coordinator. + entity_description : SwitchEntityDescription + The entity description. + + """ super().__init__( config_entry=config_entry, coordinator=coordinator, @@ -486,25 +1230,55 @@ def __init__( self._rule_type: str = self._opnsense_get_rule_type() self._tracker: str = self._opnsense_get_tracker() self._rule: MutableMapping[str, Any] | None = None - # _LOGGER.debug(f"[OPNsenseNatSwitch init] Name: {self.name}, tracker: {self._tracker}, rule_type: {self._rule_type}") + # _LOGGER.debug(f"[OPNsenseNatSwitchLegacy init] Name: {self.name}, tracker: {self._tracker}, rule_type: {self._rule_type}") def _opnsense_get_rule_type(self) -> str: + """Get the rule type from the entity description. + + Returns + ------- + str + The rule type. + + """ return self.entity_description.key.split(".")[0] def _opnsense_get_tracker(self) -> str: + """Get the tracker from the entity description. + + Returns + ------- + str + The tracker string. + + """ parts = self.entity_description.key.split(".") parts.pop(0) return ".".join(parts) def _opnsense_get_rule(self) -> MutableMapping[str, Any] | None: + """Get the NAT rule data from the coordinator. + + Returns + ------- + MutableMapping[str, Any] | None + The rule data if available, None otherwise. + + """ state: MutableMapping[str, Any] = self.coordinator.data if not isinstance(state, MutableMapping): return None rules: list = [] if self._rule_type == ATTR_NAT_PORT_FORWARD: - rules = state.get("config", {}).get("nat", {}).get("rule", []) + rules = state.get("firewall", {}).get("config", {}).get("nat", {}).get("rule", []) if self._rule_type == ATTR_NAT_OUTBOUND: - rules = state.get("config", {}).get("nat", {}).get("outbound", {}).get("rule", []) + rules = ( + state.get("firewall", {}) + .get("config", {}) + .get("nat", {}) + .get("outbound", {}) + .get("rule", []) + ) for rule in rules: if dict_get(rule, "created.time") == self._tracker: @@ -513,11 +1287,14 @@ def _opnsense_get_rule(self) -> MutableMapping[str, Any] | None: @callback def _handle_coordinator_update(self) -> None: + """Handle coordinator update for the NAT switch.""" if self.delay_update: _LOGGER.debug("Skipping coordinator update for NAT switch %s due to delay", self.name) return self._rule = self._opnsense_get_rule() if not isinstance(self._rule, MutableMapping): + self._available = False + self.async_write_ha_state() return try: self._attr_is_on = "disabled" not in self._rule @@ -527,16 +1304,16 @@ def _handle_coordinator_update(self) -> None: return self._available = True self.async_write_ha_state() - # _LOGGER.debug(f"[OPNsenseNatSwitch handle_coordinator_update] Name: {self.name}, available: {self.available}, is_on: {self.is_on}, extra_state_attributes: {self.extra_state_attributes}") + # _LOGGER.debug(f"[OPNsenseNatSwitchLegacy handle_coordinator_update] Name: {self.name}, available: {self.available}, is_on: {self.is_on}, extra_state_attributes: {self.extra_state_attributes}") async def async_turn_on(self, **kwargs: Any) -> None: """Turn the entity on.""" if not isinstance(self._rule, MutableMapping) or not self._client: return if self._rule_type == ATTR_NAT_PORT_FORWARD: - method = self._client.enable_nat_port_forward_rule_by_created_time + method = self._client.enable_nat_port_forward_rule_by_created_time_legacy elif self._rule_type == ATTR_NAT_OUTBOUND: - method = self._client.enable_nat_outbound_rule_by_created_time + method = self._client.enable_nat_outbound_rule_by_created_time_legacy else: return await method(self._tracker) @@ -550,9 +1327,9 @@ async def async_turn_off(self, **kwargs: Any) -> None: if not isinstance(self._rule, MutableMapping) or not self._client: return if self._rule_type == ATTR_NAT_PORT_FORWARD: - method = self._client.disable_nat_port_forward_rule_by_created_time + method = self._client.disable_nat_port_forward_rule_by_created_time_legacy elif self._rule_type == ATTR_NAT_OUTBOUND: - method = self._client.disable_nat_outbound_rule_by_created_time + method = self._client.disable_nat_outbound_rule_by_created_time_legacy else: return await method(self._tracker) @@ -578,7 +1355,18 @@ def __init__( coordinator: OPNsenseDataUpdateCoordinator, entity_description: SwitchEntityDescription, ) -> None: - """Initialize switch entity.""" + """Initialize switch entity. + + Parameters + ---------- + config_entry : ConfigEntry + The Home Assistant config entry. + coordinator : OPNsenseDataUpdateCoordinator + The data update coordinator. + entity_description : SwitchEntityDescription + The entity description. + + """ super().__init__( config_entry=config_entry, coordinator=coordinator, @@ -589,12 +1377,36 @@ def __init__( # _LOGGER.debug(f"[OPNsenseServiceSwitch init] Name: {self.name}, prop_name: {self._prop_name}") def _opnsense_get_property_name(self) -> str: + """Get the property name from the entity description. + + Returns + ------- + str + The property name. + + """ return self.entity_description.key.split(".")[2] def _opnsense_get_service_id(self) -> str: + """Get the service ID from the entity description. + + Returns + ------- + str + The service ID. + + """ return self.entity_description.key.split(".")[1] def _opnsense_get_service(self) -> MutableMapping[str, Any] | None: + """Get the service data from the coordinator. + + Returns + ------- + MutableMapping[str, Any] | None + The service data if available, None otherwise. + + """ state: MutableMapping[str, Any] = self.coordinator.data if not isinstance(state, MutableMapping): return None @@ -606,6 +1418,7 @@ def _opnsense_get_service(self) -> MutableMapping[str, Any] | None: @callback def _handle_coordinator_update(self) -> None: + """Handle coordinator update for the service switch.""" if self.delay_update: _LOGGER.debug( "Skipping coordinator update for service switch %s due to delay", self.name @@ -735,7 +1548,18 @@ def __init__( coordinator: OPNsenseDataUpdateCoordinator, entity_description: SwitchEntityDescription, ) -> None: - """Initialize switch entity.""" + """Initialize switch entity. + + Parameters + ---------- + config_entry : ConfigEntry + The Home Assistant config entry. + coordinator : OPNsenseDataUpdateCoordinator + The data update coordinator. + entity_description : SwitchEntityDescription + The entity description. + + """ super().__init__( config_entry=config_entry, coordinator=coordinator, @@ -815,7 +1639,18 @@ def __init__( coordinator: OPNsenseDataUpdateCoordinator, entity_description: SwitchEntityDescription, ) -> None: - """Initialize switch entity.""" + """Initialize switch entity. + + Parameters + ---------- + config_entry : ConfigEntry + The Home Assistant config entry. + coordinator : OPNsenseDataUpdateCoordinator + The data update coordinator. + entity_description : SwitchEntityDescription + The entity description. + + """ super().__init__( config_entry=config_entry, coordinator=coordinator, @@ -828,6 +1663,7 @@ def __init__( @callback def _handle_coordinator_update(self) -> None: + """Handle coordinator update for the VPN switch.""" if self.delay_update: _LOGGER.debug("Skipping coordinator update for VPN switch %s due to delay", self.name) return @@ -888,7 +1724,14 @@ def _handle_coordinator_update(self) -> None: # _LOGGER.debug(f"[OPNsenseVPNSwitch handle_coordinator_update] Name: {self.name}, available: {self.available}, is_on: {self.is_on}, extra_state_attributes: {self.extra_state_attributes}") async def async_turn_on(self, **kwargs: Any) -> None: - """Turn the entity on.""" + """Turn on the VPN switch. + + Parameters + ---------- + **kwargs : Any + Additional keyword arguments. + + """ if self.is_on or not self._client: return @@ -905,7 +1748,14 @@ async def async_turn_on(self, **kwargs: Any) -> None: _LOGGER.error("Failed to turn on VPN: %s", self.name) async def async_turn_off(self, **kwargs: Any) -> None: - """Turn the entity off.""" + """Turn off the VPN switch. + + Parameters + ---------- + **kwargs : Any + Additional keyword arguments. + + """ if not self.is_on or not self._client: return diff --git a/custom_components/opnsense/translations/en.json b/custom_components/opnsense/translations/en.json index 4f83ac67..8218e09a 100644 --- a/custom_components/opnsense/translations/en.json +++ b/custom_components/opnsense/translations/en.json @@ -117,6 +117,14 @@ "device_id_mismatched": { "title": "OPNsense Hardware Has Changed", "description": "OPNsense Device ID has changed which indicates new or changed hardware. In order to accommodate this, hass-opnsense needs to be removed and reinstalled for this router. hass-opnsense is shutting down." + }, + "plugin_cleanup_partial": { + "title": "OPNsense 26.1.1 and Plugin cleanup partially completed", + "description": "OPNsense Plugin is still installed. NAT Outbound and NAT Port Forward rules removed. Firewall Filter rules will be removed once the plugin is removed." + }, + "plugin_cleanup_done": { + "title": "OPNsense 26.1.1 and Plugin cleanup completed", + "description": "NAT Outbound, NAT Port Forward, and Firewall Filter rules removed." } }, "selector": { @@ -315,7 +323,7 @@ }, "reload_interface": { "name": "Reload an Interface", - "description": "Reload or restart an OPNSense interface", + "description": "Reload or restart an OPNsense interface", "fields": { "interface": { "name": "Interface Name", diff --git a/tests/conftest.py b/tests/conftest.py index 71dadddd..97d3dc6b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -74,7 +74,7 @@ def _ensure_async_create_task_mock(real, side_effect): object.__setattr__( real, "async_create_task", - MagicMock(side_effect=lambda coro, *a, **k: orig(coro, *a, **k)), + MagicMock(side_effect=orig), ) diff --git a/tests/test_coordinator.py b/tests/test_coordinator.py index 5639fd79..6268f7f0 100644 --- a/tests/test_coordinator.py +++ b/tests/test_coordinator.py @@ -18,7 +18,7 @@ CONF_SYNC_CARP, CONF_SYNC_CERTIFICATES, CONF_SYNC_DHCP_LEASES, - CONF_SYNC_FILTERS_AND_NAT, + CONF_SYNC_FIREWALL_AND_NAT, CONF_SYNC_FIRMWARE_UPDATES, CONF_SYNC_GATEWAYS, CONF_SYNC_INTERFACES, @@ -403,7 +403,7 @@ def test_build_categories_returns_empty_when_no_config(make_config_entry, fake_c (CONF_SYNC_GATEWAYS, ["gateways"]), (CONF_SYNC_SERVICES, ["services"]), (CONF_SYNC_NOTICES, ["notices"]), - (CONF_SYNC_FILTERS_AND_NAT, ["config"]), + (CONF_SYNC_FIREWALL_AND_NAT, ["firewall"]), (CONF_SYNC_UNBOUND, [ATTR_UNBOUND_BLOCKLIST]), (CONF_SYNC_INTERFACES, ["interfaces"]), (CONF_SYNC_CERTIFICATES, ["certificates"]), diff --git a/tests/test_device_tracker.py b/tests/test_device_tracker.py index a377bdcc..b1d4b6d1 100644 --- a/tests/test_device_tracker.py +++ b/tests/test_device_tracker.py @@ -31,7 +31,7 @@ async def test_async_setup_entry_configured_devices( ) # attach coordinator into runtime_data under the expected attribute name setattr(entry.runtime_data, dt_mod.DEVICE_TRACKER_COORDINATOR, coordinator) - entry.add_update_listener = lambda f: (lambda: None) + entry.add_update_listener = lambda f: lambda: None entry.async_on_unload = lambda x: None hass = ph_hass hass.config_entries.async_update_entry = MagicMock() @@ -95,7 +95,7 @@ async def test_async_setup_entry_removes_nonmatching_tracked_macs( entry_id="eid_remove", ) setattr(entry.runtime_data, dt_mod.DEVICE_TRACKER_COORDINATOR, coordinator) - entry.add_update_listener = lambda f: (lambda: None) + entry.add_update_listener = lambda f: lambda: None entry.async_on_unload = lambda x: None hass = ph_hass @@ -307,7 +307,7 @@ async def test_async_setup_entry_state_not_mapping(ph_hass, coordinator, make_co hass.data = {} hass.config_entries.async_update_entry = MagicMock() - await dt_mod.async_setup_entry(hass, entry, lambda x: added.extend(x)) + await dt_mod.async_setup_entry(hass, entry, added.extend) assert len(added) == 0 assert not hass.config_entries.async_update_entry.called @@ -437,7 +437,7 @@ async def test_async_setup_entry_from_arp_entries( added = [] - await dt_mod.async_setup_entry(hass, entry, lambda ents: added.extend(ents)) + await dt_mod.async_setup_entry(hass, entry, added.extend) assert len(added) == 2 assert all(isinstance(e, dt_mod.OPNsenseScannerEntity) for e in added) assert {e.unique_id for e in added} == {"dev1_mac_m1", "dev1_mac_m2"} diff --git a/tests/test_init.py b/tests/test_init.py index c7b604d6..4f860e33 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -491,11 +491,141 @@ async def test_async_setup_entry_firmware_between_min_and_ltd( call_args = create_issue_mock.call_args # args: (hass, domain, issue_id, ...) assert call_args[0][1] == init_mod.DOMAIN - expected_issue_id = f"opnsense_25.0_below_ltd_firmware_{init_mod.OPNSENSE_LTD_FIRMWARE}" + expected_issue_id = f"{entry.data[init_mod.CONF_DEVICE_UNIQUE_ID]}_opnsense_below_ltd_firmware_{init_mod.OPNSENSE_LTD_FIRMWARE}" assert call_args[0][2] == expected_issue_id assert call_args[1].get("severity") == init_mod.ir.IssueSeverity.WARNING +@pytest.mark.asyncio +async def test_async_setup_entry_firmware_triggers_plugin_cleanup( + monkeypatch, ph_hass, coordinator_capture, fake_client, fake_coordinator, make_config_entry +): + """async_setup_entry calls _deprecated_plugin_cleanup_26_1_1 for firmware >25.10 and <26.7.""" + monkeypatch.setattr(init_mod, "OPNsenseClient", fake_client(firmware_version="26.2")) + monkeypatch.setattr( + init_mod, "OPNsenseDataUpdateCoordinator", coordinator_capture.factory(fake_coordinator) + ) + # Mock the cleanup function to track if it's called + cleanup_mock = AsyncMock(return_value=True) + monkeypatch.setattr(init_mod, "_deprecated_plugin_cleanup_26_1_1", cleanup_mock) + + entry = make_config_entry( + data={ + init_mod.CONF_URL: "http://1.2.3.4", + init_mod.CONF_USERNAME: "u", + init_mod.CONF_PASSWORD: "p", + init_mod.CONF_DEVICE_UNIQUE_ID: "dev1", + } + ) + hass = ph_hass + hass.config_entries.async_forward_entry_setups = AsyncMock(return_value=True) + hass.config_entries.async_reload = AsyncMock() + hass.data.setdefault("aiohttp_connector", {}) + + res = await init_mod.async_setup_entry(hass, entry) + assert res is True + # Verify cleanup was called with correct args and awaited + cleanup_mock.assert_awaited_once_with( + hass=hass, + client=ANY, + entry_id=entry.entry_id, + config_device_id=entry.data[init_mod.CONF_DEVICE_UNIQUE_ID], + ) + + +@pytest.mark.asyncio +async def test_deprecated_plugin_cleanup_26_1_1_plugin_not_installed(monkeypatch): + """_deprecated_plugin_cleanup_26_1_1 removes filter entities when plugin not installed.""" + hass = MagicMock(spec=HomeAssistant) + client = MagicMock() + client.is_plugin_installed = AsyncMock(return_value=False) + entry_id = "test_entry_id" + + # Mock entity registry + entity_registry = MagicMock() + monkeypatch.setattr(init_mod.er, "async_get", lambda hass: entity_registry) + + # Mock entities: one filter entity, one normal switch + filter_entity = MagicMock() + filter_entity.entity_id = "switch.opnsense_filter_rule" + filter_entity.unique_id = "dev1_filter_rule1" + normal_entity = MagicMock() + normal_entity.entity_id = "switch.opnsense_normal_rule" + normal_entity.unique_id = "dev1_normal_rule1" + monkeypatch.setattr( + init_mod.er, + "async_entries_for_config_entry", + MagicMock(return_value=[filter_entity, normal_entity]), + ) + + # Mock issue registry + create_issue_mock = MagicMock() + monkeypatch.setattr(init_mod.ir, "async_create_issue", create_issue_mock) + + config_device_id = "dev1" + await init_mod._deprecated_plugin_cleanup_26_1_1(hass, client, entry_id, config_device_id) + + # Verify filter entity was removed, normal was not + entity_registry.async_remove.assert_called_once_with("switch.opnsense_filter_rule") + # Verify issue created for cleanup done + create_issue_mock.assert_called_once() + call_args = create_issue_mock.call_args + assert call_args[0][2] == f"{config_device_id}_plugin_cleanup_done" + + +@pytest.mark.asyncio +async def test_deprecated_plugin_cleanup_26_1_1_plugin_installed(monkeypatch): + """_deprecated_plugin_cleanup_26_1_1 removes NAT entities when plugin is installed.""" + hass = MagicMock(spec=HomeAssistant) + client = MagicMock() + client.is_plugin_installed = AsyncMock(return_value=True) + entry_id = "test_entry_id" + + # Mock entity registry + entity_registry = MagicMock() + monkeypatch.setattr(init_mod.er, "async_get", lambda hass: entity_registry) + + # Mock entities: NAT port forward, NAT outbound, filter, normal switch + nat_pf_entity = MagicMock() + nat_pf_entity.entity_id = "switch.opnsense_nat_port_forward_rule" + nat_pf_entity.unique_id = "dev1_nat_port_forward_rule1" + nat_out_entity = MagicMock() + nat_out_entity.entity_id = "switch.opnsense_nat_outbound_rule" + nat_out_entity.unique_id = "dev1_nat_outbound_rule1" + filter_entity = MagicMock() + filter_entity.entity_id = "switch.opnsense_filter_rule" + filter_entity.unique_id = "dev1_filter_rule1" + normal_entity = MagicMock() + normal_entity.entity_id = "switch.opnsense_normal_rule" + normal_entity.unique_id = "dev1_normal_rule1" + monkeypatch.setattr( + init_mod.er, + "async_entries_for_config_entry", + MagicMock(return_value=[nat_pf_entity, nat_out_entity, filter_entity, normal_entity]), + ) + + # Mock issue registry + create_issue_mock = MagicMock() + monkeypatch.setattr(init_mod.ir, "async_create_issue", create_issue_mock) + + config_device_id = "dev1" + await init_mod._deprecated_plugin_cleanup_26_1_1(hass, client, entry_id, config_device_id) + + # Verify NAT entities were removed, filter and normal were not + assert entity_registry.async_remove.call_count == 2 + entity_registry.async_remove.assert_any_call("switch.opnsense_nat_port_forward_rule") + entity_registry.async_remove.assert_any_call("switch.opnsense_nat_outbound_rule") + # Ensure filter entity was preserved when plugin is installed + assert not entity_registry.async_remove.mock_calls or all( + call.args[0] != "switch.opnsense_filter_rule" + for call in entity_registry.async_remove.mock_calls + ) + # Verify issue created for partial cleanup + create_issue_mock.assert_called_once() + call_args = create_issue_mock.call_args + assert call_args[0][2] == f"{config_device_id}_plugin_cleanup_partial" + + @pytest.mark.asyncio async def test_migrate_2_to_3_missing_device_id(monkeypatch, fake_client): """_migrate_2_to_3 returns False when the client provides no device id.""" @@ -888,8 +1018,8 @@ async def test_async_setup_entry_firmware_at_or_above_ltd_deletes_previous_issue # Expect delete_issue to be called for the previous below-min and below-ltd issue ids assert delete_issue_mock.called, "async_delete_issue should have been called" called_issue_ids = [call[0][2] for call in delete_issue_mock.call_args_list if len(call[0]) > 2] - expected_min = f"opnsense_{init_mod.OPNSENSE_LTD_FIRMWARE}_below_min_firmware_{init_mod.OPNSENSE_MIN_FIRMWARE}" - expected_ltd = f"opnsense_{init_mod.OPNSENSE_LTD_FIRMWARE}_below_ltd_firmware_{init_mod.OPNSENSE_LTD_FIRMWARE}" + expected_min = f"{entry.data[init_mod.CONF_DEVICE_UNIQUE_ID]}_opnsense_below_min_firmware_{init_mod.OPNSENSE_MIN_FIRMWARE}" + expected_ltd = f"{entry.data[init_mod.CONF_DEVICE_UNIQUE_ID]}_opnsense_below_ltd_firmware_{init_mod.OPNSENSE_LTD_FIRMWARE}" assert expected_min in called_issue_ids assert expected_ltd in called_issue_ids @@ -922,9 +1052,9 @@ async def test_async_setup_entry_delete_uses_actual_firmware_string( res = await init_mod.async_setup_entry(hass, entry) assert res is True - # Confirm delete_issue was called with the firmware returned by the client - expected_min = f"opnsense_{firmware_str}_below_min_firmware_{init_mod.OPNSENSE_MIN_FIRMWARE}" - expected_ltd = f"opnsense_{firmware_str}_below_ltd_firmware_{init_mod.OPNSENSE_LTD_FIRMWARE}" + # Confirm delete_issue was called for the expected issue ids + expected_min = f"{entry.data[init_mod.CONF_DEVICE_UNIQUE_ID]}_opnsense_below_min_firmware_{init_mod.OPNSENSE_MIN_FIRMWARE}" + expected_ltd = f"{entry.data[init_mod.CONF_DEVICE_UNIQUE_ID]}_opnsense_below_ltd_firmware_{init_mod.OPNSENSE_LTD_FIRMWARE}" assert calls.called, "async_delete_issue should have been called" issue_ids = [call[0][2] for call in calls.call_args_list if len(call[0]) > 2] assert expected_min in issue_ids diff --git a/tests/test_integration.py b/tests/test_integration.py index c185ef59..f0049bb7 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -173,7 +173,7 @@ async def async_reload(self, entry_id): # pragma: no cover - reload path not as return None hass.config_entries = _Cfg() - hass.async_create_task = MagicMock(side_effect=lambda coro: asyncio.create_task(coro)) + hass.async_create_task = MagicMock(side_effect=asyncio.create_task) return hass @@ -223,7 +223,7 @@ async def _noop_unique_id(*a, **k): options={}, ) # Provide stubs expected by integration (update listener registration returns unsubscribe) - entry.add_update_listener = lambda f: (lambda: None) + entry.add_update_listener = lambda f: lambda: None entry.async_on_unload = lambda x: None hass.data = {} @@ -286,7 +286,7 @@ async def _noop_unique_id(*a, **k): # redefined for this test context entry_id="entry_gran", options={}, ) - entry.add_update_listener = lambda f: (lambda: None) + entry.add_update_listener = lambda f: lambda: None entry.async_on_unload = lambda x: None # Add to fake hass store so options flow update calls can mutate it @@ -396,7 +396,7 @@ async def _noop_unique_id(*a, **k): entry_id="entry_rel", options={}, ) - entry.add_update_listener = lambda f: (lambda: None) + entry.add_update_listener = lambda f: lambda: None entry.async_on_unload = lambda x: None # Setup diff --git a/tests/test_pyopnsense.py b/tests/test_pyopnsense.py index 7fb7369b..25558b55 100644 --- a/tests/test_pyopnsense.py +++ b/tests/test_pyopnsense.py @@ -7,15 +7,18 @@ import asyncio import contextlib +import copy from datetime import datetime, timedelta import inspect as _inspect import socket from ssl import SSLError +from typing import Any from unittest.mock import AsyncMock, MagicMock import xmlrpc.client as xc from xmlrpc.client import Fault import aiohttp +import awesomeversion import pytest from yarl import URL @@ -25,6 +28,7 @@ sensor as sensor_mod, switch as switch_mod, ) +from custom_components.opnsense.const import CONF_SYNC_FIREWALL_AND_NAT def test_human_friendly_duration() -> None: @@ -193,7 +197,9 @@ async def test_opnsenseclient_async_close(make_client) -> None: @pytest.mark.asyncio -async def test_get_host_firmware_set_use_snake_case_and_plugin_installed(make_client) -> None: +async def test_get_host_firmware_set_use_snake_case_and_plugin_installed( + monkeypatch, make_client +) -> None: """Ensure firmware parsing, snake_case detection and plugin detection work.""" # create client/session for this test session = MagicMock(spec=aiohttp.ClientSession) @@ -209,6 +215,21 @@ async def test_get_host_firmware_set_use_snake_case_and_plugin_installed(make_cl await client.set_use_snake_case() assert client._use_snake_case is True + # set use snake case should detect <25.7 + client._firmware_version = "25.1.0" + await client.set_use_snake_case() + assert client._use_snake_case is False + + # test AwesomeVersionCompareException handling + def mock_compare(self, other): + raise awesomeversion.exceptions.AwesomeVersionCompareException("test exception") + + monkeypatch.setattr(awesomeversion.AwesomeVersion, "__lt__", mock_compare) + client._firmware_version = "25.8.0" + await client.set_use_snake_case() + # Should default to True on exception + assert client._use_snake_case is True + # invalid semver -> fallback to product_series client._safe_dict_get = AsyncMock( return_value={"product": {"product_version": "weird", "product_series": "seriesX"}} @@ -959,6 +980,9 @@ async def __aexit__(self, exc_type, exc, tb): async def json(self, content_type=None): return {"x": 1} + async def text(self): + return "raw response text" + @property def content(self): class C: @@ -1067,7 +1091,7 @@ def mkarg(pname: str): for name, func in _inspect.getmembers( pyopnsense.OPNsenseClient, predicate=_inspect.iscoroutinefunction ) - if name not in ("__init__",) + if name != "__init__" ] for name, _func in coros: @@ -1313,31 +1337,31 @@ async def test_enable_and_disable_filter_rules_and_nat_port_forward(make_client) url="http://localhost", username="u", password="p", session=session ) - # enable_filter_rule_by_created_time: rule has 'disabled' -> should remove and call restore+configure + # enable_filter_rule_by_created_time_legacy: rule has 'disabled' -> should remove and call restore+configure cfg_enable = {"filter": {"rule": [{"created": {"time": "t-enable"}, "disabled": "1"}]}} client.get_config = AsyncMock(return_value=cfg_enable) client._restore_config_section = AsyncMock() client._filter_configure = AsyncMock() - await client.enable_filter_rule_by_created_time("t-enable") - client._restore_config_section.assert_called() + await client.enable_filter_rule_by_created_time_legacy("t-enable") + client._restore_config_section.assert_awaited() client._filter_configure.assert_awaited() - # disable_filter_rule_by_created_time: rule missing 'disabled' -> should add it and call restore+configure + # disable_filter_rule_by_created_time_legacy: rule missing 'disabled' -> should add it and call restore+configure cfg_disable = {"filter": {"rule": [{"created": {"time": "t-disable"}}]}} client.get_config = AsyncMock(return_value=cfg_disable) client._restore_config_section = AsyncMock() client._filter_configure = AsyncMock() - await client.disable_filter_rule_by_created_time("t-disable") - client._restore_config_section.assert_called() + await client.disable_filter_rule_by_created_time_legacy("t-disable") + client._restore_config_section.assert_awaited() client._filter_configure.assert_awaited() - # enable_nat_port_forward_rule_by_created_time: similar flow under 'nat' section + # enable_nat_port_forward_rule_by_created_time_legacy: similar flow under 'nat' section cfg_nat = {"nat": {"rule": [{"created": {"time": "t-nat"}, "disabled": "1"}]}} client.get_config = AsyncMock(return_value=cfg_nat) client._restore_config_section = AsyncMock() client._filter_configure = AsyncMock() - await client.enable_nat_port_forward_rule_by_created_time("t-nat") - client._restore_config_section.assert_called() + await client.enable_nat_port_forward_rule_by_created_time_legacy("t-nat") + client._restore_config_section.assert_awaited() client._filter_configure.assert_awaited() @@ -1500,21 +1524,28 @@ def add_entities(ents): # filter switches: validate via public switch setup path as a smoke test state2 = { - "config": { - "filter": { - "rule": [ - {"descr": "Anti-Lockout Rule", "created": {"time": "t1"}}, - {"descr": "Normal", "created": {"time": "t2"}, "associated-rule-id": "r1"}, - {"descr": "Ok", "created": {"time": "t3"}}, - ] + "host_firmware_version": "25.7.8", + "firewall": { + "config": { + "filter": { + "rule": [ + {"description": "Anti-Lockout Rule", "created": {"time": "t1"}}, + { + "description": "Normal", + "created": {"time": "t2"}, + "associated-rule-id": "r1", + }, + {"description": "Ok", "created": {"time": "t3"}}, + ] + } } - } + }, } # prepare a switch config entry with filter sync enabled switch_cfg = make_config_entry( data={ "device_unique_id": "dev1", - "sync_filters_and_nat": True, + CONF_SYNC_FIREWALL_AND_NAT: True, "sync_unbound": False, "sync_vpn": False, "sync_services": False, @@ -2690,9 +2721,12 @@ async def test_get_config_and_rule_enable_disable_branches() -> None: } client._exec_php = AsyncMock(return_value=fake_config) - + client._restore_config_section = AsyncMock() + client._filter_configure = AsyncMock() # calling enable should remove 'disabled' and call restore/configure (no exception) - await client.enable_filter_rule_by_created_time("t1") + await client.enable_filter_rule_by_created_time_legacy("t1") + client._restore_config_section.assert_awaited() + client._filter_configure.assert_awaited() # disable_nat_port_forward: add a rule without 'disabled' and expect it to set 'disabled' client._exec_php = AsyncMock( @@ -2701,7 +2735,7 @@ async def test_get_config_and_rule_enable_disable_branches() -> None: # patch _restore_config_section and _filter_configure to be no-ops client._restore_config_section = AsyncMock() client._filter_configure = AsyncMock() - await client.disable_nat_port_forward_rule_by_created_time("n1") + await client.disable_nat_port_forward_rule_by_created_time_legacy("n1") await client.async_close() @@ -2890,7 +2924,7 @@ async def test_reset_and_get_query_counts(): ), ], ) -async def test_enable_filter_rule_by_created_time( +async def test_enable_filter_rule_by_created_time_legacy( make_client, rules, created_time, should_call ) -> None: """Ensure enabling a filter rule removes 'disabled' and triggers restore/configure only when appropriate. @@ -2906,7 +2940,7 @@ async def test_enable_filter_rule_by_created_time( client._restore_config_section = AsyncMock() client._filter_configure = AsyncMock() - await client.enable_filter_rule_by_created_time(created_time) + await client.enable_filter_rule_by_created_time_legacy(created_time) if should_call: client._restore_config_section.assert_awaited() @@ -3053,7 +3087,7 @@ async def test_enable_disable_nat_outbound_rules(make_client) -> None: client.get_config = AsyncMock(return_value=cfg_enable) client._restore_config_section = AsyncMock() client._filter_configure = AsyncMock() - await client.enable_nat_outbound_rule_by_created_time("t1") + await client.enable_nat_outbound_rule_by_created_time_legacy("t1") client._restore_config_section.assert_awaited() client._filter_configure.assert_awaited() @@ -3062,7 +3096,7 @@ async def test_enable_disable_nat_outbound_rules(make_client) -> None: client.get_config = AsyncMock(return_value=cfg_disable) client._restore_config_section = AsyncMock() client._filter_configure = AsyncMock() - await client.disable_nat_outbound_rule_by_created_time("t2") + await client.disable_nat_outbound_rule_by_created_time_legacy("t2") client._restore_config_section.assert_awaited() client._filter_configure.assert_awaited() await client.async_close() @@ -3171,3 +3205,294 @@ async def test_get_device_unique_id_no_mac(make_client) -> None: client._safe_list_get = AsyncMock(return_value=[{"is_physical": False}]) assert await client.get_device_unique_id() is None await client.async_close() + + +@pytest.mark.asyncio +async def test_get_firewall_legacy_fallback(make_client) -> None: + """get_firewall falls back to legacy config for OPNsense < 26.1.1.""" + session = MagicMock(spec=aiohttp.ClientSession) + client = make_client(session=session) + client._firmware_version = "25.7.0" + + # Mock get_config for legacy fallback + client.get_config = AsyncMock(return_value={"filter": {"rule": []}}) + + result = await client.get_firewall() + assert result == {"config": {"filter": {"rule": []}}} + client.get_config.assert_awaited_once() + await client.async_close() + + +@pytest.mark.asyncio +async def test_get_firewall_new_api(make_client) -> None: + """get_firewall uses new API for OPNsense >= 26.1.1.""" + session = MagicMock(spec=aiohttp.ClientSession) + client = make_client(session=session) + client._firmware_version = "26.1.1" + + # Mock all the methods called in the new API path + client.is_plugin_installed = AsyncMock(return_value=True) + client.get_config = AsyncMock(return_value={"filter": {"rule": []}}) + client._get_firewall_rules = AsyncMock(return_value={"rule1": {"uuid": "rule1"}}) + client._get_nat_destination_rules = AsyncMock(return_value={"nat1": {"uuid": "nat1"}}) + client._get_nat_one_to_one_rules = AsyncMock(return_value={"one1": {"uuid": "one1"}}) + client._get_nat_source_rules = AsyncMock(return_value={"src1": {"uuid": "src1"}}) + client._get_nat_npt_rules = AsyncMock(return_value={"npt1": {"uuid": "npt1"}}) + + result = await client.get_firewall() + expected = { + "config": {"filter": {"rule": []}}, + "rules": {"rule1": {"uuid": "rule1"}}, + "nat": { + "d_nat": {"nat1": {"uuid": "nat1"}}, + "one_to_one": {"one1": {"uuid": "one1"}}, + "source_nat": {"src1": {"uuid": "src1"}}, + "npt": {"npt1": {"uuid": "npt1"}}, + }, + } + assert result == expected + client.is_plugin_installed.assert_awaited_once() + client.get_config.assert_awaited_once() + client._get_firewall_rules.assert_awaited_once() + client._get_nat_destination_rules.assert_awaited_once() + client._get_nat_one_to_one_rules.assert_awaited_once() + client._get_nat_source_rules.assert_awaited_once() + client._get_nat_npt_rules.assert_awaited_once() + await client.async_close() + + +@pytest.mark.asyncio +async def test_get_firewall_new_api_plugin_not_installed(make_client) -> None: + """get_firewall uses new API for OPNsense >= 26.1.1 but when plugin not installed it should skip config.""" + session = MagicMock(spec=aiohttp.ClientSession) + client = make_client(session=session) + client._firmware_version = "26.1.1" + + # Plugin not installed: shouldn't call get_config + client.is_plugin_installed = AsyncMock(return_value=False) + client.get_config = AsyncMock(return_value={"filter": {"rule": []}}) + client._get_firewall_rules = AsyncMock(return_value={"rule1": {"uuid": "rule1"}}) + client._get_nat_destination_rules = AsyncMock(return_value={"nat1": {"uuid": "nat1"}}) + client._get_nat_one_to_one_rules = AsyncMock(return_value={"one1": {"uuid": "one1"}}) + client._get_nat_source_rules = AsyncMock(return_value={"src1": {"uuid": "src1"}}) + client._get_nat_npt_rules = AsyncMock(return_value={"npt1": {"uuid": "npt1"}}) + + result = await client.get_firewall() + expected = { + "rules": {"rule1": {"uuid": "rule1"}}, + "nat": { + "d_nat": {"nat1": {"uuid": "nat1"}}, + "one_to_one": {"one1": {"uuid": "one1"}}, + "source_nat": {"src1": {"uuid": "src1"}}, + "npt": {"npt1": {"uuid": "npt1"}}, + }, + } + assert result == expected + client.is_plugin_installed.assert_awaited_once() + client.get_config.assert_not_awaited() + client._get_firewall_rules.assert_awaited_once() + client._get_nat_destination_rules.assert_awaited_once() + client._get_nat_one_to_one_rules.assert_awaited_once() + client._get_nat_source_rules.assert_awaited_once() + client._get_nat_npt_rules.assert_awaited_once() + await client.async_close() + + +@pytest.mark.asyncio +async def test_get_firewall_version_compare_exception(make_client) -> None: + """get_firewall handles AwesomeVersionCompareException gracefully.""" + session = MagicMock(spec=aiohttp.ClientSession) + client = make_client(session=session) + client._firmware_version = "invalid" + + result = await client.get_firewall() + assert result == {} + await client.async_close() + + +@pytest.mark.asyncio +async def test_get_firewall_rules_successful_parsing(make_client) -> None: + """_get_firewall_rules successfully parses rows returned from the REST API.""" + session = MagicMock(spec=aiohttp.ClientSession) + client = make_client(session=session) + + rows = [ + { + "uuid": "rule1", + "enabled": "1", + "action": "pass", + "interface": "lan", + "descr": "Allow HTTP", + }, + { + "uuid": "rule2", + "enabled": "0", + "action": "block", + "interface": "wan", + "descr": "Block traffic", + }, + ] + + client._safe_dict_post = AsyncMock(return_value={"rows": rows}) + + result = await client._get_firewall_rules() + + expected = {r["uuid"]: r.copy() for r in rows} + assert result == expected + client._safe_dict_post.assert_awaited_once_with( + "/api/firewall/filter/search_rule", payload={"current": 1, "sort": {}} + ) + await client.async_close() + + +@pytest.mark.asyncio +async def test_get_firewall_rules_empty_response(make_client) -> None: + """_get_firewall_rules returns empty dict when API response has no rows.""" + session = MagicMock(spec=aiohttp.ClientSession) + client = make_client(session=session) + + client._safe_dict_post = AsyncMock(return_value={}) + + result = await client._get_firewall_rules() + assert result == {} + await client.async_close() + + +@pytest.mark.asyncio +async def test_get_firewall_rules_skips_invalid_rows(make_client) -> None: + """_get_firewall_rules skips rules without uuid and lockout rules.""" + session = MagicMock(spec=aiohttp.ClientSession) + client = make_client(session=session) + + rows = [ + {"enabled": "1", "action": "pass"}, # missing uuid + {"uuid": "lockout-1", "enabled": "1"}, # lockout rule + {"uuid": "rule-ok", "enabled": "1"}, # valid + ] + + client._safe_dict_post = AsyncMock(return_value={"rows": rows}) + + result = await client._get_firewall_rules() + assert list(result.keys()) == ["rule-ok"] + await client.async_close() + + +@pytest.mark.parametrize( + ("method_name", "api_endpoint", "has_transformations"), + [ + ("_get_nat_destination_rules", "/api/firewall/d_nat/search_rule", True), + ("_get_nat_one_to_one_rules", "/api/firewall/one_to_one/search_rule", False), + ("_get_nat_source_rules", "/api/firewall/source_nat/search_rule", False), + ("_get_nat_npt_rules", "/api/firewall/npt/search_rule", False), + ], +) +@pytest.mark.parametrize( + ("test_case", "expected_result"), + [ + ( + "successful_parsing", + { + "test-rule-1": { + "uuid": "test-rule-1", + "description": "Test rule 1", + "enabled": "1", + "interface": "wan", + "protocol": "tcp", + }, + "test-rule-2": { + "uuid": "test-rule-2", + "description": "Test rule 2", + "enabled": "0", + "interface": "lan", + "protocol": "udp", + }, + }, + ), + ( + "filters_lockout_rules", + { + "normal-rule": { + "uuid": "normal-rule", + "description": "Normal rule", + "enabled": "1", + } + }, + ), + ("empty_response", {}), + ("response_without_rows", {}), + ], +) +@pytest.mark.asyncio +async def test_nat_rules_parsing( + make_client, + method_name, + api_endpoint, + has_transformations, + test_case, + expected_result, +) -> None: + """Test NAT rules parsing for all NAT rule types.""" + session = MagicMock(spec=aiohttp.ClientSession) + client = make_client(session=session) + + # Build API-style mock response depending on whether the endpoint uses + # transformations (d_nat-like endpoints use 'descr'/'disabled'). + mock_response: dict[str, Any] + if test_case == "empty_response": + mock_response = {} + elif test_case == "response_without_rows": + mock_response = {"some_other_key": "value"} + else: + normalized_rows: list[dict[str, Any]] = [] + extra_rows: list[dict[str, Any]] = [] + if test_case == "successful_parsing": + for uid, info in expected_result.items(): + row = {"uuid": uid} + row["description"] = info.get("description") + row["enabled"] = info.get("enabled") + if "interface" in info: + row["interface"] = info.get("interface") + if "protocol" in info: + row["protocol"] = info.get("protocol") + normalized_rows.append(row) + elif test_case == "filters_lockout_rules": + normalized_rows = [ + {"uuid": "normal-rule", "description": "Normal rule", "enabled": "1"} + ] + extra_rows = [ + {"uuid": "lockout-rule", "description": "Lockout rule", "enabled": "1"}, + {"uuid": "another-lockout", "description": "Another lockout", "enabled": "1"}, + {"uuid": None, "description": "No UUID rule", "enabled": "1"}, + ] + + api_rows: list[dict[str, Any]] = [] + for row in normalized_rows + extra_rows: + if has_transformations: + new_row = row.copy() + if "description" in new_row: + new_row["descr"] = new_row.pop("description") + if "enabled" in new_row: + new_row["disabled"] = "0" if new_row.pop("enabled") == "1" else "1" + api_rows.append(new_row) + else: + api_rows.append(row.copy()) + + mock_response = {"rows": api_rows} + + client._safe_dict_post = AsyncMock(return_value=mock_response) + + # Call the appropriate method + method = getattr(client, method_name) + result = await method() + + # Make a deep copy of expected_result so we don't mutate the shared fixture + expected = copy.deepcopy(expected_result) + + assert result == expected + + # Verify the correct API endpoint was called + client._safe_dict_post.assert_awaited_once_with( + api_endpoint, payload={"current": 1, "sort": {}} + ) + + await client.async_close() diff --git a/tests/test_switch.py b/tests/test_switch.py index 43093fc0..84638937 100644 --- a/tests/test_switch.py +++ b/tests/test_switch.py @@ -16,7 +16,7 @@ ATTR_NAT_OUTBOUND, ATTR_NAT_PORT_FORWARD, CONF_DEVICE_UNIQUE_ID, - CONF_SYNC_FILTERS_AND_NAT, + CONF_SYNC_FIREWALL_AND_NAT, CONF_SYNC_SERVICES, CONF_SYNC_UNBOUND, CONF_SYNC_VPN, @@ -24,13 +24,20 @@ ) from custom_components.opnsense.coordinator import OPNsenseDataUpdateCoordinator from custom_components.opnsense.switch import ( - OPNsenseFilterSwitch, - OPNsenseNatSwitch, + OPNsenseFilterSwitchLegacy, + OPNsenseFirewallRuleSwitch, + OPNsenseNATRuleSwitch, + OPNsenseNatSwitchLegacy, OPNsenseServiceSwitch, OPNsenseVPNSwitch, - _compile_filter_switches, - _compile_nat_outbound_switches, - _compile_port_forward_switches, + _compile_filter_switches_legacy, + _compile_firewall_rules_switches, + _compile_nat_destination_rules_switches, + _compile_nat_npt_rules_switches, + _compile_nat_one_to_one_rules_switches, + _compile_nat_outbound_switches_legacy, + _compile_nat_source_rules_switches, + _compile_port_forward_switches_legacy, _compile_service_switches, _compile_static_unbound_switch_legacy, _compile_unbound_switches, @@ -52,28 +59,43 @@ def make_coord(data): "compile_fn,state,client_methods", [ ( - _compile_filter_switches, - {"config": {"filter": {"rule": [{"descr": "Allow LAN", "created": {"time": "t1"}}]}}}, - ("enable_filter_rule_by_created_time", "disable_filter_rule_by_created_time"), + _compile_filter_switches_legacy, + { + "firewall": { + "config": { + "filter": {"rule": [{"descr": "Allow LAN", "created": {"time": "t1"}}]} + } + } + }, + ( + "enable_filter_rule_by_created_time_legacy", + "disable_filter_rule_by_created_time_legacy", + ), ), ( - _compile_port_forward_switches, - {"config": {"nat": {"rule": [{"descr": "PF", "created": {"time": "p1"}}]}}}, + _compile_port_forward_switches_legacy, + { + "firewall": { + "config": {"nat": {"rule": [{"descr": "PF", "created": {"time": "p1"}}]}} + } + }, ( - "enable_nat_port_forward_rule_by_created_time", - "disable_nat_port_forward_rule_by_created_time", + "enable_nat_port_forward_rule_by_created_time_legacy", + "disable_nat_port_forward_rule_by_created_time_legacy", ), ), ( - _compile_nat_outbound_switches, + _compile_nat_outbound_switches_legacy, { - "config": { - "nat": {"outbound": {"rule": [{"descr": "OB", "created": {"time": "o1"}}]}} + "firewall": { + "config": { + "nat": {"outbound": {"rule": [{"descr": "OB", "created": {"time": "o1"}}]}} + } } }, ( - "enable_nat_outbound_rule_by_created_time", - "disable_nat_outbound_rule_by_created_time", + "enable_nat_outbound_rule_by_created_time_legacy", + "disable_nat_outbound_rule_by_created_time_legacy", ), ), ( @@ -96,6 +118,94 @@ def make_coord(data): {"unbound_blocklist": {"legacy": {"enabled": "1"}}}, ("enable_unbound_blocklist", "disable_unbound_blocklist"), ), + ( + _compile_firewall_rules_switches, + { + "firewall": { + "rules": { + "rule1": { + "uuid": "rule1", + "description": "Test Firewall Rule", + "%interface": "wan", + "enabled": "1", + } + } + } + }, + ("toggle_firewall_rule", "toggle_firewall_rule"), + ), + ( + _compile_nat_source_rules_switches, + { + "firewall": { + "nat": { + "source_nat": { + "nat1": { + "uuid": "nat1", + "description": "Source NAT Rule", + "%interface": "wan", + "enabled": "1", + } + } + } + } + }, + ("toggle_nat_rule", "toggle_nat_rule"), + ), + ( + _compile_nat_destination_rules_switches, + { + "firewall": { + "nat": { + "d_nat": { + "dnat1": { + "uuid": "dnat1", + "description": "Destination NAT Rule", + "%interface": "wan", + "enabled": "1", + } + } + } + } + }, + ("toggle_nat_rule", "toggle_nat_rule"), + ), + ( + _compile_nat_one_to_one_rules_switches, + { + "firewall": { + "nat": { + "one_to_one": { + "oto1": { + "uuid": "oto1", + "description": "One-to-One NAT Rule", + "%interface": "wan", + "enabled": "1", + } + } + } + } + }, + ("toggle_nat_rule", "toggle_nat_rule"), + ), + ( + _compile_nat_npt_rules_switches, + { + "firewall": { + "nat": { + "npt": { + "npt1": { + "uuid": "npt1", + "description": "NPT NAT Rule", + "%interface": "wan", + "enabled": "1", + } + } + } + } + }, + ("toggle_nat_rule", "toggle_nat_rule"), + ), ], ) async def test_switch_toggle_variants( @@ -136,30 +246,101 @@ async def test_switch_toggle_variants( # call turn_on/turn_off and assert client methods called await ent.async_turn_on() - # ensure the async client coroutine was actually awaited - getattr(ent._client, client_methods[0]).assert_awaited_once() # turning on should set delay_update for entities that perform delayed updates assert ent.delay_update is True + await ent.async_turn_off() - getattr(ent._client, client_methods[1]).assert_awaited_once() # turning off should also set delay_update assert ent.delay_update is True + # Check that the correct client methods were called + if client_methods[0] == client_methods[1]: + # Same method for on/off with different parameters + if "firewall_rule" in client_methods[0]: + # toggle_firewall_rule(rule_id, action) + getattr(ent._client, client_methods[0]).assert_has_awaits( + [ + ((ent._rule_id, "on"), {}), + ((ent._rule_id, "off"), {}), + ], + any_order=True, + ) + else: + # toggle_nat_rule(rule_type, rule_id, action) + getattr(ent._client, client_methods[0]).assert_has_awaits( + [ + ((ent._nat_rule_type, ent._rule_id, "on"), {}), + ((ent._nat_rule_type, ent._rule_id, "off"), {}), + ], + any_order=True, + ) + else: + # Different methods for on/off + getattr(ent._client, client_methods[0]).assert_awaited_once() + getattr(ent._client, client_methods[1]).assert_awaited_once() + + +@pytest.mark.asyncio +async def test_compile_filter_switches_legacy_skip_conditions(make_config_entry) -> None: + """Test that _compile_filter_switches_legacy properly skips invalid rules.""" + + config_entry = make_config_entry( + data={CONF_DEVICE_UNIQUE_ID: "dev1", "url": "http://example"}, + title="OPNsenseTest", + ) + coordinator = make_coord({}) + + # Test with non-dict rules + state_with_mixed_rules = { + "firewall": { + "config": { + "filter": { + "rule": [ + {"descr": "Valid Rule", "created": {"time": "t1"}}, + "invalid_string_rule", # should be skipped + { + "descr": "NAT Rule", + "associated-rule-id": "nat1", + "created": {"time": "t2"}, + }, # should be skipped + { + "description": "Anti-Lockout Rule", + "created": {"time": "t3"}, + }, # should be skipped + {"descr": "No Tracker"}, # should be skipped + {"descr": "Empty Tracker", "created": {"time": ""}}, # should be skipped + ] + } + } + } + } + + entities = await _compile_filter_switches_legacy( + config_entry, coordinator, state_with_mixed_rules + ) + # Only the first valid rule should be included + assert len(entities) == 1 + assert entities[0].entity_description.key == "filter.t1" + assert "Valid Rule" in entities[0].entity_description.name + @pytest.mark.asyncio async def test_compile_port_forward_skips_non_dict(coordinator, make_config_entry): """Port forward compilation should skip non-dict rule entries.""" config_entry = make_config_entry( data={CONF_DEVICE_UNIQUE_ID: "dev1", "url": "http://example"}, + options={CONF_SYNC_FIREWALL_AND_NAT: True}, title="OPNsenseTest", ) setattr(config_entry.runtime_data, COORDINATOR, coordinator) # include a non-dict in nat.rule which should be skipped state = { - "config": {"nat": {"rule": ["not-a-dict", {"descr": "PF", "created": {"time": "p2"}}]}} + "firewall": { + "config": {"nat": {"rule": ["not-a-dict", {"descr": "PF", "created": {"time": "p2"}}]}} + } } coordinator.data = state - ents = await _compile_port_forward_switches(config_entry, coordinator, state) + ents = await _compile_port_forward_switches_legacy(config_entry, coordinator, state) assert len(ents) == 1 assert ents[0].entity_description.key.endswith(".p2") @@ -174,11 +355,13 @@ def fake_add_entities(entities): # create a state that contains one of each entity type state = { - "config": { - "filter": {"rule": [{"descr": "Allow", "created": {"time": "f1"}}]}, - "nat": { - "rule": [{"descr": "PF", "created": {"time": "p1"}}], - "outbound": {"rule": [{"descr": "OB", "created": {"time": "o1"}}]}, + "firewall": { + "config": { + "filter": {"rule": [{"descr": "Allow", "created": {"time": "f1"}}]}, + "nat": { + "rule": [{"descr": "PF", "created": {"time": "p1"}}], + "outbound": {"rule": [{"descr": "OB", "created": {"time": "o1"}}]}, + }, }, }, "services": [{"id": "s1", "name": "svc", "locked": 0, "status": True}], @@ -192,7 +375,7 @@ def fake_add_entities(entities): config_entry = make_config_entry( data={ CONF_DEVICE_UNIQUE_ID: "dev1", - CONF_SYNC_FILTERS_AND_NAT: True, + CONF_SYNC_FIREWALL_AND_NAT: True, CONF_SYNC_SERVICES: True, CONF_SYNC_VPN: True, CONF_SYNC_UNBOUND: True, @@ -206,7 +389,7 @@ def fake_add_entities(entities): # compute expected counts from coordinator.data to avoid brittle hard-coded value expected = 0 - cfg = coordinator.data.get("config", {}) + cfg = coordinator.data.get("firewall", {}).get("config", {}) # filter rules expected += len(cfg.get("filter", {}).get("rule", []) or []) # port forward rules @@ -227,6 +410,84 @@ def fake_add_entities(entities): assert calls.get("len") == expected +@pytest.mark.asyncio +async def test_async_setup_entry_new_firewall_api(coordinator, ph_hass, make_config_entry): + """Async setup should create entities for new firewall API (>= 26.1.1).""" + calls = {} + + def fake_add_entities(entities): + calls["len"] = len(entities) + + # create a state that contains new firewall API structure + state = { + "firewall": { + "rules": { + "rule1": { + "uuid": "rule1", + "description": "Test Firewall Rule", + "%interface": "wan", + "enabled": "1", + } + }, + "nat": { + "source_nat": { + "nat1": { + "uuid": "nat1", + "description": "Test Source NAT", + "%interface": "wan", + "enabled": "1", + } + }, + "d_nat": { + "dnat1": { + "uuid": "dnat1", + "description": "Test Destination NAT", + "%interface": "wan", + "enabled": "1", + } + }, + "one_to_one": { + "oto1": { + "uuid": "oto1", + "description": "Test One-to-One NAT", + "%interface": "wan", + "enabled": "1", + } + }, + "npt": { + "npt1": { + "uuid": "npt1", + "description": "Test NPT NAT", + "%interface": "wan", + "enabled": "1", + } + }, + }, + }, + "host_firmware_version": "26.1.1", + } + coordinator.data = state + + config_entry = make_config_entry( + data={ + CONF_DEVICE_UNIQUE_ID: "dev1", + CONF_SYNC_FIREWALL_AND_NAT: True, + CONF_SYNC_SERVICES: False, + CONF_SYNC_VPN: False, + CONF_SYNC_UNBOUND: False, + }, + title="OPNsenseTest", + ) + setattr(config_entry.runtime_data, COORDINATOR, coordinator) + + hass = ph_hass + await switch_mod.async_setup_entry(hass, config_entry, fake_add_entities) + + # Should create entities for each rule type in new API + expected = 5 # 1 firewall rule + 4 NAT rules + assert calls.get("len") == expected + + def test_vpn_icon_property(make_config_entry): """VPN switch exposes the expected icon when available and on.""" desc = SwitchEntityDescription(key="openvpn.clients.c1", name="VPNC") @@ -291,7 +552,7 @@ def add_entities(ents): for e in created if ( ( - not isinstance(e, OPNsenseFilterSwitch) + not isinstance(e, OPNsenseFilterSwitchLegacy) and getattr(e, "_attr_unique_id", "").endswith("unbound") ) or ( @@ -342,7 +603,7 @@ def test_delay_update_setter(monkeypatch, coordinator, make_config_entry): title="OPNsenseTest", ) setattr(config_entry.runtime_data, COORDINATOR, coordinator) - ent = OPNsenseFilterSwitch( + ent = OPNsenseFilterSwitchLegacy( config_entry=config_entry, coordinator=coordinator, entity_description=desc ) # synchronous test: use a plain hass-like object with a dedicated loop @@ -480,7 +741,7 @@ async def test_vpn_turn_on_off_noops_when_preconditions_fail( # replace async_call_later to avoid scheduling monkeypatch.setattr( "custom_components.opnsense.switch.async_call_later", - lambda hass, delay, action: (lambda: None), + lambda hass, delay, action: lambda: None, ) desc = SwitchEntityDescription(key="openvpn.clients.c2", name="VPNC2") @@ -533,7 +794,7 @@ async def test_vpn_async_turn_off_variations( # avoid scheduling real async_call_later during tests monkeypatch.setattr( "custom_components.opnsense.switch.async_call_later", - lambda hass, delay, action: (lambda: None), + lambda hass, delay, action: lambda: None, ) desc = SwitchEntityDescription(key="openvpn.clients.cx", name="VPNCX") @@ -591,21 +852,26 @@ async def test_filter_disabled_and_missing(coordinator, ph_hass, make_config_ent """Filter compilation handles missing and disabled rules correctly.""" config_entry = make_config_entry( data={CONF_DEVICE_UNIQUE_ID: "dev1", "url": "http://example"}, + options={CONF_SYNC_FIREWALL_AND_NAT: True}, title="OPNsenseTest", ) setattr(config_entry.runtime_data, COORDINATOR, coordinator) # missing rules -> compile returns [] - state = {"config": {"filter": {"rule": []}}} + state = {"firewall": {"config": {"filter": {"rule": []}}}} coordinator.data = state - entities = await _compile_filter_switches(config_entry, coordinator, state) + entities = await _compile_filter_switches_legacy(config_entry, coordinator, state) assert entities == [] # disabled rule -> is_on False state = { - "config": {"filter": {"rule": [{"descr": "x", "created": {"time": "t2"}, "disabled": "1"}]}} + "firewall": { + "config": { + "filter": {"rule": [{"descr": "x", "created": {"time": "t2"}, "disabled": "1"}]} + } + } } coordinator.data = state - entities = await _compile_filter_switches(config_entry, coordinator, state) + entities = await _compile_filter_switches_legacy(config_entry, coordinator, state) ent = entities[0] # use PHCC-provided hass fixture hass = ph_hass @@ -665,31 +931,42 @@ async def test_unbound_skips_update_when_delay_set(coordinator, ph_hass, make_co @pytest.mark.asyncio @pytest.mark.parametrize( - "kind,compile_fn,state,selector", + "kind,compile_fn,state,selector,options", [ ( "unbound", _compile_static_unbound_switch_legacy, {"unbound_blocklist": {"legacy": {"enabled": "1"}}}, "first", + {CONF_SYNC_UNBOUND: True}, ), ( "filter", - _compile_filter_switches, + _compile_filter_switches_legacy, { - "config": { - "filter": { - "rule": [{"descr": "Allow", "created": {"time": "fdelay"}, "disabled": "0"}] + "firewall": { + "config": { + "filter": { + "rule": [ + {"descr": "Allow", "created": {"time": "fdelay"}, "disabled": "0"} + ] + } } } }, "first", + {CONF_SYNC_FIREWALL_AND_NAT: True}, ), ( "nat", - _compile_port_forward_switches, - {"config": {"nat": {"rule": [{"descr": "PF", "created": {"time": "pdelay"}}]}}}, + _compile_port_forward_switches_legacy, + { + "firewall": { + "config": {"nat": {"rule": [{"descr": "PF", "created": {"time": "pdelay"}}]}} + } + }, "first", + {CONF_SYNC_FIREWALL_AND_NAT: True}, ), ( "service", @@ -706,6 +983,7 @@ async def test_unbound_skips_update_when_delay_set(coordinator, ph_hass, make_co ] }, "first", + {CONF_SYNC_SERVICES: True}, ), ( "vpn", @@ -715,14 +993,17 @@ async def test_unbound_skips_update_when_delay_set(coordinator, ph_hass, make_co "wireguard": {"clients": {}, "servers": {}}, }, "endswith:v1", + {CONF_SYNC_VPN: True}, ), ], ) async def test_delay_skips_update_parametrized( - kind, compile_fn, state, selector, coordinator, ph_hass, make_config_entry + kind, compile_fn, state, selector, options, coordinator, ph_hass, make_config_entry ): """Parametrized test asserting handlers return early when delay_update is set.""" - config_entry = make_config_entry({CONF_DEVICE_UNIQUE_ID: "dev1", "url": "http://example"}) + config_entry = make_config_entry( + data={CONF_DEVICE_UNIQUE_ID: "dev1", "url": "http://example"}, options=options + ) setattr(config_entry.runtime_data, COORDINATOR, coordinator) coordinator.data = state @@ -758,9 +1039,9 @@ async def test_compile_helpers_bad_input(coordinator, make_config_entry): config_entry = make_config_entry({CONF_DEVICE_UNIQUE_ID: "dev1"}) setattr(config_entry.runtime_data, COORDINATOR, coordinator) # non-mapping state - assert await _compile_filter_switches(config_entry, coordinator, None) == [] - assert await _compile_port_forward_switches(config_entry, coordinator, None) == [] - assert await _compile_nat_outbound_switches(config_entry, coordinator, None) == [] + assert await _compile_filter_switches_legacy(config_entry, coordinator, None) == [] + assert await _compile_port_forward_switches_legacy(config_entry, coordinator, None) == [] + assert await _compile_nat_outbound_switches_legacy(config_entry, coordinator, None) == [] @pytest.mark.asyncio @@ -797,11 +1078,17 @@ async def test_switch_handle_error_sets_unavailable( if kind == "filter": # compile one valid filter entity then monkeypatch to produce error - config_entry = make_config_entry({CONF_DEVICE_UNIQUE_ID: "dev1"}) + config_entry = make_config_entry( + data={CONF_DEVICE_UNIQUE_ID: "dev1"}, options={CONF_SYNC_FIREWALL_AND_NAT: True} + ) setattr(config_entry.runtime_data, COORDINATOR, coordinator) - state = {"config": {"filter": {"rule": [{"descr": "Good", "created": {"time": "t1"}}]}}} + state = { + "firewall": { + "config": {"filter": {"rule": [{"descr": "Good", "created": {"time": "t1"}}]}} + } + } coordinator.data = state - ent = (await _compile_filter_switches(config_entry, coordinator, state))[0] + ent = (await _compile_filter_switches_legacy(config_entry, coordinator, state))[0] ent.hass = hass_local ent.coordinator = make_coord(state) ent.entity_id = f"switch.{ent._attr_unique_id}" @@ -815,7 +1102,7 @@ def _fake_get_rule_filter() -> Any: desc = SwitchEntityDescription(key="nat_port_forward.abc", name="NAT") config_entry = make_config_entry({CONF_DEVICE_UNIQUE_ID: "dev1"}) setattr(config_entry.runtime_data, COORDINATOR, coordinator) - ent = OPNsenseNatSwitch( + ent = OPNsenseNatSwitchLegacy( config_entry=config_entry, coordinator=coordinator, entity_description=desc, @@ -864,7 +1151,7 @@ def test_entity_icons(make_config_entry): f_desc = SwitchEntityDescription(key="filter.t1", name="Filter") config_entry = make_config_entry({CONF_DEVICE_UNIQUE_ID: "dev1"}) setattr(config_entry.runtime_data, COORDINATOR, None) - f_ent = OPNsenseFilterSwitch( + f_ent = OPNsenseFilterSwitchLegacy( config_entry=config_entry, coordinator=make_coord({}), entity_description=f_desc, @@ -877,7 +1164,7 @@ def test_entity_icons(make_config_entry): n_desc = SwitchEntityDescription(key="nat_port_forward.t1", name="NAT") config_entry = make_config_entry({CONF_DEVICE_UNIQUE_ID: "dev1"}) setattr(config_entry.runtime_data, COORDINATOR, None) - n_ent = OPNsenseNatSwitch( + n_ent = OPNsenseNatSwitchLegacy( config_entry=config_entry, coordinator=make_coord({}), entity_description=n_desc, @@ -1028,7 +1315,7 @@ def test_reset_delay_calls_existing_remover(monkeypatch, make_config_entry): desc = SwitchEntityDescription(key="filter.t1", name="Filter") config_entry = make_config_entry({CONF_DEVICE_UNIQUE_ID: "dev1"}) setattr(config_entry.runtime_data, COORDINATOR, None) - ent = OPNsenseFilterSwitch( + ent = OPNsenseFilterSwitchLegacy( config_entry=config_entry, coordinator=make_coord({}), entity_description=desc, @@ -1056,24 +1343,29 @@ def new_remover(): @pytest.mark.asyncio async def test_compile_filter_skip_and_invalid_rules(coordinator, make_config_entry): """Filter compilation skips invalid and non-dict rule entries.""" - config_entry = make_config_entry({CONF_DEVICE_UNIQUE_ID: "dev1", "url": "http://example"}) + config_entry = make_config_entry( + data={CONF_DEVICE_UNIQUE_ID: "dev1", "url": "http://example"}, + options={CONF_SYNC_FIREWALL_AND_NAT: True}, + ) setattr(config_entry.runtime_data, COORDINATOR, coordinator) # include various rules that should be skipped state = { - "config": { - "filter": { - "rule": [ - {"descr": "Anti-Lockout Rule", "created": {"time": "a1"}}, - {"associated-rule-id": "x", "created": {"time": "a2"}}, - {"descr": "No tracker"}, - ["not", "a", "dict"], - {"descr": "Good", "created": {"time": "g1"}}, - ] + "firewall": { + "config": { + "filter": { + "rule": [ + {"description": "Anti-Lockout Rule", "created": {"time": "a1"}}, + {"associated-rule-id": "x", "created": {"time": "a2"}}, + {"description": "No tracker"}, + ["not", "a", "dict"], + {"description": "Good", "created": {"time": "g1"}}, + ] + } } } } coordinator.data = state - entities = await _compile_filter_switches(config_entry, coordinator, state) + entities = await _compile_filter_switches_legacy(config_entry, coordinator, state) # only the valid rule should be compiled assert len(entities) == 1 @@ -1081,22 +1373,27 @@ async def test_compile_filter_skip_and_invalid_rules(coordinator, make_config_en @pytest.mark.asyncio async def test_compile_nat_outbound_skips_auto_created(coordinator, make_config_entry): """Outbound NAT compilation ignores auto-created rules.""" - config_entry = make_config_entry({CONF_DEVICE_UNIQUE_ID: "dev1", "url": "http://example"}) + config_entry = make_config_entry( + data={CONF_DEVICE_UNIQUE_ID: "dev1", "url": "http://example"}, + options={CONF_SYNC_FIREWALL_AND_NAT: True}, + ) setattr(config_entry.runtime_data, COORDINATOR, coordinator) state = { - "config": { - "nat": { - "outbound": { - "rule": [ - {"descr": "Auto created rule", "created": {"time": "x1"}}, - {"descr": "Manual", "created": {"time": "x2"}}, - ] + "firewall": { + "config": { + "nat": { + "outbound": { + "rule": [ + {"description": "Auto created rule", "created": {"time": "x1"}}, + {"description": "Manual", "created": {"time": "x2"}}, + ] + } } } } } coordinator.data = state - ents = await _compile_nat_outbound_switches(config_entry, coordinator, state) + ents = await _compile_nat_outbound_switches_legacy(config_entry, coordinator, state) assert len(ents) == 1 @@ -1131,7 +1428,7 @@ def fake_add_entities(entities): config_entry = make_config_entry( data={ CONF_DEVICE_UNIQUE_ID: "dev1", - CONF_SYNC_FILTERS_AND_NAT: False, + CONF_SYNC_FIREWALL_AND_NAT: False, CONF_SYNC_SERVICES: False, CONF_SYNC_VPN: False, CONF_SYNC_UNBOUND: True, @@ -1215,7 +1512,7 @@ async def test_nat_handle_missing_rule_returns_none(coordinator, ph_hass, make_c # create a nat switch with rule type that doesn't exist in state desc = SwitchEntityDescription(key="nat_outbound.missing", name="Missing") config_entry = make_config_entry({CONF_DEVICE_UNIQUE_ID: "dev1"}) - ent = OPNsenseNatSwitch( + ent = OPNsenseNatSwitchLegacy( config_entry=config_entry, coordinator=coordinator, entity_description=desc, @@ -1478,7 +1775,7 @@ def test_nat_rule_type_and_tracker_methods(coordinator, make_config_entry): config_entry = make_config_entry({CONF_DEVICE_UNIQUE_ID: "dev1"}) setattr(config_entry.runtime_data, COORDINATOR, coordinator) - pf = OPNsenseNatSwitch( + pf = OPNsenseNatSwitchLegacy( config_entry=config_entry, coordinator=coordinator, entity_description=desc_pf, @@ -1486,7 +1783,7 @@ def test_nat_rule_type_and_tracker_methods(coordinator, make_config_entry): # create a separate config_entry for the outbound test config_entry_ob = make_config_entry({CONF_DEVICE_UNIQUE_ID: "dev1"}) setattr(config_entry_ob.runtime_data, COORDINATOR, coordinator) - ob = OPNsenseNatSwitch( + ob = OPNsenseNatSwitchLegacy( config_entry=config_entry_ob, coordinator=coordinator, entity_description=desc_ob, @@ -1518,8 +1815,8 @@ async def test_compile_port_forward_with_missing_rules(coordinator, make_config_ # port forward compile should return [] when nat not present or rules missing config_entry = make_config_entry({CONF_DEVICE_UNIQUE_ID: "dev1"}) setattr(config_entry.runtime_data, COORDINATOR, coordinator) - coordinator.data = {"config": {}} - res = await _compile_port_forward_switches(config_entry, coordinator, coordinator.data) + coordinator.data = {"firewall": {"config": {}}} + res = await _compile_port_forward_switches_legacy(config_entry, coordinator, coordinator.data) assert res == [] @@ -1548,7 +1845,7 @@ def test_filter_handle_exceptions_sets_unavailable( config_entry = make_config_entry({CONF_DEVICE_UNIQUE_ID: "dev1"}) setattr(config_entry.runtime_data, COORDINATOR, coordinator) - ent = OPNsenseFilterSwitch( + ent = OPNsenseFilterSwitchLegacy( config_entry=config_entry, coordinator=coordinator, entity_description=desc ) ent.hass = MagicMock(spec=HomeAssistant) @@ -1582,7 +1879,7 @@ def test_nat_handle_exceptions_sets_unavailable(exc_type, coordinator, make_conf config_entry = make_config_entry({CONF_DEVICE_UNIQUE_ID: "dev1"}) setattr(config_entry.runtime_data, COORDINATOR, coordinator) - ent = OPNsenseNatSwitch( + ent = OPNsenseNatSwitchLegacy( config_entry=config_entry, coordinator=coordinator, entity_description=desc ) ent.hass = MagicMock(spec=HomeAssistant) @@ -1739,3 +2036,150 @@ async def test_unbound_extended_switch_toggle_failures(coordinator, ph_hass, mak await ent.async_turn_off() assert ent.is_on is True # Should remain on assert ent.delay_update is False # Should not set delay + + +@pytest.mark.asyncio +async def test_compile_firewall_rules_switches(coordinator, make_config_entry): + """Test compilation of firewall rule switches for new API.""" + config_entry = make_config_entry({CONF_DEVICE_UNIQUE_ID: "dev1"}) + state = { + "firewall": { + "rules": { + "rule1": { + "uuid": "rule1", + "description": "Test Rule 1", + "%interface": "wan", + "enabled": "1", + }, + "rule2": { + "uuid": "rule2", + "description": "Test Rule 2", + "%interface": "lan", + "enabled": "0", + }, + } + } + } + ents = await _compile_firewall_rules_switches(config_entry, coordinator, state) + assert len(ents) == 2 + assert isinstance(ents[0], OPNsenseFirewallRuleSwitch) + assert isinstance(ents[1], OPNsenseFirewallRuleSwitch) + assert ents[0].entity_description.key == "firewall.rule.rule1" + assert ents[1].entity_description.key == "firewall.rule.rule2" + + +@pytest.mark.asyncio +async def test_compile_nat_source_rules_switches(coordinator, make_config_entry): + """Test compilation of NAT source rule switches.""" + config_entry = make_config_entry({CONF_DEVICE_UNIQUE_ID: "dev1"}) + state = { + "firewall": { + "nat": { + "source_nat": { + "nat1": { + "uuid": "nat1", + "description": "Source NAT Rule", + "%interface": "wan", + "enabled": "1", + } + } + } + } + } + ents = await _compile_nat_source_rules_switches(config_entry, coordinator, state) + assert len(ents) == 1 + assert isinstance(ents[0], OPNsenseNATRuleSwitch) + assert ents[0].entity_description.key == "firewall.nat.source_nat.nat1" + + +@pytest.mark.asyncio +async def test_compile_nat_destination_rules_switches(coordinator, make_config_entry): + """Test compilation of NAT destination rule switches.""" + config_entry = make_config_entry({CONF_DEVICE_UNIQUE_ID: "dev1"}) + state = { + "firewall": { + "nat": { + "d_nat": { + "dnat1": { + "uuid": "dnat1", + "description": "Destination NAT Rule", + "%interface": "wan", + "enabled": "1", + } + } + } + } + } + ents = await _compile_nat_destination_rules_switches(config_entry, coordinator, state) + assert len(ents) == 1 + assert isinstance(ents[0], OPNsenseNATRuleSwitch) + assert ents[0].entity_description.key == "firewall.nat.d_nat.dnat1" + + +@pytest.mark.asyncio +async def test_compile_nat_one_to_one_rules_switches(coordinator, make_config_entry): + """Test compilation of NAT one-to-one rule switches.""" + config_entry = make_config_entry({CONF_DEVICE_UNIQUE_ID: "dev1"}) + state = { + "firewall": { + "nat": { + "one_to_one": { + "oto1": { + "uuid": "oto1", + "description": "One-to-One NAT Rule", + "%interface": "wan", + "enabled": "1", + } + } + } + } + } + ents = await _compile_nat_one_to_one_rules_switches(config_entry, coordinator, state) + assert len(ents) == 1 + assert isinstance(ents[0], OPNsenseNATRuleSwitch) + assert ents[0].entity_description.key == "firewall.nat.one_to_one.oto1" + + +@pytest.mark.asyncio +async def test_compile_nat_npt_rules_switches(coordinator, make_config_entry): + """Test compilation of NAT NPT rule switches.""" + config_entry = make_config_entry({CONF_DEVICE_UNIQUE_ID: "dev1"}) + state = { + "firewall": { + "nat": { + "npt": { + "npt1": { + "uuid": "npt1", + "description": "NPT NAT Rule", + "%interface": "wan", + "enabled": "1", + } + } + } + } + } + ents = await _compile_nat_npt_rules_switches(config_entry, coordinator, state) + assert len(ents) == 1 + assert isinstance(ents[0], OPNsenseNATRuleSwitch) + assert ents[0].entity_description.key == "firewall.nat.npt.npt1" + + +@pytest.mark.asyncio +async def test_compile_new_api_empty_state(coordinator, make_config_entry): + """Test compilation functions handle empty/missing state gracefully.""" + config_entry = make_config_entry({CONF_DEVICE_UNIQUE_ID: "dev1"}) + + # Test empty state + ents = await _compile_firewall_rules_switches(config_entry, coordinator, {}) + assert ents == [] + + ents = await _compile_nat_source_rules_switches(config_entry, coordinator, {}) + assert ents == [] + + # Test state with firewall but no rules + state = {"firewall": {}} + ents = await _compile_firewall_rules_switches(config_entry, coordinator, state) + assert ents == [] + + ents = await _compile_nat_source_rules_switches(config_entry, coordinator, state) + assert ents == []