Skip to content

Commit 2b2c795

Browse files
committed
Add r2d2 impl
1 parent 6aa409f commit 2b2c795

File tree

7 files changed

+191
-12
lines changed

7 files changed

+191
-12
lines changed

Cargo.toml

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,22 @@ version = "0.6.2"
1414
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
1515

1616
[features]
17+
default = ["sync"]
18+
# sync
19+
sync = ["pool"]
20+
pool = ["r2d2"]
21+
# sync TLS
22+
ssl = ["openssl"]
23+
sslv = ["openssl/vendored"]
24+
# async
25+
async = ["bytes", "tokio", "aio-pool"]
26+
aio-pool = ["bb8"]
27+
# async TLS
1728
aio-ssl = ["tokio-openssl", "openssl"]
1829
aio-sslv = ["tokio-openssl", "openssl/vendored"]
19-
async = ["bytes", "tokio"]
30+
# utilities
2031
const-gen = []
2132
dbg = []
22-
default = ["sync"]
23-
ssl = ["openssl"]
24-
sslv = ["openssl/vendored"]
25-
sync = []
2633

2734
[dependencies]
2835
bytes = { version = "1.1.0", optional = true }
@@ -33,6 +40,9 @@ tokio = { version = "1.15.0", features = [
3340
"io-std",
3441
], optional = true }
3542
tokio-openssl = { version = "0.6.3", optional = true }
43+
r2d2 = { version = "0.8.9", optional = true }
44+
bb8 = { version = "0.7.1", optional = true }
45+
3646
[package.metadata.docs.rs]
3747
all-features = true
3848
rustdoc-args = ["--cfg", "docsrs"]

src/actions.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,10 @@ use crate::SkyQueryResult;
5252
use crate::SkyResult;
5353

5454
cfg_async!(
55-
use core::{future::Future, pin::Pin};
55+
use crate::AsyncResult;
5656
);
5757

5858
cfg_async!(
59-
/// A special result that is returned when running actions (async)
60-
pub type AsyncResult<'s, T> = Pin<Box<dyn Future<Output = T> + Send + Sync + 's>>;
6159
#[doc(hidden)]
6260
/// A raw async connection to the database server
6361
pub trait AsyncSocket: Send + Sync {

src/aio.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ macro_rules! impl_async_methods {
110110
}
111111
}
112112
impl crate::actions::AsyncSocket for $ty {
113-
fn run(&mut self, q: Query) -> crate::actions::AsyncResult<SkyQueryResult> {
113+
fn run(&mut self, q: Query) -> crate::AsyncResult<SkyQueryResult> {
114114
Box::pin(async move { self.run_simple_query(&q).await })
115115
}
116116
}

src/ddl.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ use crate::RespCode;
4343
use crate::SkyResult;
4444

4545
cfg_async! {
46-
use crate::actions::AsyncResult;
46+
use crate::AsyncResult;
4747
use crate::actions::AsyncSocket;
4848
}
4949

src/error.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ pub enum Error {
9393
SkyError(SkyhashError),
9494
/// An application level parse error occurred
9595
ParseError(String),
96+
/// An error occurred in the configuration of the connection
97+
ConfigurationError(&'static str),
9698
}
9799

98100
impl PartialEq for Error {
@@ -118,6 +120,7 @@ impl PartialEq for Error {
118120
)))
119121
)]
120122
(SslError(a), SslError(b)) => a.to_string() == b.to_string(),
123+
(ConfigurationError(a), ConfigurationError(b)) => a == b,
121124
_ => false,
122125
}
123126
}
@@ -160,6 +163,7 @@ impl fmt::Display for Error {
160163
write!(f, "Server sent unknown data type for this client version")
161164
}
162165
},
166+
Self::ConfigurationError(e) => write!(f, "Connection setup error: {}", e),
163167
}
164168
}
165169
}
@@ -219,3 +223,5 @@ impl From<std::string::FromUtf8Error> for Error {
219223
Self::ParseError(e.to_string())
220224
}
221225
}
226+
227+
impl std::error::Error for Error {}

src/lib.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,13 @@ mod util;
181181
pub mod actions;
182182
pub mod ddl;
183183
pub mod error;
184+
#[cfg(any(
185+
feature = "sync",
186+
feature = "pool",
187+
feature = "async",
188+
feature = "aio-pool"
189+
))]
190+
pub mod pool;
184191
pub mod types;
185192
// endof public mods
186193
// private mods
@@ -202,10 +209,14 @@ pub const DEFAULT_PORT: u16 = 2003;
202209
pub const DEFAULT_ENTITY: &str = "default:default";
203210

