Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 59 additions & 4 deletions src/pymc_core/node/handlers/advert.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
import time

from ...protocol import Packet, decode_appdata
from ...protocol.constants import PAYLOAD_TYPE_ADVERT, PUB_KEY_SIZE, describe_advert_flags
from ...protocol import Identity, Packet, decode_appdata
from ...protocol.constants import (
MAX_ADVERT_DATA_SIZE,
PAYLOAD_TYPE_ADVERT,
PUB_KEY_SIZE,
SIGNATURE_SIZE,
TIMESTAMP_SIZE,
describe_advert_flags,
)
from ...protocol.utils import determine_contact_type_from_flags
from .base import BaseHandler

Expand All @@ -17,10 +24,59 @@ def __init__(self, contacts, log_fn, identity=None, event_service=None):
self.identity = identity
self.event_service = event_service

def _extract_advert_components(self, packet: Packet):
payload = packet.get_payload()
header_len = PUB_KEY_SIZE + TIMESTAMP_SIZE + SIGNATURE_SIZE
if len(payload) < header_len:
self.log(
f"Advert payload too short ({len(payload)} bytes, expected at least {header_len})"
)
return None

sig_offset = PUB_KEY_SIZE + TIMESTAMP_SIZE
pubkey = payload[:PUB_KEY_SIZE]
timestamp = payload[PUB_KEY_SIZE:sig_offset]
signature = payload[sig_offset : sig_offset + SIGNATURE_SIZE]
appdata = payload[sig_offset + SIGNATURE_SIZE :]

if len(appdata) > MAX_ADVERT_DATA_SIZE:
self.log(
f"Advert appdata too large ({len(appdata)} bytes); truncating to {MAX_ADVERT_DATA_SIZE}"
)
appdata = appdata[:MAX_ADVERT_DATA_SIZE]

return pubkey, timestamp, signature, appdata

def _verify_advert_signature(
self, pubkey: bytes, timestamp: bytes, appdata: bytes, signature: bytes
) -> bool:
try:
peer_identity = Identity(pubkey)
except Exception as exc:
self.log(f"Unable to construct peer identity: {exc}")
return False

signed_region = pubkey + timestamp + appdata
if not peer_identity.verify(signed_region, signature):
return False
return True

async def __call__(self, packet: Packet) -> None:
pubkey_bytes = packet.payload[:PUB_KEY_SIZE]
components = self._extract_advert_components(packet)
if not components:
return

pubkey_bytes, timestamp_bytes, signature_bytes, appdata = components
pubkey_hex = pubkey_bytes.hex()

if not self._verify_advert_signature(pubkey_bytes, timestamp_bytes, appdata, signature_bytes):
self.log(f"Rejecting advert with invalid signature (pubkey={pubkey_hex[:8]}...)")
return

if self.identity and pubkey_bytes == self.identity.get_public_key():
self.log("Ignoring self advert packet")
return

self.log("<<< Advert packet received >>>")

if self.contacts is not None:
Expand All @@ -31,7 +87,6 @@ async def __call__(self, packet: Packet) -> None:
contact.last_advert = int(time.time())
else:
self.log(f"<<< New contact discovered (pubkey={pubkey_hex[:8]}...) >>>")
appdata = packet.get_payload_app_data()
decoded = decode_appdata(appdata)

# Extract name from decoded data
Expand Down
9 changes: 7 additions & 2 deletions src/pymc_core/protocol/packet_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
ADVERT_FLAG_HAS_NAME,
ADVERT_FLAG_IS_CHAT_NODE,
CIPHER_BLOCK_SIZE,
MAX_ADVERT_DATA_SIZE,
CONTACT_TYPE_ROOM_SERVER,
MAX_PACKET_PAYLOAD,
MAX_PATH_SIZE,
Expand Down Expand Up @@ -304,9 +305,13 @@ def create_advert(
pubkey = local_identity.get_public_key()
ts_bytes = struct.pack("<I", timestamp)
appdata = PacketBuilder._encode_advert_data(name, lat, lon, feature1, feature2, flags)
if len(appdata) > MAX_ADVERT_DATA_SIZE:
raise ValueError(
f"advert appdata too large: {len(appdata)} bytes (max {MAX_ADVERT_DATA_SIZE})"
)

# Sign the first part of the payload (pubkey + timestamp + first 32 bytes of appdata)
body_to_sign = pubkey + ts_bytes + appdata[:32]
# Sign the payload (pubkey + timestamp + appdata)
body_to_sign = pubkey + ts_bytes + appdata
signature = local_identity.sign(body_to_sign)

# Create payload: pubkey + timestamp + signature + appdata
Expand Down
58 changes: 55 additions & 3 deletions tests/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
TextMessageHandler,
TraceHandler,
)
from pymc_core.protocol import LocalIdentity, Packet
from pymc_core.protocol import LocalIdentity, Packet, PacketBuilder
from pymc_core.protocol.constants import (
PAYLOAD_TYPE_ACK,
PAYLOAD_TYPE_ADVERT,
Expand All @@ -24,18 +24,27 @@
PAYLOAD_TYPE_RESPONSE,
PAYLOAD_TYPE_TRACE,
PAYLOAD_TYPE_TXT_MSG,
PUB_KEY_SIZE,
SIGNATURE_SIZE,
TIMESTAMP_SIZE,
)


# Mock classes for testing
class MockContact:
def __init__(self, public_key="0123456789abcdef0123456789abcdef"):
def __init__(self, public_key="0123456789abcdef0123456789abcdef", name="mock"):
self.public_key = public_key
self.name = name
self.last_advert = 0


class MockContactBook:
def __init__(self):
self.contacts = [MockContact()]
self.contacts = []
self.added_contacts = []

def add_contact(self, contact_data):
self.added_contacts.append(contact_data)


class MockDispatcher:
Expand All @@ -49,6 +58,7 @@ def __init__(self):
class MockEventService:
def __init__(self):
self.publish = AsyncMock()
self.publish_sync = MagicMock()


# Base Handler Tests
Expand Down Expand Up @@ -186,6 +196,48 @@ def test_advert_handler_initialization(self):
assert self.handler.identity == self.local_identity
assert self.handler.event_service == self.event_service

@pytest.mark.asyncio
async def test_advert_handler_accepts_valid_signature(self):
remote_identity = LocalIdentity()
packet = PacketBuilder.create_advert(remote_identity, "RemoteNode")

await self.handler(packet)

assert len(self.contacts.added_contacts) == 1
added_contact = self.contacts.added_contacts[0]
assert added_contact["public_key"] == remote_identity.get_public_key().hex()

@pytest.mark.asyncio
async def test_advert_handler_rejects_invalid_signature(self):
remote_identity = LocalIdentity()
packet = PacketBuilder.create_advert(remote_identity, "RemoteNode")
appdata_offset = PUB_KEY_SIZE + TIMESTAMP_SIZE + SIGNATURE_SIZE + 5
if appdata_offset >= packet.payload_len:
appdata_offset = packet.payload_len - 1
packet.payload[appdata_offset] ^= 0x01

await self.handler(packet)

assert len(self.contacts.added_contacts) == 0
assert any(
"invalid signature" in call.args[0].lower()
for call in self.log_fn.call_args_list
if call.args
)

@pytest.mark.asyncio
async def test_advert_handler_ignores_self_advert(self):
packet = PacketBuilder.create_advert(self.local_identity, "SelfNode")

await self.handler(packet)

assert len(self.contacts.added_contacts) == 0
assert any(
"self advert" in call.args[0].lower()
for call in self.log_fn.call_args_list
if call.args
)


# Path Handler Tests
class TestPathHandler:
Expand Down