Skip to content

Commit 833a380

Browse files
committed
Clean up voice session handling
Verify AudioConfiguration in session from voice_start command. This requires firmware 2.8.2 or newer.
1 parent 1f83d48 commit 833a380

File tree

3 files changed

+195
-18
lines changed

3 files changed

+195
-18
lines changed

tests/test_voice_assistant.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import unittest
2+
3+
from ucapi.voice_assistant import (
4+
DEFAULT_AUDIO_CHANNELS,
5+
DEFAULT_SAMPLE_FORMAT,
6+
DEFAULT_SAMPLE_RATE,
7+
AudioConfiguration,
8+
SampleFormat,
9+
)
10+
from ucapi.proto import ucr_integration_voice_pb2 as pb2
11+
12+
13+
class TestVoiceAssistantConversions(unittest.TestCase):
14+
def test_sample_format_from_proto_supported(self):
15+
self.assertEqual(SampleFormat.from_proto(pb2.I16), SampleFormat.I16)
16+
self.assertEqual(SampleFormat.from_proto(int(pb2.U32)), SampleFormat.U32)
17+
self.assertEqual(SampleFormat.from_proto("f32"), SampleFormat.F32)
18+
19+
def test_sample_format_from_proto_unsupported_to_none(self):
20+
# Values that do not exist in local enum should map to None
21+
self.assertIsNone(SampleFormat.from_proto(pb2.I8))
22+
self.assertIsNone(SampleFormat.from_proto(pb2.U8))
23+
self.assertIsNone(SampleFormat.from_proto("i8"))
24+
self.assertIsNone(SampleFormat.from_proto("unknown"))
25+
26+
def test_audio_cfg_from_proto_message(self):
27+
msg = pb2.AudioConfiguration(
28+
channels=2,
29+
sample_rate=22050,
30+
sample_format=pb2.I32,
31+
format=pb2.PCM,
32+
)
33+
cfg = AudioConfiguration.from_proto(msg)
34+
35+
self.assertIsInstance(cfg, AudioConfiguration)
36+
self.assertEqual(cfg.channels, 2)
37+
self.assertEqual(cfg.sample_rate, 22050)
38+
self.assertEqual(cfg.sample_format, SampleFormat.I32)
39+
40+
def test_audio_cfg_from_dict_and_string_coercions(self):
41+
cfg = AudioConfiguration.from_proto(
42+
{
43+
"channels": "2",
44+
"sample_rate": "16000",
45+
"sample_format": "u16",
46+
}
47+
)
48+
self.assertEqual(cfg.channels, 2)
49+
self.assertEqual(cfg.sample_rate, 16000)
50+
self.assertEqual(cfg.sample_format, SampleFormat.U16)
51+
52+
def test_audio_cfg_defaults_and_unsupported_sample_format(self):
53+
# Provide garbage values to trigger defaults
54+
cfg = AudioConfiguration.from_proto(
55+
{
56+
"channels": "x",
57+
"sample_rate": "",
58+
"sample_format": pb2.U8, # unsupported in local enum
59+
}
60+
)
61+
self.assertEqual(cfg.channels, DEFAULT_AUDIO_CHANNELS)
62+
self.assertEqual(cfg.sample_rate, DEFAULT_SAMPLE_RATE)
63+
# Unsupported sample_format falls back to default
64+
self.assertEqual(cfg.sample_format, DEFAULT_SAMPLE_FORMAT)
65+
66+
def test_audio_cfg_from_proto_none(self):
67+
self.assertIsNone(AudioConfiguration.from_proto(None))
68+
69+
70+
if __name__ == "__main__":
71+
unittest.main()

