66"""
77
88import asyncio
9+ import inspect
910import json
1011import logging
1112import os
1213import socket
1314from asyncio import AbstractEventLoop
1415from copy import deepcopy
1516from dataclasses import asdict , dataclass
17+ from functools import wraps
1618from typing import Any , Callable
1719
1820import websockets
@@ -212,7 +214,7 @@ async def _handle_ws(self, websocket) -> None:
212214 # authenticate on connection
213215 await self ._authenticate (websocket , True )
214216
215- self ._events .emit (uc .Events .CLIENT_CONNECTED )
217+ self ._events .emit (uc .Events .CLIENT_CONNECTED , websocket = websocket )
216218
217219 async for message in websocket :
218220 # Distinguish between text (str) and binary (bytes-like) messages
@@ -264,7 +266,7 @@ async def _handle_ws(self, websocket) -> None:
264266
265267 self ._clients .remove (websocket )
266268 _LOG .info ("[%s] WS: Client removed" , websocket .remote_address )
267- self ._events .emit (uc .Events .CLIENT_DISCONNECTED )
269+ self ._events .emit (uc .Events .CLIENT_DISCONNECTED , websocket = websocket )
268270
269271 async def _send_ok_result (
270272 self , websocket , req_id : int , msg_data : dict [str , Any ] | list | None = None
@@ -411,7 +413,7 @@ async def _process_ws_message(self, websocket, message) -> None:
411413 else :
412414 await self ._handle_ws_request_msg (websocket , msg , req_id , msg_data )
413415 elif kind == "event" :
414- await self ._handle_ws_event_msg (msg , msg_data )
416+ await self ._handle_ws_event_msg (websocket , msg , msg_data )
415417
416418 async def _process_ws_binary_message (self , websocket , data : bytes ) -> None :
417419 """Process a binary WebSocket message using protobuf IntegrationMessage.
@@ -710,10 +712,10 @@ async def _handle_ws_request_msg(
710712 elif msg == uc .WsMessages .ENTITY_COMMAND :
711713 await self ._entity_command (websocket , req_id , msg_data )
712714 elif msg == uc .WsMessages .SUBSCRIBE_EVENTS :
713- await self ._subscribe_events (msg_data )
715+ await self ._subscribe_events (websocket , msg_data )
714716 await self ._send_ok_result (websocket , req_id )
715717 elif msg == uc .WsMessages .UNSUBSCRIBE_EVENTS :
716- await self ._unsubscribe_events (msg_data )
718+ await self ._unsubscribe_events (websocket , msg_data )
717719 await self ._send_ok_result (websocket , req_id )
718720 elif msg == uc .WsMessages .GET_DRIVER_METADATA :
719721 await self ._send_ws_response (
@@ -730,16 +732,16 @@ async def _handle_ws_request_msg(
730732 await self .driver_setup_error (websocket )
731733
732734 async def _handle_ws_event_msg (
733- self , msg : str , msg_data : dict [str , Any ] | None
735+ self , websocket : Any , msg : str , msg_data : dict [str , Any ] | None
734736 ) -> None :
735737 if msg == uc .WsMsgEvents .CONNECT :
736- self ._events .emit (uc .Events .CONNECT )
738+ self ._events .emit (uc .Events .CONNECT , websocket = websocket )
737739 elif msg == uc .WsMsgEvents .DISCONNECT :
738- self ._events .emit (uc .Events .DISCONNECT )
740+ self ._events .emit (uc .Events .DISCONNECT , websocket = websocket )
739741 elif msg == uc .WsMsgEvents .ENTER_STANDBY :
740- self ._events .emit (uc .Events .ENTER_STANDBY )
742+ self ._events .emit (uc .Events .ENTER_STANDBY , websocket = websocket )
741743 elif msg == uc .WsMsgEvents .EXIT_STANDBY :
742- self ._events .emit (uc .Events .EXIT_STANDBY )
744+ self ._events .emit (uc .Events .EXIT_STANDBY , websocket = websocket )
743745 elif msg == uc .WsMsgEvents .ABORT_DRIVER_SETUP :
744746 if not self ._setup_handler :
745747 _LOG .warning (
@@ -792,7 +794,9 @@ async def set_device_state(self, state: uc.DeviceStates) -> None:
792794 uc .EventCategory .DEVICE ,
793795 )
794796
795- async def _subscribe_events (self , msg_data : dict [str , Any ] | None ) -> None :
797+ async def _subscribe_events (
798+ self , websocket : Any , msg_data : dict [str , Any ] | None
799+ ) -> None :
796800 if msg_data is None :
797801 _LOG .warning ("Ignoring _subscribe_events: called with empty msg_data" )
798802 return
@@ -806,9 +810,15 @@ async def _subscribe_events(self, msg_data: dict[str, Any] | None) -> None:
806810 entity_id ,
807811 )
808812
809- self ._events .emit (uc .Events .SUBSCRIBE_ENTITIES , msg_data ["entity_ids" ])
813+ self ._events .emit (
814+ uc .Events .SUBSCRIBE_ENTITIES ,
815+ entity_ids = msg_data ["entity_ids" ],
816+ websocket = websocket ,
817+ )
810818
811- async def _unsubscribe_events (self , msg_data : dict [str , Any ] | None ) -> bool :
819+ async def _unsubscribe_events (
820+ self , websocket : Any , msg_data : dict [str , Any ] | None
821+ ) -> bool :
812822 if msg_data is None :
813823 _LOG .warning ("Ignoring _unsubscribe_events: called with empty msg_data" )
814824 return False
@@ -819,7 +829,11 @@ async def _unsubscribe_events(self, msg_data: dict[str, Any] | None) -> bool:
819829 if self ._configured_entities .remove (entity_id ) is False :
820830 res = False
821831
822- self ._events .emit (uc .Events .UNSUBSCRIBE_ENTITIES , msg_data ["entity_ids" ])
832+ self ._events .emit (
833+ uc .Events .UNSUBSCRIBE_ENTITIES ,
834+ entity_ids = msg_data ["entity_ids" ],
835+ websocket = websocket ,
836+ )
823837
824838 return res
825839
@@ -1114,14 +1128,65 @@ async def driver_setup_error(self, websocket, error="OTHER") -> None:
11141128 websocket , uc .WsMsgEvents .DRIVER_SETUP_CHANGE , data , uc .EventCategory .DEVICE
11151129 )
11161130
1131+ @staticmethod
1132+ def _wrap_event_listener (listener : Callable ) -> Callable :
1133+ """Event listener wrapper for backwards compatibility.
1134+
1135+ Wrap an event listener so it remains compatible if the library starts emitting
1136+ additional event parameters later.
1137+
1138+ Example:
1139+ - listener() keeps working even if emitter calls listener(websocket)
1140+ - listener(websocket) keeps working if emitter calls listener(websocket, x, y)
1141+ """
1142+ try :
1143+ sig = inspect .signature (listener )
1144+ except (TypeError , ValueError ):
1145+ # Builtins / callables without inspectable signature: fall back to raw call.
1146+ return listener
1147+
1148+ params = list (sig .parameters .values ())
1149+
1150+ accepts_varargs = any (
1151+ p .kind == inspect .Parameter .VAR_POSITIONAL for p in params
1152+ )
1153+ accepts_varkw = any (p .kind == inspect .Parameter .VAR_KEYWORD for p in params )
1154+
1155+ # How many positional args can the listener accept (excluding *args/**kwargs)?
1156+ positional_kinds = (
1157+ inspect .Parameter .POSITIONAL_ONLY ,
1158+ inspect .Parameter .POSITIONAL_OR_KEYWORD ,
1159+ )
1160+ max_positional = sum (1 for p in params if p .kind in positional_kinds )
1161+
1162+ # Which named kwargs are accepted (if no **kwargs)?
1163+ accepted_kw = {
1164+ p .name
1165+ for p in params
1166+ if p .kind
1167+ in (inspect .Parameter .POSITIONAL_OR_KEYWORD , inspect .Parameter .KEYWORD_ONLY )
1168+ }
1169+
1170+ @wraps (listener )
1171+ def wrapper (* args : Any , ** kwargs : Any ):
1172+ call_args = args if accepts_varargs else args [:max_positional ]
1173+ call_kwargs = (
1174+ kwargs
1175+ if accepts_varkw
1176+ else {k : v for k , v in kwargs .items () if k in accepted_kw }
1177+ )
1178+ return listener (* call_args , ** call_kwargs )
1179+
1180+ return wrapper
1181+
11171182 def add_listener (self , event : uc .Events , f : Callable ) -> None :
11181183 """
11191184 Register a callback handler for the given event.
11201185
11211186 :param event: the event
11221187 :param f: callback handler
11231188 """
1124- self ._events .add_listener (event , f )
1189+ self ._events .add_listener (event , self . _wrap_event_listener ( f ) )
11251190
11261191 def listens_to (self , event : uc .Events ) -> Callable [[Callable ], Callable ]:
11271192 """
@@ -1132,7 +1197,7 @@ def listens_to(self, event: uc.Events) -> Callable[[Callable], Callable]:
11321197 """
11331198
11341199 def on (f : Callable ) -> Callable :
1135- self ._events .add_listener (event , f )
1200+ self ._events .add_listener (event , self . _wrap_event_listener ( f ) )
11361201 return f
11371202
11381203 return on
0 commit comments