diff --git a/examples/discover_nodes.py b/examples/discover_nodes.py new file mode 100644 index 0000000..097be79 --- /dev/null +++ b/examples/discover_nodes.py @@ -0,0 +1,186 @@ +#!/usr/bin/env python3 +""" +Minimal example: Discover nearby mesh nodes. + +This example demonstrates how to broadcast a discovery request +and collect responses from nearby repeaters and nodes in the mesh network. + +The discovery request is sent as a zero-hop broadcast, and nearby nodes +will respond with their public key and signal strength information. + +Features: +- Asynchronous callback-based response collection +- Configurable discovery filter (node types to discover) +- Signal strength data (SNR and RSSI) for each discovered node +- Automatic timeout after specified duration +""" + +import asyncio +import random +import time + +from common import create_mesh_node + +from pymc_core.protocol.packet_builder import PacketBuilder + +# ADV_TYPE_REPEATER = 2, so filter mask is (1 << 2) = 0x04 +FILTER_REPEATERS = 0x04 # Bit 2 set for repeater node type + + +async def discover_nodes( + radio_type: str = "waveshare", + serial_port: str = "/dev/ttyUSB0", + timeout: float = 5.0, + filter_mask: int = FILTER_REPEATERS, +): + """ + Discover nearby mesh nodes using control packets. + + Args: + radio_type: Radio hardware type ("waveshare", "uconsole", etc.) + serial_port: Serial port for KISS TNC + timeout: How long to wait for responses (seconds) + filter_mask: Node types to discover (bitmask of ADV_TYPE values, e.g., ADV_TYPE_REPEATER = 2, so mask = 0x04 for repeaters) + """ + mesh_node, identity = create_mesh_node("DiscoveryNode", radio_type, serial_port) + + # Dictionary to store discovered nodes + discovered_nodes = {} + + # Create callback to collect discovery responses + def on_discovery_response(response_data: dict): + """Handle discovery response callback.""" + tag = response_data.get("tag", 0) + node_type = response_data.get("node_type", 0) + inbound_snr = response_data.get("inbound_snr", 0.0) # Their RX of our request + response_snr = response_data.get("response_snr", 0.0) # Our RX of their response + rssi = response_data.get("rssi", 0) + pub_key = response_data.get("pub_key", "") + timestamp = response_data.get("timestamp", 0) + + # Get node type name + node_type_names = {1: "Chat Node", 2: "Repeater", 3: "Room Server"} + node_type_name = node_type_names.get(node_type, f"Unknown({node_type})") + + # Store node info + node_id = pub_key[:16] # Use first 8 bytes as ID + if node_id not in discovered_nodes: + discovered_nodes[node_id] = { + "pub_key": pub_key, + "node_type": node_type_name, + "inbound_snr": inbound_snr, + "response_snr": response_snr, + "rssi": rssi, + "timestamp": timestamp, + } + + print( + f"✓ Discovered {node_type_name}: {node_id}... " + f"(TX→RX SNR: {inbound_snr:+.1f}dB, RX←TX SNR: {response_snr:+.1f}dB, " + f"RSSI: {rssi}dBm)" + ) + + # Get the control handler and set up callback + control_handler = mesh_node.dispatcher.control_handler + if not control_handler: + print("Error: Control handler not available") + return + + # Generate random tag for this discovery request + discovery_tag = random.randint(0, 0xFFFFFFFF) + + # Set up callback for responses matching this tag + control_handler.set_response_callback(discovery_tag, on_discovery_response) + + # Create discovery request packet + # filter_mask: 0x04 = bit 2 set (1 << ADV_TYPE_REPEATER where ADV_TYPE_REPEATER=2) + # since: 0 = discover all nodes regardless of modification time + pkt = PacketBuilder.create_discovery_request( + tag=discovery_tag, filter_mask=filter_mask, since=0, prefix_only=False + ) + + print(f"Sending discovery request (tag: 0x{discovery_tag:08X})...") + print(f"Filter mask: 0x{filter_mask:02X} (node types to discover)") + print(f"Waiting {timeout} seconds for responses...\n") + + # Send as zero-hop broadcast (no routing path) + success = await mesh_node.dispatcher.send_packet(pkt, wait_for_ack=False) + + if success: + print("Discovery request sent successfully") + + # Wait for responses + start_time = time.time() + while time.time() - start_time < timeout: + await asyncio.sleep(0.1) + + # Display results + print(f"\n{'='*60}") + print(f"Discovery complete - found {len(discovered_nodes)} node(s)") + print(f"{'='*60}\n") + + if discovered_nodes: + for node_id, info in discovered_nodes.items(): + print(f"Node: {node_id}...") + print(f" Type: {info['node_type']}") + print(f" TX→RX SNR: {info['inbound_snr']:+.1f} dB (our request at their end)") + print(f" RX←TX SNR: {info['response_snr']:+.1f} dB (their response at our end)") + print(f" RSSI: {info['rssi']} dBm") + print(f" Public Key: {info['pub_key']}") + print() + else: + print("No nodes discovered.") + print("This could mean:") + print(" - No nodes are within range") + print(" - No nodes match the filter criteria") + print(" - Radio configuration mismatch") + + else: + print("Failed to send discovery request") + + # Clean up callback + control_handler.clear_response_callback(discovery_tag) + + +def main(): + """Main function for running the discovery example.""" + import argparse + + parser = argparse.ArgumentParser(description="Discover nearby mesh nodes") + parser.add_argument( + "--radio-type", + choices=["waveshare", "uconsole", "meshadv-mini", "kiss-tnc"], + default="waveshare", + help="Radio hardware type (default: waveshare)", + ) + parser.add_argument( + "--serial-port", + default="/dev/ttyUSB0", + help="Serial port for KISS TNC (default: /dev/ttyUSB0)", + ) + parser.add_argument( + "--timeout", + type=float, + default=5.0, + help="Discovery timeout in seconds (default: 5.0)", + ) + parser.add_argument( + "--filter", + type=lambda x: int(x, 0), + default=FILTER_REPEATERS, + help="Node type filter mask (default: 0x04 for repeaters, bit position = node type)", + ) + + args = parser.parse_args() + + print(f"Using {args.radio_type} radio configuration") + if args.radio_type == "kiss-tnc": + print(f"Serial port: {args.serial_port}") + + asyncio.run( + discover_nodes(args.radio_type, args.serial_port, args.timeout, args.filter) + ) + + +if __name__ == "__main__": + main() diff --git a/examples/login_server.py b/examples/login_server.py new file mode 100644 index 0000000..f3dbb61 --- /dev/null +++ b/examples/login_server.py @@ -0,0 +1,439 @@ +#!/usr/bin/env python3 +""" +Simple login server example: Accept and authenticate client logins. + +This example demonstrates how to set up a basic authentication server +that responds to login requests from mesh clients, validates credentials, +and manages an access control list (ACL). + +The server supports: +- Admin password authentication +- Guest password authentication +- ACL-based authentication (blank password) +- Automatic response to login attempts + +This example implements the application-level authentication logic +(password validation, ACL management, permission assignment). +The handler (LoginServerHandler) performs only protocol operations. +""" + +import asyncio +import time +from typing import Dict, Optional + +from common import create_mesh_node + +from pymc_core.node.handlers.login_server import LoginServerHandler +from pymc_core.protocol import Identity, LocalIdentity +from pymc_core.protocol.constants import PUB_KEY_SIZE + + +def create_mesh_node_with_identity( + node_name: str, radio_type: str, serial_port: str, identity: LocalIdentity +) -> tuple[any, LocalIdentity]: + """Create a mesh node with a specific identity (modified from common.py)""" + import logging + import os + import sys + + # Set up logging (copied from common.py) + logger = logging.getLogger(__name__) + + # Add the src directory to the path so we can import pymc_core + sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) + + from common import create_radio + + from pymc_core.node.node import MeshNode + + logger.info(f"Creating mesh node with name: {node_name} using {radio_type} radio") + + try: + # Use the provided identity instead of creating a new one + logger.info( + f"Using provided identity with public key: {identity.get_public_key().hex()[:16]}..." + ) + + # Create the radio (copied from common.py logic) + radio = create_radio(radio_type, serial_port) + + # Initialize radio based on type + if radio_type == "kiss-tnc": + import time + + time.sleep(1) # Give KISS time to initialize + if hasattr(radio, "begin"): + radio.begin() + # Check KISS status + if hasattr(radio, "kiss_mode_active") and radio.kiss_mode_active: + logger.info("KISS mode is active") + else: + logger.warning("KISS mode may not be active") + print("Warning: KISS mode may not be active") + else: + logger.debug("Calling radio.begin()...") + radio.begin() + logger.info("Radio initialized successfully") + + # Create a mesh node with the radio and identity + config = {"node": {"name": node_name}} + logger.debug(f"Creating MeshNode with config: {config}") + mesh_node = MeshNode(radio=radio, local_identity=identity, config=config) + logger.info(f"MeshNode created successfully: {node_name}") + + return mesh_node, identity + + except Exception as e: + logger.error(f"Failed to create mesh node: {e}") + raise + + +# ============================================================================= +# HARDCODED EXAMPLE CREDENTIALS - FOR TESTING ONLY! +# ============================================================================= +# Server identity (hardcoded for easy testing) +# Use a deterministic seed to always generate the same identity +EXAMPLE_SEED = bytes.fromhex("1111111111111111111111111111111111111111111111111111111111111111") + +# Example credentials +EXAMPLE_ADMIN_PASSWORD = "admin123" +EXAMPLE_GUEST_PASSWORD = "guest123" +# ============================================================================= + +# Permission levels +PERM_ACL_GUEST = 0x01 +PERM_ACL_ADMIN = 0x02 +PERM_ACL_ROLE_MASK = 0x03 + + +class ClientInfo: + """Represents an authenticated client in the access control list.""" + + def __init__(self, identity: Identity, permissions: int = 0): + self.id = identity + self.permissions = permissions + self.shared_secret = b"" + self.last_timestamp = 0 + self.last_activity = 0 + self.last_login_success = 0 + self.out_path_len = -1 # -1 means no path, need to discover + self.out_path = bytearray() + + def is_admin(self) -> bool: + """Check if client has admin permissions.""" + return (self.permissions & PERM_ACL_ROLE_MASK) == PERM_ACL_ADMIN + + def is_guest(self) -> bool: + """Check if client has guest permissions.""" + return (self.permissions & PERM_ACL_ROLE_MASK) == PERM_ACL_GUEST + + +class ClientACL: + """ + Access Control List for managing authenticated clients. + + Implements application-level authentication logic: + - Password validation + - Client state management + - Permission assignment + - Replay attack detection + """ + + def __init__( + self, + max_clients: int = 32, + admin_password: str = "admin123", + guest_password: str = "guest123", + ): + self.max_clients = max_clients + self.admin_password = admin_password + self.guest_password = guest_password + self.clients: Dict[bytes, ClientInfo] = {} # pub_key -> ClientInfo + + def authenticate_client( + self, client_identity: Identity, shared_secret: bytes, password: str, timestamp: int + ) -> tuple[bool, int]: + """ + Authenticate a client login request. + + This is the authentication callback used by LoginServerHandler. + It implements the application's password validation and ACL logic. + + Args: + client_identity: Client's identity + shared_secret: ECDH shared secret for encryption + password: Password provided by client + timestamp: Timestamp from client request + + Returns: + (success: bool, permissions: int) - True/permissions on success, False/0 on failure + """ + pub_key = client_identity.get_public_key()[:PUB_KEY_SIZE] + + # Check for blank password (ACL-only authentication) + if not password: + client = self.clients.get(pub_key) + if client is None: + print(f"[ACL] Blank password, sender not in ACL") + return False, 0 + # Client exists in ACL, allow login with existing permissions + print(f"[ACL] ACL-based login for {pub_key[:6].hex()}...") + return True, client.permissions + + # Validate password + permissions = 0 + if password == self.admin_password: + permissions = PERM_ACL_ADMIN + print(f"[ACL] Admin password validated") + elif self.guest_password and password == self.guest_password: + permissions = PERM_ACL_GUEST + print(f"[ACL] Guest password validated") + else: + print(f"[ACL] Invalid password") + return False, 0 + + # Get or create client + client = self.clients.get(pub_key) + if client is None: + # Check capacity + if len(self.clients) >= self.max_clients: + print(f"[ACL] ACL full, cannot add client") + return False, 0 + + # Add new client + client = ClientInfo(client_identity, 0) + self.clients[pub_key] = client + print(f"[ACL] Added new client {pub_key[:6].hex()}...") + + # Check for replay attack + if timestamp <= client.last_timestamp: + print( + f"[ACL] Possible replay attack! timestamp={timestamp}, last={client.last_timestamp}" + ) + return False, 0 + + # Update client state + client.last_timestamp = timestamp + client.last_activity = int(time.time()) + client.last_login_success = int(time.time()) + client.permissions &= ~PERM_ACL_ROLE_MASK + client.permissions |= permissions + client.shared_secret = shared_secret + + print(f"[ACL] Login success! Permissions: {'ADMIN' if client.is_admin() else 'GUEST'}") + return True, client.permissions + + def get_client(self, pub_key: bytes) -> Optional[ClientInfo]: + """Get client by public key.""" + return self.clients.get(pub_key[:PUB_KEY_SIZE]) + + def get_num_clients(self) -> int: + """Get number of clients in ACL.""" + return len(self.clients) + + def get_all_clients(self): + """Get all clients.""" + return list(self.clients.values()) + + def remove_client(self, pub_key: bytes) -> bool: + """Remove client from ACL.""" + key = pub_key[:PUB_KEY_SIZE] + if key in self.clients: + del self.clients[key] + return True + return False + + +async def run_login_server( + radio_type: str = "waveshare", + serial_port: str = "/dev/ttyUSB0", + admin_password: str = EXAMPLE_ADMIN_PASSWORD, + guest_password: str = EXAMPLE_GUEST_PASSWORD, + use_hardcoded_identity: bool = True, +): + """ + Run a login authentication server. + + Args: + radio_type: Radio hardware type ("waveshare", "uconsole", etc.) + serial_port: Serial port for KISS TNC + admin_password: Password for admin access + guest_password: Password for guest access (empty string to disable) + use_hardcoded_identity: Use hardcoded identity for easy testing + """ + print("=" * 60) + print("PyMC Core - Login Server Example") + print("=" * 60) + print(f"Admin Password: {admin_password}") + print(f"Guest Password: {guest_password if guest_password else ''}") + print(f"Hardcoded Identity: {use_hardcoded_identity}") + print("=" * 60) + + # Create mesh node with optional hardcoded identity + if use_hardcoded_identity: + print("Using hardcoded example identity for easy testing...") + hardcoded_identity = LocalIdentity(seed=EXAMPLE_SEED) + mesh_node, identity = create_mesh_node_with_identity( + "LoginServer", radio_type, serial_port, hardcoded_identity + ) + else: + mesh_node, identity = create_mesh_node("LoginServer", radio_type, serial_port) + + # Get our public key info + our_pub_key = identity.get_public_key() + our_hash = our_pub_key[0] + print(f"Server Identity: {our_pub_key.hex()}") + print(f"Server Hash: 0x{our_hash:02X}") + print() + + # Create ACL for managing authenticated clients + acl = ClientACL(max_clients=32, admin_password=admin_password, guest_password=guest_password) + + # Create login server handler with authentication callback + login_handler = LoginServerHandler( + local_identity=identity, + log_fn=lambda msg: print(msg), + authenticate_callback=acl.authenticate_client, # Delegate authentication to ACL + ) + + # Set up packet sending callback + def send_packet_with_delay(packet, delay_ms: int): + """Send a packet with a delay.""" + asyncio.create_task(delayed_send(packet, delay_ms)) + + async def delayed_send(packet, delay_ms: int): + """Send packet after delay.""" + await asyncio.sleep(delay_ms / 1000.0) + try: + await mesh_node.dispatcher.send_packet(packet, wait_for_ack=False) + except Exception as e: + print(f"Error sending response: {e}") + + login_handler.set_send_packet_callback(send_packet_with_delay) + + # Register the handler with the dispatcher + mesh_node.dispatcher.register_handler(LoginServerHandler.payload_type(), login_handler) + + print("Login server started and listening...") + print(" Waiting for login requests from clients...") + print() + print("Commands:") + print(" - Press Ctrl+C to stop") + print(" - Type 'status' to show ACL status") + print(" - Type 'list' to list authenticated clients") + print() + + # Command processor + async def process_commands(): + """Process user commands.""" + import sys + + loop = asyncio.get_event_loop() + + while True: + # Check for stdin input + try: + if sys.stdin.readable(): + cmd = await loop.run_in_executor(None, sys.stdin.readline) + cmd = cmd.strip().lower() + + if cmd == "status": + print(f"\nACL Status:") + print( + f" Authenticated clients: {acl.get_num_clients()}/{acl.max_clients}" + ) + print() + + elif cmd == "list": + clients = acl.get_all_clients() + print(f"\n👥 Authenticated Clients ({len(clients)}):") + if not clients: + print(" ") + else: + for i, client in enumerate(clients, 1): + pub_key_hex = client.id.get_public_key()[:8].hex() + role = "ADMIN" if client.is_admin() else "GUEST" + print(f" {i}. {pub_key_hex}... [{role}]") + print() + + except Exception: + pass + + await asyncio.sleep(0.1) + + # Run command processor in background + asyncio.create_task(process_commands()) + + try: + while True: + await asyncio.sleep(1) + except KeyboardInterrupt: + print("\n\nShutting down login server...") + print(f" Final ACL size: {acl.get_num_clients()} clients") + + +def main(): + """Main function for running the example.""" + import argparse + + parser = argparse.ArgumentParser( + description="Run a login authentication server for the mesh network" + ) + parser.add_argument( + "--radio-type", + choices=["waveshare", "uconsole", "meshadv-mini", "kiss-tnc"], + default="waveshare", + help="Radio hardware type (default: waveshare)", + ) + parser.add_argument( + "--serial-port", + default="/dev/ttyUSB0", + help="Serial port for KISS TNC (default: /dev/ttyUSB0)", + ) + parser.add_argument( + "--admin-password", + default=EXAMPLE_ADMIN_PASSWORD, + help=f"Admin password (default: {EXAMPLE_ADMIN_PASSWORD})", + ) + parser.add_argument( + "--guest-password", + default=EXAMPLE_GUEST_PASSWORD, + help=f"Guest password (default: {EXAMPLE_GUEST_PASSWORD}, empty to disable)", + ) + parser.add_argument( + "--use-random-identity", + action="store_true", + help="Use random identity instead of hardcoded example identity", + ) + + args = parser.parse_args() + + print(f"Using {args.radio_type} radio configuration") + if args.radio_type == "kiss-tnc": + print(f"Serial port: {args.serial_port}") + + # Show the identity that will be used + if not args.use_random_identity: + # Create a temporary identity to show what the keys will be + temp_identity = LocalIdentity(seed=EXAMPLE_SEED) + temp_pubkey = temp_identity.get_public_key() + print(f"Server Public Key: {temp_pubkey.hex()}") + print(f"Server Hash: 0x{temp_pubkey[0]:02X}") + print("(Use --use-random-identity to generate random keys instead)") + + try: + asyncio.run( + run_login_server( + args.radio_type, + args.serial_port, + args.admin_password, + args.guest_password, + use_hardcoded_identity=not args.use_random_identity, + ) + ) + except KeyboardInterrupt: + print("\nExample terminated by user") + + +if __name__ == "__main__": + main() diff --git a/examples/respond_to_discovery.py b/examples/respond_to_discovery.py new file mode 100644 index 0000000..bf8fb93 --- /dev/null +++ b/examples/respond_to_discovery.py @@ -0,0 +1,151 @@ +#!/usr/bin/env python3 +""" +Minimal example: Respond to discovery requests. + +This example demonstrates how to listen for discovery requests from other nodes +and automatically respond with this node's information. + +Simply run this script and it will respond to any discovery requests until stopped. +""" + +import asyncio + +from common import create_mesh_node + +from pymc_core.protocol.packet_builder import PacketBuilder + +# Node type values from C++ AdvertDataHelpers.h +ADV_TYPE_REPEATER = 2 +ADV_TYPE_CHAT_NODE = 1 +ADV_TYPE_ROOM_SERVER = 3 + + +async def respond_to_discovery( + radio_type: str = "waveshare", + serial_port: str = "/dev/ttyUSB0", + node_type: int = ADV_TYPE_REPEATER, +): + """ + Listen for discovery requests and respond with node information. + + Args: + radio_type: Radio hardware type ("waveshare", "uconsole", etc.) + serial_port: Serial port for KISS TNC + node_type: Type of this node (1=chat, 2=repeater, 3=room_server) + """ + mesh_node, identity = create_mesh_node("DiscoveryResponder", radio_type, serial_port) + + # Get our public key for responses + our_pub_key = identity.get_public_key() + + # Node type names for logging + node_type_names = { + ADV_TYPE_CHAT_NODE: "Chat Node", + ADV_TYPE_REPEATER: "Repeater", + ADV_TYPE_ROOM_SERVER: "Room Server", + } + node_type_name = node_type_names.get(node_type, f"Unknown({node_type})") + + # Create callback to handle discovery requests + def on_discovery_request(request_data: dict): + """Handle incoming discovery request.""" + tag = request_data.get("tag", 0) + filter_byte = request_data.get("filter", 0) + prefix_only = request_data.get("prefix_only", False) + snr = request_data.get("snr", 0.0) + rssi = request_data.get("rssi", 0) + + print( + f"📡 Discovery request: tag=0x{tag:08X}, " + f"filter=0x{filter_byte:02X}, SNR={snr:+.1f}dB, RSSI={rssi}dBm" + ) + + # Check if filter matches our node type + filter_mask = 1 << node_type + if (filter_byte & filter_mask) == 0: + print(f" ↳ Filter doesn't match, ignoring") + return + + # Create and send discovery response + print(f" ↳ Sending response...") + + pkt = PacketBuilder.create_discovery_response( + tag=tag, + node_type=node_type, + inbound_snr=snr, + pub_key=our_pub_key, + prefix_only=prefix_only, + ) + + # Send the response + asyncio.create_task(send_response(mesh_node, pkt, tag)) + + async def send_response(node, pkt, tag): + """Send discovery response packet.""" + try: + success = await node.dispatcher.send_packet(pkt, wait_for_ack=False) + if success: + print(f" ✓ Response sent\n") + else: + print(f" ✗ Failed to send\n") + except Exception as e: + print(f" ✗ Error: {e}\n") + + # Get the control handler and set up request callback + control_handler = mesh_node.dispatcher.control_handler + if not control_handler: + print("Error: Control handler not available") + return + + control_handler.set_request_callback(on_discovery_request) + + print(f" Listening for discovery requests as {node_type_name}") + print(f" Node type: {node_type} (filter: 0x{1 << node_type:02X})") + print(f" Public key: {our_pub_key.hex()[:32]}...") + print(f" Press Ctrl+C to stop\n") + + # Listen forever + try: + while True: + await asyncio.sleep(1) + except KeyboardInterrupt: + print("\n\n Stopped\n") + + control_handler.clear_request_callback() + + +def main(): + """Main function for running the discovery responder example.""" + import argparse + + parser = argparse.ArgumentParser(description="Respond to mesh node discovery requests") + parser.add_argument( + "--radio-type", + choices=["waveshare", "uconsole", "meshadv-mini", "kiss-tnc"], + default="waveshare", + help="Radio hardware type (default: waveshare)", + ) + parser.add_argument( + "--serial-port", + default="/dev/ttyUSB0", + help="Serial port for KISS TNC (default: /dev/ttyUSB0)", + ) + parser.add_argument( + "--node-type", + type=int, + choices=[1, 2, 3], + default=ADV_TYPE_CHAT_NODE, + help="Node type: 1=chat, 2=repeater, 3=room_server (default: 1)", + ) + + args = parser.parse_args() + + print(f"Using {args.radio_type} radio configuration") + if args.radio_type == "kiss-tnc": + print(f"Serial port: {args.serial_port}") + + asyncio.run(respond_to_discovery(args.radio_type, args.serial_port, args.node_type)) + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 193fbf9..96aecfa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "pymc_core" -version = "1.0.5" +version = "1.0.6" authors = [ {name = "Lloyd Newton", email = "lloyd@rightup.co.uk"}, ] @@ -35,8 +35,7 @@ dependencies = [ [project.optional-dependencies] radio = ["pyserial>=3.5"] hardware = [ - "gpiozero>=2.0.0", - "lgpio>=0.2.0", + "python-periphery>=2.4.1", "spidev>=3.5", "pyserial>=3.5", ] diff --git a/src/pymc_core/__init__.py b/src/pymc_core/__init__.py index b44ac9b..02f0cf5 100644 --- a/src/pymc_core/__init__.py +++ b/src/pymc_core/__init__.py @@ -3,7 +3,7 @@ Clean, simple API for building mesh network applications. """ -__version__ = "1.0.5" +__version__ = "1.0.6" # Core mesh functionality from .node.node import MeshNode diff --git a/src/pymc_core/hardware/base.py b/src/pymc_core/hardware/base.py index b63a3b6..35b8837 100644 --- a/src/pymc_core/hardware/base.py +++ b/src/pymc_core/hardware/base.py @@ -9,7 +9,7 @@ def begin(self): @abstractmethod async def send(self, data: bytes): - """Send a packet asynchronously.""" + """Send a packet asynchronously. Returns transmission metadata dict or None.""" pass @abstractmethod diff --git a/src/pymc_core/hardware/gpio_manager.py b/src/pymc_core/hardware/gpio_manager.py index 804db8c..399a5da 100644 --- a/src/pymc_core/hardware/gpio_manager.py +++ b/src/pymc_core/hardware/gpio_manager.py @@ -1,43 +1,137 @@ """ -GPIO Pin Manager for Raspberry Pi -Manages GPIO pins abstraction using gpiozero +GPIO Pin Manager for Linux SBCs +Manages GPIO pins abstraction using python-periphery +Works on Raspberry Pi, Orange Pi, Luckfox, and other Linux SBCs """ -import asyncio +import glob import logging -from typing import Callable, Optional - -from gpiozero import Button, Device, OutputDevice - -# Force gpiozero to use LGPIOFactory - no RPi.GPIO fallback -from gpiozero.pins.lgpio import LGPIOFactory +import sys +import threading +import time +from typing import Callable, Dict, Optional + +try: + from periphery import GPIO, EdgeEvent + + PERIPHERY_AVAILABLE = True +except ImportError: + # Mock GPIO classes for testing/non-hardware environments + PERIPHERY_AVAILABLE = False + GPIO = None + EdgeEvent = None + + class GPIOImportError(ImportError): + """Raised when GPIO functionality is used without python-periphery""" + + def __init__(self): + super().__init__( + "\n\nError: python-periphery library is required for GPIO management.\n" + "━" * 60 + "\n" + "This application requires GPIO hardware access which is only\n" + "available on Linux-based systems (Raspberry Pi, Orange Pi, etc.)\n\n" + "Reason: python-periphery uses Linux kernel interfaces that\n" + " don't exist on macOS or Windows.\n\n" + "Solutions:\n" + " • Run this application on a Linux SBC\n" + "━" * 60 + ) -Device.pin_factory = LGPIOFactory() logger = logging.getLogger("GPIOPinManager") class GPIOPinManager: - """Manages GPIO pins abstraction""" + """Manages GPIO pins abstraction using Linux GPIO character device interface""" + + def __init__(self, gpio_chip: str = "/dev/gpiochip0"): + """ + Initialize GPIO Pin Manager + + Args: + gpio_chip: Path to GPIO chip device (default: /dev/gpiochip0) + Set to "auto" to auto-detect first available chip - def __init__(self): - self._pins = {} - self._led_tasks = {} # Track active LED tasks + Raises: + GPIOImportError: If python-periphery is not available + """ + if not PERIPHERY_AVAILABLE: + raise GPIOImportError() + + self._gpio_chip = self._resolve_gpio_chip(gpio_chip) + self._pins: Dict[int, GPIO] = {} + self._led_threads: Dict[int, threading.Thread] = {} # Track active LED threads + self._led_stop_events: Dict[int, threading.Event] = {} # Stop events for LED threads + self._input_callbacks: Dict[int, Callable] = {} # Track input pin callbacks + self._edge_threads: Dict[int, threading.Thread] = {} # Track edge detection threads + self._edge_stop_events: Dict[int, threading.Event] = {} # Stop events for edge threads + + logger.debug(f"GPIO Manager initialized with chip: {self._gpio_chip}") + + def _resolve_gpio_chip(self, gpio_chip: str) -> str: + """Resolve GPIO chip path, auto-detecting if needed""" + if gpio_chip == "auto": + chips = sorted(glob.glob("/dev/gpiochip*")) + if chips: + logger.info(f"Auto-detected GPIO chips: {chips}, using {chips[0]}") + return chips[0] + else: + logger.warning("No GPIO chips found, defaulting to /dev/gpiochip0") + return "/dev/gpiochip0" + return gpio_chip def setup_output_pin(self, pin_number: int, initial_value: bool = False) -> bool: - """Setup an output pin with initial value""" + """ + Setup an output pin with initial value + + Args: + pin_number: GPIO line number + initial_value: Initial state (True=HIGH, False=LOW) + """ if pin_number == -1: return False try: + # Close existing pin if already configured if pin_number in self._pins: self._pins[pin_number].close() + del self._pins[pin_number] + + # Open GPIO pin as output + gpio = GPIO(self._gpio_chip, pin_number, "out") + gpio.write(initial_value) + self._pins[pin_number] = gpio - self._pins[pin_number] = OutputDevice(pin_number, initial_value=initial_value) + logger.debug(f"Output pin {pin_number} configured (initial={initial_value})") return True except Exception as e: - logger.warning(f"Failed to setup output pin {pin_number}: {e}") - return False + error_msg = str(e).lower() + if "busy" in error_msg or "device or resource busy" in error_msg: + logger.error(f"GPIO pin {pin_number} is already in use by another process: {e}") + print(f"\nFATAL: GPIO Pin {pin_number} is already in use") + print("━" * 60) + print("The pin is being used by another process.") + print(f"\nDebug: sudo lsof /dev/gpiochip* | grep {pin_number}") + print("\nThe system cannot function without GPIO access.") + print("━" * 60) + sys.exit(1) + elif "permission denied" in error_msg: + logger.error(f"Permission denied for GPIO pin {pin_number}: {e}") + print(f"\nFATAL: Permission denied for GPIO pin {pin_number}") + print("━" * 60) + print("Solutions:") + print(" • Add user to gpio group: sudo usermod -a -G gpio $USER") + print(" • Then logout and login again") + print("━" * 60) + sys.exit(1) + else: + logger.error(f"Failed to setup output pin {pin_number}: {e}") + print(f"\nFATAL: Cannot setup GPIO pin {pin_number}") + print("━" * 60) + print(f"Error: {e}") + print("\nThe system cannot function without GPIO access.") + print("━" * 60) + sys.exit(1) def setup_input_pin( self, @@ -45,120 +139,326 @@ def setup_input_pin( pull_up: bool = False, callback: Optional[Callable] = None, ) -> bool: - """Setup an input pin with optional interrupt callback""" + """ + Setup an input pin with optional callback using hardware edge detection + + Args: + pin_number: GPIO line number + pull_up: Enable pull-up resistor (not all chips support this) + callback: Function to call on rising edge (hardware interrupt) + """ if pin_number == -1: return False try: + # Close existing pin if already configured if pin_number in self._pins: self._pins[pin_number].close() + del self._pins[pin_number] - self._pins[pin_number] = Button(pin_number, pull_up=pull_up) - if callback: - self._pins[pin_number].when_activated = callback + # Determine bias setting + bias = "pull_up" if pull_up else "default" + # Open GPIO pin as input with edge detection if callback provided + if callback: + gpio = GPIO(self._gpio_chip, pin_number, "in", bias=bias, edge="rising") + self._input_callbacks[pin_number] = callback + self._start_edge_detection(pin_number) + else: + # No callback, just simple input + gpio = GPIO(self._gpio_chip, pin_number, "in", bias=bias) + + self._pins[pin_number] = gpio + + logger.debug( + f"Input pin {pin_number} configured " + f"(pull_up={pull_up}, callback={callback is not None})" + ) return True except Exception as e: - logger.warning(f"Failed to setup input pin {pin_number}: {e}") - return False + error_msg = str(e).lower() + if "busy" in error_msg or "device or resource busy" in error_msg: + logger.error(f"GPIO pin {pin_number} is already in use by another process: {e}") + print(f"\nFATAL: GPIO Pin {pin_number} is already in use") + print("━" * 60) + print("The pin is being used by another process.") + print(f"\nDebug: sudo lsof /dev/gpiochip* | grep {pin_number}") + print("\nThe system cannot function without GPIO access.") + print("━" * 60) + sys.exit(1) + elif "permission denied" in error_msg: + logger.error(f"Permission denied for GPIO pin {pin_number}: {e}") + print(f"\nFATAL: Permission denied for GPIO pin {pin_number}") + print("━" * 60) + print("Solutions:") + print(" • Add user to gpio group: sudo usermod -a -G gpio $USER") + print(" • Then logout and login again") + print("━" * 60) + sys.exit(1) + else: + logger.error(f"Failed to setup input pin {pin_number}: {e}") + print(f"\nFATAL: Cannot setup GPIO pin {pin_number}") + print("━" * 60) + print(f"Error: {e}") + print("\nThe system cannot function without GPIO access.") + print("━" * 60) + sys.exit(1) def setup_interrupt_pin( self, pin_number: int, pull_up: bool = False, callback: Optional[Callable] = None, - ) -> Optional[Button]: - """Setup an interrupt pin and return the Button object for direct access""" + ) -> Optional[GPIO]: + """ + Setup an interrupt pin with edge detection (alias for setup_input_pin) + + Args: + pin_number: GPIO line number + pull_up: Enable pull-up resistor + callback: Function to call on rising edge (hardware interrupt) + + Returns: + GPIO object for direct access, or None on failure + """ if pin_number == -1: return None try: + # Close existing pin if already configured if pin_number in self._pins: self._pins[pin_number].close() + del self._pins[pin_number] + + # Determine bias setting + bias = "pull_up" if pull_up else "default" + + # Open GPIO pin as input with edge detection on rising edge + gpio = GPIO(self._gpio_chip, pin_number, "in", bias=bias, edge="rising") + self._pins[pin_number] = gpio - button = Button(pin_number, pull_up=pull_up) + # Setup callback with async edge monitoring if callback: - button.when_activated = callback + self._input_callbacks[pin_number] = callback + self._start_edge_detection(pin_number) - self._pins[pin_number] = button - return button + logger.debug( + f"Interrupt pin {pin_number} configured " + f"(pull_up={pull_up}, callback={callback is not None})" + ) + return gpio except Exception as e: - logger.warning(f"Failed to setup interrupt pin {pin_number}: {e}") - return None + error_msg = str(e).lower() + if "busy" in error_msg or "device or resource busy" in error_msg: + print(f"\nFATAL: GPIO Pin {pin_number} is already in use") + print("━" * 60) + print("The pin is being used by another process.") + print(f"\nDebug: sudo lsof /dev/gpiochip* | grep {pin_number}") + print("\nThe system cannot function without GPIO access.") + print("━" * 60) + sys.exit(1) + elif "permission denied" in error_msg: + print(f"\nFATAL: Permission denied for GPIO pin {pin_number}") + print("━" * 60) + print("Solutions:") + print(" • Add user to gpio group: sudo usermod -a -G gpio $USER") + print(" • Then logout and login again") + print("━" * 60) + sys.exit(1) + else: + logger.error(f"Failed to setup interrupt pin {pin_number}: {e}") + print(f"\nFATAL: Cannot setup GPIO pin {pin_number}") + print("━" * 60) + print(f"Error: {e}") + print("\nThe system cannot function without GPIO access.") + print("━" * 60) + sys.exit(1) + + def _start_edge_detection(self, pin_number: int) -> None: + """Start hardware edge detection thread""" + stop_event = threading.Event() + self._edge_stop_events[pin_number] = stop_event + + thread = threading.Thread( + target=self._monitor_edge_events, + args=(pin_number, stop_event), + daemon=True, + name=f"GPIO-Edge-{pin_number}", + ) + thread.start() + self._edge_threads[pin_number] = thread + logger.debug(f"Edge detection thread started for pin {pin_number}") + + def _monitor_edge_events(self, pin_number: int, stop_event: threading.Event) -> None: + """Monitor hardware edge events using poll() for interrupts""" + try: + gpio = self._pins.get(pin_number) + if not gpio: + return + + while not stop_event.is_set() and pin_number in self._pins: + try: + # Wait for edge event (kernel blocks until interrupt) + if gpio.poll(30.0) and not stop_event.is_set(): + # Consume event from kernel queue to prevent repeated triggers + event: EdgeEvent = gpio.read_event() + + # Only process rising edges (kernel filters, but verify) + if event.edge == "rising": + callback = self._input_callbacks.get(pin_number) + if callback: + callback() + + except Exception: + if not stop_event.is_set(): + time.sleep(0.1) # Brief pause on errors + + except Exception as e: + logger.error(f"Edge detection error for pin {pin_number}: {e}") def set_pin_high(self, pin_number: int) -> bool: """Set output pin to HIGH""" - if pin_number in self._pins and hasattr(self._pins[pin_number], "on"): + if pin_number in self._pins: try: - self._pins[pin_number].on() - return True + gpio = self._pins[pin_number] + if gpio.direction == "out": + gpio.write(True) + return True + else: + logger.warning(f"Pin {pin_number} is not configured as output") except Exception as e: logger.warning(f"Failed to set pin {pin_number} HIGH: {e}") return False def set_pin_low(self, pin_number: int) -> bool: """Set output pin to LOW""" - if pin_number in self._pins and hasattr(self._pins[pin_number], "off"): + if pin_number in self._pins: try: - self._pins[pin_number].off() - return True + gpio = self._pins[pin_number] + if gpio.direction == "out": + gpio.write(False) + return True + else: + logger.warning(f"Pin {pin_number} is not configured as output") except Exception as e: logger.warning(f"Failed to set pin {pin_number} LOW: {e}") return False + def read_pin(self, pin_number: int) -> Optional[bool]: + """ + Read current state of a pin + + Returns: + True for HIGH, False for LOW, None if pin not configured or error + """ + if pin_number in self._pins: + try: + return self._pins[pin_number].read() + except Exception as e: + logger.warning(f"Failed to read pin {pin_number}: {e}") + return None + def cleanup_pin(self, pin_number: int) -> None: """Clean up a specific pin""" + # Stop any LED thread for this pin + if pin_number in self._led_stop_events: + self._led_stop_events[pin_number].set() + if pin_number in self._led_threads: + self._led_threads[pin_number].join(timeout=2.0) + del self._led_threads[pin_number] + if pin_number in self._led_stop_events: + del self._led_stop_events[pin_number] + + # Stop any edge detection thread for this pin + if pin_number in self._edge_stop_events: + self._edge_stop_events[pin_number].set() + if pin_number in self._edge_threads: + self._edge_threads[pin_number].join(timeout=2.0) + del self._edge_threads[pin_number] + if pin_number in self._edge_stop_events: + del self._edge_stop_events[pin_number] + + # Remove callback + if pin_number in self._input_callbacks: + del self._input_callbacks[pin_number] + + # Close GPIO pin if pin_number in self._pins: try: self._pins[pin_number].close() del self._pins[pin_number] + logger.debug(f"Pin {pin_number} cleaned up") except Exception as e: logger.warning(f"Failed to cleanup pin {pin_number}: {e}") def cleanup_all(self) -> None: """Clean up all managed pins""" - # Cancel any running LED tasks - for task in self._led_tasks.values(): - if not task.done(): - task.cancel() - self._led_tasks.clear() - - # Clean up pins + # Stop all LED threads + for stop_event in self._led_stop_events.values(): + stop_event.set() + for thread in self._led_threads.values(): + thread.join(timeout=2.0) + self._led_threads.clear() + self._led_stop_events.clear() + + # Stop all edge detection threads + for stop_event in self._edge_stop_events.values(): + stop_event.set() + for thread in self._edge_threads.values(): + thread.join(timeout=2.0) + self._edge_threads.clear() + self._edge_stop_events.clear() + + # Clear callbacks + self._input_callbacks.clear() + + # Clean up all pins for pin_number in list(self._pins.keys()): - self.cleanup_pin(pin_number) + try: + self._pins[pin_number].close() + del self._pins[pin_number] + except Exception as e: + logger.warning(f"Failed to cleanup pin {pin_number}: {e}") + + logger.debug("All GPIO pins cleaned up") - async def _led_blink_task(self, pin_number: int, duration: float = 3.0) -> None: - """Internal task to blink LED for specified duration""" + def _led_blink_thread( + self, pin_number: int, duration: float, stop_event: threading.Event + ) -> None: + """Internal thread function to blink LED for specified duration""" try: # Turn LED on self.set_pin_high(pin_number) logger.debug(f"LED {pin_number} turned ON for {duration}s") - # Wait for duration - await asyncio.sleep(duration) + # Wait for duration or stop event + stop_event.wait(timeout=duration) # Turn LED off self.set_pin_low(pin_number) logger.debug(f"LED {pin_number} turned OFF") - except asyncio.CancelledError: - # Turn off LED if task was cancelled - self.set_pin_low(pin_number) - logger.debug(f"LED {pin_number} task cancelled, LED turned OFF") except Exception as e: - logger.warning(f"LED {pin_number} task error: {e}") + logger.warning(f"LED {pin_number} thread error: {e}") + # Ensure LED is off on error + try: + self.set_pin_low(pin_number) + except Exception: + pass finally: - # Remove from active tasks - if pin_number in self._led_tasks: - del self._led_tasks[pin_number] + # Remove from active threads + if pin_number in self._led_threads: + del self._led_threads[pin_number] + if pin_number in self._led_stop_events: + del self._led_stop_events[pin_number] - def blink_led(self, pin_number: int, duration: float = 3.0) -> None: + def blink_led(self, pin_number: int, duration: float = 0.2) -> None: """ Blink LED for specified duration (non-blocking) Args: pin_number: GPIO pin number for LED - duration: How long to keep LED on (seconds, default: 3.0) + duration: How long to keep LED on (seconds, default: 0.2) """ if pin_number == -1: return # LED disabled @@ -168,19 +468,24 @@ def blink_led(self, pin_number: int, duration: float = 3.0) -> None: return try: - # Cancel any existing LED task for this pin - if pin_number in self._led_tasks and not self._led_tasks[pin_number].done(): - self._led_tasks[pin_number].cancel() - - # Start new LED task - loop = asyncio.get_running_loop() - self._led_tasks[pin_number] = loop.create_task( - self._led_blink_task(pin_number, duration) + # Stop any existing LED thread for this pin + if pin_number in self._led_stop_events: + self._led_stop_events[pin_number].set() + if pin_number in self._led_threads: + self._led_threads[pin_number].join(timeout=0.1) + + # Start new LED thread + stop_event = threading.Event() + self._led_stop_events[pin_number] = stop_event + + thread = threading.Thread( + target=self._led_blink_thread, + args=(pin_number, duration, stop_event), + daemon=True, + name=f"GPIO-LED-{pin_number}", ) + thread.start() + self._led_threads[pin_number] = thread - except RuntimeError: - # No event loop running - just turn on LED (won't auto-turn off) - logger.warning(f"No event loop, LED pin {pin_number} turned on (manual off required)") - self.set_pin_high(pin_number) except Exception as e: - logger.warning(f"Failed to start LED task for pin {pin_number}: {e}") + logger.warning(f"Failed to start LED thread for pin {pin_number}: {e}") diff --git a/src/pymc_core/hardware/kiss_serial_wrapper.py b/src/pymc_core/hardware/kiss_serial_wrapper.py index c41f37e..5ff2f6c 100644 --- a/src/pymc_core/hardware/kiss_serial_wrapper.py +++ b/src/pymc_core/hardware/kiss_serial_wrapper.py @@ -66,7 +66,8 @@ def __init__( Initialize KISS Serial Wrapper Args: - port: Serial port device path (e.g., '/dev/ttyUSB0', '/dev/cu.usbserial-0001', 'comm1', etc.) + port: Serial port device path (e.g., '/dev/ttyUSB0', + '/dev/cu.usbserial-0001', 'comm1', etc.) baudrate: Serial communication baud rate (default: 115200) timeout: Serial read timeout in seconds (default: 1.0) kiss_port: KISS port number (0-15, default: 0) @@ -476,7 +477,7 @@ def begin(self): async def send(self, data: bytes) -> None: """ - Send data via KISS TNC + Send data via KISS TNC. Returns None (no metadata available). Args: data: Data to send @@ -487,6 +488,7 @@ async def send(self, data: bytes) -> None: success = self.send_frame(data) if not success: raise Exception("Failed to send frame via KISS TNC") + return None async def wait_for_rx(self) -> bytes: """ diff --git a/src/pymc_core/hardware/lora/LoRaRF/SX126x.py b/src/pymc_core/hardware/lora/LoRaRF/SX126x.py index 421bbd5..31bedd4 100644 --- a/src/pymc_core/hardware/lora/LoRaRF/SX126x.py +++ b/src/pymc_core/hardware/lora/LoRaRF/SX126x.py @@ -1,58 +1,64 @@ import time - import spidev +from ...signal_utils import snr_register_to_db from .base import BaseLoRa - spi = spidev.SpiDev() +_gpio_manager = None -from gpiozero import Device - -# Force gpiozero to use LGPIOFactory - no RPi.GPIO fallback -from gpiozero.pins.lgpio import LGPIOFactory - -Device.pin_factory = LGPIOFactory() - -# GPIOZero helpers for pin management -from gpiozero import DigitalInputDevice, DigitalOutputDevice - -_gpio_pins = {} +def set_gpio_manager(gpio_manager): + """Set the GPIO manager instance to be used by this module""" + global _gpio_manager + _gpio_manager = gpio_manager def _get_output(pin): - if pin not in _gpio_pins: - _gpio_pins[pin] = DigitalOutputDevice(pin) - return _gpio_pins[pin] + """Get output pin via centralized GPIO manager (setup only if needed)""" + if _gpio_manager is None: + raise RuntimeError("GPIO manager not initialized. Call set_gpio_manager() first.") + # Only setup if pin doesn't exist yet + if pin not in _gpio_manager._pins: + _gpio_manager.setup_output_pin(pin, initial_value=True) + return _gpio_manager._pins[pin] def _get_input(pin): - if pin not in _gpio_pins: - _gpio_pins[pin] = DigitalInputDevice(pin) - return _gpio_pins[pin] + """Get input pin via centralized GPIO manager (setup only if needed)""" + if _gpio_manager is None: + raise RuntimeError("GPIO manager not initialized. Call set_gpio_manager() first.") + # Only setup if pin doesn't exist yet + if pin not in _gpio_manager._pins: + _gpio_manager.setup_input_pin(pin) + return _gpio_manager._pins[pin] def _get_output_safe(pin): """Get output pin safely - return None if GPIO busy""" - try: - if pin not in _gpio_pins: - _gpio_pins[pin] = DigitalOutputDevice(pin) - return _gpio_pins[pin] - except Exception as e: - if "GPIO busy" in str(e): + if _gpio_manager is None: + return None + # Only setup if pin doesn't exist yet + if pin not in _gpio_manager._pins: + if not _gpio_manager.setup_output_pin(pin, initial_value=True): return None - raise e + return _gpio_manager._pins.get(pin) def _get_input_safe(pin): """Get input pin safely - return None if GPIO busy""" - try: - if pin not in _gpio_pins: - _gpio_pins[pin] = DigitalInputDevice(pin) - return _gpio_pins[pin] - except Exception as e: - if "GPIO busy" in str(e): + if _gpio_manager is None: + return None + # Only setup if pin doesn't exist yet + if pin not in _gpio_manager._pins: + if not _gpio_manager.setup_input_pin(pin): return None - raise e + return _gpio_manager._pins.get(pin) + + +def _rssi_register_to_dbm(raw_value): + """Convert RSSI register units (-0.5 dBm per LSB) into dBm.""" + if raw_value is None: + return 0.0 + return raw_value / -2.0 class SX126x(BaseLoRa): @@ -386,24 +392,17 @@ def end(self): except Exception: pass - # Close all GPIO pins - global _gpio_pins - for pin_num, pin_obj in list(_gpio_pins.items()): - try: - pin_obj.close() - except Exception: - pass - - _gpio_pins.clear() + # GPIO cleanup is handled by the centralized gpio_manager + # No need to close pins here as they're managed by GPIOPinManager def reset(self) -> bool: reset_pin = _get_output_safe(self._reset) if reset_pin is None: return True # Continue if reset pin unavailable - reset_pin.off() + reset_pin.write(False) # periphery: write(False) = LOW time.sleep(0.001) - reset_pin.on() + reset_pin.write(True) # periphery: write(True) = HIGH return not self.busyCheck() def sleep(self, option=SLEEP_WARM_START): @@ -417,7 +416,7 @@ def wake(self): if self._wake != -1: wake_pin = _get_output_safe(self._wake) if wake_pin: - wake_pin.off() + wake_pin.write(False) time.sleep(0.0005) self.setStandby(self.STANDBY_RC) self._fixResistanceAntenna() @@ -432,7 +431,7 @@ def busyCheck(self, timeout: int = _busyTimeout): return False # Assume not busy to continue t = time.time() - while busy_pin.value: + while busy_pin.read(): # periphery: read() returns True for HIGH if (time.time() - t) > (timeout / 1000): return True return False @@ -496,16 +495,12 @@ def setPins( self._txen = txen self._rxen = rxen self._wake = wake - # gpiozero pins are initialized on first use by _get_output/_get_input + # periphery pins are initialized on first use by _get_output/_get_input _get_output(reset) _get_input(busy) _get_output(self._cs_define) - # Only create a DigitalInputDevice for IRQ if not already managed externally - # (e.g., by main driver with gpiozero.Button). This avoids double allocation errors. - # If you use a Button in the main driver, do NOT call _get_input here. - # Commented out to prevent double allocation: - # if irq != -1: - # _get_input(irq) + # IRQ pin managed externally by sx1262_wrapper.py via gpio_manager + # Do NOT initialize it here to avoid double allocation if txen != -1: _get_output(txen) # if rxen != -1: _get_output(rxen) @@ -573,57 +568,116 @@ def setFrequency(self, frequency: int): self.setRfFrequency(rfFreq) def setTxPower(self, txPower: int, version=TX_POWER_SX1262): - # maximum TX power is 22 dBm and 15 dBm for SX1261 + # ----------------------------- + # Chipset-specific hard limits + # ----------------------------- if txPower > 22: txPower = 22 - elif txPower > 15 and version == self.TX_POWER_SX1261: + if version == self.TX_POWER_SX1261 and txPower > 15: txPower = 15 + if txPower < -17: + txPower = -17 + # Default configuration paDutyCycle = 0x00 hpMax = 0x00 deviceSel = 0x00 - power = 0x0E - if version == self.TX_POWER_SX1261: + paLut = 0x01 + + # ============================= + # SX1262 (E22 modules) + # ============================= + if version == self.TX_POWER_SX1262: + # Per datasheet 13.4.4: power parameter is in dBm directly + powerReg = txPower + + # Configure OCP (Over Current Protection) for high power only + # For high power (≥20 dBm), need 140 mA current limit + # For lower power, leave chip default (matches RadioLib behavior) + if txPower >= 20: + # High power: Set OCP to 140 mA + # Formula: I_max = 2.5 * (OCP + 1) mA + # 0x38 = 56 decimal → (56 + 1) * 2.5 = 142.5 mA + self.setCurrentProtection(0x38) # 140 mA + + # Matches RadioLib's SX1262::setOutputPower() implementation + deviceSel = 0x00 # SX1262 PA (0x00 for SX1262, 0x01 for SX1261) + paDutyCycle = 0x04 # Optimal duty cycle for high power + hpMax = 0x07 # Maximum clamping level (allows full +22 dBm) + + # Note: For E22-900M30S modules, 22 dBm from SX1262 chip + # → ~30 dBm (1W) output via external YP2233W PA + + # ============================= + # SX1261 + # ============================= + elif version == self.TX_POWER_SX1261: deviceSel = 0x01 - # set parameters for PA config and TX params configuration - if txPower == 22: - paDutyCycle = 0x04 - hpMax = 0x07 - power = 0x16 - elif txPower >= 20: - paDutyCycle = 0x03 - hpMax = 0x05 - power = 0x16 - elif txPower >= 17: - paDutyCycle = 0x02 - hpMax = 0x03 - power = 0x16 - elif txPower >= 14 and version == self.TX_POWER_SX1261: - paDutyCycle = 0x04 - hpMax = 0x00 - power = 0x0E - elif txPower >= 14 and version == self.TX_POWER_SX1262: - paDutyCycle = 0x02 - hpMax = 0x02 - power = 0x16 - elif txPower >= 14 and version == self.TX_POWER_SX1268: - paDutyCycle = 0x04 - hpMax = 0x06 - power = 0x0F - elif txPower >= 10 and version == self.TX_POWER_SX1261: - paDutyCycle = 0x01 - hpMax = 0x00 - power = 0x0D - elif txPower >= 10 and version == self.TX_POWER_SX1268: - paDutyCycle = 0x00 - hpMax = 0x03 - power = 0x0F + + if txPower >= 14: + paDutyCycle = 0x04 + hpMax = 0x00 + powerReg = 14 # Cap at max + elif txPower >= 10: + paDutyCycle = 0x01 + hpMax = 0x00 + powerReg = txPower + else: + # Low power mode + paDutyCycle = 0x00 + hpMax = 0x00 + powerReg = txPower + + # ============================= + # SX1268 + # ============================= + elif version == self.TX_POWER_SX1268: + if txPower >= 14: + paDutyCycle = 0x04 + hpMax = 0x06 + powerReg = txPower + deviceSel = 0x01 # High power PA + elif txPower >= 10: + paDutyCycle = 0x00 + hpMax = 0x03 + powerReg = txPower + deviceSel = 0x00 # Low power PA + else: + paDutyCycle = 0x00 + hpMax = 0x00 + powerReg = txPower + deviceSel = 0x00 + + # ============================= + # Unknown version (fallback) + # ============================= else: - return + if txPower == 22: + paDutyCycle = 0x04 + hpMax = 0x07 + deviceSel = 0x01 + powerReg = 22 + elif txPower >= 20: + paDutyCycle = 0x03 + hpMax = 0x05 + deviceSel = 0x01 + powerReg = txPower + elif txPower >= 17: + paDutyCycle = 0x02 + hpMax = 0x03 + deviceSel = 0x01 + powerReg = txPower + else: + paDutyCycle = 0x00 + hpMax = 0x00 + deviceSel = 0x00 + powerReg = txPower - # set power amplifier and TX power configuration - self.setPaConfig(paDutyCycle, hpMax, deviceSel, 0x01) - self.setTxParams(power, self.PA_RAMP_800U) + # ============================= + # APPLY FINAL CONFIG + # ============================= + self.setPaConfig(paDutyCycle, hpMax, deviceSel, paLut) + self.setTxParams(powerReg, self.PA_RAMP_40U) def setRxGain(self, rxGain): # set power saving or boosted gain in register @@ -821,8 +875,8 @@ def beginPacket(self): self.setBufferBaseAddress(self._bufferIndex, (self._bufferIndex + 0xFF) % 0xFF) # save current txen pin state and set txen pin to LOW if self._txen != -1: - self._txState = _get_output(self._txen).value - _get_output(self._txen).off() + self._txState = _get_output(self._txen).read() + _get_output(self._txen).write(False) self._fixLoRaBw500(self._bw) def endPacket(self, timeout: int = TX_SINGLE) -> bool: @@ -900,11 +954,11 @@ def request(self, timeout: int = RX_SINGLE) -> bool: self._statusWait = self.STATUS_RX_CONTINUOUS # save current txen pin state and set txen pin to high if self._txen != -1: - self._txState = _get_output(self._txen).value - _get_output(self._txen).on() + self._txState = _get_output(self._txen).read() + _get_output(self._txen).write(True) # set device to receive mode with configured timeout, single, or continuous operation self.setRx(rxTimeout) - # IRQ event handling should be implemented in the higher-level driver using gpiozero Button + # IRQ event handling should be implemented in the higher-level driver using periphery GPIO return True def listen(self, rxPeriod: int, sleepPeriod: int) -> bool: @@ -925,11 +979,11 @@ def listen(self, rxPeriod: int, sleepPeriod: int) -> bool: sleepPeriod = 0x00FFFFFF # save current txen pin state and set txen pin to high if self._txen != -1: - self._txState = _get_output(self._txen).value - _get_output(self._txen).on() + self._txState = _get_output(self._txen).read() + _get_output(self._txen).write(True) # set device to receive mode with configured receive and sleep period self.setRxDutyCycle(rxPeriod, sleepPeriod) - # IRQ event handling should be implemented in the higher-level driver using gpiozero Button + # IRQ event handling should be implemented in the higher-level driver using periphery GPIO return True def available(self) -> int: @@ -997,18 +1051,12 @@ def wait(self, timeout: int = 0) -> bool: # for transmit, calculate transmit time and set back txen pin to previous state self._transmitTime = time.time() - self._transmitTime if self._txen != -1: - if self._txState: - _get_output(self._txen).on() - else: - _get_output(self._txen).off() + _get_output(self._txen).write(self._txState) elif self._statusWait == self.STATUS_RX_WAIT: # for receive, get received payload length and buffer index and set back txen pin to previous state (self._payloadTxRx, self._bufferIndex) = self.getRxBufferStatus() if self._txen != -1: - if self._txState: - _get_output(self._txen).on() - else: - _get_output(self._txen).off() + _get_output(self._txen).write(self._txState) self._fixRxTimeout() elif self._statusWait == self.STATUS_RX_CONTINUOUS: # for receive continuous, get received payload length and buffer index and clear IRQ status @@ -1052,19 +1100,26 @@ def dataRate(self) -> float: def packetRssi(self) -> float: # get relative signal strength index (RSSI) of last incoming package - (rssiPkt, snrPkt, signalRssiPkt) = self.getPacketStatus() - return rssiPkt / -2.0 + rssi_dbm, _, _ = self.getSignalMetrics() + return rssi_dbm def snr(self) -> float: # get signal to noise ratio (SNR) of last incoming package - (rssiPkt, snrPkt, signalRssiPkt) = self.getPacketStatus() - if snrPkt > 127: - snrPkt = snrPkt - 256 - return snrPkt / 4.0 + _, snr_db, _ = self.getSignalMetrics() + return snr_db def signalRssi(self) -> float: - (rssiPkt, snrPkt, signalRssiPkt) = self.getPacketStatus() - return signalRssiPkt / -2.0 + _, _, signal_rssi_dbm = self.getSignalMetrics() + return signal_rssi_dbm + + def getSignalMetrics(self) -> tuple: + """Return RSSI, SNR, and signal RSSI (all in dB) for the last packet.""" + rssiPkt, snrPkt, signalRssiPkt = self.getPacketStatus() + return ( + _rssi_register_to_dbm(rssiPkt), + snr_register_to_db(snrPkt), + _rssi_register_to_dbm(signalRssiPkt), + ) def rssiInst(self) -> float: return self.getRssiInst() / -2.0 @@ -1096,10 +1151,7 @@ def _interruptTx(self, channel=None): self._transmitTime = time.time() - self._transmitTime # set back txen pin to previous state if self._txen != -1: - if self._txState: - _get_output(self._txen).on() - else: - _get_output(self._txen).off() + _get_output(self._txen).write(self._txState) # store IRQ status self._statusIrq = self.getIrqStatus() # call onTransmit function @@ -1109,10 +1161,7 @@ def _interruptTx(self, channel=None): def _interruptRx(self, channel=None): # set back txen pin to previous state if self._txen != -1: - if self._txState: - _get_output(self._txen).on() - else: - _get_output(self._txen).off() + _get_output(self._txen).write(self._txState) self._fixRxTimeout() # store IRQ status self._statusIrq = self.getIrqStatus() @@ -1435,23 +1484,23 @@ def _writeBytes(self, opCode: int, data: tuple, nBytes: int): # Adaptive CS control based on CS pin type if self._cs_define != 8: # Manual CS pin (like Waveshare GPIO 21) # Simple CS control for manual pins - _get_output(self._cs_define).off() + _get_output(self._cs_define).write(False) buf = [opCode] for i in range(nBytes): buf.append(data[i]) spi.xfer2(buf) - _get_output(self._cs_define).on() + _get_output(self._cs_define).write(True) else: # Kernel CS pin (like ClockworkPi GPIO 8) # Timing-based CS control for kernel CS pins - _get_output(self._cs_define).on() # Initial high state - _get_output(self._cs_define).off() + _get_output(self._cs_define).write(True) # Initial high state + _get_output(self._cs_define).write(False) time.sleep(0.000001) # 1µs setup time for CS buf = [opCode] for i in range(nBytes): buf.append(data[i]) spi.xfer2(buf) time.sleep(0.000001) # 1µs hold time before CS release - _get_output(self._cs_define).on() + _get_output(self._cs_define).write(True) def _readBytes(self, opCode: int, nBytes: int, address: tuple = (), nAddress: int = 0) -> tuple: if self.busyCheck(): @@ -1460,18 +1509,18 @@ def _readBytes(self, opCode: int, nBytes: int, address: tuple = (), nAddress: in # Adaptive CS control based on CS pin type if self._cs_define != 8: # Manual CS pin (like Waveshare GPIO 21) # Simple CS control for manual pins - _get_output(self._cs_define).off() + _get_output(self._cs_define).write(False) buf = [opCode] for i in range(nAddress): buf.append(address[i]) for i in range(nBytes): buf.append(0x00) feedback = spi.xfer2(buf) - _get_output(self._cs_define).on() + _get_output(self._cs_define).write(True) else: # Kernel CS pin (like ClockworkPi GPIO 8) # Timing-based CS control for kernel CS pins - _get_output(self._cs_define).on() # Initial high state - _get_output(self._cs_define).off() + _get_output(self._cs_define).write(True) # Initial high state + _get_output(self._cs_define).write(False) time.sleep(0.000001) # 1µs setup time for CS buf = [opCode] for i in range(nAddress): @@ -1480,7 +1529,7 @@ def _readBytes(self, opCode: int, nBytes: int, address: tuple = (), nAddress: in buf.append(0x00) feedback = spi.xfer2(buf) time.sleep(0.000001) # 1µs hold time before CS release - _get_output(self._cs_define).on() + _get_output(self._cs_define).write(True) return tuple(feedback[nAddress + 1 :]) diff --git a/src/pymc_core/hardware/lora/LoRaRF/SX127x.py b/src/pymc_core/hardware/lora/LoRaRF/SX127x.py index 790fe0a..d88d64f 100644 --- a/src/pymc_core/hardware/lora/LoRaRF/SX127x.py +++ b/src/pymc_core/hardware/lora/LoRaRF/SX127x.py @@ -2,6 +2,7 @@ import spidev +from ...signal_utils import snr_register_to_db from .base import BaseLoRa spi = spidev.SpiDev() @@ -31,6 +32,13 @@ def _get_input(pin): return _gpio_pins[pin] +def _rssi_register_to_dbm(raw_value): + """Convert RSSI register units (-0.5 dBm per LSB) into dBm.""" + if raw_value is None: + return 0.0 + return raw_value / -2.0 + + class SX126x(BaseLoRa): """Class for SX1261/62/68 and LLCC68 LoRa chipsets from Semtech""" @@ -989,19 +997,26 @@ def dataRate(self) -> float: def packetRssi(self) -> float: # get relative signal strength index (RSSI) of last incoming package - (rssiPkt, snrPkt, signalRssiPkt) = self.getPacketStatus() - return rssiPkt / -2.0 + rssi_dbm, _, _ = self.getSignalMetrics() + return rssi_dbm def snr(self) -> float: # get signal to noise ratio (SNR) of last incoming package - (rssiPkt, snrPkt, signalRssiPkt) = self.getPacketStatus() - if snrPkt > 127: - snrPkt = snrPkt - 256 - return snrPkt / 4.0 + _, snr_db, _ = self.getSignalMetrics() + return snr_db def signalRssi(self) -> float: - (rssiPkt, snrPkt, signalRssiPkt) = self.getPacketStatus() - return signalRssiPkt / -2.0 + _, _, signal_rssi_dbm = self.getSignalMetrics() + return signal_rssi_dbm + + def getSignalMetrics(self) -> tuple: + """Return RSSI, SNR, and signal RSSI (all in dB) for the last packet.""" + rssiPkt, snrPkt, signalRssiPkt = self.getPacketStatus() + return ( + _rssi_register_to_dbm(rssiPkt), + snr_register_to_db(snrPkt), + _rssi_register_to_dbm(signalRssiPkt), + ) def rssiInst(self) -> float: return self.getRssiInst() / -2.0 diff --git a/src/pymc_core/hardware/lora/LoRaRF/__init__.py b/src/pymc_core/hardware/lora/LoRaRF/__init__.py index 17668a0..5dd5c74 100644 --- a/src/pymc_core/hardware/lora/LoRaRF/__init__.py +++ b/src/pymc_core/hardware/lora/LoRaRF/__init__.py @@ -1,4 +1,5 @@ # __init__.py from .base import BaseLoRa from .SX126x import SX126x -from .SX127x import SX127x + +# from .SX127x import SX127x # Commented out to avoid lgpio initialization when not using SX127x hardware diff --git a/src/pymc_core/hardware/signal_utils.py b/src/pymc_core/hardware/signal_utils.py new file mode 100644 index 0000000..65b57c0 --- /dev/null +++ b/src/pymc_core/hardware/signal_utils.py @@ -0,0 +1,25 @@ +"""Helpers for translating raw radio signal metrics to engineering units.""" + +from __future__ import annotations + + +def snr_register_to_db(raw_value: int | None, *, bits: int = 8) -> float: + """Convert signed SX126x/SX127x SNR register (value * 4) into dB. + + Args: + raw_value: Raw register value as read from firmware/packet (unsigned). + bits: Width, in bits, of the stored value. Defaults to 8-bit registers but + discovery responses may use 16-bit fields. + """ + if raw_value is None: + return 0.0 + if bits <= 0 or bits > 32: + raise ValueError("bits must be between 1 and 32") + + max_value = 1 << bits + mask = max_value - 1 + value = raw_value & mask + sign_bit = 1 << (bits - 1) + if value >= sign_bit: + value -= max_value + return value / 4.0 diff --git a/src/pymc_core/hardware/sx1262_wrapper.py b/src/pymc_core/hardware/sx1262_wrapper.py index e5584c4..c8d54e5 100644 --- a/src/pymc_core/hardware/sx1262_wrapper.py +++ b/src/pymc_core/hardware/sx1262_wrapper.py @@ -1,11 +1,6 @@ """ SX1262 LoRa Radio Driver for Raspberry Pi Implements the LoRaRadio interface using the SX126x library - - -I have made some experimental changes to the cad section that I need to revisit. - - """ import asyncio @@ -17,7 +12,7 @@ from .base import LoRaRadio from .gpio_manager import GPIOPinManager -from .lora.LoRaRF.SX126x import SX126x +from .lora.LoRaRF.SX126x import SX126x, set_gpio_manager logger = logging.getLogger("SX1262_wrapper") @@ -115,6 +110,7 @@ def __init__( self.lora: Optional[SX126x] = None self.last_rssi: int = -99 self.last_snr: float = 0.0 + self.last_signal_rssi: int = -99 self._initialized = False self._rx_lock = asyncio.Lock() self._tx_lock = asyncio.Lock() @@ -126,10 +122,24 @@ def __init__( self._txled_pin_setup = False self._rxled_pin_setup = False + # Share GPIO manager instance with SX126x low-level driver + # This ensures singleton behavior - all GPIO access goes through one manager + set_gpio_manager(self._gpio_manager) + self._tx_done_event = asyncio.Event() self._rx_done_event = asyncio.Event() self._cad_event = asyncio.Event() + # Store last IRQ status for background task + self._last_irq_status = 0 + + # Track event loop for thread-safe interrupt handling + self._event_loop = None + + # Store CAD results from interrupt handler + self._last_cad_detected = False + self._last_cad_irq_status = 0 + # Custom CAD thresholds (None means use defaults) self._custom_cad_peak = None self._custom_cad_min = None @@ -170,6 +180,18 @@ def _get_tx_irq_mask(self) -> int: """Get the standard TX interrupt mask""" return self.lora.IRQ_TX_DONE | self.lora.IRQ_TIMEOUT + def _irq_trampoline(self): + """Lightweight trampoline called by GPIO thread - schedules real handler on event loop.""" + try: + if self._event_loop is not None: + self._event_loop.call_soon_threadsafe(self._handle_interrupt) + else: + logger.warning( + "IRQ received before event loop initialized; ignoring early interrupt" + ) + except Exception as e: + logger.error(f"IRQ trampoline error: {e}", exc_info=True) + def _safe_radio_operation( self, operation_name: str, operation_func, success_msg: str = None ) -> bool: @@ -189,7 +211,9 @@ def _safe_radio_operation( def _basic_radio_setup(self, use_busy_check: bool = False) -> bool: """Common radio setup: reset, standby, and LoRa packet type""" self.lora.reset() + time.sleep(0.01) # Give hardware time to complete reset self.lora.setStandby(self.lora.STANDBY_RC) + time.sleep(0.01) # Give hardware time to enter standby mode # Check if standby mode was set correctly (different methods for different boards) if use_busy_check: @@ -205,62 +229,89 @@ def _basic_radio_setup(self, use_busy_check: bool = False) -> bool: return True def _handle_interrupt(self): - """Simple instance method interrupt handler""" - logger.debug("Interrupt handler called!") + """instance method interrupt handler""" + try: if not self._initialized or not self.lora: logger.warning("Interrupt called but radio not initialized") return - # Read IRQ status and handle irqStat = self.lora.getIrqStatus() - logger.debug(f"Interrupt IRQ status: 0x{irqStat:04X}") - # Log specific interrupt types for debugging + if irqStat != 0: + self.lora.clearIrqStatus(irqStat) + + self._last_irq_status = irqStat if irqStat & self.lora.IRQ_TX_DONE: logger.debug("[TX] TX_DONE interrupt (0x{:04X})".format(self.lora.IRQ_TX_DONE)) self._tx_done_event.set() - # Check for CAD interrupts (needed for LBT) if irqStat & (self.lora.IRQ_CAD_DETECTED | self.lora.IRQ_CAD_DONE): cad_detected = bool(irqStat & self.lora.IRQ_CAD_DETECTED) - cad_done = bool(irqStat & self.lora.IRQ_CAD_DONE) - logger.debug( - f"[CAD] interrupt detected: {cad_detected}, done: {cad_done} (0x{irqStat:04X})" - ) + if cad_detected: + logger.debug(f"[CAD] Channel activity detected (0x{irqStat:04X})") + else: + logger.debug(f"[CAD] Channel clear detected (0x{irqStat:04X})") + + self._last_cad_detected = cad_detected + self._last_cad_irq_status = irqStat if hasattr(self, "_cad_event"): self._cad_event.set() - # Handle RX interrupts normally - no filtering needed since they're disabled during TX rx_interrupts = self._get_rx_irq_mask() - if irqStat & self.lora.IRQ_RX_DONE: - logger.debug("[RX] RX_DONE interrupt (0x{:04X})".format(self.lora.IRQ_RX_DONE)) - if not self._tx_lock.locked(): - self._rx_done_event.set() - else: - logger.debug("[RX] Ignoring RX_DONE during TX operation") - elif irqStat & self.lora.IRQ_CRC_ERR: - logger.debug("[RX] CRC_ERR interrupt (0x{:04X})".format(self.lora.IRQ_CRC_ERR)) - if not self._tx_lock.locked(): - self._rx_done_event.set() - else: - logger.debug("[RX] Ignoring CRC_ERR during TX operation") - elif irqStat & self.lora.IRQ_TIMEOUT: - logger.debug("[RX] TIMEOUT interrupt (0x{:04X})".format(self.lora.IRQ_TIMEOUT)) - if not self._tx_lock.locked(): - self._rx_done_event.set() - else: - logger.debug("[RX] Ignoring TIMEOUT during TX operation") - elif irqStat & rx_interrupts: - logger.debug(f"[RX] Other RX interrupt detected: 0x{irqStat & rx_interrupts:04X}") - if not self._tx_lock.locked(): - self._rx_done_event.set() - else: + if irqStat & rx_interrupts: + # Define terminal interrupts (packet complete or failed - need action) + terminal_interrupts = ( + self.lora.IRQ_RX_DONE + | self.lora.IRQ_CRC_ERR + | self.lora.IRQ_TIMEOUT + | self.lora.IRQ_HEADER_ERR + ) + + # Log all interrupt types for debugging + if irqStat & self.lora.IRQ_RX_DONE: + logger.debug("[RX] RX_DONE interrupt (0x{:04X})".format(self.lora.IRQ_RX_DONE)) + if irqStat & self.lora.IRQ_CRC_ERR: + logger.debug("[RX] CRC_ERR interrupt (0x{:04X})".format(self.lora.IRQ_CRC_ERR)) + if irqStat & self.lora.IRQ_TIMEOUT: + logger.debug("[RX] TIMEOUT interrupt (0x{:04X})".format(self.lora.IRQ_TIMEOUT)) + if irqStat & self.lora.IRQ_HEADER_ERR: + logger.debug( + "[RX] HEADER_ERR interrupt (0x{:04X})".format(self.lora.IRQ_HEADER_ERR) + ) + if irqStat & self.lora.IRQ_PREAMBLE_DETECTED: logger.debug( - f"[RX] Ignoring spurious interrupt " - f"0x{irqStat & rx_interrupts:04X} during TX operation" + "[RX] PREAMBLE_DETECTED interrupt (0x{:04X})".format( + self.lora.IRQ_PREAMBLE_DETECTED + ) + ) + if irqStat & self.lora.IRQ_SYNC_WORD_VALID: + logger.debug( + "[RX] SYNC_WORD_VALID interrupt (0x{:04X})".format( + self.lora.IRQ_SYNC_WORD_VALID + ) + ) + if irqStat & self.lora.IRQ_HEADER_VALID: + logger.debug( + "[RX] HEADER_VALID interrupt (0x{:04X})".format(self.lora.IRQ_HEADER_VALID) ) + # Only wake the background task for TERMINAL interrupts + # Intermediate interrupts (preamble, sync, header valid) are just progress updates + if irqStat & terminal_interrupts: + if not self._tx_lock.locked(): + self._rx_done_event.set() + logger.debug( + f"[RX] Terminal interrupt 0x{irqStat:04X} - waking background task" + ) + else: + logger.debug( + f"[RX] Ignoring terminal interrupt 0x{irqStat:04X} during TX operation" + ) + else: + # Non-terminal interrupt - just log it, don't wake background task + logger.debug(f"[RX] Progress interrupt 0x{irqStat:04X} - packet still incoming") + except Exception as e: logger.error(f"IRQ handler error: {e}") # Fallback: set both events if we can't read status @@ -283,9 +334,11 @@ def set_rx_callback(self, callback): ): try: loop = asyncio.get_running_loop() + # Capture event loop for thread-safe interrupt handling + self._event_loop = loop self._rx_irq_task = loop.create_task(self._rx_irq_background_task()) except RuntimeError: - pass + logger.debug("No event loop available for RX task startup") except Exception as e: logger.warning(f"Failed to start delayed RX IRQ background handler: {e}") @@ -293,109 +346,146 @@ async def _rx_irq_background_task(self): """Background task: waits for RX_DONE IRQ and processes received packets automatically.""" logger.debug("[RX] Starting RX IRQ background task") rx_check_count = 0 - while self._initialized: - if self._interrupt_setup: - # Wait for RX_DONE event - try: - await asyncio.wait_for( - self._rx_done_event.wait(), timeout=self.RADIO_TIMING_DELAY - ) - self._rx_done_event.clear() - logger.debug("[RX] RX_DONE event triggered!") - - # Mark that we're processing a packet (prevents noise floor sampling) - self._is_receiving_packet = True - self._last_packet_activity = time.time() + while self._initialized: + try: + if self._interrupt_setup: + # Wait for RX_DONE event try: - # Read and process the received packet - irqStat = self.lora.getIrqStatus() - logger.debug(f"[RX] IRQ Status: 0x{irqStat:04X}") - - # Clear ALL interrupt flags immediately to prevent duplicate processing - if irqStat != 0: - self.lora.clearIrqStatus(irqStat) - - if irqStat & self.lora.IRQ_RX_DONE: - ( - payloadLengthRx, - rxStartBufferPointer, - ) = self.lora.getRxBufferStatus() - rssiPkt, snrPkt, signalRssiPkt = self.lora.getPacketStatus() - self.last_rssi = int(rssiPkt / -2) - self.last_snr = snrPkt / 4 + await asyncio.wait_for( + self._rx_done_event.wait(), timeout=self.RADIO_TIMING_DELAY + ) + self._rx_done_event.clear() + logger.debug("[RX] RX_DONE event triggered!") - logger.debug( - f"[RX] Packet received: length={payloadLengthRx}, " - f"RSSI={self.last_rssi}dBm, SNR={self.last_snr}dB" - ) + # Mark that we're processing a packet (prevents noise floor sampling) + self._is_receiving_packet = True + self._last_packet_activity = time.time() - # Trigger RX LED - self._gpio_manager.blink_led(self.rxled_pin) + try: + # Use the IRQ status stored by the interrupt handler + irqStat = self._last_irq_status + if irqStat & self.lora.IRQ_RX_DONE: + ( + payloadLengthRx, + rxStartBufferPointer, + ) = self.lora.getRxBufferStatus() + ( + packet_rssi_dbm, + snr_db, + signal_rssi_dbm, + ) = self.lora.getSignalMetrics() + self.last_rssi = int(packet_rssi_dbm) + self.last_snr = snr_db + self.last_signal_rssi = int(signal_rssi_dbm) - if payloadLengthRx > 0: - buffer = self.lora.readBuffer(rxStartBufferPointer, payloadLengthRx) - packet_data = bytes(buffer) logger.debug( - f"[RX] Packet data: {packet_data.hex()[:32]}... " - f"({len(packet_data)} bytes)" + f"[RX] Packet received: length={payloadLengthRx}, " + f"RSSI={self.last_rssi}dBm, SNR={self.last_snr}dB" ) - # Call user RX callback if set - if self.rx_callback: - try: - logger.debug("[RX] Calling dispatcher callback") - self.rx_callback(packet_data) - except Exception as cb_exc: - logger.error(f"RX callback error: {cb_exc}") + # Trigger RX LED + self._gpio_manager.blink_led(self.rxled_pin) + + if payloadLengthRx > 0: + buffer = self.lora.readBuffer( + rxStartBufferPointer, payloadLengthRx + ) + packet_data = bytes(buffer) + logger.debug( + f"[RX] Packet data: {packet_data.hex()[:32]}... " + f"({len(packet_data)} bytes)" + ) + + # Call user RX callback if set + if self.rx_callback: + try: + self.rx_callback(packet_data) + except Exception as cb_exc: + logger.error(f"RX callback error: {cb_exc}") + else: + logger.warning("[RX] No RX callback registered!") else: - logger.warning("[RX] No RX callback registered!") + logger.warning("[RX] Empty packet received") + elif irqStat & self.lora.IRQ_CRC_ERR: + logger.warning("[RX] CRC error detected") + elif irqStat & self.lora.IRQ_TIMEOUT: + logger.warning("[RX] RX timeout detected") + elif irqStat & self.lora.IRQ_HEADER_ERR: + logger.warning( + f"[RX] Header error detected (0x{irqStat:04X}) - " + "corrupted header, restoring RX mode" + ) + elif irqStat & self.lora.IRQ_PREAMBLE_DETECTED: + logger.debug("[RX] Preamble detected - packet incoming") + elif irqStat & self.lora.IRQ_SYNC_WORD_VALID: + logger.debug("[RX] Sync word valid - receiving packet data") + elif irqStat & self.lora.IRQ_HEADER_VALID: + logger.debug( + "[RX] Header valid - packet header received, payload coming" + ) else: - logger.warning("[RX] Empty packet received") - elif irqStat & self.lora.IRQ_CRC_ERR: - logger.warning("[RX] CRC error detected") - elif irqStat & self.lora.IRQ_TIMEOUT: - logger.warning("[RX] RX timeout detected") - elif irqStat & self.lora.IRQ_PREAMBLE_DETECTED: - pass - elif irqStat & self.lora.IRQ_SYNC_WORD_VALID: - pass # Sync word valid - receiving packet data... - elif irqStat & self.lora.IRQ_HEADER_VALID: - pass # Header valid - packet header received, payload coming... - elif irqStat & self.lora.IRQ_HEADER_ERR: - pass # Header error - corrupted header, packet dropped - else: - pass # Other RX interrupt - - # Always restore RX continuous mode after processing any interrupt - # This ensures the radio stays ready for the next packet - try: - self.lora.setRx(self.lora.RX_CONTINUOUS) - await asyncio.sleep(self.RADIO_TIMING_DELAY) + logger.debug(f"[RX] Other interrupt: 0x{irqStat:04X}") + + # Always restore RX continuous mode after processing any interrupt + # This ensures the radio stays ready for the next packet + try: + self.lora.request(self.lora.RX_CONTINUOUS) + await asyncio.sleep(self.RADIO_TIMING_DELAY) + logger.debug( + f"[RX] Restored RX continuous mode after IRQ 0x{irqStat:04X}" + ) + except Exception as e: + logger.error(f"Failed to restore RX mode: {e}") except Exception as e: - logger.debug(f"Failed to restore RX mode: {e}") - except Exception as e: - logger.error(f"[IRQ RX] Error processing received packet: {e}") - finally: - # Clear packet processing flag - self._is_receiving_packet = False - - except asyncio.TimeoutError: - # No RX event within timeout - normal operation - rx_check_count += 1 - - # Sample noise floor during quiet periods - self._sample_noise_floor() - - # Log every 500 checks (roughly every 5 seconds) to show RX task is alive - if rx_check_count % 500 == 0: - logger.debug( - f"[RX Task] Status check #{rx_check_count}, " - f"noise_floor={self._noise_floor:.1f}dBm" - ) + logger.error(f"[IRQ RX] Error processing received packet: {e}") + finally: + # Clear packet processing flag + self._is_receiving_packet = False - else: - await asyncio.sleep(0.1) # Longer delay when interrupts not set up + except asyncio.TimeoutError: + # No RX event within timeout - normal operation + rx_check_count += 1 + + # Sample noise floor during quiet periods + self._sample_noise_floor() + + # Log every 500 checks (roughly every 5 seconds) to show RX task is alive + if rx_check_count % 500 == 0: + logger.debug( + f"[RX Task] Status check #{rx_check_count}, " + f"noise_floor={self._noise_floor:.1f}dBm" + ) + + else: + await asyncio.sleep(0.1) + + except Exception as e: + logger.error(f"[RX Task] Unexpected error: {e}") + await asyncio.sleep(1.0) # Wait and continue + + logger.warning("[RX] RX IRQ background task exiting") + + def check_radio_health(self) -> bool: + """Simple health check - restart RX task if it's dead.""" + if not self._initialized: + return False + + # Check if RX task is dead and restart it + if ( + not hasattr(self, "_rx_irq_task") + or self._rx_irq_task is None + or self._rx_irq_task.done() + ): + try: + loop = asyncio.get_running_loop() + self._rx_irq_task = loop.create_task(self._rx_irq_background_task()) + logger.warning("[RX] Restarted dead RX task") + return False # Was dead, now restarted + except Exception: + return False # Failed to restart + + return True # Task is alive def begin(self) -> bool: """Initialize the SX1262 radio module. Returns True if successful, False otherwise.""" @@ -407,8 +497,10 @@ def begin(self) -> bool: try: logger.debug("Initializing SX1262 radio...") self.lora = SX126x() + + # Register GPIO interrupt using lightweight trampoline self.irq_pin = self._gpio_manager.setup_interrupt_pin( - self.irq_pin_number, pull_up=False, callback=self._handle_interrupt + self.irq_pin_number, pull_up=False, callback=self._irq_trampoline ) if self.irq_pin is not None: @@ -423,9 +515,6 @@ def begin(self) -> bool: # Override CS pin for special boards (e.g., Waveshare HAT) self.lora.setManualCsPin(self.cs_pin) - # Don't call setPins! It creates duplicate GPIO objects that conflict - # with our Button/GPIOManager - # Instead, manually set the pin variables the SX126x needs self.lora._reset = self.reset_pin self.lora._busy = self.busy_pin self.lora._irq = self.irq_pin_number @@ -495,19 +584,20 @@ def begin(self) -> bool: False, # IQ standard ) - self.lora.setPaConfig(0x02, 0x03, 0x00, 0x01) - self.lora.setTxParams(self.tx_power, self.lora.PA_RAMP_200U) + # Use RadioLib-compatible PA configuration and optimized setTxPower + # This automatically configures PA based on requested power level + self.lora.setTxPower(self.tx_power, self.lora.TX_POWER_SX1262) # Configure RX interrupts (critical for RX functionality!) rx_mask = self._get_rx_irq_mask() - self.lora.setDioIrqParams(rx_mask, rx_mask, self.lora.IRQ_NONE, self.lora.IRQ_NONE) self.lora.clearIrqStatus(0xFFFF) + self.lora.setDioIrqParams(rx_mask, rx_mask, self.lora.IRQ_NONE, self.lora.IRQ_NONE) else: # Use full initialization # Reset RF module and set to standby if not self._basic_radio_setup(use_busy_check=True): return False - + self.lora._fixResistanceAntenna() # Configure TCXO, regulator, calibration and RF switch if self.use_dio3_tcxo: # Map voltage to DIO3 constants following Meshtastic pattern @@ -544,7 +634,7 @@ def begin(self) -> bool: self.lora.setRegulatorMode(self.lora.REGULATOR_DC_DC) self.lora.calibrate(0x7F) - self.lora.setDio2RfSwitch() + self.lora.setDio2RfSwitch(False) # Set packet type and frequency rfFreq = int(self.frequency * 33554432 / 32000000) @@ -552,8 +642,10 @@ def begin(self) -> bool: # Set RX gain and TX power self.lora.writeRegister(self.lora.REG_RX_GAIN, [self.lora.RX_GAIN_POWER_SAVING], 1) - self.lora.setPaConfig(0x02, 0x03, 0x00, 0x01) - self.lora.setTxParams(self.tx_power, self.lora.PA_RAMP_200U) + # Use setTxPower for automatic PA configuration based on power level + # For E22 modules: 22 dBm from SX1262 → ~30 dBm (1W) via external YP2233W PA + logger.info(f"Setting TX power to {self.tx_power} dBm during initialization") + self.lora.setTxPower(self.tx_power, self.lora.TX_POWER_SX1262) # Configure modulation and packet parameters # Enable LDRO if symbol duration > 16ms (SF11/62.5kHz = 32.768ms) @@ -576,8 +668,10 @@ def begin(self) -> bool: # Configure RX interrupts rx_mask = self._get_rx_irq_mask() - self.lora.setDioIrqParams(rx_mask, rx_mask, self.lora.IRQ_NONE, self.lora.IRQ_NONE) self.lora.clearIrqStatus(0xFFFF) + self.lora.setDioIrqParams(rx_mask, rx_mask, self.lora.IRQ_NONE, self.lora.IRQ_NONE) + # Configure RX gain for maximum sensitivity (boosted mode) + self.lora.setRxGain(self.lora.RX_GAIN_BOOSTED) # Program custom CAD thresholds to chip hardware if available if self._custom_cad_peak is not None and self._custom_cad_min is not None: @@ -597,8 +691,7 @@ def begin(self) -> bool: except Exception as e: logger.warning(f"Failed to write CAD thresholds: {e}") - # Set to RX continuous mode for initial operation - self.lora.setRx(self.lora.RX_CONTINUOUS) + self.lora.request(self.lora.RX_CONTINUOUS) self._initialized = True logger.info("SX1262 radio initialized successfully") @@ -614,6 +707,8 @@ def begin(self) -> bool: ): try: loop = asyncio.get_running_loop() + # Capture event loop for thread-safe interrupt handling + self._event_loop = loop except RuntimeError: # No event loop running, we'll start the task later # when one is available @@ -633,20 +728,48 @@ def begin(self) -> bool: raise RuntimeError(f"Failed to initialize SX1262 radio: {e}") from e def _calculate_tx_timeout(self, packet_length: int) -> tuple[int, int]: - """Calculate transmission timeout using C++ MeshCore formula""" + """ + Calculate the LoRa packet airtime and transmission timeout using the standard + Semtech formula. + + This method implements the LoRa airtime calculation as described in the Semtech + LoRa Modem Designer's Guide (AN1200.13, section 4.1), taking into account the + following parameters: + - Spreading Factor (SF) + - Bandwidth (BW) + - Coding Rate (CR) + - Preamble length + - Explicit/implicit header mode (always explicit here) + - CRC enabled (always enabled here) + - Low Data Rate Optimization (enabled if SF >= 11 and BW <= 125 kHz) + - Payload length (packet_length) - symbol_time = float(1 << self.spreading_factor) / float(self.bandwidth) - preamble_time = (self.preamble_length + 4.25) * symbol_time - tmp = (8 * packet_length) - (4 * self.spreading_factor) + 28 + 16 - # CRC is enabled - tmp -= 16 + Returns: + timeout_ms (int): Calculated packet transmission timeout in milliseconds + (airtime + margin). + driver_timeout (int): Timeout value in units required by the radio driver + (typically ms * 64). + """ + sf = self.spreading_factor + bw_hz = int(self.bandwidth) # your class already stores Hz + cr = self.coding_rate # 1→4/5, 2→4/6, 3→4/7, 4→4/8 + preamble = self.preamble_length + crc_on = True # you always enable CRC + explicit_header = True # you always use explicit header + low_dr_opt = 1 if (sf >= 11 and bw_hz <= 125000) else 0 + symbol_time = (1 << sf) / float(bw_hz) + preamble_time = (preamble + 4.25) * symbol_time + ih = 0 if explicit_header else 1 + crc = 1 if crc_on else 0 + + tmp = 8 * packet_length - 4 * sf + 28 + 16 * crc - 20 * ih + + denom = 4 * (sf - 2 * low_dr_opt) if tmp > 0: - payload_symbols = 8.0 + math.ceil(float(tmp) / float(4 * self.spreading_factor)) * ( - self.coding_rate + 4 - ) + payload_symbols = 8 + max(math.ceil(tmp / denom) * (cr + 4), 0) else: - payload_symbols = 8.0 + payload_symbols = 8 payload_time = payload_symbols * symbol_time air_time_ms = (preamble_time + payload_time) * 1000.0 @@ -654,14 +777,14 @@ def _calculate_tx_timeout(self, packet_length: int) -> tuple[int, int]: driver_timeout = timeout_ms * 64 logger.debug( - f"TX timing SF{self.spreading_factor}/{self.bandwidth/1000:.1f}kHz " - f"CR4/{self.coding_rate} {packet_length}B: " - f"symbol={symbol_time*1000:.1f}ms, " - f"preamble={preamble_time*1000:.0f}ms, " + f"TX timing SF{sf}/{bw_hz/1000:.1f}kHz " + f"CR4/{cr} {packet_length}B: " + f"symbol={symbol_time*1000:.3f}ms, " + f"preamble={preamble_time*1000:.1f}ms, " f"tmp={tmp}, " f"payload_syms={payload_symbols:.1f}, " - f"payload={payload_time*1000:.0f}ms, " - f"air_time={air_time_ms:.0f}ms, " + f"payload={payload_time*1000:.1f}ms, " + f"air_time={air_time_ms:.1f}ms, " f"timeout={timeout_ms}ms, " f"driver_timeout={driver_timeout}" ) @@ -670,13 +793,8 @@ def _calculate_tx_timeout(self, packet_length: int) -> tuple[int, int]: def _prepare_packet_transmission(self, data_list: list, length: int) -> None: """Prepare radio for packet transmission""" - # Set buffer base address self.lora.setBufferBaseAddress(0x00, 0x80) - - # Write the message to buffer self.lora.writeBuffer(0x00, data_list, length) - - # Configure packet parameters for this transmission headerType = self.lora.HEADER_EXPLICIT preambleLength = self.preamble_length crcType = self.lora.CRC_ON @@ -686,22 +804,18 @@ def _prepare_packet_transmission(self, data_list: list, length: int) -> None: def _setup_tx_interrupts(self) -> None: """Configure interrupts for transmission - TX and CAD only, disable RX interrupts""" - # Set up TX and CAD interrupts only - this prevents spurious RX interrupts during TX mask = self._get_tx_irq_mask() | self.lora.IRQ_CAD_DONE | self.lora.IRQ_CAD_DETECTED self.lora.setDioIrqParams(mask, mask, self.lora.IRQ_NONE, self.lora.IRQ_NONE) - # Clear any existing interrupt flags before starting existing_irq = self.lora.getIrqStatus() if existing_irq != 0: self.lora.clearIrqStatus(existing_irq) - async def _prepare_radio_for_tx(self) -> bool: - """Prepare radio hardware for transmission. Returns True if successful.""" - # Clear the TX done event before starting transmission + async def _prepare_radio_for_tx(self) -> tuple[bool, list[float]]: + """Prepare radio hardware for transmission. Returns (success, lbt_backoff_delays_ms).""" self._tx_done_event.clear() - - # Ensure radio is in standby before TX setup self.lora.setStandby(self.lora.STANDBY_RC) + await asyncio.sleep(self.RADIO_TIMING_DELAY) # Give hardware time to enter standby if self.lora.busyCheck(): busy_wait = 0 while self.lora.busyCheck() and busy_wait < 20: @@ -711,14 +825,19 @@ async def _prepare_radio_for_tx(self) -> bool: # Listen Before Talk (LBT) - Check for channel activity using CAD lbt_attempts = 0 max_lbt_attempts = 5 + lbt_backoff_delays = [] # Track each backoff delay in ms + while lbt_attempts < max_lbt_attempts: try: # Perform CAD with your custom thresholds channel_busy = await self.perform_cad(timeout=0.5) if not channel_busy: - logger.debug(f"Channel clear after {lbt_attempts + 1} CAD checks") + logger.debug( + f"CAD check clear - channel available after {lbt_attempts + 1} attempts" + ) break else: + logger.debug("CAD check still busy - channel activity detected") lbt_attempts += 1 if lbt_attempts < max_lbt_attempts: # Jitter (50-200ms) @@ -728,23 +847,25 @@ async def _prepare_radio_for_tx(self) -> bool: # Cap at 5 seconds maximum backoff_ms = min(backoff_ms, 5000) + # Record this backoff delay + lbt_backoff_delays.append(float(backoff_ms)) + logger.debug( - f"Channel busy (CAD detected activity), backing off {backoff_ms}ms " - f"- attempt {lbt_attempts}/{max_lbt_attempts} (exponential backoff)" + f"CAD backoff - waiting {backoff_ms}ms before retry " + f"(attempt {lbt_attempts}/{max_lbt_attempts})" ) await asyncio.sleep(backoff_ms / 1000.0) else: logger.warning( - f"Channel still busy after {max_lbt_attempts} CAD attempts - tx anyway" + f"CAD max attempts reached - channel still busy after " + f"{max_lbt_attempts} attempts, transmitting anyway" ) except Exception as e: - logger.debug(f"CAD check failed: {e}, proceeding with transmission") + logger.warning(f"CAD check failed: {e}, proceeding with transmission") break - # Set TXEN/RXEN pins for TX mode self._control_tx_rx_pins(tx_mode=True) - # Check busy status before starting transmission if self.lora.busyCheck(): logger.warning("Radio is busy before starting transmission") # Wait for radio to become ready @@ -754,9 +875,9 @@ async def _prepare_radio_for_tx(self) -> bool: busy_timeout += 1 if self.lora.busyCheck(): logger.error("Radio stayed busy - cannot start transmission") - return False + return False, lbt_backoff_delays - return True + return True, lbt_backoff_delays def _control_tx_rx_pins(self, tx_mode: bool) -> None: """Control TXEN/RXEN pins for the E22 module (simple and deterministic).""" @@ -858,7 +979,8 @@ def _finalize_transmission(self) -> None: elif irqStat & self.lora.IRQ_TIMEOUT: logger.warning("TX_TIMEOUT interrupt received - transmission failed") else: - logger.warning(f"Unexpected interrupt status: 0x{irqStat:04X}") + # No warning for 0x0000 - interrupt already cleared by handler + pass # Get transmission stats if available try: @@ -880,21 +1002,12 @@ async def _restore_rx_mode(self) -> None: logger.debug("[TX->RX] Starting RX mode restoration after transmission") try: if self.lora: - # Clear any interrupt flags and set standby self.lora.clearIrqStatus(0xFFFF) self.lora.setStandby(self.lora.STANDBY_RC) - - # Brief delay for radio to settle await asyncio.sleep(0.05) - # Configure full RX interrupts and set RX continuous mode - rx_mask = ( - self._get_rx_irq_mask() | self.lora.IRQ_CAD_DONE | self.lora.IRQ_CAD_DETECTED - ) - self.lora.setDioIrqParams(rx_mask, rx_mask, self.lora.IRQ_NONE, self.lora.IRQ_NONE) - self.lora.setRx(self.lora.RX_CONTINUOUS) + self.lora.request(self.lora.RX_CONTINUOUS) - # Final clear of any spurious flags and we're done await asyncio.sleep(0.05) self.lora.clearIrqStatus(0xFFFF) @@ -903,23 +1016,23 @@ async def _restore_rx_mode(self) -> None: except Exception as e: logger.warning(f"[TX->RX] Failed to restore RX mode after TX: {e}") - async def send(self, data: bytes) -> None: - """Send a packet asynchronously""" + async def send(self, data: bytes) -> dict: + """Send a packet asynchronously. Returns transmission metadata including LBT metrics.""" if not self._initialized or self.lora is None: logger.error("Radio not initialized") - return + return None async with self._tx_lock: try: - # Convert bytes to list of integers data_list = list(data) length = len(data_list) - # Calculate transmission timeout + # Calculate transmission timeout and airtime final_timeout_ms, driver_timeout = self._calculate_tx_timeout(length) - timeout_seconds = (final_timeout_ms / 1000.0) + 3.0 # Add 3 seconds buffer + timeout_seconds = (final_timeout_ms / 1000.0) + 0.5 # Add margin + # Airtime is the timeout minus the 1000ms margin we add + airtime_ms = final_timeout_ms - 1000 - # Prepare packet for transmission self._prepare_packet_transmission(data_list, length) logger.debug( @@ -927,30 +1040,38 @@ async def send(self, data: bytes) -> None: f"(tOut={driver_timeout}) for {length} bytes" ) - if not await self._prepare_radio_for_tx(): - return + # Prepare for TX and capture LBT metrics + tx_ready, lbt_backoff_delays = await self._prepare_radio_for_tx() + if not tx_ready: + return None # Setup TX interrupts AFTER CAD checks (CAD changes interrupt config) self._setup_tx_interrupts() await asyncio.sleep(self.RADIO_TIMING_DELAY) + self.lora.setTxPower(self.tx_power, self.lora.TX_POWER_SX1262) - # Execute the transmission if not await self._execute_transmission(driver_timeout): - return + return None - # Wait for transmission to complete if not await self._wait_for_transmission_complete(timeout_seconds): - return + return None - # Finalize transmission and log results self._finalize_transmission() # Trigger TX LED self._gpio_manager.blink_led(self.txled_pin) + # Build and return transmission metadata + return { + "airtime_ms": airtime_ms, + "lbt_attempts": len(lbt_backoff_delays), + "lbt_backoff_delays_ms": lbt_backoff_delays, + "lbt_channel_busy": len(lbt_backoff_delays) > 0, + } + except Exception as e: logger.error(f"Failed to send packet: {e}") - return + return None finally: # Always leave radio in RX continuous mode after TX await self._restore_rx_mode() @@ -978,6 +1099,10 @@ def get_last_snr(self) -> float: """Return last received SNR in dB""" return self.last_snr + def get_last_signal_rssi(self) -> int: + """Return last received signal RSSI in dBm""" + return self.last_signal_rssi + def _sample_noise_floor(self) -> None: """Sample noise floor""" if not self._initialized or self.lora is None: @@ -1094,6 +1219,7 @@ def get_status(self) -> dict: "coding_rate": self.coding_rate, "last_rssi": self.last_rssi, "last_snr": self.last_snr, + "last_signal_rssi": self.last_signal_rssi, } if self._initialized and self.lora: @@ -1176,6 +1302,7 @@ async def perform_cad( try: # Put radio in standby mode before CAD configuration self.lora.setStandby(self.lora.STANDBY_RC) + await asyncio.sleep(0.01) # Give hardware time to enter standby # Clear any existing interrupt flags existing_irq = self.lora.getIrqStatus() @@ -1194,27 +1321,34 @@ async def perform_cad( 0, # no timeout ) - # Clear CAD event before starting self._cad_event.clear() - - # Start CAD operation self.lora.setCad() + await asyncio.sleep(0.01) # Give hardware time to start CAD operation - logger.debug(f"CAD started with peak={det_peak}, min={det_min}") + logger.debug( + f"CAD operation started - checking channel with peak={det_peak}, min={det_min}" + ) - # Wait for CAD completion try: await asyncio.wait_for(self._cad_event.wait(), timeout=timeout) self._cad_event.clear() - irq = self.lora.getIrqStatus() - logger.debug(f"CAD completed with IRQ status: 0x{irq:04X}") - self.lora.clearIrqStatus(irq) - detected = bool(irq & self.lora.IRQ_CAD_DETECTED) + # Use CAD results stored by interrupt handler (avoids race condition) + irq = self._last_cad_irq_status + detected = self._last_cad_detected cad_done = bool(irq & self.lora.IRQ_CAD_DONE) - if not cad_done: - logger.warning("CAD interrupt received but CAD_DONE flag not set") + logger.debug(f"CAD operation completed - IRQ status: 0x{irq:04X}") + + if detected: + logger.debug("CAD result: BUSY - channel activity detected") + else: + logger.debug("CAD result: CLEAR - no channel activity detected") + + # Clear hardware IRQ status + current_irq = self.lora.getIrqStatus() + if current_irq != 0: + self.lora.clearIrqStatus(current_irq) if calibration: return { @@ -1231,8 +1365,7 @@ async def perform_cad( return detected except asyncio.TimeoutError: - logger.debug("CAD operation timed out") - # Check if there were any interrupt flags set anyway + logger.debug("CAD operation timed out - assuming channel clear") irq = self.lora.getIrqStatus() if irq != 0: logger.debug(f"CAD timeout but IRQ status: 0x{irq:04X}") @@ -1266,11 +1399,16 @@ async def perform_cad( else: return False finally: - # Restore RX mode after CAD try: - rx_mask = self._get_rx_irq_mask() - self.lora.setDioIrqParams(rx_mask, rx_mask, self.lora.IRQ_NONE, self.lora.IRQ_NONE) - self.lora.setRx(self.lora.RX_CONTINUOUS) + # Full sequence required to prevent SX1262 lockups after CAD + self.lora.clearIrqStatus(0xFFFF) + self.lora.setStandby(self.lora.STANDBY_RC) + await asyncio.sleep(0.001) + + # Use standardized RX restoration method like everywhere else + self.lora.request(self.lora.RX_CONTINUOUS) + await asyncio.sleep(0.001) + self.lora.clearIrqStatus(0xFFFF) except Exception as e: logger.warning(f"Failed to restore RX mode after CAD: {e}") diff --git a/src/pymc_core/hardware/wsradio.py b/src/pymc_core/hardware/wsradio.py index 358c7f8..637b3ab 100644 --- a/src/pymc_core/hardware/wsradio.py +++ b/src/pymc_core/hardware/wsradio.py @@ -114,6 +114,7 @@ def begin(self): # Connection will be established when needed async def send(self, data: bytes): + """Send packet via WebSocket. Returns None (no metadata available).""" await self._ensure() try: if self.ws is not None and self._connected: diff --git a/src/pymc_core/node/dispatcher.py b/src/pymc_core/node/dispatcher.py index f405d37..5a0a2ea 100644 --- a/src/pymc_core/node/dispatcher.py +++ b/src/pymc_core/node/dispatcher.py @@ -18,6 +18,7 @@ AckHandler, AdvertHandler, AnonReqResponseHandler, + ControlHandler, GroupTextHandler, LoginResponseHandler, PathHandler, @@ -82,6 +83,9 @@ def __init__( self._recent_acks: dict[int, float] = {} # {crc: timestamp} self._waiting_acks = {} + # Simple TX lock to prevent concurrent transmissions + self._tx_lock = asyncio.Lock() + # Use provided packet filter or create default if packet_filter is not None: self.packet_filter = packet_filter @@ -155,7 +159,7 @@ def register_default_handlers( # Register all the standard handlers self.register_handler( AdvertHandler.payload_type(), - AdvertHandler(contacts, self._log, local_identity, event_service), + AdvertHandler(self._log), ) self.register_handler(AckHandler.payload_type(), ack_handler) @@ -203,7 +207,9 @@ def register_default_handlers( self.telemetry_response_handler = protocol_response_handler # PATH handler - for route discovery packets, with ACK and protocol response processing - path_handler = PathHandler(self._log, ack_handler, protocol_response_handler) + path_handler = PathHandler( + self._log, ack_handler, protocol_response_handler, login_response_handler + ) self.register_handler(PathHandler.payload_type(), path_handler) # Login response handler for PAYLOAD_TYPE_RESPONSE packets @@ -227,6 +233,15 @@ def register_default_handlers( # Keep a reference for the node self.trace_handler = trace_handler + # CONTROL handler for node discovery + control_handler = ControlHandler(self._log) + self.register_handler( + ControlHandler.payload_type(), + control_handler, + ) + # Keep a reference for the node + self.control_handler = control_handler + self._logger.info("Default handlers registered.") # Set up a fallback handler for unknown packet types @@ -284,8 +299,6 @@ def set_raw_packet_callback( def _on_packet_received(self, data: bytes) -> None: """Called by the radio when a packet comes in.""" - self._log(f"[RX DEBUG] Packet received: {len(data)} bytes") - # Schedule the packet processing in the event loop try: loop = asyncio.get_running_loop() @@ -378,6 +391,7 @@ async def send_packet( ) -> bool: """ Send a packet and optionally wait for an ACK. + Uses a lock to serialize transmissions instead of dropping packets. Args: packet: The packet to send @@ -385,22 +399,26 @@ async def send_packet( expected_crc: The expected CRC for ACK matching. If None, will be calculated from packet. """ - payload_type = packet.header >> PH_TYPE_SHIFT + async with self._tx_lock: # Wait our turn + return await self._send_packet_immediate(packet, wait_for_ack, expected_crc) - # ------------------------------------------------------------------ # - # Make sure we're not already busy - # ------------------------------------------------------------------ # - if self.state != DispatcherState.IDLE: - self._log("Busy, skipping TX.") - return False + async def _send_packet_immediate( + self, + packet: Packet, + wait_for_ack: bool = True, + expected_crc: Optional[int] = None, + ) -> bool: + """Send a packet immediately (assumes lock is held).""" + payload_type = packet.header >> PH_TYPE_SHIFT # ------------------------------------------------------------------ # - # Send the packet + # Send the packet (lock ensures only one transmission at a time) # ------------------------------------------------------------------ # self.state = DispatcherState.TRANSMIT raw = packet.write_to() + tx_metadata = None try: - await self.radio.send(raw) + tx_metadata = await self.radio.send(raw) except Exception as e: self._log(f"Radio transmit error: {e}") self.state = DispatcherState.IDLE @@ -410,6 +428,10 @@ async def send_packet( route_name = ROUTE_TYPES.get(packet.get_route_type(), f"UNKNOWN_{packet.get_route_type()}") self._log(f"TX {packet.get_raw_length()} bytes (type={type_name}, route={route_name})") + # Store metadata on packet for access by handlers + if tx_metadata: + packet._tx_metadata = tx_metadata + if self.packet_sent_callback: await self._invoke_callback(self.packet_sent_callback, packet) @@ -516,6 +538,7 @@ def _register_ack_received(self, crc: int) -> None: async def run_forever(self) -> None: """Run the dispatcher maintenance loop indefinitely (call this in an asyncio task).""" + health_check_counter = 0 while True: # Clean out old ACK CRCs (older than 5 seconds) now = asyncio.get_event_loop().time() @@ -524,6 +547,13 @@ async def run_forever(self) -> None: # Clean old packet hashes for deduplication self.packet_filter.cleanup_old_hashes() + # Simple health check every 60 seconds + health_check_counter += 1 + if health_check_counter >= 60: + health_check_counter = 0 + if hasattr(self.radio, "check_radio_health"): + self.radio.check_radio_health() + # With callback-based RX, just do maintenance tasks await asyncio.sleep(1.0) # Check every second for cleanup @@ -576,7 +606,9 @@ def _log(self, msg: str) -> None: def get_filter_stats(self) -> dict: """Get current packet filter statistics.""" - return self.packet_filter.get_stats() + stats = self.packet_filter.get_stats() + stats["tx_lock_locked"] = self._tx_lock.locked() + return stats def clear_packet_filter(self) -> None: """Clear packet filter data.""" @@ -597,3 +629,7 @@ async def _find_contact_by_hash(self, src_hash: int): except Exception: continue return None + + def cleanup(self): + """Clean up resources when shutting down.""" + self._log("Dispatcher cleanup completed") diff --git a/src/pymc_core/node/handlers/__init__.py b/src/pymc_core/node/handlers/__init__.py index 16b4205..27fc5c5 100644 --- a/src/pymc_core/node/handlers/__init__.py +++ b/src/pymc_core/node/handlers/__init__.py @@ -5,6 +5,7 @@ from .ack import AckHandler from .advert import AdvertHandler from .base import BaseHandler +from .control import ControlHandler from .group_text import GroupTextHandler from .login_response import AnonReqResponseHandler, LoginResponseHandler from .path import PathHandler @@ -23,4 +24,5 @@ "ProtocolResponseHandler", "AnonReqResponseHandler", "TraceHandler", + "ControlHandler", ] diff --git a/src/pymc_core/node/handlers/advert.py b/src/pymc_core/node/handlers/advert.py index acc740c..3fd30c1 100644 --- a/src/pymc_core/node/handlers/advert.py +++ b/src/pymc_core/node/handlers/advert.py @@ -1,8 +1,16 @@ import time +from typing import Optional, Dict, Any -from ...protocol import Packet, decode_appdata -from ...protocol.constants import PAYLOAD_TYPE_ADVERT, PUB_KEY_SIZE, describe_advert_flags -from ...protocol.utils import determine_contact_type_from_flags +from ...protocol import Identity, Packet, decode_appdata +from ...protocol.constants import ( + MAX_ADVERT_DATA_SIZE, + PAYLOAD_TYPE_ADVERT, + PUB_KEY_SIZE, + SIGNATURE_SIZE, + TIMESTAMP_SIZE, + describe_advert_flags, +) +from ...protocol.utils import determine_contact_type_from_flags, get_contact_type_name from .base import BaseHandler @@ -11,61 +19,114 @@ class AdvertHandler(BaseHandler): def payload_type() -> int: return PAYLOAD_TYPE_ADVERT - def __init__(self, contacts, log_fn, identity=None, event_service=None): - self.contacts = contacts + def __init__(self, log_fn): self.log = log_fn - self.identity = identity - self.event_service = event_service - - async def __call__(self, packet: Packet) -> None: - pubkey_bytes = packet.payload[:PUB_KEY_SIZE] - pubkey_hex = pubkey_bytes.hex() - - self.log("<<< Advert packet received >>>") - - if self.contacts is not None: - self.log(f"Processing advert for pubkey: {pubkey_hex}") - contact = next((c for c in self.contacts.contacts if c.public_key == pubkey_hex), None) - if contact: - self.log(f"Peer identity already known: {contact.name}") - contact.last_advert = int(time.time()) - else: - self.log(f"<<< New contact discovered (pubkey={pubkey_hex[:8]}...) >>>") - appdata = packet.get_payload_app_data() - decoded = decode_appdata(appdata) - - # Extract name from decoded data - name = decoded.get("node_name") or decoded.get("name") - - # Require valid name - ignore packet if no name present - if not name: - self.log(f"Ignoring advert packet without name (pubkey={pubkey_hex[:8]}...)") - return - - self.log(f"Processing contact with name: {name}") - lon = decoded.get("lon") or 0.0 - lat = decoded.get("lat") or 0.0 - flags_int = decoded.get("flags", 0) - flags = describe_advert_flags(flags_int) - contact_type = determine_contact_type_from_flags(flags_int) - - new_contact_data = { - "type": contact_type, - "name": name, - "longitude": lon, - "latitude": lat, - "flags": flags, - "public_key": pubkey_hex, - "last_advert": int(time.time()), - } - - self.contacts.add_contact(new_contact_data) - - # Publish new contact event - if self.event_service: - try: - from ..events import MeshEvents - - self.event_service.publish_sync(MeshEvents.NEW_CONTACT, new_contact_data) - except Exception as broadcast_error: - self.log(f"Failed to publish new contact event: {broadcast_error}") + + def _extract_advert_components(self, packet: Packet): + """Extract and validate advert packet components.""" + payload = packet.get_payload() + header_len = PUB_KEY_SIZE + TIMESTAMP_SIZE + SIGNATURE_SIZE + if len(payload) < header_len: + self.log( + f"Advert payload too short ({len(payload)} bytes, expected at least {header_len})" + ) + return None + + sig_offset = PUB_KEY_SIZE + TIMESTAMP_SIZE + pubkey = payload[:PUB_KEY_SIZE] + timestamp = payload[PUB_KEY_SIZE:sig_offset] + signature = payload[sig_offset : sig_offset + SIGNATURE_SIZE] + appdata = payload[sig_offset + SIGNATURE_SIZE :] + + if len(appdata) > MAX_ADVERT_DATA_SIZE: + self.log( + f"Advert appdata too large ({len(appdata)} bytes); truncating to {MAX_ADVERT_DATA_SIZE}" + ) + appdata = appdata[:MAX_ADVERT_DATA_SIZE] + + return pubkey, timestamp, signature, appdata + + def _verify_advert_signature( + self, pubkey: bytes, timestamp: bytes, appdata: bytes, signature: bytes + ) -> bool: + """Verify the cryptographic signature of the advert packet.""" + try: + + if len(pubkey) != PUB_KEY_SIZE: + self.log(f"Invalid public key length: {len(pubkey)} bytes (expected {PUB_KEY_SIZE})") + return False + + if len(signature) != SIGNATURE_SIZE: + self.log(f"Invalid signature length: {len(signature)} bytes (expected {SIGNATURE_SIZE})") + return False + + peer_identity = Identity(pubkey) + except ValueError as exc: + self.log(f"Unable to construct peer identity - invalid key format: {exc}") + return False + except Exception as exc: + self.log(f"Unable to construct peer identity: {type(exc).__name__}: {exc}") + return False + + signed_region = pubkey + timestamp + appdata + if not peer_identity.verify(signed_region, signature): + return False + return True + + async def __call__(self, packet: Packet) -> Optional[Dict[str, Any]]: + """Process advert packet and return parsed data with signature verification.""" + try: + # Extract and validate packet components + components = self._extract_advert_components(packet) + if not components: + return None + + pubkey_bytes, timestamp_bytes, signature_bytes, appdata = components + pubkey_hex = pubkey_bytes.hex() + + # Verify cryptographic signature + if not self._verify_advert_signature(pubkey_bytes, timestamp_bytes, appdata, signature_bytes): + self.log(f"Rejecting advert with invalid signature (pubkey={pubkey_hex[:8]}...)") + return None + + self.log(f"Processing advert for pubkey: {pubkey_hex[:16]}...") + + # Decode application data + decoded = decode_appdata(appdata) + + # Extract name from decoded data + name = decoded.get("node_name") or decoded.get("name") + if not name: + self.log(f"Ignoring advert without name (pubkey={pubkey_hex[:8]}...)") + return None + + # Extract location and flags + lon = decoded.get("longitude") or decoded.get("lon") or 0.0 + lat = decoded.get("latitude") or decoded.get("lat") or 0.0 + flags_int = decoded.get("flags", 0) + flags_description = describe_advert_flags(flags_int) + contact_type_id = determine_contact_type_from_flags(flags_int) + contact_type = get_contact_type_name(contact_type_id) + + # Build parsed advert data + advert_data = { + "public_key": pubkey_hex, + "name": name, + "longitude": lon, + "latitude": lat, + "flags": flags_int, + "flags_description": flags_description, + "contact_type_id": contact_type_id, + "contact_type": contact_type, + "timestamp": int(time.time()), + "snr": packet._snr if hasattr(packet, '_snr') else 0.0, + "rssi": packet._rssi if hasattr(packet, '_rssi') else 0, + "valid": True, + } + + self.log(f"Parsed advert: {name} ({contact_type})") + return advert_data + + except Exception as e: + self.log(f"Error parsing advert packet: {e}") + return None diff --git a/src/pymc_core/node/handlers/base.py b/src/pymc_core/node/handlers/base.py index 925af0e..41dc88d 100644 --- a/src/pymc_core/node/handlers/base.py +++ b/src/pymc_core/node/handlers/base.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from typing import Any, Optional class BaseHandler(ABC): @@ -9,5 +10,5 @@ def payload_type() -> int: pass @abstractmethod - async def __call__(self, packet): + async def __call__(self, packet) -> Optional[Any]: pass diff --git a/src/pymc_core/node/handlers/control.py b/src/pymc_core/node/handlers/control.py new file mode 100644 index 0000000..93148b4 --- /dev/null +++ b/src/pymc_core/node/handlers/control.py @@ -0,0 +1,210 @@ +"""Control packet handler for mesh network discovery. + +Handles control packets for node discovery requests and responses. +These are zero-hop packets used for network topology discovery. +""" + +import struct +import time +from typing import Any, Callable, Dict, Optional + +from ...hardware.signal_utils import snr_register_to_db +from ...protocol import Packet +from ...protocol.constants import PAYLOAD_TYPE_CONTROL + +# Control packet type constants (upper 4 bits of first byte) +CTL_TYPE_NODE_DISCOVER_REQ = 0x80 # Discovery request +CTL_TYPE_NODE_DISCOVER_RESP = 0x90 # Discovery response + + +class ControlHandler: + """Handler for control packets (payload type 0x0B). + + Control packets are used for node discovery and network topology mapping. + This handler processes incoming discovery requests and responses. + """ + + def __init__(self, log_fn: Callable[[str], None]): + """Initialize control handler. + + Args: + log_fn: Logging function + """ + self._log = log_fn + + # Callbacks for discovery responses + self._response_callbacks: Dict[int, Callable[[Dict[str, Any]], None]] = {} + self._request_callbacks: Dict[int, Callable[[Dict[str, Any]], None]] = {} + + @staticmethod + def payload_type() -> int: + return PAYLOAD_TYPE_CONTROL + + def set_response_callback( + self, tag: int, callback: Callable[[Dict[str, Any]], None] + ) -> None: + """Set callback for discovery responses with a specific tag.""" + self._response_callbacks[tag] = callback + + def clear_response_callback(self, tag: int) -> None: + """Clear callback for discovery responses with a specific tag.""" + self._response_callbacks.pop(tag, None) + + def set_request_callback( + self, callback: Callable[[Dict[str, Any]], None] + ) -> None: + """Set callback for discovery requests (for logging/monitoring).""" + self._request_callbacks[0] = callback + + def clear_request_callback(self) -> None: + """Clear callback for discovery requests.""" + self._request_callbacks.pop(0, None) + + async def __call__(self, pkt: Packet) -> Optional[Dict[str, Any]]: + """Handle incoming control packet and return parsed data.""" + try: + if not pkt.payload or len(pkt.payload) == 0: + self._log("[ControlHandler] Empty payload, ignoring") + return None + + # Check if this is a zero-hop packet (path_len must be 0) + if pkt.path_len != 0: + self._log( + f"[ControlHandler] Non-zero path length ({pkt.path_len}), ignoring" + ) + return None + + # Extract control type (upper 4 bits of first byte) + control_type = pkt.payload[0] & 0xF0 + + if control_type == CTL_TYPE_NODE_DISCOVER_REQ: + return await self._handle_discovery_request(pkt) + elif control_type == CTL_TYPE_NODE_DISCOVER_RESP: + return await self._handle_discovery_response(pkt) + else: + self._log( + f"[ControlHandler] Unknown control type: 0x{control_type:02X}" + ) + return None + + except Exception as e: + self._log(f"[ControlHandler] Error processing control packet: {e}") + return None + + async def _handle_discovery_request(self, pkt: Packet) -> Optional[Dict[str, Any]]: + """Handle node discovery request packet and return parsed data. + + Expected format: + - byte 0: type (0x80) + flags (bit 0: prefix_only) + - byte 1: filter (bitfield of node types to respond) + - bytes 2-5: tag (uint32_t, little-endian) + - bytes 6-9: since timestamp (uint32_t, optional) + """ + try: + if len(pkt.payload) < 6: + self._log("[ControlHandler] Discovery request too short") + return None + + # Parse request + flags_byte = pkt.payload[0] + prefix_only = (flags_byte & 0x01) != 0 + filter_byte = pkt.payload[1] + tag = struct.unpack("= 10: + since = struct.unpack(" Optional[Dict[str, Any]]: + """Handle node discovery response packet and return parsed data. + + Response format: + - byte 0: type (0x90) + node_type (lower 4 bits) + - byte 1: SNR of our request (int8_t, multiplied by 4) + - bytes 2-5: tag (matches our request) + - bytes 6-onwards: public key (8 or 32 bytes) + """ + try: + if len(pkt.payload) < 6: + self._log("[ControlHandler] Discovery response too short") + return None + + # Parse response + type_byte = pkt.payload[0] + node_type = type_byte & 0x0F + snr_byte = pkt.payload[1] + inbound_snr = snr_register_to_db(snr_byte) + tag = struct.unpack(" int: + return PAYLOAD_TYPE_ANON_REQ + + def __init__( + self, + local_identity, + log_fn: Callable[[str], None], + authenticate_callback: Callable[[Identity, bytes, str, int], tuple[bool, int]], + is_room_server: bool = False, + ): + """ + Initialize login server handler. + + Args: + local_identity: Server's local identity + log_fn: Logging function + authenticate_callback: Function(client_identity, shared_secret, password, timestamp) + Returns: (success: bool, permissions: int) + is_room_server: True if this identity is a room server (expects sync_since field), + False if repeater (no sync_since field) + """ + self.local_identity = local_identity + self.log = log_fn + self.authenticate = authenticate_callback + self.is_room_server = is_room_server + self._send_packet_callback: Optional[Callable[[Packet, int], None]] = None + + def set_send_packet_callback(self, callback: Callable[[Packet, int], None]): + """Set callback for sending response packets.""" + self._send_packet_callback = callback + + async def __call__(self, packet: Packet) -> None: + """Handle ANON_REQ login packet from client.""" + try: + # Debug: Log packet routing info + path_data = list(packet.path[: packet.path_len]) if packet.path_len > 0 else [] + self.log( + f"[LoginServer] Packet route flood: {packet.is_route_flood()}, " + f"path_len: {packet.path_len}, path: {path_data}" + ) + + # Parse ANON_REQ structure: dest_hash(1) + client_pubkey(32) + encrypted_data + if len(packet.payload) < 34: + self.log("[LoginServer] ANON_REQ packet too short") + return + + dest_hash = packet.payload[0] + client_pubkey = bytes(packet.payload[1:33]) + encrypted_data = bytes(packet.payload[33:]) + + # Verify this is for us + our_hash = self.local_identity.get_public_key()[0] + if dest_hash != our_hash: + return # Not for us + + # Create client identity and calculate shared secret + client_identity = Identity(client_pubkey) + shared_secret = client_identity.calc_shared_secret( + self.local_identity.get_private_key() + ) + aes_key = shared_secret[:16] + + # Decrypt the login request + try: + plaintext = CryptoUtils.mac_then_decrypt(aes_key, shared_secret, encrypted_data) + except Exception as e: + self.log(f"[LoginServer] Failed to decrypt login request: {e}") + return + + if len(plaintext) < 4: + self.log("[LoginServer] Decrypted data too short") + return + + # Parse plaintext - two formats: + # Repeater format: timestamp(4) + password(variable) + null + # Room server format: timestamp(4) + sync_since(4) + password(variable) + null + client_timestamp = struct.unpack(" 0 else '(empty)'}") + else: + # Repeater format: password only + # Find null terminator after timestamp (starting from byte 4) + null_idx = plaintext.find(b"\x00", 4) + if null_idx == -1: + null_idx = len(plaintext) + + password_bytes = plaintext[4:null_idx] + self.log(f"[LoginServer] Repeater format: password from byte 4 to {null_idx}") + + # Null-terminate password + null_idx = password_bytes.find(b"\x00") + if null_idx >= 0: + password_bytes = password_bytes[:null_idx] + password = password_bytes.decode("utf-8", errors="ignore") + + self.log( + f"[LoginServer] Login request from {client_pubkey[:6].hex()}... " + f"password={'' if not password else ''}" + ) + + # Call application authentication logic with optional sync_since parameter + # For backwards compatibility, check if authenticate accepts sync_since + import inspect + + sig = inspect.signature(self.authenticate) + if "sync_since" in sig.parameters: + success, permissions = self.authenticate( + client_identity, shared_secret, password, client_timestamp, sync_since + ) + else: + # Old signature without sync_since + success, permissions = self.authenticate( + client_identity, shared_secret, password, client_timestamp + ) + + if success: + self.log("[LoginServer] Authentication successful") + # Send success response + await self._send_login_response( + client_identity, + shared_secret, + packet.is_route_flood(), + RESP_SERVER_LOGIN_OK, + permissions, + packet, + ) + else: + self.log("[LoginServer] Authentication failed") + # Optionally send failure response (or just ignore) + # Most implementations just ignore failed attempts + + except Exception as e: + self.log(f"[LoginServer] Error handling login packet: {e}") + + async def _send_login_response( + self, + client_identity: Identity, + shared_secret: bytes, + is_flood: bool, + response_code: int, + permissions: int, + original_packet: Packet = None, + ): + """Build and send login response packet to client.""" + if self._send_packet_callback is None: + self.log("[LoginServer] No send packet callback set, cannot send response") + return + + try: + # Build response data (13 bytes total) + # timestamp(4) + response_code(1) + keep_alive(1) + is_admin(1) + + # permissions(1) + random(4) + firmware_ver(1) + reply_data = bytearray(13) + current_time = int(time.time()) + + struct.pack_into(" 0 + else [] + ) + + self.log( + f"[LoginServer] Creating PATH response: " + f"client_hash=0x{client_hash:02X}, " + f"server_hash=0x{server_hash:02X}, path={path_list}, " + f"original_flood={is_flood}" + ) + + response_pkt = PacketBuilder.create_path_return( + dest_hash=client_hash, + src_hash=server_hash, + secret=shared_secret, + path=path_list, + extra_type=PAYLOAD_TYPE_RESPONSE, + extra=bytes(reply_data), + ) + packet_type_name = "PATH" + + # Debug: Log packet details + self.log( + f"[LoginServer] RESPONSE packet details: " + f"header=0x{response_pkt.header:02X}, " + f"payload_len={response_pkt.payload_len}, " + f"path_len={response_pkt.path_len}, " + f"payload[0:2]={bytes(response_pkt.payload[:2]).hex()}" + ) + + # Send with delay (matches C++ SERVER_RESPONSE_DELAY) + delay_ms = 300 + self._send_packet_callback(response_pkt, delay_ms) + + self.log( + f"[LoginServer] Sent login response ({packet_type_name}) to " + f"{client_identity.get_public_key()[:6].hex()}..." + ) + + except Exception as e: + self.log(f"[LoginServer] Failed to send login response: {e}") diff --git a/src/pymc_core/node/handlers/path.py b/src/pymc_core/node/handlers/path.py index 98eddf1..225190c 100644 --- a/src/pymc_core/node/handlers/path.py +++ b/src/pymc_core/node/handlers/path.py @@ -30,10 +30,12 @@ def __init__( log_fn: Callable[[str], None], ack_handler=None, protocol_response_handler=None, + login_response_handler=None, ): self._log = log_fn self._ack_handler = ack_handler self._protocol_response_handler = protocol_response_handler + self._login_response_handler = login_response_handler @staticmethod def payload_type() -> int: @@ -50,6 +52,10 @@ async def __call__(self, pkt: Packet) -> None: if self._protocol_response_handler: await self._protocol_response_handler(pkt) + # Then, check if this PATH packet contains login responses + if self._login_response_handler: + await self._login_response_handler(pkt) + # Then, check if this PATH packet contains ACKs and delegate to ACK handler if self._ack_handler: ack_crc = await self._ack_handler.process_path_ack_variants(pkt) @@ -75,10 +81,19 @@ async def __call__(self, pkt: Packet) -> None: # Extract and log key PATH information directly from packet try: payload = pkt.get_payload() + hop_count = pkt.path_len if len(payload) >= 2: - hop_count = payload[1] - self._log(f"PATH packet: hop_count={hop_count}, payload_len={len(payload)}") - self._log(f"Path contains {hop_count} hops") + dest_hash = payload[0] + src_hash = payload[1] + self._log( + f"PATH packet: hop_count={hop_count}, " + f"dest=0x{dest_hash:02X}, src=0x{src_hash:02X}, " + f"payload_len={len(payload)}" + ) + if hop_count > 0: + self._log(f"Path contains {hop_count} hops") + else: + self._log("Direct PATH (no intermediate hops)") else: self._log("PATH packet received with minimal payload") diff --git a/src/pymc_core/node/handlers/protocol_request.py b/src/pymc_core/node/handlers/protocol_request.py new file mode 100644 index 0000000..6541635 --- /dev/null +++ b/src/pymc_core/node/handlers/protocol_request.py @@ -0,0 +1,240 @@ +""" +Protocol request handler for authenticated client requests. + +Handles REQ packets and sends RESPONSE packets with requested data. +""" + +import struct +from typing import Optional, Callable, Any + +from pymc_core.protocol.constants import PAYLOAD_TYPE_REQ, PAYLOAD_TYPE_RESPONSE +from pymc_core.protocol.crypto import CryptoUtils +from pymc_core.protocol import PacketBuilder + +# Request type codes (matching C++ implementation) +REQ_TYPE_GET_STATUS = 0x01 +REQ_TYPE_KEEP_ALIVE = 0x02 +REQ_TYPE_GET_TELEMETRY_DATA = 0x03 +REQ_TYPE_GET_ACCESS_LIST = 0x05 +REQ_TYPE_GET_NEIGHBOURS = 0x06 + +# Response delay (matching C++ SERVER_RESPONSE_DELAY) +SERVER_RESPONSE_DELAY_MS = 500 + + +class ProtocolRequestHandler: + """ + Handler for protocol request packets (PAYLOAD_TYPE_REQ). + + Processes encrypted request packets from authenticated clients and sends + appropriate RESPONSE packets. Request handling is delegated to callbacks + for application-specific logic. + """ + + @staticmethod + def payload_type(): + """Return the payload type this handler processes.""" + return PAYLOAD_TYPE_REQ + + def __init__( + self, + local_identity, + contacts, + get_client_fn: Optional[Callable] = None, + request_handlers: Optional[dict] = None, + log_fn: Optional[Callable] = None, + ): + """ + Initialize protocol request handler. + + Args: + local_identity: LocalIdentity for this handler + contacts: Contact manager or wrapper providing client lookup + get_client_fn: Optional function to get client info by hash + request_handlers: Dict mapping request type codes to handler functions + log_fn: Optional logging function + """ + self.local_identity = local_identity + self.contacts = contacts + self.get_client_fn = get_client_fn + self.request_handlers = request_handlers or {} + self.log = log_fn if log_fn else lambda msg: None + + async def __call__(self, packet): + """ + Process a protocol request packet. + + Args: + packet: Packet instance with REQ payload + + Returns: + Packet: RESPONSE packet to send, or None + """ + try: + if len(packet.payload) < 2: + return None + + dest_hash = packet.payload[0] + src_hash = packet.payload[1] + + # Verify this packet is for us + our_hash = self.local_identity.get_public_key()[0] + if dest_hash != our_hash: + return None + + self.log(f"Processing REQ from 0x{src_hash:02X}") + + # Get client info + client = self._get_client(src_hash) + if not client: + self.log(f"REQ from unknown client 0x{src_hash:02X}") + return None + + # Get shared secret + shared_secret = self._get_shared_secret(client) + if not shared_secret: + self.log(f"No shared secret for client 0x{src_hash:02X}") + return None + + # Decrypt request + encrypted_data = packet.payload[2:] + aes_key = shared_secret[:16] + + try: + plaintext = CryptoUtils.mac_then_decrypt(aes_key, shared_secret, bytes(encrypted_data)) + except Exception as e: + self.log(f"Failed to decrypt REQ: {e}") + return None + + # Parse request + if len(plaintext) < 5: + self.log("REQ packet too short") + return None + + timestamp = struct.unpack(' 5 else b'' + + self.log(f"REQ type=0x{req_type:02X}, timestamp={timestamp}") + + # Handle request + response_data = await self._handle_request(client, timestamp, req_type, req_data) + + if response_data: + return self._build_response(packet, client, response_data, shared_secret) + + return None + + except Exception as e: + self.log(f"Error processing REQ: {e}") + return None + + def _get_client(self, src_hash: int): + """Get client info by source hash.""" + if self.get_client_fn: + return self.get_client_fn(src_hash) + + # Fallback: search in contacts + if hasattr(self.contacts, 'contacts'): + for contact in self.contacts.contacts: + if hasattr(contact, 'public_key'): + pk = bytes.fromhex(contact.public_key) if isinstance(contact.public_key, str) else contact.public_key + if pk[0] == src_hash: + return contact + + return None + + def _get_shared_secret(self, client): + """Get shared secret for client.""" + if hasattr(client, 'shared_secret'): + return client.shared_secret + + if hasattr(client, 'public_key'): + pk = bytes.fromhex(client.public_key) if isinstance(client.public_key, str) else client.public_key + from pymc_core.protocol.identity import Identity + identity = Identity(pk) + return identity.calc_shared_secret(self.local_identity.get_private_key()) + + return None + + async def _handle_request(self, client, timestamp: int, req_type: int, req_data: bytes): + """ + Handle request and generate response. + + Args: + client: Client info object + timestamp: Request timestamp + req_type: Request type code + req_data: Request payload + + Returns: + bytes: Response data (timestamp + payload) or None + """ + # Build response with reflected timestamp + response = bytearray(struct.pack('= 0 and len(client.out_path) > 0: + reply_packet.path = bytearray(client.out_path[:client.out_path_len]) + reply_packet.path_len = client.out_path_len + + self.log(f"RESPONSE built for 0x{client_identity.get_public_key()[0]:02X} via {route_type.upper()}") + + return reply_packet + + except Exception as e: + self.log(f"Error building RESPONSE: {e}") + return None diff --git a/src/pymc_core/node/handlers/protocol_response.py b/src/pymc_core/node/handlers/protocol_response.py index a55c7fe..c82ca4f 100644 --- a/src/pymc_core/node/handlers/protocol_response.py +++ b/src/pymc_core/node/handlers/protocol_response.py @@ -7,6 +7,7 @@ import struct from typing import Any, Callable, Dict, Optional +from ...hardware.signal_utils import snr_register_to_db from ...protocol import CryptoUtils, Identity, Packet from ...protocol.constants import PAYLOAD_TYPE_PATH @@ -168,7 +169,7 @@ def _parse_stats_response(self, data: bytes) -> Optional[Dict[str, Any]]: "total_up_time_secs": parsed[11], # Uptime in seconds "total_air_time_secs": parsed[13], # Air time in seconds "err_events": parsed[17], # Error events count - "last_snr": self._convert_signed_16bit(parsed[19]) / 4.0, # SNR in dB (scaled by 4) + "last_snr": snr_register_to_db(parsed[19], bits=16), "n_flood_dups": parsed[22], # Flood duplicate packets "n_direct_dups": parsed[23], # Direct duplicate packets } diff --git a/src/pymc_core/node/handlers/text.py b/src/pymc_core/node/handlers/text.py index 2482f34..174b1df 100644 --- a/src/pymc_core/node/handlers/text.py +++ b/src/pymc_core/node/handlers/text.py @@ -69,6 +69,7 @@ async def __call__(self, packet: Packet) -> None: timestamp = decrypted[:4] # First 4 bytes are the timestamp flags = decrypted[4] # 5th byte contains flags attempt = flags & 0x03 # Last 2 bits are the attempt number + txt_type = (flags >> 2) & 0x3F # Upper 6 bits are txt_type message_body = decrypted[5:] # Rest is the message content pubkey = bytes.fromhex(matched_contact.public_key) @@ -80,72 +81,81 @@ async def __call__(self, packet: Packet) -> None: self.log( f"Processing message - route_type: {route_type}, is_flood: {is_flood}, " - f"timestamp: {timestamp_int}" + f"timestamp: {timestamp_int}, txt_type: {txt_type}" ) - # Create appropriate ACK response - if is_flood: - # FLOOD messages use PATH ACK responses with ACK hash in extra payload - text_bytes = message_body.rstrip(b"\x00") - - # Calculate ACK hash using standard method (same as DIRECT messages) - pack_data = PacketBuilder._pack_timestamp_data(timestamp_int, attempt, text_bytes) - ack_hash = CryptoUtils.sha256(pack_data + pubkey)[:4] - - # Create PATH ACK response - incoming_path = list(packet.path if hasattr(packet, "path") else []) - - ack_packet = PacketBuilder.create_path_return( - dest_hash=PacketBuilder._hash_byte(pubkey), - src_hash=PacketBuilder._hash_byte(self.local_identity.get_public_key()), - secret=shared_secret, - path=incoming_path, - extra_type=PAYLOAD_TYPE_ACK, - extra=ack_hash, - ) - - packet_len = len(ack_packet.write_to()) - ack_airtime = PacketTimingUtils.estimate_airtime_ms(packet_len, self.radio_config) - ack_timeout_ms = PacketTimingUtils.calc_flood_timeout_ms(ack_airtime) - - self.log( - f"FLOOD ACK timing - packet:{packet_len}B, airtime:{ack_airtime:.1f}ms, " - f"delay:{ack_timeout_ms:.1f}ms" - ) - ack_timeout_ms = ack_timeout_ms / 1000.0 # Convert to seconds + # Skip ACK for TXT_TYPE_CLI_DATA (0x01) - CLI commands don't need ACKs + # Following C++ pattern: only TXT_TYPE_PLAIN (0x00) gets ACKs + TXT_TYPE_PLAIN = 0x00 + TXT_TYPE_CLI_DATA = 0x01 + send_ack = (txt_type == TXT_TYPE_PLAIN) + + if send_ack: + # Create appropriate ACK response + if is_flood: + # FLOOD messages use PATH ACK responses with ACK hash in extra payload + text_bytes = message_body.rstrip(b"\x00") + + # Calculate ACK hash using standard method (same as DIRECT messages) + pack_data = PacketBuilder._pack_timestamp_data(timestamp_int, attempt, text_bytes) + ack_hash = CryptoUtils.sha256(pack_data + pubkey)[:4] + + # Create PATH ACK response + incoming_path = list(packet.path if hasattr(packet, "path") else []) + + ack_packet = PacketBuilder.create_path_return( + dest_hash=PacketBuilder._hash_byte(pubkey), + src_hash=PacketBuilder._hash_byte(self.local_identity.get_public_key()), + secret=shared_secret, + path=incoming_path, + extra_type=PAYLOAD_TYPE_ACK, + extra=ack_hash, + ) + + packet_len = len(ack_packet.write_to()) + ack_airtime = PacketTimingUtils.estimate_airtime_ms(packet_len, self.radio_config) + ack_timeout_ms = PacketTimingUtils.calc_flood_timeout_ms(ack_airtime) - else: - # DIRECT messages use discrete ACK packets - ack_packet = PacketBuilder.create_ack( - pubkey=pubkey, - timestamp=timestamp_int, - attempt=attempt, - text=message_body.rstrip(b"\x00"), - ) - - packet_len = len(ack_packet.write_to()) - ack_airtime = PacketTimingUtils.estimate_airtime_ms(packet_len, self.radio_config) - ack_timeout_ms = PacketTimingUtils.calc_direct_timeout_ms(ack_airtime, 0) - - self.log( - f"DIRECT ACK timing - packet:{packet_len}B, airtime:{ack_airtime:.1f}ms, " - f"delay:{ack_timeout_ms:.1f}ms, radio_config:{self.radio_config}" - ) - ack_timeout_ms = ack_timeout_ms / 1000.0 # Convert to seconds - - async def send_delayed_ack(): - await asyncio.sleep(ack_timeout_ms) - try: - await self.send_packet(ack_packet, wait_for_ack=False) self.log( - f"ACK packet sent successfully (delayed {ack_timeout_ms*1000:.1f}ms) " - f"for timestamp {timestamp_int}" + f"FLOOD ACK timing - packet:{packet_len}B, airtime:{ack_airtime:.1f}ms, " + f"delay:{ack_timeout_ms:.1f}ms" + ) + ack_timeout_ms = ack_timeout_ms / 1000.0 # Convert to seconds + + else: + # DIRECT messages use discrete ACK packets + ack_packet = PacketBuilder.create_ack( + pubkey=pubkey, + timestamp=timestamp_int, + attempt=attempt, + text=message_body.rstrip(b"\x00"), ) - except Exception as ack_send_error: - self.log(f"Failed to send ACK packet: {ack_send_error}") - # Schedule ACK to be sent after delay (non-blocking) - asyncio.create_task(send_delayed_ack()) + packet_len = len(ack_packet.write_to()) + ack_airtime = PacketTimingUtils.estimate_airtime_ms(packet_len, self.radio_config) + ack_timeout_ms = PacketTimingUtils.calc_direct_timeout_ms(ack_airtime, 0) + + self.log( + f"DIRECT ACK timing - packet:{packet_len}B, airtime:{ack_airtime:.1f}ms, " + f"delay:{ack_timeout_ms:.1f}ms, radio_config:{self.radio_config}" + ) + ack_timeout_ms = ack_timeout_ms / 1000.0 # Convert to seconds + + async def send_delayed_ack(): + await asyncio.sleep(ack_timeout_ms) + try: + await self.send_packet(ack_packet, wait_for_ack=False) + self.log( + f"ACK packet sent successfully (delayed {ack_timeout_ms*1000:.1f}ms) " + f"for timestamp {timestamp_int}" + ) + except Exception as ack_send_error: + self.log(f"Failed to send ACK packet: {ack_send_error}") + + # Schedule ACK to be sent after delay (non-blocking) + asyncio.create_task(send_delayed_ack()) + else: + self.log(f"Skipping ACK for txt_type={txt_type} (CLI command)") decoded_msg = message_body.decode("utf-8", "replace") self.log(f"Received TXT_MSG: {decoded_msg}") diff --git a/src/pymc_core/node/handlers/trace.py b/src/pymc_core/node/handlers/trace.py index 9f2c594..d87f5a5 100644 --- a/src/pymc_core/node/handlers/trace.py +++ b/src/pymc_core/node/handlers/trace.py @@ -5,7 +5,7 @@ """ import struct -from typing import Any, Callable, Dict +from typing import Any, Callable, Dict, Optional from ...protocol import Packet from ...protocol.constants import PAYLOAD_TYPE_TRACE @@ -39,8 +39,8 @@ def clear_response_callback(self, contact_hash: int) -> None: """Clear callback for trace responses from a specific contact.""" self._response_callbacks.pop(contact_hash, None) - async def __call__(self, pkt: Packet) -> None: - """Handle incoming trace packet.""" + async def __call__(self, pkt: Packet) -> Optional[Dict[str, Any]]: + """Handle incoming trace packet and return parsed data.""" try: self._log(f"[TraceHandler] Processing trace packet: {len(pkt.payload)} bytes") @@ -97,8 +97,11 @@ async def __call__(self, pkt: Packet) -> None: f"for 0x{contact_hash:02X}" ) + return parsed_data + except Exception as e: self._log(f"[TraceHandler] Error processing trace packet: {e}") + return None def _parse_trace_payload(self, payload: bytes) -> Dict[str, Any]: """Parse trace packet payload. diff --git a/src/pymc_core/protocol/__init__.py b/src/pymc_core/protocol/__init__.py index 1e23f07..62ffd46 100644 --- a/src/pymc_core/protocol/__init__.py +++ b/src/pymc_core/protocol/__init__.py @@ -71,6 +71,7 @@ PacketValidationUtils, RouteTypeUtils, ) +from .transport_keys import calc_transport_code, get_auto_key_for from .utils import decode_appdata, parse_advert_payload __all__ = [ @@ -85,6 +86,8 @@ "parse_advert_payload", "decode_appdata", "describe_advert_flags", + "get_auto_key_for", + "calc_transport_code", # Utility classes "PacketValidationUtils", "PacketDataUtils", diff --git a/src/pymc_core/protocol/constants.py b/src/pymc_core/protocol/constants.py index 8f7fd95..7a3028a 100644 --- a/src/pymc_core/protocol/constants.py +++ b/src/pymc_core/protocol/constants.py @@ -31,6 +31,8 @@ PAYLOAD_TYPE_ANON_REQ = 0x07 PAYLOAD_TYPE_PATH = 0x08 PAYLOAD_TYPE_TRACE = 0x09 +PAYLOAD_TYPE_MULTIPART = 0x0A +PAYLOAD_TYPE_CONTROL = 0x0B PAYLOAD_TYPE_RAW_CUSTOM = 0x0F # --------------------------------------------------------------------------- @@ -61,7 +63,7 @@ # Node Advert Flags (bitfield values) ADVERT_FLAG_IS_CHAT_NODE = 0x01 ADVERT_FLAG_IS_REPEATER = 0x02 -ADVERT_FLAG_IS_ROOM_SERVER = 0x04 +ADVERT_FLAG_IS_ROOM_SERVER = 0x03 ADVERT_FLAG_HAS_LOCATION = 0x10 ADVERT_FLAG_HAS_FEATURE1 = 0x20 ADVERT_FLAG_HAS_FEATURE2 = 0x40 @@ -71,12 +73,18 @@ def describe_advert_flags(flags: int) -> str: labels = [] - if flags & ADVERT_FLAG_IS_CHAT_NODE: + # Extract node type from bits 0-3 + node_type = flags & 0x0F + if node_type == ADVERT_FLAG_IS_CHAT_NODE: labels.append("is chat node") - if flags & ADVERT_FLAG_IS_REPEATER: + elif node_type == ADVERT_FLAG_IS_REPEATER: labels.append("is repeater") - if flags & ADVERT_FLAG_IS_ROOM_SERVER: + elif node_type == ADVERT_FLAG_IS_ROOM_SERVER: labels.append("is room server") + elif node_type == 0x04: + labels.append("is sensor") + + # Check feature flags (bits 4-7) if flags & ADVERT_FLAG_HAS_LOCATION: labels.append("has location") if flags & ADVERT_FLAG_HAS_FEATURE1: diff --git a/src/pymc_core/protocol/packet.py b/src/pymc_core/protocol/packet.py index 3fc4f6c..f77fd9b 100644 --- a/src/pymc_core/protocol/packet.py +++ b/src/pymc_core/protocol/packet.py @@ -8,6 +8,10 @@ PH_VER_MASK, PH_VER_SHIFT, PUB_KEY_SIZE, + ROUTE_TYPE_DIRECT, + ROUTE_TYPE_FLOOD, + ROUTE_TYPE_TRANSPORT_DIRECT, + ROUTE_TYPE_TRANSPORT_FLOOD, SIGNATURE_SIZE, TIMESTAMP_SIZE, ) @@ -22,6 +26,9 @@ ║ Header (1 byte) ║ Encodes route type (2 bits), payload type (4 bits), ║ ║ ║ and version (2 bits). ║ ╠════════════════════╬══════════════════════════════════════════════════════╣ +║ Transport Codes ║ Two 16-bit codes (4 bytes total). Only present for ║ +║ (0 or 4 bytes) ║ TRANSPORT_FLOOD and TRANSPORT_DIRECT route types. ║ +╠════════════════════╬══════════════════════════════════════════════════════╣ ║ Path Length (1 B) ║ Number of path hops (0–15). ║ ╠════════════════════╬══════════════════════════════════════════════════════╣ ║ Path (N bytes) ║ List of node hashes (1 byte each), length = path_len ║ @@ -35,8 +42,9 @@ ╔═══════════╦════════════╦════════════════════════════════╗ ║ Bits ║ Name ║ Meaning ║ ╠═══════════╬════════════╬════════════════════════════════╣ -║ 0–1 ║ RouteType ║ 00: Flood, 01: Direct, ║ -║ ║ ║ 10: TransportFlood, 11: Direct ║ +║ 0–1 ║ RouteType ║ 00: TransportFlood, ║ +║ ║ ║ 01: Flood, 10: Direct, ║ +║ ║ ║ 11: TransportDirect ║ ╠═══════════╬════════════╬════════════════════════════════╣ ║ 2–5 ║ PayloadType║ See PAYLOAD_TYPE_* constants ║ ╠═══════════╬════════════╬════════════════════════════════╣ @@ -45,6 +53,7 @@ Notes: - `write_to()` and `read_from()` enforce the exact structure used in firmware. +- Transport codes are included only for route types 0x00 and 0x03. - Payload size must be ≤ MAX_PACKET_PAYLOAD (typically 254). - `calculate_packet_hash()` includes payload type + path_len (only for TRACE). """ @@ -52,7 +61,7 @@ class Packet: """ - Represents a mesh network packet with header, path, and payload components. + Represents a mesh network packet with header, transport codes, path, and payload components. This class handles serialization and deserialization of packets in the mesh protocol, providing methods for packet validation, hashing, and data extraction. It maintains @@ -60,6 +69,7 @@ class Packet: Attributes: header (int): Single byte header containing packet type and flags. + transport_codes (list): Two 16-bit transport codes for TRANSPORT route types. path_len (int): Length of the path component in bytes. path (bytearray): Variable-length path data for routing. payload (bytearray): Variable-length payload data. @@ -70,7 +80,7 @@ class Packet: Example: ```python packet = Packet() - packet.header = 0x01 + packet.header = 0x01 # Flood routing packet.path = b"node1->node2" packet.path_len = len(packet.path) packet.payload = b"Hello World" @@ -97,8 +107,12 @@ class Packet: "payload_len", "path", "payload", + "transport_codes", "_snr", "_rssi", + "_do_not_retransmit", + "drop_reason", + "_tx_metadata", ) def __init__(self): @@ -115,8 +129,12 @@ def __init__(self): self.decrypted = {} self.path_len = 0 self.payload_len = 0 + self.transport_codes = [0, 0] # Array of two 16-bit transport codes self._snr = 0 self._rssi = 0 + # Repeater flag to prevent retransmission and log drop reason + self._do_not_retransmit = False + self.drop_reason = None # Optional: reason for dropping packet def get_route_type(self) -> int: """ @@ -124,10 +142,10 @@ def get_route_type(self) -> int: Returns: int: Route type value (0-3) indicating routing method: - - 0: Flood routing - - 1: Direct routing - - 2: Transport flood routing - - 3: Reserved + - 0: Transport flood routing (with transport codes) + - 1: Flood routing + - 2: Direct routing + - 3: Transport direct routing (with transport codes) """ return self.header & PH_ROUTE_MASK @@ -157,6 +175,37 @@ def get_payload_ver(self) -> int: """ return (self.header >> PH_VER_SHIFT) & PH_VER_MASK + def has_transport_codes(self) -> bool: + """ + Check if this packet includes transport codes in its format. + + Returns: + bool: True if the packet uses transport flood or transport direct + routing, which includes 4 bytes of transport codes after the header. + """ + route_type = self.get_route_type() + return route_type == ROUTE_TYPE_TRANSPORT_FLOOD or route_type == ROUTE_TYPE_TRANSPORT_DIRECT + + def is_route_flood(self) -> bool: + """ + Check if this packet uses flood routing (with or without transport codes). + + Returns: + bool: True if the packet uses any form of flood routing. + """ + route_type = self.get_route_type() + return route_type == ROUTE_TYPE_TRANSPORT_FLOOD or route_type == ROUTE_TYPE_FLOOD + + def is_route_direct(self) -> bool: + """ + Check if this packet uses direct routing (with or without transport codes). + + Returns: + bool: True if the packet uses any form of direct routing. + """ + route_type = self.get_route_type() + return route_type == ROUTE_TYPE_TRANSPORT_DIRECT or route_type == ROUTE_TYPE_DIRECT + def get_payload(self) -> bytes: """ Get the packet payload as immutable bytes, truncated to declared length. @@ -227,7 +276,8 @@ def write_to(self) -> bytes: Returns: bytes: Serialized packet data in the format: - ``header(1) | path_len(1) | path(N) | payload(M)`` + ``header(1) | [transport_codes(4)] | path_len(1) | path(N) | payload(M)`` + Transport codes are only included if has_transport_codes() is True. Raises: ValueError: If internal length values don't match actual buffer lengths, @@ -236,6 +286,13 @@ def write_to(self) -> bytes: self._validate_lengths() out = bytearray([self.header]) + + # Add transport codes if this packet type requires them + if self.has_transport_codes(): + # Pack two 16-bit transport codes (4 bytes total) in little-endian format + out.extend(self.transport_codes[0].to_bytes(2, "little")) + out.extend(self.transport_codes[1].to_bytes(2, "little")) + out.append(self.path_len) out += self.path out += self.payload[: self.payload_len] @@ -262,6 +319,16 @@ def read_from(self, data: ByteString) -> bool: self.header = data[idx] idx += 1 + # Read transport codes if present + if self.has_transport_codes(): + self._check_bounds(idx, 4, data_len, "missing transport codes") + # Unpack two 16-bit transport codes from little-endian format + self.transport_codes[0] = int.from_bytes(data[idx : idx + 2], "little") + self.transport_codes[1] = int.from_bytes(data[idx + 2 : idx + 4], "little") + idx += 4 + else: + self.transport_codes = [0, 0] + self._check_bounds(idx, 1, data_len, "missing path_len") self.path_len = data[idx] idx += 1 @@ -299,6 +366,24 @@ def calculate_packet_hash(self) -> bytes: self.get_payload_type(), self.path_len, self.payload ) + def get_packet_hash_hex(self, length: int | None = None) -> str: + """ + Return upper-case hex string representation of this packet's hash. + + Args: + length (int | None, optional): Maximum length of the returned hex string. + Defaults to None (full hash string). + + Returns: + str: Upper-case hex string of the packet hash. + """ + return PacketHashingUtils.calculate_packet_hash_string( + payload_type=self.get_payload_type(), + path_len=self.path_len, + payload=self.payload, + length=length, + ) + def get_crc(self) -> int: """ Calculate a 4-byte CRC from SHA256 digest for ACK confirmation. @@ -328,12 +413,14 @@ def get_raw_length(self) -> int: Returns: int: Total packet size in bytes, calculated as: - header(1) + path_len(1) + path(N) + payload(M) + header(1) + [transport_codes(4)] + path_len(1) + path(N) + payload(M) + Transport codes are only included if has_transport_codes() is True. Note: This matches the wire format used by write_to() and expected by read_from(). """ - return 2 + self.path_len + self.payload_len # header + path_len + path + payload + base_length = 2 + self.path_len + self.payload_len # header + path_len + path + payload + return base_length + (4 if self.has_transport_codes() else 0) def get_snr(self) -> float: """ @@ -376,3 +463,27 @@ def snr(self) -> float: signal below noise floor. """ return self.get_snr() + + def mark_do_not_retransmit(self) -> None: + """ + Mark this packet to prevent retransmission. + + Sets a flag indicating this packet should not be forwarded by repeaters. + This is typically set when a packet has been successfully delivered to + its intended destination to prevent unnecessary network traffic. + + Used by destination nodes after successfully decrypting and processing + a message intended for them. + """ + self._do_not_retransmit = True + + def is_marked_do_not_retransmit(self) -> bool: + """ + Check if this packet is marked to prevent retransmission. + + Returns: + bool: True if the packet should not be retransmitted/forwarded. + This indicates the packet has reached its destination or should + remain local to the receiving node. + """ + return self._do_not_retransmit diff --git a/src/pymc_core/protocol/packet_builder.py b/src/pymc_core/protocol/packet_builder.py index d40bed1..3f0a67f 100644 --- a/src/pymc_core/protocol/packet_builder.py +++ b/src/pymc_core/protocol/packet_builder.py @@ -13,11 +13,13 @@ ADVERT_FLAG_IS_CHAT_NODE, CIPHER_BLOCK_SIZE, CONTACT_TYPE_ROOM_SERVER, + MAX_ADVERT_DATA_SIZE, MAX_PACKET_PAYLOAD, MAX_PATH_SIZE, PAYLOAD_TYPE_ACK, PAYLOAD_TYPE_ADVERT, PAYLOAD_TYPE_ANON_REQ, + PAYLOAD_TYPE_CONTROL, PAYLOAD_TYPE_GRP_DATA, PAYLOAD_TYPE_GRP_TXT, PAYLOAD_TYPE_PATH, @@ -303,9 +305,13 @@ def create_advert( pubkey = local_identity.get_public_key() ts_bytes = struct.pack(" MAX_ADVERT_DATA_SIZE: + raise ValueError( + f"advert appdata too large: {len(appdata)} bytes (max {MAX_ADVERT_DATA_SIZE})" + ) - # Sign the first part of the payload (pubkey + timestamp + first 32 bytes of appdata) - body_to_sign = pubkey + ts_bytes + appdata[:32] + # Sign the payload (pubkey + timestamp + appdata) + body_to_sign = pubkey + ts_bytes + appdata signature = local_identity.sign(body_to_sign) # Create payload: pubkey + timestamp + signature + appdata @@ -383,12 +389,15 @@ def create_datagram( if ptype not in (PAYLOAD_TYPE_TXT_MSG, PAYLOAD_TYPE_REQ, PAYLOAD_TYPE_RESPONSE): raise ValueError("invalid payload type") - aes_key = CryptoUtils.sha256(secret) + aes_key = secret[:16] cipher = PacketBuilder._encrypt_payload(aes_key, secret, plaintext) payload = PacketBuilder._hash_bytes(dest.get_public_key(), local_identity) + cipher header = PacketBuilder._create_header(ptype, route_type) - return PacketBuilder._create_packet(header, payload) + pkt = PacketBuilder._create_packet(header, payload) + pkt.path_len = 0 + pkt.path = bytearray() + return pkt @staticmethod def create_anon_req( @@ -878,3 +887,124 @@ def create_telem_request( protocol_code=REQ_TYPE_GET_TELEMETRY_DATA, data=bytes([inv]), # Just the permission mask as additional data ) + + # ---------- Control/Discovery Packets ---------- + + @staticmethod + def create_discovery_request( + tag: int, + filter_mask: int, + since: int = 0, + prefix_only: bool = False, + ) -> Packet: + """Create a node discovery request packet. + + Generates a control packet to discover nearby nodes on the mesh network. + This is a zero-hop broadcast packet that nearby nodes will respond to. + + Args: + tag: Random identifier to match responses (uint32_t). + filter_mask: Bitmask of node types to discover; the bit at position `node_type` is set + to select that type (e.g., for ADV_TYPE_REPEATER=2, use (1 << 2) == 0x04). + since: Optional timestamp - only nodes modified after this respond (uint32_t). + prefix_only: Request 8-byte key prefix instead of full 32-byte key. + + Returns: + Packet: Discovery request packet ready to send as zero-hop. + + Example: + ```python + import random + tag = random.randint(0, 0xFFFFFFFF) + # Filter for repeaters: ADV_TYPE_REPEATER=2, so (1 << 2) = 0x04 + packet = PacketBuilder.create_discovery_request(tag, filter_mask=0x04) + # Send as zero-hop broadcast + ``` + """ + # Build payload: type+flags(1) + filter(1) + tag(4) + since(4, optional) + payload = bytearray() + + # First byte: CTL_TYPE_NODE_DISCOVER_REQ (0x80) + flags + flags = 0x01 if prefix_only else 0x00 + payload.append(0x80 | flags) + + # Filter byte + payload.append(filter_mask & 0xFF) + + # Tag (4 bytes, little-endian) + payload.extend(struct.pack(" 0: + payload.extend(struct.pack(" Packet: + """Create a node discovery response packet. + + Generates a control packet in response to a discovery request. + This is sent as a zero-hop packet to the requester. + + Args: + tag: Tag from the discovery request to match. + node_type: Type of this node (0-15, e.g., 1 for repeater). + inbound_snr: SNR of the received request (will be multiplied by 4). + pub_key: Node's public key (32 bytes). + prefix_only: Send only 8-byte key prefix instead of full key. + + Returns: + Packet: Discovery response packet ready to send as zero-hop. + + Example: + ```python + identity = LocalIdentity() + packet = PacketBuilder.create_discovery_response( + tag=0x12345678, + node_type=1, # Repeater + inbound_snr=8.5, + pub_key=identity.get_public_key() + ) + ``` + """ + # Build payload: type+node_type(1) + snr(1) + tag(4) + pub_key(8 or 32) + payload = bytearray() + + # First byte: CTL_TYPE_NODE_DISCOVER_RESP (0x90) + node_type (lower 4 bits) + payload.append(0x90 | (node_type & 0x0F)) + + # SNR byte (multiply by 4, clamp to signed int8_t range, and encode as unsigned byte) + snr_byte = max(-128, min(127, int(inbound_snr * 4))) + payload.append(snr_byte & 0xFF) + + # Tag (4 bytes, little-endian) + payload.extend(struct.pack(" b sha.update(payload) return sha.digest()[:MAX_HASH_SIZE] + @staticmethod + def calculate_packet_hash_string( + payload_type: int, + path_len: int, + payload: bytes, + length: int | None = None, + ) -> str: + """ + Return upper-case hex string representation of the packet hash. + + Args: + payload_type: Packet payload type + path_len: Path length (only used for TRACE packets) + payload: Packet payload bytes + length: Optional maximum length of the returned hex string. + + Returns: + str: Upper-case hex string of the packet hash, optionally truncated. + """ + raw_hash = PacketHashingUtils.calculate_packet_hash(payload_type, path_len, payload) + hex_str = raw_hash.hex().upper() + return hex_str if length is None else hex_str[:length] + @staticmethod def calculate_crc(payload_type: int, path_len: int, payload: bytes) -> int: """Calculate 4-byte CRC from packet hash.""" diff --git a/src/pymc_core/protocol/transport_keys.py b/src/pymc_core/protocol/transport_keys.py new file mode 100644 index 0000000..9518f3e --- /dev/null +++ b/src/pymc_core/protocol/transport_keys.py @@ -0,0 +1,70 @@ +""" +Transport Key utilities for mesh packet authentication. + +Simple implementation matching the C++ MeshCore transport key functionality: +- Generate 128-bit key from region name (SHA256 of ASCII name) +- Calculate transport codes using HMAC-SHA256 +""" + +import struct +from .crypto import CryptoUtils + + +def get_auto_key_for(name: str) -> bytes: + """ + Generate 128-bit transport key from region name. + + Matches C++ implementation: + void TransportKeyStore::getAutoKeyFor(uint16_t id, const char* name, TransportKey& dest) + + Args: + name: Region name including '#' (e.g., "#usa") + + Returns: + bytes: 16-byte transport key + """ + if not name: + raise ValueError("Region name cannot be empty") + if not name.startswith('#'): + raise ValueError("Region name must start with '#'") + if len(name) > 64: + raise ValueError("Region name is too long (max 64 characters)") + key_hash = CryptoUtils.sha256(name.encode('ascii')) + return key_hash[:16] # First 16 bytes (128 bits) + + +def calc_transport_code(key: bytes, packet) -> int: + """ + Calculate transport code for a packet. + + Matches C++ implementation: + uint16_t TransportKey::calcTransportCode(const mesh::Packet* packet) const + + Args: + key: 16-byte transport key + packet: Packet with payload_type and payload + + Returns: + int: 16-bit transport code + """ + if len(key) != 16: + raise ValueError(f"Transport key must be 16 bytes, got {len(key)}") + payload_type = packet.get_payload_type() + payload_data = packet.get_payload() + + # HMAC input: payload_type (1 byte) + payload + hmac_data = bytes([payload_type]) + payload_data + + # Calculate HMAC-SHA256 + hmac_digest = CryptoUtils._hmac_sha256(key, hmac_data) + + # Extract first 2 bytes as little-endian uint16 (matches Arduino platform endianness) + code = struct.unpack(' int: ADVERT_FLAG_IS_ROOM_SERVER, ) - is_chat = bool(flags & ADVERT_FLAG_IS_CHAT_NODE) - is_repeater = bool(flags & ADVERT_FLAG_IS_REPEATER) - is_room_server = bool(flags & ADVERT_FLAG_IS_ROOM_SERVER) - if is_room_server: + # Extract node type from bits 0-3 (mask with 0x0F) + node_type = flags & 0x0F + + if node_type == ADVERT_FLAG_IS_ROOM_SERVER: # 0x03 return 3 # CONTACT_TYPE_ROOM_SERVER - elif is_repeater and is_chat: - return 4 # CONTACT_TYPE_HYBRID - elif is_repeater: + elif node_type == ADVERT_FLAG_IS_REPEATER: # 0x02 return 2 # CONTACT_TYPE_REPEATER - elif is_chat: + elif node_type == ADVERT_FLAG_IS_CHAT_NODE: # 0x01 return 1 # CONTACT_TYPE_CHAT_NODE + elif node_type == 0x04: # Sensor (if defined) + return 5 # CONTACT_TYPE_SENSOR (you may need to add this constant) else: return 0 # CONTACT_TYPE_UNKNOWN def get_contact_type_name(contact_type: int) -> str: - type_names = { - 0: "Unknown", - 1: "Chat Node", - 2: "Repeater", - 3: "Room Server", - 4: "Hybrid Node", - } + type_names = {0: "Unknown", 1: "Chat Node", 2: "Repeater", 3: "Room Server", 4: "Sensor"} return type_names.get(contact_type, f"Unknown Type ({contact_type})") diff --git a/tests/test_basic.py b/tests/test_basic.py index d953316..0dcd5c9 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -2,7 +2,7 @@ def test_version(): - assert __version__ == "1.0.5" + assert __version__ == "1.0.6" def test_import(): diff --git a/tests/test_handlers.py b/tests/test_handlers.py index cd3816a..835aaf3 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -14,7 +14,7 @@ TextMessageHandler, TraceHandler, ) -from pymc_core.protocol import LocalIdentity, Packet +from pymc_core.protocol import LocalIdentity, Packet, PacketBuilder from pymc_core.protocol.constants import ( PAYLOAD_TYPE_ACK, PAYLOAD_TYPE_ADVERT, @@ -24,18 +24,27 @@ PAYLOAD_TYPE_RESPONSE, PAYLOAD_TYPE_TRACE, PAYLOAD_TYPE_TXT_MSG, + PUB_KEY_SIZE, + SIGNATURE_SIZE, + TIMESTAMP_SIZE, ) # Mock classes for testing class MockContact: - def __init__(self, public_key="0123456789abcdef0123456789abcdef"): + def __init__(self, public_key="0123456789abcdef0123456789abcdef", name="mock"): self.public_key = public_key + self.name = name + self.last_advert = 0 class MockContactBook: def __init__(self): - self.contacts = [MockContact()] + self.contacts = [] + self.added_contacts = [] + + def add_contact(self, contact_data): + self.added_contacts.append(contact_data) class MockDispatcher: @@ -49,6 +58,7 @@ def __init__(self): class MockEventService: def __init__(self): self.publish = AsyncMock() + self.publish_sync = MagicMock() # Base Handler Tests @@ -167,13 +177,8 @@ async def test_call_with_short_payload(self): # Advert Handler Tests class TestAdvertHandler: def setup_method(self): - self.contacts = MockContactBook() self.log_fn = MagicMock() - self.local_identity = LocalIdentity() - self.event_service = MockEventService() - self.handler = AdvertHandler( - self.contacts, self.log_fn, self.local_identity, self.event_service - ) + self.handler = AdvertHandler(self.log_fn) def test_payload_type(self): """Test advert handler payload type.""" @@ -181,10 +186,49 @@ def test_payload_type(self): def test_advert_handler_initialization(self): """Test advert handler initialization.""" - assert self.handler.contacts == self.contacts assert self.handler.log == self.log_fn - assert self.handler.identity == self.local_identity - assert self.handler.event_service == self.event_service + + @pytest.mark.asyncio + async def test_advert_handler_accepts_valid_signature(self): + remote_identity = LocalIdentity() + packet = PacketBuilder.create_advert(remote_identity, "RemoteNode") + + result = await self.handler(packet) + + assert result is not None + assert result["valid"] is True + assert result["public_key"] == remote_identity.get_public_key().hex() + assert result["name"] == "RemoteNode" + + @pytest.mark.asyncio + async def test_advert_handler_rejects_invalid_signature(self): + remote_identity = LocalIdentity() + packet = PacketBuilder.create_advert(remote_identity, "RemoteNode") + appdata_offset = PUB_KEY_SIZE + TIMESTAMP_SIZE + SIGNATURE_SIZE + 5 + if appdata_offset >= packet.payload_len: + appdata_offset = packet.payload_len - 1 + packet.payload[appdata_offset] ^= 0x01 + + result = await self.handler(packet) + + assert result is None + assert any( + "invalid signature" in call.args[0].lower() + for call in self.log_fn.call_args_list + if call.args + ) + + @pytest.mark.asyncio + async def test_advert_handler_ignores_self_advert(self): + """Test that handler processes self-advert (dispatcher handles filtering).""" + local_identity = LocalIdentity() + packet = PacketBuilder.create_advert(local_identity, "SelfNode") + + result = await self.handler(packet) + + # Handler should still return parsed data; dispatcher filters self-adverts + assert result is not None + assert result["name"] == "SelfNode" # Path Handler Tests @@ -333,7 +377,7 @@ async def test_handlers_can_be_called(): handlers = [ AckHandler(log_fn), TextMessageHandler(local_identity, contacts, log_fn, send_packet_fn, event_service), - AdvertHandler(contacts, log_fn, local_identity, event_service), + AdvertHandler(log_fn), PathHandler(log_fn), GroupTextHandler(local_identity, contacts, log_fn, send_packet_fn), LoginResponseHandler(local_identity, contacts, log_fn), diff --git a/tests/test_packet_utils.py b/tests/test_packet_utils.py index 8ad59fc..c5bb384 100644 --- a/tests/test_packet_utils.py +++ b/tests/test_packet_utils.py @@ -3,7 +3,11 @@ import pytest from pymc_core.protocol.constants import MAX_PACKET_PAYLOAD, MAX_PATH_SIZE -from pymc_core.protocol.packet_utils import PacketDataUtils, PacketValidationUtils +from pymc_core.protocol.packet_utils import ( + PacketDataUtils, + PacketHashingUtils, + PacketValidationUtils, +) class TestPacketValidationUtils: @@ -139,3 +143,40 @@ def test_pack_timestamp_data_edge_cases(self): result = PacketDataUtils.pack_timestamp_data(0) expected = struct.pack("