Skip to content

Commit a5dbf60

Browse files
committed
Implement async connection pool
1 parent 2b2c795 commit a5dbf60

File tree

3 files changed

+207
-80
lines changed

3 files changed

+207
-80
lines changed

Cargo.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ ssl = ["openssl"]
2323
sslv = ["openssl/vendored"]
2424
# async
2525
async = ["bytes", "tokio", "aio-pool"]
26-
aio-pool = ["bb8"]
26+
aio-pool = ["bb8", "async-trait"]
2727
# async TLS
2828
aio-ssl = ["tokio-openssl", "openssl"]
2929
aio-sslv = ["tokio-openssl", "openssl/vendored"]
@@ -42,6 +42,10 @@ tokio = { version = "1.15.0", features = [
4242
tokio-openssl = { version = "0.6.3", optional = true }
4343
r2d2 = { version = "0.8.9", optional = true }
4444
bb8 = { version = "0.7.1", optional = true }
45+
async-trait = { version = "0.1.52", optional = true }
46+
47+
[dev-dependencies]
48+
tokio = { version = "1.16.1", features = ["test-util", "macros"] }
4549

4650
[package.metadata.docs.rs]
4751
all-features = true

src/error.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ pub enum Error {
9393
SkyError(SkyhashError),
9494
/// An application level parse error occurred
9595
ParseError(String),
96-
/// An error occurred in the configuration of the connection
96+
/// A configuration error
9797
ConfigurationError(&'static str),
9898
}
9999

@@ -163,7 +163,7 @@ impl fmt::Display for Error {
163163
write!(f, "Server sent unknown data type for this client version")
164164
}
165165
},
166-
Self::ConfigurationError(e) => write!(f, "Connection setup error: {}", e),
166+
Self::ConfigurationError(e) => write!(f, "Configuration error: {}", e),
167167
}
168168
}
169169
}

src/pool.rs

Lines changed: 200 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -13,140 +13,263 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16+
//! # Connection pooling
17+
//!
18+
//! This module provides utilities to use connection pooling. As we already know, it is far more
19+
//! efficient to maintain a number of live connections to a database and share them across multiple
20+
//! "worker threads", using a "connection pool" because creating individual connections whenever
21+
//! a worker receives a task is slow while maintaining a connection per worker might be cumbersome
22+
//! to implement.
23+
//!
24+
//! To provide connection pooling, we use [`r2d2`] for a sync connection pool while we use [`bb8`]
25+
//! to provide an async connection pool.
26+
//!
27+
//! ## Sync usage
28+
//!
29+
//! Example usage for TLS and non-TLS connection pools are given below.
30+
//!
31+
//! ```no_run
32+
//! use skytable::pool::{ConnectionManager, Pool, TlsPool};
33+
//! use skytable::sync::{Connection, TlsConnection};
34+
//!
35+
//! // non-TLS (TCP pool)
36+
//! let notls_manager = ConnectionManager::<Connection>::new_notls("127.0.0.1".into(), 2003);
37+
//! let notls_pool = Pool::builder()
38+
//! .max_size(10)
39+
//! .build(notls_manager)
40+
//! .unwrap();
41+
//!
42+
//! // TLS pool
43+
//! let tls_manager = ConnectionManager::<TlsConnection>::new_tls(
44+
//! "127.0.0.1".into(), 2003, "cert.pem".into()
45+
//! );
46+
//! let notls_pool = TlsPool::builder()
47+
//! .max_size(10)
48+
//! .build(tls_manager)
49+
//! .unwrap();
50+
//!```
51+
//!
52+
//! ## Async usage
53+
//!
54+
//! Example usage for TLS and non-TLS connection pools are given below.
55+
//!
56+
//! ```no_run
57+
//! use skytable::pool::{ConnectionManager, AsyncPool, AsyncTlsPool};
58+
//! use skytable::aio::{Connection, TlsConnection};
59+
//! async fn run() {
60+
//! // non-TLS (TCP pool)
61+
//! let notls_manager = ConnectionManager::<Connection>::new_notls("127.0.0.1".into(), 2003);
62+
//! let notls_pool = AsyncPool::builder()
63+
//! .max_size(10)
64+
//! .build(notls_manager)
65+
//! .await
66+
//! .unwrap();
67+
//!
68+
//! // TLS pool
69+
//! let tls_manager = ConnectionManager::<TlsConnection>::new_tls(
70+
//! "127.0.0.1".into(), 2003, "cert.pem".into()
71+
//! );
72+
//! let notls_pool = AsyncTlsPool::builder()
73+
//! .max_size(10)
74+
//! .build(tls_manager)
75+
//! .await
76+
//! .unwrap();
77+
//! }
78+
//!```
79+
//!
1680
1781
// re-exports
82+
// sync
1883
#[cfg(any(feature = "sync", feature = "pool"))]
19-
pub use self::sync_impls::{ConnectionManager, Pool, TlsPool};
84+
pub use self::sync_impls::{Pool, TlsPool};
85+
#[cfg(any(feature = "sync", feature = "pool"))]
86+
/// [`r2d2`](https://docs.rs/r2d2)'s error type
87+
pub use r2d2::Error as r2d2Error;
88+
// async
89+
#[cfg(any(feature = "async", feature = "aio-pool"))]
90+
pub use self::async_impls::{Pool as AsyncPool, TlsPool as AsyncTlsPool};
91+
#[cfg(any(feature = "async", feature = "aio-pool"))]
92+
/// [`bb8`](https://docs.rs/bb8)'s error type
93+
pub use bb8::RunError as bb8Error;
2094

