Skip to content

Commit 84d51a2

Browse files
authored
Header propagation fixes and improvements (#11391)
TODO * Sibling: rerun-io/dataplatform#1774
1 parent 06bc521 commit 84d51a2

File tree

6 files changed

+222
-59
lines changed

6 files changed

+222
-59
lines changed

Cargo.lock

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9062,7 +9062,9 @@ name = "re_protos"
90629062
version = "0.26.0-alpha.1+dev"
90639063
dependencies = [
90649064
"arrow",
9065+
"http",
90659066
"jiff",
9067+
"pin-project-lite",
90669068
"prost",
90679069
"prost-types",
90689070
"pyo3",
@@ -9076,6 +9078,7 @@ dependencies = [
90769078
"serde",
90779079
"thiserror 1.0.69",
90789080
"tonic",
9081+
"tower",
90799082
"url",
90809083
]
90819084

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,7 @@ paste = "1.0"
301301
pathdiff = "0.2"
302302
percent-encoding = "2.3"
303303
pico-args = "0.5"
304+
pin-project-lite = "0.2"
304305
ply-rs = { version = "0.1", default-features = false }
305306
poll-promise = "0.3"
306307
pollster = "0.4"

crates/store/re_protos/Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,15 @@ re_tuid.workspace = true
3131

3232
# External
3333
arrow.workspace = true
34+
http.workspace = true
3435
jiff.workspace = true
36+
pin-project-lite.workspace = true
3537
prost-types.workspace = true
3638
prost.workspace = true
3739
pyo3 = { workspace = true, optional = true }
3840
serde.workspace = true
3941
thiserror.workspace = true
42+
tower.workspace = true
4043
url = { workspace = true, features = ["serde"] }
4144

4245
# Native dependencies:

crates/store/re_protos/src/headers.rs

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,3 +126,188 @@ impl<T> RerunHeadersExtractorExt for tonic::Request<T> {
126126
Ok(Some(entry_name))
127127
}
128128
}
129+
130+
/// Creates a new [`tower::Layer`] middleware that always makes sure to propagate Rerun headers
131+
/// back and forth across requests and responses.
132+
pub fn new_rerun_headers_propagation_layer() -> PropagateHeadersLayer {
133+
PropagateHeadersLayer::new(
134+
[
135+
http::HeaderName::from_static(RERUN_HTTP_HEADER_ENTRY_ID),
136+
http::HeaderName::from_static("x-request-id"),
137+
]
138+
.into_iter()
139+
.collect(),
140+
)
141+
}
142+
143+
// ---
144+
145+
// NOTE: This if a fork of <https://docs.rs/tower-http/0.6.6/tower_http/propagate_header/struct.PropagateHeader.html>.
146+
//
147+
// It exists to prevent never-ending chains of generics when propagating multiple headers, e.g.:
148+
// ```
149+
// pub type RedapClientInner =
150+
// re_perf_telemetry::external::tower_http::propagate_header::PropagateHeader<
151+
// re_perf_telemetry::external::tower_http::propagate_header::PropagateHeader<
152+
// re_perf_telemetry::external::tower_http::propagate_header::PropagateHeader<
153+
// re_perf_telemetry::external::tower_http::propagate_header::PropagateHeader<
154+
// re_perf_telemetry::external::tower_http::trace::Trace<
155+
// tonic::service::interceptor::InterceptedService<
156+
// tonic::service::interceptor::InterceptedService<
157+
// tonic::transport::Channel,
158+
// re_auth::client::AuthDecorator,
159+
// >,
160+
// re_perf_telemetry::TracingInjectorInterceptor,
161+
// >,
162+
// re_perf_telemetry::external::tower_http::classify::SharedClassifier<
163+
// re_perf_telemetry::external::tower_http::classify::GrpcErrorsAsFailures,
164+
// >,
165+
// re_perf_telemetry::GrpcMakeSpan,
166+
// >,
167+
// >,
168+
// >,
169+
// >,
170+
// >;
171+
// ```
172+
// which instead becomes this:
173+
// ```
174+
// pub type RedapClientInner =
175+
// PropagateHeaders<
176+
// re_perf_telemetry::external::tower_http::trace::Trace<
177+
// tonic::service::interceptor::InterceptedService<
178+
// tonic::service::interceptor::InterceptedService<
179+
// tonic::transport::Channel,
180+
// re_auth::client::AuthDecorator,
181+
// >,
182+
// re_perf_telemetry::TracingInjectorInterceptor,
183+
// >,
184+
// re_perf_telemetry::external::tower_http::classify::SharedClassifier<
185+
// re_perf_telemetry::external::tower_http::classify::GrpcErrorsAsFailures,
186+
// >,
187+
// re_perf_telemetry::GrpcMakeSpan,
188+
// >,
189+
// >;
190+
// ```
191+
192+
use std::collections::HashSet;
193+
use std::future::Future;
194+
use std::{
195+
pin::Pin,
196+
task::{Context, Poll, ready},
197+
};
198+
199+
use http::{HeaderValue, Request, Response, header::HeaderName};
200+
use pin_project_lite::pin_project;
201+
use tower::Service;
202+
use tower::layer::Layer;
203+
204+
/// Layer that applies [`PropagateHeaders`] which propagates multiple headers at once from requests to responses.
205+
///
206+
/// If the headers are present on the request they'll be applied to the response as well. This could
207+
/// for example be used to propagate headers such as `x-rerun-entry-id`, `x-rerun-client-version`, etc.
208+
#[derive(Clone, Debug)]
209+
pub struct PropagateHeadersLayer {
210+
headers: HashSet<HeaderName>,
211+
}
212+
213+
impl PropagateHeadersLayer {
214+
/// Create a new [`PropagateHeadersLayer`].
215+
pub fn new(headers: HashSet<HeaderName>) -> Self {
216+
Self { headers }
217+
}
218+
}
219+
220+
impl<S> Layer<S> for PropagateHeadersLayer {
221+
type Service = PropagateHeaders<S>;
222+
223+
fn layer(&self, inner: S) -> Self::Service {
224+
PropagateHeaders {
225+
inner,
226+
headers: self.headers.clone(),
227+
}
228+
}
229+
}
230+
231+
/// Middleware that propagates multiple headers at once from requests to responses.
232+
///
233+
/// If the headers are present on the request they'll be applied to the response as well. This could
234+
/// for example be used to propagate headers such as `x-rerun-entry-id`, `x-rerun-client-version`, etc.
235+
#[derive(Clone, Debug)]
236+
pub struct PropagateHeaders<S> {
237+
inner: S,
238+
headers: HashSet<HeaderName>,
239+
}
240+
241+
impl<S> PropagateHeaders<S> {
242+
/// Create a new [`PropagateHeaders`] that propagates the given header.
243+
pub fn new(inner: S, headers: HashSet<HeaderName>) -> Self {
244+
Self { inner, headers }
245+
}
246+
247+
/// Returns a new [`Layer`] that wraps services with a `PropagateHeaders` middleware.
248+
///
249+
/// [`Layer`]: tower::layer::Layer
250+
pub fn layer(headers: HashSet<HeaderName>) -> PropagateHeadersLayer {
251+
PropagateHeadersLayer::new(headers)
252+
}
253+
}
254+
255+
impl<ReqBody, ResBody, S> Service<Request<ReqBody>> for PropagateHeaders<S>
256+
where
257+
S: Service<Request<ReqBody>, Response = Response<ResBody>>,
258+
{
259+
type Response = S::Response;
260+
type Error = S::Error;
261+
type Future = ResponseFuture<S::Future>;
262+
263+
#[inline]
264+
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
265+
self.inner.poll_ready(cx)
266+
}
267+
268+
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
269+
let headers_and_values = self
270+
.headers
271+
.iter()
272+
.filter_map(|name| {
273+
req.headers()
274+
.get(name)
275+
.cloned()
276+
.map(|value| (name.clone(), value))
277+
})
278+
.collect();
279+
280+
ResponseFuture {
281+
future: self.inner.call(req),
282+
headers_and_values,
283+
}
284+
}
285+
}
286+
287+
pin_project! {
288+
/// Response future for [`PropagateHeaders`].
289+
#[derive(Debug)]
290+
pub struct ResponseFuture<F> {
291+
#[pin]
292+
future: F,
293+
headers_and_values: Vec<(HeaderName, HeaderValue)>,
294+
}
295+
}
296+
297+
impl<F, ResBody, E> Future for ResponseFuture<F>
298+
where
299+
F: Future<Output = Result<Response<ResBody>, E>>,
300+
{
301+
type Output = F::Output;
302+
303+
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
304+
let this = self.project();
305+
let mut res = ready!(this.future.poll(cx)?);
306+
307+
for (header, value) in std::mem::take(this.headers_and_values) {
308+
res.headers_mut().insert(header, value);
309+
}
310+
311+
Poll::Ready(Ok(res))
312+
}
313+
}

