diff --git a/src/pymc_core/node/handlers/advert.py b/src/pymc_core/node/handlers/advert.py index acc740c..c26916a 100644 --- a/src/pymc_core/node/handlers/advert.py +++ b/src/pymc_core/node/handlers/advert.py @@ -1,7 +1,14 @@ import time -from ...protocol import Packet, decode_appdata -from ...protocol.constants import PAYLOAD_TYPE_ADVERT, PUB_KEY_SIZE, describe_advert_flags +from ...protocol import Identity, Packet, decode_appdata +from ...protocol.constants import ( + MAX_ADVERT_DATA_SIZE, + PAYLOAD_TYPE_ADVERT, + PUB_KEY_SIZE, + SIGNATURE_SIZE, + TIMESTAMP_SIZE, + describe_advert_flags, +) from ...protocol.utils import determine_contact_type_from_flags from .base import BaseHandler @@ -17,10 +24,59 @@ def __init__(self, contacts, log_fn, identity=None, event_service=None): self.identity = identity self.event_service = event_service + def _extract_advert_components(self, packet: Packet): + payload = packet.get_payload() + header_len = PUB_KEY_SIZE + TIMESTAMP_SIZE + SIGNATURE_SIZE + if len(payload) < header_len: + self.log( + f"Advert payload too short ({len(payload)} bytes, expected at least {header_len})" + ) + return None + + sig_offset = PUB_KEY_SIZE + TIMESTAMP_SIZE + pubkey = payload[:PUB_KEY_SIZE] + timestamp = payload[PUB_KEY_SIZE:sig_offset] + signature = payload[sig_offset : sig_offset + SIGNATURE_SIZE] + appdata = payload[sig_offset + SIGNATURE_SIZE :] + + if len(appdata) > MAX_ADVERT_DATA_SIZE: + self.log( + f"Advert appdata too large ({len(appdata)} bytes); truncating to {MAX_ADVERT_DATA_SIZE}" + ) + appdata = appdata[:MAX_ADVERT_DATA_SIZE] + + return pubkey, timestamp, signature, appdata + + def _verify_advert_signature( + self, pubkey: bytes, timestamp: bytes, appdata: bytes, signature: bytes + ) -> bool: + try: + peer_identity = Identity(pubkey) + except Exception as exc: + self.log(f"Unable to construct peer identity: {exc}") + return False + + signed_region = pubkey + timestamp + appdata + if not peer_identity.verify(signed_region, signature): + return False + return True + async def __call__(self, packet: Packet) -> None: - pubkey_bytes = packet.payload[:PUB_KEY_SIZE] + components = self._extract_advert_components(packet) + if not components: + return + + pubkey_bytes, timestamp_bytes, signature_bytes, appdata = components pubkey_hex = pubkey_bytes.hex() + if not self._verify_advert_signature(pubkey_bytes, timestamp_bytes, appdata, signature_bytes): + self.log(f"Rejecting advert with invalid signature (pubkey={pubkey_hex[:8]}...)") + return + + if self.identity and pubkey_bytes == self.identity.get_public_key(): + self.log("Ignoring self advert packet") + return + self.log("<<< Advert packet received >>>") if self.contacts is not None: @@ -31,7 +87,6 @@ async def __call__(self, packet: Packet) -> None: contact.last_advert = int(time.time()) else: self.log(f"<<< New contact discovered (pubkey={pubkey_hex[:8]}...) >>>") - appdata = packet.get_payload_app_data() decoded = decode_appdata(appdata) # Extract name from decoded data diff --git a/src/pymc_core/protocol/packet_builder.py b/src/pymc_core/protocol/packet_builder.py index 58fc65b..9954cea 100644 --- a/src/pymc_core/protocol/packet_builder.py +++ b/src/pymc_core/protocol/packet_builder.py @@ -12,6 +12,7 @@ ADVERT_FLAG_HAS_NAME, ADVERT_FLAG_IS_CHAT_NODE, CIPHER_BLOCK_SIZE, + MAX_ADVERT_DATA_SIZE, CONTACT_TYPE_ROOM_SERVER, MAX_PACKET_PAYLOAD, MAX_PATH_SIZE, @@ -304,9 +305,13 @@ def create_advert( pubkey = local_identity.get_public_key() ts_bytes = struct.pack(" MAX_ADVERT_DATA_SIZE: + raise ValueError( + f"advert appdata too large: {len(appdata)} bytes (max {MAX_ADVERT_DATA_SIZE})" + ) - # Sign the first part of the payload (pubkey + timestamp + first 32 bytes of appdata) - body_to_sign = pubkey + ts_bytes + appdata[:32] + # Sign the payload (pubkey + timestamp + appdata) + body_to_sign = pubkey + ts_bytes + appdata signature = local_identity.sign(body_to_sign) # Create payload: pubkey + timestamp + signature + appdata diff --git a/tests/test_handlers.py b/tests/test_handlers.py index cd3816a..eb9a91a 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -14,7 +14,7 @@ TextMessageHandler, TraceHandler, ) -from pymc_core.protocol import LocalIdentity, Packet +from pymc_core.protocol import LocalIdentity, Packet, PacketBuilder from pymc_core.protocol.constants import ( PAYLOAD_TYPE_ACK, PAYLOAD_TYPE_ADVERT, @@ -24,18 +24,27 @@ PAYLOAD_TYPE_RESPONSE, PAYLOAD_TYPE_TRACE, PAYLOAD_TYPE_TXT_MSG, + PUB_KEY_SIZE, + SIGNATURE_SIZE, + TIMESTAMP_SIZE, ) # Mock classes for testing class MockContact: - def __init__(self, public_key="0123456789abcdef0123456789abcdef"): + def __init__(self, public_key="0123456789abcdef0123456789abcdef", name="mock"): self.public_key = public_key + self.name = name + self.last_advert = 0 class MockContactBook: def __init__(self): - self.contacts = [MockContact()] + self.contacts = [] + self.added_contacts = [] + + def add_contact(self, contact_data): + self.added_contacts.append(contact_data) class MockDispatcher: @@ -49,6 +58,7 @@ def __init__(self): class MockEventService: def __init__(self): self.publish = AsyncMock() + self.publish_sync = MagicMock() # Base Handler Tests @@ -186,6 +196,48 @@ def test_advert_handler_initialization(self): assert self.handler.identity == self.local_identity assert self.handler.event_service == self.event_service + @pytest.mark.asyncio + async def test_advert_handler_accepts_valid_signature(self): + remote_identity = LocalIdentity() + packet = PacketBuilder.create_advert(remote_identity, "RemoteNode") + + await self.handler(packet) + + assert len(self.contacts.added_contacts) == 1 + added_contact = self.contacts.added_contacts[0] + assert added_contact["public_key"] == remote_identity.get_public_key().hex() + + @pytest.mark.asyncio + async def test_advert_handler_rejects_invalid_signature(self): + remote_identity = LocalIdentity() + packet = PacketBuilder.create_advert(remote_identity, "RemoteNode") + appdata_offset = PUB_KEY_SIZE + TIMESTAMP_SIZE + SIGNATURE_SIZE + 5 + if appdata_offset >= packet.payload_len: + appdata_offset = packet.payload_len - 1 + packet.payload[appdata_offset] ^= 0x01 + + await self.handler(packet) + + assert len(self.contacts.added_contacts) == 0 + assert any( + "invalid signature" in call.args[0].lower() + for call in self.log_fn.call_args_list + if call.args + ) + + @pytest.mark.asyncio + async def test_advert_handler_ignores_self_advert(self): + packet = PacketBuilder.create_advert(self.local_identity, "SelfNode") + + await self.handler(packet) + + assert len(self.contacts.added_contacts) == 0 + assert any( + "self advert" in call.args[0].lower() + for call in self.log_fn.call_args_list + if call.args + ) + # Path Handler Tests class TestPathHandler: