|
| 1 | +#![doc = include_str!("../README.md")] |
| 2 | + |
| 3 | +use rama::{ |
| 4 | + Service, |
| 5 | + error::OpaqueError, |
| 6 | + http::{Request, server::HttpServer, service::web::response::IntoResponse}, |
| 7 | + tcp::server::TcpListener, |
| 8 | +}; |
| 9 | +use shuttle_runtime::{CustomError, Error, tokio}; |
| 10 | +use std::{convert::Infallible, fmt, net::SocketAddr}; |
| 11 | + |
| 12 | +/// A wrapper type for [`Service`] so we can implement [`shuttle_runtime::Service`] for it. |
| 13 | +pub struct RamaService<T, State> { |
| 14 | + svc: T, |
| 15 | + state: State, |
| 16 | +} |
| 17 | + |
| 18 | +impl<T: Clone, State: Clone> Clone for RamaService<T, State> { |
| 19 | + fn clone(&self) -> Self { |
| 20 | + Self { |
| 21 | + svc: self.svc.clone(), |
| 22 | + state: self.state.clone(), |
| 23 | + } |
| 24 | + } |
| 25 | +} |
| 26 | + |
| 27 | +impl<T: fmt::Debug, State: fmt::Debug> fmt::Debug for RamaService<T, State> { |
| 28 | + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 29 | + f.debug_struct("RamaService") |
| 30 | + .field("svc", &self.svc) |
| 31 | + .field("state", &self.state) |
| 32 | + .finish() |
| 33 | + } |
| 34 | +} |
| 35 | + |
| 36 | +/// Private type wrapper to indicate [`RamaService`] |
| 37 | +/// is used by the user from the Transport layer (tcp). |
| 38 | +pub struct Transport<S>(S); |
| 39 | + |
| 40 | +/// Private type wrapper to indicate [`RamaService`] |
| 41 | +/// is used by the user from the Application layer (http(s)). |
| 42 | +pub struct Application<S>(S); |
| 43 | + |
| 44 | +macro_rules! impl_wrapper_derive_traits { |
| 45 | + ($name:ident) => { |
| 46 | + impl<S: Clone> Clone for $name<S> { |
| 47 | + fn clone(&self) -> Self { |
| 48 | + Self(self.0.clone()) |
| 49 | + } |
| 50 | + } |
| 51 | + |
| 52 | + impl<S: fmt::Debug> fmt::Debug for $name<S> { |
| 53 | + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 54 | + f.debug_tuple(stringify!($name)).field(&self.0).finish() |
| 55 | + } |
| 56 | + } |
| 57 | + }; |
| 58 | +} |
| 59 | + |
| 60 | +impl_wrapper_derive_traits!(Transport); |
| 61 | +impl_wrapper_derive_traits!(Application); |
| 62 | + |
| 63 | +impl<S> RamaService<Transport<S>, ()> { |
| 64 | + pub fn transport(svc: S) -> Self { |
| 65 | + Self { |
| 66 | + svc: Transport(svc), |
| 67 | + state: (), |
| 68 | + } |
| 69 | + } |
| 70 | +} |
| 71 | + |
| 72 | +impl<S> RamaService<Application<S>, ()> { |
| 73 | + pub fn application(svc: S) -> Self { |
| 74 | + Self { |
| 75 | + svc: Application(svc), |
| 76 | + state: (), |
| 77 | + } |
| 78 | + } |
| 79 | +} |
| 80 | + |
| 81 | +impl<T> RamaService<T, ()> { |
| 82 | + /// Attach state to this [`RamaService`], such that it will be passed |
| 83 | + /// as part of each request's [`Context`]. |
| 84 | + /// |
| 85 | + /// [`Context`]: rama::Context |
| 86 | + pub fn with_state<State>(self, state: State) -> RamaService<T, State> |
| 87 | + where |
| 88 | + State: Clone + Send + Sync + 'static, |
| 89 | + { |
| 90 | + RamaService { |
| 91 | + svc: self.svc, |
| 92 | + state, |
| 93 | + } |
| 94 | + } |
| 95 | +} |
| 96 | + |
| 97 | +#[shuttle_runtime::async_trait] |
| 98 | +impl<S, State> shuttle_runtime::Service for RamaService<Transport<S>, State> |
| 99 | +where |
| 100 | + S: Service<State, tokio::net::TcpStream>, |
| 101 | + State: Clone + Send + Sync + 'static, |
| 102 | +{ |
| 103 | + /// Takes the service that is returned by the user in their [shuttle_runtime::main] function |
| 104 | + /// and binds to an address passed in by shuttle. |
| 105 | + async fn bind(self, addr: SocketAddr) -> Result<(), Error> { |
| 106 | + TcpListener::build_with_state(self.state) |
| 107 | + .bind(addr) |
| 108 | + .await |
| 109 | + .map_err(|err| Error::BindPanic(err.to_string()))? |
| 110 | + .serve(self.svc.0) |
| 111 | + .await; |
| 112 | + Ok(()) |
| 113 | + } |
| 114 | +} |
| 115 | + |
| 116 | +#[shuttle_runtime::async_trait] |
| 117 | +impl<S, State, Response> shuttle_runtime::Service for RamaService<Application<S>, State> |
| 118 | +where |
| 119 | + S: Service<State, Request, Response = Response, Error = Infallible>, |
| 120 | + Response: IntoResponse + Send + 'static, |
| 121 | + State: Clone + Send + Sync + 'static, |
| 122 | +{ |
| 123 | + /// Takes the service that is returned by the user in their [shuttle_runtime::main] function |
| 124 | + /// and binds to an address passed in by shuttle. |
| 125 | + async fn bind(self, addr: SocketAddr) -> Result<(), Error> { |
| 126 | + // shuttle only supports h1 between load balancer <=> web service, |
| 127 | + // h2 is terminated by shuttle's load balancer |
| 128 | + HttpServer::http1() |
| 129 | + .listen_with_state(self.state, addr, self.svc.0) |
| 130 | + .await |
| 131 | + .map_err(|err| CustomError::new(OpaqueError::from_boxed(err)))?; |
| 132 | + Ok(()) |
| 133 | + } |
| 134 | +} |
| 135 | + |
| 136 | +#[doc = include_str!("../README.md")] |
| 137 | +pub type ShuttleRamaTransport<S, State = ()> = Result<RamaService<Transport<S>, State>, Error>; |
| 138 | + |
| 139 | +#[doc = include_str!("../README.md")] |
| 140 | +pub type ShuttleRamaApplication<S, State = ()> = Result<RamaService<Application<S>, State>, Error>; |
| 141 | + |
| 142 | +pub use shuttle_runtime::{Error as ShuttleError, Service as ShuttleService}; |
0 commit comments