Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 81 additions & 16 deletions ucapi/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
"""

import asyncio
import inspect
import json
import logging
import os
import socket
from asyncio import AbstractEventLoop
from copy import deepcopy
from dataclasses import asdict, dataclass
from functools import wraps
from typing import Any, Callable

import websockets
Expand Down Expand Up @@ -212,7 +214,7 @@ async def _handle_ws(self, websocket) -> None:
# authenticate on connection
await self._authenticate(websocket, True)

self._events.emit(uc.Events.CLIENT_CONNECTED)
self._events.emit(uc.Events.CLIENT_CONNECTED, websocket=websocket)

async for message in websocket:
# Distinguish between text (str) and binary (bytes-like) messages
Expand Down Expand Up @@ -264,7 +266,7 @@ async def _handle_ws(self, websocket) -> None:

self._clients.remove(websocket)
_LOG.info("[%s] WS: Client removed", websocket.remote_address)
self._events.emit(uc.Events.CLIENT_DISCONNECTED)
self._events.emit(uc.Events.CLIENT_DISCONNECTED, websocket=websocket)

async def _send_ok_result(
self, websocket, req_id: int, msg_data: dict[str, Any] | list | None = None
Expand Down Expand Up @@ -411,7 +413,7 @@ async def _process_ws_message(self, websocket, message) -> None:
else:
await self._handle_ws_request_msg(websocket, msg, req_id, msg_data)
elif kind == "event":
await self._handle_ws_event_msg(msg, msg_data)
await self._handle_ws_event_msg(websocket, msg, msg_data)

async def _process_ws_binary_message(self, websocket, data: bytes) -> None:
"""Process a binary WebSocket message using protobuf IntegrationMessage.
Expand Down Expand Up @@ -710,10 +712,10 @@ async def _handle_ws_request_msg(
elif msg == uc.WsMessages.ENTITY_COMMAND:
await self._entity_command(websocket, req_id, msg_data)
elif msg == uc.WsMessages.SUBSCRIBE_EVENTS:
await self._subscribe_events(msg_data)
await self._subscribe_events(websocket, msg_data)
await self._send_ok_result(websocket, req_id)
elif msg == uc.WsMessages.UNSUBSCRIBE_EVENTS:
await self._unsubscribe_events(msg_data)
await self._unsubscribe_events(websocket, msg_data)
await self._send_ok_result(websocket, req_id)
elif msg == uc.WsMessages.GET_DRIVER_METADATA:
await self._send_ws_response(
Expand All @@ -730,16 +732,16 @@ async def _handle_ws_request_msg(
await self.driver_setup_error(websocket)

async def _handle_ws_event_msg(
self, msg: str, msg_data: dict[str, Any] | None
self, websocket: Any, msg: str, msg_data: dict[str, Any] | None
) -> None:
if msg == uc.WsMsgEvents.CONNECT:
self._events.emit(uc.Events.CONNECT)
self._events.emit(uc.Events.CONNECT, websocket=websocket)
elif msg == uc.WsMsgEvents.DISCONNECT:
self._events.emit(uc.Events.DISCONNECT)
self._events.emit(uc.Events.DISCONNECT, websocket=websocket)
elif msg == uc.WsMsgEvents.ENTER_STANDBY:
self._events.emit(uc.Events.ENTER_STANDBY)
self._events.emit(uc.Events.ENTER_STANDBY, websocket=websocket)
elif msg == uc.WsMsgEvents.EXIT_STANDBY:
self._events.emit(uc.Events.EXIT_STANDBY)
self._events.emit(uc.Events.EXIT_STANDBY, websocket=websocket)
elif msg == uc.WsMsgEvents.ABORT_DRIVER_SETUP:
if not self._setup_handler:
_LOG.warning(
Expand Down Expand Up @@ -792,7 +794,9 @@ async def set_device_state(self, state: uc.DeviceStates) -> None:
uc.EventCategory.DEVICE,
)

async def _subscribe_events(self, msg_data: dict[str, Any] | None) -> None:
async def _subscribe_events(
self, websocket: Any, msg_data: dict[str, Any] | None
) -> None:
if msg_data is None:
_LOG.warning("Ignoring _subscribe_events: called with empty msg_data")
return
Expand All @@ -806,9 +810,15 @@ async def _subscribe_events(self, msg_data: dict[str, Any] | None) -> None:
entity_id,
)

self._events.emit(uc.Events.SUBSCRIBE_ENTITIES, msg_data["entity_ids"])
self._events.emit(
uc.Events.SUBSCRIBE_ENTITIES,
entity_ids=msg_data["entity_ids"],
websocket=websocket,
)

async def _unsubscribe_events(self, msg_data: dict[str, Any] | None) -> bool:
async def _unsubscribe_events(
self, websocket: Any, msg_data: dict[str, Any] | None
) -> bool:
if msg_data is None:
_LOG.warning("Ignoring _unsubscribe_events: called with empty msg_data")
return False
Expand All @@ -819,7 +829,11 @@ async def _unsubscribe_events(self, msg_data: dict[str, Any] | None) -> bool:
if self._configured_entities.remove(entity_id) is False:
res = False

self._events.emit(uc.Events.UNSUBSCRIBE_ENTITIES, msg_data["entity_ids"])
self._events.emit(
uc.Events.UNSUBSCRIBE_ENTITIES,
entity_ids=msg_data["entity_ids"],
websocket=websocket,
)

return res

Expand Down Expand Up @@ -1114,14 +1128,65 @@ async def driver_setup_error(self, websocket, error="OTHER") -> None:
websocket, uc.WsMsgEvents.DRIVER_SETUP_CHANGE, data, uc.EventCategory.DEVICE
)

@staticmethod
def _wrap_event_listener(listener: Callable) -> Callable:
"""Event listener wrapper for backwards compatibility.

Wrap an event listener so it remains compatible if the library starts emitting
additional event parameters later.

Example:
- listener() keeps working even if emitter calls listener(websocket)
- listener(websocket) keeps working if emitter calls listener(websocket, x, y)
"""
try:
sig = inspect.signature(listener)
except (TypeError, ValueError):
# Builtins / callables without inspectable signature: fall back to raw call.
return listener

params = list(sig.parameters.values())

accepts_varargs = any(
p.kind == inspect.Parameter.VAR_POSITIONAL for p in params
)
accepts_varkw = any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params)

