@@ -88,6 +88,8 @@ pub trait CustomMessageHandler: wire::CustomMessageReader {
8888/// May return an `Err(())` if the features the peer supports are not sufficient to communicate 
8989/// with us. Implementors should be somewhat conservative about doing so, however, as other 
9090/// message handlers may still wish to communicate with this peer. 
91+ /// 
92+ /// [`Self::peer_disconnected`] will not be called if `Err(())` is returned. 
9193fn  peer_connected ( & self ,  their_node_id :  PublicKey ,  msg :  & Init ,  inbound :  bool )  -> Result < ( ) ,  ( ) > ; 
9294
9395	/// Gets the node feature flags which this handler itself supports. All available handlers are 
@@ -119,6 +121,7 @@ impl RoutingMessageHandler for IgnoringMessageHandler {
119121		Option < ( msgs:: ChannelAnnouncement ,  Option < msgs:: ChannelUpdate > ,  Option < msgs:: ChannelUpdate > ) >  {  None  } 
120122	fn  get_next_node_announcement ( & self ,  _starting_point :  Option < & NodeId > )  -> Option < msgs:: NodeAnnouncement >  {  None  } 
121123	fn  peer_connected ( & self ,  _their_node_id :  PublicKey ,  _init :  & msgs:: Init ,  _inbound :  bool )  -> Result < ( ) ,  ( ) >  {  Ok ( ( ) )  } 
124+ 	fn  peer_disconnected ( & self ,  _their_node_id :  PublicKey )  {  } 
122125	fn  handle_reply_channel_range ( & self ,  _their_node_id :  PublicKey ,  _msg :  msgs:: ReplyChannelRange )  -> Result < ( ) ,  LightningError >  {  Ok ( ( ) )  } 
123126	fn  handle_reply_short_channel_ids_end ( & self ,  _their_node_id :  PublicKey ,  _msg :  msgs:: ReplyShortChannelIdsEnd )  -> Result < ( ) ,  LightningError >  {  Ok ( ( ) )  } 
124127	fn  handle_query_channel_range ( & self ,  _their_node_id :  PublicKey ,  _msg :  msgs:: QueryChannelRange )  -> Result < ( ) ,  LightningError >  {  Ok ( ( ) )  } 
@@ -1714,14 +1717,20 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
17141717			} 
17151718			if  let  Err ( ( ) )  = self . message_handler . chan_handler . peer_connected ( their_node_id,  & msg,  peer_lock. inbound_connection )  { 
17161719				log_debug ! ( logger,  "Channel Handler decided we couldn't communicate with peer {}" ,  log_pubkey!( their_node_id) ) ; 
1720+ 				self . message_handler . route_handler . peer_disconnected ( their_node_id) ; 
17171721				return  Err ( PeerHandleError  {  } . into ( ) ) ; 
17181722			} 
17191723			if  let  Err ( ( ) )  = self . message_handler . onion_message_handler . peer_connected ( their_node_id,  & msg,  peer_lock. inbound_connection )  { 
17201724				log_debug ! ( logger,  "Onion Message Handler decided we couldn't communicate with peer {}" ,  log_pubkey!( their_node_id) ) ; 
1725+ 				self . message_handler . route_handler . peer_disconnected ( their_node_id) ; 
1726+ 				self . message_handler . chan_handler . peer_disconnected ( their_node_id) ; 
17211727				return  Err ( PeerHandleError  {  } . into ( ) ) ; 
17221728			} 
17231729			if  let  Err ( ( ) )  = self . message_handler . custom_message_handler . peer_connected ( their_node_id,  & msg,  peer_lock. inbound_connection )  { 
17241730				log_debug ! ( logger,  "Custom Message Handler decided we couldn't communicate with peer {}" ,  log_pubkey!( their_node_id) ) ; 
1731+ 				self . message_handler . route_handler . peer_disconnected ( their_node_id) ; 
1732+ 				self . message_handler . chan_handler . peer_disconnected ( their_node_id) ; 
1733+ 				self . message_handler . onion_message_handler . peer_disconnected ( their_node_id) ; 
17251734				return  Err ( PeerHandleError  {  } . into ( ) ) ; 
17261735			} 
17271736
@@ -2533,6 +2542,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
25332542		debug_assert ! ( peer. their_node_id. is_some( ) ) ; 
25342543		if  let  Some ( ( node_id,  _) )  = peer. their_node_id  { 
25352544			log_trace ! ( WithContext :: from( & self . logger,  Some ( node_id) ,  None ,  None ) ,  "Disconnecting peer with id {} due to {}" ,  node_id,  reason) ; 
2545+ 			self . message_handler . route_handler . peer_disconnected ( node_id) ; 
25362546			self . message_handler . chan_handler . peer_disconnected ( node_id) ; 
25372547			self . message_handler . onion_message_handler . peer_disconnected ( node_id) ; 
25382548			self . message_handler . custom_message_handler . peer_disconnected ( node_id) ; 
@@ -2557,6 +2567,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
25572567					let  removed = self . node_id_to_descriptor . lock ( ) . unwrap ( ) . remove ( & node_id) ; 
25582568					debug_assert ! ( removed. is_some( ) ,  "descriptor maps should be consistent" ) ; 
25592569					if  !peer. handshake_complete ( )  {  return ;  } 
2570+ 					self . message_handler . route_handler . peer_disconnected ( node_id) ; 
25602571					self . message_handler . chan_handler . peer_disconnected ( node_id) ; 
25612572					self . message_handler . onion_message_handler . peer_disconnected ( node_id) ; 
25622573					self . message_handler . custom_message_handler . peer_disconnected ( node_id) ; 
@@ -2856,6 +2867,16 @@ mod tests {
28562867
28572868	struct  TestCustomMessageHandler  { 
28582869		features :  InitFeatures , 
2870+ 		conn_tracker :  test_utils:: ConnectionTracker , 
2871+ 	} 
2872+ 
2873+ 	impl  TestCustomMessageHandler  { 
2874+ 		fn  new ( features :  InitFeatures )  -> Self  { 
2875+ 			Self  { 
2876+ 				features, 
2877+ 				conn_tracker :  test_utils:: ConnectionTracker :: new ( ) , 
2878+ 			} 
2879+ 		} 
28592880	} 
28602881
28612882	impl  wire:: CustomMessageReader  for  TestCustomMessageHandler  { 
@@ -2872,10 +2893,13 @@ mod tests {
28722893
28732894		fn  get_and_clear_pending_msg ( & self )  -> Vec < ( PublicKey ,  Self :: CustomMessage ) >  {  Vec :: new ( )  } 
28742895
2896+ 		fn  peer_disconnected ( & self ,  their_node_id :  PublicKey )  { 
2897+ 			self . conn_tracker . peer_disconnected ( their_node_id) ; 
2898+ 		} 
28752899
2876- 		fn  peer_disconnected ( & self ,  _their_node_id :  PublicKey )   { } 
2877- 
2878- 		fn   peer_connected ( & self ,   _their_node_id :   PublicKey ,   _msg :   & Init ,   _inbound :   bool )  ->  Result < ( ) ,   ( ) >   {   Ok ( ( ) )   } 
2900+ 		fn  peer_connected ( & self ,  their_node_id :  PublicKey ,   _msg :   & Init ,   _inbound :   bool )  ->  Result < ( ) ,   ( ) >   { 
2901+ 			 self . conn_tracker . peer_connected ( their_node_id ) 
2902+ 		} 
28792903
28802904		fn  provided_node_features ( & self )  -> NodeFeatures  {  NodeFeatures :: empty ( )  } 
28812905
@@ -2898,7 +2922,7 @@ mod tests {
28982922					chan_handler :  test_utils:: TestChannelMessageHandler :: new ( ChainHash :: using_genesis_block ( Network :: Testnet ) ) , 
28992923					logger :  test_utils:: TestLogger :: with_id ( i. to_string ( ) ) , 
29002924					routing_handler :  test_utils:: TestRoutingMessageHandler :: new ( ) , 
2901- 					custom_handler :  TestCustomMessageHandler   {   features  } , 
2925+ 					custom_handler :  TestCustomMessageHandler :: new ( features) , 
29022926					node_signer :  test_utils:: TestNodeSigner :: new ( node_secret) , 
29032927				} 
29042928			) ; 
@@ -2921,7 +2945,7 @@ mod tests {
29212945					chan_handler :  test_utils:: TestChannelMessageHandler :: new ( ChainHash :: using_genesis_block ( Network :: Testnet ) ) , 
29222946					logger :  test_utils:: TestLogger :: new ( ) , 
29232947					routing_handler :  test_utils:: TestRoutingMessageHandler :: new ( ) , 
2924- 					custom_handler :  TestCustomMessageHandler   {   features  } , 
2948+ 					custom_handler :  TestCustomMessageHandler :: new ( features) , 
29252949					node_signer :  test_utils:: TestNodeSigner :: new ( node_secret) , 
29262950				} 
29272951			) ; 
@@ -2941,7 +2965,7 @@ mod tests {
29412965					chan_handler :  test_utils:: TestChannelMessageHandler :: new ( network) , 
29422966					logger :  test_utils:: TestLogger :: new ( ) , 
29432967					routing_handler :  test_utils:: TestRoutingMessageHandler :: new ( ) , 
2944- 					custom_handler :  TestCustomMessageHandler   {   features  } , 
2968+ 					custom_handler :  TestCustomMessageHandler :: new ( features) , 
29452969					node_signer :  test_utils:: TestNodeSigner :: new ( node_secret) , 
29462970				} 
29472971			) ; 
@@ -2965,19 +2989,16 @@ mod tests {
29652989		peers
29662990	} 
29672991
2968- 	fn  establish_connection < ' a > ( peer_a :  & PeerManager < FileDescriptor ,  & ' a  test_utils:: TestChannelMessageHandler ,  & ' a  test_utils:: TestRoutingMessageHandler ,  IgnoringMessageHandler ,  & ' a  test_utils:: TestLogger ,  & ' a  TestCustomMessageHandler ,  & ' a  test_utils:: TestNodeSigner > ,  peer_b :  & PeerManager < FileDescriptor ,  & ' a  test_utils:: TestChannelMessageHandler ,  & ' a  test_utils:: TestRoutingMessageHandler ,  IgnoringMessageHandler ,  & ' a  test_utils:: TestLogger ,  & ' a  TestCustomMessageHandler ,  & ' a  test_utils:: TestNodeSigner > )  -> ( FileDescriptor ,  FileDescriptor )  { 
2992+ 	fn  try_establish_connection < ' a > ( peer_a :  & PeerManager < FileDescriptor ,  & ' a  test_utils:: TestChannelMessageHandler ,  & ' a  test_utils:: TestRoutingMessageHandler ,  IgnoringMessageHandler ,  & ' a  test_utils:: TestLogger ,  & ' a  TestCustomMessageHandler ,  & ' a  test_utils:: TestNodeSigner > ,  peer_b :  & PeerManager < FileDescriptor ,  & ' a  test_utils:: TestChannelMessageHandler ,  & ' a  test_utils:: TestRoutingMessageHandler ,  IgnoringMessageHandler ,  & ' a  test_utils:: TestLogger ,  & ' a  TestCustomMessageHandler ,  & ' a  test_utils:: TestNodeSigner > )  -> ( FileDescriptor ,  FileDescriptor ,  Result < bool ,  PeerHandleError > ,  Result < bool ,  PeerHandleError > )  { 
2993+ 		let  addr_a = SocketAddress :: TcpIpV4 { addr :  [ 127 ,  0 ,  0 ,  1 ] ,  port :  1000 } ; 
2994+ 		let  addr_b = SocketAddress :: TcpIpV4 { addr :  [ 127 ,  0 ,  0 ,  1 ] ,  port :  1001 } ; 
2995+ 
29692996		static  FD_COUNTER :  AtomicUsize  = AtomicUsize :: new ( 0 ) ; 
29702997		let  fd = FD_COUNTER . fetch_add ( 1 ,  Ordering :: Relaxed )  as  u16 ; 
29712998
29722999		let  id_a = peer_a. node_signer . get_node_id ( Recipient :: Node ) . unwrap ( ) ; 
29733000		let  mut  fd_a = FileDescriptor :: new ( fd) ; 
2974- 		let  addr_a = SocketAddress :: TcpIpV4 { addr :  [ 127 ,  0 ,  0 ,  1 ] ,  port :  1000 } ; 
2975- 
2976- 		let  id_b = peer_b. node_signer . get_node_id ( Recipient :: Node ) . unwrap ( ) ; 
2977- 		let  features_a = peer_a. init_features ( id_b) ; 
2978- 		let  features_b = peer_b. init_features ( id_a) ; 
29793001		let  mut  fd_b = FileDescriptor :: new ( fd) ; 
2980- 		let  addr_b = SocketAddress :: TcpIpV4 { addr :  [ 127 ,  0 ,  0 ,  1 ] ,  port :  1001 } ; 
29813002
29823003		let  initial_data = peer_b. new_outbound_connection ( id_a,  fd_b. clone ( ) ,  Some ( addr_a. clone ( ) ) ) . unwrap ( ) ; 
29833004		peer_a. new_inbound_connection ( fd_a. clone ( ) ,  Some ( addr_b. clone ( ) ) ) . unwrap ( ) ; 
@@ -2989,11 +3010,30 @@ mod tests {
29893010
29903011		peer_b. process_events ( ) ; 
29913012		let  b_data = fd_b. outbound_data . lock ( ) . unwrap ( ) . split_off ( 0 ) ; 
2992- 		assert_eq ! ( peer_a. read_event( & mut  fd_a,  & b_data) . unwrap ( ) ,   false ) ; 
3013+ 		let  a_refused =  peer_a. read_event ( & mut  fd_a,  & b_data) ; 
29933014
29943015		peer_a. process_events ( ) ; 
29953016		let  a_data = fd_a. outbound_data . lock ( ) . unwrap ( ) . split_off ( 0 ) ; 
2996- 		assert_eq ! ( peer_b. read_event( & mut  fd_b,  & a_data) . unwrap( ) ,  false ) ; 
3017+ 		let  b_refused = peer_b. read_event ( & mut  fd_b,  & a_data) ; 
3018+ 
3019+ 		( fd_a,  fd_b,  a_refused,  b_refused) 
3020+ 	} 
3021+ 
3022+ 
3023+ 	fn  establish_connection < ' a > ( peer_a :  & PeerManager < FileDescriptor ,  & ' a  test_utils:: TestChannelMessageHandler ,  & ' a  test_utils:: TestRoutingMessageHandler ,  IgnoringMessageHandler ,  & ' a  test_utils:: TestLogger ,  & ' a  TestCustomMessageHandler ,  & ' a  test_utils:: TestNodeSigner > ,  peer_b :  & PeerManager < FileDescriptor ,  & ' a  test_utils:: TestChannelMessageHandler ,  & ' a  test_utils:: TestRoutingMessageHandler ,  IgnoringMessageHandler ,  & ' a  test_utils:: TestLogger ,  & ' a  TestCustomMessageHandler ,  & ' a  test_utils:: TestNodeSigner > )  -> ( FileDescriptor ,  FileDescriptor )  { 
3024+ 		let  addr_a = SocketAddress :: TcpIpV4 { addr :  [ 127 ,  0 ,  0 ,  1 ] ,  port :  1000 } ; 
3025+ 		let  addr_b = SocketAddress :: TcpIpV4 { addr :  [ 127 ,  0 ,  0 ,  1 ] ,  port :  1001 } ; 
3026+ 
3027+ 		let  id_a = peer_a. node_signer . get_node_id ( Recipient :: Node ) . unwrap ( ) ; 
3028+ 		let  id_b = peer_b. node_signer . get_node_id ( Recipient :: Node ) . unwrap ( ) ; 
3029+ 
3030+ 		let  features_a = peer_a. init_features ( id_b) ; 
3031+ 		let  features_b = peer_b. init_features ( id_a) ; 
3032+ 
3033+ 		let  ( fd_a,  fd_b,  a_refused,  b_refused)  = try_establish_connection ( peer_a,  peer_b) ; 
3034+ 
3035+ 		assert_eq ! ( a_refused. unwrap( ) ,  false ) ; 
3036+ 		assert_eq ! ( b_refused. unwrap( ) ,  false ) ; 
29973037
29983038		assert_eq ! ( peer_a. peer_by_node_id( & id_b) . unwrap( ) . counterparty_node_id,  id_b) ; 
29993039		assert_eq ! ( peer_a. peer_by_node_id( & id_b) . unwrap( ) . socket_address,  Some ( addr_b) ) ; 
@@ -3246,6 +3286,50 @@ mod tests {
32463286		assert_eq ! ( peers[ 0 ] . peers. read( ) . unwrap( ) . len( ) ,  0 ) ; 
32473287	} 
32483288
3289+ 	fn  do_test_peer_connected_error_disconnects ( handler :  usize )  { 
3290+ 		// Test that if a message handler fails a connection in `peer_connected` we reliably 
3291+ 		// produce `peer_disconnected` events for all other message handlers (that saw a 
3292+ 		// corresponding `peer_connected`). 
3293+ 		let  cfgs = create_peermgr_cfgs ( 2 ) ; 
3294+ 		let  peers = create_network ( 2 ,  & cfgs) ; 
3295+ 
3296+ 		match  handler &  !1  { 
3297+ 			0  => { 
3298+ 				peers[ handler &  1 ] . message_handler . chan_handler . conn_tracker . fail_connections . store ( true ,  Ordering :: Release ) ; 
3299+ 			} 
3300+ 			2  => { 
3301+ 				peers[ handler &  1 ] . message_handler . route_handler . conn_tracker . fail_connections . store ( true ,  Ordering :: Release ) ; 
3302+ 			} 
3303+ 			4  => { 
3304+ 				peers[ handler &  1 ] . message_handler . custom_message_handler . conn_tracker . fail_connections . store ( true ,  Ordering :: Release ) ; 
3305+ 			} 
3306+ 			_ => panic ! ( ) , 
3307+ 		} 
3308+ 		let  ( _sd1,  _sd2,  a_refused,  b_refused)  = try_establish_connection ( & peers[ 0 ] ,  & peers[ 1 ] ) ; 
3309+ 		if  handler &  1  == 0  { 
3310+ 			assert ! ( a_refused. is_err( ) ) ; 
3311+ 			assert ! ( peers[ 0 ] . list_peers( ) . is_empty( ) ) ; 
3312+ 		}  else  { 
3313+ 			assert ! ( b_refused. is_err( ) ) ; 
3314+ 			assert ! ( peers[ 1 ] . list_peers( ) . is_empty( ) ) ; 
3315+ 		} 
3316+ 		// At least one message handler should have seen the connection. 
3317+ 		assert ! ( peers[ handler &  1 ] . message_handler. chan_handler. conn_tracker. had_peers. load( Ordering :: Acquire )  ||
3318+ 			peers[ handler &  1 ] . message_handler. route_handler. conn_tracker. had_peers. load( Ordering :: Acquire )  ||
3319+ 			peers[ handler &  1 ] . message_handler. custom_message_handler. conn_tracker. had_peers. load( Ordering :: Acquire ) ) ; 
3320+ 		// And both message handlers doing tracking should see the disconnection 
3321+ 		assert ! ( peers[ handler &  1 ] . message_handler. chan_handler. conn_tracker. connected_peers. lock( ) . unwrap( ) . is_empty( ) ) ; 
3322+ 		assert ! ( peers[ handler &  1 ] . message_handler. route_handler. conn_tracker. connected_peers. lock( ) . unwrap( ) . is_empty( ) ) ; 
3323+ 		assert ! ( peers[ handler &  1 ] . message_handler. custom_message_handler. conn_tracker. connected_peers. lock( ) . unwrap( ) . is_empty( ) ) ; 
3324+ 	} 
3325+ 
3326+ 	#[ test]  
3327+ 	fn  test_peer_connected_error_disconnects ( )  { 
3328+ 		for  i in  0 ..6  { 
3329+ 			do_test_peer_connected_error_disconnects ( i) ; 
3330+ 		} 
3331+ 	} 
3332+ 
32493333	#[ test]  
32503334	fn  test_do_attempt_write_data ( )  { 
32513335		// Create 2 peers with custom TestRoutingMessageHandlers and connect them. 
0 commit comments