Skip to content

Commit ee80065

Browse files
authored
Merge pull request #10312 from SomberNight/202511_pr10230_1
lnonion: immutable OnionPacket and OnionHopsDataSingle
2 parents c09b3d2 + 1b600b4 commit ee80065

File tree

8 files changed

+125
-74
lines changed

8 files changed

+125
-74
lines changed

electrum/lnmsg.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import os
22
import csv
33
import io
4-
from typing import Callable, Tuple, Any, Dict, List, Sequence, Union, Optional
4+
from typing import Callable, Tuple, Any, Dict, List, Sequence, Union, Optional, Mapping
5+
from types import MappingProxyType
56
from collections import OrderedDict
67

78
from .lnutil import OnionFailureCodeMetaFlag
@@ -289,7 +290,7 @@ def _write_tlv_record(*, fd: io.BytesIO, tlv_type: int, tlv_val: bytes) -> None:
289290
_write_primitive_field(fd=fd, field_type="byte", count=tlv_len, value=tlv_val)
290291

291292

292-
def _resolve_field_count(field_count_str: str, *, vars_dict: dict, allow_any=False) -> Union[int, str]:
293+
def _resolve_field_count(field_count_str: str, *, vars_dict: Mapping, allow_any=False) -> Union[int, str]:
293294
"""Returns an evaluated field count, typically an int.
294295
If allow_any is True, the return value can be a str with value=="...".
295296
"""
@@ -403,7 +404,7 @@ def write_field(
403404
fd: io.BytesIO,
404405
field_type: str,
405406
count: Union[int, str],
406-
value: Union[List[Dict[str, Any]], Dict[str, Any]]
407+
value: Union[Sequence[Mapping[str, Any]], Mapping[str, Any]],
407408
) -> None:
408409
assert fd
409410

