Skip to content

Commit 7779a7c

Browse files
committed
select api connected
1 parent 7ed03f8 commit 7779a7c

File tree

3 files changed

+74
-16
lines changed

3 files changed

+74
-16
lines changed

tests/test_select.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from unittest.mock import call, patch
44

5+
import pytest
56
from zhaquirks import (
67
DEVICE_TYPE,
78
ENDPOINTS,
@@ -27,14 +28,19 @@
2728
join_zigpy_device,
2829
send_attributes_report,
2930
)
31+
from tests.conftest import CombinedGateways
3032
from zha.application import Platform
31-
from zha.application.gateway import Gateway
32-
from zha.application.platforms import EntityCategory
33+
from zha.application.platforms import EntityCategory, PlatformEntity
3334
from zha.application.platforms.select import AqaraMotionSensitivities
3435

3536

36-
async def test_select(zha_gateway: Gateway) -> None:
37+
@pytest.mark.parametrize(
38+
"gateway_type",
39+
["zha_gateway", "ws_gateway"],
40+
)
41+
async def test_select(zha_gateways: CombinedGateways, gateway_type: str) -> None:
3742
"""Test zha select platform."""
43+
zha_gateway = getattr(zha_gateways, gateway_type)
3844
zigpy_device = create_mock_zigpy_device(
3945
zha_gateway,
4046
{
@@ -63,7 +69,9 @@ async def test_select(zha_gateway: Gateway) -> None:
6369
"Fire Panic",
6470
"Emergency Panic",
6571
]
66-
assert entity._enum == security.IasWd.Warning.WarningMode
72+
73+
if isinstance(entity, PlatformEntity):
74+
assert entity._enum == security.IasWd.Warning.WarningMode
6775

6876
# change value from client
6977
await entity.async_select_option(security.IasWd.Warning.WarningMode.Burglar.name)
@@ -107,9 +115,16 @@ def __init__(self, *args, **kwargs):
107115
}
108116

109117

110-
async def test_on_off_select_attribute_report(zha_gateway: Gateway) -> None:
118+
@pytest.mark.parametrize(
119+
"gateway_type",
120+
["zha_gateway", "ws_gateway"],
121+
)
122+
async def test_on_off_select_attribute_report(
123+
zha_gateways: CombinedGateways, gateway_type: str
124+
) -> None:
111125
"""Test ZHA attribute report parsing for select platform."""
112126

