Skip to content

Commit 8519509

Browse files
committed
feat: Support custom messages when calling mcp
1 parent f9d904a commit 8519509

File tree

6 files changed

+64
-10
lines changed

6 files changed

+64
-10
lines changed

src/ai/mod.rs

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -682,6 +682,11 @@ impl ChatSession {
682682
Ok(response)
683683
}
684684

685+
pub fn get_tool_call_message(&self, tool_call: &llm::ToolCall) -> Option<String> {
686+
let tool = self.tools.get_tool(tool_call.function.name.as_str())?;
687+
Some(tool.call_mcp_message().to_string())
688+
}
689+
685690
pub async fn execute_tool(&mut self, tool_call: &llm::ToolCall) -> anyhow::Result<()> {
686691
use crate::ai::openai::tool::Tool;
687692

@@ -775,6 +780,7 @@ pub async fn load_sse_tools(
775780
rmcp::service::RunningService<rmcp::RoleClient, rmcp::model::InitializeRequestParam>,
776781
>,
777782
mcp_servers_url: &str,
783+
call_mcp_message: &str,
778784
) -> anyhow::Result<()> {
779785
// load MCP
780786
let transport = SseClientTransport::start(mcp_servers_url).await?;
@@ -794,7 +800,11 @@ pub async fn load_sse_tools(
794800
for tool in tools {
795801
let server = client.peer().clone();
796802
log::info!("add tool: {}", tool.name);
797-
tool_set.add_tool(McpToolAdapter::new(tool, server));
803+
tool_set.add_tool(McpToolAdapter::new(
804+
tool,
805+
call_mcp_message.to_string(),
806+
server,
807+
));
798808
}
799809
clients.push(client);
800810
Ok(())
@@ -806,6 +816,7 @@ pub async fn load_http_streamable_tools(
806816
rmcp::service::RunningService<rmcp::RoleClient, rmcp::model::InitializeRequestParam>,
807817
>,
808818
mcp_servers_url: &str,
819+
call_mcp_message: &str,
809820
) -> anyhow::Result<()> {
810821
// load MCP
811822
let transport = StreamableHttpClientTransport::from_uri(mcp_servers_url);
@@ -825,7 +836,11 @@ pub async fn load_http_streamable_tools(
825836
for tool in tools {
826837
let server = client.peer().clone();
827838
log::info!("add tool: {}", tool.name);
828-
tool_set.add_tool(McpToolAdapter::new(tool, server));
839+
tool_set.add_tool(McpToolAdapter::new(
840+
tool,
841+
call_mcp_message.to_string(),
842+
server,
843+
));
829844
}
830845

831846
clients.push(client);
@@ -857,7 +872,7 @@ async fn test_chat_session() {
857872
let mut clients = vec![];
858873

859874
let mut tools = ToolSet::default();
860-
load_http_streamable_tools(&mut tools, &mut clients, "http://localhost:8000/mcp")
875+
load_http_streamable_tools(&mut tools, &mut clients, "http://localhost:8000/mcp", "")
861876
.await
862877
.unwrap();
863878

src/ai/openai/tool.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,23 @@ pub trait Tool: Send + Sync {
1717

1818
pub struct McpToolAdapter {
1919
tool: McpTool,
20+
call_mcp_message: String,
2021
server: ServerSink,
2122
}
2223

2324
impl McpToolAdapter {
24-
pub fn new(tool: McpTool, server: ServerSink) -> Self {
25-
Self { tool, server }
25+
pub fn new(tool: McpTool, call_mcp_message: String, server: ServerSink) -> Self {
26+
Self {
27+
tool,
28+
call_mcp_message,
29+
server,
30+
}
31+
}
32+
}
33+
34+
impl McpToolAdapter {
35+
pub fn call_mcp_message(&self) -> &str {
36+
&self.call_mcp_message
2637
}
2738
}
2839

src/config.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ pub struct MCPServerConfig {
2222
pub api_key: String,
2323
#[serde(rename = "type", default)]
2424
pub type_: MCPType,
25+
#[serde(default)]
26+
pub call_mcp_message: String,
2527
}
2628

2729
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]

src/main.rs

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,16 +66,25 @@ async fn routes(
6666
for server in &llm.mcp_server {
6767
match server.type_ {
6868
config::MCPType::SSE => {
69-
if let Err(e) =
70-
ai::load_sse_tools(&mut tool_set, clients, &server.server).await
69+
if let Err(e) = ai::load_sse_tools(
70+
&mut tool_set,
71+
clients,
72+
&server.server,
73+
&server.call_mcp_message,
74+
)
75+
.await
7176
{
7277
log::error!("Failed to load tools from {}: {}", &server.server, e);
7378
}
7479
}
7580
config::MCPType::HttpStreamable => {
76-
if let Err(e) =
77-
ai::load_http_streamable_tools(&mut tool_set, clients, &server.server)
78-
.await
81+
if let Err(e) = ai::load_http_streamable_tools(
82+
&mut tool_set,
83+
clients,
84+
&server.server,
85+
&server.call_mcp_message,
86+
)
87+
.await
7988
{
8089
log::error!("Failed to load tools from {}: {}", &server.server, e);
8190
}

src/services/ws/stable/llm.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,19 @@ pub async fn chat(
204204
log::info!("llm functions: {:#?}", functions);
205205
chat_session.add_assistant_tool_call(functions.clone());
206206
for function in functions {
207+
if let Some(message) = chat_session.get_tool_call_message(&function) {
208+
log::info!("tool {} call message: {}", &function.function.name, message);
209+
if !message.is_empty() {
210+
let (tts_resp_tx, tts_resp_rx) = tokio::sync::mpsc::unbounded_channel();
211+
drop(tts_resp_tx);
212+
213+
chunks_tx.send((message, tts_resp_rx)).map_err(|e| {
214+
anyhow::anyhow!(
215+
"error sending tts chunks receiver for llm chunk: {e}"
216+
)
217+
})?;
218+
}
219+
}
207220
chat_session.execute_tool(&function).await?
208221
}
209222
resp = chat_session.complete().await?;

src/services/ws/stable/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,10 @@ async fn handle_tts_requests(mut chunks_rx: ChunksRx, session: &mut Session) ->
210210
tts_chunk.len()
211211
);
212212

213+
if tts_chunk.is_empty() {
214+
continue;
215+
}
216+
213217
session
214218
.cmd_tx
215219
.send(super::WsCommand::Audio(tts_chunk))

0 commit comments

Comments
 (0)