@@ -421,10 +422,10 @@ def write_field(
421422
return
422423

423424
if count == 1:
424-
assert isinstance(value, dict) or isinstance(value, list)
425-
values = [value] if isinstance(value, dict) else value
425+
assert isinstance(value, (MappingProxyType, dict)) or isinstance(value, (list, tuple)), type(value)
426+
values = [value] if isinstance(value, (MappingProxyType, dict)) else value
426427
else:
427-
assert isinstance(value, list), f'{field_type=}, expected value of type list for {count=}'
428+
assert isinstance(value, (tuple, list)), f'{field_type=}, expected value of type list/tuple for {count=}'
428429
values = value
429430

430431
if count == '...':

electrum/lnonion.py

Lines changed: 43 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,10 @@
2525

2626
import io
2727
import hashlib
28-
from typing import Sequence, List, Tuple, NamedTuple, TYPE_CHECKING, Dict, Any, Optional, Union
28+
from typing import Sequence, List, Tuple, NamedTuple, TYPE_CHECKING, Dict, Any, Optional, Union, Mapping
2929
from enum import IntEnum
30+
from dataclasses import dataclass, field, replace
31+
from types import MappingProxyType
3032

3133
import electrum_ecc as ecc
3234

@@ -53,18 +55,22 @@ class InvalidOnionPubkey(Exception): pass
5355
class InvalidPayloadSize(Exception): pass
5456

5557

56-
class OnionHopsDataSingle: # called HopData in lnd
58+
@dataclass(frozen=True, kw_only=True)
59+
class OnionHopsDataSingle:
60+
payload: Mapping = field(default_factory=lambda: MappingProxyType({}))
61+
hmac: Optional[bytes] = None
62+
tlv_stream_name: str = 'payload'
63+
blind_fields: Mapping = field(default_factory=lambda: MappingProxyType({}))
64+
_raw_bytes_payload: Optional[bytes] = None
5765

58-
def __init__(self, *, payload: dict = None, tlv_stream_name: str = 'payload', blind_fields: dict = None):
59-
if payload is None:
60-
payload = {}
61-
self.payload = payload
62-
self.hmac = None
63-
self.tlv_stream_name = tlv_stream_name
64-
if blind_fields is None:
65-
blind_fields = {}
66-
self.blind_fields = blind_fields
67-
self._raw_bytes_payload = None # used in unit tests
66+
def __post_init__(self):
67+
# make all fields immutable recursively
68+
object.__setattr__(self, 'payload', util.make_object_immutable(self.payload))
69+
object.__setattr__(self, 'blind_fields', util.make_object_immutable(self.blind_fields))
70+
assert isinstance(self.payload, MappingProxyType)
71+
assert isinstance(self.blind_fields, MappingProxyType)
72+
assert isinstance(self.tlv_stream_name, str)
73+
assert (isinstance(self.hmac, bytes) and len(self.hmac) == PER_HOP_HMAC_SIZE) or self.hmac is None
6874

6975
def to_bytes(self) -> bytes:
7076
hmac_ = self.hmac if self.hmac is not None else bytes(PER_HOP_HMAC_SIZE)
@@ -101,32 +107,35 @@ def from_fd(cls, fd: io.BytesIO, *, tlv_stream_name: str = 'payload') -> 'OnionH
101107
hop_payload = fd.read(hop_payload_length)
102108
if hop_payload_length != len(hop_payload):
103109
raise Exception(f"unexpected EOF")
104-
ret = OnionHopsDataSingle(tlv_stream_name=tlv_stream_name)
105-
ret.payload = OnionWireSerializer.read_tlv_stream(fd=io.BytesIO(hop_payload),
106-
tlv_stream_name=tlv_stream_name)
107-
ret.hmac = fd.read(PER_HOP_HMAC_SIZE)
108-
assert len(ret.hmac) == PER_HOP_HMAC_SIZE
110+
payload = OnionWireSerializer.read_tlv_stream(fd=io.BytesIO(hop_payload),
111+
tlv_stream_name=tlv_stream_name)
112+
ret = OnionHopsDataSingle(
113+
tlv_stream_name=tlv_stream_name,
114+
payload=payload,
115+
hmac=fd.read(PER_HOP_HMAC_SIZE)
116+
)
109117
return ret
110118

111119
def __repr__(self):
112-
return f"<OnionHopsDataSingle. payload={self.payload}. hmac={self.hmac}>"
120+
return f"<OnionHopsDataSingle. {self.payload=}. {self.hmac=}>"
113121

114122

123+
@dataclass(frozen=True, kw_only=True)
115124
class OnionPacket:
116-
117-
def __init__(self, *, public_key: bytes, hops_data: bytes, hmac: bytes, version: int = 0):
118-
assert len(public_key) == 33
119-
assert len(hops_data) in [HOPS_DATA_SIZE, TRAMPOLINE_HOPS_DATA_SIZE, ONION_MESSAGE_LARGE_SIZE]
120-
assert len(hmac) == PER_HOP_HMAC_SIZE
121-
self.version = version
122-
self.public_key = public_key
123-
self.hops_data = hops_data # also called RoutingInfo in bolt-04
124-
self.hmac = hmac
125-
if not ecc.ECPubkey.is_pubkey_bytes(public_key):
125+
public_key: bytes
126+
hops_data: bytes # also called RoutingInfo in bolt-04
127+
hmac: bytes
128+
version: int = 0
129+
# for debugging our own onions:
130+
_debug_hops_data: Optional[Sequence[OnionHopsDataSingle]] = None
131+
_debug_route: Optional['LNPaymentRoute'] = None
132+
133+
def __post_init__(self):
134+
assert len(self.public_key) == 33
135+
assert len(self.hops_data) in [HOPS_DATA_SIZE, TRAMPOLINE_HOPS_DATA_SIZE, ONION_MESSAGE_LARGE_SIZE]
136+
assert len(self.hmac) == PER_HOP_HMAC_SIZE
137+
if not ecc.ECPubkey.is_pubkey_bytes(self.public_key):
126138
raise InvalidOnionPubkey()
127-
# for debugging our own onions:
128-
self._debug_hops_data = None # type: Optional[Sequence[OnionHopsDataSingle]]
129-
self._debug_route = None # type: Optional[LNPaymentRoute]
130139

131140
def to_bytes(self) -> bytes:
132141
ret = bytes([self.version])
@@ -138,7 +147,7 @@ def to_bytes(self) -> bytes:
138147
return ret
139148

140149
@classmethod
141-
def from_bytes(cls, b: bytes):
150+
def from_bytes(cls, b: bytes) -> 'OnionPacket':
142151
if len(b) - 66 not in [HOPS_DATA_SIZE, TRAMPOLINE_HOPS_DATA_SIZE, ONION_MESSAGE_LARGE_SIZE]:
143152
raise Exception('unexpected length {}'.format(len(b)))
144153
return OnionPacket(
@@ -187,7 +196,7 @@ def get_blinded_node_id(node_id: bytes, shared_secret: bytes):
187196
def new_onion_packet(
188197
payment_path_pubkeys: Sequence[bytes],
189198
session_key: bytes,
190-
hops_data: Sequence[OnionHopsDataSingle],
199+
hops_data: List[OnionHopsDataSingle],
191200
*,
192201
associated_data: bytes = b'',
193202
trampoline: bool = False,
@@ -226,7 +235,7 @@ def new_onion_packet(
226235
for i in range(num_hops-1, -1, -1):
227236
rho_key = get_bolt04_onion_key(b'rho', hop_shared_secrets[i])
228237
mu_key = get_bolt04_onion_key(b'mu', hop_shared_secrets[i])
229-
hops_data[i].hmac = next_hmac
238+
hops_data[i] = replace(hops_data[i], hmac=next_hmac)
230239
stream_bytes = generate_cipher_stream(rho_key, data_size)
231240
hop_data_bytes = hops_data[i].to_bytes()
232241
mix_header = mix_header[:-len(hop_data_bytes)]

electrum/lnworker.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
Optional, Sequence, Tuple, List, Set, Dict, TYPE_CHECKING, NamedTuple, Mapping, Any, Iterable, AsyncGenerator,
1313
Callable, Awaitable
1414
)
15+
from types import MappingProxyType
1516
import threading
1617
import socket
1718
from functools import partial
@@ -3723,13 +3724,14 @@ def create_onion_for_route(
37233724
# if we are forwarding a trampoline payment, add trampoline onion
37243725
if trampoline_onion:
37253726
self.logger.info(f'adding trampoline onion to final payload')
3726-
trampoline_payload = hops_data[-1].payload
3727+
trampoline_payload = dict(hops_data[-1].payload)
37273728
trampoline_payload["trampoline_onion_packet"] = {
37283729
"version": trampoline_onion.version,
37293730
"public_key": trampoline_onion.public_key,
37303731
"hops_data": trampoline_onion.hops_data,
37313732
"hmac": trampoline_onion.hmac
37323733
}
3734+
hops_data[-1] = dataclasses.replace(hops_data[-1], payload=trampoline_payload)
37333735
if t_hops_data := trampoline_onion._debug_hops_data: # None if trampoline-forwarding
37343736
t_route = trampoline_onion._debug_route
37353737
assert t_route is not None

electrum/onion_message.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,11 @@
2727
import os
2828
import threading
2929
import time
30+
import dataclasses
3031
from random import random
32+
from types import MappingProxyType
3133

32-
from typing import TYPE_CHECKING, Optional, Sequence, NamedTuple
34+
from typing import TYPE_CHECKING, Optional, Sequence, NamedTuple, List
3335

3436
import electrum_ecc as ecc
3537

@@ -139,7 +141,7 @@ def is_onion_message_node(node_id: bytes, node_info: Optional['NodeInfo']) -> bo
139141

140142

141143
def encrypt_onionmsg_tlv_hops_data(
142-
hops_data: Sequence[OnionHopsDataSingle],
144+
hops_data: List[OnionHopsDataSingle],
143145
hop_shared_secrets: Sequence[bytes]
144146
) -> None:
145147
"""encrypt unencrypted onionmsg_tlv.encrypted_recipient_data for hops with blind_fields"""
@@ -148,7 +150,9 @@ def encrypt_onionmsg_tlv_hops_data(
148150
if hops_data[i].tlv_stream_name == 'onionmsg_tlv' and 'encrypted_recipient_data' not in hops_data[i].payload:
149151
# construct encrypted_recipient_data from blind_fields
150152
encrypted_recipient_data = encrypt_onionmsg_data_tlv(shared_secret=hop_shared_secrets[i], **hops_data[i].blind_fields)
151-
hops_data[i].payload['encrypted_recipient_data'] = {'encrypted_recipient_data': encrypted_recipient_data}
153+
new_payload = dict(hops_data[i].payload)
154+
new_payload['encrypted_recipient_data'] = {'encrypted_recipient_data': encrypted_recipient_data}
155+
hops_data[i] = dataclasses.replace(hops_data[i], payload=new_payload)
152156

153157

154158
def create_onion_message_route_to(lnwallet: 'LNWallet', node_id: bytes) -> Sequence[PathEdge]:
@@ -280,7 +284,7 @@ def send_onion_message_to(
280284
hops_data = [
281285
OnionHopsDataSingle(
282286
tlv_stream_name='onionmsg_tlv',
283-
blind_fields={'next_node_id': {'node_id': x.end_node}}
287+
blind_fields={'next_node_id': {'node_id': x.end_node}},
284288
) for x in path[:-1]
285289
]
286290

@@ -290,7 +294,7 @@ def send_onion_message_to(
290294
blind_fields={
291295
'next_node_id': {'node_id': introduction_point},
292296
'next_path_key_override': {'path_key': blinded_path['first_path_key']},
293-
}
297+
},
294298
)
295299
hops_data.append(final_hop_pre_ip)
296300

@@ -299,9 +303,11 @@ def send_onion_message_to(
299303
encrypted_recipient_data = encrypt_onionmsg_data_tlv(
300304
shared_secret=hop_shared_secrets[i],
301305
**hops_data[i].blind_fields)
302-
hops_data[i].payload['encrypted_recipient_data'] = {
306+
payload = dict(hops_data[i].payload)
307+
payload['encrypted_recipient_data'] = {
303308
'encrypted_recipient_data': encrypted_recipient_data
304309
}
310+
hops_data[i] = dataclasses.replace(hops_data[i], payload=payload)
305311

306312
path_key = ecc.ECPrivkey(session_key).get_public_key_bytes()
307313

@@ -345,13 +351,13 @@ def send_onion_message_to(
345351
hops_data = [
346352
OnionHopsDataSingle(
347353
tlv_stream_name='onionmsg_tlv',
348-
blind_fields={'next_node_id': {'node_id': x.end_node}}
354+
blind_fields={'next_node_id': {'node_id': x.end_node}},
349355
) for x in path[1:]
350356
]
351357

352358
final_hop = OnionHopsDataSingle(
353359
tlv_stream_name='onionmsg_tlv',
354-
payload=destination_payload
360+
payload=destination_payload,
355361
)
356362

357363
hops_data.append(final_hop)

electrum/trampoline.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import io
22
import os
33
import random
4+
import dataclasses
45
from typing import Mapping, Tuple, Optional, List, Iterable, Sequence, Set, Any
6+
from types import MappingProxyType
57

68
from .lnutil import LnFeatures, PaymentFeeBudget, FeeBudgetExceeded
79
from .lnonion import (
@@ -302,12 +304,12 @@ def create_trampoline_onion(
302304
for i in range(num_hops):
303305
route_edge = route[i]
304306
assert route_edge.is_trampoline()
305-
payload = hops_data[i].payload
307+
payload = dict(hops_data[i].payload)
306308
if i < num_hops - 1:
307309
payload.pop('short_channel_id')
308310
next_edge = route[i+1]
309311
assert next_edge.is_trampoline()
310-
hops_data[i].payload["outgoing_node_id"] = {"outgoing_node_id": next_edge.node_id}
312+
payload["outgoing_node_id"] = {"outgoing_node_id": next_edge.node_id}
311313
# only for final
312314
if i == num_hops - 1:
313315
payload["payment_data"] = {
@@ -322,10 +324,11 @@ def create_trampoline_onion(
322324
"payment_secret": payment_secret,
323325
"total_msat": total_msat
324326
}
327+
hops_data[i] = dataclasses.replace(hops_data[i], payload=payload)
325328

326329
if (index := routing_info_payload_index) is not None:
327330
# fill the remaining payload space with available routing hints (r_tags)
328-
payload: dict = hops_data[index].payload
331+
payload = dict(hops_data[index].payload)
329332
# try different r_tag order on each attempt
330333
invoice_routing_info = random_shuffled_copy(route[index].invoice_routing_info)
331334
remaining_payload_space = TRAMPOLINE_HOPS_DATA_SIZE \
@@ -341,12 +344,16 @@ def create_trampoline_onion(
341344
remaining_payload_space -= r_tag_size
342345
# add the chosen r_tags to the payload
343346
payload["invoice_routing_info"] = {"invoice_routing_info": b''.join(routing_info_to_use)}
347+
hops_data[index] = dataclasses.replace(hops_data[index], payload=payload)
344348
_logger.debug(f"Using {len(routing_info_to_use)} of {len(invoice_routing_info)} r_tags")
345349

346350
trampoline_session_key = os.urandom(32)
347351
trampoline_onion = new_onion_packet(payment_path_pubkeys, trampoline_session_key, hops_data, associated_data=payment_hash, trampoline=True)
348-
trampoline_onion._debug_hops_data = hops_data
349-
trampoline_onion._debug_route = route
352+
trampoline_onion = dataclasses.replace(
353+
trampoline_onion,
354+
_debug_hops_data=hops_data,
355+
_debug_route=route,
356+
)
350357
return trampoline_onion, amount_msat, cltv_abs
351358

352359

electrum/util.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
NamedTuple, Union, TYPE_CHECKING, Tuple, Optional, Callable, Any, Sequence, Dict, Generic, TypeVar, List, Iterable,
3333
Set, Awaitable
3434
)
35+
from types import MappingProxyType
3536
from datetime import datetime, timezone, timedelta
3637
import decimal
3738
from decimal import Decimal
@@ -1875,6 +1876,21 @@ def __setitem__(self, key, *args, **kwargs):
18751876
return ret
18761877

18771878

1879+
def make_object_immutable(obj):
1880+
"""Makes the passed object immutable recursively."""
1881+
allowed_types = (
1882+
dict, MappingProxyType, list, tuple, set, frozenset, str, int, float, bool, bytes, type(None)
1883+
)
1884+
assert isinstance(obj, allowed_types), f"{type(obj)=} cannot be made immutable"
1885+
if isinstance(obj, (dict, MappingProxyType)):
1886+
return MappingProxyType({k: make_object_immutable(v) for k, v in obj.items()})
1887+
elif isinstance(obj, (list, tuple)):
1888+
return tuple(make_object_immutable(item) for item in obj)
1889+
elif isinstance(obj, (set, frozenset)):
1890+
return frozenset(make_object_immutable(item) for item in obj)
1891+
return obj
1892+
1893+
18781894
def multisig_type(wallet_type):
18791895
"""If wallet_type is mofn multi-sig, return [m, n],
18801896
otherwise return None."""

0 commit comments

Comments
 (0)