Skip to content

Commit de09725

Browse files
committed
TLS support
1 parent d8aed09 commit de09725

File tree

6 files changed

+250
-41
lines changed

6 files changed

+250
-41
lines changed

postgres-tokio/Cargo.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ name = "postgres-tokio"
33
version = "0.1.0"
44
authors = ["Steven Fackler <[email protected]>"]
55

6+
[features]
7+
with-openssl = ["tokio-openssl", "openssl"]
8+
69
[dependencies]
710
fallible-iterator = "0.1.3"
811
futures = "0.1.7"
@@ -12,3 +15,6 @@ postgres-protocol = "0.2"
1215
tokio-core = "0.1"
1316
tokio-dns-unofficial = "0.1"
1417
tokio-uds = "0.1"
18+
19+
tokio-openssl = { version = "0.1", optional = true }
20+
openssl = { version = "0.9", optional = true }

postgres-tokio/src/lib.rs

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@ extern crate tokio_core;
77
extern crate tokio_dns;
88
extern crate tokio_uds;
99

10+
#[cfg(feature = "tokio-openssl")]
11+
extern crate tokio_openssl;
12+
#[cfg(feature = "openssl")]
13+
extern crate openssl;
14+
1015
use fallible_iterator::FallibleIterator;
1116
use futures::{Future, IntoFuture, BoxFuture, Stream, Sink, Poll, StartSend};
1217
use futures::future::Either;
@@ -31,13 +36,21 @@ use error::{ConnectError, Error, DbError};
3136
use params::{ConnectParams, IntoConnectParams};
3237
use stream::PostgresStream;
3338
use types::{Oid, Type, ToSql, SessionInfo, IsNull, FromSql, WrongType};
39+
use tls::Handshake;
3440

3541
pub mod error;
3642
mod stream;
43+
pub mod tls;
3744

3845
#[cfg(test)]
3946
mod test;
4047

