Skip to content

Commit df90062

Browse files
committed
Working handle of websocket requests from ucapi
Improved requests signatures Handled unkown entity types
1 parent d6c5611 commit df90062

File tree

2 files changed

+213
-50
lines changed

2 files changed

+213
-50
lines changed

ucapi/api.py

Lines changed: 188 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from zeroconf.asyncio import AsyncServiceInfo, AsyncZeroconf
2828

2929
from . import api_definitions as uc
30+
from .api_definitions import LocalizationCfg, Version
3031
from .entities import Entities
3132
from .entity import EntityTypes
3233
from .media_player import Attributes as MediaAttr
@@ -107,6 +108,13 @@ def __init__(self, loop: AbstractEventLoop | None = None):
107108
# Enforce: at most one active session per entity_id (across all websockets)
108109
self._voice_session_by_entity: dict[str, VoiceSessionKey] = {}
109110

111+
# One receiver per websocket (already in _handle_ws). Responses are dispatched to futures here.
112+
self._ws_pending: dict[Any, dict[int, asyncio.Future]] = {}
113+
self._ws_send_locks: dict[Any, asyncio.Lock] = {}
114+
self._req_id_lock = asyncio.Lock()
115+
116+
self._supported_entity_types: list[EntityTypes] | None = None
117+
110118
# Setup event loop
111119
asyncio.set_event_loop(self._loop)
112120

@@ -209,11 +217,17 @@ async def _start_web_socket_server(self, host: str, port: int) -> None:
209217
async def _handle_ws(self, websocket) -> None:
210218
try:
211219
self._clients.add(websocket)
220+
# Init per-websocket pending requests map + send lock
221+
self._ws_pending[websocket] = {}
222+
self._ws_send_locks[websocket] = asyncio.Lock()
212223
_LOG.info("WS: Client added: %s", websocket.remote_address)
213224

214225
# authenticate on connection
215226
await self._authenticate(websocket, True)
216227

228+
# Request supported entity types from remote
229+
asyncio.create_task(self._update_supported_entity_types(websocket))
230+
217231
self._events.emit(uc.Events.CLIENT_CONNECTED)
218232

219233
async for message in websocket:
@@ -263,7 +277,12 @@ async def _handle_ws(self, websocket) -> None:
263277
key[1],
264278
ex,
265279
)
266-
280+
# Cancel all pending requests for this websocket (client disconnected)
281+
pending = self._ws_pending.pop(websocket, {})
282+
for _, fut in pending.items():
283+
if not fut.done():
284+
fut.set_exception(ConnectionError("WebSocket disconnected"))
285+
self._ws_send_locks.pop(websocket, None)
267286
self._clients.remove(websocket)
268287
_LOG.info("[%s] WS: Client removed", websocket.remote_address)
269288
self._events.emit(uc.Events.CLIENT_DISCONNECTED)
@@ -414,6 +433,102 @@ async def _process_ws_message(self, websocket, message) -> None:
414433
await self._handle_ws_request_msg(websocket, msg, req_id, msg_data)
415434
elif kind == "event":
416435
await self._handle_ws_event_msg(msg, msg_data)
436+
elif kind == "resp":
437+
# Response to a previously sent request
438+
# Some implementations use "req_id", others use "id"
439+
resp_id = data.get("req_id", data.get("id"))
440+
if resp_id is None:
441+
_LOG.warning(
442+
"[%s] WS: Received resp without req_id/id: %s",
443+
websocket.remote_address,
444+
message,
445+
)
446+
return
447+
448+
pending = self._ws_pending.get(websocket)
449+
if not pending:
450+
_LOG.debug(
451+
"[%s] WS: No pending map for resp_id=%s (late resp?)",
452+
websocket.remote_address,
453+
resp_id,
454+
)
455+
return
456+
fut = pending.get(int(resp_id))
457+
if fut is None:
458+
_LOG.debug(
459+
"[%s] WS: Unmatched resp_id=%s (not pending). msg=%s",
460+
websocket.remote_address,
461+
resp_id,
462+
msg,
463+
)
464+
return
465+
466+
if not fut.done():
467+
fut.set_result(data)
468+
469+
async def _ws_request(
470+
self,
471+
websocket,
472+
msg: str,
473+
msg_data: dict[str, Any] | None = None,
474+
*,
475+
timeout: float = 10.0,
476+
) -> dict[str, Any]:
477+
"""
478+
Send a request over websocket and await the matching response.
479+
480+
- Uses a Future stored in self._ws_pending[websocket][req_id]
481+
- Reader task (_handle_ws -> _process_ws_message) completes the future on 'resp'
482+
- Raises TimeoutError on timeout
483+
"""
484+
if websocket is None:
485+
if not self._clients:
486+
raise RuntimeError("No active websocket connection!")
487+
websocket = next(iter(self._clients))
488+
489+
# Ensure per-socket structures exist (in case you call before _handle_ws init)
490+
if websocket not in self._ws_pending:
491+
self._ws_pending[websocket] = {}
492+
if websocket not in self._ws_send_locks:
493+
self._ws_send_locks[websocket] = asyncio.Lock()
494+
495+
# Allocate req_id safely
496+
async with self._req_id_lock:
497+
req_id = self._req_id
498+
self._req_id += 1
499+
500+
fut = self._loop.create_future()
501+
self._ws_pending[websocket][req_id] = fut
502+
503+
try:
504+
payload: dict[str, Any] = {"kind": "req", "id": req_id, "msg": msg}
505+
if msg_data is not None:
506+
payload["msg_data"] = msg_data
507+
508+
if _LOG.isEnabledFor(logging.DEBUG):
509+
_LOG.debug(
510+
"[%s] ->: %s",
511+
websocket.remote_address,
512+
filter_log_msg_data(payload),
513+
)
514+
# Serialize sends to avoid interleaving issues (optional but recommended)
515+
async with self._ws_send_locks[websocket]:
516+
await websocket.send(json.dumps(payload))
517+
518+
# Await response
519+
resp = await asyncio.wait_for(fut, timeout=timeout)
520+
return resp
521+
522+
except asyncio.TimeoutError as ex:
523+
raise TimeoutError(
524+
f"Timeout waiting for response to '{msg}' (req_id={req_id})"
525+
) from ex
526+
527+
finally:
528+
# Cleanup pending future entry
529+
pending = self._ws_pending.get(websocket)
530+
if pending:
531+
pending.pop(req_id, None)
417532

