diff --git a/src/pymc_core/node/__init__.py b/src/pymc_core/node/__init__.py index ef3d95f..ca840c1 100644 --- a/src/pymc_core/node/__init__.py +++ b/src/pymc_core/node/__init__.py @@ -18,6 +18,7 @@ TraceHandler, ) from .node import MeshNode +from .contact_book import ContactBook, ContactBookPreferences, ContactPermissions, ContactRecord __all__ = [ "MeshNode", @@ -36,4 +37,8 @@ "ProtocolResponseHandler", "AnonReqResponseHandler", "TraceHandler", + "ContactBook", + "ContactBookPreferences", + "ContactPermissions", + "ContactRecord", ] diff --git a/src/pymc_core/node/contact_book.py b/src/pymc_core/node/contact_book.py new file mode 100644 index 0000000..039a758 --- /dev/null +++ b/src/pymc_core/node/contact_book.py @@ -0,0 +1,223 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +import time +from typing import Iterable, List, Optional, Sequence + +from ..protocol import ( + CONTACT_TYPE_CHAT_NODE, + CONTACT_TYPE_HYBRID, + CONTACT_TYPE_REPEATER, + CONTACT_TYPE_ROOM_SERVER, + CONTACT_TYPE_UNKNOWN, +) + + +@dataclass +class ContactPermissions: + allow_cli: bool = False + allow_telemetry: bool = False + allow_bridge: bool = False + + +@dataclass +class ContactRecord: + public_key: str + name: str = "" + contact_type: int = CONTACT_TYPE_UNKNOWN + flags: int = 0 + longitude: float = 0.0 + latitude: float = 0.0 + last_advert: int = 0 + tags: set[str] = field(default_factory=set) + permissions: ContactPermissions = field(default_factory=ContactPermissions) + out_path: List[int] = field(default_factory=list) + last_path_update: int = 0 + + def src_hash(self) -> Optional[int]: + try: + return bytes.fromhex(self.public_key)[0] + except Exception: + return None + + +@dataclass +class ContactBookPreferences: + allow_read_only: bool = False + bridge_enabled: bool = False + + +class ContactBook: + """Contact store with MeshCore-style ACL helpers.""" + + def __init__( + self, + contacts: Optional[Iterable[ContactRecord | dict]] = None, + prefs: Optional[ContactBookPreferences] = None, + ) -> None: + self.prefs = prefs or ContactBookPreferences() + self.contacts: List[ContactRecord] = [] + if contacts: + for entry in contacts: + self.add_contact(entry) + + # ------------------------------------------------------------------ + # Contact CRUD + # ------------------------------------------------------------------ + def add_contact(self, contact: ContactRecord | dict) -> ContactRecord: + record = self._normalize_contact(contact) + existing = self.get_by_public_key(record.public_key) + if existing: + self._update_contact(existing, record) + return existing + + self._apply_default_permissions(record) + self.contacts.append(record) + return record + + def list_contacts(self) -> List[ContactRecord]: + return list(self.contacts) + + def get_by_public_key(self, pubkey_hex: str) -> Optional[ContactRecord]: + for contact in self.contacts: + if contact.public_key.lower() == pubkey_hex.lower(): + return contact + return None + + def get_by_hash(self, hash_byte: int) -> Optional[ContactRecord]: + for contact in self.contacts: + if contact.src_hash() == hash_byte: + return contact + return None + + def get_by_name(self, name: str) -> Optional[ContactRecord]: + for contact in self.contacts: + if contact.name == name: + return contact + return None + + def remove_contact(self, pubkey_hex: str) -> bool: + before = len(self.contacts) + self.contacts = [c for c in self.contacts if c.public_key.lower() != pubkey_hex.lower()] + return len(self.contacts) != before + + # ------------------------------------------------------------------ + # Preferences / ACL management + # ------------------------------------------------------------------ + def update_preferences( + self, + *, + allow_read_only: Optional[bool] = None, + bridge_enabled: Optional[bool] = None, + ) -> None: + if allow_read_only is not None: + self.prefs.allow_read_only = allow_read_only + if bridge_enabled is not None: + self.prefs.bridge_enabled = bridge_enabled + for contact in self.contacts: + self._apply_default_permissions(contact, overwrite=False) + + def set_permissions( + self, + pubkey_hex: str, + *, + allow_cli: Optional[bool] = None, + allow_telemetry: Optional[bool] = None, + allow_bridge: Optional[bool] = None, + ) -> None: + contact = self.get_by_public_key(pubkey_hex) + if not contact: + raise ValueError(f"Unknown contact {pubkey_hex}") + if allow_cli is not None: + contact.permissions.allow_cli = allow_cli + if allow_telemetry is not None: + contact.permissions.allow_telemetry = allow_telemetry + if allow_bridge is not None: + contact.permissions.allow_bridge = allow_bridge + + # ------------------------------------------------------------------ + # Permission helpers + # ------------------------------------------------------------------ + def can_execute_cli(self, contact: ContactRecord | int | str) -> bool: + record = self._resolve_contact(contact) + return bool(record and record.permissions.allow_cli) + + def can_receive_telemetry(self, contact: ContactRecord | int | str) -> bool: + record = self._resolve_contact(contact) + return bool(record and record.permissions.allow_telemetry) + + def can_use_bridge(self, contact: ContactRecord | int | str) -> bool: + record = self._resolve_contact(contact) + return bool(record and record.permissions.allow_bridge) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + def _normalize_contact(self, contact: ContactRecord | dict) -> ContactRecord: + if isinstance(contact, ContactRecord): + return contact + + return ContactRecord( + public_key=contact["public_key"], + name=contact.get("name", ""), + contact_type=contact.get("type", contact.get("contact_type", CONTACT_TYPE_UNKNOWN)), + flags=contact.get("flags", 0), + longitude=contact.get("longitude", 0.0), + latitude=contact.get("latitude", 0.0), + last_advert=contact.get("last_advert", 0), + out_path=list(contact.get("out_path", []) or []), + last_path_update=contact.get("last_path_update", 0), + ) + + def _update_contact(self, dest: ContactRecord, src: ContactRecord) -> None: + dest.name = src.name or dest.name + dest.contact_type = src.contact_type or dest.contact_type + dest.flags = src.flags or dest.flags + dest.longitude = src.longitude or dest.longitude + dest.latitude = src.latitude or dest.latitude + dest.last_advert = src.last_advert or dest.last_advert + if src.tags: + dest.tags.update(src.tags) + if src.permissions != dest.permissions: + dest.permissions = src.permissions + if src.out_path: + dest.out_path = list(src.out_path) + dest.last_path_update = src.last_path_update or dest.last_path_update + + def _apply_default_permissions(self, contact: ContactRecord, *, overwrite: bool = True) -> None: + if overwrite: + perms = ContactPermissions() + else: + perms = contact.permissions + perms.allow_cli = False + perms.allow_telemetry = False + perms.allow_bridge = False + if contact.contact_type in (CONTACT_TYPE_ROOM_SERVER, CONTACT_TYPE_HYBRID): + perms.allow_cli = True + perms.allow_telemetry = True + elif contact.contact_type == CONTACT_TYPE_CHAT_NODE and self.prefs.allow_read_only: + perms.allow_cli = True + if self.prefs.bridge_enabled and contact.contact_type in ( + CONTACT_TYPE_ROOM_SERVER, + CONTACT_TYPE_HYBRID, + ): + perms.allow_bridge = True + contact.permissions = perms + + def _resolve_contact(self, ref: ContactRecord | int | str) -> Optional[ContactRecord]: + if isinstance(ref, ContactRecord): + return ref + if isinstance(ref, int): + return self.get_by_hash(ref) + return self.get_by_public_key(ref) + + # ------------------------------------------------------------------ + # Path helpers + # ------------------------------------------------------------------ + def update_out_path(self, contact: ContactRecord | str | int, path: Sequence[int]) -> Optional[ContactRecord]: + record = self._resolve_contact(contact) + if not record: + return None + record.out_path = list(path) + record.last_path_update = int(time.time()) + return record \ No newline at end of file diff --git a/src/pymc_core/node/dispatcher.py b/src/pymc_core/node/dispatcher.py index 0531123..4ad5848 100644 --- a/src/pymc_core/node/dispatcher.py +++ b/src/pymc_core/node/dispatcher.py @@ -3,12 +3,14 @@ import asyncio import enum import logging +import time from typing import Any, Awaitable, Callable, Optional -from ..protocol import Packet +from ..protocol import Packet, PacketTimingUtils from ..protocol.constants import ( # Payload types PAYLOAD_TYPE_ACK, PAYLOAD_TYPE_ADVERT, + PAYLOAD_TYPE_MULTIPART, PH_TYPE_SHIFT, ) from ..protocol.utils import PAYLOAD_TYPES, ROUTE_TYPES, format_packet_info @@ -28,6 +30,8 @@ ) ACK_TIMEOUT = 5.0 # seconds to wait for an ACK +OWN_PACKET_CACHE_TTL = 180.0 # seconds to keep outbound packet hashes +OWN_PACKET_CACHE_MAX = 2048 # max outbound packet hashes to track class DispatcherState(str, enum.Enum): @@ -56,10 +60,17 @@ def __init__( tx_delay: float = 0.05, log_fn: Optional[Callable[[str], None]] = None, packet_filter: Optional[Any] = None, + radio_config: Optional[dict] = None, ) -> None: self.radio = radio self.tx_delay = tx_delay self.state: DispatcherState = DispatcherState.IDLE + self.radio_config: dict = dict(radio_config or {}) + self._score_delay_threshold_ms = 50 + self._next_tx_allowed_at: float = 0.0 + self._recent_tx_packets: dict[int, float] = {} + self._own_packet_cache_ttl = OWN_PACKET_CACHE_TTL + self._own_packet_cache_max = OWN_PACKET_CACHE_MAX self.packet_received_callback: Optional[Callable[[Packet], Awaitable[None] | None]] = None self.packet_sent_callback: Optional[Callable[[Packet], Awaitable[None] | None]] = None @@ -148,6 +159,8 @@ def register_default_handlers( """Quick setup for all the standard packet handlers.""" # Keep our identity handy for detecting our own packets self.local_identity = local_identity + if radio_config is not None: + self.radio_config = dict(radio_config) # Set up ACK handler with callback to us ack_handler = AckHandler(self._log, self) @@ -159,6 +172,7 @@ def register_default_handlers( AdvertHandler(contacts, self._log, local_identity, event_service), ) self.register_handler(AckHandler.payload_type(), ack_handler) + self.register_handler(PAYLOAD_TYPE_MULTIPART, ack_handler) # Text message handler - needs to send ACKs back through us text_message_handler = TextMessageHandler( @@ -204,7 +218,13 @@ def register_default_handlers( self.telemetry_response_handler = protocol_response_handler # PATH handler - for route discovery packets, with ACK and protocol response processing - path_handler = PathHandler(self._log, ack_handler, protocol_response_handler) + path_handler = PathHandler( + self._log, + local_identity=local_identity, + contact_book=contacts, + ack_handler=ack_handler, + protocol_response_handler=protocol_response_handler, + ) self.register_handler(PathHandler.payload_type(), path_handler) # Login response handler for PAYLOAD_TYPE_RESPONSE packets @@ -259,22 +279,17 @@ def _get_handler(self, ptype: int): return self._handlers.get(ptype, self._fallback_handler) def _is_own_packet(self, pkt: Packet) -> bool: - """Check if this packet came from us by comparing the source hash.""" - if not self.local_identity or len(pkt.payload) < 2: - return False + """Detect our own packets by matching recently transmitted CRCs.""" - # Get our public key hash (first byte) - our_pubkey = self.local_identity.get_public_key() - our_hash = our_pubkey[0] if len(our_pubkey) > 0 else 0 - - # Compare with src_hash in payload[1] - src_hash = pkt.payload[1] - is_own = src_hash == our_hash + if not self.local_identity: + return False - if is_own: - self._log(f"Own packet detected: src_hash={src_hash:02X}, our_hash={our_hash:02X}") + crc = pkt.get_crc() + if self._is_recent_outbound_crc(crc): + self._log(f"Own packet detected via CRC {crc:08X}") + return True - return is_own + return False def set_packet_received_callback( self, callback: Callable[[Packet], Awaitable[None] | None] @@ -346,8 +361,6 @@ async def _process_received_packet(self, data: bytes) -> None: # Let the node know about this packet for analysis (statistics, caching, etc.) if self.packet_analysis_callback: try: - import asyncio - if asyncio.iscoroutinefunction(self.packet_analysis_callback): await self.packet_analysis_callback(pkt, data) else: @@ -372,10 +385,145 @@ async def _process_received_packet(self, data: bytes) -> None: self._log(f"Ignoring own packet (type={pkt.header >> 4:02X}) to prevent loops") return + if pkt.is_route_flood(): + delay_ms, score, airtime_ms = self._calculate_flood_delay_ms(pkt) + if delay_ms >= self._score_delay_threshold_ms: + self._log( + "[RX DEBUG] Flood packet delay: " + f"{delay_ms}ms (score={score:.2f}, airtime={airtime_ms:.1f}ms)" + ) + await asyncio.sleep(delay_ms / 1000.0) + else: + self._log( + "[RX DEBUG] Flood score delay below threshold " + f"({delay_ms}ms), processing immediately" + ) + # Handle ACK matching for waiting senders self._log("[RX DEBUG] Dispatching packet to handlers") await self._dispatch(pkt) + def _get_spreading_factor(self) -> Optional[int]: + if not self.radio_config: + return None + sf = self.radio_config.get("spreading_factor") + try: + return int(sf) if sf is not None else None + except (TypeError, ValueError): + return None + + def _calculate_flood_delay_ms(self, pkt: Packet) -> tuple[int, float, float]: + packet_len = pkt.get_raw_length() + airtime_ms = PacketTimingUtils.estimate_airtime_ms(packet_len, self.radio_config or None) + snr_db = pkt.snr if pkt.snr is not None else 0.0 + score = PacketTimingUtils.calculate_packet_score( + snr_db, + packet_len, + self._get_spreading_factor(), + ) + delay_ms = PacketTimingUtils.calc_rx_delay_ms(score, airtime_ms) + return delay_ms, score, airtime_ms + + def _estimate_packet_airtime_ms(self, packet: Packet) -> float: + return PacketTimingUtils.estimate_airtime_ms(packet.get_raw_length(), self.radio_config or None) + + async def _await_tx_budget_window(self) -> None: + if self._next_tx_allowed_at <= 0: + return + now = asyncio.get_event_loop().time() + delay = self._next_tx_allowed_at - now + if delay > 0: + self._log(f"[TX DEBUG] Airtime budget wait {delay * 1000:.0f}ms") + await asyncio.sleep(delay) + + def _schedule_next_tx_window(self, packet_airtime_ms: float) -> None: + delay_ms = PacketTimingUtils.calc_airtime_budget_delay_ms(packet_airtime_ms) + if delay_ms <= 0: + self._next_tx_allowed_at = 0.0 + return + now = asyncio.get_event_loop().time() + self._next_tx_allowed_at = now + (delay_ms / 1000.0) + self._log(f"[TX DEBUG] Next TX allowed in {delay_ms:.0f}ms") + + async def _ensure_channel_clear(self) -> None: + cad_method = getattr(self.radio, "perform_cad", None) + if not callable(cad_method): + return + + retry_delay = PacketTimingUtils.get_cad_fail_retry_delay_ms() / 1000.0 + max_duration = PacketTimingUtils.get_cad_fail_max_duration_ms() / 1000.0 + start = asyncio.get_event_loop().time() + attempt = 0 + + while True: + attempt += 1 + try: + cad_result = await cad_method() + except Exception as exc: + self._log(f"[TX DEBUG] CAD attempt {attempt} failed: {exc}; continuing with TX") + return + + if isinstance(cad_result, dict): + channel_busy = bool( + cad_result.get("detected") + or cad_result.get("cad_detected") + or cad_result.get("activity") + ) + else: + channel_busy = bool(cad_result) + + if not channel_busy: + if attempt > 1: + self._log(f"[TX DEBUG] CAD cleared after {attempt} attempts") + return + + elapsed = asyncio.get_event_loop().time() - start + remaining = max_duration - elapsed + if remaining <= 0: + self._log("[TX DEBUG] CAD busy window exceeded, forcing transmit") + return + + backoff = min(retry_delay, remaining) + self._log( + f"[TX DEBUG] CAD detected activity (attempt {attempt}), backing off {backoff * 1000:.0f}ms" + ) + await asyncio.sleep(backoff) + + def _record_outbound_packet_crc(self, crc: int | None) -> None: + if crc is None: + return + now = time.monotonic() + self._recent_tx_packets[crc] = now + self._prune_recent_outbound(now) + + def _is_recent_outbound_crc(self, crc: int | None) -> bool: + if crc is None: + return False + now = time.monotonic() + self._prune_recent_outbound(now) + ts = self._recent_tx_packets.get(crc) + if ts is None: + return False + if now - ts > self._own_packet_cache_ttl: + self._recent_tx_packets.pop(crc, None) + return False + return True + + def _prune_recent_outbound(self, now: float) -> None: + ttl = self._own_packet_cache_ttl + expired = [crc for crc, ts in self._recent_tx_packets.items() if now - ts > ttl] + for crc in expired: + self._recent_tx_packets.pop(crc, None) + + if len(self._recent_tx_packets) <= self._own_packet_cache_max: + return + + # Drop oldest entries until within cap to bound memory usage + for crc, _ in sorted(self._recent_tx_packets.items(), key=lambda item: item[1]): + self._recent_tx_packets.pop(crc, None) + if len(self._recent_tx_packets) <= self._own_packet_cache_max: + break + # ------------------------------------------------------------------ # Public interface - sending and receiving packets # ------------------------------------------------------------------ @@ -396,6 +544,7 @@ async def send_packet( If None, will be calculated from packet. """ payload_type = packet.header >> PH_TYPE_SHIFT + packet_crc = packet.get_crc() # ------------------------------------------------------------------ # # Make sure we're not already busy @@ -404,6 +553,16 @@ async def send_packet( self._log("Busy, skipping TX.") return False + packet_airtime_ms = self._estimate_packet_airtime_ms(packet) + + self.state = DispatcherState.WAIT + try: + await self._await_tx_budget_window() + await self._ensure_channel_clear() + except Exception: + self.state = DispatcherState.IDLE + raise + # ------------------------------------------------------------------ # # Send the packet # ------------------------------------------------------------------ # @@ -415,6 +574,9 @@ async def send_packet( self._log(f"Radio transmit error: {e}") self.state = DispatcherState.IDLE return False + + self._record_outbound_packet_crc(packet_crc) + self._schedule_next_tx_window(packet_airtime_ms) # Log what we sent type_name = PAYLOAD_TYPES.get(payload_type, f"UNKNOWN_{payload_type}") route_name = ROUTE_TYPES.get(packet.get_route_type(), f"UNKNOWN_{packet.get_route_type()}") @@ -438,7 +600,7 @@ async def send_packet( if expected_crc is not None: self._current_expected_crc = expected_crc else: - self._current_expected_crc = packet.get_crc() + self._current_expected_crc = packet_crc self._log( f"Waiting for ACK with CRC {self._current_expected_crc:08X} (timeout: {ACK_TIMEOUT}s)" diff --git a/src/pymc_core/node/handlers/ack.py b/src/pymc_core/node/handlers/ack.py index 7d260cb..43295f4 100644 --- a/src/pymc_core/node/handlers/ack.py +++ b/src/pymc_core/node/handlers/ack.py @@ -1,17 +1,12 @@ from typing import Callable, Optional from ...protocol import Packet -from ...protocol.constants import PAYLOAD_TYPE_ACK +from ...protocol.constants import PAYLOAD_TYPE_ACK, PAYLOAD_TYPE_MULTIPART from .base import BaseHandler class AckHandler(BaseHandler): - """ - ACK handler that processes all ACK variants: - 1. Discrete ACK packets (payload type 1) - 2. Bundled ACKs in PATH packets - 3. Encrypted ACK responses (20-byte PATH packets) - """ + """Process discrete, multipart, and path-embedded ACK packets.""" @staticmethod def payload_type() -> int: @@ -31,30 +26,76 @@ def set_dispatcher(self, dispatcher): self.dispatcher = dispatcher async def __call__(self, packet: Packet) -> None: - """Handle discrete ACK packets (payload type 1).""" - ack_crc = await self.process_discrete_ack(packet) + """Handle all ACK packets, including multipart variants.""" + payload_type = packet.get_payload_type() + + if payload_type == PAYLOAD_TYPE_MULTIPART: + ack_crc = await self.process_multipart_ack(packet) + else: + ack_crc = await self.process_discrete_ack(packet) + if ack_crc is not None: await self._notify_ack_received(ack_crc) async def process_discrete_ack(self, packet: Packet) -> Optional[int]: """Process a discrete ACK packet and return the CRC if valid.""" - self.log(f"Processing discrete ACK: payload_len={len(packet.payload)}") - self.log(f"ACK payload (hex): {packet.payload.hex().upper()}") + payload_len = packet.payload_len or len(packet.payload) + payload = bytes(packet.payload[:payload_len]) + + self.log(f"Processing discrete ACK: payload_len={payload_len}") + self.log(f"ACK payload (hex): {payload.hex().upper()}") - if len(packet.payload) != 4: - self.log(f"Invalid ACK length: {len(packet.payload)} bytes (expected 4)") + if payload_len != 4: + self.log(f"Invalid ACK length: {payload_len} bytes (expected 4)") return None - # Extract CRC checksum (4 bytes, little endian per protocol spec) - crc = int.from_bytes(packet.payload, "little") + crc = int.from_bytes(payload, "little") self.log(f"Discrete ACK received: CRC={crc:08X}") return crc + async def process_multipart_ack(self, packet: Packet) -> Optional[int]: + """Process multipart ACK packets (PAYLOAD_TYPE_MULTIPART).""" + payload_len = packet.payload_len or len(packet.payload) + payload = bytes(packet.payload[:payload_len]) + + if not payload: + self.log("Multipart ACK missing payload") + return None + + multi_header = payload[0] + remaining = multi_header >> 4 + inner_type = multi_header & 0x0F + self.log( + f"Processing multipart ACK: remaining={remaining}, inner_type=0x{inner_type:02X}, total_len={payload_len}" + ) + + if inner_type != PAYLOAD_TYPE_ACK: + self.log("Multipart packet inner type is not ACK; ignoring") + return None + + ack_payload = payload[1:] + if len(ack_payload) < 4: + self.log( + f"Multipart ACK payload too short: {len(ack_payload)} bytes (expected >= 4)" + ) + return None + + crc = int.from_bytes(ack_payload[:4], "little") + self.log(f"Multipart ACK decoded: CRC={crc:08X} (remaining={remaining})") + return crc + async def process_path_ack_variants(self, packet: Packet) -> Optional[int]: """ Process PATH packets that may contain ACKs in different forms. Returns CRC if ACK found, None otherwise. """ + inner = getattr(packet, "decrypted", {}).get("path_inner") if packet else None + if inner: + bundled_crc = await self._process_bundled_ack_in_path(inner) + if bundled_crc is not None: + self.log(f"Found bundled ACK in PATH payload: CRC={bundled_crc:08X}") + return bundled_crc + if not self.dispatcher: return None @@ -80,12 +121,6 @@ async def process_path_ack_variants(self, packet: Packet) -> Optional[int]: self.log(f"Found encrypted ACK response: CRC={ack_crc:08X}") return ack_crc - # Check for bundled ACKs in returned path messages - bundled_crc = await self._process_bundled_ack_in_path(payload) - if bundled_crc is not None: - self.log(f"Found bundled ACK: CRC={bundled_crc:08X}") - return bundled_crc - return None async def _try_decrypt_encrypted_ack(self, payload: bytes) -> Optional[int]: @@ -110,11 +145,9 @@ async def _try_decrypt_encrypted_ack(self, payload: bytes) -> Optional[int]: # Decrypt (skip dest_hash and src_hash) mac_and_ciphertext = payload[2:] decrypted = CryptoUtils.mac_then_decrypt(aes_key, shared_secret, mac_and_ciphertext) - if not decrypted or len(decrypted) < 4: return None - # Look for expected CRC in decrypted data expected_crcs = set(self.dispatcher._waiting_acks.keys()) for i in range(len(decrypted) - 3): crc_bytes = decrypted[i : i + 4] @@ -133,7 +166,7 @@ async def _try_decrypt_encrypted_ack(self, payload: bytes) -> Optional[int]: return None async def _process_bundled_ack_in_path(self, payload: bytes) -> Optional[int]: - """Process bundled ACKs in returned path messages according to protocol spec.""" + """Process bundled ACKs from already-decrypted PATH payloads.""" if len(payload) < 1: return None diff --git a/src/pymc_core/node/handlers/control.py b/src/pymc_core/node/handlers/control.py index bf417d3..9ec7af4 100644 --- a/src/pymc_core/node/handlers/control.py +++ b/src/pymc_core/node/handlers/control.py @@ -6,14 +6,26 @@ import struct import time -from typing import Any, Callable, Dict +from typing import Any, Callable, Dict, List, Optional from ...protocol import Packet -from ...protocol.constants import PAYLOAD_TYPE_CONTROL +from ...protocol.constants import ( + ADV_TYPE_CHAT, + ADV_TYPE_LABELS, + ADV_TYPE_REPEATER, + ADV_TYPE_ROOM, + ADV_TYPE_SENSOR, + PAYLOAD_TYPE_CONTROL, +) # Control packet type constants (upper 4 bits of first byte) CTL_TYPE_NODE_DISCOVER_REQ = 0x80 # Discovery request CTL_TYPE_NODE_DISCOVER_RESP = 0x90 # Discovery response +CTL_TYPE_MASK = 0xF0 +DISCOVER_REQ_PREFIX_FLAG = 0x01 + +KNOWN_ADV_TYPES = [ADV_TYPE_CHAT, ADV_TYPE_REPEATER, ADV_TYPE_ROOM, ADV_TYPE_SENSOR] +KNOWN_ADV_TYPE_MASK = sum(1 << adv_type for adv_type in KNOWN_ADV_TYPES) class ControlHandler: @@ -31,9 +43,9 @@ def __init__(self, log_fn: Callable[[str], None]): """ self._log = log_fn - # Callbacks for discovery responses + # Callbacks for discovery responses/requests self._response_callbacks: Dict[int, Callable[[Dict[str, Any]], None]] = {} - self._request_callbacks: Dict[int, Callable[[Dict[str, Any]], None]] = {} + self._request_callback: Optional[Callable[[Dict[str, Any]], None]] = None @staticmethod def payload_type() -> int: @@ -53,11 +65,11 @@ def set_request_callback( self, callback: Callable[[Dict[str, Any]], None] ) -> None: """Set callback for discovery requests (for logging/monitoring).""" - self._request_callbacks[0] = callback + self._request_callback = callback def clear_request_callback(self) -> None: """Clear callback for discovery requests.""" - self._request_callbacks.pop(0, None) + self._request_callback = None async def __call__(self, pkt: Packet) -> None: """Handle incoming control packet.""" @@ -74,7 +86,7 @@ async def __call__(self, pkt: Packet) -> None: return # Extract control type (upper 4 bits of first byte) - control_type = pkt.payload[0] & 0xF0 + control_type = pkt.payload[0] & CTL_TYPE_MASK if control_type == CTL_TYPE_NODE_DISCOVER_REQ: await self._handle_discovery_request(pkt) @@ -104,8 +116,8 @@ async def _handle_discovery_request(self, pkt: Packet) -> None: # Parse request flags_byte = pkt.payload[0] - prefix_only = (flags_byte & 0x01) != 0 - filter_byte = pkt.payload[1] + prefix_only = (flags_byte & DISCOVER_REQ_PREFIX_FLAG) != 0 + filter_mask = pkt.payload[1] tag = struct.unpack(" None: if len(pkt.payload) >= 10: since = struct.unpack(" 0, "prefix_only": prefix_only, - "snr": pkt._snr, - "rssi": pkt._rssi, + "snr": pkt.get_snr(), + "raw_snr": pkt._snr, + "rssi": pkt.rssi, "timestamp": time.time(), + "payload_len": pkt.payload_len, } + unknown_filter_bits = filter_mask & ~KNOWN_ADV_TYPE_MASK + if unknown_filter_bits: + request_data["unknown_filter_bits"] = unknown_filter_bits + # Call request callback if registered (for logging/monitoring) - if 0 in self._request_callbacks: - callback = self._request_callbacks[0] - if callback: - callback(request_data) + if self._request_callback: + self._request_callback(request_data) except Exception as e: self._log(f"[ControlHandler] Error handling discovery request: {e}") @@ -148,7 +172,7 @@ async def _handle_discovery_response(self, pkt: Packet) -> None: - bytes 6-onwards: public key (8 or 32 bytes) """ try: - if len(pkt.payload) < 6: + if len(pkt.payload) < 14: self._log("[ControlHandler] Discovery response too short") return @@ -157,37 +181,53 @@ async def _handle_discovery_response(self, pkt: Packet) -> None: node_type = type_byte & 0x0F snr_byte = pkt.payload[1] # Convert signed byte to float SNR (C++ stores as int8_t multiplied by 4) - inbound_snr = (snr_byte if snr_byte < 128 else snr_byte - 256) / 4.0 + inbound_snr = self._decode_response_snr(snr_byte) tag = struct.unpack(" None: except Exception as e: self._log(f"[ControlHandler] Error handling discovery response: {e}") + + @staticmethod + def _decode_filter_types(filter_mask: int) -> List[int]: + if filter_mask is None: + return [] + matched = [] + for adv_type in KNOWN_ADV_TYPES: + if filter_mask & (1 << adv_type): + matched.append(adv_type) + return matched + + @staticmethod + def _decode_response_snr(raw_snr: int) -> float: + if raw_snr is None: + return 0.0 + signed = raw_snr if raw_snr < 128 else raw_snr - 256 + return signed / 4.0 diff --git a/src/pymc_core/node/handlers/path.py b/src/pymc_core/node/handlers/path.py index 98eddf1..3ca1435 100644 --- a/src/pymc_core/node/handlers/path.py +++ b/src/pymc_core/node/handlers/path.py @@ -1,37 +1,26 @@ """Path packet handler for mesh network routing.""" -from typing import Callable +from typing import Any, Callable, Optional -from ...protocol import Packet -from ...protocol.constants import PAYLOAD_TYPE_PATH +from ...protocol import CryptoUtils, Identity, Packet +from ...protocol.constants import PAYLOAD_TYPE_PATH, PAYLOAD_TYPE_RESPONSE class PathHandler: - """Handler for PATH packets (payload type 0x08) - "Returned path" packets. - - According to the official documentation, PATH packets are used for returning - responses through the mesh network along discovered routing paths. - - Official Packet Structure: - - Header [1B]: Route type (0-1) + Payload type (2-5) + Version (6-7) - - Path Length [1B]: Length of the path field - - Path [up to 64B]: Routing path data (if applicable) - - Payload [up to 184B]: The actual data being transmitted - - For PATH packets, the payload typically contains: - - [1B] dest_hash: Destination node hash - - [1B] src_hash: Source node hash - - [2B] MAC: Message Authentication Code (for payload version 0x00) - - [NB] encrypted_data: Contains ACK or other response data - """ + """Decrypt PATH packets, update cached routes, and surface extras.""" def __init__( self, log_fn: Callable[[str], None], + *, + local_identity: Any = None, + contact_book: Any = None, ack_handler=None, protocol_response_handler=None, ): self._log = log_fn + self._local_identity = local_identity + self._contact_book = contact_book self._ack_handler = ack_handler self._protocol_response_handler = protocol_response_handler @@ -44,63 +33,147 @@ def set_ack_handler(self, ack_handler): self._ack_handler = ack_handler async def __call__(self, pkt: Packet) -> None: - """Handle incoming PATH packet according to official specification.""" + """Handle incoming PATH packet: decrypt, update contact path, route extras.""" try: - # First, check if this PATH packet contains protocol responses - if self._protocol_response_handler: - await self._protocol_response_handler(pkt) + decoded = self._decode_path_payload(pkt) + if not decoded: + return + + # Surface decrypted payload for downstream consumers (ACK/analysis) + if not isinstance(pkt.decrypted, dict): + pkt.decrypted = {} + pkt.decrypted["path_inner"] = decoded["raw_inner"] + pkt.decrypted["path_meta"] = { + "path": decoded["path"], + "extra_type": decoded["extra_type"], + "src_hash": decoded["src_hash"], + } + + self._update_contact_path(decoded) + await self._handle_response_extra(decoded) - # Then, check if this PATH packet contains ACKs and delegate to ACK handler if self._ack_handler: ack_crc = await self._ack_handler.process_path_ack_variants(pkt) if ack_crc is not None: - # ACK was found, notify dispatcher await self._ack_handler._notify_ack_received(ack_crc) - # Optional PATH packet analysis if analyzer is available - try: - # Try to use any available packet analyzer through callback - if hasattr(self, "_dispatcher") and hasattr( - self._dispatcher, "packet_analysis_callback" - ): - if self._dispatcher.packet_analysis_callback: - self._dispatcher.packet_analysis_callback(pkt) - self._log("PATH packet analysis delegated to app") - else: - self._log("PATH packet received - hop analysis requires app-level analyzer") - - except Exception as e: - self._log(f"PATH packet analysis failed: {e}") - - # Extract and log key PATH information directly from packet - try: - payload = pkt.get_payload() - if len(payload) >= 2: - hop_count = payload[1] - self._log(f"PATH packet: hop_count={hop_count}, payload_len={len(payload)}") - self._log(f"Path contains {hop_count} hops") - else: - self._log("PATH packet received with minimal payload") - - # Log basic routing behavior based on header - try: - # These constants are already imported at the top - # from ...protocol.constants import ( - # ROUTE_TYPE_DIRECT, - # ROUTE_TYPE_FLOOD, - # ) - - # Extract route type from packet header if possible - # This is a simplified version without full analysis - self._log("PATH packet routing analysis requires app-level analyzer") - except ImportError: - pass - - except Exception as e: - self._log(f"Error extracting PATH information: {e}") + self._run_packet_analysis(pkt) + self._log_basic_path_stats(decoded) except Exception as e: self._log(f"Error in PATH handler: {e}") import traceback self._log(traceback.format_exc()) + + def _decode_path_payload(self, pkt: Packet) -> Optional[dict]: + if not self._local_identity or not self._contact_book: + self._log("PATH handler missing identity/contact book; skipping decrypt") + return None + + payload = pkt.payload + if len(payload) < 2: + self._log("PATH packet missing dest/src hashes") + return None + + dest_hash, src_hash = payload[0], payload[1] + local_hash = self._local_identity.get_public_key()[0] + if dest_hash != local_hash: + self._log( + f"PATH packet dest hash mismatch (dest=0x{dest_hash:02X}, local=0x{local_hash:02X})" + ) + return None + + contact = getattr(self._contact_book, "get_by_hash", lambda _: None)(src_hash) + if not contact: + self._log(f"PATH packet from unknown contact hash 0x{src_hash:02X}") + return None + + try: + peer_identity = Identity(bytes.fromhex(contact.public_key)) + shared_secret = peer_identity.calc_shared_secret( + self._local_identity.get_private_key() + ) + aes_key = shared_secret[:16] + decrypted = CryptoUtils.mac_then_decrypt(aes_key, shared_secret, payload[2:]) + except Exception as err: + self._log(f"PATH payload decryption failed: {err}") + return None + + if not decrypted: + self._log("PATH payload decrypted to empty data") + return None + if len(decrypted) < 1: + self._log("PATH payload missing path length byte") + return None + + path_len = decrypted[0] + if len(decrypted) < 1 + path_len: + self._log( + f"PATH payload truncated for path_len={path_len} (have {len(decrypted)})" + ) + return None + + path_bytes = list(int(b) & 0xFF for b in decrypted[1 : 1 + path_len]) + extra_offset = 1 + path_len + extra_type: Optional[int] = None + extra_payload = b"" + if len(decrypted) > extra_offset: + extra_type = decrypted[extra_offset] + extra_payload = bytes(decrypted[extra_offset + 1 :]) + + return { + "dest_hash": dest_hash, + "src_hash": src_hash, + "contact": contact, + "path": path_bytes, + "extra_type": extra_type, + "extra_payload": extra_payload, + "raw_inner": bytes(decrypted), + } + + def _update_contact_path(self, decoded: dict) -> None: + contact = decoded.get("contact") + path = decoded.get("path") or [] + if not contact or not path: + return + + updater = getattr(self._contact_book, "update_out_path", None) + if callable(updater): + updater(contact, path) + self._log( + f"Updated cached path for contact hash 0x{decoded['src_hash']:02X}: {path}" + ) + + async def _handle_response_extra(self, decoded: dict) -> None: + if decoded.get("extra_type") != PAYLOAD_TYPE_RESPONSE: + return + if not self._protocol_response_handler: + self._log("PATH response extra ignored (no protocol handler)") + return + + extra_payload: bytes = decoded.get("extra_payload", b"") + if not extra_payload: + self._log("PATH response extra empty; nothing to deliver") + return + + await self._protocol_response_handler.handle_plaintext_response( + decoded["src_hash"], decoded.get("contact"), extra_payload + ) + + def _run_packet_analysis(self, pkt: Packet) -> None: + try: + dispatcher = getattr(self, "_dispatcher", None) + if dispatcher and getattr(dispatcher, "packet_analysis_callback", None): + dispatcher.packet_analysis_callback(pkt) + self._log("PATH packet analysis delegated to app") + except Exception as exc: + self._log(f"PATH packet analysis failed: {exc}") + + def _log_basic_path_stats(self, decoded: dict) -> None: + path = decoded.get("path") or [] + extra_type = decoded.get("extra_type") + self._log( + f"PATH packet: hops={len(path)}, extra_type=" + f"{('0x%02X' % extra_type) if extra_type is not None else 'none'}" + ) diff --git a/src/pymc_core/node/handlers/protocol_response.py b/src/pymc_core/node/handlers/protocol_response.py index a55c7fe..1e8d704 100644 --- a/src/pymc_core/node/handlers/protocol_response.py +++ b/src/pymc_core/node/handlers/protocol_response.py @@ -63,27 +63,23 @@ async def __call__(self, pkt: Packet) -> None: ) # Try to decrypt the response - success, decoded_text, parsed_data = await self._decrypt_protocol_response( + success, decoded_text, parsed_data, contact = await self._decrypt_protocol_response( pkt, src_hash ) - - # Call the waiting callback - callback = self._response_callbacks[src_hash] - if callback: - callback(success, decoded_text, parsed_data) + await self._deliver_response(src_hash, success, decoded_text, parsed_data, contact) except Exception as e: self._log(f"[ProtocolResponse] Error processing protocol response: {e}") async def _decrypt_protocol_response( self, pkt: Packet, src_hash: int - ) -> tuple[bool, str, Dict[str, Any]]: + ) -> tuple[bool, str, Dict[str, Any], Optional[Any]]: """Decrypt and parse a protocol response packet.""" try: # Find the contact by hash contact = self._find_contact_by_hash(src_hash) if not contact: - return False, f"Unknown contact for hash 0x{src_hash:02X}", {} + return False, f"Unknown contact for hash 0x{src_hash:02X}", {}, None # Get encryption keys contact_pubkey = bytes.fromhex(contact.public_key) @@ -100,11 +96,54 @@ async def _decrypt_protocol_response( self._log(f"[ProtocolResponse] Successfully decrypted {len(decrypted)} bytes") # Parse based on content type - return self._parse_protocol_response(decrypted) + success, decoded_text, parsed_data = self._parse_protocol_response(decrypted) + return success, decoded_text, parsed_data, contact except Exception as e: self._log(f"[ProtocolResponse] Decryption failed: {e}") - return False, f"Decryption failed: {e}", {} + return False, f"Decryption failed: {e}", {}, contact if 'contact' in locals() else None + + async def handle_plaintext_response( + self, src_hash: int, contact: Optional[Any], plaintext: bytes + ) -> None: + """Process a plaintext response that has already been decrypted.""" + try: + success, decoded_text, parsed_data = self._parse_protocol_response(plaintext) + except Exception as exc: + self._log(f"[ProtocolResponse] Plaintext parsing failed: {exc}") + success, decoded_text, parsed_data = False, f"Parse error: {exc}", {} + + await self._deliver_response(src_hash, success, decoded_text, parsed_data, contact) + + async def _deliver_response( + self, + src_hash: int, + success: bool, + decoded_text: str, + parsed_data: Dict[str, Any], + contact: Optional[Any], + ) -> None: + """Invoke waiting callback with ACL enforcement.""" + if src_hash not in self._response_callbacks: + return + + if ( + success + and parsed_data.get("type") == "telemetry" + and contact + and hasattr(self._contact_book, "can_receive_telemetry") + ): + if not self._contact_book.can_receive_telemetry(contact): + self._log( + "[ProtocolResponse] Telemetry blocked by ACL for " f"0x{src_hash:02X}" + ) + success = False + decoded_text = "Telemetry blocked by ACL" + parsed_data = {"type": "telemetry", "acl_blocked": True} + + callback = self._response_callbacks.get(src_hash) + if callback: + callback(success, decoded_text, parsed_data) def _parse_protocol_response(self, data: bytes) -> tuple[bool, str, Dict[str, Any]]: """Parse decrypted protocol response data.""" diff --git a/src/pymc_core/node/handlers/text.py b/src/pymc_core/node/handlers/text.py index 2482f34..05c68d3 100644 --- a/src/pymc_core/node/handlers/text.py +++ b/src/pymc_core/node/handlers/text.py @@ -74,6 +74,17 @@ async def __call__(self, packet: Packet) -> None: pubkey = bytes.fromhex(matched_contact.public_key) timestamp_int = int.from_bytes(timestamp, "little") + if flags == 0x01: + can_cli = True + permission_fn = getattr(self.contacts, "can_execute_cli", None) + if callable(permission_fn): + can_cli = permission_fn(matched_contact) + if not can_cli: + self.log( + f"CLI command from '{matched_contact.name}' blocked by ACL" + ) + return + # Determine message routing type from packet header route_type = packet.header & 0x03 # Route type is in bits 0-1 is_flood = route_type == 1 # ROUTE_TYPE_FLOOD = 1 diff --git a/src/pymc_core/node/node.py b/src/pymc_core/node/node.py index 266be07..f3945e0 100644 --- a/src/pymc_core/node/node.py +++ b/src/pymc_core/node/node.py @@ -11,6 +11,7 @@ setattr(collections, "Hashable", collections.abc.Hashable) from ..protocol import LocalIdentity +from .contact_book import ContactBook from .dispatcher import Dispatcher logger = logging.getLogger("Node") @@ -54,7 +55,7 @@ def __init__( """ self.radio = radio self.identity = local_identity - self.contacts = contacts # App can inject contact storage + self.contacts = contacts or ContactBook() self.channel_db = channel_db # App can inject channel database self.event_service = event_service # App can inject event service diff --git a/src/pymc_core/protocol/__init__.py b/src/pymc_core/protocol/__init__.py index 62ffd46..43fb47a 100644 --- a/src/pymc_core/protocol/__init__.py +++ b/src/pymc_core/protocol/__init__.py @@ -11,6 +11,7 @@ ADVERT_FLAG_IS_CHAT_NODE, ADVERT_FLAG_IS_REPEATER, ADVERT_FLAG_IS_ROOM_SERVER, + ADVERT_FLAG_IS_SENSOR, CIPHER_BLOCK_SIZE, CIPHER_MAC_SIZE, CONTACT_TYPE_CHAT_NODE, @@ -71,7 +72,13 @@ PacketValidationUtils, RouteTypeUtils, ) -from .transport_keys import calc_transport_code, get_auto_key_for +from .transport_keys import ( + TransportKey, + TransportKeyStore, + calc_transport_code, + derive_auto_key, + get_auto_key_for, +) from .utils import decode_appdata, parse_advert_payload __all__ = [ @@ -136,6 +143,7 @@ "ADVERT_FLAG_IS_CHAT_NODE", "ADVERT_FLAG_IS_REPEATER", "ADVERT_FLAG_IS_ROOM_SERVER", + "ADVERT_FLAG_IS_SENSOR", "ADVERT_FLAG_HAS_LOCATION", "ADVERT_FLAG_HAS_FEATURE1", "ADVERT_FLAG_HAS_FEATURE2", diff --git a/src/pymc_core/protocol/constants.py b/src/pymc_core/protocol/constants.py index 71af37f..9d29f2c 100644 --- a/src/pymc_core/protocol/constants.py +++ b/src/pymc_core/protocol/constants.py @@ -60,10 +60,26 @@ TIMESTAMP_SIZE = 4 # 4 bytes for a timestamp (32-bit unsigned int) # --------------------------------------------------------------------------- +# Node Advert Types (ADV_TYPE_* from firmware) +ADV_TYPE_NONE = 0 +ADV_TYPE_CHAT = 1 +ADV_TYPE_REPEATER = 2 +ADV_TYPE_ROOM = 3 +ADV_TYPE_SENSOR = 4 + +ADV_TYPE_LABELS = { + ADV_TYPE_NONE: "unknown", + ADV_TYPE_CHAT: "chat", + ADV_TYPE_REPEATER: "repeater", + ADV_TYPE_ROOM: "room", + ADV_TYPE_SENSOR: "sensor", +} + # Node Advert Flags (bitfield values) ADVERT_FLAG_IS_CHAT_NODE = 0x01 ADVERT_FLAG_IS_REPEATER = 0x02 ADVERT_FLAG_IS_ROOM_SERVER = 0x04 +ADVERT_FLAG_IS_SENSOR = 0x08 ADVERT_FLAG_HAS_LOCATION = 0x10 ADVERT_FLAG_HAS_FEATURE1 = 0x20 ADVERT_FLAG_HAS_FEATURE2 = 0x40 @@ -79,6 +95,8 @@ def describe_advert_flags(flags: int) -> str: labels.append("is repeater") if flags & ADVERT_FLAG_IS_ROOM_SERVER: labels.append("is room server") + if flags & ADVERT_FLAG_IS_SENSOR: + labels.append("is sensor") if flags & ADVERT_FLAG_HAS_LOCATION: labels.append("has location") if flags & ADVERT_FLAG_HAS_FEATURE1: diff --git a/src/pymc_core/protocol/packet.py b/src/pymc_core/protocol/packet.py index 5ce80dd..368d6b3 100644 --- a/src/pymc_core/protocol/packet.py +++ b/src/pymc_core/protocol/packet.py @@ -17,6 +17,9 @@ ) from .packet_utils import PacketDataUtils, PacketHashingUtils, PacketValidationUtils + +DO_NOT_RETRANSMIT_HEADER = 0xFF + """ ╔═══════════════════════════════════════════════════════════════════════════╗ ║ MESH PACKET STRUCTURE OVERVIEW ║ @@ -110,7 +113,6 @@ class Packet: "transport_codes", "_snr", "_rssi", - "_do_not_retransmit", ) def __init__(self): @@ -130,7 +132,6 @@ def __init__(self): self.transport_codes = [0, 0] # Array of two 16-bit transport codes self._snr = 0 self._rssi = 0 - self._do_not_retransmit = False def get_route_type(self) -> int: """ @@ -150,14 +151,21 @@ def get_payload_type(self) -> int: Extract the 4-bit payload type from the packet header. Returns: - int: Payload type value indicating the type of data in the packet: - - 0: Plain text message - - 1: Encrypted message - - 2: ACK packet - - 3: Advertisement - - 4: Login request/response - - 5: Protocol control - - 6-15: Reserved for future use + int: Payload type value (bits 2-5) describing the payload semantics: + - 0x00: `PAYLOAD_TYPE_REQ` (protocol requests) + - 0x01: `PAYLOAD_TYPE_RESPONSE` (login / protocol responses) + - 0x02: `PAYLOAD_TYPE_TXT_MSG` (direct text datagrams) + - 0x03: `PAYLOAD_TYPE_ACK` (delivery acknowledgements) + - 0x04: `PAYLOAD_TYPE_ADVERT` (node advertisements) + - 0x05: `PAYLOAD_TYPE_GRP_TXT` (group/channel text) + - 0x06: `PAYLOAD_TYPE_GRP_DATA` (channel binary data) + - 0x07: `PAYLOAD_TYPE_ANON_REQ` (anonymous requests/login) + - 0x08: `PAYLOAD_TYPE_PATH` (returned path + responses) + - 0x09: `PAYLOAD_TYPE_TRACE` (trace diagnostics) + - 0x0A: `PAYLOAD_TYPE_MULTIPART` (wrapper with inner payload) + - 0x0B: `PAYLOAD_TYPE_CONTROL` (discovery/control plane) + - 0x0F: `PAYLOAD_TYPE_RAW_CUSTOM` (vendor-specific) + Remaining values are reserved by MeshCore. """ return (self.header >> PH_TYPE_SHIFT) & PH_TYPE_MASK @@ -286,8 +294,8 @@ def write_to(self) -> bytes: # Add transport codes if this packet type requires them if self.has_transport_codes(): # Pack two 16-bit transport codes (4 bytes total) in little-endian format - out.extend(self.transport_codes[0].to_bytes(2, 'little')) - out.extend(self.transport_codes[1].to_bytes(2, 'little')) + out.extend((self.transport_codes[0] & 0xFFFF).to_bytes(2, "little")) + out.extend((self.transport_codes[1] & 0xFFFF).to_bytes(2, "little")) out.append(self.path_len) out += self.path @@ -319,8 +327,8 @@ def read_from(self, data: ByteString) -> bool: if self.has_transport_codes(): self._check_bounds(idx, 4, data_len, "missing transport codes") # Unpack two 16-bit transport codes from little-endian format - self.transport_codes[0] = int.from_bytes(data[idx:idx+2], 'little') - self.transport_codes[1] = int.from_bytes(data[idx+2:idx+4], 'little') + self.transport_codes[0] = int.from_bytes(data[idx:idx+2], "little") + self.transport_codes[1] = int.from_bytes(data[idx+2:idx+4], "little") idx += 4 else: self.transport_codes = [0, 0] @@ -471,7 +479,7 @@ def mark_do_not_retransmit(self) -> None: Used by destination nodes after successfully decrypting and processing a message intended for them. """ - self._do_not_retransmit = True + self.header = DO_NOT_RETRANSMIT_HEADER def is_marked_do_not_retransmit(self) -> bool: """ @@ -482,4 +490,4 @@ def is_marked_do_not_retransmit(self) -> bool: This indicates the packet has reached its destination or should remain local to the receiving node. """ - return self._do_not_retransmit + return self.header == DO_NOT_RETRANSMIT_HEADER diff --git a/src/pymc_core/protocol/packet_filter.py b/src/pymc_core/protocol/packet_filter.py index 1b91420..d101f86 100644 --- a/src/pymc_core/protocol/packet_filter.py +++ b/src/pymc_core/protocol/packet_filter.py @@ -1,24 +1,39 @@ -""" -Simple packet filter for dispatcher-level routing decisions. +"""MeshCore-aligned packet filter used by the dispatcher. -This handles only the essential routing concerns: -- Duplicate detection -- Packet blacklisting for malformed packets -- Basic packet hash tracking +Matches firmware heuristics by: +- Keeping a bounded pool of recent packet hashes for duplicate detection. +- Tracking a timed blacklist so malformed packets eventually expire. +- Providing a delayed-queue helper so callers can avoid reprocessing floods + that are already waiting for their score-based holdoff window. """ import hashlib import time -from typing import Dict, Set +from collections import OrderedDict +from typing import MutableMapping class PacketFilter: - """Lightweight packet filter for dispatcher routing decisions.""" + """Stateful packet filter mirroring MeshCore's inbound manager heuristics.""" + + def __init__( + self, + window_seconds: int = 45, + *, + blacklist_duration: int = 180, + max_tracked_packets: int = 4096, + max_blacklist_size: int = 512, + max_delayed_packets: int = 512, + ): + self.window_seconds = max(0, window_seconds) + self.blacklist_duration = max(1, blacklist_duration) + self.max_tracked_packets = max(1, max_tracked_packets) + self.max_blacklist_size = max(1, max_blacklist_size) + self.max_delayed_packets = max(1, max_delayed_packets) - def __init__(self, window_seconds: int = 30): - self.window_seconds = window_seconds - self._packet_hashes: Dict[str, float] = {} # packet_hash -> timestamp - self._blacklist: Set[str] = set() # blacklisted packet hashes + self._packet_hashes: MutableMapping[str, float] = OrderedDict() + self._blacklist: MutableMapping[str, float] = OrderedDict() + self._delayed_packets: MutableMapping[str, float] = OrderedDict() def generate_hash(self, data: bytes) -> str: """Generate a hash for packet data.""" @@ -26,43 +41,144 @@ def generate_hash(self, data: bytes) -> str: def is_duplicate(self, packet_hash: str) -> bool: """Check if we've seen this packet recently.""" + if self.window_seconds == 0: + # Deduplication disabled - always treat as new packet. + self._packet_hashes.pop(packet_hash, None) + return False + now = time.time() - if packet_hash in self._packet_hashes: - age = now - self._packet_hashes[packet_hash] - if age < self.window_seconds: - return True - return False + timestamp = self._packet_hashes.get(packet_hash) + if timestamp is None: + return False + + if (now - timestamp) >= self.window_seconds: + # Entry aged out – drop from pool and treat as new. + self._packet_hashes.pop(packet_hash, None) + return False + + return True def track_packet(self, packet_hash: str) -> None: """Track a packet hash with current timestamp.""" - self._packet_hashes[packet_hash] = time.time() + now = time.time() + self._packet_hashes[packet_hash] = now + # Maintain insertion order so we can evict the oldest hashes first. + if isinstance(self._packet_hashes, OrderedDict): + self._packet_hashes.move_to_end(packet_hash) + self._evict_old_packets(now) def blacklist(self, packet_hash: str) -> None: """Add a packet hash to the blacklist.""" - self._blacklist.add(packet_hash) + expiry = time.time() + self.blacklist_duration + self._blacklist[packet_hash] = expiry + if isinstance(self._blacklist, OrderedDict): + self._blacklist.move_to_end(packet_hash) + self._evict_old_blacklist_entries(time.time()) def is_blacklisted(self, packet_hash: str) -> bool: """Check if a packet hash is blacklisted.""" - return packet_hash in self._blacklist + expiry = self._blacklist.get(packet_hash) + if expiry is None: + return False + if time.time() >= expiry: + self._blacklist.pop(packet_hash, None) + return False + return True + + def schedule_delay(self, packet_hash: str, delay_seconds: float) -> None: + """Register that a packet is being delayed before processing.""" + + expiry = time.time() + max(0.0, delay_seconds) + self._delayed_packets[packet_hash] = expiry + if isinstance(self._delayed_packets, OrderedDict): + self._delayed_packets.move_to_end(packet_hash) + self._evict_old_delays(time.time()) + + def is_delay_active(self, packet_hash: str) -> bool: + """Return True if a packet is currently waiting in the delayed queue.""" + + expiry = self._delayed_packets.get(packet_hash) + if expiry is None: + return False + if time.time() >= expiry: + self._delayed_packets.pop(packet_hash, None) + return False + return True def cleanup_old_hashes(self) -> None: """Clean up old packet hashes beyond the deduplication window.""" current_time = time.time() - old_hashes = [ - h for h, ts in self._packet_hashes.items() if current_time - ts > self.window_seconds - ] - for h in old_hashes: - del self._packet_hashes[h] + self._evict_old_packets(current_time) + self._evict_old_blacklist_entries(current_time) + self._evict_old_delays(current_time) def get_stats(self) -> dict: """Get basic filter statistics.""" return { "tracked_packets": len(self._packet_hashes), "blacklisted_packets": len(self._blacklist), + "delayed_packets": len(self._delayed_packets), "window_seconds": self.window_seconds, + "blacklist_duration": self.blacklist_duration, } def clear(self) -> None: """Clear all tracked data.""" self._packet_hashes.clear() self._blacklist.clear() + self._delayed_packets.clear() + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _evict_old_packets(self, now: float) -> None: + if not self._packet_hashes: + return + cutoff = now - self.window_seconds if self.window_seconds else None + keys_to_remove = [] + for packet_hash, ts in self._packet_hashes.items(): + if cutoff is not None and ts < cutoff: + keys_to_remove.append(packet_hash) + elif len(self._packet_hashes) - len(keys_to_remove) > self.max_tracked_packets: + keys_to_remove.append(packet_hash) + else: + # OrderedDict is chronological; break when remaining entries are recent. + if cutoff is not None and ts >= cutoff: + break + for key in keys_to_remove: + self._packet_hashes.pop(key, None) + + # Trim if still above cap (window might be 0 -> no cutoff) + while len(self._packet_hashes) > self.max_tracked_packets: + self._packet_hashes.popitem(last=False) + + def _evict_old_blacklist_entries(self, now: float) -> None: + if not self._blacklist: + return + keys_to_remove = [] + for packet_hash, expiry in self._blacklist.items(): + if expiry <= now or len(self._blacklist) - len(keys_to_remove) > self.max_blacklist_size: + keys_to_remove.append(packet_hash) + else: + break + for key in keys_to_remove: + self._blacklist.pop(key, None) + + while len(self._blacklist) > self.max_blacklist_size: + self._blacklist.popitem(last=False) + + def _evict_old_delays(self, now: float) -> None: + if not self._delayed_packets: + return + keys_to_remove = [] + for packet_hash, expiry in self._delayed_packets.items(): + if expiry <= now or len(self._delayed_packets) - len(keys_to_remove) > self.max_delayed_packets: + keys_to_remove.append(packet_hash) + else: + break + for key in keys_to_remove: + self._delayed_packets.pop(key, None) + + while len(self._delayed_packets) > self.max_delayed_packets: + self._delayed_packets.popitem(last=False) diff --git a/src/pymc_core/protocol/packet_utils.py b/src/pymc_core/protocol/packet_utils.py index f74b1a9..daf487a 100644 --- a/src/pymc_core/protocol/packet_utils.py +++ b/src/pymc_core/protocol/packet_utils.py @@ -4,8 +4,9 @@ """ import hashlib +import math import struct -from typing import Any, List, Union +from typing import Any, List, Optional, Union from .constants import ( MAX_HASH_SIZE, @@ -214,7 +215,8 @@ def calculate_packet_hash(payload_type: int, path_len: int, payload: bytes) -> b sha = hashlib.sha256() sha.update(bytes([payload_type])) if payload_type == PAYLOAD_TYPE_TRACE: - sha.update(bytes([path_len])) + # MeshCore feeds the two-byte path_len field as-is for TRACE hashes + sha.update(struct.pack(" int class PacketTimingUtils: """Utilities for packet transmission timing calculations.""" + DEFAULT_BUDGET_FACTOR = 2.0 # Matches Dispatcher::getAirtimeBudgetFactor + CAD_FAIL_RETRY_DELAY_MS = 200 # Matches Dispatcher::getCADFailRetryDelay + CAD_FAIL_MAX_DURATION_MS = 4000 # Matches Dispatcher::getCADFailMaxDuration + MAX_RX_DELAY_MS = 32_000 # Dispatcher::MAX_RX_DELAY_MILLIS + DEFAULT_SPREADING_FACTOR = 10 + SNR_THRESHOLDS = { + 7: -7.5, + 8: -10.0, + 9: -12.5, + 10: -15.0, + 11: -17.5, + 12: -20.0, + } + @staticmethod def estimate_airtime_ms(packet_length_bytes: int, radio_config: dict = None) -> float: """ @@ -288,38 +304,57 @@ def estimate_airtime_ms(packet_length_bytes: int, radio_config: dict = None) -> if radio_config is None: radio_config = { "spreading_factor": 10, - "bandwidth": 250000, # 250kHz - "coding_rate": 5, + "bandwidth": 250_000, # Hz + "coding_rate": 5, # LoRa denominator (4/5 defaults to 5) "preamble_length": 8, + "explicit_header": True, + "crc_enabled": True, } if "measured_airtime_ms" in radio_config: - return radio_config["measured_airtime_ms"] - - sf = radio_config.get("spreading_factor", 10) - bw = radio_config.get("bandwidth", 250000) # Hz or kHz - convert to Hz if needed - cr = radio_config.get("coding_rate", 5) - preamble = radio_config.get("preamble_length", 8) - - # Convert bandwidth to Hz if it's in kHz (values < 1000 are assumed to be kHz) - if bw < 1000: - bw = bw * 1000 # Convert kHz to Hz - - symbol_time = (2**sf) / bw # seconds per symbol - - # Preamble time - preamble_time = preamble * symbol_time - - # Payload symbols (simplified) - payload_symbols = 8 + max(0, (packet_length_bytes * 8 - 4 * sf + 28) // (4 * (sf - 2))) * ( - cr + 4 + return float(radio_config["measured_airtime_ms"]) + + sf = int(radio_config.get("spreading_factor", 10)) + bw_hz = float(radio_config.get("bandwidth", 250_000)) + coding_rate = int(radio_config.get("coding_rate", 5)) + preamble = float(radio_config.get("preamble_length", 8)) + explicit_header = bool(radio_config.get("explicit_header", True)) + crc_enabled = bool(radio_config.get("crc_enabled", True)) + + if bw_hz < 1000: + bw_hz *= 1000.0 # accept kHz inputs + + # Low data rate optimization mirrors RadioLib: enable for SF >= 11 at <=125kHz unless overridden. + ldro = radio_config.get("low_data_rate_optimize") + if ldro is None: + ldro = sf >= 11 and bw_hz <= 125_000 + de = 1 if ldro else 0 + + # Convert coding rate to the 1..4 representation used by Semtech's formula. + if coding_rate > 4: + cr_setting = coding_rate - 4 + else: + cr_setting = coding_rate + cr_setting = max(1, min(4, cr_setting)) + + h = 0 if explicit_header else 1 + crc_term = 16 if crc_enabled else 0 + denom = 4 * (sf - 2 * de) + if denom <= 0: + raise ValueError("Invalid radio configuration: spreading_factor too low for LDRO setting") + + payload_numerator = (8 * packet_length_bytes) - (4 * sf) + 28 + crc_term - (20 * h) + payload_symbols = 8 + max( + math.ceil(payload_numerator / denom) * (cr_setting + 4), + 0, ) - payload_time = payload_symbols * symbol_time - total_time_ms = (preamble_time + payload_time) * 1000 + symbol_time_sec = (2**sf) / bw_hz + preamble_time_sec = (preamble + 4.25) * symbol_time_sec + payload_time_sec = payload_symbols * symbol_time_sec - # Add some overhead for processing and turnaround - return max(total_time_ms, 50.0) # Minimum 50ms + total_time_ms = (preamble_time_sec + payload_time_sec) * 1000.0 + return max(total_time_ms, 1.0) @staticmethod def calc_flood_timeout_ms(packet_airtime_ms: float) -> float: @@ -359,3 +394,64 @@ def calc_direct_timeout_ms(packet_airtime_ms: float, path_len: int) -> float: (packet_airtime_ms * DIRECT_SEND_PERHOP_FACTOR + DIRECT_SEND_PERHOP_EXTRA_MILLIS) * (path_len + 1) ) + + @staticmethod + def calc_rx_delay_ms( + score: float, + packet_airtime_ms: float, + max_delay_ms: Optional[int] = None, + ) -> int: + """Replicate Dispatcher::calcRxDelay using the MeshCore score heuristic.""" + + delay = (math.pow(10.0, 0.85 - score) - 1.0) * packet_airtime_ms + if delay <= 0: + return 0 + limit = PacketTimingUtils.MAX_RX_DELAY_MS if max_delay_ms is None else max_delay_ms + return min(int(round(delay)), limit) + + @staticmethod + def calc_airtime_budget_delay_ms( + packet_airtime_ms: float, + budget_factor: Optional[float] = None, + ) -> float: + """Return the radio silence interval enforced after a send (Dispatcher::getAirtimeBudgetFactor).""" + + factor = PacketTimingUtils.DEFAULT_BUDGET_FACTOR if budget_factor is None else budget_factor + return max(packet_airtime_ms * factor, 0.0) + + @staticmethod + def get_cad_fail_retry_delay_ms() -> int: + """Expose Dispatcher::getCADFailRetryDelay for CAD back-off modeling.""" + + return PacketTimingUtils.CAD_FAIL_RETRY_DELAY_MS + + @staticmethod + def get_cad_fail_max_duration_ms() -> int: + """Expose Dispatcher::getCADFailMaxDuration for CAD timeout modeling.""" + + return PacketTimingUtils.CAD_FAIL_MAX_DURATION_MS + + @staticmethod + def calculate_packet_score( + snr_db: float, + packet_len_bytes: int, + spreading_factor: Optional[int] = None, + ) -> float: + """Approximate mesh::Radio::packetScore from MeshCore RadioLib wrappers.""" + + sf = spreading_factor or PacketTimingUtils.DEFAULT_SPREADING_FACTOR + if sf < 7: + return 0.0 + + threshold = PacketTimingUtils.SNR_THRESHOLDS.get(sf, PacketTimingUtils.SNR_THRESHOLDS[10]) + if snr_db < threshold: + return 0.0 + + success_rate = (snr_db - threshold) / 10.0 + collision_penalty = 1.0 - min(max(packet_len_bytes / 256.0, 0.0), 1.0) + score = success_rate * collision_penalty + if score < 0.0: + return 0.0 + if score > 1.0: + return 1.0 + return score diff --git a/src/pymc_core/protocol/transport_keys.py b/src/pymc_core/protocol/transport_keys.py index 9518f3e..651d442 100644 --- a/src/pymc_core/protocol/transport_keys.py +++ b/src/pymc_core/protocol/transport_keys.py @@ -1,70 +1,138 @@ -""" -Transport Key utilities for mesh packet authentication. +"""Transport key helpers that mirror MeshCore's TransportKeyStore.""" -Simple implementation matching the C++ MeshCore transport key functionality: -- Generate 128-bit key from region name (SHA256 of ASCII name) -- Calculate transport codes using HMAC-SHA256 -""" +from __future__ import annotations +from collections import deque import struct +from typing import Iterable, List, Sequence + from .crypto import CryptoUtils +MAX_TKS_ENTRIES = 16 + + +def derive_auto_key(name: str) -> bytes: + """Derive a 128-bit key from a region hashtag (MeshCore parity).""" -def get_auto_key_for(name: str) -> bytes: - """ - Generate 128-bit transport key from region name. - - Matches C++ implementation: - void TransportKeyStore::getAutoKeyFor(uint16_t id, const char* name, TransportKey& dest) - - Args: - name: Region name including '#' (e.g., "#usa") - - Returns: - bytes: 16-byte transport key - """ if not name: raise ValueError("Region name cannot be empty") - if not name.startswith('#'): + if not name.startswith("#"): raise ValueError("Region name must start with '#'") if len(name) > 64: raise ValueError("Region name is too long (max 64 characters)") - key_hash = CryptoUtils.sha256(name.encode('ascii')) - return key_hash[:16] # First 16 bytes (128 bits) + key_hash = CryptoUtils.sha256(name.encode("ascii")) + return key_hash[:16] + + +def get_auto_key_for(name: str) -> bytes: + """Backward-compatible alias for :func:`derive_auto_key`.""" + + return derive_auto_key(name) def calc_transport_code(key: bytes, packet) -> int: - """ - Calculate transport code for a packet. - - Matches C++ implementation: - uint16_t TransportKey::calcTransportCode(const mesh::Packet* packet) const - - Args: - key: 16-byte transport key - packet: Packet with payload_type and payload - - Returns: - int: 16-bit transport code - """ + """Calculate the transport code (HMAC-SHA256, reserve 0x0000/0xFFFF).""" + if len(key) != 16: raise ValueError(f"Transport key must be 16 bytes, got {len(key)}") + payload_type = packet.get_payload_type() payload_data = packet.get_payload() - - # HMAC input: payload_type (1 byte) + payload hmac_data = bytes([payload_type]) + payload_data - - # Calculate HMAC-SHA256 hmac_digest = CryptoUtils._hmac_sha256(key, hmac_data) - - # Extract first 2 bytes as little-endian uint16 (matches Arduino platform endianness) - code = struct.unpack(' None: + self.key = bytes(16) + if key is not None: + self.set_key(key) + + def set_key(self, key: bytes | bytearray | memoryview) -> None: + data = bytes(key) + if len(data) != 16: + raise ValueError(f"Transport key must be 16 bytes, got {len(data)}") + self.key = data + + def is_null(self) -> bool: + return all(b == 0 for b in self.key) + + def calc_transport_code(self, packet) -> int: + return calc_transport_code(self.key, packet) + + def copy(self) -> "TransportKey": + return TransportKey(self.key) + + def __repr__(self) -> str: # pragma: no cover - debug helper + return f"TransportKey({self.key.hex()})" + + +class TransportKeyStore: + """In-memory cache that mirrors MeshCore's TransportKeyStore behavior.""" + + def __init__(self, max_entries: int = MAX_TKS_ENTRIES) -> None: + if max_entries <= 0: + raise ValueError("max_entries must be positive") + self.max_entries = max_entries + self._cache: deque[tuple[int, TransportKey]] = deque(maxlen=max_entries) + + # -- cache primitives ------------------------------------------------- + def _put_cache(self, region_id: int, key: TransportKey) -> None: + self._cache.append((region_id, key.copy())) + + def invalidate_cache(self) -> None: + self._cache.clear() + + def _cached_keys_for(self, region_id: int) -> List[TransportKey]: + return [entry.copy() for rid, entry in self._cache if rid == region_id] + + # -- public API ------------------------------------------------------- + def get_auto_key_for(self, region_id: int, name: str) -> TransportKey: + cached = self._cached_keys_for(region_id) + if cached: + return cached[0] + + derived = TransportKey(derive_auto_key(name)) + self._put_cache(region_id, derived) + return derived.copy() + + def load_keys_for(self, region_id: int, max_num: int | None = None) -> List[TransportKey]: + keys = self._cached_keys_for(region_id) + if max_num is not None: + return keys[:max_num] + return keys + + def save_keys_for(self, region_id: int, keys: Sequence[bytes | TransportKey]) -> bool: + if not keys: + return False + for key in keys: + tk = key if isinstance(key, TransportKey) else TransportKey(key) + self._put_cache(region_id, tk) + return True + + def remove_keys(self, region_id: int) -> bool: + original = len(self._cache) + self._cache = deque([(rid, key) for rid, key in self._cache if rid != region_id], maxlen=self.max_entries) + return len(self._cache) != original + + def clear(self) -> bool: + changed = len(self._cache) > 0 + self.invalidate_cache() + return changed + + def cache_snapshot(self) -> List[tuple[int, TransportKey]]: + """Return a copy of the cache for diagnostics/testing.""" + + return [(rid, key.copy()) for rid, key in self._cache] diff --git a/src/pymc_core/protocol/utils.py b/src/pymc_core/protocol/utils.py index bd3f045..ed73b73 100644 --- a/src/pymc_core/protocol/utils.py +++ b/src/pymc_core/protocol/utils.py @@ -2,6 +2,8 @@ Centralized protocol utility functions and lookup tables for mesh network. """ +import struct + from .constants import ( ADVERT_FLAG_HAS_FEATURE1, ADVERT_FLAG_HAS_FEATURE2, @@ -10,6 +12,7 @@ ADVERT_FLAG_IS_CHAT_NODE, ADVERT_FLAG_IS_REPEATER, ADVERT_FLAG_IS_ROOM_SERVER, + ADVERT_FLAG_IS_SENSOR, PUB_KEY_SIZE, SIGNATURE_SIZE, TIMESTAMP_SIZE, @@ -17,13 +20,14 @@ # Lookup tables APPDATA_FLAGS = { - 0x01: "is_chat_node", - 0x02: "is_repeater", - 0x04: "is_room_server", - 0x10: "has_location", - 0x20: "has_feature_1", - 0x40: "has_feature_2", - 0x80: "has_name", + ADVERT_FLAG_IS_CHAT_NODE: "is_chat_node", + ADVERT_FLAG_IS_REPEATER: "is_repeater", + ADVERT_FLAG_IS_ROOM_SERVER: "is_room_server", + ADVERT_FLAG_IS_SENSOR: "is_sensor", + ADVERT_FLAG_HAS_LOCATION: "has_location", + ADVERT_FLAG_HAS_FEATURE1: "has_feature_1", + ADVERT_FLAG_HAS_FEATURE2: "has_feature_2", + ADVERT_FLAG_HAS_NAME: "has_name", } REQUEST_TYPES = {0x01: "get_status", 0x02: "keepalive", 0x03: "get_telemetry_data"} @@ -96,48 +100,48 @@ def parse_advert_payload(payload: bytes): def decode_appdata(appdata: bytes) -> dict: - result = {} - offset = 0 if len(appdata) < 1: raise ValueError("Appdata too short to contain flags") + + result: dict[str, object] = {} + offset = 0 flags = appdata[offset] result["flags"] = flags offset += 1 - # Parse conditional fields based on flags (following the same logic as packet_analyzer) - if flags & 0x10: # has_location - if len(appdata) >= offset + 8: - import struct - - lat_raw = struct.unpack(" bytes: + nonlocal offset + end = offset + length + if end > len(appdata): + raise ValueError( + f"Appdata indicates {field}, but only {len(appdata) - offset} bytes remain" + ) + chunk = appdata[offset:end] + offset = end + return chunk - if flags & 0x20: # has_feature_1 - if len(appdata) >= offset + 2: - import struct - - result["feature_1"] = struct.unpack("= offset + 2: - import struct + if flags & ADVERT_FLAG_HAS_FEATURE1: + (feature_one,) = struct.unpack(" offset: + if flags & ADVERT_FLAG_HAS_NAME: + name_bytes = appdata[offset:] + if name_bytes: try: - name = appdata[offset:].decode("utf-8").rstrip("\x00").strip() - if name: # Only add if non-empty + name = name_bytes.decode("utf-8").rstrip("\x00").strip() + if name: result["node_name"] = name except UnicodeDecodeError: - # If UTF-8 decoding fails, store as hex for debugging - result["raw_name_bytes"] = appdata[offset:].hex() + result["raw_name_bytes"] = name_bytes.hex() result["name_decode_error"] = True return result diff --git a/tests/test_contact_book.py b/tests/test_contact_book.py new file mode 100644 index 0000000..d8342bd --- /dev/null +++ b/tests/test_contact_book.py @@ -0,0 +1,83 @@ +import pytest + +from pymc_core.node.contact_book import ContactBook, ContactBookPreferences +from pymc_core.protocol import ( + CONTACT_TYPE_CHAT_NODE, + CONTACT_TYPE_HYBRID, + CONTACT_TYPE_ROOM_SERVER, +) + + +def _pubkey(prefix: int) -> str: + return (bytes([prefix]) + b"\xAA" * 31).hex() + + +def test_room_server_gets_cli_and_telemetry_by_default(): + book = ContactBook() + record = book.add_contact({"public_key": _pubkey(0x21), "type": CONTACT_TYPE_ROOM_SERVER}) + + assert book.can_execute_cli(record) + assert book.can_receive_telemetry(record) + assert book.can_use_bridge(record) is False + + +def test_allow_read_only_enables_cli_for_chat_nodes(): + prefs = ContactBookPreferences(allow_read_only=True) + book = ContactBook(prefs=prefs) + record = book.add_contact({"public_key": _pubkey(0x11), "type": CONTACT_TYPE_CHAT_NODE}) + + assert book.can_execute_cli(record) + assert book.can_receive_telemetry(record) is False + + +def test_bridge_permission_tracks_preference(): + prefs = ContactBookPreferences(bridge_enabled=True) + book = ContactBook(prefs=prefs) + record = book.add_contact({"public_key": _pubkey(0x33), "type": CONTACT_TYPE_HYBRID}) + + assert book.can_use_bridge(record) + + book.update_preferences(bridge_enabled=False) + assert book.can_use_bridge(record) is False + + +def test_set_permissions_overrides_defaults(): + book = ContactBook() + record = book.add_contact({"public_key": _pubkey(0x44), "type": CONTACT_TYPE_CHAT_NODE}) + + assert book.can_execute_cli(record) is False + + book.set_permissions(record.public_key, allow_cli=True, allow_telemetry=True) + + assert book.can_execute_cli(record) + assert book.can_receive_telemetry(record) + + +def test_get_by_hash_and_remove(): + book = ContactBook() + record = book.add_contact({"public_key": _pubkey(0x55), "name": "peer"}) + + retrieved = book.get_by_hash(record.src_hash()) + assert retrieved is record + + removed = book.remove_contact(record.public_key) + assert removed is True + assert book.list_contacts() == [] + + +def test_add_contact_dict_preserves_fields(): + book = ContactBook() + data = { + "public_key": _pubkey(0x66), + "name": "sensor", + "type": CONTACT_TYPE_CHAT_NODE, + "longitude": 1.23, + "latitude": 4.56, + "flags": 0xAA, + } + record = book.add_contact(data) + + assert record.name == "sensor" + assert record.longitude == pytest.approx(1.23) + assert record.latitude == pytest.approx(4.56) + assert record.flags == 0xAA diff --git a/tests/test_dispatcher.py b/tests/test_dispatcher.py index a805815..363148d 100644 --- a/tests/test_dispatcher.py +++ b/tests/test_dispatcher.py @@ -4,7 +4,7 @@ import pytest from pymc_core.node.dispatcher import Dispatcher, DispatcherState -from pymc_core.protocol import Packet +from pymc_core.protocol import Packet, PacketTimingUtils from pymc_core.protocol.constants import PAYLOAD_TYPE_ACK, PAYLOAD_TYPE_ADVERT, PAYLOAD_TYPE_TXT_MSG from pymc_core.protocol.packet_filter import PacketFilter @@ -230,6 +230,89 @@ async def test_process_received_packet_duplicate(self, dispatcher): assert mock_handler.call_count == 1 +class TestDispatcherOwnPacketDetection: + @pytest.mark.asyncio + async def test_is_own_packet_detects_recent_crc(self, dispatcher): + payload = b"loopback" + packet_data = create_test_packet(PAYLOAD_TYPE_TXT_MSG, payload) + pkt = Packet() + pkt.read_from(packet_data) + + dispatcher._record_outbound_packet_crc(pkt.get_crc()) + + assert dispatcher._is_own_packet(pkt) is True + + @pytest.mark.asyncio + async def test_is_own_packet_requires_tracked_crc(self, dispatcher): + payload = bytearray(b"hash_collision") + if len(payload) < 2: + payload.extend(b"\x00" * (2 - len(payload))) + packet_data = create_test_packet(PAYLOAD_TYPE_TXT_MSG, bytes(payload)) + pkt = Packet() + pkt.read_from(packet_data) + + # Force src hash byte to match local identity to simulate collision + pkt.payload[1] = dispatcher.local_identity.get_public_key()[0] + + assert dispatcher._is_own_packet(pkt) is False + + @pytest.mark.asyncio + async def test_recent_crc_tracking_expires(self, dispatcher): + payload = b"expire_me" + packet_data = create_test_packet(PAYLOAD_TYPE_TXT_MSG, payload) + pkt = Packet() + pkt.read_from(packet_data) + + dispatcher._own_packet_cache_ttl = 0.05 + dispatcher._record_outbound_packet_crc(pkt.get_crc()) + + await asyncio.sleep(0.1) + + assert dispatcher._is_own_packet(pkt) is False + + +class TestDispatcherFloodDelays: + @pytest.mark.asyncio + async def test_flood_packet_delay_applied(self, dispatcher, monkeypatch): + payload = b"delay_me" + packet_data = create_test_packet(PAYLOAD_TYPE_TXT_MSG, payload) + + sleep_calls: list[float] = [] + + async def fake_sleep(delay: float): + sleep_calls.append(delay) + + monkeypatch.setattr("asyncio.sleep", fake_sleep) + + dispatcher._calculate_flood_delay_ms = Mock(return_value=(200, 0.6, 300.0)) + dispatcher._dispatch = AsyncMock() + + await dispatcher._process_received_packet(packet_data) + + assert sleep_calls == [0.2] + dispatcher._dispatch.assert_awaited_once() + + @pytest.mark.asyncio + async def test_flood_packet_immediate_when_delay_low(self, dispatcher, monkeypatch): + payload = b"fast" + packet_data = create_test_packet(PAYLOAD_TYPE_TXT_MSG, payload) + + sleep_calls: list[float] = [] + + async def fake_sleep(delay: float): + sleep_calls.append(delay) + + monkeypatch.setattr("asyncio.sleep", fake_sleep) + + dispatcher._calculate_flood_delay_ms = Mock(return_value=(10, 0.1, 150.0)) + dispatcher._dispatch = AsyncMock() + + await dispatcher._process_received_packet(packet_data) + + assert sleep_calls == [] + dispatcher._dispatch.assert_awaited_once() + + class TestDispatcherACKSystem: """Test ACK system functionality.""" @@ -400,6 +483,70 @@ async def test_send_packet_with_ack_waiting(self, dispatcher): assert result is True + @pytest.mark.asyncio + async def test_send_packet_waits_for_airtime_budget(self, dispatcher, monkeypatch): + packet = Packet() + packet.header = (0 << 6) | (PAYLOAD_TYPE_TXT_MSG << 2) + packet.payload = bytearray(b"budget") + packet.payload_len = len(packet.payload) + packet.path_len = 0 + + dispatcher.radio.send = AsyncMock(return_value=True) + + sleep_calls: list[float] = [] + + async def fake_sleep(delay: float): + sleep_calls.append(delay) + + dispatcher._next_tx_allowed_at = asyncio.get_event_loop().time() + 0.1 + + monkeypatch.setattr("asyncio.sleep", fake_sleep) + monkeypatch.setattr( + "pymc_core.node.dispatcher.PacketTimingUtils.estimate_airtime_ms", + Mock(return_value=120.0), + ) + monkeypatch.setattr( + "pymc_core.node.dispatcher.PacketTimingUtils.calc_airtime_budget_delay_ms", + Mock(return_value=400.0), + ) + + await dispatcher.send_packet(packet, wait_for_ack=False) + + assert sleep_calls and pytest.approx(sleep_calls[0], rel=1e-3) == 0.1 + assert dispatcher._next_tx_allowed_at >= asyncio.get_event_loop().time() + + @pytest.mark.asyncio + async def test_send_packet_retries_cad_until_clear(self, dispatcher, monkeypatch): + packet = Packet() + packet.header = (0 << 6) | (PAYLOAD_TYPE_TXT_MSG << 2) + packet.payload = bytearray(b"cad") + packet.payload_len = len(packet.payload) + packet.path_len = 0 + + dispatcher.radio.send = AsyncMock(return_value=True) + dispatcher.radio.perform_cad = AsyncMock(side_effect=[True, True, False]) + + sleep_calls: list[float] = [] + + async def fake_sleep(delay: float): + sleep_calls.append(delay) + + monkeypatch.setattr("asyncio.sleep", fake_sleep) + monkeypatch.setattr( + "pymc_core.node.dispatcher.PacketTimingUtils.estimate_airtime_ms", + Mock(return_value=80.0), + ) + monkeypatch.setattr( + "pymc_core.node.dispatcher.PacketTimingUtils.calc_airtime_budget_delay_ms", + Mock(return_value=200.0), + ) + + await dispatcher.send_packet(packet, wait_for_ack=False) + + assert dispatcher.radio.perform_cad.await_count == 3 + expected_delay = PacketTimingUtils.get_cad_fail_retry_delay_ms() / 1000.0 + assert sleep_calls == [expected_delay, expected_delay] + def test_own_packet_detection(self, dispatcher): """Test detection of own packets.""" # Create packet with our own address as source diff --git a/tests/test_handlers.py b/tests/test_handlers.py index eb9a91a..3e97545 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -1,12 +1,15 @@ +import struct from unittest.mock import AsyncMock, MagicMock import pytest # from pymc_core.node.events import MeshEvents # Not currently used +from pymc_core.node.contact_book import ContactBook from pymc_core.node.handlers import ( AckHandler, AdvertHandler, BaseHandler, + ControlHandler, GroupTextHandler, LoginResponseHandler, PathHandler, @@ -16,14 +19,19 @@ ) from pymc_core.protocol import LocalIdentity, Packet, PacketBuilder from pymc_core.protocol.constants import ( + ADV_TYPE_REPEATER, + ADV_TYPE_SENSOR, PAYLOAD_TYPE_ACK, PAYLOAD_TYPE_ADVERT, PAYLOAD_TYPE_ANON_REQ, + PAYLOAD_TYPE_CONTROL, PAYLOAD_TYPE_GRP_TXT, + PAYLOAD_TYPE_MULTIPART, PAYLOAD_TYPE_PATH, PAYLOAD_TYPE_RESPONSE, PAYLOAD_TYPE_TRACE, PAYLOAD_TYPE_TXT_MSG, + PH_TYPE_SHIFT, PUB_KEY_SIZE, SIGNATURE_SIZE, TIMESTAMP_SIZE, @@ -127,6 +135,31 @@ async def test_call_discrete_ack(self): callback.assert_called_once_with(0x12345678) + @pytest.mark.asyncio + async def test_process_multipart_ack_valid(self): + """Test decoding of multipart ACK payloads.""" + packet = Packet() + packet.payload = bytearray(b"\x13\x78\x56\x34\x12") + packet.payload_len = len(packet.payload) + + crc = await self.handler.process_multipart_ack(packet) + assert crc == 0x12345678 + + @pytest.mark.asyncio + async def test_call_multipart_ack(self): + """Ensure multipart ACK packets trigger dispatcher callback.""" + packet = Packet() + packet.payload = bytearray(b"\x13\x78\x56\x34\x12") + packet.payload_len = len(packet.payload) + packet.header = PAYLOAD_TYPE_MULTIPART << PH_TYPE_SHIFT + + callback = MagicMock() + self.handler.set_ack_received_callback(callback) + + await self.handler(packet) + + callback.assert_called_once_with(0x12345678) + # Text Message Handler Tests class TestTextMessageHandler: @@ -173,6 +206,90 @@ async def test_call_with_short_payload(self): # Should return early without processing self.log_fn.assert_called() +class TestTextHandlerACL: + @pytest.mark.asyncio + async def test_cli_command_blocked_without_permission(self, monkeypatch): + local_identity = LocalIdentity() + contacts = ContactBook() + peer = LocalIdentity() + record = contacts.add_contact( + { + "public_key": peer.get_public_key().hex(), + "name": "peer", + "type": 1, + } + ) + contacts.set_permissions(record.public_key, allow_cli=False) + + log_fn = MagicMock() + send_packet = AsyncMock() + handler = TextMessageHandler(local_identity, contacts, log_fn, send_packet, None) + + dest_hash = local_identity.get_public_key()[0] + src_hash = record.src_hash() + packet = Packet() + packet.payload = bytearray([dest_hash, src_hash]) + bytearray(b"\x00" * 16) + packet.payload_len = len(packet.payload) + packet.header = 0 + + plaintext = (1234).to_bytes(4, "little") + bytes([0x01]) + b"reboot" + monkeypatch.setattr( + "pymc_core.node.handlers.text.CryptoUtils.mac_then_decrypt", + lambda *args, **kwargs: plaintext, + ) + + await handler(packet) + + send_packet.assert_not_called() + log_fn.assert_called() + + +class TestProtocolResponseACL: + @pytest.mark.asyncio + async def test_telemetry_blocked_when_not_allowed(self, monkeypatch): + contacts = ContactBook() + local_identity = LocalIdentity() + peer = LocalIdentity() + record = contacts.add_contact( + { + "public_key": peer.get_public_key().hex(), + "name": "peer", + "type": 1, + } + ) + contacts.set_permissions(record.public_key, allow_telemetry=False) + + log_fn = MagicMock() + handler = ProtocolResponseHandler(log_fn, local_identity, contacts) + + captured = [] + + def callback(success, text, parsed): + captured.append((success, text, parsed)) + + src_hash = record.src_hash() + handler.set_response_callback(src_hash, callback) + + packet = Packet() + dest_hash = local_identity.get_public_key()[0] + packet.payload = bytearray([dest_hash, src_hash]) + bytearray(b"\x00" * 8) + packet.payload_len = len(packet.payload) + packet.header = (PAYLOAD_TYPE_PATH << 2) + + plaintext = (4321).to_bytes(4, "little") + b"\x01\x02\x03" + monkeypatch.setattr( + "pymc_core.node.handlers.protocol_response.CryptoUtils.mac_then_decrypt", + lambda *args, **kwargs: plaintext, + ) + + await handler(packet) + + assert captured, "Expected callback to be invoked" + success, text, parsed = captured[0] + assert success is False + assert text == "Telemetry blocked by ACL" + assert parsed == {"type": "telemetry", "acl_blocked": True} + # Advert Handler Tests class TestAdvertHandler: @@ -243,9 +360,30 @@ async def test_advert_handler_ignores_self_advert(self): class TestPathHandler: def setup_method(self): self.log_fn = MagicMock() + self.local_identity = LocalIdentity() + self.remote_identity = LocalIdentity() + self.contacts = ContactBook() + self.contact = self.contacts.add_contact( + { + "public_key": self.remote_identity.get_public_key().hex(), + "name": "peer", + "type": 1, + } + ) + self.contacts.set_permissions(self.contact.public_key, allow_telemetry=True) self.ack_handler = AckHandler(self.log_fn) - self.protocol_response_handler = MagicMock() - self.handler = PathHandler(self.log_fn, self.ack_handler, self.protocol_response_handler) + self.received_crcs: list[int] = [] + self.ack_handler.set_ack_received_callback(lambda crc: self.received_crcs.append(crc)) + self.protocol_response_handler = ProtocolResponseHandler( + self.log_fn, self.local_identity, self.contacts + ) + self.handler = PathHandler( + self.log_fn, + local_identity=self.local_identity, + contact_book=self.contacts, + ack_handler=self.ack_handler, + protocol_response_handler=self.protocol_response_handler, + ) def test_payload_type(self): """Test path handler payload type.""" @@ -256,6 +394,56 @@ def test_path_handler_initialization(self): assert self.handler._log == self.log_fn assert self.handler._ack_handler == self.ack_handler assert self.handler._protocol_response_handler == self.protocol_response_handler + assert self.handler._contact_book == self.contacts + assert self.handler._local_identity == self.local_identity + + def _shared_secret(self) -> bytes: + return self.remote_identity.calc_shared_secret(self.local_identity.get_private_key()) + + def _build_path_packet(self, extra_type: int, extra_payload: bytes, path: list[int]): + return PacketBuilder.create_path_return( + dest_hash=self.local_identity.get_public_key()[0], + src_hash=self.remote_identity.get_public_key()[0], + secret=self._shared_secret(), + path=path, + extra_type=extra_type, + extra=extra_payload, + ) + + @pytest.mark.asyncio + async def test_path_handler_updates_path_and_notifies_ack(self): + path = [1, 2, 3] + crc = 0x11223344 + packet = self._build_path_packet(PAYLOAD_TYPE_ACK, crc.to_bytes(4, "little"), path) + + await self.handler(packet) + + contact = self.contacts.get_by_public_key(self.remote_identity.get_public_key().hex()) + assert contact is not None + assert contact.out_path == path + assert self.received_crcs == [crc] + + @pytest.mark.asyncio + async def test_path_handler_forwards_protocol_responses(self): + captured = [] + src_hash = self.remote_identity.get_public_key()[0] + self.protocol_response_handler.set_response_callback( + src_hash, lambda success, text, parsed: captured.append((success, text, parsed)) + ) + + payload = b"mesh ok" + packet = self._build_path_packet(PAYLOAD_TYPE_RESPONSE, payload, [7, 8]) + + await self.handler(packet) + + assert self.received_crcs == [] + assert len(captured) == 1 + success, text, parsed = captured[0] + assert success is True + assert text.startswith("Unknown telemetry") + assert parsed.get("type") == "telemetry" + contact = self.contacts.get_by_public_key(self.remote_identity.get_public_key().hex()) + assert contact.out_path == [7, 8] # Group Text Handler Tests @@ -304,6 +492,136 @@ def test_login_response_handler_initialization(self): assert self.handler.local_identity == self.local_identity assert self.handler.local_identity == self.local_identity +class TestControlHandler: + def setup_method(self): + self.log_fn = MagicMock() + self.handler = ControlHandler(self.log_fn) + + def test_payload_type(self): + assert ControlHandler.payload_type() == PAYLOAD_TYPE_CONTROL + + @pytest.mark.asyncio + async def test_discovery_request_callback_receives_details(self): + captured: list[dict] = [] + + def on_request(data: dict): + captured.append(data) + + self.handler.set_request_callback(on_request) + + pkt = Packet() + tag = 0xA1B2C3D4 + since = 1_700_000_000 + filter_mask = (1 << ADV_TYPE_REPEATER) | (1 << ADV_TYPE_SENSOR) + payload = bytearray() + payload.append(0x80 | 0x01) + payload.append(filter_mask) + payload.extend(struct.pack(" Packet: + pkt = Packet() + pkt.header = (1 << 6) | (0 << 4) | (3 << 2) | 0 + pkt.payload = bytearray(b"hello world") + pkt.payload_len = len(pkt.payload) + pkt.path_len = 0 + return pkt + + +def test_calc_transport_code_matches_reference(sample_packet): + key = bytes(range(16)) + assert calc_transport_code(key, sample_packet) == 0x8D7B + + +def test_transport_key_wrapper_matches_function(sample_packet): + key = bytes(range(16)) + tk = TransportKey(key) + assert tk.calc_transport_code(sample_packet) == calc_transport_code(key, sample_packet) + + +def test_calc_transport_code_reserves_zero_and_ffff(sample_packet, monkeypatch): + called = {"data": None} + + def fake_hmac(key, data): + called["data"] = data + return b"\x00\x00" + b"\x00" * 30 + + monkeypatch.setattr( + "pymc_core.protocol.transport_keys.CryptoUtils._hmac_sha256", + fake_hmac, + ) + assert calc_transport_code(bytes(range(16)), sample_packet) == 1 + + def fake_hmac_ff(key, data): + return b"\xFF\xFF" + b"\x00" * 30 + + monkeypatch.setattr( + "pymc_core.protocol.transport_keys.CryptoUtils._hmac_sha256", + fake_hmac_ff, + ) + assert calc_transport_code(bytes(range(16)), sample_packet) == 0xFFFE + + +@pytest.mark.parametrize( + "name", + ["", "usa", "#" + "a" * 65], +) +def test_derive_auto_key_validation(name): + with pytest.raises(ValueError): + derive_auto_key(name) + + +def test_key_store_auto_key_caches(monkeypatch): + store = TransportKeyStore(max_entries=4) + calls = {"count": 0} + + def fake_derive(name): + calls["count"] += 1 + return b"A" * 16 + + monkeypatch.setattr( + "pymc_core.protocol.transport_keys.derive_auto_key", + fake_derive, + ) + + first = store.get_auto_key_for(17, "#test") + second = store.get_auto_key_for(17, "#other") + + assert first.key == second.key == b"A" * 16 + assert calls["count"] == 1 + + +def test_key_store_save_load_remove(): + store = TransportKeyStore(max_entries=2) + key_a = bytes.fromhex("00" * 15 + "01") + key_b = bytes.fromhex("00" * 15 + "02") + + assert store.save_keys_for(1, [key_a, TransportKey(key_b)]) + loaded = store.load_keys_for(1) + assert len(loaded) == 2 + assert {k.key for k in loaded} == {key_a, key_b} + + assert store.remove_keys(1) + assert store.load_keys_for(1) == [] + + assert store.clear() is False # already empty + + +def test_key_store_cache_bounds(): + store = TransportKeyStore(max_entries=3) + for region_id in range(6): + store.save_keys_for(region_id, [bytes([region_id % 256]) * 16]) + + snapshot = store.cache_snapshot() + assert len(snapshot) == 3 + assert [rid for rid, _ in snapshot] == [3, 4, 5]