Skip to content

Commit afd3a4d

Browse files
authored
Merge pull request #71 from gleb-chipiga/blocking-call-tool-to-thread-pool
Wrap plugin calls in tokio spawn_blocking tasks
2 parents 667ac6e + aa0b62e commit afd3a4d

File tree

1 file changed

+58
-27
lines changed

1 file changed

+58
-27
lines changed

src/plugins.rs

Lines changed: 58 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@ use std::str::FromStr;
1111
use serde_json::json;
1212
use sha2::{Digest, Sha256};
1313
use std::collections::HashMap;
14-
use std::sync::Arc;
14+
use std::sync::{Arc, Mutex};
1515
use tokio::sync::RwLock;
1616

1717
#[derive(Clone)]
1818
pub struct PluginService {
1919
config: Config,
20-
plugins: Arc<RwLock<HashMap<String, Plugin>>>,
20+
plugins: Arc<RwLock<HashMap<String, Arc<Mutex<Plugin>>>>>,
2121
tool_plugin_map: Arc<RwLock<HashMap<String, String>>>,
2222
oci_downloader: Arc<OciDownloader>,
2323
}
@@ -122,11 +122,18 @@ impl PluginService {
122122
}
123123
}
124124
}
125-
let mut plugin = Plugin::new(&manifest, [], true).unwrap();
125+
let plugin = Arc::new(Mutex::new(Plugin::new(&manifest, [], true).unwrap()));
126+
let plugin_clone = Arc::clone(&plugin);
126127

127128
// Try to get tool information from the plugin and populate the cache
128-
if let Ok(result) = plugin.call::<&str, &str>("describe", "") {
129-
if let Ok(parsed) = serde_json::from_str::<ListToolsResult>(result) {
129+
let describe_result = tokio::task::spawn_blocking(move || {
130+
let mut plugin = plugin_clone.lock().unwrap();
131+
plugin.call::<&str, String>("describe", "")
132+
})
133+
.await;
134+
135+
if let Ok(Ok(result)) = describe_result {
136+
if let Ok(parsed) = serde_json::from_str::<ListToolsResult>(&result) {
130137
let mut cache = self.tool_plugin_map.write().await;
131138
let skip_tools = plugin_cfg
132139
.runtime_config
@@ -189,7 +196,7 @@ impl ServerHandler for PluginService {
189196
request: CallToolRequestParam,
190197
_context: RequestContext<RoleServer>,
191198
) -> Result<CallToolResult, McpError> {
192-
let mut plugins = self.plugins.write().await;
199+
let plugins = self.plugins.read().await;
193200
let tool_cache = self.tool_plugin_map.read().await;
194201

195202
let tool_name = request.name.clone();
@@ -201,23 +208,35 @@ impl ServerHandler for PluginService {
201208

202209
// Check if the tool exists in the cache
203210
if let Some(plugin_name) = tool_cache.get(&tool_name.to_string()) {
204-
if let Some(plugin) = plugins.get_mut(plugin_name) {
205-
return match plugin.call::<&str, &str>("call", &json_string) {
206-
Ok(result) => match serde_json::from_str::<CallToolResult>(result) {
211+
if let Some(plugin_arc) = plugins.get(plugin_name) {
212+
let plugin_clone = Arc::clone(plugin_arc);
213+
let plugin_name_clone = plugin_name.clone();
214+
215+
let result = tokio::task::spawn_blocking(move || {
216+
let mut plugin = plugin_clone.lock().unwrap();
217+
plugin.call::<&str, String>("call", &json_string)
218+
})
219+
.await;
220+
221+
return match result {
222+
Ok(Ok(result)) => match serde_json::from_str::<CallToolResult>(&result) {
207223
Ok(parsed) => Ok(parsed),
208-
Err(e) => {
209-
return Err(McpError::internal_error(
210-
format!("Failed to deserialize data: {}", e),
211-
None,
212-
));
213-
}
214-
},
215-
Err(e) => {
216-
return Err(McpError::internal_error(
217-
format!("Failed to execute plugin {}: {}", plugin_name, e),
224+
Err(e) => Err(McpError::internal_error(
225+
format!("Failed to deserialize data: {}", e),
218226
None,
219-
));
220-
}
227+
)),
228+
},
229+
Ok(Err(e)) => Err(McpError::internal_error(
230+
format!("Failed to execute plugin {}: {}", plugin_name_clone, e),
231+
None,
232+
)),
233+
Err(e) => Err(McpError::internal_error(
234+
format!(
235+
"Failed to spawn blocking task for plugin {}: {}",
236+
plugin_name_clone, e
237+
),
238+
None,
239+
)),
221240
};
222241
}
223242
}
@@ -231,7 +250,7 @@ impl ServerHandler for PluginService {
231250
_context: RequestContext<RoleServer>,
232251
) -> std::result::Result<ListToolsResult, McpError> {
233252
tracing::info!("got tools/list request {:?}", request);
234-
let mut plugins = self.plugins.write().await;
253+
let plugins = self.plugins.write().await;
235254
let mut tool_cache = self.tool_plugin_map.write().await;
236255

237256
let mut payload = ListToolsResult::default();
@@ -240,10 +259,19 @@ impl ServerHandler for PluginService {
240259
tool_cache.clear();
241260

242261
for plugin_cfg in &self.config.plugins {
243-
if let Some(plugin) = plugins.get_mut(&plugin_cfg.name) {
244-
match plugin.call::<&str, &str>("describe", "") {
245-
Ok(result) => {
246-
if let Ok(parsed) = serde_json::from_str::<ListToolsResult>(result) {
262+
if let Some(plugin_arc) = plugins.get(&plugin_cfg.name) {
263+
let plugin_clone = Arc::clone(plugin_arc);
264+
let plugin_name = plugin_cfg.name.clone();
265+
266+
let result = tokio::task::spawn_blocking(move || {
267+
let mut plugin = plugin_clone.lock().unwrap();
268+
plugin.call::<&str, String>("describe", "")
269+
})
270+
.await;
271+
272+
match result {
273+
Ok(Ok(result)) => {
274+
if let Ok(parsed) = serde_json::from_str::<ListToolsResult>(&result) {
247275
let skip_tools = plugin_cfg
248276
.runtime_config
249277
.as_ref()
@@ -262,8 +290,11 @@ impl ServerHandler for PluginService {
262290
}
263291
}
264292
}
293+
Ok(Err(e)) => {
294+
log::error!("tool {} describe() error: {}", plugin_name, e);
295+
}
265296
Err(e) => {
266-
log::error!("tool {} describe() error: {}", plugin_cfg.name, e);
297+
log::error!("tool {} spawn_blocking error: {}", plugin_name, e);
267298
}
268299
}
269300
}

0 commit comments

Comments
 (0)