Skip to content

Commit eeaf92e

Browse files
committed
fix(lib): simplify async task spawning on runtime for better compatibility
1 parent b5659ed commit eeaf92e

File tree

1 file changed

+60
-68
lines changed

1 file changed

+60
-68
lines changed

src/lib.rs

Lines changed: 60 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -797,42 +797,38 @@ pub extern "C" fn mcp_list_tools_init() -> usize {
797797
// Get the global client
798798
let client_mutex = GLOBAL_CLIENT.get_or_init(|| Mutex::new(None));
799799

800-
// Spawn the async task using spawn_blocking + block_on for Windows compatibility
800+
// Spawn the async task on the runtime
801801
{
802802
let client_opt = client_mutex.lock().unwrap();
803803
if let Some(client) = client_opt.as_ref() {
804804
// Clone the Arc to share the service across async boundaries
805805
let service_arc = client.service.clone();
806806
let runtime_handle = client.runtime.handle().clone();
807807

808-
// Use spawn_blocking to run a blocking task that calls block_on
809-
// This ensures the reactor is running when async code executes (Windows fix)
810-
std::thread::spawn(move || {
811-
let _guard = runtime_handle.enter(); // Enter runtime context first
812-
runtime_handle.block_on(async move {
813-
let service_guard = service_arc.lock().await;
814-
if let Some(service) = service_guard.as_ref() {
815-
match service.list_tools(None).await {
816-
Ok(response) => {
817-
// Send each tool as a separate chunk
818-
for tool in response.tools {
819-
if let Ok(tool_json) = serde_json::to_value(&tool) {
820-
let _ = tx.send(StreamChunk::Tool(tool_json));
821-
}
808+
// Spawn directly on the runtime (works with both multi-threaded and current-thread runtimes)
809+
runtime_handle.spawn(async move {
810+
let service_guard = service_arc.lock().await;
811+
if let Some(service) = service_guard.as_ref() {
812+
match service.list_tools(None).await {
813+
Ok(response) => {
814+
// Send each tool as a separate chunk
815+
for tool in response.tools {
816+
if let Ok(tool_json) = serde_json::to_value(&tool) {
817+
let _ = tx.send(StreamChunk::Tool(tool_json));
822818
}
823-
let _ = tx.send(StreamChunk::Done);
824-
}
825-
Err(e) => {
826-
let _ = tx.send(StreamChunk::Error(format!("Failed to list tools: {}", e)));
827-
let _ = tx.send(StreamChunk::Done);
828819
}
820+
let _ = tx.send(StreamChunk::Done);
821+
}
822+
Err(e) => {
823+
let _ = tx.send(StreamChunk::Error(format!("Failed to list tools: {}", e)));
824+
let _ = tx.send(StreamChunk::Done);
829825
}
830-
} else {
831-
// No service connected
832-
let _ = tx.send(StreamChunk::Error("Not connected. Call mcp_connect() first".to_string()));
833-
let _ = tx.send(StreamChunk::Done);
834826
}
835-
});
827+
} else {
828+
// No service connected
829+
let _ = tx.send(StreamChunk::Error("Not connected. Call mcp_connect() first".to_string()));
830+
let _ = tx.send(StreamChunk::Done);
831+
}
836832
});
837833
} else {
838834
// No client initialized
@@ -887,65 +883,61 @@ pub extern "C" fn mcp_call_tool_init(tool_name: *const c_char, arguments: *const
887883
// Get the global client
888884
let client_mutex = GLOBAL_CLIENT.get_or_init(|| Mutex::new(None));
889885

890-
// Spawn the async task using spawn_blocking + block_on for Windows compatibility
886+
// Spawn the async task on the runtime
891887
{
892888
let client_opt = client_mutex.lock().unwrap();
893889
if let Some(client) = client_opt.as_ref() {
894890
let service_arc = client.service.clone();
895891
let runtime_handle = client.runtime.handle().clone();
896892

897-
// Use spawn_blocking to run a blocking task that calls block_on
898-
// This ensures the reactor is running when async code executes (Windows fix)
899-
std::thread::spawn(move || {
900-
let _guard = runtime_handle.enter(); // Enter runtime context first
901-
runtime_handle.block_on(async move {
902-
let service_guard = service_arc.lock().await;
903-
if let Some(service) = service_guard.as_ref() {
904-
// Parse arguments
905-
let arguments_json: serde_json::Value = match serde_json::from_str(&arguments_str) {
906-
Ok(v) => v,
907-
Err(e) => {
908-
let _ = tx.send(StreamChunk::Error(format!("Invalid JSON arguments: {}", e)));
909-
let _ = tx.send(StreamChunk::Done);
910-
return;
911-
}
912-
};
893+
// Spawn directly on the runtime (works with both multi-threaded and current-thread runtimes)
894+
runtime_handle.spawn(async move {
895+
let service_guard = service_arc.lock().await;
896+
if let Some(service) = service_guard.as_ref() {
897+
// Parse arguments
898+
let arguments_json: serde_json::Value = match serde_json::from_str(&arguments_str) {
899+
Ok(v) => v,
900+
Err(e) => {
901+
let _ = tx.send(StreamChunk::Error(format!("Invalid JSON arguments: {}", e)));
902+
let _ = tx.send(StreamChunk::Done);
903+
return;
904+
}
905+
};
913906

914-
// Create the call tool parameter
915-
let call_param = rmcp::model::CallToolRequestParam {
916-
name: std::borrow::Cow::Owned(tool_name_str),
917-
arguments: arguments_json.as_object().cloned(),
918-
};
907+
// Create the call tool parameter
908+
let call_param = rmcp::model::CallToolRequestParam {
909+
name: std::borrow::Cow::Owned(tool_name_str),
910+
arguments: arguments_json.as_object().cloned(),
911+
};
919912

920-
// Call the tool
921-
match service.call_tool(call_param).await {
922-
Ok(result) => {
923-
// Serialize the result to JSON and extract text content
924-
if let Ok(result_json) = serde_json::to_value(&result) {
925-
if let Some(content_array) = result_json.get("content").and_then(|v| v.as_array()) {
926-
for item in content_array {
927-
if let Some(item_type) = item.get("type").and_then(|v| v.as_str()) {
928-
if item_type == "text" {
929-
if let Some(text) = item.get("text").and_then(|v| v.as_str()) {
930-
let _ = tx.send(StreamChunk::Text(text.to_string()));
931-
}
913+
// Call the tool
914+
match service.call_tool(call_param).await {
915+
Ok(result) => {
916+
// Serialize the result to JSON and extract text content
917+
if let Ok(result_json) = serde_json::to_value(&result) {
918+
if let Some(content_array) = result_json.get("content").and_then(|v| v.as_array()) {
919+
for item in content_array {
920+
if let Some(item_type) = item.get("type").and_then(|v| v.as_str()) {
921+
if item_type == "text" {
922+
if let Some(text) = item.get("text").and_then(|v| v.as_str()) {
923+
let _ = tx.send(StreamChunk::Text(text.to_string()));
932924
}
933925
}
934926
}
935927
}
936928
}
937-
let _ = tx.send(StreamChunk::Done);
938-
}
939-
Err(e) => {
940-
let _ = tx.send(StreamChunk::Error(format!("Failed to call tool: {}", e)));
941-
let _ = tx.send(StreamChunk::Done);
942929
}
930+
let _ = tx.send(StreamChunk::Done);
931+
}
932+
Err(e) => {
933+
let _ = tx.send(StreamChunk::Error(format!("Failed to call tool: {}", e)));
934+
let _ = tx.send(StreamChunk::Done);
943935
}
944-
} else {
945-
let _ = tx.send(StreamChunk::Error("Not connected. Call mcp_connect() first".to_string()));
946-
let _ = tx.send(StreamChunk::Done);
947936
}
948-
});
937+
} else {
938+
let _ = tx.send(StreamChunk::Error("Not connected. Call mcp_connect() first".to_string()));
939+
let _ = tx.send(StreamChunk::Done);
940+
}
949941
});
950942
} else {
951943
let _ = tx.send(StreamChunk::Error("Client not initialized".to_string()));

0 commit comments

Comments
 (0)