Skip to content

Commit fc722ad

Browse files
committed
fix(windows/lib.rs): Tokio runtime in background thread to isolate Tokio context and avoid new EnterGuards ordering issues
1 parent 511ec2e commit fc722ad

File tree

1 file changed

+92
-54
lines changed

1 file changed

+92
-54
lines changed

src/lib.rs

Lines changed: 92 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -8,35 +8,70 @@
88
use std::ffi::{CStr, CString};
99
use std::os::raw::c_char;
1010
use std::ptr;
11-
use std::sync::{Mutex, OnceLock};
11+
use std::sync::{Arc, Mutex, OnceLock};
12+
use std::thread;
1213

1314
use rmcp::transport::{SseClientTransport, StreamableHttpClientTransport};
1415
use rmcp::{ServiceExt, RoleClient};
1516
use rmcp::model::{ClientInfo, ClientCapabilities, Implementation};
1617

17-
// Global runtime instance - single runtime for entire process
18-
static GLOBAL_RUNTIME: OnceLock<tokio::runtime::Runtime> = OnceLock::new();
19-
2018
// Global client instance
2119
static GLOBAL_CLIENT: OnceLock<Mutex<Option<McpClient>>> = OnceLock::new();
2220

23-
/// Get or create the global Tokio runtime
24-
fn get_runtime() -> &'static tokio::runtime::Runtime {
25-
GLOBAL_RUNTIME.get_or_init(|| {
26-
// Create a simple single-threaded runtime to avoid Windows issues
27-
tokio::runtime::Builder::new_current_thread()
28-
.enable_all()
29-
.build()
30-
.expect("Failed to create Tokio runtime")
31-
})
21+
// Global background runtime
22+
static BACKGROUND_RUNTIME: OnceLock<(
23+
tokio::sync::mpsc::UnboundedSender<Box<dyn FnOnce() + Send + 'static>>,
24+
thread::JoinHandle<()>
25+
)> = OnceLock::new();
26+
27+
/// Initialize background runtime thread
28+
fn get_background_runtime() -> &'static tokio::sync::mpsc::UnboundedSender<Box<dyn FnOnce() + Send + 'static>> {
29+
let (tx, _handle) = BACKGROUND_RUNTIME.get_or_init(|| {
30+
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<Box<dyn FnOnce() + Send + 'static>>();
31+
32+
let handle = thread::spawn(move || {
33+
let rt = tokio::runtime::Builder::new_current_thread()
34+
.enable_all()
35+
.build()
36+
.expect("Failed to create background runtime");
37+
38+
rt.block_on(async {
39+
while let Some(task) = rx.recv().await {
40+
task();
41+
}
42+
});
43+
});
44+
45+
(tx, handle)
46+
});
47+
tx
48+
}
49+
50+
/// Execute async code using dedicated background thread
51+
fn execute_async_sync<F, R>(future: F) -> R
52+
where
53+
F: std::future::Future<Output = R> + Send + 'static,
54+
R: Send + 'static,
55+
{
56+
let (tx, rx) = std::sync::mpsc::channel();
57+
let sender = get_background_runtime();
58+
59+
let task = Box::new(move || {
60+
let handle = tokio::runtime::Handle::current();
61+
handle.spawn(async move {
62+
let result = future.await;
63+
let _ = tx.send(result);
64+
});
65+
});
66+
67+
sender.send(task).expect("Failed to send task to background thread");
68+
rx.recv().expect("Failed to receive result from background thread")
3269
}
3370

3471
/// Initialize the MCP library
3572
/// Returns 0 on success, non-zero on error
3673
#[no_mangle]
3774
pub extern "C" fn mcp_init() -> i32 {
38-
// Initialize the runtime early
39-
let _ = get_runtime();
4075
0
4176
}
4277

@@ -66,17 +101,14 @@ type RunningClient = rmcp::service::RunningService<RoleClient, ClientInfo>;
66101

67102
/// Opaque handle for MCP client
68103
pub struct McpClient {
69-
service: Mutex<Option<RunningClient>>,
104+
service: Mutex<Option<Arc<RunningClient>>>,
70105
server_url: Mutex<Option<String>>,
71106
}
72107

