diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9fd45e0..4f3a988 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -18,5 +18,7 @@ jobs: - uses: actions/checkout@v4 - name: Build run: cargo build --verbose + - name: Build examples + run: cargo build --examples - name: Run tests run: cargo test --verbose diff --git a/Cargo.lock b/Cargo.lock index a0cd56f..05aeb13 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -622,7 +622,7 @@ dependencies = [ "tokio", "tokio-rustls", "tower-service", - "webpki-roots", + "webpki-roots 0.26.11", ] [[package]] @@ -902,6 +902,12 @@ version = "0.4.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94" +[[package]] +name = "lru-slab" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" + [[package]] name = "matchit" version = "0.8.4" @@ -977,6 +983,18 @@ dependencies = [ "tempfile", ] +[[package]] +name = "nix" +version = "0.30.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74523f3a35e05aba87a1d978330aef40f67b0304ac79c1c00b294c9830543db6" +dependencies = [ + "bitflags", + "cfg-if", + "cfg_aliases", + "libc", +] + [[package]] name = "nu-ansi-term" version = "0.46.0" @@ -1142,11 +1160,25 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "process-wrap" +version = "8.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3ef4f2f0422f23a82ec9f628ea2acd12871c81a9362b02c43c1aa86acfc3ba1" +dependencies = [ + "futures", + "indexmap", + "nix", + "tokio", + "tracing", + "windows", +] + [[package]] name = "quinn" -version = "0.11.7" +version = "0.11.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3bd15a6f2967aef83887dcb9fec0014580467e33720d073560cf015a5683012" +checksum = "626214629cda6781b6dc1d316ba307189c85ba657213ce642d9c77670f8202c8" dependencies = [ "bytes", "cfg_aliases", @@ -1164,12 +1196,13 @@ dependencies = [ [[package]] name = "quinn-proto" -version = "0.11.11" +version = "0.11.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bcbafbbdbb0f638fe3f35f3c56739f77a8a1d070cb25603226c83339b391472b" +checksum = "49df843a9161c85bb8aae55f101bc0bac8bcafd637a620d9122fd7e0b2f7422e" dependencies = [ "bytes", "getrandom 0.3.2", + "lru-slab", "rand", "ring", "rustc-hash", @@ -1193,7 +1226,7 @@ dependencies = [ "once_cell", "socket2", "tracing", - "windows-sys 0.59.0", + "windows-sys 0.52.0", ] [[package]] @@ -1296,7 +1329,7 @@ dependencies = [ "wasm-bindgen-futures", "wasm-streams", "web-sys", - "webpki-roots", + "webpki-roots 0.26.11", "windows-registry", ] @@ -1317,14 +1350,16 @@ dependencies = [ [[package]] name = "rmcp" version = "0.1.5" -source = "git+https://github.com/modelcontextprotocol/rust-sdk.git?rev=afb8a905e54b87c69e880f9377cfe8424aa6f13b#afb8a905e54b87c69e880f9377cfe8424aa6f13b" +source = "git+https://github.com/modelcontextprotocol/rust-sdk.git?rev=076dc2c2cd8910bee56bae13f29bbcff8c279666#076dc2c2cd8910bee56bae13f29bbcff8c279666" dependencies = [ "axum", "base64 0.21.7", "chrono", "futures", + "http", "paste", "pin-project-lite", + "process-wrap", "rand", "reqwest", "rmcp-macros", @@ -1337,16 +1372,17 @@ dependencies = [ "tokio-stream", "tokio-util", "tracing", - "url", + "uuid", ] [[package]] name = "rmcp-macros" version = "0.1.5" -source = "git+https://github.com/modelcontextprotocol/rust-sdk.git?rev=afb8a905e54b87c69e880f9377cfe8424aa6f13b#afb8a905e54b87c69e880f9377cfe8424aa6f13b" +source = "git+https://github.com/modelcontextprotocol/rust-sdk.git?rev=076dc2c2cd8910bee56bae13f29bbcff8c279666#076dc2c2cd8910bee56bae13f29bbcff8c279666" dependencies = [ "proc-macro2", "quote", + "serde_json", "syn", ] @@ -1445,6 +1481,7 @@ version = "0.8.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3fbf2ae1b8bc8e02df939598064d22402220cd5bbcca1c76f7d6a310974d5615" dependencies = [ + "chrono", "dyn-clone", "schemars_derive", "serde", @@ -2115,9 +2152,18 @@ dependencies = [ [[package]] name = "webpki-roots" -version = "0.26.10" +version = "0.26.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "521bc38abb08001b01866da9f51eb7c5d647a19260e00054a8c7fd5f9e57f7a9" +dependencies = [ + "webpki-roots 1.0.0", +] + +[[package]] +name = "webpki-roots" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37493cadf42a2a939ed404698ded7fb378bf301b5011f973361779a3a74f8c93" +checksum = "2853738d1cc4f2da3a225c18ec6c3721abb31961096e9dbf5ab35fa88b19cfdb" dependencies = [ "rustls-pki-types", ] @@ -2144,6 +2190,28 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows" +version = "0.61.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c5ee8f3d025738cb02bad7868bbb5f8a6327501e870bf51f1b455b0a2454a419" +dependencies = [ + "windows-collections", + "windows-core", + "windows-future", + "windows-link", + "windows-numerics", +] + +[[package]] +name = "windows-collections" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3beeceb5e5cfd9eb1d76b381630e82c4241ccd0d27f1a39ed41b2760b255c5e8" +dependencies = [ + "windows-core", +] + [[package]] name = "windows-core" version = "0.61.0" @@ -2157,6 +2225,16 @@ dependencies = [ "windows-strings 0.4.0", ] +[[package]] +name = "windows-future" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a1d6bbefcb7b60acd19828e1bc965da6fcf18a7e39490c5f8be71e54a19ba32" +dependencies = [ + "windows-core", + "windows-link", +] + [[package]] name = "windows-implement" version = "0.60.0" @@ -2185,6 +2263,16 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "76840935b766e1b0a05c0066835fb9ec80071d4c09a16f6bd5f7e655e3c14c38" +[[package]] +name = "windows-numerics" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9150af68066c4c5c07ddc0ce30421554771e528bde427614c61038bc2c92c2b1" +dependencies = [ + "windows-core", + "windows-link", +] + [[package]] name = "windows-registry" version = "0.4.0" diff --git a/Cargo.toml b/Cargo.toml index c02262d..751c060 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,16 @@ version = "0.1.1" edition = "2024" [dependencies] -rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk.git", rev = "afb8a905e54b87c69e880f9377cfe8424aa6f13b", features = ["server", "client", "transport-sse", "transport-child-process"] } +rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk.git", rev = "076dc2c2cd8910bee56bae13f29bbcff8c279666", features = [ + "server", + "client", + "reqwest", + "client-side-sse", + "transport-sse-client", + "transport-streamable-http-client", + "transport-worker", + "transport-child-process" +] } clap = { version = "4.5.37", features = ["derive"] } tokio = { version = "1", features = ["full"] } tracing = "0.1.41" @@ -20,7 +29,17 @@ version = "0.9" features = ["vendored"] [dev-dependencies] -rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk.git", rev = "afb8a905e54b87c69e880f9377cfe8424aa6f13b", features = ["server", "client", "transport-sse", "transport-sse-server", "transport-child-process", "macros"] } +rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk.git", rev = "076dc2c2cd8910bee56bae13f29bbcff8c279666", features = [ + "server", + "client", + "reqwest", + "client-side-sse", + "transport-sse-client", + "transport-sse-server", + "transport-child-process", + "transport-streamable-http-server", + "macros" +] } axum = { version = "0.8", features = ["macros"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" diff --git a/examples/echo.rs b/examples/echo.rs new file mode 100644 index 0000000..6a8bd72 --- /dev/null +++ b/examples/echo.rs @@ -0,0 +1,59 @@ +use anyhow::Context; +use clap::Parser; +use rmcp::transport::SseServer; +use tracing_subscriber::FmtSubscriber; + +use rmcp::{ + ServerHandler, + model::{ServerCapabilities, ServerInfo}, + schemars, tool, +}; +#[derive(Debug, Clone, Default)] +pub struct Echo; +#[tool(tool_box)] +impl Echo { + #[tool(description = "Echo a message")] + fn echo(&self, #[tool(param)] message: String) -> String { + message + } +} + +#[tool(tool_box)] +impl ServerHandler for Echo { + fn get_info(&self) -> ServerInfo { + ServerInfo { + instructions: Some("A simple echo server".into()), + capabilities: ServerCapabilities::builder().enable_tools().build(), + ..Default::default() + } + } +} + +#[derive(Parser)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Address to bind the server to + #[arg(short, long, default_value = "127.0.0.1:8080")] + address: std::net::SocketAddr, +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + let subscriber = FmtSubscriber::builder() + .with_max_level(tracing::Level::DEBUG) + .with_writer(std::io::stderr) + .finish(); + + // Parse command line arguments + let args = Args::parse(); + + tracing::subscriber::set_global_default(subscriber).context("Failed to set up logging")?; + + let ct = SseServer::serve(args.address) + .await? + .with_service(Echo::default); + + tokio::signal::ctrl_c().await?; + ct.cancel(); + Ok(()) +} diff --git a/examples/echo_streamable.rs b/examples/echo_streamable.rs new file mode 100644 index 0000000..a3ef7a7 --- /dev/null +++ b/examples/echo_streamable.rs @@ -0,0 +1,59 @@ +use anyhow::Context; +use clap::Parser; +use rmcp::transport::StreamableHttpServer; +use tracing_subscriber::FmtSubscriber; + +use rmcp::{ + ServerHandler, + model::{ServerCapabilities, ServerInfo}, + schemars, tool, +}; +#[derive(Debug, Clone, Default)] +pub struct Echo; +#[tool(tool_box)] +impl Echo { + #[tool(description = "Echo a message")] + fn echo(&self, #[tool(param)] message: String) -> String { + message + } +} + +#[tool(tool_box)] +impl ServerHandler for Echo { + fn get_info(&self) -> ServerInfo { + ServerInfo { + instructions: Some("A simple echo server".into()), + capabilities: ServerCapabilities::builder().enable_tools().build(), + ..Default::default() + } + } +} + +#[derive(Parser)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Address to bind the server to + #[arg(short, long, default_value = "127.0.0.1:8080")] + address: std::net::SocketAddr, +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + let subscriber = FmtSubscriber::builder() + .with_max_level(tracing::Level::DEBUG) + .with_writer(std::io::stderr) + .finish(); + + // Parse command line arguments + let args = Args::parse(); + + tracing::subscriber::set_global_default(subscriber).context("Failed to set up logging")?; + + let ct = StreamableHttpServer::serve(args.address) + .await? + .with_service(Echo::default); + + tokio::signal::ctrl_c().await?; + ct.cancel(); + Ok(()) +} diff --git a/src/core.rs b/src/core.rs index 534246e..6bc2b66 100644 --- a/src/core.rs +++ b/src/core.rs @@ -1,13 +1,11 @@ use crate::state::{AppState, BufferMode, ProxyState, ReconnectFailureReason}; -use crate::{DISCONNECTED_ERROR_CODE, SseClientTransport, StdoutSink, TRANSPORT_SEND_ERROR_CODE}; -use anyhow::Result; -use futures::{FutureExt, SinkExt, StreamExt}; -use rmcp::{ - model::{ - ClientJsonRpcMessage, ClientRequest, EmptyResult, ErrorData, RequestId, - ServerJsonRpcMessage, - }, - transport::sse::{ReqwestSseClient, SseTransport, SseTransportRetryConfig}, +use crate::{DISCONNECTED_ERROR_CODE, SseClientType, StdoutSink, TRANSPORT_SEND_ERROR_CODE}; +use anyhow::{Result, anyhow}; +use futures::FutureExt; +use futures::SinkExt; +use rmcp::model::{ + ClientJsonRpcMessage, ClientNotification, ClientRequest, ErrorData, RequestId, + ServerJsonRpcMessage, }; use std::time::Duration; use tracing::{debug, error, info}; @@ -36,11 +34,83 @@ pub(crate) async fn reply_disconnected(id: &RequestId, stdout_sink: &mut StdoutS Ok(()) } +pub(crate) async fn connect(app_state: &AppState) -> Result { + // this function should try sending a POST request to the sse_url and see if + // the server responds with 405 method not supported. If so, it should call + // connect_with_sse, otherwise it should call connect_with_streamable. + let result = reqwest::Client::new() + .post(app_state.url.clone()) + .header("Accept", "application/json,text/event-stream") + .header("Content-Type", "application/json") + .body(r#"{"jsonrpc":"2.0","id":"init","method":"initialize","params":{"protocolVersion":"2025-03-26","capabilities":{},"clientInfo":{"name":"test","version":"0.1.0"}}}"#) + .send() + .await?; + + if result.status() == reqwest::StatusCode::METHOD_NOT_ALLOWED { + debug!("Server responded with 405, using SSE transport"); + return connect_with_sse(app_state).await; + } else if result.status().is_success() { + debug!("Server responded successfully, using streamable transport"); + return connect_with_streamable(app_state).await; + } else { + error!("Server returned unexpected status: {}", result.status()); + anyhow::bail!("Server returned unexpected status: {}", result.status()); + } +} + +pub(crate) async fn connect_with_streamable(app_state: &AppState) -> Result { + let result = rmcp::transport::StreamableHttpClientTransport::with_client( + reqwest::Client::default(), + rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig { + uri: app_state.url.clone().into(), + // we don't want the sdk to perform any retries + retry_config: std::sync::Arc::new( + rmcp::transport::common::client_side_sse::FixedInterval { + max_times: Some(0), + duration: Duration::from_millis(0), + }, + ), + channel_buffer_capacity: 16, + }, + ); + + Ok(SseClientType::Streamable(result)) +} + +pub(crate) async fn connect_with_sse(app_state: &AppState) -> Result { + let result = rmcp::transport::SseClientTransport::start_with_client( + reqwest::Client::default(), + rmcp::transport::sse_client::SseClientConfig { + sse_endpoint: app_state.url.clone().into(), + // we don't want the sdk to perform any retries + retry_policy: std::sync::Arc::new( + rmcp::transport::common::client_side_sse::FixedInterval { + max_times: Some(0), + duration: Duration::from_millis(0), + }, + ), + use_message_endpoint: None, + }, + ) + .await; + + match result { + Ok(transport) => { + info!("Successfully reconnected to SSE server"); + Ok(SseClientType::Sse(transport)) + } + Err(e) => { + error!("Failed to reconnect: {}", e); + Err(anyhow!("Connection failed: {}", e)) + } + } +} + /// Attempts to reconnect to the SSE server with backoff. /// Does not mutate AppState directly. pub(crate) async fn try_reconnect( app_state: &AppState, -) -> Result { +) -> Result { let backoff = app_state.get_backoff_duration(); info!( "Attempting to reconnect in {}s (attempt {})", @@ -53,21 +123,16 @@ pub(crate) async fn try_reconnect( return Err(ReconnectFailureReason::TimeoutExceeded); } - let client = ReqwestSseClient::new(&app_state.url) - .map_err(|e| ReconnectFailureReason::ConnectionFailed(e.into()))?; + let result = connect(app_state).await; - match SseTransport::start_with_client(client).await { - Ok(mut new_transport) => { - new_transport.retry_config = SseTransportRetryConfig { - max_times: Some(0), - min_duration: Duration::from_millis(0), - }; + match result { + Ok(transport) => { info!("Successfully reconnected to SSE server"); - Ok(new_transport) + Ok(transport) } Err(e) => { error!("Failed to reconnect: {}", e); - Err(ReconnectFailureReason::ConnectionFailed(e.into())) + Err(ReconnectFailureReason::ConnectionFailed(e)) } } } @@ -75,9 +140,9 @@ pub(crate) async fn try_reconnect( /// Sends a JSON-RPC request to the SSE server and handles any transport errors. /// Returns true if the send was successful, false otherwise. pub(crate) async fn send_request_to_sse( - transport: &mut SseClientTransport, + transport: &mut SseClientType, request: ClientJsonRpcMessage, - original_id: RequestId, + original_message: ClientJsonRpcMessage, stdout_sink: &mut StdoutSink, app_state: &mut AppState, ) -> Result { @@ -86,26 +151,11 @@ pub(crate) async fn send_request_to_sse( Ok(_) => Ok(true), Err(e) => { error!("Error sending to SSE: {}", e); - app_state.disconnected(); - - if app_state.buf_mode == BufferMode::Store { - debug!("Buffering request for later retry"); - app_state.in_buf.push(request); - app_state.schedule_flush_timer(); - app_state.schedule_reconnect(); - } else { - let error_response = ServerJsonRpcMessage::error( - ErrorData::new( - TRANSPORT_SEND_ERROR_CODE, - format!("Transport error: {}", e), - None, - ), - original_id, - ); - if let Err(write_err) = stdout_sink.send(error_response).await { - error!("Error writing error response to stdout: {}", write_err); - } - } + app_state.handle_fatal_transport_error(); + app_state + .maybe_handle_message_while_disconnected(original_message, stdout_sink) + .await?; + Ok(false) } } @@ -116,29 +166,9 @@ pub(crate) async fn send_request_to_sse( pub(crate) async fn process_client_request( message: ClientJsonRpcMessage, app_state: &mut AppState, - transport: &mut SseClientTransport, + transport: &mut SseClientType, stdout_sink: &mut StdoutSink, ) -> Result<()> { - // Handle ping directly if disconnected - if let ClientJsonRpcMessage::Request(ref req) = message { - if let ClientRequest::PingRequest(_) = &req.request { - if app_state.state == ProxyState::Disconnected { - debug!( - "Received Ping request while disconnected, replying directly: {:?}", - req.id - ); - let response = ServerJsonRpcMessage::response( - rmcp::model::ServerResult::EmptyResult(EmptyResult {}), - req.id.clone(), - ); - if let Err(e) = stdout_sink.send(response).await { - error!("Error sending direct ping response to stdout: {}", e); - } - return Ok(()); - } - } - } - // Try mapping the ID first (for Response/Error cases). // If it returns None, the ID was unknown, so we skip processing/forwarding. let message = match app_state.map_client_response_error_id(message) { @@ -146,14 +176,40 @@ pub(crate) async fn process_client_request( None => return Ok(()), // Skip forwarding if ID was not mapped }; - // Check if disconnected and buffer if necessary (both requests and mapped responses/errors) - if app_state.state == ProxyState::Disconnected && app_state.buf_mode == BufferMode::Store { - debug!("Buffering message while disconnected: {:?}", message); - app_state.in_buf.push(message); - return Ok(()); + // Handle ping directly if disconnected + match app_state + .maybe_handle_message_while_disconnected(message.clone(), stdout_sink) + .await + { + Err(_) => {} + Ok(_) => return Ok(()), + } + + match &message { + ClientJsonRpcMessage::Request(req) => { + if app_state.init_message.is_none() { + if let ClientRequest::InitializeRequest(_) = req.request { + debug!("Stored client initialization message"); + app_state.init_message = Some(message.clone()); + app_state.state = ProxyState::WaitingForServerInit(req.id.clone()); + } + } + } + ClientJsonRpcMessage::Notification(notification) => { + if let ClientNotification::InitializedNotification(_) = notification.notification { + if app_state.state == ProxyState::WaitingForClientInitialized { + debug!("Received client initialized notification, proxy fully connected."); + app_state.connected(); + } else { + debug!("Forwarding client initialized notification outside of expected state."); + } + } + } + _ => {} } // Process requests separately to map their IDs before sending + let original_message = message.clone(); if let ClientJsonRpcMessage::Request(req) = message { let request_id = req.id.clone(); let mut req = req.clone(); @@ -167,7 +223,7 @@ pub(crate) async fn process_client_request( let _success = send_request_to_sse( transport, ClientJsonRpcMessage::Request(req), - request_id, // Pass the original ID for potential error reporting + original_message, stdout_sink, app_state, ) @@ -188,7 +244,7 @@ pub(crate) async fn process_client_request( /// Process buffered messages after a successful reconnection pub(crate) async fn process_buffered_messages( app_state: &mut AppState, - transport: &mut SseClientTransport, + transport: &mut SseClientType, stdout_sink: &mut StdoutSink, ) -> Result<()> { let buffered_messages = std::mem::take(&mut app_state.in_buf); @@ -266,7 +322,7 @@ pub(crate) async fn flush_buffer_with_errors( /// Returns Ok(false) if sending the init message failed (triggers disconnect). pub(crate) async fn initiate_post_reconnect_handshake( app_state: &mut AppState, - transport: &mut SseClientTransport, + transport: &mut SseClientType, stdout_sink: &mut StdoutSink, ) -> Result { if let Some(init_msg) = &app_state.init_message { @@ -274,9 +330,7 @@ pub(crate) async fn initiate_post_reconnect_handshake( req.id.clone() } else { error!("Stored init_message is not a request: {:?}", init_msg); - process_buffered_messages(app_state, transport, stdout_sink).await?; - app_state.state = ProxyState::Connected; - return Ok(true); + return Ok(false); }; debug!( @@ -285,7 +339,9 @@ pub(crate) async fn initiate_post_reconnect_handshake( ); app_state.state = ProxyState::WaitingForServerInitHidden(id.clone()); - if let Err(e) = transport.send(init_msg.clone()).await { + if let Err(e) = + process_client_request(init_msg.clone(), app_state, transport, stdout_sink).await + { info!("Error resending init message during handshake: {}", e); app_state.handle_fatal_transport_error(); Ok(false) @@ -308,11 +364,11 @@ pub(crate) async fn initiate_post_reconnect_handshake( /// Returns Some(true) if alive, Some(false) if dead, None if check not needed. pub(crate) async fn send_heartbeat_if_needed( app_state: &AppState, - transport: &mut SseClientTransport, + transport: &mut SseClientType, ) -> Option { if app_state.last_heartbeat.elapsed() > Duration::from_secs(5) { debug!("Checking SSE connection state due to inactivity..."); - match transport.next().now_or_never() { + match transport.receive().now_or_never() { Some(Some(_)) => { debug!("Heartbeat check: Received message/event, connection alive."); Some(true) diff --git a/src/main.rs b/src/main.rs index 10e696e..d8b54a5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,7 +3,7 @@ use clap::Parser; use futures::StreamExt; use rmcp::{ model::{ClientJsonRpcMessage, ErrorCode, ServerJsonRpcMessage}, - transport::sse::{ReqwestSseClient, SseTransport, SseTransportRetryConfig}, + transport::{StreamableHttpClientTransport, Transport, sse_client::SseClientTransport}, }; use std::env; use tokio::io::{Stdin, Stdout}; @@ -18,16 +18,41 @@ mod core; mod state; use crate::cli::Args; -use crate::core::flush_buffer_with_errors; +use crate::core::{connect, flush_buffer_with_errors}; use crate::state::{AppState, ProxyState}; // Only needed directly by main for final check // Custom Error Codes (Keep here or move to common/state? Keeping here for now) const DISCONNECTED_ERROR_CODE: ErrorCode = ErrorCode(-32010); const TRANSPORT_SEND_ERROR_CODE: ErrorCode = ErrorCode(-32011); -type SseClientTransport = SseTransport; -type StdinCodec = rmcp::transport::io::JsonRpcMessageCodec; -type StdoutCodec = rmcp::transport::io::JsonRpcMessageCodec; +enum SseClientType { + Sse(SseClientTransport), + Streamable(StreamableHttpClientTransport), +} + +impl SseClientType { + async fn send( + &mut self, + item: ClientJsonRpcMessage, + ) -> Result<(), Box> { + match self { + SseClientType::Sse(transport) => transport.send(item).await.map_err(|e| e.into()), + SseClientType::Streamable(transport) => { + transport.send(item).await.map_err(|e| e.into()) + } + } + } + + async fn receive(&mut self) -> Option { + match self { + SseClientType::Sse(transport) => transport.receive().await, + SseClientType::Streamable(transport) => transport.receive().await, + } + } +} + +type StdinCodec = rmcp::transport::async_rw::JsonRpcMessageCodec; +type StdoutCodec = rmcp::transport::async_rw::JsonRpcMessageCodec; type StdinStream = FramedRead; type StdoutSink = FramedWrite; @@ -35,7 +60,7 @@ type StdoutSink = FramedWrite; const INITIAL_CONNECT_TIMEOUT: Duration = Duration::from_secs(5 * 60); // 5 minutes /// Attempts to establish the initial SSE connection, retrying on failure. -async fn connect_with_retry(sse_url: &str, delay: Duration) -> Result { +async fn connect_with_retry(app_state: &AppState, delay: Duration) -> Result { let start_time = Instant::now(); let mut attempts = 0; @@ -46,26 +71,16 @@ async fn connect_with_retry(sse_url: &str, delay: Duration) -> Result { - match SseTransport::start_with_client(client).await { - Ok(mut transport) => { - info!("Initial connection successful!"); - // Configure transport to not retry internally after initial connect - transport.retry_config = SseTransportRetryConfig { - max_times: Some(0), - min_duration: Duration::from_millis(0), - }; - return Ok(transport); - } - Err(e) => { - warn!("Attempt {} failed to start transport: {}", attempts, e); - } - } + let result = connect(app_state).await; + + // Try creating the transport + match result { + Ok(transport) => { + info!("Initial connection successful!"); + return Ok(transport); } Err(e) => { - warn!("Attempt {} failed to create SSE client: {}", attempts, e); + warn!("Attempt {} failed to start transport: {}", attempts, e); } } @@ -126,7 +141,7 @@ async fn main() -> Result<()> { // Establish initial SSE connection using the retry helper info!("Attempting initial connection to {}...", sse_url); let mut transport = - connect_with_retry(&sse_url, Duration::from_secs(args.initial_retry_interval)).await?; + connect_with_retry(&app_state, Duration::from_secs(args.initial_retry_interval)).await?; info!("Connection established. Proxy operational."); app_state.state = ProxyState::WaitingForClientInit; @@ -154,7 +169,7 @@ async fn main() -> Result<()> { } } // Handle message from SSE server - result = transport.next(), if app_state.transport_valid => { + result = transport.receive(), if app_state.transport_valid => { if !app_state.handle_sse_message(result, &mut transport, &mut stdout_sink).await? { break; } diff --git a/src/state.rs b/src/state.rs index 9f65a80..28ee1c0 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,12 +1,13 @@ use crate::core::{ flush_buffer_with_errors, generate_id, initiate_post_reconnect_handshake, - process_buffered_messages, process_client_request, send_heartbeat_if_needed, try_reconnect, + process_buffered_messages, process_client_request, reply_disconnected, + send_heartbeat_if_needed, try_reconnect, }; -use crate::{SseClientTransport, StdoutSink}; +use crate::{SseClientType, StdoutSink}; use anyhow::Result; use futures::SinkExt; use rmcp::model::{ - ClientJsonRpcMessage, ClientNotification, ClientRequest, InitializedNotification, + ClientJsonRpcMessage, ClientNotification, ClientRequest, EmptyResult, InitializedNotification, InitializedNotificationMethod, RequestId, ServerJsonRpcMessage, }; use std::collections::HashMap; @@ -113,9 +114,9 @@ impl AppState { self.state = ProxyState::Disconnected; self.disconnected_since = Some(Instant::now()); self.buf_mode = BufferMode::Store; - self.connect_tries += 1; self.transport_valid = false; } + self.connect_tries += 1; } pub fn disconnected_too_long(&self) -> bool { @@ -181,39 +182,14 @@ impl AppState { /// Returns Ok(true) to continue processing, Ok(false) to break the main loop. pub(crate) async fn handle_stdin_message( &mut self, - msg: Option>, - transport: &mut SseClientTransport, + msg: Option< + Result, + >, + transport: &mut SseClientType, stdout_sink: &mut StdoutSink, ) -> Result { match msg { Some(Ok(message)) => { - match &message { - ClientJsonRpcMessage::Request(req) => { - if let ClientRequest::InitializeRequest(_) = req.request { - debug!("Stored client initialization message"); - self.init_message = Some(message.clone()); - self.state = ProxyState::WaitingForServerInit(req.id.clone()); - } - } - ClientJsonRpcMessage::Notification(notification) => { - if let ClientNotification::InitializedNotification(_) = - notification.notification - { - if self.state == ProxyState::WaitingForClientInitialized { - debug!( - "Received client initialized notification, proxy fully connected." - ); - self.state = ProxyState::Connected; - } else { - debug!( - "Forwarding client initialized notification outside of expected state." - ); - } - } - } - _ => {} - } - process_client_request(message, self, transport, stdout_sink).await?; Ok(true) } @@ -233,9 +209,10 @@ impl AppState { pub(crate) async fn handle_sse_message( &mut self, result: Option, - transport: &mut SseClientTransport, + transport: &mut SseClientType, stdout_sink: &mut StdoutSink, ) -> Result { + debug!("Received SSE message: {:?}", result); match result { Some(mut message) => { self.update_heartbeat(); @@ -256,6 +233,10 @@ impl AppState { } // --- End Server-Initiated Request Handling --- else { + match self.map_server_response_error_id(message) { + Some(mapped_message) => message = mapped_message, + None => return Ok(true), // Skip forwarding this message + } // --- Handle Initialization Response --- (Only for Response/Error) let is_init_response = match &message { ServerJsonRpcMessage::Response(response) => match self.state { @@ -271,11 +252,16 @@ impl AppState { _ => false, }; + debug!( + "Handling initialization response - state: {:?}, message: {:?}, is_init_response: {}", + self.state, message, is_init_response + ); + if is_init_response { let was_hidden = matches!(self.state, ProxyState::WaitingForServerInitHidden(_)); if was_hidden { - self.state = ProxyState::Connected; + self.connected(); debug!("Reconnection successful, received hidden init response"); let initialized_notification = ClientJsonRpcMessage::notification( ClientNotification::InitializedNotification( @@ -294,24 +280,15 @@ impl AppState { } else { process_buffered_messages(self, transport, stdout_sink).await?; } + return Ok(true); // Don't forward the init response } else { debug!( "Initial connection successful, received init response. Waiting for client initialized." ); self.state = ProxyState::WaitingForClientInitialized; } - return Ok(true); // Don't forward the init response } // --- End Initialization Response Handling --- - - // --- Handle Regular Response/Error ID Mapping --- (Client->Server flow) - // Map Response/Error back to original client ID if possible - // If mapping fails (returns None), skip forwarding. - match self.map_server_response_error_id(message) { - Some(mapped_message) => message = mapped_message, - None => return Ok(true), // Skip forwarding this message - } - // --- End Regular Response/Error ID Mapping --- } // Forward the (potentially modified) message to stdout @@ -332,19 +309,54 @@ impl AppState { } } + pub(crate) async fn maybe_handle_message_while_disconnected( + &mut self, + message: ClientJsonRpcMessage, + stdout_sink: &mut StdoutSink, + ) -> Result<()> { + if self.state != ProxyState::Disconnected { + return Err(anyhow::anyhow!("Not disconnected")); + } + + // Handle ping directly if disconnected + if let ClientJsonRpcMessage::Request(ref req) = message { + if let ClientRequest::PingRequest(_) = &req.request { + debug!( + "Received Ping request while disconnected, replying directly: {:?}", + req.id + ); + let response = ServerJsonRpcMessage::response( + rmcp::model::ServerResult::EmptyResult(EmptyResult {}), + req.id.clone(), + ); + if let Err(e) = stdout_sink.send(response).await { + error!("Error sending direct ping response to stdout: {}", e); + } + return Ok(()); + } + if self.buf_mode == BufferMode::Store { + debug!("Buffering request for later retry"); + self.in_buf.push(message); + } else { + reply_disconnected(&req.id, stdout_sink).await?; + } + } + + Ok(()) + } + /// Handles the reconnect signal. /// Returns the potentially new transport if reconnection was successful. pub(crate) async fn handle_reconnect_signal( &mut self, stdout_sink: &mut StdoutSink, - ) -> Result> { + ) -> Result> { debug!("Received reconnect signal"); self.reconnect_scheduled = false; if self.state == ProxyState::Disconnected { match try_reconnect(self).await { Ok(mut new_transport) => { - self.connected(); self.transport_valid = true; initiate_post_reconnect_handshake(self, &mut new_transport, stdout_sink) @@ -405,7 +417,7 @@ impl AppState { /// Handles the heartbeat interval tick. pub(crate) async fn handle_heartbeat_tick( &mut self, - transport: &mut SseClientTransport, + transport: &mut SseClientType, ) -> Result<()> { if self.state == ProxyState::Connected { let check_result = send_heartbeat_if_needed(self, transport).await; @@ -414,7 +426,6 @@ impl AppState { self.update_heartbeat(); } Some(false) => { - debug!("Heartbeat check failed - connection confirmed down"); self.handle_fatal_transport_error(); } None => {} diff --git a/tests/advanced_test.rs b/tests/advanced_test.rs index 011b7a4..46e8dd3 100644 --- a/tests/advanced_test.rs +++ b/tests/advanced_test.rs @@ -4,82 +4,146 @@ use rmcp::{ ServiceExt, model::CallToolRequestParam, object, - transport::{SseServer, TokioChildProcess}, + transport::{ConfigureCommandExt, TokioChildProcess}, +}; +use std::{ + net::SocketAddr, + sync::{Arc, Mutex}, + time::Duration, }; -use std::{net::SocketAddr, time::Duration}; use tokio::{ io::{AsyncBufReadExt, AsyncWriteExt, BufReader}, time::{sleep, timeout}, }; -// Creates a new SSE server for testing -async fn create_sse_server( - address: SocketAddr, -) -> Result<(tokio_util::sync::CancellationToken, String)> { - let url = format!("http://{}/sse", address); - let ct = tokio_util::sync::CancellationToken::new(); - - tracing::info!("Creating SSE server at {}", url); - - let config = rmcp::transport::sse_server::SseServerConfig { - bind: address, - sse_path: "/sse".to_string(), - post_path: "/message".to_string(), - ct: ct.clone(), - sse_keep_alive: None, - }; +/// A guard that ensures processes are killed on drop, especially on test failures (panics) +struct TestGuard { + child: Option, + server_handle: Option, + stderr_buffer: Arc>>, +} - let (sse_server, router) = SseServer::new(config); +impl TestGuard { + fn new( + child: tokio::process::Child, + server_handle: tokio::process::Child, + stderr_buffer: Arc>>, + ) -> Self { + Self { + child: Some(child), + server_handle: Some(server_handle), + stderr_buffer, + } + } +} - // Bind the listener for the server - let listener = tokio::net::TcpListener::bind(sse_server.config.bind).await?; - tracing::debug!("SSE server bound to {}", sse_server.config.bind); +impl Drop for TestGuard { + fn drop(&mut self) { + // If we're dropping because of a panic, print the stderr content + if std::thread::panicking() { + eprintln!("Test failed! Process stderr output:"); + for line in self.stderr_buffer.lock().unwrap().iter() { + eprintln!("{}", line); + } + } - // Create a child token for cancellation - let child_ct = sse_server.config.ct.child_token(); + // Force kill both processes + if let Some(mut child) = self.child.take() { + let _ = child.start_kill(); + } + if let Some(mut server_handle) = self.server_handle.take() { + let _ = server_handle.start_kill(); + } + } +} - // Spawn the server task with graceful shutdown - let server = axum::serve(listener, router).with_graceful_shutdown(async move { - tracing::info!("Waiting for cancellation signal..."); - child_ct.cancelled().await; - tracing::info!("SSE server cancelled"); - }); +/// Spawns a proxy process with stdin, stdout, and stderr all captured +async fn spawn_proxy( + server_url: &str, + extra_args: Vec<&str>, +) -> Result<( + tokio::process::Child, + tokio::io::BufReader, + tokio::io::BufReader, + tokio::process::ChildStdin, +)> { + let mut cmd = tokio::process::Command::new("./target/debug/mcp-proxy"); + cmd.arg(server_url) + .args(extra_args) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::piped()) + .stdin(std::process::Stdio::piped()); + + let mut child = cmd.spawn()?; + let stdin = child.stdin.take().unwrap(); + let stdout = BufReader::new(child.stdout.take().unwrap()); + let stderr = BufReader::new(child.stderr.take().unwrap()); + + Ok((child, stdout, stderr, stdin)) +} + +/// Collects stderr lines in the background +fn collect_stderr( + mut stderr_reader: BufReader, +) -> Arc>> { + let stderr_buffer = Arc::new(Mutex::new(Vec::new())); + let buffer_clone = stderr_buffer.clone(); - tracing::debug!("Starting SSE server task"); tokio::spawn(async move { - if let Err(e) = server.await { - tracing::error!(error = %e, "SSE server shutdown with error"); - } else { - tracing::info!("SSE server shutdown successfully"); + let mut line = String::new(); + while let Ok(bytes_read) = stderr_reader.read_line(&mut line).await { + if bytes_read == 0 { + break; + } + buffer_clone.lock().unwrap().push(line.clone()); + line.clear(); } }); - // Create the echo service - let service_ct = sse_server.with_service(echo::Echo::default); - tracing::info!("SSE server created successfully with Echo service"); - - // Force using this cancellation token to ensure proper shutdown - Ok((service_ct, url)) + stderr_buffer } -#[tokio::test] -async fn test_protocol_initialization() -> Result<()> { - const BIND_ADDRESS: &str = "127.0.0.1:8181"; - // Start the SSE server - let (server_handle, server_url) = create_sse_server(BIND_ADDRESS.parse()?).await?; +// Creates a new SSE server for testing +// Starts the echo-server as a subprocess +async fn create_sse_server( + server_name: &str, + address: SocketAddr, +) -> Result<(tokio::process::Child, String)> { + let url = if server_name == "echo_streamable" { + format!("http://{}", address) + } else { + format!("http://{}/sse", address) + }; - // Create a child process for the proxy - let mut cmd = tokio::process::Command::new("./target/debug/mcp-proxy"); - cmd.arg(&server_url) - .stdout(std::process::Stdio::piped()) - .stdin(std::process::Stdio::piped()); + tracing::info!("Starting echo-server at {}", url); - let mut child = cmd.spawn()?; + // Create echo-server process + let mut cmd = tokio::process::Command::new(format!("./target/debug/examples/{}", server_name)); + cmd.arg("--address").arg(address.to_string()); - // Get stdin and stdout handles - let mut stdin = child.stdin.take().unwrap(); - let stdout = child.stdout.take().unwrap(); - let mut reader = BufReader::new(stdout); + tracing::debug!("cmd: {:?}", cmd); + + // Start the process with stdout/stderr redirected to null + let child = cmd + .stdout(std::process::Stdio::null()) + .stderr(std::process::Stdio::null()) + .spawn()?; + + // Give the server time to start up + sleep(Duration::from_millis(500)).await; + tracing::info!("{} server started successfully", server_name); + + Ok((child, url)) +} + +async fn protocol_initialization(server_name: &str) -> Result<()> { + const BIND_ADDRESS: &str = "127.0.0.1:8181"; + let (server_handle, server_url) = create_sse_server(server_name, BIND_ADDRESS.parse()?).await?; + + // Create a child process for the proxy with stderr capture + let (child, mut reader, stderr_reader, mut stdin) = spawn_proxy(&server_url, vec![]).await?; + let stderr_buffer = collect_stderr(stderr_reader); + let _guard = TestGuard::new(child, server_handle, stderr_buffer); // Send initialization message let init_message = r#"{"jsonrpc":"2.0","id":"init-1","method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"0.1.0"}}}"#; @@ -112,15 +176,18 @@ async fn test_protocol_initialization() -> Result<()> { assert!(echo_response.contains("\"id\":\"call-1\"")); assert!(echo_response.contains("Hey!")); - // Clean up - child.kill().await?; - server_handle.cancel(); - Ok(()) } #[tokio::test] -async fn test_reconnection_handling() -> Result<()> { +async fn test_protocol_initialization() -> Result<()> { + protocol_initialization("echo").await?; + protocol_initialization("echo_streamable").await?; + + Ok(()) +} + +async fn reconnection_handling(server_name: &str) -> Result<()> { let subscriber = tracing_subscriber::fmt() .with_max_level(tracing::Level::DEBUG) .with_test_writer() @@ -130,22 +197,14 @@ async fn test_reconnection_handling() -> Result<()> { const BIND_ADDRESS: &str = "127.0.0.1:8182"; // Start the SSE server - println!("Test: Starting initial SSE server"); - let (server_handle, server_url) = create_sse_server(BIND_ADDRESS.parse()?).await?; + tracing::info!("Test: Starting initial SSE server"); + let (server_handle, server_url) = create_sse_server(server_name, BIND_ADDRESS.parse()?).await?; // Create a child process for the proxy - println!("Test: Creating proxy process"); - let mut cmd = tokio::process::Command::new("./target/debug/mcp-proxy"); - cmd.arg(&server_url) - .stdout(std::process::Stdio::piped()) - .stdin(std::process::Stdio::piped()); - - let mut child = cmd.spawn()?; - - // Get stdin and stdout handles - let mut stdin = child.stdin.take().unwrap(); - let stdout = child.stdout.take().unwrap(); - let mut reader = BufReader::new(stdout); + tracing::info!("Test: Creating proxy process"); + let (child, mut reader, stderr_reader, mut stdin) = spawn_proxy(&server_url, vec![]).await?; + let stderr_buffer = collect_stderr(stderr_reader); + let mut test_guard = TestGuard::new(child, server_handle, stderr_buffer); // Send initialization message let init_message = r#"{"jsonrpc":"2.0","id":"init-1","method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"0.1.0"}}}"#; @@ -174,19 +233,25 @@ async fn test_reconnection_handling() -> Result<()> { ); // Shutdown the server - server_handle.cancel(); + if let Some(mut server) = test_guard.server_handle.take() { + server.kill().await?; + } // Give the server time to shut down sleep(Duration::from_millis(1000)).await; // Create a new server on the same address - println!("Test: Starting new SSE server"); - let (new_ct, new_url) = create_sse_server(BIND_ADDRESS.parse()?).await?; + tracing::info!("Test: Starting new SSE server"); + let (new_server_handle, new_url) = + create_sse_server(server_name, BIND_ADDRESS.parse()?).await?; assert_eq!( server_url, new_url, "New server URL should match the original" ); + // Update the test guard with the new server handle + test_guard.server_handle = Some(new_server_handle); + // Give the proxy time to reconnect sleep(Duration::from_millis(3000)).await; @@ -199,29 +264,36 @@ async fn test_reconnection_handling() -> Result<()> { let mut echo_response = String::new(); reader.read_line(&mut echo_response).await?; + tracing::info!("Test: Received echo response: {}", echo_response.trim()); + // Even if the response contains an error, we should at least get a response assert!( echo_response.contains("\"id\":\"call-2\""), "No response received after reconnection" ); - // Clean up - new_ct.cancel(); - sleep(Duration::from_millis(500)).await; // Give server time to shutdown - child.kill().await?; - Ok(()) } #[tokio::test] -async fn test_server_info_and_capabilities() -> Result<()> { +async fn test_reconnection_handling() -> Result<()> { + reconnection_handling("echo").await?; + reconnection_handling("echo_streamable").await?; + + Ok(()) +} + +async fn server_info_and_capabilities(server_name: &str) -> Result<()> { const BIND_ADDRESS: &str = "127.0.0.1:8183"; // Start the SSE server - let (server_handle, server_url) = create_sse_server(BIND_ADDRESS.parse()?).await?; + let (mut server_handle, server_url) = + create_sse_server(server_name, BIND_ADDRESS.parse()?).await?; // Create a transport for the proxy let transport = TokioChildProcess::new( - tokio::process::Command::new("./target/debug/mcp-proxy").arg(&server_url), + tokio::process::Command::new("./target/debug/mcp-proxy").configure(|cmd| { + cmd.arg(&server_url); + }), )?; // Connect a client to the proxy @@ -253,13 +325,20 @@ async fn test_server_info_and_capabilities() -> Result<()> { // Clean up drop(client); - server_handle.cancel(); + server_handle.kill().await?; Ok(()) } #[tokio::test] -async fn test_initial_connection_retry() -> Result<()> { +async fn test_server_info_and_capabilities() -> Result<()> { + server_info_and_capabilities("echo").await?; + server_info_and_capabilities("echo_streamable").await?; + + Ok(()) +} + +async fn initial_connection_retry(server_name: &str) -> Result<()> { // Set up custom logger for this test to clearly see what's happening let subscriber = tracing_subscriber::fmt() .with_max_level(tracing::Level::INFO) @@ -268,28 +347,24 @@ async fn test_initial_connection_retry() -> Result<()> { let _guard = tracing::subscriber::set_default(subscriber); const BIND_ADDRESS: &str = "127.0.0.1:8184"; - let server_url = format!("http://{}/sse", BIND_ADDRESS); + let server_url = if server_name == "echo_streamable" { + format!("http://{}", BIND_ADDRESS) + } else { + format!("http://{}/sse", BIND_ADDRESS) + }; let bind_addr: SocketAddr = BIND_ADDRESS.parse()?; // 1. Start the proxy process BEFORE the server - println!("Test: Starting proxy process..."); - let mut cmd = tokio::process::Command::new("./target/debug/mcp-proxy"); - cmd.arg(&server_url) - .arg("--initial-retry-interval") - .arg("1") - .stdout(std::process::Stdio::piped()) - .stdin(std::process::Stdio::piped()); - let mut child = cmd.spawn()?; + tracing::info!("Test: Starting proxy process..."); + let (child, mut reader, stderr_reader, mut stdin) = + spawn_proxy(&server_url, vec!["--initial-retry-interval", "1"]).await?; - // Get stdin and stdout handles - let mut stdin = child.stdin.take().unwrap(); - let stdout = child.stdout.take().unwrap(); - let mut reader = BufReader::new(stdout); + let stderr_buffer = collect_stderr(stderr_reader); // 2. Wait for slightly longer than the proxy's retry delay // This ensures the proxy has attempted connection at least once and is retrying. let retry_wait = Duration::from_secs(2); - println!( + tracing::info!( "Test: Waiting {:?} for proxy to attempt connection...", retry_wait ); @@ -297,19 +372,21 @@ async fn test_initial_connection_retry() -> Result<()> { // Send initialize message WHILE proxy is still trying to connect // (it will be buffered by the OS pipe until proxy reads stdin) - println!("Test: Sending initialize request (before server starts)..."); + tracing::info!("Test: Sending initialize request (before server starts)..."); let init_message = r#"{"jsonrpc":"2.0","id":"init-retry","method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"retry-test","version":"0.1.0"}}}"#; stdin.write_all(init_message.as_bytes()).await?; stdin.write_all(b"\n").await?; // 3. Start the SSE server AFTER the wait and AFTER sending init - println!("Test: Starting SSE server on {}", BIND_ADDRESS); - let (server_handle, returned_url) = create_sse_server(bind_addr).await?; + tracing::info!("Test: Starting SSE server on {}", BIND_ADDRESS); + let (server_handle, returned_url) = create_sse_server(server_name, bind_addr).await?; assert_eq!(server_url, returned_url, "Server URL mismatch"); + let _test_guard = TestGuard::new(child, server_handle, stderr_buffer); + // 4. Proceed with initialization handshake (Proxy should now process buffered init) // Read the initialize response (with a timeout) - println!("Test: Waiting for initialize response..."); + tracing::info!("Test: Waiting for initialize response..."); let mut init_response = String::new(); match timeout( Duration::from_secs(10), @@ -318,7 +395,7 @@ async fn test_initial_connection_retry() -> Result<()> { .await { Ok(Ok(_)) => { - println!( + tracing::info!( "Test: Received initialize response: {}", init_response.trim() ); @@ -335,22 +412,22 @@ async fn test_initial_connection_retry() -> Result<()> { Err(_) => return Err(anyhow::anyhow!("Timed out waiting for init response")), } - println!("Test: Sending initialized notification..."); + tracing::info!("Test: Sending initialized notification..."); let initialized_message = r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#; stdin.write_all(initialized_message.as_bytes()).await?; stdin.write_all(b"\n").await?; // 5. Test basic functionality (e.g., echo tool call) - println!("Test: Sending echo request..."); + tracing::info!("Test: Sending echo request..."); let echo_call = r#"{"jsonrpc":"2.0","id":"call-retry","method":"tools/call","params":{"name":"echo","arguments":{"message":"Hello after initial retry!"}}}"#; stdin.write_all(echo_call.as_bytes()).await?; stdin.write_all(b"\n").await?; - println!("Test: Waiting for echo response..."); + tracing::info!("Test: Waiting for echo response..."); let mut echo_response = String::new(); match timeout(Duration::from_secs(5), reader.read_line(&mut echo_response)).await { Ok(Ok(_)) => { - println!("Test: Received echo response: {}", echo_response.trim()); + tracing::info!("Test: Received echo response: {}", echo_response.trim()); assert!( echo_response.contains("\"id\":\"call-retry\""), "Echo response missing correct ID" @@ -364,18 +441,19 @@ async fn test_initial_connection_retry() -> Result<()> { Err(_) => return Err(anyhow::anyhow!("Timed out waiting for echo response")), } - // 6. Cleanup - println!("Test: Cleaning up..."); - child.kill().await?; - server_handle.cancel(); - sleep(Duration::from_millis(500)).await; // Give server time to shutdown - - println!("Test: Completed successfully"); + tracing::info!("Test: Completed successfully"); Ok(()) } #[tokio::test] -async fn test_ping_when_disconnected() -> Result<()> { +async fn test_initial_connection_retry() -> Result<()> { + initial_connection_retry("echo").await?; + initial_connection_retry("echo_streamable").await?; + + Ok(()) +} + +async fn ping_when_disconnected(server_name: &str) -> Result<()> { const BIND_ADDRESS: &str = "127.0.0.1:8185"; let subscriber = tracing_subscriber::fmt() .with_max_level(tracing::Level::DEBUG) @@ -385,22 +463,16 @@ async fn test_ping_when_disconnected() -> Result<()> { // 1. Start the SSE server tracing::info!("Test: Starting SSE server for ping test"); - let (server_handle, server_url) = create_sse_server(BIND_ADDRESS.parse()?).await?; + let (server_handle, server_url) = create_sse_server(server_name, BIND_ADDRESS.parse()?).await?; // Create a child process for the proxy tracing::info!("Test: Creating proxy process"); - let mut cmd = tokio::process::Command::new("./target/debug/mcp-proxy"); - cmd.arg(&server_url) - .arg("--debug") - .stdout(std::process::Stdio::piped()) - .stdin(std::process::Stdio::piped()); + let (child, mut reader, stderr_reader, mut stdin) = + spawn_proxy(&server_url, vec!["--debug"]).await?; - let mut child = cmd.spawn()?; + let stderr_buffer = collect_stderr(stderr_reader); - // Get stdin and stdout handles - let mut stdin = child.stdin.take().unwrap(); - let stdout = child.stdout.take().unwrap(); - let mut reader = BufReader::new(stdout); + let mut test_guard = TestGuard::new(child, server_handle, stderr_buffer); // 2. Initializes everything let init_message = r#"{"jsonrpc":"2.0","id":"init-ping","method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"ping-test","version":"0.1.0"}}}"#; @@ -410,12 +482,7 @@ async fn test_ping_when_disconnected() -> Result<()> { // Read the initialize response let mut init_response = String::new(); - match timeout( - Duration::from_secs(15), - reader.read_line(&mut init_response), - ) - .await - { + match timeout(Duration::from_secs(5), reader.read_line(&mut init_response)).await { Ok(Ok(_)) => { tracing::info!( "Test: Received initialize response: {}", @@ -437,9 +504,11 @@ async fn test_ping_when_disconnected() -> Result<()> { // 3. Kills the SSE server tracing::info!("Test: Shutting down SSE server"); - server_handle.cancel(); + if let Some(mut server) = test_guard.server_handle.take() { + server.kill().await?; + } // Give the server time to shut down and the proxy time to notice - sleep(Duration::from_secs(2)).await; + sleep(Duration::from_secs(3)).await; // 4. Sends a ping request let ping_message = r#"{"jsonrpc":"2.0","id":"ping-1","method":"ping"}"#; @@ -466,9 +535,13 @@ async fn test_ping_when_disconnected() -> Result<()> { Err(_) => panic!("Timed out waiting for ping response"), } - // Clean up - tracing::info!("Test: Cleaning up proxy process"); - child.kill().await?; + Ok(()) +} + +#[tokio::test] +async fn test_ping_when_disconnected() -> Result<()> { + ping_when_disconnected("echo").await?; + ping_when_disconnected("echo_streamable").await?; Ok(()) } diff --git a/tests/basic_test.rs b/tests/basic_test.rs index d5f1bac..641bdd8 100644 --- a/tests/basic_test.rs +++ b/tests/basic_test.rs @@ -1,7 +1,7 @@ mod echo; use rmcp::{ ServiceExt, - transport::{SseServer, TokioChildProcess}, + transport::{ConfigureCommandExt, SseServer, TokioChildProcess}, }; const BIND_ADDRESS: &str = "127.0.0.1:8099"; @@ -14,7 +14,9 @@ async fn test_proxy_connects_to_real_server() -> anyhow::Result<()> { .with_service(echo::Echo::default); let transport = TokioChildProcess::new( - tokio::process::Command::new("./target/debug/mcp-proxy").arg(TEST_SERVER_URL), + tokio::process::Command::new("./target/debug/mcp-proxy").configure(|cmd| { + cmd.arg(TEST_SERVER_URL); + }), )?; let client = ().serve(transport).await?;