@@ -2867,6 +2867,16 @@ mod tests {
28672867
28682868	struct  TestCustomMessageHandler  { 
28692869		features :  InitFeatures , 
2870+ 		conn_tracker :  test_utils:: ConnectionTracker , 
2871+ 	} 
2872+ 
2873+ 	impl  TestCustomMessageHandler  { 
2874+ 		fn  new ( features :  InitFeatures )  -> Self  { 
2875+ 			Self  { 
2876+ 				features, 
2877+ 				conn_tracker :  test_utils:: ConnectionTracker :: new ( ) , 
2878+ 			} 
2879+ 		} 
28702880	} 
28712881
28722882	impl  wire:: CustomMessageReader  for  TestCustomMessageHandler  { 
@@ -2883,10 +2893,13 @@ mod tests {
28832893
28842894		fn  get_and_clear_pending_msg ( & self )  -> Vec < ( PublicKey ,  Self :: CustomMessage ) >  {  Vec :: new ( )  } 
28852895
2896+ 		fn  peer_disconnected ( & self ,  their_node_id :  PublicKey )  { 
2897+ 			self . conn_tracker . peer_disconnected ( their_node_id) ; 
2898+ 		} 
28862899
2887- 		fn  peer_disconnected ( & self ,  _their_node_id :  PublicKey )   { } 
2888- 
2889- 		fn   peer_connected ( & self ,   _their_node_id :   PublicKey ,   _msg :   & Init ,   _inbound :   bool )  ->  Result < ( ) ,   ( ) >   {   Ok ( ( ) )   } 
2900+ 		fn  peer_connected ( & self ,  their_node_id :  PublicKey ,   _msg :   & Init ,   _inbound :   bool )  ->  Result < ( ) ,   ( ) >   { 
2901+ 			 self . conn_tracker . peer_connected ( their_node_id ) 
2902+ 		} 
28902903
28912904		fn  provided_node_features ( & self )  -> NodeFeatures  {  NodeFeatures :: empty ( )  } 
28922905
@@ -2909,7 +2922,7 @@ mod tests {
29092922					chan_handler :  test_utils:: TestChannelMessageHandler :: new ( ChainHash :: using_genesis_block ( Network :: Testnet ) ) , 
29102923					logger :  test_utils:: TestLogger :: with_id ( i. to_string ( ) ) , 
29112924					routing_handler :  test_utils:: TestRoutingMessageHandler :: new ( ) , 
2912- 					custom_handler :  TestCustomMessageHandler   {   features  } , 
2925+ 					custom_handler :  TestCustomMessageHandler :: new ( features) , 
29132926					node_signer :  test_utils:: TestNodeSigner :: new ( node_secret) , 
29142927				} 
29152928			) ; 
@@ -2932,7 +2945,7 @@ mod tests {
29322945					chan_handler :  test_utils:: TestChannelMessageHandler :: new ( ChainHash :: using_genesis_block ( Network :: Testnet ) ) , 
29332946					logger :  test_utils:: TestLogger :: new ( ) , 
29342947					routing_handler :  test_utils:: TestRoutingMessageHandler :: new ( ) , 
2935- 					custom_handler :  TestCustomMessageHandler   {   features  } , 
2948+ 					custom_handler :  TestCustomMessageHandler :: new ( features) , 
29362949					node_signer :  test_utils:: TestNodeSigner :: new ( node_secret) , 
29372950				} 
29382951			) ; 
@@ -2952,7 +2965,7 @@ mod tests {
29522965					chan_handler :  test_utils:: TestChannelMessageHandler :: new ( network) , 
29532966					logger :  test_utils:: TestLogger :: new ( ) , 
29542967					routing_handler :  test_utils:: TestRoutingMessageHandler :: new ( ) , 
2955- 					custom_handler :  TestCustomMessageHandler   {   features  } , 
2968+ 					custom_handler :  TestCustomMessageHandler :: new ( features) , 
29562969					node_signer :  test_utils:: TestNodeSigner :: new ( node_secret) , 
29572970				} 
29582971			) ; 
@@ -2976,19 +2989,16 @@ mod tests {
29762989		peers
29772990	} 
29782991
2979- 	fn  establish_connection < ' a > ( peer_a :  & PeerManager < FileDescriptor ,  & ' a  test_utils:: TestChannelMessageHandler ,  & ' a  test_utils:: TestRoutingMessageHandler ,  IgnoringMessageHandler ,  & ' a  test_utils:: TestLogger ,  & ' a  TestCustomMessageHandler ,  & ' a  test_utils:: TestNodeSigner > ,  peer_b :  & PeerManager < FileDescriptor ,  & ' a  test_utils:: TestChannelMessageHandler ,  & ' a  test_utils:: TestRoutingMessageHandler ,  IgnoringMessageHandler ,  & ' a  test_utils:: TestLogger ,  & ' a  TestCustomMessageHandler ,  & ' a  test_utils:: TestNodeSigner > )  -> ( FileDescriptor ,  FileDescriptor )  { 
2992+ 	fn  try_establish_connection < ' a > ( peer_a :  & PeerManager < FileDescriptor ,  & ' a  test_utils:: TestChannelMessageHandler ,  & ' a  test_utils:: TestRoutingMessageHandler ,  IgnoringMessageHandler ,  & ' a  test_utils:: TestLogger ,  & ' a  TestCustomMessageHandler ,  & ' a  test_utils:: TestNodeSigner > ,  peer_b :  & PeerManager < FileDescriptor ,  & ' a  test_utils:: TestChannelMessageHandler ,  & ' a  test_utils:: TestRoutingMessageHandler ,  IgnoringMessageHandler ,  & ' a  test_utils:: TestLogger ,  & ' a  TestCustomMessageHandler ,  & ' a  test_utils:: TestNodeSigner > )  -> ( FileDescriptor ,  FileDescriptor ,  Result < bool ,  PeerHandleError > ,  Result < bool ,  PeerHandleError > )  { 
2993+ 		let  addr_a = SocketAddress :: TcpIpV4 { addr :  [ 127 ,  0 ,  0 ,  1 ] ,  port :  1000 } ; 
2994+ 		let  addr_b = SocketAddress :: TcpIpV4 { addr :  [ 127 ,  0 ,  0 ,  1 ] ,  port :  1001 } ; 
2995+ 
29802996		static  FD_COUNTER :  AtomicUsize  = AtomicUsize :: new ( 0 ) ; 
29812997		let  fd = FD_COUNTER . fetch_add ( 1 ,  Ordering :: Relaxed )  as  u16 ; 
29822998
29832999		let  id_a = peer_a. node_signer . get_node_id ( Recipient :: Node ) . unwrap ( ) ; 
29843000		let  mut  fd_a = FileDescriptor :: new ( fd) ; 
2985- 		let  addr_a = SocketAddress :: TcpIpV4 { addr :  [ 127 ,  0 ,  0 ,  1 ] ,  port :  1000 } ; 
2986- 
2987- 		let  id_b = peer_b. node_signer . get_node_id ( Recipient :: Node ) . unwrap ( ) ; 
2988- 		let  features_a = peer_a. init_features ( id_b) ; 
2989- 		let  features_b = peer_b. init_features ( id_a) ; 
29903001		let  mut  fd_b = FileDescriptor :: new ( fd) ; 
2991- 		let  addr_b = SocketAddress :: TcpIpV4 { addr :  [ 127 ,  0 ,  0 ,  1 ] ,  port :  1001 } ; 
29923002
29933003		let  initial_data = peer_b. new_outbound_connection ( id_a,  fd_b. clone ( ) ,  Some ( addr_a. clone ( ) ) ) . unwrap ( ) ; 
29943004		peer_a. new_inbound_connection ( fd_a. clone ( ) ,  Some ( addr_b. clone ( ) ) ) . unwrap ( ) ; 
@@ -3000,11 +3010,30 @@ mod tests {
30003010
30013011		peer_b. process_events ( ) ; 
30023012		let  b_data = fd_b. outbound_data . lock ( ) . unwrap ( ) . split_off ( 0 ) ; 
3003- 		assert_eq ! ( peer_a. read_event( & mut  fd_a,  & b_data) . unwrap ( ) ,   false ) ; 
3013+ 		let  a_refused =  peer_a. read_event ( & mut  fd_a,  & b_data) ; 
30043014
30053015		peer_a. process_events ( ) ; 
30063016		let  a_data = fd_a. outbound_data . lock ( ) . unwrap ( ) . split_off ( 0 ) ; 
3007- 		assert_eq ! ( peer_b. read_event( & mut  fd_b,  & a_data) . unwrap( ) ,  false ) ; 
3017+ 		let  b_refused = peer_b. read_event ( & mut  fd_b,  & a_data) ; 
3018+ 
3019+ 		( fd_a,  fd_b,  a_refused,  b_refused) 
3020+ 	} 
3021+ 
3022+ 
3023+ 	fn  establish_connection < ' a > ( peer_a :  & PeerManager < FileDescriptor ,  & ' a  test_utils:: TestChannelMessageHandler ,  & ' a  test_utils:: TestRoutingMessageHandler ,  IgnoringMessageHandler ,  & ' a  test_utils:: TestLogger ,  & ' a  TestCustomMessageHandler ,  & ' a  test_utils:: TestNodeSigner > ,  peer_b :  & PeerManager < FileDescriptor ,  & ' a  test_utils:: TestChannelMessageHandler ,  & ' a  test_utils:: TestRoutingMessageHandler ,  IgnoringMessageHandler ,  & ' a  test_utils:: TestLogger ,  & ' a  TestCustomMessageHandler ,  & ' a  test_utils:: TestNodeSigner > )  -> ( FileDescriptor ,  FileDescriptor )  { 
3024+ 		let  addr_a = SocketAddress :: TcpIpV4 { addr :  [ 127 ,  0 ,  0 ,  1 ] ,  port :  1000 } ; 
3025+ 		let  addr_b = SocketAddress :: TcpIpV4 { addr :  [ 127 ,  0 ,  0 ,  1 ] ,  port :  1001 } ; 
3026+ 
3027+ 		let  id_a = peer_a. node_signer . get_node_id ( Recipient :: Node ) . unwrap ( ) ; 
3028+ 		let  id_b = peer_b. node_signer . get_node_id ( Recipient :: Node ) . unwrap ( ) ; 
3029+ 
3030+ 		let  features_a = peer_a. init_features ( id_b) ; 
3031+ 		let  features_b = peer_b. init_features ( id_a) ; 
3032+ 
3033+ 		let  ( fd_a,  fd_b,  a_refused,  b_refused)  = try_establish_connection ( peer_a,  peer_b) ; 
3034+ 
3035+ 		assert_eq ! ( a_refused. unwrap( ) ,  false ) ; 
3036+ 		assert_eq ! ( b_refused. unwrap( ) ,  false ) ; 
30083037
30093038		assert_eq ! ( peer_a. peer_by_node_id( & id_b) . unwrap( ) . counterparty_node_id,  id_b) ; 
30103039		assert_eq ! ( peer_a. peer_by_node_id( & id_b) . unwrap( ) . socket_address,  Some ( addr_b) ) ; 
@@ -3257,6 +3286,50 @@ mod tests {
32573286		assert_eq ! ( peers[ 0 ] . peers. read( ) . unwrap( ) . len( ) ,  0 ) ; 
32583287	} 
32593288
3289+ 	fn  do_test_peer_connected_error_disconnects ( handler :  usize )  { 
3290+ 		// Test that if a message handler fails a connection in `peer_connected` we reliably 
3291+ 		// produce `peer_disconnected` events for all other message handlers (that saw a 
3292+ 		// corresponding `peer_connected`). 
3293+ 		let  cfgs = create_peermgr_cfgs ( 2 ) ; 
3294+ 		let  peers = create_network ( 2 ,  & cfgs) ; 
3295+ 
3296+ 		match  handler &  !1  { 
3297+ 			0  => { 
3298+ 				peers[ handler &  1 ] . message_handler . chan_handler . conn_tracker . fail_connections . store ( true ,  Ordering :: Release ) ; 
3299+ 			} 
3300+ 			2  => { 
3301+ 				peers[ handler &  1 ] . message_handler . route_handler . conn_tracker . fail_connections . store ( true ,  Ordering :: Release ) ; 
3302+ 			} 
3303+ 			4  => { 
3304+ 				peers[ handler &  1 ] . message_handler . custom_message_handler . conn_tracker . fail_connections . store ( true ,  Ordering :: Release ) ; 
3305+ 			} 
3306+ 			_ => panic ! ( ) , 
3307+ 		} 
3308+ 		let  ( _sd1,  _sd2,  a_refused,  b_refused)  = try_establish_connection ( & peers[ 0 ] ,  & peers[ 1 ] ) ; 
3309+ 		if  handler &  1  == 0  { 
3310+ 			assert ! ( a_refused. is_err( ) ) ; 
3311+ 			assert ! ( peers[ 0 ] . list_peers( ) . is_empty( ) ) ; 
3312+ 		}  else  { 
3313+ 			assert ! ( b_refused. is_err( ) ) ; 
3314+ 			assert ! ( peers[ 1 ] . list_peers( ) . is_empty( ) ) ; 
3315+ 		} 
3316+ 		// At least one message handler should have seen the connection. 
3317+ 		assert ! ( peers[ handler &  1 ] . message_handler. chan_handler. conn_tracker. had_peers. load( Ordering :: Acquire )  ||
3318+ 			peers[ handler &  1 ] . message_handler. route_handler. conn_tracker. had_peers. load( Ordering :: Acquire )  ||
3319+ 			peers[ handler &  1 ] . message_handler. custom_message_handler. conn_tracker. had_peers. load( Ordering :: Acquire ) ) ; 
3320+ 		// And both message handlers doing tracking should see the disconnection 
3321+ 		assert ! ( peers[ handler &  1 ] . message_handler. chan_handler. conn_tracker. connected_peers. lock( ) . unwrap( ) . is_empty( ) ) ; 
3322+ 		assert ! ( peers[ handler &  1 ] . message_handler. route_handler. conn_tracker. connected_peers. lock( ) . unwrap( ) . is_empty( ) ) ; 
3323+ 		assert ! ( peers[ handler &  1 ] . message_handler. custom_message_handler. conn_tracker. connected_peers. lock( ) . unwrap( ) . is_empty( ) ) ; 
3324+ 	} 
3325+ 
3326+ 	#[ test]  
3327+ 	fn  test_peer_connected_error_disconnects ( )  { 
3328+ 		for  i in  0 ..6  { 
3329+ 			do_test_peer_connected_error_disconnects ( i) ; 
3330+ 		} 
3331+ 	} 
3332+ 
32603333	#[ test]  
32613334	fn  test_do_attempt_write_data ( )  { 
32623335		// Create 2 peers with custom TestRoutingMessageHandlers and connect them. 
0 commit comments