Skip to content

Commit e56fd77

Browse files
committed
Refactor voice session handling with context class
Simplify, too many dicts keeping track of a voice session! Enforce that there's only one session per voice-assistant entity.
1 parent 833a380 commit e56fd77

File tree

1 file changed

+96
-85
lines changed

1 file changed

+96
-85
lines changed

ucapi/api.py

Lines changed: 96 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import socket
1313
from asyncio import AbstractEventLoop
1414
from copy import deepcopy
15-
from dataclasses import asdict
15+
from dataclasses import asdict, dataclass
1616
from typing import Any, Callable
1717

1818
import websockets
@@ -57,10 +57,23 @@
5757
_LOG.setLevel(logging.DEBUG)
5858

5959

60+
VoiceSessionKey = tuple[Any, int]
61+
"""Tuple of (websocket, session_id)"""
62+
63+
64+
@dataclass(slots=True)
65+
class _VoiceSessionContext:
66+
session: VoiceSession
67+
timeout_task: asyncio.Task | None = None
68+
handler_task: asyncio.Task | None = None
69+
70+
6071
# pylint: disable=too-many-public-methods, too-many-lines
6172
class IntegrationAPI:
6273
"""Integration API to communicate with Remote Two/3."""
6374

75+
DEFAULT_VOICE_SESSION_TIMEOUT_S: int = 30
76+
6477
def __init__(self, loop: AbstractEventLoop):
6578
"""
6679
Create an integration driver API instance.
@@ -69,36 +82,38 @@ def __init__(self, loop: AbstractEventLoop):
6982
"""
7083
self._loop = loop
7184
self._events = AsyncIOEventEmitter(self._loop)
85+
7286
self._setup_handler: uc.SetupHandler | None = None
7387
self._driver_info: dict[str, Any] = {}
7488
self._driver_path: str | None = None
7589
self._state: uc.DeviceStates = uc.DeviceStates.DISCONNECTED
90+
7691
self._server_task = None
7792
self._clients = set()
7893

79-
self._config_dir_path: str = (
80-
os.getenv("UC_CONFIG_HOME") or os.getenv("HOME") or "./"
81-
)
94+
self._config_dir_path = self._resolve_config_dir()
8295

8396
self._available_entities = Entities("available", self._loop)
8497
self._configured_entities = Entities("configured", self._loop)
8598

86-
# Voice stream handler
8799
self._voice_handler: VoiceStreamHandler | None = None
88-
self._voice_sessions: dict[int, VoiceSession] = {}
89-
# Track which websocket owns which voice session ids, to cleanup on disconnect
90-
self._voice_ws_sessions: dict[Any, set[int]] = {}
91-
# Owner mapping: session_id -> websocket to facilitate cleanup
92-
self._voice_session_owner: dict[int, Any] = {}
93-
# Voice session timeout management (seconds)
94-
self._voice_session_timeout: int = 30
95-
self._voice_session_timeouts: dict[int, asyncio.Task] = {}
96-
# Track handler tasks per session (to avoid double-start and for observability)
97-
self._voice_handler_tasks: dict[int, asyncio.Task] = {}
100+
self._voice_session_timeout: int = self.DEFAULT_VOICE_SESSION_TIMEOUT_S
101+
# Active voice sessions
102+
self._voice_sessions: dict[VoiceSessionKey, _VoiceSessionContext] = {}
103+
# Enforce: at most one active session per entity_id (across all websockets)
104+
self._voice_session_by_entity: dict[str, VoiceSessionKey] = {}
98105

99106
# Setup event loop
100107
asyncio.set_event_loop(self._loop)
101108

109+
@staticmethod
110+
def _resolve_config_dir() -> str:
111+
return os.getenv("UC_CONFIG_HOME") or os.getenv("HOME") or "./"
112+
113+
@staticmethod
114+
def _voice_key(websocket: Any, session_id: int) -> VoiceSessionKey:
115+
return websocket, int(session_id)
116+
102117
async def init(
103118
self, driver_path: str, setup_handler: uc.SetupHandler | None = None
104119
):
@@ -233,18 +248,16 @@ async def _handle_ws(self, websocket) -> None:
233248