# How many positional args can the listener accept (excluding *args/**kwargs)?
positional_kinds = (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
)
max_positional = sum(1 for p in params if p.kind in positional_kinds)

# Which named kwargs are accepted (if no **kwargs)?
accepted_kw = {
p.name
for p in params
if p.kind
in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY)
}

@wraps(listener)
def wrapper(*args: Any, **kwargs: Any):
call_args = args if accepts_varargs else args[:max_positional]
call_kwargs = (
kwargs
if accepts_varkw
else {k: v for k, v in kwargs.items() if k in accepted_kw}
)
return listener(*call_args, **call_kwargs)

return wrapper

def add_listener(self, event: uc.Events, f: Callable) -> None:
"""
Register a callback handler for the given event.

:param event: the event
:param f: callback handler
"""
self._events.add_listener(event, f)
self._events.add_listener(event, self._wrap_event_listener(f))

def listens_to(self, event: uc.Events) -> Callable[[Callable], Callable]:
"""
Expand All @@ -1132,7 +1197,7 @@ def listens_to(self, event: uc.Events) -> Callable[[Callable], Callable]:
"""

def on(f: Callable) -> Callable:
self._events.add_listener(event, f)
self._events.add_listener(event, self._wrap_event_listener(f))
return f

return on
Expand Down
70 changes: 61 additions & 9 deletions ucapi/api_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,25 +81,77 @@ class WsMsgEvents(str, Enum):


class Events(str, Enum):
"""Internal library events."""
"""Internal library events.

All event parameters are named parameters and optional.
"""

CLIENT_CONNECTED = "client_connected"
"""WebSocket client connected."""
"""WebSocket client connected.

Named parameters:

- websocket: WebSocket client connection
"""
CLIENT_DISCONNECTED = "client_disconnected"
"""WebSocket client disconnected."""
"""WebSocket client disconnected.

Named parameters:

- websocket: WebSocket client connection
"""
ENTITY_ATTRIBUTES_UPDATED = "entity_attributes_updated"
"""Entity attributes updated.

Named parameters:

- entity_id: entity identifier
- entity_type: entity type
- attributes: updated attributes"""
SUBSCRIBE_ENTITIES = "subscribe_entities"
"""Integration API `subscribe_events` message."""
"""Integration API `subscribe_events` message.

Named parameters:

- entity_ids: list of entity IDs to subscribe to
- websocket: WebSocket client connection
"""
UNSUBSCRIBE_ENTITIES = "unsubscribe_entities"
"""Integration API `unsubscribe_events` message."""
"""Integration API `unsubscribe_events` message.

Named parameters:

- entity_ids: list of entity IDs to unsubscribe
- websocket: WebSocket client connection
"""
CONNECT = "connect"
"""Integration-API `connect` event message."""
"""Integration-API `connect` event message.

Named parameters:

- websocket: WebSocket client connection
"""
DISCONNECT = "disconnect"
"""Integration-API `disconnect` event message."""
"""Integration-API `disconnect` event message.

Named parameters:

- websocket: WebSocket client connection
"""
ENTER_STANDBY = "enter_standby"
"""Integration-API `enter_standby` event message."""
"""Integration-API `enter_standby` event message.

Named parameters:

- websocket: WebSocket client connection
"""
EXIT_STANDBY = "exit_standby"
"""Integration-API `exit_standby` event message."""
"""Integration-API `exit_standby` event message.

Named parameters:

- websocket: WebSocket client connection
"""


# Does EventCategory need to be public?
Expand Down