2525
2626import io
2727import hashlib
28- from typing import Sequence , List , Tuple , NamedTuple , TYPE_CHECKING , Dict , Any , Optional , Union
28+ from typing import Sequence , List , Tuple , NamedTuple , TYPE_CHECKING , Dict , Any , Optional , Union , Mapping
2929from enum import IntEnum
30+ from dataclasses import dataclass , field , replace
31+ from types import MappingProxyType
3032
3133import electrum_ecc as ecc
3234
@@ -53,18 +55,22 @@ class InvalidOnionPubkey(Exception): pass
5355class InvalidPayloadSize (Exception ): pass
5456
5557
56- class OnionHopsDataSingle : # called HopData in lnd
58+ @dataclass (frozen = True , kw_only = True )
59+ class OnionHopsDataSingle :
60+ payload : Mapping = field (default_factory = lambda : MappingProxyType ({}))
61+ hmac : Optional [bytes ] = None
62+ tlv_stream_name : str = 'payload'
63+ blind_fields : Mapping = field (default_factory = lambda : MappingProxyType ({}))
64+ _raw_bytes_payload : Optional [bytes ] = None
5765
58- def __init__ (self , * , payload : dict = None , tlv_stream_name : str = 'payload' , blind_fields : dict = None ):
59- if payload is None :
60- payload = {}
61- self .payload = payload
62- self .hmac = None
63- self .tlv_stream_name = tlv_stream_name
64- if blind_fields is None :
65- blind_fields = {}
66- self .blind_fields = blind_fields
67- self ._raw_bytes_payload = None # used in unit tests
66+ def __post_init__ (self ):
67+ # make all fields immutable recursively
68+ object .__setattr__ (self , 'payload' , util .make_object_immutable (self .payload ))
69+ object .__setattr__ (self , 'blind_fields' , util .make_object_immutable (self .blind_fields ))
70+ assert isinstance (self .payload , MappingProxyType )
71+ assert isinstance (self .blind_fields , MappingProxyType )
72+ assert isinstance (self .tlv_stream_name , str )
73+ assert (isinstance (self .hmac , bytes ) and len (self .hmac ) == PER_HOP_HMAC_SIZE ) or self .hmac is None
6874
6975 def to_bytes (self ) -> bytes :
7076 hmac_ = self .hmac if self .hmac is not None else bytes (PER_HOP_HMAC_SIZE )
@@ -101,32 +107,35 @@ def from_fd(cls, fd: io.BytesIO, *, tlv_stream_name: str = 'payload') -> 'OnionH
101107 hop_payload = fd .read (hop_payload_length )
102108 if hop_payload_length != len (hop_payload ):
103109 raise Exception (f"unexpected EOF" )
104- ret = OnionHopsDataSingle (tlv_stream_name = tlv_stream_name )
105- ret .payload = OnionWireSerializer .read_tlv_stream (fd = io .BytesIO (hop_payload ),
106- tlv_stream_name = tlv_stream_name )
107- ret .hmac = fd .read (PER_HOP_HMAC_SIZE )
108- assert len (ret .hmac ) == PER_HOP_HMAC_SIZE
110+ payload = OnionWireSerializer .read_tlv_stream (fd = io .BytesIO (hop_payload ),
111+ tlv_stream_name = tlv_stream_name )
112+ ret = OnionHopsDataSingle (
113+ tlv_stream_name = tlv_stream_name ,
114+ payload = payload ,
115+ hmac = fd .read (PER_HOP_HMAC_SIZE )
116+ )
109117 return ret
110118
111119 def __repr__ (self ):
112- return f"<OnionHopsDataSingle. payload= { self .payload } . hmac= { self .hmac } >"
120+ return f"<OnionHopsDataSingle. { self .payload = } . { self .hmac = } >"
113121
114122
123+ @dataclass (frozen = True , kw_only = True )
115124class OnionPacket :
116-
117- def __init__ (self , * , public_key : bytes , hops_data : bytes , hmac : bytes , version : int = 0 ):
118- assert len (public_key ) == 33
119- assert len (hops_data ) in [HOPS_DATA_SIZE , TRAMPOLINE_HOPS_DATA_SIZE , ONION_MESSAGE_LARGE_SIZE ]
120- assert len (hmac ) == PER_HOP_HMAC_SIZE
121- self .version = version
122- self .public_key = public_key
123- self .hops_data = hops_data # also called RoutingInfo in bolt-04
124- self .hmac = hmac
125- if not ecc .ECPubkey .is_pubkey_bytes (public_key ):
125+ public_key : bytes
126+ hops_data : bytes # also called RoutingInfo in bolt-04
127+ hmac : bytes
128+ version : int = 0
129+ # for debugging our own onions:
130+ _debug_hops_data : Optional [Sequence [OnionHopsDataSingle ]] = None
131+ _debug_route : Optional ['LNPaymentRoute' ] = None
132+
133+ def __post_init__ (self ):
134+ assert len (self .public_key ) == 33
135+ assert len (self .hops_data ) in [HOPS_DATA_SIZE , TRAMPOLINE_HOPS_DATA_SIZE , ONION_MESSAGE_LARGE_SIZE ]
136+ assert len (self .hmac ) == PER_HOP_HMAC_SIZE
137+ if not ecc .ECPubkey .is_pubkey_bytes (self .public_key ):
126138 raise InvalidOnionPubkey ()
127- # for debugging our own onions:
128- self ._debug_hops_data = None # type: Optional[Sequence[OnionHopsDataSingle]]
129- self ._debug_route = None # type: Optional[LNPaymentRoute]
130139
131140 def to_bytes (self ) -> bytes :
132141 ret = bytes ([self .version ])
@@ -138,7 +147,7 @@ def to_bytes(self) -> bytes:
138147 return ret
139148
140149 @classmethod
141- def from_bytes (cls , b : bytes ):
150+ def from_bytes (cls , b : bytes ) -> 'OnionPacket' :
142151 if len (b ) - 66 not in [HOPS_DATA_SIZE , TRAMPOLINE_HOPS_DATA_SIZE , ONION_MESSAGE_LARGE_SIZE ]:
143152 raise Exception ('unexpected length {}' .format (len (b )))
144153 return OnionPacket (
@@ -187,7 +196,7 @@ def get_blinded_node_id(node_id: bytes, shared_secret: bytes):
187196def new_onion_packet (
188197 payment_path_pubkeys : Sequence [bytes ],
189198 session_key : bytes ,
190- hops_data : Sequence [OnionHopsDataSingle ],
199+ hops_data : List [OnionHopsDataSingle ],
191200 * ,
192201 associated_data : bytes = b'' ,
193202 trampoline : bool = False ,
@@ -226,7 +235,7 @@ def new_onion_packet(
226235 for i in range (num_hops - 1 , - 1 , - 1 ):
227236 rho_key = get_bolt04_onion_key (b'rho' , hop_shared_secrets [i ])
228237 mu_key = get_bolt04_onion_key (b'mu' , hop_shared_secrets [i ])
229- hops_data [i ]. hmac = next_hmac
238+ hops_data [i ] = replace ( hops_data [ i ], hmac = next_hmac )
230239 stream_bytes = generate_cipher_stream (rho_key , data_size )
231240 hop_data_bytes = hops_data [i ].to_bytes ()
232241 mix_header = mix_header [:- len (hop_data_bytes )]
0 commit comments