Skip to content

Commit fe8ee1e

Browse files
committed
fix(lib): simplify global Tokio runtime management and improve Windows compatibility
1 parent 8dd20c9 commit fe8ee1e

File tree

1 file changed

+32
-35
lines changed

1 file changed

+32
-35
lines changed

src/lib.rs

Lines changed: 32 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,6 @@ use serde_json;
1818
// Global runtime - one runtime per process that stays alive
1919
static GLOBAL_RUNTIME: OnceLock<tokio::runtime::Runtime> = OnceLock::new();
2020

21-
// Flag to track if runtime has been entered (Windows fix for nested spawns)
22-
static RUNTIME_ENTERED: OnceLock<()> = OnceLock::new();
23-
2421
// Global client instance - one client per process
2522
static GLOBAL_CLIENT: OnceLock<Mutex<Option<McpClient>>> = OnceLock::new();
2623

@@ -247,30 +244,24 @@ pub struct McpClient {
247244
server_url: Mutex<Option<String>>,
248245
}
249246

250-
/// Get or create the global Tokio runtime and ensure it's running
247+
/// Get or create the global Tokio runtime
251248
fn get_runtime() -> &'static tokio::runtime::Runtime {
252-
let runtime = GLOBAL_RUNTIME.get_or_init(|| {
253-
tokio::runtime::Runtime::new().expect("Failed to create Tokio runtime")
254-
});
255-
256-
// On first call, spawn a background thread to keep the runtime active forever (Windows fix)
257-
RUNTIME_ENTERED.get_or_init(|| {
258-
// SAFETY: The runtime is stored in a static and lives for the entire program duration
259-
unsafe {
260-
let runtime_ptr: *const tokio::runtime::Runtime = runtime;
261-
let runtime_ref: &'static tokio::runtime::Runtime = &*runtime_ptr;
262-
263-
// Spawn a background thread that keeps the runtime active
264-
std::thread::spawn(move || {
265-
runtime_ref.block_on(async {
266-
// Keep this future running forever to maintain runtime context
267-
std::future::pending::<()>().await
268-
});
269-
});
270-
}
271-
});
249+
GLOBAL_RUNTIME.get_or_init(|| {
250+
tokio::runtime::Builder::new_multi_thread()
251+
.enable_all()
252+
.build()
253+
.expect("Failed to create Tokio runtime")
254+
})
255+
}
272256

273-
runtime
257+
/// Run an async block with proper runtime context (Windows compatibility)
258+
fn run_async<F, T>(future: F) -> T
259+
where
260+
F: std::future::Future<Output = T>,
261+
{
262+
let runtime = get_runtime();
263+
let _enter = runtime.enter();
264+
runtime.block_on(future)
274265
}
275266

276267
/// Create a new MCP client
@@ -364,7 +355,7 @@ pub extern "C" fn mcp_connect(
364355

365356
let (result, maybe_service) = if use_sse {
366357
// Use SSE transport (legacy) with optional custom headers
367-
get_runtime().block_on(async {
358+
run_async(async {
368359
// Create HTTP client with optional custom headers
369360
let mut client_builder = reqwest::Client::builder();
370361
if let Some(ref headers_map) = headers_map {
@@ -450,7 +441,7 @@ pub extern "C" fn mcp_connect(
450441
})
451442
} else {
452443
// Use streamable HTTP transport (default) with optional custom headers
453-
get_runtime().block_on(async {
444+
run_async(async {
454445
// For Streamable HTTP, we need to extract the Authorization header specifically
455446
// since it has a dedicated field, and we'll use a custom HTTP client for other headers
456447
let auth_header_value = headers_map.as_ref().and_then(|m| m.get("Authorization")).map(|s| s.clone());
@@ -549,7 +540,7 @@ pub extern "C" fn mcp_connect(
549540

550541
// Store service and URL if connection succeeded
551542
if let Some((service, url)) = maybe_service {
552-
get_runtime().block_on(async {
543+
run_async(async {
553544
*new_client.service.lock().await = Some(service);
554545
});
555546
*new_client.server_url.lock().unwrap() = Some(url);
@@ -628,7 +619,7 @@ pub extern "C" fn mcp_list_tools_json(_client_ptr: *mut McpClient) -> *mut c_cha
628619
}
629620
};
630621

631-
let result = get_runtime().block_on(async {
622+
let result = run_async(async {
632623
let service_guard = client.service.lock().await;
633624
let service = match service_guard.as_ref() {
634625
Some(s) => s,
@@ -722,7 +713,7 @@ pub extern "C" fn mcp_call_tool_json(
722713
}
723714
};
724715

725-
let result = get_runtime().block_on(async {
716+
let result = run_async(async {
726717
let service_guard = client.service.lock().await;
727718
let service = match service_guard.as_ref() {
728719
Some(s) => s,
@@ -811,8 +802,11 @@ pub extern "C" fn mcp_list_tools_init() -> usize {
811802
if let Some(client) = client_opt.as_ref() {
812803
// Clone the Arc to share the service across async boundaries
813804
let service_arc = client.service.clone();
805+
// Enter runtime context before spawning (Windows fix)
806+
let runtime = get_runtime();
807+
let _enter = runtime.enter();
814808
// Spawn async task directly on the global runtime (like the official rmcp examples)
815-
get_runtime().spawn(async move {
809+
runtime.spawn(async move {
816810
let service_guard = service_arc.lock().await;
817811
if let Some(service) = service_guard.as_ref() {
818812
match service.list_tools(None).await {
@@ -894,8 +888,11 @@ pub extern "C" fn mcp_call_tool_init(tool_name: *const c_char, arguments: *const
894888
let client_opt = client_mutex.lock().unwrap();
895889
if let Some(client) = client_opt.as_ref() {
896890
let service_arc = client.service.clone();
891+
// Enter runtime context before spawning (Windows fix)
892+
let runtime = get_runtime();
893+
let _enter = runtime.enter();
897894
// Spawn async task directly on the global runtime (like the official rmcp examples)
898-
get_runtime().spawn(async move {
895+
runtime.spawn(async move {
899896
let service_guard = service_arc.lock().await;
900897
if let Some(service) = service_guard.as_ref() {
901898
// Parse arguments
@@ -990,12 +987,12 @@ pub extern "C" fn mcp_stream_wait(stream_id: usize, timeout_ms: u64) -> *mut Str
990987
let client_mutex = GLOBAL_CLIENT.get_or_init(|| Mutex::new(None));
991988
let client_opt = client_mutex.lock().unwrap();
992989

993-
if let Some(client) = client_opt.as_ref() {
990+
if let Some(_client) = client_opt.as_ref() {
994991
// Get mutable reference to receiver outside of async block
995992
let mut channels = STREAM_CHANNELS.lock().unwrap();
996993
if let Some(rx) = channels.get_mut(&stream_id) {
997-
// We need to use block_on for the async recv operation
998-
get_runtime().block_on(async {
994+
// We need to use run_async for the async recv operation
995+
run_async(async {
999996
let timeout = tokio::time::Duration::from_millis(timeout_ms);
1000997
match tokio::time::timeout(timeout, rx.recv()).await {
1001998
Ok(Some(chunk)) => {

0 commit comments

Comments
 (0)