@@ -5,13 +5,14 @@ use socks5_impl::{
55} ;
66use std:: {
77 net:: { SocketAddr , ToSocketAddrs } ,
8- sync:: { Arc , atomic :: AtomicBool } ,
8+ sync:: Arc ,
99} ;
1010use tokio:: {
1111 io,
1212 net:: { TcpStream , UdpSocket } ,
1313 sync:: Mutex ,
1414} ;
15+ use tokio_util:: sync:: CancellationToken ;
1516
1617/// Simple socks5 proxy server.
1718#[ derive( clap:: Parser , Debug , Clone , PartialEq , Eq ) ]
@@ -55,52 +56,48 @@ async fn main() -> Result<()> {
5556 let default = format ! ( "{}={:?}" , module_path!( ) , opt. verbosity) ;
5657 env_logger:: Builder :: from_env ( env_logger:: Env :: default ( ) . default_filter_or ( default) ) . init ( ) ;
5758
58- let exiting_flag = Arc :: new ( AtomicBool :: new ( false ) ) ;
59- let exiting_flag_clone = exiting_flag . clone ( ) ;
59+ let token = CancellationToken :: new ( ) ;
60+ let cloned_token = token . clone ( ) ;
6061
61- let local_addr = opt. listen_addr ;
62-
63- ctrlc2:: set_async_handler ( async move {
64- exiting_flag_clone. store ( true , std:: sync:: atomic:: Ordering :: Relaxed ) ;
65-
66- let addr = if local_addr. is_ipv6 ( ) {
67- SocketAddr :: from ( ( std:: net:: Ipv6Addr :: LOCALHOST , local_addr. port ( ) ) )
68- } else {
69- SocketAddr :: from ( ( std:: net:: Ipv4Addr :: LOCALHOST , local_addr. port ( ) ) )
70- } ;
71- let _ = std:: net:: TcpStream :: connect ( addr) ;
62+ let ctrlc = ctrlc2:: AsyncCtrlC :: new ( move || {
7263 log:: info!( "" ) ;
7364 log:: info!( "Ctrl-C received, shutting down..." ) ;
65+ cloned_token. cancel ( ) ;
66+ true
7467 } )
75- . await ;
68+ . expect ( "Failed to set up Ctrl-C handler" ) ;
7669
7770 match ( opt. username , opt. password ) {
7871 ( Some ( username) , password) => {
7972 let password = password. unwrap_or_default ( ) ;
8073 let auth = Arc :: new ( auth:: UserKeyAuth :: new ( & username, & password) ) ;
81- main_loop ( auth, opt. listen_addr , Some ( exiting_flag ) ) . await ?;
74+ main_loop ( auth, opt. listen_addr , token ) . await ?;
8275 }
8376 _ => {
8477 let auth = Arc :: new ( auth:: NoAuth ) ;
85- main_loop ( auth, opt. listen_addr , Some ( exiting_flag ) ) . await ?;
78+ main_loop ( auth, opt. listen_addr , token ) . await ?;
8679 }
8780 }
8881
82+ ctrlc. await ;
83+
8984 Ok ( ( ) )
9085}
9186
92- async fn main_loop < S > ( auth : auth:: AuthAdaptor < S > , listen_addr : SocketAddr , exiting_flag : Option < Arc < AtomicBool > > ) -> Result < ( ) >
87+ async fn main_loop < S > ( auth : auth:: AuthAdaptor < S > , listen_addr : SocketAddr , token : CancellationToken ) -> Result < ( ) >
9388where
9489 S : Send + Sync + ' static ,
9590{
9691 let server = Server :: bind ( listen_addr, auth) . await ?;
9792
98- while let Ok ( ( conn, _) ) = server. accept ( ) . await {
99- if let Some ( exiting_flag) = & exiting_flag {
100- if exiting_flag. load ( std:: sync:: atomic:: Ordering :: Relaxed ) {
93+ loop {
94+ let ( conn, _) = tokio:: select! {
95+ _ = token. cancelled( ) => {
96+ log:: info!( "CancellationToken fired, session will be closed" ) ;
10197 break ;
10298 }
103- }
99+ conn = server. accept( ) => conn?,
100+ } ;
104101 tokio:: spawn ( async move {
105102 if let Err ( err) = handle ( conn) . await {
106103 log:: error!( "{err}" ) ;
0 commit comments