418533
async def _process_ws_binary_message(self, websocket, data: bytes) -> None:
419534
"""Process a binary WebSocket message using protobuf IntegrationMessage.
@@ -694,12 +809,11 @@ async def _handle_ws_request_msg(
694809
{"state": self.device_state},
695810
)
696811
elif msg == uc.WsMessages.GET_AVAILABLE_ENTITIES:
697-
available_entities = self._available_entities.get_all()
698812
await self._send_ws_response(
699813
websocket,
700814
req_id,
701815
uc.WsMsgEvents.AVAILABLE_ENTITIES,
702-
{"available_entities": available_entities},
816+
{"available_entities": self._available_entities.get_all()},
703817
)
704818
elif msg == uc.WsMessages.GET_ENTITY_STATES:
705819
entity_states = await self._configured_entities.get_states()
@@ -1158,53 +1272,77 @@ def remove_all_listeners(self, event: uc.Events | None) -> None:
11581272
"""
11591273
self._events.remove_all_listeners(event)
11601274

1161-
async def get_supported_entity_types(self, websocket=None):
1162-
"""Send get_supported_entity_types request and wait for response."""
1163-
if websocket is None:
1164-
if not self._clients:
1165-
raise RuntimeError("No active websocket connection!")
1166-
websocket = next(iter(self._clients))
1167-
req_id = self._req_id
1168-
self._req_id += 1
1169-
request = {"kind": "req", "id": req_id, "msg": "get_supported_entity_types"}
1170-
await websocket.send(json.dumps(request))
1171-
while True:
1172-
response = await websocket.recv()
1173-
data = json.loads(response)
1174-
if data.get("kind") == "resp" and data.get("req_id") == req_id and data.get("msg") == "supported_entity_types":
1175-
return data.get("msg_data")
1176-
1177-
async def get_version(self, websocket=None):
1178-
"""Send get_version request and wait for response."""
1179-
if websocket is None:
1180-
if not self._clients:
1181-
raise RuntimeError("No active websocket connection!")
1182-
websocket = next(iter(self._clients))
1183-
req_id = self._req_id
1184-
self._req_id += 1
1185-
request = {"kind": "req", "id": req_id, "msg": "get_version"}
1186-
await websocket.send(json.dumps(request))
1187-
while True:
1188-
response = await websocket.recv()
1189-
data = json.loads(response)
1190-
if data.get("kind") == "resp" and data.get("req_id") == req_id and data.get("msg") == "version":
1191-
return data.get("msg_data")
1192-
1193-
async def get_localization_cfg(self, websocket=None):
1194-
"""Send get_localization_cfg request and wait for response."""
1195-
if websocket is None:
1196-
if not self._clients:
1197-
raise RuntimeError("No active websocket connection!")
1198-
websocket = next(iter(self._clients))
1199-
req_id = self._req_id
1200-
self._req_id += 1
1201-
request = {"kind": "req", "id": req_id, "msg": "get_localization_cfg"}
1202-
await websocket.send(json.dumps(request))
1203-
while True:
1204-
response = await websocket.recv()
1205-
data = json.loads(response)
1206-
if data.get("kind") == "resp" and data.get("req_id") == req_id and data.get("msg") == "localization_cfg":
1207-
return data.get("msg_data")
1275+
async def get_supported_entity_types(
1276+
self, websocket=None, *, timeout: float = 5.0
1277+
) -> list[EntityTypes]:
1278+
"""Request supported entity types from client and return msg_data."""
1279+
resp = await self._ws_request(
1280+
websocket,
1281+
"get_supported_entity_types",
1282+
timeout=timeout,
1283+
)
1284+
if resp.get("msg") != "supported_entity_types":
1285+
_LOG.debug(
1286+
"[%s] Unexpected resp msg for get_supported_entity_types: %s",
1287+
websocket.remote_address if websocket else "",
1288+
resp.get("msg"),
1289+
)
1290+
entity_types: list[EntityTypes] = []
1291+
for entity_type in resp.get("msg_data", []):
1292+
try:
1293+
entity_types.append(EntityTypes(entity_type))
1294+
except ValueError:
1295+
pass
1296+
return entity_types
1297+
1298+
async def _update_supported_entity_types(
1299+
self, websocket=None, *, timeout: float = 5.0
1300+
) -> None:
1301+
"""Update supported entity types by remote."""
1302+
await asyncio.sleep(0)
1303+
self._supported_entity_types = await self.get_supported_entity_types(
1304+
websocket, timeout=timeout
1305+
)
1306+
_LOG.debug(
1307+
"[%s] Supported entity types %s",
1308+
websocket.remote_address if websocket else "",
1309+
self._supported_entity_types,
1310+
)
1311+
1312+
async def get_version(self, websocket=None, *, timeout: float = 5.0) -> Version:
1313+
"""Request client version and return msg_data."""
1314+
resp = await self._ws_request(
1315+
websocket,
1316+
"get_version",
1317+
timeout=timeout,
1318+
)
1319+
if resp.get("msg") != "version":
1320+
_LOG.debug(
1321+
"[%s] Unexpected resp msg for get_version: %s",
1322+
websocket.remote_address if websocket else "",
1323+
resp.get("msg"),
1324+
)
1325+
1326+
return resp.get("msg_data")
1327+
1328+
async def get_localization_cfg(
1329+
self, websocket=None, *, timeout: float = 5.0
1330+
) -> LocalizationCfg:
1331+
"""Request localization config and return msg_data."""
1332+
resp = await self._ws_request(
1333+
websocket,
1334+
"get_localization_cfg",
1335+
timeout=timeout,
1336+
)
1337+
1338+
if resp.get("msg") != "localization_cfg":
1339+
_LOG.debug(
1340+
"[%s] Unexpected resp msg for get_localization_cfg: %s",
1341+
websocket.remote_address if websocket else "",
1342+
resp.get("msg"),
1343+
)
1344+
1345+
return resp.get("msg_data")
12081346

12091347
##############
12101348
# Properties #

ucapi/api_definitions.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,3 +302,28 @@ class AssistantEvent:
302302
entity_id: str
303303
session_id: int
304304
data: AssistantEventData | None = None
305+
306+
307+
@dataclass
308+
class Version:
309+
"""Version response payload sent via the ``get_version`` request."""
310+
311+
model: str
312+
device_name: str
313+
hostname: str
314+
address: str
315+
api: str
316+
core: str
317+
ui: str
318+
os: str
319+
320+
321+
@dataclass
322+
class LocalizationCfg:
323+
"""Localization response payload sent via the ``get_localization_cfg`` request."""
324+
325+
language_code: str
326+
country_code: str
327+
time_zone: str
328+
time_format_24h: bool
329+
measurement_unit: str

0 commit comments

Comments
 (0)