234249
finally:
235250
# Cleanup any active voice sessions associated with this websocket
236-
try:
237-
session_ids = self._voice_ws_sessions.pop(websocket, set())
238-
except Exception: # pylint: disable=W0718
239-
session_ids = set()
240-
for sid in list(session_ids):
251+
keys_to_cleanup = [k for k in self._voice_sessions if k[0] is websocket]
252+
for key in keys_to_cleanup:
241253
try:
242-
await self._cleanup_voice_session(sid, VoiceEndReason.REMOTE)
243-
except Exception: # pylint: disable=W0718
254+
await self._cleanup_voice_session(key, VoiceEndReason.REMOTE)
255+
except Exception as ex: # pylint: disable=W0718
244256
_LOG.exception(
245-
"[%s] WS: Error during voice session cleanup for %s",
257+
"[%s] WS: Error during voice session cleanup for session_id=%s: %s",
246258
websocket.remote_address,
247-
sid,
259+
key[1],
260+
ex,
248261
)
249262

250263
self._clients.remove(websocket)
@@ -456,15 +469,18 @@ async def _on_remote_voice_begin(self, websocket, msg: RemoteVoiceBegin) -> None
456469
return
457470

458471
session_id = int(getattr(msg, "session_id", 0) or 0)
459-
session = self._voice_sessions.get(session_id)
460-
if not session:
472+
key = self._voice_key(websocket, session_id)
473+
ctx = self._voice_sessions.get(key)
474+
if ctx is None:
461475
_LOG.error(
462476
"[%s] proto VoiceBegin: no voice session for session_id=%s",
463477
websocket.remote_address,
464478
session_id,
465479
)
466480
return
467481

482+
session = ctx.session
483+
468484
# verify AudioConfiguration in session from voice_start command
469485
cfg = getattr(msg, "configuration", None)
470486
audio_cfg = AudioConfiguration.from_proto(cfg) or AudioConfiguration()
@@ -475,11 +491,6 @@ async def _on_remote_voice_begin(self, websocket, msg: RemoteVoiceBegin) -> None
475491
)
476492
return
477493

478-
# Track ownership for cleanup on disconnect
479-
owners = self._voice_ws_sessions.setdefault(websocket, set())
480-
owners.add(session_id)
481-
self._voice_session_owner[session_id] = websocket
482-
483494
if _LOG.isEnabledFor(logging.DEBUG):
484495
_LOG.debug(
485496
"[%s] proto VoiceBegin: session_id=%s cfg(ch=%s sr=%s fmt=%s)",
@@ -491,8 +502,7 @@ async def _on_remote_voice_begin(self, websocket, msg: RemoteVoiceBegin) -> None
491502
)
492503

493504
# Invoke handler in the background so the WS loop is not blocked
494-
task = self._loop.create_task(self._run_voice_handler(session))
495-
self._voice_handler_tasks[session.session_id] = task
505+
ctx.handler_task = self._loop.create_task(self._run_voice_handler(session))
496506

