11use axum:: {
22 extract:: {
33 ws:: { Message , WebSocket , WebSocketUpgrade } ,
4- State ,
4+ State , Path ,
55 } ,
66 http:: StatusCode ,
77 response:: IntoResponse ,
@@ -24,6 +24,8 @@ use tokio::{
2424 } ,
2525 sync:: oneshot:: { self , error:: TryRecvError , Receiver } ,
2626} ;
27+ use tokio:: net:: lookup_host;
28+
2729use tower_http:: {
2830 services:: ServeDir ,
2931 trace:: { DefaultMakeSpan , TraceLayer } ,
@@ -51,13 +53,6 @@ pub async fn start_server(handle: Handle) {
5153 Ok ( it) => it,
5254 Err ( err) => { error ! ( "Arg parse error:{}" , err) ; panic ! ( "" ) } ,
5355 } ;
54- // tracing_subscriber::registry()
55- // .with(
56- // tracing_subscriber::EnvFilter::try_from_default_env()
57- // .unwrap_or_else(|_| "example_websockets=debug,tower_http=debug".into()),
58- // )
59- // .with(tracing_subscriber::fmt::layer())
60- // .init();
6156 let assets_dir = args. web . clone ( ) ; //PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("assets");
6257 info ! ( "source 地址:{}" , args. source) ;
6358 let target_addr = lookup_host ( & args. target )
@@ -84,11 +79,10 @@ pub async fn start_server(handle: Handle) {
8479 app = app
8580 // routes are matched from bottom to top, so we have to put `nest` at the
8681 // top since it matches all routes
87- . route ( "/websockify" , get ( ws_handler) . with_state ( state) )
82+ . route ( "/websockify/:socket_address " , get ( ws_handler) . with_state ( state) )
8883 // logging so we can see whats going on
8984 . layer ( TraceLayer :: new_for_http ( ) . make_span_with ( DefaultMakeSpan :: default ( ) . include_headers ( true ) ) ) ;
9085
91- use tokio:: net:: lookup_host;
9286 let addr = lookup_host ( & args. source )
9387 . await
9488 . expect ( "Wrong source address" )
@@ -104,10 +98,13 @@ pub async fn start_server(handle: Handle) {
10498 } ;
10599}
106100
107- async fn ws_handler ( ws : WebSocketUpgrade , State ( state) : State < AppState > ) -> impl IntoResponse {
108- // accept connections and process them serially
109-
110- // 也要在handle_socket里把从websocket读到的内容写入tcp
101+ async fn ws_handler ( Path ( socket_address) : Path < String > , ws : WebSocketUpgrade , State ( state) : State < AppState > ) -> impl IntoResponse {
102+ // 如果socket_address是合法的地址,那就优先用这个地址
103+ if let Ok ( mut iter) = lookup_host ( socket_address) . await {
104+ if let Some ( addr) = iter. next ( ) {
105+ return ws. on_upgrade ( move |socket| handle_socket ( socket, addr) ) ;
106+ }
107+ }
111108 ws. on_upgrade ( move |socket| handle_socket ( socket, state. addr ) )
112109}
113110
0 commit comments