127+
zha_gateway = getattr(zha_gateways, gateway_type)
113128
zigpy_device = create_mock_zigpy_device(
114129
zha_gateway,
115130
{
@@ -126,7 +141,7 @@ async def test_on_off_select_attribute_report(zha_gateway: Gateway) -> None:
126141

127142
zigpy_device = get_device(zigpy_device)
128143
aqara_sensor = await join_zigpy_device(zha_gateway, zigpy_device)
129-
cluster = aqara_sensor.device.endpoints.get(1).opple_cluster
144+
cluster = zigpy_device.endpoints.get(1).opple_cluster
130145

131146
entity = get_entity(aqara_sensor, platform=Platform.SELECT)
132147
assert entity.state["state"] == AqaraMotionSensitivities.Medium.name
@@ -160,11 +175,16 @@ async def test_on_off_select_attribute_report(zha_gateway: Gateway) -> None:
160175
)
161176

162177

178+
@pytest.mark.parametrize(
179+
"gateway_type",
180+
["zha_gateway", "ws_gateway"],
181+
)
163182
async def test_on_off_select_attribute_report_v2(
164-
zha_gateway: Gateway,
183+
zha_gateways: CombinedGateways, gateway_type: str
165184
) -> None:
166185
"""Test ZHA attribute report parsing for select platform."""
167186

187+
zha_gateway = getattr(zha_gateways, gateway_type)
168188
zigpy_device = create_mock_zigpy_device(
169189
zha_gateway,
170190
{
@@ -184,7 +204,7 @@ async def test_on_off_select_attribute_report_v2(
184204

185205
zha_device = await join_zigpy_device(zha_gateway, zigpy_device)
186206
cluster = zigpy_device.endpoints[1].opple_cluster
187-
assert isinstance(zha_device.device, CustomDeviceV2)
207+
assert isinstance(zigpy_device, CustomDeviceV2)
188208

189209
entity = get_entity(zha_device, platform=Platform.SELECT)
190210

@@ -228,8 +248,15 @@ async def test_on_off_select_attribute_report_v2(
228248
)
229249

230250

231-
async def test_non_zcl_select_state_restoration(zha_gateway: Gateway) -> None:
251+
@pytest.mark.parametrize(
252+
"gateway_type",
253+
["zha_gateway", "ws_gateway"],
254+
)
255+
async def test_non_zcl_select_state_restoration(
256+
zha_gateways: CombinedGateways, gateway_type: str
257+
) -> None:
232258
"""Test the non-ZCL select state restoration."""
259+
zha_gateway = getattr(zha_gateways, gateway_type)
233260
zigpy_device = create_mock_zigpy_device(
234261
zha_gateway,
235262
{
@@ -251,9 +278,11 @@ async def test_non_zcl_select_state_restoration(zha_gateway: Gateway) -> None:
251278
entity.restore_external_state_attributes(
252279
state=security.IasWd.Warning.WarningMode.Burglar.name
253280
)
281+
await zha_gateway.async_block_till_done() # needed for WS operations
254282
assert entity.state["state"] == security.IasWd.Warning.WarningMode.Burglar.name
255283

256284
entity.restore_external_state_attributes(
257285
state=security.IasWd.Warning.WarningMode.Fire.name
258286
)
287+
await zha_gateway.async_block_till_done() # needed for WS operations
259288
assert entity.state["state"] == security.IasWd.Warning.WarningMode.Fire.name

zha/application/platforms/select/__init__.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
from abc import ABC, abstractmethod
6+
import asyncio
67
from enum import Enum
78
import functools
89
import logging
@@ -116,21 +117,18 @@ def current_option(self) -> str | None:
116117
return None
117118
return option.name.replace("_", " ")
118119

119-
async def async_select_option(self, option: str) -> None:
120+
async def async_select_option(self, option: str, **kwargs) -> None:
120121
"""Change the selected option."""
121122
self._cluster_handler.data_cache[self._attribute_name] = self._enum[
122123
option.replace(" ", "_")
123124
]
124125
self.maybe_emit_state_changed_event()
125126

126-
def restore_external_state_attributes(
127-
self,
128-
*,
129-
state: str,
130-
) -> None:
127+
def restore_external_state_attributes(self, *, state: str, **kwargs) -> None:
131128
"""Restore extra state attributes that are stored outside of the ZCL cache."""
132129
value = state.replace(" ", "_")
133130
self._cluster_handler.data_cache[self._attribute_name] = self._enum[value]
131+
self.maybe_emit_state_changed_event()
134132

135133

136134
class NonZCLSelectEntity(EnumSelectEntity):
@@ -262,7 +260,7 @@ def current_option(self) -> str | None:
262260
option = self._enum(option)
263261
return option.name.replace("_", " ")
264262

265-
async def async_select_option(self, option: str) -> None:
263+
async def async_select_option(self, option: str, **kwargs) -> None:
266264
"""Change the selected option."""
267265
await self._cluster_handler.write_attributes_safe(
268266
{self._attribute_name: self._enum[option.replace(" ", "_")]}
@@ -916,17 +914,26 @@ def __init__(
916914
) -> None:
917915
"""Initialize the ZHA select entity."""
918916
super().__init__(entity_info, device)
917+
self._tasks: list[asyncio.Task] = []
919918

920919
@property
921920
def current_option(self) -> str | None:
922921
"""Return the selected entity option to represent the entity state."""
923922

924923
async def async_select_option(self, option: str) -> None:
925924
"""Change the selected option."""
925+
await self._device.gateway.selects.select_option(self.info_object, option)
926926

927927
def restore_external_state_attributes(
928928
self,
929929
*,
930930
state: str,
931931
) -> None:
932932
"""Restore extra state attributes."""
933+
task = asyncio.create_task(
934+
self._device.gateway.selects.restore_external_state_attributes(
935+
self.info_object, state
936+
)
937+
)
938+
self._tasks.append(task)
939+
task.add_done_callback(self._tasks.remove)

zha/application/platforms/select/websocket_api.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,28 @@ async def select_option(
3838
)
3939

4040

41+
class SelectRestoreExternalStateAttributesCommand(PlatformEntityCommand):
42+
"""Select restore external state command."""
43+
44+
command: Literal[APICommands.SELECT_RESTORE_EXTERNAL_STATE_ATTRIBUTES] = (
45+
APICommands.SELECT_RESTORE_EXTERNAL_STATE_ATTRIBUTES
46+
)
47+
platform: str = Platform.SELECT
48+
state: str
49+
50+
51+
@decorators.websocket_command(SelectRestoreExternalStateAttributesCommand)
52+
@decorators.async_response
53+
async def restore_lock_external_state_attributes(
54+
server: Server, client: Client, command: SelectRestoreExternalStateAttributesCommand
55+
) -> None:
56+
"""Restore externally preserved state for selects."""
57+
await execute_platform_entity_command(
58+
server, client, command, "restore_external_state_attributes"
59+
)
60+
61+
4162
def load_api(server: Server) -> None:
4263
"""Load the api command handlers."""
4364
register_api_command(server, select_option)
65+
register_api_command(server, restore_lock_external_state_attributes)

0 commit comments

Comments
 (0)