Skip to content

Commit 22e7c98

Browse files
committed
fix(windows/lib.rs): revert to the original Tokio individual runtime per client implementation
1 parent fc722ad commit 22e7c98

File tree

1 file changed

+58
-101
lines changed

1 file changed

+58
-101
lines changed

src/lib.rs

Lines changed: 58 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -8,66 +8,15 @@
88
use std::ffi::{CStr, CString};
99
use std::os::raw::c_char;
1010
use std::ptr;
11-
use std::sync::{Arc, Mutex, OnceLock};
12-
use std::thread;
11+
use std::sync::{Mutex, OnceLock};
1312

1413
use rmcp::transport::{SseClientTransport, StreamableHttpClientTransport};
1514
use rmcp::{ServiceExt, RoleClient};
1615
use rmcp::model::{ClientInfo, ClientCapabilities, Implementation};
1716

18-
// Global client instance
17+
// Global client instance - one client per process
1918
static GLOBAL_CLIENT: OnceLock<Mutex<Option<McpClient>>> = OnceLock::new();
2019

21-
// Global background runtime
22-
static BACKGROUND_RUNTIME: OnceLock<(
23-
tokio::sync::mpsc::UnboundedSender<Box<dyn FnOnce() + Send + 'static>>,
24-
thread::JoinHandle<()>
25-
)> = OnceLock::new();
26-
27-
/// Initialize background runtime thread
28-
fn get_background_runtime() -> &'static tokio::sync::mpsc::UnboundedSender<Box<dyn FnOnce() + Send + 'static>> {
29-
let (tx, _handle) = BACKGROUND_RUNTIME.get_or_init(|| {
30-
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<Box<dyn FnOnce() + Send + 'static>>();
31-
32-
let handle = thread::spawn(move || {
33-
let rt = tokio::runtime::Builder::new_current_thread()
34-
.enable_all()
35-
.build()
36-
.expect("Failed to create background runtime");
37-
38-
rt.block_on(async {
39-
while let Some(task) = rx.recv().await {
40-
task();
41-
}
42-
});
43-
});
44-
45-
(tx, handle)
46-
});
47-
tx
48-
}
49-
50-
/// Execute async code using dedicated background thread
51-
fn execute_async_sync<F, R>(future: F) -> R
52-
where
53-
F: std::future::Future<Output = R> + Send + 'static,
54-
R: Send + 'static,
55-
{
56-
let (tx, rx) = std::sync::mpsc::channel();
57-
let sender = get_background_runtime();
58-
59-
let task = Box::new(move || {
60-
let handle = tokio::runtime::Handle::current();
61-
handle.spawn(async move {
62-
let result = future.await;
63-
let _ = tx.send(result);
64-
});
65-
});
66-
67-
sender.send(task).expect("Failed to send task to background thread");
68-
rx.recv().expect("Failed to receive result from background thread")
69-
}
70-
7120
/// Initialize the MCP library
7221
/// Returns 0 on success, non-zero on error
7322
#[no_mangle]
@@ -101,19 +50,26 @@ type RunningClient = rmcp::service::RunningService<RoleClient, ClientInfo>;
10150

10251
/// Opaque handle for MCP client
10352
pub struct McpClient {
104-
service: Mutex<Option<Arc<RunningClient>>>,
53+
runtime: tokio::runtime::Runtime,
54+
service: Mutex<Option<RunningClient>>,
10555
server_url: Mutex<Option<String>>,
10656
}
10757

10858
/// Create a new MCP client
10959
/// Returns NULL on error
11060
#[no_mangle]
11161
pub extern "C" fn mcp_client_new() -> *mut McpClient {
112-
let client = Box::new(McpClient {
113-
service: Mutex::new(None),
114-
server_url: Mutex::new(None),
115-
});
116-
Box::into_raw(client)
62+
match tokio::runtime::Runtime::new() {
63+
Ok(runtime) => {
64+
let client = Box::new(McpClient {
65+
runtime,
66+
service: Mutex::new(None),
67+
server_url: Mutex::new(None),
68+
});
69+
Box::into_raw(client)
70+
}
71+
Err(_) => ptr::null_mut(),
72+
}
11773
}
11874

11975
/// Free an MCP client
@@ -180,8 +136,15 @@ pub extern "C" fn mcp_connect(
180136
}
181137
};
182138

