Skip to content

Commit 6737b36

Browse files
authored
Convert ZCL attributes from strings when writing (#267)
* Convert ZCL attributes when writing * Handle bytes too * Add some unit tests * Address review comments
1 parent c494a3e commit 6737b36

File tree

3 files changed

+141
-22
lines changed

3 files changed

+141
-22
lines changed

tests/test_application_helpers.py

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
"""Test zha application helpers."""
22

3+
from typing import Any
4+
5+
import pytest
36
from zigpy.device import Device as ZigpyDevice
47
from zigpy.profiles import zha
5-
from zigpy.zcl.clusters.general import Basic, OnOff
8+
import zigpy.types as t
9+
from zigpy.zcl.clusters.general import Basic, Identify, OnOff
610
from zigpy.zcl.clusters.security import IasZone
711

812
from tests.common import (
@@ -14,7 +18,12 @@
1418
join_zigpy_device,
1519
)
1620
from zha.application.gateway import Gateway
17-
from zha.application.helpers import async_is_bindable_target, get_matched_clusters
21+
from zha.application.helpers import (
22+
async_is_bindable_target,
23+
convert_to_zcl_values,
24+
convert_zcl_value,
25+
get_matched_clusters,
26+
)
1827

1928
IEEE_GROUPABLE_DEVICE = "01:2d:6f:00:0a:90:69:e8"
2029
IEEE_GROUPABLE_DEVICE2 = "02:2d:6f:00:0a:90:69:e8"
@@ -105,3 +114,76 @@ async def test_get_matched_clusters(
105114
assert matches[0].target_ep_id == 1
106115

107116
assert not await get_matched_clusters(not_bindable_zha_device, remote_zha_device)
117+
118+
119+
class SomeEnum(t.enum8):
120+
"""Some enum."""
121+
122+
value_1 = 0x12
123+
value_2 = 0x34
124+
value_3 = 0x56
125+
126+
127+
class SomeFlag(t.bitmap8):
128+
"""Some bitmap."""
129+
130+
flag_1 = 0b00000001
131+
flag_2 = 0b00000010
132+
flag_3 = 0b00000100
133+
134+
135+
@pytest.mark.parametrize(
136+
("text", "field_type", "result"),
137+
[
138+
# Bytes
139+
(
140+
"b'Some data\\x00\\x01'",
141+
t.SerializableBytes,
142+
t.SerializableBytes(b"Some data\x00\x01"),
143+
),
144+
(
145+
'b"Some data\\x00\\x01"',
146+
t.SerializableBytes,
147+
t.SerializableBytes(b"Some data\x00\x01"),
148+
),
149+
(
150+
b"Some data\x00\x01".hex(),
151+
t.SerializableBytes,
152+
t.SerializableBytes(b"Some data\x00\x01"),
153+
),
154+
# Enum
155+
("value 1", SomeEnum, SomeEnum.value_1),
156+
("value_1", SomeEnum, SomeEnum.value_1),
157+
("SomeEnum.value_1", SomeEnum, SomeEnum.value_1),
158+
(0x12, SomeEnum, SomeEnum.value_1),
159+
# Flag
160+
("flag 1", SomeFlag, SomeFlag.flag_1),
161+
("flag_1", SomeFlag, SomeFlag.flag_1),
162+
("SomeFlag.flag_1", SomeFlag, SomeFlag.flag_1),
163+
("SomeFlag.flag_1|flag_2", SomeFlag, SomeFlag.flag_1 | SomeFlag.flag_2),
164+
(0b00000001, SomeFlag, SomeFlag.flag_1),
165+
([0b00000001], SomeFlag, SomeFlag.flag_1),
166+
([0b00000001, 0b00000010], SomeFlag, SomeFlag.flag_1 | SomeFlag.flag_2),
167+
(["flag_1", "flag_2"], SomeFlag, SomeFlag.flag_1 | SomeFlag.flag_2),
168+
# Int
169+
(0x1234, t.uint16_t, 0x1234),
170+
("0x1234", t.uint16_t, 0x1234),
171+
("4660", t.uint16_t, 0x1234),
172+
# Some fallthrough type
173+
(1.000, t.Single, t.Single(1.000)),
174+
("1.000", t.Single, t.Single(1.000)),
175+
],
176+
)
177+
def test_convert_zcl_value(text: Any, field_type: Any, result: Any) -> None:
178+
"""Test converting ZCL values."""
179+
assert convert_zcl_value(text, field_type) == result
180+
181+
182+
def test_convert_to_zcl_values() -> None:
183+
"""Test converting ZCL values."""
184+
185+
identify_schema = Identify.ServerCommandDefs.identify.schema
186+
assert convert_to_zcl_values(
187+
fields={"identify_time": "1"},
188+
schema=identify_schema,
189+
) == {"identify_time": 1}

zha/application/helpers.py

Lines changed: 53 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22

33
from __future__ import annotations
44

5+
import ast
56
import asyncio
67
import binascii
78
import collections
89
from collections.abc import Callable
10+
import contextlib
911
import dataclasses
1012
from dataclasses import dataclass
1113
import datetime
@@ -126,6 +128,53 @@ async def get_matched_clusters(
126128
return clusters_to_bind
127129

128130

131+
def convert_zcl_value(value: Any, field_type: Any) -> Any:
132+
"""Convert user input to ZCL value."""
133+
if issubclass(field_type, enum.Flag):
134+
if isinstance(value, str):
135+
with contextlib.suppress(ValueError):
136+
value = int(value)
137+
138+
if isinstance(value, int):
139+
value = field_type(value)
140+
elif isinstance(value, str):
141+
# List of flags: `SomeFlag.field1 | field2`
142+
value = [v.strip() for v in value.split(".", 1)[-1].split("|")]
143+
144+
if isinstance(value, list):
145+
new_value = 0
146+
147+
for flag in value:
148+
if isinstance(flag, str):
149+
new_value |= field_type[flag.replace(" ", "_")]
150+
else:
151+
new_value |= flag
152+
153+
value = field_type(new_value)
154+
elif issubclass(field_type, enum.Enum):
155+
value = (
156+
field_type[value.replace(" ", "_").split(".", 1)[-1]]
157+
if isinstance(value, str)
158+
else field_type(value)
159+
)
160+
elif issubclass(field_type, zigpy.types.SerializableBytes):
161+
if value.startswith(("b'", 'b"')):
162+
value = ast.literal_eval(value)
163+
else:
164+
value = bytes.fromhex(value)
165+
166+
value = field_type(value)
167+
elif issubclass(field_type, int):
168+
if isinstance(value, str) and value.startswith("0x"):
169+
value = int(value, 16)
170+
171+
value = field_type(value)
172+
else:
173+
value = field_type(value)
174+
175+
return value
176+
177+
129178
def convert_to_zcl_values(
130179
fields: dict[str, Any], schema: CommandSchema
131180
) -> dict[str, Any]:
@@ -134,32 +183,17 @@ def convert_to_zcl_values(
134183
for field in schema.fields:
135184
if field.name not in fields:
136185
continue
137-
value = fields[field.name]
138-
if issubclass(field.type, enum.Flag) and isinstance(value, list):
139-
new_value = 0
140186

141-
for flag in value:
142-
if isinstance(flag, str):
143-
new_value |= field.type[flag.replace(" ", "_")]
144-
else:
145-
new_value |= flag
187+
value = fields[field.name]
188+
new_value = converted_fields[field.name] = convert_zcl_value(value, field.type)
146189

147-
value = field.type(new_value)
148-
elif issubclass(field.type, enum.Enum):
149-
value = (
150-
field.type[value.replace(" ", "_")]
151-
if isinstance(value, str)
152-
else field.type(value)
153-
)
154-
else:
155-
value = field.type(value)
156190
_LOGGER.debug(
157191
"Converted ZCL schema field(%s) value from: %s to: %s",
158192
field.name,
159-
fields[field.name],
160193
value,
194+
new_value,
161195
)
162-
converted_fields[field.name] = value
196+
163197
return converted_fields
164198

165199

zha/zigbee/device.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
ZHA_CLUSTER_HANDLER_MSG,
6161
ZHA_EVENT,
6262
)
63-
from zha.application.helpers import convert_to_zcl_values
63+
from zha.application.helpers import convert_to_zcl_values, convert_zcl_value
6464
from zha.application.platforms import BaseEntityInfo, PlatformEntity
6565
from zha.event import EventBase
6666
from zha.exceptions import ZHAException
@@ -874,6 +874,9 @@ async def write_zigbee_attribute(
874874
f" writing attribute {attribute} with value {value}"
875875
) from exc
876876

877+
attr_def = cluster.find_attribute(attribute)
878+
value = convert_zcl_value(value, attr_def.type)
879+
877880
try:
878881
response = await cluster.write_attributes(
879882
{attribute: value}, manufacturer=manufacturer

0 commit comments

Comments
 (0)