73108
/// Create a new MCP client
74109
/// Returns NULL on error
75110
#[no_mangle]
76111
pub extern "C" fn mcp_client_new() -> *mut McpClient {
77-
// Initialize runtime early
78-
let _ = get_runtime();
79-
80112
let client = Box::new(McpClient {
81113
service: Mutex::new(None),
82114
server_url: Mutex::new(None),
@@ -158,7 +190,7 @@ pub extern "C" fn mcp_connect(
158190

159191
let (result, maybe_service) = if use_sse {
160192
// Use SSE transport (legacy) with optional custom headers
161-
get_runtime().block_on(async {
193+
execute_async_sync(async move {
162194
// Create HTTP client with optional custom headers
163195
let mut client_builder = reqwest::Client::builder();
164196
if let Some(ref headers_map) = headers_map {
@@ -244,7 +276,7 @@ pub extern "C" fn mcp_connect(
244276
})
245277
} else {
246278
// Use streamable HTTP transport (default) with optional custom headers
247-
get_runtime().block_on(async {
279+
execute_async_sync(async move {
248280
// For Streamable HTTP, we need to extract the Authorization header specifically
249281
// since it has a dedicated field, and we'll use a custom HTTP client for other headers
250282
let auth_header_value = headers_map.as_ref().and_then(|m| m.get("Authorization")).map(|s| s.clone());
@@ -343,7 +375,7 @@ pub extern "C" fn mcp_connect(
343375

344376
// Store service and URL if connection succeeded
345377
if let Some((service, url)) = maybe_service {
346-
*new_client.service.lock().unwrap() = Some(service);
378+
*new_client.service.lock().unwrap() = Some(Arc::new(service));
347379
*new_client.server_url.lock().unwrap() = Some(url);
348380

349381
// Store the client globally
@@ -361,26 +393,29 @@ pub extern "C" fn mcp_connect(
361393
/// Returns: JSON string with tools list (must be freed with mcp_free_string)
362394
#[no_mangle]
363395
pub extern "C" fn mcp_list_tools(_client_ptr: *mut McpClient) -> *mut c_char {
364-
// Get global client
365-
let global_client_guard = GLOBAL_CLIENT.get()
366-
.and_then(|c| Some(c.lock().unwrap()));
367-
let client = match global_client_guard.as_ref().and_then(|g| g.as_ref()) {
368-
Some(c) => c,
369-
None => {
370-
let error = r#"{"error": "Not connected. Call mcp_connect() first"}"#;
371-
return CString::new(error).unwrap_or_default().into_raw();
372-
}
373-
};
374-
375-
let result = get_runtime().block_on(async {
376-
let service_guard = client.service.lock().unwrap();
377-
let service = match service_guard.as_ref() {
378-
Some(s) => s,
396+
// Get global client and extract service in main thread
397+
let service = {
398+
let global_client_guard = GLOBAL_CLIENT.get()
399+
.and_then(|c| Some(c.lock().unwrap()));
400+
let client = match global_client_guard.as_ref().and_then(|g| g.as_ref()) {
401+
Some(c) => c,
379402
None => {
380-
return r#"{"error": "Not connected to server"}"#.to_string();
403+
let error = r#"{"error": "Not connected. Call mcp_connect() first"}"#;
404+
return CString::new(error).unwrap_or_default().into_raw();
381405
}
382406
};
407+
408+
let service_guard = client.service.lock().unwrap();
409+
match service_guard.as_ref() {
410+
Some(s) => Arc::clone(s), // Clone the Arc to move to thread
411+
None => {
412+
let error = r#"{"error": "Not connected to server"}"#;
413+
return CString::new(error).unwrap_or_default().into_raw();
414+
}
415+
}
416+
};
383417

418+
let result = execute_async_sync(async move {
384419
match service.list_tools(Default::default()).await {
385420
Ok(tools_response) => {
386421
let tools_json: Vec<serde_json::Value> = tools_response
@@ -455,26 +490,29 @@ pub extern "C" fn mcp_call_tool(
455490
}
456491
};
457492

458-
// Get global client
459-
let global_client_guard = GLOBAL_CLIENT.get()
460-
.and_then(|c| Some(c.lock().unwrap()));
461-
let client = match global_client_guard.as_ref().and_then(|g| g.as_ref()) {
462-
Some(c) => c,
463-
None => {
464-
let error = r#"{"error": "Not connected. Call mcp_connect() first"}"#;
465-
return CString::new(error).unwrap_or_default().into_raw();
466-
}
467-
};
468-
469-
let result = get_runtime().block_on(async {
470-
let service_guard = client.service.lock().unwrap();
471-
let service = match service_guard.as_ref() {
472-
Some(s) => s,
493+
// Get global client and extract service in main thread
494+
let service = {
495+
let global_client_guard = GLOBAL_CLIENT.get()
496+
.and_then(|c| Some(c.lock().unwrap()));
497+
let client = match global_client_guard.as_ref().and_then(|g| g.as_ref()) {
498+
Some(c) => c,
473499
None => {
474-
return r#"{"error": "Not connected to server"}"#.to_string();
500+
let error = r#"{"error": "Not connected. Call mcp_connect() first"}"#;
501+
return CString::new(error).unwrap_or_default().into_raw();
475502
}
476503
};
504+
505+
let service_guard = client.service.lock().unwrap();
506+
match service_guard.as_ref() {
507+
Some(s) => Arc::clone(s), // Clone the Arc to move to thread
508+
None => {
509+
let error = r#"{"error": "Not connected to server"}"#;
510+
return CString::new(error).unwrap_or_default().into_raw();
511+
}
512+
}
513+
};
477514

515+
let result = execute_async_sync(async move {
478516
let call_param = rmcp::model::CallToolRequestParam {
479517
name: std::borrow::Cow::Owned(tool_name_str),
480518
arguments: arguments.as_object().cloned(),

0 commit comments

Comments
 (0)