183-
// Create a new McpClient
139+
// Create a new McpClient with runtime
184140
let new_client = McpClient {
141+
runtime: match tokio::runtime::Runtime::new() {
142+
Ok(r) => r,
143+
Err(e) => {
144+
let error = format!(r#"{{"error": "Failed to create runtime: {}"}}"#, e);
145+
return CString::new(error).unwrap_or_default().into_raw();
146+
}
147+
},
185148
service: Mutex::new(None),
186149
server_url: Mutex::new(None),
187150
};
@@ -190,7 +153,7 @@ pub extern "C" fn mcp_connect(
190153

191154
let (result, maybe_service) = if use_sse {
192155
// Use SSE transport (legacy) with optional custom headers
193-
execute_async_sync(async move {
156+
new_client.runtime.block_on(async {
194157
// Create HTTP client with optional custom headers
195158
let mut client_builder = reqwest::Client::builder();
196159
if let Some(ref headers_map) = headers_map {
@@ -276,7 +239,7 @@ pub extern "C" fn mcp_connect(
276239
})
277240
} else {
278241
// Use streamable HTTP transport (default) with optional custom headers
279-
execute_async_sync(async move {
242+
new_client.runtime.block_on(async {
280243
// For Streamable HTTP, we need to extract the Authorization header specifically
281244
// since it has a dedicated field, and we'll use a custom HTTP client for other headers
282245
let auth_header_value = headers_map.as_ref().and_then(|m| m.get("Authorization")).map(|s| s.clone());
@@ -375,7 +338,7 @@ pub extern "C" fn mcp_connect(
375338

376339
// Store service and URL if connection succeeded
377340
if let Some((service, url)) = maybe_service {
378-
*new_client.service.lock().unwrap() = Some(Arc::new(service));
341+
*new_client.service.lock().unwrap() = Some(service);
379342
*new_client.server_url.lock().unwrap() = Some(url);
380343

381344
// Store the client globally
@@ -393,29 +356,26 @@ pub extern "C" fn mcp_connect(
393356
/// Returns: JSON string with tools list (must be freed with mcp_free_string)
394357
#[no_mangle]
395358
pub extern "C" fn mcp_list_tools(_client_ptr: *mut McpClient) -> *mut c_char {
396-
// Get global client and extract service in main thread
397-
let service = {
398-
let global_client_guard = GLOBAL_CLIENT.get()
399-
.and_then(|c| Some(c.lock().unwrap()));
400-
let client = match global_client_guard.as_ref().and_then(|g| g.as_ref()) {
401-
Some(c) => c,
402-
None => {
403-
let error = r#"{"error": "Not connected. Call mcp_connect() first"}"#;
404-
return CString::new(error).unwrap_or_default().into_raw();
405-
}
406-
};
407-
359+
// Get global client
360+
let global_client_guard = GLOBAL_CLIENT.get()
361+
.and_then(|c| Some(c.lock().unwrap()));
362+
let client = match global_client_guard.as_ref().and_then(|g| g.as_ref()) {
363+
Some(c) => c,
364+
None => {
365+
let error = r#"{"error": "Not connected. Call mcp_connect() first"}"#;
366+
return CString::new(error).unwrap_or_default().into_raw();
367+
}
368+
};
369+
370+
let result = client.runtime.block_on(async {
408371
let service_guard = client.service.lock().unwrap();
409-
match service_guard.as_ref() {
410-
Some(s) => Arc::clone(s), // Clone the Arc to move to thread
372+
let service = match service_guard.as_ref() {
373+
Some(s) => s,
411374
None => {
412-
let error = r#"{"error": "Not connected to server"}"#;
413-
return CString::new(error).unwrap_or_default().into_raw();
375+
return r#"{"error": "Not connected to server"}"#.to_string();
414376
}
415-
}
416-
};
377+
};
417378

418-
let result = execute_async_sync(async move {
419379
match service.list_tools(Default::default()).await {
420380
Ok(tools_response) => {
421381
let tools_json: Vec<serde_json::Value> = tools_response
@@ -490,29 +450,26 @@ pub extern "C" fn mcp_call_tool(
490450
}
491451
};
492452

493-
// Get global client and extract service in main thread
494-
let service = {
495-
let global_client_guard = GLOBAL_CLIENT.get()
496-
.and_then(|c| Some(c.lock().unwrap()));
497-
let client = match global_client_guard.as_ref().and_then(|g| g.as_ref()) {
498-
Some(c) => c,
499-
None => {
500-
let error = r#"{"error": "Not connected. Call mcp_connect() first"}"#;
501-
return CString::new(error).unwrap_or_default().into_raw();
502-
}
503-
};
504-
453+
// Get global client
454+
let global_client_guard = GLOBAL_CLIENT.get()
455+
.and_then(|c| Some(c.lock().unwrap()));
456+
let client = match global_client_guard.as_ref().and_then(|g| g.as_ref()) {
457+
Some(c) => c,
458+
None => {
459+
let error = r#"{"error": "Not connected. Call mcp_connect() first"}"#;
460+
return CString::new(error).unwrap_or_default().into_raw();
461+
}
462+
};
463+
464+
let result = client.runtime.block_on(async {
505465
let service_guard = client.service.lock().unwrap();
506-
match service_guard.as_ref() {
507-
Some(s) => Arc::clone(s), // Clone the Arc to move to thread
466+
let service = match service_guard.as_ref() {
467+
Some(s) => s,
508468
None => {
509-
let error = r#"{"error": "Not connected to server"}"#;
510-
return CString::new(error).unwrap_or_default().into_raw();
469+
return r#"{"error": "Not connected to server"}"#.to_string();
511470
}
512-
}
513-
};
471+
};
514472

515-
let result = execute_async_sync(async move {
516473
let call_param = rmcp::model::CallToolRequestParam {
517474
name: std::borrow::Cow::Owned(tool_name_str),
518475
arguments: arguments.as_object().cloned(),

0 commit comments

Comments
 (0)