From 30c5a4bb6b4278bc921f5d6e3a4b0d21f982fa04 Mon Sep 17 00:00:00 2001 From: Muhamad Awad Date: Tue, 31 Mar 2026 09:33:06 +0200 Subject: [PATCH 1/3] Add HTTP/2 multiplexed connection with concurrency control ## Summary - Introduce Connection, a Tower Service-based HTTP/2 connection that multiplexes requests over a single H2 session with semaphore-backed concurrency control - Add TcpConnector service and ConnectionInfo/IntoConnectionInfo abstractions for URI-based TCP connection establishment --- Cargo.lock | 10 +- crates/service-client/Cargo.toml | 12 +- crates/service-client/src/lib.rs | 1 + crates/service-client/src/pool/conn.rs | 1079 +++++++++++++++++ .../src/pool/conn/concurrency.rs | 334 +++++ crates/service-client/src/pool/mod.rs | 161 +++ crates/service-client/src/pool/test_util.rs | 157 +++ crates/service-client/src/pool/tls.rs | 309 +++++ crates/service-client/src/proxy.rs | 2 +- workspace-hack/Cargo.toml | 4 +- 10 files changed, 2063 insertions(+), 6 deletions(-) create mode 100644 crates/service-client/src/pool/conn.rs create mode 100644 crates/service-client/src/pool/conn/concurrency.rs create mode 100644 crates/service-client/src/pool/mod.rs create mode 100644 crates/service-client/src/pool/test_util.rs create mode 100644 crates/service-client/src/pool/tls.rs diff --git a/Cargo.lock b/Cargo.lock index f920c6cb75..79ff715a03 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7879,17 +7879,21 @@ dependencies = [ "bs58", "bytes", "bytestring", + "derive_builder", "futures", "googletest", "h2", "http 1.4.0", + "http-body 1.0.1", "http-body-util", "http-serde", "hyper", "hyper-rustls", "hyper-util", "jsonwebtoken 10.3.0", + "parking_lot", "pem", + "pin-project", "restate-types", "restate-workspace-hack", "ring", @@ -7899,7 +7903,11 @@ dependencies = [ "serde_with", "tempfile", "thiserror 2.0.18", - "tower-service", + "tokio", + "tokio-rustls", + "tokio-stream", + "tokio-util", + "tower", "tracing", "zstd", ] diff --git a/crates/service-client/Cargo.toml b/crates/service-client/Cargo.toml index a6cc57ab71..9d2a1c303f 100644 --- a/crates/service-client/Cargo.toml +++ b/crates/service-client/Cargo.toml @@ -25,26 +25,34 @@ base64 = { workspace = true } bs58 = { workspace = true } bytes = { workspace = true } bytestring = { workspace = true } +derive_builder = { workspace = true } futures = { workspace = true } h2 = "0.4.12" http = { workspace = true } +http-body = { workspace = true } http-body-util = { workspace = true } http-serde = { workspace = true } hyper = { workspace = true, features = ["http1", "http2", "client"] } hyper-rustls = { workspace = true } hyper-util = { workspace = true, features = ["client-legacy"] } jsonwebtoken = { workspace = true } +parking_lot = { workspace = true } pem = { version = "3.0.6" } +pin-project = { workspace = true } ring = { version = "0.17.14" } rustls = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } serde_with = { workspace = true } thiserror = { workspace = true } -tower-service = { version = "0.3" } +tokio = { workspace = true } +tokio-util = { workspace = true } +tokio-rustls = "0.26" +tower = { workspace = true } tracing = { workspace = true } zstd = { workspace = true } [dev-dependencies] googletest = { workspace = true } -tempfile = { workspace = true } \ No newline at end of file +tempfile = { workspace = true } +tokio-stream = {workspace = true} \ No newline at end of file diff --git a/crates/service-client/src/lib.rs b/crates/service-client/src/lib.rs index ba80cb9e3a..2f47b69f1c 100644 --- a/crates/service-client/src/lib.rs +++ b/crates/service-client/src/lib.rs @@ -36,6 +36,7 @@ use std::sync::Arc; mod http; mod lambda; +pub mod pool; mod proxy; mod request_identity; mod utils; diff --git a/crates/service-client/src/pool/conn.rs b/crates/service-client/src/pool/conn.rs new file mode 100644 index 0000000000..e805249681 --- /dev/null +++ b/crates/service-client/src/pool/conn.rs @@ -0,0 +1,1079 @@ +// Copyright (c) 2023 - 2026 Restate Software, Inc., Restate GmbH. +// All rights reserved. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0. + +use std::io; +use std::pin::Pin; +use std::sync::atomic::{AtomicU8, Ordering}; +use std::sync::{Arc, OnceLock}; +use std::task::{Context, Poll}; +use std::time::Duration; + +use bytes::Bytes; +use futures::channel::oneshot; +use futures::future::{BoxFuture, poll_fn}; +use futures::{FutureExt, ready}; +use h2::client::{ResponseFuture as H2ResponseFuture, SendRequest}; +use h2::{Reason, RecvStream, SendStream}; +use http::{HeaderMap, Request, Response, Uri}; +use http_body::{Body, Frame}; +use http_body_util::BodyExt; +use parking_lot::Mutex; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_util::sync::{CancellationToken, DropGuard}; +use tower::Service; +use tracing::debug; + +use crate::pool::conn::concurrency::{Concurrency, Permit, PermitFuture}; + +mod concurrency; + +/// Errors that can occur during the lifecycle of an H2 connection. +#[derive(thiserror::Error, Debug)] +pub enum ConnectionError { + #[error(transparent)] + IO(#[from] io::Error), + #[error(transparent)] + H2(#[from] h2::Error), + #[error("connection is closed")] + Closed, + #[error("connection keep-alive timeout")] + KeepAliveTimeout, +} + +const STATE_NEW: u8 = 0; +const STATE_CONNECTING: u8 = 1; +const STATE_CONNECTED: u8 = 2; +const STATE_CLOSED: u8 = 3; + +/// The H2 handle obtained after a successful handshake. Set exactly once. +struct H2Handle { + send_request: SendRequest, + cancel: CancellationToken, +} + +/// Lock-free shared state for an H2 connection. +/// +/// State transitions: `New → Connecting → Connected → Closed`. +/// The `state` field tracks the discriminant atomically. The `h2` handle is set +/// once via `OnceLock` when transitioning to `Connected`. Only the waiter list +/// requires a brief lock during the `Connecting` phase. +struct ConnectionShared { + state: AtomicU8, + h2: OnceLock, + /// Waiters registered during the Connecting phase. Narrowly-scoped lock. + waiters: Mutex>>>, +} + +impl ConnectionShared { + fn new() -> Self { + Self { + state: AtomicU8::new(STATE_NEW), + h2: OnceLock::new(), + waiters: Mutex::new(Some(Vec::new())), + } + } + + /// Mark the connection as closed and wake any pending waiters. + fn close(&self) { + self.state.store(STATE_CLOSED, Ordering::Relaxed); + // Drop all waiter senders so receivers get Err (Cancelled) + self.waiters.lock().take(); + } +} + +#[derive(Clone, Copy, derive_builder::Builder)] +#[builder(pattern = "owned", default)] +pub struct ConnectionConfig { + initial_max_send_streams: u32, + keep_alive_timeout: Duration, + keep_alive_interval: Option, +} + +impl Default for ConnectionConfig { + fn default() -> Self { + Self { + initial_max_send_streams: 50, + keep_alive_timeout: Duration::from_secs(20), + keep_alive_interval: None, + } + } +} + +/// A lazily-initialized, multiplexed HTTP/2 connection. +/// +/// `Connection` wraps a connector `C` (a Tower [`Service`] that produces an async I/O stream) +/// and lazily performs the H2 handshake on the first request. Subsequent requests reuse the +/// same underlying H2 connection. +/// +/// Concurrency is bounded by a semaphore that limits the number of in-flight H2 streams +/// (configured via `init_max_streams`, which sets both the semaphore and +/// `h2::client::Builder::initial_max_send_streams`). Callers must call +/// [`poll_ready`](Self::poll_ready) (or [`ready`](Self::ready)) before each +/// [`request`](Self::request) to acquire a stream permit. +/// +/// Cloning a `Connection` shares the underlying H2 session; the clone starts without a +/// permit or in-progress acquire future. +pub struct Connection { + connector: C, + /// Connection configuration + config: ConnectionConfig, + /// Tracks the current semaphore size and updates it + /// based on the last known max_concurrent_streams + concurrency: Concurrency, + /// Lock-free shared connection state. + shared: Arc, + /// Permit acquired via [`poll_ready`](Self::poll_ready), consumed by [`request`](Self::request). + permit: Option, + /// In-progress semaphore acquire, if any. + acquire: Option>>, +} + +impl Clone for Connection +where + C: Clone, +{ + fn clone(&self) -> Self { + Self { + connector: self.connector.clone(), + config: self.config, + concurrency: self.concurrency.clone(), + shared: Arc::clone(&self.shared), + permit: None, + acquire: None, + } + } +} + +impl Connection +where + C: Service, + C::Response: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static, + C::Future: Send + 'static, + C::Error: Into, +{ + pub fn new(connector: C, config: ConnectionConfig) -> Self { + Self { + connector, + config, + concurrency: Concurrency::new(config.initial_max_send_streams as usize), + shared: Arc::new(ConnectionShared::new()), + permit: None, + acquire: None, + } + } + + pub async fn ready(&mut self) -> Result<(), ConnectionError> { + poll_fn(|cx| self.poll_ready(cx)).await + } + + #[cfg(test)] + pub fn try_ready(&mut self) -> Option> { + use std::task::Waker; + match self.poll_ready(&mut Context::from_waker(Waker::noop())) { + Poll::Pending => { + // drop the acquire future since it will never be polled again + // to clear up Semaphore resources + self.acquire = None; + self.permit = None; + None + } + Poll::Ready(result) => Some(result), + } + } + + /// Return the number of the available streams on this connection. + /// + /// This does not guarantee that poll_ready(), try_ready(), or ready() + /// will succeed. It can only be used to get an estimate of how many + /// h2 streams are available + pub fn available_streams(&self) -> usize { + self.concurrency.available() + } + + pub fn max_concurrent_streams(&self) -> usize { + self.concurrency.size() + } + + /// Returns `true` if the connection has been closed or encountered a fatal error. + pub fn is_closed(&self) -> bool { + self.shared.state.load(Ordering::Relaxed) == STATE_CLOSED + } + + /// Must be polled before each request. This makes sure we acquire the permit + /// to open a new h2 stream. + /// This should return immediately if connection has enough permits. Otherwise + /// it will return Pending. + /// + /// If you want to wait on the connection to be ready, use `ready()` instead. + pub fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + match self.shared.state.load(Ordering::Relaxed) { + STATE_NEW => { + ready!(self.connector.poll_ready(cx)).map_err(Into::into)?; + } + STATE_CLOSED => { + return Poll::Ready(Err(ConnectionError::Closed)); + } + STATE_CONNECTED => { + let h2 = self.shared.h2.get().expect("h2 must be set in Connected"); + if h2.cancel.is_cancelled() { + return Poll::Ready(Err(ConnectionError::Closed)); + } + + // this is a good synchronization point to update the permits + // to the last known size known by the send_request object. + self.concurrency + .resize(h2.send_request.current_max_send_streams()); + } + STATE_CONNECTING => {} + _ => unreachable!(), + } + + if self.permit.is_some() { + return Poll::Ready(Ok(())); + } + + if self.acquire.is_none() { + self.acquire = Some(Box::pin(self.concurrency.acquire())); + } + + let acquire = self.acquire.as_mut().unwrap(); + + self.permit = Some(ready!(acquire.poll_unpin(cx))); + self.acquire = None; + + Poll::Ready(Ok(())) + } + + /// Sends an HTTP request over the shared H2 connection. + /// + /// # Panics + /// Panics if called without a prior successful [`poll_ready`](Self::poll_ready) call. + pub fn request(&mut self, request: http::Request) -> ResponseFuture + where + B: Body + Send + Sync + 'static, + B::Error: Into>, + { + assert!( + self.permit.is_some(), + "called request() without calling poll_ready()" + ); + // we already have a permit. + let permit = self.permit.take().unwrap(); + + let state = match self.shared.state.load(Ordering::Relaxed) { + STATE_CLOSED => ResponseFutureState::error(ConnectionError::Closed), + STATE_NEW => { + // CAS New → Connecting. Only one request wins and drives the handshake. + match self.shared.state.compare_exchange( + STATE_NEW, + STATE_CONNECTING, + Ordering::Relaxed, + Ordering::Relaxed, + ) { + Ok(_) => self.drive_handshake(&request), + Err(current) => self.handle_state(current), + } + } + other => self.handle_state(other), + }; + + ResponseFuture { + permit: Some(permit), + shared: Arc::clone(&self.shared), + request: Some(request), + state, + } + } + + /// Build the state for the request that lost the New → Connecting race + /// or arrived after the transition. + fn handle_state(&self, mut state: u8) -> ResponseFutureState { + loop { + match state { + STATE_CONNECTING => { + let (tx, rx) = oneshot::channel(); + match &mut *self.shared.waiters.lock() { + // If `waiters` is None, it means that the connection + // state has changed to either connected, or closed. + // In that case we need to re-load the state and check again + None => { + state = self.shared.state.load(Ordering::Relaxed); + assert!(state != STATE_CONNECTING); + } + Some(waiters) => { + waiters.push(tx); + return ResponseFutureState::WaitingConnection { rx }; + } + } + } + STATE_CONNECTED => { + return ResponseFutureState::PreFlight { + send_request: self + .shared + .h2 + .get() + .expect("h2 must be set in Connected") + .send_request + .clone(), + }; + } + STATE_CLOSED => return ResponseFutureState::error(ConnectionError::Closed), + _ => unreachable!(), + } + } + } + + /// Create the driving future that performs the H2 handshake. + fn drive_handshake(&mut self, request: &http::Request) -> ResponseFutureState { + let weak_shared = Arc::downgrade(&self.shared); + let connect = self.connector.call(request.uri().clone()); + let config = self.config; + ResponseFutureState::drive(async move { + let stream = connect.await.map_err(Into::into)?; + + let (send_request, mut connection) = h2::client::Builder::new() + .initial_max_send_streams(config.initial_max_send_streams as usize) + .handshake::<_, Bytes>(stream) + .await?; + + let ping_pong = connection.ping_pong().expect("to succeed on first call"); + let cancel = CancellationToken::new(); + let cancellation = cancel.clone().drop_guard(); + tokio::spawn(async move { + let mut connection = std::pin::pin!(connection); + let mut keep_alive = std::pin::pin!(Self::keep_alive(ping_pong, config)); + + tokio::select! { + result = &mut connection => match result { + Ok(_) => { + debug!("h2 connection shutdown"); + }, + Err(err) => { + debug!("h2 connection shutdown with error: {err}"); + } + }, + Err(err) = &mut keep_alive => { + debug!("h2 connection keep-alive error: {err}"); + } + _ = cancel.cancelled() => { + debug!("h2 connection cancelled"); + } + }; + + // set state to closed + if let Some(shared) = weak_shared.upgrade() { + shared.close(); + } + }); + + Ok((send_request, cancellation)) + }) + } + + async fn keep_alive( + mut ping_pong: h2::PingPong, + config: ConnectionConfig, + ) -> Result<(), ConnectionError> { + let keep_alive_interval = match config.keep_alive_interval { + None => { + let _: () = futures::future::pending().await; + return Ok(()); + } + Some(interval) => interval, + }; + + let mut interval = tokio::time::interval(keep_alive_interval); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); + + loop { + interval.tick().await; + + match tokio::time::timeout( + config.keep_alive_timeout, + ping_pong.ping(h2::Ping::opaque()), + ) + .await + { + Ok(Ok(_)) => {} + Ok(Err(err)) => return Err(err.into()), + Err(_) => { + return Err(ConnectionError::KeepAliveTimeout); + } + } + } + } +} + +/// Internal state machine for a single in-flight request on a [`Connection`]. +/// +/// Each variant represents a phase of the request lifecycle: +/// - **Driving** – this request is driving the initial H2 handshake. +/// - **WaitingConnection** – another request is driving the handshake; we wait for notification. +/// - **PreFlight** – we have a `SendRequest` handle and are waiting for H2 stream capacity. +/// - **InFlight** – the request has been sent; we are waiting for the response. +/// - **Error** – a terminal error was captured for the caller to consume. +enum ResponseFutureState { + Driving { + fut: BoxFuture<'static, Result<(SendRequest, DropGuard), ConnectionError>>, + }, + WaitingConnection { + rx: oneshot::Receiver<()>, + }, + PreFlight { + send_request: SendRequest, + }, + InFlight { + fut: H2ResponseFuture, + }, + Error { + err: Option, + }, +} + +impl ResponseFutureState { + fn error(err: impl Into) -> Self { + Self::Error { + err: Some(err.into()), + } + } + + fn drive(fut: F) -> Self + where + F: Future, DropGuard), ConnectionError>> + + Send + + 'static, + { + Self::Driving { fut: Box::pin(fut) } + } +} + +/// Future returned by [`Connection::request`]. +/// +/// Drives the request through its [`RequestFutureState`] state machine until a response +/// is received. Holds a semaphore permit for the duration of the request to bound +/// concurrent H2 streams. +/// +/// On drop, if this future was responsible for driving the H2 handshake (i.e. the +/// connection is still in `Connecting` state), the connection is moved to `Closed` to +/// prevent waiters from hanging indefinitely. +pub struct ResponseFuture { + permit: Option, + shared: Arc, + request: Option>, + state: ResponseFutureState, +} + +impl Drop for ResponseFuture { + fn drop(&mut self) { + // if the driving future was dropped (while in Connecting state), we need to + // immediately switch connection to closed to make sure + // waiters are immediately notified otherwise they will be stuck forever + if let ResponseFutureState::Driving { .. } = &self.state { + self.shared.close(); + } + } +} + +impl Future for ResponseFuture +where + B: Body + Unpin + Send + Sync + 'static, + B::Error: Into> + Send, +{ + type Output = Result, ConnectionError>; + + // Error handling strategy for h2 errors: + // + // Errors are classified as either "connection-level" or "stream-level" based on + // when they occur in the request lifecycle: + // + // **Connection-level** (close the entire connection via `shared.close()`): + // - Errors during `Driving` (handshake failures). + // - Errors from `send_request.poll_ready()` in `PreFlight`. + // These go through the `Error` state which cancels the h2 handle and marks + // the connection as closed. + // + // **Stream-level** (returned only to the individual caller): + // - Errors from `send_request.send_request()` in `PreFlight`. + // - Errors from the response future in `InFlight`. + // These are returned directly without closing the connection. If the + // underlying cause was actually a connection-level h2 error, it will + // surface on subsequent requests either via `send_request.poll_ready()` + // (triggering a connection close in PreFlight), or via the background + // connection task detecting the h2 shutdown and calling `shared.close()`. + // + // This simplifies error handling here: we don't need to distinguish h2 + // connection errors from stream errors ourselves — we let the phase of + // the lifecycle determine the behavior. + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + loop { + match this.state { + ResponseFutureState::Error { ref mut err } => { + // Cancel the h2 connection if we have a handle + if let Some(h2) = this.shared.h2.get() { + h2.cancel.cancel(); + } + + this.shared.close(); + return Poll::Ready(Err(err.take().expect("future polled after finish"))); + } + ResponseFutureState::Driving { ref mut fut } => { + let (send_request, cancel) = match ready!(fut.poll_unpin(cx)) { + Ok(stream) => stream, + Err(err) => { + this.state = ResponseFutureState::error(err); + continue; + } + }; + + this.state = ResponseFutureState::PreFlight { + send_request: send_request.clone(), + }; + + // Store h2 handle (set once, visible to all threads via OnceLock) + let _ = this.shared.h2.set(H2Handle { + send_request: send_request.clone(), + cancel: cancel.disarm(), + }); + this.shared.state.store(STATE_CONNECTED, Ordering::Relaxed); + + // Drain and notify waiters + if let Some(waiters) = this.shared.waiters.lock().take() { + for waiter in waiters { + let _ = waiter.send(()); + } + } + } + ResponseFutureState::WaitingConnection { ref mut rx } => { + match ready!(rx.poll_unpin(cx)) { + Ok(_) => { + let send_request = this + .shared + .h2 + .get() + .expect("h2 must be set after notification") + .send_request + .clone(); + this.state = ResponseFutureState::PreFlight { send_request }; + } + Err(_) => { + this.state = ResponseFutureState::error(ConnectionError::Closed); + continue; + } + } + } + ResponseFutureState::PreFlight { + ref mut send_request, + } => { + if let Err(err) = ready!(send_request.poll_ready(cx)) { + this.state = ResponseFutureState::error(err); + continue; + } + + // we finally can forward the request now + let (parts, body) = this.request.take().unwrap().into_parts(); + + let req = Request::from_parts(parts, ()); + let end_stream = body.is_end_stream(); + let (fut, send_stream) = send_request.send_request(req, end_stream)?; + + if !end_stream { + tokio::spawn(RequestPumpTask::new(send_stream, body).run()); + } + this.state = ResponseFutureState::InFlight { fut }; + } + ResponseFutureState::InFlight { ref mut fut } => { + let resp = ready!(fut.poll_unpin(cx)).map_err(ConnectionError::from)?; + let permit = this.permit.take().expect("permit not taken"); + let resp = resp.map(|recv| PermittedRecvStream::new(recv, permit)); + return Poll::Ready(Ok(resp)); + } + } + } + } +} + +/// Background task that streams the request body into an H2 `SendStream`. +/// +/// Spawned by [`RequestFuture`] once the H2 stream is established. Reads frames from +/// the body, respects H2 flow-control by reserving and polling capacity before each +/// write, and sends trailers (or empty trailers) once the body is exhausted. +struct RequestPumpTask { + send_stream: SendStream, + body: B, +} + +impl RequestPumpTask +where + B: http_body::Body + Unpin, + B::Error: Into> + Send, +{ + fn new(send_stream: SendStream, body: B) -> Self { + Self { send_stream, body } + } + + async fn run(mut self) { + if let Err(err) = self.run_inner().await { + debug!(%err, "error while sending request stream"); + + self.send_stream + .send_reset(err.reason().unwrap_or(Reason::INTERNAL_ERROR)); + } + } + + async fn run_inner(&mut self) -> Result<(), h2::Error> { + while let Some(frame) = self.body.frame().await { + match frame { + Ok(frame) => { + if self.handle_frame(frame, self.body.is_end_stream()).await? { + // end stream already sent! + return Ok(()); + } + } + Err(err) => { + debug!("error while reading request stream: {}", err.into()); + return Err(Reason::CANCEL.into()); + } + } + } + + // Send an explicit end stream + self.send_stream.send_trailers(HeaderMap::default())?; + + Ok(()) + } + + /// handle a frame, returns true if it's last frame or trailers. It's illegal to + /// send more data frames after handle_frame returns true + async fn handle_frame( + &mut self, + frame: Frame, + end_of_stream: bool, + ) -> Result { + if frame.is_data() { + let mut data = frame.into_data().unwrap(); + + let mut end = false; + while !data.is_empty() { + self.send_stream.reserve_capacity(data.len()); + let size = poll_fn(|cx| self.send_stream.poll_capacity(cx)) + .await + .ok_or(Reason::INTERNAL_ERROR)??; + + let chunk = data.split_to(size.min(data.len())); + end = end_of_stream && data.is_empty(); + self.send_stream.send_data(chunk, end)?; + } + Ok(end) + } else if frame.is_trailers() { + let trailers = frame.into_trailers().unwrap(); + self.send_stream.send_trailers(trailers)?; + Ok(true) + } else { + Err(Reason::PROTOCOL_ERROR.into()) + } + } +} + +/// Response body stream that holds an H2 stream permit for its lifetime. +/// +/// Implements [`http_body::Body`] by delegating to the inner [`RecvStream`], +/// automatically releasing H2 flow-control capacity after each data frame. +/// The semaphore permit is held until this stream is dropped, ensuring the +/// concurrency slot remains occupied while the response body is being consumed. +#[derive(Debug)] +pub struct PermittedRecvStream { + stream: RecvStream, + /// Tracks whether all data frames have been consumed and we should poll trailers next. + data_done: bool, + _permit: Permit, +} + +impl PermittedRecvStream { + fn new(stream: RecvStream, permit: Permit) -> Self { + Self { + stream, + data_done: false, + _permit: permit, + } + } +} + +impl Body for PermittedRecvStream { + type Data = Bytes; + type Error = h2::Error; + + fn poll_frame( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + if !self.data_done { + match ready!(self.stream.poll_data(cx)) { + Some(Ok(data)) => { + let len = data.len(); + let _ = self.stream.flow_control().release_capacity(len); + return Poll::Ready(Some(Ok(Frame::data(data)))); + } + Some(Err(err)) => return Poll::Ready(Some(Err(err))), + None => { + self.data_done = true; + } + } + } + + // Data is exhausted, poll for trailers + match ready!(self.stream.poll_trailers(cx)) { + Ok(Some(trailers)) => Poll::Ready(Some(Ok(Frame::trailers(trailers)))), + Ok(None) => Poll::Ready(None), + Err(err) => Poll::Ready(Some(Err(err))), + } + } + + fn is_end_stream(&self) -> bool { + self.stream.is_end_stream() + } +} + +#[cfg(test)] +mod test { + use bytes::Bytes; + use http::Request; + use http_body::Frame; + use http_body_util::BodyExt; + use tokio::sync::mpsc; + use tokio::task::JoinSet; + + use crate::pool::conn::ConnectionConfigBuilder; + use crate::pool::test_util::{ControlledConnector, TestConnector}; + + use super::Connection; + + /// Sends a request with an empty body and returns the response body stream. + /// The response body must be consumed or dropped by the caller. + async fn send_request(conn: &mut Connection) -> super::PermittedRecvStream { + conn.ready().await.unwrap(); + let resp = conn + .request( + Request::builder() + .uri("http://test-host:80") + .body(http_body_util::Empty::::new()) + .unwrap(), + ) + .await + .unwrap(); + resp.into_body() + } + + /// Client starts with init_max_streams=100 but server advertises + /// max_concurrent_streams=5. After the first request round-trip triggers + /// the handshake and a subsequent poll_ready reads the updated setting, + /// the semaphore must shrink to 5. + #[tokio::test] + async fn permits_sync_with_server_max_concurrent_streams() { + let mut connection = Connection::new( + TestConnector::new(5), + ConnectionConfigBuilder::default() + .initial_max_send_streams(100) + .build() + .unwrap(), + ); + + // First request triggers the handshake. Drop the body to release permit. + drop(send_request(&mut connection).await); + + // Second ready() hits the Connected arm which calls semaphore_updater.update(), + // syncing the semaphore down to 5. Send a request to consume that permit too. + drop(send_request(&mut connection).await); + + // Now hold exactly 5 response bodies to exhaust the synced semaphore. + let mut held_bodies = Vec::new(); + for _ in 0..5 { + let mut c = connection.clone(); + held_bodies.push(send_request(&mut c).await); + } + + // 6th try_ready must fail (no permits left) + let mut c6 = connection.clone(); + assert!( + c6.try_ready().is_none(), + "expected try_ready to return None at capacity" + ); + + // Drop all held bodies, permits are released + drop(held_bodies); + } + + /// With max_concurrent_streams=2, holding two response bodies should + /// exhaust permits. Dropping one should free a slot. + #[tokio::test] + async fn try_ready_fails_at_capacity() { + let mut connection = Connection::new( + TestConnector::new(2), + ConnectionConfigBuilder::default() + .initial_max_send_streams(2) + .build() + .unwrap(), + ); + + // Open two streams, hold both response bodies + let body1 = send_request(&mut connection).await; + let body2 = send_request(&mut connection).await; + + // A third try_ready must fail + let mut c3 = connection.clone(); + assert!( + c3.try_ready().is_none(), + "expected try_ready to return None when at capacity" + ); + + // Drop one body, freeing a permit + drop(body1); + + // Now ready should succeed + c3.ready().await.unwrap(); + + drop(body2); + } + + /// Multiple tasks sharing a single Connection can send concurrent + /// requests. The first request triggers the handshake; subsequent + /// requests wait for it and then reuse the same H2 session. + #[tokio::test] + async fn concurrent_requests_on_shared_connection() { + let connection = Connection::new( + TestConnector::new(10), + ConnectionConfigBuilder::default() + .initial_max_send_streams(10) + .build() + .unwrap(), + ); + + let mut handles = JoinSet::default(); + for i in 0u8..50 { + let mut c = connection.clone(); + handles.spawn(async move { + c.ready().await.unwrap(); + let resp = c + .request( + Request::builder() + .uri("http://test-host:80") + .body(http_body_util::Full::new(Bytes::from(vec![i; 4]))) + .unwrap(), + ) + .await + .unwrap(); + + let collected = resp.into_body().collect().await.unwrap().to_bytes(); + assert_eq!( + collected.as_ref(), + &[i; 4], + "response should echo request body" + ); + }); + } + + handles.join_all().await; + } + + /// Sends multiple data frames over a streaming request body and reads + /// back each echoed frame from the response, verifying bidirectional + /// streaming over an open H2 stream. + #[tokio::test] + async fn streaming_request_and_response() { + let mut connection = Connection::new( + TestConnector::new(10), + ConnectionConfigBuilder::default() + .initial_max_send_streams(10) + .build() + .unwrap(), + ); + connection.ready().await.unwrap(); + + let (tx, rx) = mpsc::channel::, std::convert::Infallible>>(10); + let resp = connection + .request( + Request::builder() + .uri("http://test-host:80") + .body(http_body_util::StreamBody::new( + tokio_stream::wrappers::ReceiverStream::new(rx), + )) + .unwrap(), + ) + .await + .unwrap(); + + let mut body = resp.into_body(); + + // Send 3 messages, reading the echo after each one + for i in 0u8..3 { + let msg = Bytes::from(vec![i; 8]); + tx.send(Ok(Frame::data(msg.clone()))).await.unwrap(); + + let frame = body.frame().await.unwrap().unwrap(); + assert_eq!( + frame.data_ref().unwrap().as_ref(), + msg.as_ref(), + "echo for message {i} should match" + ); + } + + // Close the request body stream, then expect trailers from the server + drop(tx); + let trailer_frame = body.frame().await.unwrap().unwrap(); + assert!(trailer_frame.is_trailers(), "expected trailers frame"); + + // Stream should be done + assert!(body.frame().await.is_none()); + } + + fn waiter_count(conn: &Connection) -> usize { + conn.shared.waiters.lock().as_ref().map_or(0, |v| v.len()) + } + + /// Spins until the expected number of waiters are registered on the connection. + async fn wait_for_waiters(conn: &Connection, expected: usize) { + for _ in 0..10_000 { + if waiter_count(conn) >= expected { + return; + } + tokio::task::yield_now().await; + } + panic!( + "timed out waiting for {expected} waiters, got {}", + waiter_count(conn) + ); + } + + #[tokio::test] + async fn waiters_notified_on_successful_connection() { + let (connector, gate) = ControlledConnector::new(10); + let connection = Connection::new( + connector, + ConnectionConfigBuilder::default() + .initial_max_send_streams(10) + .build() + .unwrap(), + ); + + let mut handles = JoinSet::default(); + for i in 0u8..5 { + let mut c = connection.clone(); + handles.spawn(async move { + c.ready().await.unwrap(); + c.request( + Request::builder() + .uri("http://test-host:80") + .body(http_body_util::Full::new(Bytes::from(vec![i; 4]))) + .unwrap(), + ) + .await + }); + } + + // 1 task drives the handshake, the other 4 become waiters + wait_for_waiters(&connection, 4).await; + gate.notify_waiters(); + + let results = handles.join_all().await; + for result in &results { + assert!(result.is_ok(), "all requests should succeed: {result:?}"); + } + assert!(!connection.is_closed()); + } + + #[tokio::test] + async fn waiters_notified_on_connection_failure() { + let (connector, gate) = ControlledConnector::with_error(10); + let connection = Connection::new( + connector, + ConnectionConfigBuilder::default() + .initial_max_send_streams(10) + .build() + .unwrap(), + ); + + let mut handles = JoinSet::default(); + for _ in 0..5 { + let mut c = connection.clone(); + handles.spawn(async move { + c.ready().await.unwrap(); + c.request( + Request::builder() + .uri("http://test-host:80") + .body(http_body_util::Empty::::new()) + .unwrap(), + ) + .await + }); + } + + wait_for_waiters(&connection, 4).await; + gate.notify_waiters(); + + let results = handles.join_all().await; + for result in &results { + assert!(result.is_err(), "all requests should fail: {result:?}"); + } + assert!(connection.is_closed()); + } + + #[tokio::test] + async fn waiters_notified_on_driver_drop() { + let (connector, _gate) = ControlledConnector::new(10); + let connection = Connection::new( + connector, + ConnectionConfigBuilder::default() + .initial_max_send_streams(10) + .build() + .unwrap(), + ); + + // First request wins the CAS and becomes the driver + let mut driver = connection.clone(); + driver.ready().await.unwrap(); + let driving_fut = driver.request( + Request::builder() + .uri("http://test-host:80") + .body(http_body_util::Empty::::new()) + .unwrap(), + ); + + // Spawn waiter tasks + let mut handles = JoinSet::default(); + for _ in 0..4 { + let mut c = connection.clone(); + handles.spawn(async move { + c.ready().await.unwrap(); + c.request( + Request::builder() + .uri("http://test-host:80") + .body(http_body_util::Empty::::new()) + .unwrap(), + ) + .await + }); + } + + wait_for_waiters(&connection, 4).await; + + // Drop the driver: Drop impl CAS CONNECTING→CLOSED, calls close() + drop(driving_fut); + + let results = handles.join_all().await; + for result in &results { + let err = result.as_ref().unwrap_err(); + assert!( + matches!(err, super::ConnectionError::Closed), + "expected ConnectionError::Closed, got {err:?}" + ); + } + assert!(connection.is_closed()); + } +} diff --git a/crates/service-client/src/pool/conn/concurrency.rs b/crates/service-client/src/pool/conn/concurrency.rs new file mode 100644 index 0000000000..c68623c1dd --- /dev/null +++ b/crates/service-client/src/pool/conn/concurrency.rs @@ -0,0 +1,334 @@ +// Copyright (c) 2023 - 2026 Restate Software, Inc., Restate GmbH. +// All rights reserved. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0. + +use std::{ + pin::Pin, + sync::{ + Arc, + atomic::{AtomicUsize, Ordering}, + }, + task::{Context, Poll, ready}, +}; + +use tokio::sync::{Notify, futures::OwnedNotified}; + +/// Shared state for the concurrency limiter. +#[derive(Debug)] +struct ConcurrencyInner { + size: AtomicUsize, + inflight: AtomicUsize, +} + +/// An async concurrency limiter that controls how many tasks can run simultaneously. +/// +/// Permits are acquired asynchronously and released automatically when dropped. +/// The concurrency limit can be dynamically resized at runtime. +#[derive(Clone, Debug)] +pub struct Concurrency { + inner: Arc, + notify: Arc, +} + +impl Concurrency { + /// Creates a new concurrency limiter with the given maximum number of permits. + pub fn new(size: usize) -> Self { + Self { + inner: Arc::new(ConcurrencyInner { + size: AtomicUsize::new(size), + inflight: AtomicUsize::new(0), + }), + notify: Arc::new(Notify::new()), + } + } + + /// Returns a future that resolves to a [`Permit`] once capacity is available. + pub fn acquire(&self) -> PermitFuture { + PermitFuture::new(self.clone()) + } + + /// Dynamically changes the concurrency limit. If the new limit is larger than + /// the previous one, waiting tasks are notified so they can attempt to acquire. + /// Existing permits beyond the new limit are not revoked; they drain naturally. + pub fn resize(&self, size: usize) { + let prev = self.inner.size.swap(size, Ordering::Relaxed); + if prev < size { + self.notify.notify_waiters(); + } + } + + /// Returns the current concurrency limit. + pub fn size(&self) -> usize { + self.inner.size.load(Ordering::Relaxed) + } + + /// Returns the number of currently acquired (in-flight) permits. + pub fn acquired(&self) -> usize { + self.inner.inflight.load(Ordering::Relaxed) + } + + /// Returns an approximate number of available permits. + /// + /// This is a best-effort snapshot since `size` and `inflight` are read + /// as two separate atomic loads and may be inconsistent under contention. + pub fn available(&self) -> usize { + self.size().saturating_sub(self.acquired()) + } +} + +/// A guard representing an acquired concurrency permit. +/// The permit is released automatically when dropped, waking one waiting task. +#[derive(Debug)] +pub struct Permit { + inner: Concurrency, +} + +impl Drop for Permit { + fn drop(&mut self) { + let prev = self.inner.inner.inflight.fetch_sub(1, Ordering::Relaxed); + debug_assert!(prev != 0); + self.inner.notify.notify_one(); + } +} + +#[pin_project::pin_project(project=PermitFutureStateProject)] +enum PermitFutureState { + TryAcquire, + Waiting { + #[pin] + notified: OwnedNotified, + }, +} + +/// A future that resolves to a [`Permit`] when concurrency capacity becomes available. +#[pin_project::pin_project] +pub struct PermitFuture { + concurrency: Concurrency, + #[pin] + state: PermitFutureState, +} + +impl PermitFuture { + fn new(concurrency: Concurrency) -> Self { + Self { + concurrency, + state: PermitFutureState::TryAcquire, + } + } +} + +impl Future for PermitFuture { + type Output = Permit; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); + loop { + match this.state.as_mut().project() { + PermitFutureStateProject::TryAcquire => { + let mut size = this.concurrency.inner.size.load(Ordering::Relaxed); + let mut inflight = this.concurrency.inner.inflight.load(Ordering::Relaxed); + + // make sure we register interest in waking up + let notified = this.concurrency.notify.clone().notified_owned(); + + while inflight < size { + match this.concurrency.inner.inflight.compare_exchange( + inflight, + inflight + 1, + Ordering::Relaxed, + Ordering::Relaxed, + ) { + Ok(_) => { + return Poll::Ready(Permit { + inner: this.concurrency.clone(), + }); + } + Err(value) => { + inflight = value; + // also update the size in case the size has changed + // to make sure we never go above the limits + size = this.concurrency.inner.size.load(Ordering::Relaxed); + } + } + } + + this.state.set(PermitFutureState::Waiting { notified }); + } + PermitFutureStateProject::Waiting { notified } => { + ready!(notified.poll(cx)); + this.state.set(PermitFutureState::TryAcquire); + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio::sync::oneshot; + + #[tokio::test] + async fn acquire_within_limit() { + let conc = Concurrency::new(2); + let _p1 = conc.acquire().await; + let _p2 = conc.acquire().await; + assert_eq!(conc.acquired(), 2); + assert_eq!(conc.available(), 0); + } + + #[tokio::test] + async fn drop_permit_releases_capacity() { + let conc = Concurrency::new(1); + + let p1 = conc.acquire().await; + assert_eq!(conc.acquired(), 1); + + // Spawn a task that waits for a permit + let conc2 = conc.clone(); + let (tx, rx) = oneshot::channel(); + tokio::spawn(async move { + let _p2 = conc2.acquire().await; + tx.send(conc2.acquired()).unwrap(); + }); + + // Give the spawned task a chance to register as a waiter + tokio::task::yield_now().await; + + // Dropping p1 should wake the waiting task and grant it a permit + drop(p1); + let acquired = rx.await.unwrap(); + assert_eq!(acquired, 1); + } + + #[tokio::test] + async fn resize_up_unblocks_waiters() { + let conc = Concurrency::new(0); + + // Spawn tasks that wait for permits on a zero-sized limiter + let conc1 = conc.clone(); + let (tx1, rx1) = oneshot::channel(); + tokio::spawn(async move { + let _p = conc1.acquire().await; + tx1.send(()).unwrap(); + }); + + let conc2 = conc.clone(); + let (tx2, rx2) = oneshot::channel(); + tokio::spawn(async move { + let _p = conc2.acquire().await; + tx2.send(()).unwrap(); + }); + + // Let both tasks reach the waiting state + tokio::task::yield_now().await; + assert_eq!(conc.acquired(), 0); + + // Resize to 2 — both waiters should be unblocked + conc.resize(2); + + rx1.await.unwrap(); + rx2.await.unwrap(); + assert_eq!(conc.size(), 2); + } + + #[tokio::test] + async fn resize_down_drains_existing_permits() { + let conc = Concurrency::new(3); + + let p1 = conc.acquire().await; + let p2 = conc.acquire().await; + let _p3 = conc.acquire().await; + assert_eq!(conc.acquired(), 3); + + // Resize down to 1 — existing permits are not revoked + conc.resize(1); + assert_eq!(conc.acquired(), 3); + + // New acquires should block until enough permits are dropped + let conc2 = conc.clone(); + let (tx, rx) = oneshot::channel(); + tokio::spawn(async move { + let _p = conc2.acquire().await; + tx.send(conc2.acquired()).unwrap(); + }); + + tokio::task::yield_now().await; + + // Drop two permits — inflight goes from 3 to 1, which is at the limit. + // The waiter still can't acquire since 1 + 1 > 1. + drop(p1); + drop(p2); + + // Now inflight is 1. Drop _p3 (held by the outer scope) would unblock, + // but let's resize back up instead to verify combined behavior. + conc.resize(2); + + let acquired = rx.await.unwrap(); + assert_eq!(acquired, 2); + } + + #[tokio::test] + async fn multiple_waiters_released_one_at_a_time() { + let conc = Concurrency::new(1); + let p1 = conc.acquire().await; + + // Each spawned task sends its permit back so it stays alive + let conc2 = conc.clone(); + let (tx1, rx1) = oneshot::channel(); + tokio::spawn(async move { + let p = conc2.acquire().await; + tx1.send(p).unwrap(); + }); + + let conc3 = conc.clone(); + let (tx2, rx2) = oneshot::channel(); + tokio::spawn(async move { + let p = conc3.acquire().await; + tx2.send(p).unwrap(); + }); + + tokio::task::yield_now().await; + + // Drop p1 — only one waiter should be woken (notify_one) + drop(p1); + + // One of the two should complete + let _permit = tokio::select! { + p = rx1 => p.unwrap(), + p = rx2 => p.unwrap(), + }; + + // The other is still waiting since limit is 1 + assert_eq!(conc.acquired(), 1); + } + + #[tokio::test] + async fn resize_up_while_at_capacity() { + let conc = Concurrency::new(1); + let _p1 = conc.acquire().await; + + // Waiter blocked because at capacity — send permit back to keep it alive + let conc2 = conc.clone(); + let (tx, rx) = oneshot::channel(); + tokio::spawn(async move { + let p = conc2.acquire().await; + tx.send(p).unwrap(); + }); + + tokio::task::yield_now().await; + assert_eq!(conc.acquired(), 1); + + // Resize from 1 to 2 — the waiter should now be granted a permit + conc.resize(2); + + let _p2 = rx.await.unwrap(); + assert_eq!(conc.acquired(), 2); + } +} diff --git a/crates/service-client/src/pool/mod.rs b/crates/service-client/src/pool/mod.rs new file mode 100644 index 0000000000..47f5878682 --- /dev/null +++ b/crates/service-client/src/pool/mod.rs @@ -0,0 +1,161 @@ +// Copyright (c) 2023 - 2026 Restate Software, Inc., Restate GmbH. +// All rights reserved. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0. +pub mod conn; +#[cfg(test)] +pub(crate) mod test_util; +pub mod tls; + +use std::{ + io::{self, ErrorKind}, + net::IpAddr, + str::FromStr, + task::{Context, Poll}, + time::Duration, +}; + +use futures::future::BoxFuture; +use http::Uri; +use rustls::pki_types::{DnsName, ServerName}; +use tokio::net::TcpStream; +use tower::Service; +use tracing::trace; + +/// A Tower [`Service`] that establishes TCP connections to a given URI. +/// +/// Extracts the host and port from the URI (defaulting to port 80 for HTTP +/// and 443 for HTTPS) and connects via [`TcpStream`]. +#[derive(Debug, Clone, Copy)] +pub struct TcpConnector { + connect_timeout: Duration, +} + +impl TcpConnector { + pub fn new(connect_timeout: Duration) -> Self { + Self { connect_timeout } + } +} + +impl Service for TcpConnector { + type Response = TcpStream; + type Error = io::Error; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Uri) -> Self::Future { + let connect_timeout = self.connect_timeout; + let fut = async move { + let req = req.get_connection_info(); + trace!("connecting to {:?}:{:?}", req.host, req.port()); + + let host = req + .host() + .ok_or_else(|| io::Error::new(ErrorKind::InvalidInput, "unknown host name"))?; + let port = req + .port() + .ok_or_else(|| io::Error::new(ErrorKind::InvalidInput, "missing port number"))?; + + let stream = tokio::time::timeout(connect_timeout, async { + match host { + Host::IpAddress(addr) => TcpStream::connect((*addr, port)).await, + Host::DnsName(dns) => TcpStream::connect((dns.as_ref(), port)).await, + } + }) + .await + .map_err(|_| io::Error::new(ErrorKind::TimedOut, "connect timeout"))??; + + stream.set_nodelay(true)?; + Ok(stream) + }; + + Box::pin(fut) + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum Host { + IpAddress(IpAddr), + DnsName(DnsName<'static>), +} + +impl From for ServerName<'static> { + fn from(value: Host) -> Self { + match value { + Host::IpAddress(addr) => ServerName::IpAddress(addr.into()), + Host::DnsName(dns) => ServerName::DnsName(dns), + } + } +} + +trait IntoConnectionInfo { + fn get_connection_info(&self) -> ConnectionInfo; +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +enum Schema { + Unknown, + Http, + Https, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +struct ConnectionInfo { + schema: Schema, + host: Option, + port: Option, +} + +impl ConnectionInfo { + pub fn port(&self) -> Option { + self.port + } + + pub fn host(&self) -> Option<&Host> { + self.host.as_ref() + } + + pub fn schema(&self) -> Schema { + self.schema + } +} + +impl IntoConnectionInfo for Uri { + fn get_connection_info(&self) -> ConnectionInfo { + let (schema, default_port) = match self.scheme() { + None => (Schema::Unknown, None), + Some(schema) => match schema.as_str() { + "http" => (Schema::Http, Some(80)), + "https" => (Schema::Https, Some(443)), + _ => (Schema::Unknown, None), + }, + }; + + let port = self.port_u16().or(default_port); + let host = match self.host() { + None => None, + Some(host) => match std::net::IpAddr::from_str(host) { + Ok(addr) => Some(Host::IpAddress(addr)), + Err(_) => DnsName::try_from_str(host) + .ok() + .map(|x| Host::DnsName(x.to_owned())), + }, + }; + + ConnectionInfo { schema, host, port } + } +} + +impl IntoConnectionInfo for http::Request { + fn get_connection_info(&self) -> ConnectionInfo { + self.uri().get_connection_info() + } +} diff --git a/crates/service-client/src/pool/test_util.rs b/crates/service-client/src/pool/test_util.rs new file mode 100644 index 0000000000..f3fd74bf22 --- /dev/null +++ b/crates/service-client/src/pool/test_util.rs @@ -0,0 +1,157 @@ +// Copyright (c) 2023 - 2026 Restate Software, Inc., Restate GmbH. +// All rights reserved. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0. + +use std::{ + io, + sync::{ + Arc, + atomic::{AtomicBool, Ordering}, + }, + task::{Context, Poll}, +}; + +use bytes::Bytes; +use futures::future::BoxFuture; +use http::{Response, StatusCode, Uri}; +use tokio::{io::DuplexStream, sync::Notify}; +use tower::Service; + +/// In-process h2 server configuration. +struct ServerConfig { + max_concurrent_streams: u32, +} + +/// A test connector that creates in-memory duplex streams and spawns an +/// h2 server on the other end. +#[derive(Clone)] +pub struct TestConnector { + config: std::sync::Arc, +} + +impl TestConnector { + pub fn new(max_concurrent_streams: u32) -> Self { + Self { + config: std::sync::Arc::new(ServerConfig { + max_concurrent_streams, + }), + } + } +} + +impl Service for TestConnector { + type Response = DuplexStream; + type Error = io::Error; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, _req: Uri) -> Self::Future { + let config = std::sync::Arc::clone(&self.config); + Box::pin(async move { + let (client, server) = tokio::io::duplex(64 * 1024); + tokio::spawn(run_server(server, config)); + Ok(client) + }) + } +} + +/// Runs an h2 server on the given stream. For each request, echoes the +/// request body back in the response and sends empty trailers when done. +async fn run_server(stream: DuplexStream, config: std::sync::Arc) { + let mut h2 = h2::server::Builder::new() + .max_concurrent_streams(config.max_concurrent_streams) + .handshake::<_, Bytes>(stream) + .await + .unwrap(); + + while let Some(request) = h2.accept().await { + let (request, mut send_response) = request.unwrap(); + tokio::spawn(async move { + let response = Response::builder().status(StatusCode::OK).body(()).unwrap(); + let mut send_stream = send_response.send_response(response, false).unwrap(); + let mut request_body = request.into_body(); + + while let Some(data) = request_body.data().await { + let data = data.unwrap(); + request_body + .flow_control() + .release_capacity(data.len()) + .unwrap(); + + send_stream.reserve_capacity(data.len()); + let _ = futures::future::poll_fn(|cx| send_stream.poll_capacity(cx)).await; + if send_stream.send_data(data, false).is_err() { + return; + } + } + + let _ = send_stream.send_trailers(http::HeaderMap::new()); + }); + } +} + +/// A test connector that blocks `call()` on a [`tokio::sync::Notify`] gate, +/// giving tests explicit control over when the handshake proceeds. +#[derive(Clone)] +pub struct ControlledConnector { + server_config: Arc, + gate: Arc, + force_error: Arc, +} + +impl ControlledConnector { + pub fn new(max_concurrent_streams: u32) -> (Self, Arc) { + let gate = Arc::new(tokio::sync::Notify::new()); + let connector = Self { + server_config: Arc::new(ServerConfig { + max_concurrent_streams, + }), + gate: Arc::clone(&gate), + force_error: Arc::new(AtomicBool::new(false)), + }; + (connector, gate) + } + + pub fn with_error(max_concurrent_streams: u32) -> (Self, Arc) { + let (connector, gate) = Self::new(max_concurrent_streams); + connector.force_error.store(true, Ordering::Relaxed); + (connector, gate) + } +} + +impl Service for ControlledConnector { + type Response = DuplexStream; + type Error = io::Error; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, _req: Uri) -> Self::Future { + let config = Arc::clone(&self.server_config); + let gate = Arc::clone(&self.gate); + let force_error = Arc::clone(&self.force_error); + Box::pin(async move { + gate.notified().await; + if force_error.load(Ordering::Relaxed) { + return Err(io::Error::new( + io::ErrorKind::ConnectionRefused, + "forced error", + )); + } + let (client, server) = tokio::io::duplex(64 * 1024); + tokio::spawn(run_server(server, config)); + Ok(client) + }) + } +} diff --git a/crates/service-client/src/pool/tls.rs b/crates/service-client/src/pool/tls.rs new file mode 100644 index 0000000000..b2b80ebdcf --- /dev/null +++ b/crates/service-client/src/pool/tls.rs @@ -0,0 +1,309 @@ +// Copyright (c) 2023 - 2026 Restate Software, Inc., Restate GmbH. +// All rights reserved. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0. + +//! TLS connection layer for HTTP clients. +//! +//! Provides Tower middleware for establishing TLS connections using rustls + +use std::{ + io, mem, + pin::Pin, + sync::Arc, + task::{Context, Poll, ready}, +}; + +use futures::FutureExt; +use http::Uri; +use rustls::ClientConfig; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_rustls::client::TlsStream; +use tower::{Layer, Service}; + +use crate::pool::{IntoConnectionInfo, Schema}; + +use super::ConnectionInfo; + +/// A Tower [`Layer`] that adds TLS support to transport services. +/// +/// Wraps a TCP connector and performs TLS handshakes for HTTPS URIs, +/// passing through plain connections for HTTP URIs. +pub struct TlsConnectorLayer { + config: ClientConfig, +} + +impl TlsConnectorLayer { + pub fn new(config: ClientConfig) -> Self { + Self { config } + } +} +impl Layer for TlsConnectorLayer { + type Service = TlsConnector; + fn layer(&self, inner: S) -> Self::Service { + TlsConnector::new(inner, self.config.clone()) + } +} + +/// A service that optionally wraps connections in TLS. +/// +/// For HTTPS URIs, performs a TLS handshake with ALPN support for both +/// "http/1.1" and "h2" protocols. For HTTP URIs, passes through the +/// plain connection. Uses native system certificates for validation. +#[derive(Clone)] +pub struct TlsConnector { + inner: S, + connector: tokio_rustls::TlsConnector, +} + +impl TlsConnector { + pub fn new(inner: S, mut config: ClientConfig) -> Self { + // only support h2 + config.alpn_protocols = vec!["h2".into()]; + + let connector = tokio_rustls::TlsConnector::from(Arc::new(config)); + Self { inner, connector } + } +} + +impl Service for TlsConnector +where + S: Service + Send + Clone + 'static, + S::Response: AsyncRead + AsyncWrite + Unpin + Send + 'static, + S::Future: Send, +{ + type Error = io::Error; + type Future = TlsConnectorFuture; + type Response = MaybeTlsStream; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Uri) -> Self::Future { + let connection_info = req.get_connection_info(); + + let mut this = mem::replace(self, self.clone()); + let fut = this.inner.call(req); + + TlsConnectorFutureInner::connecting(fut, connection_info, this.connector).into() + } +} + +/// Future returned by [`TlsConnectorService`] that resolves an inner connection future +/// and optionally upgrades it to TLS, yielding a [`MaybeTlsStream`]. +/// +// This wrapper exists to keep [`TlsConnectorFutureInner`] private; exposing the enum +// directly would make all its variants public as well. +#[pin_project::pin_project] +pub struct TlsConnectorFuture { + #[pin] + inner: TlsConnectorFutureInner, +} + +impl From> for TlsConnectorFuture { + fn from(inner: TlsConnectorFutureInner) -> Self { + Self { inner } + } +} + +impl Future for TlsConnectorFuture +where + F: Future>, + IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + type Output = Result, io::Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.project().inner.poll(cx) + } +} + +#[pin_project::pin_project(project = TlsConnectorFutureInnerProject)] +enum TlsConnectorFutureInner { + /// Waiting for the TCP connection to be established. + Connecting { + #[pin] + fut: F, + connection_info: ConnectionInfo, + connector: tokio_rustls::TlsConnector, + }, + /// TCP connected, performing TLS handshake. + Handshaking { + // the Connect fut is too big. + // clippy suggested to wrap it in a box + fut: Box>, + }, +} + +impl TlsConnectorFutureInner { + fn connecting( + fut: F, + connection_info: ConnectionInfo, + connector: tokio_rustls::TlsConnector, + ) -> Self { + Self::Connecting { + fut, + connection_info, + connector, + } + } +} + +impl Future for TlsConnectorFutureInner +where + F: Future>, + IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + type Output = Result, io::Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self; + + loop { + match this.as_mut().project() { + TlsConnectorFutureInnerProject::Connecting { + fut, + connection_info, + connector, + } => match ready!(fut.poll(cx)) { + Err(err) => { + return Poll::Ready(Err(err)); + } + Ok(socket) => { + let host = match connection_info.host() { + Some(host) => host, + None => { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidData, + "unknown host", + ))); + } + }; + + match connection_info.schema() { + Schema::Unknown => { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidData, + "unknown schema", + ))); + } + Schema::Http => { + return Poll::Ready(Ok(MaybeTlsStream::Plain(socket))); + } + Schema::Https => { + let fut = connector.clone().connect(host.clone().into(), socket); + this.set(TlsConnectorFutureInner::Handshaking { + fut: Box::new(fut), + }); + } + } + } + }, + TlsConnectorFutureInnerProject::Handshaking { fut } => { + match ready!(fut.poll_unpin(cx)) { + Ok(stream) => { + return Poll::Ready(Ok(MaybeTlsStream::TLS(Box::new(stream)))); + } + Err(err) => { + return Poll::Ready(Err(err)); + } + } + } + } + } + } +} + +/// A stream that may or may not be wrapped in TLS. +/// +/// Provides a unified interface for both encrypted (HTTPS) and plain (HTTP) +/// connections. Implements [`AsyncRead`] and [`AsyncWrite`] by delegating +/// to the inner stream. +pub enum MaybeTlsStream { + /// A TLS-encrypted stream (HTTPS). + TLS(Box>), + /// A plain, unencrypted stream (HTTP). + Plain(S), +} + +impl AsyncRead for MaybeTlsStream +where + S: AsyncRead + AsyncWrite + Unpin, +{ + fn poll_read( + self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + match self.get_mut() { + MaybeTlsStream::Plain(plain) => { + let plain = Pin::new(plain); + plain.poll_read(cx, buf) + } + MaybeTlsStream::TLS(tls) => { + let tls = Pin::new(tls); + tls.poll_read(cx, buf) + } + } + } +} + +impl AsyncWrite for MaybeTlsStream +where + S: AsyncRead + AsyncWrite + Unpin, +{ + fn poll_write( + self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match self.get_mut() { + MaybeTlsStream::Plain(plain) => { + let plain = Pin::new(plain); + plain.poll_write(cx, buf) + } + MaybeTlsStream::TLS(tls) => { + let tls = Pin::new(tls); + tls.poll_write(cx, buf) + } + } + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + match self.get_mut() { + MaybeTlsStream::Plain(plain) => { + let plain = Pin::new(plain); + plain.poll_flush(cx) + } + MaybeTlsStream::TLS(tls) => { + let tls = Pin::new(tls); + tls.poll_flush(cx) + } + } + } + + fn poll_shutdown( + self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + match self.get_mut() { + MaybeTlsStream::Plain(plain) => { + let plain = Pin::new(plain); + plain.poll_shutdown(cx) + } + MaybeTlsStream::TLS(tls) => { + let tls = Pin::new(tls); + tls.poll_shutdown(cx) + } + } + } +} diff --git a/crates/service-client/src/proxy.rs b/crates/service-client/src/proxy.rs index 7a8d7ee8d3..8fab1bde08 100644 --- a/crates/service-client/src/proxy.rs +++ b/crates/service-client/src/proxy.rs @@ -15,7 +15,7 @@ use std::{ sync::Arc, task::{Context, Poll}, }; -use tower_service::Service; +use tower::Service; #[derive(Clone, Debug)] pub struct ProxyConnector { diff --git a/workspace-hack/Cargo.toml b/workspace-hack/Cargo.toml index 8cb2cd9966..4882a6d98e 100644 --- a/workspace-hack/Cargo.toml +++ b/workspace-hack/Cargo.toml @@ -118,7 +118,7 @@ stable_deref_trait = { version = "1" } syn = { version = "2", features = ["extra-traits", "fold", "full", "visit", "visit-mut"] } sync_wrapper = { version = "1", default-features = false, features = ["futures"] } time = { version = "0.3", features = ["formatting", "local-offset", "macros", "parsing"] } -tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "ring", "tls12"] } +tokio-rustls = { version = "0.26", features = ["ring"] } tokio-stream = { version = "0.1", features = ["net", "sync"] } tokio-util = { version = "0.7", features = ["codec", "io-util", "net", "rt", "time"] } toml_datetime = { version = "0.6", default-features = false, features = ["serde"] } @@ -250,7 +250,7 @@ stable_deref_trait = { version = "1" } syn = { version = "2", features = ["extra-traits", "fold", "full", "visit", "visit-mut"] } sync_wrapper = { version = "1", default-features = false, features = ["futures"] } time = { version = "0.3", features = ["formatting", "local-offset", "macros", "parsing"] } -tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "ring", "tls12"] } +tokio-rustls = { version = "0.26", features = ["ring"] } tokio-stream = { version = "0.1", features = ["net", "sync"] } tokio-util = { version = "0.7", features = ["codec", "io-util", "net", "rt", "time"] } toml_datetime = { version = "0.6", default-features = false, features = ["serde"] } From b441cf98dae9c5208474e1a42e61d2cf53480ba0 Mon Sep 17 00:00:00 2001 From: Muhamad Awad Date: Tue, 31 Mar 2026 17:45:14 +0200 Subject: [PATCH 2/3] Add H2 connection pool with per-authority multiplexing ## Summary - Add AuthorityPool that manages multiple H2 connections to a single authority (scheme+host+port), creating connections on demand when streams are exhausted and evicting failed ones - Add Pool that routes requests to the correct AuthorityPool via a DashMap> - Add PoolBuilder with configurable max_connections and init_max_streams per authority --- Cargo.lock | 4 + crates/service-client/Cargo.toml | 9 +- crates/service-client/benches/README.md | 73 +++ .../benches/h2_pool_benchmark.rs | 575 ++++++++++++++++++ crates/service-client/src/pool/authority.rs | 391 ++++++++++++ crates/service-client/src/pool/config.rs | 60 ++ crates/service-client/src/pool/mod.rs | 186 +++++- workspace-hack/Cargo.toml | 4 + 8 files changed, 1298 insertions(+), 4 deletions(-) create mode 100644 crates/service-client/benches/README.md create mode 100644 crates/service-client/benches/h2_pool_benchmark.rs create mode 100644 crates/service-client/src/pool/authority.rs create mode 100644 crates/service-client/src/pool/config.rs diff --git a/Cargo.lock b/Cargo.lock index 79ff715a03..c0a533a00a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7879,6 +7879,8 @@ dependencies = [ "bs58", "bytes", "bytestring", + "criterion", + "dashmap", "derive_builder", "futures", "googletest", @@ -7894,6 +7896,7 @@ dependencies = [ "parking_lot", "pem", "pin-project", + "pprof", "restate-types", "restate-workspace-hack", "ring", @@ -8421,6 +8424,7 @@ dependencies = [ "log", "md-5", "memchr", + "miniz_oxide", "mio", "nom", "num", diff --git a/crates/service-client/Cargo.toml b/crates/service-client/Cargo.toml index 9d2a1c303f..3370b43574 100644 --- a/crates/service-client/Cargo.toml +++ b/crates/service-client/Cargo.toml @@ -25,6 +25,7 @@ base64 = { workspace = true } bs58 = { workspace = true } bytes = { workspace = true } bytestring = { workspace = true } +dashmap = { workspace = true } derive_builder = { workspace = true } futures = { workspace = true } h2 = "0.4.12" @@ -53,6 +54,12 @@ tracing = { workspace = true } zstd = { workspace = true } [dev-dependencies] +criterion = { workspace = true, features = ["async_tokio"] } googletest = { workspace = true } +pprof = { version = "0.15", features = ["criterion", "flamegraph"] } tempfile = { workspace = true } -tokio-stream = {workspace = true} \ No newline at end of file +tokio-stream = {workspace = true} + +[[bench]] +name = "h2_pool_benchmark" +harness = false diff --git a/crates/service-client/benches/README.md b/crates/service-client/benches/README.md new file mode 100644 index 0000000000..8dd0b33d87 --- /dev/null +++ b/crates/service-client/benches/README.md @@ -0,0 +1,73 @@ +# Service Client Benchmarks + +Compares the custom H2 connection pool against hyper_util's legacy client and the raw h2 client using in-memory duplex streams. + +## Benchmark groups + +| Group | What it measures | +|-------|-----------------| +| `sequential` | Single request latency | +| `concurrent/{10,50}` | Throughput under H2 multiplexing | +| `body-{1KB,64KB}` | Data echo throughput | + +## Running benchmarks + +Run all benchmarks: + +```bash +cargo bench -p restate-service-client --bench h2_pool_benchmark +``` + +Dry-run (verify they execute without measuring): + +```bash +cargo bench -p restate-service-client --bench h2_pool_benchmark -- --test +``` + +Run a single benchmark by name filter: + +```bash +cargo bench -p restate-service-client --bench h2_pool_benchmark -- "sequential/custom-pool" +``` + +## CPU profiling with pprof (built-in) + +The benchmarks include [pprof](https://github.com/tikv/pprof-rs) integration that generates flamegraph SVGs. + +### Prerequisites + +On Linux, allow perf events: + +```bash +sudo sysctl kernel.perf_event_paranoid=-1 +``` + +### Profile a benchmark + +Pass `--profile-time=` to activate profiling: + +```bash +cargo bench -p restate-service-client --bench h2_pool_benchmark -- "sequential/custom-pool" --profile-time=30 +``` + +The flamegraph is written to: + +``` +target/criterion/sequential/custom-pool/profile/flamegraph.svg +``` + +Open it in a browser for an interactive view. + +## CPU profiling with samply (external) + +[samply](https://github.com/mstange/samply) can profile the benchmark binary without any code changes. + +```bash +# Build the benchmark binary (release mode) +cargo bench -p restate-service-client --bench h2_pool_benchmark --no-run + +# Find and run with samply +samply record target/release/deps/h2_pool_benchmark-* --bench "sequential/custom-pool" --profile-time=30 +``` + +This opens the Firefox Profiler UI automatically. diff --git a/crates/service-client/benches/h2_pool_benchmark.rs b/crates/service-client/benches/h2_pool_benchmark.rs new file mode 100644 index 0000000000..ba14de9abc --- /dev/null +++ b/crates/service-client/benches/h2_pool_benchmark.rs @@ -0,0 +1,575 @@ +// Copyright (c) 2023 - 2026 Restate Software, Inc., Restate GmbH. +// All rights reserved. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0. + +use std::io; +use std::num::{NonZeroU32, NonZeroUsize}; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use std::time::Duration; + +use bytes::Bytes; +use criterion::{BenchmarkId, Criterion, Throughput, criterion_group, criterion_main}; +use futures::future::BoxFuture; +use http::{Request, StatusCode, Uri}; +use http_body_util::BodyExt; +use hyper_util::client::legacy::connect::{Connected, Connection}; +use pprof::criterion::{Output, PProfProfiler}; +use pprof::flamegraph::Options; +use tokio::io::{AsyncRead, AsyncWrite, DuplexStream, ReadBuf}; +use tokio::runtime::Builder; +use tokio::task::JoinSet; +use tower::Service; + +use restate_service_client::pool::PoolBuilder; +use restate_types::errors::GenericError; + +const MAX_CONCURRENT_STREAMS: u32 = 100; + +// --------------------------------------------------------------------------- +// Raw h2 helpers +// --------------------------------------------------------------------------- + +/// Creates a raw h2 client `SendRequest` over an in-memory duplex, with the +/// echo server on the other end and the connection driver spawned. +async fn make_raw_h2() -> h2::client::SendRequest { + let (client_stream, server_stream) = tokio::io::duplex(64 * 1024); + let config = Arc::new(ServerConfig { + max_concurrent_streams: MAX_CONCURRENT_STREAMS, + }); + tokio::spawn(run_server(server_stream, config)); + + let (send_request, connection) = h2::client::Builder::new() + .initial_max_send_streams(MAX_CONCURRENT_STREAMS as usize) + .handshake::<_, Bytes>(client_stream) + .await + .unwrap(); + tokio::spawn(async move { + let _ = connection.await; + }); + send_request +} + +/// Send a request with body via raw h2 and drain the echoed response. +async fn raw_h2_request(send_request: &h2::client::SendRequest, payload: Bytes) { + let mut ready = send_request.clone().ready().await.unwrap(); + let req = Request::builder() + .uri("http://bench-host:80") + .body(()) + .unwrap(); + let (resp_fut, mut send_stream) = ready.send_request(req, payload.is_empty()).unwrap(); + + // Send the body + send_stream.reserve_capacity(payload.len()); + let _ = futures::future::poll_fn(|cx| send_stream.poll_capacity(cx)).await; + if !payload.is_empty() { + send_stream.send_data(payload, true).unwrap(); + } + + // Drain response + let resp = resp_fut.await.unwrap(); + let mut body = resp.into_body(); + while let Some(chunk) = body.data().await { + let chunk = chunk.unwrap(); + body.flow_control().release_capacity(chunk.len()).unwrap(); + } +} + +// --------------------------------------------------------------------------- +// Shared in-memory H2 echo server infrastructure +// --------------------------------------------------------------------------- + +struct ServerConfig { + max_concurrent_streams: u32, +} + +/// A connector that creates in-memory duplex streams and spawns an H2 echo +/// server on the other end. Used by the custom pool benchmarks. +#[derive(Clone)] +struct TestConnector { + config: Arc, +} + +impl TestConnector { + fn new(max_concurrent_streams: u32) -> Self { + Self { + config: Arc::new(ServerConfig { + max_concurrent_streams, + }), + } + } +} + +impl Service for TestConnector { + type Response = DuplexStream; + type Error = io::Error; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, _req: Uri) -> Self::Future { + let config = Arc::clone(&self.config); + Box::pin(async move { + let (client, server) = tokio::io::duplex(64 * 1024); + tokio::spawn(run_server(server, config)); + Ok(client) + }) + } +} + +/// A wrapper around `DuplexStream` that implements hyper_util's `Connection` +/// trait (plus hyper's `Read`/`Write`) for use with the legacy client. +struct DuplexConnection(DuplexStream); + +impl Connection for DuplexConnection { + fn connected(&self) -> Connected { + Connected::new() + } +} + +impl hyper::rt::Read for DuplexConnection { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + mut buf: hyper::rt::ReadBufCursor<'_>, + ) -> Poll> { + let n = { + let mut tbuf = ReadBuf::uninit(unsafe { buf.as_mut() }); + match Pin::new(&mut self.0).poll_read(cx, &mut tbuf) { + Poll::Ready(Ok(())) => tbuf.filled().len(), + other => return other, + } + }; + unsafe { buf.advance(n) }; + Poll::Ready(Ok(())) + } +} + +impl hyper::rt::Write for DuplexConnection { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.0).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_shutdown(cx) + } +} + +/// A connector wrapping `TestConnector` that returns `DuplexConnection`, +/// which implements the traits required by hyper_util's legacy client. +#[derive(Clone)] +struct HyperUtilConnector { + inner: TestConnector, +} + +impl HyperUtilConnector { + fn new(max_concurrent_streams: u32) -> Self { + Self { + inner: TestConnector::new(max_concurrent_streams), + } + } +} + +impl Service for HyperUtilConnector { + type Response = DuplexConnection; + type Error = io::Error; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Uri) -> Self::Future { + let fut = self.inner.call(req); + Box::pin(async move { fut.await.map(DuplexConnection) }) + } +} + +/// Runs an H2 echo server: for each request, echoes the request body back +/// and sends empty trailers when done. +async fn run_server(stream: DuplexStream, config: Arc) { + let mut h2 = h2::server::Builder::new() + .max_concurrent_streams(config.max_concurrent_streams) + .handshake::<_, Bytes>(stream) + .await + .unwrap(); + + while let Some(request) = h2.accept().await { + let (request, mut respond) = request.unwrap(); + tokio::spawn(async move { + let response = http::Response::builder() + .status(StatusCode::OK) + .body(()) + .unwrap(); + let mut send_stream = respond.send_response(response, false).unwrap(); + let mut recv_body = request.into_body(); + + while let Some(data) = recv_body.data().await { + let data = data.unwrap(); + recv_body + .flow_control() + .release_capacity(data.len()) + .unwrap(); + + send_stream.reserve_capacity(data.len()); + let _ = futures::future::poll_fn(|cx| send_stream.poll_capacity(cx)).await; + if send_stream.send_data(data, false).is_err() { + return; + } + } + + let _ = send_stream.send_trailers(http::HeaderMap::new()); + }); + } +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +fn make_custom_pool(max_connections: usize) -> restate_service_client::pool::Pool { + PoolBuilder::default() + .max_connections(NonZeroUsize::new(max_connections).unwrap()) + .initial_max_send_streams(NonZeroU32::new(MAX_CONCURRENT_STREAMS).unwrap()) + .build(TestConnector::new(MAX_CONCURRENT_STREAMS)) +} + +type BoxBody = http_body_util::combinators::BoxBody; + +fn make_hyper_util_client() -> hyper_util::client::legacy::Client { + hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::default()) + .timer(hyper_util::rt::TokioTimer::default()) + .http2_only(true) + .build(HyperUtilConnector::new(MAX_CONCURRENT_STREAMS)) +} + +fn empty_request(uri: &str) -> Request> { + Request::builder() + .uri(uri) + .body(http_body_util::Empty::::new()) + .unwrap() +} + +fn body_request(uri: &str, payload: Bytes) -> Request> { + Request::builder() + .uri(uri) + .body(http_body_util::Full::new(payload)) + .unwrap() +} + +fn boxed_empty_request(uri: &str) -> Request { + Request::builder() + .uri(uri) + .body( + http_body_util::Empty::::new() + .map_err(|e| e.into()) + .boxed(), + ) + .unwrap() +} + +fn boxed_body_request(uri: &str, payload: Bytes) -> Request { + Request::builder() + .uri(uri) + .body( + http_body_util::Full::new(payload) + .map_err(|e| e.into()) + .boxed(), + ) + .unwrap() +} + +fn flamegraph_options<'a>() -> Options<'a> { + #[allow(unused_mut)] + let mut options = Options::default(); + if cfg!(target_os = "macos") { + options.base = vec!["__pthread_joiner_wake".to_string(), "_main".to_string()]; + } + options +} + +// --------------------------------------------------------------------------- +// Benchmarks +// --------------------------------------------------------------------------- + +fn bench_sequential_requests(c: &mut Criterion) { + let rt = Builder::new_multi_thread().enable_all().build().unwrap(); + let mut group = c.benchmark_group("sequential"); + group + .sample_size(50) + .measurement_time(Duration::from_secs(10)); + + group.bench_function("custom-pool", |b| { + let pool = make_custom_pool(1); + // Warm up: establish H2 connection + rt.block_on(async { + let resp = pool + .request(empty_request("http://bench-host:80")) + .await + .unwrap(); + drop(resp.into_body().collect().await); + }); + + b.to_async(&rt).iter(|| { + let pool = pool.clone(); + async move { + let resp = pool + .request(empty_request("http://bench-host:80")) + .await + .unwrap(); + resp.into_body().collect().await.unwrap(); + } + }); + }); + + group.bench_function("hyper-util-legacy", |b| { + let client = make_hyper_util_client(); + // Warm up + rt.block_on(async { + let resp = client + .request(boxed_empty_request("http://bench-host:80")) + .await + .unwrap(); + drop(resp.into_body().collect().await); + }); + + b.to_async(&rt).iter(|| { + let client = client.clone(); + async move { + let resp = client + .request(boxed_empty_request("http://bench-host:80")) + .await + .unwrap(); + resp.into_body().collect().await.unwrap(); + } + }); + }); + + group.bench_function("h2-raw", |b| { + let send_request = rt.block_on(make_raw_h2()); + // Warm up + rt.block_on(raw_h2_request(&send_request, Bytes::default())); + + b.to_async(&rt).iter(|| { + let send_request = send_request.clone(); + async move { + raw_h2_request(&send_request, Bytes::default()).await; + } + }); + }); + + group.finish(); +} + +fn bench_concurrent_requests(c: &mut Criterion) { + let rt = Builder::new_multi_thread().enable_all().build().unwrap(); + + for concurrency in [10, 50] { + let mut group = c.benchmark_group("concurrent"); + group + .sample_size(30) + .measurement_time(Duration::from_secs(15)) + .throughput(Throughput::Elements(concurrency as u64)); + + group.bench_with_input( + BenchmarkId::new("custom-pool", concurrency), + &concurrency, + |b, &n| { + let pool = make_custom_pool(1); + // Warm up + rt.block_on(async { + let resp = pool + .request(empty_request("http://bench-host:80")) + .await + .unwrap(); + drop(resp.into_body().collect().await); + }); + + b.to_async(&rt).iter(|| { + let pool = pool.clone(); + async move { + let mut set = JoinSet::new(); + for _ in 0..n { + let pool = pool.clone(); + set.spawn(async move { + let resp = pool + .request(empty_request("http://bench-host:80")) + .await + .unwrap(); + resp.into_body().collect().await.unwrap(); + }); + } + set.join_all().await; + } + }); + }, + ); + + group.bench_with_input( + BenchmarkId::new("hyper-util-legacy", concurrency), + &concurrency, + |b, &n| { + let client = make_hyper_util_client(); + // Warm up + rt.block_on(async { + let resp = client + .request(boxed_empty_request("http://bench-host:80")) + .await + .unwrap(); + drop(resp.into_body().collect().await); + }); + + b.to_async(&rt).iter(|| { + let client = client.clone(); + async move { + let mut set = JoinSet::new(); + for _ in 0..n { + let client = client.clone(); + set.spawn(async move { + let resp = client + .request(boxed_empty_request("http://bench-host:80")) + .await + .unwrap(); + resp.into_body().collect().await.unwrap(); + }); + } + set.join_all().await; + } + }); + }, + ); + + group.bench_with_input( + BenchmarkId::new("h2-raw", concurrency), + &concurrency, + |b, &n| { + let send_request = rt.block_on(make_raw_h2()); + // Warm up + rt.block_on(raw_h2_request(&send_request, Bytes::default())); + + b.to_async(&rt).iter(|| { + let send_request = send_request.clone(); + async move { + let mut set = JoinSet::new(); + for _ in 0..n { + let send_request = send_request.clone(); + set.spawn(async move { + raw_h2_request(&send_request, Bytes::default()).await; + }); + } + set.join_all().await; + } + }); + }, + ); + + group.finish(); + } +} + +fn bench_body_throughput(c: &mut Criterion) { + let rt = Builder::new_multi_thread().enable_all().build().unwrap(); + + for (label, size) in [("1KB", 1024), ("64KB", 64 * 1024)] { + let payload = Bytes::from(vec![0xABu8; size]); + let mut group = c.benchmark_group(format!("body-{label}")); + group + .sample_size(30) + .measurement_time(Duration::from_secs(10)) + .throughput(Throughput::Bytes(size as u64)); + + group.bench_function("custom-pool", |b| { + let pool = make_custom_pool(1); + let payload = payload.clone(); + // Warm up + rt.block_on(async { + let resp = pool + .request(body_request("http://bench-host:80", payload.clone())) + .await + .unwrap(); + resp.into_body().collect().await.unwrap(); + }); + + b.to_async(&rt).iter(|| { + let pool = pool.clone(); + let payload = payload.clone(); + async move { + let resp = pool + .request(body_request("http://bench-host:80", payload)) + .await + .unwrap(); + let collected = resp.into_body().collect().await.unwrap(); + std::hint::black_box(collected); + } + }); + }); + + group.bench_function("hyper-util-legacy", |b| { + let client = make_hyper_util_client(); + let payload = payload.clone(); + // Warm up + rt.block_on(async { + let resp = client + .request(boxed_body_request("http://bench-host:80", payload.clone())) + .await + .unwrap(); + resp.into_body().collect().await.unwrap(); + }); + + b.to_async(&rt).iter(|| { + let client = client.clone(); + let payload = payload.clone(); + async move { + let resp = client + .request(boxed_body_request("http://bench-host:80", payload)) + .await + .unwrap(); + let collected = resp.into_body().collect().await.unwrap(); + std::hint::black_box(collected); + } + }); + }); + + group.bench_function("h2-raw", |b| { + let send_request = rt.block_on(make_raw_h2()); + let payload = payload.clone(); + // Warm up + rt.block_on(raw_h2_request(&send_request, payload.clone())); + + b.to_async(&rt).iter(|| { + let send_request = send_request.clone(); + let payload = payload.clone(); + async move { + raw_h2_request(&send_request, payload).await; + } + }); + }); + + group.finish(); + } +} + +criterion_group!( + name = benches; + config = Criterion::default() + .with_profiler(PProfProfiler::new(997, Output::Flamegraph(Some(flamegraph_options())))); + targets = bench_sequential_requests, bench_concurrent_requests, bench_body_throughput +); +criterion_main!(benches); diff --git a/crates/service-client/src/pool/authority.rs b/crates/service-client/src/pool/authority.rs new file mode 100644 index 0000000000..503ea49d66 --- /dev/null +++ b/crates/service-client/src/pool/authority.rs @@ -0,0 +1,391 @@ +// Copyright (c) 2023 - 2026 Restate Software, Inc., Restate GmbH. +// All rights reserved. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0. + +//! A pool of HTTP/2 connections to a single HTTP authority (scheme + host + port). +//! +//! [`AuthorityPool`] manages multiple [`Connection`] instances, creating new ones +//! on demand when existing connections are fully utilized, and evicting connections +//! that have failed. + +use std::sync::Arc; +use std::task::{Context, Poll}; + +use bytes::Bytes; +use http::Uri; +use http_body::Body; +use parking_lot::Mutex; +use tokio::io::{AsyncRead, AsyncWrite}; +use tower::Service; + +use crate::pool::PoolConfig; +use crate::pool::conn::ConnectionConfigBuilder; + +use super::conn::{Connection, ConnectionError, ResponseFuture}; + +/// Shared mutable state for the pool. +struct AuthorityPoolInner { + connections: Vec>, +} + +/// A pool of HTTP/2 connections to a single HTTP authority. +/// +/// Manages multiple [`Connection`] instances, creating new ones on demand when +/// all existing connections are fully utilized (no available H2 streams), and +/// evicting connections that have entered a closed/failed state. +/// +/// Cloning an `AuthorityPool` shares the underlying connection set; each clone +/// maintains its own per-handle state for the `poll_ready`/`call` cycle. +pub struct AuthorityPool { + connector: C, + config: PoolConfig, + inner: Arc>>, + /// The readied connection (permit acquired). Consumed by [`call`]. + ready: Option>, + /// Connections being polled for readiness. When all connections are at + /// capacity, we poll all of them so we're woken no matter which one + /// frees up a stream. + candidates: Vec>, +} + +impl Clone for AuthorityPool { + fn clone(&self) -> Self { + Self { + connector: self.connector.clone(), + config: self.config.clone(), + inner: Arc::clone(&self.inner), + ready: None, + candidates: Vec::new(), + } + } +} + +impl AuthorityPool +where + C: Service + Clone, + C::Response: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static, + C::Future: Send + 'static, + C::Error: Into, +{ + pub fn new(connector: C, config: PoolConfig) -> Self { + Self { + connector, + config, + inner: Arc::new(Mutex::new(AuthorityPoolInner { + connections: Vec::new(), + })), + ready: None, + candidates: Vec::new(), + } + } + + pub fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + if self.ready.is_some() { + return Poll::Ready(Ok(())); + } + + loop { + let mut failure = None; + for candidate in &mut self.candidates { + match candidate.poll_ready(cx) { + Poll::Ready(Ok(_)) => { + // since we polled the candidate within the candidates + // vec. We need to take it out, so we swap it with a copy + // of itself + let mut ready = candidate.clone(); + std::mem::swap(candidate, &mut ready); + self.ready = Some(ready); + self.candidates.clear(); + return Poll::Ready(Ok(())); + } + Poll::Ready(Err(err)) => { + // record the fact we failed on one connection + // but we keep trying other candidates! If all + // candidates failed we try with a fresh set of candidates + failure = Some(err); + } + Poll::Pending => { + // We will try the next candidate! + // + // Waker registered inside conn.poll_ready — we'll be + // woken when this connection's h2/permits frees up. + } + } + } + + self.candidates.retain(|c| !c.is_closed()); + + if let Some(err) = failure + && self.candidates.is_empty() + { + // no more candidates to check. + // delegate the error back to the caller + // they can decide to retry + return Poll::Ready(Err(err)); + } + + if !self.candidates.is_empty() { + return Poll::Pending; + } + + // extend the candidates from the current set of connections. + let mut inner = self.inner.lock(); + + // evict and filter out exhausted candidates + let mut i = 0usize; + while !inner.connections.is_empty() && i < inner.connections.len() { + let candidate = &mut inner.connections[i]; + if candidate.is_closed() { + inner.connections.swap_remove(i); + continue; + } + i += 1; + if candidate.available_streams() == 0 { + continue; + } + self.candidates.push(candidate.clone()); + } + + if !self.candidates.is_empty() { + continue; + } + + // No connection with available capacity. Create a new one if under limit. + if inner.connections.len() < self.config.max_connections.get() { + let mut candidate = Connection::new( + self.connector.clone(), + ConnectionConfigBuilder::default() + .initial_max_send_streams(self.config.initial_max_send_streams.get()) + .keep_alive_interval(self.config.keep_alive_interval) + .keep_alive_timeout(self.config.keep_alive_timeout) + .build() + .unwrap(), + ); + inner.connections.push(candidate.clone()); + drop(inner); + + match candidate.poll_ready(cx) { + Poll::Pending => {} + Poll::Ready(Ok(_)) => { + self.ready = Some(candidate); + self.candidates.clear(); + return Poll::Ready(Ok(())); + } + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + } + } + } + } + + /// Sends a request over a connection selected by [`poll_ready`]. + /// + /// # Panics + /// Panics if called without a prior successful [`poll_ready`]. + pub fn call(&mut self, request: http::Request) -> ResponseFuture + where + B: Body + Send + Sync + 'static, + B::Error: Into>, + { + let conn = self + .ready + .as_mut() + .expect("call() invoked without prior poll_ready()"); + let fut = conn.request(request); + self.ready = None; + fut + } +} + +#[cfg(test)] +mod test { + use std::num::{NonZeroU32, NonZeroUsize}; + use std::task::Poll; + + use bytes::Bytes; + use http::Request; + use http_body_util::BodyExt; + + use crate::pool::conn::PermittedRecvStream; + use crate::pool::test_util::TestConnector; + + use super::{AuthorityPool, PoolConfig}; + + fn make_pool( + max_concurrent_streams: u32, + max_connections: usize, + ) -> AuthorityPool { + AuthorityPool::new( + TestConnector::new(max_concurrent_streams), + PoolConfig { + max_connections: NonZeroUsize::new(max_connections).unwrap(), + initial_max_send_streams: NonZeroU32::new(max_concurrent_streams).unwrap(), + ..Default::default() + }, + ) + } + + async fn send_empty_request(pool: &mut AuthorityPool) -> PermittedRecvStream { + futures::future::poll_fn(|cx| pool.poll_ready(cx)) + .await + .unwrap(); + let resp = pool + .call( + Request::builder() + .uri("http://test-host:80") + .body(http_body_util::Empty::::new()) + .unwrap(), + ) + .await + .unwrap(); + resp.into_body() + } + + /// First request creates a connection; the pool starts empty. + #[tokio::test] + async fn creates_connection_on_demand() { + let mut pool = make_pool(10, 4); + + { + let inner = pool.inner.lock(); + assert_eq!(inner.connections.len(), 0); + } + + let body = send_empty_request(&mut pool).await; + + { + let inner = pool.inner.lock(); + assert_eq!(inner.connections.len(), 1); + } + + drop(body); + } + + /// When all streams on existing connections are busy, a new connection is + /// created (up to max_connections). + #[tokio::test] + async fn scales_up_when_streams_exhausted() { + // 1 stream per connection, max 3 connections. + let mut pool = make_pool(1, 3); + + // Hold response bodies to keep streams occupied. + let b1 = send_empty_request(&mut pool).await; + + // Second request should trigger a second connection. + let b2 = send_empty_request(&mut pool).await; + + { + let inner = pool.inner.lock(); + assert_eq!(inner.connections.len(), 2); + } + + // Third request -> third connection. + let b3 = send_empty_request(&mut pool).await; + + { + let inner = pool.inner.lock(); + assert_eq!(inner.connections.len(), 3); + } + + drop((b1, b2, b3)); + } + + /// The pool does not create more connections than max_connections. + /// When at capacity and all streams busy, poll_ready returns Pending. + /// Dropping a held response body frees a stream and unblocks poll_ready. + #[tokio::test] + async fn respects_max_connections() { + // 1 stream per connection, max 2 connections. + let mut pool = make_pool(1, 2); + + let b1 = send_empty_request(&mut pool).await; + let b2 = send_empty_request(&mut pool).await; + + { + let inner = pool.inner.lock(); + assert_eq!(inner.connections.len(), 2); + } + + // Third poll_ready should return Pending (no capacity). + let mut pool_clone = pool.clone(); + let result = futures::future::poll_fn(|cx| match pool_clone.poll_ready(cx) { + Poll::Ready(r) => Poll::Ready(Some(r)), + Poll::Pending => Poll::Ready(None), + }) + .await; + assert!(result.is_none(), "expected Pending when at max capacity"); + + // Drop one body, freeing a stream. + drop(b1); + + // Now poll_ready should succeed (wakers were registered on all connections). + futures::future::poll_fn(|cx| pool_clone.poll_ready(cx)) + .await + .unwrap(); + + drop(b2); + } + + /// Cloned pools share the same connection set. + #[tokio::test] + async fn clones_share_connections() { + let pool = make_pool(10, 4); + let mut p1 = pool.clone(); + let mut p2 = pool.clone(); + + let _b1 = send_empty_request(&mut p1).await; + + // p2 should see the connection created by p1. + { + let inner = p2.inner.lock(); + assert_eq!(inner.connections.len(), 1); + } + + let _b2 = send_empty_request(&mut p2).await; + + // Still 1 connection (10 streams available, only 2 used). + { + let inner = p1.inner.lock(); + assert_eq!(inner.connections.len(), 1); + } + } + + /// Concurrent requests with body echo work correctly through the pool. + #[tokio::test] + async fn concurrent_requests_with_echo() { + let pool = make_pool(10, 4); + let mut handles = tokio::task::JoinSet::default(); + + for i in 0u8..5 { + let mut p = pool.clone(); + handles.spawn(async move { + futures::future::poll_fn(|cx| p.poll_ready(cx)) + .await + .unwrap(); + let resp = p + .call( + Request::builder() + .uri("http://test-host:80") + .body(http_body_util::Full::new(Bytes::from(vec![i; 4]))) + .unwrap(), + ) + .await + .unwrap(); + + let collected = resp.into_body().collect().await.unwrap().to_bytes(); + assert_eq!( + collected.as_ref(), + &[i; 4], + "response should echo request body" + ); + }); + } + + handles.join_all().await; + } +} diff --git a/crates/service-client/src/pool/config.rs b/crates/service-client/src/pool/config.rs new file mode 100644 index 0000000000..07c71815d6 --- /dev/null +++ b/crates/service-client/src/pool/config.rs @@ -0,0 +1,60 @@ +// Copyright (c) 2023 - 2026 Restate Software, Inc., Restate GmbH. +// All rights reserved. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0. + +use std::num::{NonZeroU32, NonZeroUsize}; +use std::time::Duration; + +/// Configuration for an [`AuthorityPool`]. +#[derive(Debug, Clone, derive_builder::Builder)] +#[builder( + pattern = "owned", + build_fn(name = "build_inner", private), + name = "PoolBuilder", + default +)] +pub struct PoolConfig { + /// Maximum number of connections to open to a single authority. + pub(crate) max_connections: NonZeroUsize, + + /// Initial max H2 send streams per connection (passed to [`Connection::new`]). + /// + /// Most HTTP/2 frameworks default to 100 max-concurrent-streams. We use a + /// lower initial value of 50 so the pool scales up sooner under load, + /// limiting the number of requests queued behind a single pending connection. + /// Once the connection is established, it discovers the remote peer's actual + /// max-concurrent-streams and adjusts accordingly. + #[builder(default = NonZeroU32::new(50).unwrap())] + pub(crate) initial_max_send_streams: NonZeroU32, + /// Maximum time to wait for an HTTP/2 PING response before declaring the + /// connection dead and returning [`ConnectionError::KeepAliveTimeout`]. + /// Only meaningful when `keep_alive_interval` is `Some`. Defaults to 20 s. + pub(crate) keep_alive_timeout: Duration, + /// How often to send HTTP/2 PING frames to keep idle connections alive. + /// `None` disables keep-alive pings entirely. Defaults to `None`. + pub(crate) keep_alive_interval: Option, +} + +impl Default for PoolConfig { + fn default() -> Self { + Self { + max_connections: NonZeroUsize::new(1).unwrap(), + initial_max_send_streams: NonZeroU32::new(50).unwrap(), + keep_alive_interval: None, + keep_alive_timeout: Duration::from_secs(20), + } + } +} + +impl PoolBuilder { + pub fn build(self, connector: C) -> super::Pool { + let config = self.build_inner().unwrap(); + super::Pool::new(connector, config) + } +} diff --git a/crates/service-client/src/pool/mod.rs b/crates/service-client/src/pool/mod.rs index 47f5878682..9eec9f1693 100644 --- a/crates/service-client/src/pool/mod.rs +++ b/crates/service-client/src/pool/mod.rs @@ -7,26 +7,107 @@ // As of the Change Date specified in that file, in accordance with // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0. + +pub mod authority; +mod config; pub mod conn; #[cfg(test)] pub(crate) mod test_util; pub mod tls; use std::{ + future::poll_fn, + hash::Hash, io::{self, ErrorKind}, net::IpAddr, str::FromStr, + sync::Arc, task::{Context, Poll}, time::Duration, }; +use bytes::Bytes; +use dashmap::DashMap; use futures::future::BoxFuture; -use http::Uri; +use http::{Response, Uri}; +use http_body::Body; use rustls::pki_types::{DnsName, ServerName}; -use tokio::net::TcpStream; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + net::TcpStream, +}; use tower::Service; use tracing::trace; +use crate::pool::{authority::AuthorityPool, conn::PermittedRecvStream}; + +pub use config::PoolBuilder; +use config::PoolConfig; +pub use conn::ConnectionError; + +#[derive(Clone)] +pub struct Pool { + connector: C, + config: PoolConfig, + authorities: Arc>>, +} + +impl Pool { + fn new(connector: C, config: PoolConfig) -> Self { + Self { + config, + connector, + authorities: Arc::new(DashMap::default()), + } + } +} + +impl Pool +where + C: Service + Send + Clone + 'static, + C::Response: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static, + C::Future: Send + 'static, + C::Error: Into, +{ + pub fn request( + &self, + request: http::Request, + ) -> impl Future, ConnectionError>> + Send + 'static + where + B: Body + Unpin + Send + Sync + 'static, + B::Error: Into> + Send, + { + let key = PoolKey::from_uri(request.uri()); + + let mut authority_pool = self + .authorities + .entry(key) + .or_insert_with(|| AuthorityPool::new(self.connector.clone(), self.config.clone())) + .value() + .clone(); + + async move { + poll_fn(|cx| authority_pool.poll_ready(cx)).await?; + authority_pool.call(request).await + } + } +} + +#[derive(PartialEq, Eq, Hash)] +struct PoolKey { + scheme: Option, + authority: Option, +} + +impl PoolKey { + fn from_uri(u: &Uri) -> Self { + Self { + scheme: u.scheme().cloned(), + authority: u.authority().cloned(), + } + } +} + /// A Tower [`Service`] that establishes TCP connections to a given URI. /// /// Extracts the host and port from the URI (defaulting to port 80 for HTTP @@ -82,7 +163,7 @@ impl Service for TcpConnector { } #[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum Host { +enum Host { IpAddress(IpAddr), DnsName(DnsName<'static>), } @@ -159,3 +240,102 @@ impl IntoConnectionInfo for http::Request { self.uri().get_connection_info() } } + +#[cfg(test)] +mod test { + + use bytes::Bytes; + use http::{Request, Uri}; + use http_body_util::BodyExt; + + use crate::pool::PoolBuilder; + use crate::pool::test_util::TestConnector; + + fn make_pool( + max_concurrent_streams: u32, + max_connections: usize, + ) -> super::Pool { + PoolBuilder::default() + .max_connections(std::num::NonZeroUsize::new(max_connections).unwrap()) + .initial_max_send_streams(std::num::NonZeroU32::new(max_concurrent_streams).unwrap()) + .build(TestConnector::new(max_concurrent_streams)) + } + + /// Requests to different hosts create separate authority pools. + #[tokio::test] + async fn routes_to_separate_authorities() { + let pool = make_pool(10, 4); + + assert_eq!(pool.authorities.len(), 0); + + pool.request( + Request::builder() + .uri("http://host-a:80") + .body(http_body_util::Empty::::new()) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(pool.authorities.len(), 1); + + pool.request( + Request::builder() + .uri("http://host-b:80") + .body(http_body_util::Empty::::new()) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(pool.authorities.len(), 2); + } + + /// Multiple requests to the same authority reuse the same pool entry. + #[tokio::test] + async fn same_authority_shares_pool() { + let pool = make_pool(10, 4); + + for _ in 0..3 { + pool.request( + Request::builder() + .uri("http://host-a:80") + .body(http_body_util::Empty::::new()) + .unwrap(), + ) + .await + .unwrap(); + } + + assert_eq!(pool.authorities.len(), 1); + } + + /// Requests to multiple authorities with echo payloads all resolve correctly. + #[tokio::test] + async fn multi_authority_echo() { + let pool = make_pool(10, 4); + + for (i, host) in ["host-a", "host-b", "host-c"].iter().enumerate() { + let uri: Uri = format!("http://{}:80", host).parse().unwrap(); + let resp = pool + .request( + Request::builder() + .uri(uri) + .body(http_body_util::Full::new(Bytes::from(vec![i as u8; 4]))) + .unwrap(), + ) + .await + .unwrap(); + + let collected = resp.into_body().collect().await.unwrap().to_bytes(); + assert_eq!( + collected.as_ref(), + &[i as u8; 4], + "response should echo request body for {}", + host, + ); + } + + assert_eq!(pool.authorities.len(), 3); + } +} diff --git a/workspace-hack/Cargo.toml b/workspace-hack/Cargo.toml index 4882a6d98e..75aaaf01f3 100644 --- a/workspace-hack/Cargo.toml +++ b/workspace-hack/Cargo.toml @@ -285,6 +285,7 @@ getrandom-9fbad63c4bcf4a8f = { package = "getrandom", version = "0.4", default-f hyper-rustls = { version = "0.27", default-features = false, features = ["webpki-tokio"] } jemalloc_pprof = { version = "0.8", default-features = false, features = ["flamegraph", "symbolize"] } libc = { version = "0.2", default-features = false, features = ["use_std"] } +miniz_oxide = { version = "0.8", default-features = false, features = ["simd", "with-alloc"] } mio = { version = "1", features = ["net", "os-ext"] } num = { version = "0.4" } object = { version = "0.37", default-features = false, features = ["read", "std"] } @@ -303,6 +304,7 @@ getrandom-9fbad63c4bcf4a8f = { package = "getrandom", version = "0.4", default-f hyper-rustls = { version = "0.27", default-features = false, features = ["webpki-tokio"] } jemalloc_pprof = { version = "0.8", default-features = false, features = ["flamegraph", "symbolize"] } libc = { version = "0.2", default-features = false, features = ["use_std"] } +miniz_oxide = { version = "0.8", default-features = false, features = ["simd", "with-alloc"] } mio = { version = "1", features = ["net", "os-ext"] } num = { version = "0.4" } object = { version = "0.37", default-features = false, features = ["read", "std"] } @@ -322,6 +324,7 @@ getrandom-9fbad63c4bcf4a8f = { package = "getrandom", version = "0.4", default-f hyper-rustls = { version = "0.27", default-features = false, features = ["webpki-tokio"] } jemalloc_pprof = { version = "0.8", default-features = false, features = ["flamegraph", "symbolize"] } libc = { version = "0.2", default-features = false, features = ["use_std"] } +miniz_oxide = { version = "0.8", default-features = false, features = ["simd", "with-alloc"] } num = { version = "0.4" } object = { version = "0.37", default-features = false, features = ["read", "std"] } pprof_util = { version = "0.8", default-features = false, features = ["flamegraph"] } @@ -340,6 +343,7 @@ getrandom-9fbad63c4bcf4a8f = { package = "getrandom", version = "0.4", default-f hyper-rustls = { version = "0.27", default-features = false, features = ["webpki-tokio"] } jemalloc_pprof = { version = "0.8", default-features = false, features = ["flamegraph", "symbolize"] } libc = { version = "0.2", default-features = false, features = ["use_std"] } +miniz_oxide = { version = "0.8", default-features = false, features = ["simd", "with-alloc"] } num = { version = "0.4" } object = { version = "0.37", default-features = false, features = ["read", "std"] } pprof_util = { version = "0.8", default-features = false, features = ["flamegraph"] } From d2a297bd2740d2fa57dce2c4d618387496b57b9b Mon Sep 17 00:00:00 2001 From: Muhamad Awad Date: Tue, 31 Mar 2026 17:45:14 +0200 Subject: [PATCH 3/3] [Service-Client] H2 Pool integation and usage in service client - This PR finally enable using of the new H2 pool in the http service client. Fixes #4451 --- .../invoker-impl/src/invocation_task/mod.rs | 2 + crates/service-client/src/http.rs | 175 ++++++++++++++---- crates/service-client/src/lib.rs | 8 +- crates/service-protocol/src/discovery.rs | 2 +- crates/types/src/config/http.rs | 24 ++- 5 files changed, 164 insertions(+), 47 deletions(-) diff --git a/crates/invoker-impl/src/invocation_task/mod.rs b/crates/invoker-impl/src/invocation_task/mod.rs index 1190e48c60..b6a638b4ee 100644 --- a/crates/invoker-impl/src/invocation_task/mod.rs +++ b/crates/invoker-impl/src/invocation_task/mod.rs @@ -496,6 +496,8 @@ impl ResponseStream { // This task::spawn won't be required by hyper 1.0, as the connection will be driven by a task // spawned somewhere else (perhaps in the connection pool). // See: https://github.com/restatedev/restate/issues/96 and https://github.com/restatedev/restate/issues/76 + + //todo: this is a temp clone to test Self::WaitingHeaders { join_handle: AbortOnDropHandle::new(tokio::task::spawn(client.call(req))), } diff --git a/crates/service-client/src/http.rs b/crates/service-client/src/http.rs index 6aad45eb58..847beb7f40 100644 --- a/crates/service-client/src/http.rs +++ b/crates/service-client/src/http.rs @@ -10,14 +10,17 @@ use super::proxy::ProxyConnector; +use crate::pool::conn::PermittedRecvStream; +use crate::pool::tls::TlsConnector; +use crate::pool::{self, Pool, TcpConnector}; use crate::utils::ErrorExt; use bytes::Bytes; use futures::FutureExt; -use futures::future::Either; +use futures::future::{self, Either}; use http::Version; -use http_body_util::BodyExt; -use hyper::body::Body; +use http_body_util::{BodyExt, Either as EitherBody}; +use hyper::body::{Body, Incoming}; use hyper::http::HeaderValue; use hyper::http::uri::PathAndQuery; use hyper::{HeaderMap, Method, Request, Response, Uri}; @@ -26,10 +29,14 @@ use hyper_util::client::legacy::connect::HttpConnector; use restate_types::config::HttpOptions; use rustls::{ClientConfig, KeyLogFile}; use std::error::Error; +use std::fmt; use std::fmt::Debug; -use std::future::Future; +use std::num::NonZeroU32; +use std::pin::Pin; use std::sync::{Arc, LazyLock}; -use std::{fmt, future}; +use std::task::{Context, Poll, ready}; +use std::time::Duration; +use tower::Layer; type ProxiedHttpsConnector = ProxyConnector>; @@ -55,7 +62,7 @@ static TLS_CLIENT_CONFIG: LazyLock = LazyLock::new(|| { type BoxError = Box; type BoxBody = http_body_util::combinators::BoxBody; -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct HttpClient { /// Client used for HTTPS as long as HTTP1.1 or HTTP2 was not specifically requested. /// All HTTP versions are possible. @@ -68,7 +75,7 @@ pub struct HttpClient { /// Client when HTTP2 was specifically requested - for cleartext, we use h2c, /// and for HTTPS, we will fail unless the ALPN supports h2. /// In practice, at discovery time we never force h2 for HTTPS. - h2_client: hyper_util::client::legacy::Client, + h2_pool: Pool>>, } impl HttpClient { @@ -77,11 +84,18 @@ impl HttpClient { hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::default()); builder.timer(hyper_util::rt::TokioTimer::default()); + let keep_alive_interval: Duration = options.http_keep_alive_options.interval.into(); + let keep_alive_interval = if keep_alive_interval == Duration::ZERO { + None + } else { + Some(keep_alive_interval) + }; + builder - .http2_initial_max_send_streams(options.initial_max_send_streams) + .http2_initial_max_send_streams(options.initial_max_send_streams.map(|v| v as usize)) .http2_adaptive_window(true) .http2_keep_alive_timeout(options.http_keep_alive_options.timeout.into()) - .http2_keep_alive_interval(Some(options.http_keep_alive_options.interval.into())); + .http2_keep_alive_interval(keep_alive_interval); let mut http_connector = HttpConnector::new(); http_connector.enforce_http(false); @@ -101,11 +115,27 @@ impl HttpClient { .enable_http1() .wrap_connector(http_connector.clone()); - let https_h2_connector = hyper_rustls::HttpsConnectorBuilder::new() - .with_tls_config(TLS_CLIENT_CONFIG.clone()) - .https_or_http() - .enable_http2() - .wrap_connector(http_connector.clone()); + let h2_pool = { + let connector = pool::tls::TlsConnectorLayer::new(TLS_CLIENT_CONFIG.clone()) + .layer(pool::TcpConnector::new(options.connect_timeout.into())); + let connector = ProxyConnector::new( + options.http_proxy.clone(), + options.no_proxy.clone(), + connector, + ); + + let builder = pool::PoolBuilder::default() + .max_connections(options.max_http2_connections) + .keep_alive_interval(keep_alive_interval) + .keep_alive_timeout(options.http_keep_alive_options.timeout.into()); + + let builder = match options.initial_max_send_streams.and_then(NonZeroU32::new) { + Some(value) => builder.initial_max_send_streams(value), + None => builder, + }; + + builder.build(connector) + }; HttpClient { alpn_client: builder.clone().build::<_, BoxBody>(ProxyConnector::new( @@ -118,14 +148,7 @@ impl HttpClient { options.no_proxy.clone(), https_h1_connector, )), - h2_client: { - builder.http2_only(true); - builder.build::<_, BoxBody>(ProxyConnector::new( - options.http_proxy.clone(), - options.no_proxy.clone(), - https_h2_connector, - )) - }, + h2_pool, } } @@ -138,7 +161,7 @@ impl HttpClient { headers: HeaderMap, ) -> Result, http::Error> where - B: Body + Send + Sync + Unpin + Sized + 'static, + B: Body + Send + Sync + Sized + 'static, ::Error: Error + Send + Sync + 'static, { let mut uri_parts = uri.into_parts(); @@ -186,10 +209,10 @@ impl HttpClient { body: B, path: PathAndQuery, headers: HeaderMap, - ) -> impl Future, HttpError>> + Send + 'static + ) -> impl Future, HttpError>> + Send + 'static where - B: Body + Send + Sync + Unpin + Sized + 'static, - ::Error: Error + Send + Sync + 'static, + B: Body + Send + Sync + Sized + 'static, + B::Error: std::error::Error + Send + Sync + 'static, { let request = match Self::build_request(uri, version, body, method, path, headers) { Ok(request) => request, @@ -198,21 +221,98 @@ impl HttpClient { let fut = match version { // version is set to http1.1 when use_http1.1 is set - Some(Version::HTTP_11) => self.h1_client.request(request), + Some(Version::HTTP_11) => ResponseMapper { + fut: self.h1_client.request(request), + } + .left_future(), // version is set to http2 for cleartext urls when use_http1.1 is not set - Some(Version::HTTP_2) => self.h2_client.request(request), + Some(Version::HTTP_2) => ResponseMapper { + fut: self.h2_pool.request(request), + } + .right_future(), // version is currently set to none for https urls when use_http1.1 is not set - None => self.alpn_client.request(request), + None => ResponseMapper { + fut: self.alpn_client.request(request), + } + .left_future(), // nothing currently sets a different version, but the alpn client is a sensible default - Some(_) => self.alpn_client.request(request), + Some(_) => ResponseMapper { + fut: self.alpn_client.request(request), + } + .left_future(), }; - Either::Left(async move { - match fut.await { - Ok(res) => Ok(res), - Err(err) => Err(err.into()), - } - }) + Either::Left(fut) + } +} + +#[pin_project::pin_project] +struct ResponseMapper +where + F: Future, E>>, + E: Into, + B: Into, +{ + #[pin] + fut: F, +} + +impl Future for ResponseMapper +where + F: Future, E>>, + E: Into, + B: Into, +{ + type Output = Result, HttpError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let result = ready!(self.project().fut.poll(cx)) + .map_err(Into::into) + .map(|response| response.map(Into::into)); + + Poll::Ready(result) + } +} + +/// A wrapper around [`http_body_util::Either`] to hide +/// type complexity for higher layer +#[pin_project::pin_project] +pub struct ResponseBody { + #[pin] + inner: EitherBody, +} + +impl From for ResponseBody { + fn from(value: Incoming) -> Self { + Self { + inner: EitherBody::Left(value), + } + } +} + +impl From for ResponseBody { + fn from(value: PermittedRecvStream) -> Self { + Self { + inner: EitherBody::Right(value), + } + } +} + +impl Body for ResponseBody { + type Data = Bytes; + type Error = Box; + + fn is_end_stream(&self) -> bool { + self.inner.is_end_stream() + } + fn poll_frame( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll, Self::Error>>> { + self.project().inner.poll_frame(cx) + } + fn size_hint(&self) -> http_body::SizeHint { + self.inner.size_hint() } } @@ -228,6 +328,8 @@ pub enum HttpError { Connect(#[source] hyper_util::client::legacy::Error), #[error("{}", FormatHyperError(.0))] Hyper(#[source] hyper_util::client::legacy::Error), + #[error("h2 pool connection error: {0}")] + PoolError(#[from] pool::ConnectionError), } impl HttpError { @@ -240,6 +342,7 @@ impl HttpError { HttpError::PossibleHTTP11Only(_) => false, HttpError::PossibleHTTP2Only(_) => false, HttpError::Connect(_) => true, + HttpError::PoolError(_) => true, } } diff --git a/crates/service-client/src/lib.rs b/crates/service-client/src/lib.rs index 2f47b69f1c..5c5bfea19b 100644 --- a/crates/service-client/src/lib.rs +++ b/crates/service-client/src/lib.rs @@ -19,7 +19,7 @@ use arc_swap::ArcSwapOption; use bytes::Bytes; use bytestring::ByteString; use core::fmt; -use futures::FutureExt; +use futures::{FutureExt, future}; use http_body_util::Full; use hyper::body::Body; use hyper::http::uri::PathAndQuery; @@ -30,8 +30,6 @@ use restate_types::schema::deployment::EndpointLambdaCompression; use std::collections::HashMap; use std::error::Error; use std::fmt::Formatter; -use std::future; -use std::future::Future; use std::sync::Arc; mod http; @@ -41,9 +39,9 @@ mod proxy; mod request_identity; mod utils; -pub type ResponseBody = http_body_util::Either>; +pub type ResponseBody = http_body_util::Either>; -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct ServiceClient { // TODO a single client uses the pooling provided by hyper, but this is not enough. // See https://github.com/restatedev/restate/issues/76 for more background on the topic. diff --git a/crates/service-protocol/src/discovery.rs b/crates/service-protocol/src/discovery.rs index 1e29becec6..1b67548dc4 100644 --- a/crates/service-protocol/src/discovery.rs +++ b/crates/service-protocol/src/discovery.rs @@ -173,7 +173,7 @@ impl DiscoveryError { } } -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct ServiceDiscovery { retry_policy: RetryPolicy, client: ServiceClient, diff --git a/crates/types/src/config/http.rs b/crates/types/src/config/http.rs index 96edfcd1f0..f29daf054a 100644 --- a/crates/types/src/config/http.rs +++ b/crates/types/src/config/http.rs @@ -8,10 +8,12 @@ // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0. +use std::num::NonZeroUsize; + use serde::{Deserialize, Serialize}; use serde_with::serde_as; -use restate_time_util::NonZeroFriendlyDuration; +use restate_time_util::{FriendlyDuration, NonZeroFriendlyDuration}; /// # HTTP client options #[derive(Debug, Clone, Serialize, Deserialize, derive_builder::Builder)] @@ -23,7 +25,6 @@ pub struct HttpOptions { /// # HTTP/2 Keep-alive /// /// Configuration for the HTTP/2 keep-alive mechanism, using PING frames. - /// If unset, HTTP/2 keep-alive are disabled. pub http_keep_alive_options: Http2KeepAliveOptions, /// # Proxy URI /// @@ -59,7 +60,15 @@ pub struct HttpOptions { /// /// **NOTE**: Setting this value to None (default) users the default /// recommended value from HTTP2 specs - pub initial_max_send_streams: Option, + pub initial_max_send_streams: Option, + + /// # Max HTTP2 Connections + /// + /// Sets the maximum number of open HTTP/2 connections per + /// client for a single host. + /// + /// Default: 20 + pub max_http2_connections: NonZeroUsize, } impl Default for HttpOptions { @@ -70,6 +79,7 @@ impl Default for HttpOptions { no_proxy: None, connect_timeout: NonZeroFriendlyDuration::from_secs_unchecked(10), initial_max_send_streams: None, + max_http2_connections: NonZeroUsize::new(20).unwrap(), } } } @@ -100,8 +110,10 @@ pub struct Http2KeepAliveOptions { /// Sets an interval for HTTP/2 PING frames should be sent to keep a /// connection alive. /// + /// `0` disables keep-alive pings entirely. Defaults to `40s`. + /// /// You should set this timeout with a value lower than the `abort_timeout`. - pub interval: NonZeroFriendlyDuration, + pub interval: FriendlyDuration, /// # Timeout /// @@ -109,13 +121,15 @@ pub struct Http2KeepAliveOptions { /// /// If the ping is not acknowledged within the timeout, the connection will /// be closed. + /// + /// Only meaningful when `interval` is not zero. Defaults to 20 s. pub timeout: NonZeroFriendlyDuration, } impl Default for Http2KeepAliveOptions { fn default() -> Self { Self { - interval: NonZeroFriendlyDuration::from_secs_unchecked(40), + interval: FriendlyDuration::from_secs(40), timeout: NonZeroFriendlyDuration::from_secs_unchecked(20), } }