|
1 |
| -use futures::{BoxFuture, Future, IntoFuture, Async}; |
| 1 | +use futures::{BoxFuture, Future, IntoFuture, Async, Sink, Stream as FuturesStream}; |
| 2 | +use futures::future::Either; |
2 | 3 | use postgres_shared::params::Host;
|
3 | 4 | use postgres_protocol::message::backend::{self, ParseResult};
|
| 5 | +use postgres_protocol::message::frontend; |
4 | 6 | use std::io::{self, Read, Write};
|
5 | 7 | use tokio_core::io::{Io, Codec, EasyBuf, Framed};
|
6 | 8 | use tokio_core::net::TcpStream;
|
7 | 9 | use tokio_core::reactor::Handle;
|
8 | 10 | use tokio_dns;
|
9 | 11 | use tokio_uds::UnixStream;
|
10 | 12 |
|
11 |
| -pub type PostgresStream = Framed<InnerStream, PostgresCodec>; |
| 13 | +use TlsMode; |
| 14 | +use error::ConnectError; |
| 15 | +use tls::TlsStream; |
12 | 16 |
|
13 |
| -pub fn connect(host: &Host, |
14 |
| - port: u16, |
15 |
| - handle: &Handle) |
16 |
| - -> BoxFuture<PostgresStream, io::Error> { |
17 |
| - match *host { |
| 17 | +pub type PostgresStream = Framed<Box<TlsStream>, PostgresCodec>; |
| 18 | + |
| 19 | +pub fn connect(host: Host, |
| 20 | + port: u16, |
| 21 | + tls_mode: TlsMode, |
| 22 | + handle: &Handle) |
| 23 | + -> BoxFuture<PostgresStream, ConnectError> { |
| 24 | + let inner = match host { |
18 | 25 | Host::Tcp(ref host) => {
|
19 |
| - tokio_dns::tcp_connect((&**host, port), handle.remote().clone()) |
20 |
| - .map(|s| InnerStream::Tcp(s).framed(PostgresCodec)) |
21 |
| - .boxed() |
| 26 | + Either::A(tokio_dns::tcp_connect((&**host, port), handle.remote().clone()) |
| 27 | + .map(|s| Stream(InnerStream::Tcp(s)))) |
22 | 28 | }
|
23 | 29 | Host::Unix(ref host) => {
|
24 | 30 | let addr = host.join(format!(".s.PGSQL.{}", port));
|
25 |
| - UnixStream::connect(addr, handle) |
26 |
| - .map(|s| InnerStream::Unix(s).framed(PostgresCodec)) |
27 |
| - .into_future() |
28 |
| - .boxed() |
| 31 | + Either::B(UnixStream::connect(addr, handle) |
| 32 | + .map(|s| Stream(InnerStream::Unix(s))) |
| 33 | + .into_future()) |
29 | 34 | }
|
30 |
| - } |
| 35 | + }; |
| 36 | + |
| 37 | + let (required, mut handshaker) = match tls_mode { |
| 38 | + TlsMode::Require(h) => (true, h), |
| 39 | + TlsMode::Prefer(h) => (false, h), |
| 40 | + TlsMode::None => { |
| 41 | + return inner.map(|s| { |
| 42 | + let s: Box<TlsStream> = Box::new(s); |
| 43 | + s.framed(PostgresCodec) |
| 44 | + }) |
| 45 | + .map_err(ConnectError::Io) |
| 46 | + .boxed() |
| 47 | + }, |
| 48 | + }; |
| 49 | + |
| 50 | + inner.map(|s| s.framed(SslCodec)) |
| 51 | + .and_then(|s| { |
| 52 | + let mut buf = vec![]; |
| 53 | + frontend::ssl_request(&mut buf); |
| 54 | + s.send(buf) |
| 55 | + }) |
| 56 | + .and_then(|s| s.into_future().map_err(|e| e.0)) |
| 57 | + .map_err(ConnectError::Io) |
| 58 | + .and_then(move |(m, s)| { |
| 59 | + let s = s.into_inner(); |
| 60 | + match (m, required) { |
| 61 | + (Some(b'N'), true) => { |
| 62 | + Either::A(Err(ConnectError::Tls("the server does not support TLS".into())).into_future()) |
| 63 | + } |
| 64 | + (Some(b'N'), false) => { |
| 65 | + let s: Box<TlsStream> = Box::new(s); |
| 66 | + Either::A(Ok(s).into_future()) |
| 67 | + }, |
| 68 | + (None, _) => Either::A(Err(ConnectError::Io(io::Error::new(io::ErrorKind::UnexpectedEof, "unexpected EOF"))).into_future()), |
| 69 | + _ => { |
| 70 | + let host = match host { |
| 71 | + Host::Tcp(ref host) => host, |
| 72 | + Host::Unix(_) => unreachable!(), |
| 73 | + }; |
| 74 | + Either::B(handshaker.handshake(host, s).map_err(ConnectError::Tls)) |
| 75 | + } |
| 76 | + } |
| 77 | + }) |
| 78 | + .map(|s| s.framed(PostgresCodec)) |
| 79 | + .boxed() |
31 | 80 | }
|
32 | 81 |
|
33 |
| -pub enum InnerStream { |
| 82 | +pub struct Stream(InnerStream); |
| 83 | + |
| 84 | +enum InnerStream { |
34 | 85 | Tcp(TcpStream),
|
35 | 86 | Unix(UnixStream),
|
36 | 87 | }
|
37 | 88 |
|
38 |
| -impl Read for InnerStream { |
| 89 | +impl Read for Stream { |
39 | 90 | fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
40 |
| - match *self { |
| 91 | + match self.0 { |
41 | 92 | InnerStream::Tcp(ref mut s) => s.read(buf),
|
42 | 93 | InnerStream::Unix(ref mut s) => s.read(buf),
|
43 | 94 | }
|
44 | 95 | }
|
45 | 96 | }
|
46 | 97 |
|
47 |
| -impl Write for InnerStream { |
| 98 | +impl Write for Stream { |
48 | 99 | fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
49 |
| - match *self { |
| 100 | + match self.0 { |
50 | 101 | InnerStream::Tcp(ref mut s) => s.write(buf),
|
51 | 102 | InnerStream::Unix(ref mut s) => s.write(buf),
|
52 | 103 | }
|
53 | 104 | }
|
54 | 105 |
|
55 | 106 | fn flush(&mut self) -> io::Result<()> {
|
56 |
| - match *self { |
| 107 | + match self.0 { |
57 | 108 | InnerStream::Tcp(ref mut s) => s.flush(),
|
58 | 109 | InnerStream::Unix(ref mut s) => s.flush(),
|
59 | 110 | }
|
60 | 111 | }
|
61 | 112 | }
|
62 | 113 |
|
63 |
| -impl Io for InnerStream { |
| 114 | +impl Io for Stream { |
64 | 115 | fn poll_read(&mut self) -> Async<()> {
|
65 |
| - match *self { |
| 116 | + match self.0 { |
66 | 117 | InnerStream::Tcp(ref mut s) => s.poll_read(),
|
67 | 118 | InnerStream::Unix(ref mut s) => s.poll_read(),
|
68 | 119 | }
|
69 | 120 | }
|
70 | 121 |
|
71 | 122 | fn poll_write(&mut self) -> Async<()> {
|
72 |
| - match *self { |
| 123 | + match self.0 { |
73 | 124 | InnerStream::Tcp(ref mut s) => s.poll_write(),
|
74 | 125 | InnerStream::Unix(ref mut s) => s.poll_write(),
|
75 | 126 | }
|
@@ -98,3 +149,25 @@ impl Codec for PostgresCodec {
|
98 | 149 | Ok(())
|
99 | 150 | }
|
100 | 151 | }
|
| 152 | + |
| 153 | +struct SslCodec; |
| 154 | + |
| 155 | +impl Codec for SslCodec { |
| 156 | + type In = u8; |
| 157 | + type Out = Vec<u8>; |
| 158 | + |
| 159 | + fn decode(&mut self, buf: &mut EasyBuf) -> io::Result<Option<u8>> { |
| 160 | + if buf.as_slice().is_empty() { |
| 161 | + Ok(None) |
| 162 | + } else { |
| 163 | + let byte = buf.as_slice()[0]; |
| 164 | + buf.drain_to(1); |
| 165 | + Ok(Some(byte)) |
| 166 | + } |
| 167 | + } |
| 168 | + |
| 169 | + fn encode(&mut self, msg: Vec<u8>, buf: &mut Vec<u8>) -> io::Result<()> { |
| 170 | + buf.extend_from_slice(&msg); |
| 171 | + Ok(()) |
| 172 | + } |
| 173 | +} |
0 commit comments