1212import socket
1313from asyncio import AbstractEventLoop
1414from copy import deepcopy
15- from dataclasses import asdict
15+ from dataclasses import asdict , dataclass
1616from typing import Any , Callable
1717
1818import websockets
5757_LOG .setLevel (logging .DEBUG )
5858
5959
60+ VoiceSessionKey = tuple [Any , int ]
61+ """Tuple of (websocket, session_id)"""
62+
63+
64+ @dataclass (slots = True )
65+ class _VoiceSessionContext :
66+ session : VoiceSession
67+ timeout_task : asyncio .Task | None = None
68+ handler_task : asyncio .Task | None = None
69+
70+
6071# pylint: disable=too-many-public-methods, too-many-lines
6172class IntegrationAPI :
6273 """Integration API to communicate with Remote Two/3."""
6374
75+ DEFAULT_VOICE_SESSION_TIMEOUT_S : int = 30
76+
6477 def __init__ (self , loop : AbstractEventLoop ):
6578 """
6679 Create an integration driver API instance.
@@ -69,36 +82,38 @@ def __init__(self, loop: AbstractEventLoop):
6982 """
7083 self ._loop = loop
7184 self ._events = AsyncIOEventEmitter (self ._loop )
85+
7286 self ._setup_handler : uc .SetupHandler | None = None
7387 self ._driver_info : dict [str , Any ] = {}
7488 self ._driver_path : str | None = None
7589 self ._state : uc .DeviceStates = uc .DeviceStates .DISCONNECTED
90+
7691 self ._server_task = None
7792 self ._clients = set ()
7893
79- self ._config_dir_path : str = (
80- os .getenv ("UC_CONFIG_HOME" ) or os .getenv ("HOME" ) or "./"
81- )
94+ self ._config_dir_path = self ._resolve_config_dir ()
8295
8396 self ._available_entities = Entities ("available" , self ._loop )
8497 self ._configured_entities = Entities ("configured" , self ._loop )
8598
86- # Voice stream handler
8799 self ._voice_handler : VoiceStreamHandler | None = None
88- self ._voice_sessions : dict [int , VoiceSession ] = {}
89- # Track which websocket owns which voice session ids, to cleanup on disconnect
90- self ._voice_ws_sessions : dict [Any , set [int ]] = {}
91- # Owner mapping: session_id -> websocket to facilitate cleanup
92- self ._voice_session_owner : dict [int , Any ] = {}
93- # Voice session timeout management (seconds)
94- self ._voice_session_timeout : int = 30
95- self ._voice_session_timeouts : dict [int , asyncio .Task ] = {}
96- # Track handler tasks per session (to avoid double-start and for observability)
97- self ._voice_handler_tasks : dict [int , asyncio .Task ] = {}
100+ self ._voice_session_timeout : int = self .DEFAULT_VOICE_SESSION_TIMEOUT_S
101+ # Active voice sessions
102+ self ._voice_sessions : dict [VoiceSessionKey , _VoiceSessionContext ] = {}
103+ # Enforce: at most one active session per entity_id (across all websockets)
104+ self ._voice_session_by_entity : dict [str , VoiceSessionKey ] = {}
98105
99106 # Setup event loop
100107 asyncio .set_event_loop (self ._loop )
101108
109+ @staticmethod
110+ def _resolve_config_dir () -> str :
111+ return os .getenv ("UC_CONFIG_HOME" ) or os .getenv ("HOME" ) or "./"
112+
113+ @staticmethod
114+ def _voice_key (websocket : Any , session_id : int ) -> VoiceSessionKey :
115+ return websocket , int (session_id )
116+
102117 async def init (
103118 self , driver_path : str , setup_handler : uc .SetupHandler | None = None
104119 ):
@@ -233,18 +248,16 @@ async def _handle_ws(self, websocket) -> None:
233248
234249 finally :
235250 # Cleanup any active voice sessions associated with this websocket
236- try :
237- session_ids = self ._voice_ws_sessions .pop (websocket , set ())
238- except Exception : # pylint: disable=W0718
239- session_ids = set ()
240- for sid in list (session_ids ):
251+ keys_to_cleanup = [k for k in self ._voice_sessions if k [0 ] is websocket ]
252+ for key in keys_to_cleanup :
241253 try :
242- await self ._cleanup_voice_session (sid , VoiceEndReason .REMOTE )
243- except Exception : # pylint: disable=W0718
254+ await self ._cleanup_voice_session (key , VoiceEndReason .REMOTE )
255+ except Exception as ex : # pylint: disable=W0718
244256 _LOG .exception (
245- "[%s] WS: Error during voice session cleanup for %s" ,
257+ "[%s] WS: Error during voice session cleanup for session_id=%s: %s" ,
246258 websocket .remote_address ,
247- sid ,
259+ key [1 ],
260+ ex ,
248261 )
249262
250263 self ._clients .remove (websocket )
@@ -456,15 +469,18 @@ async def _on_remote_voice_begin(self, websocket, msg: RemoteVoiceBegin) -> None
456469 return
457470
458471 session_id = int (getattr (msg , "session_id" , 0 ) or 0 )
459- session = self ._voice_sessions .get (session_id )
460- if not session :
472+ key = self ._voice_key (websocket , session_id )
473+ ctx = self ._voice_sessions .get (key )
474+ if ctx is None :
461475 _LOG .error (
462476 "[%s] proto VoiceBegin: no voice session for session_id=%s" ,
463477 websocket .remote_address ,
464478 session_id ,
465479 )
466480 return
467481
482+ session = ctx .session
483+
468484 # verify AudioConfiguration in session from voice_start command
469485 cfg = getattr (msg , "configuration" , None )
470486 audio_cfg = AudioConfiguration .from_proto (cfg ) or AudioConfiguration ()
@@ -475,11 +491,6 @@ async def _on_remote_voice_begin(self, websocket, msg: RemoteVoiceBegin) -> None
475491 )
476492 return
477493
478- # Track ownership for cleanup on disconnect
479- owners = self ._voice_ws_sessions .setdefault (websocket , set ())
480- owners .add (session_id )
481- self ._voice_session_owner [session_id ] = websocket
482-
483494 if _LOG .isEnabledFor (logging .DEBUG ):
484495 _LOG .debug (
485496 "[%s] proto VoiceBegin: session_id=%s cfg(ch=%s sr=%s fmt=%s)" ,
@@ -491,8 +502,7 @@ async def _on_remote_voice_begin(self, websocket, msg: RemoteVoiceBegin) -> None
491502 )
492503
493504 # Invoke handler in the background so the WS loop is not blocked
494- task = self ._loop .create_task (self ._run_voice_handler (session ))
495- self ._voice_handler_tasks [session .session_id ] = task
505+ ctx .handler_task = self ._loop .create_task (self ._run_voice_handler (session ))
496506
497507 async def _on_remote_voice_data (self , websocket , msg : RemoteVoiceData ) -> None :
498508 """Handle a RemoteVoiceData protobuf message.
@@ -505,8 +515,9 @@ async def _on_remote_voice_data(self, websocket, msg: RemoteVoiceData) -> None:
505515 return
506516
507517 session_id = int (getattr (msg , "session_id" , 0 ) or 0 )
508- session = self ._voice_sessions .get (session_id )
509- if not session :
518+ key = self ._voice_key (websocket , session_id )
519+ ctx = self ._voice_sessions .get (key )
520+ if ctx is None :
510521 _LOG .error (
511522 "[%s] proto VoiceData: no voice session for session_id=%s" ,
512523 websocket .remote_address ,
@@ -517,7 +528,7 @@ async def _on_remote_voice_data(self, websocket, msg: RemoteVoiceData) -> None:
517528 samples = getattr (msg , "samples" , b"" ) or b""
518529 if samples :
519530 try :
520- session .feed (bytes (samples ))
531+ ctx . session .feed (bytes (samples ))
521532 except Exception as ex : # pylint: disable=W0718
522533 _LOG .error (
523534 "[%s] proto VoiceData: session %s processing error: %s" ,
@@ -526,50 +537,41 @@ async def _on_remote_voice_data(self, websocket, msg: RemoteVoiceData) -> None:
526537 ex ,
527538 )
528539
529- async def _on_remote_voice_end (self , _websocket , msg : RemoteVoiceEnd ) -> None :
540+ async def _on_remote_voice_end (self , websocket , msg : RemoteVoiceEnd ) -> None :
530541 """Handle a RemoteVoiceEnd protobuf message.
531542
532543 If no voice handler is registered, do nothing (default ignore behavior).
533544 """
534545 if self ._voice_handler is None :
535546 return
536547 session_id = int (getattr (msg , "session_id" , 0 ) or 0 )
537- await self ._cleanup_voice_session (session_id )
548+ await self ._cleanup_voice_session (self . _voice_key ( websocket , session_id ) )
538549
539550 async def _cleanup_voice_session (
540- self , session_id : int , end_reason : VoiceEndReason = VoiceEndReason .NORMAL
551+ self ,
552+ key : VoiceSessionKey ,
553+ end_reason : VoiceEndReason = VoiceEndReason .NORMAL ,
541554 ) -> None :
542- """Cleanup internal state for a voice session.
555+ """Cleanup internal state for a voice session context."""
556+ ctx = self ._voice_sessions .pop (key , None )
557+ if ctx is None :
558+ return
543559
544- - Cancel and remove any pending timeout task for the session.
545- - End the session iterator if still open.
546- - Remove bookkeeping: _voice_sessions, _voice_session_owner, _voice_ws_sessions.
547- - Forget handler task reference (do not cancel it; handler should exit on end()).
548- """
549560 # Cancel timeout task if present
550- t = self . _voice_session_timeouts . pop ( session_id , None )
561+ t = ctx . timeout_task
551562 if t is not None and not t .done ():
552563 t .cancel ()
553564
554- # End and remove session
555- session = self ._voice_sessions .pop (session_id , None )
556- if session is not None and not session .closed :
557- session .end (end_reason )
565+ # Enforce entity_id uniqueness index cleanup
566+ if self ._voice_session_by_entity .get (ctx .session .entity_id ) == key :
567+ self ._voice_session_by_entity .pop (ctx .session .entity_id , None )
558568
559- # Remove ownership mappings
560- try :
561- owner_ws = self ._voice_session_owner .pop (session_id , None )
562- if owner_ws is not None :
563- owners = self ._voice_ws_sessions .get (owner_ws )
564- if owners is not None :
565- owners .discard (session_id )
566- if not owners :
567- self ._voice_ws_sessions .pop (owner_ws , None )
568- except Exception : # pylint: disable=W0718
569- pass
570-
571- # Drop handler task reference (don't cancel; allow graceful finish)
572- self ._voice_handler_tasks .pop (session_id , None )
569+ # End session iterator
570+ if not ctx .session .closed :
571+ ctx .session .end (end_reason )
572+
573+ # Note: do not cancel handler task; handler should exit on session.end()
574+ ctx .handler_task = None
573575
574576 def set_voice_stream_handler (self , handler : VoiceStreamHandler | None ) -> None :
575577 """Register or clear the voice stream handler.
@@ -618,56 +620,57 @@ async def _run_voice_handler(self, session: VoiceSession) -> None:
618620 # Ensure iterator is unblocked and session is cleaned up
619621 await self ._cleanup_voice_session (session .session_id )
620622
621- def _schedule_voice_timeout (self , session_id : int ) -> None :
623+ def _schedule_voice_timeout (self , key : VoiceSessionKey ) -> None :
622624 """Schedule the timeout task for a voice session.
623625
624626 Starts counting immediately at creation time. When the timeout expires and the
625627 session is still active, the handler is notified (invoked) if not already
626628 started, the session is ended, and cleanup is performed.
627629 """
630+ ctx = self ._voice_sessions .get (key )
631+ if ctx is None :
632+ return
633+
628634 # Cancel pre-existing task if any (defensive)
629- existing = self . _voice_session_timeouts . pop ( session_id , None )
635+ existing = ctx . timeout_task
630636 if existing is not None and not existing .done ():
631637 existing .cancel ()
632638
633- task = self ._loop .create_task (self ._voice_session_timeout_task (session_id ))
634- self ._voice_session_timeouts [session_id ] = task
639+ ctx .timeout_task = self ._loop .create_task (self ._voice_session_timeout_task (key ))
635640
636- async def _voice_session_timeout_task (self , session_id : int ) -> None :
641+ async def _voice_session_timeout_task (self , key : VoiceSessionKey ) -> None :
637642 """Timeout watchdog for a voice session."""
638643 try :
639644 await asyncio .sleep (self ._voice_session_timeout )
640645 except asyncio .CancelledError :
641646 return
642647
643648 # If still active after timeout; take action
644- session = self ._voice_sessions .get (session_id )
645- if session is None :
649+ ctx = self ._voice_sessions .get (key )
650+ if ctx is None :
646651 return
647652
648653 _LOG .warning (
649654 "Voice session %s timed out after %ss" ,
650- session_id ,
655+ ctx . session . session_id ,
651656 self ._voice_session_timeout ,
652657 )
653658
654- # If handler not started yet (e.g., no VoiceBegin received), notify it now
655- if (
656- session_id not in self ._voice_handler_tasks
657- and self ._voice_handler is not None
658- ):
659+ # If handler not started yet, start it now (best effort)
660+ if ctx .handler_task is None and self ._voice_handler is not None :
659661 try :
660- task = self ._loop .create_task (self ._run_voice_handler (session ))
661- self ._voice_handler_tasks [session_id ] = task
662+ ctx .handler_task = self ._loop .create_task (
663+ self ._run_voice_handler (ctx .session )
664+ )
662665 except Exception : # pylint: disable=W0718
663666 _LOG .exception (
664667 "Failed to start voice handler on timeout for session %s" ,
665- session_id ,
668+ ctx . session . session_id ,
666669 )
667670
668671 # End and cleanup
669- session .end (VoiceEndReason .TIMEOUT )
670- await self ._cleanup_voice_session (session_id )
672+ ctx . session .end (VoiceEndReason .TIMEOUT )
673+ await self ._cleanup_voice_session (key )
671674
672675 async def _handle_ws_request_msg (
673676 self , websocket , msg : str , req_id : int , msg_data : dict [str , Any ] | None
@@ -853,7 +856,7 @@ async def _entity_command(
853856 and "params" in msg_data
854857 ):
855858 params = msg_data ["params" ]
856- session_id = params .get ("session_id" )
859+ session_id = int ( params .get ("session_id" ) )
857860 cfg = params .get ("audio_cfg" )
858861 audio_cfg = (
859862 AudioConfiguration (
@@ -865,6 +868,12 @@ async def _entity_command(
865868 else AudioConfiguration ()
866869 )
867870
871+ # Enforce: only one active session per entity_id across all websockets
872+ existing_key = self ._voice_session_by_entity .get (entity_id )
873+ if existing_key is not None :
874+ await self ._cleanup_voice_session (existing_key , VoiceEndReason .LOCAL )
875+
876+ key = self ._voice_key (websocket , session_id )
868877 session = VoiceSession (
869878 session_id ,
870879 entity_id ,
@@ -873,9 +882,11 @@ async def _entity_command(
873882 websocket = websocket ,
874883 loop = self ._loop ,
875884 )
876- self ._voice_sessions [session_id ] = session
885+ self ._voice_sessions [key ] = _VoiceSessionContext (session = session )
886+ self ._voice_session_by_entity [entity_id ] = key
887+
877888 # Start timeout immediately on session creation
878- self ._schedule_voice_timeout (session_id )
889+ self ._schedule_voice_timeout (key )
879890
880891 result = await entity .command (
881892 cmd_id , msg_data ["params" ] if "params" in msg_data else None
0 commit comments