|
13 | 13 | * See the License for the specific language governing permissions and |
14 | 14 | * limitations under the License. |
15 | 15 | */ |
| 16 | +//! # Connection pooling |
| 17 | +//! |
| 18 | +//! This module provides utilities to use connection pooling. As we already know, it is far more |
| 19 | +//! efficient to maintain a number of live connections to a database and share them across multiple |
| 20 | +//! "worker threads", using a "connection pool" because creating individual connections whenever |
| 21 | +//! a worker receives a task is slow while maintaining a connection per worker might be cumbersome |
| 22 | +//! to implement. |
| 23 | +//! |
| 24 | +//! To provide connection pooling, we use [`r2d2`] for a sync connection pool while we use [`bb8`] |
| 25 | +//! to provide an async connection pool. |
| 26 | +//! |
| 27 | +//! ## Sync usage |
| 28 | +//! |
| 29 | +//! Example usage for TLS and non-TLS connection pools are given below. |
| 30 | +//! |
| 31 | +//! ```no_run |
| 32 | +//! use skytable::pool::{ConnectionManager, Pool, TlsPool}; |
| 33 | +//! use skytable::sync::{Connection, TlsConnection}; |
| 34 | +//! |
| 35 | +//! // non-TLS (TCP pool) |
| 36 | +//! let notls_manager = ConnectionManager::<Connection>::new_notls("127.0.0.1".into(), 2003); |
| 37 | +//! let notls_pool = Pool::builder() |
| 38 | +//! .max_size(10) |
| 39 | +//! .build(notls_manager) |
| 40 | +//! .unwrap(); |
| 41 | +//! |
| 42 | +//! // TLS pool |
| 43 | +//! let tls_manager = ConnectionManager::<TlsConnection>::new_tls( |
| 44 | +//! "127.0.0.1".into(), 2003, "cert.pem".into() |
| 45 | +//! ); |
| 46 | +//! let notls_pool = TlsPool::builder() |
| 47 | +//! .max_size(10) |
| 48 | +//! .build(tls_manager) |
| 49 | +//! .unwrap(); |
| 50 | +//!``` |
| 51 | +//! |
| 52 | +//! ## Async usage |
| 53 | +//! |
| 54 | +//! Example usage for TLS and non-TLS connection pools are given below. |
| 55 | +//! |
| 56 | +//! ```no_run |
| 57 | +//! use skytable::pool::{ConnectionManager, AsyncPool, AsyncTlsPool}; |
| 58 | +//! use skytable::aio::{Connection, TlsConnection}; |
| 59 | +//! async fn run() { |
| 60 | +//! // non-TLS (TCP pool) |
| 61 | +//! let notls_manager = ConnectionManager::<Connection>::new_notls("127.0.0.1".into(), 2003); |
| 62 | +//! let notls_pool = AsyncPool::builder() |
| 63 | +//! .max_size(10) |
| 64 | +//! .build(notls_manager) |
| 65 | +//! .await |
| 66 | +//! .unwrap(); |
| 67 | +//! |
| 68 | +//! // TLS pool |
| 69 | +//! let tls_manager = ConnectionManager::<TlsConnection>::new_tls( |
| 70 | +//! "127.0.0.1".into(), 2003, "cert.pem".into() |
| 71 | +//! ); |
| 72 | +//! let notls_pool = AsyncTlsPool::builder() |
| 73 | +//! .max_size(10) |
| 74 | +//! .build(tls_manager) |
| 75 | +//! .await |
| 76 | +//! .unwrap(); |
| 77 | +//! } |
| 78 | +//!``` |
| 79 | +//! |
16 | 80 |
|
17 | 81 | // re-exports |
| 82 | +// sync |
18 | 83 | #[cfg(any(feature = "sync", feature = "pool"))] |
19 | | -pub use self::sync_impls::{ConnectionManager, Pool, TlsPool}; |
| 84 | +pub use self::sync_impls::{Pool, TlsPool}; |
| 85 | +#[cfg(any(feature = "sync", feature = "pool"))] |
| 86 | +/// [`r2d2`](https://docs.rs/r2d2)'s error type |
| 87 | +pub use r2d2::Error as r2d2Error; |
| 88 | +// async |
| 89 | +#[cfg(any(feature = "async", feature = "aio-pool"))] |
| 90 | +pub use self::async_impls::{Pool as AsyncPool, TlsPool as AsyncTlsPool}; |
| 91 | +#[cfg(any(feature = "async", feature = "aio-pool"))] |
| 92 | +/// [`bb8`](https://docs.rs/bb8)'s error type |
| 93 | +pub use bb8::RunError as bb8Error; |
20 | 94 |
|
21 | | -use crate::error::Error; |
22 | | -use core::fmt; |
| 95 | +// imports |
| 96 | +use core::marker::PhantomData; |
23 | 97 |
|
24 | 98 | #[derive(Debug)] |
25 | | -pub enum PoolError { |
26 | | - PoolError(String), |
27 | | - Other(Error), |
28 | | -} |
29 | | - |
30 | | -impl From<Error> for PoolError { |
31 | | - fn from(e: Error) -> Self { |
32 | | - Self::Other(e) |
33 | | - } |
| 99 | +pub struct ConnectionManager<C> { |
| 100 | + host: String, |
| 101 | + port: u16, |
| 102 | + cert: Option<String>, |
| 103 | + _m: PhantomData<C>, |
34 | 104 | } |
35 | 105 |
|
36 | | -impl fmt::Display for PoolError { |
37 | | - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
38 | | - match self { |
39 | | - Self::PoolError(e) => write!(f, "Pool error: {}", e), |
40 | | - Self::Other(e) => write!(f, "Pool connection error: {}", e), |
| 106 | +impl<C> ConnectionManager<C> { |
| 107 | + fn _new(host: String, port: u16, cert: Option<String>) -> Self { |
| 108 | + Self { |
| 109 | + host, |
| 110 | + port, |
| 111 | + cert, |
| 112 | + _m: PhantomData, |
41 | 113 | } |
42 | 114 | } |
43 | 115 | } |
44 | 116 |
|
45 | | -impl std::error::Error for PoolError {} |
46 | | - |
47 | | -#[derive(Debug)] |
48 | | -pub enum ConnectionProfile { |
49 | | - NoTls { |
50 | | - host: String, |
51 | | - port: u16, |
52 | | - }, |
53 | | - Tls { |
54 | | - host: String, |
55 | | - port: u16, |
56 | | - cert: String, |
57 | | - }, |
58 | | -} |
59 | | - |
60 | 117 | #[cfg(any(feature = "sync", feature = "pool"))] |
61 | 118 | mod sync_impls { |
62 | | - use super::{ConnectionProfile, PoolError}; |
| 119 | + use super::ConnectionManager; |
63 | 120 | use crate::sync::{Connection as SyncConnection, TlsConnection as SyncTlsConnection}; |
64 | 121 | use crate::{ |
65 | 122 | error::{Error, SkyhashError}, |
66 | 123 | Element, Query, SkyQueryResult, SkyResult, |
67 | 124 | }; |
68 | | - use core::marker::PhantomData; |
69 | 125 | use r2d2::ManageConnection; |
70 | 126 |
|
71 | 127 | pub type Pool = r2d2::Pool<ConnectionManager<SyncConnection>>; |
72 | 128 | pub type TlsPool = r2d2::Pool<ConnectionManager<SyncTlsConnection>>; |
73 | 129 |
|
74 | | - #[derive(Debug)] |
75 | | - pub struct ConnectionManager<C> { |
76 | | - profile: ConnectionProfile, |
77 | | - _m: PhantomData<C>, |
78 | | - } |
79 | | - |
80 | | - impl<C> ConnectionManager<C> { |
81 | | - fn _new(profile: ConnectionProfile) -> Self { |
82 | | - Self { |
83 | | - profile, |
84 | | - _m: PhantomData, |
85 | | - } |
86 | | - } |
87 | | - } |
88 | | - |
89 | 130 | impl ConnectionManager<SyncConnection> { |
90 | | - pub fn new(host: String, port: u16) -> Self { |
91 | | - Self::_new(ConnectionProfile::NoTls { host, port }) |
| 131 | + pub fn new_notls(host: String, port: u16) -> Self { |
| 132 | + Self::_new(host, port, None) |
92 | 133 | } |
93 | 134 | } |
94 | | - |
95 | 135 | impl ConnectionManager<SyncTlsConnection> { |
96 | 136 | pub fn new_tls(host: String, port: u16, cert: String) -> Self { |
97 | | - Self::_new(ConnectionProfile::Tls { host, port, cert }) |
| 137 | + Self::_new(host, port, Some(cert)) |
98 | 138 | } |
99 | 139 | } |
100 | 140 |
|
101 | 141 | pub trait PoolableConnection: Send + Sync + Sized { |
102 | | - fn get_connection(profile: &ConnectionProfile) -> SkyResult<Self>; |
| 142 | + fn get_connection(host: &str, port: u16, tls_cert: Option<&String>) -> SkyResult<Self>; |
103 | 143 | fn run_query(&mut self, q: Query) -> SkyQueryResult; |
104 | 144 | } |
105 | 145 |
|
106 | 146 | impl PoolableConnection for SyncConnection { |
107 | | - fn get_connection(profile: &ConnectionProfile) -> SkyResult<Self> { |
108 | | - if let ConnectionProfile::NoTls { host, port } = profile { |
109 | | - let c = Self::new(host, *port)?; |
110 | | - Ok(c) |
111 | | - } else { |
112 | | - Err(Error::ConfigurationError( |
113 | | - "Connection profile is TLS. Expected TCP", |
114 | | - )) |
115 | | - } |
| 147 | + fn get_connection(host: &str, port: u16, _tls_cert: Option<&String>) -> SkyResult<Self> { |
| 148 | + let c = Self::new(host, port)?; |
| 149 | + Ok(c) |
116 | 150 | } |
117 | 151 | fn run_query(&mut self, q: Query) -> SkyQueryResult { |
118 | 152 | self.run_simple_query(&q) |
119 | 153 | } |
120 | 154 | } |
121 | 155 |
|
122 | 156 | impl PoolableConnection for SyncTlsConnection { |
123 | | - fn get_connection(profile: &ConnectionProfile) -> SkyResult<Self> { |
124 | | - if let ConnectionProfile::Tls { host, port, cert } = profile { |
125 | | - let c = Self::new(host, *port, cert)?; |
126 | | - Ok(c) |
127 | | - } else { |
128 | | - Err(Error::ConfigurationError( |
129 | | - "Connection profile is TCP. Expected TLS", |
130 | | - )) |
131 | | - } |
| 157 | + fn get_connection(host: &str, port: u16, tls_cert: Option<&String>) -> SkyResult<Self> { |
| 158 | + let c = Self::new( |
| 159 | + &host, |
| 160 | + port, |
| 161 | + tls_cert.ok_or(Error::ConfigurationError( |
| 162 | + "Expected TLS certificate in `ConnectionManager`", |
| 163 | + ))?, |
| 164 | + )?; |
| 165 | + Ok(c) |
132 | 166 | } |
133 | 167 | fn run_query(&mut self, q: Query) -> SkyQueryResult { |
134 | 168 | self.run_simple_query(&q) |
135 | 169 | } |
136 | 170 | } |
137 | 171 | impl<C: PoolableConnection + 'static> ManageConnection for ConnectionManager<C> { |
138 | | - type Error = PoolError; |
| 172 | + type Error = Error; |
139 | 173 | type Connection = C; |
140 | 174 | fn connect(&self) -> Result<Self::Connection, Self::Error> { |
141 | | - C::get_connection(&self.profile).map_err(|e| Self::Error::Other(e)) |
| 175 | + C::get_connection(self.host.as_ref(), self.port, self.cert.as_ref()) |
142 | 176 | } |
143 | 177 | fn is_valid(&self, con: &mut Self::Connection) -> Result<(), Self::Error> { |
144 | 178 | let q = crate::query!("HEYA"); |
145 | 179 | match con.run_query(q)? { |
146 | 180 | Element::String(st) if st.eq("HEY!") => Ok(()), |
147 | | - _ => Err(PoolError::Other(Error::SkyError( |
148 | | - SkyhashError::UnexpectedResponse, |
149 | | - ))), |
| 181 | + _ => Err(Error::SkyError(SkyhashError::UnexpectedResponse)), |
| 182 | + } |
| 183 | + } |
| 184 | + fn has_broken(&self, _: &mut Self::Connection) -> bool { |
| 185 | + false |
| 186 | + } |
| 187 | + } |
| 188 | +} |
| 189 | + |
| 190 | +#[cfg(any(feature = "async", feature = "aio-pool"))] |
| 191 | +mod async_impls { |
| 192 | + use super::ConnectionManager; |
| 193 | + use crate::aio::{Connection as AsyncConnection, TlsConnection as AsyncTlsConnection}; |
| 194 | + use crate::{ |
| 195 | + error::{Error, SkyhashError}, |
| 196 | + Element, Query, SkyQueryResult, SkyResult, |
| 197 | + }; |
| 198 | + use async_trait::async_trait; |
| 199 | + use bb8::{ManageConnection, PooledConnection}; |
| 200 | + |
| 201 | + pub type Pool = bb8::Pool<ConnectionManager<AsyncConnection>>; |
| 202 | + pub type TlsPool = bb8::Pool<ConnectionManager<AsyncTlsConnection>>; |
| 203 | + |
| 204 | + #[async_trait] |
| 205 | + pub trait PoolableConnection: Send + Sync + Sized { |
| 206 | + async fn get_connection( |
| 207 | + host: &str, |
| 208 | + port: u16, |
| 209 | + tls_cert: Option<&String>, |
| 210 | + ) -> SkyResult<Self>; |
| 211 | + async fn run_query(&mut self, q: Query) -> SkyQueryResult; |
| 212 | + } |
| 213 | + |
| 214 | + #[async_trait] |
| 215 | + impl PoolableConnection for AsyncConnection { |
| 216 | + async fn get_connection( |
| 217 | + host: &str, |
| 218 | + port: u16, |
| 219 | + _tls_cert: Option<&String>, |
| 220 | + ) -> SkyResult<Self> { |
| 221 | + let con = AsyncConnection::new(&host, port).await?; |
| 222 | + Ok(con) |
| 223 | + } |
| 224 | + async fn run_query(&mut self, q: Query) -> SkyQueryResult { |
| 225 | + self.run_simple_query(&q).await |
| 226 | + } |
| 227 | + } |
| 228 | + |
| 229 | + #[async_trait] |
| 230 | + impl PoolableConnection for AsyncTlsConnection { |
| 231 | + async fn get_connection( |
| 232 | + host: &str, |
| 233 | + port: u16, |
| 234 | + tls_cert: Option<&String>, |
| 235 | + ) -> SkyResult<Self> { |
| 236 | + let con = AsyncTlsConnection::new( |
| 237 | + &host, |
| 238 | + port, |
| 239 | + tls_cert.ok_or(Error::ConfigurationError( |
| 240 | + "Expected TLS certificate in `ConnectionManager`", |
| 241 | + ))?, |
| 242 | + ) |
| 243 | + .await?; |
| 244 | + Ok(con) |
| 245 | + } |
| 246 | + async fn run_query(&mut self, q: Query) -> SkyQueryResult { |
| 247 | + self.run_simple_query(&q).await |
| 248 | + } |
| 249 | + } |
| 250 | + |
| 251 | + impl ConnectionManager<AsyncConnection> { |
| 252 | + pub fn new_notls(host: String, port: u16) -> Self { |
| 253 | + Self::_new(host, port, None) |
| 254 | + } |
| 255 | + } |
| 256 | + impl ConnectionManager<AsyncTlsConnection> { |
| 257 | + pub fn new_tls(host: String, port: u16, cert: String) -> Self { |
| 258 | + Self::_new(host, port, Some(cert)) |
| 259 | + } |
| 260 | + } |
| 261 | + |
| 262 | + #[async_trait] |
| 263 | + impl<C: PoolableConnection + 'static> ManageConnection for ConnectionManager<C> { |
| 264 | + type Connection = C; |
| 265 | + type Error = Error; |
| 266 | + async fn connect(&self) -> Result<Self::Connection, Self::Error> { |
| 267 | + C::get_connection(&self.host, self.port, self.cert.as_ref()).await |
| 268 | + } |
| 269 | + async fn is_valid(&self, con: &mut PooledConnection<'_, Self>) -> Result<(), Self::Error> { |
| 270 | + match con.run_query(crate::query!("HEYA")).await? { |
| 271 | + Element::String(st) if st.eq("HEY!") => Ok(()), |
| 272 | + _ => Err(Error::SkyError(SkyhashError::UnexpectedResponse)), |
150 | 273 | } |
151 | 274 | } |
152 | 275 | fn has_broken(&self, _: &mut Self::Connection) -> bool { |
|
0 commit comments