Skip to content

Commit 70f4db6

Browse files
committed
fix(lib): refactor MCP client to manage its own Tokio runtime and improve async handling
1 parent 0c53f76 commit 70f4db6

File tree

1 file changed

+78
-80
lines changed

1 file changed

+78
-80
lines changed

src/lib.rs

Lines changed: 78 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,6 @@ use rmcp::{ServiceExt, RoleClient};
1515
use rmcp::model::{ClientInfo, ClientCapabilities, Implementation};
1616
use serde_json;
1717

18-
// Global runtime - one runtime per process that stays alive
19-
static GLOBAL_RUNTIME: OnceLock<tokio::runtime::Runtime> = OnceLock::new();
20-
2118
// Global client instance - one client per process
2219
static GLOBAL_CLIENT: OnceLock<Mutex<Option<McpClient>>> = OnceLock::new();
2320

@@ -240,41 +237,26 @@ type RunningClient = rmcp::service::RunningService<RoleClient, ClientInfo>;
240237

241238
/// Opaque handle for MCP client
242239
pub struct McpClient {
240+
runtime: tokio::runtime::Runtime,
243241
service: Arc<TokioMutex<Option<RunningClient>>>,
244242
server_url: Mutex<Option<String>>,
245243
}
246244

247-
/// Get or create the global Tokio runtime
248-
fn get_runtime() -> &'static tokio::runtime::Runtime {
249-
GLOBAL_RUNTIME.get_or_init(|| {
250-
// Use single-threaded runtime to avoid TLS issues with nested spawns on Windows
251-
tokio::runtime::Builder::new_current_thread()
252-
.enable_all()
253-
.build()
254-
.expect("Failed to create Tokio runtime")
255-
})
256-
}
257-
258-
/// Run an async block with proper runtime context (Windows compatibility)
259-
fn run_async<F, T>(future: F) -> T
260-
where
261-
F: std::future::Future<Output = T>,
262-
{
263-
get_runtime().block_on(future)
264-
}
265-
266245
/// Create a new MCP client
267246
/// Returns NULL on error
268247
#[no_mangle]
269248
pub extern "C" fn mcp_client_new() -> *mut McpClient {
270-
// Ensure runtime is initialized
271-
let _ = get_runtime();
272-
273-
let client = Box::new(McpClient {
274-
service: Arc::new(TokioMutex::new(None)),
275-
server_url: Mutex::new(None),
276-
});
277-
Box::into_raw(client)
249+
match tokio::runtime::Runtime::new() {
250+
Ok(runtime) => {
251+
let client = Box::new(McpClient {
252+
runtime,
253+
service: Arc::new(TokioMutex::new(None)),
254+
server_url: Mutex::new(None),
255+
});
256+
Box::into_raw(client)
257+
}
258+
Err(_) => ptr::null_mut(),
259+
}
278260
}
279261

280262
/// Free an MCP client
@@ -341,11 +323,15 @@ pub extern "C" fn mcp_connect(
341323
}
342324
};
343325

