diff --git a/.vscode/settings.json b/.vscode/settings.json index e09242343..edb237f64 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -4,5 +4,6 @@ ], "python.testing.pytestEnabled": true, "editor.formatOnSave": true, - "python.testing.unittestEnabled": false -} + "python.testing.unittestEnabled": false, + "debugpy.debugJustMyCode": false, +} \ No newline at end of file diff --git a/examples/server_config.json b/examples/server_config.json new file mode 100644 index 000000000..c3bf37459 --- /dev/null +++ b/examples/server_config.json @@ -0,0 +1,66 @@ +{ + "ws_server_config": { + "host": "localhost", + "port": 8001, + "network_auto_start": false + }, + "ws_client_config": { + "host": "localhost", + "port": 8001, + "aiohttp_session": null + }, + "zha_config": { + "coordinator_configuration": { + "path": "/dev/cu.wchusbserial971207DO", + "baudrate": 115200, + "flow_control": "hardware", + "radio_type": "ezsp" + }, + "quirks_configuration": { + "enabled": true, + "custom_quirks_path": "/Users/davidmulcahey/.homeassistant/quirks" + }, + "device_overrides": {}, + "light_options": { + "default_light_transition": 0.0, + "enable_enhanced_light_transition": false, + "enable_light_transitioning_flag": true, + "always_prefer_xy_color_mode": true, + "group_members_assume_state": true + }, + "device_options": { + "enable_identify_on_join": true, + "consider_unavailable_mains": 7200, + "consider_unavailable_battery": 21600, + "enable_mains_startup_polling": true + }, + "alarm_control_panel_options": { + "master_code": "1234", + "failed_tries": 3, + "arm_requires_code": false + } + }, + "zigpy_config": { + "startup_energy_scan": false, + "handle_unknown_devices": true, + "source_routing": true, + "max_concurrent_requests": 128, + "ezsp_config": { + "CONFIG_PACKET_BUFFER_COUNT": 255, + "CONFIG_MTORR_FLOW_CONTROL": 1, + "CONFIG_KEY_TABLE_SIZE": 12, + "CONFIG_ROUTE_TABLE_SIZE": 200 + }, + "ota": { + "otau_directory": "/Users/davidmulcahey/.homeassistant/zigpy_ota", + "inovelli_provider": false, + "thirdreality_provider": true + }, + "database_path": "/Users/davidmulcahey/.homeassistant/zigbee.db", + "device": { + "baudrate": 115200, + "flow_control": "hardware", + "path": "/dev/cu.wchusbserial971207DO" + } + } +} diff --git a/pyproject.toml b/pyproject.toml index 57155496d..c61358fe2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,9 @@ dependencies = [ "zha-quirks==0.0.124", "pyserial==3.5", "pyserial-asyncio-fast", + "pydantic==2.9.2", + "websockets<14.0", + "aiohttp" ] [tool.setuptools.packages.find] @@ -231,4 +234,7 @@ show_missing = true exclude_also = [ "if TYPE_CHECKING:", "raise NotImplementedError", +] +omit =[ + "*/__main__.py", ] \ No newline at end of file diff --git a/script/run_websocket_server b/script/run_websocket_server new file mode 100755 index 000000000..746f25605 --- /dev/null +++ b/script/run_websocket_server @@ -0,0 +1,10 @@ +#!/usr/bin/env bash + +# Stop on errors +set -e + +cd "$(dirname "$0")/.." + +source venv/bin/activate + +python -m zha.websocket.server --config=./examples/server_config.json \ No newline at end of file diff --git a/tests/common.py b/tests/common.py index bff7c862e..503f0222a 100644 --- a/tests/common.py +++ b/tests/common.py @@ -238,7 +238,10 @@ def get_entity( if not isinstance(entity, entity_type): continue - if exact_entity_type is not None and type(entity) is not exact_entity_type: + if ( + exact_entity_type is not None + and entity.info_object.class_name != exact_entity_type.__name__ + ): continue if qualifier is not None and qualifier not in entity.info_object.unique_id: @@ -262,45 +265,45 @@ async def group_entity_availability_test( assert entity.state["available"] is True device_1.on_network = False - await asyncio.sleep(0.1) + await asyncio.sleep(1) await zha_gateway.async_block_till_done() assert entity.state["available"] is True device_2.on_network = False - await asyncio.sleep(0.1) + await asyncio.sleep(1) await zha_gateway.async_block_till_done() assert entity.state["available"] is False device_1.on_network = True - await asyncio.sleep(0.1) + await asyncio.sleep(1) await zha_gateway.async_block_till_done() assert entity.state["available"] is True device_2.on_network = True - await asyncio.sleep(0.1) + await asyncio.sleep(1) await zha_gateway.async_block_till_done() assert entity.state["available"] is True - device_1.available = False - await asyncio.sleep(0.1) + device_1.update_available(available=False, on_network=device_1.on_network) + await asyncio.sleep(1) await zha_gateway.async_block_till_done() assert entity.state["available"] is True - device_2.available = False - await asyncio.sleep(0.1) + device_2.update_available(available=False, on_network=device_2.on_network) + await asyncio.sleep(1) await zha_gateway.async_block_till_done() assert entity.state["available"] is False - device_1.available = True - await asyncio.sleep(0.1) + device_1.update_available(available=True, on_network=device_1.on_network) + await asyncio.sleep(1) await zha_gateway.async_block_till_done() assert entity.state["available"] is True - device_2.available = True - await asyncio.sleep(0.1) + device_2.update_available(available=True, on_network=device_2.on_network) + await asyncio.sleep(1) await zha_gateway.async_block_till_done() assert entity.state["available"] is True @@ -504,6 +507,9 @@ def create_mock_zigpy_device( descriptor_capability_field=zdo_t.NodeDescriptor.DescriptorCapability.NONE, ) + if isinstance(node_descriptor, bytes): + node_descriptor = zdo_t.NodeDescriptor.deserialize(node_descriptor)[0] + device.node_desc = node_descriptor device.last_seen = time.time() diff --git a/tests/conftest.py b/tests/conftest.py index e2c45bb17..7fd3190df 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,7 @@ """Test configuration for the ZHA component.""" import asyncio -from collections.abc import Callable, Generator +from collections.abc import AsyncGenerator, Callable, Generator from contextlib import contextmanager import logging import os @@ -10,6 +10,7 @@ from types import TracebackType from unittest.mock import AsyncMock, MagicMock, patch +import aiohttp.test_utils import pytest import zigpy from zigpy.application import ControllerApplication @@ -23,15 +24,23 @@ import zigpy.zdo.types as zdo_t from zha.application import Platform -from zha.application.gateway import Gateway -from zha.application.helpers import ( +from zha.application.gateway import ( + Gateway, + WebSocketClientGateway, + WebSocketServerGateway, +) +from zha.application.model import ( AlarmControlPanelOptions, CoordinatorConfiguration, LightOptions, + WebsocketClientConfiguration, + WebsocketServerConfiguration, ZHAConfiguration, ZHAData, ) -from zha.async_ import ZHAJob +from zha.async_ import ZHAJob, cancelling +from zha.zigbee.group import WebSocketClientGroup +from zha.zigbee.model import GroupMemberReference FIXTURE_GRP_ID = 0x1001 FIXTURE_GRP_NAME = "fixture group" @@ -216,7 +225,7 @@ async def zigpy_app_controller_fixture(): app.groups.add_group(FIXTURE_GRP_ID, FIXTURE_GRP_NAME, suppress_event=True) - app.state.node_info.nwk = 0x0000 + app.state.node_info.nwk = zigpy.types.NWK(0x0000) app.state.node_info.ieee = zigpy.types.EUI64.convert("00:15:8d:00:02:32:4f:32") app.state.network_info.pan_id = 0x1234 app.state.network_info.extended_pan_id = app.state.node_info.ieee @@ -229,7 +238,21 @@ async def zigpy_app_controller_fixture(): # Create a fake coordinator device dev = app.add_device(nwk=app.state.node_info.nwk, ieee=app.state.node_info.ieee) - dev.node_desc = zdo_t.NodeDescriptor() + dev.node_desc = zdo_t.NodeDescriptor( + logical_type=zdo_t.LogicalType.Coordinator, + complex_descriptor_available=0, + user_descriptor_available=0, + reserved=0, + aps_flags=0, + frequency_band=zdo_t.NodeDescriptor.FrequencyBand.Freq2400MHz, + mac_capability_flags=zdo_t.NodeDescriptor.MACCapabilityFlags.AllocateAddress, + manufacturer_code=0x1234, + maximum_buffer_size=127, + maximum_incoming_transfer_size=100, + server_mask=10752, + maximum_outgoing_transfer_size=100, + descriptor_capability_field=zdo_t.NodeDescriptor.DescriptorCapability.NONE, + ) dev.node_desc.logical_type = zdo_t.LogicalType.Coordinator dev.manufacturer = "Coordinator Manufacturer" dev.model = "Coordinator Model" @@ -252,7 +275,7 @@ def caplog_fixture(caplog: pytest.LogCaptureFixture) -> pytest.LogCaptureFixture @pytest.fixture(name="zha_data") def zha_data_fixture() -> ZHAData: """Fixture representing zha configuration data.""" - + port = aiohttp.test_utils.unused_port() return ZHAData( config=ZHAConfiguration( coordinator_configuration=CoordinatorConfiguration( @@ -268,7 +291,15 @@ def zha_data_fixture() -> ZHAData: master_code="4321", failed_tries=2, ), - ) + ), + ws_server_config=WebsocketServerConfiguration( + host="localhost", + port=port, + network_auto_start=False, + ), + ws_client_config=WebsocketClientConfiguration( + host="localhost", port=port, aiohttp_session=None + ), ) @@ -298,14 +329,124 @@ async def __aexit__( await asyncio.sleep(0) +class CombinedWebsocketGateways: + """Combine multiple gateways into a single one.""" + + def __init__( + self, + zha_data: ZHAData, + ws_gateway: WebSocketServerGateway, + client_gateway: WebSocketClientGateway, + ): + """Initialize the CombinedWebsocketGateways class.""" + self.zha_data = zha_data + self.ws_gateway: WebSocketServerGateway = ws_gateway + self.client_gateway: WebSocketClientGateway = client_gateway + + @property + def application_controller(self) -> ControllerApplication: + """Return the Zigpy application controller.""" + return self.ws_gateway.application_controller + + @property + def config(self) -> ZHAData: + """Return the ZHA configuration.""" + return self.ws_gateway.config + + async def async_block_till_done(self, wait_background_tasks=False) -> None: + """Block until all gateways are done.""" + await asyncio.sleep(0.005) + await self.ws_gateway.async_block_till_done( + wait_background_tasks=wait_background_tasks + ) + await asyncio.sleep(0.001) + if self.client_gateway._tasks: + current_task = asyncio.current_task() + while tasks := [ + task + for task in self.client_gateway._tasks + if task is not current_task and not cancelling(task) + ]: + await self.ws_gateway._await_and_log_pending_tasks(tasks) + + async def async_device_initialized(self, device: zigpy.device.Device) -> None: + """Handle device joined and basic information discovered (async).""" + await self.ws_gateway.async_device_initialized(device) + + def get_device(self, ieee: zigpy.types.EUI64): + """Return Device for given ieee.""" + return self.client_gateway.get_device(ieee) + + async def async_remove_zigpy_group(self, group_id: int) -> None: + """Remove a Zigbee group from Zigpy.""" + await self.client_gateway.async_remove_zigpy_group(group_id) + + async def async_create_zigpy_group( + self, + name: str, + members: list[GroupMemberReference] | None, + group_id: int | None = None, + ) -> WebSocketClientGroup | None: + """Create a new Zigpy Zigbee group.""" + return await self.client_gateway.async_create_zigpy_group( + name, members, group_id + ) + + +class CombinedWebsocketGatewaysContextManager: + """Combine multiple gateways into a single one.""" + + def __init__( + self, + zha_data: ZHAData, + ): + """Initialize the CombinedWebsocketGateways class.""" + self.zha_data = zha_data + self.combined_gateways: CombinedWebsocketGateways + + async def __aenter__(self) -> CombinedWebsocketGateways: + """Start the ZHA gateway.""" + ws_gateway = await WebSocketServerGateway.async_from_config(self.zha_data) + await ws_gateway.start_server() + await ws_gateway.async_initialize() + await ws_gateway.async_block_till_done() + await ws_gateway.async_initialize_devices_and_entities() + await ws_gateway.async_block_till_done(wait_background_tasks=True) + + client_gateway = WebSocketClientGateway(self.zha_data) + await client_gateway.connect() + await client_gateway.clients.listen() + await ws_gateway.async_block_till_done() + await client_gateway.async_initialize() + assert client_gateway.state is not None + + self.combined_gateways = CombinedWebsocketGateways( + self.zha_data, ws_gateway, client_gateway + ) + INSTANCES.append(self.combined_gateways) + + return self.combined_gateways + + async def __aexit__( + self, exc_type: Exception, exc_value: str, traceback: TracebackType + ) -> None: + """Shutdown the ZHA gateway.""" + + await self.combined_gateways.client_gateway.disconnect() + await self.combined_gateways.ws_gateway.async_block_till_done() + await self.combined_gateways.ws_gateway.shutdown() + await asyncio.sleep(0) + INSTANCES.remove(self.combined_gateways) + + @pytest.fixture async def zha_gateway( zha_data: ZHAData, zigpy_app_controller, + request, caplog, # pylint: disable=unused-argument -): +) -> AsyncGenerator[Gateway | CombinedWebsocketGateways, None]: """Set up ZHA component.""" - with ( patch( "bellows.zigbee.application.ControllerApplication.new", @@ -316,8 +457,12 @@ async def zha_gateway( return_value=zigpy_app_controller, ), ): - async with TestGateway(zha_data) as gateway: - yield gateway + if hasattr(request, "param") and request.param == "ws_gateways": + async with CombinedWebsocketGatewaysContextManager(zha_data) as gateway: + yield gateway + else: + async with TestGateway(zha_data) as gateway: + yield gateway @pytest.fixture(scope="session", autouse=True) @@ -367,8 +512,11 @@ def cluster_handler_factory( return cluster_handler_factory +# https://github.com/nolar/looptime arg docs are here def pytest_collection_modifyitems(config, items): """Add the looptime marker to all tests except the test_async.py file.""" for item in items: if "test_async_.py" not in item.nodeid: - item.add_marker(pytest.mark.looptime) + item.add_marker( + pytest.mark.looptime.with_args(noop_cycles=100, idle_step=0.000001) + ) diff --git a/tests/data/serialization_data/centralite-3320-l-extended-device-info.json b/tests/data/serialization_data/centralite-3320-l-extended-device-info.json new file mode 100644 index 000000000..039ac99fa --- /dev/null +++ b/tests/data/serialization_data/centralite-3320-l-extended-device-info.json @@ -0,0 +1 @@ +{"ieee":"00:0d:6f:00:0f:3a:e3:69","nwk":"0x970A","manufacturer":"CentraLite","model":"3320-L","name":"CentraLite 3320-L","quirk_applied":true,"quirk_class":"zhaquirks.centralite.ias.CentraLiteIASSensor","quirk_id":null,"manufacturer_code":49887,"power_source":"Battery or Unknown","lqi":null,"rssi":null,"available":true,"on_network":true,"is_groupable":false,"device_type":"EndDevice","signature":{"node_descriptor":{"logical_type":2,"complex_descriptor_available":0,"user_descriptor_available":0,"reserved":0,"aps_flags":0,"frequency_band":8,"mac_capability_flags":128,"manufacturer_code":49887,"maximum_buffer_size":82,"maximum_incoming_transfer_size":82,"server_mask":0,"maximum_outgoing_transfer_size":82,"descriptor_capability_field":0},"endpoints":{"1":{"profile_id":"0x0104","device_type":"0x0402","input_clusters":["0x0000","0x0001","0x0003","0x0020","0x0402","0x0500","0x0b05"],"output_clusters":["0x0019"]},"2":{"profile_id":"0xc2df","device_type":"0x000c","input_clusters":["0x0000","0x0003","0x0b05","0xfc0f"],"output_clusters":["0x0003"]}},"manufacturer":"CentraLite","model":"3320-L"},"sw_version":null,"active_coordinator":false,"entities":{"binary_sensor,00:0d:6f:00:0f:3a:e3:69-1-1280":{"platform":"binary_sensor","unique_id":"00:0d:6f:00:0f:3a:e3:69-1-1280","class_name":"IASZone","translation_key":null,"device_class":"opening","state_class":null,"entity_category":null,"entity_registry_enabled_default":true,"enabled":true,"fallback_name":null,"state":{"available":true,"state":false,"class_name":"IASZone","model_class_name":"EntityState"},"cluster_handlers":[{"class_name":"IASZoneClusterHandler","generic_id":"cluster_handler_0x0500","endpoint_id":1,"cluster":{"id":1280,"name":"IAS Zone","type":"server","endpoint_id":1,"endpoint_attribute":"ias_zone"},"id":"1:0x0500","unique_id":"00:0d:6f:00:0f:3a:e3:69:1:0x0500","status":"initialized","value_attribute":null}],"device_ieee":"00:0d:6f:00:0f:3a:e3:69","endpoint_id":1,"available":true,"group_id":null,"attribute_name":"zone_status","model_class_name":"BinarySensorEntityInfo"},"button,00:0d:6f:00:0f:3a:e3:69-1-3":{"platform":"button","unique_id":"00:0d:6f:00:0f:3a:e3:69-1-3","class_name":"IdentifyButton","translation_key":null,"device_class":"identify","state_class":null,"entity_category":"diagnostic","entity_registry_enabled_default":true,"enabled":true,"fallback_name":null,"state":{"available":true,"state":null,"class_name":"IdentifyButton","model_class_name":"EntityState"},"cluster_handlers":[{"class_name":"IdentifyClusterHandler","generic_id":"cluster_handler_0x0003","endpoint_id":1,"cluster":{"id":3,"name":"Identify","type":"server","endpoint_id":1,"endpoint_attribute":"identify"},"id":"1:0x0003","unique_id":"00:0d:6f:00:0f:3a:e3:69:1:0x0003","status":"initialized","value_attribute":null}],"device_ieee":"00:0d:6f:00:0f:3a:e3:69","endpoint_id":1,"available":true,"group_id":null,"command":"identify","args":[5],"kwargs":{},"model_class_name":"CommandButtonEntityInfo"},"sensor,00:0d:6f:00:0f:3a:e3:69-1-1":{"platform":"sensor","unique_id":"00:0d:6f:00:0f:3a:e3:69-1-1","class_name":"Battery","translation_key":null,"device_class":"battery","state_class":"measurement","entity_category":"diagnostic","entity_registry_enabled_default":true,"enabled":true,"fallback_name":null,"state":{"state":100,"battery_size":"Other","battery_quantity":1,"battery_voltage":2.8,"available":true,"class_name":"Battery","model_class_name":"BatteryState"},"cluster_handlers":[{"class_name":"PowerConfigurationClusterHandler","generic_id":"cluster_handler_0x0001","endpoint_id":1,"cluster":{"id":1,"name":"Power Configuration","type":"server","endpoint_id":1,"endpoint_attribute":"power"},"id":"1:0x0001","unique_id":"00:0d:6f:00:0f:3a:e3:69:1:0x0001","status":"initialized","value_attribute":"battery_voltage"}],"device_ieee":"00:0d:6f:00:0f:3a:e3:69","endpoint_id":1,"available":true,"group_id":null,"attribute":"battery_percentage_remaining","decimals":1,"divisor":1,"multiplier":1,"unit":"%","extra_state_attribute_names":["battery_voltage","battery_size","battery_quantity"],"model_class_name":"BatteryEntityInfo"},"sensor,00:0d:6f:00:0f:3a:e3:69-1-1026":{"platform":"sensor","unique_id":"00:0d:6f:00:0f:3a:e3:69-1-1026","class_name":"Temperature","translation_key":null,"device_class":"temperature","state_class":"measurement","entity_category":null,"entity_registry_enabled_default":true,"enabled":true,"fallback_name":null,"state":{"available":true,"state":20.2,"class_name":"Temperature","model_class_name":"EntityState"},"cluster_handlers":[{"class_name":"TemperatureMeasurementClusterHandler","generic_id":"cluster_handler_0x0402","endpoint_id":1,"cluster":{"id":1026,"name":"Temperature Measurement","type":"server","endpoint_id":1,"endpoint_attribute":"temperature"},"id":"1:0x0402","unique_id":"00:0d:6f:00:0f:3a:e3:69:1:0x0402","status":"initialized","value_attribute":"measured_value"}],"device_ieee":"00:0d:6f:00:0f:3a:e3:69","endpoint_id":1,"available":true,"group_id":null,"attribute":"measured_value","decimals":1,"divisor":100,"multiplier":1,"unit":"°C","extra_state_attribute_names":null,"entity_desctiption":null,"model_class_name":"SensorEntityInfo"},"sensor,00:0d:6f:00:0f:3a:e3:69-1-0-rssi":{"platform":"sensor","unique_id":"00:0d:6f:00:0f:3a:e3:69-1-0-rssi","class_name":"RSSISensor","translation_key":"rssi","device_class":"signal_strength","state_class":"measurement","entity_category":"diagnostic","entity_registry_enabled_default":false,"enabled":true,"fallback_name":null,"state":{"available":true,"state":null,"class_name":"RSSISensor","model_class_name":"EntityState"},"cluster_handlers":[{"class_name":"BasicClusterHandler","generic_id":"cluster_handler_0x0000","endpoint_id":1,"cluster":{"id":0,"name":"Basic","type":"server","endpoint_id":1,"endpoint_attribute":"basic"},"id":"1:0x0000","unique_id":"00:0d:6f:00:0f:3a:e3:69:1:0x0000","status":"initialized","value_attribute":null}],"device_ieee":"00:0d:6f:00:0f:3a:e3:69","endpoint_id":1,"available":true,"group_id":null,"attribute":null,"decimals":1,"divisor":1,"multiplier":1,"unit":"dBm","extra_state_attribute_names":null,"entity_desctiption":null,"model_class_name":"SensorEntityInfo"},"sensor,00:0d:6f:00:0f:3a:e3:69-1-0-lqi":{"platform":"sensor","unique_id":"00:0d:6f:00:0f:3a:e3:69-1-0-lqi","class_name":"LQISensor","translation_key":"lqi","device_class":null,"state_class":"measurement","entity_category":"diagnostic","entity_registry_enabled_default":false,"enabled":true,"fallback_name":null,"state":{"available":true,"state":null,"class_name":"LQISensor","model_class_name":"EntityState"},"cluster_handlers":[{"class_name":"BasicClusterHandler","generic_id":"cluster_handler_0x0000","endpoint_id":1,"cluster":{"id":0,"name":"Basic","type":"server","endpoint_id":1,"endpoint_attribute":"basic"},"id":"1:0x0000","unique_id":"00:0d:6f:00:0f:3a:e3:69:1:0x0000","status":"initialized","value_attribute":null}],"device_ieee":"00:0d:6f:00:0f:3a:e3:69","endpoint_id":1,"available":true,"group_id":null,"attribute":null,"decimals":1,"divisor":1,"multiplier":1,"unit":null,"extra_state_attribute_names":null,"entity_desctiption":null,"model_class_name":"SensorEntityInfo"},"update,00:0d:6f:00:0f:3a:e3:69-1-25-firmware_update":{"platform":"update","unique_id":"00:0d:6f:00:0f:3a:e3:69-1-25-firmware_update","class_name":"FirmwareUpdateEntity","translation_key":null,"device_class":"firmware","state_class":null,"entity_category":"config","entity_registry_enabled_default":true,"enabled":true,"fallback_name":null,"state":{"available":true,"installed_version":null,"in_progress":false,"progress":null,"latest_version":null,"release_summary":null,"release_notes":null,"release_url":null,"class_name":"FirmwareUpdateEntity","update_percentage":null,"model_class_name":"FirmwareUpdateState"},"cluster_handlers":[{"class_name":"OtaClientClusterHandler","generic_id":"cluster_handler_0x0019","endpoint_id":1,"cluster":{"id":25,"name":"Ota","type":"client","endpoint_id":1,"endpoint_attribute":"ota"},"id":"1:0x0019","unique_id":"00:0d:6f:00:0f:3a:e3:69:1:0x0019","status":"initialized","value_attribute":null}],"device_ieee":"00:0d:6f:00:0f:3a:e3:69","endpoint_id":1,"available":true,"group_id":null,"supported_features":7,"model_class_name":"FirmwareUpdateEntityInfo"}},"neighbors":[],"routes":[],"endpoint_names":[{"name":"IAS_ZONE"},{"name":"unknown 12 device_type of 0xc2df profile id"}],"device_automation_triggers":{"device_offline,device_offline":{"device_event_type":"device_offline"}}} \ No newline at end of file diff --git a/tests/test_alarm_control_panel.py b/tests/test_alarm_control_panel.py index 5ca55aa47..89c3000e2 100644 --- a/tests/test_alarm_control_panel.py +++ b/tests/test_alarm_control_panel.py @@ -20,8 +20,15 @@ ) from zha.application import Platform from zha.application.gateway import Gateway -from zha.application.platforms.alarm_control_panel import AlarmControlPanel -from zha.application.platforms.alarm_control_panel.const import AlarmState +from zha.application.platforms.alarm_control_panel import ( + AlarmControlPanel, + WebSocketClientAlarmControlPanel, +) +from zha.application.platforms.alarm_control_panel.const import ( + AlarmControlPanelEntityFeature, + AlarmState, + CodeFormat, +) from zha.zigbee.device import Device _LOGGER = logging.getLogger(__name__) @@ -37,6 +44,14 @@ } +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) @patch( "zigpy.zcl.clusters.security.IasAce.client_command", new=AsyncMock(return_value=[sentinel.data, zcl_f.Status.SUCCESS]), @@ -75,7 +90,21 @@ async def test_alarm_control_panel( (Platform.ALARM_CONTROL_PANEL, "00:0d:6f:00:0a:90:69:e7-1") ) assert alarm_entity is not None - assert isinstance(alarm_entity, AlarmControlPanel) + assert isinstance( + alarm_entity, + AlarmControlPanel + if not hasattr(zha_gateway, "ws_gateway") + else WebSocketClientAlarmControlPanel, + ) + + assert alarm_entity.code_format == CodeFormat.NUMBER + assert alarm_entity.code_arm_required is False + assert alarm_entity.supported_features == ( + AlarmControlPanelEntityFeature.ARM_HOME + | AlarmControlPanelEntityFeature.ARM_AWAY + | AlarmControlPanelEntityFeature.ARM_NIGHT + | AlarmControlPanelEntityFeature.TRIGGER + ) # test that the state is STATE_ALARM_DISARMED assert alarm_entity.state["state"] == AlarmState.DISARMED @@ -248,7 +277,12 @@ async def test_alarm_control_panel( await reset_alarm_panel(zha_gateway, cluster, alarm_entity) assert alarm_entity.state["state"] == AlarmState.DISARMED - alarm_entity._cluster_handler.code_required_arm_actions = True + if isinstance(alarm_entity, WebSocketClientAlarmControlPanel): + zha_gateway.ws_gateway.devices[zha_device.ieee].platform_entities[ + (alarm_entity.PLATFORM, alarm_entity.unique_id) + ]._cluster_handler.code_required_arm_actions = True + else: + alarm_entity._cluster_handler.code_required_arm_actions = True await alarm_entity.async_alarm_arm_away() await zha_gateway.async_block_till_done() assert alarm_entity.state["state"] == AlarmState.DISARMED diff --git a/tests/test_binary_sensor.py b/tests/test_binary_sensor.py index 5f45d5b66..39638c7f1 100644 --- a/tests/test_binary_sensor.py +++ b/tests/test_binary_sensor.py @@ -22,7 +22,12 @@ from zha.application import Platform from zha.application.gateway import Gateway from zha.application.platforms import PlatformEntity -from zha.application.platforms.binary_sensor import Accelerometer, IASZone, Occupancy +from zha.application.platforms.binary_sensor import ( + Accelerometer, + IASZone, + Occupancy, + WebSocketClientBinarySensor, +) from zha.zigbee.cluster_handlers.const import SMARTTHINGS_ACCELERATION_CLUSTER DEVICE_IAS = { @@ -128,6 +133,14 @@ async def async_test_iaszone_on_off( assert entity.is_on +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) @pytest.mark.parametrize( "device, on_off_test, cluster_name, entity_type, plugs", [ @@ -161,7 +174,12 @@ async def test_binary_sensor( entity: PlatformEntity = find_entity(zha_device, Platform.BINARY_SENSOR) assert entity is not None - assert isinstance(entity, entity_type) + assert isinstance( + entity, + entity_type + if not hasattr(zha_gateway, "ws_gateway") + else WebSocketClientBinarySensor, + ) assert entity.PLATFORM == Platform.BINARY_SENSOR assert entity.is_on is False @@ -170,24 +188,42 @@ async def test_binary_sensor( await on_off_test(zha_gateway, cluster, entity, plugs) -async def test_smarttthings_multi( - zha_gateway: Gateway, -) -> None: +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) +async def test_smarttthings_multi(zha_gateway: Gateway) -> None: """Test smartthings multi.""" zigpy_device = create_mock_zigpy_device( zha_gateway, DEVICE_SMARTTHINGS_MULTI, manufacturer="Samjin", model="multi" ) zha_device = await join_zigpy_device(zha_gateway, zigpy_device) + entity_type = ( + Accelerometer + if not hasattr(zha_gateway, "ws_gateway") + else WebSocketClientBinarySensor + ) entity: PlatformEntity = get_entity( - zha_device, Platform.BINARY_SENSOR, entity_type=Accelerometer + zha_device, Platform.BINARY_SENSOR, entity_type=entity_type ) assert entity is not None - assert isinstance(entity, Accelerometer) + assert isinstance(entity, entity_type) assert entity.PLATFORM == Platform.BINARY_SENSOR assert entity.is_on is False - st_ch = zha_device.endpoints[1].all_cluster_handlers["1:0xfc02"] + if isinstance(entity, WebSocketClientBinarySensor): + st_ch = ( + zha_gateway.ws_gateway.devices[zha_device.ieee] + .endpoints[1] + .all_cluster_handlers["1:0xfc02"] + ) + else: + st_ch = zha_device.endpoints[1].all_cluster_handlers["1:0xfc02"] assert st_ch is not None st_ch.emit_zha_event = MagicMock(wraps=st_ch.emit_zha_event) diff --git a/tests/test_button.py b/tests/test_button.py index cd62a0f67..01dd9e089 100644 --- a/tests/test_button.py +++ b/tests/test_button.py @@ -35,24 +35,15 @@ from zha.application import Platform from zha.application.gateway import Gateway from zha.application.platforms import EntityCategory, PlatformEntity -from zha.application.platforms.button import Button, WriteAttributeButton +from zha.application.platforms.button import ( + Button, + WebSocketClientButtonEntity, + WriteAttributeButton, +) from zha.application.platforms.button.const import ButtonDeviceClass from zha.exceptions import ZHAException from zha.zigbee.device import Device -ZIGPY_DEVICE = { - 1: { - SIG_EP_INPUT: [ - general.Basic.cluster_id, - general.Identify.cluster_id, - security.IasZone.cluster_id, - ], - SIG_EP_OUTPUT: [], - SIG_EP_TYPE: zha.DeviceType.IAS_ZONE, - SIG_EP_PROFILE: zha.PROFILE_ID, - } -} - class FrostLockQuirk(CustomDevice): """Quirk with frost lock attribute.""" @@ -77,38 +68,50 @@ class TuyaManufCluster(CustomCluster, ManufacturerSpecificCluster): } -TUYA_WATER_VALVE = { - 1: { - PROFILE_ID: zha.PROFILE_ID, - DEVICE_TYPE: zha.DeviceType.ON_OFF_SWITCH, - INPUT_CLUSTERS: [ - general.Basic.cluster_id, - general.Identify.cluster_id, - general.Groups.cluster_id, - general.Scenes.cluster_id, - general.OnOff.cluster_id, - ParksideTuyaValveManufCluster.cluster_id, - ], - OUTPUT_CLUSTERS: [general.Time.cluster_id, general.Ota.cluster_id], - }, -} - - +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_button( zha_gateway: Gateway, ) -> None: """Test zha button platform.""" + zigpy_device = create_mock_zigpy_device( zha_gateway, - ZIGPY_DEVICE, + { + 1: { + SIG_EP_INPUT: [ + general.Basic.cluster_id, + general.Identify.cluster_id, + security.IasZone.cluster_id, + ], + SIG_EP_OUTPUT: [], + SIG_EP_TYPE: zha.DeviceType.IAS_ZONE, + SIG_EP_PROFILE: zha.PROFILE_ID, + } + }, ) + zha_device: Device = await join_zigpy_device(zha_gateway, zigpy_device) cluster = zigpy_device.endpoints[1].identify assert cluster is not None entity: PlatformEntity = get_entity(zha_device, Platform.BUTTON) - assert isinstance(entity, Button) + assert isinstance( + entity, + Button + if not hasattr(zha_gateway, "ws_gateway") + else WebSocketClientButtonEntity, + ) assert entity.PLATFORM == Platform.BUTTON + assert entity.args == [5] + assert entity.kwargs == {} + with patch( "zigpy.zcl.Cluster.request", return_value=mock_coro([0x00, zcl_f.Status.SUCCESS]), @@ -121,23 +124,55 @@ async def test_button( assert cluster.request.call_args[0][3] == 5 # duration in seconds +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_frost_unlock( zha_gateway: Gateway, ) -> None: """Test custom frost unlock ZHA button.""" + zigpy_device = create_mock_zigpy_device( zha_gateway, - TUYA_WATER_VALVE, + { + 1: { + PROFILE_ID: zha.PROFILE_ID, + DEVICE_TYPE: zha.DeviceType.ON_OFF_SWITCH, + INPUT_CLUSTERS: [ + general.Basic.cluster_id, + general.Identify.cluster_id, + general.Groups.cluster_id, + general.Scenes.cluster_id, + general.OnOff.cluster_id, + ParksideTuyaValveManufCluster.cluster_id, + ], + OUTPUT_CLUSTERS: [general.Time.cluster_id, general.Ota.cluster_id], + }, + }, manufacturer="_TZE200_htnnfasr", model="TS0601", ) + zha_device = await join_zigpy_device(zha_gateway, zigpy_device) cluster = zigpy_device.endpoints[1].tuya_manufacturer assert cluster is not None + entity_type = ( + WriteAttributeButton + if not hasattr(zha_gateway, "ws_gateway") + else WebSocketClientButtonEntity + ) entity: PlatformEntity = get_entity( - zha_device, platform=Platform.BUTTON, entity_type=WriteAttributeButton + zha_device, + platform=Platform.BUTTON, + entity_type=entity_type, + qualifier="reset_frost_lock", ) - assert isinstance(entity, WriteAttributeButton) + assert isinstance(entity, entity_type) assert entity._attr_device_class == ButtonDeviceClass.RESTART assert entity._attr_entity_category == EntityCategory.CONFIG @@ -204,9 +239,18 @@ class ServerCommandDefs(zcl_f.BaseCommandDefs): ) -async def custom_button_device(zha_gateway: Gateway): - """Button device fixture for quirks button tests.""" - +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) +async def test_quirks_command_button( + zha_gateway: Gateway, +) -> None: + """Test ZHA button platform.""" zigpy_device = create_mock_zigpy_device( zha_gateway, { @@ -229,14 +273,7 @@ async def custom_button_device(zha_gateway: Gateway): } update_attribute_cache(zigpy_device.endpoints[1].mfg_identify) zha_device = await join_zigpy_device(zha_gateway, zigpy_device) - return zha_device, zigpy_device.endpoints[1].mfg_identify - - -async def test_quirks_command_button( - zha_gateway: Gateway, -) -> None: - """Test ZHA button platform.""" - zha_device, cluster = await custom_button_device(zha_gateway) + cluster = zigpy_device.endpoints[1].mfg_identify assert cluster is not None entity: PlatformEntity = get_entity(zha_device, platform=Platform.BUTTON) @@ -252,14 +289,50 @@ async def test_quirks_command_button( assert cluster.request.call_args[0][3] == 5 # duration in seconds +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_quirks_write_attr_button( zha_gateway: Gateway, ) -> None: """Test ZHA button platform.""" - zha_device, cluster = await custom_button_device(zha_gateway) + + entity_type = ( + WriteAttributeButton + if not hasattr(zha_gateway, "ws_gateway") + else WebSocketClientButtonEntity + ) + zigpy_device = create_mock_zigpy_device( + zha_gateway, + { + 1: { + SIG_EP_INPUT: [ + general.Basic.cluster_id, + FakeManufacturerCluster.cluster_id, + ], + SIG_EP_OUTPUT: [], + SIG_EP_TYPE: zha.DeviceType.REMOTE_CONTROL, + SIG_EP_PROFILE: zha.PROFILE_ID, + } + }, + manufacturer="Fake_Model", + model="Fake_Manufacturer", + ) + + zigpy_device.endpoints[1].mfg_identify.PLUGGED_ATTR_READS = { + FakeManufacturerCluster.AttributeDefs.feed.name: 0, + } + update_attribute_cache(zigpy_device.endpoints[1].mfg_identify) + zha_device = await join_zigpy_device(zha_gateway, zigpy_device) + cluster = zigpy_device.endpoints[1].mfg_identify assert cluster is not None entity: PlatformEntity = get_entity( - zha_device, platform=Platform.BUTTON, entity_type=WriteAttributeButton + zha_device, platform=Platform.BUTTON, entity_type=entity_type, qualifier="feed" ) assert cluster.get(cluster.AttributeDefs.feed.name) == 0 diff --git a/tests/test_climate.py b/tests/test_climate.py index c8b859aee..a7eff54cc 100644 --- a/tests/test_climate.py +++ b/tests/test_climate.py @@ -3,6 +3,7 @@ # pylint: disable=redefined-outer-name,too-many-lines import asyncio +from collections.abc import Awaitable import logging from typing import Any from unittest.mock import AsyncMock, MagicMock, call, patch @@ -13,6 +14,7 @@ import zhaquirks.sinope.thermostat from zhaquirks.sinope.thermostat import SinopeTechnologiesThermostatCluster import zhaquirks.tuya.ts0601_trv +from zigpy.device import Device as ZigpyDevice import zigpy.profiles import zigpy.quirks import zigpy.zcl.clusters @@ -38,16 +40,22 @@ PRESET_TEMP_MANUAL, ) from zha.application.gateway import Gateway +from zha.application.platforms import WebSocketClientEntity from zha.application.platforms.climate import ( HVAC_MODE_2_SYSTEM, SEQ_OF_OPERATION, Thermostat as ThermostatEntity, ) -from zha.application.platforms.climate.const import FanState +from zha.application.platforms.climate.const import ( + ClimateEntityFeature, + FanState, + HVACMode, +) from zha.application.platforms.sensor import ( Sensor, SinopeHVACAction, ThermostatHVACAction, + WebSocketClientSensorEntity, ) from zha.const import STATE_CHANGED from zha.exceptions import ZHAException @@ -204,7 +212,7 @@ async def device_climate_mock( plug: dict[str, Any] | None = None, manuf: str | None = None, quirk: type[zigpy.quirks.CustomDevice] | None = None, -) -> Device: +) -> tuple[ZigpyDevice, Device]: """Test regular thermostat device.""" plugged_attrs = ZCL_ATTR_PLUG if plug is None else {**ZCL_ATTR_PLUG, **plug} @@ -214,7 +222,7 @@ async def device_climate_mock( zigpy_device.node_desc.mac_capability_flags |= 0b_0000_0100 zigpy_device.endpoints[1].thermostat.PLUGGED_ATTR_READS = plugged_attrs zha_device = await join_zigpy_device(zha_gateway, zigpy_device) - return zha_device + return zigpy_device, zha_device @patch.object( @@ -222,7 +230,7 @@ async def device_climate_mock( "ep_attribute", "sinope_manufacturer_specific", ) -async def device_climate_sinope(zha_gateway: Gateway): +async def device_climate_sinope(zha_gateway: Gateway) -> tuple[ZigpyDevice, Device]: """Sinope thermostat.""" return await device_climate_mock( @@ -242,31 +250,91 @@ def test_sequence_mappings(): assert Thermostat.SystemMode(HVAC_MODE_2_SYSTEM[hvac_mode]) is not None +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) +async def test_climate_entity_properties( + zha_gateway: Gateway, +) -> None: + """Test climate entity properties.""" + zigpy_device, device_climate = await device_climate_mock(zha_gateway, CLIMATE) + thrm_cluster = zigpy_device.endpoints[1].thermostat + entity: ThermostatEntity = get_entity(device_climate, platform=Platform.CLIMATE) + await send_attributes_report(zha_gateway, thrm_cluster, {0: 2100}) + + assert entity.current_temperature == 21.0 + assert entity.target_temperature is None + assert entity.target_temperature_low is None + assert entity.target_temperature_high is None + assert entity.outdoor_temperature is None + assert entity.min_temp == 7 + assert entity.max_temp == 39 + assert entity.hvac_mode == "off" + assert entity.hvac_action is None + assert entity.fan_mode == "auto" + assert entity.preset_mode == PRESET_NONE + assert ( + entity.supported_features + == ClimateEntityFeature.TARGET_TEMPERATURE + | ClimateEntityFeature.TARGET_TEMPERATURE_RANGE + | ClimateEntityFeature.TURN_OFF + | ClimateEntityFeature.TURN_ON + ) + assert entity.hvac_modes == [ + HVACMode.OFF, + HVACMode.HEAT_COOL, + HVACMode.COOL, + HVACMode.HEAT, + ] + assert entity.fan_modes is None + assert entity.preset_modes == [] + + +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_climate_local_temperature( zha_gateway: Gateway, ) -> None: """Test local temperature.""" - device_climate = await device_climate_mock(zha_gateway, CLIMATE) - thrm_cluster = device_climate.device.endpoints[1].thermostat - entity: ThermostatEntity = get_entity( - device_climate, platform=Platform.CLIMATE, entity_type=ThermostatEntity - ) + zigpy_device, device_climate = await device_climate_mock(zha_gateway, CLIMATE) + thrm_cluster = zigpy_device.endpoints[1].thermostat + entity: ThermostatEntity = get_entity(device_climate, platform=Platform.CLIMATE) assert entity.state["current_temperature"] is None + assert entity.current_temperature is None await send_attributes_report(zha_gateway, thrm_cluster, {0: 2100}) assert entity.state["current_temperature"] == 21.0 + assert entity.current_temperature == 21.0 +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_climate_outdoor_temperature( zha_gateway: Gateway, ) -> None: """Test outdoor temperature.""" - device_climate = await device_climate_mock(zha_gateway, CLIMATE) - thrm_cluster = device_climate.device.endpoints[1].thermostat - entity: ThermostatEntity = get_entity( - device_climate, platform=Platform.CLIMATE, entity_type=ThermostatEntity - ) + zigpy_device, device_climate = await device_climate_mock(zha_gateway, CLIMATE) + thrm_cluster = zigpy_device.endpoints[1].thermostat + entity: ThermostatEntity = get_entity(device_climate, platform=Platform.CLIMATE) assert entity.state["outdoor_temperature"] is None + assert entity.outdoor_temperature is None await send_attributes_report( zha_gateway, @@ -274,20 +342,28 @@ async def test_climate_outdoor_temperature( {Thermostat.AttributeDefs.outdoor_temperature.id: 2150}, ) assert entity.state["outdoor_temperature"] == 21.5 + assert entity.outdoor_temperature == 21.5 +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_climate_hvac_action_running_state( zha_gateway: Gateway, ): """Test hvac action via running state.""" - dev_climate_sinope = await device_climate_sinope(zha_gateway) - thrm_cluster = dev_climate_sinope.device.endpoints[1].thermostat - entity: ThermostatEntity = get_entity( - dev_climate_sinope, platform=Platform.CLIMATE, entity_type=ThermostatEntity - ) + zigpy_device, dev_climate_sinope = await device_climate_sinope(zha_gateway) + thrm_cluster = zigpy_device.endpoints[1].thermostat + + entity: ThermostatEntity = get_entity(dev_climate_sinope, platform=Platform.CLIMATE) sensor_entity: SinopeHVACAction = get_entity( - dev_climate_sinope, platform=Platform.SENSOR, entity_type=SinopeHVACAction + dev_climate_sinope, platform=Platform.SENSOR, qualifier="hvac_action" ) subscriber = MagicMock() @@ -295,67 +371,100 @@ async def test_climate_hvac_action_running_state( sensor_entity.on_event(STATE_CHANGED, subscriber) assert entity.state["hvac_action"] == "off" + assert entity.hvac_action == "off" assert sensor_entity.state["state"] == "off" + # the state isn't actually changing here... on the WS impl side we are getting + # the correct call count... we are getting the wrong call count on the normal impl + # TODO look into why this is the case... await send_attributes_report( zha_gateway, thrm_cluster, {0x001E: Thermostat.RunningMode.Off} ) assert entity.state["hvac_action"] == "off" + assert entity.hvac_action == "off" assert sensor_entity.state["state"] == "off" await send_attributes_report( zha_gateway, thrm_cluster, {0x001C: Thermostat.SystemMode.Auto} ) assert entity.state["hvac_action"] == "idle" + assert entity.hvac_action == "idle" assert sensor_entity.state["state"] == "idle" await send_attributes_report( zha_gateway, thrm_cluster, {0x001E: Thermostat.RunningMode.Cool} ) assert entity.state["hvac_action"] == "cooling" + assert entity.hvac_action == "cooling" assert sensor_entity.state["state"] == "cooling" await send_attributes_report( zha_gateway, thrm_cluster, {0x001E: Thermostat.RunningMode.Heat} ) assert entity.state["hvac_action"] == "heating" + assert entity.hvac_action == "heating" assert sensor_entity.state["state"] == "heating" await send_attributes_report( zha_gateway, thrm_cluster, {0x001E: Thermostat.RunningMode.Off} ) assert entity.state["hvac_action"] == "idle" + assert entity.hvac_action == "idle" assert sensor_entity.state["state"] == "idle" await send_attributes_report( zha_gateway, thrm_cluster, {0x0029: Thermostat.RunningState.Fan_State_On} ) assert entity.state["hvac_action"] == "fan" + assert entity.hvac_action == "fan" assert sensor_entity.state["state"] == "fan" # Both entities are updated! - assert len(subscriber.mock_calls) == 2 * 6 + assert ( + len(subscriber.mock_calls) == 2 * 6 + if not hasattr(zha_gateway, "ws_gateway") + else 2 * 5 + ) +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_sinope_time( zha_gateway: Gateway, ): """Test hvac action via running state.""" - dev_climate_sinope = await device_climate_sinope(zha_gateway) - mfg_cluster = dev_climate_sinope.device.endpoints[1].sinope_manufacturer_specific + zigpy_device, dev_climate_sinope = await device_climate_sinope(zha_gateway) + mfg_cluster = zigpy_device.endpoints[1].sinope_manufacturer_specific assert mfg_cluster is not None - entity: ThermostatEntity = get_entity( - dev_climate_sinope, platform=Platform.CLIMATE, entity_type=ThermostatEntity - ) + entity: ThermostatEntity = get_entity(dev_climate_sinope, platform=Platform.CLIMATE) - entity._async_update_time = AsyncMock(wraps=entity._async_update_time) + if isinstance(entity, WebSocketClientEntity): + server_entity = get_entity( + zha_gateway.ws_gateway.devices[dev_climate_sinope.ieee], + platform=Platform.CLIMATE, + ) + original_async_update_time: Awaitable = server_entity._async_update_time + server_entity._async_update_time = AsyncMock( + wraps=server_entity._async_update_time + ) + async_update_time_mock = server_entity._async_update_time + else: + original_async_update_time = entity._async_update_time + entity._async_update_time = AsyncMock(wraps=entity._async_update_time) + async_update_time_mock = entity._async_update_time await asyncio.sleep(4600) write_attributes = mfg_cluster.write_attributes - assert entity._async_update_time.await_count == 1 + assert async_update_time_mock.await_count == 1 assert write_attributes.await_count == 1 assert "secs_since_2k" in write_attributes.mock_calls[0].args[0] @@ -363,7 +472,7 @@ async def test_sinope_time( # Default time zone of UTC with freeze_time("2000-01-02 00:00:00"): - await entity._async_update_time() + await async_update_time_mock() secs_since_2k = write_attributes.mock_calls[0].args[0]["secs_since_2k"] assert secs_since_2k == pytest.approx(60 * 60 * 24) @@ -373,146 +482,200 @@ async def test_sinope_time( zha_gateway.config.local_timezone = zoneinfo.ZoneInfo("America/New_York") with freeze_time("2000-01-02 00:00:00"): - await entity._async_update_time() + await async_update_time_mock() secs_since_2k = write_attributes.mock_calls[0].args[0]["secs_since_2k"] assert secs_since_2k == pytest.approx(60 * 60 * 24 - 5 * 60 * 60) write_attributes.reset_mock() - entity._async_update_time.reset_mock() + async_update_time_mock.reset_mock() entity.disable() + await zha_gateway.async_block_till_done() assert entity.enabled is False await asyncio.sleep(4600) - assert entity._async_update_time.await_count == 0 + assert async_update_time_mock.await_count == 0 assert mfg_cluster.write_attributes.await_count == 0 entity.enable() + await zha_gateway.async_block_till_done() assert entity.enabled is True await asyncio.sleep(4600) - assert entity._async_update_time.await_count == 1 + assert async_update_time_mock.await_count == 1 assert mfg_cluster.write_attributes.await_count == 1 write_attributes.reset_mock() - entity._async_update_time.reset_mock() + async_update_time_mock.reset_mock() + + if isinstance(entity, WebSocketClientEntity): + server_entity = get_entity( + zha_gateway.ws_gateway.devices[dev_climate_sinope.ieee], + platform=Platform.CLIMATE, + ) + server_entity._async_update_time = original_async_update_time + else: + entity._async_update_time = original_async_update_time +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_climate_hvac_action_running_state_zen( zha_gateway: Gateway, ): """Test Zen hvac action via running state.""" - device_climate_zen = await device_climate_mock( + + zigpy_device, device_climate_zen = await device_climate_mock( zha_gateway, CLIMATE_ZEN, manuf=MANUF_ZEN ) - thrm_cluster = device_climate_zen.device.endpoints[1].thermostat + thrm_cluster = zigpy_device.endpoints[1].thermostat - entity: ThermostatEntity = get_entity( - device_climate_zen, platform=Platform.CLIMATE, entity_type=ThermostatEntity - ) + entity: ThermostatEntity = get_entity(device_climate_zen, platform=Platform.CLIMATE) sensor_entity: Sensor = get_entity( - device_climate_zen, platform=Platform.SENSOR, entity_type=ThermostatHVACAction + device_climate_zen, platform=Platform.SENSOR, qualifier="hvac_action" + ) + assert isinstance( + sensor_entity, + ThermostatHVACAction + if not hasattr(zha_gateway, "ws_gateway") + else WebSocketClientSensorEntity, ) - assert isinstance(sensor_entity, ThermostatHVACAction) assert entity.state["hvac_action"] is None + assert entity.hvac_action is None assert sensor_entity.state["state"] is None await send_attributes_report( zha_gateway, thrm_cluster, {0x0029: Thermostat.RunningState.Cool_2nd_Stage_On} ) assert entity.state["hvac_action"] == "cooling" + assert entity.hvac_action == "cooling" assert sensor_entity.state["state"] == "cooling" await send_attributes_report( zha_gateway, thrm_cluster, {0x0029: Thermostat.RunningState.Fan_State_On} ) assert entity.state["hvac_action"] == "fan" + assert entity.hvac_action == "fan" assert sensor_entity.state["state"] == "fan" await send_attributes_report( zha_gateway, thrm_cluster, {0x0029: Thermostat.RunningState.Heat_2nd_Stage_On} ) assert entity.state["hvac_action"] == "heating" + assert entity.hvac_action == "heating" assert sensor_entity.state["state"] == "heating" await send_attributes_report( zha_gateway, thrm_cluster, {0x0029: Thermostat.RunningState.Fan_2nd_Stage_On} ) assert entity.state["hvac_action"] == "fan" + assert entity.hvac_action == "fan" assert sensor_entity.state["state"] == "fan" await send_attributes_report( zha_gateway, thrm_cluster, {0x0029: Thermostat.RunningState.Cool_State_On} ) assert entity.state["hvac_action"] == "cooling" + assert entity.hvac_action == "cooling" assert sensor_entity.state["state"] == "cooling" await send_attributes_report( zha_gateway, thrm_cluster, {0x0029: Thermostat.RunningState.Fan_3rd_Stage_On} ) assert entity.state["hvac_action"] == "fan" + assert entity.hvac_action == "fan" assert sensor_entity.state["state"] == "fan" await send_attributes_report( zha_gateway, thrm_cluster, {0x0029: Thermostat.RunningState.Heat_State_On} ) assert entity.state["hvac_action"] == "heating" + assert entity.hvac_action == "heating" assert sensor_entity.state["state"] == "heating" await send_attributes_report( zha_gateway, thrm_cluster, {0x0029: Thermostat.RunningState.Idle} ) assert entity.state["hvac_action"] == "off" + assert entity.hvac_action == "off" assert sensor_entity.state["state"] == "off" await send_attributes_report( zha_gateway, thrm_cluster, {0x001C: Thermostat.SystemMode.Heat} ) assert entity.state["hvac_action"] == "idle" + assert entity.hvac_action == "idle" assert sensor_entity.state["state"] == "idle" +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_climate_hvac_action_pi_demand( zha_gateway: Gateway, ): """Test hvac action based on pi_heating/cooling_demand attrs.""" - device_climate = await device_climate_mock(zha_gateway, CLIMATE) - thrm_cluster = device_climate.device.endpoints[1].thermostat - entity: ThermostatEntity = get_entity( - device_climate, platform=Platform.CLIMATE, entity_type=ThermostatEntity - ) + + zigpy_device, device_climate = await device_climate_mock(zha_gateway, CLIMATE) + thrm_cluster = zigpy_device.endpoints[1].thermostat + entity: ThermostatEntity = get_entity(device_climate, platform=Platform.CLIMATE) assert entity.state["hvac_action"] is None + assert entity.hvac_action is None await send_attributes_report(zha_gateway, thrm_cluster, {0x0007: 10}) assert entity.state["hvac_action"] == "cooling" + assert entity.hvac_action == "cooling" await send_attributes_report(zha_gateway, thrm_cluster, {0x0008: 20}) assert entity.state["hvac_action"] == "heating" + assert entity.hvac_action == "heating" await send_attributes_report(zha_gateway, thrm_cluster, {0x0007: 0}) await send_attributes_report(zha_gateway, thrm_cluster, {0x0008: 0}) assert entity.state["hvac_action"] == "off" + assert entity.hvac_action == "off" await send_attributes_report( zha_gateway, thrm_cluster, {0x001C: Thermostat.SystemMode.Heat} ) assert entity.state["hvac_action"] == "idle" + assert entity.hvac_action == "idle" await send_attributes_report( zha_gateway, thrm_cluster, {0x001C: Thermostat.SystemMode.Cool} ) assert entity.state["hvac_action"] == "idle" + assert entity.hvac_action == "idle" +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) @pytest.mark.parametrize( "sys_mode, hvac_mode", ( @@ -530,26 +693,36 @@ async def test_hvac_mode( hvac_mode, ): """Test HVAC mode.""" - device_climate = await device_climate_mock(zha_gateway, CLIMATE) - thrm_cluster = device_climate.device.endpoints[1].thermostat - entity: ThermostatEntity = get_entity( - device_climate, platform=Platform.CLIMATE, entity_type=ThermostatEntity - ) + zigpy_device, device_climate = await device_climate_mock(zha_gateway, CLIMATE) + thrm_cluster = zigpy_device.endpoints[1].thermostat + entity: ThermostatEntity = get_entity(device_climate, platform=Platform.CLIMATE) assert entity.state["hvac_mode"] == "off" + assert entity.hvac_mode == "off" await send_attributes_report(zha_gateway, thrm_cluster, {0x001C: sys_mode}) assert entity.state["hvac_mode"] == hvac_mode + assert entity.hvac_mode == hvac_mode await send_attributes_report( zha_gateway, thrm_cluster, {0x001C: Thermostat.SystemMode.Off} ) assert entity.state["hvac_mode"] == "off" + assert entity.hvac_mode == "off" await send_attributes_report(zha_gateway, thrm_cluster, {0x001C: 0xFF}) assert entity.state["hvac_mode"] is None + assert entity.hvac_mode is None +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) @pytest.mark.parametrize( "seq_of_op, modes", ( @@ -569,15 +742,21 @@ async def test_hvac_modes( # pylint: disable=unused-argument ): """Test HVAC modes from sequence of operations.""" - dev_climate = await device_climate_mock( + _, dev_climate = await device_climate_mock( zha_gateway, CLIMATE, {"ctrl_sequence_of_oper": seq_of_op} ) - entity: ThermostatEntity = get_entity( - dev_climate, platform=Platform.CLIMATE, entity_type=ThermostatEntity - ) + entity: ThermostatEntity = get_entity(dev_climate, platform=Platform.CLIMATE) assert set(entity.hvac_modes) == modes +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) @pytest.mark.parametrize( "sys_mode, preset, target_temp", ( @@ -595,7 +774,7 @@ async def test_target_temperature( ): """Test target temperature property.""" - dev_climate = await device_climate_mock( + _, dev_climate = await device_climate_mock( zha_gateway, CLIMATE_SINOPE, { @@ -608,9 +787,7 @@ async def test_target_temperature( manuf=MANUF_SINOPE, quirk=zhaquirks.sinope.thermostat.SinopeTechnologiesThermostat, ) - entity: ThermostatEntity = get_entity( - dev_climate, platform=Platform.CLIMATE, entity_type=ThermostatEntity - ) + entity: ThermostatEntity = get_entity(dev_climate, platform=Platform.CLIMATE) if preset: await entity.async_set_preset_mode(preset) await zha_gateway.async_block_till_done() @@ -618,6 +795,14 @@ async def test_target_temperature( assert entity.state["target_temperature"] == target_temp +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) @pytest.mark.parametrize( "preset, unoccupied, target_temp", ( @@ -634,7 +819,7 @@ async def test_target_temperature_high( ): """Test target temperature high property.""" - dev_climate = await device_climate_mock( + _, dev_climate = await device_climate_mock( zha_gateway, CLIMATE_SINOPE, { @@ -645,16 +830,23 @@ async def test_target_temperature_high( manuf=MANUF_SINOPE, quirk=zhaquirks.sinope.thermostat.SinopeTechnologiesThermostat, ) - entity: ThermostatEntity = get_entity( - dev_climate, platform=Platform.CLIMATE, entity_type=ThermostatEntity - ) + entity: ThermostatEntity = get_entity(dev_climate, platform=Platform.CLIMATE) if preset: await entity.async_set_preset_mode(preset) await zha_gateway.async_block_till_done() assert entity.state["target_temperature_high"] == target_temp + assert entity.target_temperature_high == target_temp +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) @pytest.mark.parametrize( "preset, unoccupied, target_temp", ( @@ -671,7 +863,7 @@ async def test_target_temperature_low( ): """Test target temperature low property.""" - dev_climate = await device_climate_mock( + _, dev_climate = await device_climate_mock( zha_gateway, CLIMATE_SINOPE, { @@ -682,16 +874,23 @@ async def test_target_temperature_low( manuf=MANUF_SINOPE, quirk=zhaquirks.sinope.thermostat.SinopeTechnologiesThermostat, ) - entity: ThermostatEntity = get_entity( - dev_climate, platform=Platform.CLIMATE, entity_type=ThermostatEntity - ) + entity: ThermostatEntity = get_entity(dev_climate, platform=Platform.CLIMATE) if preset: await entity.async_set_preset_mode(preset) await zha_gateway.async_block_till_done() assert entity.state["target_temperature_low"] == target_temp + assert entity.target_temperature_low == target_temp +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) @pytest.mark.parametrize( "hvac_mode, sys_mode", ( @@ -710,19 +909,19 @@ async def test_set_hvac_mode( ): """Test setting hvac mode.""" - device_climate = await device_climate_mock(zha_gateway, CLIMATE) - thrm_cluster = device_climate.device.endpoints[1].thermostat - entity: ThermostatEntity = get_entity( - device_climate, platform=Platform.CLIMATE, entity_type=ThermostatEntity - ) + zigpy_device, device_climate = await device_climate_mock(zha_gateway, CLIMATE) + thrm_cluster = zigpy_device.endpoints[1].thermostat + entity: ThermostatEntity = get_entity(device_climate, platform=Platform.CLIMATE) assert entity.state["hvac_mode"] == "off" + assert entity.hvac_mode == "off" await entity.async_set_hvac_mode(hvac_mode) await zha_gateway.async_block_till_done() if sys_mode is not None: assert entity.state["hvac_mode"] == hvac_mode + assert entity.hvac_mode == hvac_mode assert thrm_cluster.write_attributes.call_count == 1 assert thrm_cluster.write_attributes.call_args[0][0] == { "system_mode": sys_mode @@ -730,6 +929,7 @@ async def test_set_hvac_mode( else: assert thrm_cluster.write_attributes.call_count == 0 assert entity.state["hvac_mode"] == "off" + assert entity.hvac_mode == "off" # turn off thrm_cluster.write_attributes.reset_mock() @@ -737,23 +937,32 @@ async def test_set_hvac_mode( await zha_gateway.async_block_till_done() assert entity.state["hvac_mode"] == "off" + assert entity.hvac_mode == "off" assert thrm_cluster.write_attributes.call_count == 1 assert thrm_cluster.write_attributes.call_args[0][0] == { "system_mode": Thermostat.SystemMode.Off } +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_preset_setting( zha_gateway: Gateway, ): """Test preset setting.""" - dev_climate_sinope = await device_climate_sinope(zha_gateway) - thrm_cluster = dev_climate_sinope.device.endpoints[1].thermostat - entity: ThermostatEntity = get_entity( - dev_climate_sinope, platform=Platform.CLIMATE, entity_type=ThermostatEntity - ) + + zigpy_device, dev_climate_sinope = await device_climate_sinope(zha_gateway) + thrm_cluster = zigpy_device.endpoints[1].thermostat + entity: ThermostatEntity = get_entity(dev_climate_sinope, platform=Platform.CLIMATE) assert entity.state["preset_mode"] == "none" + assert entity.preset_mode == "none" # unsuccessful occupancy change thrm_cluster.write_attributes.return_value = [ @@ -772,6 +981,7 @@ async def test_preset_setting( await zha_gateway.async_block_till_done() assert entity.state["preset_mode"] == "none" + assert entity.preset_mode == "none" assert thrm_cluster.write_attributes.call_count == 1 assert thrm_cluster.write_attributes.call_args[0][0] == {"set_occupancy": 0} @@ -784,6 +994,7 @@ async def test_preset_setting( await zha_gateway.async_block_till_done() assert entity.state["preset_mode"] == "away" + assert entity.preset_mode == "away" assert thrm_cluster.write_attributes.call_count == 1 assert thrm_cluster.write_attributes.call_args[0][0] == {"set_occupancy": 0} @@ -806,6 +1017,7 @@ async def test_preset_setting( await zha_gateway.async_block_till_done() assert entity.state["preset_mode"] == "away" + assert entity.preset_mode == "away" assert thrm_cluster.write_attributes.call_count == 1 assert thrm_cluster.write_attributes.call_args[0][0] == {"set_occupancy": 1} @@ -819,56 +1031,82 @@ async def test_preset_setting( await zha_gateway.async_block_till_done() assert entity.state["preset_mode"] == "none" + assert entity.preset_mode == "none" assert thrm_cluster.write_attributes.call_count == 1 assert thrm_cluster.write_attributes.call_args[0][0] == {"set_occupancy": 1} +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_preset_setting_invalid( zha_gateway: Gateway, ): """Test invalid preset setting.""" - dev_climate_sinope = await device_climate_sinope(zha_gateway) - thrm_cluster = dev_climate_sinope.device.endpoints[1].thermostat - entity: ThermostatEntity = get_entity( - dev_climate_sinope, platform=Platform.CLIMATE, entity_type=ThermostatEntity - ) + + zigpy_device, dev_climate_sinope = await device_climate_sinope(zha_gateway) + thrm_cluster = zigpy_device.endpoints[1].thermostat + entity: ThermostatEntity = get_entity(dev_climate_sinope, platform=Platform.CLIMATE) assert entity.state["preset_mode"] == "none" + assert entity.preset_mode == "none" await entity.async_set_preset_mode("invalid_preset") await zha_gateway.async_block_till_done() assert entity.state["preset_mode"] == "none" + assert entity.preset_mode == "none" assert thrm_cluster.write_attributes.call_count == 0 +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_set_temperature_hvac_mode( zha_gateway: Gateway, ): """Test setting HVAC mode in temperature service call.""" - device_climate = await device_climate_mock(zha_gateway, CLIMATE) - thrm_cluster = device_climate.device.endpoints[1].thermostat - entity: ThermostatEntity = get_entity( - device_climate, platform=Platform.CLIMATE, entity_type=ThermostatEntity - ) + zigpy_device, device_climate = await device_climate_mock(zha_gateway, CLIMATE) + thrm_cluster = zigpy_device.endpoints[1].thermostat + entity: ThermostatEntity = get_entity(device_climate, platform=Platform.CLIMATE) assert entity.state["hvac_mode"] == "off" + assert entity.hvac_mode == "off" await entity.async_set_temperature(hvac_mode="heat_cool", temperature=20) await zha_gateway.async_block_till_done() assert entity.state["hvac_mode"] == "heat_cool" + assert entity.hvac_mode == "heat_cool" assert thrm_cluster.write_attributes.await_count == 1 assert thrm_cluster.write_attributes.call_args[0][0] == { "system_mode": Thermostat.SystemMode.Auto } +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_set_temperature_heat_cool( zha_gateway: Gateway, ): """Test setting temperature service call in heating/cooling HVAC mode.""" - device_climate = await device_climate_mock( + zigpy_device, device_climate = await device_climate_mock( zha_gateway, CLIMATE_SINOPE, { @@ -881,25 +1119,28 @@ async def test_set_temperature_heat_cool( manuf=MANUF_SINOPE, quirk=zhaquirks.sinope.thermostat.SinopeTechnologiesThermostat, ) - thrm_cluster = device_climate.device.endpoints[1].thermostat - entity: ThermostatEntity = get_entity( - device_climate, platform=Platform.CLIMATE, entity_type=ThermostatEntity - ) + thrm_cluster = zigpy_device.endpoints[1].thermostat + entity: ThermostatEntity = get_entity(device_climate, platform=Platform.CLIMATE) assert entity.state["hvac_mode"] == "heat_cool" + assert entity.hvac_mode == "heat_cool" await entity.async_set_temperature(temperature=20) await zha_gateway.async_block_till_done() assert entity.state["target_temperature_low"] == 20.0 + assert entity.target_temperature_low == 20.0 assert entity.state["target_temperature_high"] == 25.0 + assert entity.target_temperature_high == 25.0 assert thrm_cluster.write_attributes.await_count == 0 await entity.async_set_temperature(target_temp_high=26, target_temp_low=19) await zha_gateway.async_block_till_done() assert entity.state["target_temperature_low"] == 19.0 + assert entity.target_temperature_low == 19.0 assert entity.state["target_temperature_high"] == 26.0 + assert entity.target_temperature_high == 26.0 assert thrm_cluster.write_attributes.await_count == 2 assert thrm_cluster.write_attributes.call_args_list[0][0][0] == { "occupied_heating_setpoint": 1900 @@ -916,7 +1157,9 @@ async def test_set_temperature_heat_cool( await zha_gateway.async_block_till_done() assert entity.state["target_temperature_low"] == 15.0 + assert entity.target_temperature_low == 15.0 assert entity.state["target_temperature_high"] == 30.0 + assert entity.target_temperature_high == 30.0 assert thrm_cluster.write_attributes.await_count == 2 assert thrm_cluster.write_attributes.call_args_list[0][0][0] == { "unoccupied_heating_setpoint": 1500 @@ -926,12 +1169,20 @@ async def test_set_temperature_heat_cool( } +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_set_temperature_heat( zha_gateway: Gateway, ): """Test setting temperature service call in heating HVAC mode.""" - device_climate = await device_climate_mock( + zigpy_device, device_climate = await device_climate_mock( zha_gateway, CLIMATE_SINOPE, { @@ -944,27 +1195,32 @@ async def test_set_temperature_heat( manuf=MANUF_SINOPE, quirk=zhaquirks.sinope.thermostat.SinopeTechnologiesThermostat, ) - thrm_cluster = device_climate.device.endpoints[1].thermostat - entity: ThermostatEntity = get_entity( - device_climate, platform=Platform.CLIMATE, entity_type=ThermostatEntity - ) + thrm_cluster = zigpy_device.endpoints[1].thermostat + entity: ThermostatEntity = get_entity(device_climate, platform=Platform.CLIMATE) assert entity.state["hvac_mode"] == "heat" + assert entity.hvac_mode == "heat" await entity.async_set_temperature(target_temp_high=30, target_temp_low=15) await zha_gateway.async_block_till_done() assert entity.state["target_temperature_low"] is None + assert entity.target_temperature_low is None assert entity.state["target_temperature_high"] is None + assert entity.target_temperature_high is None assert entity.state["target_temperature"] == 20.0 + assert entity.target_temperature == 20.0 assert thrm_cluster.write_attributes.await_count == 0 await entity.async_set_temperature(temperature=21) await zha_gateway.async_block_till_done() assert entity.state["target_temperature_low"] is None + assert entity.target_temperature_low is None assert entity.state["target_temperature_high"] is None + assert entity.target_temperature_high is None assert entity.state["target_temperature"] == 21.0 + assert entity.target_temperature == 21.0 assert thrm_cluster.write_attributes.await_count == 1 assert thrm_cluster.write_attributes.call_args_list[0][0][0] == { "occupied_heating_setpoint": 2100 @@ -978,20 +1234,31 @@ async def test_set_temperature_heat( await zha_gateway.async_block_till_done() assert entity.state["target_temperature_low"] is None + assert entity.target_temperature_low is None assert entity.state["target_temperature_high"] is None + assert entity.target_temperature_high is None assert entity.state["target_temperature"] == 22.0 + assert entity.target_temperature == 22.0 assert thrm_cluster.write_attributes.await_count == 1 assert thrm_cluster.write_attributes.call_args_list[0][0][0] == { "unoccupied_heating_setpoint": 2200 } +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_set_temperature_cool( zha_gateway: Gateway, ): """Test setting temperature service call in cooling HVAC mode.""" - device_climate = await device_climate_mock( + zigpy_device, device_climate = await device_climate_mock( zha_gateway, CLIMATE_SINOPE, { @@ -1004,27 +1271,32 @@ async def test_set_temperature_cool( manuf=MANUF_SINOPE, quirk=zhaquirks.sinope.thermostat.SinopeTechnologiesThermostat, ) - thrm_cluster = device_climate.device.endpoints[1].thermostat - entity: ThermostatEntity = get_entity( - device_climate, platform=Platform.CLIMATE, entity_type=ThermostatEntity - ) + thrm_cluster = zigpy_device.endpoints[1].thermostat + entity: ThermostatEntity = get_entity(device_climate, platform=Platform.CLIMATE) assert entity.state["hvac_mode"] == "cool" + assert entity.hvac_mode == "cool" await entity.async_set_temperature(target_temp_high=30, target_temp_low=15) await zha_gateway.async_block_till_done() assert entity.state["target_temperature_low"] is None + assert entity.target_temperature_low is None assert entity.state["target_temperature_high"] is None + assert entity.target_temperature_high is None assert entity.state["target_temperature"] == 25.0 + assert entity.target_temperature == 25.0 assert thrm_cluster.write_attributes.await_count == 0 await entity.async_set_temperature(temperature=21) await zha_gateway.async_block_till_done() assert entity.state["target_temperature_low"] is None + assert entity.target_temperature_low is None assert entity.state["target_temperature_high"] is None + assert entity.target_temperature_high is None assert entity.state["target_temperature"] == 21.0 + assert entity.target_temperature == 21.0 assert thrm_cluster.write_attributes.await_count == 1 assert thrm_cluster.write_attributes.call_args_list[0][0][0] == { "occupied_cooling_setpoint": 2100 @@ -1038,14 +1310,25 @@ async def test_set_temperature_cool( await zha_gateway.async_block_till_done() assert entity.state["target_temperature_low"] is None + assert entity.target_temperature_low is None assert entity.state["target_temperature_high"] is None + assert entity.target_temperature_high is None assert entity.state["target_temperature"] == 22.0 + assert entity.target_temperature == 22.0 assert thrm_cluster.write_attributes.await_count == 1 assert thrm_cluster.write_attributes.call_args_list[0][0][0] == { "unoccupied_cooling_setpoint": 2200 } +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_set_temperature_wrong_mode( zha_gateway: Gateway, ): @@ -1056,7 +1339,7 @@ async def test_set_temperature_wrong_mode( "ep_attribute", "sinope_manufacturer_specific", ): - device_climate = await device_climate_mock( + zigpy_device, device_climate = await device_climate_mock( zha_gateway, CLIMATE_SINOPE, { @@ -1068,39 +1351,50 @@ async def test_set_temperature_wrong_mode( }, manuf=MANUF_SINOPE, ) - thrm_cluster = device_climate.device.endpoints[1].thermostat - entity: ThermostatEntity = get_entity( - device_climate, platform=Platform.CLIMATE, entity_type=ThermostatEntity - ) + thrm_cluster = zigpy_device.endpoints[1].thermostat + entity: ThermostatEntity = get_entity(device_climate, platform=Platform.CLIMATE) assert entity.state["hvac_mode"] == "dry" + assert entity.hvac_mode == "dry" await entity.async_set_temperature(temperature=24) await zha_gateway.async_block_till_done() assert entity.state["target_temperature_low"] is None + assert entity.target_temperature_low is None assert entity.state["target_temperature_high"] is None + assert entity.target_temperature_high is None assert entity.state["target_temperature"] is None + assert entity.target_temperature is None assert thrm_cluster.write_attributes.await_count == 0 +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_occupancy_reset( zha_gateway: Gateway, ): """Test away preset reset.""" - dev_climate_sinope = await device_climate_sinope(zha_gateway) - thrm_cluster = dev_climate_sinope.device.endpoints[1].thermostat - entity: ThermostatEntity = get_entity( - dev_climate_sinope, platform=Platform.CLIMATE, entity_type=ThermostatEntity - ) + + zigpy_device, dev_climate_sinope = await device_climate_sinope(zha_gateway) + thrm_cluster = zigpy_device.endpoints[1].thermostat + entity: ThermostatEntity = get_entity(dev_climate_sinope, platform=Platform.CLIMATE) assert entity.state["preset_mode"] == "none" + assert entity.preset_mode == "none" await entity.async_set_preset_mode("away") await zha_gateway.async_block_till_done() thrm_cluster.write_attributes.reset_mock() assert entity.state["preset_mode"] == "away" + assert entity.preset_mode == "away" await send_attributes_report( zha_gateway, @@ -1108,20 +1402,31 @@ async def test_occupancy_reset( {"occupied_heating_setpoint": zigpy.types.uint16_t(1950)}, ) assert entity.state["preset_mode"] == "none" + assert entity.preset_mode == "none" +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_fan_mode( zha_gateway: Gateway, ): """Test fan mode.""" - device_climate_fan = await device_climate_mock(zha_gateway, CLIMATE_FAN) - thrm_cluster = device_climate_fan.device.endpoints[1].thermostat - entity: ThermostatEntity = get_entity( - device_climate_fan, platform=Platform.CLIMATE, entity_type=ThermostatEntity + + zigpy_device, device_climate_fan = await device_climate_mock( + zha_gateway, CLIMATE_FAN ) + thrm_cluster = zigpy_device.endpoints[1].thermostat + entity: ThermostatEntity = get_entity(device_climate_fan, platform=Platform.CLIMATE) assert set(entity.fan_modes) == {FanState.AUTO, FanState.ON} assert entity.state["fan_mode"] == FanState.AUTO + assert entity.fan_mode == FanState.AUTO await send_attributes_report( zha_gateway, @@ -1129,11 +1434,13 @@ async def test_fan_mode( {"running_state": Thermostat.RunningState.Fan_State_On}, ) assert entity.state["fan_mode"] == FanState.ON + assert entity.fan_mode == FanState.ON await send_attributes_report( zha_gateway, thrm_cluster, {"running_state": Thermostat.RunningState.Idle} ) assert entity.state["fan_mode"] == FanState.AUTO + assert entity.fan_mode == FanState.AUTO await send_attributes_report( zha_gateway, @@ -1141,34 +1448,54 @@ async def test_fan_mode( {"running_state": Thermostat.RunningState.Fan_2nd_Stage_On}, ) assert entity.state["fan_mode"] == FanState.ON + assert entity.fan_mode == FanState.ON +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_set_fan_mode_not_supported( zha_gateway: Gateway, ): """Test fan setting unsupported mode.""" - device_climate_fan = await device_climate_mock(zha_gateway, CLIMATE_FAN) - fan_cluster = device_climate_fan.device.endpoints[1].fan - entity: ThermostatEntity = get_entity( - device_climate_fan, platform=Platform.CLIMATE, entity_type=ThermostatEntity + + zigpy_device, device_climate_fan = await device_climate_mock( + zha_gateway, CLIMATE_FAN ) + fan_cluster = zigpy_device.endpoints[1].fan + entity: ThermostatEntity = get_entity(device_climate_fan, platform=Platform.CLIMATE) await entity.async_set_fan_mode(FanState.LOW) await zha_gateway.async_block_till_done() assert fan_cluster.write_attributes.await_count == 0 +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_set_fan_mode( zha_gateway: Gateway, ): """Test fan mode setting.""" - device_climate_fan = await device_climate_mock(zha_gateway, CLIMATE_FAN) - fan_cluster = device_climate_fan.device.endpoints[1].fan - entity: ThermostatEntity = get_entity( - device_climate_fan, platform=Platform.CLIMATE, entity_type=ThermostatEntity + + zigpy_device, device_climate_fan = await device_climate_mock( + zha_gateway, CLIMATE_FAN ) + fan_cluster = zigpy_device.endpoints[1].fan + entity: ThermostatEntity = get_entity(device_climate_fan, platform=Platform.CLIMATE) assert entity.state["fan_mode"] == FanState.AUTO + assert entity.fan_mode == FanState.AUTO await entity.async_set_fan_mode(FanState.ON) await zha_gateway.async_block_till_done() @@ -1183,21 +1510,32 @@ async def test_set_fan_mode( assert fan_cluster.write_attributes.call_args[0][0] == {"fan_mode": 5} -async def test_set_moes_preset(zha_gateway: Gateway): +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) +async def test_set_moes_preset( + zha_gateway: Gateway, +): """Test setting preset for moes trv.""" - device_climate_moes = await device_climate_mock( + zigpy_device, device_climate_moes = await device_climate_mock( zha_gateway, CLIMATE_MOES, manuf=MANUF_MOES, quirk=zhaquirks.tuya.ts0601_trv.MoesHY368_Type1, ) - thrm_cluster = device_climate_moes.device.endpoints[1].thermostat + thrm_cluster = zigpy_device.endpoints[1].thermostat entity: ThermostatEntity = get_entity( - device_climate_moes, platform=Platform.CLIMATE, entity_type=ThermostatEntity + device_climate_moes, platform=Platform.CLIMATE ) assert entity.state["preset_mode"] == "none" + assert entity.preset_mode == "none" await entity.async_set_preset_mode("away") await zha_gateway.async_block_till_done() @@ -1277,52 +1615,78 @@ async def test_set_moes_preset(zha_gateway: Gateway): } -async def test_set_moes_operation_mode(zha_gateway: Gateway): +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) +async def test_set_moes_operation_mode( + zha_gateway: Gateway, +): """Test setting preset for moes trv.""" - device_climate_moes = await device_climate_mock( + + zigpy_device, device_climate_moes = await device_climate_mock( zha_gateway, CLIMATE_MOES, manuf=MANUF_MOES, quirk=zhaquirks.tuya.ts0601_trv.MoesHY368_Type1, ) - thrm_cluster = device_climate_moes.device.endpoints[1].thermostat + thrm_cluster = zigpy_device.endpoints[1].thermostat entity: ThermostatEntity = get_entity( - device_climate_moes, platform=Platform.CLIMATE, entity_type=ThermostatEntity + device_climate_moes, platform=Platform.CLIMATE ) await send_attributes_report(zha_gateway, thrm_cluster, {"operation_preset": 0}) assert entity.state["preset_mode"] == "away" + assert entity.preset_mode == "away" await send_attributes_report(zha_gateway, thrm_cluster, {"operation_preset": 1}) assert entity.state["preset_mode"] == "Schedule" + assert entity.preset_mode == "Schedule" await send_attributes_report(zha_gateway, thrm_cluster, {"operation_preset": 2}) assert entity.state["preset_mode"] == "none" + assert entity.preset_mode == "none" await send_attributes_report(zha_gateway, thrm_cluster, {"operation_preset": 3}) assert entity.state["preset_mode"] == "comfort" + assert entity.preset_mode == "comfort" await send_attributes_report(zha_gateway, thrm_cluster, {"operation_preset": 4}) assert entity.state["preset_mode"] == "eco" + assert entity.preset_mode == "eco" await send_attributes_report(zha_gateway, thrm_cluster, {"operation_preset": 5}) assert entity.state["preset_mode"] == "boost" + assert entity.preset_mode == "boost" await send_attributes_report(zha_gateway, thrm_cluster, {"operation_preset": 6}) assert entity.state["preset_mode"] == "Complex" + assert entity.preset_mode == "Complex" # Device is running an energy-saving mode PRESET_ECO = "eco" +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) @pytest.mark.parametrize( ("preset_attr", "preset_mode"), [ @@ -1340,15 +1704,15 @@ async def test_beca_operation_mode_update( preset_mode: str, ) -> None: """Test beca trv operation mode attribute update.""" - device_climate_beca = await device_climate_mock( + zigpy_device, device_climate_beca = await device_climate_mock( zha_gateway, CLIMATE_BECA, manuf=MANUF_BECA, quirk=zhaquirks.tuya.ts0601_trv.MoesHY368_Type1new, ) - thrm_cluster = device_climate_beca.device.endpoints[1].thermostat + thrm_cluster = zigpy_device.endpoints[1].thermostat entity: ThermostatEntity = get_entity( - device_climate_beca, platform=Platform.CLIMATE, entity_type=ThermostatEntity + device_climate_beca, platform=Platform.CLIMATE ) # Test sending an attribute report @@ -1357,6 +1721,7 @@ async def test_beca_operation_mode_update( ) assert entity.state[ATTR_PRESET_MODE] == preset_mode + assert entity.preset_mode == preset_mode await entity.async_set_preset_mode(preset_mode) await zha_gateway.async_block_till_done() @@ -1369,22 +1734,33 @@ async def test_beca_operation_mode_update( ] -async def test_set_zonnsmart_preset(zha_gateway: Gateway) -> None: +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) +async def test_set_zonnsmart_preset( + zha_gateway: Gateway, +) -> None: """Test setting preset from homeassistant for zonnsmart trv.""" - device_climate_zonnsmart = await device_climate_mock( + + zigpy_device, device_climate_zonnsmart = await device_climate_mock( zha_gateway, CLIMATE_ZONNSMART, manuf=MANUF_ZONNSMART, quirk=zhaquirks.tuya.ts0601_trv.ZonnsmartTV01_ZG, ) - thrm_cluster = device_climate_zonnsmart.device.endpoints[1].thermostat + thrm_cluster = zigpy_device.endpoints[1].thermostat entity: ThermostatEntity = get_entity( device_climate_zonnsmart, platform=Platform.CLIMATE, - entity_type=ThermostatEntity, ) assert entity.state[ATTR_PRESET_MODE] == PRESET_NONE + assert entity.preset_mode == PRESET_NONE await entity.async_set_preset_mode(PRESET_SCHEDULE) await zha_gateway.async_block_till_done() @@ -1429,37 +1805,52 @@ async def test_set_zonnsmart_preset(zha_gateway: Gateway) -> None: } -async def test_set_zonnsmart_operation_mode(zha_gateway: Gateway) -> None: +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) +async def test_set_zonnsmart_operation_mode( + zha_gateway: Gateway, +) -> None: """Test setting preset from trv for zonnsmart trv.""" - device_climate_zonnsmart = await device_climate_mock( + + zigpy_device, device_climate_zonnsmart = await device_climate_mock( zha_gateway, CLIMATE_ZONNSMART, manuf=MANUF_ZONNSMART, quirk=zhaquirks.tuya.ts0601_trv.ZonnsmartTV01_ZG, ) - thrm_cluster = device_climate_zonnsmart.device.endpoints[1].thermostat + thrm_cluster = zigpy_device.endpoints[1].thermostat entity: ThermostatEntity = get_entity( device_climate_zonnsmart, platform=Platform.CLIMATE, - entity_type=ThermostatEntity, ) await send_attributes_report(zha_gateway, thrm_cluster, {"operation_preset": 0}) assert entity.state[ATTR_PRESET_MODE] == PRESET_SCHEDULE + assert entity.preset_mode == PRESET_SCHEDULE await send_attributes_report(zha_gateway, thrm_cluster, {"operation_preset": 1}) assert entity.state[ATTR_PRESET_MODE] == PRESET_NONE + assert entity.preset_mode == PRESET_NONE await send_attributes_report(zha_gateway, thrm_cluster, {"operation_preset": 2}) assert entity.state[ATTR_PRESET_MODE] == "holiday" + assert entity.preset_mode == "holiday" await send_attributes_report(zha_gateway, thrm_cluster, {"operation_preset": 3}) assert entity.state[ATTR_PRESET_MODE] == "holiday" + assert entity.preset_mode == "holiday" await send_attributes_report(zha_gateway, thrm_cluster, {"operation_preset": 4}) assert entity.state[ATTR_PRESET_MODE] == "frost protect" + assert entity.preset_mode == "frost protect" diff --git a/tests/test_cover.py b/tests/test_cover.py index 5e7a66ea1..2de36cf3c 100644 --- a/tests/test_cover.py +++ b/tests/test_cover.py @@ -26,17 +26,15 @@ from zha.application import Platform from zha.application.const import ATTR_COMMAND from zha.application.gateway import Gateway -from zha.application.platforms.cover import ( - ATTR_CURRENT_POSITION, - STATE_CLOSED, - STATE_OPEN, -) +from zha.application.platforms.cover import STATE_CLOSED, STATE_OPEN from zha.application.platforms.cover.const import ( + ATTR_CURRENT_POSITION, STATE_CLOSING, STATE_OPENING, CoverEntityFeature, ) from zha.exceptions import ZHAException +from zha.zigbee.device import WebSocketClientDevice Default_Response = zcl_f.GENERAL_COMMANDS[zcl_f.GeneralCommand.Default_Response].schema @@ -91,11 +89,20 @@ WCCS = closures.WindowCovering.ConfigStatus +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_cover_non_tilt_initial_state( # pylint: disable=unused-argument zha_gateway: Gateway, ) -> None: """Test ZHA cover platform.""" + zigpy_cover_device = create_mock_zigpy_device(zha_gateway, ZIGPY_COVER_DEVICE) # load up cover domain zigpy_cover_device = create_mock_zigpy_device(zha_gateway, ZIGPY_COVER_DEVICE) cluster = zigpy_cover_device.endpoints[1].window_covering @@ -106,11 +113,19 @@ async def test_cover_non_tilt_initial_state( # pylint: disable=unused-argument } update_attribute_cache(cluster) zha_device = await join_zigpy_device(zha_gateway, zigpy_cover_device) - assert ( - not zha_device.endpoints[1] - .all_cluster_handlers[f"1:0x{cluster.cluster_id:04x}"] - .inverted - ) + + if isinstance(zha_device, WebSocketClientDevice): + ch = ( + zha_gateway.ws_gateway.devices[zha_device.ieee] + .endpoints[1] + .all_cluster_handlers[f"1:0x{cluster.cluster_id:04x}"] + ) + else: + ch = zha_device.endpoints[1].all_cluster_handlers[ + f"1:0x{cluster.cluster_id:04x}" + ] + assert not ch.inverted + assert cluster.read_attributes.call_count == 3 assert ( WCAttrs.current_position_lift_percentage.name @@ -141,6 +156,14 @@ async def test_cover_non_tilt_initial_state( # pylint: disable=unused-argument assert entity.state[ATTR_CURRENT_POSITION] == 0 +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_cover( zha_gateway: Gateway, ) -> None: @@ -157,11 +180,17 @@ async def test_cover( update_attribute_cache(cluster) zha_device = await join_zigpy_device(zha_gateway, zigpy_cover_device) - assert ( - not zha_device.endpoints[1] - .all_cluster_handlers[f"1:0x{cluster.cluster_id:04x}"] - .inverted - ) + if isinstance(zha_device, WebSocketClientDevice): + ch = ( + zha_gateway.ws_gateway.devices[zha_device.ieee] + .endpoints[1] + .all_cluster_handlers[f"1:0x{cluster.cluster_id:04x}"] + ) + else: + ch = zha_device.endpoints[1].all_cluster_handlers[ + f"1:0x{cluster.cluster_id:04x}" + ] + assert not ch.inverted assert cluster.read_attributes.call_count == 3 @@ -191,12 +220,15 @@ async def test_cover( zha_gateway, cluster, {WCAttrs.current_position_lift_percentage.id: 100} ) assert entity.state["state"] == STATE_CLOSED + assert entity.current_cover_position == 0 + assert entity.is_closed is True # test to see if it opens await send_attributes_report( zha_gateway, cluster, {WCAttrs.current_position_lift_percentage.id: 0} ) assert entity.state["state"] == STATE_OPEN + assert entity.is_closed is False # test that the state remains after tilting to 100% await send_attributes_report( @@ -209,6 +241,7 @@ async def test_cover( zha_gateway, cluster, {WCAttrs.current_position_tilt_percentage.id: 0} ) assert entity.state["state"] == STATE_OPEN + assert entity.current_cover_tilt_position == 100 cluster.PLUGGED_ATTR_READS = {1: 100} update_attribute_cache(cluster) @@ -267,6 +300,7 @@ async def test_cover( assert cluster.request.call_args[1]["expect_reply"] is True assert entity.state["state"] == STATE_OPENING + assert entity.is_opening is True await send_attributes_report( zha_gateway, cluster, {WCAttrs.current_position_lift_percentage.id: 0} @@ -370,9 +404,20 @@ async def test_cover( assert cluster.request.call_args[1]["expect_reply"] is True -async def test_cover_failures(zha_gateway: Gateway) -> None: +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) +async def test_cover_failures( + zha_gateway: Gateway, +) -> None: """Test ZHA cover platform failure cases.""" + zigpy_cover_device = create_mock_zigpy_device(zha_gateway, ZIGPY_COVER_DEVICE) # load up cover domain zigpy_cover_device = create_mock_zigpy_device(zha_gateway, ZIGPY_COVER_DEVICE) cluster = zigpy_cover_device.endpoints[1].window_covering @@ -392,6 +437,11 @@ async def test_cover_failures(zha_gateway: Gateway) -> None: assert entity.state["state"] == STATE_OPEN + exception_string = ( + r"Failed to close cover" + if isinstance(zha_gateway, Gateway) + else "(5, 'PLATFORM_ENTITY_ACTION_ERROR')" + ) # close from UI with patch( "zigpy.zcl.Cluster.request", @@ -400,7 +450,7 @@ async def test_cover_failures(zha_gateway: Gateway) -> None: status=zcl_f.Status.UNSUP_CLUSTER_COMMAND, ), ): - with pytest.raises(ZHAException, match=r"Failed to close cover"): + with pytest.raises(ZHAException, match=exception_string): await entity.async_close_cover() await zha_gateway.async_block_till_done() assert cluster.request.call_count == 1 @@ -410,6 +460,11 @@ async def test_cover_failures(zha_gateway: Gateway) -> None: ) assert entity.state["state"] == STATE_OPEN + exception_string = ( + r"Failed to close cover tilt" + if isinstance(zha_gateway, Gateway) + else "(6, 'PLATFORM_ENTITY_ACTION_ERROR')" + ) with patch( "zigpy.zcl.Cluster.request", return_value=Default_Response( @@ -417,7 +472,7 @@ async def test_cover_failures(zha_gateway: Gateway) -> None: status=zcl_f.Status.UNSUP_CLUSTER_COMMAND, ), ): - with pytest.raises(ZHAException, match=r"Failed to close cover tilt"): + with pytest.raises(ZHAException, match=exception_string): await entity.async_close_cover_tilt() await zha_gateway.async_block_till_done() assert cluster.request.call_count == 1 @@ -426,6 +481,11 @@ async def test_cover_failures(zha_gateway: Gateway) -> None: == closures.WindowCovering.ServerCommandDefs.go_to_tilt_percentage.id ) + exception_string = ( + r"Failed to open cover" + if isinstance(zha_gateway, Gateway) + else "(7, 'PLATFORM_ENTITY_ACTION_ERROR')" + ) # open from UI with patch( "zigpy.zcl.Cluster.request", @@ -434,7 +494,7 @@ async def test_cover_failures(zha_gateway: Gateway) -> None: status=zcl_f.Status.UNSUP_CLUSTER_COMMAND, ), ): - with pytest.raises(ZHAException, match=r"Failed to open cover"): + with pytest.raises(ZHAException, match=exception_string): await entity.async_open_cover() await zha_gateway.async_block_till_done() assert cluster.request.call_count == 1 @@ -443,6 +503,11 @@ async def test_cover_failures(zha_gateway: Gateway) -> None: == closures.WindowCovering.ServerCommandDefs.up_open.id ) + exception_string = ( + r"Failed to open cover tilt" + if isinstance(zha_gateway, Gateway) + else "(8, 'PLATFORM_ENTITY_ACTION_ERROR')" + ) with patch( "zigpy.zcl.Cluster.request", return_value=Default_Response( @@ -450,7 +515,7 @@ async def test_cover_failures(zha_gateway: Gateway) -> None: status=zcl_f.Status.UNSUP_CLUSTER_COMMAND, ), ): - with pytest.raises(ZHAException, match=r"Failed to open cover tilt"): + with pytest.raises(ZHAException, match=exception_string): await entity.async_open_cover_tilt() await zha_gateway.async_block_till_done() assert cluster.request.call_count == 1 @@ -459,6 +524,11 @@ async def test_cover_failures(zha_gateway: Gateway) -> None: == closures.WindowCovering.ServerCommandDefs.go_to_tilt_percentage.id ) + exception_string = ( + r"Failed to set cover position" + if isinstance(zha_gateway, Gateway) + else "(9, 'PLATFORM_ENTITY_ACTION_ERROR')" + ) # set position UI with patch( "zigpy.zcl.Cluster.request", @@ -467,7 +537,7 @@ async def test_cover_failures(zha_gateway: Gateway) -> None: status=zcl_f.Status.UNSUP_CLUSTER_COMMAND, ), ): - with pytest.raises(ZHAException, match=r"Failed to set cover position"): + with pytest.raises(ZHAException, match=exception_string): await entity.async_set_cover_position(position=47) await zha_gateway.async_block_till_done() @@ -477,6 +547,11 @@ async def test_cover_failures(zha_gateway: Gateway) -> None: == closures.WindowCovering.ServerCommandDefs.go_to_lift_percentage.id ) + exception_string = ( + r"Failed to set cover tilt position" + if isinstance(zha_gateway, Gateway) + else "(10, 'PLATFORM_ENTITY_ACTION_ERROR')" + ) with patch( "zigpy.zcl.Cluster.request", return_value=Default_Response( @@ -484,7 +559,7 @@ async def test_cover_failures(zha_gateway: Gateway) -> None: status=zcl_f.Status.UNSUP_CLUSTER_COMMAND, ), ): - with pytest.raises(ZHAException, match=r"Failed to set cover tilt position"): + with pytest.raises(ZHAException, match=exception_string): await entity.async_set_cover_tilt_position(tilt_position=47) await zha_gateway.async_block_till_done() assert cluster.request.call_count == 1 @@ -493,6 +568,11 @@ async def test_cover_failures(zha_gateway: Gateway) -> None: == closures.WindowCovering.ServerCommandDefs.go_to_tilt_percentage.id ) + exception_string = ( + r"Failed to stop cover" + if isinstance(zha_gateway, Gateway) + else "(11, 'PLATFORM_ENTITY_ACTION_ERROR')" + ) # stop from UI with patch( "zigpy.zcl.Cluster.request", @@ -501,7 +581,7 @@ async def test_cover_failures(zha_gateway: Gateway) -> None: status=zcl_f.Status.UNSUP_CLUSTER_COMMAND, ), ): - with pytest.raises(ZHAException, match=r"Failed to stop cover"): + with pytest.raises(ZHAException, match=exception_string): await entity.async_stop_cover() await zha_gateway.async_block_till_done() assert cluster.request.call_count == 1 @@ -510,6 +590,11 @@ async def test_cover_failures(zha_gateway: Gateway) -> None: == closures.WindowCovering.ServerCommandDefs.stop.id ) + exception_string = ( + r"Failed to stop cover" + if isinstance(zha_gateway, Gateway) + else "(12, 'PLATFORM_ENTITY_ACTION_ERROR')" + ) # stop tilt from UI with patch( "zigpy.zcl.Cluster.request", @@ -518,7 +603,7 @@ async def test_cover_failures(zha_gateway: Gateway) -> None: status=zcl_f.Status.UNSUP_CLUSTER_COMMAND, ), ): - with pytest.raises(ZHAException, match=r"Failed to stop cover"): + with pytest.raises(ZHAException, match=exception_string): await entity.async_stop_cover_tilt() await zha_gateway.async_block_till_done() assert cluster.request.call_count == 1 @@ -528,6 +613,14 @@ async def test_cover_failures(zha_gateway: Gateway) -> None: ) +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_shade( zha_gateway: Gateway, ) -> None: @@ -566,6 +659,11 @@ async def test_shade( await zha_gateway.async_block_till_done() assert entity.state["state"] == STATE_OPEN + exception_string = ( + r"Failed to close cover" + if isinstance(zha_gateway, Gateway) + else "(6, 'PLATFORM_ENTITY_ACTION_ERROR')" + ) # close from client command fails with ( patch( @@ -575,7 +673,7 @@ async def test_shade( status=zcl_f.Status.UNSUP_CLUSTER_COMMAND, ), ), - pytest.raises(ZHAException, match="Failed to close cover"), + pytest.raises(ZHAException, match=exception_string), ): await entity.async_close_cover() await zha_gateway.async_block_till_done() @@ -598,6 +696,11 @@ async def test_shade( await send_attributes_report(zha_gateway, cluster_level, {0: 0}) assert entity.state["state"] == STATE_CLOSED + exception_string = ( + r"Failed to open cover" + if isinstance(zha_gateway, Gateway) + else "(8, 'PLATFORM_ENTITY_ACTION_ERROR')" + ) with ( patch( "zigpy.zcl.Cluster.request", @@ -606,7 +709,7 @@ async def test_shade( status=zcl_f.Status.UNSUP_CLUSTER_COMMAND, ), ), - pytest.raises(ZHAException, match="Failed to open cover"), + pytest.raises(ZHAException, match=exception_string), ): await entity.async_open_cover() await zha_gateway.async_block_till_done() @@ -626,6 +729,11 @@ async def test_shade( assert cluster_on_off.request.call_args[0][1] == 0x0001 assert entity.state["state"] == STATE_OPEN + exception_string = ( + r"Failed to set cover position" + if isinstance(zha_gateway, Gateway) + else "(10, 'PLATFORM_ENTITY_ACTION_ERROR')" + ) # set position UI command fails with ( patch( @@ -635,7 +743,7 @@ async def test_shade( status=zcl_f.Status.UNSUP_CLUSTER_COMMAND, ), ), - pytest.raises(ZHAException, match="Failed to set cover position"), + pytest.raises(ZHAException, match=exception_string), ): await entity.async_set_cover_position(position=47) await zha_gateway.async_block_till_done() @@ -661,6 +769,11 @@ async def test_shade( await send_attributes_report(zha_gateway, cluster_level, {8: 0, 0: 100, 1: 1}) assert entity.state["current_position"] == int(100 * 100 / 255) + exception_string = ( + r"Failed to stop cover" + if isinstance(zha_gateway, Gateway) + else "(12, 'PLATFORM_ENTITY_ACTION_ERROR')" + ) # stop command fails with ( patch( @@ -670,7 +783,7 @@ async def test_shade( status=zcl_f.Status.UNSUP_CLUSTER_COMMAND, ), ), - pytest.raises(ZHAException, match="Failed to stop cover"), + pytest.raises(ZHAException, match=exception_string), ): await entity.async_stop_cover() await zha_gateway.async_block_till_done() @@ -689,6 +802,14 @@ async def test_shade( assert cluster_level.request.call_args[0][1] in (0x0003, 0x0007) +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_keen_vent( zha_gateway: Gateway, ) -> None: @@ -724,12 +845,15 @@ async def test_keen_vent( await zha_gateway.async_block_till_done() assert entity.state["state"] == STATE_CLOSED + exception_string = ( + r"Failed to send request: device did not respond" + if isinstance(zha_gateway, Gateway) + else "(6, 'PLATFORM_ENTITY_ACTION_ERROR')" + ) # open from client command fails p1 = patch.object(cluster_on_off, "request", side_effect=asyncio.TimeoutError) p2 = patch.object(cluster_level, "request", AsyncMock(return_value=[4, 0])) - p3 = pytest.raises( - ZHAException, match="Failed to send request: device did not respond" - ) + p3 = pytest.raises(ZHAException, match=exception_string) with p1, p2, p3: await entity.async_open_cover() @@ -755,45 +879,71 @@ async def test_keen_vent( assert entity.state["current_position"] == 100 -async def test_cover_remote(zha_gateway: Gateway) -> None: +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) +async def test_cover_remote( + zha_gateway: Gateway, +) -> None: """Test ZHA cover remote.""" + zigpy_cover_remote = create_mock_zigpy_device(zha_gateway, ZIGPY_COVER_REMOTE) # load up cover domain zigpy_cover_remote = create_mock_zigpy_device(zha_gateway, ZIGPY_COVER_REMOTE) zha_device = await join_zigpy_device(zha_gateway, zigpy_cover_remote) - zha_device.emit_zha_event = MagicMock(wraps=zha_device.emit_zha_event) + + if isinstance(zha_gateway, Gateway): + zha_device.emit_zha_event = MagicMock(wraps=zha_device.emit_zha_event) + device = zha_device + else: + device = zha_gateway.ws_gateway.devices[zha_device.ieee] + device.emit_zha_event = MagicMock(wraps=device.emit_zha_event) cluster = zigpy_cover_remote.endpoints[1].out_clusters[ closures.WindowCovering.cluster_id ] - zha_device.emit_zha_event.reset_mock() + device.emit_zha_event.reset_mock() # up command hdr = make_zcl_header(0, global_command=False) cluster.handle_message(hdr, []) await zha_gateway.async_block_till_done() - assert zha_device.emit_zha_event.call_count == 1 - assert ATTR_COMMAND in zha_device.emit_zha_event.call_args[0][0] - assert zha_device.emit_zha_event.call_args[0][0][ATTR_COMMAND] == "up_open" + assert device.emit_zha_event.call_count == 1 + assert ATTR_COMMAND in device.emit_zha_event.call_args[0][0] + assert device.emit_zha_event.call_args[0][0][ATTR_COMMAND] == "up_open" - zha_device.emit_zha_event.reset_mock() + device.emit_zha_event.reset_mock() # down command hdr = make_zcl_header(1, global_command=False) cluster.handle_message(hdr, []) await zha_gateway.async_block_till_done() - assert zha_device.emit_zha_event.call_count == 1 - assert ATTR_COMMAND in zha_device.emit_zha_event.call_args[0][0] - assert zha_device.emit_zha_event.call_args[0][0][ATTR_COMMAND] == "down_close" + assert device.emit_zha_event.call_count == 1 + assert ATTR_COMMAND in device.emit_zha_event.call_args[0][0] + assert device.emit_zha_event.call_args[0][0][ATTR_COMMAND] == "down_close" +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_cover_state_restoration( zha_gateway: Gateway, ) -> None: """Test the cover state restoration.""" + zigpy_cover_device = create_mock_zigpy_device(zha_gateway, ZIGPY_COVER_DEVICE) zha_device = await join_zigpy_device(zha_gateway, zigpy_cover_device) entity = get_entity(zha_device, platform=Platform.COVER) @@ -808,6 +958,11 @@ async def test_cover_state_restoration( target_tilt_position=34, ) + # ws impl needs a round trip to get the state back to the client + # maybe we make this optimistic, set the state manually on the client + # and avoid the round trip refresh call? + await zha_gateway.async_block_till_done() + assert entity.state["state"] == STATE_CLOSED assert entity.state["target_lift_position"] == 12 assert entity.state["target_tilt_position"] == 34 diff --git a/tests/test_device.py b/tests/test_device.py index ef52b3e85..0c50504f0 100644 --- a/tests/test_device.py +++ b/tests/test_device.py @@ -1,6 +1,8 @@ """Test ZHA device switch.""" import asyncio +from collections.abc import Mapping, Sequence +import json import logging import time from unittest import mock @@ -37,8 +39,14 @@ from zha.application.platforms.sensor import LQISensor, RSSISensor from zha.application.platforms.switch import Switch from zha.exceptions import ZHAException -from zha.zigbee.device import ClusterBinding, get_device_automation_triggers +from zha.zigbee.device import ( + ClusterBinding, + NeighborInfo, + RouteInfo, + get_device_automation_triggers, +) from zha.zigbee.group import Group +from zha.zigbee.model import ExtendedDeviceInfo def zigpy_device( @@ -106,6 +114,14 @@ async def _send_time_changed(zha_gateway: Gateway, seconds: int): "zha.zigbee.cluster_handlers.general.BasicClusterHandler.async_initialize", new=mock.AsyncMock(), ) +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_check_available_success( zha_gateway: Gateway, caplog: pytest.LogCaptureFixture, @@ -116,6 +132,12 @@ async def test_check_available_success( ) zha_device = await join_zigpy_device(zha_gateway, device_with_basic_cluster_handler) basic_ch = device_with_basic_cluster_handler.endpoints[3].basic + if hasattr(zha_gateway, "ws_gateway"): + server_device = zha_gateway.ws_gateway.devices[zha_device.ieee] + server_gateway = zha_gateway.ws_gateway + else: + server_device = zha_device + server_gateway = zha_gateway assert not zha_device.is_coordinator assert not zha_device.is_active_coordinator @@ -123,12 +145,15 @@ async def test_check_available_success( basic_ch.read_attributes.reset_mock() device_with_basic_cluster_handler.last_seen = None assert zha_device.available is True - await _send_time_changed(zha_gateway, zha_device.consider_unavailable_time + 2) + await _send_time_changed(zha_gateway, server_device.consider_unavailable_time + 2) assert zha_device.available is False assert basic_ch.read_attributes.await_count == 0 + for entity in server_device.platform_entities.values(): + assert not entity.available + device_with_basic_cluster_handler.last_seen = ( - time.time() - zha_device.consider_unavailable_time - 100 + time.time() - server_device.consider_unavailable_time - 100 ) _seens = [time.time(), device_with_basic_cluster_handler.last_seen] @@ -138,63 +163,82 @@ def _update_last_seen(*args, **kwargs): # pylint: disable=unused-argument basic_ch.read_attributes.side_effect = _update_last_seen - for entity in zha_device.platform_entities.values(): + for entity in server_device.platform_entities.values(): entity.emit = mock.MagicMock(wraps=entity.emit) # we want to test the device availability handling alone - zha_gateway.global_updater.stop() + server_gateway.global_updater.stop() # successfully ping zigpy device, but zha_device is not yet available await _send_time_changed( - zha_gateway, zha_gateway._device_availability_checker.__polling_interval + 1 + zha_gateway, server_gateway._device_availability_checker.__polling_interval + 1 ) assert basic_ch.read_attributes.await_count == 1 assert basic_ch.read_attributes.await_args[0][0] == ["manufacturer"] assert zha_device.available is False - for entity in zha_device.platform_entities.values(): + for entity in server_device.platform_entities.values(): entity.emit.assert_not_called() assert not entity.available + if server_device != zha_device: + assert not zha_device.platform_entities[ + (entity.PLATFORM, entity.unique_id) + ].available entity.emit.reset_mock() # There was traffic from the device: pings, but not yet available await _send_time_changed( - zha_gateway, zha_gateway._device_availability_checker.__polling_interval + 1 + zha_gateway, server_gateway._device_availability_checker.__polling_interval + 1 ) assert basic_ch.read_attributes.await_count == 2 assert basic_ch.read_attributes.await_args[0][0] == ["manufacturer"] assert zha_device.available is False - for entity in zha_device.platform_entities.values(): + for entity in server_device.platform_entities.values(): entity.emit.assert_not_called() assert not entity.available + if server_device != zha_device: + assert not zha_device.platform_entities[ + (entity.PLATFORM, entity.unique_id) + ].available entity.emit.reset_mock() # There was traffic from the device: don't try to ping, marked as available await _send_time_changed( - zha_gateway, zha_gateway._device_availability_checker.__polling_interval + 1 + zha_gateway, server_gateway._device_availability_checker.__polling_interval + 1 ) assert basic_ch.read_attributes.await_count == 2 assert basic_ch.read_attributes.await_args[0][0] == ["manufacturer"] assert zha_device.available is True assert zha_device.on_network is True - for entity in zha_device.platform_entities.values(): + for entity in server_device.platform_entities.values(): entity.emit.assert_called() + if server_device != zha_device: + assert zha_device.platform_entities[ + (entity.PLATFORM, entity.unique_id) + ].available assert entity.available entity.emit.reset_mock() assert "Device is not on the network, marking unavailable" not in caplog.text - zha_device.on_network = False + server_gateway._device_availability_checker.stop() + + server_device.on_network = False + await zha_gateway.async_block_till_done(wait_background_tasks=True) assert zha_device.available is False assert zha_device.on_network is False assert "Device is not on the network, marking unavailable" in caplog.text - for entity in zha_device.platform_entities.values(): + for entity in server_device.platform_entities.values(): entity.emit.assert_called() assert not entity.available + if server_device != zha_device: + assert not zha_device.platform_entities[ + (entity.PLATFORM, entity.unique_id) + ].available entity.emit.reset_mock() @@ -202,6 +246,14 @@ def _update_last_seen(*args, **kwargs): # pylint: disable=unused-argument "zha.zigbee.cluster_handlers.general.BasicClusterHandler.async_initialize", new=mock.AsyncMock(), ) +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_check_available_unsuccessful( zha_gateway: Gateway, ) -> None: @@ -213,59 +265,78 @@ async def test_check_available_unsuccessful( zha_device = await join_zigpy_device(zha_gateway, device_with_basic_cluster_handler) basic_ch = device_with_basic_cluster_handler.endpoints[3].basic + if hasattr(zha_gateway, "ws_gateway"): + server_device = zha_gateway.ws_gateway.devices[zha_device.ieee] + server_gateway = zha_gateway.ws_gateway + else: + server_device = zha_device + server_gateway = zha_gateway + assert zha_device.available is True assert basic_ch.read_attributes.await_count == 0 device_with_basic_cluster_handler.last_seen = ( - time.time() - zha_device.consider_unavailable_time - 2 + time.time() - server_device.consider_unavailable_time - 2 ) - for entity in zha_device.platform_entities.values(): + for entity in server_device.platform_entities.values(): entity.emit = mock.MagicMock(wraps=entity.emit) # we want to test the device availability handling alone - zha_gateway.global_updater.stop() + server_gateway.global_updater.stop() # unsuccessfully ping zigpy device, but zha_device is still available await _send_time_changed( - zha_gateway, zha_gateway._device_availability_checker.__polling_interval + 1 + zha_gateway, server_gateway._device_availability_checker.__polling_interval + 1 ) assert basic_ch.read_attributes.await_count == 1 assert basic_ch.read_attributes.await_args[0][0] == ["manufacturer"] assert zha_device.available is True - for entity in zha_device.platform_entities.values(): + for entity in server_device.platform_entities.values(): entity.emit.assert_not_called() assert entity.available + if server_device != zha_device: + assert zha_device.platform_entities[ + (entity.PLATFORM, entity.unique_id) + ].available entity.emit.reset_mock() # still no traffic, but zha_device is still available await _send_time_changed( - zha_gateway, zha_gateway._device_availability_checker.__polling_interval + 1 + zha_gateway, server_gateway._device_availability_checker.__polling_interval + 1 ) assert basic_ch.read_attributes.await_count == 2 assert basic_ch.read_attributes.await_args[0][0] == ["manufacturer"] assert zha_device.available is True - for entity in zha_device.platform_entities.values(): + for entity in server_device.platform_entities.values(): entity.emit.assert_not_called() assert entity.available + if server_device != zha_device: + assert zha_device.platform_entities[ + (entity.PLATFORM, entity.unique_id) + ].available entity.emit.reset_mock() # not even trying to update, device is unavailable await _send_time_changed( - zha_gateway, zha_gateway._device_availability_checker.__polling_interval + 1 + zha_gateway, server_gateway._device_availability_checker.__polling_interval + 1 ) assert basic_ch.read_attributes.await_count == 2 assert basic_ch.read_attributes.await_args[0][0] == ["manufacturer"] assert zha_device.available is False - for entity in zha_device.platform_entities.values(): + for entity in server_device.platform_entities.values(): entity.emit.assert_called() assert not entity.available + if server_device != zha_device: + assert not zha_device.platform_entities[ + (entity.PLATFORM, entity.unique_id) + ].available entity.emit.reset_mock() @@ -705,9 +776,15 @@ async def test_device_automation_triggers( } -async def test_device_properties( - zha_gateway: Gateway, -) -> None: +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) +async def test_device_properties(zha_gateway: Gateway) -> None: """Test device properties.""" zigpy_dev = zigpy_device(zha_gateway, with_basic_cluster_handler=True) zha_device = await join_zigpy_device(zha_gateway, zigpy_dev) @@ -731,10 +808,107 @@ async def test_device_properties( assert zha_device.manufacturer == "FakeManufacturer" assert zha_device.model == "FakeModel" assert zha_device.is_groupable is False + assert zha_device.quirk_applied is False + assert zha_device.quirk_class == "zigpy.device.Device" + assert zha_device.quirk_id is None - assert zha_device.power_configuration_ch is None - assert zha_device.basic_ch is not None + assert zha_device.device_automation_commands == {} + assert zha_device.device_automation_triggers == { + ("device_offline", "device_offline"): {"device_event_type": "device_offline"} + } assert zha_device.sw_version is None + assert isinstance(zha_device.extended_device_info, ExtendedDeviceInfo) + assert zha_device.extended_device_info.manufacturer == "FakeManufacturer" + assert zha_device.extended_device_info.model == "FakeModel" + assert zha_device.extended_device_info.power_source == "Battery or Unknown" + assert zha_device.extended_device_info.device_type == "EndDevice" + assert zha_device.extended_device_info.ieee == zigpy_dev.ieee + assert zha_device.extended_device_info.nwk == zigpy_dev.nwk + assert zha_device.extended_device_info.manufacturer_code == 0x1037 + assert zha_device.extended_device_info.name == "FakeManufacturer FakeModel" + assert zha_device.extended_device_info.is_groupable is False + assert zha_device.extended_device_info.on_network is True + assert zha_device.extended_device_info.last_seen is not None + assert zha_device.extended_device_info.last_seen < time.time() + assert zha_device.extended_device_info.quirk_applied is False + assert zha_device.extended_device_info.quirk_class == "zigpy.device.Device" + assert zha_device.extended_device_info.quirk_id is None + assert zha_device.extended_device_info.sw_version is None + assert zha_device.extended_device_info.device_type == "EndDevice" + assert zha_device.extended_device_info.power_source == "Battery or Unknown" + assert zha_device.extended_device_info.last_seen_time is not None + assert zha_device.extended_device_info.available is True + assert zha_device.extended_device_info.lqi is None + assert zha_device.extended_device_info.rssi is None + + # TODO this needs to be fixed + if not hasattr(zha_gateway, "ws_gateway"): + assert zha_device.zigbee_signature == { + "endpoints": { + 3: { + "device_type": "0x0000", + "input_clusters": [ + "0x0000", + "0x0006", + ], + "output_clusters": [], + "profile_id": "", + }, + }, + "manufacturer": "FakeManufacturer", + "model": "FakeModel", + "node_descriptor": zdo_t.NodeDescriptor( + logical_type=zdo_t.LogicalType.EndDevice, + complex_descriptor_available=0, + user_descriptor_available=0, + reserved=0, + aps_flags=0, + frequency_band=zdo_t._NodeDescriptorEnums.FrequencyBand.Freq2400MHz, + mac_capability_flags=zdo_t._NodeDescriptorEnums.MACCapabilityFlags.AllocateAddress, + manufacturer_code=4151, + maximum_buffer_size=127, + maximum_incoming_transfer_size=100, + server_mask=10752, + maximum_outgoing_transfer_size=100, + descriptor_capability_field=zdo_t._NodeDescriptorEnums.DescriptorCapability.NONE, + ), + } + else: + assert zha_device.zigbee_signature == { + "endpoints": { + "3": { + "device_type": "0x0000", + "input_clusters": [ + "0x0000", + "0x0006", + ], + "output_clusters": [], + "profile_id": "", + }, + }, + "manufacturer": "FakeManufacturer", + "model": "FakeModel", + "node_descriptor": { + "aps_flags": 0, + "complex_descriptor_available": 0, + "descriptor_capability_field": 0, + "frequency_band": 8, + "logical_type": 2, + "mac_capability_flags": 128, + "manufacturer_code": 4151, + "maximum_buffer_size": 127, + "maximum_incoming_transfer_size": 100, + "maximum_outgoing_transfer_size": 100, + "reserved": 0, + "server_mask": 10752, + "user_descriptor_available": 0, + }, + } + + if not hasattr(zha_gateway, "ws_gateway"): + assert zha_device.power_configuration_ch is None + assert zha_device.basic_ch is not None + assert zha_device.sw_version is None assert len(zha_device.platform_entities) == 3 assert ( @@ -750,54 +924,58 @@ async def test_device_properties( "00:0d:6f:00:0a:90:69:e7-3-6", ) in zha_device.platform_entities - assert isinstance( - zha_device.platform_entities[ - (Platform.SENSOR, "00:0d:6f:00:0a:90:69:e7-3-0-lqi") - ], - LQISensor, - ) - assert isinstance( - zha_device.platform_entities[ - (Platform.SENSOR, "00:0d:6f:00:0a:90:69:e7-3-0-rssi") - ], - RSSISensor, - ) - assert isinstance( - zha_device.platform_entities[(Platform.SWITCH, "00:0d:6f:00:0a:90:69:e7-3-6")], - Switch, - ) + if not hasattr(zha_gateway, "ws_gateway"): + assert isinstance( + zha_device.platform_entities[ + (Platform.SENSOR, "00:0d:6f:00:0a:90:69:e7-3-0-lqi") + ], + LQISensor, + ) + assert isinstance( + zha_device.platform_entities[ + (Platform.SENSOR, "00:0d:6f:00:0a:90:69:e7-3-0-rssi") + ], + RSSISensor, + ) + assert isinstance( + zha_device.platform_entities[ + (Platform.SWITCH, "00:0d:6f:00:0a:90:69:e7-3-6") + ], + Switch, + ) - assert ( - zha_device.get_platform_entity( - Platform.SENSOR, "00:0d:6f:00:0a:90:69:e7-3-0-lqi" + assert ( + zha_device.get_platform_entity( + Platform.SENSOR, "00:0d:6f:00:0a:90:69:e7-3-0-lqi" + ) + is not None + ) + assert isinstance( + zha_device.get_platform_entity( + Platform.SENSOR, "00:0d:6f:00:0a:90:69:e7-3-0-lqi" + ), + LQISensor, ) - is not None - ) - assert isinstance( - zha_device.get_platform_entity( - Platform.SENSOR, "00:0d:6f:00:0a:90:69:e7-3-0-lqi" - ), - LQISensor, - ) - with pytest.raises(KeyError, match="Entity foo not found"): + with pytest.raises(KeyError, match="('bar', 'foo')"): zha_device.get_platform_entity("bar", "foo") - # test things are none when they aren't returned by Zigpy - zigpy_dev.node_desc = None - delattr(zha_device, "manufacturer_code") - delattr(zha_device, "is_mains_powered") - delattr(zha_device, "device_type") - delattr(zha_device, "is_router") - delattr(zha_device, "is_end_device") - delattr(zha_device, "is_coordinator") + if not hasattr(zha_gateway, "ws_gateway"): + # test things are none when they aren't returned by Zigpy + zigpy_dev.node_desc = None + delattr(zha_device, "manufacturer_code") + delattr(zha_device, "is_mains_powered") + delattr(zha_device, "device_type") + delattr(zha_device, "is_router") + delattr(zha_device, "is_end_device") + delattr(zha_device, "is_coordinator") - assert zha_device.manufacturer_code is None - assert zha_device.is_mains_powered is None - assert zha_device.device_type is UNKNOWN - assert zha_device.is_router is None - assert zha_device.is_end_device is None - assert zha_device.is_coordinator is None + assert zha_device.manufacturer_code is None + assert zha_device.is_mains_powered is None + assert zha_device.device_type is UNKNOWN + assert zha_device.is_router is None + assert zha_device.is_end_device is None + assert zha_device.is_coordinator is None async def test_quirks_v2_device_renaming(zha_gateway: Gateway) -> None: @@ -820,3 +998,139 @@ async def test_quirks_v2_device_renaming(zha_gateway: Gateway) -> None: zha_device = await join_zigpy_device(zha_gateway, zigpy_dev) assert zha_device.model == "IRIS Keypad V2" assert zha_device.manufacturer == "Lowe's" + + +def test_neighbor_info_ser_deser() -> None: + """Test the serialization and deserialization of the neighbor info.""" + + neighbor_info = NeighborInfo( + ieee="00:0d:6f:00:0a:90:69:e7", + nwk="0x1234", + extended_pan_id="00:0d:6f:00:0a:90:69:e7", + lqi=255, + relationship=zdo_t._NeighborEnums.Relationship.Child.name, + depth=0, + device_type=zdo_t._NeighborEnums.DeviceType.Router.name, + rx_on_when_idle=zdo_t._NeighborEnums.RxOnWhenIdle.On.name, + permit_joining=zdo_t._NeighborEnums.PermitJoins.Accepting.name, + ) + + assert isinstance(neighbor_info.ieee, zigpy.types.EUI64) + assert isinstance(neighbor_info.nwk, zigpy.types.NWK) + assert isinstance(neighbor_info.extended_pan_id, zigpy.types.EUI64) + assert isinstance(neighbor_info.relationship, zdo_t._NeighborEnums.Relationship) + assert isinstance(neighbor_info.device_type, zdo_t._NeighborEnums.DeviceType) + assert isinstance(neighbor_info.rx_on_when_idle, zdo_t._NeighborEnums.RxOnWhenIdle) + assert isinstance(neighbor_info.permit_joining, zdo_t._NeighborEnums.PermitJoins) + + assert neighbor_info.model_dump() == { + "ieee": "00:0d:6f:00:0a:90:69:e7", + "nwk": 0x1234, + "extended_pan_id": "00:0d:6f:00:0a:90:69:e7", + "lqi": 255, + "relationship": zdo_t._NeighborEnums.Relationship.Child.name, + "depth": 0, + "device_type": zdo_t._NeighborEnums.DeviceType.Router.name, + "rx_on_when_idle": zdo_t._NeighborEnums.RxOnWhenIdle.On.name, + "permit_joining": zdo_t._NeighborEnums.PermitJoins.Accepting.name, + } + + assert neighbor_info.model_dump_json() == ( + '{"device_type":"Router","rx_on_when_idle":"On","relationship":"Child",' + '"extended_pan_id":"00:0d:6f:00:0a:90:69:e7","ieee":"00:0d:6f:00:0a:90:69:e7","nwk":"0x1234",' + '"permit_joining":"Accepting","depth":0,"lqi":255}' + ) + + +def test_route_info_ser_deser() -> None: + """Test the serialization and deserialization of the route info.""" + + route_info = RouteInfo( + dest_nwk=0x1234, + next_hop=0x5678, + route_status=zdo_t.RouteStatus.Active.name, + memory_constrained=0, + many_to_one=1, + route_record_required=1, + ) + + assert isinstance(route_info.dest_nwk, zigpy.types.NWK) + assert isinstance(route_info.next_hop, zigpy.types.NWK) + assert isinstance(route_info.route_status, zdo_t.RouteStatus) + + assert route_info.model_dump() == { + "dest_nwk": 0x1234, + "next_hop": 0x5678, + "route_status": zdo_t.RouteStatus.Active.name, + "memory_constrained": 0, + "many_to_one": 1, + "route_record_required": 1, + } + + assert route_info.model_dump_json() == ( + '{"dest_nwk":"0x1234","route_status":"Active","memory_constrained":0,"many_to_one":1,' + '"route_record_required":1,"next_hop":"0x5678"}' + ) + + +def test_convert_extended_pan_id() -> None: + """Test conversion of extended panid.""" + + extended_pan_id = zigpy.types.ExtendedPanId.convert("00:0d:6f:00:0a:90:69:e7") + + assert NeighborInfo.convert_extended_pan_id(extended_pan_id) == extended_pan_id + + converted_extended_pan_id = NeighborInfo.convert_extended_pan_id( + "00:0d:6f:00:0a:90:69:e7" + ) + assert isinstance(converted_extended_pan_id, zigpy.types.ExtendedPanId) + assert converted_extended_pan_id == extended_pan_id + + +async def test_extended_device_info_ser_deser(zha_gateway: Gateway) -> None: + """Test the serialization and deserialization of the extended device info.""" + + zigpy_dev = await zigpy_device_from_json( + zha_gateway.application_controller, "tests/data/devices/centralite-3320-l.json" + ) + zha_device = await join_zigpy_device(zha_gateway, zigpy_dev) + assert zha_device is not None + + assert isinstance(zha_device.extended_device_info.ieee, zigpy.types.EUI64) + assert isinstance(zha_device.extended_device_info.nwk, zigpy.types.NWK) + + # last_seen changes so we exclude it from the comparison + json_string = zha_device.extended_device_info.model_dump_json( + exclude=["last_seen", "last_seen_time"] + ) + + # load the json from a file as string + with open( + "tests/data/serialization_data/centralite-3320-l-extended-device-info.json", + encoding="UTF-8", + ) as file: + expected_json = file.read() + + assert deep_compare(json.loads(json_string), json.loads(expected_json)) + + +def deep_compare(obj1, obj2): + """Recursively compare two objects.""" + if isinstance(obj1, Mapping) and isinstance(obj2, Mapping): + # Compare dictionaries (order of keys doesn't matter) + if obj1.keys() != obj2.keys(): + return False + return all(deep_compare(obj1[key], obj2[key]) for key in obj1) + + elif ( + isinstance(obj1, Sequence) + and isinstance(obj2, Sequence) + and not isinstance(obj1, str) + ): + # Compare lists or other sequences as sets, ignoring order + return len(obj1) == len(obj2) and all( + any(deep_compare(item1, item2) for item2 in obj2) for item1 in obj1 + ) + + # Base case: compare values directly + return obj1 == obj2 diff --git a/tests/test_device_tracker.py b/tests/test_device_tracker.py index 1612f937a..5b9c39f12 100644 --- a/tests/test_device_tracker.py +++ b/tests/test_device_tracker.py @@ -4,6 +4,7 @@ import time from unittest.mock import AsyncMock +import pytest import zigpy.profiles.zha from zigpy.zcl.clusters import general @@ -19,10 +20,19 @@ ) from zha.application import Platform from zha.application.gateway import Gateway +from zha.application.platforms import WebSocketClientEntity from zha.application.platforms.device_tracker import SourceType from zha.application.registries import SMARTTHINGS_ARRIVAL_SENSOR_DEVICE_TYPE +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_device_tracker( zha_gateway: Gateway, ) -> None: @@ -55,11 +65,27 @@ async def test_device_tracker( zha_gateway, cluster, {0x0000: 0, 0x0020: 23, 0x0021: 200, 0x0001: 2} ) - entity.async_update = AsyncMock(wraps=entity.async_update) + if isinstance(entity, WebSocketClientEntity): + server_entity = get_entity( + zha_gateway.ws_gateway.devices[zha_device.ieee], + platform=Platform.DEVICE_TRACKER, + ) + original_async_update = server_entity.async_update + server_entity.async_update = AsyncMock(wraps=server_entity.async_update) + async_update_mock = server_entity.async_update + else: + entity.async_update = AsyncMock(wraps=entity.async_update) + async_update_mock = entity.async_update + + async_update_mock.reset_mock() zigpy_device_dt.last_seen = time.time() + 10 await asyncio.sleep(48) await zha_gateway.async_block_till_done() - assert entity.async_update.await_count == 1 + assert async_update_mock.await_count == 1 + + # this is because of the argspec stuff w/ WS calls... Look for a better solution + if isinstance(entity, WebSocketClientEntity): + server_entity.async_update = original_async_update assert entity.state["connected"] is True assert entity.is_connected is True diff --git a/tests/test_discover.py b/tests/test_discover.py index 7455849e0..83c31c65b 100644 --- a/tests/test_discover.py +++ b/tests/test_discover.py @@ -52,7 +52,7 @@ from zha.application import Platform, discovery from zha.application.discovery import ENDPOINT_PROBE, EndpointProbe from zha.application.gateway import Gateway -from zha.application.helpers import DeviceOverridesConfiguration +from zha.application.model import DeviceOverridesConfiguration from zha.application.platforms import binary_sensor, sensor from zha.application.registries import SINGLE_INPUT_CLUSTER_DEVICE_CLASS from zha.zigbee.cluster_handlers import ClusterHandler diff --git a/tests/test_fan.py b/tests/test_fan.py index 8bc96c5f5..b4c08d0aa 100644 --- a/tests/test_fan.py +++ b/tests/test_fan.py @@ -30,7 +30,7 @@ ) from zha.application import Platform from zha.application.gateway import Gateway -from zha.application.platforms import GroupEntity, PlatformEntity +from zha.application.platforms import GroupEntity, PlatformEntity, WebSocketClientEntity from zha.application.platforms.fan.const import ( ATTR_PERCENTAGE, ATTR_PRESET_MODE, @@ -41,6 +41,7 @@ SPEED_LOW, SPEED_MEDIUM, SPEED_OFF, + FanEntityFeature, ) from zha.application.platforms.fan.helpers import NotValidPresetModeError from zha.exceptions import ZHAException @@ -135,6 +136,14 @@ async def device_fan_2_mock(zha_gateway: Gateway) -> Device: return zha_device +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_fan( zha_gateway: Gateway, ) -> None: @@ -146,14 +155,36 @@ async def test_fan( entity = get_entity(zha_device, platform=Platform.FAN) assert entity.state["is_on"] is False + assert entity.is_on is False + + assert entity.preset_modes == [PRESET_MODE_ON, PRESET_MODE_AUTO, PRESET_MODE_SMART] + assert entity.speed_list == [ + SPEED_OFF, + SPEED_LOW, + SPEED_MEDIUM, + SPEED_HIGH, + PRESET_MODE_ON, + PRESET_MODE_AUTO, + PRESET_MODE_SMART, + ] + assert entity.speed_count == 3 + assert entity.default_on_percentage == 50 + assert ( + entity.supported_features + == FanEntityFeature.SET_SPEED + | FanEntityFeature.TURN_OFF + | FanEntityFeature.TURN_ON + ) # turn on at fan await send_attributes_report(zha_gateway, cluster, {1: 2, 0: 1, 2: 3}) assert entity.state["is_on"] is True + assert entity.is_on is True # turn off at fan await send_attributes_report(zha_gateway, cluster, {1: 1, 0: 0, 2: 2}) assert entity.state["is_on"] is False + assert entity.is_on is False # turn on from client cluster.write_attributes.reset_mock() @@ -163,6 +194,7 @@ async def test_fan( {"fan_mode": 2}, manufacturer=None ) assert entity.state["is_on"] is True + assert entity.is_on is True # turn off from client cluster.write_attributes.reset_mock() @@ -182,6 +214,7 @@ async def test_fan( ) assert entity.state["is_on"] is True assert entity.state["speed"] == SPEED_HIGH + assert entity.speed == SPEED_HIGH # change preset_mode from client cluster.write_attributes.reset_mock() @@ -192,6 +225,7 @@ async def test_fan( ) assert entity.state["is_on"] is True assert entity.state["preset_mode"] == PRESET_MODE_ON + assert entity.preset_mode == PRESET_MODE_ON # test set percentage from client cluster.write_attributes.reset_mock() @@ -203,12 +237,18 @@ async def test_fan( ) # this is converted to a ranged value assert entity.state["percentage"] == 66 + assert entity.percentage == 66 assert entity.state["is_on"] is True # set invalid preset_mode from client cluster.write_attributes.reset_mock() - with pytest.raises(NotValidPresetModeError): + exception = ( + ZHAException + if isinstance(entity, WebSocketClientEntity) + else NotValidPresetModeError + ) + with pytest.raises(exception): await entity.async_set_preset_mode("invalid") assert len(cluster.write_attributes.mock_calls) == 0 @@ -274,10 +314,19 @@ async def async_set_preset_mode( "zigpy.zcl.clusters.hvac.Fan.write_attributes", new=AsyncMock(return_value=zcl_f.WriteAttributesResponse.deserialize(b"\x00")[0]), ) +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_zha_group_fan_entity( zha_gateway: Gateway, ): """Test the fan entity for a ZHAWS group.""" + device_fan_1 = await device_fan_1_mock(zha_gateway) device_fan_2 = await device_fan_2_mock(zha_gateway) member_ieee_addresses = [device_fan_1.ieee, device_fan_2.ieee] @@ -286,8 +335,7 @@ async def test_zha_group_fan_entity( GroupMemberReference(ieee=device_fan_2.ieee, endpoint_id=1), ] - # test creating a group with 2 members - zha_group: Group = await zha_gateway.async_create_zigpy_group("Test Group", members) + zha_group = await zha_gateway.async_create_zigpy_group("Test Group", members) await zha_gateway.async_block_till_done() assert zha_group is not None @@ -295,18 +343,37 @@ async def test_zha_group_fan_entity( for member in zha_group.members: assert member.device.ieee in member_ieee_addresses assert member.group == zha_group - assert member.endpoint is not None + if not hasattr( + zha_gateway, "ws_gateway" + ): # we only have / need this on the server side + assert member.endpoint is not None + assert member.endpoint_id == 1 entity: GroupEntity = get_group_entity(zha_group, platform=Platform.FAN) assert entity.group_id == zha_group.group_id - assert isinstance(entity, GroupEntity) + assert isinstance( + entity, + GroupEntity + if not hasattr(zha_gateway, "ws_gateway") + else WebSocketClientEntity, + ) assert entity.info_object.fallback_name == zha_group.name - group_fan_cluster = zha_group.zigpy_group.endpoint[hvac.Fan.cluster_id] - - dev1_fan_cluster = device_fan_1.device.endpoints[1].fan - dev2_fan_cluster = device_fan_2.device.endpoints[1].fan + if not hasattr(zha_gateway, "ws_gateway"): + group_fan_cluster = zha_group.zigpy_group.endpoint[hvac.Fan.cluster_id] + dev1_fan_cluster = device_fan_1.device.endpoints[1].fan + dev2_fan_cluster = device_fan_2.device.endpoints[1].fan + else: + group_fan_cluster = zha_gateway.ws_gateway.groups[ + zha_group.group_id + ].zigpy_group.endpoint[hvac.Fan.cluster_id] + dev1_fan_cluster = ( + zha_gateway.ws_gateway.devices[device_fan_1.ieee].device.endpoints[1].fan + ) + dev2_fan_cluster = ( + zha_gateway.ws_gateway.devices[device_fan_2.ieee].device.endpoints[1].fan + ) # test that the fan group entity was created and is off assert entity.state["is_on"] is False @@ -380,11 +447,27 @@ async def test_zha_group_fan_entity( # test that group fan is now off assert entity.state["is_on"] is False - await group_entity_availability_test( - zha_gateway, device_fan_1, device_fan_2, entity - ) + if not hasattr(zha_gateway, "ws_gateway"): + await group_entity_availability_test( + zha_gateway, device_fan_1, device_fan_2, entity + ) + else: + await group_entity_availability_test( + zha_gateway, + zha_gateway.ws_gateway.devices[device_fan_1.ieee], + zha_gateway.ws_gateway.devices[device_fan_2.ieee], + entity, + ) +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) @patch( "zigpy.zcl.clusters.hvac.Fan.write_attributes", new=AsyncMock(side_effect=ZigbeeException), @@ -412,19 +495,32 @@ async def test_zha_group_fan_entity_failure_state( for member in zha_group.members: assert member.device.ieee in member_ieee_addresses assert member.group == zha_group - assert member.endpoint is not None + if not hasattr( + zha_gateway, "ws_gateway" + ): # we only have / need this on the server side + assert member.endpoint is not None entity: GroupEntity = get_group_entity(zha_group, platform=Platform.FAN) assert entity.group_id == zha_group.group_id - group_fan_cluster = zha_group.zigpy_group.endpoint[hvac.Fan.cluster_id] + if not hasattr(zha_gateway, "ws_gateway"): + group_fan_cluster = zha_group.zigpy_group.endpoint[hvac.Fan.cluster_id] + else: + group_fan_cluster = zha_gateway.ws_gateway.groups[ + zha_group.group_id + ].zigpy_group.endpoint[hvac.Fan.cluster_id] # test that the fan group entity was created and is off assert entity.state["is_on"] is False # turn on from client group_fan_cluster.write_attributes.reset_mock() - with pytest.raises(ZHAException, match="Failed to send request"): + with pytest.raises( + ZHAException, + match="Failed to send request" + if not hasattr(zha_gateway, "ws_gateway") + else "(6, 'PLATFORM_ENTITY_ACTION_ERROR')", + ): await async_turn_on(zha_gateway, entity) await zha_gateway.async_block_till_done() assert len(group_fan_cluster.write_attributes.mock_calls) == 1 @@ -432,6 +528,14 @@ async def test_zha_group_fan_entity_failure_state( assert "Could not set fan mode" in caplog.text +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) @pytest.mark.parametrize( "plug_read, expected_state, expected_speed, expected_percentage", ( @@ -464,6 +568,14 @@ async def test_fan_init( assert entity.state["preset_mode"] is None +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_fan_update_entity( zha_gateway: Gateway, ): @@ -544,10 +656,19 @@ def zigpy_device_ikea_mock(zha_gateway: Gateway) -> ZigpyDevice: ) +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_fan_ikea( zha_gateway: Gateway, ) -> None: """Test ZHA fan Ikea platform.""" + zigpy_device_ikea = zigpy_device_ikea_mock(zha_gateway) zha_device = await join_zigpy_device(zha_gateway, zigpy_device_ikea) cluster = zigpy_device_ikea.endpoints.get(1).ikea_airpurifier @@ -607,7 +728,12 @@ async def test_fan_ikea( # set invalid preset_mode from HA cluster.write_attributes.reset_mock() - with pytest.raises(NotValidPresetModeError): + exception = ( + ZHAException + if isinstance(entity, WebSocketClientEntity) + else NotValidPresetModeError + ) + with pytest.raises(exception): await async_set_preset_mode( zha_gateway, entity, @@ -616,6 +742,14 @@ async def test_fan_ikea( assert len(cluster.write_attributes.mock_calls) == 0 +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) @pytest.mark.parametrize( ( "ikea_plug_read", @@ -646,6 +780,7 @@ async def test_fan_ikea_init( zha_gateway: Gateway, ) -> None: """Test ZHA fan platform.""" + zigpy_device_ikea = zigpy_device_ikea_mock(zha_gateway) cluster = zigpy_device_ikea.endpoints.get(1).ikea_airpurifier cluster.PLUGGED_ATTR_READS = ikea_plug_read @@ -657,10 +792,19 @@ async def test_fan_ikea_init( assert entity.state["preset_mode"] == ikea_preset_mode +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_fan_ikea_update_entity( zha_gateway: Gateway, ) -> None: """Test ZHA fan platform.""" + zigpy_device_ikea = zigpy_device_ikea_mock(zha_gateway) cluster = zigpy_device_ikea.endpoints.get(1).ikea_airpurifier cluster.PLUGGED_ATTR_READS = {"fan_mode": 0, "fan_speed": 0} @@ -680,7 +824,7 @@ async def test_fan_ikea_update_entity( assert entity.state["is_on"] is True assert entity.state[ATTR_PERCENTAGE] == 60 - assert entity.state[ATTR_PRESET_MODE] is PRESET_MODE_AUTO + assert entity.state[ATTR_PRESET_MODE] == PRESET_MODE_AUTO assert entity.percentage_step == 100 / 10 @@ -728,10 +872,19 @@ def zigpy_device_kof_mock(zha_gateway: Gateway) -> ZigpyDevice: ) +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_fan_kof( zha_gateway: Gateway, ) -> None: """Test ZHA fan platform for King of Fans.""" + zigpy_device_kof = zigpy_device_kof_mock(zha_gateway) zha_device = await join_zigpy_device(zha_gateway, zigpy_device_kof) cluster = zigpy_device_kof.endpoints.get(1).fan @@ -777,13 +930,31 @@ async def test_fan_kof( # set invalid preset_mode from HA cluster.write_attributes.reset_mock() - with pytest.raises(NotValidPresetModeError): + exception = ( + ZHAException + if isinstance(entity, WebSocketClientEntity) + else NotValidPresetModeError + ) + with pytest.raises(exception): await async_set_preset_mode(zha_gateway, entity, preset_mode=PRESET_MODE_AUTO) assert len(cluster.write_attributes.mock_calls) == 0 @pytest.mark.parametrize( - ("plug_read", "expected_state", "expected_percentage", "expected_preset"), + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) +@pytest.mark.parametrize( + ( + "plug_read", + "expected_state", + "expected_percentage", + "expected_preset", + ), [ (None, False, None, None), ({"fan_mode": 0}, False, 0, None), @@ -815,6 +986,14 @@ async def test_fan_kof_init( assert entity.state[ATTR_PRESET_MODE] == expected_preset +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_fan_kof_update_entity( zha_gateway: Gateway, ) -> None: diff --git a/tests/test_gateway.py b/tests/test_gateway.py index 63ad41988..d0c010f8d 100644 --- a/tests/test_gateway.py +++ b/tests/test_gateway.py @@ -23,12 +23,7 @@ join_zigpy_device, ) from zha.application import Platform -from zha.application.const import ( - CONF_USE_THREAD, - ZHA_GW_MSG, - ZHA_GW_MSG_CONNECTION_LOST, - RadioType, -) +from zha.application.const import CONF_USE_THREAD, ZHA_GW_MSG_CONNECTION_LOST, RadioType from zha.application.gateway import ( ConnectionLostEvent, DeviceJoinedDeviceInfo, @@ -38,7 +33,7 @@ RawDeviceInitializedDeviceInfo, RawDeviceInitializedEvent, ) -from zha.application.helpers import ZHAData +from zha.application.model import ZHAData from zha.application.platforms import GroupEntity from zha.application.platforms.light.const import EFFECT_OFF, LightEntityFeature from zha.zigbee.device import Device @@ -72,7 +67,7 @@ async def coordinator_mock(zha_gateway: Gateway) -> Device: } }, ieee="00:15:8d:00:02:32:4f:32", - nwk=0x0000, + nwk=zigpy.types.NWK(0x0000), node_descriptor=zdo_t.NodeDescriptor( logical_type=zdo_t.LogicalType.Coordinator, complex_descriptor_available=0, @@ -507,7 +502,7 @@ async def test_startup_concurrency_limit( } }, ieee=f"11:22:33:44:{i:08x}", - nwk=0x1234 + i, + nwk=zigpy.types.NWK(0x1234 + i), ) zigpy_dev.node_desc.mac_capability_flags |= ( zigpy.zdo.types.NodeDescriptor.MACCapabilityFlags.MainsPowered @@ -615,14 +610,14 @@ def test_gateway_raw_device_initialized( RawDeviceInitializedEvent( device_info=RawDeviceInitializedDeviceInfo( ieee=zigpy.types.EUI64.convert("00:0d:6f:00:0a:90:69:e7"), - nwk=0xB79C, + nwk=zigpy.types.NWK(0xB79C), pairing_status=DevicePairingStatus.INTERVIEW_COMPLETE, model="FakeModel", manufacturer="FakeManufacturer", signature={ "manufacturer": "FakeManufacturer", "model": "FakeModel", - "node_desc": { + "node_descriptor": { "logical_type": LogicalType.EndDevice, "complex_descriptor_available": 0, "user_descriptor_available": 0, @@ -639,16 +634,14 @@ def test_gateway_raw_device_initialized( }, "endpoints": { 1: { - "profile_id": 260, - "device_type": zha.DeviceType.ON_OFF_SWITCH, - "input_clusters": [0], + "profile_id": "0x0104", + "device_type": "0x0000", + "input_clusters": ["0x0000"], "output_clusters": [], } }, }, - ), - event_type="zha_gateway_message", - event="raw_device_initialized", + ) ), ) @@ -668,7 +661,7 @@ def test_gateway_device_joined( DeviceJoinedEvent( device_info=DeviceJoinedDeviceInfo( ieee=zigpy.types.EUI64.convert("00:0d:6f:00:0a:90:69:e7"), - nwk=0xB79C, + nwk=zigpy.types.NWK(0xB79C), pairing_status=DevicePairingStatus.PAIRED, ) ), @@ -687,8 +680,6 @@ def test_gateway_connection_lost(zha_gateway: Gateway) -> None: ZHA_GW_MSG_CONNECTION_LOST, ConnectionLostEvent( exception=exception, - event=ZHA_GW_MSG_CONNECTION_LOST, - event_type=ZHA_GW_MSG, ), ) diff --git a/tests/test_light.py b/tests/test_light.py index 952994238..55fac086a 100644 --- a/tests/test_light.py +++ b/tests/test_light.py @@ -30,15 +30,18 @@ ) from zha.application import Platform from zha.application.gateway import Gateway -from zha.application.platforms import GroupEntity, PlatformEntity +from zha.application.platforms import GroupEntity, PlatformEntity, WebSocketClientEntity from zha.application.platforms.light.const import ( + EFFECT_COLORLOOP, + EFFECT_OFF, FLASH_EFFECTS, FLASH_LONG, FLASH_SHORT, ColorMode, + LightEntityFeature, ) from zha.zigbee.device import Device -from zha.zigbee.group import Group, GroupMemberReference +from zha.zigbee.group import GroupMemberReference ON = 1 OFF = 0 @@ -276,10 +279,19 @@ async def eWeLink_light_mock( return zha_device +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_light_refresh( zha_gateway: Gateway, ): """Test zha light platform refresh.""" + zigpy_device = create_mock_zigpy_device(zha_gateway, LIGHT_ON_OFF) on_off_cluster = zigpy_device.endpoints[1].on_off on_off_cluster.PLUGGED_ATTR_READS = {"on_off": 0} @@ -296,6 +308,7 @@ async def test_light_refresh( assert on_off_cluster.read_attributes.call_count == 0 assert on_off_cluster.read_attributes.await_count == 0 assert bool(entity.state["on"]) is False + assert entity.is_on is False # 1 interval - at least 1 call on_off_cluster.PLUGGED_ATTR_READS = {"on_off": 1} @@ -317,6 +330,7 @@ async def test_light_refresh( read_await_count = on_off_cluster.read_attributes.await_count entity.disable() + await zha_gateway.async_block_till_done() assert entity.enabled is False @@ -328,6 +342,7 @@ async def test_light_refresh( assert bool(entity.state["on"]) is False entity.enable() + await zha_gateway.async_block_till_done() assert entity.enabled is True @@ -338,6 +353,14 @@ async def test_light_refresh( assert bool(entity.state["on"]) is True +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) # TODO reporting is not checked @patch( "zigpy.zcl.clusters.lighting.Color.request", @@ -357,7 +380,11 @@ async def test_light_refresh( ) @pytest.mark.parametrize( "device, reporting", - [(LIGHT_ON_OFF, (1, 0, 0)), (LIGHT_LEVEL, (1, 1, 0)), (LIGHT_COLOR, (1, 1, 3))], + [ + (LIGHT_ON_OFF, (1, 0, 0)), + (LIGHT_LEVEL, (1, 1, 0)), + (LIGHT_COLOR, (1, 1, 3)), + ], ) async def test_light( zha_gateway: Gateway, @@ -462,12 +489,12 @@ async def test_light( cluster_color.request.reset_mock() # test color xy from the client - assert entity.state["xy_color"] != [13369, 18087] - await entity.async_turn_on(brightness=50, xy_color=[13369, 18087]) + assert entity.state["xy_color"] != (13369, 18087) + await entity.async_turn_on(brightness=50, xy_color=(13369, 18087)) await zha_gateway.async_block_till_done() assert entity.state["color_mode"] == ColorMode.XY assert entity.state["brightness"] == 50 - assert entity.state["xy_color"] == [13369, 18087] + assert entity.state["xy_color"] == (13369, 18087) assert cluster_color.request.call_count == 1 assert cluster_color.request.await_count == 1 assert cluster_color.request.call_args == call( @@ -496,9 +523,12 @@ async def async_test_on_off_from_light( await zha_gateway.async_block_till_done() # group member updates are debounced - if isinstance(entity, GroupEntity): + if isinstance(entity, GroupEntity) or ( + isinstance(entity, WebSocketClientEntity) + and "Group" in entity.info_object.class_name + ): assert bool(entity.state["on"]) is False - await asyncio.sleep(0.1) + await asyncio.sleep(1) await zha_gateway.async_block_till_done() assert bool(entity.state["on"]) is True @@ -508,9 +538,12 @@ async def async_test_on_off_from_light( await zha_gateway.async_block_till_done() # group member updates are debounced - if isinstance(entity, GroupEntity): + if isinstance(entity, GroupEntity) or ( + isinstance(entity, WebSocketClientEntity) + and "Group" in entity.info_object.class_name + ): assert bool(entity.state["on"]) is True - await asyncio.sleep(0.1) + await asyncio.sleep(1) await zha_gateway.async_block_till_done() assert bool(entity.state["on"]) is False @@ -529,9 +562,12 @@ async def async_test_on_from_light( await zha_gateway.async_block_till_done() # group member updates are debounced - if isinstance(entity, GroupEntity): + if isinstance(entity, GroupEntity) or ( + isinstance(entity, WebSocketClientEntity) + and "Group" in entity.info_object.class_name + ): assert bool(entity.state["on"]) is False - await asyncio.sleep(0.1) + await asyncio.sleep(1) await zha_gateway.async_block_till_done() assert bool(entity.state["on"]) is True @@ -548,6 +584,7 @@ async def async_test_on_off_from_client( await entity.async_turn_on() await zha_gateway.async_block_till_done() assert bool(entity.state["on"]) is True + assert entity.is_on assert cluster.request.call_count == 1 assert cluster.request.await_count == 1 assert cluster.request.call_args == call( @@ -693,12 +730,17 @@ async def async_test_dimmer_from_light( # hass uses None for brightness of 0 in state attributes if level == 0: assert entity.state["brightness"] is None + assert entity.brightness is None else: # group member updates are debounced - if isinstance(entity, GroupEntity): - await asyncio.sleep(0.1) + if isinstance(entity, GroupEntity) or ( + isinstance(entity, WebSocketClientEntity) + and "Group" in entity.info_object.class_name + ): + await asyncio.sleep(1) await zha_gateway.async_block_till_done() assert entity.state["brightness"] == level + assert entity.brightness == level async def async_test_flash_from_client( @@ -743,6 +785,14 @@ async def async_test_flash_from_client( "zigpy.zcl.clusters.general.OnOff.request", new=AsyncMock(return_value=[sentinel.data, zcl_f.Status.SUCCESS]), ) +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_zha_group_light_entity( zha_gateway: Gateway, ) -> None: @@ -760,7 +810,7 @@ async def test_zha_group_light_entity( ] # test creating a group with 2 members - zha_group: Group = await zha_gateway.async_create_zigpy_group("Test Group", members) + zha_group = await zha_gateway.async_create_zigpy_group("Test Group", members) await zha_gateway.async_block_till_done() assert zha_group is not None @@ -768,7 +818,9 @@ async def test_zha_group_light_entity( for member in zha_group.members: assert member.device.ieee in member_ieee_addresses assert member.group == zha_group - assert member.endpoint is not None + if not hasattr(zha_gateway, "ws_gateway"): + assert member.endpoint is not None + assert member.endpoint_id == 1 entity: GroupEntity = get_group_entity(zha_group, platform=Platform.LIGHT) assert entity.group_id == zha_group.group_id @@ -785,18 +837,53 @@ async def test_zha_group_light_entity( assert device_2_light_entity.unique_id in zha_group.all_member_entity_unique_ids assert device_3_light_entity.unique_id not in zha_group.all_member_entity_unique_ids - group_cluster_on_off = zha_group.zigpy_group.endpoint[general.OnOff.cluster_id] - group_cluster_level = zha_group.zigpy_group.endpoint[ - general.LevelControl.cluster_id - ] - group_cluster_identify = zha_group.zigpy_group.endpoint[general.Identify.cluster_id] - assert group_cluster_identify is not None + if not hasattr(zha_gateway, "ws_gateway"): + group_cluster_on_off = zha_group.zigpy_group.endpoint[general.OnOff.cluster_id] + group_cluster_level = zha_group.zigpy_group.endpoint[ + general.LevelControl.cluster_id + ] + group_cluster_identify = zha_group.zigpy_group.endpoint[ + general.Identify.cluster_id + ] + assert group_cluster_identify is not None - dev1_cluster_on_off = device_light_1.device.endpoints[1].on_off - dev2_cluster_on_off = device_light_2.device.endpoints[1].on_off - dev3_cluster_on_off = device_light_3.device.endpoints[1].on_off + dev1_cluster_on_off = device_light_1.device.endpoints[1].on_off + dev1_cluster_level = device_light_1.device.endpoints[1].level + + dev2_cluster_on_off = device_light_2.device.endpoints[1].on_off + dev3_cluster_on_off = device_light_3.device.endpoints[1].on_off + else: + group_cluster_on_off = zha_gateway.ws_gateway.groups[ + zha_group.group_id + ].endpoint[general.OnOff.cluster_id] + group_cluster_level = zha_gateway.ws_gateway.groups[ + zha_group.group_id + ].endpoint[general.LevelControl.cluster_id] + group_cluster_identify = zha_gateway.ws_gateway.groups[ + zha_group.group_id + ].endpoint[general.Identify.cluster_id] + assert group_cluster_identify is not None + dev1_cluster_on_off = ( + zha_gateway.ws_gateway.devices[device_light_1.ieee] + .device.endpoints[1] + .on_off + ) + dev1_cluster_level = ( + zha_gateway.ws_gateway.devices[device_light_1.ieee] + .device.endpoints[1] + .level + ) - dev1_cluster_level = device_light_1.device.endpoints[1].level + dev2_cluster_on_off = ( + zha_gateway.ws_gateway.devices[device_light_2.ieee] + .device.endpoints[1] + .on_off + ) + dev3_cluster_on_off = ( + zha_gateway.ws_gateway.devices[device_light_3.ieee] + .device.endpoints[1] + .on_off + ) # test that the lights were created and are off assert bool(entity.state["on"]) is False @@ -813,6 +900,7 @@ async def test_zha_group_light_entity( color_mode=ColorMode.XY, effect="colorloop", ) + await zha_gateway.async_block_till_done() assert bool(entity.state["on"]) is False assert bool(entity.state["off_with_transition"]) is False @@ -881,7 +969,7 @@ async def test_zha_group_light_entity( # group member updates are debounced assert bool(entity.state["on"]) is True - await asyncio.sleep(0.1) + await asyncio.sleep(1) await zha_gateway.async_block_till_done() assert bool(entity.state["on"]) is False @@ -893,13 +981,21 @@ async def test_zha_group_light_entity( assert device_2_light_entity.state["on"] is False # group member updates are debounced assert bool(entity.state["on"]) is False - await asyncio.sleep(0.1) + await asyncio.sleep(1) await zha_gateway.async_block_till_done() assert bool(entity.state["on"]) is True - await group_entity_availability_test( - zha_gateway, device_light_1, device_light_2, entity - ) + if not hasattr(zha_gateway, "ws_gateway"): + await group_entity_availability_test( + zha_gateway, device_light_1, device_light_2, entity + ) + else: + await group_entity_availability_test( + zha_gateway, + zha_gateway.ws_gateway.devices[device_light_1.ieee], + zha_gateway.ws_gateway.devices[device_light_2.ieee], + entity, + ) # turn it off to test a new member add being tracked await send_attributes_report(zha_gateway, dev1_cluster_on_off, {0: 0}) @@ -908,7 +1004,7 @@ async def test_zha_group_light_entity( assert device_2_light_entity.state["on"] is False # group member updates are debounced assert bool(entity.state["on"]) is True - await asyncio.sleep(0.1) + await asyncio.sleep(1) await zha_gateway.async_block_till_done() assert bool(entity.state["on"]) is False @@ -929,7 +1025,7 @@ async def test_zha_group_light_entity( assert device_3_light_entity.state["on"] is True # group member updates are debounced assert bool(entity.state["on"]) is False - await asyncio.sleep(0.1) + await asyncio.sleep(1) await zha_gateway.async_block_till_done() assert bool(entity.state["on"]) is True @@ -969,7 +1065,7 @@ async def test_zha_group_light_entity( await zha_gateway.async_block_till_done() # group member updates are debounced assert bool(entity.state["on"]) is True - await asyncio.sleep(0.1) + await asyncio.sleep(1) await zha_gateway.async_block_till_done() assert bool(entity.state["on"]) is False @@ -984,11 +1080,12 @@ async def test_zha_group_light_entity( assert len(zha_group.members) == 4 entity = get_group_entity(zha_group, platform=Platform.LIGHT) assert entity is not None + assert bool(entity.state["on"]) is False await send_attributes_report(zha_gateway, dev2_cluster_on_off, {0: 1}) await zha_gateway.async_block_till_done() # group member updates are debounced assert bool(entity.state["on"]) is False - await asyncio.sleep(0.1) + await asyncio.sleep(1) await zha_gateway.async_block_till_done() assert bool(entity.state["on"]) is True @@ -1009,6 +1106,14 @@ async def test_zha_group_light_entity( get_group_entity(zha_group, platform=Platform.LIGHT) +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) @pytest.mark.parametrize( ("plugged_attr_reads", "config_override", "expected_state"), [ @@ -1099,6 +1204,14 @@ async def test_light_initialization( "zigpy.zcl.clusters.general.OnOff.request", new=AsyncMock(return_value=[sentinel.data, zcl_f.Status.SUCCESS]), ) +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_transitions( zha_gateway: Gateway, ) -> None: @@ -1114,8 +1227,14 @@ async def test_transitions( ] # test creating a group with 2 members - zha_group: Group = await zha_gateway.async_create_zigpy_group("Test Group", members) - await zha_gateway.async_block_till_done() + if not hasattr(zha_gateway, "ws_gateway"): + zha_group = await zha_gateway.async_create_zigpy_group("Test Group", members) + await zha_gateway.async_block_till_done() + else: + zha_group = await zha_gateway.ws_gateway.async_create_zigpy_group( + "Test Group", members + ) + await zha_gateway.async_block_till_done() assert zha_group is not None assert len(zha_group.members) == 2 @@ -1139,17 +1258,64 @@ async def test_transitions( assert device_2_light_entity.unique_id in zha_group.all_member_entity_unique_ids assert eWeLink_light_entity.unique_id not in zha_group.all_member_entity_unique_ids - dev1_cluster_on_off = device_light_1.device.endpoints[1].on_off - dev2_cluster_on_off = device_light_2.device.endpoints[1].on_off - eWeLink_cluster_on_off = eWeLink_light.device.endpoints[1].on_off + if not hasattr(zha_gateway, "ws_gateway"): + dev1_cluster_on_off = device_light_1.device.endpoints[1].on_off + dev1_cluster_level = device_light_1.device.endpoints[1].level + dev1_cluster_color = device_light_1.device.endpoints[1].light_color + + dev2_cluster_on_off = device_light_2.device.endpoints[1].on_off + dev2_cluster_level = device_light_2.device.endpoints[1].level + dev2_cluster_color = device_light_2.device.endpoints[1].light_color + + eWeLink_cluster_on_off = eWeLink_light.device.endpoints[1].on_off + eWeLink_cluster_level = eWeLink_light.device.endpoints[1].level + eWeLink_cluster_color = eWeLink_light.device.endpoints[1].light_color + else: + dev1_cluster_on_off = ( + zha_gateway.ws_gateway.devices[device_light_1.ieee] + .device.endpoints[1] + .on_off + ) + dev1_cluster_level = ( + zha_gateway.ws_gateway.devices[device_light_1.ieee] + .device.endpoints[1] + .level + ) + dev1_cluster_color = ( + zha_gateway.ws_gateway.devices[device_light_1.ieee] + .device.endpoints[1] + .light_color + ) - dev1_cluster_level = device_light_1.device.endpoints[1].level - dev2_cluster_level = device_light_2.device.endpoints[1].level - eWeLink_cluster_level = eWeLink_light.device.endpoints[1].level + dev2_cluster_on_off = ( + zha_gateway.ws_gateway.devices[device_light_2.ieee] + .device.endpoints[1] + .on_off + ) + dev2_cluster_level = ( + zha_gateway.ws_gateway.devices[device_light_2.ieee] + .device.endpoints[1] + .level + ) + dev2_cluster_color = ( + zha_gateway.ws_gateway.devices[device_light_2.ieee] + .device.endpoints[1] + .light_color + ) - dev1_cluster_color = device_light_1.device.endpoints[1].light_color - dev2_cluster_color = device_light_2.device.endpoints[1].light_color - eWeLink_cluster_color = eWeLink_light.device.endpoints[1].light_color + eWeLink_cluster_on_off = ( + zha_gateway.ws_gateway.devices[eWeLink_light.ieee] + .device.endpoints[1] + .on_off + ) + eWeLink_cluster_level = ( + zha_gateway.ws_gateway.devices[eWeLink_light.ieee].device.endpoints[1].level + ) + eWeLink_cluster_color = ( + zha_gateway.ws_gateway.devices[eWeLink_light.ieee] + .device.endpoints[1] + .light_color + ) # test that the lights were created and are off assert bool(entity.state["on"]) is False @@ -1159,6 +1325,7 @@ async def test_transitions( # first test 0 length transition with no color and no brightness provided dev1_cluster_on_off.request.reset_mock() dev1_cluster_level.request.reset_mock() + dev1_cluster_color.request.reset_mock() await device_1_light_entity.async_turn_on(transition=0) await zha_gateway.async_block_till_done() assert dev1_cluster_on_off.request.call_count == 0 @@ -1760,12 +1927,39 @@ async def test_transitions( "zigpy.zcl.clusters.general.OnOff.request", new=AsyncMock(return_value=[sentinel.data, zcl_f.Status.SUCCESS]), ) +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_on_with_off_color(zha_gateway: Gateway) -> None: """Test turning on the light and sending color commands before on/level commands for supporting lights.""" + device_light_1 = await device_light_1_mock(zha_gateway) - dev1_cluster_on_off = device_light_1.device.endpoints[1].on_off - dev1_cluster_level = device_light_1.device.endpoints[1].level - dev1_cluster_color = device_light_1.device.endpoints[1].light_color + + if not hasattr(zha_gateway, "ws_gateway"): + dev1_cluster_on_off = device_light_1.device.endpoints[1].on_off + dev1_cluster_level = device_light_1.device.endpoints[1].level + dev1_cluster_color = device_light_1.device.endpoints[1].light_color + else: + dev1_cluster_on_off = ( + zha_gateway.ws_gateway.devices[device_light_1.ieee] + .device.endpoints[1] + .on_off + ) + dev1_cluster_level = ( + zha_gateway.ws_gateway.devices[device_light_1.ieee] + .device.endpoints[1] + .level + ) + dev1_cluster_color = ( + zha_gateway.ws_gateway.devices[device_light_1.ieee] + .device.endpoints[1] + .light_color + ) entity = get_entity(device_light_1, platform=Platform.LIGHT) @@ -1812,12 +2006,15 @@ async def test_on_with_off_color(zha_gateway: Gateway) -> None: assert entity.state["color_temp"] == 235 assert entity.state["color_mode"] == ColorMode.COLOR_TEMP assert entity.supported_color_modes == {ColorMode.COLOR_TEMP, ColorMode.XY} - assert entity._supported_color_modes == { - ColorMode.COLOR_TEMP, - ColorMode.XY, - ColorMode.ONOFF, - ColorMode.BRIGHTNESS, - } + + # TODO what do we do here... + if not hasattr(zha_gateway, "ws_gateway"): + assert entity._supported_color_modes == { + ColorMode.COLOR_TEMP, + ColorMode.XY, + ColorMode.ONOFF, + ColorMode.BRIGHTNESS, + } # now let's turn off the Execute_if_off option and see if the old behavior is restored dev1_cluster_color.PLUGGED_ATTR_READS = {"options": 0} @@ -1886,6 +2083,14 @@ async def test_on_with_off_color(zha_gateway: Gateway) -> None: "zigpy.zcl.clusters.general.LevelControl.request", new=AsyncMock(return_value=[sentinel.data, zcl_f.Status.SUCCESS]), ) +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_group_member_assume_state(zha_gateway: Gateway) -> None: """Test the group members assume state function.""" @@ -1907,8 +2112,14 @@ async def test_group_member_assume_state(zha_gateway: Gateway) -> None: ] # test creating a group with 2 members - zha_group: Group = await zha_gateway.async_create_zigpy_group("Test Group", members) - await zha_gateway.async_block_till_done() + if not hasattr(zha_gateway, "ws_gateway"): + zha_group = await zha_gateway.async_create_zigpy_group("Test Group", members) + await zha_gateway.async_block_till_done() + else: + zha_group = await zha_gateway.ws_gateway.async_create_zigpy_group( + "Test Group", members + ) + await zha_gateway.async_block_till_done() assert zha_group is not None assert len(zha_group.members) == 2 @@ -1978,8 +2189,17 @@ async def test_group_member_assume_state(zha_gateway: Gateway) -> None: assert device_2_light_entity.state["brightness"] == 100 +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_light_state_restoration(zha_gateway: Gateway) -> None: """Test the light state restoration function.""" + device_light_3 = await device_light_3_mock(zha_gateway) entity = get_entity(device_light_3, platform=Platform.LIGHT) entity.restore_external_state_attributes( @@ -1992,6 +2212,7 @@ async def test_light_state_restoration(zha_gateway: Gateway) -> None: color_mode=ColorMode.XY, effect="colorloop", ) + await zha_gateway.async_block_till_done() assert entity.state["on"] is True assert entity.state["brightness"] == 34 @@ -2010,10 +2231,24 @@ async def test_light_state_restoration(zha_gateway: Gateway) -> None: color_mode=None, effect=None, ) + await zha_gateway.async_block_till_done() assert entity.state["on"] is True + assert entity.is_on assert entity.state["brightness"] == 34 + assert entity.brightness == 34 assert entity.state["color_temp"] == 500 + assert entity.color_temp == 500 assert entity.state["xy_color"] == (1, 2) + assert entity.xy_color == (1, 2) assert entity.state["color_mode"] == ColorMode.XY + assert entity.color_mode == ColorMode.XY assert entity.state["effect"] == "colorloop" + assert entity.effect == "colorloop" + assert entity.effect_list == [EFFECT_OFF, EFFECT_COLORLOOP] + assert ( + entity.supported_features + == LightEntityFeature.EFFECT + | LightEntityFeature.FLASH + | LightEntityFeature.TRANSITION + ) diff --git a/tests/test_lock.py b/tests/test_lock.py index 570e77863..f7979f72e 100644 --- a/tests/test_lock.py +++ b/tests/test_lock.py @@ -2,6 +2,7 @@ from unittest.mock import patch +import pytest import zigpy.profiles.zha from zigpy.zcl.clusters import closures, general import zigpy.zcl.foundation as zcl_f @@ -39,6 +40,14 @@ } +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_lock(zha_gateway: Gateway) -> None: """Test zha lock platform.""" @@ -205,17 +214,29 @@ async def async_disable_user_code( assert cluster.request.call_args[0][4] == closures.DoorLock.UserStatus.Disabled +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_lock_state_restoration(zha_gateway: Gateway) -> None: """Test the lock state restoration.""" + zigpy_device = create_mock_zigpy_device(zha_gateway, ZIGPY_LOCK) zha_device = await join_zigpy_device(zha_gateway, zigpy_device) entity = get_entity(zha_device, platform=Platform.LOCK) assert entity.state["is_locked"] is False + assert entity.is_locked is False entity.restore_external_state_attributes(state=STATE_LOCKED) + await zha_gateway.async_block_till_done() # needed for WS commands assert entity.state["is_locked"] is True entity.restore_external_state_attributes(state=STATE_UNLOCKED) + await zha_gateway.async_block_till_done() # needed for WS commands assert entity.state["is_locked"] is False diff --git a/tests/test_model.py b/tests/test_model.py new file mode 100644 index 000000000..9cdc6b22b --- /dev/null +++ b/tests/test_model.py @@ -0,0 +1,115 @@ +"""Tests for the ZHA model module.""" + +from collections.abc import Callable +from enum import Enum + +from zigpy.types import NWK +from zigpy.types.named import EUI64 + +from zha.model import convert_enum +from zha.zigbee.device import DeviceInfo, ZHAEvent + + +def test_ser_deser_zha_event(): + """Test serializing and deserializing ZHA events.""" + + zha_event = ZHAEvent( + device_ieee="00:00:00:00:00:00:00:00", + unique_id="00:00:00:00:00:00:00:00", + data={"key": "value"}, + ) + + assert isinstance(zha_event.device_ieee, EUI64) + assert zha_event.device_ieee == EUI64.convert("00:00:00:00:00:00:00:00") + assert zha_event.unique_id == "00:00:00:00:00:00:00:00" + assert zha_event.data == {"key": "value"} + + assert zha_event.model_dump() == { + "message_type": "event", + "event_type": "device_event", + "event": "zha_event", + "device_ieee": "00:00:00:00:00:00:00:00", + "unique_id": "00:00:00:00:00:00:00:00", + "data": {"key": "value"}, + "model_class_name": "ZHAEvent", + } + + assert ( + zha_event.model_dump_json() + == '{"message_type":"event","event_type":"device_event","event":"zha_event",' + '"device_ieee":"00:00:00:00:00:00:00:00","unique_id":"00:00:00:00:00:00:00:00",' + '"data":{"key":"value"},"model_class_name":"ZHAEvent"}' + ) + + device_info = DeviceInfo( + ieee="00:00:00:00:00:00:00:00", + nwk="0x0000", + manufacturer="test", + model="test", + name="test", + quirk_applied=True, + quirk_class="test", + quirk_id="test", + manufacturer_code=0x0000, + power_source="test", + lqi=1, + rssi=2, + last_seen=123456789.0, + last_seen_time=None, + available=True, + on_network=True, + is_groupable=True, + device_type="test", + signature={"foo": "bar"}, + ) + + assert isinstance(device_info.ieee, EUI64) + assert device_info.ieee == EUI64.convert("00:00:00:00:00:00:00:00") + assert isinstance(device_info.nwk, NWK) + + assert device_info.model_dump() == { + "ieee": "00:00:00:00:00:00:00:00", + "nwk": 0x0000, + "manufacturer": "test", + "model": "test", + "name": "test", + "quirk_applied": True, + "quirk_class": "test", + "quirk_id": "test", + "manufacturer_code": 0, + "power_source": "test", + "lqi": 1, + "rssi": 2, + "last_seen": 123456789.0, + "last_seen_time": None, + "available": True, + "on_network": True, + "is_groupable": True, + "device_type": "test", + "signature": {"foo": "bar"}, + "sw_version": None, + } + + assert device_info.model_dump_json() == ( + '{"ieee":"00:00:00:00:00:00:00:00","nwk":"0x0000",' + '"manufacturer":"test","model":"test","name":"test","quirk_applied":true,' + '"quirk_class":"test","quirk_id":"test","manufacturer_code":0,"power_source":"test",' + '"lqi":1,"rssi":2,"last_seen":123456789.0,"last_seen_time":null,"available":true,' + '"on_network":true,"is_groupable":true,"device_type":"test","signature":{"foo":"bar"},' + '"sw_version":null}' + ) + + +def test_convert_enum() -> None: + """Test the convert enum method.""" + + class TestEnum(Enum): + """Test enum.""" + + VALUE = 1 + + convert_test_enum: Callable[[str | Enum], Enum] = convert_enum(TestEnum) + + assert convert_test_enum(TestEnum.VALUE.name) == TestEnum.VALUE + assert isinstance(convert_test_enum(TestEnum.VALUE.name), TestEnum) + assert convert_test_enum(TestEnum.VALUE) == TestEnum.VALUE diff --git a/tests/test_number.py b/tests/test_number.py index 7408534d4..1c389db3b 100644 --- a/tests/test_number.py +++ b/tests/test_number.py @@ -23,7 +23,11 @@ ) from zha.application import Platform from zha.application.gateway import Gateway -from zha.application.platforms import EntityCategory, PlatformEntity +from zha.application.platforms import ( + EntityCategory, + PlatformEntity, + WebSocketClientEntity, +) from zha.application.platforms.number.const import NumberMode from zha.exceptions import ZHAException @@ -79,6 +83,14 @@ async def light_mock(zha_gateway: Gateway) -> ZigpyDevice: return zigpy_device +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_number( zha_gateway: Gateway, ) -> None: @@ -115,8 +127,13 @@ async def test_number( assert "engineering_units" in attr_reads assert "application_type" in attr_reads + entity_type = ( + PlatformEntity + if not hasattr(zha_gateway, "ws_gateway") + else WebSocketClientEntity + ) entity: PlatformEntity = get_entity(zha_device, platform=Platform.NUMBER) - assert isinstance(entity, PlatformEntity) + assert isinstance(entity, entity_type) assert cluster.read_attributes.call_count == 3 @@ -124,6 +141,7 @@ async def test_number( # test that the state is 15.0 assert entity.state["state"] == 15.0 + assert entity.native_value == 15.0 # test attributes assert entity.info_object.min_value == 1.0 @@ -179,6 +197,14 @@ async def test_number( assert entity.state["state"] == 30.0 +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) @pytest.mark.parametrize( ("attr", "initial_value", "new_value", "max_value"), ( @@ -310,6 +336,14 @@ async def test_level_control_number( assert entity.state["state"] == initial_value +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) @pytest.mark.parametrize( ("attr", "initial_value", "new_value"), (("start_up_color_temperature", 500, 350),), diff --git a/tests/test_select.py b/tests/test_select.py index 57b1411f8..c005f3cf6 100644 --- a/tests/test_select.py +++ b/tests/test_select.py @@ -2,6 +2,7 @@ from unittest.mock import call, patch +import pytest from zhaquirks import ( DEVICE_TYPE, ENDPOINTS, @@ -29,12 +30,21 @@ ) from zha.application import Platform from zha.application.gateway import Gateway -from zha.application.platforms import EntityCategory +from zha.application.platforms import EntityCategory, PlatformEntity from zha.application.platforms.select import AqaraMotionSensitivities +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_select(zha_gateway: Gateway) -> None: """Test zha select platform.""" + zigpy_device = create_mock_zigpy_device( zha_gateway, { @@ -63,7 +73,9 @@ async def test_select(zha_gateway: Gateway) -> None: "Fire Panic", "Emergency Panic", ] - assert entity._enum == security.IasWd.Warning.WarningMode + + if isinstance(entity, PlatformEntity): + assert entity._enum == security.IasWd.Warning.WarningMode # change value from client await entity.async_select_option(security.IasWd.Warning.WarningMode.Burglar.name) @@ -107,6 +119,14 @@ def __init__(self, *args, **kwargs): } +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_on_off_select_attribute_report(zha_gateway: Gateway) -> None: """Test ZHA attribute report parsing for select platform.""" @@ -126,7 +146,7 @@ async def test_on_off_select_attribute_report(zha_gateway: Gateway) -> None: zigpy_device = get_device(zigpy_device) aqara_sensor = await join_zigpy_device(zha_gateway, zigpy_device) - cluster = aqara_sensor.device.endpoints.get(1).opple_cluster + cluster = zigpy_device.endpoints.get(1).opple_cluster entity = get_entity(aqara_sensor, platform=Platform.SELECT) assert entity.state["state"] == AqaraMotionSensitivities.Medium.name @@ -160,9 +180,15 @@ async def test_on_off_select_attribute_report(zha_gateway: Gateway) -> None: ) -async def test_on_off_select_attribute_report_v2( - zha_gateway: Gateway, -) -> None: +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) +async def test_on_off_select_attribute_report_v2(zha_gateway: Gateway) -> None: """Test ZHA attribute report parsing for select platform.""" zigpy_device = create_mock_zigpy_device( @@ -184,7 +210,7 @@ async def test_on_off_select_attribute_report_v2( zha_device = await join_zigpy_device(zha_gateway, zigpy_device) cluster = zigpy_device.endpoints[1].opple_cluster - assert isinstance(zha_device.device, CustomDeviceV2) + assert isinstance(zigpy_device, CustomDeviceV2) entity = get_entity(zha_device, platform=Platform.SELECT) @@ -228,8 +254,17 @@ async def test_on_off_select_attribute_report_v2( ) +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_non_zcl_select_state_restoration(zha_gateway: Gateway) -> None: """Test the non-ZCL select state restoration.""" + zigpy_device = create_mock_zigpy_device( zha_gateway, { @@ -251,9 +286,11 @@ async def test_non_zcl_select_state_restoration(zha_gateway: Gateway) -> None: entity.restore_external_state_attributes( state=security.IasWd.Warning.WarningMode.Burglar.name ) + await zha_gateway.async_block_till_done() # needed for WS operations assert entity.state["state"] == security.IasWd.Warning.WarningMode.Burglar.name entity.restore_external_state_attributes( state=security.IasWd.Warning.WarningMode.Fire.name ) + await zha_gateway.async_block_till_done() # needed for WS operations assert entity.state["state"] == security.IasWd.Warning.WarningMode.Fire.name diff --git a/tests/test_sensor.py b/tests/test_sensor.py index 25bbb4a36..9fa289ee4 100644 --- a/tests/test_sensor.py +++ b/tests/test_sensor.py @@ -36,7 +36,7 @@ from zha.application.gateway import Gateway from zha.application.platforms import PlatformEntity, sensor from zha.application.platforms.sensor import DanfossSoftwareErrorCode, UnitOfMass -from zha.application.platforms.sensor.const import SensorDeviceClass +from zha.application.platforms.sensor.const import SensorDeviceClass, SensorStateClass from zha.units import PERCENTAGE, UnitOfEnergy, UnitOfPressure, UnitOfVolume from zha.zigbee.device import Device @@ -126,6 +126,8 @@ async def async_test_illuminance( await send_attributes_report(zha_gateway, cluster, {0: 0xFFFF}) assert_state(entity, None, "lx") + assert entity.extra_state_attribute_names is None + async def async_test_metering( zha_gateway: Gateway, cluster: Cluster, entity: PlatformEntity @@ -578,6 +580,14 @@ async def async_test_change_source_timestamp( ), ), ) +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_sensor( zha_gateway: Gateway, cluster_id: int, @@ -615,8 +625,7 @@ async def test_sensor( entity = get_entity( zha_device, platform=Platform.SENSOR, exact_entity_type=entity_type ) - - await zha_gateway.async_block_till_done() + assert entity.available is True # test sensor associated logic await test_func(zha_gateway, cluster, entity) @@ -631,6 +640,14 @@ def assert_state(entity: PlatformEntity, state: Any, unit_of_measurement: str) - assert entity.info_object.unit == unit_of_measurement +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_electrical_measurement_init( zha_gateway: Gateway, caplog: pytest.LogCaptureFixture, @@ -666,9 +683,20 @@ async def test_electrical_measurement_init( ) assert entity.state["state"] == 100 - cluster_handler = list(zha_device._endpoints.values())[0].all_cluster_handlers[ - "1:0x0b04" - ] + if isinstance(entity, sensor.WebSocketClientSensorEntity): + server_device = zha_gateway.ws_gateway.devices[zha_device.ieee] + cluster_handler = list(server_device._endpoints.values())[ + 0 + ].all_cluster_handlers["1:0x0b04"] + polling_interval = server_device.platform_entities[ + (entity.PLATFORM, entity.unique_id) + ].__polling_interval + else: + cluster_handler = list(zha_device._endpoints.values())[0].all_cluster_handlers[ + "1:0x0b04" + ] + polling_interval = entity.__polling_interval + assert cluster_handler.ac_power_divisor == 1 assert cluster_handler.ac_power_multiplier == 1 @@ -678,26 +706,34 @@ async def test_electrical_measurement_init( cluster, {EMAttrs.active_power.id: 20, EMAttrs.power_divisor.id: 5}, ) + await asyncio.sleep(polling_interval + 1) assert cluster_handler.ac_power_divisor == 5 assert cluster_handler.ac_power_multiplier == 1 assert entity.state["state"] == 4.0 - zha_device.on_network = False + if isinstance(entity, sensor.WebSocketClientSensorEntity): + zha_gateway.ws_gateway.devices[zha_device.ieee].on_network = False + else: + zha_device.on_network = False - await asyncio.sleep(entity.__polling_interval + 1) + await asyncio.sleep(polling_interval + 1) await zha_gateway.async_block_till_done(wait_background_tasks=True) assert ( "1-2820: skipping polling for updated state, available: False, allow polled requests: True" in caplog.text ) - zha_device.on_network = True + if isinstance(entity, sensor.WebSocketClientSensorEntity): + zha_gateway.ws_gateway.devices[zha_device.ieee].on_network = True + else: + zha_device.on_network = True await send_attributes_report( zha_gateway, cluster, {EMAttrs.active_power.id: 30, EMAttrs.ac_power_divisor.id: 10}, ) + await asyncio.sleep(polling_interval + 1) assert cluster_handler.ac_power_divisor == 10 assert cluster_handler.ac_power_multiplier == 1 assert entity.state["state"] == 3.0 @@ -708,6 +744,7 @@ async def test_electrical_measurement_init( cluster, {EMAttrs.active_power.id: 20, EMAttrs.power_multiplier.id: 6}, ) + await asyncio.sleep(polling_interval + 1) assert cluster_handler.ac_power_divisor == 10 assert cluster_handler.ac_power_multiplier == 6 assert entity.state["state"] == 12.0 @@ -717,31 +754,42 @@ async def test_electrical_measurement_init( cluster, {EMAttrs.active_power.id: 30, EMAttrs.ac_power_multiplier.id: 20}, ) + await asyncio.sleep(polling_interval + 1) assert cluster_handler.ac_power_divisor == 10 assert cluster_handler.ac_power_multiplier == 20 assert entity.state["state"] == 60.0 - entity._refresh = AsyncMock(wraps=entity._refresh) + if isinstance(entity, sensor.WebSocketClientSensorEntity): + server_entity = zha_gateway.ws_gateway.devices[ + zha_device.ieee + ].platform_entities[(entity.PLATFORM, entity.unique_id)] + server_entity._refresh = AsyncMock(wraps=server_entity._refresh) + refresh_mock = server_entity._refresh + else: + entity._refresh = AsyncMock(wraps=entity._refresh) + refresh_mock = entity._refresh - assert entity._refresh.await_count == 0 + assert refresh_mock.await_count == 0 entity.disable() + await zha_gateway.async_block_till_done() assert entity.enabled is False - await asyncio.sleep(entity.__polling_interval + 1) + await asyncio.sleep(polling_interval + 1) await zha_gateway.async_block_till_done(wait_background_tasks=True) - assert entity._refresh.await_count == 0 + assert refresh_mock.await_count == 0 entity.enable() + await zha_gateway.async_block_till_done() assert entity.enabled is True - await asyncio.sleep(entity.__polling_interval + 1) + await asyncio.sleep(polling_interval + 1) await zha_gateway.async_block_till_done(wait_background_tasks=True) - assert entity._refresh.await_count == 1 + assert refresh_mock.await_count == 1 @pytest.mark.parametrize( @@ -760,14 +808,14 @@ async def test_electrical_measurement_init( "rms_current", }, { - sensor.PolledElectricalMeasurement, - sensor.ElectricalMeasurementFrequency, - sensor.ElectricalMeasurementPowerFactor, + sensor.PolledElectricalMeasurement.__name__, + sensor.ElectricalMeasurementFrequency.__name__, + sensor.ElectricalMeasurementPowerFactor.__name__, }, { - sensor.ElectricalMeasurementApparentPower, - sensor.ElectricalMeasurementRMSVoltage, - sensor.ElectricalMeasurementRMSCurrent, + sensor.ElectricalMeasurementApparentPower.__name__, + sensor.ElectricalMeasurementRMSVoltage.__name__, + sensor.ElectricalMeasurementRMSCurrent.__name__, }, ), ( @@ -779,26 +827,26 @@ async def test_electrical_measurement_init( "power_factor", }, { - sensor.ElectricalMeasurementRMSVoltage, - sensor.PolledElectricalMeasurement, + sensor.ElectricalMeasurementRMSVoltage.__name__, + sensor.PolledElectricalMeasurement.__name__, }, { - sensor.ElectricalMeasurementApparentPower, - sensor.ElectricalMeasurementRMSCurrent, - sensor.ElectricalMeasurementFrequency, - sensor.ElectricalMeasurementPowerFactor, + sensor.ElectricalMeasurementApparentPower.__name__, + sensor.ElectricalMeasurementRMSCurrent.__name__, + sensor.ElectricalMeasurementFrequency.__name__, + sensor.ElectricalMeasurementPowerFactor.__name__, }, ), ( homeautomation.ElectricalMeasurement.cluster_id, set(), { - sensor.ElectricalMeasurementRMSVoltage, - sensor.PolledElectricalMeasurement, - sensor.ElectricalMeasurementApparentPower, - sensor.ElectricalMeasurementRMSCurrent, - sensor.ElectricalMeasurementFrequency, - sensor.ElectricalMeasurementPowerFactor, + sensor.ElectricalMeasurementRMSVoltage.__name__, + sensor.PolledElectricalMeasurement.__name__, + sensor.ElectricalMeasurementApparentPower.__name__, + sensor.ElectricalMeasurementRMSCurrent.__name__, + sensor.ElectricalMeasurementFrequency.__name__, + sensor.ElectricalMeasurementPowerFactor.__name__, }, set(), ), @@ -808,10 +856,10 @@ async def test_electrical_measurement_init( "instantaneous_demand", }, { - sensor.SmartEnergySummation, + sensor.SmartEnergySummation.__name__, }, { - sensor.SmartEnergyMetering, + sensor.SmartEnergyMetering.__name__, }, ), ( @@ -822,21 +870,29 @@ async def test_electrical_measurement_init( }, set(), { - sensor.SmartEnergyMetering, - sensor.SmartEnergySummation, + sensor.SmartEnergyMetering.__name__, + sensor.SmartEnergySummation.__name__, }, ), ( smartenergy.Metering.cluster_id, set(), { - sensor.SmartEnergyMetering, - sensor.SmartEnergySummation, + sensor.SmartEnergyMetering.__name__, + sensor.SmartEnergySummation.__name__, }, set(), ), ), ) +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_unsupported_attributes_sensor( zha_gateway: Gateway, cluster_id: int, @@ -867,7 +923,7 @@ async def test_unsupported_attributes_sensor( zha_device = await join_zigpy_device(zha_gateway, zigpy_device) present_entity_types = { - type(e) + e.info_object.class_name for e in zha_device.platform_entities.values() if e.PLATFORM == Platform.SENSOR and ("lqi" not in e.unique_id and "rssi" not in e.unique_id) @@ -878,100 +934,138 @@ async def test_unsupported_attributes_sensor( @pytest.mark.parametrize( - "raw_uom, raw_value, expected_state, expected_uom", + "raw_uom, raw_value, expected_state, expected_uom, expected_device_class, expected_state_class", ( ( 1, 12320, 1.23, UnitOfVolume.CUBIC_METERS, + SensorDeviceClass.VOLUME, + SensorStateClass.TOTAL_INCREASING, ), ( 1, 1232000, 123.2, UnitOfVolume.CUBIC_METERS, + SensorDeviceClass.VOLUME, + SensorStateClass.TOTAL_INCREASING, ), ( 3, 2340, 0.23, UnitOfVolume.CUBIC_FEET, + SensorDeviceClass.VOLUME, + SensorStateClass.TOTAL_INCREASING, ), ( 3, 2360, 0.24, UnitOfVolume.CUBIC_FEET, + SensorDeviceClass.VOLUME, + SensorStateClass.TOTAL_INCREASING, ), ( 8, 23660, 2.37, UnitOfPressure.KPA, + SensorDeviceClass.PRESSURE, + SensorStateClass.MEASUREMENT, ), ( 0, 9366, 0.937, UnitOfEnergy.KILO_WATT_HOUR, + SensorDeviceClass.ENERGY, + SensorStateClass.TOTAL_INCREASING, ), ( 0, 999, 0.1, UnitOfEnergy.KILO_WATT_HOUR, + SensorDeviceClass.ENERGY, + SensorStateClass.TOTAL_INCREASING, ), ( 0, 10091, 1.009, UnitOfEnergy.KILO_WATT_HOUR, + SensorDeviceClass.ENERGY, + SensorStateClass.TOTAL_INCREASING, ), ( 0, 10099, 1.01, UnitOfEnergy.KILO_WATT_HOUR, + SensorDeviceClass.ENERGY, + SensorStateClass.TOTAL_INCREASING, ), ( 0, 100999, 10.1, UnitOfEnergy.KILO_WATT_HOUR, + SensorDeviceClass.ENERGY, + SensorStateClass.TOTAL_INCREASING, ), ( 0, 100023, 10.002, UnitOfEnergy.KILO_WATT_HOUR, + SensorDeviceClass.ENERGY, + SensorStateClass.TOTAL_INCREASING, ), ( 0, 102456, 10.246, UnitOfEnergy.KILO_WATT_HOUR, + SensorDeviceClass.ENERGY, + SensorStateClass.TOTAL_INCREASING, ), ( 5, 102456, 10.25, "IMP gal", + None, + SensorStateClass.TOTAL_INCREASING, ), ( 7, 50124, 5.01, UnitOfVolume.LITERS, + SensorDeviceClass.VOLUME, + SensorStateClass.TOTAL_INCREASING, ), ), ) +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_se_summation_uom( zha_gateway: Gateway, raw_uom: int, raw_value: int, expected_state: str, expected_uom: str, + expected_device_class: SensorDeviceClass, + expected_state_class: SensorStateClass, ) -> None: """Test zha smart energy summation.""" @@ -1010,6 +1104,15 @@ async def test_se_summation_uom( zha_device, platform=Platform.SENSOR, qualifier="summation_delivered" ) + assert entity.device_class == expected_device_class + assert entity.state_class == expected_state_class + assert entity.extra_state_attribute_names == { + "device_type", + "status", + "zcl_unit_of_measurement", + } + assert entity.native_value == expected_state + assert_state(entity, expected_state, expected_uom) @@ -1025,6 +1128,14 @@ async def test_se_summation_uom( ), ), ) +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_elec_measurement_sensor_type( raw_measurement_type: int, expected_type: str, @@ -1043,6 +1154,14 @@ async def test_elec_measurement_sensor_type( assert entity.state["measurement_type"] == expected_type +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_elec_measurement_sensor_polling(zha_gateway: Gateway) -> None: """Test ZHA electrical measurement sensor polling.""" @@ -1106,6 +1225,14 @@ async def test_elec_measurement_sensor_polling(zha_gateway: Gateway) -> None: }, ), ) +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_elec_measurement_skip_unsupported_attribute( zha_gateway: Gateway, supported_attributes: set[str], @@ -1115,7 +1242,7 @@ async def test_elec_measurement_skip_unsupported_attribute( elec_measurement_zigpy_dev = elec_measurement_zigpy_device_mock(zha_gateway) zha_dev = await join_zigpy_device(zha_gateway, elec_measurement_zigpy_dev) - cluster = zha_dev.device.endpoints[1].electrical_measurement + cluster = elec_measurement_zigpy_dev.endpoints[1].electrical_measurement all_attrs = { "active_power", @@ -1139,7 +1266,7 @@ async def test_elec_measurement_skip_unsupported_attribute( exact_entity_type=sensor.PolledElectricalMeasurement, ) await entity.async_update() - await zha_dev.gateway.async_block_till_done() + await zha_gateway.async_block_till_done() assert cluster.read_attributes.call_count == math.ceil( len(supported_attributes) / ZHA_CLUSTER_HANDLER_READS_PER_REQ ) @@ -1206,6 +1333,7 @@ async def zigpy_device_timestamp_sensor_v2_mock( return zha_device, zigpy_device.endpoints[1].time_test_cluster +# TODO figure out how to support this in the websocket gateway async def test_timestamp_sensor_v2(zha_gateway: Gateway) -> None: """Test quirks defined sensor.""" @@ -1276,11 +1404,24 @@ async def zigpy_device_aqara_sensor_v2_mock( return zha_device, zigpy_device.endpoints[1].opple_cluster +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_last_feeding_size_sensor_v2(zha_gateway: Gateway) -> None: """Test quirks defined sensor.""" zha_device, cluster = await zigpy_device_aqara_sensor_v2_mock(zha_gateway) - assert isinstance(zha_device.device, CustomDeviceV2) + if hasattr(zha_gateway, "ws_gateway"): + assert isinstance( + zha_gateway.ws_gateway.devices[zha_device.ieee].device, CustomDeviceV2 + ) + else: + assert isinstance(zha_device.device, CustomDeviceV2) entity = get_entity( zha_device, platform=Platform.SENSOR, qualifier="last_feeding_size" ) @@ -1292,10 +1433,24 @@ async def test_last_feeding_size_sensor_v2(zha_gateway: Gateway) -> None: assert_state(entity, 5.0, "g") +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_device_counter_sensors(zha_gateway: Gateway) -> None: """Test coordinator counter sensor.""" - coordinator = zha_gateway.coordinator_zha_device + if hasattr(zha_gateway, "ws_gateway"): + coordinator = zha_gateway.ws_gateway.coordinator_zha_device + server_gateway = zha_gateway.ws_gateway + else: + coordinator = zha_gateway.coordinator_zha_device + server_gateway = zha_gateway + assert coordinator.is_coordinator entity = get_entity(coordinator, platform=Platform.SENSOR) @@ -1306,32 +1461,40 @@ async def test_device_counter_sensors(zha_gateway: Gateway) -> None: "counter_1" ].increment() - await asyncio.sleep(zha_gateway.global_updater.__polling_interval + 2) + await asyncio.sleep(server_gateway.global_updater.__polling_interval + 2) await zha_gateway.async_block_till_done(wait_background_tasks=True) assert entity.state["state"] == 2 # test disabling the entity disables it and removes it from the updater - assert len(zha_gateway.global_updater._update_listeners) == 3 + assert len(server_gateway.global_updater._update_listeners) == 3 assert entity.enabled is True entity.disable() assert entity.enabled is False - assert len(zha_gateway.global_updater._update_listeners) == 2 + assert len(server_gateway.global_updater._update_listeners) == 2 # test enabling the entity enables it and adds it to the updater entity.enable() assert entity.enabled is True - assert len(zha_gateway.global_updater._update_listeners) == 3 + assert len(server_gateway.global_updater._update_listeners) == 3 # make sure we don't get multiple listeners for the same entity in the updater entity.enable() - assert len(zha_gateway.global_updater._update_listeners) == 3 + assert len(server_gateway.global_updater._update_listeners) == 3 +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_device_unavailable_or_disabled_skips_entity_polling( zha_gateway: Gateway, caplog: pytest.LogCaptureFixture, @@ -1344,6 +1507,13 @@ async def test_device_unavailable_or_disabled_skips_entity_polling( ) assert not elec_measurement_zha_dev.is_coordinator assert not elec_measurement_zha_dev.is_active_coordinator + if hasattr(zha_gateway, "ws_gateway"): + server_device = zha_gateway.ws_gateway.devices[elec_measurement_zha_dev.ieee] + server_gateway = zha_gateway.ws_gateway + else: + server_device = elec_measurement_zha_dev + server_gateway = zha_gateway + entity = get_entity( elec_measurement_zha_dev, platform=Platform.SENSOR, @@ -1352,37 +1522,48 @@ async def test_device_unavailable_or_disabled_skips_entity_polling( assert entity.state["state"] is None - elec_measurement_zha_dev.device.rssi = 60 + server_device.device.rssi = 60 - await asyncio.sleep(zha_gateway.global_updater.__polling_interval + 2) + await asyncio.sleep(server_gateway.global_updater.__polling_interval + 2) await zha_gateway.async_block_till_done(wait_background_tasks=True) assert entity.state["state"] == 60 assert entity.enabled is True - assert len(zha_gateway.global_updater._update_listeners) == 5 + assert len(server_gateway.global_updater._update_listeners) == 5 # let's drop the normal update method from the updater entity.disable() + await zha_gateway.async_block_till_done() assert entity.enabled is False - assert len(zha_gateway.global_updater._update_listeners) == 4 + assert len(server_gateway.global_updater._update_listeners) == 4 # wrap the update method so we can count how many times it was called - entity.update = MagicMock(wraps=entity.update) - await asyncio.sleep(zha_gateway.global_updater.__polling_interval + 2) + if hasattr(zha_gateway, "ws_gateway"): + server_entity = server_gateway.devices[ + elec_measurement_zha_dev.ieee + ].platform_entities[(entity.PLATFORM, entity.unique_id)] + server_entity.update = MagicMock(wraps=server_entity.update) + mock_update = server_entity.update + else: + entity.update = MagicMock(wraps=entity.update) + mock_update = entity.update + + await asyncio.sleep(server_gateway.global_updater.__polling_interval + 2) await zha_gateway.async_block_till_done(wait_background_tasks=True) - assert entity.update.call_count == 0 + assert mock_update.call_count == 0 # re-enable the entity and ensure it is back in the updater and that update is called entity.enable() - assert len(zha_gateway.global_updater._update_listeners) == 5 + await zha_gateway.async_block_till_done() + assert len(server_gateway.global_updater._update_listeners) == 5 assert entity.enabled is True - await asyncio.sleep(zha_gateway.global_updater.__polling_interval + 2) + await asyncio.sleep(server_gateway.global_updater.__polling_interval + 2) await zha_gateway.async_block_till_done(wait_background_tasks=True) - assert entity.update.call_count == 1 + assert mock_update.call_count == 1 # knock it off the network and ensure the polling is skipped assert ( @@ -1390,11 +1571,11 @@ async def test_device_unavailable_or_disabled_skips_entity_polling( "available: False, allow polled requests: True" not in caplog.text ) - elec_measurement_zha_dev.on_network = False - await asyncio.sleep(zha_gateway.global_updater.__polling_interval + 2) + server_device.on_network = False + await asyncio.sleep(server_gateway.global_updater.__polling_interval + 2) await zha_gateway.async_block_till_done(wait_background_tasks=True) - assert entity.update.call_count == 2 + assert mock_update.call_count == 2 assert ( "00:0d:6f:00:0a:90:69:e7-1-0-rssi: skipping polling for updated state, " @@ -1434,6 +1615,14 @@ async def zigpy_device_danfoss_thermostat_mock( return zha_device, zigpy_device +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_danfoss_thermostat_sw_error(zha_gateway: Gateway) -> None: """Test quirks defined thermostat.""" diff --git a/tests/test_siren.py b/tests/test_siren.py index 746a79926..82396661d 100644 --- a/tests/test_siren.py +++ b/tests/test_siren.py @@ -3,6 +3,7 @@ import asyncio from unittest.mock import patch +import pytest from zigpy.const import SIG_EP_PROFILE from zigpy.profiles import zha from zigpy.zcl.clusters import general, security @@ -44,6 +45,14 @@ async def siren_mock( return zha_device, zigpy_device.endpoints[1].ias_wd +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_siren(zha_gateway: Gateway) -> None: """Test zha siren platform.""" @@ -119,8 +128,17 @@ async def test_siren(zha_gateway: Gateway) -> None: assert entity.state["state"] is True +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_siren_timed_off(zha_gateway: Gateway) -> None: """Test zha siren platform.""" + zha_device, cluster = await siren_mock(zha_gateway) assert cluster is not None @@ -146,6 +164,7 @@ async def test_siren_timed_off(zha_gateway: Gateway) -> None: # test that the state has changed to on assert entity.state["state"] is True + assert entity.is_on is True await asyncio.sleep(6) diff --git a/tests/test_switch.py b/tests/test_switch.py index 3cbdc7d38..91bf97d21 100644 --- a/tests/test_switch.py +++ b/tests/test_switch.py @@ -39,7 +39,7 @@ from zha.application.platforms import GroupEntity, PlatformEntity from zha.exceptions import ZHAException from zha.zigbee.device import Device -from zha.zigbee.group import Group, GroupMemberReference +from zha.zigbee.group import GroupMemberReference ON = 1 OFF = 0 @@ -109,8 +109,19 @@ async def device_switch_2_mock(zha_gateway: Gateway) -> Device: return zha_device -async def test_switch(zha_gateway: Gateway) -> None: +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) +async def test_switch( + zha_gateway: Gateway, +) -> None: """Test zha switch platform.""" + zigpy_device = create_mock_zigpy_device(zha_gateway, ZIGPY_DEVICE) zigpy_device.node_desc.mac_capability_flags |= ( 0b_0000_0100 # this one is mains powered @@ -124,10 +135,12 @@ async def test_switch(zha_gateway: Gateway) -> None: # turn on at switch await send_attributes_report(zha_gateway, cluster, {1: 0, 0: 1, 2: 2}) assert bool(entity.state["state"]) is True + assert bool(entity.is_on) is True # turn off at switch await send_attributes_report(zha_gateway, cluster, {1: 1, 0: 0, 2: 2}) assert bool(entity.state["state"]) is False + assert bool(entity.is_on) is False # turn on from client with patch( @@ -147,13 +160,18 @@ async def test_switch(zha_gateway: Gateway) -> None: tsn=None, ) + exc_match = ( + "Failed to turn off" + if not hasattr(zha_gateway, "ws_gateway") + else "'PLATFORM_ENTITY_ACTION_ERROR'" + ) # Fail turn off from client with ( patch( "zigpy.zcl.Cluster.request", return_value=[0x01, zcl_f.Status.FAILURE], ), - pytest.raises(ZHAException, match="Failed to turn off"), + pytest.raises(ZHAException, match=exc_match), ): await entity.async_turn_off() await zha_gateway.async_block_till_done() @@ -186,13 +204,18 @@ async def test_switch(zha_gateway: Gateway) -> None: tsn=None, ) + exc_match = ( + "Failed to turn on" + if not hasattr(zha_gateway, "ws_gateway") + else "'PLATFORM_ENTITY_ACTION_ERROR'" + ) # Fail turn on from client with ( patch( "zigpy.zcl.Cluster.request", return_value=[0x01, zcl_f.Status.FAILURE], ), - pytest.raises(ZHAException, match="Failed to turn on"), + pytest.raises(ZHAException, match=exc_match), ): await entity.async_turn_on() await zha_gateway.async_block_till_done() @@ -220,8 +243,17 @@ async def test_switch(zha_gateway: Gateway) -> None: assert bool(entity.state["state"]) is True +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_zha_group_switch_entity(zha_gateway: Gateway) -> None: """Test the switch entity for a ZHA group.""" + device_switch_1 = await device_switch_1_mock(zha_gateway) device_switch_2 = await device_switch_2_mock(zha_gateway) member_ieee_addresses = [device_switch_1.ieee, device_switch_2.ieee] @@ -231,8 +263,14 @@ async def test_zha_group_switch_entity(zha_gateway: Gateway) -> None: ] # test creating a group with 2 members - zha_group: Group = await zha_gateway.async_create_zigpy_group("Test Group", members) - await zha_gateway.async_block_till_done() + if not hasattr(zha_gateway, "ws_gateway"): + zha_group = await zha_gateway.async_create_zigpy_group("Test Group", members) + await zha_gateway.async_block_till_done() + else: + zha_group = await zha_gateway.ws_gateway.async_create_zigpy_group( + "Test Group", members + ) + await zha_gateway.async_block_till_done() assert zha_group is not None assert len(zha_group.members) == 2 @@ -246,8 +284,21 @@ async def test_zha_group_switch_entity(zha_gateway: Gateway) -> None: assert entity.info_object.fallback_name == zha_group.name group_cluster_on_off = zha_group.zigpy_group.endpoint[general.OnOff.cluster_id] - dev1_cluster_on_off = device_switch_1.device.endpoints[1].on_off - dev2_cluster_on_off = device_switch_2.device.endpoints[1].on_off + + if not hasattr(zha_gateway, "ws_gateway"): + dev1_cluster_on_off = device_switch_1.device.endpoints[1].on_off + dev2_cluster_on_off = device_switch_2.device.endpoints[1].on_off + else: + dev1_cluster_on_off = ( + zha_gateway.ws_gateway.devices[device_switch_1.ieee] + .device.endpoints[1] + .on_off + ) + dev2_cluster_on_off = ( + zha_gateway.ws_gateway.devices[device_switch_2.ieee] + .device.endpoints[1] + .on_off + ) # test that the lights were created and are off assert bool(entity.state["state"]) is False @@ -331,9 +382,17 @@ async def test_zha_group_switch_entity(zha_gateway: Gateway) -> None: # test that group light is now back on assert bool(entity.state["state"]) is True - await group_entity_availability_test( - zha_gateway, device_switch_1, device_switch_2, entity - ) + if not hasattr(zha_gateway, "ws_gateway"): + await group_entity_availability_test( + zha_gateway, device_switch_1, device_switch_2, entity + ) + else: + await group_entity_availability_test( + zha_gateway, + zha_gateway.ws_gateway.devices[device_switch_1.ieee], + zha_gateway.ws_gateway.devices[device_switch_2.ieee], + entity, + ) class WindowDetectionFunctionQuirk(CustomDevice): @@ -369,6 +428,14 @@ def __init__(self, *args, **kwargs): } +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_switch_configurable( zha_gateway: Gateway, ) -> None: @@ -482,6 +549,14 @@ async def test_switch_configurable( ] +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_switch_configurable_custom_on_off_values(zha_gateway: Gateway) -> None: """Test ZHA configurable switch platform.""" @@ -559,6 +634,14 @@ async def test_switch_configurable_custom_on_off_values(zha_gateway: Gateway) -> ] +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_switch_configurable_custom_on_off_values_force_inverted( zha_gateway: Gateway, ) -> None: @@ -639,6 +722,14 @@ async def test_switch_configurable_custom_on_off_values_force_inverted( ] +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_switch_configurable_custom_on_off_values_inverter_attribute( zha_gateway: Gateway, ) -> None: @@ -728,10 +819,19 @@ async def test_switch_configurable_custom_on_off_values_inverter_attribute( WCM = closures.WindowCovering.WindowCoveringMode +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_cover_inversion_switch(zha_gateway: Gateway) -> None: """Test ZHA cover platform.""" # load up cover domain + zigpy_cover_device = create_mock_zigpy_device(zha_gateway, ZIGPY_COVER_DEVICE) cluster = zigpy_cover_device.endpoints[1].window_covering cluster.PLUGGED_ATTR_READS = { @@ -743,11 +843,19 @@ async def test_cover_inversion_switch(zha_gateway: Gateway) -> None: } update_attribute_cache(cluster) zha_device = await join_zigpy_device(zha_gateway, zigpy_cover_device) - assert ( - not zha_device.endpoints[1] - .all_cluster_handlers[f"1:0x{cluster.cluster_id:04x}"] - .inverted - ) + + if hasattr(zha_gateway, "ws_gateway"): + ch = ( + zha_gateway.ws_gateway.devices[zha_device.ieee] + .endpoints[1] + .all_cluster_handlers[f"1:0x{cluster.cluster_id:04x}"] + ) + else: + ch = zha_device.endpoints[1].all_cluster_handlers[ + f"1:0x{cluster.cluster_id:04x}" + ] + + assert not ch.inverted assert cluster.read_attributes.call_count == 3 assert ( WCAttrs.current_position_lift_percentage.name @@ -820,10 +928,19 @@ async def test_cover_inversion_switch(zha_gateway: Gateway) -> None: assert bool(entity.state["state"]) is False +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_cover_inversion_switch_not_created(zha_gateway: Gateway) -> None: """Test ZHA cover platform.""" # load up cover domain + zigpy_cover_device = create_mock_zigpy_device(zha_gateway, ZIGPY_COVER_DEVICE) cluster = zigpy_cover_device.endpoints[1].window_covering cluster.PLUGGED_ATTR_READS = { diff --git a/tests/test_update.py b/tests/test_update.py index b2405a644..54cc9ab87 100644 --- a/tests/test_update.py +++ b/tests/test_update.py @@ -31,6 +31,12 @@ ATTR_LATEST_VERSION, ATTR_UPDATE_PERCENTAGE, ) +from zha.application.platforms.update.const import ( + ATTR_RELEASE_NOTES, + ATTR_RELEASE_SUMMARY, + ATTR_RELEASE_URL, + UpdateEntityFeature, +) from zha.exceptions import ZHAException @@ -151,11 +157,24 @@ async def setup_test_data( ) zha_device = await join_zigpy_device(zha_gateway, zigpy_device) - zha_device.async_update_sw_build_id(installed_fw_version) + if hasattr(zha_gateway, "ws_gateway"): + zha_gateway.ws_gateway.devices[zha_device.ieee].async_update_sw_build_id( + installed_fw_version + ) + else: + zha_device.async_update_sw_build_id(installed_fw_version) return zha_device, ota_cluster, fw_image, installed_fw_version +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_firmware_update_notification_from_zigpy(zha_gateway: Gateway) -> None: """Test ZHA update platform - firmware update notification.""" zigpy_device = zigpy_device_mock(zha_gateway) @@ -192,7 +211,39 @@ async def test_firmware_update_notification_from_zigpy(zha_gateway: Gateway) -> == f"0x{fw_image.firmware.header.file_version:08x}" ) + # property coverage + assert entity.installed_version == f"0x{installed_fw_version:08x}" + assert entity.latest_version == f"0x{fw_image.firmware.header.file_version:08x}" + assert entity.in_progress is False + assert entity.update_percentage is None + assert entity.release_notes is None + assert entity.release_url is None + assert ( + entity.supported_features + == UpdateEntityFeature.INSTALL + | UpdateEntityFeature.SPECIFIC_VERSION + | UpdateEntityFeature.PROGRESS + ) + assert entity.release_summary == "This is a test firmware image!" + assert entity.state_attributes == { + ATTR_INSTALLED_VERSION: f"0x{installed_fw_version:08x}", + ATTR_IN_PROGRESS: False, + ATTR_UPDATE_PERCENTAGE: None, + ATTR_LATEST_VERSION: f"0x{fw_image.firmware.header.file_version:08x}", + ATTR_RELEASE_SUMMARY: "This is a test firmware image!", + ATTR_RELEASE_NOTES: None, + ATTR_RELEASE_URL: None, + } + +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) @patch("zigpy.device.AFTER_OTA_ATTR_READ_DELAY", 0.01) async def test_firmware_update_success(zha_gateway: Gateway) -> None: """Test ZHA update platform - firmware update success.""" @@ -304,18 +355,22 @@ async def endpoint_reply(cluster, sequence, data, **kwargs): # make sure the state machine gets progress reports - assert ( - entity.state[ATTR_INSTALLED_VERSION] - == f"0x{installed_fw_version:08x}" - ) - assert entity.state[ATTR_IN_PROGRESS] is True - assert entity.state[ATTR_UPDATE_PERCENTAGE] == pytest.approx( - 100 * (40 / 70) - ) - assert ( - entity.state[ATTR_LATEST_VERSION] - == f"0x{fw_image.firmware.header.file_version:08x}" - ) + # TODO I can't figure out how to allow the server to send the progress to the client in the + # test. This all happens in a tight loop so the state doesn't get to the client until + # this is all complete... I think. + if not hasattr(zha_gateway, "ws_gateway"): + assert ( + entity.state[ATTR_INSTALLED_VERSION] + == f"0x{installed_fw_version:08x}" + ) + assert entity.state[ATTR_IN_PROGRESS] is True + assert entity.state[ATTR_UPDATE_PERCENTAGE] == pytest.approx( + 100 * (40 / 70) + ) + assert ( + entity.state[ATTR_LATEST_VERSION] + == f"0x{fw_image.firmware.header.file_version:08x}" + ) zigpy_device.packet_received( make_packet( @@ -365,12 +420,21 @@ def read_new_fw_version(*args, **kwargs): assert not entity.state[ATTR_IN_PROGRESS] assert entity.state[ATTR_LATEST_VERSION] == entity.state[ATTR_INSTALLED_VERSION] - # If we send a progress notification incorrectly, it won't be handled - entity._update_progress(50, 100, 0.50) + if not hasattr(zha_gateway, "ws_gateway"): + # If we send a progress notification incorrectly, it won't be handled + entity._update_progress(50, 100, 0.50) - assert not entity.state[ATTR_IN_PROGRESS] + assert not entity.state[ATTR_IN_PROGRESS] +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_firmware_update_raises(zha_gateway: Gateway) -> None: """Test ZHA update platform - firmware update raises.""" zigpy_device = zigpy_device_mock(zha_gateway) @@ -448,6 +512,14 @@ async def endpoint_reply(cluster, sequence, data, **kwargs): await zha_gateway.async_block_till_done() +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_firmware_update_downgrade(zha_gateway: Gateway) -> None: """Test ZHA update platform - force a firmware downgrade.""" zigpy_device = zigpy_device_mock(zha_gateway) @@ -523,6 +595,14 @@ async def test_firmware_update_downgrade(zha_gateway: Gateway) -> None: ) +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_firmware_update_no_image(zha_gateway: Gateway) -> None: """Test ZHA update platform - no images exist.""" zigpy_device = zigpy_device_mock(zha_gateway) @@ -566,6 +646,14 @@ async def test_firmware_update_no_image(zha_gateway: Gateway) -> None: assert entity.state[ATTR_LATEST_VERSION] is None +@pytest.mark.parametrize( + "zha_gateway", + [ + "zha_gateway", + "ws_gateways", + ], + indirect=True, +) async def test_firmware_update_latest_version_even_if_downgrade( zha_gateway: Gateway, ) -> None: diff --git a/tests/websocket/__init__.py b/tests/websocket/__init__.py new file mode 100644 index 000000000..a766f6adb --- /dev/null +++ b/tests/websocket/__init__.py @@ -0,0 +1 @@ +"""Websocket tests modules.""" diff --git a/tests/websocket/test_client_controller.py b/tests/websocket/test_client_controller.py new file mode 100644 index 000000000..c8ce6e6a3 --- /dev/null +++ b/tests/websocket/test_client_controller.py @@ -0,0 +1,521 @@ +"""Test zha switch.""" + +import logging +from typing import Optional +from unittest.mock import AsyncMock, MagicMock, call + +import pytest +from zigpy.device import Device as ZigpyDevice +from zigpy.profiles import zha +from zigpy.types.named import EUI64 +from zigpy.zcl.clusters import general + +from tests.conftest import CombinedWebsocketGateways +from zha.application.discovery import Platform +from zha.application.gateway import ( + DeviceJoinedDeviceInfo, + DevicePairingStatus, + RawDeviceInitializedDeviceInfo, + RawDeviceInitializedEvent, + WebSocketServerGateway, +) +from zha.application.model import DeviceJoinedEvent, DeviceLeftEvent +from zha.application.platforms import WebSocketClientEntity +from zha.application.platforms.switch import WebSocketClientSwitchEntity +from zha.const import ControllerEvents +from zha.websocket.server.api.model import ( + ReadClusterAttributesResponse, + WriteClusterAttributeResponse, +) +from zha.zigbee.device import Device, WebSocketClientDevice +from zha.zigbee.group import ( + Group, + GroupMemberReference, + WebSocketClientGroup, + WebSocketClientGroupMember, +) +from zha.zigbee.model import GroupInfo + +from ..common import ( + SIG_EP_INPUT, + SIG_EP_OUTPUT, + SIG_EP_PROFILE, + SIG_EP_TYPE, + create_mock_zigpy_device, + find_entity, + join_zigpy_device, + update_attribute_cache, +) + +ON = 1 +OFF = 0 +IEEE_GROUPABLE_DEVICE = "01:2d:6f:00:0a:90:69:e8" +IEEE_GROUPABLE_DEVICE2 = "02:2d:6f:00:0a:90:69:e8" +_LOGGER = logging.getLogger(__name__) + + +def zigpy_device_mock( + zha_gateway: WebSocketServerGateway, +) -> ZigpyDevice: + """Device tracker zigpy device.""" + endpoints = { + 1: { + SIG_EP_INPUT: [general.Basic.cluster_id, general.OnOff.cluster_id], + SIG_EP_OUTPUT: [], + SIG_EP_TYPE: zha.DeviceType.ON_OFF_SWITCH, + SIG_EP_PROFILE: zha.PROFILE_ID, + } + } + return create_mock_zigpy_device(zha_gateway, endpoints) + + +async def device_switch_1_mock( + zha_gateway: WebSocketServerGateway, +) -> Device: + """Test zha switch platform.""" + + zigpy_dev = create_mock_zigpy_device( + zha_gateway, + { + 1: { + SIG_EP_INPUT: [general.OnOff.cluster_id, general.Groups.cluster_id], + SIG_EP_OUTPUT: [], + SIG_EP_TYPE: zha.DeviceType.ON_OFF_SWITCH, + SIG_EP_PROFILE: zha.PROFILE_ID, + } + }, + ieee=IEEE_GROUPABLE_DEVICE, + ) + zha_device = await join_zigpy_device(zha_gateway, zigpy_dev) + ws_server_device = zha_gateway.ws_gateway.devices[zha_device.ieee] + ws_server_device.update_available(available=True, on_network=zha_device.on_network) + return zha_device + + +def get_group_entity( + group_proxy: WebSocketClientGroup, entity_id: str +) -> Optional[WebSocketClientEntity]: + """Get entity.""" + + return group_proxy.group_entities.get(entity_id) + + +async def device_switch_2_mock( + zha_gateway: WebSocketServerGateway, +) -> Device: + """Test zha switch platform.""" + + zigpy_dev = create_mock_zigpy_device( + zha_gateway, + { + 1: { + SIG_EP_INPUT: [general.OnOff.cluster_id, general.Groups.cluster_id], + SIG_EP_OUTPUT: [], + SIG_EP_TYPE: zha.DeviceType.ON_OFF_SWITCH, + SIG_EP_PROFILE: zha.PROFILE_ID, + } + }, + ieee=IEEE_GROUPABLE_DEVICE2, + ) + zha_device = await join_zigpy_device(zha_gateway, zigpy_dev) + ws_server_device = zha_gateway.ws_gateway.devices[zha_device.ieee] + ws_server_device.update_available(available=True, on_network=zha_device.on_network) + return zha_device + + +@pytest.mark.parametrize( + "zha_gateway", + [ + "ws_gateways", + ], + indirect=True, +) +async def test_ws_client_gateway_devices( + zha_gateway: CombinedWebsocketGateways, +) -> None: + """Test client ws_client_gateway device related functionality.""" + ws_client_gateway = zha_gateway.client_gateway + zigpy_device = zigpy_device_mock(zha_gateway) + zha_device = await join_zigpy_device(zha_gateway, zigpy_device) + + client_device: Optional[WebSocketClientDevice] = ws_client_gateway.devices.get( + zha_device.ieee + ) + assert client_device is not None + + entity = find_entity(client_device, Platform.SWITCH) + assert entity is not None + + assert isinstance(entity, WebSocketClientSwitchEntity) + + assert entity.state["state"] is False + + await ws_client_gateway.load_devices() + devices: dict[EUI64, WebSocketClientDevice] = ws_client_gateway.devices + assert len(devices) == 2 + assert zha_device.ieee in devices + + # test client -> ws_server_gateway + zha_gateway.application_controller.remove = AsyncMock( + wraps=zha_gateway.application_controller.remove + ) + await ws_client_gateway.devices_helper.remove_device( + client_device._extended_device_info + ) + assert zha_gateway.application_controller.remove.await_count == 1 + assert zha_gateway.application_controller.remove.await_args == call( + client_device.ieee + ) + + # test zha_gateway -> client + zha_gateway.ws_gateway.device_removed(zigpy_device) + await zha_gateway.async_block_till_done() + assert len(ws_client_gateway.devices) == 1 + + # rejoin the device + zha_device = await join_zigpy_device(zha_gateway, zigpy_device) + await zha_gateway.async_block_till_done() + assert len(ws_client_gateway.devices) == 2 + + # test rejoining the same device + zha_device = await join_zigpy_device(zha_gateway, zigpy_device) + await zha_gateway.async_block_till_done() + assert len(ws_client_gateway.devices) == 2 + + # test client gateway device removal + await ws_client_gateway.async_remove_device(zha_device.ieee) + await zha_gateway.async_block_till_done() + assert len(ws_client_gateway.devices) == 1 + + # lets kill the network and then start it back up to make sure everything is still in working order + await ws_client_gateway.network.stop_network() + + assert zha_gateway.application_controller is None + + await ws_client_gateway.network.start_network() + + assert zha_gateway.application_controller is not None + + # let's add it back + zha_device = await join_zigpy_device(zha_gateway, zigpy_device) + await zha_gateway.async_block_till_done() + assert len(ws_client_gateway.devices) == 2 + + # we removed and joined the device again so lets get the entity again + client_device = ws_client_gateway.devices.get(zha_device.ieee) + assert client_device is not None + + entity = find_entity(client_device, Platform.SWITCH) + assert entity is not None + + # test device reconfigure + ws_server_device = zha_gateway.ws_gateway.devices[zha_device.ieee] + async_configure_mock = AsyncMock(wraps=ws_server_device.async_configure) + ws_server_device.async_configure = async_configure_mock + + await ws_client_gateway.devices_helper.reconfigure_device( + client_device._extended_device_info + ) + await zha_gateway.async_block_till_done() + assert async_configure_mock.call_count == 1 + assert async_configure_mock.await_count == 1 + assert async_configure_mock.call_args == call() + + # test read cluster attribute + cluster = zigpy_device.endpoints.get(1).on_off + assert cluster is not None + cluster.PLUGGED_ATTR_READS = {general.OnOff.AttributeDefs.on_off.name: 1} + update_attribute_cache(cluster) + await ws_client_gateway.entities.refresh_state(entity.info_object) + await zha_gateway.async_block_till_done() + read_response: ReadClusterAttributesResponse = ( + await ws_client_gateway.devices_helper.read_cluster_attributes( + client_device._extended_device_info, + general.OnOff.cluster_id, + "in", + 1, + [general.OnOff.AttributeDefs.on_off.name], + ) + ) + await zha_gateway.async_block_till_done() + assert read_response is not None + assert read_response.success is True + assert len(read_response.succeeded) == 1 + assert len(read_response.failed) == 0 + assert read_response.succeeded[general.OnOff.AttributeDefs.on_off.name] == 1 + assert read_response.cluster.id == general.OnOff.cluster_id + assert read_response.cluster.endpoint_id == 1 + assert ( + read_response.cluster.endpoint_attribute + == general.OnOff.AttributeDefs.on_off.name + ) + assert read_response.cluster.name == general.OnOff.name + assert entity.state["state"] is True + + # test write cluster attribute + write_response: WriteClusterAttributeResponse = ( + await ws_client_gateway.devices_helper.write_cluster_attribute( + client_device._extended_device_info, + general.OnOff.cluster_id, + "in", + 1, + general.OnOff.AttributeDefs.on_off.name, + 0, + ) + ) + assert write_response is not None + assert write_response.success is True + assert write_response.cluster.id == general.OnOff.cluster_id + assert write_response.cluster.endpoint_id == 1 + assert ( + write_response.cluster.endpoint_attribute + == general.OnOff.AttributeDefs.on_off.name + ) + assert write_response.cluster.name == general.OnOff.name + + await ws_client_gateway.entities.refresh_state(entity.info_object) + await zha_gateway.async_block_till_done() + assert entity.state["state"] is False + + # test ws_client_gateway events + listener = MagicMock() + + # test device joined + ws_client_gateway.on_event(ControllerEvents.DEVICE_JOINED, listener) + device_joined_event = DeviceJoinedEvent( + device_info=DeviceJoinedDeviceInfo( + pairing_status=DevicePairingStatus.PAIRED, + ieee=zigpy_device.ieee, + nwk=zigpy_device.nwk, + ) + ) + zha_gateway.ws_gateway.device_joined(zigpy_device) + await zha_gateway.async_block_till_done() + assert listener.call_count == 1 + assert listener.call_args == call(device_joined_event) + + # test device left + listener.reset_mock() + ws_client_gateway.on_event(ControllerEvents.DEVICE_LEFT, listener) + zha_gateway.ws_gateway.device_left(zigpy_device) + await zha_gateway.async_block_till_done() + assert listener.call_count == 1 + assert listener.call_args == call( + DeviceLeftEvent( + ieee=zigpy_device.ieee, + nwk=str(zigpy_device.nwk).lower(), + ) + ) + + # test raw device initialized + listener.reset_mock() + ws_client_gateway.on_event(ControllerEvents.RAW_DEVICE_INITIALIZED, listener) + zha_gateway.ws_gateway.raw_device_initialized(zigpy_device) + await zha_gateway.async_block_till_done() + assert listener.call_count == 1 + assert listener.call_args == call( + RawDeviceInitializedEvent( + device_info=RawDeviceInitializedDeviceInfo( + pairing_status=DevicePairingStatus.INTERVIEW_COMPLETE, + ieee=zigpy_device.ieee, + nwk=zigpy_device.nwk, + manufacturer=client_device.manufacturer, + model=client_device.model, + signature=client_device._extended_device_info.signature, + ), + ) + ) + + # test topology scan + zha_gateway.application_controller.topology.scan = AsyncMock() + await ws_client_gateway.network.update_topology() + assert zha_gateway.application_controller.topology.scan.await_count == 1 + + # test permit join + zha_gateway.application_controller.permit = AsyncMock() + await ws_client_gateway.network.permit_joining(60) + assert zha_gateway.application_controller.permit.await_count == 1 + assert zha_gateway.application_controller.permit.await_args == call(60, None) + + +@pytest.mark.parametrize( + "zha_gateway", + [ + "ws_gateways", + ], + indirect=True, +) +async def test_ws_client_gateway_groups( + zha_gateway: CombinedWebsocketGateways, +) -> None: + """Test client ws_client_gateway group related functionality.""" + ws_client_gateway = zha_gateway.client_gateway + device_switch_1: Device = await device_switch_1_mock(zha_gateway) + device_switch_2: Device = await device_switch_2_mock(zha_gateway) + member_ieee_addresses = [device_switch_1.ieee, device_switch_2.ieee] + members = [ + GroupMemberReference(ieee=device_switch_1.ieee, endpoint_id=1), + GroupMemberReference(ieee=device_switch_2.ieee, endpoint_id=1), + ] + + # test creating a group with 2 members + zha_group: Group = await zha_gateway.async_create_zigpy_group("Test Group", members) + await zha_gateway.async_block_till_done() + + assert zha_group is not None + assert len(zha_group.members) == 2 + for member in zha_group.members: + assert member.device.ieee in member_ieee_addresses + assert member.group == zha_group + assert member.endpoint_id == 1 + + entity_id = f"{Platform.SWITCH}_zha_group_0x{zha_group.group_id:04x}" + assert entity_id is not None + + group_proxy: Optional[WebSocketClientGroup] = ws_client_gateway.groups.get( + zha_group.group_id + ) + assert group_proxy is not None + + entity: WebSocketClientSwitchEntity = get_group_entity(group_proxy, entity_id) # type: ignore + assert entity is not None + + assert isinstance(entity, WebSocketClientSwitchEntity) + + assert entity is not None + + await ws_client_gateway.load_groups() + groups: dict[int, WebSocketClientGroup] = ws_client_gateway.groups + # the application ws_client_gateway mock starts with a group already created + assert len(groups) == 2 + assert zha_group.group_id in groups + + # test client -> zha_gateway + await ws_client_gateway.groups_helper.remove_groups([group_proxy._group_info]) + await zha_gateway.async_block_till_done() + assert len(ws_client_gateway.groups) == 1 + + # test client create group + client_device1: Optional[WebSocketClientDevice] = ws_client_gateway.devices.get( + device_switch_1.ieee + ) + assert client_device1 is not None + + entity1: WebSocketClientSwitchEntity = find_entity(client_device1, Platform.SWITCH) + assert entity1 is not None + + client_device2: Optional[WebSocketClientDevice] = ws_client_gateway.devices.get( + device_switch_2.ieee + ) + assert client_device2 is not None + + entity2: WebSocketClientSwitchEntity = find_entity(client_device2, Platform.SWITCH) + assert entity2 is not None + + response: GroupInfo = await ws_client_gateway.groups_helper.create_group( + members=[ + GroupMemberReference( + ieee=entity1.info_object.device_ieee, + endpoint_id=entity1.info_object.endpoint_id, + ), + GroupMemberReference( + ieee=entity2.info_object.device_ieee, + endpoint_id=entity2.info_object.endpoint_id, + ), + ], + name="Test Group Controller", + ) + await zha_gateway.async_block_till_done() + assert len(ws_client_gateway.groups) == 2 + assert response.group_id in ws_client_gateway.groups + assert response.name == "Test Group Controller" + assert client_device1.ieee in response.members_by_ieee + assert client_device2.ieee in response.members_by_ieee + + # test get group from ws_client_gateway + group_from_ws_client_gateway = ws_client_gateway.get_group(response.group_id) + assert group_from_ws_client_gateway is not None + assert group_from_ws_client_gateway.group_id == response.group_id + assert group_from_ws_client_gateway.name == response.name + assert ( + group_from_ws_client_gateway.info_object.members_by_ieee + == response.members_by_ieee + ) + + # test get group from ws_client_gateway by group name + group_from_ws_client_gateway = ws_client_gateway.get_group(response.name) + assert group_from_ws_client_gateway is not None + assert group_from_ws_client_gateway.group_id == response.group_id + assert group_from_ws_client_gateway.name == response.name + assert ( + group_from_ws_client_gateway.info_object.members_by_ieee + == response.members_by_ieee + ) + + # test remove member from group from ws_client_gateway + response = await ws_client_gateway.groups_helper.remove_group_members( + response, + [ + GroupMemberReference( + ieee=entity2.info_object.device_ieee, + endpoint_id=entity2.info_object.endpoint_id, + ) + ], + ) + await zha_gateway.async_block_till_done() + assert len(ws_client_gateway.groups) == 2 + assert response.group_id in ws_client_gateway.groups + assert response.name == "Test Group Controller" + assert client_device1.ieee in response.members_by_ieee + assert client_device2.ieee not in response.members_by_ieee + + # test add member to group from ws_client_gateway + response = await ws_client_gateway.groups_helper.add_group_members( + response, + [ + GroupMemberReference( + ieee=entity2.info_object.device_ieee, + endpoint_id=entity2.info_object.endpoint_id, + ) + ], + ) + await zha_gateway.async_block_till_done() + assert len(ws_client_gateway.groups) == 2 + assert response.group_id in ws_client_gateway.groups + assert response.name == "Test Group Controller" + assert client_device1.ieee in response.members_by_ieee + assert client_device2.ieee in response.members_by_ieee + + # test member info and removal from member + + member_info = response.members_by_ieee[client_device1.ieee] + assert member_info is not None + assert member_info.endpoint_id == entity1.info_object.endpoint_id + assert member_info.ieee == entity1.info_object.device_ieee + assert member_info.device_info is not None + assert member_info.device_info.ieee == entity1._device.extended_device_info.ieee + assert member_info.device_info.nwk == entity1._device.extended_device_info.nwk + assert ( + member_info.device_info.manufacturer + == entity1._device.extended_device_info.manufacturer + ) + assert member_info.device_info.model == entity1._device.extended_device_info.model + assert ( + member_info.device_info.signature + == entity1._device.extended_device_info.signature + ) + + client_group: WebSocketClientGroup = ws_client_gateway.get_group(response.group_id) + assert client_group is not None + members = client_group.members + assert len(members) == 2 + entity_1_member: WebSocketClientGroupMember + for member in members: + if member.member_info.ieee == entity1.info_object.device_ieee: + entity_1_member = member + break + + assert entity_1_member is not None + await entity_1_member.async_remove_from_group() + await zha_gateway.async_block_till_done() + assert len(client_group.members) == 1 diff --git a/tests/websocket/test_websocket_server_client.py b/tests/websocket/test_websocket_server_client.py new file mode 100644 index 000000000..e194ffaa7 --- /dev/null +++ b/tests/websocket/test_websocket_server_client.py @@ -0,0 +1,100 @@ +"""Tests for the server and client.""" + +from __future__ import annotations + +from unittest.mock import patch + +import pytest + +from tests.conftest import CombinedWebsocketGateways +from zha.application.gateway import WebSocketClientGateway, WebSocketServerGateway +from zha.application.model import ZHAData +from zha.websocket import ZHAWebSocketException +from zha.websocket.client.client import Client +from zha.websocket.client.helpers import ClientHelper + + +async def test_server_client_connect_disconnect( + zha_data: ZHAData, +) -> None: + """Tests basic connect/disconnect logic.""" + + async with WebSocketServerGateway(zha_data) as gateway: + assert gateway.is_serving + assert gateway._ws_server is not None + + async with Client(f"ws://localhost:{zha_data.ws_server_config.port}") as client: + assert client.connected + assert "connected" in repr(client) + + # The client does not begin listening immediately + assert client._listen_task is None + await client.listen() + assert client._listen_task is not None + + # The listen task is automatically stopped when we disconnect + assert client._listen_task is None + assert "not connected" in repr(client) + assert not client.connected + + async with WebSocketClientGateway(zha_data) as client_gateway: + assert client_gateway.client.connected + assert client_gateway.client._listen_task is not None + + assert not client_gateway.client.connected + assert client_gateway.client._listen_task is None + + assert not gateway.is_serving + assert gateway._ws_server is None + + with ( + pytest.raises(ZHAWebSocketException), + patch("zha.websocket.client.client.Client.connect", side_effect=TimeoutError), + ): + async with WebSocketClientGateway(zha_data) as client_gateway: + assert client_gateway.client.connected + + +async def test_client_helper_disconnect( + zha_data: ZHAData, +) -> None: + """Tests client helper disconnect logic.""" + + async with WebSocketServerGateway(zha_data) as gateway: + assert gateway.is_serving + assert gateway._ws_server is not None + + client = Client(f"ws://localhost:{zha_data.ws_server_config.port}") + client_helper = ClientHelper(client) + + await client.connect() + assert client.connected + assert "connected" in repr(client) + + # The client does not begin listening immediately + assert client._listen_task is None + await client_helper.listen() + assert client._listen_task is not None + + await client_helper.disconnect() + assert client._listen_task is None + assert "not connected" in repr(client) + assert not client.connected + + assert not gateway.is_serving + assert gateway._ws_server is None + + +@pytest.mark.parametrize( + "zha_gateway", + [ + "ws_gateways", + ], + indirect=True, +) +async def test_client_message_id_uniqueness( + zha_gateway: CombinedWebsocketGateways, +) -> None: + """Tests that client message IDs are unique.""" + ids = [zha_gateway.client_gateway.client.new_message_id() for _ in range(1000)] + assert len(ids) == len(set(ids)) diff --git a/zha/application/discovery.py b/zha/application/discovery.py index a66c0cfd6..077edf03e 100644 --- a/zha/application/discovery.py +++ b/zha/application/discovery.py @@ -22,7 +22,6 @@ from zigpy.zcl.clusters.general import Ota from zha.application import Platform, const as zha_const -from zha.application.helpers import DeviceOverridesConfiguration from zha.application.platforms import ( # noqa: F401 pylint: disable=unused-import alarm_control_panel, binary_sensor, @@ -33,6 +32,7 @@ fan, light, lock, + model, number, select, sensor, @@ -40,7 +40,42 @@ switch, update, ) +from zha.application.platforms.alarm_control_panel import AlarmControlPanelEntityInfo +from zha.application.platforms.binary_sensor.model import BinarySensorEntityInfo +from zha.application.platforms.button.model import ( + ButtonEntityInfo, + CommandButtonEntityInfo, + WriteAttributeButtonEntityInfo, +) +from zha.application.platforms.climate.model import ThermostatEntityInfo +from zha.application.platforms.cover.model import CoverEntityInfo, ShadeEntityInfo +from zha.application.platforms.device_tracker.model import DeviceTrackerEntityInfo +from zha.application.platforms.fan.model import FanEntityInfo +from zha.application.platforms.light.model import LightEntityInfo +from zha.application.platforms.lock.model import LockEntityInfo +from zha.application.platforms.number.model import ( + NumberConfigurationEntityInfo, + NumberEntityInfo, +) +from zha.application.platforms.select.model import ( + EnumSelectEntityInfo, + SelectEntityInfo, +) from zha.application.platforms.sensor.const import SensorDeviceClass +from zha.application.platforms.sensor.model import ( + BatteryEntityInfo, + DeviceCounterSensorEntityInfo, + ElectricalMeasurementEntityInfo, + SensorEntityInfo, + SetpointChangeSourceTimestampSensorEntityInfo, + SmartEnergyMeteringEntityInfo, +) +from zha.application.platforms.siren.model import SirenEntityInfo +from zha.application.platforms.switch.model import ( + ConfigurableAttributeSwitchEntityInfo, + SwitchEntityInfo, +) +from zha.application.platforms.update.model import FirmwareUpdateEntityInfo from zha.application.registries import ( DEVICE_CLASS, PLATFORM_ENTITIES, @@ -72,6 +107,7 @@ if TYPE_CHECKING: from zha.application.gateway import Gateway + from zha.application.model import DeviceOverridesConfiguration from zha.zigbee.device import Device from zha.zigbee.endpoint import Endpoint @@ -173,6 +209,35 @@ SensorDeviceClass.TIMESTAMP: sensor.TimestampSensor } +ENTITY_INFO_CLASS_TO_WEBSOCKET_CLIENT_ENTITY_CLASS = { + AlarmControlPanelEntityInfo: alarm_control_panel.WebSocketClientAlarmControlPanel, + BinarySensorEntityInfo: binary_sensor.WebSocketClientBinarySensor, + ButtonEntityInfo: button.WebSocketClientButtonEntity, + CommandButtonEntityInfo: button.WebSocketClientButtonEntity, + WriteAttributeButtonEntityInfo: button.WebSocketClientButtonEntity, + ThermostatEntityInfo: climate.WebSocketClientThermostatEntity, + CoverEntityInfo: cover.WebSocketClientCoverEntity, + ShadeEntityInfo: cover.WebSocketClientCoverEntity, + DeviceTrackerEntityInfo: device_tracker.WebSocketClientDeviceTrackerEntity, + FanEntityInfo: fan.WebSocketClientFanEntity, + LightEntityInfo: light.WebSocketClientLightEntity, + LockEntityInfo: lock.WebSocketClientLockEntity, + NumberEntityInfo: number.WebSocketClientNumberEntity, + SelectEntityInfo: select.WebSocketClientSelectEntity, + SensorEntityInfo: sensor.WebSocketClientSensorEntity, + SirenEntityInfo: siren.WebSocketClientSirenEntity, + SwitchEntityInfo: switch.WebSocketClientSwitchEntity, + FirmwareUpdateEntityInfo: update.WebSocketClientFirmwareUpdateEntity, + BatteryEntityInfo: sensor.WebSocketClientSensorEntity, + ElectricalMeasurementEntityInfo: sensor.WebSocketClientSensorEntity, + SmartEnergyMeteringEntityInfo: sensor.WebSocketClientSensorEntity, + DeviceCounterSensorEntityInfo: sensor.WebSocketClientSensorEntity, + SetpointChangeSourceTimestampSensorEntityInfo: sensor.WebSocketClientSensorEntity, + NumberConfigurationEntityInfo: number.WebSocketClientNumberEntity, + EnumSelectEntityInfo: select.WebSocketClientSelectEntity, + ConfigurableAttributeSwitchEntityInfo: switch.WebSocketClientSwitchEntity, +} + class DeviceProbe: """Probe to discover entities for a device.""" diff --git a/zha/application/gateway.py b/zha/application/gateway.py index 60ab3ca05..9a7d91f18 100644 --- a/zha/application/gateway.py +++ b/zha/application/gateway.py @@ -2,15 +2,19 @@ from __future__ import annotations +from abc import ABC, abstractmethod import asyncio +from collections.abc import Coroutine +import contextlib from contextlib import suppress -from dataclasses import dataclass from datetime import timedelta -from enum import Enum import logging import time -from typing import Any, Final, Self, TypeVar, cast +from types import TracebackType +from typing import TYPE_CHECKING, Any, Final, Self, cast +from async_timeout import timeout +import websockets from zhaquirks import setup as setup_quirks from zigpy.application import ControllerApplication from zigpy.config import ( @@ -26,166 +30,183 @@ from zigpy.quirks.v2 import UNBUILT_QUIRK_BUILDERS from zigpy.state import State from zigpy.types.named import EUI64 +from zigpy.zdo import ZDO from zha.application import discovery from zha.application.const import ( + ATTR_DEVICE_TYPE, + ATTR_ENDPOINTS, + ATTR_MANUFACTURER, + ATTR_MODEL, + ATTR_NODE_DESCRIPTOR, + ATTR_PROFILE_ID, CONF_USE_THREAD, UNKNOWN_MANUFACTURER, UNKNOWN_MODEL, - ZHA_GW_MSG, ZHA_GW_MSG_CONNECTION_LOST, ZHA_GW_MSG_DEVICE_FULL_INIT, ZHA_GW_MSG_DEVICE_JOINED, ZHA_GW_MSG_DEVICE_LEFT, ZHA_GW_MSG_DEVICE_REMOVED, - ZHA_GW_MSG_GROUP_ADDED, - ZHA_GW_MSG_GROUP_MEMBER_ADDED, - ZHA_GW_MSG_GROUP_MEMBER_REMOVED, - ZHA_GW_MSG_GROUP_REMOVED, ZHA_GW_MSG_RAW_INIT, RadioType, ) -from zha.application.helpers import DeviceAvailabilityChecker, GlobalUpdater, ZHAData +from zha.application.helpers import DeviceAvailabilityChecker, GlobalUpdater +from zha.application.model import ( + ConnectionLostEvent, + DeviceFullyInitializedEvent, + DeviceJoinedDeviceInfo, + DeviceJoinedEvent, + DeviceLeftEvent, + DeviceOfflineEvent, + DeviceOnlineEvent, + DevicePairingStatus, + DeviceRemovedEvent, + ExtendedDeviceInfoWithPairingStatus, + GroupAddedEvent, + GroupMemberAddedEvent, + GroupMemberRemovedEvent, + GroupRemovedEvent, + RawDeviceInitializedDeviceInfo, + RawDeviceInitializedEvent, + ZHAData, +) +from zha.application.platforms.websocket_api import load_platform_entity_apis +from zha.application.websocket_api import load_api as load_zigbee_controller_api from zha.async_ import ( AsyncUtilMixin, create_eager_task, gather_with_limited_concurrency, ) +from zha.const import ControllerEvents, DeviceEvents from zha.event import EventBase -from zha.zigbee.device import Device, DeviceInfo, DeviceStatus, ExtendedDeviceInfo -from zha.zigbee.group import Group, GroupInfo, GroupMemberReference +from zha.model import BaseEvent +from zha.websocket import ZHAWebSocketException +from zha.websocket.client.client import Client +from zha.websocket.client.helpers import ( + AlarmControlPanelHelper, + ButtonHelper, + ClientHelper, + ClimateHelper, + CoverHelper, + DeviceHelper, + FanHelper, + GroupHelper, + LightHelper, + LockHelper, + NetworkHelper, + NumberHelper, + PlatformEntityHelper, + SelectHelper, + ServerHelper, + SirenHelper, + SwitchHelper, + UpdateHelper, +) +from zha.websocket.const import WEBSOCKET_API +from zha.websocket.server.client import ClientManager, load_api as load_client_api +from zha.zigbee.device import BaseDevice, Device, WebSocketClientDevice +from zha.zigbee.endpoint import ATTR_IN_CLUSTERS, ATTR_OUT_CLUSTERS +from zha.zigbee.group import ( + BaseGroup, + Group, + GroupMemberReference, + WebSocketClientGroup, +) +from zha.zigbee.model import DeviceStatus + +if TYPE_CHECKING: + from zha.application.platforms.events import EntityStateChangedEvent + from zha.zigbee.model import ExtendedDeviceInfo, ZHAEvent BLOCK_LOG_TIMEOUT: Final[int] = 60 -_R = TypeVar("_R") _LOGGER = logging.getLogger(__name__) -class DevicePairingStatus(Enum): - """Status of a device.""" - - PAIRED = 1 - INTERVIEW_COMPLETE = 2 - CONFIGURED = 3 - INITIALIZED = 4 - - -@dataclass(kw_only=True, frozen=True) -class DeviceInfoWithPairingStatus(DeviceInfo): - """Information about a device with pairing status.""" - - pairing_status: DevicePairingStatus - - -@dataclass(kw_only=True, frozen=True) -class ExtendedDeviceInfoWithPairingStatus(ExtendedDeviceInfo): - """Information about a device with pairing status.""" - - pairing_status: DevicePairingStatus - - -@dataclass(kw_only=True, frozen=True) -class DeviceJoinedDeviceInfo: - """Information about a device.""" - - ieee: str - nwk: int - pairing_status: DevicePairingStatus +class BaseGateway(EventBase, ABC): + """Base gateway class.""" + def __init__(self, config: ZHAData) -> None: + """Initialize the gateway.""" + super().__init__() + self.config: ZHAData = config + self.config.gateway = self -@dataclass(kw_only=True, frozen=True) -class ConnectionLostEvent: - """Event to signal that the connection to the radio has been lost.""" - - event_type: Final[str] = ZHA_GW_MSG - event: Final[str] = ZHA_GW_MSG_CONNECTION_LOST - exception: Exception | None = None - - -@dataclass(kw_only=True, frozen=True) -class DeviceJoinedEvent: - """Event to signal that a device has joined the network.""" - - device_info: DeviceJoinedDeviceInfo - event_type: Final[str] = ZHA_GW_MSG - event: Final[str] = ZHA_GW_MSG_DEVICE_JOINED - - -@dataclass(kw_only=True, frozen=True) -class DeviceLeftEvent: - """Event to signal that a device has left the network.""" - - ieee: EUI64 - nwk: int - event_type: Final[str] = ZHA_GW_MSG - event: Final[str] = ZHA_GW_MSG_DEVICE_LEFT - - -@dataclass(kw_only=True, frozen=True) -class RawDeviceInitializedDeviceInfo(DeviceJoinedDeviceInfo): - """Information about a device that has been initialized without quirks loaded.""" - - model: str - manufacturer: str - signature: dict[str, Any] - - -@dataclass(kw_only=True, frozen=True) -class RawDeviceInitializedEvent: - """Event to signal that a device has been initialized without quirks loaded.""" - - device_info: RawDeviceInitializedDeviceInfo - event_type: Final[str] = ZHA_GW_MSG - event: Final[str] = ZHA_GW_MSG_RAW_INIT + @abstractmethod + async def _async_initialize(self) -> None: + """Initialize controller and connect radio.""" + @abstractmethod + def _find_coordinator_device(self) -> zigpy.device.Device: + """Find the coordinator device.""" -@dataclass(kw_only=True, frozen=True) -class DeviceFullInitEvent: - """Event to signal that a device has been fully initialized.""" + @abstractmethod + async def async_initialize_devices_and_entities(self) -> None: + """Initialize devices and load entities.""" - device_info: ExtendedDeviceInfoWithPairingStatus - new_join: bool = False - event_type: Final[str] = ZHA_GW_MSG - event: Final[str] = ZHA_GW_MSG_DEVICE_FULL_INIT + @property + @abstractmethod + def state(self) -> State: + """Return the active coordinator's network state.""" + @abstractmethod + def get_or_create_device( + self, zigpy_device: zigpy.device.Device | ExtendedDeviceInfo + ) -> BaseDevice: + """Get or create a ZHA device.""" -@dataclass(kw_only=True, frozen=True) -class GroupEvent: - """Event to signal a group event.""" + @abstractmethod + async def async_create_zigpy_group( + self, + name: str, + members: list[GroupMemberReference] | None, + group_id: int | None = None, + ) -> BaseGroup | None: + """Create a new Zigpy Zigbee group.""" - event: str - group_info: GroupInfo - event_type: Final[str] = ZHA_GW_MSG + @abstractmethod + async def async_remove_device(self, ieee: EUI64) -> None: + """Remove a device from ZHA.""" + @abstractmethod + async def async_remove_zigpy_group(self, group_id: int) -> None: + """Remove a Zigbee group from Zigpy.""" -@dataclass(kw_only=True, frozen=True) -class DeviceRemovedEvent: - """Event to signal that a device has been removed.""" + @abstractmethod + async def shutdown(self) -> None: + """Stop ZHA Controller Application.""" - device_info: ExtendedDeviceInfo - event_type: Final[str] = ZHA_GW_MSG - event: Final[str] = ZHA_GW_MSG_DEVICE_REMOVED + @abstractmethod + def broadcast_event(self, event: BaseEvent) -> None: + """Broadcast an event to all listeners.""" -class Gateway(AsyncUtilMixin, EventBase): +class Gateway(AsyncUtilMixin, BaseGateway): """Gateway that handles events that happen on the ZHA Zigbee network.""" def __init__(self, config: ZHAData) -> None: """Initialize the gateway.""" - super().__init__() - self.config: ZHAData = config + super().__init__(config) self._devices: dict[EUI64, Device] = {} self._groups: dict[int, Group] = {} - self.application_controller: ControllerApplication = None self.coordinator_zha_device: Device = None # type: ignore[assignment] - + self.application_controller: ControllerApplication = None self.shutting_down: bool = False self._reload_task: asyncio.Task | None = None - self.global_updater: GlobalUpdater = GlobalUpdater(self) self._device_availability_checker: DeviceAvailabilityChecker = ( DeviceAvailabilityChecker(self) ) - self.config.gateway = self + + @property + def devices(self) -> dict[EUI64, Device]: + """Return devices.""" + return self._devices + + @property + def groups(self) -> dict[int, Group]: + """Return groups.""" + return self._groups @property def radio_type(self) -> RadioType: @@ -436,7 +457,33 @@ def raw_device_initialized(self, device: zigpy.device.Device) -> None: # pylint manufacturer=device.manufacturer if device.manufacturer else UNKNOWN_MANUFACTURER, - signature=device.get_signature(), + signature={ + ATTR_NODE_DESCRIPTOR: device.node_desc.as_dict(), + ATTR_ENDPOINTS: { + ep_id: { + ATTR_PROFILE_ID: f"0x{endpoint.profile_id:04x}" + if endpoint.profile_id is not None + else "", + ATTR_DEVICE_TYPE: f"0x{endpoint.device_type:04x}" + if endpoint.device_type is not None + else "", + ATTR_IN_CLUSTERS: [ + f"0x{cluster_id:04x}" + for cluster_id in sorted(endpoint.in_clusters) + ], + ATTR_OUT_CLUSTERS: [ + f"0x{cluster_id:04x}" + for cluster_id in sorted(endpoint.out_clusters) + ], + } + for ep_id, endpoint in device.endpoints.items() + if not isinstance(endpoint, ZDO) + }, + ATTR_MANUFACTURER: device.manufacturer + if device.manufacturer + else UNKNOWN_MANUFACTURER, + ATTR_MODEL: device.model if device.model else UNKNOWN_MODEL, + }, ) ), ) @@ -472,10 +519,9 @@ def group_member_removed( """Handle zigpy group member removed event.""" # need to handle endpoint correctly on groups zha_group = self.get_or_create_group(zigpy_group) - zha_group.clear_caches() discovery.GROUP_PROBE.discover_group_entities(zha_group) zha_group.info("group_member_removed - endpoint: %s", endpoint) - self._emit_group_gateway_message(zigpy_group, ZHA_GW_MSG_GROUP_MEMBER_REMOVED) + self._emit_group_gateway_message(zigpy_group, GroupMemberRemovedEvent) def group_member_added( self, zigpy_group: zigpy.group.Group, endpoint: zigpy.endpoint.Endpoint @@ -483,38 +529,40 @@ def group_member_added( """Handle zigpy group member added event.""" # need to handle endpoint correctly on groups zha_group = self.get_or_create_group(zigpy_group) - zha_group.clear_caches() discovery.GROUP_PROBE.discover_group_entities(zha_group) zha_group.info("group_member_added - endpoint: %s", endpoint) - self._emit_group_gateway_message(zigpy_group, ZHA_GW_MSG_GROUP_MEMBER_ADDED) + self._emit_group_gateway_message(zigpy_group, GroupMemberAddedEvent) def group_added(self, zigpy_group: zigpy.group.Group) -> None: """Handle zigpy group added event.""" zha_group = self.get_or_create_group(zigpy_group) zha_group.info("group_added") # need to dispatch for entity creation here - self._emit_group_gateway_message(zigpy_group, ZHA_GW_MSG_GROUP_ADDED) + self._emit_group_gateway_message(zigpy_group, GroupAddedEvent) def group_removed(self, zigpy_group: zigpy.group.Group) -> None: """Handle zigpy group removed event.""" - self._emit_group_gateway_message(zigpy_group, ZHA_GW_MSG_GROUP_REMOVED) + self._emit_group_gateway_message(zigpy_group, GroupRemovedEvent) zha_group = self._groups.pop(zigpy_group.group_id) zha_group.info("group_removed") def _emit_group_gateway_message( # pylint: disable=unused-argument self, zigpy_group: zigpy.group.Group, - gateway_message_type: str, + gateway_message_type: GroupRemovedEvent + | GroupAddedEvent + | GroupMemberAddedEvent + | GroupMemberRemovedEvent, ) -> None: """Send the gateway event for a zigpy group event.""" zha_group = self._groups.get(zigpy_group.group_id) if zha_group is not None: + response = gateway_message_type( + group_info=zha_group.info_object, + ) self.emit( - gateway_message_type, - GroupEvent( - event=gateway_message_type, - group_info=zha_group.info_object, - ), + response.event, + response, ) def device_removed(self, device: zigpy.device.Device) -> None: @@ -552,16 +600,6 @@ def state(self) -> State: """Return the active coordinator's network state.""" return self.application_controller.state - @property - def devices(self) -> dict[EUI64, Device]: - """Return devices.""" - return self._devices - - @property - def groups(self) -> dict[int, Group]: - """Return groups.""" - return self._groups - def get_or_create_device(self, zigpy_device: zigpy.device.Device) -> Device: """Get or create a ZHA device.""" if (zha_device := self._devices.get(zigpy_device.ieee)) is None: @@ -587,7 +625,7 @@ def async_update_device( device = self.devices[sender.ieee] # avoid a race condition during new joins if device.status is DeviceStatus.INITIALIZED: - device.update_available(available) + device.update_available(available=available, on_network=available) async def async_device_initialized(self, device: zigpy.device.Device) -> None: """Handle device joined and basic information discovered (async).""" @@ -618,26 +656,26 @@ async def async_device_initialized(self, device: zigpy.device.Device) -> None: device_info = ExtendedDeviceInfoWithPairingStatus( pairing_status=DevicePairingStatus.INITIALIZED, - **zha_device.extended_device_info.__dict__, + **zha_device.extended_device_info.model_dump(), ) self.emit( ZHA_GW_MSG_DEVICE_FULL_INIT, - DeviceFullInitEvent(device_info=device_info), + DeviceFullyInitializedEvent(device_info=device_info), ) async def _async_device_joined(self, zha_device: Device) -> None: - zha_device.available = True - zha_device.on_network = True + zha_device._available = True + zha_device._on_network = True await zha_device.async_configure() + await zha_device.async_initialize(from_cache=False) + self.create_platform_entities() device_info = ExtendedDeviceInfoWithPairingStatus( pairing_status=DevicePairingStatus.CONFIGURED, - **zha_device.extended_device_info.__dict__, + **zha_device.extended_device_info.model_dump(), ) - await zha_device.async_initialize(from_cache=False) - self.create_platform_entities() self.emit( ZHA_GW_MSG_DEVICE_FULL_INIT, - DeviceFullInitEvent(device_info=device_info, new_join=True), + DeviceFullyInitializedEvent(device_info=device_info, new_join=True), ) async def _async_device_rejoined(self, zha_device: Device) -> None: @@ -651,15 +689,13 @@ async def _async_device_rejoined(self, zha_device: Device) -> None: await zha_device.async_configure() device_info = ExtendedDeviceInfoWithPairingStatus( pairing_status=DevicePairingStatus.CONFIGURED, - **zha_device.extended_device_info.__dict__, + **zha_device.extended_device_info.model_dump(), ) self.emit( ZHA_GW_MSG_DEVICE_FULL_INIT, - DeviceFullInitEvent(device_info=device_info), + DeviceFullyInitializedEvent(device_info=device_info), ) - # force async_initialize() to fire so don't explicitly call it - zha_device.available = False - zha_device.on_network = True + await zha_device.async_initialize(False) async def async_create_zigpy_group( self, @@ -728,7 +764,7 @@ async def async_remove_zigpy_group(self, group_id: int) -> None: await asyncio.gather(*tasks) self.application_controller.groups.pop(group_id) - async def shutdown(self) -> None: + async def shutdown(self, call_super=True) -> None: """Stop ZHA Controller Application.""" if self.shutting_down: _LOGGER.debug("Ignoring duplicate shutdown event") @@ -751,7 +787,8 @@ async def shutdown(self) -> None: self.application_controller = None await asyncio.sleep(0.1) # give bellows thread callback a chance to run - await super().shutdown() + if call_super: + await super().shutdown() self._devices.clear() self._groups.clear() @@ -769,3 +806,413 @@ def handle_message( # pylint: disable=unused-argument if sender.ieee in self.devices and not self.devices[sender.ieee].available: self.devices[sender.ieee].on_network = True self.async_update_device(sender, available=True) + + def broadcast_event(self, event: BaseEvent) -> None: + """Broadcast an event to all listeners.""" + + +class WebSocketServerGateway(Gateway): + """ZHA websocket server implementation.""" + + def __init__(self, config: ZHAData) -> None: + """Initialize the websocket server gateway.""" + super().__init__(config) + self._ws_server: websockets.WebSocketServer | None = None + self._client_manager: ClientManager = ClientManager(self) + self._stopped_event: asyncio.Event = asyncio.Event() + self.data: dict[Any, Any] = {} + for platform in discovery.PLATFORMS: + self.data.setdefault(platform, []) + self.data.setdefault(WEBSOCKET_API, {}) + self._register_api_commands() + + @property + def is_serving(self) -> bool: + """Return whether or not the websocket server is serving.""" + return self._ws_server is not None and self._ws_server.is_serving + + @property + def client_manager(self) -> ClientManager: + """Return the zigbee application controller.""" + return self._client_manager + + async def start_server(self) -> None: + """Start the websocket server.""" + assert self._ws_server is None + self._stopped_event.clear() + self._ws_server = await websockets.serve( + self.client_manager.add_client, + self.config.ws_server_config.host, + self.config.ws_server_config.port, + logger=_LOGGER, + ) + if self.config.ws_server_config.network_auto_start: + await self.async_initialize() + await self.async_initialize_devices_and_entities() + + async def async_initialize(self) -> None: + """Initialize controller and connect radio.""" + await super().async_initialize() + self.on_all_events(self.client_manager.broadcast) + + async def stop_server(self) -> None: + """Stop the websocket server.""" + if self._ws_server is None: + self._stopped_event.set() + return + + assert self._ws_server is not None + + await self.shutdown() + + self._ws_server.close() + await self._ws_server.wait_closed() + self._ws_server = None + + self._stopped_event.set() + + async def start_network(self) -> None: + """Start the Zigbee network.""" + await super().async_initialize() # we do this to avoid 2x event registration + await self.async_initialize_devices_and_entities() + + async def stop_network(self) -> None: + """Stop the Zigbee network.""" + await self.shutdown(call_super=False) + + async def wait_closed(self) -> None: + """Wait until the server is not running.""" + await self._stopped_event.wait() + _LOGGER.info("Server stopped. Completing remaining tasks...") + tasks = [ + t + for t in self._tracked_completable_tasks + if not (t.done() or t.cancelled()) + ] + for task in tasks: + _LOGGER.debug("Cancelling task: %s", task) + task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await asyncio.gather(*tasks, return_exceptions=True) + + async def __aenter__(self) -> WebSocketServerGateway: + """Enter the context manager.""" + await self.start_server() + return self + + async def __aexit__( + self, exc_type: Exception, exc_value: str, traceback: TracebackType + ) -> None: + """Exit the context manager.""" + await self.stop_server() + await self.wait_closed() + + def broadcast_event(self, event: BaseEvent) -> None: + """Broadcast an event to all listeners.""" + self.emit(event.event, event) + + def _register_api_commands(self) -> None: + """Load server API commands.""" + + load_zigbee_controller_api(self) + load_platform_entity_apis(self) + load_client_api(self) + + +CONNECT_TIMEOUT = 10 + + +class WebSocketClientGateway(BaseGateway): + """ZHA gateway implementation for a websocket client.""" + + def __init__(self, config: ZHAData) -> None: + """Initialize the websocket client gateway.""" + super().__init__(config) + self._tasks: list[asyncio.Task] = [] + self._ws_server_url: str = ( + f"ws://{config.ws_client_config.host}:{config.ws_client_config.port}" + ) + self._client: Client = Client( + self._ws_server_url, aiohttp_session=config.ws_client_config.aiohttp_session + ) + self._state: State + self._devices: dict[EUI64, WebSocketClientDevice] = {} + self._groups: dict[int, WebSocketClientGroup] = {} + self.coordinator_zha_device: WebSocketClientDevice = None # type: ignore[assignment] + self.lights: LightHelper = LightHelper(self._client) + self.switches: SwitchHelper = SwitchHelper(self._client) + self.sirens: SirenHelper = SirenHelper(self._client) + self.buttons: ButtonHelper = ButtonHelper(self._client) + self.covers: CoverHelper = CoverHelper(self._client) + self.fans: FanHelper = FanHelper(self._client) + self.locks: LockHelper = LockHelper(self._client) + self.numbers: NumberHelper = NumberHelper(self._client) + self.selects: SelectHelper = SelectHelper(self._client) + self.thermostats: ClimateHelper = ClimateHelper(self._client) + self.alarm_control_panels: AlarmControlPanelHelper = AlarmControlPanelHelper( + self._client + ) + self.entities: PlatformEntityHelper = PlatformEntityHelper(self._client) + self.clients: ClientHelper = ClientHelper(self._client) + self.groups_helper: GroupHelper = GroupHelper(self._client) + self.devices_helper: DeviceHelper = DeviceHelper(self._client) + self.network: NetworkHelper = NetworkHelper(self._client) + self.server_helper: ServerHelper = ServerHelper(self._client) + self.update_helper: UpdateHelper = UpdateHelper(self._client) + self._client.on_all_events(self._handle_event_protocol) + + @property + def client(self) -> Client: + """Return the client.""" + return self._client + + @property + def devices(self) -> dict[EUI64, WebSocketClientDevice]: + """Return devices.""" + return self._devices + + @property + def groups(self) -> dict[int, WebSocketClientGroup]: + """Return groups.""" + return self._groups + + @property + def state(self) -> State: + """Return the active coordinator's network state.""" + return self._state + + async def connect(self) -> None: + """Connect to the websocket server.""" + _LOGGER.debug("Connecting to websocket server at: %s", self._ws_server_url) + try: + async with timeout(CONNECT_TIMEOUT): + await self._client.connect() + except TimeoutError as err: + _LOGGER.exception("Unable to connect to the ZHA wss", exc_info=err) + await self._client.disconnect() + raise ZHAWebSocketException from err + + async def disconnect(self) -> None: + """Disconnect from the websocket server.""" + await self._client.disconnect() + + async def __aenter__(self) -> WebSocketClientGateway: + """Connect to the websocket server.""" + await self.connect() + await self.clients.listen() + return self + + async def __aexit__( + self, exc_type: Exception, exc_value: str, traceback: TracebackType + ) -> None: + """Disconnect from the websocket server.""" + await self.disconnect() + + def create_and_track_task(self, coroutine: Coroutine) -> asyncio.Task: + """Create and track a task.""" + task = asyncio.create_task(coroutine) + self._tasks.append(task) + task.add_done_callback(self._tasks.remove) + return task + + async def load_devices(self) -> None: + """Restore ZHA devices from zigpy application state.""" + response_devices = await self.devices_helper.get_devices() + for ieee, device in response_devices.items(): + self._devices[ieee] = self.get_or_create_device(device) + + async def load_groups(self) -> None: + """Initialize ZHA groups.""" + response_groups = await self.groups_helper.get_groups() + for group_id, group in response_groups.items(): + self._groups[group_id] = WebSocketClientGroup(group, self) + + async def load_application_state(self) -> None: + """Load the application state.""" + response = await self.network.get_application_state() + self._state = response.get_converted_state() + + async def async_initialize(self) -> None: + """Initialize controller and connect radio.""" + try: + await self._async_initialize() + except Exception: + await self.shutdown() + raise + + async def _async_initialize(self) -> None: + """Initialize controller and connect radio.""" + + await self.load_application_state() + await self.load_devices() + self.coordinator_zha_device = self._find_coordinator_device() + await self.load_groups() + + def _find_coordinator_device(self) -> WebSocketClientDevice | None: + """Find the coordinator device.""" + for device in self._devices.values(): + if device.is_active_coordinator: + return device + return None + + async def async_initialize_devices_and_entities(self) -> None: + """Initialize devices and load entities.""" + + def get_or_create_device( + self, zigpy_device: zigpy.device.Device | ExtendedDeviceInfo + ) -> WebSocketClientDevice: + """Get or create a ZHA device.""" + if (zha_device := self._devices.get(zigpy_device.ieee)) is None: + zha_device = WebSocketClientDevice(zigpy_device, self) + self._devices[zigpy_device.ieee] = zha_device + else: + self._devices[zigpy_device.ieee].extended_device_info = zigpy_device + return zha_device + + async def async_create_zigpy_group( + self, + name: str, + members: list[GroupMemberReference] | None, + group_id: int | None = None, + ) -> WebSocketClientGroup | None: + """Create a new Zigpy Zigbee group.""" + response = await self.groups_helper.create_group(name, group_id, members) + return self._groups.get(response.group_id) + + def get_device(self, ieee: EUI64) -> WebSocketClientDevice | None: + """Return Device for given ieee.""" + return self._devices.get(ieee) + + def get_group(self, group_id_or_name: int | str) -> WebSocketClientGroup | None: + """Return Group for given group id or group name.""" + if isinstance(group_id_or_name, str): + for group in self.groups.values(): + if group.name == group_id_or_name: + return group + return None + return self.groups.get(group_id_or_name) + + async def async_remove_device(self, ieee: EUI64) -> None: + """Remove a device from ZHA.""" + await self.devices_helper.remove_device(self.devices[ieee].extended_device_info) + + async def async_remove_zigpy_group(self, group_id: int) -> None: + """Remove a Zigbee group from Zigpy.""" + await self.groups_helper.remove_groups([self.groups[group_id].info_object]) + + async def shutdown(self) -> None: + """Stop ZHA Controller Application.""" + await self.server_helper.stop_server() + + def handle_state_changed(self, event: EntityStateChangedEvent) -> None: + """Handle a platform_entity_event from the websocket server.""" + _LOGGER.debug("platform_entity_event: %s", event) + if event.device_ieee: + device = self.devices.get(event.device_ieee) + if device is None: + _LOGGER.warning("Received event from unknown device: %s", event) + return + device.emit_platform_entity_event(event) + elif event.group_id: + group = self.groups.get(event.group_id) + if not group: + _LOGGER.warning("Received event from unknown group: %s", event) + return + group.emit_platform_entity_event(event) + + def handle_zha_event(self, event: ZHAEvent) -> None: + """Handle a zha_event from the websocket server.""" + _LOGGER.debug("zha_event: %s", event) + device = self.devices.get(event.device_ieee) + if device is None: + _LOGGER.warning("Received zha_event from unknown device: %s", event) + return + device.emit(DeviceEvents.ZHA_EVENT, event) + + def handle_device_joined(self, event: DeviceJoinedEvent) -> None: + """Handle device joined. + + At this point, no information about the device is known other than its + address + """ + + self.emit(ZHA_GW_MSG_DEVICE_JOINED, event) + + def handle_raw_device_initialized(self, event: RawDeviceInitializedEvent) -> None: + """Handle a device initialization without quirks loaded.""" + + self.emit(ZHA_GW_MSG_RAW_INIT, event) + + def handle_device_fully_initialized( + self, event: DeviceFullyInitializedEvent + ) -> None: + """Handle device joined and basic information discovered.""" + device_model = event.device_info + _LOGGER.info("Device %s - %s initialized", device_model.ieee, device_model.nwk) + if device_model.ieee in self.devices: + self.devices[device_model.ieee].extended_device_info = device_model + else: + self._devices[device_model.ieee] = self.get_or_create_device(device_model) + self.emit(ControllerEvents.DEVICE_FULLY_INITIALIZED, event) + + def handle_device_left(self, event: DeviceLeftEvent) -> None: + """Handle device leaving the network.""" + _LOGGER.info("Device %s - %s left", event.ieee, event.nwk) + self.emit(ZHA_GW_MSG_DEVICE_LEFT, event) + + def handle_device_removed(self, event: DeviceRemovedEvent) -> None: + """Handle device being removed from the network.""" + device = event.device_info + _LOGGER.info( + "Device %s - %s has been removed from the network", device.ieee, device.nwk + ) + self._devices.pop(device.ieee, None) + self.emit(ZHA_GW_MSG_DEVICE_REMOVED, event) + + def handle_device_online(self, event: DeviceOnlineEvent) -> None: + """Handle device online event.""" + if event.device_info.ieee in self.devices: + device = self.devices[event.device_info.ieee] + device.extended_device_info = event.device_info + device.emit(DeviceEvents.DEVICE_ONLINE, event) + + def handle_device_offline(self, event: DeviceOfflineEvent) -> None: + """Handle device offline event.""" + if event.device_info.ieee in self.devices: + device = self.devices[event.device_info.ieee] + device.extended_device_info = event.device_info + device.emit(DeviceEvents.DEVICE_OFFLINE, event) + + def handle_group_member_removed(self, event: GroupMemberRemovedEvent) -> None: + """Handle group member removed event.""" + if event.group_info.group_id in self.groups: + self.groups[event.group_info.group_id].info_object = event.group_info + self.emit(ControllerEvents.GROUP_MEMBER_REMOVED, event) + + def handle_group_member_added(self, event: GroupMemberAddedEvent) -> None: + """Handle group member added event.""" + if event.group_info.group_id in self.groups: + self.groups[event.group_info.group_id].info_object = event.group_info + self.emit(ControllerEvents.GROUP_MEMBER_ADDED, event) + + def handle_group_added(self, event: GroupAddedEvent) -> None: + """Handle group added event.""" + if event.group_info.group_id in self.groups: + self.groups[event.group_info.group_id].info_object = event.group_info + else: + self.groups[event.group_info.group_id] = WebSocketClientGroup( + event.group_info, self + ) + self.emit(ControllerEvents.GROUP_ADDED, event) + + def handle_group_removed(self, event: GroupRemovedEvent) -> None: + """Handle group removed event.""" + if event.group_info.group_id in self.groups: + self.groups.pop(event.group_info.group_id) + self.emit(ControllerEvents.GROUP_REMOVED, event) + + def connection_lost(self, exc: Exception) -> None: + """Handle connection lost event.""" + + def broadcast_event(self, event: BaseEvent) -> None: + """Broadcast an event to all listeners.""" diff --git a/zha/application/helpers.py b/zha/application/helpers.py index 300de0078..0596c1691 100644 --- a/zha/application/helpers.py +++ b/zha/application/helpers.py @@ -4,11 +4,8 @@ import asyncio import binascii -import collections from collections.abc import Callable -import dataclasses from dataclasses import dataclass -import datetime import enum import logging import re @@ -22,15 +19,10 @@ from zigpy.zcl.foundation import CommandSchema import zigpy.zdo.types as zdo_types -from zha.application import Platform -from zha.application.const import ( - CLUSTER_TYPE_IN, - CLUSTER_TYPE_OUT, - CONF_DEFAULT_CONSIDER_UNAVAILABLE_BATTERY, - CONF_DEFAULT_CONSIDER_UNAVAILABLE_MAINS, -) +from zha.application.const import CLUSTER_TYPE_IN, CLUSTER_TYPE_OUT from zha.async_ import gather_with_limited_concurrency from zha.decorators import periodic +from zha.zigbee.cluster_handlers.registries import BINDABLE_CLUSTERS if TYPE_CHECKING: from zha.application.gateway import Gateway @@ -89,9 +81,6 @@ async def get_matched_clusters( source_zha_device: Device, target_zha_device: Device ) -> list[BindingPair]: """Get matched input/output cluster pairs for 2 devices.""" - from zha.zigbee.cluster_handlers.registries import ( # pylint: disable=import-outside-toplevel - BINDABLE_CLUSTERS, - ) source_clusters = source_zha_device.async_get_std_clusters() target_clusters = target_zha_device.async_get_std_clusters() @@ -165,9 +154,6 @@ def convert_to_zcl_values( def async_is_bindable_target(source_zha_device: Device, target_zha_device: Device): """Determine if target is bindable to source.""" - from zha.zigbee.cluster_handlers.registries import ( # pylint: disable=import-outside-toplevel - BINDABLE_CLUSTERS, - ) if target_zha_device.nwk == 0x0000: return True @@ -261,102 +247,6 @@ def qr_to_install_code(qr_code: str) -> tuple[zigpy.types.EUI64, zigpy.types.Key raise vol.Invalid(f"couldn't convert qr code: {qr_code}") -@dataclass(kw_only=True, slots=True) -class LightOptions: - """ZHA light options.""" - - default_light_transition: float = dataclasses.field(default=0) - enable_enhanced_light_transition: bool = dataclasses.field(default=False) - enable_light_transitioning_flag: bool = dataclasses.field(default=True) - always_prefer_xy_color_mode: bool = dataclasses.field(default=True) - group_members_assume_state: bool = dataclasses.field(default=True) - - -@dataclass(kw_only=True, slots=True) -class DeviceOptions: - """ZHA device options.""" - - enable_identify_on_join: bool = dataclasses.field(default=True) - consider_unavailable_mains: int = dataclasses.field( - default=CONF_DEFAULT_CONSIDER_UNAVAILABLE_MAINS - ) - consider_unavailable_battery: int = dataclasses.field( - default=CONF_DEFAULT_CONSIDER_UNAVAILABLE_BATTERY - ) - enable_mains_startup_polling: bool = dataclasses.field(default=True) - - -@dataclass(kw_only=True, slots=True) -class AlarmControlPanelOptions: - """ZHA alarm control panel options.""" - - master_code: str = dataclasses.field(default="1234") - failed_tries: int = dataclasses.field(default=3) - arm_requires_code: bool = dataclasses.field(default=False) - - -@dataclass(kw_only=True, slots=True) -class CoordinatorConfiguration: - """ZHA coordinator configuration.""" - - path: str - baudrate: int = dataclasses.field(default=115200) - flow_control: str = dataclasses.field(default="hardware") - radio_type: str = dataclasses.field(default="ezsp") - - -@dataclass(kw_only=True, slots=True) -class QuirksConfiguration: - """ZHA quirks configuration.""" - - enabled: bool = dataclasses.field(default=True) - custom_quirks_path: str | None = dataclasses.field(default=None) - - -@dataclass(kw_only=True, slots=True) -class DeviceOverridesConfiguration: - """ZHA device overrides configuration.""" - - type: Platform - - -@dataclass(kw_only=True, slots=True) -class ZHAConfiguration: - """ZHA configuration.""" - - coordinator_configuration: CoordinatorConfiguration = dataclasses.field( - default_factory=CoordinatorConfiguration - ) - quirks_configuration: QuirksConfiguration = dataclasses.field( - default_factory=QuirksConfiguration - ) - device_overrides: dict[str, DeviceOverridesConfiguration] = dataclasses.field( - default_factory=dict - ) - light_options: LightOptions = dataclasses.field(default_factory=LightOptions) - device_options: DeviceOptions = dataclasses.field(default_factory=DeviceOptions) - alarm_control_panel_options: AlarmControlPanelOptions = dataclasses.field( - default_factory=AlarmControlPanelOptions - ) - - -@dataclasses.dataclass(kw_only=True, slots=True) -class ZHAData: - """ZHA data stored in `gateway.data`.""" - - config: ZHAConfiguration - zigpy_config: dict[str, Any] = dataclasses.field(default_factory=dict) - platforms: collections.defaultdict[Platform, list] = dataclasses.field( - default_factory=lambda: collections.defaultdict(list) - ) - gateway: Gateway | None = dataclasses.field(default=None) - device_trigger_cache: dict[str, tuple[str, dict]] = dataclasses.field( - default_factory=dict - ) - allow_polling: bool = dataclasses.field(default=False) - local_timezone: datetime.tzinfo = dataclasses.field(default=datetime.UTC) - - class GlobalUpdater: """Global updater for ZHA. diff --git a/zha/application/model.py b/zha/application/model.py new file mode 100644 index 000000000..912e06b2c --- /dev/null +++ b/zha/application/model.py @@ -0,0 +1,275 @@ +"""Models for the ZHA application module.""" + +from __future__ import annotations + +import collections +import dataclasses +import datetime +from enum import Enum +from typing import TYPE_CHECKING, Any, Literal + +from aiohttp import ClientSession +from pydantic import Field +from zigpy.types.named import EUI64, NWK + +from zha.application import Platform +from zha.application.const import ( + CONF_DEFAULT_CONSIDER_UNAVAILABLE_BATTERY, + CONF_DEFAULT_CONSIDER_UNAVAILABLE_MAINS, +) +from zha.const import ControllerEvents, DeviceEvents, EventTypes +from zha.model import BaseEvent, BaseModel +from zha.zigbee.model import DeviceInfo, ExtendedDeviceInfo, GroupInfo + +if TYPE_CHECKING: + from zha.application.gateway import Gateway + + +class DevicePairingStatus(Enum): + """Status of a device.""" + + PAIRED = 1 + INTERVIEW_COMPLETE = 2 + CONFIGURED = 3 + INITIALIZED = 4 + + +class DeviceInfoWithPairingStatus(DeviceInfo): + """Information about a device with pairing status.""" + + pairing_status: DevicePairingStatus + + +class ExtendedDeviceInfoWithPairingStatus(ExtendedDeviceInfo): + """Information about a device with pairing status.""" + + pairing_status: DevicePairingStatus + + +class DeviceJoinedDeviceInfo(BaseModel): + """Information about a device.""" + + ieee: EUI64 + nwk: NWK + pairing_status: DevicePairingStatus + + +class ConnectionLostEvent(BaseEvent): + """Event to signal that the connection to the radio has been lost.""" + + event_type: Literal[EventTypes.CONTROLLER_EVENT] = EventTypes.CONTROLLER_EVENT + event: Literal[ControllerEvents.CONNECTION_LOST] = ControllerEvents.CONNECTION_LOST + exception: Exception | None = None + + +class DeviceJoinedEvent(BaseEvent): + """Event to signal that a device has joined the network.""" + + device_info: DeviceJoinedDeviceInfo + event_type: Literal[EventTypes.CONTROLLER_EVENT] = EventTypes.CONTROLLER_EVENT + event: Literal[ControllerEvents.DEVICE_JOINED] = ControllerEvents.DEVICE_JOINED + + +class DeviceLeftEvent(BaseEvent): + """Event to signal that a device has left the network.""" + + ieee: EUI64 + nwk: NWK + event_type: Literal[EventTypes.CONTROLLER_EVENT] = EventTypes.CONTROLLER_EVENT + event: Literal[ControllerEvents.DEVICE_LEFT] = ControllerEvents.DEVICE_LEFT + + +class RawDeviceInitializedDeviceInfo(DeviceJoinedDeviceInfo): + """Information about a device that has been initialized without quirks loaded.""" + + model: str + manufacturer: str + signature: dict[str, Any] + + +class RawDeviceInitializedEvent(BaseEvent): + """Event to signal that a device has been initialized without quirks loaded.""" + + device_info: RawDeviceInitializedDeviceInfo + event_type: Literal[EventTypes.CONTROLLER_EVENT] = EventTypes.CONTROLLER_EVENT + event: Literal[ControllerEvents.RAW_DEVICE_INITIALIZED] = ( + ControllerEvents.RAW_DEVICE_INITIALIZED + ) + + +class DeviceFullyInitializedEvent(BaseEvent): + """Event to signal that a device has been fully initialized.""" + + device_info: ExtendedDeviceInfoWithPairingStatus + new_join: bool = False + event_type: Literal[EventTypes.CONTROLLER_EVENT] = EventTypes.CONTROLLER_EVENT + event: Literal[ControllerEvents.DEVICE_FULLY_INITIALIZED] = ( + ControllerEvents.DEVICE_FULLY_INITIALIZED + ) + + +class GroupRemovedEvent(BaseEvent): + """Group removed event.""" + + event_type: Literal[EventTypes.CONTROLLER_EVENT] = EventTypes.CONTROLLER_EVENT + event: Literal[ControllerEvents.GROUP_REMOVED] = ControllerEvents.GROUP_REMOVED + group_info: GroupInfo + + +class GroupAddedEvent(BaseEvent): + """Group added event.""" + + event_type: Literal[EventTypes.CONTROLLER_EVENT] = EventTypes.CONTROLLER_EVENT + event: Literal[ControllerEvents.GROUP_ADDED] = ControllerEvents.GROUP_ADDED + group_info: GroupInfo + + +class GroupMemberAddedEvent(BaseEvent): + """Group member added event.""" + + event_type: Literal[EventTypes.CONTROLLER_EVENT] = EventTypes.CONTROLLER_EVENT + event: Literal[ControllerEvents.GROUP_MEMBER_ADDED] = ( + ControllerEvents.GROUP_MEMBER_ADDED + ) + group_info: GroupInfo + + +class GroupMemberRemovedEvent(BaseEvent): + """Group member removed event.""" + + event_type: Literal[EventTypes.CONTROLLER_EVENT] = EventTypes.CONTROLLER_EVENT + event: Literal[ControllerEvents.GROUP_MEMBER_REMOVED] = ( + ControllerEvents.GROUP_MEMBER_REMOVED + ) + group_info: GroupInfo + + +class DeviceRemovedEvent(BaseEvent): + """Event to signal that a device has been removed.""" + + device_info: ExtendedDeviceInfo + event_type: Literal[EventTypes.CONTROLLER_EVENT] = EventTypes.CONTROLLER_EVENT + event: Literal[ControllerEvents.DEVICE_REMOVED] = ControllerEvents.DEVICE_REMOVED + + +class DeviceOfflineEvent(BaseEvent): + """Device offline event.""" + + event: Literal[DeviceEvents.DEVICE_OFFLINE] = DeviceEvents.DEVICE_OFFLINE + event_type: Literal[EventTypes.DEVICE_EVENT] = EventTypes.DEVICE_EVENT + device_info: ExtendedDeviceInfo + + +class DeviceOnlineEvent(BaseEvent): + """Device online event.""" + + event: Literal[DeviceEvents.DEVICE_ONLINE] = DeviceEvents.DEVICE_ONLINE + event_type: Literal[EventTypes.DEVICE_EVENT] = EventTypes.DEVICE_EVENT + device_info: ExtendedDeviceInfo + + +class LightOptions(BaseModel): + """ZHA light options.""" + + default_light_transition: float = Field(default=0) + enable_enhanced_light_transition: bool = Field(default=False) + enable_light_transitioning_flag: bool = Field(default=True) + always_prefer_xy_color_mode: bool = Field(default=True) + group_members_assume_state: bool = Field(default=True) + + +class DeviceOptions(BaseModel): + """ZHA device options.""" + + enable_identify_on_join: bool = Field(default=True) + consider_unavailable_mains: int = Field( + default=CONF_DEFAULT_CONSIDER_UNAVAILABLE_MAINS + ) + consider_unavailable_battery: int = Field( + default=CONF_DEFAULT_CONSIDER_UNAVAILABLE_BATTERY + ) + enable_mains_startup_polling: bool = Field(default=True) + + +class AlarmControlPanelOptions(BaseModel): + """ZHA alarm control panel options.""" + + master_code: str = Field(default="1234") + failed_tries: int = Field(default=3) + arm_requires_code: bool = Field(default=False) + + +class CoordinatorConfiguration(BaseModel): + """ZHA coordinator configuration.""" + + path: str + baudrate: int = Field(default=115200) + flow_control: str = Field(default="hardware") + radio_type: str = Field(default="ezsp") + + +class QuirksConfiguration(BaseModel): + """ZHA quirks configuration.""" + + enabled: bool = Field(default=True) + custom_quirks_path: str | None = Field(default=None) + + +class DeviceOverridesConfiguration(BaseModel): + """ZHA device overrides configuration.""" + + type: Platform + + +class WebsocketServerConfiguration(BaseModel): + """Websocket Server configuration for zha.""" + + host: str = "0.0.0.0" + port: int = 8001 + network_auto_start: bool = False + + +class WebsocketClientConfiguration(BaseModel): + """Websocket client configuration for zha.""" + + host: str = "0.0.0.0" + port: int = 8001 + aiohttp_session: ClientSession | None = None + + +class ZHAConfiguration(BaseModel): + """ZHA configuration.""" + + coordinator_configuration: CoordinatorConfiguration = Field( + default_factory=CoordinatorConfiguration + ) + quirks_configuration: QuirksConfiguration = Field( + default_factory=QuirksConfiguration + ) + device_overrides: dict[str, DeviceOverridesConfiguration] = Field( + default_factory=dict + ) + light_options: LightOptions = Field(default_factory=LightOptions) + device_options: DeviceOptions = Field(default_factory=DeviceOptions) + alarm_control_panel_options: AlarmControlPanelOptions = Field( + default_factory=AlarmControlPanelOptions + ) + + +@dataclasses.dataclass(kw_only=True, slots=True) +class ZHAData: + """ZHA data stored in `gateway.data`.""" + + config: ZHAConfiguration + ws_server_config: WebsocketServerConfiguration | None = None + ws_client_config: WebsocketClientConfiguration | None = None + zigpy_config: dict[str, Any] = dataclasses.field(default_factory=dict) + platforms: collections.defaultdict[Platform, list] = dataclasses.field( + default_factory=lambda: collections.defaultdict(list) + ) + gateway: Gateway | None = dataclasses.field(default=None) + device_trigger_cache: dict[str, tuple[str, dict]] = dataclasses.field( + default_factory=dict + ) + allow_polling: bool = dataclasses.field(default=False) + local_timezone: datetime.tzinfo = dataclasses.field(default=datetime.UTC) diff --git a/zha/application/platforms/__init__.py b/zha/application/platforms/__init__.py index b0aedf75b..353833738 100644 --- a/zha/application/platforms/__init__.py +++ b/zha/application/platforms/__init__.py @@ -5,25 +5,32 @@ from abc import abstractmethod import asyncio from contextlib import suppress -import dataclasses -from enum import StrEnum from functools import cached_property import logging -from typing import TYPE_CHECKING, Any, Final, Optional, final +from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, final from zigpy.quirks.v2 import EntityMetadata, EntityType from zigpy.types.named import EUI64 from zha.application import Platform -from zha.const import STATE_CHANGED +from zha.application.platforms.const import EntityCategory +from zha.application.platforms.model import ( + BaseEntityInfo, + BaseIdentifiers, + GroupEntityIdentifiers, + PlatformEntityIdentifiers, + T as BaseEntityInfoType, +) +from zha.const import STATE_CHANGED, EntityEvents, EventTypes from zha.debounce import Debouncer from zha.event import EventBase from zha.mixins import LogMixin -from zha.zigbee.cluster_handlers import ClusterHandlerInfo +from zha.model import BaseEvent if TYPE_CHECKING: + from zha.websocket.server.api.model import WebSocketCommandResponse from zha.zigbee.cluster_handlers import ClusterHandler - from zha.zigbee.device import Device + from zha.zigbee.device import Device, WebSocketClientDevice from zha.zigbee.endpoint import Endpoint from zha.zigbee.group import Group @@ -33,76 +40,18 @@ DEFAULT_UPDATE_GROUP_FROM_CHILD_DELAY: float = 0.5 -class EntityCategory(StrEnum): - """Category of an entity.""" - - # Config: An entity which allows changing the configuration of a device. - CONFIG = "config" - - # Diagnostic: An entity exposing some configuration parameter, - # or diagnostics of a device. - DIAGNOSTIC = "diagnostic" - - -@dataclasses.dataclass(frozen=True, kw_only=True) -class BaseEntityInfo: - """Information about a base entity.""" - - fallback_name: str - unique_id: str - platform: str - class_name: str - translation_key: str | None - device_class: str | None - state_class: str | None - entity_category: str | None - entity_registry_enabled_default: bool - enabled: bool = True - - # For platform entities - cluster_handlers: list[ClusterHandlerInfo] - device_ieee: EUI64 | None - endpoint_id: int | None - available: bool | None - - # For group entities - group_id: int | None - - -@dataclasses.dataclass(frozen=True, kw_only=True) -class BaseIdentifiers: - """Identifiers for the base entity.""" - - unique_id: str - platform: str - - -@dataclasses.dataclass(frozen=True, kw_only=True) -class PlatformEntityIdentifiers(BaseIdentifiers): - """Identifiers for the platform entity.""" - - device_ieee: EUI64 - endpoint_id: int - - -@dataclasses.dataclass(frozen=True, kw_only=True) -class GroupEntityIdentifiers(BaseIdentifiers): - """Identifiers for the group entity.""" - - group_id: int - - -@dataclasses.dataclass(frozen=True, kw_only=True) -class EntityStateChangedEvent: +# this class exists solely to break circular imports +class EntityStateChangedEvent(BaseEvent): """Event for when an entity state changes.""" - event_type: Final[str] = "entity" - event: Final[str] = STATE_CHANGED - platform: str + event_type: Literal[EventTypes.ENTITY_EVENT] = EventTypes.ENTITY_EVENT + event: Literal[EntityEvents.STATE_CHANGED] = EntityEvents.STATE_CHANGED + platform: Platform unique_id: str - device_ieee: Optional[EUI64] = None - endpoint_id: Optional[int] = None - group_id: Optional[int] = None + device_ieee: EUI64 | None = None + endpoint_id: int | None = None + group_id: int | None = None + state: dict[str, Any] | None class BaseEntity(LogMixin, EventBase): @@ -124,7 +73,7 @@ def __init__(self, unique_id: str) -> None: self._unique_id: str = unique_id - self.__previous_state: Any = None + self._previous_state: Any = None self._tracked_tasks: list[asyncio.Task] = [] self._tracked_handles: list[asyncio.Handle] = [] @@ -197,7 +146,7 @@ def identifiers(self) -> BaseIdentifiers: platform=self.PLATFORM, ) - @cached_property + @property def info_object(self) -> BaseEntityInfo: """Return a representation of the platform entity.""" @@ -219,6 +168,7 @@ def info_object(self) -> BaseEntityInfo: available=None, # Set by group entities group_id=None, + state=self.state, ) @property @@ -242,10 +192,12 @@ def extra_state_attribute_names(self) -> set[str] | None: def enable(self) -> None: """Enable the entity.""" self.enabled = True + self.maybe_emit_state_changed_event() def disable(self) -> None: """Disable the entity.""" self.enabled = False + self.maybe_emit_state_changed_event() async def on_remove(self) -> None: """Cancel tasks and timers this entity owns.""" @@ -263,11 +215,15 @@ async def on_remove(self) -> None: def maybe_emit_state_changed_event(self) -> None: """Send the state of this platform entity.""" state = self.state - if self.__previous_state != state: + if self._previous_state != state: self.emit( - STATE_CHANGED, EntityStateChangedEvent(**self.identifiers.__dict__) + STATE_CHANGED, + EntityStateChangedEvent( + state=self.state, + **self.identifiers.model_dump(), + ), ) - self.__previous_state = state + self._previous_state = state def log(self, level: int, msg: str, *args: Any, **kwargs: Any) -> None: """Log a message.""" @@ -276,6 +232,9 @@ def log(self, level: int, msg: str, *args: Any, **kwargs: Any) -> None: _LOGGER.log(level, msg, *args, **kwargs) +T = TypeVar("T", bound=BaseEntity) + + class PlatformEntity(BaseEntity): """Class that represents an entity for a device platform.""" @@ -372,15 +331,16 @@ def identifiers(self) -> PlatformEntityIdentifiers: endpoint_id=self.endpoint.id, ) - @cached_property + @property def info_object(self) -> BaseEntityInfo: """Return a representation of the platform entity.""" - return dataclasses.replace( - super().info_object, - cluster_handlers=[ch.info_object for ch in self._cluster_handlers], - device_ieee=self._device.ieee, - endpoint_id=self._endpoint.id, - available=self.available, + return super().info_object.model_copy( + update={ + "cluster_handlers": [ch.info_object for ch in self._cluster_handlers], + "device_ieee": self._device.ieee, + "endpoint_id": self._endpoint.id, + "available": self.available, + } ) @property @@ -410,6 +370,16 @@ def state(self) -> dict[str, Any]: state["available"] = self.available return state + def maybe_emit_state_changed_event(self) -> None: + """Send the state of this platform entity.""" + super().maybe_emit_state_changed_event() + self.device.gateway.broadcast_event( + EntityStateChangedEvent( + state=self.state, + **self.identifiers.model_dump(), + ), + ) + async def async_update(self) -> None: """Retrieve latest state.""" self.debug("polling current state") @@ -453,12 +423,11 @@ def identifiers(self) -> GroupEntityIdentifiers: group_id=self.group_id, ) - @cached_property + @property def info_object(self) -> BaseEntityInfo: """Return a representation of the group.""" - return dataclasses.replace( - super().info_object, - group_id=self.group_id, + return super().info_object.model_copy( + update={"group_id": self.group_id, "available": self.available} ) @property @@ -486,6 +455,16 @@ def group(self) -> Group: """Return the group.""" return self._group + def maybe_emit_state_changed_event(self) -> None: + """Send the state of this platform entity.""" + super().maybe_emit_state_changed_event() + self.group.gateway.broadcast_event( + EntityStateChangedEvent( + state=self.state, + **self.identifiers.model_dump(), + ), + ) + def debounced_update(self, _: Any | None = None) -> None: """Debounce updating group entity from member entity updates.""" # Delay to ensure that we get updates from all members before updating the group entity @@ -505,3 +484,98 @@ def update(self, _: Any | None = None) -> None: async def async_update(self, _: Any | None = None) -> None: """Update the state of this group entity.""" self.update() + + +class WebSocketClientEntity(BaseEntity, Generic[BaseEntityInfoType]): + """Entity repsentation for the websocket client.""" + + def __init__( + self, entity_info: BaseEntityInfoType, device: WebSocketClientDevice + ) -> None: + """Initialize the websocket client entity.""" + super().__init__(entity_info.unique_id) + self.PLATFORM = entity_info.platform + self._device: WebSocketClientDevice = device + self._entity_info: BaseEntityInfoType = entity_info + self._update_attrs_from_entity_info() + + @property + def info_object(self) -> BaseEntityInfoType: + """Return a representation of the alarm control panel.""" + return self._entity_info + + @info_object.setter + def info_object(self, entity_info: BaseEntityInfoType) -> None: + """Set the entity info object.""" + self._entity_info = entity_info + self._update_attrs_from_entity_info() + self.maybe_emit_state_changed_event() + + @property + def state(self) -> dict[str, Any]: + """Return the arguments to use in the command.""" + return self._entity_info.state.model_dump() + + @state.setter + def state(self, value: dict[str, Any]) -> None: + """Set the state of the entity.""" + self._entity_info.state = value + self._attr_enabled = self._entity_info.enabled + self.maybe_emit_state_changed_event() + + @property + def group_id(self) -> int | None: + """Return the group id.""" + return self._entity_info.group_id + + @property + def available(self) -> bool: + """Return true if the device this entity belongs to is available.""" + return bool(self._entity_info.available) + + def enable(self) -> None: + """Enable the entity.""" + task = self._device.gateway.create_and_track_task( + self._device.gateway.entities.enable(self._entity_info) + ) + task.add_done_callback(self._enable) + + def disable(self) -> None: + """Disable the entity.""" + task = self._device.gateway.create_and_track_task( + self._device.gateway.entities.disable(self._entity_info) + ) + task.add_done_callback(self._disable) + + def _enable(self, future: asyncio.Future) -> None: + """Enable the entity.""" + response: WebSocketCommandResponse = future.result() + if response.success: + self._entity_info.enabled = True + self._attr_enabled = True + self.maybe_emit_state_changed_event() + + def _disable(self, future: asyncio.Future) -> None: + """Disable the entity.""" + response: WebSocketCommandResponse = future.result() + if response.success: + self._entity_info.enabled = False + self._attr_enabled = False + self.maybe_emit_state_changed_event() + + def _update_attrs_from_entity_info(self) -> None: + """Update the entity attributes.""" + self._attr_enabled = self._entity_info.enabled + self._attr_fallback_name = self._entity_info.fallback_name + self._attr_translation_key = self._entity_info.translation_key + self._attr_entity_category = self._entity_info.entity_category + self._attr_entity_registry_enabled_default = ( + self._entity_info.entity_registry_enabled_default + ) + self._attr_device_class = self._entity_info.device_class + self._attr_state_class = self._entity_info.state_class + + async def async_update(self) -> None: + """Retrieve latest state.""" + self.debug("polling current state") + await self._device.gateway.entities.refresh_state(self._entity_info) diff --git a/zha/application/platforms/alarm_control_panel/__init__.py b/zha/application/platforms/alarm_control_panel/__init__.py index f1716a4e6..4504259d8 100644 --- a/zha/application/platforms/alarm_control_panel/__init__.py +++ b/zha/application/platforms/alarm_control_panel/__init__.py @@ -2,7 +2,7 @@ from __future__ import annotations -from dataclasses import dataclass +from abc import ABC, abstractmethod import functools import logging from typing import TYPE_CHECKING, Any @@ -10,29 +10,31 @@ from zigpy.zcl.clusters.security import IasAce from zha.application import Platform -from zha.application.platforms import BaseEntityInfo, PlatformEntity +from zha.application.platforms import PlatformEntity, WebSocketClientEntity from zha.application.platforms.alarm_control_panel.const import ( IAS_ACE_STATE_MAP, - SUPPORT_ALARM_ARM_AWAY, - SUPPORT_ALARM_ARM_HOME, - SUPPORT_ALARM_ARM_NIGHT, - SUPPORT_ALARM_TRIGGER, + AlarmControlPanelEntityFeature, AlarmState, CodeFormat, ) +from zha.application.platforms.alarm_control_panel.model import ( + AlarmControlPanelEntityInfo, +) +from zha.application.platforms.model import EntityState from zha.application.registries import PLATFORM_ENTITIES +from zha.const import MODEL_CLASS_NAME from zha.zigbee.cluster_handlers.const import ( CLUSTER_HANDLER_IAS_ACE, CLUSTER_HANDLER_STATE_CHANGED, ) -from zha.zigbee.cluster_handlers.security import ( - ClusterHandlerStateChangedEvent, - IasAceClusterHandler, -) if TYPE_CHECKING: from zha.zigbee.cluster_handlers import ClusterHandler - from zha.zigbee.device import Device + from zha.zigbee.cluster_handlers.security import ( + ClusterHandlerStateChangedEvent, + IasAceClusterHandler, + ) + from zha.zigbee.device import Device, WebSocketClientDevice from zha.zigbee.endpoint import Endpoint STRICT_MATCH = functools.partial( @@ -42,22 +44,51 @@ _LOGGER = logging.getLogger(__name__) -@dataclass(frozen=True, kw_only=True) -class AlarmControlPanelEntityInfo(BaseEntityInfo): - """Alarm control panel entity info.""" +class AlarmControlPanelEntityInterface(ABC): + """Base class for alarm control panels.""" + + @property + @abstractmethod + def code_arm_required(self) -> bool: + """Whether the code is required for arm actions.""" + + @functools.cached_property + @abstractmethod + def code_format(self) -> CodeFormat: + """Code format or None if no code is required.""" + + @functools.cached_property + @abstractmethod + def supported_features(self) -> int: + """Return the list of supported features.""" + + @abstractmethod + async def async_alarm_disarm(self, code: str | None = None, **kwargs) -> None: + """Send disarm command.""" + + @abstractmethod + async def async_alarm_arm_home(self, code: str | None = None, **kwargs) -> None: + """Send arm home command.""" + + @abstractmethod + async def async_alarm_arm_away(self, code: str | None = None, **kwargs) -> None: + """Send arm away command.""" + + @abstractmethod + async def async_alarm_arm_night(self, code: str | None = None, **kwargs) -> None: + """Send arm night command.""" - code_arm_required: bool - code_format: CodeFormat - supported_features: int - translation_key: str + @abstractmethod + async def async_alarm_trigger(self, code: str | None = None, **kwargs) -> None: + """Send alarm trigger command.""" @STRICT_MATCH(cluster_handler_names=CLUSTER_HANDLER_IAS_ACE) -class AlarmControlPanel(PlatformEntity): +class AlarmControlPanel(PlatformEntity, AlarmControlPanelEntityInterface): """Entity for ZHA alarm control devices.""" - _attr_translation_key: str = "alarm_control_panel" PLATFORM = Platform.ALARM_CONTROL_PANEL + _attr_translation_key: str = "alarm_control_panel" def __init__( self, @@ -80,24 +111,26 @@ def __init__( CLUSTER_HANDLER_STATE_CHANGED, self._handle_event_protocol ) - @functools.cached_property + @property def info_object(self) -> AlarmControlPanelEntityInfo: """Return a representation of the alarm control panel.""" return AlarmControlPanelEntityInfo( - **super().info_object.__dict__, + **super().info_object.model_dump(exclude=[MODEL_CLASS_NAME]), code_arm_required=self.code_arm_required, code_format=self.code_format, supported_features=self.supported_features, + max_invalid_tries=self._cluster_handler.max_invalid_tries, ) @property def state(self) -> dict[str, Any]: """Get the state of the alarm control panel.""" - response = super().state - response["state"] = IAS_ACE_STATE_MAP.get( - self._cluster_handler.armed_state, AlarmState.UNKNOWN - ) - return response + return EntityState( + **super().state, + state=IAS_ACE_STATE_MAP.get( + self._cluster_handler.armed_state, AlarmState.UNKNOWN + ), + ).model_dump() @property def code_arm_required(self) -> bool: @@ -110,13 +143,13 @@ def code_format(self) -> CodeFormat: return CodeFormat.NUMBER @functools.cached_property - def supported_features(self) -> int: + def supported_features(self) -> AlarmControlPanelEntityFeature: """Return the list of supported features.""" return ( - SUPPORT_ALARM_ARM_HOME - | SUPPORT_ALARM_ARM_AWAY - | SUPPORT_ALARM_ARM_NIGHT - | SUPPORT_ALARM_TRIGGER + AlarmControlPanelEntityFeature.ARM_HOME + | AlarmControlPanelEntityFeature.ARM_AWAY + | AlarmControlPanelEntityFeature.ARM_NIGHT + | AlarmControlPanelEntityFeature.TRIGGER ) def handle_cluster_handler_state_changed( @@ -126,27 +159,83 @@ def handle_cluster_handler_state_changed( """Handle state changed on cluster.""" self.maybe_emit_state_changed_event() - async def async_alarm_disarm(self, code: str | None = None) -> None: + async def async_alarm_disarm(self, code: str | None = None, **kwargs) -> None: """Send disarm command.""" self._cluster_handler.arm(IasAce.ArmMode.Disarm, code, 0) self.maybe_emit_state_changed_event() - async def async_alarm_arm_home(self, code: str | None = None) -> None: + async def async_alarm_arm_home(self, code: str | None = None, **kwargs) -> None: """Send arm home command.""" self._cluster_handler.arm(IasAce.ArmMode.Arm_Day_Home_Only, code, 0) self.maybe_emit_state_changed_event() - async def async_alarm_arm_away(self, code: str | None = None) -> None: + async def async_alarm_arm_away(self, code: str | None = None, **kwargs) -> None: """Send arm away command.""" self._cluster_handler.arm(IasAce.ArmMode.Arm_All_Zones, code, 0) self.maybe_emit_state_changed_event() - async def async_alarm_arm_night(self, code: str | None = None) -> None: + async def async_alarm_arm_night(self, code: str | None = None, **kwargs) -> None: """Send arm night command.""" self._cluster_handler.arm(IasAce.ArmMode.Arm_Night_Sleep_Only, code, 0) self.maybe_emit_state_changed_event() - async def async_alarm_trigger(self, code: str | None = None) -> None: # pylint: disable=unused-argument + async def async_alarm_trigger(self, code: str | None = None, **kwargs) -> None: # pylint: disable=unused-argument """Send alarm trigger command.""" self._cluster_handler.panic() self.maybe_emit_state_changed_event() + + +class WebSocketClientAlarmControlPanel( + WebSocketClientEntity[AlarmControlPanelEntityInfo], AlarmControlPanelEntityInterface +): + """Alarm control panel entity for the WebSocket API.""" + + PLATFORM = Platform.ALARM_CONTROL_PANEL + _attr_translation_key: str = "alarm_control_panel" + + def __init__( + self, entity_info: AlarmControlPanelEntityInfo, device: WebSocketClientDevice + ) -> None: + """Initialize the ZHA alarm control device.""" + super().__init__(entity_info, device) + + @property + def code_arm_required(self) -> bool: + """Whether the code is required for arm actions.""" + return self._entity_info.code_arm_required + + @functools.cached_property + def code_format(self) -> CodeFormat: + """Code format or None if no code is required.""" + return self._entity_info.code_format + + @functools.cached_property + def supported_features(self) -> int: + """Return the list of supported features.""" + return self._entity_info.supported_features + + async def async_alarm_disarm(self, code: str | None = None, **kwargs) -> None: + """Send disarm command.""" + await self._device.gateway.alarm_control_panels.disarm(self._entity_info, code) + + async def async_alarm_arm_home(self, code: str | None = None, **kwargs) -> None: + """Send arm home command.""" + await self._device.gateway.alarm_control_panels.arm_home( + self._entity_info, code + ) + + async def async_alarm_arm_away(self, code: str | None = None, **kwargs) -> None: + """Send arm away command.""" + await self._device.gateway.alarm_control_panels.arm_away( + self._entity_info, code + ) + + async def async_alarm_arm_night(self, code: str | None = None, **kwargs) -> None: + """Send arm night command.""" + await self._device.gateway.alarm_control_panels.arm_night( + self._entity_info, code + ) + + async def async_alarm_trigger(self, code: str | None = None, **kwargs) -> None: + """Send alarm trigger command.""" + await self._device.gateway.alarm_control_panels.trigger(self._entity_info) diff --git a/zha/application/platforms/alarm_control_panel/const.py b/zha/application/platforms/alarm_control_panel/const.py index a5bdec719..65df5abc4 100644 --- a/zha/application/platforms/alarm_control_panel/const.py +++ b/zha/application/platforms/alarm_control_panel/const.py @@ -1,17 +1,9 @@ """Constants for the alarm control panel platform.""" from enum import IntFlag, StrEnum -from typing import Final from zigpy.zcl.clusters.security import IasAce -SUPPORT_ALARM_ARM_HOME: Final[int] = 1 -SUPPORT_ALARM_ARM_AWAY: Final[int] = 2 -SUPPORT_ALARM_ARM_NIGHT: Final[int] = 4 -SUPPORT_ALARM_TRIGGER: Final[int] = 8 -SUPPORT_ALARM_ARM_CUSTOM_BYPASS: Final[int] = 16 -SUPPORT_ALARM_ARM_VACATION: Final[int] = 32 - class AlarmState(StrEnum): """Alarm state.""" @@ -37,9 +29,6 @@ class AlarmState(StrEnum): IasAce.PanelStatus.In_Alarm: AlarmState.TRIGGERED, } -ATTR_CHANGED_BY: Final[str] = "changed_by" -ATTR_CODE_ARM_REQUIRED: Final[str] = "code_arm_required" - class CodeFormat(StrEnum): """Code formats for the Alarm Control Panel.""" diff --git a/zha/application/platforms/alarm_control_panel/model.py b/zha/application/platforms/alarm_control_panel/model.py new file mode 100644 index 000000000..002a2bf6d --- /dev/null +++ b/zha/application/platforms/alarm_control_panel/model.py @@ -0,0 +1,19 @@ +"""Models for the alarm control panel platform.""" + +from __future__ import annotations + +from zha.application.platforms.alarm_control_panel.const import ( + AlarmControlPanelEntityFeature, + CodeFormat, +) +from zha.application.platforms.model import BasePlatformEntityInfo, EntityState + + +class AlarmControlPanelEntityInfo(BasePlatformEntityInfo): + """Alarm control panel model.""" + + code_format: CodeFormat + supported_features: AlarmControlPanelEntityFeature + code_arm_required: bool + max_invalid_tries: int + state: EntityState diff --git a/zha/application/platforms/alarm_control_panel/websocket_api.py b/zha/application/platforms/alarm_control_panel/websocket_api.py new file mode 100644 index 000000000..0e5d91b1e --- /dev/null +++ b/zha/application/platforms/alarm_control_panel/websocket_api.py @@ -0,0 +1,131 @@ +"""WS api for the alarm control panel platform entity.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal, Union + +from zha.application.discovery import Platform +from zha.application.platforms.websocket_api import ( + PlatformEntityCommand, + execute_platform_entity_command, +) +from zha.websocket.const import APICommands +from zha.websocket.server.api import decorators, register_api_command + +if TYPE_CHECKING: + from zha.application.gateway import WebSocketServerGateway + from zha.websocket.server.client import Client + + +class DisarmCommand(PlatformEntityCommand): + """Disarm command.""" + + command: Literal[APICommands.ALARM_CONTROL_PANEL_DISARM] = ( + APICommands.ALARM_CONTROL_PANEL_DISARM + ) + platform: str = Platform.ALARM_CONTROL_PANEL + code: Union[str, None] = None + + +@decorators.websocket_command(DisarmCommand) +@decorators.async_response +async def disarm( + gateway: WebSocketServerGateway, client: Client, command: DisarmCommand +) -> None: + """Disarm the alarm control panel.""" + await execute_platform_entity_command( + gateway, client, command, "async_alarm_disarm" + ) + + +class ArmHomeCommand(PlatformEntityCommand): + """Arm home command.""" + + command: Literal[APICommands.ALARM_CONTROL_PANEL_ARM_HOME] = ( + APICommands.ALARM_CONTROL_PANEL_ARM_HOME + ) + platform: str = Platform.ALARM_CONTROL_PANEL + code: Union[str, None] = None + + +@decorators.websocket_command(ArmHomeCommand) +@decorators.async_response +async def arm_home( + gateway: WebSocketServerGateway, client: Client, command: ArmHomeCommand +) -> None: + """Arm the alarm control panel in home mode.""" + await execute_platform_entity_command( + gateway, client, command, "async_alarm_arm_home" + ) + + +class ArmAwayCommand(PlatformEntityCommand): + """Arm away command.""" + + command: Literal[APICommands.ALARM_CONTROL_PANEL_ARM_AWAY] = ( + APICommands.ALARM_CONTROL_PANEL_ARM_AWAY + ) + platform: str = Platform.ALARM_CONTROL_PANEL + code: Union[str, None] = None + + +@decorators.websocket_command(ArmAwayCommand) +@decorators.async_response +async def arm_away( + gateway: WebSocketServerGateway, client: Client, command: ArmAwayCommand +) -> None: + """Arm the alarm control panel in away mode.""" + await execute_platform_entity_command( + gateway, client, command, "async_alarm_arm_away" + ) + + +class ArmNightCommand(PlatformEntityCommand): + """Arm night command.""" + + command: Literal[APICommands.ALARM_CONTROL_PANEL_ARM_NIGHT] = ( + APICommands.ALARM_CONTROL_PANEL_ARM_NIGHT + ) + platform: str = Platform.ALARM_CONTROL_PANEL + code: Union[str, None] = None + + +@decorators.websocket_command(ArmNightCommand) +@decorators.async_response +async def arm_night( + gateway: WebSocketServerGateway, client: Client, command: ArmNightCommand +) -> None: + """Arm the alarm control panel in night mode.""" + await execute_platform_entity_command( + gateway, client, command, "async_alarm_arm_night" + ) + + +class TriggerAlarmCommand(PlatformEntityCommand): + """Trigger alarm command.""" + + command: Literal[APICommands.ALARM_CONTROL_PANEL_TRIGGER] = ( + APICommands.ALARM_CONTROL_PANEL_TRIGGER + ) + platform: str = Platform.ALARM_CONTROL_PANEL + code: Union[str, None] = None + + +@decorators.websocket_command(TriggerAlarmCommand) +@decorators.async_response +async def trigger( + gateway: WebSocketServerGateway, client: Client, command: TriggerAlarmCommand +) -> None: + """Trigger the alarm control panel.""" + await execute_platform_entity_command( + gateway, client, command, "async_alarm_trigger" + ) + + +def load_api(gateway: WebSocketServerGateway) -> None: + """Load the api command handlers.""" + register_api_command(gateway, disarm) + register_api_command(gateway, arm_home) + register_api_command(gateway, arm_away) + register_api_command(gateway, arm_night) + register_api_command(gateway, trigger) diff --git a/zha/application/platforms/binary_sensor/__init__.py b/zha/application/platforms/binary_sensor/__init__.py index c35b2b624..a7afa860f 100644 --- a/zha/application/platforms/binary_sensor/__init__.py +++ b/zha/application/platforms/binary_sensor/__init__.py @@ -2,23 +2,26 @@ from __future__ import annotations -from dataclasses import dataclass +from abc import ABC, abstractmethod import functools import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from zhaquirks.quirk_ids import DANFOSS_ALLY_THERMOSTAT from zigpy.quirks.v2 import BinarySensorMetadata from zha.application import Platform -from zha.application.platforms import BaseEntityInfo, EntityCategory, PlatformEntity +from zha.application.platforms import PlatformEntity, WebSocketClientEntity from zha.application.platforms.binary_sensor.const import ( IAS_ZONE_CLASS_MAPPING, BinarySensorDeviceClass, ) +from zha.application.platforms.binary_sensor.model import BinarySensorEntityInfo +from zha.application.platforms.const import EntityCategory from zha.application.platforms.helpers import validate_device_class +from zha.application.platforms.model import EntityState from zha.application.registries import PLATFORM_ENTITIES -from zha.zigbee.cluster_handlers import ClusterAttributeUpdatedEvent +from zha.const import MODEL_CLASS_NAME from zha.zigbee.cluster_handlers.const import ( CLUSTER_HANDLER_ACCELEROMETER, CLUSTER_HANDLER_ATTRIBUTE_UPDATED, @@ -31,8 +34,8 @@ ) if TYPE_CHECKING: - from zha.zigbee.cluster_handlers import ClusterHandler - from zha.zigbee.device import Device + from zha.zigbee.cluster_handlers import ClusterAttributeUpdatedEvent, ClusterHandler + from zha.zigbee.device import Device, WebSocketClientDevice from zha.zigbee.endpoint import Endpoint @@ -46,15 +49,16 @@ _LOGGER = logging.getLogger(__name__) -@dataclass(frozen=True, kw_only=True) -class BinarySensorEntityInfo(BaseEntityInfo): - """Binary sensor entity info.""" +class BinarySensorEntityInterface(ABC): + """Base class for binary sensors.""" - attribute_name: str - device_class: BinarySensorDeviceClass | None + @property + @abstractmethod + def is_on(self) -> bool: + """Return True if the switch is on based on the state machine.""" -class BinarySensor(PlatformEntity): +class BinarySensor(PlatformEntity, BinarySensorEntityInterface): """ZHA BinarySensor.""" _attr_device_class: BinarySensorDeviceClass | None @@ -90,20 +94,21 @@ def _init_from_quirks_metadata(self, entity_metadata: BinarySensorMetadata) -> N _LOGGER, ) - @functools.cached_property + @property def info_object(self) -> BinarySensorEntityInfo: """Return a representation of the binary sensor.""" return BinarySensorEntityInfo( - **super().info_object.__dict__, + **super().info_object.model_dump(exclude=[MODEL_CLASS_NAME]), attribute_name=self._attribute_name, ) @property - def state(self) -> dict: + def state(self) -> dict[str, Any]: """Return the state of the binary sensor.""" - response = super().state - response["state"] = self.is_on - return response + return EntityState( + **super().state, + state=self.is_on, + ).model_dump() @property def is_on(self) -> bool: @@ -400,3 +405,22 @@ class DanfossPreheatStatus(BinarySensor): _attr_translation_key: str = "preheat_status" _attr_entity_registry_enabled_default = False _attr_entity_category = EntityCategory.DIAGNOSTIC + + +class WebSocketClientBinarySensor( + WebSocketClientEntity[BinarySensorEntityInfo], BinarySensorEntityInterface +): + """Base class for binary sensors that are updated via a websocket client.""" + + PLATFORM: Platform = Platform.BINARY_SENSOR + + def __init__( + self, entity_info: BinarySensorEntityInfo, device: WebSocketClientDevice + ) -> None: + """Initialize the ZHA alarm control device.""" + super().__init__(entity_info, device) + + @property + def is_on(self) -> bool: + """Return True if the switch is on based on the state machine.""" + return bool(self.info_object.state.state) diff --git a/zha/application/platforms/binary_sensor/model.py b/zha/application/platforms/binary_sensor/model.py new file mode 100644 index 000000000..ab3b13eed --- /dev/null +++ b/zha/application/platforms/binary_sensor/model.py @@ -0,0 +1,12 @@ +"""Models for the binary sensor platform.""" + +from __future__ import annotations + +from zha.application.platforms.model import BasePlatformEntityInfo, EntityState + + +class BinarySensorEntityInfo(BasePlatformEntityInfo): + """Binary sensor model.""" + + attribute_name: str | None = None + state: EntityState diff --git a/zha/application/platforms/button/__init__.py b/zha/application/platforms/button/__init__.py index fa0d6271d..8408b701f 100644 --- a/zha/application/platforms/button/__init__.py +++ b/zha/application/platforms/button/__init__.py @@ -2,7 +2,7 @@ from __future__ import annotations -from dataclasses import dataclass +from abc import ABC, abstractmethod import functools import logging from typing import TYPE_CHECKING, Any, Self @@ -11,14 +11,22 @@ from zha.application import Platform from zha.application.const import ENTITY_METADATA -from zha.application.platforms import BaseEntityInfo, EntityCategory, PlatformEntity +from zha.application.platforms import PlatformEntity, WebSocketClientEntity from zha.application.platforms.button.const import DEFAULT_DURATION, ButtonDeviceClass +from zha.application.platforms.button.model import ( + ButtonEntityInfo, + CommandButtonEntityInfo, + WriteAttributeButtonEntityInfo, +) +from zha.application.platforms.const import EntityCategory +from zha.application.platforms.model import EntityState from zha.application.registries import PLATFORM_ENTITIES +from zha.const import MODEL_CLASS_NAME from zha.zigbee.cluster_handlers.const import CLUSTER_HANDLER_IDENTIFY if TYPE_CHECKING: from zha.zigbee.cluster_handlers import ClusterHandler - from zha.zigbee.device import Device + from zha.zigbee.device import Device, WebSocketClientDevice from zha.zigbee.endpoint import Endpoint @@ -30,24 +38,15 @@ _LOGGER = logging.getLogger(__name__) -@dataclass(frozen=True, kw_only=True) -class CommandButtonEntityInfo(BaseEntityInfo): - """Command button entity info.""" - - command: str - args: list[Any] - kwargs: dict[str, Any] - - -@dataclass(frozen=True, kw_only=True) -class WriteAttributeButtonEntityInfo(BaseEntityInfo): - """Write attribute button entity info.""" +class ButtonEntityInterface(ABC): + """Base class for ZHA button.""" - attribute_name: str - attribute_value: Any + @abstractmethod + async def async_press(self) -> None: + """Press the button.""" -class Button(PlatformEntity): +class Button(PlatformEntity, ButtonEntityInterface): """Defines a ZHA button.""" PLATFORM = Platform.BUTTON @@ -79,16 +78,23 @@ def _init_from_quirks_metadata( self._args = entity_metadata.args self._kwargs = entity_metadata.kwargs - @functools.cached_property + @property def info_object(self) -> CommandButtonEntityInfo: """Return a representation of the button.""" return CommandButtonEntityInfo( - **super().info_object.__dict__, + **super().info_object.model_dump(exclude=[MODEL_CLASS_NAME]), command=self._command_name, args=self._args, kwargs=self._kwargs, ) + @property + def state(self) -> dict[str, Any]: + """Return the state of the button.""" + return EntityState( + **super().state, + ).model_dump() + @functools.cached_property def args(self) -> list[Any]: """Return the arguments to use in the command.""" @@ -167,15 +173,22 @@ def _init_from_quirks_metadata( self._attribute_name = entity_metadata.attribute_name self._attribute_value = entity_metadata.attribute_value - @functools.cached_property + @property def info_object(self) -> WriteAttributeButtonEntityInfo: """Return a representation of the button.""" return WriteAttributeButtonEntityInfo( - **super().info_object.__dict__, + **super().info_object.model_dump(exclude=[MODEL_CLASS_NAME]), attribute_name=self._attribute_name, attribute_value=self._attribute_value, ) + @property + def state(self) -> dict[str, Any]: + """Return the state of the button.""" + return EntityState( + **super().state, + ).model_dump() + async def async_press(self) -> None: """Write attribute with defined value.""" await self._cluster_handler.write_attributes_safe( @@ -235,3 +248,31 @@ class AqaraSelfTestButton(WriteAttributeButton): _attribute_value = 1 _attr_entity_category = EntityCategory.CONFIG _attr_translation_key = "self_test" + + +class WebSocketClientButtonEntity( + WebSocketClientEntity[ButtonEntityInfo], ButtonEntityInterface +): + """Defines a ZHA button that is controlled via a websocket.""" + + PLATFORM = Platform.BUTTON + + def __init__( + self, entity_info: ButtonEntityInfo, device: WebSocketClientDevice + ) -> None: + """Initialize the ZHA alarm control device.""" + super().__init__(entity_info, device) + + @functools.cached_property + def args(self) -> list[Any]: + """Return the arguments to use in the command.""" + return self._entity_info.args or [] + + @functools.cached_property + def kwargs(self) -> dict[str, Any]: + """Return the keyword arguments to use in the command.""" + return self._entity_info.kwargs or {} + + async def async_press(self) -> None: + """Press the button.""" + await self._device.gateway.buttons.press(self._entity_info) diff --git a/zha/application/platforms/button/model.py b/zha/application/platforms/button/model.py new file mode 100644 index 000000000..3e2c695f5 --- /dev/null +++ b/zha/application/platforms/button/model.py @@ -0,0 +1,39 @@ +"""Models for the button platform.""" + +from __future__ import annotations + +from typing import Any + +from zha.application.platforms.model import ( + BaseEntityInfo, + BasePlatformEntityInfo, + EntityState, +) + + +class ButtonEntityInfo( + BasePlatformEntityInfo +): # TODO split into two models CommandButton and WriteAttributeButton + """Button model.""" + + command: str | None = None + attribute_name: str | None = None + attribute_value: Any | None = None + state: EntityState + + +class CommandButtonEntityInfo(BaseEntityInfo): + """Command button entity info.""" + + command: str + args: list[Any] + kwargs: dict[str, Any] + state: EntityState + + +class WriteAttributeButtonEntityInfo(BaseEntityInfo): + """Write attribute button entity info.""" + + attribute_name: str + attribute_value: Any + state: EntityState diff --git a/zha/application/platforms/button/websocket_api.py b/zha/application/platforms/button/websocket_api.py new file mode 100644 index 000000000..5bdbade62 --- /dev/null +++ b/zha/application/platforms/button/websocket_api.py @@ -0,0 +1,38 @@ +"""WS API for the button platform entity.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal + +from zha.application.discovery import Platform +from zha.application.platforms.websocket_api import ( + PlatformEntityCommand, + execute_platform_entity_command, +) +from zha.websocket.const import APICommands +from zha.websocket.server.api import decorators, register_api_command + +if TYPE_CHECKING: + from zha.application.gateway import WebSocketServerGateway + from zha.websocket.server.client import Client + + +class ButtonPressCommand(PlatformEntityCommand): + """Button press command.""" + + command: Literal[APICommands.BUTTON_PRESS] = APICommands.BUTTON_PRESS + platform: str = Platform.BUTTON + + +@decorators.websocket_command(ButtonPressCommand) +@decorators.async_response +async def press( + gateway: WebSocketServerGateway, client: Client, command: PlatformEntityCommand +) -> None: + """Turn on the button.""" + await execute_platform_entity_command(gateway, client, command, "async_press") + + +def load_api(gateway: WebSocketServerGateway) -> None: + """Load the api command handlers.""" + register_api_command(gateway, press) diff --git a/zha/application/platforms/climate/__init__.py b/zha/application/platforms/climate/__init__.py index c0ba9851b..d4afbc232 100644 --- a/zha/application/platforms/climate/__init__.py +++ b/zha/application/platforms/climate/__init__.py @@ -2,8 +2,8 @@ from __future__ import annotations +from abc import ABC, abstractmethod from asyncio import Task -from dataclasses import dataclass import datetime as dt import functools from typing import TYPE_CHECKING, Any @@ -11,7 +11,7 @@ from zigpy.zcl.clusters.hvac import FanMode, RunningState, SystemMode from zha.application import Platform -from zha.application.platforms import BaseEntityInfo, PlatformEntity +from zha.application.platforms import PlatformEntity, WebSocketClientEntity from zha.application.platforms.climate.const import ( ATTR_HVAC_MODE, ATTR_OCCP_COOL_SETPT, @@ -37,10 +37,14 @@ HVACMode, Preset, ) +from zha.application.platforms.climate.model import ( + ThermostatEntityInfo, + ThermostatState, +) from zha.application.registries import PLATFORM_ENTITIES +from zha.const import MODEL_CLASS_NAME from zha.decorators import periodic from zha.units import UnitOfTemperature -from zha.zigbee.cluster_handlers import ClusterAttributeUpdatedEvent from zha.zigbee.cluster_handlers.const import ( CLUSTER_HANDLER_ATTRIBUTE_UPDATED, CLUSTER_HANDLER_FAN, @@ -48,24 +52,107 @@ ) if TYPE_CHECKING: - from zha.zigbee.cluster_handlers import ClusterHandler - from zha.zigbee.device import Device + from zha.zigbee.cluster_handlers import ClusterAttributeUpdatedEvent, ClusterHandler + from zha.zigbee.device import Device, WebSocketClientDevice from zha.zigbee.endpoint import Endpoint STRICT_MATCH = functools.partial(PLATFORM_ENTITIES.strict_match, Platform.CLIMATE) MULTI_MATCH = functools.partial(PLATFORM_ENTITIES.multipass_match, Platform.CLIMATE) -@dataclass(frozen=True, kw_only=True) -class ThermostatEntityInfo(BaseEntityInfo): - """Thermostat entity info.""" +class ClimateEntityInterface(ABC): + """Climate interface.""" + + @property + @abstractmethod + def current_temperature(self) -> float | None: + """Return the current temperature.""" + + @property + @abstractmethod + def outdoor_temperature(self) -> float | None: + """Return the outdoor temperature.""" + + @property + @abstractmethod + def fan_mode(self) -> str | None: + """Return current FAN mode.""" - max_temp: float - min_temp: float - supported_features: ClimateEntityFeature - fan_modes: list[str] | None - preset_modes: list[str] | None - hvac_modes: list[HVACMode] + @property + @abstractmethod + def fan_modes(self) -> list[str] | None: + """Return supported FAN modes.""" + + @property + @abstractmethod + def hvac_action(self) -> HVACAction | None: + """Return the current HVAC action.""" + + @property + @abstractmethod + def hvac_mode(self) -> HVACMode | None: + """Return HVAC operation mode.""" + + @property + @abstractmethod + def hvac_modes(self) -> list[HVACMode]: + """Return the list of available HVAC operation modes.""" + + @property + @abstractmethod + def preset_mode(self) -> str: + """Return current preset mode.""" + + @property + @abstractmethod + def preset_modes(self) -> list[str] | None: + """Return supported preset modes.""" + + @property + @abstractmethod + def supported_features(self) -> ClimateEntityFeature: + """Return the list of supported features.""" + + @property + @abstractmethod + def target_temperature(self) -> float | None: + """Return the temperature we try to reach.""" + + @property + @abstractmethod + def target_temperature_high(self) -> float | None: + """Return the upper bound temperature we try to reach.""" + + @property + @abstractmethod + def target_temperature_low(self) -> float | None: + """Return the lower bound temperature we try to reach.""" + + @property + @abstractmethod + def max_temp(self) -> float: + """Return the maximum temperature.""" + + @property + @abstractmethod + def min_temp(self) -> float: + """Return the minimum temperature.""" + + @abstractmethod + async def async_set_fan_mode(self, fan_mode: str) -> None: + """Set fan mode.""" + + @abstractmethod + async def async_set_hvac_mode(self, hvac_mode: HVACMode) -> None: + """Set new target operation mode.""" + + @abstractmethod + async def async_set_preset_mode(self, preset_mode: str) -> None: + """Set new preset mode.""" + + @abstractmethod + async def async_set_temperature(self, **kwargs: Any) -> None: + """Set new target temperature.""" @MULTI_MATCH( @@ -73,7 +160,7 @@ class ThermostatEntityInfo(BaseEntityInfo): aux_cluster_handlers=CLUSTER_HANDLER_FAN, stop_on_match_group=CLUSTER_HANDLER_THERMOSTAT, ) -class Thermostat(PlatformEntity): +class Thermostat(PlatformEntity, ClimateEntityInterface): """Representation of a ZHA Thermostat device.""" PLATFORM = Platform.CLIMATE @@ -129,11 +216,11 @@ def __init__( if self._fan_cluster_handler is not None: self._supported_features |= ClimateEntityFeature.FAN_MODE - @functools.cached_property + @property def info_object(self) -> ThermostatEntityInfo: """Return a representation of the thermostat.""" return ThermostatEntityInfo( - **super().info_object.__dict__, + **super().info_object.model_dump(exclude=[MODEL_CLASS_NAME]), max_temp=self.max_temp, min_temp=self.min_temp, supported_features=self.supported_features, @@ -148,30 +235,28 @@ def state(self) -> dict[str, Any]: thermostat = self._thermostat_cluster_handler system_mode = SYSTEM_MODE_2_HVAC.get(thermostat.system_mode, "unknown") - response = super().state - response["current_temperature"] = self.current_temperature - response["outdoor_temperature"] = self.outdoor_temperature - response["target_temperature"] = self.target_temperature - response["target_temperature_high"] = self.target_temperature_high - response["target_temperature_low"] = self.target_temperature_low - response["hvac_action"] = self.hvac_action - response["hvac_mode"] = self.hvac_mode - response["preset_mode"] = self.preset_mode - response["fan_mode"] = self.fan_mode - - response[ATTR_SYS_MODE] = ( - f"[{thermostat.system_mode}]/{system_mode}" + return ThermostatState( + **super().state, + current_temperature=self.current_temperature, + outdoor_temperature=self.outdoor_temperature, + target_temperature=self.target_temperature, + target_temperature_high=self.target_temperature_high, + target_temperature_low=self.target_temperature_low, + hvac_action=self.hvac_action, + hvac_mode=self.hvac_mode, + preset_mode=self.preset_mode, + fan_mode=self.fan_mode, + system_mode=f"[{thermostat.system_mode}]/{system_mode}" if self.hvac_mode is not None - else None - ) - response[ATTR_OCCUPANCY] = thermostat.occupancy - response[ATTR_OCCP_COOL_SETPT] = thermostat.occupied_cooling_setpoint - response[ATTR_OCCP_HEAT_SETPT] = thermostat.occupied_heating_setpoint - response[ATTR_PI_HEATING_DEMAND] = thermostat.pi_heating_demand - response[ATTR_PI_COOLING_DEMAND] = thermostat.pi_cooling_demand - response[ATTR_UNOCCP_COOL_SETPT] = thermostat.unoccupied_cooling_setpoint - response[ATTR_UNOCCP_HEAT_SETPT] = thermostat.unoccupied_heating_setpoint - return response + else None, + occupancy=thermostat.occupancy, + occupied_cooling_setpoint=thermostat.occupied_cooling_setpoint, + occupied_heating_setpoint=thermostat.occupied_heating_setpoint, + pi_heating_demand=thermostat.pi_heating_demand, + pi_cooling_demand=thermostat.pi_cooling_demand, + unoccupied_cooling_setpoint=thermostat.unoccupied_cooling_setpoint, + unoccupied_heating_setpoint=thermostat.unoccupied_heating_setpoint, + ).model_dump() @property def current_temperature(self): @@ -386,7 +471,7 @@ async def _handle_cluster_handler_attribute_updated( ) self.maybe_emit_state_changed_event() - async def async_set_fan_mode(self, fan_mode: str) -> None: + async def async_set_fan_mode(self, fan_mode: str, **kwargs) -> None: """Set fan mode.""" if not self.fan_modes or fan_mode not in self.fan_modes: self.warning("Unsupported '%s' fan mode", fan_mode) @@ -396,7 +481,7 @@ async def async_set_fan_mode(self, fan_mode: str) -> None: await self._fan_cluster_handler.async_set_speed(mode) - async def async_set_hvac_mode(self, hvac_mode: HVACMode) -> None: + async def async_set_hvac_mode(self, hvac_mode: HVACMode, **kwargs) -> None: """Set new target operation mode.""" if hvac_mode not in self.hvac_modes: self.warning( @@ -411,7 +496,7 @@ async def async_set_hvac_mode(self, hvac_mode: HVACMode) -> None: ): self.maybe_emit_state_changed_event() - async def async_set_preset_mode(self, preset_mode: str) -> None: + async def async_set_preset_mode(self, preset_mode: str, **kwargs) -> None: """Set new preset mode.""" if not self.preset_modes or preset_mode not in self.preset_modes: self.debug("Preset mode '%s' is not supported", preset_mode) @@ -472,7 +557,9 @@ async def async_set_temperature(self, **kwargs: Any) -> None: self.maybe_emit_state_changed_event() - async def async_preset_handler(self, preset: str, enable: bool = False) -> None: + async def async_preset_handler( + self, preset: str, enable: bool = False, **kwargs + ) -> None: """Set the preset mode via handler.""" handler = getattr(self, f"async_preset_handler_{preset}") @@ -571,7 +658,7 @@ async def _async_update_time(self) -> None: {"secs_since_2k": secs_2k}, manufacturer=self.manufacturer ) - async def async_preset_handler_away(self, is_away: bool = False) -> None: + async def async_preset_handler_away(self, is_away: bool = False, **kwargs) -> None: """Set occupancy.""" mfg_code = self._device.manufacturer_code await self._thermostat_cluster_handler.write_attributes_safe( @@ -668,7 +755,9 @@ def handle_cluster_handler_attribute_updated( self._preset = Preset.COMPLEX super().handle_cluster_handler_attribute_updated(event) - async def async_preset_handler(self, preset: str, enable: bool = False) -> None: + async def async_preset_handler( + self, preset: str, enable: bool = False, **kwargs + ) -> None: """Set the preset mode.""" mfg_code = self._device.manufacturer_code if not enable: @@ -754,7 +843,9 @@ def handle_cluster_handler_attribute_updated( self._preset = Preset.TEMP_MANUAL super().handle_cluster_handler_attribute_updated(event) - async def async_preset_handler(self, preset: str, enable: bool = False) -> None: + async def async_preset_handler( + self, preset: str, enable: bool = False, **kwargs + ) -> None: """Set the preset mode.""" mfg_code = self._device.manufacturer_code if not enable: @@ -854,7 +945,9 @@ def handle_cluster_handler_attribute_updated( self._preset = self.PRESET_FROST super().handle_cluster_handler_attribute_updated(event) - async def async_preset_handler(self, preset: str, enable: bool = False) -> None: + async def async_preset_handler( + self, preset: str, enable: bool = False, **kwargs + ) -> None: """Set the preset mode.""" mfg_code = self._device.manufacturer_code if not enable: @@ -873,3 +966,114 @@ async def async_preset_handler(self, preset: str, enable: bool = False) -> None: return await self._thermostat_cluster_handler.write_attributes_safe( {"operation_preset": 4}, manufacturer=mfg_code ) + + +class WebSocketClientThermostatEntity( + WebSocketClientEntity[ThermostatEntityInfo], ClimateEntityInterface +): + """Representation of a ZHA Thermostat device.""" + + PLATFORM: Platform = Platform.CLIMATE + + def __init__( + self, entity_info: ThermostatEntityInfo, device: WebSocketClientDevice + ) -> None: + """Initialize the ZHA climate entity.""" + super().__init__(entity_info, device) + + @property + def current_temperature(self) -> float | None: + """Return the current temperature.""" + return self.info_object.state.current_temperature + + @property + def outdoor_temperature(self) -> float | None: + """Return the outdoor temperature.""" + return self.info_object.state.outdoor_temperature + + @property + def fan_mode(self) -> str | None: + """Return current FAN mode.""" + return self.info_object.state.fan_mode + + @property + def fan_modes(self) -> list[str] | None: + """Return supported FAN modes.""" + return self.info_object.fan_modes + + @property + def hvac_action(self) -> HVACAction | None: + """Return the current HVAC action.""" + return self.info_object.state.hvac_action + + @property + def hvac_mode(self) -> HVACMode | None: + """Return HVAC operation mode.""" + return self.info_object.state.hvac_mode + + @property + def hvac_modes(self) -> list[HVACMode]: + """Return the list of available HVAC operation modes.""" + return self.info_object.hvac_modes + + @property + def preset_mode(self) -> str: + """Return current preset mode.""" + return self.info_object.state.preset_mode + + @property + def preset_modes(self) -> list[str] | None: + """Return supported preset modes.""" + return self.info_object.preset_modes + + @property + def supported_features(self) -> ClimateEntityFeature: + """Return the list of supported features.""" + return self.info_object.supported_features + + @property + def target_temperature(self) -> float | None: + """Return the temperature we try to reach.""" + return self.info_object.state.target_temperature + + @property + def target_temperature_high(self) -> float | None: + """Return the upper bound temperature we try to reach.""" + return self.info_object.state.target_temperature_high + + @property + def target_temperature_low(self) -> float | None: + """Return the lower bound temperature we try to reach.""" + return self.info_object.state.target_temperature_low + + @property + def max_temp(self) -> float: + """Return the maximum temperature.""" + return self.info_object.max_temp + + @property + def min_temp(self) -> float: + """Return the minimum temperature.""" + return self.info_object.min_temp + + async def async_set_fan_mode(self, fan_mode: str) -> None: + """Set fan mode.""" + await self._device.gateway.thermostats.set_fan_mode(self.info_object, fan_mode) + + async def async_set_hvac_mode(self, hvac_mode: HVACMode) -> None: + """Set new target operation mode.""" + await self._device.gateway.thermostats.set_hvac_mode( + self.info_object, hvac_mode + ) + + async def async_set_preset_mode(self, preset_mode: str) -> None: + """Set new preset mode.""" + await self._device.gateway.thermostats.set_preset_mode( + self.info_object, preset_mode + ) + + async def async_set_temperature(self, **kwargs: Any) -> None: + """Set new target temperature.""" + await self._device.gateway.thermostats.set_temperature( + self.info_object, **kwargs + ) diff --git a/zha/application/platforms/climate/model.py b/zha/application/platforms/climate/model.py new file mode 100644 index 000000000..ac7fb46e2 --- /dev/null +++ b/zha/application/platforms/climate/model.py @@ -0,0 +1,46 @@ +"""Models for the climate platform.""" + +from __future__ import annotations + +from zha.application.platforms.climate.const import ( + ClimateEntityFeature, + HVACAction, + HVACMode, +) +from zha.application.platforms.model import BasePlatformEntityInfo +from zha.model import TypedBaseModel + + +class ThermostatState(TypedBaseModel): + """Thermostat state model.""" + + current_temperature: float | None = None + outdoor_temperature: float | None = None + target_temperature: float | None = None + target_temperature_low: float | None = None + target_temperature_high: float | None = None + hvac_action: HVACAction | None = None + hvac_mode: HVACMode | None = None + preset_mode: str + fan_mode: str | None = None + system_mode: str | None = None + occupancy: int | None = None + occupied_cooling_setpoint: int | None = None + occupied_heating_setpoint: int | None = None + unoccupied_heating_setpoint: int | None = None + unoccupied_cooling_setpoint: int | None = None + pi_cooling_demand: int | None = None + pi_heating_demand: int | None = None + available: bool + + +class ThermostatEntityInfo(BasePlatformEntityInfo): + """Thermostat entity model.""" + + state: ThermostatState + supported_features: ClimateEntityFeature + hvac_modes: list[HVACMode] + fan_modes: list[str] | None = None + preset_modes: list[str] | None = None + max_temp: float + min_temp: float diff --git a/zha/application/platforms/climate/websocket_api.py b/zha/application/platforms/climate/websocket_api.py new file mode 100644 index 000000000..19cf98d02 --- /dev/null +++ b/zha/application/platforms/climate/websocket_api.py @@ -0,0 +1,137 @@ +"""WS api for the climate platform entity.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal + +from zha.application.discovery import Platform +from zha.application.platforms.websocket_api import ( + PlatformEntityCommand, + execute_platform_entity_command, +) +from zha.websocket.const import APICommands +from zha.websocket.server.api import decorators, register_api_command + +if TYPE_CHECKING: + from zha.application.gateway import WebSocketServerGateway + from zha.websocket.server.client import Client + + +class ClimateSetFanModeCommand(PlatformEntityCommand): + """Set fan mode command.""" + + command: Literal[APICommands.CLIMATE_SET_FAN_MODE] = ( + APICommands.CLIMATE_SET_FAN_MODE + ) + platform: str = Platform.CLIMATE + fan_mode: str + + +@decorators.websocket_command(ClimateSetFanModeCommand) +@decorators.async_response +async def set_fan_mode( + gateway: WebSocketServerGateway, client: Client, command: ClimateSetFanModeCommand +) -> None: + """Set the fan mode for the climate platform entity.""" + await execute_platform_entity_command( + gateway, client, command, "async_set_fan_mode" + ) + + +class ClimateSetHVACModeCommand(PlatformEntityCommand): + """Set HVAC mode command.""" + + command: Literal[APICommands.CLIMATE_SET_HVAC_MODE] = ( + APICommands.CLIMATE_SET_HVAC_MODE + ) + platform: str = Platform.CLIMATE + hvac_mode: Literal[ + "off", # All activity disabled / Device is off/standby + "heat", # Heating + "cool", # Cooling + "heat_cool", # The device supports heating/cooling to a range + "auto", # The temperature is set based on a schedule, learned behavior, AI or some other related mechanism. User is not able to adjust the temperature + "dry", # Device is in Dry/Humidity mode + "fan_only", # Only the fan is on, not fan and another mode like cool + ] + + +@decorators.websocket_command(ClimateSetHVACModeCommand) +@decorators.async_response +async def set_hvac_mode( + gateway: WebSocketServerGateway, client: Client, command: ClimateSetHVACModeCommand +) -> None: + """Set the hvac mode for the climate platform entity.""" + await execute_platform_entity_command( + gateway, client, command, "async_set_hvac_mode" + ) + + +class ClimateSetPresetModeCommand(PlatformEntityCommand): + """Set preset mode command.""" + + command: Literal[APICommands.CLIMATE_SET_PRESET_MODE] = ( + APICommands.CLIMATE_SET_PRESET_MODE + ) + platform: str = Platform.CLIMATE + preset_mode: str + + +@decorators.websocket_command(ClimateSetPresetModeCommand) +@decorators.async_response +async def set_preset_mode( + gateway: WebSocketServerGateway, + client: Client, + command: ClimateSetPresetModeCommand, +) -> None: + """Set the preset mode for the climate platform entity.""" + await execute_platform_entity_command( + gateway, client, command, "async_set_preset_mode" + ) + + +class ClimateSetTemperatureCommand(PlatformEntityCommand): + """Set temperature command.""" + + command: Literal[APICommands.CLIMATE_SET_TEMPERATURE] = ( + APICommands.CLIMATE_SET_TEMPERATURE + ) + platform: str = Platform.CLIMATE + temperature: float | None = None + target_temp_high: float | None = None + target_temp_low: float | None = None + hvac_mode: ( + ( + Literal[ + "off", # All activity disabled / Device is off/standby + "heat", # Heating + "cool", # Cooling + "heat_cool", # The device supports heating/cooling to a range + "auto", # The temperature is set based on a schedule, learned behavior, AI or some other related mechanism. User is not able to adjust the temperature + "dry", # Device is in Dry/Humidity mode + "fan_only", # Only the fan is on, not fan and another mode like cool + ] + ) + | None + ) = None + + +@decorators.websocket_command(ClimateSetTemperatureCommand) +@decorators.async_response +async def set_temperature( + gateway: WebSocketServerGateway, + client: Client, + command: ClimateSetTemperatureCommand, +) -> None: + """Set the temperature and hvac mode for the climate platform entity.""" + await execute_platform_entity_command( + gateway, client, command, "async_set_temperature" + ) + + +def load_api(gateway: WebSocketServerGateway) -> None: + """Load the api command handlers.""" + register_api_command(gateway, set_fan_mode) + register_api_command(gateway, set_hvac_mode) + register_api_command(gateway, set_preset_mode) + register_api_command(gateway, set_temperature) diff --git a/zha/application/platforms/const.py b/zha/application/platforms/const.py new file mode 100644 index 000000000..d3311c299 --- /dev/null +++ b/zha/application/platforms/const.py @@ -0,0 +1,14 @@ +"""Constants for ZHA platforms.""" + +from enum import StrEnum + + +class EntityCategory(StrEnum): + """Category of an entity.""" + + # Config: An entity which allows changing the configuration of a device. + CONFIG = "config" + + # Diagnostic: An entity exposing some configuration parameter, + # or diagnostics of a device. + DIAGNOSTIC = "diagnostic" diff --git a/zha/application/platforms/cover/__init__.py b/zha/application/platforms/cover/__init__.py index 14dfe71b3..b86ee406c 100644 --- a/zha/application/platforms/cover/__init__.py +++ b/zha/application/platforms/cover/__init__.py @@ -2,6 +2,7 @@ from __future__ import annotations +from abc import ABC, abstractmethod import asyncio import functools import logging @@ -11,9 +12,8 @@ from zigpy.zcl.foundation import Status from zha.application import Platform -from zha.application.platforms import PlatformEntity +from zha.application.platforms import PlatformEntity, WebSocketClientEntity from zha.application.platforms.cover.const import ( - ATTR_CURRENT_POSITION, ATTR_POSITION, ATTR_TILT_POSITION, STATE_CLOSED, @@ -26,9 +26,15 @@ CoverEntityFeature, WCAttrs, ) +from zha.application.platforms.cover.model import ( + CoverEntityInfo, + CoverState, + ShadeEntityInfo, + ShadeState, +) from zha.application.registries import PLATFORM_ENTITIES +from zha.const import MODEL_CLASS_NAME from zha.exceptions import ZHAException -from zha.zigbee.cluster_handlers import ClusterAttributeUpdatedEvent from zha.zigbee.cluster_handlers.closures import WindowCoveringClusterHandler from zha.zigbee.cluster_handlers.const import ( CLUSTER_HANDLER_ATTRIBUTE_UPDATED, @@ -38,11 +44,11 @@ CLUSTER_HANDLER_ON_OFF, CLUSTER_HANDLER_SHADE, ) -from zha.zigbee.cluster_handlers.general import LevelChangeEvent if TYPE_CHECKING: - from zha.zigbee.cluster_handlers import ClusterHandler - from zha.zigbee.device import Device + from zha.zigbee.cluster_handlers import ClusterAttributeUpdatedEvent, ClusterHandler + from zha.zigbee.cluster_handlers.general import LevelChangeEvent + from zha.zigbee.device import Device, WebSocketClientDevice from zha.zigbee.endpoint import Endpoint _LOGGER = logging.getLogger(__name__) @@ -50,8 +56,66 @@ MULTI_MATCH = functools.partial(PLATFORM_ENTITIES.multipass_match, Platform.COVER) +class CoverEntityInterface(ABC): + """Representation of a ZHA cover.""" + + @property + @abstractmethod + def supported_features(self) -> CoverEntityFeature: + """Return supported features.""" + + @property + @abstractmethod + def is_closed(self) -> bool | None: + """Return True if the cover is closed.""" + + @property + @abstractmethod + def is_opening(self) -> bool: + """Return if the cover is opening or not.""" + + @property + @abstractmethod + def is_closing(self) -> bool: + """Return if the cover is closing or not.""" + + @property + @abstractmethod + def current_cover_position(self) -> int | None: + """Return the current position of the cover.""" + + @property + @abstractmethod + def current_cover_tilt_position(self) -> int | None: + """Return the current tilt position of the cover.""" + + async def async_open_cover(self, **kwargs: Any) -> None: + """Open the cover.""" + + async def async_open_cover_tilt(self, **kwargs: Any) -> None: + """Open the cover tilt.""" + + async def async_close_cover(self, **kwargs: Any) -> None: + """Close the cover.""" + + async def async_close_cover_tilt(self, **kwargs: Any) -> None: + """Close the cover tilt.""" + + async def async_set_cover_position(self, **kwargs: Any) -> None: + """Move the cover to a specific position.""" + + async def async_set_cover_tilt_position(self, **kwargs: Any) -> None: + """Move the cover tilt to a specific position.""" + + async def async_stop_cover(self, **kwargs: Any) -> None: + """Stop the cover.""" + + async def async_stop_cover_tilt(self, **kwargs: Any) -> None: + """Stop the cover tilt.""" + + @MULTI_MATCH(cluster_handler_names=CLUSTER_HANDLER_COVER) -class Cover(PlatformEntity): +class Cover(PlatformEntity, CoverEntityInterface): """Representation of a ZHA cover.""" PLATFORM = Platform.COVER @@ -100,22 +164,28 @@ def supported_features(self) -> CoverEntityFeature: """Return supported features.""" return self._attr_supported_features + @property + def info_object(self) -> CoverEntityInfo: + """Return the info object for this entity.""" + return CoverEntityInfo( + **super().info_object.model_dump(exclude=[MODEL_CLASS_NAME]), + supported_features=self.supported_features, + ) + @property def state(self) -> dict[str, Any]: """Get the state of the cover.""" - response = super().state - response.update( - { - ATTR_CURRENT_POSITION: self.current_cover_position, - "state": self._state, - "is_opening": self.is_opening, - "is_closing": self.is_closing, - "is_closed": self.is_closed, - "target_lift_position": self._target_lift_position, - "target_tilt_position": self._target_tilt_position, - } - ) - return response + return CoverState( + **super().state, + current_position=self.current_cover_position, + current_tilt_position=self.current_cover_tilt_position, + state=self._state, + is_opening=self.is_opening, + is_closing=self.is_closing, + is_closed=self.is_closed, + target_lift_position=self._target_lift_position, + target_tilt_position=self._target_tilt_position, + ).model_dump() def restore_external_state_attributes( self, @@ -125,11 +195,13 @@ def restore_external_state_attributes( ], # FIXME: why must these be expanded? target_lift_position: int | None, target_tilt_position: int | None, + **kwargs: Any, ): """Restore external state attributes.""" self._state = state self._target_lift_position = target_lift_position self._target_tilt_position = target_tilt_position + self.maybe_emit_state_changed_event() @property def is_closed(self) -> bool | None: @@ -410,6 +482,14 @@ def __init__( | CoverEntityFeature.SET_POSITION ) + @property + def info_object(self) -> ShadeEntityInfo: + """Return the info object for this entity.""" + return ShadeEntityInfo( + **super().info_object.model_dump(exclude=[MODEL_CLASS_NAME]), + supported_features=self.supported_features, + ) + @property def state(self) -> dict[str, Any]: """Get the state of the cover.""" @@ -417,15 +497,15 @@ def state(self) -> dict[str, Any]: state = None else: state = STATE_CLOSED if closed else STATE_OPEN - response = super().state - response.update( - { - ATTR_CURRENT_POSITION: self.current_cover_position, - "is_closed": self.is_closed, - "state": state, - } - ) - return response + + return ShadeState( + **super().state, + current_position=self.current_cover_position, + state=state, + is_opening=self.is_opening, + is_closing=self.is_closing, + is_closed=closed, + ).model_dump() @functools.cached_property def is_opening(self) -> bool: @@ -511,6 +591,19 @@ async def async_stop_cover(self, **kwargs: Any) -> None: # pylint: disable=unus if res[1] != Status.SUCCESS: raise ZHAException(f"Failed to stop cover: {res[1]}") + def restore_external_state_attributes( + self, + *, + state: Literal[ + "open", "opening", "closed", "closing" + ], # FIXME: why must these be expanded? + target_lift_position: int | None, + target_tilt_position: int | None, + **kwargs: Any, + ): + """Restore external state attributes.""" + # Shades don't restore state attributes + @MULTI_MATCH( cluster_handler_names={CLUSTER_HANDLER_LEVEL, CLUSTER_HANDLER_ON_OFF}, @@ -535,3 +628,101 @@ async def async_open_cover(self, **kwargs: Any) -> None: self._is_open = True self._position = position self.maybe_emit_state_changed_event() + + +class WebSocketClientCoverEntity( + WebSocketClientEntity[CoverEntityInfo], CoverEntityInterface +): + """Representation of a ZHA cover.""" + + PLATFORM: Platform = Platform.COVER + + def __init__( + self, entity_info: CoverEntityInfo, device: WebSocketClientDevice + ) -> None: + """Initialize the ZHA fan entity.""" + super().__init__(entity_info, device) + + @property + def supported_features(self) -> CoverEntityFeature: + """Return supported features.""" + return self.info_object.supported_features + + @property + def is_closed(self) -> bool | None: + """Return True if the cover is closed.""" + return self.info_object.state.is_closed + + @property + def is_opening(self) -> bool: + """Return if the cover is opening or not.""" + return self.info_object.state.is_opening + + @property + def is_closing(self) -> bool: + """Return if the cover is closing or not.""" + return self.info_object.state.is_closing + + @property + def current_cover_position(self) -> int | None: + """Return the current position of the cover.""" + return self.info_object.state.current_position + + @property + def current_cover_tilt_position(self) -> int | None: + """Return the current tilt position of the cover.""" + return self.info_object.state.current_tilt_position + + async def async_open_cover(self, **kwargs: Any) -> None: + """Open the cover.""" + await self._device.gateway.covers.open_cover(self.info_object) + + async def async_open_cover_tilt(self, **kwargs: Any) -> None: + """Open the cover tilt.""" + await self._device.gateway.covers.open_cover_tilt(self.info_object) + + async def async_close_cover(self, **kwargs: Any) -> None: + """Close the cover.""" + await self._device.gateway.covers.close_cover(self.info_object) + + async def async_close_cover_tilt(self, **kwargs: Any) -> None: + """Close the cover tilt.""" + await self._device.gateway.covers.close_cover_tilt(self.info_object) + + async def async_set_cover_position(self, **kwargs: Any) -> None: + """Move the cover to a specific position.""" + await self._device.gateway.covers.set_cover_position(self.info_object, **kwargs) + + async def async_set_cover_tilt_position(self, **kwargs: Any) -> None: + """Move the cover tilt to a specific position.""" + await self._device.gateway.covers.set_cover_tilt_position( + self.info_object, **kwargs + ) + + async def async_stop_cover(self, **kwargs: Any) -> None: + """Stop the cover.""" + await self._device.gateway.covers.stop_cover(self.info_object) + + async def async_stop_cover_tilt(self, **kwargs: Any) -> None: + """Stop the cover tilt.""" + await self._device.gateway.covers.stop_cover_tilt(self.info_object) + + def restore_external_state_attributes( + self, + *, + state: Literal[ + "open", "opening", "closed", "closing", "unavailable" + ], # FIXME: why must these be expanded? + target_lift_position: int | None, + target_tilt_position: int | None, + **kwargs: Any, + ): + """Restore external state attributes.""" + self._device.gateway.create_and_track_task( + self._device.gateway.covers.restore_external_state_attributes( + self.info_object, + state=state, + target_lift_position=target_lift_position, + target_tilt_position=target_tilt_position, + ) + ) diff --git a/zha/application/platforms/cover/model.py b/zha/application/platforms/cover/model.py new file mode 100644 index 000000000..a868e8f72 --- /dev/null +++ b/zha/application/platforms/cover/model.py @@ -0,0 +1,46 @@ +"""Models for the device tracker platform.""" + +from __future__ import annotations + +from zha.application.platforms.cover.const import CoverEntityFeature +from zha.application.platforms.model import BasePlatformEntityInfo +from zha.model import TypedBaseModel + + +class CoverState(TypedBaseModel): + """Cover state model.""" + + current_position: int | None = None + current_tilt_position: int | None = None + target_lift_position: int | None = None + target_tilt_position: int | None = None + state: str | None = None + is_opening: bool + is_closing: bool + is_closed: bool | None = None + available: bool + + +class ShadeState(TypedBaseModel): + """Cover state model.""" + + current_position: int | None = ( + None # TODO: how should we represent this when it is None? + ) + is_closed: bool | None = None + state: str | None = None + available: bool + + +class CoverEntityInfo(BasePlatformEntityInfo): + """Cover entity model.""" + + supported_features: CoverEntityFeature + state: CoverState + + +class ShadeEntityInfo(BasePlatformEntityInfo): + """Shade entity model.""" + + supported_features: CoverEntityFeature + state: ShadeState diff --git a/zha/application/platforms/cover/websocket_api.py b/zha/application/platforms/cover/websocket_api.py new file mode 100644 index 000000000..4a6869a31 --- /dev/null +++ b/zha/application/platforms/cover/websocket_api.py @@ -0,0 +1,199 @@ +"""WS API for the cover platform entity.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal + +from zha.application.discovery import Platform +from zha.application.platforms.websocket_api import ( + PlatformEntityCommand, + execute_platform_entity_command, +) +from zha.websocket.const import APICommands +from zha.websocket.server.api import decorators, register_api_command + +if TYPE_CHECKING: + from zha.application.gateway import WebSocketServerGateway + from zha.websocket.server.client import Client + + +class CoverOpenCommand(PlatformEntityCommand): + """Cover open command.""" + + command: Literal[APICommands.COVER_OPEN] = APICommands.COVER_OPEN + platform: str = Platform.COVER + + +@decorators.websocket_command(CoverOpenCommand) +@decorators.async_response +async def open_cover( + gateway: WebSocketServerGateway, client: Client, command: CoverOpenCommand +) -> None: + """Open the cover.""" + await execute_platform_entity_command(gateway, client, command, "async_open_cover") + + +class CoverOpenTiltCommand(PlatformEntityCommand): + """Cover open tilt command.""" + + command: Literal[APICommands.COVER_OPEN_TILT] = APICommands.COVER_OPEN_TILT + platform: str = Platform.COVER + + +@decorators.websocket_command(CoverOpenTiltCommand) +@decorators.async_response +async def open_cover_tilt( + gateway: WebSocketServerGateway, client: Client, command: CoverOpenTiltCommand +) -> None: + """Open the cover tilt.""" + await execute_platform_entity_command( + gateway, client, command, "async_open_cover_tilt" + ) + + +class CoverCloseCommand(PlatformEntityCommand): + """Cover close command.""" + + command: Literal[APICommands.COVER_CLOSE] = APICommands.COVER_CLOSE + platform: str = Platform.COVER + + +@decorators.websocket_command(CoverCloseCommand) +@decorators.async_response +async def close_cover( + gateway: WebSocketServerGateway, client: Client, command: CoverCloseCommand +) -> None: + """Close the cover.""" + await execute_platform_entity_command(gateway, client, command, "async_close_cover") + + +class CoverCloseTiltCommand(PlatformEntityCommand): + """Cover close tilt command.""" + + command: Literal[APICommands.COVER_CLOSE_TILT] = APICommands.COVER_CLOSE_TILT + platform: str = Platform.COVER + + +@decorators.websocket_command(CoverCloseTiltCommand) +@decorators.async_response +async def close_cover_tilt( + gateway: WebSocketServerGateway, client: Client, command: CoverCloseTiltCommand +) -> None: + """Close the cover tilt.""" + await execute_platform_entity_command( + gateway, client, command, "async_close_cover_tilt" + ) + + +class CoverSetPositionCommand(PlatformEntityCommand): + """Cover set position command.""" + + command: Literal[APICommands.COVER_SET_POSITION] = APICommands.COVER_SET_POSITION + platform: str = Platform.COVER + position: int + + +@decorators.websocket_command(CoverSetPositionCommand) +@decorators.async_response +async def set_position( + gateway: WebSocketServerGateway, client: Client, command: CoverSetPositionCommand +) -> None: + """Set the cover position.""" + await execute_platform_entity_command( + gateway, client, command, "async_set_cover_position" + ) + + +class CoverSetTiltPositionCommand(PlatformEntityCommand): + """Cover set position command.""" + + command: Literal[APICommands.COVER_SET_TILT_POSITION] = ( + APICommands.COVER_SET_TILT_POSITION + ) + platform: str = Platform.COVER + tilt_position: int + + +@decorators.websocket_command(CoverSetTiltPositionCommand) +@decorators.async_response +async def set_tilt_position( + gateway: WebSocketServerGateway, + client: Client, + command: CoverSetTiltPositionCommand, +) -> None: + """Set the cover tilt position.""" + await execute_platform_entity_command( + gateway, client, command, "async_set_cover_tilt_position" + ) + + +class CoverStopCommand(PlatformEntityCommand): + """Cover stop command.""" + + command: Literal[APICommands.COVER_STOP] = APICommands.COVER_STOP + platform: str = Platform.COVER + + +@decorators.websocket_command(CoverStopCommand) +@decorators.async_response +async def stop_cover( + gateway: WebSocketServerGateway, client: Client, command: CoverStopCommand +) -> None: + """Stop the cover.""" + await execute_platform_entity_command(gateway, client, command, "async_stop_cover") + + +class CoverStopTiltCommand(PlatformEntityCommand): + """Cover stop tilt command.""" + + command: Literal[APICommands.COVER_STOP_TILT] = APICommands.COVER_STOP_TILT + platform: str = Platform.COVER + + +@decorators.websocket_command(CoverStopTiltCommand) +@decorators.async_response +async def stop_cover_tilt( + gateway: WebSocketServerGateway, client: Client, command: CoverStopTiltCommand +) -> None: + """Stop the cover tilt.""" + await execute_platform_entity_command( + gateway, client, command, "async_stop_cover_tilt" + ) + + +class CoverRestoreExternalStateAttributesCommand(PlatformEntityCommand): + """Cover restore external state attributes command.""" + + command: Literal[APICommands.COVER_RESTORE_EXTERNAL_STATE_ATTRIBUTES] = ( + APICommands.COVER_RESTORE_EXTERNAL_STATE_ATTRIBUTES + ) + platform: str = Platform.COVER + state: Literal["open", "opening", "closed", "closing", "unavailable"] + target_lift_position: int | None = None + target_tilt_position: int | None = None + + +@decorators.websocket_command(CoverRestoreExternalStateAttributesCommand) +@decorators.async_response +async def restore_cover_external_state_attributes( + gateway: WebSocketServerGateway, + client: Client, + command: CoverRestoreExternalStateAttributesCommand, +) -> None: + """Stop the cover tilt.""" + await execute_platform_entity_command( + gateway, client, command, "restore_external_state_attributes" + ) + + +def load_api(gateway: WebSocketServerGateway) -> None: + """Load the api command handlers.""" + register_api_command(gateway, open_cover) + register_api_command(gateway, close_cover) + register_api_command(gateway, set_position) + register_api_command(gateway, stop_cover) + register_api_command(gateway, open_cover_tilt) + register_api_command(gateway, close_cover_tilt) + register_api_command(gateway, set_tilt_position) + register_api_command(gateway, stop_cover_tilt) + register_api_command(gateway, restore_cover_external_state_attributes) diff --git a/zha/application/platforms/device_tracker.py b/zha/application/platforms/device_tracker/__init__.py similarity index 60% rename from zha/application/platforms/device_tracker.py rename to zha/application/platforms/device_tracker/__init__.py index 6c0d0eb07..5b4a50ee6 100644 --- a/zha/application/platforms/device_tracker.py +++ b/zha/application/platforms/device_tracker/__init__.py @@ -2,7 +2,7 @@ from __future__ import annotations -from enum import StrEnum +from abc import ABC, abstractmethod import functools import time from typing import TYPE_CHECKING, Any @@ -10,19 +10,24 @@ from zigpy.zcl.clusters.general import PowerConfiguration from zha.application import Platform -from zha.application.platforms import PlatformEntity +from zha.application.platforms import PlatformEntity, WebSocketClientEntity +from zha.application.platforms.device_tracker.const import SourceType +from zha.application.platforms.device_tracker.model import ( + DeviceTrackerEntityInfo, + DeviceTrackerState, +) from zha.application.platforms.sensor import Battery from zha.application.registries import PLATFORM_ENTITIES +from zha.const import MODEL_CLASS_NAME from zha.decorators import periodic -from zha.zigbee.cluster_handlers import ClusterAttributeUpdatedEvent from zha.zigbee.cluster_handlers.const import ( CLUSTER_HANDLER_ATTRIBUTE_UPDATED, CLUSTER_HANDLER_POWER_CONFIGURATION, ) if TYPE_CHECKING: - from zha.zigbee.cluster_handlers import ClusterHandler - from zha.zigbee.device import Device + from zha.zigbee.cluster_handlers import ClusterAttributeUpdatedEvent, ClusterHandler + from zha.zigbee.device import Device, WebSocketClientDevice from zha.zigbee.endpoint import Endpoint STRICT_MATCH = functools.partial( @@ -30,17 +35,30 @@ ) -class SourceType(StrEnum): - """Source type for device trackers.""" +class DeviceTrackerEntityInterface(ABC): + """Device tracker interface.""" - GPS = "gps" - ROUTER = "router" - BLUETOOTH = "bluetooth" - BLUETOOTH_LE = "bluetooth_le" + @property + @abstractmethod + def is_connected(self) -> bool: + """Return true if the device is connected to the network.""" + + @property + @abstractmethod + def source_type(self) -> SourceType: + """Return the source type, eg gps or router, of the device.""" + + @property + @abstractmethod + def battery_level(self) -> float | None: + """Return the battery level of the device. + + Percentage from 0-100. + """ @STRICT_MATCH(cluster_handler_names=CLUSTER_HANDLER_POWER_CONFIGURATION) -class DeviceScannerEntity(PlatformEntity): +class DeviceScannerEntity(PlatformEntity, DeviceTrackerEntityInterface): """Represent a tracked device.""" PLATFORM = Platform.DEVICE_TRACKER @@ -83,17 +101,22 @@ def __init__( getattr(self, "__polling_interval"), ) + @property + def info_object(self) -> DeviceTrackerEntityInfo: + """Return a representation of the device tracker.""" + return DeviceTrackerEntityInfo( + **super().info_object.model_dump(exclude=[MODEL_CLASS_NAME]) + ) + @property def state(self) -> dict[str, Any]: """Return the state of the device.""" - response = super().state - response.update( - { - "connected": self._connected, - "battery_level": self._battery_level, - } - ) - return response + return DeviceTrackerState( + **super().state, + connected=self._connected, + battery_level=self._battery_level, + source_type=self.source_type, + ).model_dump() @property def is_connected(self): @@ -143,3 +166,35 @@ def handle_cluster_handler_attribute_updated( self._connected = True self._battery_level = Battery.formatter(event.attribute_value) self.maybe_emit_state_changed_event() + + +class WebSocketClientDeviceTrackerEntity( + WebSocketClientEntity[DeviceTrackerEntityInfo], DeviceTrackerEntityInterface +): + """Device tracker entity for the WebSocket API.""" + + PLATFORM = Platform.DEVICE_TRACKER + + def __init__( + self, entity_info: DeviceTrackerEntityInfo, device: WebSocketClientDevice + ) -> None: + """Initialize the ZHA device tracker.""" + super().__init__(entity_info, device) + + @property + def is_connected(self) -> bool: + """Return true if the device is connected to the network.""" + return self.info_object.state.connected + + @property + def source_type(self) -> SourceType: + """Return the source type, eg gps or router, of the device.""" + return self.info_object.state.source_type + + @property + def battery_level(self) -> float | None: + """Return the battery level of the device. + + Percentage from 0-100. + """ + return self.info_object.state.battery_level diff --git a/zha/application/platforms/device_tracker/const.py b/zha/application/platforms/device_tracker/const.py new file mode 100644 index 000000000..cadc487b7 --- /dev/null +++ b/zha/application/platforms/device_tracker/const.py @@ -0,0 +1,14 @@ +"""Constants for the ZHA device tracker platform.""" + +from __future__ import annotations + +from enum import StrEnum + + +class SourceType(StrEnum): + """Source type for device trackers.""" + + GPS = "gps" + ROUTER = "router" + BLUETOOTH = "bluetooth" + BLUETOOTH_LE = "bluetooth_le" diff --git a/zha/application/platforms/device_tracker/model.py b/zha/application/platforms/device_tracker/model.py new file mode 100644 index 000000000..9da67abc9 --- /dev/null +++ b/zha/application/platforms/device_tracker/model.py @@ -0,0 +1,22 @@ +"""Models for the device tracker platform.""" + +from __future__ import annotations + +from zha.application.platforms.device_tracker.const import SourceType +from zha.application.platforms.model import BasePlatformEntityInfo +from zha.model import TypedBaseModel + + +class DeviceTrackerState(TypedBaseModel): + """Device tracker state model.""" + + connected: bool + battery_level: float | None = None + source_type: SourceType + available: bool + + +class DeviceTrackerEntityInfo(BasePlatformEntityInfo): + """Device tracker entity model.""" + + state: DeviceTrackerState diff --git a/zha/application/platforms/events.py b/zha/application/platforms/events.py new file mode 100644 index 000000000..503a7a452 --- /dev/null +++ b/zha/application/platforms/events.py @@ -0,0 +1,61 @@ +"""Events for ZHA platforms.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal + +from zigpy.types.named import EUI64 + +from zha.application import Platform +from zha.application.platforms.climate.model import ThermostatState +from zha.application.platforms.cover.model import CoverState, ShadeState +from zha.application.platforms.device_tracker.model import DeviceTrackerState +from zha.application.platforms.fan.model import FanState +from zha.application.platforms.light.model import LightState +from zha.application.platforms.lock.model import LockState +from zha.application.platforms.model import EntityState +from zha.application.platforms.sensor.model import ( + BatteryState, + DeviceCounterSensorState, + ElectricalMeasurementState, + SmartEnergyMeteringState, + TimestampState, +) +from zha.application.platforms.switch.model import SwitchState +from zha.application.platforms.update.model import FirmwareUpdateState +from zha.const import EntityEvents, EventTypes +from zha.model import BaseEvent, as_tagged_union + +EntityStateUnion = ( + DeviceTrackerState + | CoverState + | ShadeState + | FanState + | LockState + | BatteryState + | ElectricalMeasurementState + | LightState + | SwitchState + | SmartEnergyMeteringState + | EntityState + | ThermostatState + | FirmwareUpdateState + | DeviceCounterSensorState + | TimestampState +) + +if not TYPE_CHECKING: + EntityStateUnion = as_tagged_union(EntityStateUnion) + + +class EntityStateChangedEvent(BaseEvent): + """Event for when an entity state changes.""" + + event_type: Literal[EventTypes.ENTITY_EVENT] = EventTypes.ENTITY_EVENT + event: Literal[EntityEvents.STATE_CHANGED] = EntityEvents.STATE_CHANGED + platform: Platform + unique_id: str + device_ieee: EUI64 | None = None + endpoint_id: int | None = None + group_id: int | None = None + state: EntityStateUnion | None diff --git a/zha/application/platforms/fan/__init__.py b/zha/application/platforms/fan/__init__.py index b3270a1a9..dd9c04886 100644 --- a/zha/application/platforms/fan/__init__.py +++ b/zha/application/platforms/fan/__init__.py @@ -2,8 +2,7 @@ from __future__ import annotations -from abc import abstractmethod -from dataclasses import dataclass +from abc import ABC, abstractmethod import functools import math from typing import TYPE_CHECKING, Any @@ -13,9 +12,9 @@ from zha.application import Platform from zha.application.platforms import ( BaseEntity, - BaseEntityInfo, GroupEntity, PlatformEntity, + WebSocketClientEntity, ) from zha.application.platforms.fan.const import ( ATTR_PERCENTAGE, @@ -38,38 +37,98 @@ percentage_to_ranged_value, ranged_value_to_percentage, ) +from zha.application.platforms.fan.model import FanEntityInfo, FanState from zha.application.registries import PLATFORM_ENTITIES -from zha.zigbee.cluster_handlers import ( - ClusterAttributeUpdatedEvent, - wrap_zigpy_exceptions, -) +from zha.const import MODEL_CLASS_NAME +from zha.zigbee.cluster_handlers import wrap_zigpy_exceptions from zha.zigbee.cluster_handlers.const import ( CLUSTER_HANDLER_ATTRIBUTE_UPDATED, CLUSTER_HANDLER_FAN, ) -from zha.zigbee.group import Group if TYPE_CHECKING: - from zha.zigbee.cluster_handlers import ClusterHandler - from zha.zigbee.device import Device + from zha.zigbee.cluster_handlers import ClusterAttributeUpdatedEvent, ClusterHandler + from zha.zigbee.device import Device, WebSocketClientDevice from zha.zigbee.endpoint import Endpoint + from zha.zigbee.group import Group STRICT_MATCH = functools.partial(PLATFORM_ENTITIES.strict_match, Platform.FAN) GROUP_MATCH = functools.partial(PLATFORM_ENTITIES.group_match, Platform.FAN) MULTI_MATCH = functools.partial(PLATFORM_ENTITIES.multipass_match, Platform.FAN) -@dataclass(frozen=True, kw_only=True) -class FanEntityInfo(BaseEntityInfo): - """Fan entity info.""" +class FanEntityInterface(ABC): + """Fan interface.""" + + @property + @abstractmethod + def preset_modes(self) -> list[str]: + """Return the available preset modes.""" + + @property + @abstractmethod + def default_on_percentage(self) -> int: + """Return the default on percentage.""" + + @property + @abstractmethod + def speed_list(self) -> list[str]: + """Get the list of available speeds.""" + + @property + @abstractmethod + def speed_count(self) -> int: + """Return the number of speeds the fan supports.""" + + @property + @abstractmethod + def supported_features(self) -> FanEntityFeature: + """Flag supported features.""" - preset_modes: list[str] - supported_features: FanEntityFeature - speed_count: int - speed_list: list[str] + @property + @abstractmethod + def is_on(self) -> bool: + """Return true if the entity is on.""" + @property + @abstractmethod + def percentage(self) -> int | None: + """Return the current speed percentage.""" -class BaseFan(BaseEntity): + @property + @abstractmethod + def preset_mode(self) -> str | None: + """Return the current preset mode.""" + + @property + @abstractmethod + def speed(self) -> str | None: + """Return the current speed.""" + + @abstractmethod + async def async_turn_on( + self, + speed: str | None = None, + percentage: int | None = None, + preset_mode: str | None = None, + **kwargs: Any, + ) -> None: + """Turn the entity on.""" + + @abstractmethod + async def async_turn_off(self, **kwargs: Any) -> None: + """Turn the entity off.""" + + @abstractmethod + async def async_set_percentage(self, percentage: int) -> None: + """Set the speed percentage of the fan.""" + + @abstractmethod + async def async_set_preset_mode(self, preset_mode: str) -> None: + """Set the preset mode for the fan.""" + + +class BaseFan(BaseEntity, FanEntityInterface): """Base representation of a ZHA fan.""" PLATFORM = Platform.FAN @@ -156,12 +215,12 @@ async def async_turn_off(self, **kwargs: Any) -> None: # pylint: disable=unused """Turn the entity off.""" await self.async_set_percentage(0) - async def async_set_percentage(self, percentage: int) -> None: + async def async_set_percentage(self, percentage: int, **kwargs) -> None: """Set the speed percentage of the fan.""" fan_mode = math.ceil(percentage_to_ranged_value(self.speed_range, percentage)) await self._async_set_fan_mode(fan_mode) - async def async_set_preset_mode(self, preset_mode: str) -> None: + async def async_set_preset_mode(self, preset_mode: str, **kwargs) -> None: """Set the preset mode for the fan.""" try: mode = self.preset_name_to_mode[preset_mode] @@ -172,7 +231,7 @@ async def async_set_preset_mode(self, preset_mode: str) -> None: await self._async_set_fan_mode(mode) @abstractmethod - async def _async_set_fan_mode(self, fan_mode: int) -> None: + async def _async_set_fan_mode(self, fan_mode: int, **kwargs) -> None: """Set the fan mode for the fan.""" def handle_cluster_handler_attribute_updated( @@ -220,30 +279,29 @@ def __init__( self.handle_cluster_handler_attribute_updated, ) - @functools.cached_property + @property def info_object(self) -> FanEntityInfo: """Return a representation of the binary sensor.""" return FanEntityInfo( - **super().info_object.__dict__, + **super().info_object.model_dump(exclude=[MODEL_CLASS_NAME]), preset_modes=self.preset_modes, supported_features=self.supported_features, speed_count=self.speed_count, speed_list=self.speed_list, + default_on_percentage=self.default_on_percentage, + percentage_step=self.percentage_step, ) @property - def state(self) -> dict: + def state(self) -> dict[str, Any]: """Return the state of the fan.""" - response = super().state - response.update( - { - "preset_mode": self.preset_mode, - "percentage": self.percentage, - "is_on": self.is_on, - "speed": self.speed, - } - ) - return response + return FanState( + **super().state, + preset_mode=self.preset_mode, + percentage=self.percentage, + is_on=self.is_on, + speed=self.speed, + ).model_dump() @property def percentage(self) -> int | None: @@ -273,7 +331,7 @@ def speed(self) -> str | None: return None return self.percentage_to_speed(percentage) - async def _async_set_fan_mode(self, fan_mode: int) -> None: + async def _async_set_fan_mode(self, fan_mode: int, **kwargs) -> None: """Set the fan mode for the fan.""" await self._fan_cluster_handler.async_set_speed(fan_mode) self.maybe_emit_state_changed_event() @@ -289,34 +347,31 @@ def __init__(self, group: Group): super().__init__(group) self._percentage = None self._preset_mode = None - if hasattr(self, "info_object"): - delattr(self, "info_object") self.update() - @functools.cached_property + @property def info_object(self) -> FanEntityInfo: """Return a representation of the binary sensor.""" return FanEntityInfo( - **super().info_object.__dict__, + **super().info_object.model_dump(exclude=[MODEL_CLASS_NAME]), preset_modes=self.preset_modes, supported_features=self.supported_features, speed_count=self.speed_count, speed_list=self.speed_list, + default_on_percentage=self.default_on_percentage, + percentage_step=self.percentage_step, ) @property def state(self) -> dict[str, Any]: """Return the state of the fan.""" - response = super().state - response.update( - { - "preset_mode": self.preset_mode, - "percentage": self.percentage, - "is_on": self.is_on, - "speed": self.speed, - } - ) - return response + return FanState( + **super().state, + preset_mode=self.preset_mode, + percentage=self.percentage, + is_on=self.is_on, + speed=self.speed, + ).model_dump() @property def percentage(self) -> int | None: @@ -337,7 +392,7 @@ def speed(self) -> str | None: return None return self.percentage_to_speed(percentage) - async def _async_set_fan_mode(self, fan_mode: int) -> None: + async def _async_set_fan_mode(self, fan_mode: int, **kwargs) -> None: """Set the fan mode for the group.""" with wrap_zigpy_exceptions(): @@ -446,7 +501,7 @@ async def async_turn_on( else: await super().async_turn_on(speed, percentage, preset_mode) - async def async_set_percentage(self, percentage: int) -> None: + async def async_set_percentage(self, percentage: int, **kwargs) -> None: """Set the speed percentage of the fan.""" fan_mode = math.ceil(percentage_to_ranged_value(self.speed_range, percentage)) # 1 is a mode, not a speed, so we skip to 2 instead. @@ -478,3 +533,93 @@ def speed_range(self) -> tuple[int, int]: def preset_modes_to_name(self) -> dict[int, str]: """Return a dict from preset mode to name.""" return {6: PRESET_MODE_SMART} + + +class WebSocketClientFanEntity( + WebSocketClientEntity[FanEntityInfo], FanEntityInterface +): + """Representation of a ZHA fan over WebSocket.""" + + PLATFORM: Platform = Platform.FAN + + def __init__( + self, entity_info: FanEntityInfo, device: WebSocketClientDevice + ) -> None: + """Initialize the ZHA fan entity.""" + super().__init__(entity_info, device) + + @property + def preset_modes(self) -> list[str]: + """Return the available preset modes.""" + return self.info_object.preset_modes + + @property + def default_on_percentage(self) -> int: + """Return the default on percentage.""" + return self.info_object.default_on_percentage + + @property + def speed_list(self) -> list[str]: + """Get the list of available speeds.""" + return self.info_object.speed_list + + @property + def speed_count(self) -> int: + """Return the number of speeds the fan supports.""" + return self.info_object.speed_count + + @property + def supported_features(self) -> FanEntityFeature: + """Flag supported features.""" + return self.info_object.supported_features + + @property + def is_on(self) -> bool: + """Return true if the entity is on.""" + return self.info_object.state.is_on + + @property + def percentage(self) -> int | None: + """Return the current speed percentage.""" + return self.info_object.state.percentage + + @property + def preset_mode(self) -> str | None: + """Return the current preset mode.""" + return self.info_object.state.preset_mode + + @property + def speed(self) -> str | None: + """Return the current speed.""" + return self.info_object.state.speed + + @property + def percentage_step(self) -> float: + """Return the step size for percentage.""" + return self.info_object.percentage_step + + async def async_turn_on( + self, + speed: str | None = None, + percentage: int | None = None, + preset_mode: str | None = None, + **kwargs: Any, + ) -> None: + """Turn the entity on.""" + await self._device.gateway.fans.turn_on( + self.info_object, speed, percentage, preset_mode + ) + + async def async_turn_off(self, **kwargs: Any) -> None: + """Turn the entity off.""" + await self._device.gateway.fans.turn_off(self.info_object) + + async def async_set_percentage(self, percentage: int) -> None: + """Set the speed percentage of the fan.""" + await self._device.gateway.fans.set_fan_percentage(self.info_object, percentage) + + async def async_set_preset_mode(self, preset_mode: str) -> None: + """Set the preset mode for the fan.""" + await self._device.gateway.fans.set_fan_preset_mode( + self.info_object, preset_mode + ) diff --git a/zha/application/platforms/fan/model.py b/zha/application/platforms/fan/model.py new file mode 100644 index 000000000..5d9c2ae34 --- /dev/null +++ b/zha/application/platforms/fan/model.py @@ -0,0 +1,33 @@ +"""Models for the fan platform.""" + +from __future__ import annotations + +from zha.application.platforms.fan.const import FanEntityFeature +from zha.application.platforms.model import BasePlatformEntityInfo +from zha.model import TypedBaseModel + + +class FanState(TypedBaseModel): + """Fan state model.""" + + preset_mode: str | None = ( + None # TODO: how should we represent these when they are None? + ) + percentage: int | None = ( + None # TODO: how should we represent these when they are None? + ) + is_on: bool + speed: str | None = None + available: bool + + +class FanEntityInfo(BasePlatformEntityInfo): + """Fan model.""" + + preset_modes: list[str] + supported_features: FanEntityFeature + default_on_percentage: int + speed_count: int + speed_list: list[str] + percentage_step: float + state: FanState diff --git a/zha/application/platforms/fan/websocket_api.py b/zha/application/platforms/fan/websocket_api.py new file mode 100644 index 000000000..658546d91 --- /dev/null +++ b/zha/application/platforms/fan/websocket_api.py @@ -0,0 +1,100 @@ +"""WS API for the fan platform entity.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Annotated, Literal + +from pydantic import Field + +from zha.application.discovery import Platform +from zha.application.platforms.websocket_api import ( + PlatformEntityCommand, + execute_platform_entity_command, +) +from zha.websocket.const import APICommands +from zha.websocket.server.api import decorators, register_api_command + +if TYPE_CHECKING: + from zha.application.gateway import WebSocketServerGateway + from zha.websocket.server.client import Client + + +class FanTurnOnCommand(PlatformEntityCommand): + """Fan turn on command.""" + + command: Literal[APICommands.FAN_TURN_ON] = APICommands.FAN_TURN_ON + platform: str = Platform.FAN + speed: str | None = None + percentage: Annotated[int, Field(ge=0, le=100)] | None = None + preset_mode: str | None = None + + +@decorators.websocket_command(FanTurnOnCommand) +@decorators.async_response +async def turn_on( + gateway: WebSocketServerGateway, client: Client, command: FanTurnOnCommand +) -> None: + """Turn fan on.""" + await execute_platform_entity_command(gateway, client, command, "async_turn_on") + + +class FanTurnOffCommand(PlatformEntityCommand): + """Fan turn off command.""" + + command: Literal[APICommands.FAN_TURN_OFF] = APICommands.FAN_TURN_OFF + platform: str = Platform.FAN + + +@decorators.websocket_command(FanTurnOffCommand) +@decorators.async_response +async def turn_off( + gateway: WebSocketServerGateway, client: Client, command: FanTurnOffCommand +) -> None: + """Turn fan off.""" + await execute_platform_entity_command(gateway, client, command, "async_turn_off") + + +class FanSetPercentageCommand(PlatformEntityCommand): + """Fan set percentage command.""" + + command: Literal[APICommands.FAN_SET_PERCENTAGE] = APICommands.FAN_SET_PERCENTAGE + platform: str = Platform.FAN + percentage: Annotated[int, Field(ge=0, le=100)] + + +@decorators.websocket_command(FanSetPercentageCommand) +@decorators.async_response +async def set_percentage( + gateway: WebSocketServerGateway, client: Client, command: FanSetPercentageCommand +) -> None: + """Set the fan speed percentage.""" + await execute_platform_entity_command( + gateway, client, command, "async_set_percentage" + ) + + +class FanSetPresetModeCommand(PlatformEntityCommand): + """Fan set preset mode command.""" + + command: Literal[APICommands.FAN_SET_PRESET_MODE] = APICommands.FAN_SET_PRESET_MODE + platform: str = Platform.FAN + preset_mode: str + + +@decorators.websocket_command(FanSetPresetModeCommand) +@decorators.async_response +async def set_preset_mode( + gateway: WebSocketServerGateway, client: Client, command: FanSetPresetModeCommand +) -> None: + """Set the fan preset mode.""" + await execute_platform_entity_command( + gateway, client, command, "async_set_preset_mode" + ) + + +def load_api(gateway: WebSocketServerGateway) -> None: + """Load the api command handlers.""" + register_api_command(gateway, turn_on) + register_api_command(gateway, turn_off) + register_api_command(gateway, set_percentage) + register_api_command(gateway, set_preset_mode) diff --git a/zha/application/platforms/helpers.py b/zha/application/platforms/helpers.py index adc3086df..0891b97be 100644 --- a/zha/application/platforms/helpers.py +++ b/zha/application/platforms/helpers.py @@ -1,4 +1,4 @@ -"""Entity helpers for the zhaws server.""" +"""Entity helpers for the zha server.""" from __future__ import annotations diff --git a/zha/application/platforms/light/__init__.py b/zha/application/platforms/light/__init__.py index 9dbdfc3eb..61f6cb3b6 100644 --- a/zha/application/platforms/light/__init__.py +++ b/zha/application/platforms/light/__init__.py @@ -4,13 +4,11 @@ from __future__ import annotations -from abc import ABC +from abc import ABC, abstractmethod import asyncio from collections import Counter from collections.abc import Callable import contextlib -import dataclasses -from dataclasses import dataclass import functools import itertools import logging @@ -23,9 +21,9 @@ from zha.application import Platform from zha.application.platforms import ( BaseEntity, - BaseEntityInfo, GroupEntity, PlatformEntity, + WebSocketClientEntity, ) from zha.application.platforms.helpers import ( find_state_attributes, @@ -62,10 +60,11 @@ brightness_supported, filter_supported_color_modes, ) +from zha.application.platforms.light.model import LightEntityInfo, LightState from zha.application.registries import PLATFORM_ENTITIES +from zha.const import MODEL_CLASS_NAME from zha.debounce import Debouncer from zha.decorators import periodic -from zha.zigbee.cluster_handlers import ClusterAttributeUpdatedEvent from zha.zigbee.cluster_handlers.const import ( CLUSTER_HANDLER_ATTRIBUTE_UPDATED, CLUSTER_HANDLER_COLOR, @@ -73,11 +72,11 @@ CLUSTER_HANDLER_LEVEL_CHANGED, CLUSTER_HANDLER_ON_OFF, ) -from zha.zigbee.cluster_handlers.general import LevelChangeEvent if TYPE_CHECKING: - from zha.zigbee.cluster_handlers import ClusterHandler - from zha.zigbee.device import Device + from zha.zigbee.cluster_handlers import ClusterAttributeUpdatedEvent, ClusterHandler + from zha.zigbee.cluster_handlers.general import LevelChangeEvent + from zha.zigbee.device import Device, WebSocketClientDevice from zha.zigbee.endpoint import Endpoint from zha.zigbee.group import Group @@ -87,17 +86,74 @@ GROUP_MATCH = functools.partial(PLATFORM_ENTITIES.group_match, Platform.LIGHT) -@dataclass(frozen=True, kw_only=True) -class LightEntityInfo(BaseEntityInfo): - """Light entity info.""" +class LightEntityInterface(ABC): + """Light interface.""" - effect_list: list[str] | None = dataclasses.field(default=None) - supported_features: LightEntityFeature - min_mireds: int - max_mireds: int + @property + @abstractmethod + def xy_color(self) -> tuple[float, float] | None: + """Return the xy color value [float, float].""" + + @property + @abstractmethod + def color_temp(self) -> int | None: + """Return the CT color value in mireds.""" + + @property + @abstractmethod + def color_mode(self) -> ColorMode | None: + """Return the color mode.""" + + @property + @abstractmethod + def effect_list(self) -> list[str] | None: + """Return the list of supported effects.""" + + @property + @abstractmethod + def effect(self) -> str: + """Return the current effect.""" + + @property + @abstractmethod + def supported_features(self) -> LightEntityFeature: + """Flag supported features.""" + + @property + @abstractmethod + def supported_color_modes(self) -> set[ColorMode]: + """Flag supported color modes.""" + @property + @abstractmethod + def is_on(self) -> bool: + """Return true if entity is on.""" -class BaseLight(BaseEntity, ABC): + @property + @abstractmethod + def brightness(self) -> int | None: + """Return the brightness of this light.""" + + @property + @abstractmethod + def min_mireds(self) -> int | None: + """Return the coldest color_temp that this light supports.""" + + @property + @abstractmethod + def max_mireds(self) -> int | None: + """Return the warmest color_temp that this light supports.""" + + @abstractmethod + async def async_turn_on(self, **kwargs: Any) -> None: + """Turn the entity on.""" + + @abstractmethod + async def async_turn_off(self, **kwargs: Any) -> None: + """Turn the entity off.""" + + +class BaseLight(BaseEntity, LightEntityInterface): """Operations common to all light entities.""" PLATFORM = Platform.LIGHT @@ -137,23 +193,6 @@ def __init__(self, *args, **kwargs): self._transitioning_group: bool = False self._transition_listener: Callable[[], None] | None = None - @property - def state(self) -> dict[str, Any]: - """Return the state of the light.""" - response = super().state - response["on"] = self.is_on - response["brightness"] = self.brightness - response["xy_color"] = self.xy_color - response["color_temp"] = self.color_temp - response["effect_list"] = self.effect_list - response["effect"] = self.effect - response["supported_features"] = self.supported_features - response["color_mode"] = self.color_mode - response["supported_color_modes"] = self._supported_color_modes - response["off_with_transition"] = self._off_with_transition - response["off_brightness"] = self._off_brightness - return response - @property def xy_color(self) -> tuple[float, float] | None: """Return the xy color value [float, float].""" @@ -736,17 +775,36 @@ def __init__( self._refresh_task: asyncio.Task | None = None self.start_polling() - @functools.cached_property + @property def info_object(self) -> LightEntityInfo: """Return a representation of the select.""" return LightEntityInfo( - **super().info_object.__dict__, + **super().info_object.model_dump(exclude=[MODEL_CLASS_NAME]), effect_list=self.effect_list, supported_features=self.supported_features, min_mireds=self.min_mireds, max_mireds=self.max_mireds, + supported_color_modes=self.supported_color_modes, ) + @property + def state(self) -> dict[str, Any]: + """Return the state of the light.""" + return LightState( + **super().state, + on=self.is_on, + brightness=self.brightness, + xy_color=self.xy_color, + color_temp=self.color_temp, + effect_list=self.effect_list, + effect=self.effect, + supported_features=self.supported_features, + color_mode=self.color_mode, + supported_color_modes=self.supported_color_modes, + off_with_transition=self._off_with_transition, + off_brightness=self._off_brightness, + ).model_dump() + def start_polling(self) -> None: """Start polling.""" self._refresh_task = self.device.gateway.async_create_background_task( @@ -960,6 +1018,7 @@ def restore_external_state_attributes( xy_color: tuple[float, float] | None, color_mode: ColorMode | None, effect: str | None, + **kwargs, ) -> None: """Restore extra state attributes that are stored outside of the ZCL cache.""" if state is not None: @@ -978,6 +1037,7 @@ def restore_external_state_attributes( self._color_mode = color_mode if effect is not None: self._effect = effect + self.maybe_emit_state_changed_event() @STRICT_MATCH( @@ -1083,21 +1143,38 @@ def __init__(self, group: Group): function=self._force_member_updates, ) - if hasattr(self, "info_object"): - delattr(self, "info_object") self.update() - @functools.cached_property + @property def info_object(self) -> LightEntityInfo: """Return a representation of the select.""" return LightEntityInfo( - **super().info_object.__dict__, + **super().info_object.model_dump(exclude=[MODEL_CLASS_NAME]), effect_list=self.effect_list, supported_features=self.supported_features, min_mireds=self.min_mireds, max_mireds=self.max_mireds, + supported_color_modes=self.supported_color_modes, ) + @property + def state(self) -> dict[str, Any]: + """Return the state of the light.""" + return LightState( + **super().state, + on=self.is_on, + brightness=self.brightness, + xy_color=self.xy_color, + color_temp=self.color_temp, + effect_list=self.effect_list, + effect=self.effect, + supported_features=self.supported_features, + color_mode=self.color_mode, + supported_color_modes=self.supported_color_modes, + off_with_transition=self._off_with_transition, + off_brightness=self._off_brightness, + ).model_dump() + async def on_remove(self) -> None: """Cancel tasks this entity owns.""" await super().on_remove() @@ -1263,6 +1340,7 @@ def restore_external_state_attributes( xy_color: tuple[float, float] | None, color_mode: ColorMode | None, effect: str | None, + **kwargs: Any, ) -> None: """Restore extra state attributes.""" # Group state is calculated from the members, @@ -1271,3 +1349,108 @@ def restore_external_state_attributes( self._off_with_transition = off_with_transition if off_brightness is not None: self._off_brightness = off_brightness + self.maybe_emit_state_changed_event() + + +class WebSocketClientLightEntity( + WebSocketClientEntity[LightEntityInfo], LightEntityInterface +): + """Light entity that sends commands to a websocket client.""" + + PLATFORM: Platform = Platform.LIGHT + + def __init__( + self, entity_info: LightEntityInfo, device: WebSocketClientDevice + ) -> None: + """Initialize the ZHA lock entity.""" + super().__init__(entity_info, device) + + @property + def xy_color(self) -> tuple[float, float] | None: + """Return the xy color value [float, float].""" + return self.info_object.state.xy_color + + @property + def color_temp(self) -> int | None: + """Return the CT color value in mireds.""" + return self.info_object.state.color_temp + + @property + def color_mode(self) -> ColorMode | None: + """Return the color mode.""" + return self.info_object.state.color_mode + + @property + def effect_list(self) -> list[str] | None: + """Return the list of supported effects.""" + return self.info_object.effect_list + + @property + def effect(self) -> str: + """Return the current effect.""" + return self.info_object.state.effect + + @property + def supported_features(self) -> LightEntityFeature: + """Flag supported features.""" + return self.info_object.supported_features + + @property + def supported_color_modes(self) -> set[ColorMode]: + """Flag supported color modes.""" + return self.info_object.supported_color_modes + + @property + def is_on(self) -> bool: + """Return true if entity is on.""" + return self.info_object.state.on + + @property + def brightness(self) -> int | None: + """Return the brightness of this light.""" + return self.info_object.state.brightness + + @property + def min_mireds(self) -> int | None: + """Return the coldest color_temp that this light supports.""" + return self.info_object.min_mireds + + @property + def max_mireds(self) -> int | None: + """Return the warmest color_temp that this light supports.""" + return self.info_object.max_mireds + + async def async_turn_on(self, **kwargs: Any) -> None: + """Turn the entity on.""" + await self._device.gateway.lights.turn_on(self.info_object, **kwargs) + + async def async_turn_off(self, **kwargs: Any) -> None: + """Turn the entity off.""" + await self._device.gateway.lights.turn_off(self.info_object, **kwargs) + + def restore_external_state_attributes( + self, + *, + state: bool | None, + off_with_transition: bool | None, + off_brightness: int | None, + brightness: int | None, + color_temp: int | None, + xy_color: tuple[float, float] | None, + color_mode: ColorMode | None, + effect: str | None, + ) -> None: + """Restore extra state attributes that are stored outside of the ZCL cache.""" + self._device.gateway.create_and_track_task( + self._device.gateway.lights.restore_external_state_attributes( + self.info_object, + state=state, + off_with_transition=off_with_transition, + off_brightness=off_brightness, + brightness=brightness, + color_temp=color_temp, + xy_color=xy_color, + color_mode=color_mode, + effect=effect, + ) + ) diff --git a/zha/application/platforms/light/model.py b/zha/application/platforms/light/model.py new file mode 100644 index 000000000..7bb8dc67e --- /dev/null +++ b/zha/application/platforms/light/model.py @@ -0,0 +1,32 @@ +"""Models for the light platform.""" + +from __future__ import annotations + +from zha.application.platforms.light.const import ColorMode, LightEntityFeature +from zha.application.platforms.model import BasePlatformEntityInfo +from zha.model import TypedBaseModel + + +class LightState(TypedBaseModel): + """Light state model.""" + + on: bool + brightness: int | None = None + xy_color: tuple[float, float] | None = None + color_temp: int | None = None + effect: str + off_brightness: int | None = None + color_mode: ColorMode | None = None + off_with_transition: bool = False + available: bool + + +class LightEntityInfo(BasePlatformEntityInfo): + """Light model.""" + + supported_features: LightEntityFeature + min_mireds: int + max_mireds: int + effect_list: list[str] | None = None + supported_color_modes: set[ColorMode] + state: LightState diff --git a/zha/application/platforms/light/websocket_api.py b/zha/application/platforms/light/websocket_api.py new file mode 100644 index 000000000..e2f3f9213 --- /dev/null +++ b/zha/application/platforms/light/websocket_api.py @@ -0,0 +1,121 @@ +"""WS API for the light platform entity.""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Annotated, Literal, Union + +from pydantic import Field, ValidationInfo, field_validator + +from zha.application.discovery import Platform +from zha.application.platforms.light.const import ColorMode +from zha.application.platforms.websocket_api import ( + PlatformEntityCommand, + execute_platform_entity_command, +) +from zha.websocket.const import APICommands +from zha.websocket.server.api import decorators, register_api_command + +if TYPE_CHECKING: + from zha.application.gateway import WebSocketServerGateway + from zha.websocket.server.client import Client + +_LOGGER = logging.getLogger(__name__) + + +class LightTurnOnCommand(PlatformEntityCommand): + """Light turn on command.""" + + command: Literal[APICommands.LIGHT_TURN_ON] = APICommands.LIGHT_TURN_ON + platform: str = Platform.LIGHT + brightness: Union[Annotated[int, Field(ge=0, le=255)], None] = None + transition: Union[Annotated[float, Field(ge=0, le=6553)], None] = None + flash: Union[Literal["short", "long"], None] = None + effect: Union[str, None] = None + hs_color: Union[ + None, + ( + tuple[ + Annotated[int, Field(ge=0, le=360)], Annotated[int, Field(ge=0, le=100)] + ] + ), + ] = None + color_temp: Union[int, None] = None + + @field_validator("color_temp", mode="before", check_fields=False) + @classmethod + def check_color_setting_exclusivity( + cls, color_temp: int | None, validation_info: ValidationInfo + ) -> int | None: + """Ensure only one color mode is set.""" + if ( + "hs_color" in validation_info.data + and validation_info.data["hs_color"] is not None + and color_temp is not None + ): + raise ValueError('Only one of "hs_color" and "color_temp" can be set') + return color_temp + + +@decorators.websocket_command(LightTurnOnCommand) +@decorators.async_response +async def turn_on( + gateway: WebSocketServerGateway, client: Client, command: LightTurnOnCommand +) -> None: + """Turn on the light.""" + await execute_platform_entity_command(gateway, client, command, "async_turn_on") + + +class LightTurnOffCommand(PlatformEntityCommand): + """Light turn off command.""" + + command: Literal[APICommands.LIGHT_TURN_OFF] = APICommands.LIGHT_TURN_OFF + platform: str = Platform.LIGHT + transition: Union[Annotated[float, Field(ge=0, le=6553)], None] = None + flash: Union[Literal["short", "long"], None] = None + + +@decorators.websocket_command(LightTurnOffCommand) +@decorators.async_response +async def turn_off( + gateway: WebSocketServerGateway, client: Client, command: LightTurnOffCommand +) -> None: + """Turn on the light.""" + await execute_platform_entity_command(gateway, client, command, "async_turn_off") + + +class LightRestoreExternalStateAttributesCommand(PlatformEntityCommand): + """Light restore external state attributes command.""" + + command: Literal[APICommands.LIGHT_RESTORE_EXTERNAL_STATE_ATTRIBUTES] = ( + APICommands.LIGHT_RESTORE_EXTERNAL_STATE_ATTRIBUTES + ) + platform: str = Platform.LIGHT + state: bool | None = None + off_with_transition: bool | None = None + off_brightness: int | None = None + brightness: int | None = None + color_temp: int | None = None + xy_color: tuple[float, float] | None = None + color_mode: ColorMode | None = None + effect: str | None = None + + +@decorators.websocket_command(LightRestoreExternalStateAttributesCommand) +@decorators.async_response +async def restore_light_external_state_attributes( + gateway: WebSocketServerGateway, + client: Client, + command: LightRestoreExternalStateAttributesCommand, +) -> None: + """Restore external state attributes for lights.""" + await execute_platform_entity_command( + gateway, client, command, "restore_external_state_attributes" + ) + + +def load_api(gateway: WebSocketServerGateway) -> None: + """Load the api command handlers.""" + register_api_command(gateway, turn_on) + register_api_command(gateway, turn_off) + register_api_command(gateway, restore_light_external_state_attributes) diff --git a/zha/application/platforms/lock/__init__.py b/zha/application/platforms/lock/__init__.py index 7bcff82cb..2bb75f480 100644 --- a/zha/application/platforms/lock/__init__.py +++ b/zha/application/platforms/lock/__init__.py @@ -2,6 +2,7 @@ from __future__ import annotations +from abc import ABC, abstractmethod import functools from typing import TYPE_CHECKING, Any, Literal @@ -9,29 +10,64 @@ from zigpy.zcl.foundation import Status from zha.application import Platform -from zha.application.platforms import PlatformEntity +from zha.application.platforms import PlatformEntity, WebSocketClientEntity from zha.application.platforms.lock.const import ( STATE_LOCKED, STATE_UNLOCKED, VALUE_TO_STATE, ) +from zha.application.platforms.lock.model import LockEntityInfo, LockState from zha.application.registries import PLATFORM_ENTITIES -from zha.zigbee.cluster_handlers import ClusterAttributeUpdatedEvent +from zha.const import MODEL_CLASS_NAME from zha.zigbee.cluster_handlers.const import ( CLUSTER_HANDLER_ATTRIBUTE_UPDATED, CLUSTER_HANDLER_DOORLOCK, ) if TYPE_CHECKING: - from zha.zigbee.cluster_handlers import ClusterHandler - from zha.zigbee.device import Device + from zha.zigbee.cluster_handlers import ClusterAttributeUpdatedEvent, ClusterHandler + from zha.zigbee.device import Device, WebSocketClientDevice from zha.zigbee.endpoint import Endpoint MULTI_MATCH = functools.partial(PLATFORM_ENTITIES.multipass_match, Platform.LOCK) +class LockEntityInterface(ABC): + """Lock interface.""" + + @property + @abstractmethod + def is_locked(self) -> bool: + """Return true if the lock is locked.""" + + async def async_lock(self) -> None: + """Lock the lock.""" + + async def async_unlock(self) -> None: + """Unlock the lock.""" + + async def async_set_lock_user_code(self, code_slot: int, user_code: str) -> None: + """Set the user_code to index X on the lock.""" + + async def async_enable_lock_user_code(self, code_slot: int) -> None: + """Enable user_code at index X on the lock.""" + + async def async_disable_lock_user_code(self, code_slot: int) -> None: + """Disable user_code at index X on the lock.""" + + async def async_clear_lock_user_code(self, code_slot: int) -> None: + """Clear the user_code at index X on the lock.""" + + def restore_external_state_attributes( + self, + *, + state: Literal["locked", "unlocked"] | None, + ) -> None: + """Restore extra state attributes that are stored outside of the ZCL cache.""" + + @MULTI_MATCH(cluster_handler_names=CLUSTER_HANDLER_DOORLOCK) -class DoorLock(PlatformEntity): +class DoorLock(PlatformEntity, LockEntityInterface): """Representation of a ZHA lock.""" PLATFORM = Platform.LOCK @@ -58,12 +94,20 @@ def __init__( self.handle_cluster_handler_attribute_updated, ) + @property + def info_object(self) -> LockEntityInfo: + """Return a representation of the lock.""" + return LockEntityInfo( + **super().info_object.model_dump(exclude=[MODEL_CLASS_NAME]) + ) + @property def state(self) -> dict[str, Any]: """Get the state of the lock.""" - response = super().state - response["is_locked"] = self.is_locked - return response + return LockState( + **super().state, + is_locked=self.is_locked, + ).model_dump() @property def is_locked(self) -> bool: @@ -92,7 +136,9 @@ async def async_unlock(self) -> None: self._state = STATE_UNLOCKED self.maybe_emit_state_changed_event() - async def async_set_lock_user_code(self, code_slot: int, user_code: str) -> None: + async def async_set_lock_user_code( + self, code_slot: int, user_code: str, **kwargs + ) -> None: """Set the user_code to index X on the lock.""" if self._doorlock_cluster_handler: await self._doorlock_cluster_handler.async_set_user_code( @@ -100,19 +146,19 @@ async def async_set_lock_user_code(self, code_slot: int, user_code: str) -> None ) self.debug("User code at slot %s set", code_slot) - async def async_enable_lock_user_code(self, code_slot: int) -> None: + async def async_enable_lock_user_code(self, code_slot: int, **kwargs) -> None: """Enable user_code at index X on the lock.""" if self._doorlock_cluster_handler: await self._doorlock_cluster_handler.async_enable_user_code(code_slot) self.debug("User code at slot %s enabled", code_slot) - async def async_disable_lock_user_code(self, code_slot: int) -> None: + async def async_disable_lock_user_code(self, code_slot: int, **kwargs) -> None: """Disable user_code at index X on the lock.""" if self._doorlock_cluster_handler: await self._doorlock_cluster_handler.async_disable_user_code(code_slot) self.debug("User code at slot %s disabled", code_slot) - async def async_clear_lock_user_code(self, code_slot: int) -> None: + async def async_clear_lock_user_code(self, code_slot: int, **kwargs) -> None: """Clear the user_code at index X on the lock.""" if self._doorlock_cluster_handler: await self._doorlock_cluster_handler.async_clear_user_code(code_slot) @@ -127,10 +173,73 @@ def handle_cluster_handler_attribute_updated( self._state = VALUE_TO_STATE.get(event.attribute_value, self._state) self.maybe_emit_state_changed_event() + def restore_external_state_attributes( + self, *, state: Literal["locked", "unlocked"] | None, **kwargs + ) -> None: + """Restore extra state attributes that are stored outside of the ZCL cache.""" + self._state = state + self.maybe_emit_state_changed_event() + + +class WebSocketClientLockEntity( + WebSocketClientEntity[LockEntityInfo], LockEntityInterface +): + """Representation of a ZHA lock on the client side.""" + + PLATFORM: Platform = Platform.LOCK + + def __init__( + self, entity_info: LockEntityInfo, device: WebSocketClientDevice + ) -> None: + """Initialize the ZHA lock entity.""" + super().__init__(entity_info, device) + + @property + def is_locked(self) -> bool: + """Return true if the lock is locked.""" + return self.info_object.state.is_locked + + async def async_lock(self) -> None: + """Lock the lock.""" + await self._device.gateway.locks.lock(self.info_object) + + async def async_unlock(self) -> None: + """Unlock the lock.""" + await self._device.gateway.locks.unlock(self.info_object) + + async def async_set_lock_user_code(self, code_slot: int, user_code: str) -> None: + """Set the user_code to index X on the lock.""" + await self._device.gateway.locks.set_user_lock_code( + self.info_object, code_slot, user_code + ) + + async def async_enable_lock_user_code(self, code_slot: int) -> None: + """Enable user_code at index X on the lock.""" + await self._device.gateway.locks.enable_user_lock_code( + self.info_object, code_slot + ) + + async def async_disable_lock_user_code(self, code_slot: int) -> None: + """Disable user_code at index X on the lock.""" + await self._device.gateway.locks.disable_user_lock_code( + self.info_object, code_slot + ) + + async def async_clear_lock_user_code(self, code_slot: int) -> None: + """Clear the user_code at index X on the lock.""" + await self._device.gateway.locks.clear_user_lock_code( + self.info_object, code_slot + ) + def restore_external_state_attributes( self, *, state: Literal["locked", "unlocked"] | None, ) -> None: """Restore extra state attributes that are stored outside of the ZCL cache.""" - self._state = state + self._device.gateway.create_and_track_task( + self._device.gateway.locks.restore_external_state_attributes( + self.info_object, + state=state, + ) + ) diff --git a/zha/application/platforms/lock/model.py b/zha/application/platforms/lock/model.py new file mode 100644 index 000000000..f0d7b9c30 --- /dev/null +++ b/zha/application/platforms/lock/model.py @@ -0,0 +1,19 @@ +"""Models for the lock platform.""" + +from __future__ import annotations + +from zha.application.platforms.model import BasePlatformEntityInfo +from zha.model import TypedBaseModel + + +class LockState(TypedBaseModel): + """Lock state model.""" + + is_locked: bool + available: bool + + +class LockEntityInfo(BasePlatformEntityInfo): + """Lock entity model.""" + + state: LockState diff --git a/zha/application/platforms/lock/websocket_api.py b/zha/application/platforms/lock/websocket_api.py new file mode 100644 index 000000000..84fcb51c4 --- /dev/null +++ b/zha/application/platforms/lock/websocket_api.py @@ -0,0 +1,172 @@ +"""WS api for the lock platform entity.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal + +from zha.application.discovery import Platform +from zha.application.platforms.websocket_api import ( + PlatformEntityCommand, + execute_platform_entity_command, +) +from zha.websocket.const import APICommands +from zha.websocket.server.api import decorators, register_api_command + +if TYPE_CHECKING: + from zha.application.gateway import WebSocketServerGateway + from zha.websocket.server.client import Client + + +class LockLockCommand(PlatformEntityCommand): + """Lock lock command.""" + + command: Literal[APICommands.LOCK_LOCK] = APICommands.LOCK_LOCK + platform: str = Platform.LOCK + + +@decorators.websocket_command(LockLockCommand) +@decorators.async_response +async def lock( + gateway: WebSocketServerGateway, client: Client, command: LockLockCommand +) -> None: + """Lock the lock.""" + await execute_platform_entity_command(gateway, client, command, "async_lock") + + +class LockUnlockCommand(PlatformEntityCommand): + """Lock unlock command.""" + + command: Literal[APICommands.LOCK_UNLOCK] = APICommands.LOCK_UNLOCK + platform: str = Platform.LOCK + + +@decorators.websocket_command(LockUnlockCommand) +@decorators.async_response +async def unlock( + gateway: WebSocketServerGateway, client: Client, command: LockUnlockCommand +) -> None: + """Unlock the lock.""" + await execute_platform_entity_command(gateway, client, command, "async_unlock") + + +class LockSetUserLockCodeCommand(PlatformEntityCommand): + """Set user lock code command.""" + + command: Literal[APICommands.LOCK_SET_USER_CODE] = APICommands.LOCK_SET_USER_CODE + platform: str = Platform.LOCK + code_slot: int + user_code: str + + +@decorators.websocket_command(LockSetUserLockCodeCommand) +@decorators.async_response +async def set_user_lock_code( + gateway: WebSocketServerGateway, client: Client, command: LockSetUserLockCodeCommand +) -> None: + """Set a user lock code in the specified slot for the lock.""" + await execute_platform_entity_command( + gateway, client, command, "async_set_lock_user_code" + ) + + +class LockEnableUserLockCodeCommand(PlatformEntityCommand): + """Enable user lock code command.""" + + command: Literal[APICommands.LOCK_ENAABLE_USER_CODE] = ( + APICommands.LOCK_ENAABLE_USER_CODE + ) + platform: str = Platform.LOCK + code_slot: int + + +@decorators.websocket_command(LockEnableUserLockCodeCommand) +@decorators.async_response +async def enable_user_lock_code( + gateway: WebSocketServerGateway, + client: Client, + command: LockEnableUserLockCodeCommand, +) -> None: + """Enable a user lock code for the lock.""" + await execute_platform_entity_command( + gateway, client, command, "async_enable_lock_user_code" + ) + + +class LockDisableUserLockCodeCommand(PlatformEntityCommand): + """Disable user lock code command.""" + + command: Literal[APICommands.LOCK_DISABLE_USER_CODE] = ( + APICommands.LOCK_DISABLE_USER_CODE + ) + platform: str = Platform.LOCK + code_slot: int + + +@decorators.websocket_command(LockDisableUserLockCodeCommand) +@decorators.async_response +async def disable_user_lock_code( + gateway: WebSocketServerGateway, + client: Client, + command: LockDisableUserLockCodeCommand, +) -> None: + """Disable a user lock code for the lock.""" + await execute_platform_entity_command( + gateway, client, command, "async_disable_lock_user_code" + ) + + +class LockClearUserLockCodeCommand(PlatformEntityCommand): + """Clear user lock code command.""" + + command: Literal[APICommands.LOCK_CLEAR_USER_CODE] = ( + APICommands.LOCK_CLEAR_USER_CODE + ) + platform: str = Platform.LOCK + code_slot: int + + +@decorators.websocket_command(LockClearUserLockCodeCommand) +@decorators.async_response +async def clear_user_lock_code( + gateway: WebSocketServerGateway, + client: Client, + command: LockClearUserLockCodeCommand, +) -> None: + """Clear a user lock code for the lock.""" + await execute_platform_entity_command( + gateway, client, command, "async_clear_lock_user_code" + ) + + +class LockRestoreExternalStateAttributesCommand(PlatformEntityCommand): + """Restore external state attributes command.""" + + command: Literal[APICommands.LOCK_RESTORE_EXTERNAL_STATE_ATTRIBUTES] = ( + APICommands.LOCK_RESTORE_EXTERNAL_STATE_ATTRIBUTES + ) + platform: str = Platform.LOCK + state: Literal["locked", "unlocked", "unavailable"] | None + + +@decorators.websocket_command(LockRestoreExternalStateAttributesCommand) +@decorators.async_response +async def restore_lock_external_state_attributes( + gateway: WebSocketServerGateway, + client: Client, + command: LockRestoreExternalStateAttributesCommand, +) -> None: + """Restore externally preserved state for locks.""" + await execute_platform_entity_command( + gateway, client, command, "restore_external_state_attributes" + ) + + +def load_api(gateway: WebSocketServerGateway) -> None: + """Load the api command handlers.""" + register_api_command(gateway, lock) + register_api_command(gateway, unlock) + register_api_command(gateway, set_user_lock_code) + register_api_command(gateway, enable_user_lock_code) + register_api_command(gateway, disable_user_lock_code) + register_api_command(gateway, clear_user_lock_code) + register_api_command(gateway, restore_lock_external_state_attributes) diff --git a/zha/application/platforms/model.py b/zha/application/platforms/model.py new file mode 100644 index 000000000..b908283a5 --- /dev/null +++ b/zha/application/platforms/model.py @@ -0,0 +1,72 @@ +"""Models for the ZHA platforms module.""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any, TypeVar + +from zigpy.types.named import EUI64 + +from zha.application.discovery import Platform +from zha.event import EventBase +from zha.model import BaseModel, TypedBaseModel +from zha.zigbee.cluster_handlers.model import ClusterHandlerInfo + + +class BaseEntityInfo(TypedBaseModel): + """Information about a base entity.""" + + platform: Platform + unique_id: str + class_name: str + translation_key: str | None = None + device_class: str | None = None + state_class: str | None = None + entity_category: str | None = None + entity_registry_enabled_default: bool + enabled: bool = True + fallback_name: str | None = None + state: dict[str, Any] + + # For platform entities + cluster_handlers: list[ClusterHandlerInfo] + device_ieee: EUI64 | None = None + endpoint_id: int | None = None + available: bool | None = None + + # For group entities + group_id: int | None = None + + +T = TypeVar("T", bound=BaseEntityInfo) + + +class BaseIdentifiers(BaseModel): + """Identifiers for the base entity.""" + + unique_id: str + platform: Platform + + +class PlatformEntityIdentifiers(BaseIdentifiers): + """Identifiers for the platform entity.""" + + device_ieee: EUI64 + endpoint_id: int + + +class GroupEntityIdentifiers(BaseIdentifiers): + """Identifiers for the group entity.""" + + group_id: int + + +class EntityState(TypedBaseModel): + """Default state model.""" + + available: bool | None = None + state: str | bool | int | float | datetime | None = None + + +class BasePlatformEntityInfo(EventBase, BaseEntityInfo): + """Base platform entity model.""" diff --git a/zha/application/platforms/number/__init__.py b/zha/application/platforms/number/__init__.py index 4bb0468db..86e07b09a 100644 --- a/zha/application/platforms/number/__init__.py +++ b/zha/application/platforms/number/__init__.py @@ -2,10 +2,10 @@ from __future__ import annotations -from dataclasses import dataclass +from abc import ABC, abstractmethod import functools import logging -from typing import TYPE_CHECKING, Any, Self +from typing import TYPE_CHECKING, Any, Self, cast from zhaquirks.quirk_ids import DANFOSS_ALLY_THERMOSTAT from zigpy.quirks.v2 import NumberMetadata @@ -13,17 +13,23 @@ from zha.application import Platform from zha.application.const import ENTITY_METADATA -from zha.application.platforms import BaseEntityInfo, EntityCategory, PlatformEntity +from zha.application.platforms import PlatformEntity, WebSocketClientEntity +from zha.application.platforms.const import EntityCategory from zha.application.platforms.helpers import validate_device_class +from zha.application.platforms.model import EntityState from zha.application.platforms.number.const import ( ICONS, UNITS, NumberDeviceClass, NumberMode, ) +from zha.application.platforms.number.model import ( + NumberConfigurationEntityInfo, + NumberEntityInfo, +) from zha.application.registries import PLATFORM_ENTITIES +from zha.const import MODEL_CLASS_NAME from zha.units import UnitOfMass, UnitOfTemperature, UnitOfTime, validate_unit -from zha.zigbee.cluster_handlers import ClusterAttributeUpdatedEvent from zha.zigbee.cluster_handlers.const import ( CLUSTER_HANDLER_ANALOG_OUTPUT, CLUSTER_HANDLER_ATTRIBUTE_UPDATED, @@ -36,8 +42,8 @@ ) if TYPE_CHECKING: - from zha.zigbee.cluster_handlers import ClusterHandler - from zha.zigbee.device import Device + from zha.zigbee.cluster_handlers import ClusterAttributeUpdatedEvent, ClusterHandler + from zha.zigbee.device import Device, WebSocketClientDevice from zha.zigbee.endpoint import Endpoint _LOGGER = logging.getLogger(__name__) @@ -48,30 +54,56 @@ ) -@dataclass(frozen=True, kw_only=True) -class NumberEntityInfo(BaseEntityInfo): - """Number entity info.""" +class NumberEntityInterface(ABC): + """Number interface.""" + + @property + @abstractmethod + def native_value(self) -> float | None: + """Return the current value.""" + + @property + @abstractmethod + def native_min_value(self) -> float: + """Return the minimum value.""" + + @property + @abstractmethod + def native_max_value(self) -> float: + """Return the maximum value.""" + + @property + @abstractmethod + def native_step(self) -> float | None: + """Return the value step.""" + + @property + @abstractmethod + def native_unit_of_measurement(self) -> str | None: + """Return the unit the value is expressed in.""" - engineering_units: int - application_type: int - min_value: float | None - max_value: float | None - step: float | None + @property + @abstractmethod + def mode(self) -> NumberMode: + """Return the mode of the entity.""" + @property + @abstractmethod + def description(self) -> str | None: + """Return the description of the number entity.""" -@dataclass(frozen=True, kw_only=True) -class NumberConfigurationEntityInfo(BaseEntityInfo): - """Number configuration entity info.""" + @property + @abstractmethod + def icon(self) -> str | None: + """Return the icon to be used for this entity.""" - min_value: float | None - max_value: float | None - step: float | None - multiplier: float | None - device_class: str | None + @abstractmethod + async def async_set_native_value(self, value: float) -> None: + """Update the current value from HA.""" @STRICT_MATCH(cluster_handler_names=CLUSTER_HANDLER_ANALOG_OUTPUT) -class Number(PlatformEntity): +class Number(PlatformEntity, NumberEntityInterface): """Representation of a ZHA Number entity.""" PLATFORM = Platform.NUMBER @@ -96,24 +128,29 @@ def __init__( self.handle_cluster_handler_attribute_updated, ) - @functools.cached_property + @property def info_object(self) -> NumberEntityInfo: """Return a representation of the number entity.""" return NumberEntityInfo( - **super().info_object.__dict__, + **super().info_object.model_dump(exclude=[MODEL_CLASS_NAME]), engineering_units=self._analog_output_cluster_handler.engineering_units, application_type=self._analog_output_cluster_handler.application_type, min_value=self.native_min_value, max_value=self.native_max_value, step=self.native_step, + mode=self.mode, + description=self.description, + icon=self.icon, + unit=self.native_unit_of_measurement, ) @property def state(self) -> dict[str, Any]: """Return the state of the entity.""" - response = super().state - response["state"] = self.native_value - return response + return EntityState( + **super().state, + state=self.native_value, + ).model_dump() @property def native_value(self) -> float | None: @@ -168,7 +205,7 @@ def mode(self) -> NumberMode: """Return the mode of the entity.""" return self._attr_mode - async def async_set_native_value(self, value: float) -> None: + async def async_set_native_value(self, value: float, **kwargs) -> None: """Update the current value from HA.""" await self._analog_output_cluster_handler.async_set_present_value(float(value)) self.maybe_emit_state_changed_event() @@ -270,11 +307,11 @@ def _init_from_quirks_metadata(self, entity_metadata: NumberMetadata) -> None: entity_metadata.unit ).value - @functools.cached_property + @property def info_object(self) -> NumberConfigurationEntityInfo: """Return a representation of the number entity.""" return NumberConfigurationEntityInfo( - **super().info_object.__dict__, + **super().info_object.model_dump(exclude=[MODEL_CLASS_NAME]), min_value=self._attr_native_min_value, max_value=self._attr_native_max_value, step=self._attr_native_step, @@ -284,9 +321,10 @@ def info_object(self) -> NumberConfigurationEntityInfo: @property def state(self) -> dict[str, Any]: """Return the state of the entity.""" - response = super().state - response["state"] = self.native_value - return response + return EntityState( + **super().state, + state=self.native_value, + ).model_dump() @property def native_value(self) -> float | None: @@ -329,7 +367,7 @@ def mode(self) -> NumberMode: """Return the mode of the entity.""" return self._attr_mode - async def async_set_native_value(self, value: float) -> None: + async def async_set_native_value(self, value: float, **kwargs) -> None: """Update the current value from HA.""" await self._cluster_handler.write_attributes_safe( {self._attribute_name: int(value / self._attr_multiplier)} @@ -1064,3 +1102,67 @@ class SinopeLightLEDOffIntensityConfigurationEntity(NumberConfigurationEntity): _attr_native_max_value: float = 100 _attribute_name = "off_led_intensity" _attr_translation_key: str = "off_led_intensity" + + +class WebSocketClientNumberEntity( + WebSocketClientEntity[NumberEntityInfo], NumberEntityInterface +): + """Representation of a WebSocket client number entity.""" + + PLATFORM: Platform = Platform.NUMBER + + def __init__( + self, entity_info: NumberEntityInfo, device: WebSocketClientDevice + ) -> None: + """Initialize the ZHA number entity.""" + super().__init__(entity_info, device) + + @property + def native_value(self) -> float | None: + """Return the current value.""" + return cast( + float, self.info_object.state.state + ) # TODO make a proper state class for number entities + + @property + def native_min_value(self) -> float: + """Return the minimum value.""" + return self.info_object.min_value + + @property + def native_max_value(self) -> float: + """Return the maximum value.""" + return self.info_object.max_value + + @property + def native_step(self) -> float | None: + """Return the value step.""" + return self.info_object.step + + @property + def native_unit_of_measurement(self) -> str | None: + """Return the unit the value is expressed in.""" + return self.info_object.unit + + @property + def mode(self) -> NumberMode: + """Return the mode of the entity.""" + return self.info_object.mode + + @property + def description(self) -> str | None: + """Return the description of the number entity.""" + return self.info_object.description + + @property + def icon(self) -> str | None: + """Return the icon of the number entity.""" + return self.info_object.icon + + async def async_set_value(self, value: float) -> None: + """Update the current value from HA.""" + await self.async_set_native_value(value) + + async def async_set_native_value(self, value: float) -> None: + """Update the current value from HA.""" + await self._device.gateway.numbers.set_value(self.info_object, value) diff --git a/zha/application/platforms/number/model.py b/zha/application/platforms/number/model.py new file mode 100644 index 000000000..78ea52bbc --- /dev/null +++ b/zha/application/platforms/number/model.py @@ -0,0 +1,40 @@ +"""Models for the number platform.""" + +from __future__ import annotations + +from zha.application.platforms.model import BasePlatformEntityInfo, EntityState +from zha.application.platforms.number.const import NumberMode + + +class NumberEntityInfo(BasePlatformEntityInfo): + """Number entity model.""" + + engineering_units: int | None = ( + None # TODO: how should we represent this when it is None? + ) + application_type: int | None = ( + None # TODO: how should we represent this when it is None? + ) + step: float | None = None # TODO: how should we represent this when it is None? + min_value: float + max_value: float + mode: NumberMode = NumberMode.AUTO + unit: str | None = None + description: str | None = None + icon: str | None = None + state: EntityState + + +class NumberConfigurationEntityInfo(BasePlatformEntityInfo): + """Number configuration entity info.""" + + step: float | None + min_value: float | None + max_value: float | None + mode: NumberMode = NumberMode.AUTO + unit: str | None = None + multiplier: float | None + device_class: str | None + description: str | None = None + icon: str | None = None + state: EntityState diff --git a/zha/application/platforms/number/websocket_api.py b/zha/application/platforms/number/websocket_api.py new file mode 100644 index 000000000..753602b57 --- /dev/null +++ b/zha/application/platforms/number/websocket_api.py @@ -0,0 +1,44 @@ +"""WS api for the number platform entity.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal + +from zha.application.discovery import Platform +from zha.application.platforms.websocket_api import ( + PlatformEntityCommand, + execute_platform_entity_command, +) +from zha.websocket.const import APICommands +from zha.websocket.server.api import decorators, register_api_command + +if TYPE_CHECKING: + from zha.application.gateway import WebSocketServerGateway + from zha.websocket.server.client import Client + +ATTR_VALUE = "value" +COMMAND_SET_VALUE = "number_set_value" + + +class NumberSetValueCommand(PlatformEntityCommand): + """Number set value command.""" + + command: Literal[APICommands.NUMBER_SET_VALUE] = APICommands.NUMBER_SET_VALUE + platform: str = Platform.NUMBER + value: float + + +@decorators.websocket_command(NumberSetValueCommand) +@decorators.async_response +async def set_value( + gateway: WebSocketServerGateway, client: Client, command: NumberSetValueCommand +) -> None: + """Select an option.""" + await execute_platform_entity_command( + gateway, client, command, "async_set_native_value" + ) + + +def load_api(gateway: WebSocketServerGateway) -> None: + """Load the api command handlers.""" + register_api_command(gateway, set_value) diff --git a/zha/application/platforms/select.py b/zha/application/platforms/select/__init__.py similarity index 89% rename from zha/application/platforms/select.py rename to zha/application/platforms/select/__init__.py index 101296652..ba6d52586 100644 --- a/zha/application/platforms/select.py +++ b/zha/application/platforms/select/__init__.py @@ -2,7 +2,7 @@ from __future__ import annotations -from dataclasses import dataclass +from abc import ABC, abstractmethod from enum import Enum import functools import logging @@ -23,9 +23,15 @@ from zha.application import Platform from zha.application.const import ENTITY_METADATA, Strobe -from zha.application.platforms import BaseEntityInfo, EntityCategory, PlatformEntity +from zha.application.platforms import PlatformEntity, WebSocketClientEntity +from zha.application.platforms.const import EntityCategory +from zha.application.platforms.model import EntityState +from zha.application.platforms.select.model import ( + EnumSelectEntityInfo, + SelectEntityInfo, +) from zha.application.registries import PLATFORM_ENTITIES -from zha.zigbee.cluster_handlers import ClusterAttributeUpdatedEvent +from zha.const import MODEL_CLASS_NAME from zha.zigbee.cluster_handlers.const import ( CLUSTER_HANDLER_ATTRIBUTE_UPDATED, CLUSTER_HANDLER_HUE_OCCUPANCY, @@ -37,8 +43,8 @@ ) if TYPE_CHECKING: - from zha.zigbee.cluster_handlers import ClusterHandler - from zha.zigbee.device import Device + from zha.zigbee.cluster_handlers import ClusterAttributeUpdatedEvent, ClusterHandler + from zha.zigbee.device import Device, WebSocketClientDevice from zha.zigbee.endpoint import Endpoint @@ -48,15 +54,28 @@ _LOGGER = logging.getLogger(__name__) -@dataclass(frozen=True, kw_only=True) -class EnumSelectInfo(BaseEntityInfo): - """Enum select entity info.""" +class SelectEntityInterface(ABC): + """Select interface for ZHA select entities.""" + + @property + @abstractmethod + def current_option(self) -> str | None: + """Return the selected entity option to represent the entity state.""" + + @abstractmethod + async def async_select_option(self, option: str) -> None: + """Change the selected option.""" - enum: str - options: list[str] + @abstractmethod + def restore_external_state_attributes( + self, + *, + state: str, + ) -> None: + """Restore extra state attributes that are stored outside of the ZCL cache.""" -class EnumSelectEntity(PlatformEntity): +class EnumSelectEntity(PlatformEntity, SelectEntityInterface): """Representation of a ZHA select entity.""" PLATFORM = Platform.SELECT @@ -78,21 +97,22 @@ def __init__( self._attr_options = [entry.name.replace("_", " ") for entry in self._enum] super().__init__(unique_id, cluster_handlers, endpoint, device, **kwargs) - @functools.cached_property - def info_object(self) -> EnumSelectInfo: + @property + def info_object(self) -> EnumSelectEntityInfo: """Return a representation of the select.""" - return EnumSelectInfo( - **super().info_object.__dict__, + return EnumSelectEntityInfo( + **super().info_object.model_dump(exclude=[MODEL_CLASS_NAME]), enum=self._enum.__name__, options=self._attr_options, ) @property - def state(self) -> dict: + def state(self) -> dict[str, Any]: """Return the state of the select.""" - response = super().state - response["state"] = self.current_option - return response + return EntityState( + **super().state, + state=self.current_option, + ).model_dump() @property def current_option(self) -> str | None: @@ -102,21 +122,18 @@ def current_option(self) -> str | None: return None return option.name.replace("_", " ") - async def async_select_option(self, option: str) -> None: + async def async_select_option(self, option: str, **kwargs) -> None: """Change the selected option.""" self._cluster_handler.data_cache[self._attribute_name] = self._enum[ option.replace(" ", "_") ] self.maybe_emit_state_changed_event() - def restore_external_state_attributes( - self, - *, - state: str, - ) -> None: + def restore_external_state_attributes(self, *, state: str, **kwargs) -> None: """Restore extra state attributes that are stored outside of the ZCL cache.""" value = state.replace(" ", "_") self._cluster_handler.data_cache[self._attribute_name] = self._enum[value] + self.maybe_emit_state_changed_event() class NonZCLSelectEntity(EnumSelectEntity): @@ -164,7 +181,7 @@ class DefaultStrobeSelectEntity(NonZCLSelectEntity): _attr_translation_key: str = "default_strobe" -class ZCLEnumSelectEntity(PlatformEntity): +class ZCLEnumSelectEntity(PlatformEntity, SelectEntityInterface): """Representation of a ZHA ZCL enum select entity.""" PLATFORM = Platform.SELECT @@ -223,11 +240,11 @@ def _init_from_quirks_metadata(self, entity_metadata: ZCLEnumMetadata) -> None: self._attribute_name = entity_metadata.attribute_name self._enum = entity_metadata.enum - @functools.cached_property - def info_object(self) -> EnumSelectInfo: + @property + def info_object(self) -> EnumSelectEntityInfo: """Return a representation of the select.""" - return EnumSelectInfo( - **super().info_object.__dict__, + return EnumSelectEntityInfo( + **super().info_object.model_dump(exclude=[MODEL_CLASS_NAME]), enum=self._enum.__name__, options=self._attr_options, ) @@ -235,9 +252,10 @@ def info_object(self) -> EnumSelectInfo: @property def state(self) -> dict[str, Any]: """Return the state of the select.""" - response = super().state - response["state"] = self.current_option - return response + return EntityState( + **super().state, + state=self.current_option, + ).model_dump() @property def current_option(self) -> str | None: @@ -248,7 +266,7 @@ def current_option(self) -> str | None: option = self._enum(option) return option.name.replace("_", " ") - async def async_select_option(self, option: str) -> None: + async def async_select_option(self, option: str, **kwargs) -> None: """Change the selected option.""" await self._cluster_handler.write_attributes_safe( {self._attribute_name: self._enum[option.replace(" ", "_")]} @@ -267,6 +285,7 @@ def restore_external_state_attributes( self, *, state: str, + **kwargs, ) -> None: """Restore extra state attributes.""" # Select entities backed by the ZCL cache don't need to restore their state! @@ -888,3 +907,37 @@ class SinopeLightLEDOnColorSelect(ZCLEnumSelectEntity): _attribute_name = "on_led_color" _attr_translation_key: str = "on_led_color" _enum = SinopeLightLedColors + + +class WebSocketClientSelectEntity( + WebSocketClientEntity[SelectEntityInfo], SelectEntityInterface +): + """Representation of a ZHA select entity controlled via a websocket.""" + + PLATFORM = Platform.SELECT + + def __init__( + self, entity_info: SelectEntityInfo, device: WebSocketClientDevice + ) -> None: + """Initialize the ZHA select entity.""" + super().__init__(entity_info, device) + + @property + def current_option(self) -> str | None: + """Return the selected entity option to represent the entity state.""" + + async def async_select_option(self, option: str) -> None: + """Change the selected option.""" + await self._device.gateway.selects.select_option(self.info_object, option) + + def restore_external_state_attributes( + self, + *, + state: str, + ) -> None: + """Restore extra state attributes.""" + self._device.gateway.create_and_track_task( + self._device.gateway.selects.restore_external_state_attributes( + self.info_object, state + ) + ) diff --git a/zha/application/platforms/select/model.py b/zha/application/platforms/select/model.py new file mode 100644 index 000000000..f1202f77d --- /dev/null +++ b/zha/application/platforms/select/model.py @@ -0,0 +1,21 @@ +"""Models for the select platform.""" + +from __future__ import annotations + +from zha.application.platforms.model import BasePlatformEntityInfo, EntityState + + +class SelectEntityInfo(BasePlatformEntityInfo): + """Select entity model.""" + + enum: str + options: list[str] + state: EntityState + + +class EnumSelectEntityInfo(BasePlatformEntityInfo): + """Enum select entity info.""" + + enum: str + options: list[str] + state: EntityState diff --git a/zha/application/platforms/select/websocket_api.py b/zha/application/platforms/select/websocket_api.py new file mode 100644 index 000000000..cc26671b0 --- /dev/null +++ b/zha/application/platforms/select/websocket_api.py @@ -0,0 +1,67 @@ +"""WS api for the select platform entity.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal + +from zha.application.discovery import Platform +from zha.application.platforms.websocket_api import ( + PlatformEntityCommand, + execute_platform_entity_command, +) +from zha.websocket.const import APICommands +from zha.websocket.server.api import decorators, register_api_command + +if TYPE_CHECKING: + from zha.application.gateway import WebSocketServerGateway + from zha.websocket.server.client import Client + + +class SelectSelectOptionCommand(PlatformEntityCommand): + """Select select option command.""" + + command: Literal[APICommands.SELECT_SELECT_OPTION] = ( + APICommands.SELECT_SELECT_OPTION + ) + platform: str = Platform.SELECT + option: str + + +@decorators.websocket_command(SelectSelectOptionCommand) +@decorators.async_response +async def select_option( + gateway: WebSocketServerGateway, client: Client, command: SelectSelectOptionCommand +) -> None: + """Select an option.""" + await execute_platform_entity_command( + gateway, client, command, "async_select_option" + ) + + +class SelectRestoreExternalStateAttributesCommand(PlatformEntityCommand): + """Select restore external state command.""" + + command: Literal[APICommands.SELECT_RESTORE_EXTERNAL_STATE_ATTRIBUTES] = ( + APICommands.SELECT_RESTORE_EXTERNAL_STATE_ATTRIBUTES + ) + platform: str = Platform.SELECT + state: str + + +@decorators.websocket_command(SelectRestoreExternalStateAttributesCommand) +@decorators.async_response +async def restore_lock_external_state_attributes( + gateway: WebSocketServerGateway, + client: Client, + command: SelectRestoreExternalStateAttributesCommand, +) -> None: + """Restore externally preserved state for selects.""" + await execute_platform_entity_command( + gateway, client, command, "restore_external_state_attributes" + ) + + +def load_api(gateway: WebSocketServerGateway) -> None: + """Load the api command handlers.""" + register_api_command(gateway, select_option) + register_api_command(gateway, restore_lock_external_state_attributes) diff --git a/zha/application/platforms/sensor/__init__.py b/zha/application/platforms/sensor/__init__.py index f5341d98c..ee48478b4 100644 --- a/zha/application/platforms/sensor/__init__.py +++ b/zha/application/platforms/sensor/__init__.py @@ -2,8 +2,8 @@ from __future__ import annotations +from abc import ABC, abstractmethod from asyncio import Task -from dataclasses import dataclass from datetime import UTC, date, datetime import enum import functools @@ -21,21 +21,35 @@ from zha.application import Platform from zha.application.const import ENTITY_METADATA -from zha.application.platforms import ( - BaseEntity, - BaseEntityInfo, - BaseIdentifiers, - EntityCategory, - PlatformEntity, -) +from zha.application.platforms import BaseEntity, PlatformEntity, WebSocketClientEntity from zha.application.platforms.climate.const import HVACAction +from zha.application.platforms.const import EntityCategory from zha.application.platforms.helpers import validate_device_class +from zha.application.platforms.model import EntityState from zha.application.platforms.sensor.const import ( UNIX_EPOCH_TO_ZCL_EPOCH, SensorDeviceClass, SensorStateClass, ) +from zha.application.platforms.sensor.model import ( + BaseSensorEntityInfo, + BatteryEntityInfo, + BatteryState, + DeviceCounterSensorEntityInfo, + DeviceCounterSensorIdentifiers, + DeviceCounterSensorState, + ElectricalMeasurementEntityInfo, + ElectricalMeasurementState, + SensorEntityInfo, + SetpointChangeSourceTimestampSensorEntityInfo, + SmartEnergyMeteringEntityDescription, + SmartEnergyMeteringEntityInfo, + SmartEnergyMeteringState, + SmartEnergySummationEntityDescription, + TimestampState, +) from zha.application.registries import PLATFORM_ENTITIES +from zha.const import MODEL_CLASS_NAME from zha.decorators import periodic from zha.units import ( CONCENTRATION_MICROGRAMS_PER_CUBIC_METER, @@ -58,7 +72,6 @@ UnitOfVolumeFlowRate, validate_unit, ) -from zha.zigbee.cluster_handlers import ClusterAttributeUpdatedEvent from zha.zigbee.cluster_handlers.const import ( CLUSTER_HANDLER_ANALOG_INPUT, CLUSTER_HANDLER_ATTRIBUTE_UPDATED, @@ -82,8 +95,8 @@ ) if TYPE_CHECKING: - from zha.zigbee.cluster_handlers import ClusterHandler - from zha.zigbee.device import Device + from zha.zigbee.cluster_handlers import ClusterAttributeUpdatedEvent, ClusterHandler + from zha.zigbee.device import Device, WebSocketClientDevice from zha.zigbee.endpoint import Endpoint BATTERY_SIZES = { @@ -114,36 +127,13 @@ ) -@dataclass(frozen=True, kw_only=True) -class SensorEntityInfo(BaseEntityInfo): - """Sensor entity info.""" - - attribute: str - decimals: int - divisor: int - multiplier: int - unit: str | None = None - device_class: SensorDeviceClass | None = None - state_class: SensorStateClass | None = None - - -@dataclass(frozen=True, kw_only=True) -class DeviceCounterEntityInfo(BaseEntityInfo): - """Device counter entity info.""" - - device_ieee: str - available: bool - counter: str - counter_value: int - counter_groups: str - counter_group: str - - -@dataclass(frozen=True, kw_only=True) -class DeviceCounterSensorIdentifiers(BaseIdentifiers): - """Device counter sensor identifiers.""" +class SensorEntityInterface(ABC): + """Sensor interface.""" - device_ieee: str + @property + @abstractmethod + def native_value(self) -> date | datetime | str | int | float | None: + """Return the state of the entity.""" class Sensor(PlatformEntity): @@ -221,11 +211,11 @@ def _init_from_quirks_metadata(self, entity_metadata: ZCLSensorMetadata) -> None entity_metadata.unit ).value - @functools.cached_property + @property def info_object(self) -> SensorEntityInfo: """Return a representation of the sensor.""" return SensorEntityInfo( - **super().info_object.__dict__, + **super().info_object.model_dump(exclude=[MODEL_CLASS_NAME]), attribute=self._attribute_name, decimals=self._decimals, divisor=self._divisor, @@ -235,15 +225,20 @@ def info_object(self) -> SensorEntityInfo: if getattr(self, "entity_description", None) is not None else self._attr_native_unit_of_measurement ), + extra_state_attribute_names=getattr( + self, "_attr_extra_state_attribute_names", None + ), + entity_desctiption=getattr(self, "entity_description", None), ) @property - def state(self) -> dict: + def state(self) -> dict[str, Any]: """Return the state for this sensor.""" - response = super().state - native_value = self.native_value - response["state"] = native_value - return response + data = EntityState( + **super().state, + state=self.native_value, + ).model_dump() + return data @property def native_value(self) -> date | datetime | str | int | float | None: @@ -259,16 +254,7 @@ def handle_cluster_handler_attribute_updated( event: ClusterAttributeUpdatedEvent, # pylint: disable=unused-argument ) -> None: """Handle attribute updates from the cluster handler.""" - if ( - event.attribute_name == self._attribute_name - or ( - hasattr(self, "_attr_extra_state_attribute_names") - and event.attribute_name - in getattr(self, "_attr_extra_state_attribute_names") - ) - or self._attribute_name is None - ): - self.maybe_emit_state_changed_event() + self.maybe_emit_state_changed_event() def formatter( self, value: int | enum.IntEnum @@ -421,14 +407,20 @@ def __init__( def identifiers(self) -> DeviceCounterSensorIdentifiers: """Return a dict with the information necessary to identify this entity.""" return DeviceCounterSensorIdentifiers( - **super().identifiers.__dict__, device_ieee=str(self._device.ieee) + **super().identifiers.model_dump(), + device_ieee=str(self._device.ieee), ) - @functools.cached_property - def info_object(self) -> DeviceCounterEntityInfo: + @property + def info_object(self) -> DeviceCounterSensorEntityInfo: """Return a representation of the platform entity.""" - return DeviceCounterEntityInfo( - **super().info_object.__dict__, + data = super().info_object.model_dump(exclude=[MODEL_CLASS_NAME]) + data.pop("device_ieee") + data.pop("available") + return DeviceCounterSensorEntityInfo( + **data, + device_ieee=self._device.ieee, + available=self._device.available, counter=self._zigpy_counter.name, counter_value=self._zigpy_counter.value, counter_groups=self._zigpy_counter_groups, @@ -438,9 +430,11 @@ def info_object(self) -> DeviceCounterEntityInfo: @property def state(self) -> dict[str, Any]: """Return the state for this sensor.""" - response = super().state - response["state"] = self._zigpy_counter.value - return response + return DeviceCounterSensorState( + **super().state, + state=self._zigpy_counter.value, + available=self._device.available, + ).model_dump() @property def native_value(self) -> int | None: @@ -574,6 +568,25 @@ def formatter(value: int) -> int | None: # pylint: disable=arguments-differ value = round(value / 2) return value + @property + def info_object(self) -> BatteryEntityInfo: + """Return a representation of the sensor.""" + return BatteryEntityInfo( + **super(Sensor, self).info_object.model_dump(exclude=[MODEL_CLASS_NAME]), + attribute=self._attribute_name, + decimals=self._decimals, + divisor=self._divisor, + multiplier=self._multiplier, + unit=( + getattr(self, "entity_description").native_unit_of_measurement + if getattr(self, "entity_description", None) is not None + else self._attr_native_unit_of_measurement + ), + extra_state_attribute_names=getattr( + self, "_attr_extra_state_attribute_names", None + ), + ) + @property def state(self) -> dict[str, Any]: """Return the state for battery sensors.""" @@ -587,7 +600,8 @@ def state(self) -> dict[str, Any]: battery_voltage = self._cluster_handler.cluster.get("battery_voltage") if battery_voltage is not None: response["battery_voltage"] = round(battery_voltage / 10, 2) - return response + + return BatteryState(**response).model_dump() @MULTI_MATCH( @@ -620,6 +634,27 @@ def __init__( f"{self._attribute_name}_max", } + @property + def info_object(self) -> ElectricalMeasurementEntityInfo: + """Return a representation of the sensor.""" + return ElectricalMeasurementEntityInfo( + **super(Sensor, self).info_object.model_dump(exclude=[MODEL_CLASS_NAME]), + attribute=self._attribute_name, + decimals=self._decimals, + divisor=self._divisor, + multiplier=self._multiplier, + unit=( + getattr(self, "entity_description").native_unit_of_measurement + if getattr(self, "entity_description", None) is not None + else self._attr_native_unit_of_measurement + ), + measurement_type=self._cluster_handler.measurement_type, + extra_state_attribute_names=getattr( + self, "_attr_extra_state_attribute_names", None + ), + entity_desctiption=getattr(self, "entity_description", None), + ) + @property def state(self) -> dict[str, Any]: """Return the state for this sensor.""" @@ -628,13 +663,13 @@ def state(self) -> dict[str, Any]: response["measurement_type"] = self._cluster_handler.measurement_type max_attr_name = f"{self._attribute_name}_max" - if not hasattr(self._cluster_handler.cluster.AttributeDefs, max_attr_name): - return response - - if (max_v := self._cluster_handler.cluster.get(max_attr_name)) is not None: + if ( + hasattr(self._cluster_handler.cluster.AttributeDefs, max_attr_name) + and (max_v := self._cluster_handler.cluster.get(max_attr_name)) is not None + ): response[max_attr_name] = self.formatter(max_v) - return response + return ElectricalMeasurementState(**response).model_dump() def formatter(self, value: int) -> int | float: """Return 'normalized' value.""" @@ -783,17 +818,6 @@ def formatter(self, value: int) -> int | None: return round(pow(10, ((value - 1) / 10000))) -@dataclass(frozen=True, kw_only=True) -class SmartEnergyMeteringEntityDescription: - """Dataclass that describes a Zigbee smart energy metering entity.""" - - key: str = "instantaneous_demand" - state_class: SensorStateClass | None = SensorStateClass.MEASUREMENT - scale: int = 1 - native_unit_of_measurement: str | None = None - device_class: SensorDeviceClass | None = None - - @MULTI_MATCH( cluster_handler_names=CLUSTER_HANDLER_SMARTENERGY_METERING, stop_on_match_group=CLUSTER_HANDLER_SMARTENERGY_METERING, @@ -887,6 +911,26 @@ def __init__( self._attr_device_class = entity_description.device_class self._attr_state_class = entity_description.state_class + @property + def info_object(self) -> SmartEnergyMeteringEntityInfo: + """Return a representation of the sensor.""" + return SmartEnergyMeteringEntityInfo( + **super(Sensor, self).info_object.model_dump(exclude=[MODEL_CLASS_NAME]), + attribute=self._attribute_name, + decimals=self._decimals, + divisor=self._divisor, + multiplier=self._multiplier, + unit=( + getattr(self, "entity_description").native_unit_of_measurement + if getattr(self, "entity_description", None) is not None + else self._attr_native_unit_of_measurement + ), + extra_state_attribute_names=getattr( + self, "_attr_extra_state_attribute_names", None + ), + entity_desctiption=getattr(self, "entity_description", None), + ) + @property def state(self) -> dict[str, Any]: """Return state for this sensor.""" @@ -901,21 +945,31 @@ def state(self) -> dict[str, Any]: else: response["status"] = str(status)[len(status.__class__.__name__) + 1 :] response["zcl_unit_of_measurement"] = self._cluster_handler.unit_of_measurement - return response + return SmartEnergyMeteringState(**response).model_dump() + + @property + def device_class(self) -> str | None: + """Return the device class.""" + return ( + getattr(self, "entity_description").device_class + if getattr(self, "entity_description", None) is not None + else self._attr_device_class + ) + + @property + def state_class(self) -> str | None: + """Return the state class.""" + return ( + getattr(self, "entity_description").state_class + if getattr(self, "entity_description", None) is not None + else self._attr_state_class + ) def formatter(self, value: int) -> int | float: """Pass through cluster handler formatter.""" return self._cluster_handler.demand_formatter(value) -@dataclass(frozen=True, kw_only=True) -class SmartEnergySummationEntityDescription(SmartEnergyMeteringEntityDescription): - """Dataclass that describes a Zigbee smart energy summation entity.""" - - key: str = "summation_delivered" - state_class: SensorStateClass | None = SensorStateClass.TOTAL_INCREASING - - @MULTI_MATCH( cluster_handler_names=CLUSTER_HANDLER_SMARTENERGY_METERING, stop_on_match_group=CLUSTER_HANDLER_SMARTENERGY_METERING, @@ -1310,7 +1364,7 @@ def create_platform_entity( return cls(unique_id, cluster_handlers, endpoint, device, **kwargs) @property - def state(self) -> dict: + def state(self) -> dict[str, Any]: """Return the current HVAC action.""" response = super().state if ( @@ -1320,7 +1374,7 @@ def state(self) -> dict: response["state"] = self._rm_rs_action else: response["state"] = self._pi_demand_action - return response + return EntityState(**response).model_dump() @property def native_value(self) -> str | None: @@ -1461,11 +1515,11 @@ def __init__( self.device.gateway.global_updater.register_update_listener(self.update) @property - def state(self) -> dict: + def state(self) -> dict[str, Any]: """Return the state of the sensor.""" response = super().state response["state"] = getattr(self.device.device, self._unique_id_suffix) - return response + return EntityState(**response).model_dump() @property def native_value(self) -> str | int | float | None: @@ -1677,6 +1731,32 @@ class SetpointChangeSourceTimestamp(TimestampSensor): _attr_entity_category = EntityCategory.DIAGNOSTIC _attr_device_class = SensorDeviceClass.TIMESTAMP + @property + def info_object(self) -> SetpointChangeSourceTimestampSensorEntityInfo: + """Return the info object for this entity.""" + return SetpointChangeSourceTimestampSensorEntityInfo( + **super(Sensor, self).info_object.model_dump(exclude=[MODEL_CLASS_NAME]), + attribute=self._attribute_name, + decimals=self._decimals, + divisor=self._divisor, + multiplier=self._multiplier, + unit=( + getattr(self, "entity_description").native_unit_of_measurement + if getattr(self, "entity_description", None) is not None + else self._attr_native_unit_of_measurement + ), + extra_state_attribute_names=getattr( + self, "_attr_extra_state_attribute_names", None + ), + ) + + @property + def state(self) -> dict[str, Any]: + """Return the state for this sensor.""" + response = super(Sensor, self).state + response["state"] = self.native_value + return TimestampState(**response).model_dump() + @CONFIG_DIAGNOSTIC_MATCH(cluster_handler_names=CLUSTER_HANDLER_COVER) class WindowCoveringTypeSensor(EnumSensor): @@ -1757,7 +1837,7 @@ def state(self) -> dict[str, Any]: response[bit.name] = False else: response[bit.name] = bit in self._bitmap(value) - return response + return EntityState(**response).model_dump() def formatter(self, _value: int) -> str: """Summary of all attributes.""" @@ -1858,3 +1938,46 @@ class DanfossMotorStepCounter(Sensor): _attribute_name = "motor_step_counter" _attr_translation_key: str = "motor_stepcount" _attr_entity_category = EntityCategory.DIAGNOSTIC + + +class WebSocketClientSensorEntity( + WebSocketClientEntity[BaseSensorEntityInfo], SensorEntityInterface +): + """Representation of a ZHA sensor entity.""" + + PLATFORM: Platform = Platform.SENSOR + + def __init__( + self, entity_info: BaseSensorEntityInfo, device: WebSocketClientDevice + ) -> None: + """Initialize the ZHA alarm control device.""" + super().__init__(entity_info, device) + + @property + def device_class(self) -> str | None: + """Return the device class of the sensor.""" + return self.info_object.device_class + + @property + def state_class(self) -> str | None: + """Return the state class of the sensor.""" + return self.info_object.state_class + + @property + def entity_description( + self, + ) -> SmartEnergyMeteringEntityDescription | SmartEnergySummationEntityDescription: + """Return the entity description for this entity.""" + return self.info_object.entity_description + + @property + def extra_state_attribute_names(self) -> set[str] | None: + """Return the extra state attribute names.""" + if hasattr(self.info_object, "extra_state_attribute_names"): + return self.info_object.extra_state_attribute_names + return None + + @property + def native_value(self) -> date | datetime | str | int | float | None: + """Return the state of the entity.""" + return self.info_object.state.state diff --git a/zha/application/platforms/sensor/model.py b/zha/application/platforms/sensor/model.py new file mode 100644 index 000000000..aec418cdc --- /dev/null +++ b/zha/application/platforms/sensor/model.py @@ -0,0 +1,166 @@ +"""Models for the sensor platform.""" + +from __future__ import annotations + +from datetime import datetime + +from pydantic import ValidationInfo, field_validator +from zigpy.types.named import EUI64 + +from zha.application.platforms.model import ( + BaseEntityInfo, + BaseIdentifiers, + BasePlatformEntityInfo, + EntityState, +) +from zha.application.platforms.sensor.const import SensorDeviceClass, SensorStateClass +from zha.model import BaseEventedModel, BaseModel, TypedBaseModel + + +class BatteryState(TypedBaseModel): + """Battery state model.""" + + state: str | float | int | None = None + battery_size: str | None = None + battery_quantity: int | None = None + battery_voltage: float | None = None + available: bool + + +class ElectricalMeasurementState(TypedBaseModel): + """Electrical measurement state model.""" + + state: str | float | int | None = None + measurement_type: str | None = None + active_power_max: float | None = None + rms_current_max: float | None = None + rms_voltage_max: float | None = None + available: bool + + +class SmartEnergyMeteringState(TypedBaseModel): + """Smare energy metering state model.""" + + state: str | float | int | None = None + device_type: str | None = None + status: str | None = None + available: bool + + +class DeviceCounterSensorState(TypedBaseModel): + """Device counter sensor state model.""" + + state: int + available: bool + + +class SmartEnergyMeteringEntityDescription(BaseModel): + """Model that describes a Zigbee smart energy metering entity.""" + + key: str = "instantaneous_demand" + state_class: SensorStateClass | None = SensorStateClass.MEASUREMENT + scale: int = 1 + native_unit_of_measurement: str | None = None + device_class: SensorDeviceClass | None = None + + +class SmartEnergySummationEntityDescription(SmartEnergyMeteringEntityDescription): + """Model that describes a Zigbee smart energy summation entity.""" + + key: str = "summation_delivered" + state_class: SensorStateClass | None = SensorStateClass.TOTAL_INCREASING + + +class BaseSensorEntityInfo(BasePlatformEntityInfo): + """Sensor model.""" + + attribute: str | None = None + decimals: int + divisor: int + multiplier: int | float + unit: int | str | None = None + device_class: SensorDeviceClass | None = None + state_class: SensorStateClass | None = None + extra_state_attribute_names: set[str] | None = None + + +class SensorEntityInfo(BaseSensorEntityInfo): + """Sensor entity model.""" + + state: EntityState + + +class TimestampState(TypedBaseModel): + """Default state model.""" + + available: bool | None = None + state: datetime | None = None + + +class SetpointChangeSourceTimestampSensorEntityInfo(BaseSensorEntityInfo): + """Setpoint change source timestamp sensor model.""" + + state: TimestampState + + +class DeviceCounterSensorEntityInfo(BaseEventedModel, BaseEntityInfo): + """Device counter sensor model.""" + + counter: str + counter_value: int + counter_groups: str + counter_group: str + state: DeviceCounterSensorState + + @field_validator("state", mode="before", check_fields=False) + @classmethod + def convert_state( + cls, state: dict | int | None, validation_info: ValidationInfo + ) -> DeviceCounterSensorState: + """Convert counter value to counter_value.""" + if state is not None: + if isinstance(state, int): + return DeviceCounterSensorState(state=state) + if isinstance(state, dict): + if "state" in state: + return DeviceCounterSensorState( + state=state["state"], available=state["available"] + ) + else: + return DeviceCounterSensorState( + state=validation_info.data["counter_value"], + available=state["available"], + ) + return DeviceCounterSensorState( + state=validation_info.data["counter_value"], + available=validation_info.data["available"], + ) + + +class BatteryEntityInfo(BaseSensorEntityInfo): + """Battery entity model.""" + + state: BatteryState + + +class ElectricalMeasurementEntityInfo(BaseSensorEntityInfo): + """Electrical measurement entity model.""" + + state: ElectricalMeasurementState + + +class SmartEnergyMeteringEntityInfo(BaseSensorEntityInfo): + """Smare energy metering entity model.""" + + state: SmartEnergyMeteringState + entity_description: ( + SmartEnergySummationEntityDescription + | SmartEnergyMeteringEntityDescription + | None + ) = None + + +class DeviceCounterSensorIdentifiers(BaseIdentifiers): + """Device counter sensor identifiers.""" + + device_ieee: EUI64 diff --git a/zha/application/platforms/siren.py b/zha/application/platforms/siren/__init__.py similarity index 71% rename from zha/application/platforms/siren.py rename to zha/application/platforms/siren/__init__.py index b5ab76b17..6c511cc1f 100644 --- a/zha/application/platforms/siren.py +++ b/zha/application/platforms/siren/__init__.py @@ -2,12 +2,11 @@ from __future__ import annotations +from abc import ABC, abstractmethod import asyncio import contextlib -from dataclasses import dataclass -from enum import IntFlag import functools -from typing import TYPE_CHECKING, Any, Final, cast +from typing import TYPE_CHECKING, Any, cast from zigpy.zcl.clusters.security import IasWd as WD @@ -25,45 +24,53 @@ WARNING_DEVICE_STROBE_NO, Strobe, ) -from zha.application.platforms import BaseEntityInfo, PlatformEntity +from zha.application.platforms import PlatformEntity, WebSocketClientEntity +from zha.application.platforms.model import EntityState +from zha.application.platforms.siren.const import ( + ATTR_DURATION, + ATTR_TONE, + ATTR_VOLUME_LEVEL, + DEFAULT_DURATION, + SirenEntityFeature, +) +from zha.application.platforms.siren.model import SirenEntityInfo from zha.application.registries import PLATFORM_ENTITIES +from zha.const import MODEL_CLASS_NAME from zha.zigbee.cluster_handlers.const import CLUSTER_HANDLER_IAS_WD from zha.zigbee.cluster_handlers.security import IasWdClusterHandler if TYPE_CHECKING: from zha.zigbee.cluster_handlers import ClusterHandler - from zha.zigbee.device import Device + from zha.zigbee.device import Device, WebSocketClientDevice from zha.zigbee.endpoint import Endpoint MULTI_MATCH = functools.partial(PLATFORM_ENTITIES.multipass_match, Platform.SIREN) -DEFAULT_DURATION = 5 # seconds - -ATTR_AVAILABLE_TONES: Final[str] = "available_tones" -ATTR_DURATION: Final[str] = "duration" -ATTR_VOLUME_LEVEL: Final[str] = "volume_level" -ATTR_TONE: Final[str] = "tone" -class SirenEntityFeature(IntFlag): - """Supported features of the siren entity.""" +class SirenEntityInterface(ABC): + """Siren interface.""" - TURN_ON = 1 - TURN_OFF = 2 - TONES = 4 - VOLUME_SET = 8 - DURATION = 16 + @property + @abstractmethod + def is_on(self) -> bool: + """Return true if the entity is on.""" + @property + @abstractmethod + def supported_features(self) -> SirenEntityFeature: + """Return supported features.""" -@dataclass(frozen=True, kw_only=True) -class SirenEntityInfo(BaseEntityInfo): - """Siren entity info.""" + @abstractmethod + async def async_turn_on(self, **kwargs: Any) -> None: + """Turn on the siren.""" - available_tones: dict[int, str] - supported_features: SirenEntityFeature + @abstractmethod + async def async_turn_off(self, **kwargs: Any) -> None: + """Turn off the siren.""" @MULTI_MATCH(cluster_handler_names=CLUSTER_HANDLER_IAS_WD) -class Siren(PlatformEntity): +class Siren(PlatformEntity, SirenEntityInterface): """Representation of a ZHA siren.""" PLATFORM = Platform.SIREN @@ -100,11 +107,11 @@ def __init__( self._attr_is_on: bool = False self._off_listener: asyncio.TimerHandle | None = None - @functools.cached_property + @property def info_object(self) -> SirenEntityInfo: """Return representation of the siren.""" return SirenEntityInfo( - **super().info_object.__dict__, + **super().info_object.model_dump(exclude=[MODEL_CLASS_NAME]), available_tones=self._attr_available_tones, supported_features=self._attr_supported_features, ) @@ -112,9 +119,10 @@ def info_object(self) -> SirenEntityInfo: @property def state(self) -> dict[str, Any]: """Get the state of the siren.""" - response = super().state - response["state"] = self.is_on - return response + return EntityState( + **super().state, + state=self.is_on, + ).model_dump() @property def supported_features(self) -> SirenEntityFeature: @@ -198,3 +206,35 @@ def async_set_off(self) -> None: self._off_listener = None self.maybe_emit_state_changed_event() + + +class WebSocketClientSirenEntity( + WebSocketClientEntity[SirenEntityInfo], SirenEntityInterface +): + """Siren entity for the WebSocket API.""" + + PLATFORM = Platform.SIREN + + def __init__( + self, entity_info: SirenEntityInfo, device: WebSocketClientDevice + ) -> None: + """Initialize the ZHA siren device.""" + super().__init__(entity_info, device) + + @property + def is_on(self) -> bool: + """Return true if the entity is on.""" + return bool(self.info_object.state.state) + + @property + def supported_features(self) -> SirenEntityFeature: + """Return supported features.""" + return self.info_object.supported_features + + async def async_turn_on(self, **kwargs: Any) -> None: + """Turn on the siren.""" + await self._device.gateway.sirens.turn_on(self.info_object, **kwargs) + + async def async_turn_off(self, **kwargs: Any) -> None: + """Turn off the siren.""" + await self._device.gateway.sirens.turn_off(self.info_object, **kwargs) diff --git a/zha/application/platforms/siren/const.py b/zha/application/platforms/siren/const.py new file mode 100644 index 000000000..1b8bea41b --- /dev/null +++ b/zha/application/platforms/siren/const.py @@ -0,0 +1,23 @@ +"""Constants for the Siren platform.""" + +from __future__ import annotations + +from enum import IntFlag +from typing import Final + +DEFAULT_DURATION = 5 # seconds + +ATTR_AVAILABLE_TONES: Final[str] = "available_tones" +ATTR_DURATION: Final[str] = "duration" +ATTR_VOLUME_LEVEL: Final[str] = "volume_level" +ATTR_TONE: Final[str] = "tone" + + +class SirenEntityFeature(IntFlag): + """Supported features of the siren entity.""" + + TURN_ON = 1 + TURN_OFF = 2 + TONES = 4 + VOLUME_SET = 8 + DURATION = 16 diff --git a/zha/application/platforms/siren/model.py b/zha/application/platforms/siren/model.py new file mode 100644 index 000000000..86eec17e3 --- /dev/null +++ b/zha/application/platforms/siren/model.py @@ -0,0 +1,14 @@ +"""Models for the siren platform.""" + +from __future__ import annotations + +from zha.application.platforms.model import BasePlatformEntityInfo, EntityState +from zha.application.platforms.siren.const import SirenEntityFeature + + +class SirenEntityInfo(BasePlatformEntityInfo): + """Siren entity model.""" + + available_tones: dict[int, str] + supported_features: SirenEntityFeature + state: EntityState diff --git a/zha/application/platforms/siren/websocket_api.py b/zha/application/platforms/siren/websocket_api.py new file mode 100644 index 000000000..0b88c6e87 --- /dev/null +++ b/zha/application/platforms/siren/websocket_api.py @@ -0,0 +1,58 @@ +"""WS api for the siren platform entity.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal, Union + +from zha.application.discovery import Platform +from zha.application.platforms.websocket_api import ( + PlatformEntityCommand, + execute_platform_entity_command, +) +from zha.websocket.const import APICommands +from zha.websocket.server.api import decorators, register_api_command + +if TYPE_CHECKING: + from zha.application.gateway import WebSocketServerGateway + from zha.websocket.server.client import Client + + +class SirenTurnOnCommand(PlatformEntityCommand): + """Siren turn on command.""" + + command: Literal[APICommands.SIREN_TURN_ON] = APICommands.SIREN_TURN_ON + platform: str = Platform.SIREN + duration: Union[int, None] = None + tone: Union[int, None] = None + volume_level: Union[int, None] = None + + +@decorators.websocket_command(SirenTurnOnCommand) +@decorators.async_response +async def turn_on( + gateway: WebSocketServerGateway, client: Client, command: SirenTurnOnCommand +) -> None: + """Turn on the siren.""" + await execute_platform_entity_command(gateway, client, command, "async_turn_on") + + +class SirenTurnOffCommand(PlatformEntityCommand): + """Siren turn off command.""" + + command: Literal[APICommands.SIREN_TURN_OFF] = APICommands.SIREN_TURN_OFF + platform: str = Platform.SIREN + + +@decorators.websocket_command(SirenTurnOffCommand) +@decorators.async_response +async def turn_off( + gateway: WebSocketServerGateway, client: Client, command: SirenTurnOffCommand +) -> None: + """Turn on the siren.""" + await execute_platform_entity_command(gateway, client, command, "async_turn_off") + + +def load_api(gateway: WebSocketServerGateway) -> None: + """Load the api command handlers.""" + register_api_command(gateway, turn_on) + register_api_command(gateway, turn_off) diff --git a/zha/application/platforms/switch.py b/zha/application/platforms/switch/__init__.py similarity index 90% rename from zha/application/platforms/switch.py rename to zha/application/platforms/switch/__init__.py index b5f536109..d27cd5b29 100644 --- a/zha/application/platforms/switch.py +++ b/zha/application/platforms/switch/__init__.py @@ -2,8 +2,7 @@ from __future__ import annotations -from abc import ABC -from dataclasses import dataclass +from abc import ABC, abstractmethod import functools import logging from typing import TYPE_CHECKING, Any, Self, cast @@ -18,13 +17,18 @@ from zha.application.const import ENTITY_METADATA from zha.application.platforms import ( BaseEntity, - BaseEntityInfo, - EntityCategory, GroupEntity, PlatformEntity, + WebSocketClientEntity, +) +from zha.application.platforms.const import EntityCategory +from zha.application.platforms.switch.model import ( + ConfigurableAttributeSwitchEntityInfo, + SwitchEntityInfo, + SwitchState, ) from zha.application.registries import PLATFORM_ENTITIES -from zha.zigbee.cluster_handlers import ClusterAttributeUpdatedEvent +from zha.const import MODEL_CLASS_NAME from zha.zigbee.cluster_handlers.const import ( CLUSTER_HANDLER_ATTRIBUTE_UPDATED, CLUSTER_HANDLER_BASIC, @@ -34,12 +38,12 @@ CLUSTER_HANDLER_THERMOSTAT, ) from zha.zigbee.cluster_handlers.general import OnOffClusterHandler -from zha.zigbee.group import Group if TYPE_CHECKING: - from zha.zigbee.cluster_handlers import ClusterHandler - from zha.zigbee.device import Device + from zha.zigbee.cluster_handlers import ClusterAttributeUpdatedEvent, ClusterHandler + from zha.zigbee.device import Device, WebSocketClientDevice from zha.zigbee.endpoint import Endpoint + from zha.zigbee.group import Group STRICT_MATCH = functools.partial(PLATFORM_ENTITIES.strict_match, Platform.SWITCH) GROUP_MATCH = functools.partial(PLATFORM_ENTITIES.group_match, Platform.SWITCH) @@ -50,19 +54,25 @@ _LOGGER = logging.getLogger(__name__) -@dataclass(frozen=True, kw_only=True) -class ConfigurableAttributeSwitchInfo(BaseEntityInfo): - """Switch configuration entity info.""" +class SwitchEntityInterface(ABC): + """Switch interface.""" + + @property + @abstractmethod + def is_on(self) -> bool: + """Return if the switch is on based on the statemachine.""" + + @abstractmethod + async def async_turn_on(self, **kwargs: Any) -> None: + """Turn the entity on.""" - attribute_name: str - invert_attribute_name: str | None - force_inverted: bool - off_value: int - on_value: int + @abstractmethod + async def async_turn_off(self, **kwargs: Any) -> None: + """Turn the entity off.""" -class BaseSwitch(BaseEntity, ABC): - """Common base class for zhawss switches.""" +class BaseSwitch(BaseEntity, SwitchEntityInterface): + """Common base class for zha switches.""" PLATFORM = Platform.SWITCH @@ -75,13 +85,6 @@ def __init__( self._on_off_cluster_handler: OnOffClusterHandler super().__init__(*args, **kwargs) - @property - def state(self) -> dict[str, Any]: - """Return the state of the switch.""" - response = super().state - response["state"] = self.is_on - return response - @property def is_on(self) -> bool: """Return if the switch is on based on the statemachine.""" @@ -125,6 +128,18 @@ def __init__( self.handle_cluster_handler_attribute_updated, ) + @property + def info_object(self) -> SwitchEntityInfo: + """Return representation of the switch entity.""" + return SwitchEntityInfo( + **super().info_object.model_dump(exclude=[MODEL_CLASS_NAME]), + ) + + @property + def state(self) -> dict[str, Any]: + """Return the state of the switch.""" + return SwitchState(**super().state, state=self.is_on).model_dump() + def handle_cluster_handler_attribute_updated( self, event: ClusterAttributeUpdatedEvent, # pylint: disable=unused-argument @@ -143,10 +158,20 @@ def __init__(self, group: Group): super().__init__(group) self._state: bool self._on_off_cluster_handler = group.zigpy_group.endpoint[OnOff.cluster_id] - if hasattr(self, "info_object"): - delattr(self, "info_object") self.update() + @property + def info_object(self) -> SwitchEntityInfo: + """Return representation of the switch entity.""" + return SwitchEntityInfo( + **super().info_object.model_dump(exclude=[MODEL_CLASS_NAME]), + ) + + @property + def state(self) -> dict[str, Any]: + """Return the state of the switch.""" + return SwitchState(**super().state, state=self.is_on).model_dump() + @property def is_on(self) -> bool: """Return if the switch is on based on the statemachine.""" @@ -250,11 +275,11 @@ def _init_from_quirks_metadata(self, entity_metadata: SwitchMetadata) -> None: self._off_value = entity_metadata.off_value self._on_value = entity_metadata.on_value - @functools.cached_property - def info_object(self) -> ConfigurableAttributeSwitchInfo: + @property + def info_object(self) -> ConfigurableAttributeSwitchEntityInfo: """Return representation of the switch configuration entity.""" - return ConfigurableAttributeSwitchInfo( - **super().info_object.__dict__, + return ConfigurableAttributeSwitchEntityInfo( + **super().info_object.model_dump(exclude=[MODEL_CLASS_NAME]), attribute_name=self._attribute_name, invert_attribute_name=self._inverter_attribute_name, force_inverted=self._force_inverted, @@ -265,10 +290,11 @@ def info_object(self) -> ConfigurableAttributeSwitchInfo: @property def state(self) -> dict[str, Any]: """Return the state of the switch.""" - response = super().state - response["state"] = self.is_on - response["inverted"] = self.inverted - return response + return SwitchState( + **super().state, + state=self.is_on, + inverted=self.inverted, + ).model_dump() @property def inverted(self) -> bool: @@ -861,3 +887,30 @@ class SinopeLightDoubleTapFullSwitch(ConfigurableAttributeSwitch): _unique_id_suffix = "double_up_full" _attribute_name = "double_up_full" _attr_translation_key: str = "double_up_full" + + +class WebSocketClientSwitchEntity( + WebSocketClientEntity[SwitchEntityInfo], SwitchEntityInterface +): + """Defines a ZHA switch that is controlled via a websocket.""" + + PLATFORM = Platform.SWITCH + + def __init__( + self, entity_info: SwitchEntityInfo, device: WebSocketClientDevice + ) -> None: + """Initialize the ZHA switch entity.""" + super().__init__(entity_info, device) + + @property + def is_on(self) -> bool: + """Return if the switch is on based on the statemachine.""" + return self.info_object.state.state + + async def async_turn_on(self, **kwargs: Any) -> None: + """Turn the entity on.""" + await self._device.gateway.switches.turn_on(self.info_object) + + async def async_turn_off(self, **kwargs: Any) -> None: + """Turn the entity off.""" + await self._device.gateway.switches.turn_off(self.info_object) diff --git a/zha/application/platforms/switch/model.py b/zha/application/platforms/switch/model.py new file mode 100644 index 000000000..d63e196f5 --- /dev/null +++ b/zha/application/platforms/switch/model.py @@ -0,0 +1,31 @@ +"""Models for the switch platform.""" + +from __future__ import annotations + +from zha.application.platforms.model import BasePlatformEntityInfo +from zha.model import TypedBaseModel + + +class SwitchState(TypedBaseModel): + """Switch state model.""" + + state: bool + available: bool + inverted: bool | None = None + + +class SwitchEntityInfo(BasePlatformEntityInfo): + """Switch entity model.""" + + state: SwitchState + + +class ConfigurableAttributeSwitchEntityInfo(BasePlatformEntityInfo): + """Switch configuration entity info.""" + + attribute_name: str + invert_attribute_name: str | None = None + force_inverted: bool + off_value: int + on_value: int + state: SwitchState diff --git a/zha/application/platforms/switch/websocket_api.py b/zha/application/platforms/switch/websocket_api.py new file mode 100644 index 000000000..9b2ccb7fb --- /dev/null +++ b/zha/application/platforms/switch/websocket_api.py @@ -0,0 +1,55 @@ +"""WS api for the switch platform entity.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal + +from zha.application.discovery import Platform +from zha.application.platforms.websocket_api import ( + PlatformEntityCommand, + execute_platform_entity_command, +) +from zha.websocket.const import APICommands +from zha.websocket.server.api import decorators, register_api_command + +if TYPE_CHECKING: + from zha.application.gateway import WebSocketServerGateway + from zha.websocket.server.client import Client + + +class SwitchTurnOnCommand(PlatformEntityCommand): + """Switch turn on command.""" + + command: Literal[APICommands.SWITCH_TURN_ON] = APICommands.SWITCH_TURN_ON + platform: str = Platform.SWITCH + + +@decorators.websocket_command(SwitchTurnOnCommand) +@decorators.async_response +async def turn_on( + gateway: WebSocketServerGateway, client: Client, command: SwitchTurnOnCommand +) -> None: + """Turn on the switch.""" + await execute_platform_entity_command(gateway, client, command, "async_turn_on") + + +class SwitchTurnOffCommand(PlatformEntityCommand): + """Switch turn off command.""" + + command: Literal[APICommands.SWITCH_TURN_OFF] = APICommands.SWITCH_TURN_OFF + platform: str = Platform.SWITCH + + +@decorators.websocket_command(SwitchTurnOffCommand) +@decorators.async_response +async def turn_off( + gateway: WebSocketServerGateway, client: Client, command: SwitchTurnOffCommand +) -> None: + """Turn on the switch.""" + await execute_platform_entity_command(gateway, client, command, "async_turn_off") + + +def load_api(gateway: WebSocketServerGateway) -> None: + """Load the api command handlers.""" + register_api_command(gateway, turn_on) + register_api_command(gateway, turn_off) diff --git a/zha/application/platforms/update.py b/zha/application/platforms/update/__init__.py similarity index 60% rename from zha/application/platforms/update.py rename to zha/application/platforms/update/__init__.py index 2834ce914..0c2a23a3f 100644 --- a/zha/application/platforms/update.py +++ b/zha/application/platforms/update/__init__.py @@ -2,22 +2,37 @@ from __future__ import annotations -from dataclasses import dataclass -from enum import IntFlag, StrEnum +from abc import ABC, abstractmethod import functools import itertools import logging -from typing import TYPE_CHECKING, Any, Final, final +from typing import TYPE_CHECKING, Any, final from zigpy.ota import OtaImagesResult, OtaImageWithMetadata from zigpy.zcl.clusters.general import Ota, QueryNextImageCommand from zigpy.zcl.foundation import Status from zha.application import Platform -from zha.application.platforms import BaseEntityInfo, EntityCategory, PlatformEntity +from zha.application.platforms import PlatformEntity, WebSocketClientEntity +from zha.application.platforms.const import EntityCategory +from zha.application.platforms.update.const import ( + ATTR_IN_PROGRESS, + ATTR_INSTALLED_VERSION, + ATTR_LATEST_VERSION, + ATTR_RELEASE_NOTES, + ATTR_RELEASE_SUMMARY, + ATTR_RELEASE_URL, + ATTR_UPDATE_PERCENTAGE, + UpdateDeviceClass, + UpdateEntityFeature, +) +from zha.application.platforms.update.model import ( + FirmwareUpdateEntityInfo, + FirmwareUpdateState, +) from zha.application.registries import PLATFORM_ENTITIES +from zha.const import MODEL_CLASS_NAME from zha.exceptions import ZHAException -from zha.zigbee.cluster_handlers import ClusterAttributeUpdatedEvent from zha.zigbee.cluster_handlers.const import ( CLUSTER_HANDLER_ATTRIBUTE_UPDATED, CLUSTER_HANDLER_OTA, @@ -25,8 +40,8 @@ from zha.zigbee.endpoint import Endpoint if TYPE_CHECKING: - from zha.zigbee.cluster_handlers import ClusterHandler - from zha.zigbee.device import Device + from zha.zigbee.cluster_handlers import ClusterAttributeUpdatedEvent, ClusterHandler + from zha.zigbee.device import Device, WebSocketClientDevice _LOGGER = logging.getLogger(__name__) @@ -35,44 +50,74 @@ ) -class UpdateDeviceClass(StrEnum): - """Device class for update.""" +class FirmwareUpdateEntityInterface(ABC): + """Base class for ZHA firmware update entity.""" - FIRMWARE = "firmware" + @property + @abstractmethod + def installed_version(self) -> str | None: + """Version installed and in use.""" + @property + @abstractmethod + def in_progress(self) -> bool | None: + """Update installation progress. -class UpdateEntityFeature(IntFlag): - """Supported features of the update entity.""" + Needs UpdateEntityFeature.PROGRESS flag to be set for it to be used. + + Returns a boolean (True if in progress, False if not). + """ - INSTALL = 1 - SPECIFIC_VERSION = 2 - PROGRESS = 4 - BACKUP = 8 - RELEASE_NOTES = 16 + @property + @abstractmethod + def update_percentage(self) -> float | None: + """Update installation progress. + + Returns a number indicating the progress from 0 to 100%. If an update's progress + is indeterminate, this will return None. + """ + + @property + @abstractmethod + def latest_version(self) -> str | None: + """Latest version available for install.""" + + @property + @abstractmethod + def release_summary(self) -> str | None: + """Summary of the release notes or changelog. + + This is not suitable for long changelogs, but merely suitable + for a short excerpt update description of max 255 characters. + """ + @property + @abstractmethod + def release_notes(self) -> str | None: + """Full release notes of the latest version available.""" -ATTR_BACKUP: Final = "backup" -ATTR_INSTALLED_VERSION: Final = "installed_version" -ATTR_IN_PROGRESS: Final = "in_progress" -ATTR_UPDATE_PERCENTAGE: Final = "update_percentage" -ATTR_LATEST_VERSION: Final = "latest_version" -ATTR_RELEASE_SUMMARY: Final = "release_summary" -ATTR_RELEASE_NOTES: Final = "release_notes" -ATTR_RELEASE_URL: Final = "release_url" -ATTR_VERSION: Final = "version" + @property + @abstractmethod + def release_url(self) -> str | None: + """URL to the full release notes of the latest version available.""" + @property + @abstractmethod + def supported_features(self) -> UpdateEntityFeature: + """Flag supported features.""" -@dataclass(frozen=True, kw_only=True) -class UpdateEntityInfo(BaseEntityInfo): - """Update entity info.""" + @property + @abstractmethod + def state_attributes(self) -> dict[str, Any] | None: + """Return state attributes.""" - supported_features: UpdateEntityFeature - device_class: UpdateDeviceClass - entity_category: EntityCategory + @abstractmethod + async def async_install(self, version: str | None) -> None: + """Install an update.""" @CONFIG_DIAGNOSTIC_MATCH(cluster_handler_names=CLUSTER_HANDLER_OTA) -class FirmwareUpdateEntity(PlatformEntity): +class FirmwareUpdateEntity(PlatformEntity, FirmwareUpdateEntityInterface): """Representation of a ZHA firmware update entity.""" PLATFORM = Platform.UPDATE @@ -119,20 +164,21 @@ def __init__( self.handle_cluster_handler_attribute_updated, ) - @functools.cached_property - def info_object(self) -> UpdateEntityInfo: + @property + def info_object(self) -> FirmwareUpdateEntityInfo: """Return a representation of the entity.""" - return UpdateEntityInfo( - **super().info_object.__dict__, + return FirmwareUpdateEntityInfo( + **super().info_object.model_dump(exclude=[MODEL_CLASS_NAME]), supported_features=self.supported_features, ) @property - def state(self): + def state(self) -> dict[str, Any]: """Get the state for the entity.""" - response = super().state - response.update(self.state_attributes) - return response + return FirmwareUpdateState( + **super().state, + **self.state_attributes, + ).model_dump() @property def installed_version(self) -> str | None: @@ -260,7 +306,7 @@ def _update_progress(self, current: int, total: int, progress: float) -> None: self._attr_update_percentage = progress self.maybe_emit_state_changed_event() - async def async_install(self, version: str | None) -> None: + async def async_install(self, version: str | None = None, **kwargs) -> None: """Install an update.""" if version is None: @@ -309,3 +355,87 @@ async def on_remove(self) -> None: self._attr_in_progress = False self.device.device.remove_listener(self) await super().on_remove() + + +class WebSocketClientFirmwareUpdateEntity( + WebSocketClientEntity[FirmwareUpdateEntityInfo], FirmwareUpdateEntityInterface +): + """Representation of a ZHA firmware update entity.""" + + PLATFORM = Platform.UPDATE + + def __init__( + self, entity_info: FirmwareUpdateEntityInfo, device: WebSocketClientDevice + ) -> None: + """Initialize the ZHA alarm control device.""" + super().__init__(entity_info, device) + + @property + def installed_version(self) -> str | None: + """Version installed and in use.""" + return self.info_object.state.installed_version + + @property + def in_progress(self) -> bool | None: + """Update installation progress. + + Needs UpdateEntityFeature.PROGRESS flag to be set for it to be used. + + Returns a boolean (True if in progress, False if not). + """ + return self.info_object.state.in_progress + + @property + def update_percentage(self) -> float | None: + """Update installation progress. + + Returns a number indicating the progress from 0 to 100%. If an update's progress + is indeterminate, this will return None. + """ + return self.info_object.state.progress + + @property + def latest_version(self) -> str | None: + """Latest version available for install.""" + return self.info_object.state.latest_version + + @property + def release_summary(self) -> str | None: + """Summary of the release notes or changelog. + + This is not suitable for long changelogs, but merely suitable + for a short excerpt update description of max 255 characters. + """ + return self.info_object.state.release_summary + + @property + def release_notes(self) -> str | None: + """Full release notes of the latest version available.""" + return self.info_object.state.release_notes + + @property + def release_url(self) -> str | None: + """URL to the full release notes of the latest version available.""" + return self.info_object.state.release_url + + @property + def supported_features(self) -> UpdateEntityFeature: + """Flag supported features.""" + return self.info_object.supported_features + + @property + def state_attributes(self) -> dict[str, Any] | None: + """Return state attributes.""" + return { + ATTR_INSTALLED_VERSION: self.installed_version, + ATTR_IN_PROGRESS: self.in_progress, + ATTR_UPDATE_PERCENTAGE: self.update_percentage, + ATTR_LATEST_VERSION: self.latest_version, + ATTR_RELEASE_SUMMARY: self.release_summary, + ATTR_RELEASE_NOTES: self.release_notes, + ATTR_RELEASE_URL: self.release_url, + } + + async def async_install(self, version: str | None) -> None: + """Install an update.""" + await self._device.gateway.update_helper.install_firmware(self, version) diff --git a/zha/application/platforms/update/const.py b/zha/application/platforms/update/const.py new file mode 100644 index 000000000..5d54a9358 --- /dev/null +++ b/zha/application/platforms/update/const.py @@ -0,0 +1,34 @@ +"""Constants for the ZHA update platform.""" + +from __future__ import annotations + +from enum import IntFlag, StrEnum +from typing import Final + +SERVICE_INSTALL: Final = "install" + +ATTR_BACKUP: Final = "backup" +ATTR_INSTALLED_VERSION: Final = "installed_version" +ATTR_IN_PROGRESS: Final = "in_progress" +ATTR_UPDATE_PERCENTAGE: Final = "update_percentage" +ATTR_LATEST_VERSION: Final = "latest_version" +ATTR_RELEASE_SUMMARY: Final = "release_summary" +ATTR_RELEASE_NOTES: Final = "release_notes" +ATTR_RELEASE_URL: Final = "release_url" +ATTR_VERSION: Final = "version" + + +class UpdateEntityFeature(IntFlag): + """Supported features of the update entity.""" + + INSTALL = 1 + SPECIFIC_VERSION = 2 + PROGRESS = 4 + BACKUP = 8 + RELEASE_NOTES = 16 + + +class UpdateDeviceClass(StrEnum): + """Device class for update.""" + + FIRMWARE = "firmware" diff --git a/zha/application/platforms/update/model.py b/zha/application/platforms/update/model.py new file mode 100644 index 000000000..6ba7e67c7 --- /dev/null +++ b/zha/application/platforms/update/model.py @@ -0,0 +1,27 @@ +"""Models for the update platform.""" + +from __future__ import annotations + +from zha.application.platforms.model import BasePlatformEntityInfo +from zha.application.platforms.update.const import UpdateEntityFeature +from zha.model import TypedBaseModel + + +class FirmwareUpdateState(TypedBaseModel): + """Firmware update state model.""" + + available: bool + installed_version: str | None = None + in_progress: bool | None = None + progress: int | None = None + latest_version: str | None = None + release_summary: str | None = None + release_notes: str | None = None + release_url: str | None = None + + +class FirmwareUpdateEntityInfo(BasePlatformEntityInfo): + """Firmware update entity model.""" + + state: FirmwareUpdateState + supported_features: UpdateEntityFeature diff --git a/zha/application/platforms/update/websocket_api.py b/zha/application/platforms/update/websocket_api.py new file mode 100644 index 000000000..ce9a991eb --- /dev/null +++ b/zha/application/platforms/update/websocket_api.py @@ -0,0 +1,39 @@ +"""WS api for the select platform entity.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal + +from zha.application.discovery import Platform +from zha.application.platforms.websocket_api import ( + PlatformEntityCommand, + execute_platform_entity_command, +) +from zha.websocket.const import APICommands +from zha.websocket.server.api import decorators, register_api_command + +if TYPE_CHECKING: + from zha.application.gateway import WebSocketServerGateway + from zha.websocket.server.client import Client + + +class InstallFirmwareCommand(PlatformEntityCommand): + """Install firmware command.""" + + command: Literal[APICommands.FIRMWARE_INSTALL] = APICommands.FIRMWARE_INSTALL + platform: str = Platform.UPDATE + version: str | None = None + + +@decorators.websocket_command(InstallFirmwareCommand) +@decorators.async_response +async def install_firmware( + gateway: WebSocketServerGateway, client: Client, command: InstallFirmwareCommand +) -> None: + """Select an option.""" + await execute_platform_entity_command(gateway, client, command, "async_install") + + +def load_api(gateway: WebSocketServerGateway) -> None: + """Load the api command handlers.""" + register_api_command(gateway, install_firmware) diff --git a/zha/application/platforms/websocket_api.py b/zha/application/platforms/websocket_api.py new file mode 100644 index 000000000..877e4e457 --- /dev/null +++ b/zha/application/platforms/websocket_api.py @@ -0,0 +1,177 @@ +"""WS API for common platform entity functionality.""" + +from __future__ import annotations + +import inspect +import logging +from typing import TYPE_CHECKING, Literal + +from zigpy.types.named import EUI64 + +from zha.application import Platform +from zha.websocket.const import APICommands +from zha.websocket.server.api import decorators, register_api_command +from zha.websocket.server.api.model import WebSocketCommand + +if TYPE_CHECKING: + from zha.application.gateway import WebSocketServerGateway + from zha.websocket.server.client import Client + +_LOGGER = logging.getLogger(__name__) + + +class PlatformEntityCommand(WebSocketCommand): + """Base class for platform entity commands.""" + + ieee: EUI64 | None = None + group_id: int | None = None + unique_id: str + platform: Platform + + +async def execute_platform_entity_command( + gateway: WebSocketServerGateway, + client: Client, + command: PlatformEntityCommand, + method_name: str, +) -> None: + """Get the platform entity and execute a method based on the command.""" + + _LOGGER.debug("attempting to execute platform entity command: %s", command) + + if command.group_id: + group = gateway.get_group(command.group_id) + platform_entity = group.group_entities[command.unique_id] + else: + device = gateway.get_device(command.ieee) + platform_entity = device.get_platform_entity( + command.platform, command.unique_id + ) + + if not platform_entity: + client.send_result_error( + command, "PLATFORM_ENTITY_COMMAND_ERROR", "platform entity not found" + ) + return None + + try: + action = getattr(platform_entity, method_name) + arg_spec = inspect.getfullargspec(action) + if arg_spec.varkw: + if inspect.iscoroutinefunction(action): + await action(**command.model_dump()) + else: + action(**command.model_dump()) + elif inspect.iscoroutinefunction(action): + await action() + else: + action() # the only argument is self + + except Exception as err: # pylint: disable=broad-except + _LOGGER.exception("Error executing command: %s", method_name, exc_info=err) + client.send_result_error(command, "PLATFORM_ENTITY_ACTION_ERROR", str(err)) + return + + client.send_result_success(command) + + +class PlatformEntityRefreshStateCommand(PlatformEntityCommand): + """Platform entity refresh state command.""" + + command: Literal[APICommands.PLATFORM_ENTITY_REFRESH_STATE] = ( + APICommands.PLATFORM_ENTITY_REFRESH_STATE + ) + + +@decorators.websocket_command(PlatformEntityRefreshStateCommand) +@decorators.async_response +async def refresh_state( + gateway: WebSocketServerGateway, client: Client, command: PlatformEntityCommand +) -> None: + """Refresh the state of the platform entity.""" + await execute_platform_entity_command(gateway, client, command, "async_update") + + +class PlatformEntityEnableCommand(PlatformEntityCommand): + """Platform entity enable command.""" + + command: Literal[APICommands.PLATFORM_ENTITY_ENABLE] = ( + APICommands.PLATFORM_ENTITY_ENABLE + ) + + +@decorators.websocket_command(PlatformEntityEnableCommand) +@decorators.async_response +async def enable( + gateway: WebSocketServerGateway, + client: Client, + command: PlatformEntityEnableCommand, +) -> None: + """Enable the platform entity.""" + await execute_platform_entity_command(gateway, client, command, "enable") + + +class PlatformEntityDisableCommand(PlatformEntityCommand): + """Platform entity disable command.""" + + command: Literal[APICommands.PLATFORM_ENTITY_DISABLE] = ( + APICommands.PLATFORM_ENTITY_DISABLE + ) + + +@decorators.websocket_command(PlatformEntityDisableCommand) +@decorators.async_response +async def disable( + gateway: WebSocketServerGateway, + client: Client, + command: PlatformEntityDisableCommand, +) -> None: + """Disable the platform entity.""" + await execute_platform_entity_command(gateway, client, command, "disable") + + +# pylint: disable=import-outside-toplevel +def load_platform_entity_apis(gateway: WebSocketServerGateway) -> None: + """Load the ws apis for all platform entities types.""" + from zha.application.platforms.alarm_control_panel.websocket_api import ( + load_api as load_alarm_control_panel_api, + ) + from zha.application.platforms.button.websocket_api import ( + load_api as load_button_api, + ) + from zha.application.platforms.climate.websocket_api import ( + load_api as load_climate_api, + ) + from zha.application.platforms.cover.websocket_api import load_api as load_cover_api + from zha.application.platforms.fan.websocket_api import load_api as load_fan_api + from zha.application.platforms.light.websocket_api import load_api as load_light_api + from zha.application.platforms.lock.websocket_api import load_api as load_lock_api + from zha.application.platforms.number.websocket_api import ( + load_api as load_number_api, + ) + from zha.application.platforms.select.websocket_api import ( + load_api as load_select_api, + ) + from zha.application.platforms.siren.websocket_api import load_api as load_siren_api + from zha.application.platforms.switch.websocket_api import ( + load_api as load_switch_api, + ) + from zha.application.platforms.update.websocket_api import ( + load_api as load_update_api, + ) + + register_api_command(gateway, refresh_state) + register_api_command(gateway, enable) + register_api_command(gateway, disable) + load_alarm_control_panel_api(gateway) + load_button_api(gateway) + load_climate_api(gateway) + load_cover_api(gateway) + load_fan_api(gateway) + load_light_api(gateway) + load_lock_api(gateway) + load_number_api(gateway) + load_select_api(gateway) + load_siren_api(gateway) + load_switch_api(gateway) + load_update_api(gateway) diff --git a/zha/application/websocket_api.py b/zha/application/websocket_api.py new file mode 100644 index 000000000..565f34e2a --- /dev/null +++ b/zha/application/websocket_api.py @@ -0,0 +1,512 @@ +"""Websocket API for zha.""" + +from __future__ import annotations + +import asyncio +import logging +from typing import TYPE_CHECKING, Annotated, Any, Literal, TypeVar, Union + +from pydantic import Field +from zigpy.types.named import EUI64 + +from zha.websocket.const import DEVICE, DEVICES, GROUPS, APICommands +from zha.websocket.server.api import decorators, register_api_command +from zha.websocket.server.api.model import ( + GetApplicationStateResponse, + GetDevicesResponse, + GroupsResponse, + PermitJoiningResponse, + ReadClusterAttributesResponse, + UpdateGroupResponse, + WebSocketCommand, + WriteClusterAttributeResponse, +) +from zha.zigbee.model import GroupMemberReference + +if TYPE_CHECKING: + from zha.application.gateway import WebSocketServerGateway + from zha.websocket.server.client import Client + from zha.zigbee.device import Device + from zha.zigbee.group import Group + + +GROUP = "group" +MFG_CLUSTER_ID_START = 0xFC00 + +_LOGGER = logging.getLogger(__name__) + +T = TypeVar("T") + + +class StartNetworkCommand(WebSocketCommand): + """Start the Zigbee network.""" + + command: Literal[APICommands.START_NETWORK] = APICommands.START_NETWORK + + +@decorators.websocket_command(StartNetworkCommand) +@decorators.async_response +async def start_network( + gateway: WebSocketServerGateway, client: Client, command: StartNetworkCommand +) -> None: + """Start the Zigbee network.""" + await gateway.start_network() + client.send_result_success(command) + + +class StopNetworkCommand(WebSocketCommand): + """Stop the Zigbee network.""" + + command: Literal[APICommands.STOP_NETWORK] = APICommands.STOP_NETWORK + + +@decorators.websocket_command(StopNetworkCommand) +@decorators.async_response +async def stop_network( + gateway: WebSocketServerGateway, client: Client, command: StopNetworkCommand +) -> None: + """Stop the Zigbee network.""" + await gateway.stop_network() + client.send_result_success(command) + + +class UpdateTopologyCommand(WebSocketCommand): + """Stop the Zigbee network.""" + + command: Literal[APICommands.UPDATE_NETWORK_TOPOLOGY] = ( + APICommands.UPDATE_NETWORK_TOPOLOGY + ) + + +@decorators.websocket_command(UpdateTopologyCommand) +@decorators.async_response +async def update_topology( + gateway: WebSocketServerGateway, client: Client, command: WebSocketCommand +) -> None: + """Update the Zigbee network topology.""" + await gateway.application_controller.topology.scan() + client.send_result_success(command) + + +class GetDevicesCommand(WebSocketCommand): + """Get all Zigbee devices.""" + + command: Literal[APICommands.GET_DEVICES] = APICommands.GET_DEVICES + + +@decorators.websocket_command(GetDevicesCommand) +@decorators.async_response +async def get_devices( + gateway: WebSocketServerGateway, client: Client, command: GetDevicesCommand +) -> None: + """Get Zigbee devices.""" + try: + client.send_result_success( + command, + data={ + DEVICES: { + ieee: device.extended_device_info + for ieee, device in gateway.devices.items() + } + }, + response_type=GetDevicesResponse, + ) + except Exception as e: + _LOGGER.exception("Error getting devices", exc_info=e) + client.send_result_error(command, "Error getting devices", str(e)) + + +class ReconfigureDeviceCommand(WebSocketCommand): + """Reconfigure a zigbee device.""" + + command: Literal[APICommands.RECONFIGURE_DEVICE] = APICommands.RECONFIGURE_DEVICE + ieee: EUI64 + + +@decorators.websocket_command(ReconfigureDeviceCommand) +@decorators.async_response +async def reconfigure_device( + gateway: WebSocketServerGateway, client: Client, command: ReconfigureDeviceCommand +) -> None: + """Reconfigure a zigbee device.""" + device = gateway.devices.get(command.ieee) + if device: + await device.async_configure() + client.send_result_success(command) + + +class GetGroupsCommand(WebSocketCommand): + """Get all Zigbee devices.""" + + command: Literal[APICommands.GET_GROUPS] = APICommands.GET_GROUPS + + +@decorators.websocket_command(GetGroupsCommand) +@decorators.async_response +async def get_groups( + gateway: WebSocketServerGateway, client: Client, command: GetGroupsCommand +) -> None: + """Get Zigbee groups.""" + groups: dict[int, Any] = {} + for group_id, group in gateway.groups.items(): + groups[int(group_id)] = ( + group.info_object + ) # maybe we should change the group_id type... + _LOGGER.info("groups: %s", groups) + client.send_result_success( + command, data={GROUPS: groups}, response_type=GroupsResponse + ) + + +class PermitJoiningCommand(WebSocketCommand): + """Permit joining.""" + + command: Literal[APICommands.PERMIT_JOINING] = APICommands.PERMIT_JOINING + duration: Annotated[int, Field(ge=1, le=254)] = 60 + ieee: Union[EUI64, None] = None + + +@decorators.websocket_command(PermitJoiningCommand) +@decorators.async_response +async def permit_joining( + gateway: WebSocketServerGateway, client: Client, command: PermitJoiningCommand +) -> None: + """Permit joining devices to the Zigbee network.""" + # TODO add permit with code support + await gateway.application_controller.permit(command.duration, command.ieee) + client.send_result_success(command, response_type=PermitJoiningResponse) + + +class RemoveDeviceCommand(WebSocketCommand): + """Remove device command.""" + + command: Literal[APICommands.REMOVE_DEVICE] = APICommands.REMOVE_DEVICE + ieee: EUI64 + + +@decorators.websocket_command(RemoveDeviceCommand) +@decorators.async_response +async def remove_device( + gateway: WebSocketServerGateway, client: Client, command: RemoveDeviceCommand +) -> None: + """Permit joining devices to the Zigbee network.""" + await gateway.async_remove_device(command.ieee) + client.send_result_success(command) + + +class ReadClusterAttributesCommand(WebSocketCommand): + """Read cluster attributes command.""" + + command: Literal[APICommands.READ_CLUSTER_ATTRIBUTES] = ( + APICommands.READ_CLUSTER_ATTRIBUTES + ) + ieee: EUI64 + endpoint_id: int + cluster_id: int + cluster_type: Literal["in", "out"] + attributes: list[str] + manufacturer_code: Union[int, None] = None + + +@decorators.websocket_command(ReadClusterAttributesCommand) +@decorators.async_response +async def read_cluster_attributes( + gateway: WebSocketServerGateway, + client: Client, + command: ReadClusterAttributesCommand, +) -> None: + """Read the specified cluster attributes.""" + device: Device = gateway.devices[command.ieee] + if not device: + client.send_result_error( + command, + "Device not found", + f"Device with ieee: {command.ieee} not found", + ) + return + endpoint_id = command.endpoint_id + cluster_id = command.cluster_id + cluster_type = command.cluster_type + attributes = command.attributes + manufacturer = command.manufacturer_code + if cluster_id >= MFG_CLUSTER_ID_START and manufacturer is None: + manufacturer = device.manufacturer_code + cluster = device.async_get_cluster( + endpoint_id, cluster_id, cluster_type=cluster_type + ) + if not cluster: + client.send_result_error( + command, + "Cluster not found", + f"Cluster: {endpoint_id}:{command.cluster_id} not found on device with ieee: {str(command.ieee)} not found", + ) + return + success, failure = await cluster.read_attributes( + attributes, allow_cache=False, only_cache=False, manufacturer=manufacturer + ) + + data = { + DEVICE: device.extended_device_info, + "cluster": { + "id": cluster.cluster_id, + "name": cluster.name, + "type": cluster.cluster_type, + "endpoint_id": cluster.endpoint.endpoint_id, + "endpoint_attribute": cluster.ep_attribute, + }, + "succeeded": success, + "failed": failure, + } + + client.send_result_success( + command, data=data, response_type=ReadClusterAttributesResponse + ) + + +class WriteClusterAttributeCommand(WebSocketCommand): + """Write cluster attribute command.""" + + command: Literal[APICommands.WRITE_CLUSTER_ATTRIBUTE] = ( + APICommands.WRITE_CLUSTER_ATTRIBUTE + ) + ieee: EUI64 + endpoint_id: int + cluster_id: int + cluster_type: Literal["in", "out"] + attribute: str + value: Union[str, int, float, bool] + manufacturer_code: Union[int, None] = None + + +@decorators.websocket_command(WriteClusterAttributeCommand) +@decorators.async_response +async def write_cluster_attribute( + gateway: WebSocketServerGateway, + client: Client, + command: WriteClusterAttributeCommand, +) -> None: + """Set the value of the specific cluster attribute.""" + device: Device = gateway.devices[command.ieee] + if not device: + client.send_result_error( + command, + "Device not found", + f"Device with ieee: {command.ieee} not found", + ) + return + endpoint_id = command.endpoint_id + cluster_id = command.cluster_id + cluster_type = command.cluster_type + attribute = command.attribute + value = command.value + manufacturer = command.manufacturer_code + if cluster_id >= MFG_CLUSTER_ID_START and manufacturer is None: + manufacturer = device.manufacturer_code + cluster = device.async_get_cluster( + endpoint_id, cluster_id, cluster_type=cluster_type + ) + if not cluster: + client.send_result_error( + command, + "Cluster not found", + f"Cluster: {endpoint_id}:{command.cluster_id} not found on device with ieee: {str(command.ieee)} not found", + ) + return + response = await device.write_zigbee_attribute( + endpoint_id, + cluster_id, + attribute, + value, + cluster_type=cluster_type, + manufacturer=manufacturer, + ) + + data = { + DEVICE: device.extended_device_info, + "cluster": { + "id": cluster.cluster_id, + "name": cluster.name, + "type": cluster.cluster_type, + "endpoint_id": cluster.endpoint.endpoint_id, + "endpoint_attribute": cluster.ep_attribute, + }, + "response": { + "attribute": attribute, + "status": response[0][0].status.name, # type: ignore + }, # TODO there has to be a better way to do this + } + + client.send_result_success( + command, data=data, response_type=WriteClusterAttributeResponse + ) + + +class CreateGroupCommand(WebSocketCommand): + """Create group command.""" + + command: Literal[APICommands.CREATE_GROUP] = APICommands.CREATE_GROUP + group_name: str + members: list[GroupMemberReference] + group_id: Union[int, None] = None + + +@decorators.websocket_command(CreateGroupCommand) +@decorators.async_response +async def create_group( + gateway: WebSocketServerGateway, client: Client, command: CreateGroupCommand +) -> None: + """Create a new group.""" + group_name = command.group_name + members = command.members + group_id = command.group_id + group: Group = await gateway.async_create_zigpy_group(group_name, members, group_id) + client.send_result_success( + command, data={GROUP: group.info_object}, response_type=UpdateGroupResponse + ) + + +class RemoveGroupsCommand(WebSocketCommand): + """Remove groups command.""" + + command: Literal[APICommands.REMOVE_GROUPS] = APICommands.REMOVE_GROUPS + group_ids: list[int] + + +@decorators.websocket_command(RemoveGroupsCommand) +@decorators.async_response +async def remove_groups( + gateway: WebSocketServerGateway, client: Client, command: RemoveGroupsCommand +) -> None: + """Remove the specified groups.""" + group_ids = command.group_ids + + if len(group_ids) > 1: + tasks = [] + for group_id in group_ids: + tasks.append(gateway.async_remove_zigpy_group(group_id)) + await asyncio.gather(*tasks) + else: + await gateway.async_remove_zigpy_group(group_ids[0]) + groups: dict[int, Any] = {} + for group_id, group in gateway.groups.items(): + groups[int(group_id)] = group.info_object + _LOGGER.info("groups: %s", groups) + client.send_result_success( + command, data={GROUPS: groups}, response_type=GroupsResponse + ) + + +class AddGroupMembersCommand(WebSocketCommand): + """Add group members command.""" + + command: Literal[ + APICommands.ADD_GROUP_MEMBERS, APICommands.REMOVE_GROUP_MEMBERS + ] = APICommands.ADD_GROUP_MEMBERS + group_id: int + members: list[GroupMemberReference] + + +@decorators.websocket_command(AddGroupMembersCommand) +@decorators.async_response +async def add_group_members( + gateway: WebSocketServerGateway, client: Client, command: AddGroupMembersCommand +) -> None: + """Add members to a ZHA group.""" + group_id = command.group_id + members = command.members + group = None + + if group_id in gateway.groups: + group = gateway.groups[group_id] + await group.async_add_members(members) + if not group: + client.send_result_error(command, "G1", "ZHA Group not found") + return + client.send_result_success( + command, data={GROUP: group.info_object}, response_type=UpdateGroupResponse + ) + + +class RemoveGroupMembersCommand(AddGroupMembersCommand): + """Remove group members command.""" + + command: Literal[APICommands.REMOVE_GROUP_MEMBERS] = ( + APICommands.REMOVE_GROUP_MEMBERS + ) + + +@decorators.websocket_command(RemoveGroupMembersCommand) +@decorators.async_response +async def remove_group_members( + gateway: WebSocketServerGateway, client: Client, command: RemoveGroupMembersCommand +) -> None: + """Remove members from a ZHA group.""" + group_id = command.group_id + members = command.members + group = None + + if group_id in gateway.groups: + group = gateway.groups[group_id] + await group.async_remove_members(members) + if not group: + client.send_result_error(command, "G1", "ZHA Group not found") + return + client.send_result_success( + command, data={GROUP: group.info_object}, response_type=UpdateGroupResponse + ) + + +class StopServerCommand(WebSocketCommand): + """Stop the server.""" + + command: Literal[APICommands.STOP_SERVER] = APICommands.STOP_SERVER + + +@decorators.websocket_command(StopServerCommand) +@decorators.async_response +async def stop_server( + gateway: WebSocketServerGateway, client: Client, command: WebSocketCommand +) -> None: + """Stop the Zigbee network.""" + client.send_result_success(command) + await gateway.stop_server() + + +class GetApplicationStateCommand(WebSocketCommand): + """Get the application state.""" + + command: Literal[APICommands.GET_APPLICATION_STATE] = ( + APICommands.GET_APPLICATION_STATE + ) + + +@decorators.websocket_command(GetApplicationStateCommand) +@decorators.async_response +async def get_application_state( + gateway: WebSocketServerGateway, client: Client, command: GetApplicationStateCommand +) -> None: + """Get the application state.""" + state = gateway.application_controller.state + client.send_result_success( + command, data={"state": state}, response_type=GetApplicationStateResponse + ) + + +def load_api(gateway: WebSocketServerGateway) -> None: + """Load the api command handlers.""" + register_api_command(gateway, start_network) + register_api_command(gateway, stop_network) + register_api_command(gateway, get_devices) + register_api_command(gateway, reconfigure_device) + register_api_command(gateway, get_groups) + register_api_command(gateway, create_group) + register_api_command(gateway, remove_groups) + register_api_command(gateway, add_group_members) + register_api_command(gateway, remove_group_members) + register_api_command(gateway, permit_joining) + register_api_command(gateway, remove_device) + register_api_command(gateway, update_topology) + register_api_command(gateway, read_cluster_attributes) + register_api_command(gateway, write_cluster_attribute) + register_api_command(gateway, stop_server) + register_api_command(gateway, get_application_state) diff --git a/zha/const.py b/zha/const.py index c96c47daf..e963d9a3f 100644 --- a/zha/const.py +++ b/zha/const.py @@ -8,12 +8,73 @@ EVENT_TYPE: Final[str] = "event_type" MESSAGE_TYPE: Final[str] = "message_type" +MODEL_CLASS_NAME: Final[str] = "model_class_name" + +COMMAND: Final[str] = "command" class EventTypes(StrEnum): """WS event types.""" - CONTROLLER_EVENT = "controller_event" + CONTROLLER_EVENT = "zha_gateway_message" PLATFORM_ENTITY_EVENT = "platform_entity_event" RAW_ZCL_EVENT = "raw_zcl_event" DEVICE_EVENT = "device_event" + ENTITY_EVENT = "entity" + CLUSTER_HANDLER_EVENT = "cluster_handler_event" + + +class ClusterHandlerEvents(StrEnum): + """Cluster handler events.""" + + CLUSTER_HANDLER_STATE_CHANGED = "cluster_handler_state_changed" + CLUSTER_HANDLER_ATTRIBUTE_UPDATED = "cluster_handler_attribute_updated" + + +class EntityEvents(StrEnum): + """Entity events.""" + + STATE_CHANGED = "state_changed" + + +class MessageTypes(StrEnum): + """WS message types.""" + + EVENT = "event" + RESULT = "result" + + +class ControllerEvents(StrEnum): + """WS controller events.""" + + DEVICE_JOINED = "device_joined" + RAW_DEVICE_INITIALIZED = "raw_device_initialized" + DEVICE_REMOVED = "device_removed" + DEVICE_LEFT = "device_left" + DEVICE_FULLY_INITIALIZED = "device_fully_initialized" + DEVICE_CONFIGURED = "device_configured" + GROUP_MEMBER_ADDED = "group_member_added" + GROUP_MEMBER_REMOVED = "group_member_removed" + GROUP_ADDED = "group_added" + GROUP_REMOVED = "group_removed" + CONNECTION_LOST = "connection_lost" + + +class PlatformEntityEvents(StrEnum): + """WS platform entity events.""" + + PLATFORM_ENTITY_STATE_CHANGED = "platform_entity_state_changed" + + +class RawZCLEvents(StrEnum): + """WS raw ZCL events.""" + + ATTRIBUTE_UPDATED = "attribute_updated" + + +class DeviceEvents(StrEnum): + """Events that devices can broadcast.""" + + DEVICE_OFFLINE = "device_offline" + DEVICE_ONLINE = "device_online" + ZHA_EVENT = "zha_event" diff --git a/zha/event.py b/zha/event.py index 6a31f775b..b78135089 100644 --- a/zha/event.py +++ b/zha/event.py @@ -1,4 +1,4 @@ -"""Provide Event base classes for zhaws.""" +"""Provide Event base classes for zha.""" from __future__ import annotations diff --git a/zha/model.py b/zha/model.py new file mode 100644 index 000000000..2a7088fd0 --- /dev/null +++ b/zha/model.py @@ -0,0 +1,171 @@ +"""Shared models for ZHA.""" + +from __future__ import annotations + +from collections.abc import Callable +from enum import Enum +import logging +from typing import Annotated, Any, Literal, Optional, Union, get_args + +from pydantic import ( + BaseModel as PydanticBaseModel, + ConfigDict, + Discriminator, + Field, + Tag, + computed_field, + field_serializer, + field_validator, +) +from zigpy.types.named import EUI64, NWK + +from zha.const import MODEL_CLASS_NAME, MessageTypes +from zha.event import EventBase + +_LOGGER = logging.getLogger(__name__) + + +def convert_ieee(ieee: Optional[Union[str, EUI64]]) -> Optional[EUI64]: + """Convert ieee to EUI64.""" + if ieee is None: + return None + if isinstance(ieee, str): + return EUI64.convert(ieee) + return ieee + + +def convert_nwk(nwk: Optional[Union[int, str, NWK]]) -> Optional[NWK]: + """Convert int to NWK.""" + if isinstance(nwk, int) and not isinstance(nwk, NWK): + return NWK(nwk) + if isinstance(nwk, str): + return NWK(int(nwk, base=16)) + return nwk + + +def convert_enum(enum_type: Enum) -> Callable[[str | Enum], Enum]: + """Convert enum name to enum instance.""" + + def _convert_enum(enum_name_or_instance: str | Enum) -> Enum: + """Convert extended_pan_id to ExtendedPanId.""" + if isinstance(enum_name_or_instance, str): + return enum_type[enum_name_or_instance] # type: ignore + return enum_name_or_instance + + return _convert_enum + + +def convert_int(zigpy_type: type) -> Any: + """Convert int to zigpy type.""" + + def _convert_int(value: int) -> Any: + """Convert int to zigpy type.""" + return zigpy_type(value) + + return _convert_int + + +class BaseModel(PydanticBaseModel): + """Base model for ZHA models.""" + + model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow") + + _convert_ieee = field_validator( + "ieee", "device_ieee", mode="before", check_fields=False + )(convert_ieee) + + _convert_nwk = field_validator( + "nwk", "dest_nwk", "next_hop", mode="before", check_fields=False + )(convert_nwk) + + @field_serializer("ieee", "device_ieee", check_fields=False) + def serialize_ieee(self, ieee: EUI64): + """Customize how ieee is serialized.""" + if ieee is not None: + return str(ieee) + return ieee + + @field_serializer( + "nwk", "dest_nwk", "next_hop", when_used="json", check_fields=False + ) + def serialize_nwk(self, nwk: NWK): + """Serialize nwk as hex string.""" + return repr(nwk) + + +class TypedBaseModel(BaseModel): + """Typed base model for use in discriminated unions.""" + + @computed_field # type: ignore + @property + def model_class_name(self) -> str: + """Property to create type field from class name when serializing.""" + return self.__class__.__name__ + + @classmethod + def _tag(cls): + """Create a pydantic `Tag` for this class to include it in tagged unions.""" + return Annotated[cls, Tag(cls.__name__)] + + @staticmethod + def _discriminator(): + """Create a pydantic `Discriminator` for a tagged union of `TypedBaseModel`.""" + return Field(discriminator=Discriminator(TypedBaseModel._get_model_class_name)) + + @staticmethod + def _get_model_class_name(x: Any) -> str | None: + """Get the model_class_name from an instance or serialized `dict` of `TypedBaseModel`. + + This is a callable for pydantic Discriminator to discriminate between types in a + tagged union of `TypedBaseModel` child classes. + + If given an instance of `TypedBaseModel` then this method is being called to + serialize an instance. The model_class_name field of the entry for this instance should be + its class name. + + If given a dictionary, then an instance is being deserialized. The name of the + class to be instantiated is given by the model_class_name field, and the remaining fields + should be passed as fields to the class. + + In any other case, return `None` to cause a pydantic validation error. + + Args: + x: `TypedBaseModel` instance or serialized `dict` of a `TypedBaseModel` + + """ + match x: + case TypedBaseModel(): + return x.__class__.__name__ + case dict() as serialized: + return serialized.pop(MODEL_CLASS_NAME, None) + case _: + return None + + +def as_tagged_union(union): + """Create a tagged union from a `Union` of `TypedBaseModel`. + + Members will be tagged with their class name to be discriminated by pydantic. + + Args: + union: `Union` of `TypedBaseModel` to convert to a tagged union + + """ + union_members = get_args(union) + + return Annotated[ + Union[tuple(cls._tag() for cls in union_members)], + TypedBaseModel._discriminator(), + ] + + +class BaseEvent(TypedBaseModel): + """Base model for ZHA events.""" + + message_type: Literal[MessageTypes.EVENT] = MessageTypes.EVENT + event_type: str + event: str + + +class BaseEventedModel(EventBase, BaseModel): + """Base evented model.""" diff --git a/zha/websocket/__init__.py b/zha/websocket/__init__.py new file mode 100644 index 000000000..0a01109b3 --- /dev/null +++ b/zha/websocket/__init__.py @@ -0,0 +1,9 @@ +"""Websocket module for Zigbee Home Automation.""" + +from __future__ import annotations + +from zha.exceptions import ZHAException + + +class ZHAWebSocketException(ZHAException): + """Exception raised by websocket errors.""" diff --git a/zha/websocket/client/__init__.py b/zha/websocket/client/__init__.py new file mode 100644 index 000000000..fdc0da558 --- /dev/null +++ b/zha/websocket/client/__init__.py @@ -0,0 +1 @@ +"""Client for the ZHA websocket server.""" diff --git a/zha/websocket/client/__main__.py b/zha/websocket/client/__main__.py new file mode 100644 index 000000000..7c42906e2 --- /dev/null +++ b/zha/websocket/client/__main__.py @@ -0,0 +1,9 @@ +"""Main module for zha.""" + +from websockets.__main__ import main as websockets_cli + +if __name__ == "__main__": + # "Importing this module enables command line editing using GNU readline." + import readline # noqa: F401 + + websockets_cli() diff --git a/zha/websocket/client/client.py b/zha/websocket/client/client.py new file mode 100644 index 000000000..d168b592e --- /dev/null +++ b/zha/websocket/client/client.py @@ -0,0 +1,285 @@ +"""Client implementation for the zha.client.""" + +from __future__ import annotations + +import asyncio +import contextlib +import logging +import pprint +from types import TracebackType +from typing import Any + +from aiohttp import ClientSession, ClientWebSocketResponse, client_exceptions +from aiohttp.http_websocket import WSMsgType +from async_timeout import timeout +from pydantic_core import ValidationError + +from zha.const import COMMAND, MESSAGE_TYPE, MessageTypes +from zha.event import EventBase +from zha.websocket import ZHAWebSocketException +from zha.websocket.client.model.messages import Message +from zha.websocket.const import ( + ERROR_CODE, + MESSAGE_ID, + SUCCESS, + ZIGBEE_ERROR, + ZIGBEE_ERROR_CODE, + ZIGBEE_ERROR_MESSAGE, +) +from zha.websocket.server.api.model import WebSocketCommand, WebSocketCommandResponse + +SIZE_PARSE_JSON_EXECUTOR = 8192 +_LOGGER = logging.getLogger(__package__) + + +class Client(EventBase): + """Class to manage the IoT connection.""" + + def __init__( + self, + ws_server_url: str, + *args: Any, + aiohttp_session: ClientSession | None = None, + **kwargs: Any, + ) -> None: + """Initialize the Client class.""" + super().__init__(*args, **kwargs) + self.ws_server_url = ws_server_url + + # Create a session if none is provided + if aiohttp_session is None: + self.aiohttp_session = ClientSession() + self._close_aiohttp_session: bool = True + else: + self.aiohttp_session = aiohttp_session + self._close_aiohttp_session = False + + # The WebSocket client + self._client: ClientWebSocketResponse | None = None + self._loop = asyncio.get_running_loop() + self._result_futures: dict[int, asyncio.Future] = {} + self._listen_task: asyncio.Task | None = None + self._tasks: set[asyncio.Task] = set() + + self._message_id = 0 + + def __repr__(self) -> str: + """Return the representation.""" + prefix = "" if self.connected else "not " + return f"{type(self).__name__}(ws_server_url={self.ws_server_url!r}, {prefix}connected)" + + @property + def connected(self) -> bool: + """Return if we're currently connected.""" + return self._client is not None and not self._client.closed + + def new_message_id(self) -> int: + """Create a new message ID. + + XXX: JSON doesn't define limits for integers but JavaScript itself internally + uses double precision floats for numbers (including in `JSON.parse`), setting + a hard limit of `Number.MAX_SAFE_INTEGER == 2^53 - 1`. We can be more + conservative and just restrict it to the maximum value of a 32-bit signed int. + """ + self._message_id = (self._message_id + 1) % 0x80000000 + return self._message_id + + async def async_send_command( + self, + command: WebSocketCommand, + ) -> WebSocketCommandResponse: + """Send a command and get a response.""" + future: asyncio.Future[WebSocketCommandResponse] = self._loop.create_future() + message_id = command.message_id = self.new_message_id() + self._result_futures[message_id] = future + + try: + async with timeout(20): + await self._send_json_message( + command.model_dump_json(exclude_none=True) + ) + return await future + except TimeoutError: + _LOGGER.exception("Timeout waiting for response") + return WebSocketCommandResponse.model_validate( + {MESSAGE_ID: message_id, SUCCESS: False, COMMAND: command.command} + ) + finally: + self._result_futures.pop(message_id) + + async def async_send_command_no_wait(self, command: WebSocketCommand) -> None: + """Send a command without waiting for the response.""" + command.message_id = self.new_message_id() + task = asyncio.create_task( + self._send_json_message(command.model_dump_json(exclude_none=True)), + name=f"async_send_command_no_wait:{command.command}", + ) + self._tasks.add(task) + task.add_done_callback(self._tasks.remove) + + async def connect(self) -> None: + """Connect to the websocket server.""" + + _LOGGER.debug("Trying to connect") + try: + self._client = await self.aiohttp_session.ws_connect( + self.ws_server_url, + heartbeat=55, + compress=15, + max_msg_size=0, + ) + except client_exceptions.ClientError as err: + _LOGGER.exception("Error connecting to server", exc_info=err) + raise ZHAWebSocketException from err + + async def listen_loop(self) -> None: + """Listen to the websocket.""" + assert self._client is not None + while not self._client.closed: + data = await self._receive_json_or_raise() + self._handle_incoming_message(data) + + async def listen(self) -> None: + """Start listening to the websocket.""" + if not self.connected: + raise ZHAWebSocketException("Not connected when start listening") + + assert self._client + + assert self._listen_task is None + self._listen_task = asyncio.create_task(self.listen_loop()) + + async def disconnect(self) -> None: + """Disconnect the client.""" + _LOGGER.debug("Closing client connection") + + if self._listen_task is not None: + self._listen_task.cancel() + + with contextlib.suppress(asyncio.CancelledError): + await self._listen_task + + self._listen_task = None + + if self._client is not None: + await self._client.close() + + if self._close_aiohttp_session: + await self.aiohttp_session.close() + + _LOGGER.debug("Listen completed. Cleaning up") + + for future in self._result_futures.values(): + future.cancel() + + self._result_futures.clear() + + async def _receive_json_or_raise(self) -> dict: + """Receive json or raise.""" + assert self._client + msg = await self._client.receive() + + if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.CLOSING): + raise ZHAWebSocketException(f"Connection was closed: {msg}") + + if msg.type == WSMsgType.ERROR: + raise ZHAWebSocketException(f"WS message type was ERROR: {msg}") + + if msg.type != WSMsgType.TEXT: + raise ZHAWebSocketException(f"Received non-Text message: {msg}") + + try: + if len(msg.data) > SIZE_PARSE_JSON_EXECUTOR: + data: dict = await self._loop.run_in_executor(None, msg.json) + else: + data = msg.json() + except ValueError as err: + raise ZHAWebSocketException(f"Received invalid JSON: {msg}") from err + + if _LOGGER.isEnabledFor(logging.DEBUG): + _LOGGER.debug("Received message:\n%s\n", pprint.pformat(msg)) + + return data + + def _handle_incoming_message(self, msg: dict) -> None: + """Handle incoming message. + + Run all async tasks in a wrapper to log appropriately. + """ + + try: + message = Message.model_validate(msg).root + except ValidationError as err: + _LOGGER.exception("Error parsing message: %s", msg, exc_info=err) + if msg[MESSAGE_TYPE] == MessageTypes.RESULT: + future = self._result_futures.get(msg[MESSAGE_ID]) + if future is not None: + future.set_exception(ZHAWebSocketException(err)) + return + return + + if message.message_type == MessageTypes.RESULT: + future = self._result_futures.get(message.message_id) + + if future is None: + _LOGGER.debug( + "Unable to handle result message because future for message: {message} is None" + ) + return + + if message.success: + future.set_result(message) + return + + if msg[ERROR_CODE] != ZIGBEE_ERROR: + error = ZHAWebSocketException(msg[MESSAGE_ID], msg[ERROR_CODE]) + else: + error = ZHAWebSocketException( + msg[MESSAGE_ID], + msg[ZIGBEE_ERROR_CODE], + msg[ZIGBEE_ERROR_MESSAGE], + ) + + future.set_exception(error) + return + + if message.message_type != MessageTypes.EVENT: + # Can't handle + _LOGGER.debug( + "Received message with unknown type '%s': %s", + msg[MESSAGE_TYPE], + msg, + ) + return + + try: + self.emit(message.event_type, message) + except Exception as err: # pylint: disable=broad-except + _LOGGER.exception("Error handling event", exc_info=err) + raise ZHAWebSocketException from err + + async def _send_json_message(self, message: str) -> None: + """Send a message. + + Raises NotConnected if client not connected. + """ + if not self.connected: + raise ZHAWebSocketException("Sending message failed: no active connection.") + + _LOGGER.debug("Publishing message:\n%s\n", pprint.pformat(message)) + + assert self._client + assert MESSAGE_ID in message + + await self._client.send_str(message) + + async def __aenter__(self) -> Client: + """Connect to the websocket.""" + await self.connect() + return self + + async def __aexit__( + self, exc_type: Exception, exc_value: str, traceback: TracebackType + ) -> None: + """Disconnect from the websocket.""" + await self.disconnect() diff --git a/zha/websocket/client/helpers.py b/zha/websocket/client/helpers.py new file mode 100644 index 000000000..bf6b3dc00 --- /dev/null +++ b/zha/websocket/client/helpers.py @@ -0,0 +1,1071 @@ +"""Helper classes for zha.client.""" + +from __future__ import annotations + +from typing import Any, Literal, cast + +from zigpy.types.named import EUI64 + +from zha.application.const import ATTR_ENDPOINT_ID, ATTR_IEEE, ATTR_MEMBERS +from zha.application.platforms.alarm_control_panel.model import ( + AlarmControlPanelEntityInfo, +) +from zha.application.platforms.alarm_control_panel.websocket_api import ( + ArmAwayCommand, + ArmHomeCommand, + ArmNightCommand, + DisarmCommand, + TriggerAlarmCommand, +) +from zha.application.platforms.button.model import ButtonEntityInfo +from zha.application.platforms.button.websocket_api import ButtonPressCommand +from zha.application.platforms.climate.model import ThermostatEntityInfo +from zha.application.platforms.climate.websocket_api import ( + ClimateSetFanModeCommand, + ClimateSetHVACModeCommand, + ClimateSetPresetModeCommand, + ClimateSetTemperatureCommand, +) +from zha.application.platforms.cover.model import CoverEntityInfo +from zha.application.platforms.cover.websocket_api import ( + CoverCloseCommand, + CoverCloseTiltCommand, + CoverOpenCommand, + CoverOpenTiltCommand, + CoverRestoreExternalStateAttributesCommand, + CoverSetPositionCommand, + CoverSetTiltPositionCommand, + CoverStopCommand, + CoverStopTiltCommand, +) +from zha.application.platforms.fan.model import FanEntityInfo +from zha.application.platforms.fan.websocket_api import ( + FanSetPercentageCommand, + FanSetPresetModeCommand, + FanTurnOffCommand, + FanTurnOnCommand, +) +from zha.application.platforms.light.const import ColorMode +from zha.application.platforms.light.model import LightEntityInfo +from zha.application.platforms.light.websocket_api import ( + LightRestoreExternalStateAttributesCommand, + LightTurnOffCommand, + LightTurnOnCommand, +) +from zha.application.platforms.lock.model import LockEntityInfo +from zha.application.platforms.lock.websocket_api import ( + LockClearUserLockCodeCommand, + LockDisableUserLockCodeCommand, + LockEnableUserLockCodeCommand, + LockLockCommand, + LockRestoreExternalStateAttributesCommand, + LockSetUserLockCodeCommand, + LockUnlockCommand, +) +from zha.application.platforms.model import BasePlatformEntityInfo +from zha.application.platforms.number.model import NumberEntityInfo +from zha.application.platforms.number.websocket_api import NumberSetValueCommand +from zha.application.platforms.select.model import SelectEntityInfo +from zha.application.platforms.select.websocket_api import ( + SelectRestoreExternalStateAttributesCommand, + SelectSelectOptionCommand, +) +from zha.application.platforms.siren.model import SirenEntityInfo +from zha.application.platforms.siren.websocket_api import ( + SirenTurnOffCommand, + SirenTurnOnCommand, +) +from zha.application.platforms.switch.websocket_api import ( + SwitchTurnOffCommand, + SwitchTurnOnCommand, +) +from zha.application.platforms.update import WebSocketClientFirmwareUpdateEntity +from zha.application.platforms.update.websocket_api import InstallFirmwareCommand +from zha.application.platforms.websocket_api import ( + PlatformEntityDisableCommand, + PlatformEntityEnableCommand, + PlatformEntityRefreshStateCommand, +) +from zha.application.websocket_api import ( + AddGroupMembersCommand, + CreateGroupCommand, + GetApplicationStateCommand, + GetApplicationStateResponse, + GetDevicesCommand, + GetGroupsCommand, + PermitJoiningCommand, + ReadClusterAttributesCommand, + ReconfigureDeviceCommand, + RemoveDeviceCommand, + RemoveGroupMembersCommand, + RemoveGroupsCommand, + StartNetworkCommand, + StopNetworkCommand, + StopServerCommand, + UpdateTopologyCommand, + WriteClusterAttributeCommand, +) +from zha.websocket.client.client import Client +from zha.websocket.const import GROUP_ID, GROUP_IDS, GROUP_NAME +from zha.websocket.server.api.model import ( + GetDevicesResponse, + GroupsResponse, + PermitJoiningResponse, + ReadClusterAttributesResponse, + UpdateGroupResponse, + WebSocketCommandResponse, + WriteClusterAttributeResponse, +) +from zha.websocket.server.client import ( + ClientDisconnectCommand, + ClientListenCommand, + ClientListenRawZCLCommand, +) +from zha.zigbee.model import ExtendedDeviceInfo, GroupInfo, GroupMemberReference + + +class LightHelper: + """Helper to issue light commands.""" + + def __init__(self, client: Client): + """Initialize the light helper.""" + self._client: Client = client + + async def turn_on( + self, + light_platform_entity: LightEntityInfo, + brightness: int | None = None, + transition: int | None = None, + flash: str | None = None, + effect: str | None = None, + xy_color: tuple | None = None, + color_temp: int | None = None, + ) -> WebSocketCommandResponse: + """Turn on a light.""" + command = LightTurnOnCommand( + ieee=light_platform_entity.device_ieee, + group_id=light_platform_entity.group_id, + unique_id=light_platform_entity.unique_id, + brightness=brightness, + transition=transition, + flash=flash, + effect=effect, + xy_color=xy_color, + color_temp=color_temp, + ) + return await self._client.async_send_command(command) + + async def turn_off( + self, + light_platform_entity: LightEntityInfo, + transition: int | None = None, + flash: bool | None = None, + ) -> WebSocketCommandResponse: + """Turn off a light.""" + command = LightTurnOffCommand( + ieee=light_platform_entity.device_ieee, + group_id=light_platform_entity.group_id, + unique_id=light_platform_entity.unique_id, + transition=transition, + flash=flash, + ) + return await self._client.async_send_command(command) + + async def restore_external_state_attributes( + self, + light_platform_entity: LightEntityInfo, + state: bool | None, + off_with_transition: bool | None, + off_brightness: int | None, + brightness: int | None, + color_temp: int | None, + xy_color: tuple[float, float] | None, + color_mode: ColorMode | None, + effect: str | None, + ) -> None: + """Restore extra state attributes that are stored outside of the ZCL cache.""" + command = LightRestoreExternalStateAttributesCommand( + ieee=light_platform_entity.device_ieee, + group_id=light_platform_entity.group_id, + unique_id=light_platform_entity.unique_id, + state=state, + off_with_transition=off_with_transition, + off_brightness=off_brightness, + brightness=brightness, + color_temp=color_temp, + xy_color=xy_color, + color_mode=color_mode, + effect=effect, + ) + await self._client.async_send_command(command) + + +class SwitchHelper: + """Helper to issue switch commands.""" + + def __init__(self, client: Client): + """Initialize the switch helper.""" + self._client: Client = client + + async def turn_on( + self, + switch_platform_entity: LightEntityInfo, + ) -> WebSocketCommandResponse: + """Turn on a switch.""" + command = SwitchTurnOnCommand( + ieee=switch_platform_entity.device_ieee, + group_id=switch_platform_entity.group_id, + unique_id=switch_platform_entity.unique_id, + ) + return await self._client.async_send_command(command) + + async def turn_off( + self, + switch_platform_entity: LightEntityInfo, + ) -> WebSocketCommandResponse: + """Turn off a switch.""" + command = SwitchTurnOffCommand( + ieee=switch_platform_entity.device_ieee, + group_id=switch_platform_entity.group_id, + unique_id=switch_platform_entity.unique_id, + ) + return await self._client.async_send_command(command) + + +class SirenHelper: + """Helper to issue siren commands.""" + + def __init__(self, client: Client): + """Initialize the siren helper.""" + self._client: Client = client + + async def turn_on( + self, + siren_platform_entity: SirenEntityInfo, + duration: int | None = None, + volume_level: int | None = None, + tone: int | None = None, + ) -> WebSocketCommandResponse: + """Turn on a siren.""" + command = SirenTurnOnCommand( + ieee=siren_platform_entity.device_ieee, + unique_id=siren_platform_entity.unique_id, + duration=duration, + volume_level=volume_level, + tone=tone, + ) + return await self._client.async_send_command(command) + + async def turn_off( + self, siren_platform_entity: SirenEntityInfo + ) -> WebSocketCommandResponse: + """Turn off a siren.""" + command = SirenTurnOffCommand( + ieee=siren_platform_entity.device_ieee, + unique_id=siren_platform_entity.unique_id, + ) + return await self._client.async_send_command(command) + + +class ButtonHelper: + """Helper to issue button commands.""" + + def __init__(self, client: Client): + """Initialize the button helper.""" + self._client: Client = client + + async def press( + self, button_platform_entity: ButtonEntityInfo + ) -> WebSocketCommandResponse: + """Press a button.""" + command = ButtonPressCommand( + ieee=button_platform_entity.device_ieee, + unique_id=button_platform_entity.unique_id, + ) + return await self._client.async_send_command(command) + + +class CoverHelper: + """helper to issue cover commands.""" + + def __init__(self, client: Client): + """Initialize the cover helper.""" + self._client: Client = client + + async def open_cover( + self, cover_platform_entity: CoverEntityInfo + ) -> WebSocketCommandResponse: + """Open a cover.""" + command = CoverOpenCommand( + ieee=cover_platform_entity.device_ieee, + unique_id=cover_platform_entity.unique_id, + ) + return await self._client.async_send_command(command) + + async def close_cover( + self, cover_platform_entity: CoverEntityInfo + ) -> WebSocketCommandResponse: + """Close a cover.""" + command = CoverCloseCommand( + ieee=cover_platform_entity.device_ieee, + unique_id=cover_platform_entity.unique_id, + ) + return await self._client.async_send_command(command) + + async def open_cover_tilt( + self, cover_platform_entity: CoverEntityInfo + ) -> WebSocketCommandResponse: + """Open cover tilt.""" + command = CoverOpenTiltCommand( + ieee=cover_platform_entity.device_ieee, + unique_id=cover_platform_entity.unique_id, + ) + return await self._client.async_send_command(command) + + async def close_cover_tilt( + self, cover_platform_entity: CoverEntityInfo + ) -> WebSocketCommandResponse: + """Open cover tilt.""" + command = CoverCloseTiltCommand( + ieee=cover_platform_entity.device_ieee, + unique_id=cover_platform_entity.unique_id, + ) + return await self._client.async_send_command(command) + + async def stop_cover( + self, cover_platform_entity: CoverEntityInfo + ) -> WebSocketCommandResponse: + """Stop a cover.""" + command = CoverStopCommand( + ieee=cover_platform_entity.device_ieee, + unique_id=cover_platform_entity.unique_id, + ) + return await self._client.async_send_command(command) + + async def set_cover_position( + self, + cover_platform_entity: CoverEntityInfo, + position: int, + ) -> WebSocketCommandResponse: + """Set a cover position.""" + command = CoverSetPositionCommand( + ieee=cover_platform_entity.device_ieee, + unique_id=cover_platform_entity.unique_id, + position=position, + ) + return await self._client.async_send_command(command) + + async def set_cover_tilt_position( + self, + cover_platform_entity: CoverEntityInfo, + tilt_position: int, + ) -> WebSocketCommandResponse: + """Set a cover tilt position.""" + command = CoverSetTiltPositionCommand( + ieee=cover_platform_entity.device_ieee, + unique_id=cover_platform_entity.unique_id, + tilt_position=tilt_position, + ) + return await self._client.async_send_command(command) + + async def stop_cover_tilt( + self, cover_platform_entity: CoverEntityInfo + ) -> WebSocketCommandResponse: + """Stop a cover tilt.""" + command = CoverStopTiltCommand( + ieee=cover_platform_entity.device_ieee, + unique_id=cover_platform_entity.unique_id, + ) + return await self._client.async_send_command(command) + + async def restore_external_state_attributes( + self, + cover_platform_entity: CoverEntityInfo, + state: Literal["open", "opening", "closed", "closing"], + target_lift_position: int, + target_tilt_position: int, + ) -> WebSocketCommandResponse: + """Stop a cover tilt.""" + command = CoverRestoreExternalStateAttributesCommand( + ieee=cover_platform_entity.device_ieee, + unique_id=cover_platform_entity.unique_id, + state=state, + target_lift_position=target_lift_position, + target_tilt_position=target_tilt_position, + ) + return await self._client.async_send_command(command) + + +class FanHelper: + """Helper to issue fan commands.""" + + def __init__(self, client: Client): + """Initialize the fan helper.""" + self._client: Client = client + + async def turn_on( + self, + fan_platform_entity: FanEntityInfo, + speed: str | None = None, + percentage: int | None = None, + preset_mode: str | None = None, + ) -> WebSocketCommandResponse: + """Turn on a fan.""" + command = FanTurnOnCommand( + ieee=fan_platform_entity.device_ieee, + group_id=fan_platform_entity.group_id, + unique_id=fan_platform_entity.unique_id, + speed=speed, + percentage=percentage, + preset_mode=preset_mode, + ) + return await self._client.async_send_command(command) + + async def turn_off( + self, + fan_platform_entity: FanEntityInfo, + ) -> WebSocketCommandResponse: + """Turn off a fan.""" + command = FanTurnOffCommand( + ieee=fan_platform_entity.device_ieee, + group_id=fan_platform_entity.group_id, + unique_id=fan_platform_entity.unique_id, + ) + return await self._client.async_send_command(command) + + async def set_fan_percentage( + self, + fan_platform_entity: FanEntityInfo, + percentage: int, + ) -> WebSocketCommandResponse: + """Set a fan percentage.""" + command = FanSetPercentageCommand( + ieee=fan_platform_entity.device_ieee, + group_id=fan_platform_entity.group_id, + unique_id=fan_platform_entity.unique_id, + percentage=percentage, + ) + return await self._client.async_send_command(command) + + async def set_fan_preset_mode( + self, + fan_platform_entity: FanEntityInfo, + preset_mode: str, + ) -> WebSocketCommandResponse: + """Set a fan preset mode.""" + command = FanSetPresetModeCommand( + ieee=fan_platform_entity.device_ieee, + group_id=fan_platform_entity.group_id, + unique_id=fan_platform_entity.unique_id, + preset_mode=preset_mode, + ) + return await self._client.async_send_command(command) + + +class LockHelper: + """Helper to issue lock commands.""" + + def __init__(self, client: Client): + """Initialize the lock helper.""" + self._client: Client = client + + async def lock( + self, lock_platform_entity: LockEntityInfo + ) -> WebSocketCommandResponse: + """Lock a lock.""" + command = LockLockCommand( + ieee=lock_platform_entity.device_ieee, + unique_id=lock_platform_entity.unique_id, + ) + return await self._client.async_send_command(command) + + async def unlock( + self, lock_platform_entity: LockEntityInfo + ) -> WebSocketCommandResponse: + """Unlock a lock.""" + command = LockUnlockCommand( + ieee=lock_platform_entity.device_ieee, + unique_id=lock_platform_entity.unique_id, + ) + return await self._client.async_send_command(command) + + async def set_user_lock_code( + self, + lock_platform_entity: LockEntityInfo, + code_slot: int, + user_code: str, + ) -> WebSocketCommandResponse: + """Set a user lock code.""" + command = LockSetUserLockCodeCommand( + ieee=lock_platform_entity.device_ieee, + unique_id=lock_platform_entity.unique_id, + code_slot=code_slot, + user_code=user_code, + ) + return await self._client.async_send_command(command) + + async def clear_user_lock_code( + self, + lock_platform_entity: LockEntityInfo, + code_slot: int, + ) -> WebSocketCommandResponse: + """Clear a user lock code.""" + command = LockClearUserLockCodeCommand( + ieee=lock_platform_entity.device_ieee, + unique_id=lock_platform_entity.unique_id, + code_slot=code_slot, + ) + return await self._client.async_send_command(command) + + async def enable_user_lock_code( + self, + lock_platform_entity: LockEntityInfo, + code_slot: int, + ) -> WebSocketCommandResponse: + """Enable a user lock code.""" + command = LockEnableUserLockCodeCommand( + ieee=lock_platform_entity.device_ieee, + unique_id=lock_platform_entity.unique_id, + code_slot=code_slot, + ) + return await self._client.async_send_command(command) + + async def disable_user_lock_code( + self, + lock_platform_entity: LockEntityInfo, + code_slot: int, + ) -> WebSocketCommandResponse: + """Disable a user lock code.""" + command = LockDisableUserLockCodeCommand( + ieee=lock_platform_entity.device_ieee, + unique_id=lock_platform_entity.unique_id, + code_slot=code_slot, + ) + return await self._client.async_send_command(command) + + async def restore_external_state_attributes( + self, + lock_platform_entity: LockEntityInfo, + state: Literal["locked", "unlocked"] | None, + ) -> WebSocketCommandResponse: + """Restore external state attributes.""" + command = LockRestoreExternalStateAttributesCommand( + ieee=lock_platform_entity.device_ieee, + unique_id=lock_platform_entity.unique_id, + state=state, + ) + return await self._client.async_send_command(command) + + +class NumberHelper: + """Helper to issue number commands.""" + + def __init__(self, client: Client): + """Initialize the number helper.""" + self._client: Client = client + + async def set_value( + self, + number_platform_entity: NumberEntityInfo, + value: int | float, + ) -> WebSocketCommandResponse: + """Set a number.""" + command = NumberSetValueCommand( + ieee=number_platform_entity.device_ieee, + unique_id=number_platform_entity.unique_id, + value=value, + ) + return await self._client.async_send_command(command) + + +class SelectHelper: + """Helper to issue select commands.""" + + def __init__(self, client: Client): + """Initialize the select helper.""" + self._client: Client = client + + async def select_option( + self, + select_platform_entity: SelectEntityInfo, + option: str | int, + ) -> WebSocketCommandResponse: + """Set a select.""" + command = SelectSelectOptionCommand( + ieee=select_platform_entity.device_ieee, + unique_id=select_platform_entity.unique_id, + option=option, + ) + return await self._client.async_send_command(command) + + async def restore_external_state_attributes( + self, + select_platform_entity: SelectEntityInfo, + state: str | None, + ) -> WebSocketCommandResponse: + """Restore external state attributes.""" + command = SelectRestoreExternalStateAttributesCommand( + ieee=select_platform_entity.device_ieee, + unique_id=select_platform_entity.unique_id, + state=state, + ) + return await self._client.async_send_command(command) + + +class ClimateHelper: + """Helper to issue climate commands.""" + + def __init__(self, client: Client): + """Initialize the climate helper.""" + self._client: Client = client + + async def set_hvac_mode( + self, + climate_platform_entity: ThermostatEntityInfo, + hvac_mode: Literal[ + "heat_cool", "heat", "cool", "auto", "dry", "fan_only", "off" + ], + ) -> WebSocketCommandResponse: + """Set a climate.""" + command = ClimateSetHVACModeCommand( + ieee=climate_platform_entity.device_ieee, + unique_id=climate_platform_entity.unique_id, + hvac_mode=hvac_mode, + ) + return await self._client.async_send_command(command) + + async def set_temperature( + self, + climate_platform_entity: ThermostatEntityInfo, + hvac_mode: None + | ( + Literal["heat_cool", "heat", "cool", "auto", "dry", "fan_only", "off"] + ) = None, + temperature: float | None = None, + target_temp_high: float | None = None, + target_temp_low: float | None = None, + ) -> WebSocketCommandResponse: + """Set a climate.""" + command = ClimateSetTemperatureCommand( + ieee=climate_platform_entity.device_ieee, + unique_id=climate_platform_entity.unique_id, + temperature=temperature, + target_temp_high=target_temp_high, + target_temp_low=target_temp_low, + hvac_mode=hvac_mode, + ) + return await self._client.async_send_command(command) + + async def set_fan_mode( + self, + climate_platform_entity: ThermostatEntityInfo, + fan_mode: str, + ) -> WebSocketCommandResponse: + """Set a climate.""" + command = ClimateSetFanModeCommand( + ieee=climate_platform_entity.device_ieee, + unique_id=climate_platform_entity.unique_id, + fan_mode=fan_mode, + ) + return await self._client.async_send_command(command) + + async def set_preset_mode( + self, + climate_platform_entity: ThermostatEntityInfo, + preset_mode: str, + ) -> WebSocketCommandResponse: + """Set a climate.""" + command = ClimateSetPresetModeCommand( + ieee=climate_platform_entity.device_ieee, + unique_id=climate_platform_entity.unique_id, + preset_mode=preset_mode, + ) + return await self._client.async_send_command(command) + + +class AlarmControlPanelHelper: + """Helper to issue alarm control panel commands.""" + + def __init__(self, client: Client): + """Initialize the alarm control panel helper.""" + self._client: Client = client + + async def disarm( + self, + alarm_control_panel_platform_entity: AlarmControlPanelEntityInfo, + code: str, + ) -> WebSocketCommandResponse: + """Disarm an alarm control panel.""" + command = DisarmCommand( + ieee=alarm_control_panel_platform_entity.device_ieee, + unique_id=alarm_control_panel_platform_entity.unique_id, + code=code, + ) + return await self._client.async_send_command(command) + + async def arm_home( + self, + alarm_control_panel_platform_entity: AlarmControlPanelEntityInfo, + code: str, + ) -> WebSocketCommandResponse: + """Arm an alarm control panel in home mode.""" + command = ArmHomeCommand( + ieee=alarm_control_panel_platform_entity.device_ieee, + unique_id=alarm_control_panel_platform_entity.unique_id, + code=code, + ) + return await self._client.async_send_command(command) + + async def arm_away( + self, + alarm_control_panel_platform_entity: AlarmControlPanelEntityInfo, + code: str, + ) -> WebSocketCommandResponse: + """Arm an alarm control panel in away mode.""" + command = ArmAwayCommand( + ieee=alarm_control_panel_platform_entity.device_ieee, + unique_id=alarm_control_panel_platform_entity.unique_id, + code=code, + ) + return await self._client.async_send_command(command) + + async def arm_night( + self, + alarm_control_panel_platform_entity: AlarmControlPanelEntityInfo, + code: str, + ) -> WebSocketCommandResponse: + """Arm an alarm control panel in night mode.""" + command = ArmNightCommand( + ieee=alarm_control_panel_platform_entity.device_ieee, + unique_id=alarm_control_panel_platform_entity.unique_id, + code=code, + ) + return await self._client.async_send_command(command) + + async def trigger( + self, + alarm_control_panel_platform_entity: AlarmControlPanelEntityInfo, + ) -> WebSocketCommandResponse: + """Trigger an alarm control panel alarm.""" + command = TriggerAlarmCommand( + ieee=alarm_control_panel_platform_entity.device_ieee, + unique_id=alarm_control_panel_platform_entity.unique_id, + ) + return await self._client.async_send_command(command) + + +class PlatformEntityHelper: + """Helper to send global platform entity commands.""" + + def __init__(self, client: Client): + """Initialize the platform entity helper.""" + self._client: Client = client + + async def refresh_state( + self, platform_entity: BasePlatformEntityInfo + ) -> WebSocketCommandResponse: + """Refresh the state of a platform entity.""" + command = PlatformEntityRefreshStateCommand( + ieee=platform_entity.device_ieee, + unique_id=platform_entity.unique_id, + platform=platform_entity.platform, + ) + return await self._client.async_send_command(command) + + async def enable( + self, platform_entity: BasePlatformEntityInfo + ) -> WebSocketCommandResponse: + """Enable a platform entity.""" + command = PlatformEntityEnableCommand( + ieee=platform_entity.device_ieee, + unique_id=platform_entity.unique_id, + platform=platform_entity.platform, + ) + return await self._client.async_send_command(command) + + async def disable( + self, platform_entity: BasePlatformEntityInfo + ) -> WebSocketCommandResponse: + """Disable a platform entity.""" + command = PlatformEntityDisableCommand( + ieee=platform_entity.device_ieee, + unique_id=platform_entity.unique_id, + platform=platform_entity.platform, + ) + return await self._client.async_send_command(command) + + +class ClientHelper: + """Helper to send client specific commands.""" + + def __init__(self, client: Client): + """Initialize the client helper.""" + self._client: Client = client + + async def listen(self) -> WebSocketCommandResponse: + """Listen for incoming messages.""" + command = ClientListenCommand() + await self._client.listen() + return await self._client.async_send_command(command) + + async def listen_raw_zcl(self) -> WebSocketCommandResponse: + """Listen for incoming raw ZCL messages.""" + command = ClientListenRawZCLCommand() + return await self._client.async_send_command(command) + + async def disconnect(self) -> None: + """Disconnect this client from the server.""" + command = ClientDisconnectCommand() + await self._client.async_send_command_no_wait(command) + await self._client.disconnect() + + +class GroupHelper: + """Helper to send group commands.""" + + def __init__(self, client: Client): + """Initialize the group helper.""" + self._client: Client = client + + async def get_groups(self) -> dict[int, GroupInfo]: + """Get the groups.""" + response = cast( + GroupsResponse, + await self._client.async_send_command(GetGroupsCommand()), + ) + return response.groups + + async def create_group( + self, + name: str, + group_id: int | None = None, + members: list[GroupMemberReference] | None = None, + ) -> GroupInfo: + """Create a new group.""" + request_data: dict[str, Any] = { + GROUP_NAME: name, + GROUP_ID: group_id, + } + if members is not None: + request_data[ATTR_MEMBERS] = [ + {ATTR_IEEE: member.ieee, ATTR_ENDPOINT_ID: member.endpoint_id} + for member in members + ] + + command = CreateGroupCommand(**request_data) + response = cast( + UpdateGroupResponse, + await self._client.async_send_command(command), + ) + return response.group + + async def remove_groups(self, groups: list[GroupInfo]) -> dict[int, GroupInfo]: + """Remove groups.""" + request: dict[str, Any] = { + GROUP_IDS: [group.group_id for group in groups], + } + command = RemoveGroupsCommand(**request) + response = cast( + GroupsResponse, + await self._client.async_send_command(command), + ) + return response.groups + + async def add_group_members( + self, group: GroupInfo, members: list[GroupMemberReference] + ) -> GroupInfo: + """Add members to a group.""" + request_data: dict[str, Any] = { + GROUP_ID: group.group_id, + ATTR_MEMBERS: [ + {ATTR_IEEE: member.ieee, ATTR_ENDPOINT_ID: member.endpoint_id} + for member in members + ], + } + + command = AddGroupMembersCommand(**request_data) + response = cast( + UpdateGroupResponse, + await self._client.async_send_command(command), + ) + return response.group + + async def remove_group_members( + self, group: GroupInfo, members: list[GroupMemberReference] + ) -> GroupInfo: + """Remove members from a group.""" + request_data: dict[str, Any] = { + GROUP_ID: group.group_id, + ATTR_MEMBERS: [ + {ATTR_IEEE: member.ieee, ATTR_ENDPOINT_ID: member.endpoint_id} + for member in members + ], + } + + command = RemoveGroupMembersCommand(**request_data) + response = cast( + UpdateGroupResponse, + await self._client.async_send_command(command), + ) + return response.group + + +class UpdateHelper: + """Helper to send firmware update commands.""" + + def __init__(self, client: Client): + """Initialize the device helper.""" + self._client: Client = client + + async def install_firmware( + self, + firmware_update_entity: WebSocketClientFirmwareUpdateEntity, + version: str | None = None, + ) -> dict[EUI64, ExtendedDeviceInfo]: + """Get the groups.""" + + return await self._client.async_send_command( + InstallFirmwareCommand( + ieee=firmware_update_entity.info_object.device_ieee, + unique_id=firmware_update_entity.info_object.unique_id, + platform=firmware_update_entity.info_object.platform, + version=version, + ) + ) + + +class DeviceHelper: + """Helper to send device commands.""" + + def __init__(self, client: Client): + """Initialize the device helper.""" + self._client: Client = client + + async def get_devices(self) -> dict[EUI64, ExtendedDeviceInfo]: + """Get the groups.""" + response = cast( + GetDevicesResponse, + await self._client.async_send_command(GetDevicesCommand()), + ) + return response.devices + + async def reconfigure_device(self, device: ExtendedDeviceInfo) -> None: + """Reconfigure a device.""" + await self._client.async_send_command( + ReconfigureDeviceCommand(ieee=device.ieee) + ) + + async def remove_device(self, device: ExtendedDeviceInfo) -> None: + """Remove a device.""" + await self._client.async_send_command(RemoveDeviceCommand(ieee=device.ieee)) + + async def read_cluster_attributes( + self, + device: ExtendedDeviceInfo, + cluster_id: int, + cluster_type: str, + endpoint_id: int, + attributes: list[str], + manufacturer_code: int | None = None, + ) -> ReadClusterAttributesResponse: + """Read cluster attributes.""" + response = cast( + ReadClusterAttributesResponse, + await self._client.async_send_command( + ReadClusterAttributesCommand( + ieee=device.ieee, + endpoint_id=endpoint_id, + cluster_id=cluster_id, + cluster_type=cluster_type, + attributes=attributes, + manufacturer_code=manufacturer_code, + ) + ), + ) + return response + + async def write_cluster_attribute( + self, + device: ExtendedDeviceInfo, + cluster_id: int, + cluster_type: str, + endpoint_id: int, + attribute: str, + value: Any, + manufacturer_code: int | None = None, + ) -> WriteClusterAttributeResponse: + """Set the value for a cluster attribute.""" + response = cast( + WriteClusterAttributeResponse, + await self._client.async_send_command( + WriteClusterAttributeCommand( + ieee=device.ieee, + endpoint_id=endpoint_id, + cluster_id=cluster_id, + cluster_type=cluster_type, + attribute=attribute, + value=value, + manufacturer_code=manufacturer_code, + ) + ), + ) + return response + + +class NetworkHelper: + """Helper for network commands.""" + + def __init__(self, client: Client): + """Initialize the device helper.""" + self._client: Client = client + + async def permit_joining( + self, duration: int = 255, device: ExtendedDeviceInfo | None = None + ) -> bool: + """Permit joining for a specified duration.""" + # TODO add permit with code support + request_data: dict[str, Any] = { + "duration": duration, + } + if device is not None: + if device.device_type == "EndDevice": + raise ValueError("Device is not a coordinator or router") + request_data[ATTR_IEEE] = device.ieee + command = PermitJoiningCommand(**request_data) + response = cast( + PermitJoiningResponse, + await self._client.async_send_command(command), + ) + return response.success + + async def update_topology(self) -> None: + """Update the network topology.""" + await self._client.async_send_command(UpdateTopologyCommand()) + + async def start_network(self) -> bool: + """Start the Zigbee network.""" + command = StartNetworkCommand() + response = await self._client.async_send_command(command) + return response.success + + async def stop_network(self) -> bool: + """Stop the Zigbee network.""" + response = await self._client.async_send_command(StopNetworkCommand()) + return response.success + + async def get_application_state(self) -> GetApplicationStateResponse: + """Get the application state.""" + return await self._client.async_send_command(GetApplicationStateCommand()) + + +class ServerHelper: + """Helper for server commands.""" + + def __init__(self, client: Client): + """Initialize the helper.""" + self._client: Client = client + + async def stop_server(self) -> bool: + """Stop the websocket server.""" + response = await self._client.async_send_command(StopServerCommand()) + return response.success diff --git a/zha/websocket/client/model/__init__.py b/zha/websocket/client/model/__init__.py new file mode 100644 index 000000000..9f32bfa2f --- /dev/null +++ b/zha/websocket/client/model/__init__.py @@ -0,0 +1 @@ +"""Models for the websocket client module for zha.""" diff --git a/zha/websocket/client/model/messages.py b/zha/websocket/client/model/messages.py new file mode 100644 index 000000000..59a12c10a --- /dev/null +++ b/zha/websocket/client/model/messages.py @@ -0,0 +1,11 @@ +"""Models that represent messages in zha.""" + +from pydantic import RootModel + +from zha.websocket.server.api.model import Messages + + +class Message(RootModel): + """Response model.""" + + root: Messages diff --git a/zha/websocket/const.py b/zha/websocket/const.py new file mode 100644 index 000000000..609980e63 --- /dev/null +++ b/zha/websocket/const.py @@ -0,0 +1,113 @@ +"""Constants.""" + +from enum import StrEnum +from typing import Final + + +class APICommands(StrEnum): + """WS API commands.""" + + # Device commands + GET_DEVICES = "get_devices" + REMOVE_DEVICE = "remove_device" + RECONFIGURE_DEVICE = "reconfigure_device" + READ_CLUSTER_ATTRIBUTES = "read_cluster_attributes" + WRITE_CLUSTER_ATTRIBUTE = "write_cluster_attribute" + + # Zigbee API commands + PERMIT_JOINING = "permit_joining" + START_NETWORK = "start_network" + STOP_NETWORK = "stop_network" + UPDATE_NETWORK_TOPOLOGY = "update_network_topology" + + # Group commands + GET_GROUPS = "get_groups" + CREATE_GROUP = "create_group" + REMOVE_GROUPS = "remove_groups" + ADD_GROUP_MEMBERS = "add_group_members" + REMOVE_GROUP_MEMBERS = "remove_group_members" + + # Server API commands + STOP_SERVER = "stop_server" + GET_APPLICATION_STATE = "get_application_state" + + # Light API commands + LIGHT_TURN_ON = "light_turn_on" + LIGHT_TURN_OFF = "light_turn_off" + LIGHT_RESTORE_EXTERNAL_STATE_ATTRIBUTES = "light_restore_external_state_attributes" + + # Switch API commands + SWITCH_TURN_ON = "switch_turn_on" + SWITCH_TURN_OFF = "switch_turn_off" + + SIREN_TURN_ON = "siren_turn_on" + SIREN_TURN_OFF = "siren_turn_off" + + LOCK_UNLOCK = "lock_unlock" + LOCK_LOCK = "lock_lock" + LOCK_SET_USER_CODE = "lock_set_user_lock_code" + LOCK_ENAABLE_USER_CODE = "lock_enable_user_lock_code" + LOCK_DISABLE_USER_CODE = "lock_disable_user_lock_code" + LOCK_CLEAR_USER_CODE = "lock_clear_user_lock_code" + LOCK_RESTORE_EXTERNAL_STATE_ATTRIBUTES = "lock_restore_external_state_attributes" + + CLIMATE_SET_TEMPERATURE = "climate_set_temperature" + CLIMATE_SET_HVAC_MODE = "climate_set_hvac_mode" + CLIMATE_SET_FAN_MODE = "climate_set_fan_mode" + CLIMATE_SET_PRESET_MODE = "climate_set_preset_mode" + + COVER_OPEN = "cover_open" + COVER_OPEN_TILT = "cover_open_tilt" + COVER_CLOSE = "cover_close" + COVER_CLOSE_TILT = "cover_close_tilt" + COVER_STOP = "cover_stop" + COVER_SET_POSITION = "cover_set_position" + COVER_SET_TILT_POSITION = "cover_set_tilt_position" + COVER_STOP_TILT = "cover_stop_tilt" + COVER_RESTORE_EXTERNAL_STATE_ATTRIBUTES = "cover_restore_external_state_attributes" + + FAN_TURN_ON = "fan_turn_on" + FAN_TURN_OFF = "fan_turn_off" + FAN_SET_PERCENTAGE = "fan_set_percentage" + FAN_SET_PRESET_MODE = "fan_set_preset_mode" + + BUTTON_PRESS = "button_press" + + ALARM_CONTROL_PANEL_DISARM = "alarm_control_panel_disarm" + ALARM_CONTROL_PANEL_ARM_HOME = "alarm_control_panel_arm_home" + ALARM_CONTROL_PANEL_ARM_AWAY = "alarm_control_panel_arm_away" + ALARM_CONTROL_PANEL_ARM_NIGHT = "alarm_control_panel_arm_night" + ALARM_CONTROL_PANEL_TRIGGER = "alarm_control_panel_trigger" + + SELECT_SELECT_OPTION = "select_select_option" + SELECT_RESTORE_EXTERNAL_STATE_ATTRIBUTES = ( + "select_restore_external_state_attributes" + ) + + NUMBER_SET_VALUE = "number_set_value" + + PLATFORM_ENTITY_REFRESH_STATE = "platform_entity_refresh_state" + PLATFORM_ENTITY_ENABLE = "platform_entity_enable" + PLATFORM_ENTITY_DISABLE = "platform_entity_disable" + + CLIENT_LISTEN = "client_listen" + CLIENT_LISTEN_RAW_ZCL = "client_listen_raw_zcl" + CLIENT_DISCONNECT = "client_disconnect" + + FIRMWARE_INSTALL = "firmware_install" + + +DEVICE: Final[str] = "device" +DEVICES: Final[str] = "devices" +GROUPS: Final[str] = "groups" +GROUP_ID: Final[str] = "group_id" +GROUP_IDS: Final[str] = "group_ids" +GROUP_NAME: Final[str] = "group_name" +ERROR_CODE: Final[str] = "error_code" +ERROR_MESSAGE: Final[str] = "error_message" +MESSAGE_ID: Final[str] = "message_id" +SUCCESS: Final[str] = "success" +WEBSOCKET_API: Final[str] = "websocket_api" +ZIGBEE_ERROR_CODE: Final[str] = "zigbee_error_code" +ZIGBEE_ERROR: Final[str] = "zigbee_error" +ZIGBEE_ERROR_MESSAGE: Final[str] = "zigbee_error_message" diff --git a/zha/websocket/server/__init__.py b/zha/websocket/server/__init__.py new file mode 100644 index 000000000..5732f7f2c --- /dev/null +++ b/zha/websocket/server/__init__.py @@ -0,0 +1 @@ +"""Websocket server module for Zigbee Home Automation.""" diff --git a/zha/websocket/server/__main__.py b/zha/websocket/server/__main__.py new file mode 100644 index 000000000..42847319f --- /dev/null +++ b/zha/websocket/server/__main__.py @@ -0,0 +1,73 @@ +"""Websocket application to run a zigpy Zigbee network.""" + +from __future__ import annotations + +import argparse +import asyncio +import json +import logging +from pathlib import Path + +from zha.application.gateway import WebSocketServerGateway +from zha.application.model import ( + WebsocketClientConfiguration, + WebsocketServerConfiguration, + ZHAConfiguration, + ZHAData, +) + +_LOGGER = logging.getLogger(__name__) + + +async def main(config_path: str | None = None) -> None: + """Run the websocket server.""" + if config_path is None: + raise ValueError("config_path must be provided") + else: + _LOGGER.info("Loading configuration from %s", config_path) + path = Path(config_path) + raw_data = json.loads(path.read_text(encoding="utf-8")) + zha_data = ZHAData( + config=ZHAConfiguration.model_validate(raw_data["zha_config"]), + ws_server_config=WebsocketServerConfiguration.model_validate( + raw_data["ws_server_config"] + ), + ws_client_config=WebsocketClientConfiguration.model_validate( + raw_data["ws_client_config"] + ), + zigpy_config=raw_data["zigpy_config"], + ) + async with await WebSocketServerGateway.async_from_config(zha_data) as ws_gateway: + await ws_gateway.async_initialize() + await ws_gateway.async_initialize_devices_and_entities() + await ws_gateway.wait_closed() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Start the ZHAWS gateway") + parser.add_argument( + "--config", type=str, default=None, help="Path to the configuration file" + ) + + args = parser.parse_args() + + from colorlog import ColoredFormatter + + fmt = "%(asctime)s %(levelname)s (%(threadName)s) [%(name)s] %(message)s" + colorfmt = f"%(log_color)s{fmt}%(reset)s" + logging.basicConfig(level=logging.DEBUG) + logging.getLogger().handlers[0].setFormatter( + ColoredFormatter( + colorfmt, + reset=True, + log_colors={ + "DEBUG": "cyan", + "INFO": "green", + "WARNING": "yellow", + "ERROR": "red", + "CRITICAL": "red", + }, + ) + ) + + asyncio.run(main(args.config)) diff --git a/zha/websocket/server/api/__init__.py b/zha/websocket/server/api/__init__.py new file mode 100644 index 000000000..4ebcf0d91 --- /dev/null +++ b/zha/websocket/server/api/__init__.py @@ -0,0 +1,29 @@ +"""Websocket api for zha.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, cast + +from zha.websocket.const import WEBSOCKET_API +from zha.websocket.server.api.model import WebSocketCommand +from zha.websocket.server.api.types import WebSocketCommandHandler + +if TYPE_CHECKING: + from zha.application.gateway import WebSocketServerGateway + + +def register_api_command( + gateway: WebSocketServerGateway, + command_or_handler: str | WebSocketCommandHandler, + handler: WebSocketCommandHandler | None = None, + model: type[WebSocketCommand] | None = None, +) -> None: + """Register a websocket command.""" + # pylint: disable=protected-access + if handler is None: + handler = cast(WebSocketCommandHandler, command_or_handler) + command = handler._ws_command # type: ignore[attr-defined] + model = handler._ws_command_model # type: ignore[attr-defined] + else: + command = command_or_handler + gateway.data[WEBSOCKET_API][command] = (handler, model) diff --git a/zha/websocket/server/api/decorators.py b/zha/websocket/server/api/decorators.py new file mode 100644 index 000000000..e8375efe3 --- /dev/null +++ b/zha/websocket/server/api/decorators.py @@ -0,0 +1,68 @@ +"""Decorators for the Websocket API.""" + +from __future__ import annotations + +from collections.abc import Callable +from functools import wraps +import logging +from typing import TYPE_CHECKING + +from zha.websocket.server.api.model import WebSocketCommand + +if TYPE_CHECKING: + from zha.application.gateway import WebSocketServerGateway + from zha.websocket.server.api.types import ( + AsyncWebSocketCommandHandler, + T_WebSocketCommand, + WebSocketCommandHandler, + ) + from zha.websocket.server.client import Client + +_LOGGER = logging.getLogger(__name__) + + +async def _handle_async_response( + func: AsyncWebSocketCommandHandler, + gateway: WebSocketServerGateway, + client: Client, + msg: T_WebSocketCommand, +) -> None: + """Create a response and handle exception.""" + await func(gateway, client, msg) + + +def async_response( + func: AsyncWebSocketCommandHandler, +) -> WebSocketCommandHandler: + """Decorate an async function to handle WebSocket API messages.""" + + @wraps(func) + def schedule_handler( + gateway: WebSocketServerGateway, client: Client, msg: T_WebSocketCommand + ) -> None: + """Schedule the handler.""" + # As the webserver is now started before the start + # event we do not want to block for websocket responders + gateway.async_create_task( + _handle_async_response(func, gateway, client, msg), + "_handle_async_response", + eager_start=True, + ) + + return schedule_handler + + +def websocket_command( + ws_command: type[WebSocketCommand], +) -> Callable[[WebSocketCommandHandler], WebSocketCommandHandler]: + """Tag a function as a websocket command.""" + command = ws_command.model_fields["command"].default + + def decorate(func: WebSocketCommandHandler) -> WebSocketCommandHandler: + """Decorate ws command function.""" + # pylint: disable=protected-access + func._ws_command_model = ws_command # type: ignore[attr-defined] + func._ws_command = command # type: ignore[attr-defined] + return func + + return decorate diff --git a/zha/websocket/server/api/model.py b/zha/websocket/server/api/model.py new file mode 100644 index 000000000..b5b0dec18 --- /dev/null +++ b/zha/websocket/server/api/model.py @@ -0,0 +1,289 @@ +"""Models for the websocket API.""" + +from typing import TYPE_CHECKING, Any, Literal, Optional + +from pydantic import field_serializer, field_validator +from zigpy.state import CounterGroups, NetworkInfo, NodeInfo, State +from zigpy.types.named import EUI64 + +from zha.application.model import ( + ConnectionLostEvent, + DeviceFullyInitializedEvent, + DeviceJoinedEvent, + DeviceLeftEvent, + DeviceOfflineEvent, + DeviceOnlineEvent, + DeviceRemovedEvent, + GroupAddedEvent, + GroupMemberAddedEvent, + GroupMemberRemovedEvent, + GroupRemovedEvent, + RawDeviceInitializedEvent, +) +from zha.application.platforms.events import EntityStateChangedEvent +from zha.const import MessageTypes +from zha.model import BaseModel, TypedBaseModel, as_tagged_union +from zha.websocket.const import APICommands +from zha.zigbee.cluster_handlers.model import ( + ClusterAttributeUpdatedEvent, + ClusterBindEvent, + ClusterConfigureReportingEvent, + ClusterInfo, + LevelChangeEvent, +) +from zha.zigbee.cluster_handlers.security import ClusterHandlerStateChangedEvent +from zha.zigbee.model import ( + ClusterHandlerConfigurationComplete, + ExtendedDeviceInfo, + GroupInfo, + ZHAEvent, +) + + +class WebSocketCommand(TypedBaseModel): + """Command for the websocket API.""" + + message_id: int = 1 + command: Literal[ + APICommands.STOP_SERVER, + APICommands.CLIENT_LISTEN_RAW_ZCL, + APICommands.CLIENT_DISCONNECT, + APICommands.CLIENT_LISTEN, + APICommands.BUTTON_PRESS, + APICommands.PLATFORM_ENTITY_REFRESH_STATE, + APICommands.PLATFORM_ENTITY_ENABLE, + APICommands.PLATFORM_ENTITY_DISABLE, + APICommands.ALARM_CONTROL_PANEL_DISARM, + APICommands.ALARM_CONTROL_PANEL_ARM_HOME, + APICommands.ALARM_CONTROL_PANEL_ARM_AWAY, + APICommands.ALARM_CONTROL_PANEL_ARM_NIGHT, + APICommands.ALARM_CONTROL_PANEL_TRIGGER, + APICommands.START_NETWORK, + APICommands.STOP_NETWORK, + APICommands.UPDATE_NETWORK_TOPOLOGY, + APICommands.RECONFIGURE_DEVICE, + APICommands.GET_DEVICES, + APICommands.GET_GROUPS, + APICommands.PERMIT_JOINING, + APICommands.ADD_GROUP_MEMBERS, + APICommands.REMOVE_GROUP_MEMBERS, + APICommands.CREATE_GROUP, + APICommands.REMOVE_GROUPS, + APICommands.REMOVE_DEVICE, + APICommands.READ_CLUSTER_ATTRIBUTES, + APICommands.WRITE_CLUSTER_ATTRIBUTE, + APICommands.SIREN_TURN_ON, + APICommands.SIREN_TURN_OFF, + APICommands.SELECT_SELECT_OPTION, + APICommands.SELECT_RESTORE_EXTERNAL_STATE_ATTRIBUTES, + APICommands.NUMBER_SET_VALUE, + APICommands.LOCK_CLEAR_USER_CODE, + APICommands.LOCK_SET_USER_CODE, + APICommands.LOCK_ENAABLE_USER_CODE, + APICommands.LOCK_DISABLE_USER_CODE, + APICommands.LOCK_LOCK, + APICommands.LOCK_UNLOCK, + APICommands.LOCK_RESTORE_EXTERNAL_STATE_ATTRIBUTES, + APICommands.LIGHT_TURN_OFF, + APICommands.LIGHT_TURN_ON, + APICommands.LIGHT_RESTORE_EXTERNAL_STATE_ATTRIBUTES, + APICommands.FAN_SET_PERCENTAGE, + APICommands.FAN_SET_PRESET_MODE, + APICommands.FAN_TURN_ON, + APICommands.FAN_TURN_OFF, + APICommands.COVER_STOP, + APICommands.COVER_SET_POSITION, + APICommands.COVER_OPEN, + APICommands.COVER_CLOSE, + APICommands.COVER_OPEN_TILT, + APICommands.COVER_CLOSE_TILT, + APICommands.COVER_SET_TILT_POSITION, + APICommands.COVER_STOP_TILT, + APICommands.COVER_RESTORE_EXTERNAL_STATE_ATTRIBUTES, + APICommands.CLIMATE_SET_TEMPERATURE, + APICommands.CLIMATE_SET_HVAC_MODE, + APICommands.CLIMATE_SET_FAN_MODE, + APICommands.CLIMATE_SET_PRESET_MODE, + APICommands.SWITCH_TURN_ON, + APICommands.SWITCH_TURN_OFF, + APICommands.FIRMWARE_INSTALL, + APICommands.GET_APPLICATION_STATE, + ] + + +class WebSocketCommandResponse(WebSocketCommand): + """Websocket command response.""" + + message_type: Literal[MessageTypes.RESULT] = MessageTypes.RESULT + success: bool + + +class ErrorResponse(WebSocketCommandResponse): + """Error response model.""" + + success: bool = False + error_code: str + error_message: str + zigbee_error_code: Optional[str] = None + command: APICommands + + +class PermitJoiningResponse(WebSocketCommandResponse): + """Get devices response.""" + + command: Literal[APICommands.PERMIT_JOINING] = APICommands.PERMIT_JOINING + duration: int | None = None + ieee: EUI64 | None = None + + +class GetDevicesResponse(WebSocketCommandResponse): + """Get devices response.""" + + command: Literal[APICommands.GET_DEVICES] = APICommands.GET_DEVICES + devices: dict[EUI64, ExtendedDeviceInfo] + + @field_serializer("devices", check_fields=False) + def serialize_devices(self, devices: dict[EUI64, ExtendedDeviceInfo]) -> dict: + """Serialize devices.""" + return {str(ieee): device for ieee, device in devices.items()} + + @field_validator("devices", mode="before", check_fields=False) + @classmethod + def convert_devices( + cls, devices: dict[str, ExtendedDeviceInfo] + ) -> dict[EUI64, ExtendedDeviceInfo]: + """Convert devices.""" + if all(isinstance(ieee, str) for ieee in devices): + return {EUI64.convert(ieee): device for ieee, device in devices.items()} + return devices + + +class ReadClusterAttributesResponse(WebSocketCommandResponse): + """Read cluster attributes response.""" + + command: Literal[APICommands.READ_CLUSTER_ATTRIBUTES] = ( + APICommands.READ_CLUSTER_ATTRIBUTES + ) + device: ExtendedDeviceInfo + cluster: ClusterInfo + manufacturer_code: Optional[int] + succeeded: dict[str, Any] + failed: dict[str, Any] + + +class AttributeStatus(BaseModel): + """Attribute status.""" + + attribute: str + status: str + + +class WriteClusterAttributeResponse(WebSocketCommandResponse): + """Write cluster attribute response.""" + + command: Literal[APICommands.WRITE_CLUSTER_ATTRIBUTE] = ( + APICommands.WRITE_CLUSTER_ATTRIBUTE + ) + device: ExtendedDeviceInfo + cluster: ClusterInfo + manufacturer_code: Optional[int] + response: AttributeStatus + + +class GroupsResponse(WebSocketCommandResponse): + """Get groups response.""" + + command: Literal[APICommands.GET_GROUPS, APICommands.REMOVE_GROUPS] + groups: dict[int, GroupInfo] + + +class UpdateGroupResponse(WebSocketCommandResponse): + """Update group response.""" + + command: Literal[ + APICommands.CREATE_GROUP, + APICommands.ADD_GROUP_MEMBERS, + APICommands.REMOVE_GROUP_MEMBERS, + ] + group: GroupInfo + + +class GetApplicationStateResponse(WebSocketCommandResponse): + """Get devices response.""" + + command: Literal[APICommands.GET_APPLICATION_STATE] = ( + APICommands.GET_APPLICATION_STATE + ) + state: dict[str, Any] + + @field_validator("state", mode="before", check_fields=False) + @classmethod + def validate_state(cls, value: State | dict[str, Any]) -> dict[str, Any]: + """Validate the state.""" + if isinstance(value, State): + return { + "node_info": value.node_info.as_dict(), + "network_info": value.network_info.as_dict(), + "counters": value.counters, + "broadcast_counters": value.broadcast_counters, + "device_counters": value.device_counters, + "group_counters": value.group_counters, + } + return value + + def get_converted_state(self) -> State: + """Convert state.""" + state: State = State() + state.network_info = NetworkInfo.from_dict(self.state["network_info"]) + state.node_info = NodeInfo.from_dict(self.state["node_info"]) + state.broadcast_counters = CounterGroups().update( + **self.state["broadcast_counters"] + ) + state.counters = CounterGroups().update(**self.state["counters"]) + state.device_counters = CounterGroups().update(**self.state["device_counters"]) + state.group_counters = CounterGroups().update(**self.state["group_counters"]) + return state + + +CommandResponses = ( + WebSocketCommandResponse + | ErrorResponse + | GetDevicesResponse + | GroupsResponse + | PermitJoiningResponse + | UpdateGroupResponse + | ReadClusterAttributesResponse + | WriteClusterAttributeResponse + | GetApplicationStateResponse +) + + +Events = ( + EntityStateChangedEvent + | DeviceJoinedEvent + | RawDeviceInitializedEvent + | DeviceFullyInitializedEvent + | DeviceLeftEvent + | DeviceRemovedEvent + | GroupRemovedEvent + | GroupAddedEvent + | GroupMemberAddedEvent + | GroupMemberRemovedEvent + | DeviceOfflineEvent + | DeviceOnlineEvent + | ZHAEvent + | ConnectionLostEvent + | ClusterAttributeUpdatedEvent + | ClusterBindEvent + | ClusterConfigureReportingEvent + | LevelChangeEvent + | ClusterHandlerStateChangedEvent + | ClusterHandlerConfigurationComplete +) + +Messages = CommandResponses | Events + +if not TYPE_CHECKING: + CommandResponses = as_tagged_union(CommandResponses) + Events = as_tagged_union(Events) + Messages = as_tagged_union(Messages) diff --git a/zha/websocket/server/api/types.py b/zha/websocket/server/api/types.py new file mode 100644 index 000000000..5819a91ca --- /dev/null +++ b/zha/websocket/server/api/types.py @@ -0,0 +1,15 @@ +"""Type information for the websocket api module.""" + +from __future__ import annotations + +from collections.abc import Callable, Coroutine +from typing import Any, TypeVar + +from zha.websocket.server.api.model import WebSocketCommand + +T_WebSocketCommand = TypeVar("T_WebSocketCommand", bound=WebSocketCommand) + +AsyncWebSocketCommandHandler = Callable[ + [Any, Any, T_WebSocketCommand], Coroutine[Any, Any, None] +] +WebSocketCommandHandler = Callable[[Any, Any, T_WebSocketCommand], None] diff --git a/zha/websocket/server/client.py b/zha/websocket/server/client.py new file mode 100644 index 000000000..cab6c64ab --- /dev/null +++ b/zha/websocket/server/client.py @@ -0,0 +1,303 @@ +"""Client classes for zha.""" + +from __future__ import annotations + +from collections.abc import Callable +import json +import logging +from typing import TYPE_CHECKING, Any, Literal + +from pydantic import BaseModel, ValidationError +from websockets.server import WebSocketServerProtocol + +from zha.const import COMMAND, MODEL_CLASS_NAME, EventTypes, MessageTypes +from zha.model import BaseEvent +from zha.websocket.const import ( + ERROR_CODE, + ERROR_MESSAGE, + MESSAGE_ID, + SUCCESS, + WEBSOCKET_API, + ZIGBEE_ERROR, + ZIGBEE_ERROR_CODE, + APICommands, +) +from zha.websocket.server.api import decorators, register_api_command +from zha.websocket.server.api.model import ( + ErrorResponse, + WebSocketCommand, + WebSocketCommandResponse, +) + +if TYPE_CHECKING: + from zha.application.gateway import WebSocketServerGateway + +_LOGGER = logging.getLogger(__name__) + + +class Client: + """ZHA websocket server client implementation.""" + + def __init__( + self, + websocket: WebSocketServerProtocol, + client_manager: ClientManager, + ): + """Initialize the client.""" + self._websocket: WebSocketServerProtocol = websocket + self._client_manager: ClientManager = client_manager + self.receive_events: bool = False + self.receive_raw_zcl_events: bool = False + + @property + def is_connected(self) -> bool: + """Return True if the websocket connection is connected.""" + return self._websocket.open + + def disconnect(self) -> None: + """Disconnect this client and close the websocket.""" + self._client_manager.server_gateway.async_create_task( + self._websocket.close(), name="disconnect", eager_start=True + ) + + def send_event(self, message: BaseEvent) -> None: + """Send event data to this client.""" + message.message_type = MessageTypes.EVENT + self._send_data(message) + + def send_result_success( + self, + command: WebSocketCommand, + data: dict[str, Any] | BaseModel | None = None, + response_type: type[WebSocketCommandResponse] = WebSocketCommandResponse, + ) -> None: + """Send success result prompted by a client request.""" + if data and isinstance(data, BaseModel): + self._send_data(data) + else: + if data is None: + data = {} + self._send_data( + response_type( + **command.model_dump(exclude=[MODEL_CLASS_NAME]), + success=True, + **data, + ) + ) + + def send_result_error( + self, + command: WebSocketCommand, + error_code: str, + error_message: str, + data: dict[str, Any] | None = None, + ) -> None: + """Send error result prompted by a client request.""" + message = { + SUCCESS: False, + MESSAGE_ID: command.message_id, + COMMAND: command.command, + ERROR_CODE: error_code, + ERROR_MESSAGE: error_message, + } + if data: + message.update(data) + self._send_data(ErrorResponse(**message)) + + def send_result_zigbee_error( + self, + command: WebSocketCommand, + error_message: str, + zigbee_error_code: str, + ) -> None: + """Send zigbee error result prompted by a client zigbee request.""" + self.send_result_error( + command, + error_code=ZIGBEE_ERROR, + error_message=error_message, + data={ZIGBEE_ERROR_CODE: zigbee_error_code}, + ) + + def _send_data(self, message: dict[str, Any] | BaseModel) -> None: + """Send data to this client.""" + try: + if isinstance(message, BaseModel): + message_json = message.model_dump_json() + else: + message_json = json.dumps(message) + except ValueError as exc: + _LOGGER.exception("Couldn't serialize data: %s", message, exc_info=exc) + raise exc + else: + self._client_manager.server_gateway.async_create_task( + self._websocket.send(message_json), name="send_data", eager_start=True + ) + + async def _handle_incoming_message(self, message: str | bytes) -> None: + """Handle an incoming message.""" + _LOGGER.info("Message received: %s", message) + handlers: dict[str, tuple[Callable, WebSocketCommand]] = ( + self._client_manager.server_gateway.data[WEBSOCKET_API] + ) + + try: + msg = WebSocketCommand.model_validate_json(message) + except ValidationError as exception: + _LOGGER.exception( + "Received invalid command[unable to parse command]: %s on websocket: %s", + message, + self._websocket.id, + exc_info=exception, + ) + return + + if msg.command not in handlers: + _LOGGER.error( + "Received invalid command[command not registered]: %s", message + ) + return + + handler, model = handlers[msg.command] + + try: + handler( + self._client_manager.server_gateway, + self, + model.model_validate_json(message), + ) + except Exception as err: # pylint: disable=broad-except + # TODO Fix this - make real error codes with error messages + _LOGGER.exception("Error handling message: %s", message, exc_info=err) + self.send_result_error(message, "INTERNAL_ERROR", f"Internal error: {err}") + + async def listen(self) -> None: + """Listen for incoming messages.""" + async for message in self._websocket: + self._client_manager.server_gateway.async_create_task( + self._handle_incoming_message(message), + name="handle_incoming_message", + eager_start=True, + ) + + def will_accept_message(self, message: BaseEvent) -> bool: + """Determine if client accepts this type of message.""" + if not self.receive_events: + return False + + if ( + message.event_type == EventTypes.RAW_ZCL_EVENT + and not self.receive_raw_zcl_events + ): + _LOGGER.info( + "Client %s not accepting raw ZCL events: %s", + self._websocket.id, + message, + ) + return False + + return True + + +class ClientListenRawZCLCommand(WebSocketCommand): + """Listen to raw ZCL data.""" + + command: Literal[APICommands.CLIENT_LISTEN_RAW_ZCL] = ( + APICommands.CLIENT_LISTEN_RAW_ZCL + ) + + +class ClientListenCommand(WebSocketCommand): + """Listen for zha messages.""" + + command: Literal[APICommands.CLIENT_LISTEN] = APICommands.CLIENT_LISTEN + + +class ClientDisconnectCommand(WebSocketCommand): + """Disconnect this client.""" + + command: Literal[APICommands.CLIENT_DISCONNECT] = APICommands.CLIENT_DISCONNECT + + +@decorators.websocket_command(ClientListenRawZCLCommand) +@decorators.async_response +async def listen_raw_zcl( + gateway: WebSocketServerGateway, client: Client, command: WebSocketCommand +) -> None: + """Listen for raw ZCL events.""" + client.receive_raw_zcl_events = True + client.send_result_success(command) + + +@decorators.websocket_command(ClientListenCommand) +@decorators.async_response +async def listen( + gateway: WebSocketServerGateway, client: Client, command: WebSocketCommand +) -> None: + """Listen for events.""" + client.receive_events = True + client.send_result_success(command) + + +@decorators.websocket_command(ClientDisconnectCommand) +@decorators.async_response +async def disconnect( + gateway: WebSocketServerGateway, client: Client, command: WebSocketCommand +) -> None: + """Disconnect the client.""" + gateway.client_manager.remove_client(client) + + +def load_api(gateway: WebSocketServerGateway) -> None: + """Load the api command handlers.""" + register_api_command(gateway, listen_raw_zcl) + register_api_command(gateway, listen) + register_api_command(gateway, disconnect) + + +class ClientManager: + """ZHA websocket server client manager implementation.""" + + def __init__(self, gateway: WebSocketServerGateway): + """Initialize the client.""" + self._gateway: WebSocketServerGateway = gateway + self._clients: list[Client] = [] + + @property + def server_gateway(self) -> WebSocketServerGateway: + """Return the server this ClientManager belongs to.""" + return self._gateway + + async def add_client(self, websocket: WebSocketServerProtocol) -> None: + """Add a new client to the client manager.""" + client: Client = Client(websocket, self) + self._clients.append(client) + await client.listen() + + def remove_client(self, client: Client) -> None: + """Remove a client from the client manager.""" + client.disconnect() + self._clients.remove(client) + + def broadcast(self, message: BaseEvent) -> None: + """Broadcast a message to all connected clients.""" + clients_to_remove = [] + + for client in self._clients: + if not client.is_connected: + # XXX: We cannot remove elements from `_clients` while iterating over it + clients_to_remove.append(client) + continue + + if not client.will_accept_message(message): + continue + + _LOGGER.info( + "Broadcasting message: %s to client: %s", + message, + client._websocket.id, + ) + # TODO use the receive flags on the client to determine if the client should receive the message + client.send_event(message) + + for client in clients_to_remove: + self.remove_client(client) diff --git a/zha/zigbee/cluster_handlers/__init__.py b/zha/zigbee/cluster_handlers/__init__.py index 321b9e194..d6289594b 100644 --- a/zha/zigbee/cluster_handlers/__init__.py +++ b/zha/zigbee/cluster_handlers/__init__.py @@ -4,11 +4,9 @@ from collections.abc import Awaitable, Callable, Coroutine, Iterator import contextlib -from dataclasses import dataclass -from enum import Enum import functools import logging -from typing import TYPE_CHECKING, Any, Final, ParamSpec, TypedDict +from typing import TYPE_CHECKING, Any, ParamSpec, TypedDict import zigpy.exceptions import zigpy.util @@ -21,7 +19,6 @@ ) from zha.application.const import ( - ZHA_CLUSTER_HANDLER_MSG, ZHA_CLUSTER_HANDLER_MSG_BIND, ZHA_CLUSTER_HANDLER_MSG_CFG_RPT, ) @@ -35,7 +32,6 @@ ATTRIBUTE_NAME, ATTRIBUTE_VALUE, CLUSTER_HANDLER_ATTRIBUTE_UPDATED, - CLUSTER_HANDLER_EVENT, CLUSTER_HANDLER_ZDO, CLUSTER_ID, CLUSTER_READS_PER_REQ, @@ -46,6 +42,14 @@ UNIQUE_ID, VALUE, ) +from zha.zigbee.cluster_handlers.model import ( + ClusterAttributeUpdatedEvent, + ClusterBindEvent, + ClusterConfigureReportingEvent, + ClusterHandlerInfo, + ClusterHandlerStatus, + ClusterInfo, +) if TYPE_CHECKING: from zha.zigbee.endpoint import Endpoint @@ -114,75 +118,6 @@ def parse_and_log_command(cluster_handler, tsn, command_id, args): return name -class ClusterHandlerStatus(Enum): - """Status of a cluster handler.""" - - CREATED = 1 - CONFIGURED = 2 - INITIALIZED = 3 - - -@dataclass(kw_only=True, frozen=True) -class ClusterAttributeUpdatedEvent: - """Event to signal that a cluster attribute has been updated.""" - - attribute_id: int - attribute_name: str - attribute_value: Any - cluster_handler_unique_id: str - cluster_id: int - event_type: Final[str] = CLUSTER_HANDLER_EVENT - event: Final[str] = CLUSTER_HANDLER_ATTRIBUTE_UPDATED - - -@dataclass(kw_only=True, frozen=True) -class ClusterBindEvent: - """Event generated when the cluster is bound.""" - - cluster_name: str - cluster_id: int - success: bool - cluster_handler_unique_id: str - event_type: Final[str] = ZHA_CLUSTER_HANDLER_MSG - event: Final[str] = ZHA_CLUSTER_HANDLER_MSG_BIND - - -@dataclass(kw_only=True, frozen=True) -class ClusterConfigureReportingEvent: - """Event generates when a cluster configures attribute reporting.""" - - cluster_name: str - cluster_id: int - attributes: dict[str, dict[str, Any]] - cluster_handler_unique_id: str - event_type: Final[str] = ZHA_CLUSTER_HANDLER_MSG - event: Final[str] = ZHA_CLUSTER_HANDLER_MSG_CFG_RPT - - -@dataclass(kw_only=True, frozen=True) -class ClusterInfo: - """Cluster information.""" - - id: int - name: str - type: str - commands: dict[int, str] - - -@dataclass(kw_only=True, frozen=True) -class ClusterHandlerInfo: - """Cluster handler information.""" - - class_name: str - generic_id: str - endpoint_id: str - cluster: ClusterInfo - id: str - unique_id: str - status: ClusterHandlerStatus - value_attribute: str | None = None - - class ClusterHandler(LogMixin, EventBase): """Base cluster handler for a Zigbee cluster.""" @@ -217,7 +152,7 @@ def matches(cls, cluster: zigpy.zcl.Cluster, endpoint: Endpoint) -> bool: # pyl """Filter the cluster match for specific devices.""" return True - @functools.cached_property + @property def info_object(self) -> ClusterHandlerInfo: """Return info about this cluster handler.""" return ClusterHandlerInfo( @@ -228,11 +163,12 @@ def info_object(self) -> ClusterHandlerInfo: id=self._cluster.cluster_id, name=self._cluster.name, type="client" if self._cluster.is_client else "server", - commands=self._cluster.commands, + endpoint_id=self._cluster.endpoint.endpoint_id, + endpoint_attribute=self._cluster.ep_attribute, ), id=self._id, unique_id=self._unique_id, - status=self._status.name, + status=self._status, value_attribute=getattr(self, "value_attribute", None), ) @@ -547,7 +483,7 @@ async def async_update(self) -> None: def _get_attribute_name(self, attrid: int) -> str | int: if attrid not in self.cluster.attributes: - return attrid + return "Unknown" return self.cluster.attributes[attrid].name diff --git a/zha/zigbee/cluster_handlers/general.py b/zha/zigbee/cluster_handlers/general.py index e103f1199..60b8f7bee 100644 --- a/zha/zigbee/cluster_handlers/general.py +++ b/zha/zigbee/cluster_handlers/general.py @@ -4,9 +4,8 @@ import asyncio from collections.abc import Coroutine -from dataclasses import dataclass from datetime import datetime -from typing import TYPE_CHECKING, Any, Final +from typing import TYPE_CHECKING, Any from zhaquirks.quirk_ids import TUYA_PLUG_ONOFF import zigpy.exceptions @@ -64,20 +63,12 @@ SIGNAL_SET_LEVEL, ) from zha.zigbee.cluster_handlers.helpers import is_hue_motion_sensor +from zha.zigbee.cluster_handlers.model import LevelChangeEvent if TYPE_CHECKING: from zha.zigbee.endpoint import Endpoint -@dataclass(frozen=True, kw_only=True) -class LevelChangeEvent: - """Event to signal that a cluster attribute has been updated.""" - - level: int - event: str - event_type: Final[str] = "cluster_handler_event" - - @registries.CLUSTER_HANDLER_REGISTRY.register(Alarms.cluster_id) class AlarmsClusterHandler(ClusterHandler): """Alarms cluster handler.""" diff --git a/zha/zigbee/cluster_handlers/model.py b/zha/zigbee/cluster_handlers/model.py new file mode 100644 index 000000000..f4fb8d0ce --- /dev/null +++ b/zha/zigbee/cluster_handlers/model.py @@ -0,0 +1,88 @@ +"""Models for the ZHA cluster handlers module.""" + +from enum import StrEnum +from typing import Any, Literal + +from zha.const import ClusterHandlerEvents, EventTypes +from zha.model import BaseEvent, BaseModel + + +class ClusterHandlerStatus(StrEnum): + """Status of a cluster handler.""" + + CREATED = "created" + CONFIGURED = "configured" + INITIALIZED = "initialized" + + +class ClusterAttributeUpdatedEvent(BaseEvent): + """Event to signal that a cluster attribute has been updated.""" + + attribute_id: int + attribute_name: str + attribute_value: Any + cluster_handler_unique_id: str + cluster_id: int + event_type: Literal[EventTypes.CLUSTER_HANDLER_EVENT] = ( + EventTypes.CLUSTER_HANDLER_EVENT + ) + event: Literal[ClusterHandlerEvents.CLUSTER_HANDLER_ATTRIBUTE_UPDATED] = ( + ClusterHandlerEvents.CLUSTER_HANDLER_ATTRIBUTE_UPDATED + ) + + +class ClusterBindEvent(BaseEvent): + """Event generated when the cluster is bound.""" + + cluster_name: str + cluster_id: int + success: bool + cluster_handler_unique_id: str + event_type: Literal["zha_channel_message"] = "zha_channel_message" + event: Literal["zha_channel_bind"] = "zha_channel_bind" + + +class ClusterConfigureReportingEvent(BaseEvent): + """Event generates when a cluster configures attribute reporting.""" + + cluster_name: str + cluster_id: int + attributes: dict[str, dict[str, Any]] + cluster_handler_unique_id: str + event_type: Literal["zha_channel_message"] = "zha_channel_message" + event: Literal["zha_channel_configure_reporting"] = ( + "zha_channel_configure_reporting" + ) + + +class ClusterInfo(BaseModel): + """Cluster information.""" + + id: int + name: str + type: str + endpoint_id: int + endpoint_attribute: str | None = None + + +class ClusterHandlerInfo(BaseModel): + """Cluster handler information.""" + + class_name: str + generic_id: str + endpoint_id: int + cluster: ClusterInfo + id: str + unique_id: str + status: ClusterHandlerStatus + value_attribute: str | None = None + + +class LevelChangeEvent(BaseEvent): + """Event to signal that a cluster attribute has been updated.""" + + level: int + event: str + event_type: Literal[EventTypes.CLUSTER_HANDLER_EVENT] = ( + EventTypes.CLUSTER_HANDLER_EVENT + ) diff --git a/zha/zigbee/cluster_handlers/registries.py b/zha/zigbee/cluster_handlers/registries.py index 07c8dc85e..d24db0580 100644 --- a/zha/zigbee/cluster_handlers/registries.py +++ b/zha/zigbee/cluster_handlers/registries.py @@ -1,7 +1,13 @@ """Mapping registries for zha cluster handlers.""" +from __future__ import annotations + +from typing import TYPE_CHECKING + from zha.decorators import DictRegistry, NestedDictRegistry, SetRegistry -from zha.zigbee.cluster_handlers import ClientClusterHandler, ClusterHandler + +if TYPE_CHECKING: + from zha.zigbee.cluster_handlers import ClientClusterHandler, ClusterHandler BINDABLE_CLUSTERS = SetRegistry() CLUSTER_HANDLER_ONLY_CLUSTERS = SetRegistry() diff --git a/zha/zigbee/cluster_handlers/security.py b/zha/zigbee/cluster_handlers/security.py index ea9d364c4..129129fcb 100644 --- a/zha/zigbee/cluster_handlers/security.py +++ b/zha/zigbee/cluster_handlers/security.py @@ -3,8 +3,7 @@ from __future__ import annotations from collections.abc import Callable -import dataclasses -from typing import TYPE_CHECKING, Any, Final +from typing import TYPE_CHECKING, Any, Literal import zigpy.zcl from zigpy.zcl.clusters.security import ( @@ -17,7 +16,9 @@ WarningType, ) +from zha.const import ClusterHandlerEvents, EventTypes from zha.exceptions import ZHAException +from zha.model import BaseEvent from zha.zigbee.cluster_handlers import ClusterHandler, ClusterHandlerStatus, registries from zha.zigbee.cluster_handlers.const import CLUSTER_HANDLER_STATE_CHANGED @@ -28,12 +29,15 @@ SIGNAL_ALARM_TRIGGERED = "zha_armed_triggered" -@dataclasses.dataclass(frozen=True, kw_only=True) -class ClusterHandlerStateChangedEvent: +class ClusterHandlerStateChangedEvent(BaseEvent): """Event to signal that a cluster attribute has been updated.""" - event_type: Final[str] = "cluster_handler_event" - event: Final[str] = "cluster_handler_state_changed" + event_type: Literal[EventTypes.CLUSTER_HANDLER_EVENT] = ( + EventTypes.CLUSTER_HANDLER_EVENT + ) + event: Literal[ClusterHandlerEvents.CLUSTER_HANDLER_STATE_CHANGED] = ( + ClusterHandlerEvents.CLUSTER_HANDLER_STATE_CHANGED + ) @registries.CLUSTER_HANDLER_REGISTRY.register(AceCluster.cluster_id) diff --git a/zha/zigbee/device.py b/zha/zigbee/device.py index 0cebe1856..a329a3bd1 100644 --- a/zha/zigbee/device.py +++ b/zha/zigbee/device.py @@ -4,21 +4,20 @@ from __future__ import annotations +from abc import ABC, abstractmethod import asyncio -from dataclasses import dataclass -from enum import Enum from functools import cached_property import logging import time -from typing import TYPE_CHECKING, Any, Final, Self +from typing import TYPE_CHECKING, Any, Generic, Self from zigpy.device import Device as ZigpyDevice import zigpy.exceptions from zigpy.profiles import PROFILES import zigpy.quirks from zigpy.quirks.v2 import QuirksV2RegistryEntry -from zigpy.types import uint1_t, uint8_t, uint16_t -from zigpy.types.named import EUI64, NWK, ExtendedPanId +from zigpy.types import uint8_t, uint16_t +from zigpy.types.named import EUI64, NWK from zigpy.zcl.clusters import Cluster from zigpy.zcl.clusters.general import Groups, Identify from zigpy.zcl.foundation import ( @@ -27,7 +26,6 @@ ZCLCommandDef, ) import zigpy.zdo.types as zdo_types -from zigpy.zdo.types import RouteStatus, _NeighborEnums from zha.application import Platform, discovery from zha.application.const import ( @@ -56,19 +54,31 @@ UNKNOWN_MANUFACTURER, UNKNOWN_MODEL, ZHA_CLUSTER_HANDLER_CFG_DONE, - ZHA_CLUSTER_HANDLER_MSG, ZHA_EVENT, ) from zha.application.helpers import convert_to_zcl_values -from zha.application.platforms import BaseEntityInfo, PlatformEntity +from zha.application.model import DeviceOfflineEvent, DeviceOnlineEvent +from zha.application.platforms import PlatformEntity, T, WebSocketClientEntity from zha.event import EventBase from zha.exceptions import ZHAException from zha.mixins import LogMixin from zha.zigbee.cluster_handlers import ClusterHandler, ZDOClusterHandler from zha.zigbee.endpoint import Endpoint +from zha.zigbee.model import ( + ClusterBinding, + ClusterHandlerConfigurationComplete, + DeviceInfo, + DeviceStatus, + EndpointNameInfo, + ExtendedDeviceInfo, + NeighborInfo, + RouteInfo, + ZHAEvent, +) if TYPE_CHECKING: - from zha.application.gateway import Gateway + from zha.application.gateway import Gateway, WebSocketClientGateway + from zha.application.platforms.events import EntityStateChangedEvent _LOGGER = logging.getLogger(__name__) _CHECKIN_GRACE_PERIODS = 2 @@ -84,112 +94,145 @@ def get_device_automation_triggers( } -@dataclass(frozen=True, kw_only=True) -class ClusterBinding: - """Describes a cluster binding.""" +class BaseDevice(LogMixin, EventBase, ABC, Generic[T]): + """Base device for Zigbee Home Automation.""" - name: str - type: str - id: int - endpoint_id: int + def __init__(self, gateway) -> None: + """Initialize base device.""" + super().__init__() + self._gateway = gateway + @cached_property + @abstractmethod + def name(self) -> str: + """Return device name.""" -class DeviceStatus(Enum): - """Status of a device.""" + @property + @abstractmethod + def ieee(self) -> EUI64: + """Return ieee address for device.""" - CREATED = 1 - INITIALIZED = 2 + @cached_property + @abstractmethod + def manufacturer(self) -> str: + """Return manufacturer for device.""" + @cached_property + @abstractmethod + def model(self) -> str: + """Return model for device.""" -@dataclass(kw_only=True, frozen=True) -class ZHAEvent: - """Event generated when a device wishes to send an arbitrary event.""" + @cached_property + @abstractmethod + def manufacturer_code(self) -> int | None: + """Return the manufacturer code for the device.""" - device_ieee: EUI64 - unique_id: str - data: dict[str, Any] - event_type: Final[str] = ZHA_EVENT - event: Final[str] = ZHA_EVENT + @property + @abstractmethod + def nwk(self) -> NWK: + """Return nwk for device.""" + @property + @abstractmethod + def lqi(self): + """Return lqi for device.""" -@dataclass(kw_only=True, frozen=True) -class ClusterHandlerConfigurationComplete: - """Event generated when all cluster handlers are configured.""" + @property + @abstractmethod + def rssi(self): + """Return rssi for device.""" - device_ieee: EUI64 - unique_id: str - event_type: Final[str] = ZHA_CLUSTER_HANDLER_MSG - event: Final[str] = ZHA_CLUSTER_HANDLER_CFG_DONE - - -@dataclass(kw_only=True, frozen=True) -class DeviceInfo: - """Describes a device.""" - - ieee: EUI64 - nwk: NWK - manufacturer: str - model: str - name: str - quirk_applied: bool - quirk_class: str - quirk_id: str | None - manufacturer_code: int | None - power_source: str - lqi: int - rssi: int - last_seen: str - available: bool - device_type: str - signature: dict[str, Any] - - -@dataclass(kw_only=True, frozen=True) -class NeighborInfo: - """Describes a neighbor.""" - - device_type: _NeighborEnums.DeviceType - rx_on_when_idle: _NeighborEnums.RxOnWhenIdle - relationship: _NeighborEnums.Relationship - extended_pan_id: ExtendedPanId - ieee: EUI64 - nwk: NWK - permit_joining: _NeighborEnums.PermitJoins - depth: uint8_t - lqi: uint8_t - - -@dataclass(kw_only=True, frozen=True) -class RouteInfo: - """Describes a route.""" - - dest_nwk: NWK - route_status: RouteStatus - memory_constrained: uint1_t - many_to_one: uint1_t - route_record_required: uint1_t - next_hop: NWK - - -@dataclass(kw_only=True, frozen=True) -class EndpointNameInfo: - """Describes an endpoint name.""" - - name: str - - -@dataclass(kw_only=True, frozen=True) -class ExtendedDeviceInfo(DeviceInfo): - """Describes a ZHA device.""" - - active_coordinator: bool - entities: dict[str, BaseEntityInfo] - neighbors: list[NeighborInfo] - routes: list[RouteInfo] - endpoint_names: list[EndpointNameInfo] - - -class Device(LogMixin, EventBase): + @property + @abstractmethod + def last_seen(self) -> float | None: + """Return last_seen for device.""" + + @cached_property + @abstractmethod + def is_mains_powered(self) -> bool | None: + """Return true if device is mains powered.""" + + @cached_property + @abstractmethod + def device_type(self) -> str: + """Return the logical device type for the device.""" + + @property + @abstractmethod + def power_source(self) -> str: + """Return the power source for the device.""" + + @cached_property + @abstractmethod + def is_router(self) -> bool | None: + """Return true if this is a routing capable device.""" + + @cached_property + @abstractmethod + def is_coordinator(self) -> bool | None: + """Return true if this device represents a coordinator.""" + + @property + @abstractmethod + def is_active_coordinator(self) -> bool: + """Return true if this device is the active coordinator.""" + + @cached_property + @abstractmethod + def is_end_device(self) -> bool | None: + """Return true if this device is an end device.""" + + @property + @abstractmethod + def is_groupable(self) -> bool: + """Return true if this device has a group cluster.""" + + @cached_property + @abstractmethod + def device_automation_triggers(self) -> dict[tuple[str, str], dict[str, Any]]: + """Return the device automation triggers for this device.""" + + @property + @abstractmethod + def available(self): + """Return True if device is available.""" + + @cached_property + @abstractmethod + def zigbee_signature(self) -> dict[str, Any]: + """Get zigbee signature for this device.""" + + @property + @abstractmethod + def sw_version(self) -> int | None: + """Return the software version for this device.""" + + @property + @abstractmethod + def platform_entities(self) -> dict[tuple[Platform, str], T]: + """Return the platform entities for this device.""" + + def get_platform_entity(self, platform: Platform, unique_id: str) -> T: + """Get a platform entity by unique id.""" + return self.platform_entities[(platform, unique_id)] + + @cached_property + def device_automation_commands(self) -> dict[str, list[tuple[str, str]]]: + """Return the a lookup of commands to etype/sub_type.""" + commands: dict[str, list[tuple[str, str]]] = {} + for etype_subtype, trigger in self.device_automation_triggers.items(): + if command := trigger.get(ATTR_COMMAND): + commands.setdefault(command, []).append(etype_subtype) + return commands + + def log(self, level: int, msg: str, *args: Any, **kwargs: Any) -> None: + """Log a message.""" + msg = f"[%s](%s): {msg}" + args = (self.nwk, self.model) + args + _LOGGER.log(level, msg, *args, **kwargs) + + +class Device(BaseDevice[PlatformEntity]): """ZHA Zigbee device object.""" unique_id: str @@ -199,12 +242,9 @@ def __init__( zigpy_device: zigpy.device.Device, _gateway: Gateway, ) -> None: - """Initialize the gateway.""" - super().__init__() - + """Initialize the device.""" + super().__init__(_gateway) self.unique_id = str(zigpy_device.ieee) - - self._gateway: Gateway = _gateway self._zigpy_device: ZigpyDevice = zigpy_device self.quirk_applied: bool = isinstance( self._zigpy_device, zigpy.quirks.BaseCustomDevice @@ -410,21 +450,12 @@ def skip_configuration(self) -> bool: return self._zigpy_device.skip_configuration or bool(self.is_active_coordinator) @property - def gateway(self): + def gateway(self) -> Gateway: """Return the gateway for this device.""" return self._gateway @cached_property - def device_automation_commands(self) -> dict[str, list[tuple[str, str]]]: - """Return the a lookup of commands to etype/sub_type.""" - commands: dict[str, list[tuple[str, str]]] = {} - for etype_subtype, trigger in self.device_automation_triggers.items(): - if command := trigger.get(ATTR_COMMAND): - commands.setdefault(command, []).append(etype_subtype) - return commands - - @cached_property - def device_automation_triggers(self) -> dict[tuple[str, str], dict[str, str]]: + def device_automation_triggers(self) -> dict[tuple[str, str], dict[str, Any]]: """Return the device automation triggers for this device.""" return get_device_automation_triggers(self._zigpy_device) @@ -433,11 +464,6 @@ def available(self): """Return True if device is available.""" return self.is_active_coordinator or (self._available and self.on_network) - @available.setter - def available(self, new_availability: bool) -> None: - """Set device availability.""" - self._available = new_availability - @property def on_network(self): """Return True if device is currently on the network.""" @@ -446,8 +472,7 @@ def on_network(self): @on_network.setter def on_network(self, new_on_network: bool) -> None: """Set device on_network flag.""" - self.update_available(new_on_network) - self._on_network = new_on_network + self.update_available(available=new_on_network, on_network=new_on_network) if not new_on_network: self.debug("Device is not on the network, marking unavailable") @@ -526,10 +551,7 @@ def platform_entities(self) -> dict[tuple[Platform, str], PlatformEntity]: def get_platform_entity(self, platform: Platform, unique_id: str) -> PlatformEntity: """Get a platform entity by unique id.""" - entity = self._platform_entities.get((platform, unique_id)) - if entity is None: - raise KeyError(f"Entity {unique_id} not found") - return entity + return self._platform_entities[(platform, unique_id)] @classmethod def new( @@ -552,7 +574,7 @@ async def _check_available(self, *_: Any) -> None: return if self.last_seen is None: self.debug("last_seen is None, marking the device unavailable") - self.update_available(False) + self.update_available(available=False, on_network=self.on_network) return difference = time.time() - self.last_seen @@ -560,7 +582,7 @@ async def _check_available(self, *_: Any) -> None: self.debug( "Device seen - marking the device available and resetting counter" ) - self.update_available(True) + self.update_available(available=True, on_network=self.on_network) self._checkins_missed_count = 0 return @@ -577,7 +599,7 @@ async def _check_available(self, *_: Any) -> None: ), difference, ) - self.update_available(False) + self.update_available(available=False, on_network=False) return self._checkins_missed_count += 1 @@ -585,30 +607,37 @@ async def _check_available(self, *_: Any) -> None: "Attempting to checkin with device - missed checkins: %s", self._checkins_missed_count, ) - if not self.basic_ch: + if not self._basic_ch: self.debug("does not have a mandatory basic cluster") - self.update_available(False) + self.update_available(available=False, on_network=False) return - res = await self.basic_ch.get_attribute_value( + res = await self._basic_ch.get_attribute_value( ATTR_MANUFACTURER, from_cache=False ) if res is not None: self._checkins_missed_count = 0 - def update_available(self, available: bool) -> None: + def update_available( + self, available: bool = False, on_network: bool = False + ) -> None: """Update device availability and signal entities.""" self.debug( ( "Update device availability - device available: %s - new availability:" - " %s - changed: %s" + " %s - changed: %s - on network: %s - new on network: %s - changed: %s" ), self.available, available, self.available ^ available, + self.on_network, + on_network, + self.on_network ^ on_network, ) availability_changed = self.available ^ available - self.available = available - if availability_changed and available: + on_network_changed = self.on_network ^ on_network + self._available = available + self._on_network = on_network + if (availability_changed or on_network_changed) and (available and on_network): # reinit cluster handlers then signal entities self.debug( "Device availability changed and device became available," @@ -620,8 +649,13 @@ def update_available(self, available: bool) -> None: eager_start=True, ) return - if availability_changed and not available: + if (availability_changed or on_network_changed) and not ( + available and on_network + ): self.debug("Device availability changed and device became unavailable") + self.gateway.broadcast_event( + DeviceOfflineEvent(device_info=self.extended_device_info) + ) for entity in self.platform_entities.values(): entity.maybe_emit_state_changed_event() self.emit_zha_event( @@ -632,17 +666,19 @@ def update_available(self, available: bool) -> None: def emit_zha_event(self, event_data: dict[str, str | int]) -> None: # pylint: disable=unused-argument """Relay events directly.""" - self.emit( - ZHA_EVENT, - ZHAEvent( - device_ieee=self.ieee, - unique_id=str(self.ieee), - data=event_data, - ), + event: ZHAEvent = ZHAEvent( + device_ieee=self.ieee, + unique_id=str(self.ieee), + data=event_data, ) + self.emit(ZHA_EVENT, event) + self.gateway.broadcast_event(event) async def _async_became_available(self) -> None: """Update device availability and signal entities.""" + self.gateway.broadcast_event( + DeviceOnlineEvent(device_info=self.extended_device_info) + ) await self.async_initialize(False) for platform_entity in self._platform_entities.values(): platform_entity.maybe_emit_state_changed_event() @@ -666,10 +702,14 @@ def device_info(self) -> DeviceInfo: power_source=self.power_source, lqi=self.lqi, rssi=self.rssi, - last_seen=update_time, + last_seen=self.last_seen, + last_seen_time=update_time, available=self.available, + on_network=self.on_network, + is_groupable=self.is_groupable, device_type=self.device_type, signature=self.zigbee_signature, + sw_version=self.sw_version, ) @property @@ -695,11 +735,11 @@ def extended_device_info(self) -> ExtendedDeviceInfo: ) return ExtendedDeviceInfo( - **self.device_info.__dict__, + **self.device_info.model_dump(), active_coordinator=self.is_active_coordinator, entities={ - platform_entity.unique_id: platform_entity.info_object - for platform_entity in self.platform_entities.values() + platform_entity_key: platform_entity.info_object.model_dump() + for platform_entity_key, platform_entity in self.platform_entities.items() }, neighbors=[ NeighborInfo( @@ -727,6 +767,7 @@ def extended_device_info(self) -> ExtendedDeviceInfo: for route in topology.routes[self.ieee] ], endpoint_names=names, + device_automation_triggers=self.device_automation_triggers, ) async def async_configure(self) -> None: @@ -750,7 +791,7 @@ async def async_configure(self) -> None: ZHA_CLUSTER_HANDLER_CFG_DONE, ClusterHandlerConfigurationComplete( device_ieee=self.ieee, - unique_id=self.ieee, + unique_id=self.unique_id, ), ) @@ -758,10 +799,10 @@ async def async_configure(self) -> None: if ( should_identify - and self.identify_ch is not None + and self._identify_ch is not None and not self.skip_configuration ): - await self.identify_ch.trigger_effect( + await self._identify_ch.trigger_effect( effect_id=Identify.EffectIdentifier.Okay, effect_variant=Identify.EffectVariant.Default, ) @@ -1082,8 +1123,190 @@ async def _async_group_binding_operation( fmt = f"{log_msg[1]} completed: %s" zdo.debug(fmt, *(log_msg[2] + (outcome,))) - def log(self, level: int, msg: str, *args: Any, **kwargs: Any) -> None: - """Log a message.""" - msg = f"[%s](%s): {msg}" - args = (self.nwk, self.model) + args - _LOGGER.log(level, msg, *args, **kwargs) + +class WebSocketClientDevice(BaseDevice[WebSocketClientEntity]): + """ZHA device object for the websocket client.""" + + def __init__( + self, + extended_device_info: ExtendedDeviceInfo, + gateway: WebSocketClientGateway, + ) -> None: + """Initialize the device.""" + super().__init__(gateway) + self._extended_device_info: ExtendedDeviceInfo = extended_device_info + self.unique_id: str = str(extended_device_info.ieee) + self._entities: dict[tuple[Platform, str], WebSocketClientEntity] = {} + if self._extended_device_info.entities: + self._build_or_update_entities() + + @property + def quirk_id(self) -> str | None: + """Return the quirk id for this device.""" + return self._extended_device_info.quirk_id + + @property + def quirk_class(self) -> str: + """Return the quirk class for this device.""" + return self._extended_device_info.quirk_class + + @property + def quirk_applied(self) -> bool: + """Return the quirk applied status for this device.""" + return self._extended_device_info.quirk_applied + + @property + def extended_device_info(self) -> ExtendedDeviceInfo: + """Get extended device information.""" + return self._extended_device_info + + @extended_device_info.setter + def extended_device_info(self, extended_device_info: ExtendedDeviceInfo) -> None: + """Set extended device information.""" + self._extended_device_info = extended_device_info + self._build_or_update_entities() + + @property + def gateway(self) -> WebSocketClientGateway: + """Return the gateway for this device.""" + return self._gateway + + @cached_property + def name(self) -> str: + """Return device name.""" + return self._extended_device_info.name + + @property + def ieee(self) -> EUI64: + """Return ieee address for device.""" + return self._extended_device_info.ieee + + @cached_property + def manufacturer(self) -> str: + """Return manufacturer for device.""" + return self._extended_device_info.manufacturer + + @cached_property + def model(self) -> str: + """Return model for device.""" + return self._extended_device_info.model + + @cached_property + def manufacturer_code(self) -> int | None: + """Return the manufacturer code for the device.""" + return self._extended_device_info.manufacturer_code + + @property + def nwk(self) -> NWK: + """Return nwk for device.""" + return self._extended_device_info.nwk + + @property + def lqi(self): + """Return lqi for device.""" + + @property + def rssi(self): + """Return rssi for device.""" + + @property + def last_seen(self) -> float | None: + """Return last_seen for device.""" + return self._extended_device_info.last_seen + + @cached_property + def is_mains_powered(self) -> bool | None: + """Return true if device is mains powered.""" + return self._extended_device_info.power_source == POWER_MAINS_POWERED + + @cached_property + def device_type(self) -> str: + """Return the logical device type for the device.""" + return self._extended_device_info.device_type + + @property + def power_source(self) -> str: + """Return the power source for the device.""" + return self._extended_device_info.power_source + + @cached_property + def is_router(self) -> bool | None: + """Return true if this is a routing capable device.""" + return ( + self._extended_device_info.device_type == zdo_types.LogicalType.Router.name + ) + + @cached_property + def is_coordinator(self) -> bool | None: + """Return true if this device represents a coordinator.""" + return ( + self._extended_device_info.device_type + == zdo_types.LogicalType.Coordinator.name + ) + + @property + def is_active_coordinator(self) -> bool: + """Return true if this device is the active coordinator.""" + return self._extended_device_info.active_coordinator + + @cached_property + def is_end_device(self) -> bool | None: + """Return true if this device is an end device.""" + return ( + self._extended_device_info.device_type + == zdo_types.LogicalType.EndDevice.name + ) + + @property + def is_groupable(self) -> bool: + """Return true if this device has a group cluster.""" + return self._extended_device_info.is_groupable + + @cached_property + def device_automation_triggers(self) -> dict[tuple[str, str], dict[str, Any]]: + """Return the device automation triggers for this device.""" + return self._extended_device_info.device_automation_triggers + + @property + def available(self): + """Return True if device is available.""" + return self._extended_device_info.available + + @property + def on_network(self): + """Return True if device is currently on the network.""" + return self._extended_device_info.on_network + + @cached_property + def zigbee_signature(self) -> dict[str, Any]: + """Get zigbee signature for this device.""" + return self._extended_device_info.signature + + @property + def sw_version(self) -> int | None: + """Return the software version for this device.""" + return self._extended_device_info.sw_version + + @property + def platform_entities(self) -> dict[tuple[Platform, str], WebSocketClientEntity]: + """Return the platform entities for this device.""" + return self._entities + + def _build_or_update_entities(self): + """Build the entities for this device or rebuild them from extended device info.""" + for entity_info in self._extended_device_info.entities.values(): + entity_key = (entity_info.platform, entity_info.unique_id) + if entity_key in self._entities: + self._entities[entity_key].info_object = entity_info + else: + self._entities[entity_key] = ( + discovery.ENTITY_INFO_CLASS_TO_WEBSOCKET_CLIENT_ENTITY_CLASS[ + entity_info.__class__ + ](entity_info, self) + ) + + def emit_platform_entity_event(self, event: EntityStateChangedEvent) -> None: + """Proxy the firing of an entity event.""" + entity = self.get_platform_entity(event.platform, event.unique_id) + if entity is not None: + entity.state = event.state diff --git a/zha/zigbee/endpoint.py b/zha/zigbee/endpoint.py index e222606cc..24bc72275 100644 --- a/zha/zigbee/endpoint.py +++ b/zha/zigbee/endpoint.py @@ -20,6 +20,7 @@ CLIENT_CLUSTER_HANDLER_REGISTRY, CLUSTER_HANDLER_REGISTRY, ) +from zha.zigbee.model import DeviceStatus if TYPE_CHECKING: from zigpy import Endpoint as ZigpyEndpoint @@ -219,9 +220,6 @@ def async_new_entity( **kwargs: Any, ) -> None: """Create a new entity.""" - from zha.zigbee.device import ( # pylint: disable=import-outside-toplevel - DeviceStatus, - ) if self.device.status == DeviceStatus.INITIALIZED: return diff --git a/zha/zigbee/group.py b/zha/zigbee/group.py index 4ec96a7f2..673ed6d8e 100644 --- a/zha/zigbee/group.py +++ b/zha/zigbee/group.py @@ -2,80 +2,84 @@ from __future__ import annotations +from abc import ABC, abstractmethod import asyncio from collections.abc import Callable -from dataclasses import dataclass -from functools import cached_property import logging -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Generic import zigpy.exceptions from zigpy.types.named import EUI64 -from zha.application.platforms import ( - BaseEntityInfo, - EntityStateChangedEvent, - PlatformEntity, -) +from zha.application import discovery +from zha.application.platforms import PlatformEntity, T, WebSocketClientEntity from zha.const import STATE_CHANGED +from zha.event import EventBase from zha.mixins import LogMixin -from zha.zigbee.device import ExtendedDeviceInfo +from zha.zigbee.model import GroupInfo, GroupMemberInfo, GroupMemberReference if TYPE_CHECKING: from zigpy.group import Group as ZigpyGroup, GroupEndpoint - from zha.application.gateway import Gateway + from zha.application.gateway import Gateway, WebSocketClientGateway from zha.application.platforms import GroupEntity - from zha.zigbee.device import Device + from zha.application.platforms.events import EntityStateChangedEvent + from zha.zigbee.device import Device, WebSocketClientDevice _LOGGER = logging.getLogger(__name__) -@dataclass(frozen=True, kw_only=True) -class GroupMemberReference: - """Describes a group member.""" - - ieee: EUI64 - endpoint_id: int - +class BaseGroupMember(LogMixin, ABC): + """Composite object that represents a device endpoint in a Zigbee group.""" -@dataclass(frozen=True, kw_only=True) -class GroupEntityReference: - """Reference to a group entity.""" + def __init__(self, zha_group, device, endpoint_id: int) -> None: + """Initialize the group member.""" + self._group = zha_group + self._device = device + self._endpoint_id: int = endpoint_id - entity_id: int - name: str | None = None - original_name: str | None = None + @property + @abstractmethod + def group(self): + """Return the group this member belongs to.""" + @property + def endpoint_id(self) -> int: + """Return the endpoint id for this group member.""" + return self._endpoint_id -@dataclass(frozen=True, kw_only=True) -class GroupMemberInfo: - """Describes a group member.""" + @property + @abstractmethod + def device(self): + """Return the ZHA device for this group member.""" - ieee: EUI64 - endpoint_id: int - device_info: ExtendedDeviceInfo - entities: dict[str, BaseEntityInfo] + @property + @abstractmethod + def member_info(self) -> GroupMemberInfo: + """Get ZHA group info.""" + @property + @abstractmethod + def associated_entities(self) -> list[PlatformEntity]: + """Return the list of entities that were derived from this endpoint.""" -@dataclass(frozen=True, kw_only=True) -class GroupInfo: - """Describes a group.""" + @abstractmethod + async def async_remove_from_group(self) -> None: + """Remove the device endpoint from the provided zigbee group.""" - group_id: int - name: str - members: list[GroupMemberInfo] - entities: dict[str, BaseEntityInfo] + def log(self, level: int, msg: str, *args: Any, **kwargs) -> None: + """Log a message.""" + msg = f"[%s](%s): {msg}" + args = (f"0x{self._group.group_id:04x}", self.endpoint_id) + args + _LOGGER.log(level, msg, *args, **kwargs) -class GroupMember(LogMixin): +class GroupMember(BaseGroupMember): """Composite object that represents a device endpoint in a Zigbee group.""" def __init__(self, zha_group: Group, device: Device, endpoint_id: int) -> None: """Initialize the group member.""" - self._group: Group = zha_group - self._device: Device = device - self._endpoint_id: int = endpoint_id + super().__init__(zha_group, device, endpoint_id) @property def group(self) -> Group: @@ -83,11 +87,6 @@ def group(self) -> Group: return self._group @property - def endpoint_id(self) -> int: - """Return the endpoint id for this group member.""" - return self._endpoint_id - - @cached_property def endpoint(self) -> GroupEndpoint: """Return the endpoint for this group member.""" return self._device.device.endpoints.get(self.endpoint_id) @@ -97,7 +96,7 @@ def device(self) -> Device: """Return the ZHA device for this group member.""" return self._device - @cached_property + @property def member_info(self) -> GroupMemberInfo: """Get ZHA group info.""" return GroupMemberInfo( @@ -105,12 +104,12 @@ def member_info(self) -> GroupMemberInfo: endpoint_id=self.endpoint_id, device_info=self.device.extended_device_info, entities={ - entity.unique_id: entity.info_object + entity.unique_id: entity.info_object.model_dump() for entity in self.associated_entities }, ) - @cached_property + @property def associated_entities(self) -> list[PlatformEntity]: """Return the list of entities that were derived from this endpoint.""" return [ @@ -138,14 +137,95 @@ async def async_remove_from_group(self) -> None: str(ex), ) - def log(self, level: int, msg: str, *args: Any, **kwargs) -> None: - """Log a message.""" - msg = f"[%s](%s): {msg}" - args = (f"0x{self._group.group_id:04x}", self.endpoint_id) + args - _LOGGER.log(level, msg, *args, **kwargs) +class WebSocketClientGroupMember(BaseGroupMember): + """Composite object that represents a device endpoint in a Zigbee group.""" + + def __init__( + self, + zha_group: WebSocketClientGroup, + device: WebSocketClientDevice, + endpoint_id: int, + member_info: GroupMemberInfo, + ) -> None: + """Initialize the group member.""" + super().__init__(zha_group, device, endpoint_id) + self._member_info = member_info + + @property + def group(self) -> WebSocketClientGroup: + """Return the group this member belongs to.""" + return self._group + + @property + def device(self) -> WebSocketClientDevice: + """Return the ZHA device for this group member.""" + return self._device + + @property + def member_info(self) -> GroupMemberInfo: + """Get ZHA group info.""" + return self._member_info + + @property + def associated_entities(self) -> list[PlatformEntity]: + """Return the list of entities that were derived from this endpoint.""" + return [ + platform_entity + for platform_entity in self._device.platform_entities.values() + if platform_entity.info_object.endpoint_id == self.endpoint_id + ] + + async def async_remove_from_group(self) -> None: + """Remove the device endpoint from the provided zigbee group.""" + await self.group.gateway.groups_helper.remove_group_members( + self.group.info_object, [self.member_info] + ) + + +class BaseGroup(LogMixin, EventBase, ABC, Generic[T]): + """Base class for Zigbee groups.""" + + def __init__( + self, + gateway: Gateway, + ) -> None: + """Initialize the group.""" + super().__init__() + self._gateway = gateway -class Group(LogMixin): + @property + def gateway(self) -> Gateway: + """Return the gateway for this group.""" + return self._gateway + + @property + @abstractmethod + def name(self) -> str: + """Return group name.""" + + @property + @abstractmethod + def group_id(self) -> int: + """Return group name.""" + + @property + @abstractmethod + def group_entities(self) -> dict[str, T]: + """Return the platform entities of the group.""" + + @property + @abstractmethod + def members(self): + """Return the ZHA devices that are members of this group.""" + + @property + @abstractmethod + def info_object(self) -> GroupInfo: + """Get ZHA group info.""" + + +class Group(BaseGroup): """ZHA Zigbee group object.""" def __init__( @@ -154,7 +234,7 @@ def __init__( zigpy_group: zigpy.group.Group, ) -> None: """Initialize the group.""" - self._gateway = gateway + super().__init__(gateway) self._zigpy_group = zigpy_group self._group_entities: dict[str, GroupEntity] = {} self._entity_unsubs: dict[str, Callable] = {} @@ -189,7 +269,7 @@ def gateway(self) -> Gateway: """Return the gateway for this group.""" return self._gateway - @cached_property + @property def members(self) -> list[GroupMember]: """Return the ZHA devices that are members of this group.""" return [ @@ -198,7 +278,7 @@ def members(self) -> list[GroupMember]: if member_ieee in self._gateway.devices ] - @cached_property + @property def info_object(self) -> GroupInfo: """Get ZHA group info.""" return GroupInfo( @@ -206,12 +286,12 @@ def info_object(self) -> GroupInfo: name=self.name, members=[member.member_info for member in self.members], entities={ - unique_id: entity.info_object + unique_id: entity.info_object.model_dump() for unique_id, entity in self._group_entities.items() }, ) - @cached_property + @property def all_member_entity_unique_ids(self) -> list[str]: """Return all platform entities unique ids for the members of this group.""" all_entity_unique_ids: list[str] = [] @@ -245,15 +325,6 @@ async def _maybe_update_group_members(self, event: EntityStateChangedEvent) -> N if tasks: await asyncio.gather(*tasks) - def clear_caches(self) -> None: - """Clear cached properties.""" - if hasattr(self, "all_member_entity_unique_ids"): - delattr(self, "all_member_entity_unique_ids") - if hasattr(self, "info_object"): - delattr(self, "info_object") - if hasattr(self, "members"): - delattr(self, "members") - def update_entity_subscriptions(self) -> None: """Update the entity event subscriptions. @@ -264,7 +335,6 @@ def update_entity_subscriptions(self) -> None: for group entities and the platrom entities that we processed. Then we loop over all of the unsub ids and we execute the unsubscribe method for each one that isn't in the combined list. """ - self.clear_caches() group_entity_ids = list(self._group_entities.keys()) processed_platform_entity_ids = [] @@ -349,3 +419,97 @@ async def on_remove(self) -> None: """Cancel tasks this group owns.""" for group_entity in self._group_entities.values(): await group_entity.on_remove() + + +class WebSocketClientGroup(BaseGroup): + """ZHA Zigbee group object for the websocket client.""" + + def __init__( + self, + group_info: GroupInfo, + gateway: WebSocketClientGateway, + ) -> None: + """Initialize the group.""" + super().__init__(gateway) + self._group_info = group_info + self._entities: dict[str, WebSocketClientEntity] = {} + if self._group_info.entities: + self._build_or_update_entities() + + @property + def name(self) -> str: + """Return group name.""" + return self._group_info.name + + @property + def group_id(self) -> int: + """Return group name.""" + return self._group_info.group_id + + @property + def group_entities(self) -> dict[str, WebSocketClientEntity]: + """Return the platform entities of the group.""" + return self._entities + + @property + def members(self) -> list[WebSocketClientGroupMember]: + """Return the ZHA devices that are members of this group.""" + return [ + WebSocketClientGroupMember( + self, self._gateway.devices[member.ieee], member.endpoint_id, member + ) + for member in self._group_info.members + ] + + @property + def all_member_entity_unique_ids(self) -> list[str]: + """Return all platform entities unique ids for the members of this group.""" + all_entity_unique_ids: list[str] = [] + for member in self.members: + entities = member.associated_entities + for entity in entities: + all_entity_unique_ids.append(entity.unique_id) + return all_entity_unique_ids + + @property + def info_object(self) -> GroupInfo: + """Get ZHA group info.""" + return self._group_info + + @info_object.setter + def info_object(self, group_info: GroupInfo) -> None: + """Set ZHA group info.""" + self._group_info = group_info + self._build_or_update_entities() + + def _build_or_update_entities(self): + """Build the entities for this device or rebuild them from extended device info.""" + current_entity_ids = set(self._entities.keys()) + for unique_id, entity_info in self._group_info.entities.items(): + if unique_id in self._entities: + self._entities[unique_id].entity_info = entity_info + current_entity_ids.remove(unique_id) + else: + self._entities[unique_id] = ( + discovery.ENTITY_INFO_CLASS_TO_WEBSOCKET_CLIENT_ENTITY_CLASS[ + entity_info.__class__ + ](entity_info, self) + ) + for entity_id in current_entity_ids: + self._entities.pop(entity_id, None) + + def emit_platform_entity_event(self, event: EntityStateChangedEvent) -> None: + """Proxy the firing of an entity event.""" + entity = self.group_entities.get(event.unique_id) + if entity is not None: + entity.state = event.state + + async def async_add_members(self, members: list[GroupMemberReference]) -> None: + """Add members to this group.""" + await self._gateway.groups_helper.add_group_members(self.info_object, members) + + async def async_remove_members(self, members: list[GroupMemberReference]) -> None: + """Remove members from this group.""" + await self._gateway.groups_helper.remove_group_members( + self.info_object, members + ) diff --git a/zha/zigbee/model.py b/zha/zigbee/model.py new file mode 100644 index 000000000..bb5daa60d --- /dev/null +++ b/zha/zigbee/model.py @@ -0,0 +1,328 @@ +"""Models for the ZHA zigbee module.""" + +from __future__ import annotations + +from enum import Enum, StrEnum +from typing import TYPE_CHECKING, Any, Literal, TypeAlias, Union + +from pydantic import field_serializer, field_validator +from zigpy.types import uint1_t, uint8_t +from zigpy.types.named import EUI64, NWK, ExtendedPanId +from zigpy.zdo.types import RouteStatus, _NeighborEnums + +from zha.application import Platform +from zha.application.platforms.alarm_control_panel.model import ( + AlarmControlPanelEntityInfo, +) +from zha.application.platforms.binary_sensor.model import BinarySensorEntityInfo +from zha.application.platforms.button.model import ( + ButtonEntityInfo, + CommandButtonEntityInfo, + WriteAttributeButtonEntityInfo, +) +from zha.application.platforms.climate.model import ThermostatEntityInfo +from zha.application.platforms.cover.model import CoverEntityInfo, ShadeEntityInfo +from zha.application.platforms.device_tracker.model import DeviceTrackerEntityInfo +from zha.application.platforms.fan.model import FanEntityInfo +from zha.application.platforms.light.model import LightEntityInfo +from zha.application.platforms.lock.model import LockEntityInfo +from zha.application.platforms.number.model import ( + NumberConfigurationEntityInfo, + NumberEntityInfo, +) +from zha.application.platforms.select.model import ( + EnumSelectEntityInfo, + SelectEntityInfo, +) +from zha.application.platforms.sensor.model import ( + BatteryEntityInfo, + DeviceCounterSensorEntityInfo, + ElectricalMeasurementEntityInfo, + SensorEntityInfo, + SetpointChangeSourceTimestampSensorEntityInfo, + SmartEnergyMeteringEntityInfo, +) +from zha.application.platforms.siren.model import SirenEntityInfo +from zha.application.platforms.switch.model import ( + ConfigurableAttributeSwitchEntityInfo, + SwitchEntityInfo, +) +from zha.application.platforms.update.model import FirmwareUpdateEntityInfo +from zha.const import DeviceEvents, EventTypes +from zha.model import BaseEvent, BaseModel, as_tagged_union, convert_enum, convert_int + + +class DeviceStatus(StrEnum): + """Status of a device.""" + + CREATED = "created" + INITIALIZED = "initialized" + + +class ZHAEvent(BaseEvent): + """Event generated when a device wishes to send an arbitrary event.""" + + device_ieee: EUI64 + unique_id: str + data: dict[str, Any] + event_type: Literal[EventTypes.DEVICE_EVENT] = EventTypes.DEVICE_EVENT + event: Literal[DeviceEvents.ZHA_EVENT] = DeviceEvents.ZHA_EVENT + + +class ClusterHandlerConfigurationComplete(BaseEvent): + """Event generated when all cluster handlers are configured.""" + + device_ieee: EUI64 + unique_id: str + event_type: Literal["zha_channel_message"] = "zha_channel_message" + event: Literal["zha_channel_cfg_done"] = "zha_channel_cfg_done" + + +class ClusterBinding(BaseModel): + """Describes a cluster binding.""" + + name: str + type: str + id: int + endpoint_id: int + + +class DeviceInfo(BaseModel): + """Describes a device.""" + + ieee: EUI64 + nwk: NWK + manufacturer: str + model: str + name: str + quirk_applied: bool + quirk_class: str + quirk_id: str | None + manufacturer_code: int | None + power_source: str + lqi: int | None + rssi: int | None + last_seen: float | None = None + last_seen_time: str | None = None + available: bool + on_network: bool + is_groupable: bool + device_type: str + signature: dict[str, Any] + sw_version: int | None = None + + @field_serializer("signature", check_fields=False) + def serialize_signature(self, signature: dict[str, Any]): + """Serialize signature.""" + if "node_descriptor" in signature and not isinstance( + signature["node_descriptor"], dict + ): + signature["node_descriptor"] = signature["node_descriptor"].as_dict() + return signature + + +class NeighborInfo(BaseModel): + """Describes a neighbor.""" + + device_type: _NeighborEnums.DeviceType + rx_on_when_idle: _NeighborEnums.RxOnWhenIdle + relationship: _NeighborEnums.Relationship + extended_pan_id: ExtendedPanId + ieee: EUI64 + nwk: NWK + permit_joining: _NeighborEnums.PermitJoins + depth: uint8_t + lqi: uint8_t + + _convert_device_type = field_validator( + "device_type", mode="before", check_fields=False + )(convert_enum(_NeighborEnums.DeviceType)) + + _convert_rx_on_when_idle = field_validator( + "rx_on_when_idle", mode="before", check_fields=False + )(convert_enum(_NeighborEnums.RxOnWhenIdle)) + + _convert_relationship = field_validator( + "relationship", mode="before", check_fields=False + )(convert_enum(_NeighborEnums.Relationship)) + + _convert_permit_joining = field_validator( + "permit_joining", mode="before", check_fields=False + )(convert_enum(_NeighborEnums.PermitJoins)) + + _convert_depth = field_validator("depth", mode="before", check_fields=False)( + convert_int(uint8_t) + ) + _convert_lqi = field_validator("lqi", mode="before", check_fields=False)( + convert_int(uint8_t) + ) + + @field_validator("extended_pan_id", mode="before", check_fields=False) + @classmethod + def convert_extended_pan_id( + cls, extended_pan_id: Union[str, ExtendedPanId] + ) -> ExtendedPanId: + """Convert extended_pan_id to ExtendedPanId.""" + if isinstance(extended_pan_id, str): + return ExtendedPanId.convert(extended_pan_id) + return extended_pan_id + + @field_serializer("extended_pan_id", check_fields=False) + def serialize_extended_pan_id(self, extended_pan_id: ExtendedPanId): + """Customize how extended_pan_id is serialized.""" + return str(extended_pan_id) + + @field_serializer( + "device_type", + "rx_on_when_idle", + "relationship", + "permit_joining", + check_fields=False, + ) + def serialize_enums(self, enum_value: Enum): + """Serialize enums by name.""" + return enum_value.name + + +class RouteInfo(BaseModel): + """Describes a route.""" + + dest_nwk: NWK + route_status: RouteStatus + memory_constrained: uint1_t + many_to_one: uint1_t + route_record_required: uint1_t + next_hop: NWK + + _convert_route_status = field_validator( + "route_status", mode="before", check_fields=False + )(convert_enum(RouteStatus)) + + _convert_memory_constrained = field_validator( + "memory_constrained", mode="before", check_fields=False + )(convert_int(uint1_t)) + + _convert_many_to_one = field_validator( + "many_to_one", mode="before", check_fields=False + )(convert_int(uint1_t)) + + _convert_route_record_required = field_validator( + "route_record_required", mode="before", check_fields=False + )(convert_int(uint1_t)) + + @field_serializer( + "route_status", + check_fields=False, + ) + def serialize_route_status(self, route_status: RouteStatus): + """Serialize route_status as name.""" + return route_status.name + + +class EndpointNameInfo(BaseModel): + """Describes an endpoint name.""" + + name: str + + +EntityInfoUnion: TypeAlias = ( + SirenEntityInfo + | SelectEntityInfo + | NumberEntityInfo + | LightEntityInfo + | FanEntityInfo + | ButtonEntityInfo + | CommandButtonEntityInfo + | WriteAttributeButtonEntityInfo + | AlarmControlPanelEntityInfo + | FirmwareUpdateEntityInfo + | SensorEntityInfo + | BinarySensorEntityInfo + | DeviceTrackerEntityInfo + | ShadeEntityInfo + | CoverEntityInfo + | LockEntityInfo + | SwitchEntityInfo + | BatteryEntityInfo + | ElectricalMeasurementEntityInfo + | SmartEnergyMeteringEntityInfo + | ThermostatEntityInfo + | DeviceCounterSensorEntityInfo + | SetpointChangeSourceTimestampSensorEntityInfo + | NumberConfigurationEntityInfo + | EnumSelectEntityInfo + | ConfigurableAttributeSwitchEntityInfo +) + +if not TYPE_CHECKING: + EntityInfoUnion = as_tagged_union(EntityInfoUnion) + + +class ExtendedDeviceInfo(DeviceInfo): + """Describes a ZHA device.""" + + active_coordinator: bool + entities: dict[tuple[Platform, str], EntityInfoUnion] + neighbors: list[NeighborInfo] + routes: list[RouteInfo] + endpoint_names: list[EndpointNameInfo] + device_automation_triggers: dict[tuple[str, str], dict[str, Any]] + + @field_validator( + "device_automation_triggers", "entities", mode="before", check_fields=False + ) + @classmethod + def validate_tuple_keyed_dicts( + cls, + tuple_keyed_dict: dict[tuple[str, str], Any] | dict[str, dict[str, Any]], + ) -> dict[tuple[str, str], Any] | dict[str, dict[str, Any]]: + """Validate device_automation_triggers.""" + if all(isinstance(key, str) for key in tuple_keyed_dict): + return { + tuple(key.split(",")): item for key, item in tuple_keyed_dict.items() + } + return tuple_keyed_dict + + +class GroupMemberReference(BaseModel): + """Describes a group member.""" + + ieee: EUI64 + endpoint_id: int + + +class GroupEntityReference(BaseModel): + """Reference to a group entity.""" + + entity_id: str + name: str | None = None + original_name: str | None = None + + +class GroupMemberInfo(BaseModel): + """Describes a group member.""" + + ieee: EUI64 + endpoint_id: int + device_info: ExtendedDeviceInfo + entities: dict[str, EntityInfoUnion] + + +GroupEntityUnion: TypeAlias = LightEntityInfo | FanEntityInfo | SwitchEntityInfo + +if not TYPE_CHECKING: + GroupEntityUnion = as_tagged_union(GroupEntityUnion) + + +class GroupInfo(BaseModel): + """Describes a group.""" + + group_id: int + name: str + members: list[GroupMemberInfo] + entities: dict[str, GroupEntityUnion] + + @property + def members_by_ieee(self) -> dict[EUI64, GroupMemberInfo]: + """Return members by ieee.""" + return {member.ieee: member for member in self.members}