Skip to content

Commit 01461e0

Browse files
committed
[service-client] H2 Pool integation and usage in service client
- This PR finally enable using of the new H2 pool in the http service client. - It also fixes #4456 by making sure: - Request stream is closed immediately after the we receive a terminal state - Drain the response stream. This also fixes a connection thrashing issue
1 parent 7b5ed03 commit 01461e0

File tree

6 files changed

+197
-61
lines changed

6 files changed

+197
-61
lines changed

crates/invoker-impl/src/invocation_task/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,8 @@ impl ResponseStream {
486486
// This task::spawn won't be required by hyper 1.0, as the connection will be driven by a task
487487
// spawned somewhere else (perhaps in the connection pool).
488488
// See: https://github.com/restatedev/restate/issues/96 and https://github.com/restatedev/restate/issues/76
489+
490+
//todo: this is a temp clone to test
489491
Self::WaitingHeaders {
490492
join_handle: AbortOnDropHandle::new(tokio::task::spawn(client.call(req))),
491493
}

crates/invoker-impl/src/invocation_task/service_protocol_runner_v4.rs

Lines changed: 53 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,6 @@ where
147147
.try_into()
148148
.expect("must be able to build a valid invocation path");
149149

150-
let journal_size = journal_metadata.length;
151-
152150
debug!(
153151
restate.invocation.id = %self.invocation_task.invocation_id,
154152
deployment.address = %deployment.address_display(),
@@ -159,15 +157,14 @@ where
159157

160158
// Create an arc of the parent SpanContext.
161159
// We send this with every journal entry to correctly link new spans generated from journal entries.
162-
let service_invocation_span_context = journal_metadata.span_context;
163160

164161
// Prepare the request
165-
let (mut http_stream_tx, request) = Self::prepare_request(
162+
let (http_stream_tx, request) = Self::prepare_request(
166163
path,
167164
deployment,
168165
self.service_protocol_version,
169166
&self.invocation_task.invocation_id,
170-
&service_invocation_span_context,
167+
&journal_metadata.span_context,
171168
);
172169

173170
// Initialize the response stream state
@@ -183,6 +180,49 @@ where
183180
.throttle(self.invocation_task.action_token_bucket.take())
184181
);
185182

183+
let result = self
184+
.run_inner(
185+
txn,
186+
protocol_type,
187+
journal_metadata,
188+
keyed_service_id,
189+
cached_journal_items,
190+
http_stream_tx,
191+
&mut decoder_stream,
192+
)
193+
.await;
194+
// Sanity check of the stream decoder
195+
if decoder_stream.inner().has_remaining() {
196+
warn_it!(
197+
InvokerError::WriteAfterEndOfStream,
198+
"The read buffer is non empty after the stream has been closed."
199+
);
200+
}
201+
202+
let inner_stream = &mut decoder_stream.inner_pin_mut().inner;
203+
204+
while inner_stream.next().await.is_some() {}
205+
206+
result
207+
}
208+
209+
#[allow(clippy::too_many_arguments)]
210+
async fn run_inner<Txn, S>(
211+
&mut self,
212+
txn: Txn,
213+
protocol_type: ProtocolType,
214+
journal_metadata: JournalMetadata,
215+
keyed_service_id: Option<ServiceId>,
216+
cached_journal_items: Option<Vec<JournalEntry>>,
217+
mut http_stream_tx: mpsc::Sender<Result<Frame<Bytes>, Infallible>>,
218+
decoder_stream: &mut S,
219+
) -> TerminalLoopState<()>
220+
where
221+
Txn: InvocationReaderTransaction,
222+
S: Stream<Item = Result<DecoderStreamItem, InvokerError>> + Unpin,
223+
{
224+
let journal_size = journal_metadata.length;
225+
let service_invocation_span_context = journal_metadata.span_context;
186226
// === Replay phase (transaction alive) ===
187227
{
188228
// Read state if needed (state is collected for the START message)
@@ -215,7 +255,7 @@ where
215255
crate::shortcircuit!(
216256
self.replay_loop(
217257
&mut http_stream_tx,
218-
&mut decoder_stream,
258+
decoder_stream,
219259
journal_stream,
220260
journal_metadata.length
221261
)
@@ -234,7 +274,7 @@ where
234274
crate::shortcircuit!(
235275
self.replay_loop(
236276
&mut http_stream_tx,
237-
&mut decoder_stream,
277+
decoder_stream,
238278
journal_stream,
239279
journal_metadata.length
240280
)
@@ -255,32 +295,24 @@ where
255295
self.bidi_stream_loop(
256296
&service_invocation_span_context,
257297
http_stream_tx,
258-
&mut decoder_stream
298+
decoder_stream
259299
)
260300
.await
261301
);
262302
} else {
263-
trace!("Protocol is in bidi stream mode, will now drop the sender side of the request");
303+
trace!(
304+
"Protocol is not in bidi stream mode, will now drop the sender side of the request"
305+
);
264306
// Drop the http_stream_tx.
265307
// This is required in HTTP/1.1 to let the deployment send the headers back
266308
drop(http_stream_tx)
267309
}
268310

269311
// We don't have the invoker_rx, so we simply consume the response
270312
trace!("Sender side of the request has been dropped, now processing the response");
271-
let result = self
272-
.response_stream_loop(&service_invocation_span_context, &mut decoder_stream)
273-
.await;
274-
275-
// Sanity check of the stream decoder
276-
if decoder_stream.inner().has_remaining() {
277-
warn_it!(
278-
InvokerError::WriteAfterEndOfStream,
279-
"The read buffer is non empty after the stream has been closed."
280-
);
281-
}
282313

283-
result
314+
self.response_stream_loop(&service_invocation_span_context, decoder_stream)
315+
.await
284316
}
285317

286318
fn prepare_request(

crates/service-client/src/http.rs

Lines changed: 126 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,17 @@
1010

1111
use super::proxy::ProxyConnector;
1212

13+
use crate::pool::conn::PermittedRecvStream;
14+
use crate::pool::tls::TlsConnector;
15+
use crate::pool::{self, Pool, TcpConnector};
1316
use crate::utils::ErrorExt;
1417

1518
use bytes::Bytes;
1619
use futures::FutureExt;
17-
use futures::future::Either;
20+
use futures::future::{self, Either};
1821
use http::Version;
19-
use http_body_util::BodyExt;
20-
use hyper::body::Body;
22+
use http_body_util::{BodyExt, Either as EitherBody};
23+
use hyper::body::{Body, Incoming};
2124
use hyper::http::HeaderValue;
2225
use hyper::http::uri::PathAndQuery;
2326
use hyper::{HeaderMap, Method, Request, Response, Uri};
@@ -26,10 +29,13 @@ use hyper_util::client::legacy::connect::HttpConnector;
2629
use restate_types::config::HttpOptions;
2730
use rustls::{ClientConfig, KeyLogFile};
2831
use std::error::Error;
32+
use std::fmt;
2933
use std::fmt::Debug;
30-
use std::future::Future;
34+
use std::num::NonZeroU32;
35+
use std::pin::Pin;
3136
use std::sync::{Arc, LazyLock};
32-
use std::{fmt, future};
37+
use std::task::{Context, Poll, ready};
38+
use tower::Layer;
3339

3440
type ProxiedHttpsConnector = ProxyConnector<HttpsConnector<HttpConnector>>;
3541

@@ -55,7 +61,7 @@ static TLS_CLIENT_CONFIG: LazyLock<ClientConfig> = LazyLock::new(|| {
5561
type BoxError = Box<dyn Error + Send + Sync + 'static>;
5662
type BoxBody = http_body_util::combinators::BoxBody<Bytes, BoxError>;
5763

58-
#[derive(Clone, Debug)]
64+
#[derive(Clone)]
5965
pub struct HttpClient {
6066
/// Client used for HTTPS as long as HTTP1.1 or HTTP2 was not specifically requested.
6167
/// All HTTP versions are possible.
@@ -68,7 +74,7 @@ pub struct HttpClient {
6874
/// Client when HTTP2 was specifically requested - for cleartext, we use h2c,
6975
/// and for HTTPS, we will fail unless the ALPN supports h2.
7076
/// In practice, at discovery time we never force h2 for HTTPS.
71-
h2_client: hyper_util::client::legacy::Client<ProxiedHttpsConnector, BoxBody>,
77+
h2_pool: Pool<ProxyConnector<TlsConnector<TcpConnector>>>,
7278
}
7379

7480
impl HttpClient {
@@ -78,7 +84,7 @@ impl HttpClient {
7884
builder.timer(hyper_util::rt::TokioTimer::default());
7985

8086
builder
81-
.http2_initial_max_send_streams(options.initial_max_send_streams)
87+
.http2_initial_max_send_streams(options.initial_max_send_streams.map(|v| v as usize))
8288
.http2_adaptive_window(true)
8389
.http2_keep_alive_timeout(options.http_keep_alive_options.timeout.into())
8490
.http2_keep_alive_interval(Some(options.http_keep_alive_options.interval.into()));
@@ -101,11 +107,25 @@ impl HttpClient {
101107
.enable_http1()
102108
.wrap_connector(http_connector.clone());
103109

104-
let https_h2_connector = hyper_rustls::HttpsConnectorBuilder::new()
105-
.with_tls_config(TLS_CLIENT_CONFIG.clone())
106-
.https_or_http()
107-
.enable_http2()
108-
.wrap_connector(http_connector.clone());
110+
let h2_pool = {
111+
let connector = pool::tls::TlsConnectorLayer::new(TLS_CLIENT_CONFIG.clone())
112+
.layer(pool::TcpConnector::new(options.connect_timeout.into()));
113+
let connector = ProxyConnector::new(
114+
options.http_proxy.clone(),
115+
options.no_proxy.clone(),
116+
connector,
117+
);
118+
119+
let builder =
120+
pool::PoolBuilder::default().max_connections(options.max_http2_connections);
121+
122+
let builder = match options.initial_max_send_streams.and_then(NonZeroU32::new) {
123+
Some(value) => builder.init_max_send_streams(value),
124+
None => builder,
125+
};
126+
127+
builder.build(connector)
128+
};
109129

110130
HttpClient {
111131
alpn_client: builder.clone().build::<_, BoxBody>(ProxyConnector::new(
@@ -118,14 +138,7 @@ impl HttpClient {
118138
options.no_proxy.clone(),
119139
https_h1_connector,
120140
)),
121-
h2_client: {
122-
builder.http2_only(true);
123-
builder.build::<_, BoxBody>(ProxyConnector::new(
124-
options.http_proxy.clone(),
125-
options.no_proxy.clone(),
126-
https_h2_connector,
127-
))
128-
},
141+
h2_pool,
129142
}
130143
}
131144

@@ -186,10 +199,10 @@ impl HttpClient {
186199
body: B,
187200
path: PathAndQuery,
188201
headers: HeaderMap<HeaderValue>,
189-
) -> impl Future<Output = Result<Response<hyper::body::Incoming>, HttpError>> + Send + 'static
202+
) -> impl Future<Output = Result<Response<ResponseBody>, HttpError>> + Send + 'static
190203
where
191204
B: Body<Data = Bytes> + Send + Sync + Unpin + Sized + 'static,
192-
<B as Body>::Error: Error + Send + Sync + 'static,
205+
B::Error: std::error::Error + Send + Sync + 'static,
193206
{
194207
let request = match Self::build_request(uri, version, body, method, path, headers) {
195208
Ok(request) => request,
@@ -198,21 +211,98 @@ impl HttpClient {
198211

199212
let fut = match version {
200213
// version is set to http1.1 when use_http1.1 is set
201-
Some(Version::HTTP_11) => self.h1_client.request(request),
214+
Some(Version::HTTP_11) => ResponseMapper {
215+
fut: self.h1_client.request(request),
216+
}
217+
.left_future(),
202218
// version is set to http2 for cleartext urls when use_http1.1 is not set
203-
Some(Version::HTTP_2) => self.h2_client.request(request),
219+
Some(Version::HTTP_2) => ResponseMapper {
220+
fut: self.h2_pool.request(request),
221+
}
222+
.right_future(),
204223
// version is currently set to none for https urls when use_http1.1 is not set
205-
None => self.alpn_client.request(request),
224+
None => ResponseMapper {
225+
fut: self.alpn_client.request(request),
226+
}
227+
.left_future(),
206228
// nothing currently sets a different version, but the alpn client is a sensible default
207-
Some(_) => self.alpn_client.request(request),
229+
Some(_) => ResponseMapper {
230+
fut: self.alpn_client.request(request),
231+
}
232+
.left_future(),
208233
};
209234

210-
Either::Left(async move {
211-
match fut.await {
212-
Ok(res) => Ok(res),
213-
Err(err) => Err(err.into()),
214-
}
215-
})
235+
Either::Left(fut)
236+
}
237+
}
238+
239+
#[pin_project::pin_project]
240+
struct ResponseMapper<F, B, E>
241+
where
242+
F: Future<Output = Result<Response<B>, E>>,
243+
E: Into<HttpError>,
244+
B: Into<ResponseBody>,
245+
{
246+
#[pin]
247+
fut: F,
248+
}
249+
250+
impl<F, B, E> Future for ResponseMapper<F, B, E>
251+
where
252+
F: Future<Output = Result<Response<B>, E>>,
253+
E: Into<HttpError>,
254+
B: Into<ResponseBody>,
255+
{
256+
type Output = Result<Response<ResponseBody>, HttpError>;
257+
258+
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
259+
let result = ready!(self.project().fut.poll(cx))
260+
.map_err(Into::into)
261+
.map(|response| response.map(Into::into));
262+
263+
Poll::Ready(result)
264+
}
265+
}
266+
267+
/// A wrapper around [`http_body_util::Either`] to hide
268+
/// type complexity for higher layer
269+
#[pin_project::pin_project]
270+
pub struct ResponseBody {
271+
#[pin]
272+
inner: EitherBody<Incoming, PermittedRecvStream>,
273+
}
274+
275+
impl From<Incoming> for ResponseBody {
276+
fn from(value: Incoming) -> Self {
277+
Self {
278+
inner: EitherBody::Left(value),
279+
}
280+
}
281+
}
282+
283+
impl From<PermittedRecvStream> for ResponseBody {
284+
fn from(value: PermittedRecvStream) -> Self {
285+
Self {
286+
inner: EitherBody::Right(value),
287+
}
288+
}
289+
}
290+
291+
impl Body for ResponseBody {
292+
type Data = Bytes;
293+
type Error = Box<dyn std::error::Error + Send + Sync>;
294+
295+
fn is_end_stream(&self) -> bool {
296+
self.inner.is_end_stream()
297+
}
298+
fn poll_frame(
299+
self: std::pin::Pin<&mut Self>,
300+
cx: &mut std::task::Context<'_>,
301+
) -> std::task::Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
302+
self.project().inner.poll_frame(cx)
303+
}
304+
fn size_hint(&self) -> http_body::SizeHint {
305+
self.inner.size_hint()
216306
}
217307
}
218308

@@ -228,6 +318,8 @@ pub enum HttpError {
228318
Connect(#[source] hyper_util::client::legacy::Error),
229319
#[error("{}", FormatHyperError(.0))]
230320
Hyper(#[source] hyper_util::client::legacy::Error),
321+
#[error("h2 pool connection error: {0}")]
322+
PoolError(#[from] pool::ConnectionError),
231323
}
232324

233325
impl HttpError {
@@ -240,6 +332,7 @@ impl HttpError {
240332
HttpError::PossibleHTTP11Only(_) => false,
241333
HttpError::PossibleHTTP2Only(_) => false,
242334
HttpError::Connect(_) => true,
335+
HttpError::PoolError(_) => true,
243336
}
244337
}
245338

0 commit comments

Comments
 (0)