Skip to content

Commit 0a1a6c3

Browse files
committed
finish cover api and clean up tests
1 parent 0b21b3a commit 0a1a6c3

File tree

7 files changed

+123
-29
lines changed

7 files changed

+123
-29
lines changed

tests/test_cover.py

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -94,22 +94,21 @@
9494

9595

9696
@pytest.mark.parametrize(
97-
"gateway_type, entity_type",
97+
"gateway_type",
9898
[
99-
("zha_gateway", Platform.COVER),
100-
("ws_gateway", Platform.COVER),
99+
"zha_gateway",
100+
"ws_gateway",
101101
],
102102
)
103103
@pytest.mark.looptime
104104
async def test_cover_non_tilt_initial_state( # pylint: disable=unused-argument
105105
zha_gateways: CombinedGateways,
106-
zigpy_cover_device,
107106
gateway_type: str,
108-
entity_type: type,
109107
) -> None:
110108
"""Test ZHA cover platform."""
111109

112110
zha_gateway = getattr(zha_gateways, gateway_type)
111+
zigpy_cover_device = create_mock_zigpy_device(zha_gateway, ZIGPY_COVER_DEVICE)
113112
# load up cover domain
114113
zigpy_cover_device = create_mock_zigpy_device(zha_gateway, ZIGPY_COVER_DEVICE)
115114
cluster = zigpy_cover_device.endpoints[1].window_covering
@@ -164,17 +163,16 @@ async def test_cover_non_tilt_initial_state( # pylint: disable=unused-argument
164163

165164

166165
@pytest.mark.parametrize(
167-
"gateway_type, entity_type",
166+
"gateway_type",
168167
[
169-
("zha_gateway", Platform.COVER),
170-
("ws_gateway", Platform.COVER),
168+
"zha_gateway",
169+
"ws_gateway",
171170
],
172171
)
173172
@pytest.mark.looptime
174173
async def test_cover(
175174
zha_gateways: CombinedGateways,
176175
gateway_type: str,
177-
entity_type: type,
178176
) -> None:
179177
"""Test zha cover platform."""
180178

@@ -410,21 +408,21 @@ async def test_cover(
410408

411409

412410
@pytest.mark.parametrize(
413-
"gateway_type, entity_type",
411+
"gateway_type",
414412
[
415-
("zha_gateway", Platform.COVER),
416-
("ws_gateway", Platform.COVER),
413+
"zha_gateway",
414+
"ws_gateway",
417415
],
418416
)
419417
@pytest.mark.looptime
420418
async def test_cover_failures(
421419
zha_gateways: CombinedGateways,
422420
gateway_type: str,
423-
entity_type: type,
424421
) -> None:
425422
"""Test ZHA cover platform failure cases."""
426423

427424
zha_gateway = getattr(zha_gateways, gateway_type)
425+
zigpy_cover_device = create_mock_zigpy_device(zha_gateway, ZIGPY_COVER_DEVICE)
428426
# load up cover domain
429427
zigpy_cover_device = create_mock_zigpy_device(zha_gateway, ZIGPY_COVER_DEVICE)
430428
cluster = zigpy_cover_device.endpoints[1].window_covering
@@ -621,17 +619,16 @@ async def test_cover_failures(
621619

622620

623621
@pytest.mark.parametrize(
624-
"gateway_type, entity_type",
622+
"gateway_type",
625623
[
626-
("zha_gateway", Platform.COVER),
627-
("ws_gateway", Platform.COVER),
624+
"zha_gateway",
625+
"ws_gateway",
628626
],
629627
)
630628
@pytest.mark.looptime
631629
async def test_shade(
632630
zha_gateways: CombinedGateways,
633631
gateway_type: str,
634-
entity_type: type,
635632
) -> None:
636633
"""Test zha cover platform for shade device type."""
637634

@@ -813,17 +810,16 @@ async def test_shade(
813810

814811

815812
@pytest.mark.parametrize(
816-
"gateway_type, entity_type",
813+
"gateway_type",
817814
[
818-
("zha_gateway", Platform.COVER),
819-
("ws_gateway", Platform.COVER),
815+
"zha_gateway",
816+
"ws_gateway",
820817
],
821818
)
822819
@pytest.mark.looptime
823820
async def test_keen_vent(
824821
zha_gateways: CombinedGateways,
825822
gateway_type: str,
826-
entity_type: type,
827823
) -> None:
828824
"""Test keen vent."""
829825

@@ -893,21 +889,21 @@ async def test_keen_vent(
893889

894890

895891
@pytest.mark.parametrize(
896-
"gateway_type, entity_type",
892+
"gateway_type",
897893
[
898-
("zha_gateway", Platform.COVER),
899-
("ws_gateway", Platform.COVER),
894+
"zha_gateway",
895+
"ws_gateway",
900896
],
901897
)
902898
@pytest.mark.looptime
903899
async def test_cover_remote(
904900
zha_gateways: CombinedGateways,
905901
gateway_type: str,
906-
entity_type: type,
907902
) -> None:
908903
"""Test ZHA cover remote."""
909904

910905
zha_gateway = getattr(zha_gateways, gateway_type)
906+
zigpy_cover_remote = create_mock_zigpy_device(zha_gateway, ZIGPY_COVER_REMOTE)
911907
# load up cover domain
912908
zigpy_cover_remote = create_mock_zigpy_device(zha_gateway, ZIGPY_COVER_REMOTE)
913909
zha_device = await join_zigpy_device(zha_gateway, zigpy_cover_remote)
@@ -946,12 +942,21 @@ async def test_cover_remote(
946942
assert device.emit_zha_event.call_args[0][0][ATTR_COMMAND] == "down_close"
947943

948944

949-
# TODO parametrize this test and add service to restore state attributes
945+
@pytest.mark.parametrize(
946+
"gateway_type",
947+
[
948+
"zha_gateway",
949+
"ws_gateway",
950+
],
951+
)
950952
@pytest.mark.looptime
951953
async def test_cover_state_restoration(
952-
zha_gateway: Gateway,
954+
zha_gateways: CombinedGateways,
955+
gateway_type: str,
953956
) -> None:
954957
"""Test the cover state restoration."""
958+
959+
zha_gateway = getattr(zha_gateways, gateway_type)
955960
zigpy_cover_device = create_mock_zigpy_device(zha_gateway, ZIGPY_COVER_DEVICE)
956961
zha_device = await join_zigpy_device(zha_gateway, zigpy_cover_device)
957962
entity = get_entity(zha_device, platform=Platform.COVER)
@@ -966,6 +971,11 @@ async def test_cover_state_restoration(
966971
target_tilt_position=34,
967972
)
968973

974+
# ws impl needs a round trip to get the state back to the client
975+
# maybe we make this optimistic, set the state manually on the client
976+
# and avoid the round trip refresh call?
977+
await zha_gateway.async_block_till_done()
978+
969979
assert entity.state["state"] == STATE_CLOSED
970980
assert entity.state["target_lift_position"] == 12
971981
assert entity.state["target_tilt_position"] == 34

zha/application/platforms/cover/__init__.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ def restore_external_state_attributes(
191191
], # FIXME: why must these be expanded?
192192
target_lift_position: int | None,
193193
target_tilt_position: int | None,
194+
**kwargs: Any,
194195
):
195196
"""Restore external state attributes."""
196197
self._state = state
@@ -624,6 +625,7 @@ def __init__(
624625
) -> None:
625626
"""Initialize the ZHA fan entity."""
626627
super().__init__(entity_info, device)
628+
self._tasks: list[asyncio.Task] = []
627629

628630
@property
629631
def supported_features(self) -> CoverEntityFeature:
@@ -688,3 +690,33 @@ async def async_stop_cover(self, **kwargs: Any) -> None:
688690
async def async_stop_cover_tilt(self, **kwargs: Any) -> None:
689691
"""Stop the cover tilt."""
690692
await self._device.gateway.covers.stop_cover_tilt(self.info_object)
693+
694+
def restore_external_state_attributes(
695+
self,
696+
*,
697+
state: Literal[
698+
"open", "opening", "closed", "closing"
699+
], # FIXME: why must these be expanded?
700+
target_lift_position: int | None,
701+
target_tilt_position: int | None,
702+
):
703+
"""Restore external state attributes."""
704+
705+
def refresh_state():
706+
refresh_task = asyncio.create_task(
707+
self._device.gateway.entities.refresh_state(self.info_object)
708+
)
709+
self._tasks.append(refresh_task)
710+
refresh_task.add_done_callback(self._tasks.remove)
711+
712+
task = asyncio.create_task(
713+
self._device.gateway.covers.restore_external_state_attributes(
714+
self.info_object,
715+
state=state,
716+
target_lift_position=target_lift_position,
717+
target_tilt_position=target_tilt_position,
718+
)
719+
)
720+
self._tasks.append(task)
721+
task.add_done_callback(self._tasks.remove)
722+
task.add_done_callback(lambda _: refresh_state())

zha/application/platforms/cover/websocket_api.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,29 @@ async def stop_cover_tilt(
155155
)
156156

157157

158+
class CoverRestoreExternalStateAttributesCommand(PlatformEntityCommand):
159+
"""Cover restore external state attributes command."""
160+
161+
command: Literal[APICommands.COVER_RESTORE_EXTERNAL_STATE_ATTRIBUTES] = (
162+
APICommands.COVER_RESTORE_EXTERNAL_STATE_ATTRIBUTES
163+
)
164+
platform: str = Platform.COVER
165+
state: Literal["open", "opening", "closed", "closing"]
166+
target_lift_position: int
167+
target_tilt_position: int
168+
169+
170+
@decorators.websocket_command(CoverRestoreExternalStateAttributesCommand)
171+
@decorators.async_response
172+
async def restore_cover_external_state_attributes(
173+
server: Server, client: Client, command: CoverRestoreExternalStateAttributesCommand
174+
) -> None:
175+
"""Stop the cover tilt."""
176+
await execute_platform_entity_command(
177+
server, client, command, "restore_external_state_attributes"
178+
)
179+
180+
158181
def load_api(server: Server) -> None:
159182
"""Load the api command handlers."""
160183
register_api_command(server, open_cover)
@@ -165,3 +188,4 @@ def load_api(server: Server) -> None:
165188
register_api_command(server, close_cover_tilt)
166189
register_api_command(server, set_tilt_position)
167190
register_api_command(server, stop_cover_tilt)
191+
register_api_command(server, restore_cover_external_state_attributes)

zha/application/platforms/websocket_api.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,14 @@ async def execute_platform_entity_command(
6060
action = getattr(platform_entity, method_name)
6161
arg_spec = inspect.getfullargspec(action)
6262
if arg_spec.varkw:
63-
await action(**command.model_dump(exclude_none=True))
63+
if inspect.iscoroutinefunction(action):
64+
await action(**command.model_dump(exclude_none=True))
65+
else:
66+
action(**command.model_dump(exclude_none=True))
67+
elif inspect.iscoroutinefunction(action):
68+
await action()
6469
else:
65-
await action() # the only argument is self
70+
action() # the only argument is self
6671

6772
except Exception as err:
6873
_LOGGER.exception("Error executing command: %s", method_name, exc_info=err)

zha/websocket/client/helpers.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
CoverCloseTiltCommand,
2828
CoverOpenCommand,
2929
CoverOpenTiltCommand,
30+
CoverRestoreExternalStateAttributesCommand,
3031
CoverSetPositionCommand,
3132
CoverSetTiltPositionCommand,
3233
CoverStopCommand,
@@ -359,6 +360,24 @@ async def stop_cover_tilt(
359360
)
360361
return await self._client.async_send_command(command)
361362

363+
async def restore_external_state_attributes(
364+
self,
365+
cover_platform_entity: BasePlatformEntityInfo,
366+
state: Literal["open", "opening", "closed", "closing"],
367+
target_lift_position: int,
368+
target_tilt_position: int,
369+
) -> WebSocketCommandResponse:
370+
"""Stop a cover tilt."""
371+
ensure_platform_entity(cover_platform_entity, Platform.COVER)
372+
command = CoverRestoreExternalStateAttributesCommand(
373+
ieee=cover_platform_entity.device_ieee,
374+
unique_id=cover_platform_entity.unique_id,
375+
state=state,
376+
target_lift_position=target_lift_position,
377+
target_tilt_position=target_tilt_position,
378+
)
379+
return await self._client.async_send_command(command)
380+
362381

363382
class FanHelper:
364383
"""Helper to issue fan commands."""

zha/websocket/const.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class APICommands(StrEnum):
6161
COVER_SET_POSITION = "cover_set_position"
6262
COVER_SET_TILT_POSITION = "cover_set_tilt_position"
6363
COVER_STOP_TILT = "cover_stop_tilt"
64+
COVER_RESTORE_EXTERNAL_STATE_ATTRIBUTES = "cover_restore_external_state_attributes"
6465

6566
FAN_TURN_ON = "fan_turn_on"
6667
FAN_TURN_OFF = "fan_turn_off"

zha/websocket/server/api/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ class WebSocketCommand(BaseModel):
7979
APICommands.COVER_CLOSE_TILT,
8080
APICommands.COVER_SET_TILT_POSITION,
8181
APICommands.COVER_STOP_TILT,
82+
APICommands.COVER_RESTORE_EXTERNAL_STATE_ATTRIBUTES,
8283
APICommands.CLIMATE_SET_TEMPERATURE,
8384
APICommands.CLIMATE_SET_HVAC_MODE,
8485
APICommands.CLIMATE_SET_FAN_MODE,
@@ -129,6 +130,7 @@ class ErrorResponse(WebSocketCommandResponse):
129130
"error.cover_set_tilt_position",
130131
"error.cover_stop",
131132
"error.cover_stop_tilt",
133+
"error.cover_restore_external_state_attributes",
132134
"error.climate_set_fan_mode",
133135
"error.climate_set_hvac_mode",
134136
"error.climate_set_preset_mode",
@@ -182,6 +184,7 @@ class DefaultResponse(WebSocketCommandResponse):
182184
"cover_open_tilt",
183185
"cover_close_tilt",
184186
"cover_set_tilt_position",
187+
"cover_restore_external_state_attributes",
185188
"climate_set_fan_mode",
186189
"climate_set_hvac_mode",
187190
"climate_set_preset_mode",

0 commit comments

Comments
 (0)