Skip to content

Commit cf3f8b5

Browse files
committed
fix: Reload the prompt when the device restarts
1 parent 38b0310 commit cf3f8b5

File tree

2 files changed

+52
-18
lines changed

2 files changed

+52
-18
lines changed

src/ai/mod.rs

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,25 @@ pub async fn llm_stable<'p, I: IntoIterator<Item = C>, C: AsRef<llm::Content>>(
433433

434434
let tool_choice = if tools.is_empty() { "" } else { "auto" };
435435

436+
let tool_name = tools
437+
.iter()
438+
.map(|t| t.function.name.as_str())
439+
.collect::<Vec<_>>();
440+
441+
log::debug!(
442+
"#### send to llm:\n{}\n#####",
443+
serde_json::to_string_pretty(&serde_json::json!(
444+
{
445+
"stream": true,
446+
"chat_id": chat_id,
447+
"messages": messages,
448+
"model": model.to_string(),
449+
"tools": tool_name,
450+
"tool_choice": tool_choice,
451+
}
452+
))?
453+
);
454+
436455
let request = StableLlmRequest {
437456
stream: true,
438457
chat_id: chat_id.unwrap_or_default(),
@@ -442,11 +461,6 @@ pub async fn llm_stable<'p, I: IntoIterator<Item = C>, C: AsRef<llm::Content>>(
442461
tool_choice,
443462
};
444463

445-
log::debug!(
446-
"#### send to llm:\n{}\n#####",
447-
serde_json::to_string_pretty(&request)?
448-
);
449-
450464
let response = response_builder
451465
.header(reqwest::header::USER_AGENT, "curl/7.81.0")
452466
.json(&request)

src/services/ws/stable/mod.rs

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ async fn handle_socket(
6464
request_id,
6565
cmd_tx,
6666
client_rx,
67+
is_reconnect: params.reconnect,
6768
})
6869
.map_err(|e| anyhow::anyhow!("send session error: {}", e))?;
6970

@@ -76,6 +77,7 @@ pub struct Session {
7677
request_id: u128,
7778
cmd_tx: super::WsTx,
7879
client_rx: super::ClientRx,
80+
is_reconnect: bool,
7981
}
8082

8183
async fn run_session(
@@ -240,7 +242,10 @@ pub async fn run_session_manager(
240242
tools: &ToolSet<McpToolAdapter>,
241243
mut session_rx: tokio::sync::mpsc::UnboundedReceiver<Session>,
242244
) -> anyhow::Result<()> {
243-
let mut sessions: HashMap<String, tokio::sync::mpsc::UnboundedSender<Session>> = HashMap::new();
245+
let mut sessions: HashMap<
246+
String,
247+
tokio::sync::mpsc::UnboundedSender<(Session, Option<llm::PromptParts>)>,
248+
> = HashMap::new();
244249

245250
let mut tts_session_pool = tts::TTSSessionPool::new(tts.clone(), 4);
246251
let (tts_req_tx, tts_req_rx) = tokio::sync::mpsc::channel(128);
@@ -252,22 +257,33 @@ pub async fn run_session_manager(
252257
});
253258

254259
while let Some(session) = session_rx.recv().await {
255-
let session = if let Some(tx) = sessions.get(&session.id) {
256-
if let Err(e) = tx.send(session) {
260+
let prompts;
261+
if !session.is_reconnect {
262+
prompts = Some(llm.prompts().await)
263+
} else {
264+
prompts = None
265+
}
266+
let (session, mut prompts) = if let Some(tx) = sessions.get(&session.id) {
267+
if let Err(e) = tx.send((session, prompts)) {
257268
e.0
258269
} else {
259270
continue;
260271
}
261272
} else {
262-
session
273+
(session, prompts)
263274
};
264275

276+
// device reconnects but server restarted
277+
if prompts.is_none() {
278+
prompts = Some(llm.prompts().await);
279+
}
280+
265281
// run session
266282
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
267283

268284
let id = session.id.clone();
269285
log::info!("Starting new session for id: {}", id);
270-
let _ = tx.send(session);
286+
let _ = tx.send((session, prompts));
271287
let asr = asr.clone();
272288

273289
let mut chat_session = super::ChatSession::new(
@@ -279,10 +295,6 @@ pub async fn run_session_manager(
279295
tools.clone(),
280296
);
281297

282-
let part = llm.prompts().await;
283-
chat_session.system_prompts = part.sys_prompts;
284-
chat_session.messages = part.dynamic_prompts;
285-
286298
sessions.insert(id.clone(), tx);
287299

288300
let mut tts_req_tx = tts_req_tx.clone();
@@ -293,13 +305,17 @@ pub async fn run_session_manager(
293305
anyhow::anyhow!("error creating asr session for id `{}`: {}", id, e)
294306
})?;
295307

296-
let mut session = rx
308+
let (mut session, mut prompts) = rx
297309
.recv()
298310
.await
299311
.ok_or_else(|| anyhow::anyhow!("no session received for id `{}`", id))?;
300312

301313
loop {
302314
log::info!("Running session for id `{}`", id);
315+
if let Some(prompts) = prompts.take() {
316+
chat_session.system_prompts = prompts.sys_prompts;
317+
chat_session.messages = prompts.dynamic_prompts;
318+
}
303319

304320
let run_fut = run_session(
305321
&mut chat_session,
@@ -326,9 +342,10 @@ pub async fn run_session_manager(
326342
Ok(Err(e)) => {
327343
log::error!("session for id `{}` error: {}", id, e);
328344
}
329-
Err(Some(new_session)) => {
345+
Err(Some((new_session, new_prompts))) => {
330346
log::info!("received new session for id `{}`, restarting session", id);
331347
session = new_session;
348+
prompts = new_prompts;
332349
continue;
333350
}
334351
Err(None) => {
@@ -337,8 +354,11 @@ pub async fn run_session_manager(
337354
}
338355
}
339356

340-
session = match rx.recv().await {
341-
Some(s) => s,
357+
match rx.recv().await {
358+
Some(s) => {
359+
session = s.0;
360+
prompts = s.1;
361+
}
342362
None => {
343363
log::info!("no more sessions for id `{}`, exiting", id);
344364
break;

0 commit comments

Comments
 (0)