497507
async def _on_remote_voice_data(self, websocket, msg: RemoteVoiceData) -> None:
498508
"""Handle a RemoteVoiceData protobuf message.
@@ -505,8 +515,9 @@ async def _on_remote_voice_data(self, websocket, msg: RemoteVoiceData) -> None:
505515
return
506516

507517
session_id = int(getattr(msg, "session_id", 0) or 0)
508-
session = self._voice_sessions.get(session_id)
509-
if not session:
518+
key = self._voice_key(websocket, session_id)
519+
ctx = self._voice_sessions.get(key)
520+
if ctx is None:
510521
_LOG.error(
511522
"[%s] proto VoiceData: no voice session for session_id=%s",
512523
websocket.remote_address,
@@ -517,7 +528,7 @@ async def _on_remote_voice_data(self, websocket, msg: RemoteVoiceData) -> None:
517528
samples = getattr(msg, "samples", b"") or b""
518529
if samples:
519530
try:
520-
session.feed(bytes(samples))
531+
ctx.session.feed(bytes(samples))
521532
except Exception as ex: # pylint: disable=W0718
522533
_LOG.error(
523534
"[%s] proto VoiceData: session %s processing error: %s",
@@ -526,50 +537,41 @@ async def _on_remote_voice_data(self, websocket, msg: RemoteVoiceData) -> None:
526537
ex,
527538
)
528539

529-
async def _on_remote_voice_end(self, _websocket, msg: RemoteVoiceEnd) -> None:
540+
async def _on_remote_voice_end(self, websocket, msg: RemoteVoiceEnd) -> None:
530541
"""Handle a RemoteVoiceEnd protobuf message.
531542
532543
If no voice handler is registered, do nothing (default ignore behavior).
533544
"""
534545
if self._voice_handler is None:
535546
return
536547
session_id = int(getattr(msg, "session_id", 0) or 0)
537-
await self._cleanup_voice_session(session_id)
548+
await self._cleanup_voice_session(self._voice_key(websocket, session_id))
538549

539550
async def _cleanup_voice_session(
540-
self, session_id: int, end_reason: VoiceEndReason = VoiceEndReason.NORMAL
551+
self,
552+
key: VoiceSessionKey,
553+
end_reason: VoiceEndReason = VoiceEndReason.NORMAL,
541554
) -> None:
542-
"""Cleanup internal state for a voice session.
555+
"""Cleanup internal state for a voice session context."""
556+
ctx = self._voice_sessions.pop(key, None)
557+
if ctx is None:
558+
return
543559

544-
- Cancel and remove any pending timeout task for the session.
545-
- End the session iterator if still open.
546-
- Remove bookkeeping: _voice_sessions, _voice_session_owner, _voice_ws_sessions.
547-
- Forget handler task reference (do not cancel it; handler should exit on end()).
548-
"""
549560
# Cancel timeout task if present
550-
t = self._voice_session_timeouts.pop(session_id, None)
561+
t = ctx.timeout_task
551562
if t is not None and not t.done():
552563
t.cancel()
553564

554-
# End and remove session
555-
session = self._voice_sessions.pop(session_id, None)
556-
if session is not None and not session.closed:
557-
session.end(end_reason)
565+
# Enforce entity_id uniqueness index cleanup
566+
if self._voice_session_by_entity.get(ctx.session.entity_id) == key:
567+
self._voice_session_by_entity.pop(ctx.session.entity_id, None)
558568

559-
# Remove ownership mappings
560-
try:
561-
owner_ws = self._voice_session_owner.pop(session_id, None)
562-
if owner_ws is not None:
563-
owners = self._voice_ws_sessions.get(owner_ws)
564-
if owners is not None:
565-
owners.discard(session_id)
566-
if not owners:
567-
self._voice_ws_sessions.pop(owner_ws, None)
568-
except Exception: # pylint: disable=W0718
569-
pass
570-
571-
# Drop handler task reference (don't cancel; allow graceful finish)
572-
self._voice_handler_tasks.pop(session_id, None)
569+
# End session iterator
570+
if not ctx.session.closed:
571+
ctx.session.end(end_reason)
572+
573+
# Note: do not cancel handler task; handler should exit on session.end()
574+
ctx.handler_task = None
573575

574576
def set_voice_stream_handler(self, handler: VoiceStreamHandler | None) -> None:
575577
"""Register or clear the voice stream handler.
@@ -618,56 +620,57 @@ async def _run_voice_handler(self, session: VoiceSession) -> None:
618620
# Ensure iterator is unblocked and session is cleaned up
619621
await self._cleanup_voice_session(session.session_id)
620622

621-
def _schedule_voice_timeout(self, session_id: int) -> None:
623+
def _schedule_voice_timeout(self, key: VoiceSessionKey) -> None:
622624
"""Schedule the timeout task for a voice session.
623625
624626
Starts counting immediately at creation time. When the timeout expires and the
625627
session is still active, the handler is notified (invoked) if not already
626628
started, the session is ended, and cleanup is performed.
627629
"""
630+
ctx = self._voice_sessions.get(key)
631+
if ctx is None:
632+
return
633+
628634
# Cancel pre-existing task if any (defensive)
629-
existing = self._voice_session_timeouts.pop(session_id, None)
635+
existing = ctx.timeout_task
630636
if existing is not None and not existing.done():
631637
existing.cancel()
632638