204211
cfg_async!(
212+
use core::{future::Future, pin::Pin};
205213
pub mod aio;
206214
pub use aio::Connection as AsyncConnection;
207215
use tokio::io::AsyncWriteExt;
216+
/// A special result that is returned when running actions (async)
217+
pub type AsyncResult<'s, T> = Pin<Box<dyn Future<Output = T> + Send + Sync + 's>>;
208218
);
219+
209220
cfg_sync!(
210221
pub mod sync;
211222
pub use sync::Connection;
@@ -364,8 +375,6 @@ cfg_sync! {
364375
}
365376

366377
cfg_async! {
367-
use core::pin::Pin;
368-
use core::future::Future;
369378
use tokio::io::AsyncWrite;
370379
type FutureRet<'s> = Pin<Box<dyn Future<Output = IoResult<()>> + Send + Sync + 's>>;
371380
trait WriteQueryAsync<T: AsyncWrite + Unpin + Send + Sync>: Unpin + Sync + Send {

src/pool.rs

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
/*
2+
* Copyright 2022, Sayan Nandan <[email protected]>
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
// re-exports
18+
#[cfg(any(feature = "sync", feature = "pool"))]
19+
pub use self::sync_impls::{ConnectionManager, Pool, TlsPool};
20+
21+
use crate::error::Error;
22+
use core::fmt;
23+
24+
#[derive(Debug)]
25+
pub enum PoolError {
26+
PoolError(String),
27+
Other(Error),
28+
}
29+
30+
impl From<Error> for PoolError {
31+
fn from(e: Error) -> Self {
32+
Self::Other(e)
33+
}
34+
}
35+
36+
impl fmt::Display for PoolError {
37+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
38+
match self {
39+
Self::PoolError(e) => write!(f, "Pool error: {}", e),
40+
Self::Other(e) => write!(f, "Pool connection error: {}", e),
41+
}
42+
}
43+
}
44+
45+
impl std::error::Error for PoolError {}
46+
47+
#[derive(Debug)]
48+
pub enum ConnectionProfile {
49+
NoTls {
50+
host: String,
51+
port: u16,
52+
},
53+
Tls {
54+
host: String,
55+
port: u16,
56+
cert: String,
57+
},
58+
}
59+
60+
#[cfg(any(feature = "sync", feature = "pool"))]
61+
mod sync_impls {
62+
use super::{ConnectionProfile, PoolError};
63+
use crate::sync::{Connection as SyncConnection, TlsConnection as SyncTlsConnection};
64+
use crate::{
65+
error::{Error, SkyhashError},
66+
Element, Query, SkyQueryResult, SkyResult,
67+
};
68+
use core::marker::PhantomData;
69+
use r2d2::ManageConnection;
70+
71+
pub type Pool = r2d2::Pool<ConnectionManager<SyncConnection>>;
72+
pub type TlsPool = r2d2::Pool<ConnectionManager<SyncTlsConnection>>;
73+
74+
#[derive(Debug)]
75+
pub struct ConnectionManager<C> {
76+
profile: ConnectionProfile,
77+
_m: PhantomData<C>,
78+
}
79+
80+
impl<C> ConnectionManager<C> {
81+
fn _new(profile: ConnectionProfile) -> Self {
82+
Self {
83+
profile,
84+
_m: PhantomData,
85+
}
86+
}
87+
}
88+
89+
impl ConnectionManager<SyncConnection> {
90+
pub fn new(host: String, port: u16) -> Self {
91+
Self::_new(ConnectionProfile::NoTls { host, port })
92+
}
93+
}
94+
95+
impl ConnectionManager<SyncTlsConnection> {
96+
pub fn new_tls(host: String, port: u16, cert: String) -> Self {
97+
Self::_new(ConnectionProfile::Tls { host, port, cert })
98+
}
99+
}
100+
101+
pub trait PoolableConnection: Send + Sync + Sized {
102+
fn get_connection(profile: &ConnectionProfile) -> SkyResult<Self>;
103+
fn run_query(&mut self, q: Query) -> SkyQueryResult;
104+
}
105+
106+
impl PoolableConnection for SyncConnection {
107+
fn get_connection(profile: &ConnectionProfile) -> SkyResult<Self> {
108+
if let ConnectionProfile::NoTls { host, port } = profile {
109+
let c = Self::new(host, *port)?;
110+
Ok(c)
111+
} else {
112+
Err(Error::ConfigurationError(
113+
"Connection profile is TLS. Expected TCP",
114+
))
115+
}
116+
}
117+
fn run_query(&mut self, q: Query) -> SkyQueryResult {
118+
self.run_simple_query(&q)
119+
}
120+
}
121+
122+
impl PoolableConnection for SyncTlsConnection {
123+
fn get_connection(profile: &ConnectionProfile) -> SkyResult<Self> {
124+
if let ConnectionProfile::Tls { host, port, cert } = profile {
125+
let c = Self::new(host, *port, cert)?;
126+
Ok(c)
127+
} else {
128+
Err(Error::ConfigurationError(
129+
"Connection profile is TCP. Expected TLS",
130+
))
131+
}
132+
}
133+
fn run_query(&mut self, q: Query) -> SkyQueryResult {
134+
self.run_simple_query(&q)
135+
}
136+
}
137+
impl<C: PoolableConnection + 'static> ManageConnection for ConnectionManager<C> {
138+
type Error = PoolError;
139+
type Connection = C;
140+
fn connect(&self) -> Result<Self::Connection, Self::Error> {
141+
C::get_connection(&self.profile).map_err(|e| Self::Error::Other(e))
142+
}
143+
fn is_valid(&self, con: &mut Self::Connection) -> Result<(), Self::Error> {
144+
let q = crate::query!("HEYA");
145+
match con.run_query(q)? {
146+
Element::String(st) if st.eq("HEY!") => Ok(()),
147+
_ => Err(PoolError::Other(Error::SkyError(
148+
SkyhashError::UnexpectedResponse,
149+
))),
150+
}
151+
}
152+
fn has_broken(&self, _: &mut Self::Connection) -> bool {
153+
false
154+
}
155+
}
156+
}

0 commit comments

Comments
 (0)