diff --git a/zha/application/platforms/sensor/__init__.py b/zha/application/platforms/sensor/__init__.py index 5c616ef11..ba53845e7 100644 --- a/zha/application/platforms/sensor/__init__.py +++ b/zha/application/platforms/sensor/__init__.py @@ -2,7 +2,7 @@ from __future__ import annotations -from asyncio import Task +import asyncio from dataclasses import dataclass from datetime import UTC, date, datetime import enum @@ -329,7 +329,7 @@ def __init__( ) -> None: """Init this sensor.""" super().__init__(unique_id, cluster_handlers, endpoint, device, **kwargs) - self._polling_task: Task | None = None + self._polling_task: asyncio.Task | None = None def on_add(self) -> None: """Run when entity is added.""" @@ -604,6 +604,10 @@ class ElectricalMeasurement(PollableSensor): _multiplier_attribute_name: str | None = "ac_power_multiplier" _attr_max_attribute_name: str = None + # The final state is computed from up to three attributes, wait for them all to come + # in before emitting a change + _aggregate_attribute_reports_timeout: float = 2.0 + def __init__( self, unique_id: str, @@ -619,6 +623,90 @@ def __init__( self._max_attribute_name, } + self._pending_state_update_attributes: set[str] = set() + self._pending_state_update_timer: asyncio.TimerHandle | None = None + + @property + def _all_state_update_attributes(self) -> set[str]: + """Return a set of attributes that are required to compute state.""" + return { + attr_name + for attr_name in ( + ( + self._attribute_name, + self._divisor_attribute_name, + self._multiplier_attribute_name, + ) + + tuple(self._attr_extra_state_attribute_names) + ) + if ( + attr_name is not None + and attr_name + not in self._cluster_handler.cluster.unsupported_attributes + ) + } - {"measurement_type"} + + async def on_remove(self) -> None: + """Run when entity is removed.""" + if self._pending_state_update_timer is not None: + self._pending_state_update_timer.cancel() + self._pending_state_update_timer = None + + await super().on_remove() + + def handle_cluster_handler_attribute_updated( + self, + event: ClusterAttributeUpdatedEvent, + ) -> None: + """Handle attribute updates from the cluster handler.""" + state_update_attrs = self._all_state_update_attributes + + if len(state_update_attrs) == 1 or not ( + event.attribute_name == self._attribute_name + or event.attribute_name in self._attr_extra_state_attribute_names + ): + super().handle_cluster_handler_attribute_updated(event) + return + + # We need to wait for all of the relevant attributes to be received before we + # can emit a state change event + if not self._pending_state_update_attributes: + self._pending_state_update_attributes = state_update_attrs + + loop = asyncio.get_running_loop() + self._pending_state_update_timer = loop.call_later( + self._aggregate_attribute_reports_timeout, + self._emit_state_change_after_attributes_received, + ) + + # If we have no attributes to wait for *or* we receive a new attribute report + # for an existing attribute during a timeout window, we need to emit immediately + if ( + not self._pending_state_update_attributes + or event.attribute_name not in self._pending_state_update_attributes + ): + self._emit_state_change_after_attributes_received() + else: + self._pending_state_update_attributes.discard(event.attribute_name) + _LOGGER.debug( + "Waiting for attributes to be reported before changing state: %s", + self._pending_state_update_attributes, + ) + + def _emit_state_change_after_attributes_received(self) -> None: + """Emit a state change after all attributes have been received.""" + self._pending_state_update_attributes.clear() + + if self._pending_state_update_timer is not None: + self._pending_state_update_timer.cancel() + self._pending_state_update_timer = None + + _LOGGER.debug( + "Emitting state changed event, pending attributes: %s", + self._pending_state_update_attributes, + ) + self.maybe_emit_state_changed_event() + @property def _max_attribute_name(self) -> str: """Return the max attribute name."""