@@ -88,6 +88,8 @@ pub trait CustomMessageHandler: wire::CustomMessageReader {
8888 /// May return an `Err(())` if the features the peer supports are not sufficient to communicate
8989 /// with us. Implementors should be somewhat conservative about doing so, however, as other
9090 /// message handlers may still wish to communicate with this peer.
91+ ///
92+ /// [`Self::peer_disconnected`] will not be called if `Err(())` is returned.
9193 fn peer_connected ( & self , their_node_id : PublicKey , msg : & Init , inbound : bool ) -> Result < ( ) , ( ) > ;
9294
9395 /// Gets the node feature flags which this handler itself supports. All available handlers are
@@ -119,6 +121,7 @@ impl RoutingMessageHandler for IgnoringMessageHandler {
119121 Option < ( msgs:: ChannelAnnouncement , Option < msgs:: ChannelUpdate > , Option < msgs:: ChannelUpdate > ) > { None }
120122 fn get_next_node_announcement ( & self , _starting_point : Option < & NodeId > ) -> Option < msgs:: NodeAnnouncement > { None }
121123 fn peer_connected ( & self , _their_node_id : PublicKey , _init : & msgs:: Init , _inbound : bool ) -> Result < ( ) , ( ) > { Ok ( ( ) ) }
124+ fn peer_disconnected ( & self , _their_node_id : PublicKey ) { }
122125 fn handle_reply_channel_range ( & self , _their_node_id : PublicKey , _msg : msgs:: ReplyChannelRange ) -> Result < ( ) , LightningError > { Ok ( ( ) ) }
123126 fn handle_reply_short_channel_ids_end ( & self , _their_node_id : PublicKey , _msg : msgs:: ReplyShortChannelIdsEnd ) -> Result < ( ) , LightningError > { Ok ( ( ) ) }
124127 fn handle_query_channel_range ( & self , _their_node_id : PublicKey , _msg : msgs:: QueryChannelRange ) -> Result < ( ) , LightningError > { Ok ( ( ) ) }
@@ -1714,14 +1717,20 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
17141717 }
17151718 if let Err ( ( ) ) = self . message_handler . chan_handler . peer_connected ( their_node_id, & msg, peer_lock. inbound_connection ) {
17161719 log_debug ! ( logger, "Channel Handler decided we couldn't communicate with peer {}" , log_pubkey!( their_node_id) ) ;
1720+ self . message_handler . route_handler . peer_disconnected ( their_node_id) ;
17171721 return Err ( PeerHandleError { } . into ( ) ) ;
17181722 }
17191723 if let Err ( ( ) ) = self . message_handler . onion_message_handler . peer_connected ( their_node_id, & msg, peer_lock. inbound_connection ) {
17201724 log_debug ! ( logger, "Onion Message Handler decided we couldn't communicate with peer {}" , log_pubkey!( their_node_id) ) ;
1725+ self . message_handler . route_handler . peer_disconnected ( their_node_id) ;
1726+ self . message_handler . chan_handler . peer_disconnected ( their_node_id) ;
17211727 return Err ( PeerHandleError { } . into ( ) ) ;
17221728 }
17231729 if let Err ( ( ) ) = self . message_handler . custom_message_handler . peer_connected ( their_node_id, & msg, peer_lock. inbound_connection ) {
17241730 log_debug ! ( logger, "Custom Message Handler decided we couldn't communicate with peer {}" , log_pubkey!( their_node_id) ) ;
1731+ self . message_handler . route_handler . peer_disconnected ( their_node_id) ;
1732+ self . message_handler . chan_handler . peer_disconnected ( their_node_id) ;
1733+ self . message_handler . onion_message_handler . peer_disconnected ( their_node_id) ;
17251734 return Err ( PeerHandleError { } . into ( ) ) ;
17261735 }
17271736
@@ -2533,6 +2542,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
25332542 debug_assert ! ( peer. their_node_id. is_some( ) ) ;
25342543 if let Some ( ( node_id, _) ) = peer. their_node_id {
25352544 log_trace ! ( WithContext :: from( & self . logger, Some ( node_id) , None , None ) , "Disconnecting peer with id {} due to {}" , node_id, reason) ;
2545+ self . message_handler . route_handler . peer_disconnected ( node_id) ;
25362546 self . message_handler . chan_handler . peer_disconnected ( node_id) ;
25372547 self . message_handler . onion_message_handler . peer_disconnected ( node_id) ;
25382548 self . message_handler . custom_message_handler . peer_disconnected ( node_id) ;
@@ -2557,6 +2567,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
25572567 let removed = self . node_id_to_descriptor . lock ( ) . unwrap ( ) . remove ( & node_id) ;
25582568 debug_assert ! ( removed. is_some( ) , "descriptor maps should be consistent" ) ;
25592569 if !peer. handshake_complete ( ) { return ; }
2570+ self . message_handler . route_handler . peer_disconnected ( node_id) ;
25602571 self . message_handler . chan_handler . peer_disconnected ( node_id) ;
25612572 self . message_handler . onion_message_handler . peer_disconnected ( node_id) ;
25622573 self . message_handler . custom_message_handler . peer_disconnected ( node_id) ;
@@ -2856,6 +2867,16 @@ mod tests {
28562867
28572868 struct TestCustomMessageHandler {
28582869 features : InitFeatures ,
2870+ conn_tracker : test_utils:: ConnectionTracker ,
2871+ }
2872+
2873+ impl TestCustomMessageHandler {
2874+ fn new ( features : InitFeatures ) -> Self {
2875+ Self {
2876+ features,
2877+ conn_tracker : test_utils:: ConnectionTracker :: new ( ) ,
2878+ }
2879+ }
28592880 }
28602881
28612882 impl wire:: CustomMessageReader for TestCustomMessageHandler {
@@ -2872,10 +2893,13 @@ mod tests {
28722893
28732894 fn get_and_clear_pending_msg ( & self ) -> Vec < ( PublicKey , Self :: CustomMessage ) > { Vec :: new ( ) }
28742895
2896+ fn peer_disconnected ( & self , their_node_id : PublicKey ) {
2897+ self . conn_tracker . peer_disconnected ( their_node_id) ;
2898+ }
28752899
2876- fn peer_disconnected ( & self , _their_node_id : PublicKey ) { }
2877-
2878- fn peer_connected ( & self , _their_node_id : PublicKey , _msg : & Init , _inbound : bool ) -> Result < ( ) , ( ) > { Ok ( ( ) ) }
2900+ fn peer_connected ( & self , their_node_id : PublicKey , _msg : & Init , _inbound : bool ) -> Result < ( ) , ( ) > {
2901+ self . conn_tracker . peer_connected ( their_node_id )
2902+ }
28792903
28802904 fn provided_node_features ( & self ) -> NodeFeatures { NodeFeatures :: empty ( ) }
28812905
@@ -2898,7 +2922,7 @@ mod tests {
28982922 chan_handler : test_utils:: TestChannelMessageHandler :: new ( ChainHash :: using_genesis_block ( Network :: Testnet ) ) ,
28992923 logger : test_utils:: TestLogger :: with_id ( i. to_string ( ) ) ,
29002924 routing_handler : test_utils:: TestRoutingMessageHandler :: new ( ) ,
2901- custom_handler : TestCustomMessageHandler { features } ,
2925+ custom_handler : TestCustomMessageHandler :: new ( features) ,
29022926 node_signer : test_utils:: TestNodeSigner :: new ( node_secret) ,
29032927 }
29042928 ) ;
@@ -2921,7 +2945,7 @@ mod tests {
29212945 chan_handler : test_utils:: TestChannelMessageHandler :: new ( ChainHash :: using_genesis_block ( Network :: Testnet ) ) ,
29222946 logger : test_utils:: TestLogger :: new ( ) ,
29232947 routing_handler : test_utils:: TestRoutingMessageHandler :: new ( ) ,
2924- custom_handler : TestCustomMessageHandler { features } ,
2948+ custom_handler : TestCustomMessageHandler :: new ( features) ,
29252949 node_signer : test_utils:: TestNodeSigner :: new ( node_secret) ,
29262950 }
29272951 ) ;
@@ -2941,7 +2965,7 @@ mod tests {
29412965 chan_handler : test_utils:: TestChannelMessageHandler :: new ( network) ,
29422966 logger : test_utils:: TestLogger :: new ( ) ,
29432967 routing_handler : test_utils:: TestRoutingMessageHandler :: new ( ) ,
2944- custom_handler : TestCustomMessageHandler { features } ,
2968+ custom_handler : TestCustomMessageHandler :: new ( features) ,
29452969 node_signer : test_utils:: TestNodeSigner :: new ( node_secret) ,
29462970 }
29472971 ) ;
@@ -2965,19 +2989,16 @@ mod tests {
29652989 peers
29662990 }
29672991
2968- fn establish_connection < ' a > ( peer_a : & PeerManager < FileDescriptor , & ' a test_utils:: TestChannelMessageHandler , & ' a test_utils:: TestRoutingMessageHandler , IgnoringMessageHandler , & ' a test_utils:: TestLogger , & ' a TestCustomMessageHandler , & ' a test_utils:: TestNodeSigner > , peer_b : & PeerManager < FileDescriptor , & ' a test_utils:: TestChannelMessageHandler , & ' a test_utils:: TestRoutingMessageHandler , IgnoringMessageHandler , & ' a test_utils:: TestLogger , & ' a TestCustomMessageHandler , & ' a test_utils:: TestNodeSigner > ) -> ( FileDescriptor , FileDescriptor ) {
2992+ fn try_establish_connection < ' a > ( peer_a : & PeerManager < FileDescriptor , & ' a test_utils:: TestChannelMessageHandler , & ' a test_utils:: TestRoutingMessageHandler , IgnoringMessageHandler , & ' a test_utils:: TestLogger , & ' a TestCustomMessageHandler , & ' a test_utils:: TestNodeSigner > , peer_b : & PeerManager < FileDescriptor , & ' a test_utils:: TestChannelMessageHandler , & ' a test_utils:: TestRoutingMessageHandler , IgnoringMessageHandler , & ' a test_utils:: TestLogger , & ' a TestCustomMessageHandler , & ' a test_utils:: TestNodeSigner > ) -> ( FileDescriptor , FileDescriptor , Result < bool , PeerHandleError > , Result < bool , PeerHandleError > ) {
2993+ let addr_a = SocketAddress :: TcpIpV4 { addr : [ 127 , 0 , 0 , 1 ] , port : 1000 } ;
2994+ let addr_b = SocketAddress :: TcpIpV4 { addr : [ 127 , 0 , 0 , 1 ] , port : 1001 } ;
2995+
29692996 static FD_COUNTER : AtomicUsize = AtomicUsize :: new ( 0 ) ;
29702997 let fd = FD_COUNTER . fetch_add ( 1 , Ordering :: Relaxed ) as u16 ;
29712998
29722999 let id_a = peer_a. node_signer . get_node_id ( Recipient :: Node ) . unwrap ( ) ;
29733000 let mut fd_a = FileDescriptor :: new ( fd) ;
2974- let addr_a = SocketAddress :: TcpIpV4 { addr : [ 127 , 0 , 0 , 1 ] , port : 1000 } ;
2975-
2976- let id_b = peer_b. node_signer . get_node_id ( Recipient :: Node ) . unwrap ( ) ;
2977- let features_a = peer_a. init_features ( id_b) ;
2978- let features_b = peer_b. init_features ( id_a) ;
29793001 let mut fd_b = FileDescriptor :: new ( fd) ;
2980- let addr_b = SocketAddress :: TcpIpV4 { addr : [ 127 , 0 , 0 , 1 ] , port : 1001 } ;
29813002
29823003 let initial_data = peer_b. new_outbound_connection ( id_a, fd_b. clone ( ) , Some ( addr_a. clone ( ) ) ) . unwrap ( ) ;
29833004 peer_a. new_inbound_connection ( fd_a. clone ( ) , Some ( addr_b. clone ( ) ) ) . unwrap ( ) ;
@@ -2989,11 +3010,30 @@ mod tests {
29893010
29903011 peer_b. process_events ( ) ;
29913012 let b_data = fd_b. outbound_data . lock ( ) . unwrap ( ) . split_off ( 0 ) ;
2992- assert_eq ! ( peer_a. read_event( & mut fd_a, & b_data) . unwrap ( ) , false ) ;
3013+ let a_refused = peer_a. read_event ( & mut fd_a, & b_data) ;
29933014
29943015 peer_a. process_events ( ) ;
29953016 let a_data = fd_a. outbound_data . lock ( ) . unwrap ( ) . split_off ( 0 ) ;
2996- assert_eq ! ( peer_b. read_event( & mut fd_b, & a_data) . unwrap( ) , false ) ;
3017+ let b_refused = peer_b. read_event ( & mut fd_b, & a_data) ;
3018+
3019+ ( fd_a, fd_b, a_refused, b_refused)
3020+ }
3021+
3022+
3023+ fn establish_connection < ' a > ( peer_a : & PeerManager < FileDescriptor , & ' a test_utils:: TestChannelMessageHandler , & ' a test_utils:: TestRoutingMessageHandler , IgnoringMessageHandler , & ' a test_utils:: TestLogger , & ' a TestCustomMessageHandler , & ' a test_utils:: TestNodeSigner > , peer_b : & PeerManager < FileDescriptor , & ' a test_utils:: TestChannelMessageHandler , & ' a test_utils:: TestRoutingMessageHandler , IgnoringMessageHandler , & ' a test_utils:: TestLogger , & ' a TestCustomMessageHandler , & ' a test_utils:: TestNodeSigner > ) -> ( FileDescriptor , FileDescriptor ) {
3024+ let addr_a = SocketAddress :: TcpIpV4 { addr : [ 127 , 0 , 0 , 1 ] , port : 1000 } ;
3025+ let addr_b = SocketAddress :: TcpIpV4 { addr : [ 127 , 0 , 0 , 1 ] , port : 1001 } ;
3026+
3027+ let id_a = peer_a. node_signer . get_node_id ( Recipient :: Node ) . unwrap ( ) ;
3028+ let id_b = peer_b. node_signer . get_node_id ( Recipient :: Node ) . unwrap ( ) ;
3029+
3030+ let features_a = peer_a. init_features ( id_b) ;
3031+ let features_b = peer_b. init_features ( id_a) ;
3032+
3033+ let ( fd_a, fd_b, a_refused, b_refused) = try_establish_connection ( peer_a, peer_b) ;
3034+
3035+ assert_eq ! ( a_refused. unwrap( ) , false ) ;
3036+ assert_eq ! ( b_refused. unwrap( ) , false ) ;
29973037
29983038 assert_eq ! ( peer_a. peer_by_node_id( & id_b) . unwrap( ) . counterparty_node_id, id_b) ;
29993039 assert_eq ! ( peer_a. peer_by_node_id( & id_b) . unwrap( ) . socket_address, Some ( addr_b) ) ;
@@ -3246,6 +3286,50 @@ mod tests {
32463286 assert_eq ! ( peers[ 0 ] . peers. read( ) . unwrap( ) . len( ) , 0 ) ;
32473287 }
32483288
3289+ fn do_test_peer_connected_error_disconnects ( handler : usize ) {
3290+ // Test that if a message handler fails a connection in `peer_connected` we reliably
3291+ // produce `peer_disconnected` events for all other message handlers (that saw a
3292+ // corresponding `peer_connected`).
3293+ let cfgs = create_peermgr_cfgs ( 2 ) ;
3294+ let peers = create_network ( 2 , & cfgs) ;
3295+
3296+ match handler & !1 {
3297+ 0 => {
3298+ peers[ handler & 1 ] . message_handler . chan_handler . conn_tracker . fail_connections . store ( true , Ordering :: Release ) ;
3299+ }
3300+ 2 => {
3301+ peers[ handler & 1 ] . message_handler . route_handler . conn_tracker . fail_connections . store ( true , Ordering :: Release ) ;
3302+ }
3303+ 4 => {
3304+ peers[ handler & 1 ] . message_handler . custom_message_handler . conn_tracker . fail_connections . store ( true , Ordering :: Release ) ;
3305+ }
3306+ _ => panic ! ( ) ,
3307+ }
3308+ let ( _sd1, _sd2, a_refused, b_refused) = try_establish_connection ( & peers[ 0 ] , & peers[ 1 ] ) ;
3309+ if handler & 1 == 0 {
3310+ assert ! ( a_refused. is_err( ) ) ;
3311+ assert ! ( peers[ 0 ] . list_peers( ) . is_empty( ) ) ;
3312+ } else {
3313+ assert ! ( b_refused. is_err( ) ) ;
3314+ assert ! ( peers[ 1 ] . list_peers( ) . is_empty( ) ) ;
3315+ }
3316+ // At least one message handler should have seen the connection.
3317+ assert ! ( peers[ handler & 1 ] . message_handler. chan_handler. conn_tracker. had_peers. load( Ordering :: Acquire ) ||
3318+ peers[ handler & 1 ] . message_handler. route_handler. conn_tracker. had_peers. load( Ordering :: Acquire ) ||
3319+ peers[ handler & 1 ] . message_handler. custom_message_handler. conn_tracker. had_peers. load( Ordering :: Acquire ) ) ;
3320+ // And both message handlers doing tracking should see the disconnection
3321+ assert ! ( peers[ handler & 1 ] . message_handler. chan_handler. conn_tracker. connected_peers. lock( ) . unwrap( ) . is_empty( ) ) ;
3322+ assert ! ( peers[ handler & 1 ] . message_handler. route_handler. conn_tracker. connected_peers. lock( ) . unwrap( ) . is_empty( ) ) ;
3323+ assert ! ( peers[ handler & 1 ] . message_handler. custom_message_handler. conn_tracker. connected_peers. lock( ) . unwrap( ) . is_empty( ) ) ;
3324+ }
3325+
3326+ #[ test]
3327+ fn test_peer_connected_error_disconnects ( ) {
3328+ for i in 0 ..6 {
3329+ do_test_peer_connected_error_disconnects ( i) ;
3330+ }
3331+ }
3332+
32493333 #[ test]
32503334 fn test_do_attempt_write_data ( ) {
32513335 // Create 2 peers with custom TestRoutingMessageHandlers and connect them.
0 commit comments