crates/store/re_redap_client/src/grpc.rs

Lines changed: 28 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,9 @@ pub async fn channel(origin: Origin) -> Result<tonic::transport::Channel, Connec
9595
}
9696

9797
#[cfg(target_arch = "wasm32")]
98-
pub type RedapClientInner =
99-
tonic::service::interceptor::InterceptedService<tonic_web_wasm_client::Client, AuthDecorator>;
98+
pub type RedapClientInner = re_protos::headers::PropagateHeaders<
99+
tonic::service::interceptor::InterceptedService<tonic_web_wasm_client::Client, AuthDecorator>,
100+
>;
100101

101102
#[cfg(target_arch = "wasm32")]
102103
pub(crate) async fn client(
@@ -108,40 +109,39 @@ pub(crate) async fn client(
108109
let auth = AuthDecorator::new(token);
109110

110111
let middlewares = tower::ServiceBuilder::new()
111-
.layer(tonic::service::interceptor::InterceptorLayer::new(auth))
112-
.into_inner();
112+
.layer(re_protos::headers::new_rerun_headers_propagation_layer())
113+
.layer(tonic::service::interceptor::InterceptorLayer::new(auth));
113114

114115
let svc = tower::ServiceBuilder::new()
115-
.layer(middlewares)
116+
.layer(middlewares.into_inner())
116117
.service(channel);
117118

118119
Ok(RerunCloudServiceClient::new(svc).max_decoding_message_size(MAX_DECODING_MESSAGE_SIZE))
119120
}
120121

121122
#[cfg(all(not(target_arch = "wasm32"), feature = "perf_telemetry"))]
122-
pub type RedapClientInner =
123-
re_perf_telemetry::external::tower_http::propagate_header::PropagateHeader<
124-
re_perf_telemetry::external::tower_http::propagate_header::PropagateHeader<
125-
re_perf_telemetry::external::tower_http::trace::Trace<
126-
tonic::service::interceptor::InterceptedService<
127-
tonic::service::interceptor::InterceptedService<
128-
tonic::transport::Channel,
129-
re_auth::client::AuthDecorator,
130-
>,
131-
re_perf_telemetry::TracingInjectorInterceptor,
132-
>,
133-
re_perf_telemetry::external::tower_http::classify::SharedClassifier<
134-
re_perf_telemetry::external::tower_http::classify::GrpcErrorsAsFailures,
135-
>,
136-
re_perf_telemetry::GrpcMakeSpan,
123+
pub type RedapClientInner = re_protos::headers::PropagateHeaders<
124+
re_perf_telemetry::external::tower_http::trace::Trace<
125+
tonic::service::interceptor::InterceptedService<
126+
tonic::service::interceptor::InterceptedService<
127+
tonic::transport::Channel,
128+
re_auth::client::AuthDecorator,
137129
>,
130+
re_perf_telemetry::TracingInjectorInterceptor,
131+
>,
132+
re_perf_telemetry::external::tower_http::classify::SharedClassifier<
133+
re_perf_telemetry::external::tower_http::classify::GrpcErrorsAsFailures,
138134
>,
139-
>;
135+
re_perf_telemetry::GrpcMakeSpan,
136+
>,
137+
>;
140138

141139
#[cfg(all(not(target_arch = "wasm32"), not(feature = "perf_telemetry")))]
142-
pub type RedapClientInner = tonic::service::interceptor::InterceptedService<
143-
tonic::transport::Channel,
144-
re_auth::client::AuthDecorator,
140+
pub type RedapClientInner = re_protos::headers::PropagateHeaders<
141+
tonic::service::interceptor::InterceptedService<
142+
tonic::transport::Channel,
143+
re_auth::client::AuthDecorator,
144+
>,
145145
>;
146146

147147
pub type RedapClient = RerunCloudServiceClient<RedapClientInner>;
@@ -155,17 +155,16 @@ pub(crate) async fn client(
155155

156156
let auth = AuthDecorator::new(token);
157157

158-
let middlewares = tower::ServiceBuilder::new();
158+
let middlewares = tower::ServiceBuilder::new()
159+
.layer(re_protos::headers::new_rerun_headers_propagation_layer());
159160

160161
#[cfg(feature = "perf_telemetry")]
161162
let middlewares = middlewares.layer(re_perf_telemetry::new_client_telemetry_layer());
162163

163-
let middlewares = middlewares
164-
.layer(tonic::service::interceptor::InterceptorLayer::new(auth))
165-
.into_inner();
164+
let middlewares = middlewares.layer(tonic::service::interceptor::InterceptorLayer::new(auth));
166165

167166
let svc = tower::ServiceBuilder::new()
168-
.layer(middlewares)
167+
.layer(middlewares.into_inner())
169168
.service(channel);
170169

171170
Ok(RerunCloudServiceClient::new(svc).max_decoding_message_size(MAX_DECODING_MESSAGE_SIZE))

crates/utils/re_perf_telemetry/src/grpc.rs

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -492,13 +492,7 @@ pub type ServerTelemetryLayer = tower::layer::util::Stack<
492492
GrpcOnFirstBodyChunk,
493493
GrpcOnEos,
494494
>,
495-
tower::layer::util::Stack<
496-
tower_http::propagate_header::PropagateHeaderLayer,
497-
tower::layer::util::Stack<
498-
tower_http::propagate_header::PropagateHeaderLayer,
499-
tower::layer::util::Identity,
500-
>,
501-
>,
495+
tower::layer::util::Identity,
502496
>,
503497
>;
504498

@@ -507,12 +501,6 @@ pub type ServerTelemetryLayer = tower::layer::util::Stack<
507501
/// * Logs all gRPC responses (status, latency, etc).
508502
/// * Measures all gRPC responses (status, latency, etc).
509503
pub fn new_server_telemetry_layer() -> ServerTelemetryLayer {
510-
use tower_http::propagate_header::PropagateHeaderLayer;
511-
let dataset_id_propagation_layer =
512-
PropagateHeaderLayer::new(http::HeaderName::from_static(RERUN_HTTP_HEADER_ENTRY_ID));
513-
let request_id_propagation_layer =
514-
PropagateHeaderLayer::new(http::HeaderName::from_static("x-request-id"));
515-
516504
let trace_layer = tower_http::trace::TraceLayer::new_for_grpc()
517505
.make_span_with(GrpcMakeSpan::new())
518506
.on_request(GrpcOnRequest::new())
@@ -521,8 +509,6 @@ pub fn new_server_telemetry_layer() -> ServerTelemetryLayer {
521509
.on_eos(GrpcOnEos::new());
522510

523511
tower::ServiceBuilder::new()
524-
.layer(dataset_id_propagation_layer)
525-
.layer(request_id_propagation_layer)
526512
.layer(trace_layer)
527513
.layer(TracingExtractorInterceptor::new_layer())
528514
.into_inner()
@@ -532,13 +518,7 @@ pub type ClientTelemetryLayer = tower::layer::util::Stack<
532518
tonic::service::interceptor::InterceptorLayer<TracingInjectorInterceptor>,
533519
tower::layer::util::Stack<
534520
tower_http::trace::TraceLayer<tower_http::trace::GrpcMakeClassifier, GrpcMakeSpan>,
535-
tower::layer::util::Stack<
536-
tower_http::propagate_header::PropagateHeaderLayer,
537-
tower::layer::util::Stack<
538-
tower_http::propagate_header::PropagateHeaderLayer,
539-
tower::layer::util::Identity,
540-
>,
541-
>,
521+
tower::layer::util::Identity,
542522
>,
543523
>;
544524

@@ -550,18 +530,10 @@ pub type ClientTelemetryLayer = tower::layer::util::Stack<
550530
// TODO(cmc): at the moment there's little value to have anything beyond traces on the client, but
551531
// we ultimately can add all the same things that we have on the server as we need them.
552532
pub fn new_client_telemetry_layer() -> ClientTelemetryLayer {
553-
use tower_http::propagate_header::PropagateHeaderLayer;
554-
let dataset_id_propagation_layer =
555-
PropagateHeaderLayer::new(http::HeaderName::from_static(RERUN_HTTP_HEADER_ENTRY_ID));
556-
let request_id_propagation_layer =
557-
PropagateHeaderLayer::new(http::HeaderName::from_static("x-request-id"));
558-
559533
let trace_layer =
560534
tower_http::trace::TraceLayer::new_for_grpc().make_span_with(GrpcMakeSpan::new());
561535

562536
tower::ServiceBuilder::new()
563-
.layer(dataset_id_propagation_layer)
564-
.layer(request_id_propagation_layer)
565537
.layer(trace_layer)
566538
.layer(TracingInjectorInterceptor::new_layer())
567539
.into_inner()

0 commit comments

Comments
 (0)