Skip to content

Commit 0476212

Browse files
committed
fix(lib): refactor MCP client to use a global Tokio runtime for improved resource management
1 parent 8865be1 commit 0476212

File tree

1 file changed

+32
-34
lines changed

1 file changed

+32
-34
lines changed

src/lib.rs

Lines changed: 32 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ use rmcp::{ServiceExt, RoleClient};
1515
use rmcp::model::{ClientInfo, ClientCapabilities, Implementation};
1616
use serde_json;
1717

18+
// Global runtime - one runtime per process that stays alive
19+
static GLOBAL_RUNTIME: OnceLock<tokio::runtime::Runtime> = OnceLock::new();
20+
1821
// Global client instance - one client per process
1922
static GLOBAL_CLIENT: OnceLock<Mutex<Option<McpClient>>> = OnceLock::new();
2023

@@ -237,26 +240,29 @@ type RunningClient = rmcp::service::RunningService<RoleClient, ClientInfo>;
237240

238241
/// Opaque handle for MCP client
239242
pub struct McpClient {
240-
runtime: tokio::runtime::Runtime,
241243
service: Arc<TokioMutex<Option<RunningClient>>>,
242244
server_url: Mutex<Option<String>>,
243245
}
244246

247+
/// Get or create the global Tokio runtime
248+
fn get_runtime() -> &'static tokio::runtime::Runtime {
249+
GLOBAL_RUNTIME.get_or_init(|| {
250+
tokio::runtime::Runtime::new().expect("Failed to create Tokio runtime")
251+
})
252+
}
253+
245254
/// Create a new MCP client
246255
/// Returns NULL on error
247256
#[no_mangle]
248257
pub extern "C" fn mcp_client_new() -> *mut McpClient {
249-
match tokio::runtime::Runtime::new() {
250-
Ok(runtime) => {
251-
let client = Box::new(McpClient {
252-
runtime,
253-
service: Arc::new(TokioMutex::new(None)),
254-
server_url: Mutex::new(None),
255-
});
256-
Box::into_raw(client)
257-
}
258-
Err(_) => ptr::null_mut(),
259-
}
258+
// Ensure runtime is initialized
259+
let _ = get_runtime();
260+
261+
let client = Box::new(McpClient {
262+
service: Arc::new(TokioMutex::new(None)),
263+
server_url: Mutex::new(None),
264+
});
265+
Box::into_raw(client)
260266
}
261267

262268
/// Free an MCP client
@@ -323,15 +329,11 @@ pub extern "C" fn mcp_connect(
323329
}
324330
};
325331

326-
// Create a new McpClient with runtime
332+
// Ensure global runtime is initialized
333+
let _ = get_runtime();
334+
335+
// Create a new McpClient
327336
let new_client = McpClient {
328-
runtime: match tokio::runtime::Runtime::new() {
329-
Ok(r) => r,
330-
Err(e) => {
331-
let error = format!(r#"{{"error": "Failed to create runtime: {}"}}"#, e);
332-
return CString::new(error).unwrap_or_default().into_raw();
333-
}
334-
},
335337
service: Arc::new(TokioMutex::new(None)),
336338
server_url: Mutex::new(None),
337339
};
@@ -340,7 +342,7 @@ pub extern "C" fn mcp_connect(
340342

341343
let (result, maybe_service) = if use_sse {
342344
// Use SSE transport (legacy) with optional custom headers
343-
new_client.runtime.block_on(async {
345+
get_runtime().block_on(async {
344346
// Create HTTP client with optional custom headers
345347
let mut client_builder = reqwest::Client::builder();
346348
if let Some(ref headers_map) = headers_map {
@@ -426,7 +428,7 @@ pub extern "C" fn mcp_connect(
426428
})
427429
} else {
428430
// Use streamable HTTP transport (default) with optional custom headers
429-
new_client.runtime.block_on(async {
431+
get_runtime().block_on(async {
430432
// For Streamable HTTP, we need to extract the Authorization header specifically
431433
// since it has a dedicated field, and we'll use a custom HTTP client for other headers
432434
let auth_header_value = headers_map.as_ref().and_then(|m| m.get("Authorization")).map(|s| s.clone());
@@ -525,7 +527,7 @@ pub extern "C" fn mcp_connect(
525527

526528
// Store service and URL if connection succeeded
527529
if let Some((service, url)) = maybe_service {
528-
new_client.runtime.block_on(async {
530+
get_runtime().block_on(async {
529531
*new_client.service.lock().await = Some(service);
530532
});
531533
*new_client.server_url.lock().unwrap() = Some(url);
@@ -604,7 +606,7 @@ pub extern "C" fn mcp_list_tools_json(_client_ptr: *mut McpClient) -> *mut c_cha
604606
}
605607
};
606608

607-
let result = client.runtime.block_on(async {
609+
let result = get_runtime().block_on(async {
608610
let service_guard = client.service.lock().await;
609611
let service = match service_guard.as_ref() {
610612
Some(s) => s,
@@ -698,7 +700,7 @@ pub extern "C" fn mcp_call_tool_json(
698700
}
699701
};
700702

701-
let result = client.runtime.block_on(async {
703+
let result = get_runtime().block_on(async {
702704
let service_guard = client.service.lock().await;
703705
let service = match service_guard.as_ref() {
704706
Some(s) => s,
@@ -787,10 +789,8 @@ pub extern "C" fn mcp_list_tools_init() -> usize {
787789
if let Some(client) = client_opt.as_ref() {
788790
// Clone the Arc to share the service across async boundaries
789791
let service_arc = client.service.clone();
790-
let runtime_handle = client.runtime.handle().clone();
791-
792-
// Spawn async task directly on the runtime (like the official rmcp examples)
793-
runtime_handle.spawn(async move {
792+
// Spawn async task directly on the global runtime (like the official rmcp examples)
793+
get_runtime().spawn(async move {
794794
let service_guard = service_arc.lock().await;
795795
if let Some(service) = service_guard.as_ref() {
796796
match service.list_tools(None).await {
@@ -872,10 +872,8 @@ pub extern "C" fn mcp_call_tool_init(tool_name: *const c_char, arguments: *const
872872
let client_opt = client_mutex.lock().unwrap();
873873
if let Some(client) = client_opt.as_ref() {
874874
let service_arc = client.service.clone();
875-
let runtime_handle = client.runtime.handle().clone();
876-
877-
// Spawn async task directly on the runtime (like the official rmcp examples)
878-
runtime_handle.spawn(async move {
875+
// Spawn async task directly on the global runtime (like the official rmcp examples)
876+
get_runtime().spawn(async move {
879877
let service_guard = service_arc.lock().await;
880878
if let Some(service) = service_guard.as_ref() {
881879
// Parse arguments
@@ -975,7 +973,7 @@ pub extern "C" fn mcp_stream_wait(stream_id: usize, timeout_ms: u64) -> *mut Str
975973
let mut channels = STREAM_CHANNELS.lock().unwrap();
976974
if let Some(rx) = channels.get_mut(&stream_id) {
977975
// We need to use block_on for the async recv operation
978-
client.runtime.block_on(async {
976+
get_runtime().block_on(async {
979977
let timeout = tokio::time::Duration::from_millis(timeout_ms);
980978
match tokio::time::timeout(timeout, rx.recv()).await {
981979
Ok(Some(chunk)) => {

0 commit comments

Comments
 (0)