344-
// Ensure global runtime is initialized
345-
let _ = get_runtime();
346-
347-
// Create a new McpClient
326+
// Create a new McpClient with runtime
348327
let new_client = McpClient {
328+
runtime: match tokio::runtime::Runtime::new() {
329+
Ok(r) => r,
330+
Err(e) => {
331+
let error = format!(r#"{{"error": "Failed to create runtime: {}"}}"#, e);
332+
return CString::new(error).unwrap_or_default().into_raw();
333+
}
334+
},
349335
service: Arc::new(TokioMutex::new(None)),
350336
server_url: Mutex::new(None),
351337
};
@@ -354,7 +340,7 @@ pub extern "C" fn mcp_connect(
354340

355341
let (result, maybe_service) = if use_sse {
356342
// Use SSE transport (legacy) with optional custom headers
357-
run_async(async {
343+
new_client.runtime.block_on(async {
358344
// Create HTTP client with optional custom headers
359345
let mut client_builder = reqwest::Client::builder();
360346
if let Some(ref headers_map) = headers_map {
@@ -440,7 +426,7 @@ pub extern "C" fn mcp_connect(
440426
})
441427
} else {
442428
// Use streamable HTTP transport (default) with optional custom headers
443-
run_async(async {
429+
new_client.runtime.block_on(async {
444430
// For Streamable HTTP, we need to extract the Authorization header specifically
445431
// since it has a dedicated field, and we'll use a custom HTTP client for other headers
446432
let auth_header_value = headers_map.as_ref().and_then(|m| m.get("Authorization")).map(|s| s.clone());
@@ -539,7 +525,7 @@ pub extern "C" fn mcp_connect(
539525

540526
// Store service and URL if connection succeeded
541527
if let Some((service, url)) = maybe_service {
542-
run_async(async {
528+
new_client.runtime.block_on(async {
543529
*new_client.service.lock().await = Some(service);
544530
});
545531
*new_client.server_url.lock().unwrap() = Some(url);
@@ -593,10 +579,13 @@ pub extern "C" fn mcp_disconnect() -> *mut c_char {
593579

594580
// Also clear any active stream channels
595581
{
596-
let mut channels = STREAM_CHANNELS.lock().unwrap();
597-
channels.clear();
582+
let runtime = tokio::runtime::Runtime::new().unwrap();
583+
runtime.block_on(async {
584+
let mut channels = STREAM_CHANNELS.lock().await;
585+
channels.clear();
586+
});
598587
}
599-
588+
600589
// Reset stream counter
601590
*STREAM_COUNTER.lock().unwrap() = 0;
602591

@@ -618,7 +607,7 @@ pub extern "C" fn mcp_list_tools_json(_client_ptr: *mut McpClient) -> *mut c_cha
618607
}
619608
};
620609

621-
let result = run_async(async {
610+
let result = client.runtime.block_on(async {
622611
let service_guard = client.service.lock().await;
623612
let service = match service_guard.as_ref() {
624613
Some(s) => s,
@@ -712,7 +701,7 @@ pub extern "C" fn mcp_call_tool_json(
712701
}
713702
};
714703

715-
let result = run_async(async {
704+
let result = client.runtime.block_on(async {
716705
let service_guard = client.service.lock().await;
717706
let service = match service_guard.as_ref() {
718707
Some(s) => s,
@@ -751,8 +740,8 @@ use std::collections::HashMap;
751740
use tokio::sync::Mutex as TokioMutex;
752741

753742
lazy_static::lazy_static! {
754-
static ref STREAM_CHANNELS: Arc<Mutex<HashMap<usize, tokio::sync::mpsc::UnboundedReceiver<StreamChunk>>>> =
755-
Arc::new(Mutex::new(HashMap::new()));
743+
static ref STREAM_CHANNELS: Arc<TokioMutex<HashMap<usize, tokio::sync::mpsc::UnboundedReceiver<StreamChunk>>>> =
744+
Arc::new(TokioMutex::new(HashMap::new()));
756745
static ref STREAM_COUNTER: Mutex<usize> = Mutex::new(0);
757746
}
758747

@@ -795,14 +784,15 @@ pub extern "C" fn mcp_list_tools_init() -> usize {
795784
// Get the global client
796785
let client_mutex = GLOBAL_CLIENT.get_or_init(|| Mutex::new(None));
797786

798-
// Spawn the async task on the runtime
787+
// Spawn the async task
799788
{
800789
let client_opt = client_mutex.lock().unwrap();
801790
if let Some(client) = client_opt.as_ref() {
802791
// Clone the Arc to share the service across async boundaries
803792
let service_arc = client.service.clone();
804-
// Spawn async task on the runtime
805-
get_runtime().spawn(async move {
793+
794+
// Use the client's runtime to spawn the task
795+
client.runtime.spawn(async move {
806796
let service_guard = service_arc.lock().await;
807797
if let Some(service) = service_guard.as_ref() {
808798
match service.list_tools(None).await {
@@ -834,9 +824,12 @@ pub extern "C" fn mcp_list_tools_init() -> usize {
834824
} // Release the lock here
835825

836826
// Store the receiver in global storage (now safe to acquire lock again)
837-
{
838-
let mut channels = STREAM_CHANNELS.lock().unwrap();
839-
channels.insert(stream_id, rx);
827+
let client_opt = client_mutex.lock().unwrap();
828+
if let Some(client) = client_opt.as_ref() {
829+
client.runtime.block_on(async {
830+
let mut channels = STREAM_CHANNELS.lock().await;
831+
channels.insert(stream_id, rx);
832+
});
840833
}
841834

842835
stream_id
@@ -879,13 +872,14 @@ pub extern "C" fn mcp_call_tool_init(tool_name: *const c_char, arguments: *const
879872
// Get the global client
880873
let client_mutex = GLOBAL_CLIENT.get_or_init(|| Mutex::new(None));
881874

882-
// Spawn the async task on the runtime
875+
// Spawn the async task
883876
{
884877
let client_opt = client_mutex.lock().unwrap();
885878
if let Some(client) = client_opt.as_ref() {
886879
let service_arc = client.service.clone();
887-
// Spawn async task on the runtime
888-
get_runtime().spawn(async move {
880+
881+
// Use the client's runtime to spawn the task
882+
client.runtime.spawn(async move {
889883
let service_guard = service_arc.lock().await;
890884
if let Some(service) = service_guard.as_ref() {
891885
// Parse arguments
@@ -940,9 +934,12 @@ pub extern "C" fn mcp_call_tool_init(tool_name: *const c_char, arguments: *const
940934
}
941935

942936
// Store the receiver
943-
{
944-
let mut channels = STREAM_CHANNELS.lock().unwrap();
945-
channels.insert(stream_id, rx);
937+
let client_opt = client_mutex.lock().unwrap();
938+
if let Some(client) = client_opt.as_ref() {
939+
client.runtime.block_on(async {
940+
let mut channels = STREAM_CHANNELS.lock().await;
941+
channels.insert(stream_id, rx);
942+
});
946943
}
947944

948945
stream_id
@@ -955,18 +952,20 @@ pub extern "C" fn mcp_stream_next(stream_id: usize) -> *mut StreamResult {
955952
let client_mutex = GLOBAL_CLIENT.get_or_init(|| Mutex::new(None));
956953
let client_opt = client_mutex.lock().unwrap();
957954

958-
if let Some(_client) = client_opt.as_ref() {
959-
let mut channels = STREAM_CHANNELS.lock().unwrap();
960-
if let Some(rx) = channels.get_mut(&stream_id) {
961-
match rx.try_recv() {
962-
Ok(chunk) => {
963-
Box::into_raw(Box::new(chunk_to_stream_result(chunk)))
955+
if let Some(client) = client_opt.as_ref() {
956+
client.runtime.block_on(async {
957+
let mut channels = STREAM_CHANNELS.lock().await;
958+
if let Some(rx) = channels.get_mut(&stream_id) {
959+
match rx.try_recv() {
960+
Ok(chunk) => {
961+
Box::into_raw(Box::new(chunk_to_stream_result(chunk)))
962+
}
963+
Err(_) => ptr::null_mut(),
964964
}
965-
Err(_) => ptr::null_mut(),
965+
} else {
966+
ptr::null_mut()
966967
}
967-
} else {
968-
ptr::null_mut()
969-
}
968+
})
970969
} else {
971970
ptr::null_mut()
972971
}
@@ -980,23 +979,21 @@ pub extern "C" fn mcp_stream_wait(stream_id: usize, timeout_ms: u64) -> *mut Str
980979
let client_mutex = GLOBAL_CLIENT.get_or_init(|| Mutex::new(None));
981980
let client_opt = client_mutex.lock().unwrap();
982981

983-
if let Some(_client) = client_opt.as_ref() {
984-
// Get mutable reference to receiver outside of async block
985-
let mut channels = STREAM_CHANNELS.lock().unwrap();
986-
if let Some(rx) = channels.get_mut(&stream_id) {
987-
// We need to use run_async for the async recv operation
988-
run_async(async {
982+
if let Some(client) = client_opt.as_ref() {
983+
client.runtime.block_on(async {
984+
let mut channels = STREAM_CHANNELS.lock().await;
985+
if let Some(rx) = channels.get_mut(&stream_id) {
989986
let timeout = tokio::time::Duration::from_millis(timeout_ms);
990987
match tokio::time::timeout(timeout, rx.recv()).await {
991988
Ok(Some(chunk)) => {
992989
Box::into_raw(Box::new(chunk_to_stream_result(chunk)))
993990
}
994991
_ => ptr::null_mut(),
995992
}
996-
})
997-
} else {
998-
ptr::null_mut()
999-
}
993+
} else {
994+
ptr::null_mut()
995+
}
996+
})
1000997
} else {
1001998
ptr::null_mut()
1002999
}
@@ -1008,9 +1005,11 @@ pub extern "C" fn mcp_stream_cleanup(stream_id: usize) {
10081005
let client_mutex = GLOBAL_CLIENT.get_or_init(|| Mutex::new(None));
10091006
let client_opt = client_mutex.lock().unwrap();
10101007

1011-
if let Some(_client) = client_opt.as_ref() {
1012-
let mut channels = STREAM_CHANNELS.lock().unwrap();
1013-
channels.remove(&stream_id);
1008+
if let Some(client) = client_opt.as_ref() {
1009+
client.runtime.block_on(async {
1010+
let mut channels = STREAM_CHANNELS.lock().await;
1011+
channels.remove(&stream_id);
1012+
});
10141013
}
10151014
}
10161015

@@ -1061,4 +1060,3 @@ fn chunk_to_stream_result(chunk: StreamChunk) -> StreamResult {
10611060
}
10621061
}
10631062
}
1064-

0 commit comments

Comments
 (0)