48+
pub enum TlsMode {
49+
Require(Box<Handshake>),
50+
Prefer(Box<Handshake>),
51+
None,
52+
}
53+
4154
#[derive(Debug, Copy, Clone)]
4255
pub struct CancelData {
4356
pub process_id: i32,
@@ -119,16 +132,18 @@ impl fmt::Debug for Connection {
119132
}
120133

121134
impl Connection {
122-
pub fn connect<T>(params: T, handle: &Handle) -> BoxFuture<Connection, ConnectError>
135+
pub fn connect<T>(params: T,
136+
tls_mode: TlsMode,
137+
handle: &Handle)
138+
-> BoxFuture<Connection, ConnectError>
123139
where T: IntoConnectParams
124140
{
125141
let params = match params.into_connect_params() {
126142
Ok(params) => params,
127143
Err(e) => return futures::failed(ConnectError::ConnectParams(e)).boxed(),
128144
};
129145

130-
stream::connect(params.host(), params.port(), handle)
131-
.map_err(ConnectError::Io)
146+
stream::connect(params.host().clone(), params.port(), tls_mode, handle)
132147
.map(|s| {
133148
let (sender, receiver) = mpsc::channel();
134149
Connection(InnerConnection {

postgres-tokio/src/stream.rs

Lines changed: 97 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,75 +1,126 @@
1-
use futures::{BoxFuture, Future, IntoFuture, Async};
1+
use futures::{BoxFuture, Future, IntoFuture, Async, Sink, Stream as FuturesStream};
2+
use futures::future::Either;
23
use postgres_shared::params::Host;
34
use postgres_protocol::message::backend::{self, ParseResult};
5+
use postgres_protocol::message::frontend;
46
use std::io::{self, Read, Write};
57
use tokio_core::io::{Io, Codec, EasyBuf, Framed};
68
use tokio_core::net::TcpStream;
79
use tokio_core::reactor::Handle;
810
use tokio_dns;
911
use tokio_uds::UnixStream;
1012

11-
pub type PostgresStream = Framed<InnerStream, PostgresCodec>;
13+
use TlsMode;
14+
use error::ConnectError;
15+
use tls::TlsStream;
1216

13-
pub fn connect(host: &Host,
14-
port: u16,
15-
handle: &Handle)
16-
-> BoxFuture<PostgresStream, io::Error> {
17-
match *host {
17+
pub type PostgresStream = Framed<Box<TlsStream>, PostgresCodec>;
18+
19+
pub fn connect(host: Host,
20+
port: u16,
21+
tls_mode: TlsMode,
22+
handle: &Handle)
23+
-> BoxFuture<PostgresStream, ConnectError> {
24+
let inner = match host {
1825
Host::Tcp(ref host) => {
19-
tokio_dns::tcp_connect((&**host, port), handle.remote().clone())
20-
.map(|s| InnerStream::Tcp(s).framed(PostgresCodec))
21-
.boxed()
26+
Either::A(tokio_dns::tcp_connect((&**host, port), handle.remote().clone())
27+
.map(|s| Stream(InnerStream::Tcp(s))))
2228
}
2329
Host::Unix(ref host) => {
2430
let addr = host.join(format!(".s.PGSQL.{}", port));
25-
UnixStream::connect(addr, handle)
26-
.map(|s| InnerStream::Unix(s).framed(PostgresCodec))
27-
.into_future()
28-
.boxed()
31+
Either::B(UnixStream::connect(addr, handle)
32+
.map(|s| Stream(InnerStream::Unix(s)))
33+
.into_future())
2934
}
30-
}
35+
};
36+
37+
let (required, mut handshaker) = match tls_mode {
38+
TlsMode::Require(h) => (true, h),
39+
TlsMode::Prefer(h) => (false, h),
40+
TlsMode::None => {
41+
return inner.map(|s| {
42+
let s: Box<TlsStream> = Box::new(s);
43+
s.framed(PostgresCodec)
44+
})
45+
.map_err(ConnectError::Io)
46+
.boxed()
47+
},
48+
};
49+
50+
inner.map(|s| s.framed(SslCodec))
51+
.and_then(|s| {
52+
let mut buf = vec![];
53+
frontend::ssl_request(&mut buf);
54+
s.send(buf)
55+
})
56+
.and_then(|s| s.into_future().map_err(|e| e.0))
57+
.map_err(ConnectError::Io)
58+
.and_then(move |(m, s)| {
59+
let s = s.into_inner();
60+
match (m, required) {
61+
(Some(b'N'), true) => {
62+
Either::A(Err(ConnectError::Tls("the server does not support TLS".into())).into_future())
63+
}
64+
(Some(b'N'), false) => {
65+
let s: Box<TlsStream> = Box::new(s);
66+
Either::A(Ok(s).into_future())
67+
},
68+
(None, _) => Either::A(Err(ConnectError::Io(io::Error::new(io::ErrorKind::UnexpectedEof, "unexpected EOF"))).into_future()),
69+
_ => {
70+
let host = match host {
71+
Host::Tcp(ref host) => host,
72+
Host::Unix(_) => unreachable!(),
73+
};
74+
Either::B(handshaker.handshake(host, s).map_err(ConnectError::Tls))
75+
}
76+
}
77+
})
78+
.map(|s| s.framed(PostgresCodec))
79+
.boxed()
3180
}
3281

33-
pub enum InnerStream {
82+
pub struct Stream(InnerStream);
83+
84+
enum InnerStream {
3485
Tcp(TcpStream),
3586
Unix(UnixStream),
3687
}
3788

38-
impl Read for InnerStream {
89+
impl Read for Stream {
3990
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
40-
match *self {
91+
match self.0 {
4192
InnerStream::Tcp(ref mut s) => s.read(buf),
4293
InnerStream::Unix(ref mut s) => s.read(buf),
4394
}
4495
}
4596
}
4697

47-
impl Write for InnerStream {
98+
impl Write for Stream {
4899
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
49-
match *self {
100+
match self.0 {
50101
InnerStream::Tcp(ref mut s) => s.write(buf),
51102
InnerStream::Unix(ref mut s) => s.write(buf),
52103
}
53104
}
54105

55106
fn flush(&mut self) -> io::Result<()> {
56-
match *self {
107+
match self.0 {
57108
InnerStream::Tcp(ref mut s) => s.flush(),
58109
InnerStream::Unix(ref mut s) => s.flush(),
59110
}
60111
}
61112
}
62113

63-
impl Io for InnerStream {
114+
impl Io for Stream {
64115
fn poll_read(&mut self) -> Async<()> {
65-
match *self {
116+
match self.0 {
66117
InnerStream::Tcp(ref mut s) => s.poll_read(),
67118
InnerStream::Unix(ref mut s) => s.poll_read(),
68119
}
69120
}
70121

71122
fn poll_write(&mut self) -> Async<()> {
72-
match *self {
123+
match self.0 {
73124
InnerStream::Tcp(ref mut s) => s.poll_write(),
74125
InnerStream::Unix(ref mut s) => s.poll_write(),
75126
}
@@ -98,3 +149,25 @@ impl Codec for PostgresCodec {
98149
Ok(())
99150
}
100151
}
152+
153+
struct SslCodec;
154+
155+
impl Codec for SslCodec {
156+
type In = u8;
157+
type Out = Vec<u8>;
158+
159+
fn decode(&mut self, buf: &mut EasyBuf) -> io::Result<Option<u8>> {
160+
if buf.as_slice().is_empty() {
161+
Ok(None)
162+
} else {
163+
let byte = buf.as_slice()[0];
164+
buf.drain_to(1);
165+
Ok(Some(byte))
166+
}
167+
}
168+
169+
fn encode(&mut self, msg: Vec<u8>, buf: &mut Vec<u8>) -> io::Result<()> {
170+
buf.extend_from_slice(&msg);
171+
Ok(())
172+
}
173+
}

0 commit comments

Comments
 (0)