Skip to content

Commit 2fa605e

Browse files
committed
Add connection builder and new error types
1 parent b3daab2 commit 2fa605e

File tree

4 files changed

+175
-10
lines changed

4 files changed

+175
-10
lines changed

src/aio.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ macro_rules! impl_async_methods {
6969
self.buffer.clear();
7070
return Ok(Response::InvalidResponse);
7171
}
72-
ParseError::DataTypeParseError => return Ok(Response::ParseError),
72+
ParseError::DataTypeError => return Ok(Response::ParseError),
7373
ParseError::Empty => {
7474
return Err(Error::from(ErrorKind::ConnectionReset))
7575
}

src/deserializer.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ pub enum ParseError {
9292
/// A data type was given but the parser failed to serialize it into this type
9393
///
9494
/// This can happen not just for elements but can also happen for their sizes ([`Self::parse_into_u64`])
95-
DataTypeParseError,
95+
DataTypeError,
9696
/// A data type that the client doesn't know was passed into the query
9797
///
9898
/// This is a frequent problem that can arise between different server editions as more data types
@@ -188,7 +188,7 @@ impl<'a> Parser<'a> {
188188
for dig in byte_iter {
189189
if !dig.is_ascii_digit() {
190190
// dig has to be an ASCII digit
191-
return Err(ParseError::DataTypeParseError);
191+
return Err(ParseError::DataTypeError);
192192
}
193193
// 48 is the ASCII code for 0, and 57 is the ascii code for 9
194194
// so if 0 is given, the subtraction should give 0; similarly
@@ -200,11 +200,11 @@ impl<'a> Parser<'a> {
200200
// The usize can overflow; check that case
201201
let product = match item_usize.checked_mul(10) {
202202
Some(not_overflowed) => not_overflowed,
203-
None => return Err(ParseError::DataTypeParseError),
203+
None => return Err(ParseError::DataTypeError),
204204
};
205205
let sum = match product.checked_add(curdig) {
206206
Some(not_overflowed) => not_overflowed,
207-
None => return Err(ParseError::DataTypeParseError),
207+
None => return Err(ParseError::DataTypeError),
208208
};
209209
item_usize = sum;
210210
}
@@ -220,7 +220,7 @@ impl<'a> Parser<'a> {
220220
for dig in byte_iter {
221221
if !dig.is_ascii_digit() {
222222
// dig has to be an ASCII digit
223-
return Err(ParseError::DataTypeParseError);
223+
return Err(ParseError::DataTypeError);
224224
}
225225
// 48 is the ASCII code for 0, and 57 is the ascii code for 9
226226
// so if 0 is given, the subtraction should give 0; similarly
@@ -232,11 +232,11 @@ impl<'a> Parser<'a> {
232232
// Now the entire u64 can overflow, so let's attempt to check it
233233
let product = match item_u64.checked_mul(10) {
234234
Some(not_overflowed) => not_overflowed,
235-
None => return Err(ParseError::DataTypeParseError),
235+
None => return Err(ParseError::DataTypeError),
236236
};
237237
let sum = match product.checked_add(curdig) {
238238
Some(not_overflowed) => not_overflowed,
239-
None => return Err(ParseError::DataTypeParseError),
239+
None => return Err(ParseError::DataTypeError),
240240
};
241241
item_u64 = sum;
242242
}
@@ -340,7 +340,7 @@ impl<'a> Parser<'a> {
340340
fn parse_next_flat_array(&mut self) -> ParseResult<Vec<Vec<u8>>> {
341341
let (start, stop) = self.read_line();
342342
if let Some(our_size_chunk) = self.buffer.get(start..stop) {
343-
let array_size = Self::parse_into_usize(&our_size_chunk)?;
343+
let array_size = Self::parse_into_usize(our_size_chunk)?;
344344
let mut array = Vec::with_capacity(array_size);
345345
for _ in 0..array_size {
346346
if let Some(tsymbol) = self.buffer.get(self.cursor) {

src/lib.rs

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,106 @@ cfg_sync!(
149149
pub use sync::Connection;
150150
);
151151

152+
#[derive(Debug)]
153+
/// A connection builder for easily building connections
154+
///
155+
/// ## Example (sync)
156+
/// ```no_run
157+
/// let con =
158+
/// ConnectionBuilder::new()
159+
/// .set_host("127.0.0.1")
160+
/// .set_port(2003)
161+
/// .get_connection()
162+
/// .unwrap();
163+
/// ```
164+
///
165+
/// ## Example (async)
166+
/// ```no_run
167+
/// let con =
168+
/// ConnectionBuilder::new()
169+
/// .set_host("127.0.0.1")
170+
/// .set_port(2003)
171+
/// .get_async_connection()
172+
/// .unwrap();
173+
/// ```
174+
pub struct ConnectionBuilder<'a> {
175+
port: Option<u16>,
176+
host: Option<&'a str>,
177+
}
178+
179+
impl<'a> Default for ConnectionBuilder<'a> {
180+
fn default() -> Self {
181+
Self::new()
182+
}
183+
}
184+
185+
pub type ConnectionBuilderResult<T> = Result<T, error::Error>;
186+
187+
impl<'a> ConnectionBuilder<'a> {
188+
/// Create an empty connection builder
189+
pub fn new() -> Self {
190+
Self {
191+
port: None,
192+
host: None,
193+
}
194+
}
195+
/// Set the port
196+
pub fn set_port(mut self, port: u16) -> Self {
197+
self.port = Some(port);
198+
self
199+
}
200+
/// Set the host
201+
pub fn set_host(mut self, host: &'a str) -> Self {
202+
self.host = Some(host);
203+
self
204+
}
205+
cfg_sync! {
206+
/// Get a [sync connection](sync::Connection) to the database
207+
pub fn get_connection(&self) -> ConnectionBuilderResult<sync::Connection> {
208+
let con =
209+
sync::Connection::new(self.host.unwrap_or("127.0.0.1"), self.port.unwrap_or(2003))?;
210+
Ok(con)
211+
}
212+
cfg_sync_ssl_any! {
213+
/// Get a [sync TLS connection](sync::TlsConnection) to the database
214+
pub fn get_tls_connection(
215+
&self,
216+
sslcert: String,
217+
) -> ConnectionBuilderResult<sync::TlsConnection> {
218+
let con = sync::TlsConnection::new(
219+
self.host.unwrap_or("127.0.0.1"),
220+
self.port.unwrap_or(2003),
221+
&sslcert,
222+
)?;
223+
Ok(con)
224+
}
225+
}
226+
}
227+
cfg_async! {
228+
/// Get an [async connection](aio::Connection) to the database
229+
pub async fn get_async_connection(&self) -> ConnectionBuilderResult<aio::Connection> {
230+
let con = aio::Connection::new(self.host.unwrap_or("127.0.0.1"), self.port.unwrap_or(2003))
231+
.await?;
232+
Ok(con)
233+
}
234+
cfg_async_ssl_any! {
235+
/// Get an [async TLS connection](aio::TlsConnection) to the database
236+
pub async fn get_async_tls_connection(
237+
&self,
238+
sslcert: String,
239+
) -> ConnectionBuilderResult<aio::TlsConnection> {
240+
let con = aio::TlsConnection::new(
241+
self.host.unwrap_or("127.0.0.1"),
242+
self.port.unwrap_or(2003),
243+
&sslcert,
244+
)
245+
.await?;
246+
Ok(con)
247+
}
248+
}
249+
}
250+
}
251+
152252
#[macro_export]
153253
/// A macro that can be used to easily create queries with _almost_ variadic properties.
154254
/// Where you'd normally create queries like this:
@@ -438,4 +538,69 @@ pub mod error {
438538
}
439539
}
440540
);
541+
#[derive(Debug)]
542+
/// An error originating from the Skyhash protocol
543+
pub enum SkyhashError {
544+
/// The server sent an invalid response
545+
InvalidResponse,
546+
/// The server sent a response but it could not be parsed
547+
ParseError,
548+
/// The server sent a data type not supported by this client version
549+
UnsupportedDataType,
550+
}
551+
552+
#[derive(Debug)]
553+
/// A standard error type for the client driver
554+
pub enum Error {
555+
/// An I/O error occurred
556+
IoError(std::io::Error),
557+
#[cfg(all(feature = "sync", any(feature = "ssl", feature = "sslv")))]
558+
#[cfg_attr(
559+
docsrs,
560+
doc(cfg(all(feature = "sync", any(feature = "ssl", feature = "sslv"))))
561+
)]
562+
/// An SSL error occurred
563+
SslError(openssl::ssl::Error),
564+
/// A Skyhash error occurred
565+
SkyError(SkyhashError),
566+
/// An application level parse error occurred
567+
ParseError,
568+
}
569+
570+
#[cfg(all(feature = "sync", any(feature = "ssl", feature = "sslv")))]
571+
#[cfg_attr(
572+
docsrs,
573+
doc(cfg(all(feature = "sync", any(feature = "ssl", feature = "sslv"))))
574+
)]
575+
impl From<openssl::ssl::Error> for Error {
576+
fn from(err: openssl::ssl::Error) -> Self {
577+
Self::SslError(err)
578+
}
579+
}
580+
581+
impl From<std::io::Error> for Error {
582+
fn from(err: std::io::Error) -> Self {
583+
Self::IoError(err)
584+
}
585+
}
586+
587+
#[cfg(all(feature = "sync", any(feature = "ssl", feature = "sslv")))]
588+
#[cfg_attr(
589+
docsrs,
590+
doc(cfg(all(feature = "sync", any(feature = "ssl", feature = "sslv"))))
591+
)]
592+
impl From<SslError> for Error {
593+
fn from(err: SslError) -> Self {
594+
match err {
595+
SslError::IoError(ioerr) => Self::IoError(ioerr),
596+
SslError::SslError(sslerr) => Self::SslError(sslerr),
597+
}
598+
}
599+
}
600+
601+
impl From<SkyhashError> for Error {
602+
fn from(err: SkyhashError) -> Self {
603+
Self::SkyError(err)
604+
}
605+
}
441606
}

src/sync.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ macro_rules! impl_sync_methods {
7070
self.buffer.clear();
7171
return Ok(Response::InvalidResponse);
7272
}
73-
ParseError::DataTypeParseError => return Ok(Response::ParseError),
73+
ParseError::DataTypeError => return Ok(Response::ParseError),
7474
ParseError::Empty => {
7575
return Err(Error::from(ErrorKind::ConnectionReset))
7676
}

0 commit comments

Comments
 (0)