21-
use crate::error::Error;
22-
use core::fmt;
95+
// imports
96+
use core::marker::PhantomData;
2397

2498
#[derive(Debug)]
25-
pub enum PoolError {
26-
PoolError(String),
27-
Other(Error),
28-
}
29-
30-
impl From<Error> for PoolError {
31-
fn from(e: Error) -> Self {
32-
Self::Other(e)
33-
}
99+
pub struct ConnectionManager<C> {
100+
host: String,
101+
port: u16,
102+
cert: Option<String>,
103+
_m: PhantomData<C>,
34104
}
35105

36-
impl fmt::Display for PoolError {
37-
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
38-
match self {
39-
Self::PoolError(e) => write!(f, "Pool error: {}", e),
40-
Self::Other(e) => write!(f, "Pool connection error: {}", e),
106+
impl<C> ConnectionManager<C> {
107+
fn _new(host: String, port: u16, cert: Option<String>) -> Self {
108+
Self {
109+
host,
110+
port,
111+
cert,
112+
_m: PhantomData,
41113
}
42114
}
43115
}
44116

45-
impl std::error::Error for PoolError {}
46-
47-
#[derive(Debug)]
48-
pub enum ConnectionProfile {
49-
NoTls {
50-
host: String,
51-
port: u16,
52-
},
53-
Tls {
54-
host: String,
55-
port: u16,
56-
cert: String,
57-
},
58-
}
59-
60117
#[cfg(any(feature = "sync", feature = "pool"))]
61118
mod sync_impls {
62-
use super::{ConnectionProfile, PoolError};
119+
use super::ConnectionManager;
63120
use crate::sync::{Connection as SyncConnection, TlsConnection as SyncTlsConnection};
64121
use crate::{
65122
error::{Error, SkyhashError},
66123
Element, Query, SkyQueryResult, SkyResult,
67124
};
68-
use core::marker::PhantomData;
69125
use r2d2::ManageConnection;
70126

71127
pub type Pool = r2d2::Pool<ConnectionManager<SyncConnection>>;
72128
pub type TlsPool = r2d2::Pool<ConnectionManager<SyncTlsConnection>>;
73129

74-
#[derive(Debug)]
75-
pub struct ConnectionManager<C> {
76-
profile: ConnectionProfile,
77-
_m: PhantomData<C>,
78-
}
79-
80-
impl<C> ConnectionManager<C> {
81-
fn _new(profile: ConnectionProfile) -> Self {
82-
Self {
83-
profile,
84-
_m: PhantomData,
85-
}
86-
}
87-
}
88-
89130
impl ConnectionManager<SyncConnection> {
90-
pub fn new(host: String, port: u16) -> Self {
91-
Self::_new(ConnectionProfile::NoTls { host, port })
131+
pub fn new_notls(host: String, port: u16) -> Self {
132+
Self::_new(host, port, None)
92133
}
93134
}
94-
95135
impl ConnectionManager<SyncTlsConnection> {
96136
pub fn new_tls(host: String, port: u16, cert: String) -> Self {
97-
Self::_new(ConnectionProfile::Tls { host, port, cert })
137+
Self::_new(host, port, Some(cert))
98138
}
99139
}
100140

101141
pub trait PoolableConnection: Send + Sync + Sized {
102-
fn get_connection(profile: &ConnectionProfile) -> SkyResult<Self>;
142+
fn get_connection(host: &str, port: u16, tls_cert: Option<&String>) -> SkyResult<Self>;
103143
fn run_query(&mut self, q: Query) -> SkyQueryResult;
104144
}
105145

106146
impl PoolableConnection for SyncConnection {
107-
fn get_connection(profile: &ConnectionProfile) -> SkyResult<Self> {
108-
if let ConnectionProfile::NoTls { host, port } = profile {
109-
let c = Self::new(host, *port)?;
110-
Ok(c)
111-
} else {
112-
Err(Error::ConfigurationError(
113-
"Connection profile is TLS. Expected TCP",
114-
))
115-
}
147+
fn get_connection(host: &str, port: u16, _tls_cert: Option<&String>) -> SkyResult<Self> {
148+
let c = Self::new(host, port)?;
149+
Ok(c)
116150
}
117151
fn run_query(&mut self, q: Query) -> SkyQueryResult {
118152
self.run_simple_query(&q)
119153
}
120154
}
121155

122156
impl PoolableConnection for SyncTlsConnection {
123-
fn get_connection(profile: &ConnectionProfile) -> SkyResult<Self> {
124-
if let ConnectionProfile::Tls { host, port, cert } = profile {
125-
let c = Self::new(host, *port, cert)?;
126-
Ok(c)
127-
} else {
128-
Err(Error::ConfigurationError(
129-
"Connection profile is TCP. Expected TLS",
130-
))
131-
}
157+
fn get_connection(host: &str, port: u16, tls_cert: Option<&String>) -> SkyResult<Self> {
158+
let c = Self::new(
159+
&host,
160+
port,
161+
tls_cert.ok_or(Error::ConfigurationError(
162+
"Expected TLS certificate in `ConnectionManager`",
163+
))?,
164+
)?;
165+
Ok(c)
132166
}
133167
fn run_query(&mut self, q: Query) -> SkyQueryResult {
134168
self.run_simple_query(&q)
135169
}
136170
}
137171
impl<C: PoolableConnection + 'static> ManageConnection for ConnectionManager<C> {
138-
type Error = PoolError;
172+
type Error = Error;
139173
type Connection = C;
140174
fn connect(&self) -> Result<Self::Connection, Self::Error> {
141-
C::get_connection(&self.profile).map_err(|e| Self::Error::Other(e))
175+
C::get_connection(self.host.as_ref(), self.port, self.cert.as_ref())
142176
}
143177
fn is_valid(&self, con: &mut Self::Connection) -> Result<(), Self::Error> {
144178
let q = crate::query!("HEYA");
145179
match con.run_query(q)? {
146180
Element::String(st) if st.eq("HEY!") => Ok(()),
147-
_ => Err(PoolError::Other(Error::SkyError(
148-
SkyhashError::UnexpectedResponse,
149-
))),
181+
_ => Err(Error::SkyError(SkyhashError::UnexpectedResponse)),
182+
}
183+
}
184+
fn has_broken(&self, _: &mut Self::Connection) -> bool {
185+
false
186+
}
187+
}
188+
}
189+
190+
#[cfg(any(feature = "async", feature = "aio-pool"))]
191+
mod async_impls {
192+
use super::ConnectionManager;
193+
use crate::aio::{Connection as AsyncConnection, TlsConnection as AsyncTlsConnection};
194+
use crate::{
195+
error::{Error, SkyhashError},
196+
Element, Query, SkyQueryResult, SkyResult,
197+
};
198+
use async_trait::async_trait;
199+
use bb8::{ManageConnection, PooledConnection};
200+
201+
pub type Pool = bb8::Pool<ConnectionManager<AsyncConnection>>;
202+
pub type TlsPool = bb8::Pool<ConnectionManager<AsyncTlsConnection>>;
203+
204+
#[async_trait]
205+
pub trait PoolableConnection: Send + Sync + Sized {
206+
async fn get_connection(
207+
host: &str,
208+
port: u16,
209+
tls_cert: Option<&String>,
210+
) -> SkyResult<Self>;
211+
async fn run_query(&mut self, q: Query) -> SkyQueryResult;
212+
}
213+
214+
#[async_trait]
215+
impl PoolableConnection for AsyncConnection {
216+
async fn get_connection(
217+
host: &str,
218+
port: u16,
219+
_tls_cert: Option<&String>,
220+
) -> SkyResult<Self> {
221+
let con = AsyncConnection::new(&host, port).await?;
222+
Ok(con)
223+
}
224+
async fn run_query(&mut self, q: Query) -> SkyQueryResult {
225+
self.run_simple_query(&q).await
226+
}
227+
}
228+
229+
#[async_trait]
230+
impl PoolableConnection for AsyncTlsConnection {
231+
async fn get_connection(
232+
host: &str,
233+
port: u16,
234+
tls_cert: Option<&String>,
235+
) -> SkyResult<Self> {
236+
let con = AsyncTlsConnection::new(
237+
&host,
238+
port,
239+
tls_cert.ok_or(Error::ConfigurationError(
240+
"Expected TLS certificate in `ConnectionManager`",
241+
))?,
242+
)
243+
.await?;
244+
Ok(con)
245+
}
246+
async fn run_query(&mut self, q: Query) -> SkyQueryResult {
247+
self.run_simple_query(&q).await
248+
}
249+
}
250+
251+
impl ConnectionManager<AsyncConnection> {
252+
pub fn new_notls(host: String, port: u16) -> Self {
253+
Self::_new(host, port, None)
254+
}
255+
}
256+
impl ConnectionManager<AsyncTlsConnection> {
257+
pub fn new_tls(host: String, port: u16, cert: String) -> Self {
258+
Self::_new(host, port, Some(cert))
259+
}
260+
}
261+
262+
#[async_trait]
263+
impl<C: PoolableConnection + 'static> ManageConnection for ConnectionManager<C> {
264+
type Connection = C;
265+
type Error = Error;
266+
async fn connect(&self) -> Result<Self::Connection, Self::Error> {
267+
C::get_connection(&self.host, self.port, self.cert.as_ref()).await
268+
}
269+
async fn is_valid(&self, con: &mut PooledConnection<'_, Self>) -> Result<(), Self::Error> {
270+
match con.run_query(crate::query!("HEYA")).await? {
271+
Element::String(st) if st.eq("HEY!") => Ok(()),
272+
_ => Err(Error::SkyError(SkyhashError::UnexpectedResponse)),
150273
}
151274
}
152275
fn has_broken(&self, _: &mut Self::Connection) -> bool {

0 commit comments

Comments
 (0)