|
| 1 | +//! This module implements support for callback-based gRPC service that has a callback invoked for |
| 2 | +//! every gRPC call instead of directly using the network. |
| 3 | +
|
| 4 | +use anyhow::anyhow; |
| 5 | +use bytes::{BufMut, BytesMut}; |
| 6 | +use futures_util::future::BoxFuture; |
| 7 | +use futures_util::stream; |
| 8 | +use http::{HeaderMap, Request, Response}; |
| 9 | +use http_body_util::{BodyExt, StreamBody, combinators::BoxBody}; |
| 10 | +use hyper::body::{Bytes, Frame}; |
| 11 | +use std::{ |
| 12 | + sync::Arc, |
| 13 | + task::{Context, Poll}, |
| 14 | +}; |
| 15 | +use tonic::{Status, metadata::GRPC_CONTENT_TYPE}; |
| 16 | +use tower::Service; |
| 17 | + |
| 18 | +/// gRPC request for use by a callback. |
| 19 | +pub struct GrpcRequest { |
| 20 | + /// Fully qualified gRPC service name. |
| 21 | + pub service: String, |
| 22 | + /// RPC name. |
| 23 | + pub rpc: String, |
| 24 | + /// Request headers. |
| 25 | + pub headers: HeaderMap, |
| 26 | + /// Protobuf bytes of the request. |
| 27 | + pub proto: Bytes, |
| 28 | +} |
| 29 | + |
| 30 | +/// Successful gRPC response returned by a callback. |
| 31 | +pub struct GrpcSuccessResponse { |
| 32 | + /// Response headers. |
| 33 | + pub headers: HeaderMap, |
| 34 | + |
| 35 | + /// Response proto bytes. |
| 36 | + pub proto: Vec<u8>, |
| 37 | +} |
| 38 | + |
| 39 | +/// gRPC service that invokes the given callback on each call. |
| 40 | +#[derive(Clone)] |
| 41 | +pub struct CallbackBasedGrpcService { |
| 42 | + /// Callback to invoke on each RPC call. |
| 43 | + #[allow(clippy::type_complexity)] // Signature is not that complex |
| 44 | + pub callback: Arc< |
| 45 | + dyn Fn(GrpcRequest) -> BoxFuture<'static, Result<GrpcSuccessResponse, Status>> |
| 46 | + + Send |
| 47 | + + Sync, |
| 48 | + >, |
| 49 | +} |
| 50 | + |
| 51 | +impl Service<Request<tonic::body::Body>> for CallbackBasedGrpcService { |
| 52 | + type Response = http::Response<tonic::body::Body>; |
| 53 | + type Error = anyhow::Error; |
| 54 | + type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>; |
| 55 | + |
| 56 | + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { |
| 57 | + Poll::Ready(Ok(())) |
| 58 | + } |
| 59 | + |
| 60 | + fn call(&mut self, req: Request<tonic::body::Body>) -> Self::Future { |
| 61 | + let callback = self.callback.clone(); |
| 62 | + |
| 63 | + Box::pin(async move { |
| 64 | + // Build req |
| 65 | + let (parts, body) = req.into_parts(); |
| 66 | + let mut path_parts = parts.uri.path().trim_start_matches('/').split('/'); |
| 67 | + let req_body = body.collect().await.map_err(|e| anyhow!(e))?.to_bytes(); |
| 68 | + // Body is flag saying whether compressed (we do not support that), then 32-bit length, |
| 69 | + // then the actual proto. |
| 70 | + if req_body.len() < 5 { |
| 71 | + return Err(anyhow!("Too few request bytes: {}", req_body.len())); |
| 72 | + } else if req_body[0] != 0 { |
| 73 | + return Err(anyhow!("Compression not supported")); |
| 74 | + } |
| 75 | + let req_proto_len = |
| 76 | + u32::from_be_bytes([req_body[1], req_body[2], req_body[3], req_body[4]]) as usize; |
| 77 | + if req_body.len() < 5 + req_proto_len { |
| 78 | + return Err(anyhow!( |
| 79 | + "Expected request body length at least {}, got {}", |
| 80 | + 5 + req_proto_len, |
| 81 | + req_body.len() |
| 82 | + )); |
| 83 | + } |
| 84 | + let req = GrpcRequest { |
| 85 | + service: path_parts.next().unwrap_or_default().to_owned(), |
| 86 | + rpc: path_parts.next().unwrap_or_default().to_owned(), |
| 87 | + headers: parts.headers, |
| 88 | + proto: req_body.slice(5..5 + req_proto_len), |
| 89 | + }; |
| 90 | + |
| 91 | + // Invoke and handle response |
| 92 | + match (callback)(req).await { |
| 93 | + Ok(success) => { |
| 94 | + // Create body bytes which requires a flag saying whether compressed, then |
| 95 | + // message len, then actual message. So we create a Bytes for those 5 prepend |
| 96 | + // parts, then stream it alongside the user-provided Vec. This allows us to |
| 97 | + // avoid copying the vec |
| 98 | + let mut body_prepend = BytesMut::with_capacity(5); |
| 99 | + body_prepend.put_u8(0); // 0 means no compression |
| 100 | + body_prepend.put_u32(success.proto.len() as u32); |
| 101 | + let stream = stream::iter(vec![ |
| 102 | + Ok::<_, Status>(Frame::data(Bytes::from(body_prepend))), |
| 103 | + Ok::<_, Status>(Frame::data(Bytes::from(success.proto))), |
| 104 | + ]); |
| 105 | + let stream_body = StreamBody::new(stream); |
| 106 | + let full_body = BoxBody::new(stream_body).boxed(); |
| 107 | + |
| 108 | + // Build response appending headers |
| 109 | + let mut resp_builder = Response::builder() |
| 110 | + .status(200) |
| 111 | + .header(http::header::CONTENT_TYPE, GRPC_CONTENT_TYPE); |
| 112 | + for (key, value) in success.headers.iter() { |
| 113 | + resp_builder = resp_builder.header(key, value); |
| 114 | + } |
| 115 | + Ok(resp_builder |
| 116 | + .body(tonic::body::Body::new(full_body)) |
| 117 | + .map_err(|e| anyhow!(e))?) |
| 118 | + } |
| 119 | + Err(status) => Ok(status.into_http()), |
| 120 | + } |
| 121 | + }) |
| 122 | + } |
| 123 | +} |
0 commit comments