@@ -2867,6 +2867,16 @@ mod tests {
28672867
28682868 struct TestCustomMessageHandler {
28692869 features : InitFeatures ,
2870+ conn_tracker : test_utils:: ConnectionTracker ,
2871+ }
2872+
2873+ impl TestCustomMessageHandler {
2874+ fn new ( features : InitFeatures ) -> Self {
2875+ Self {
2876+ features,
2877+ conn_tracker : test_utils:: ConnectionTracker :: new ( ) ,
2878+ }
2879+ }
28702880 }
28712881
28722882 impl wire:: CustomMessageReader for TestCustomMessageHandler {
@@ -2883,10 +2893,13 @@ mod tests {
28832893
28842894 fn get_and_clear_pending_msg ( & self ) -> Vec < ( PublicKey , Self :: CustomMessage ) > { Vec :: new ( ) }
28852895
2896+ fn peer_disconnected ( & self , their_node_id : PublicKey ) {
2897+ self . conn_tracker . peer_disconnected ( their_node_id) ;
2898+ }
28862899
2887- fn peer_disconnected ( & self , _their_node_id : PublicKey ) { }
2888-
2889- fn peer_connected ( & self , _their_node_id : PublicKey , _msg : & Init , _inbound : bool ) -> Result < ( ) , ( ) > { Ok ( ( ) ) }
2900+ fn peer_connected ( & self , their_node_id : PublicKey , _msg : & Init , _inbound : bool ) -> Result < ( ) , ( ) > {
2901+ self . conn_tracker . peer_connected ( their_node_id )
2902+ }
28902903
28912904 fn provided_node_features ( & self ) -> NodeFeatures { NodeFeatures :: empty ( ) }
28922905
@@ -2909,7 +2922,7 @@ mod tests {
29092922 chan_handler : test_utils:: TestChannelMessageHandler :: new ( ChainHash :: using_genesis_block ( Network :: Testnet ) ) ,
29102923 logger : test_utils:: TestLogger :: with_id ( i. to_string ( ) ) ,
29112924 routing_handler : test_utils:: TestRoutingMessageHandler :: new ( ) ,
2912- custom_handler : TestCustomMessageHandler { features } ,
2925+ custom_handler : TestCustomMessageHandler :: new ( features) ,
29132926 node_signer : test_utils:: TestNodeSigner :: new ( node_secret) ,
29142927 }
29152928 ) ;
@@ -2932,7 +2945,7 @@ mod tests {
29322945 chan_handler : test_utils:: TestChannelMessageHandler :: new ( ChainHash :: using_genesis_block ( Network :: Testnet ) ) ,
29332946 logger : test_utils:: TestLogger :: new ( ) ,
29342947 routing_handler : test_utils:: TestRoutingMessageHandler :: new ( ) ,
2935- custom_handler : TestCustomMessageHandler { features } ,
2948+ custom_handler : TestCustomMessageHandler :: new ( features) ,
29362949 node_signer : test_utils:: TestNodeSigner :: new ( node_secret) ,
29372950 }
29382951 ) ;
@@ -2952,7 +2965,7 @@ mod tests {
29522965 chan_handler : test_utils:: TestChannelMessageHandler :: new ( network) ,
29532966 logger : test_utils:: TestLogger :: new ( ) ,
29542967 routing_handler : test_utils:: TestRoutingMessageHandler :: new ( ) ,
2955- custom_handler : TestCustomMessageHandler { features } ,
2968+ custom_handler : TestCustomMessageHandler :: new ( features) ,
29562969 node_signer : test_utils:: TestNodeSigner :: new ( node_secret) ,
29572970 }
29582971 ) ;
@@ -2976,19 +2989,16 @@ mod tests {
29762989 peers
29772990 }
29782991
2979- fn establish_connection < ' a > ( peer_a : & PeerManager < FileDescriptor , & ' a test_utils:: TestChannelMessageHandler , & ' a test_utils:: TestRoutingMessageHandler , IgnoringMessageHandler , & ' a test_utils:: TestLogger , & ' a TestCustomMessageHandler , & ' a test_utils:: TestNodeSigner > , peer_b : & PeerManager < FileDescriptor , & ' a test_utils:: TestChannelMessageHandler , & ' a test_utils:: TestRoutingMessageHandler , IgnoringMessageHandler , & ' a test_utils:: TestLogger , & ' a TestCustomMessageHandler , & ' a test_utils:: TestNodeSigner > ) -> ( FileDescriptor , FileDescriptor ) {
2992+ fn try_establish_connection < ' a > ( peer_a : & PeerManager < FileDescriptor , & ' a test_utils:: TestChannelMessageHandler , & ' a test_utils:: TestRoutingMessageHandler , IgnoringMessageHandler , & ' a test_utils:: TestLogger , & ' a TestCustomMessageHandler , & ' a test_utils:: TestNodeSigner > , peer_b : & PeerManager < FileDescriptor , & ' a test_utils:: TestChannelMessageHandler , & ' a test_utils:: TestRoutingMessageHandler , IgnoringMessageHandler , & ' a test_utils:: TestLogger , & ' a TestCustomMessageHandler , & ' a test_utils:: TestNodeSigner > ) -> ( FileDescriptor , FileDescriptor , Result < bool , PeerHandleError > , Result < bool , PeerHandleError > ) {
2993+ let addr_a = SocketAddress :: TcpIpV4 { addr : [ 127 , 0 , 0 , 1 ] , port : 1000 } ;
2994+ let addr_b = SocketAddress :: TcpIpV4 { addr : [ 127 , 0 , 0 , 1 ] , port : 1001 } ;
2995+
29802996 static FD_COUNTER : AtomicUsize = AtomicUsize :: new ( 0 ) ;
29812997 let fd = FD_COUNTER . fetch_add ( 1 , Ordering :: Relaxed ) as u16 ;
29822998
29832999 let id_a = peer_a. node_signer . get_node_id ( Recipient :: Node ) . unwrap ( ) ;
29843000 let mut fd_a = FileDescriptor :: new ( fd) ;
2985- let addr_a = SocketAddress :: TcpIpV4 { addr : [ 127 , 0 , 0 , 1 ] , port : 1000 } ;
2986-
2987- let id_b = peer_b. node_signer . get_node_id ( Recipient :: Node ) . unwrap ( ) ;
2988- let features_a = peer_a. init_features ( id_b) ;
2989- let features_b = peer_b. init_features ( id_a) ;
29903001 let mut fd_b = FileDescriptor :: new ( fd) ;
2991- let addr_b = SocketAddress :: TcpIpV4 { addr : [ 127 , 0 , 0 , 1 ] , port : 1001 } ;
29923002
29933003 let initial_data = peer_b. new_outbound_connection ( id_a, fd_b. clone ( ) , Some ( addr_a. clone ( ) ) ) . unwrap ( ) ;
29943004 peer_a. new_inbound_connection ( fd_a. clone ( ) , Some ( addr_b. clone ( ) ) ) . unwrap ( ) ;
@@ -3000,11 +3010,30 @@ mod tests {
30003010
30013011 peer_b. process_events ( ) ;
30023012 let b_data = fd_b. outbound_data . lock ( ) . unwrap ( ) . split_off ( 0 ) ;
3003- assert_eq ! ( peer_a. read_event( & mut fd_a, & b_data) . unwrap ( ) , false ) ;
3013+ let a_refused = peer_a. read_event ( & mut fd_a, & b_data) ;
30043014
30053015 peer_a. process_events ( ) ;
30063016 let a_data = fd_a. outbound_data . lock ( ) . unwrap ( ) . split_off ( 0 ) ;
3007- assert_eq ! ( peer_b. read_event( & mut fd_b, & a_data) . unwrap( ) , false ) ;
3017+ let b_refused = peer_b. read_event ( & mut fd_b, & a_data) ;
3018+
3019+ ( fd_a, fd_b, a_refused, b_refused)
3020+ }
3021+
3022+
3023+ fn establish_connection < ' a > ( peer_a : & PeerManager < FileDescriptor , & ' a test_utils:: TestChannelMessageHandler , & ' a test_utils:: TestRoutingMessageHandler , IgnoringMessageHandler , & ' a test_utils:: TestLogger , & ' a TestCustomMessageHandler , & ' a test_utils:: TestNodeSigner > , peer_b : & PeerManager < FileDescriptor , & ' a test_utils:: TestChannelMessageHandler , & ' a test_utils:: TestRoutingMessageHandler , IgnoringMessageHandler , & ' a test_utils:: TestLogger , & ' a TestCustomMessageHandler , & ' a test_utils:: TestNodeSigner > ) -> ( FileDescriptor , FileDescriptor ) {
3024+ let addr_a = SocketAddress :: TcpIpV4 { addr : [ 127 , 0 , 0 , 1 ] , port : 1000 } ;
3025+ let addr_b = SocketAddress :: TcpIpV4 { addr : [ 127 , 0 , 0 , 1 ] , port : 1001 } ;
3026+
3027+ let id_a = peer_a. node_signer . get_node_id ( Recipient :: Node ) . unwrap ( ) ;
3028+ let id_b = peer_b. node_signer . get_node_id ( Recipient :: Node ) . unwrap ( ) ;
3029+
3030+ let features_a = peer_a. init_features ( id_b) ;
3031+ let features_b = peer_b. init_features ( id_a) ;
3032+
3033+ let ( fd_a, fd_b, a_refused, b_refused) = try_establish_connection ( peer_a, peer_b) ;
3034+
3035+ assert_eq ! ( a_refused. unwrap( ) , false ) ;
3036+ assert_eq ! ( b_refused. unwrap( ) , false ) ;
30083037
30093038 assert_eq ! ( peer_a. peer_by_node_id( & id_b) . unwrap( ) . counterparty_node_id, id_b) ;
30103039 assert_eq ! ( peer_a. peer_by_node_id( & id_b) . unwrap( ) . socket_address, Some ( addr_b) ) ;
@@ -3257,6 +3286,50 @@ mod tests {
32573286 assert_eq ! ( peers[ 0 ] . peers. read( ) . unwrap( ) . len( ) , 0 ) ;
32583287 }
32593288
3289+ fn do_test_peer_connected_error_disconnects ( handler : usize ) {
3290+ // Test that if a message handler fails a connection in `peer_connected` we reliably
3291+ // produce `peer_disconnected` events for all other message handlers (that saw a
3292+ // corresponding `peer_connected`).
3293+ let cfgs = create_peermgr_cfgs ( 2 ) ;
3294+ let peers = create_network ( 2 , & cfgs) ;
3295+
3296+ match handler & !1 {
3297+ 0 => {
3298+ peers[ handler & 1 ] . message_handler . chan_handler . conn_tracker . fail_connections . store ( true , Ordering :: Release ) ;
3299+ }
3300+ 2 => {
3301+ peers[ handler & 1 ] . message_handler . route_handler . conn_tracker . fail_connections . store ( true , Ordering :: Release ) ;
3302+ }
3303+ 4 => {
3304+ peers[ handler & 1 ] . message_handler . custom_message_handler . conn_tracker . fail_connections . store ( true , Ordering :: Release ) ;
3305+ }
3306+ _ => panic ! ( ) ,
3307+ }
3308+ let ( _sd1, _sd2, a_refused, b_refused) = try_establish_connection ( & peers[ 0 ] , & peers[ 1 ] ) ;
3309+ if handler & 1 == 0 {
3310+ assert ! ( a_refused. is_err( ) ) ;
3311+ assert ! ( peers[ 0 ] . list_peers( ) . is_empty( ) ) ;
3312+ } else {
3313+ assert ! ( b_refused. is_err( ) ) ;
3314+ assert ! ( peers[ 1 ] . list_peers( ) . is_empty( ) ) ;
3315+ }
3316+ // At least one message handler should have seen the connection.
3317+ assert ! ( peers[ handler & 1 ] . message_handler. chan_handler. conn_tracker. had_peers. load( Ordering :: Acquire ) ||
3318+ peers[ handler & 1 ] . message_handler. route_handler. conn_tracker. had_peers. load( Ordering :: Acquire ) ||
3319+ peers[ handler & 1 ] . message_handler. custom_message_handler. conn_tracker. had_peers. load( Ordering :: Acquire ) ) ;
3320+ // And both message handlers doing tracking should see the disconnection
3321+ assert ! ( peers[ handler & 1 ] . message_handler. chan_handler. conn_tracker. connected_peers. lock( ) . unwrap( ) . is_empty( ) ) ;
3322+ assert ! ( peers[ handler & 1 ] . message_handler. route_handler. conn_tracker. connected_peers. lock( ) . unwrap( ) . is_empty( ) ) ;
3323+ assert ! ( peers[ handler & 1 ] . message_handler. custom_message_handler. conn_tracker. connected_peers. lock( ) . unwrap( ) . is_empty( ) ) ;
3324+ }
3325+
3326+ #[ test]
3327+ fn test_peer_connected_error_disconnects ( ) {
3328+ for i in 0 ..6 {
3329+ do_test_peer_connected_error_disconnects ( i) ;
3330+ }
3331+ }
3332+
32603333 #[ test]
32613334 fn test_do_attempt_write_data ( ) {
32623335 // Create 2 peers with custom TestRoutingMessageHandlers and connect them.
0 commit comments