Skip to content

Commit 352d47b

Browse files
committed
tagged unions for events, commands and responses
1 parent 1094059 commit 352d47b

File tree

6 files changed

+106
-180
lines changed

6 files changed

+106
-180
lines changed

tests/test_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,14 @@ def test_ser_deser_zha_event():
3131
"device_ieee": "00:00:00:00:00:00:00:00",
3232
"unique_id": "00:00:00:00:00:00:00:00",
3333
"data": {"key": "value"},
34+
"model_class_name": "ZHAEvent",
3435
}
3536

3637
assert (
3738
zha_event.model_dump_json()
3839
== '{"message_type":"event","event_type":"device_event","event":"zha_event",'
39-
'"device_ieee":"00:00:00:00:00:00:00:00","unique_id":"00:00:00:00:00:00:00:00","data":{"key":"value"}}'
40+
'"device_ieee":"00:00:00:00:00:00:00:00","unique_id":"00:00:00:00:00:00:00:00",'
41+
'"data":{"key":"value"},"model_class_name":"ZHAEvent"}'
4042
)
4143

4244
device_info = DeviceInfo(

zha/application/websocket_api.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,15 @@
99
from pydantic import Field
1010
from zigpy.types.named import EUI64
1111

12-
from zha.websocket.const import DURATION, GROUPS, APICommands
12+
from zha.websocket.const import GROUPS, APICommands
1313
from zha.websocket.server.api import decorators, register_api_command
1414
from zha.websocket.server.api.model import (
1515
GetApplicationStateResponse,
1616
GetDevicesResponse,
17+
GroupsResponse,
18+
PermitJoiningResponse,
1719
ReadClusterAttributesResponse,
20+
UpdateGroupResponse,
1821
WebSocketCommand,
1922
WriteClusterAttributeResponse,
2023
)
@@ -150,7 +153,14 @@ async def get_groups(
150153
group.info_object
151154
) # maybe we should change the group_id type...
152155
_LOGGER.info("groups: %s", groups)
153-
client.send_result_success(command, {GROUPS: groups})
156+
client.send_result_success(
157+
command,
158+
GroupsResponse(
159+
**command.model_dump(exclude="model_class_name"),
160+
groups=groups,
161+
success=True,
162+
),
163+
)
154164

155165

156166
class PermitJoiningCommand(WebSocketCommand):
@@ -169,10 +179,13 @@ async def permit_joining(
169179
"""Permit joining devices to the Zigbee network."""
170180
# TODO add permit with code support
171181
await gateway.application_controller.permit(command.duration, command.ieee)
172-
client.send_result_success(
173-
command,
174-
{DURATION: command.duration},
182+
response = PermitJoiningResponse(
183+
**command.model_dump(exclude="model_class_name"),
184+
success=True,
185+
duration=command.duration,
186+
ieee=command.ieee,
175187
)
188+
client.send_result_success(command, response)
176189

177190

178191
class RemoveDeviceCommand(WebSocketCommand):
@@ -358,7 +371,12 @@ async def create_group(
358371
members = command.members
359372
group_id = command.group_id
360373
group: Group = await gateway.async_create_zigpy_group(group_name, members, group_id)
361-
client.send_result_success(command, {"group": group.info_object})
374+
response = UpdateGroupResponse(
375+
**command.model_dump(exclude="model_class_name"),
376+
group=group.info_object,
377+
success=True,
378+
)
379+
client.send_result_success(command, response)
362380

363381

364382
class RemoveGroupsCommand(WebSocketCommand):
@@ -416,7 +434,12 @@ async def add_group_members(
416434
if not group:
417435
client.send_result_error(command, "G1", "ZHA Group not found")
418436
return
419-
client.send_result_success(command, {GROUP: group.info_object})
437+
response = UpdateGroupResponse(
438+
**command.model_dump(exclude="model_class_name"),
439+
group=group.info_object,
440+
success=True,
441+
)
442+
client.send_result_success(command, response)
420443

421444

422445
class RemoveGroupMembersCommand(AddGroupMembersCommand):
@@ -443,7 +466,12 @@ async def remove_group_members(
443466
if not group:
444467
client.send_result_error(command, "G1", "ZHA Group not found")
445468
return
446-
client.send_result_success(command, {GROUP: group.info_object})
469+
response = UpdateGroupResponse(
470+
**command.model_dump(exclude="model_class_name"),
471+
group=group.info_object,
472+
success=True,
473+
)
474+
client.send_result_success(command, response)
447475

448476

449477
class StopServerCommand(WebSocketCommand):

zha/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def as_tagged_union(union):
160160
]
161161

162162

163-
class BaseEvent(BaseModel):
163+
class BaseEvent(TypedBaseModel):
164164
"""Base model for ZHA events."""
165165

166166
message_type: Literal["event"] = "event"
Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,11 @@
11
"""Models that represent messages in zha."""
22

3-
from typing import Annotated
4-
53
from pydantic import RootModel
6-
from pydantic.fields import Field
74

8-
from zha.websocket.server.api.model import CommandResponses, Events
5+
from zha.websocket.server.api.model import Messages
96

107

118
class Message(RootModel):
129
"""Response model."""
1310

14-
root: Annotated[
15-
CommandResponses | Events,
16-
Field(discriminator="message_type"),
17-
]
11+
root: Messages

0 commit comments

Comments
 (0)