Skip to content

Commit 871b320

Browse files
authored
Callback-based gRPC client with C interface (#963)
1 parent 19ae5b7 commit 871b320

File tree

8 files changed

+762
-38
lines changed

8 files changed

+762
-38
lines changed

client/src/callback_based.rs

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
//! This module implements support for callback-based gRPC service that has a callback invoked for
2+
//! every gRPC call instead of directly using the network.
3+
4+
use anyhow::anyhow;
5+
use bytes::{BufMut, BytesMut};
6+
use futures_util::future::BoxFuture;
7+
use futures_util::stream;
8+
use http::{HeaderMap, Request, Response};
9+
use http_body_util::{BodyExt, StreamBody, combinators::BoxBody};
10+
use hyper::body::{Bytes, Frame};
11+
use std::{
12+
sync::Arc,
13+
task::{Context, Poll},
14+
};
15+
use tonic::{Status, metadata::GRPC_CONTENT_TYPE};
16+
use tower::Service;
17+
18+
/// gRPC request for use by a callback.
19+
pub struct GrpcRequest {
20+
/// Fully qualified gRPC service name.
21+
pub service: String,
22+
/// RPC name.
23+
pub rpc: String,
24+
/// Request headers.
25+
pub headers: HeaderMap,
26+
/// Protobuf bytes of the request.
27+
pub proto: Bytes,
28+
}
29+
30+
/// Successful gRPC response returned by a callback.
31+
pub struct GrpcSuccessResponse {
32+
/// Response headers.
33+
pub headers: HeaderMap,
34+
35+
/// Response proto bytes.
36+
pub proto: Vec<u8>,
37+
}
38+
39+
/// gRPC service that invokes the given callback on each call.
40+
#[derive(Clone)]
41+
pub struct CallbackBasedGrpcService {
42+
/// Callback to invoke on each RPC call.
43+
#[allow(clippy::type_complexity)] // Signature is not that complex
44+
pub callback: Arc<
45+
dyn Fn(GrpcRequest) -> BoxFuture<'static, Result<GrpcSuccessResponse, Status>>
46+
+ Send
47+
+ Sync,
48+
>,
49+
}
50+
51+
impl Service<Request<tonic::body::Body>> for CallbackBasedGrpcService {
52+
type Response = http::Response<tonic::body::Body>;
53+
type Error = anyhow::Error;
54+
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
55+
56+
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
57+
Poll::Ready(Ok(()))
58+
}
59+
60+
fn call(&mut self, req: Request<tonic::body::Body>) -> Self::Future {
61+
let callback = self.callback.clone();
62+
63+
Box::pin(async move {
64+
// Build req
65+
let (parts, body) = req.into_parts();
66+
let mut path_parts = parts.uri.path().trim_start_matches('/').split('/');
67+
let req_body = body.collect().await.map_err(|e| anyhow!(e))?.to_bytes();
68+
// Body is flag saying whether compressed (we do not support that), then 32-bit length,
69+
// then the actual proto.
70+
if req_body.len() < 5 {
71+
return Err(anyhow!("Too few request bytes: {}", req_body.len()));
72+
} else if req_body[0] != 0 {
73+
return Err(anyhow!("Compression not supported"));
74+
}
75+
let req_proto_len =
76+
u32::from_be_bytes([req_body[1], req_body[2], req_body[3], req_body[4]]) as usize;
77+
if req_body.len() < 5 + req_proto_len {
78+
return Err(anyhow!(
79+
"Expected request body length at least {}, got {}",
80+
5 + req_proto_len,
81+
req_body.len()
82+
));
83+
}
84+
let req = GrpcRequest {
85+
service: path_parts.next().unwrap_or_default().to_owned(),
86+
rpc: path_parts.next().unwrap_or_default().to_owned(),
87+
headers: parts.headers,
88+
proto: req_body.slice(5..5 + req_proto_len),
89+
};
90+
91+
// Invoke and handle response
92+
match (callback)(req).await {
93+
Ok(success) => {
94+
// Create body bytes which requires a flag saying whether compressed, then
95+
// message len, then actual message. So we create a Bytes for those 5 prepend
96+
// parts, then stream it alongside the user-provided Vec. This allows us to
97+
// avoid copying the vec
98+
let mut body_prepend = BytesMut::with_capacity(5);
99+
body_prepend.put_u8(0); // 0 means no compression
100+
body_prepend.put_u32(success.proto.len() as u32);
101+
let stream = stream::iter(vec![
102+
Ok::<_, Status>(Frame::data(Bytes::from(body_prepend))),
103+
Ok::<_, Status>(Frame::data(Bytes::from(success.proto))),
104+
]);
105+
let stream_body = StreamBody::new(stream);
106+
let full_body = BoxBody::new(stream_body).boxed();
107+
108+
// Build response appending headers
109+
let mut resp_builder = Response::builder()
110+
.status(200)
111+
.header(http::header::CONTENT_TYPE, GRPC_CONTENT_TYPE);
112+
for (key, value) in success.headers.iter() {
113+
resp_builder = resp_builder.header(key, value);
114+
}
115+
Ok(resp_builder
116+
.body(tonic::body::Body::new(full_body))
117+
.map_err(|e| anyhow!(e))?)
118+
}
119+
Err(status) => Ok(status.into_http()),
120+
}
121+
})
122+
}
123+
}