ucapi/api.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -449,20 +449,13 @@ async def _on_remote_voice_begin(self, websocket, msg: RemoteVoiceBegin) -> None
449449
"""
450450
if self._voice_handler is None:
451451
# Log once per stream and ignore further binary messages.
452-
cfg = getattr(msg, "configuration", None)
453452
_LOG.warning(
454-
"[%s] proto VoiceBegin: session_id=%s cfg(ch=%s sr=%s fmt=%s af=%s) (no voice handler)",
453+
"[%s] proto VoiceBegin: no voice handler registered! Ignoring voice stream",
455454
websocket.remote_address,
456-
getattr(msg, "session_id", None),
457-
getattr(cfg, "channels", None) if cfg else None,
458-
getattr(cfg, "sample_rate", None) if cfg else None,
459-
getattr(cfg, "sample_format", None) if cfg else None,
460-
getattr(cfg, "format", None) if cfg else None,
461455
)
462456
return
463457

464458
session_id = int(getattr(msg, "session_id", 0) or 0)
465-
session_id = 0 # FIXME(voice) until core is fixed
466459
session = self._voice_sessions.get(session_id)
467460
if not session:
468461
_LOG.error(
@@ -472,13 +465,15 @@ async def _on_remote_voice_begin(self, websocket, msg: RemoteVoiceBegin) -> None
472465
)
473466
return
474467

475-
# TODO(voice) verify AudioConfiguration in session from voice_start command?
476-
# cfg = getattr(msg, "configuration", None)
477-
# audio_cfg = AudioConfiguration(
478-
# channels=int(getattr(cfg, "channels", 1) or 1),
479-
# sample_rate=int(getattr(cfg, "sample_rate", 0) or 0),
480-
# sample_format=int(getattr(cfg, "sample_format", 0) or 0), # FIXME convert
481-
# )
468+
# verify AudioConfiguration in session from voice_start command
469+
cfg = getattr(msg, "configuration", None)
470+
audio_cfg = AudioConfiguration.from_proto(cfg) or AudioConfiguration()
471+
if audio_cfg != session.config:
472+
_LOG.error(
473+
"[%s] proto VoiceBegin: audio cfg does not match voice_start",
474+
websocket.remote_address,
475+
)
476+
return
482477

483478
# Track ownership for cleanup on disconnect
484479
owners = self._voice_ws_sessions.setdefault(websocket, set())
@@ -510,7 +505,6 @@ async def _on_remote_voice_data(self, websocket, msg: RemoteVoiceData) -> None:
510505
return
511506

512507
session_id = int(getattr(msg, "session_id", 0) or 0)
513-
session_id = 0 # FIXME(voice) until core is fixed
514508
session = self._voice_sessions.get(session_id)
515509
if not session:
516510
_LOG.error(
@@ -540,7 +534,6 @@ async def _on_remote_voice_end(self, _websocket, msg: RemoteVoiceEnd) -> None:
540534
if self._voice_handler is None:
541535
return
542536
session_id = int(getattr(msg, "session_id", 0) or 0)
543-
session_id = 0 # FIXME(voice) until core is fixed
544537
await self._cleanup_voice_session(session_id)
545538

546539
async def _cleanup_voice_session(
@@ -861,7 +854,6 @@ async def _entity_command(
861854
):
862855
params = msg_data["params"]
863856
session_id = params.get("session_id")
864-
session_id = 0 # FIXME(voice) until core is fixed
865857
cfg = params.get("audio_cfg")
866858
audio_cfg = (
867859
AudioConfiguration(

ucapi/voice_assistant.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,18 @@
1616
from enum import Enum
1717
from typing import Any, Optional
1818

19+
# Import specific enum constants to avoid pylint no-member on dynamic attributes
1920
from ucapi.api_definitions import CommandHandler
2021
from ucapi.entity import Entity, EntityTypes
2122

23+
from ucapi.proto.ucr_integration_voice_pb2 import ( # pylint: disable=no-name-in-module # isort:skip # noqa
24+
F32 as PB_F32,
25+
I16 as PB_I16,
26+
I32 as PB_I32,
27+
U16 as PB_U16,
28+
U32 as PB_U32,
29+
)
30+
2231
DEFAULT_AUDIO_CHANNELS = 1
2332
DEFAULT_SAMPLE_RATE = 16000
2433

@@ -97,6 +106,48 @@ class SampleFormat(str, Enum):
97106
F32 = "F32"
98107
"""Float 32 bit."""
99108

109+
@classmethod
110+
def from_proto(cls, value: Any) -> Optional["SampleFormat"]:
111+
"""Convert protobuf enum ``SampleFormat`` to Python enum.
112+
113+
Returns ``None`` when the value is unknown or not available in this
114+
Python enum (e.g., ``SAMPLE_FORMAT_UNKNOWN``, ``I8``, ``U8``).
115+
116+
Accepts the following inputs:
117+
- Protobuf enum value (``pb2.SampleFormat``)
118+
- Integer value of the protobuf enum
119+
- String value (e.g., "I16", "U32")
120+
- ``None``
121+
"""
122+
if value is None:
123+
return None
124+
125+
# Map protobuf values (or their ints) to our Python enum
126+
mapping: dict[int, SampleFormat] = {
127+
int(PB_I16): cls.I16,
128+
int(PB_I32): cls.I32,
129+
int(PB_U16): cls.U16,
130+
int(PB_U32): cls.U32,
131+
int(PB_F32): cls.F32,
132+
}
133+
134+
if isinstance(value, int):
135+
return mapping.get(int(value))
136+
137+
if isinstance(value, str):
138+
key = value.strip().upper()
139+
# Only map to values that exist in this Python enum
140+
try:
141+
return cls[key]
142+
except KeyError:
143+
return None
144+
145+
# Fallback for enum-like types (protobuf enum wrappers behave like ints)
146+
try:
147+
return mapping.get(int(value))
148+
except (TypeError, ValueError):
149+
return None
150+
100151

101152
DEFAULT_SAMPLE_FORMAT = SampleFormat.I16
102153

@@ -121,6 +172,69 @@ class AudioConfiguration:
121172
sample_format: SampleFormat = DEFAULT_SAMPLE_FORMAT
122173
"""Audio sample format."""
123174

175+
@staticmethod
176+
def _to_int(value: Any, default: int) -> int:
177+
"""Best-effort conversion to ``int`` with a sensible default.
178+
179+
Accepts ``int``/``str``/``None`` and returns ``default`` if conversion
180+
fails or value is falsy.
181+
"""
182+
if value is None:
183+
return default
184+
try:
185+
if isinstance(value, bool): # avoid bool being a subclass of int
186+
return default
187+
if isinstance(value, (int,)):
188+
return int(value) or default
189+
if isinstance(value, str):
190+
s = value.strip()
191+
return int(s) if s else default
192+
except (TypeError, ValueError):
193+
return default
194+
return default
195+
196+
@classmethod
197+
def from_proto(cls, value: Any) -> Optional["AudioConfiguration"]:
198+
"""Convert protobuf ``AudioConfiguration`` (or mapping) to Python model.
199+
200+
- ``None`` returns ``None``
201+
- Protobuf message: reads fields and converts types
202+
- ``dict``/``mapping``: accepts keys ``channels``, ``sample_rate``,
203+
``sample_format`` (strings/ints acceptable)
204+
205+
The protobuf field ``format`` (``AudioFormat``) is currently ignored in
206+
the Python model.
207+
"""
208+
if value is None:
209+
return None
210+
211+
# Extract raw field values from either a proto message or a dict-like
212+
if (
213+
hasattr(value, "__class__")
214+
and value.__class__.__name__ == "AudioConfiguration"
215+
):
216+
# Likely a protobuf message instance
217+
ch = getattr(value, "channels", DEFAULT_AUDIO_CHANNELS)
218+
sr = getattr(value, "sample_rate", DEFAULT_SAMPLE_RATE)
219+
sf = getattr(value, "sample_format", None)
220+
elif isinstance(value, dict):
221+
ch = value.get("channels", DEFAULT_AUDIO_CHANNELS)
222+
sr = value.get("sample_rate", DEFAULT_SAMPLE_RATE)
223+
sf = value.get("sample_format", None)
224+
else:
225+
# Unsupported type
226+
return None
227+
228+
channels = cls._to_int(ch, DEFAULT_AUDIO_CHANNELS)
229+
sample_rate = cls._to_int(sr, DEFAULT_SAMPLE_RATE)
230+
sample_format = SampleFormat.from_proto(sf) or DEFAULT_SAMPLE_FORMAT
231+
232+
return cls(
233+
channels=channels,
234+
sample_rate=sample_rate,
235+
sample_format=sample_format,
236+
)
237+
124238

125239
@dataclass(slots=True)
126240
class VoiceAssistantProfile:

0 commit comments

Comments
 (0)