@@ -4,7 +4,7 @@ use axum::{
44 body:: Bytes ,
55 extract:: {
66 ws:: { Message , WebSocket , WebSocketUpgrade } ,
7- Path ,
7+ Path , Query ,
88 } ,
99 response:: IntoResponse ,
1010 Extension ,
@@ -66,18 +66,25 @@ impl WsSetting {
6666 }
6767}
6868
69+ #[ derive( Debug , serde:: Deserialize ) ]
70+ pub struct ConnectQueryParams {
71+ #[ serde( default ) ]
72+ reconnect : bool ,
73+ }
74+
6975pub async fn ws_handler (
7076 Extension ( pool) : Extension < Arc < WsSetting > > ,
7177 ws : WebSocketUpgrade ,
7278 Path ( id) : Path < String > ,
79+ Query ( params) : Query < ConnectQueryParams > ,
7380) -> impl IntoResponse {
7481 let request_id = uuid:: Uuid :: new_v4 ( ) . as_u128 ( ) ;
75- log:: info!( "{id}:{request_id:x} connected." ) ;
82+ log:: info!( "{id}:{request_id:x} connected. {:?}" , params ) ;
7683
7784 ws. on_upgrade ( move |socket| async move {
7885 let id = id. clone ( ) ;
7986 let pool = pool. clone ( ) ;
80- if let Err ( e) = handle_socket ( socket, & id, pool. clone ( ) ) . await {
87+ if let Err ( e) = handle_socket ( socket, & id, pool. clone ( ) , params ) . await {
8188 log:: error!( "{id}:{request_id:x} error: {e}" ) ;
8289 } ;
8390 log:: info!( "{id}:{request_id:x} disconnected." ) ;
@@ -1013,9 +1020,6 @@ async fn handle_audio(
10131020 r = submit_to_ai( & pool, & mut ws_tx, & mut chat_session, asr_result) => {
10141021 if let Err ( e) = r {
10151022 log:: error!( "`{id}` error: {e}" ) ;
1016- if let Err ( e) = ws_tx. send( WsCommand :: AsrResult ( vec![ ] ) ) {
1017- log:: error!( "`{id}` error: {e}" ) ;
1018- } ;
10191023 }
10201024 if let Err ( e) = ws_tx. send( WsCommand :: EndResponse ) {
10211025 log:: error!( "`{id}` error: {e}" ) ;
@@ -1130,11 +1134,12 @@ async fn handle_socket(
11301134 mut socket : WebSocket ,
11311135 id : & str ,
11321136 pool : Arc < WsSetting > ,
1137+ connect_params : ConnectQueryParams ,
11331138) -> anyhow:: Result < ( ) > {
11341139 let ( cmd_tx, mut cmd_rx) = tokio:: sync:: mpsc:: unbounded_channel :: < WsCommand > ( ) ;
11351140
11361141 if let Some ( hello_wav) = & pool. hello_wav {
1137- if !hello_wav. is_empty ( ) {
1142+ if !hello_wav. is_empty ( ) && !connect_params . reconnect {
11381143 send_hello_wav ( & mut socket, hello_wav) . await ?;
11391144 }
11401145 }
0 commit comments