Skip to content

Commit ffcaa7a

Browse files
committed
chore: update start_stresam
1 parent bc19630 commit ffcaa7a

File tree

1 file changed

+79
-20
lines changed

1 file changed

+79
-20
lines changed

crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs

Lines changed: 79 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -150,12 +150,12 @@ impl McpServer for ServerRuntime {
150150
match mcp_messages {
151151
ClientMessages::Single(client_message) => {
152152
let transport = transport.clone();
153-
let self_clone = self.clone();
153+
let self = self.clone();
154154
let tx = tx.clone();
155155

156156
// Handle incoming messages in a separate task to avoid blocking the stream.
157157
tokio::spawn(async move {
158-
let result = self_clone.handle_message(client_message, &transport).await;
158+
let result = self.handle_message(client_message, &transport).await;
159159

160160
let send_result: SdkResult<_> = match result {
161161
Ok(result) => {
@@ -423,7 +423,8 @@ impl ServerRuntime {
423423

424424
self.store_transport(stream_id, Arc::new(transport)).await?;
425425

426-
let transport = self.transport_by_stream(stream_id).await?;
426+
let self_clone = self.clone();
427+
let transport = self_clone.transport_by_stream(stream_id).await?;
427428

428429
let (disconnect_tx, mut disconnect_rx) = oneshot::channel::<()>();
429430
let abort_alive_task = transport
@@ -441,40 +442,98 @@ impl ServerRuntime {
441442
transport.consume_string_payload(&payload).await?;
442443
}
443444

445+
// Create a channel to collect results from spawned tasks
446+
let (tx, mut rx) = mpsc::channel(TASK_CHANNEL_CAPACITY);
447+
444448
loop {
445449
tokio::select! {
446450
Some(mcp_messages) = stream.next() =>{
447451

448452
match mcp_messages {
449453
ClientMessages::Single(client_message) => {
450-
let result = self.handle_message(client_message, &transport).await?;
451-
if let Some(result) = result {
452-
transport.send_message(ServerMessages::Single(result), None).await?;
453-
}
454+
let transport = transport.clone();
455+
let self_clone = self.clone();
456+
let tx = tx.clone();
457+
tokio::spawn(async move {
458+
459+
let result = self_clone.handle_message(client_message, &transport).await;
460+
461+
let send_result: SdkResult<_> = match result {
462+
Ok(result) => {
463+
if let Some(result) = result {
464+
transport
465+
.send_message(ServerMessages::Single(result), None)
466+
.map_err(|e| e.into())
467+
.await
468+
} else {
469+
Ok(None)
470+
}
471+
}
472+
Err(error) => {
473+
tracing::error!("Error handling message : {}", error);
474+
Ok(None)
475+
}
476+
};
477+
if let Err(error) = tx.send(send_result).await {
478+
tracing::error!("Failed to send batch result to channel: {}", error);
479+
}
480+
});
454481
}
455482
ClientMessages::Batch(client_messages) => {
456483

457-
let handling_tasks: Vec<_> = client_messages
458-
.into_iter()
459-
.map(|client_message| self.handle_message(client_message, &transport))
460-
.collect();
461-
462-
let results: Vec<_> = try_join_all(handling_tasks).await?;
463-
464-
let results: Vec<_> = results.into_iter().flatten().collect();
465-
466-
467-
if !results.is_empty() {
468-
transport.send_message(ServerMessages::Batch(results), None).await?;
469-
}
484+
let transport = transport.clone();
485+
let self_clone = self_clone.clone();
486+
let tx = tx.clone();
487+
488+
tokio::spawn(async move {
489+
let handling_tasks: Vec<_> = client_messages
490+
.into_iter()
491+
.map(|client_message| self_clone.handle_message(client_message, &transport))
492+
.collect();
493+
494+
let send_result = match try_join_all(handling_tasks).await {
495+
Ok(results) => {
496+
let results: Vec<_> = results.into_iter().flatten().collect();
497+
if !results.is_empty() {
498+
transport.send_message(ServerMessages::Batch(results), None)
499+
.map_err(|e| e.into())
500+
.await
501+
}else {
502+
Ok(None)
503+
}
504+
},
505+
Err(error) => Err(error),
506+
};
507+
if let Err(error) = tx.send(send_result).await {
508+
tracing::error!("Failed to send batch result to channel: {}", error);
509+
}
510+
});
470511
}
471512
}
513+
514+
// Check for results from spawned tasks to propagate errors
515+
while let Ok(result) = rx.try_recv() {
516+
result?; // Propagate errors
517+
}
518+
472519
// close the stream after all messages are sent, unless it is a standalone stream
473520
if !stream_id.eq(DEFAULT_STREAM_ID){
521+
// Drop tx to close the channel and collect remaining results
522+
drop(tx);
523+
while let Some(result) = rx.recv().await {
524+
println!(">>> 3000 {:?} ", result);
525+
526+
result?; // Propagate errors
527+
}
474528
return Ok(());
475529
}
476530
}
477531
_ = &mut disconnect_rx => {
532+
// Drop tx to close the channel and collect remaining results
533+
drop(tx);
534+
while let Some(result) = rx.recv().await {
535+
result?; // Propagate errors
536+
}
478537
self.remove_transport(stream_id).await?;
479538
// Disconnection detected by keep-alive task
480539
return Err(SdkError::connection_closed().into());

0 commit comments

Comments
 (0)