client/src/lib.rs

Lines changed: 53 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#[macro_use]
88
extern crate tracing;
99

10+
pub mod callback_based;
1011
mod metrics;
1112
mod proxy;
1213
mod raw;
@@ -35,7 +36,7 @@ pub use workflow_handle::{
3536
};
3637

3738
use crate::{
38-
metrics::{GrpcMetricSvc, MetricsContext},
39+
metrics::{ChannelOrGrpcOverride, GrpcMetricSvc, MetricsContext},
3940
raw::{AttachMetricLabels, sealed::RawClientLike},
4041
sealed::WfHandleClient,
4142
workflow_handle::UntypedWorkflowHandle,
@@ -434,34 +435,59 @@ impl ClientOptions {
434435
metrics_meter: Option<TemporalMeter>,
435436
) -> Result<RetryClient<ConfiguredClient<TemporalServiceClientWithMetrics>>, ClientInitError>
436437
{
437-
let channel = Channel::from_shared(self.target_url.to_string())?;
438-
let channel = self.add_tls_to_channel(channel).await?;
439-
let channel = if let Some(keep_alive) = self.keep_alive.as_ref() {
440-
channel
441-
.keep_alive_while_idle(true)
442-
.http2_keep_alive_interval(keep_alive.interval)
443-
.keep_alive_timeout(keep_alive.timeout)
444-
} else {
445-
channel
446-
};
447-
let channel = if let Some(origin) = self.override_origin.clone() {
448-
channel.origin(origin)
449-
} else {
450-
channel
451-
};
452-
// If there is a proxy, we have to connect that way
453-
let channel = if let Some(proxy) = self.http_connect_proxy.as_ref() {
454-
proxy.connect_endpoint(&channel).await?
455-
} else {
456-
channel.connect().await?
457-
};
458-
let service = ServiceBuilder::new()
459-
.layer_fn(move |channel| GrpcMetricSvc {
460-
inner: channel,
438+
self.connect_no_namespace_with_service_override(metrics_meter, None)
439+
.await
440+
}
441+
442+
/// Attempt to establish a connection to the Temporal server and return a gRPC client which is
443+
/// intercepted with retry, default headers functionality, and metrics if provided. If a
444+
/// service_override is present, network-specific options are ignored and the callback is
445+
/// invoked for each gRPC call.
446+
///
447+
/// See [RetryClient] for more
448+
pub async fn connect_no_namespace_with_service_override(
449+
&self,
450+
metrics_meter: Option<TemporalMeter>,
451+
service_override: Option<callback_based::CallbackBasedGrpcService>,
452+
) -> Result<RetryClient<ConfiguredClient<TemporalServiceClientWithMetrics>>, ClientInitError>
453+
{
454+
let service = if let Some(service_override) = service_override {
455+
GrpcMetricSvc {
456+
inner: ChannelOrGrpcOverride::GrpcOverride(service_override),
461457
metrics: metrics_meter.clone().map(MetricsContext::new),
462458
disable_errcode_label: self.disable_error_code_metric_tags,
463-
})
464-
.service(channel);
459+
}
460+
} else {
461+
let channel = Channel::from_shared(self.target_url.to_string())?;
462+
let channel = self.add_tls_to_channel(channel).await?;
463+
let channel = if let Some(keep_alive) = self.keep_alive.as_ref() {
464+
channel
465+
.keep_alive_while_idle(true)
466+
.http2_keep_alive_interval(keep_alive.interval)
467+
.keep_alive_timeout(keep_alive.timeout)
468+
} else {
469+
channel
470+
};
471+
let channel = if let Some(origin) = self.override_origin.clone() {
472+
channel.origin(origin)
473+
} else {
474+
channel
475+
};
476+
// If there is a proxy, we have to connect that way
477+
let channel = if let Some(proxy) = self.http_connect_proxy.as_ref() {
478+
proxy.connect_endpoint(&channel).await?
479+
} else {
480+
channel.connect().await?
481+
};
482+
ServiceBuilder::new()
483+
.layer_fn(move |channel| GrpcMetricSvc {
484+
inner: ChannelOrGrpcOverride::Channel(channel),
485+
metrics: metrics_meter.clone().map(MetricsContext::new),
486+
disable_errcode_label: self.disable_error_code_metric_tags,
487+
})
488+
.service(channel)
489+
};
490+
465491
let headers = Arc::new(RwLock::new(ClientHeaders {
466492
user_headers: self.headers.clone().unwrap_or_default(),
467493
api_key: self.api_key.clone(),

client/src/metrics.rs

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1-
use crate::{AttachMetricLabels, CallType, dbg_panic};
1+
use crate::{AttachMetricLabels, CallType, callback_based, dbg_panic};
2+
use futures_util::TryFutureExt;
3+
use futures_util::future::Either;
24
use futures_util::{FutureExt, future::BoxFuture};
35
use std::{
6+
fmt,
47
sync::Arc,
58
task::{Context, Poll},
69
time::{Duration, Instant},
@@ -205,19 +208,37 @@ fn code_as_screaming_snake(code: &Code) -> &'static str {
205208
/// Implements metrics functionality for gRPC (really, any http) calls
206209
#[derive(Debug, Clone)]
207210
pub struct GrpcMetricSvc {
208-
pub(crate) inner: Channel,
211+
pub(crate) inner: ChannelOrGrpcOverride,
209212
// If set to none, metrics are a no-op
210213
pub(crate) metrics: Option<MetricsContext>,
211214
pub(crate) disable_errcode_label: bool,
212215
}
213216

217+
#[derive(Clone)]
218+
pub(crate) enum ChannelOrGrpcOverride {
219+
Channel(Channel),
220+
GrpcOverride(callback_based::CallbackBasedGrpcService),
221+
}
222+
223+
impl fmt::Debug for ChannelOrGrpcOverride {
224+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
225+
match self {
226+
ChannelOrGrpcOverride::Channel(inner) => fmt::Debug::fmt(inner, f),
227+
ChannelOrGrpcOverride::GrpcOverride(_) => f.write_str("<callback-based-grpc-service>"),
228+
}
229+
}
230+
}
231+
214232
impl Service<http::Request<Body>> for GrpcMetricSvc {
215233
type Response = http::Response<Body>;
216-
type Error = tonic::transport::Error;
234+
type Error = Box<dyn std::error::Error + Send + Sync>;
217235
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
218236

219237
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
220-
self.inner.poll_ready(cx).map_err(Into::into)
238+
match &mut self.inner {
239+
ChannelOrGrpcOverride::Channel(inner) => inner.poll_ready(cx).map_err(Into::into),
240+
ChannelOrGrpcOverride::GrpcOverride(inner) => inner.poll_ready(cx).map_err(Into::into),
241+
}
221242
}
222243

223244
fn call(&mut self, mut req: http::Request<Body>) -> Self::Future {
@@ -245,7 +266,14 @@ impl Service<http::Request<Body>> for GrpcMetricSvc {
245266
metrics
246267
})
247268
});
248-
let callfut = self.inner.call(req);
269+
let callfut = match &mut self.inner {
270+
ChannelOrGrpcOverride::Channel(inner) => {
271+
Either::Left(inner.call(req).map_err(Into::into))
272+
}
273+
ChannelOrGrpcOverride::GrpcOverride(inner) => {
274+
Either::Right(inner.call(req).map_err(Into::into))
275+
}
276+
};
249277
let errcode_label_disabled = self.disable_errcode_label;
250278
async move {
251279
let started = Instant::now();

core-c-bridge/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ crate-type = ["cdylib"]
1010
[dependencies]
1111
anyhow = "1.0"
1212
async-trait = "0.1"
13+
futures-util = { version = "0.3", default-features = false }
14+
http = "1.1"
1315
libc = "0.2"
1416
prost = { workspace = true }
1517
# We rely on Cargo semver rules not updating a 0.x to 0.y. Per the rand

0 commit comments

Comments
 (0)