Skip to content

Commit 00bc21e

Browse files
author
鄢胜
committed
feat: allow custom host from url param
1 parent 08bbecf commit 00bc21e

File tree

3 files changed

+17
-16
lines changed

3 files changed

+17
-16
lines changed

src/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ fn main() -> windows_service::Result<()> {
5050
.with(file_layer)
5151
.with(time_layer)
5252
.init();
53-
run()
53+
run()
5454
}
5555
#[cfg(feature = "daemonize")]
5656
fn run() -> windows_service::Result<()>{

src/service.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,12 @@ pub fn run_service() -> Result<()> {
9090
let handle = axum_server::Handle::new();
9191

9292
runtime.block_on(async {
93-
tokio::spawn(graceful_shutdown(handle.clone(), shutdown_rx));
93+
let t = tokio::spawn(graceful_shutdown(handle.clone(), shutdown_rx));
9494
start_server(handle).await;
95+
match t.await {
96+
Ok(_) => {},
97+
Err(e) => {println!("graceful_shutdown thread error:{}", e)},
98+
}
9599
});
96100

97101
status_handle.set_service_status(ServiceStatus {

src/web.rs

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use axum::{
22
extract::{
33
ws::{Message, WebSocket, WebSocketUpgrade},
4-
State,
4+
State, Path,
55
},
66
http::StatusCode,
77
response::IntoResponse,
@@ -24,6 +24,8 @@ use tokio::{
2424
},
2525
sync::oneshot::{self, error::TryRecvError, Receiver},
2626
};
27+
use tokio::net::lookup_host;
28+
2729
use tower_http::{
2830
services::ServeDir,
2931
trace::{DefaultMakeSpan, TraceLayer},
@@ -51,13 +53,6 @@ pub async fn start_server(handle: Handle) {
5153
Ok(it) => it,
5254
Err(err) => {error!("Arg parse error:{}", err); panic!("")},
5355
};
54-
// tracing_subscriber::registry()
55-
// .with(
56-
// tracing_subscriber::EnvFilter::try_from_default_env()
57-
// .unwrap_or_else(|_| "example_websockets=debug,tower_http=debug".into()),
58-
// )
59-
// .with(tracing_subscriber::fmt::layer())
60-
// .init();
6156
let assets_dir = args.web.clone(); //PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("assets");
6257
info!("source 地址:{}", args.source);
6358
let target_addr = lookup_host(&args.target)
@@ -84,11 +79,10 @@ pub async fn start_server(handle: Handle) {
8479
app = app
8580
// routes are matched from bottom to top, so we have to put `nest` at the
8681
// top since it matches all routes
87-
.route("/websockify", get(ws_handler).with_state(state))
82+
.route("/websockify/:socket_address", get(ws_handler).with_state(state))
8883
// logging so we can see whats going on
8984
.layer(TraceLayer::new_for_http().make_span_with(DefaultMakeSpan::default().include_headers(true)));
9085

91-
use tokio::net::lookup_host;
9286
let addr = lookup_host(&args.source)
9387
.await
9488
.expect("Wrong source address")
@@ -104,10 +98,13 @@ pub async fn start_server(handle: Handle) {
10498
};
10599
}
106100

107-
async fn ws_handler(ws: WebSocketUpgrade, State(state): State<AppState>) -> impl IntoResponse {
108-
// accept connections and process them serially
109-
110-
// 也要在handle_socket里把从websocket读到的内容写入tcp
101+
async fn ws_handler(Path(socket_address):Path<String>, ws: WebSocketUpgrade, State(state): State<AppState>) -> impl IntoResponse {
102+
// 如果socket_address是合法的地址,那就优先用这个地址
103+
if let Ok(mut iter) = lookup_host(socket_address).await {
104+
if let Some(addr) = iter.next(){
105+
return ws.on_upgrade(move |socket| handle_socket(socket, addr));
106+
}
107+
}
111108
ws.on_upgrade(move |socket| handle_socket(socket, state.addr))
112109
}
113110

0 commit comments

Comments
 (0)