|
27 | 27 | from zeroconf.asyncio import AsyncServiceInfo, AsyncZeroconf |
28 | 28 |
|
29 | 29 | from . import api_definitions as uc |
| 30 | +from .api_definitions import LocalizationCfg, Version |
30 | 31 | from .entities import Entities |
31 | 32 | from .entity import EntityTypes |
32 | 33 | from .media_player import Attributes as MediaAttr |
@@ -107,6 +108,13 @@ def __init__(self, loop: AbstractEventLoop | None = None): |
107 | 108 | # Enforce: at most one active session per entity_id (across all websockets) |
108 | 109 | self._voice_session_by_entity: dict[str, VoiceSessionKey] = {} |
109 | 110 |
|
| 111 | + # One receiver per websocket (already in _handle_ws). Responses are dispatched to futures here. |
| 112 | + self._ws_pending: dict[Any, dict[int, asyncio.Future]] = {} |
| 113 | + self._ws_send_locks: dict[Any, asyncio.Lock] = {} |
| 114 | + self._req_id_lock = asyncio.Lock() |
| 115 | + |
| 116 | + self._supported_entity_types: list[EntityTypes] | None = None |
| 117 | + |
110 | 118 | # Setup event loop |
111 | 119 | asyncio.set_event_loop(self._loop) |
112 | 120 |
|
@@ -209,11 +217,17 @@ async def _start_web_socket_server(self, host: str, port: int) -> None: |
209 | 217 | async def _handle_ws(self, websocket) -> None: |
210 | 218 | try: |
211 | 219 | self._clients.add(websocket) |
| 220 | + # Init per-websocket pending requests map + send lock |
| 221 | + self._ws_pending[websocket] = {} |
| 222 | + self._ws_send_locks[websocket] = asyncio.Lock() |
212 | 223 | _LOG.info("WS: Client added: %s", websocket.remote_address) |
213 | 224 |
|
214 | 225 | # authenticate on connection |
215 | 226 | await self._authenticate(websocket, True) |
216 | 227 |
|
| 228 | + # Request supported entity types from remote |
| 229 | + asyncio.create_task(self._update_supported_entity_types(websocket)) |
| 230 | + |
217 | 231 | self._events.emit(uc.Events.CLIENT_CONNECTED) |
218 | 232 |
|
219 | 233 | async for message in websocket: |
@@ -263,7 +277,12 @@ async def _handle_ws(self, websocket) -> None: |
263 | 277 | key[1], |
264 | 278 | ex, |
265 | 279 | ) |
266 | | - |
| 280 | + # Cancel all pending requests for this websocket (client disconnected) |
| 281 | + pending = self._ws_pending.pop(websocket, {}) |
| 282 | + for _, fut in pending.items(): |
| 283 | + if not fut.done(): |
| 284 | + fut.set_exception(ConnectionError("WebSocket disconnected")) |
| 285 | + self._ws_send_locks.pop(websocket, None) |
267 | 286 | self._clients.remove(websocket) |
268 | 287 | _LOG.info("[%s] WS: Client removed", websocket.remote_address) |
269 | 288 | self._events.emit(uc.Events.CLIENT_DISCONNECTED) |
@@ -414,6 +433,102 @@ async def _process_ws_message(self, websocket, message) -> None: |
414 | 433 | await self._handle_ws_request_msg(websocket, msg, req_id, msg_data) |
415 | 434 | elif kind == "event": |
416 | 435 | await self._handle_ws_event_msg(msg, msg_data) |
| 436 | + elif kind == "resp": |
| 437 | + # Response to a previously sent request |
| 438 | + # Some implementations use "req_id", others use "id" |
| 439 | + resp_id = data.get("req_id", data.get("id")) |
| 440 | + if resp_id is None: |
| 441 | + _LOG.warning( |
| 442 | + "[%s] WS: Received resp without req_id/id: %s", |
| 443 | + websocket.remote_address, |
| 444 | + message, |
| 445 | + ) |
| 446 | + return |
| 447 | + |
| 448 | + pending = self._ws_pending.get(websocket) |
| 449 | + if not pending: |
| 450 | + _LOG.debug( |
| 451 | + "[%s] WS: No pending map for resp_id=%s (late resp?)", |
| 452 | + websocket.remote_address, |
| 453 | + resp_id, |
| 454 | + ) |
| 455 | + return |
| 456 | + fut = pending.get(int(resp_id)) |
| 457 | + if fut is None: |
| 458 | + _LOG.debug( |
| 459 | + "[%s] WS: Unmatched resp_id=%s (not pending). msg=%s", |
| 460 | + websocket.remote_address, |
| 461 | + resp_id, |
| 462 | + msg, |
| 463 | + ) |
| 464 | + return |
| 465 | + |
| 466 | + if not fut.done(): |
| 467 | + fut.set_result(data) |
| 468 | + |
| 469 | + async def _ws_request( |
| 470 | + self, |
| 471 | + websocket, |
| 472 | + msg: str, |
| 473 | + msg_data: dict[str, Any] | None = None, |
| 474 | + *, |
| 475 | + timeout: float = 10.0, |
| 476 | + ) -> dict[str, Any]: |
| 477 | + """ |
| 478 | + Send a request over websocket and await the matching response. |
| 479 | +
|
| 480 | + - Uses a Future stored in self._ws_pending[websocket][req_id] |
| 481 | + - Reader task (_handle_ws -> _process_ws_message) completes the future on 'resp' |
| 482 | + - Raises TimeoutError on timeout |
| 483 | + """ |
| 484 | + if websocket is None: |
| 485 | + if not self._clients: |
| 486 | + raise RuntimeError("No active websocket connection!") |
| 487 | + websocket = next(iter(self._clients)) |
| 488 | + |
| 489 | + # Ensure per-socket structures exist (in case you call before _handle_ws init) |
| 490 | + if websocket not in self._ws_pending: |
| 491 | + self._ws_pending[websocket] = {} |
| 492 | + if websocket not in self._ws_send_locks: |
| 493 | + self._ws_send_locks[websocket] = asyncio.Lock() |
| 494 | + |
| 495 | + # Allocate req_id safely |
| 496 | + async with self._req_id_lock: |
| 497 | + req_id = self._req_id |
| 498 | + self._req_id += 1 |
| 499 | + |
| 500 | + fut = self._loop.create_future() |
| 501 | + self._ws_pending[websocket][req_id] = fut |
| 502 | + |
| 503 | + try: |
| 504 | + payload: dict[str, Any] = {"kind": "req", "id": req_id, "msg": msg} |
| 505 | + if msg_data is not None: |
| 506 | + payload["msg_data"] = msg_data |
| 507 | + |
| 508 | + if _LOG.isEnabledFor(logging.DEBUG): |
| 509 | + _LOG.debug( |
| 510 | + "[%s] ->: %s", |
| 511 | + websocket.remote_address, |
| 512 | + filter_log_msg_data(payload), |
| 513 | + ) |
| 514 | + # Serialize sends to avoid interleaving issues (optional but recommended) |
| 515 | + async with self._ws_send_locks[websocket]: |
| 516 | + await websocket.send(json.dumps(payload)) |
| 517 | + |
| 518 | + # Await response |
| 519 | + resp = await asyncio.wait_for(fut, timeout=timeout) |
| 520 | + return resp |
| 521 | + |
| 522 | + except asyncio.TimeoutError as ex: |
| 523 | + raise TimeoutError( |
| 524 | + f"Timeout waiting for response to '{msg}' (req_id={req_id})" |
| 525 | + ) from ex |
| 526 | + |
| 527 | + finally: |
| 528 | + # Cleanup pending future entry |
| 529 | + pending = self._ws_pending.get(websocket) |
| 530 | + if pending: |
| 531 | + pending.pop(req_id, None) |
417 | 532 |
|
418 | 533 | async def _process_ws_binary_message(self, websocket, data: bytes) -> None: |
419 | 534 | """Process a binary WebSocket message using protobuf IntegrationMessage. |
@@ -694,12 +809,11 @@ async def _handle_ws_request_msg( |
694 | 809 | {"state": self.device_state}, |
695 | 810 | ) |
696 | 811 | elif msg == uc.WsMessages.GET_AVAILABLE_ENTITIES: |
697 | | - available_entities = self._available_entities.get_all() |
698 | 812 | await self._send_ws_response( |
699 | 813 | websocket, |
700 | 814 | req_id, |
701 | 815 | uc.WsMsgEvents.AVAILABLE_ENTITIES, |
702 | | - {"available_entities": available_entities}, |
| 816 | + {"available_entities": self._available_entities.get_all()}, |
703 | 817 | ) |
704 | 818 | elif msg == uc.WsMessages.GET_ENTITY_STATES: |
705 | 819 | entity_states = await self._configured_entities.get_states() |
@@ -1158,53 +1272,77 @@ def remove_all_listeners(self, event: uc.Events | None) -> None: |
1158 | 1272 | """ |
1159 | 1273 | self._events.remove_all_listeners(event) |
1160 | 1274 |
|
1161 | | - async def get_supported_entity_types(self, websocket=None): |
1162 | | - """Send get_supported_entity_types request and wait for response.""" |
1163 | | - if websocket is None: |
1164 | | - if not self._clients: |
1165 | | - raise RuntimeError("No active websocket connection!") |
1166 | | - websocket = next(iter(self._clients)) |
1167 | | - req_id = self._req_id |
1168 | | - self._req_id += 1 |
1169 | | - request = {"kind": "req", "id": req_id, "msg": "get_supported_entity_types"} |
1170 | | - await websocket.send(json.dumps(request)) |
1171 | | - while True: |
1172 | | - response = await websocket.recv() |
1173 | | - data = json.loads(response) |
1174 | | - if data.get("kind") == "resp" and data.get("req_id") == req_id and data.get("msg") == "supported_entity_types": |
1175 | | - return data.get("msg_data") |
1176 | | - |
1177 | | - async def get_version(self, websocket=None): |
1178 | | - """Send get_version request and wait for response.""" |
1179 | | - if websocket is None: |
1180 | | - if not self._clients: |
1181 | | - raise RuntimeError("No active websocket connection!") |
1182 | | - websocket = next(iter(self._clients)) |
1183 | | - req_id = self._req_id |
1184 | | - self._req_id += 1 |
1185 | | - request = {"kind": "req", "id": req_id, "msg": "get_version"} |
1186 | | - await websocket.send(json.dumps(request)) |
1187 | | - while True: |
1188 | | - response = await websocket.recv() |
1189 | | - data = json.loads(response) |
1190 | | - if data.get("kind") == "resp" and data.get("req_id") == req_id and data.get("msg") == "version": |
1191 | | - return data.get("msg_data") |
1192 | | - |
1193 | | - async def get_localization_cfg(self, websocket=None): |
1194 | | - """Send get_localization_cfg request and wait for response.""" |
1195 | | - if websocket is None: |
1196 | | - if not self._clients: |
1197 | | - raise RuntimeError("No active websocket connection!") |
1198 | | - websocket = next(iter(self._clients)) |
1199 | | - req_id = self._req_id |
1200 | | - self._req_id += 1 |
1201 | | - request = {"kind": "req", "id": req_id, "msg": "get_localization_cfg"} |
1202 | | - await websocket.send(json.dumps(request)) |
1203 | | - while True: |
1204 | | - response = await websocket.recv() |
1205 | | - data = json.loads(response) |
1206 | | - if data.get("kind") == "resp" and data.get("req_id") == req_id and data.get("msg") == "localization_cfg": |
1207 | | - return data.get("msg_data") |
| 1275 | + async def get_supported_entity_types( |
| 1276 | + self, websocket=None, *, timeout: float = 5.0 |
| 1277 | + ) -> list[EntityTypes]: |
| 1278 | + """Request supported entity types from client and return msg_data.""" |
| 1279 | + resp = await self._ws_request( |
| 1280 | + websocket, |
| 1281 | + "get_supported_entity_types", |
| 1282 | + timeout=timeout, |
| 1283 | + ) |
| 1284 | + if resp.get("msg") != "supported_entity_types": |
| 1285 | + _LOG.debug( |
| 1286 | + "[%s] Unexpected resp msg for get_supported_entity_types: %s", |
| 1287 | + websocket.remote_address if websocket else "", |
| 1288 | + resp.get("msg"), |
| 1289 | + ) |
| 1290 | + entity_types: list[EntityTypes] = [] |
| 1291 | + for entity_type in resp.get("msg_data", []): |
| 1292 | + try: |
| 1293 | + entity_types.append(EntityTypes(entity_type)) |
| 1294 | + except ValueError: |
| 1295 | + pass |
| 1296 | + return entity_types |
| 1297 | + |
| 1298 | + async def _update_supported_entity_types( |
| 1299 | + self, websocket=None, *, timeout: float = 5.0 |
| 1300 | + ) -> None: |
| 1301 | + """Update supported entity types by remote.""" |
| 1302 | + await asyncio.sleep(0) |
| 1303 | + self._supported_entity_types = await self.get_supported_entity_types( |
| 1304 | + websocket, timeout=timeout |
| 1305 | + ) |
| 1306 | + _LOG.debug( |
| 1307 | + "[%s] Supported entity types %s", |
| 1308 | + websocket.remote_address if websocket else "", |
| 1309 | + self._supported_entity_types, |
| 1310 | + ) |
| 1311 | + |
| 1312 | + async def get_version(self, websocket=None, *, timeout: float = 5.0) -> Version: |
| 1313 | + """Request client version and return msg_data.""" |
| 1314 | + resp = await self._ws_request( |
| 1315 | + websocket, |
| 1316 | + "get_version", |
| 1317 | + timeout=timeout, |
| 1318 | + ) |
| 1319 | + if resp.get("msg") != "version": |
| 1320 | + _LOG.debug( |
| 1321 | + "[%s] Unexpected resp msg for get_version: %s", |
| 1322 | + websocket.remote_address if websocket else "", |
| 1323 | + resp.get("msg"), |
| 1324 | + ) |
| 1325 | + |
| 1326 | + return resp.get("msg_data") |
| 1327 | + |
| 1328 | + async def get_localization_cfg( |
| 1329 | + self, websocket=None, *, timeout: float = 5.0 |
| 1330 | + ) -> LocalizationCfg: |
| 1331 | + """Request localization config and return msg_data.""" |
| 1332 | + resp = await self._ws_request( |
| 1333 | + websocket, |
| 1334 | + "get_localization_cfg", |
| 1335 | + timeout=timeout, |
| 1336 | + ) |
| 1337 | + |
| 1338 | + if resp.get("msg") != "localization_cfg": |
| 1339 | + _LOG.debug( |
| 1340 | + "[%s] Unexpected resp msg for get_localization_cfg: %s", |
| 1341 | + websocket.remote_address if websocket else "", |
| 1342 | + resp.get("msg"), |
| 1343 | + ) |
| 1344 | + |
| 1345 | + return resp.get("msg_data") |
1208 | 1346 |
|
1209 | 1347 | ############## |
1210 | 1348 | # Properties # |
|
0 commit comments