633-
task = self._loop.create_task(self._voice_session_timeout_task(session_id))
634-
self._voice_session_timeouts[session_id] = task
639+
ctx.timeout_task = self._loop.create_task(self._voice_session_timeout_task(key))
635640

636-
async def _voice_session_timeout_task(self, session_id: int) -> None:
641+
async def _voice_session_timeout_task(self, key: VoiceSessionKey) -> None:
637642
"""Timeout watchdog for a voice session."""
638643
try:
639644
await asyncio.sleep(self._voice_session_timeout)
640645
except asyncio.CancelledError:
641646
return
642647

643648
# If still active after timeout; take action
644-
session = self._voice_sessions.get(session_id)
645-
if session is None:
649+
ctx = self._voice_sessions.get(key)
650+
if ctx is None:
646651
return
647652

648653
_LOG.warning(
649654
"Voice session %s timed out after %ss",
650-
session_id,
655+
ctx.session.session_id,
651656
self._voice_session_timeout,
652657
)
653658

654-
# If handler not started yet (e.g., no VoiceBegin received), notify it now
655-
if (
656-
session_id not in self._voice_handler_tasks
657-
and self._voice_handler is not None
658-
):
659+
# If handler not started yet, start it now (best effort)
660+
if ctx.handler_task is None and self._voice_handler is not None:
659661
try:
660-
task = self._loop.create_task(self._run_voice_handler(session))
661-
self._voice_handler_tasks[session_id] = task
662+
ctx.handler_task = self._loop.create_task(
663+
self._run_voice_handler(ctx.session)
664+
)
662665
except Exception: # pylint: disable=W0718
663666
_LOG.exception(
664667
"Failed to start voice handler on timeout for session %s",
665-
session_id,
668+
ctx.session.session_id,
666669
)
667670

668671
# End and cleanup
669-
session.end(VoiceEndReason.TIMEOUT)
670-
await self._cleanup_voice_session(session_id)
672+
ctx.session.end(VoiceEndReason.TIMEOUT)
673+
await self._cleanup_voice_session(key)
671674

672675
async def _handle_ws_request_msg(
673676
self, websocket, msg: str, req_id: int, msg_data: dict[str, Any] | None
@@ -853,7 +856,7 @@ async def _entity_command(
853856
and "params" in msg_data
854857
):
855858
params = msg_data["params"]
856-
session_id = params.get("session_id")
859+
session_id = int(params.get("session_id"))
857860
cfg = params.get("audio_cfg")
858861
audio_cfg = (
859862
AudioConfiguration(
@@ -865,6 +868,12 @@ async def _entity_command(
865868
else AudioConfiguration()
866869
)
867870

871+
# Enforce: only one active session per entity_id across all websockets
872+
existing_key = self._voice_session_by_entity.get(entity_id)
873+
if existing_key is not None:
874+
await self._cleanup_voice_session(existing_key, VoiceEndReason.LOCAL)
875+
876+
key = self._voice_key(websocket, session_id)
868877
session = VoiceSession(
869878
session_id,
870879
entity_id,
@@ -873,9 +882,11 @@ async def _entity_command(
873882
websocket=websocket,
874883
loop=self._loop,
875884
)
876-
self._voice_sessions[session_id] = session
885+
self._voice_sessions[key] = _VoiceSessionContext(session=session)
886+
self._voice_session_by_entity[entity_id] = key
887+
877888
# Start timeout immediately on session creation
878-
self._schedule_voice_timeout(session_id)
889+
self._schedule_voice_timeout(key)
879890

880891
result = await entity.command(
881892
cmd_id, msg_data["params"] if "params" in msg_data else None

0 commit comments

Comments
 (0)