Skip to content

Commit 2d3b9bb

Browse files
committed
Move the TLS mode into config
1 parent dfc614b commit 2d3b9bb

File tree

18 files changed

+356
-424
lines changed

18 files changed

+356
-424
lines changed

postgres/src/client.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::io::{self, Read};
33
use tokio_postgres::types::{ToSql, Type};
44
use tokio_postgres::Error;
55
#[cfg(feature = "runtime")]
6-
use tokio_postgres::{MakeTlsMode, Socket, TlsMode};
6+
use tokio_postgres::{MakeTlsConnect, Socket, TlsConnect};
77

88
#[cfg(feature = "runtime")]
99
use crate::Config;
@@ -15,10 +15,10 @@ impl Client {
1515
#[cfg(feature = "runtime")]
1616
pub fn connect<T>(params: &str, tls_mode: T) -> Result<Client, Error>
1717
where
18-
T: MakeTlsMode<Socket> + 'static + Send,
19-
T::TlsMode: Send,
18+
T: MakeTlsConnect<Socket> + 'static + Send,
19+
T::TlsConnect: Send,
2020
T::Stream: Send,
21-
<T::TlsMode as TlsMode<Socket>>::Future: Send,
21+
<T::TlsConnect as TlsConnect<Socket>>::Future: Send,
2222
{
2323
params.parse::<Config>()?.connect(tls_mode)
2424
}

postgres/src/config.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use log::error;
44
use std::path::Path;
55
use std::str::FromStr;
66
use std::time::Duration;
7-
use tokio_postgres::{Error, MakeTlsMode, Socket, TargetSessionAttrs, TlsMode};
7+
use tokio_postgres::{Error, MakeTlsConnect, Socket, TargetSessionAttrs, TlsConnect};
88

99
use crate::{Client, RUNTIME};
1010

@@ -94,10 +94,10 @@ impl Config {
9494

9595
pub fn connect<T>(&self, tls_mode: T) -> Result<Client, Error>
9696
where
97-
T: MakeTlsMode<Socket> + 'static + Send,
98-
T::TlsMode: Send,
97+
T: MakeTlsConnect<Socket> + 'static + Send,
98+
T::TlsConnect: Send,
9999
T::Stream: Send,
100-
<T::TlsMode as TlsMode<Socket>>::Future: Send,
100+
<T::TlsConnect as TlsConnect<Socket>>::Future: Send,
101101
{
102102
let connect = self.0.connect(tls_mode);
103103
let (client, connection) = oneshot::spawn(connect, &RUNTIME.executor()).wait()?;

tokio-postgres-native-tls/src/test.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@ use futures::{Future, Stream};
22
use native_tls::{self, Certificate};
33
use tokio::net::TcpStream;
44
use tokio::runtime::current_thread::Runtime;
5-
use tokio_postgres::{self, PreferTls, RequireTls, TlsMode};
5+
use tokio_postgres::TlsConnect;
66

77
use crate::TlsConnector;
88

99
fn smoke_test<T>(s: &str, tls: T)
1010
where
11-
T: TlsMode<TcpStream>,
11+
T: TlsConnect<TcpStream>,
1212
T::Stream: 'static,
1313
{
1414
let mut runtime = Runtime::new().unwrap();
@@ -44,8 +44,8 @@ fn require() {
4444
.build()
4545
.unwrap();
4646
smoke_test(
47-
"user=ssl_user dbname=postgres",
48-
RequireTls(TlsConnector::with_connector(connector, "localhost")),
47+
"user=ssl_user dbname=postgres sslmode=require",
48+
TlsConnector::with_connector(connector, "localhost"),
4949
);
5050
}
5151

@@ -59,7 +59,7 @@ fn prefer() {
5959
.unwrap();
6060
smoke_test(
6161
"user=ssl_user dbname=postgres",
62-
PreferTls(TlsConnector::with_connector(connector, "localhost")),
62+
TlsConnector::with_connector(connector, "localhost"),
6363
);
6464
}
6565

@@ -72,7 +72,7 @@ fn scram_user() {
7272
.build()
7373
.unwrap();
7474
smoke_test(
75-
"user=scram_user password=password dbname=postgres",
76-
RequireTls(TlsConnector::with_connector(connector, "localhost")),
75+
"user=scram_user password=password dbname=postgres sslmode=require",
76+
TlsConnector::with_connector(connector, "localhost"),
7777
);
7878
}

tokio-postgres-openssl/src/test.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@ use futures::{Future, Stream};
22
use openssl::ssl::{SslConnector, SslMethod};
33
use tokio::net::TcpStream;
44
use tokio::runtime::current_thread::Runtime;
5-
use tokio_postgres::{self, PreferTls, RequireTls, TlsMode};
5+
use tokio_postgres::TlsConnect;
66

77
use super::*;
88

99
fn smoke_test<T>(s: &str, tls: T)
1010
where
11-
T: TlsMode<TcpStream>,
11+
T: TlsConnect<TcpStream>,
1212
T::Stream: 'static,
1313
{
1414
let mut runtime = Runtime::new().unwrap();
@@ -41,8 +41,8 @@ fn require() {
4141
builder.set_ca_file("../test/server.crt").unwrap();
4242
let ctx = builder.build();
4343
smoke_test(
44-
"user=ssl_user dbname=postgres",
45-
RequireTls(TlsConnector::new(ctx.configure().unwrap(), "localhost")),
44+
"user=ssl_user dbname=postgres sslmode=require",
45+
TlsConnector::new(ctx.configure().unwrap(), "localhost"),
4646
);
4747
}
4848

@@ -53,7 +53,7 @@ fn prefer() {
5353
let ctx = builder.build();
5454
smoke_test(
5555
"user=ssl_user dbname=postgres",
56-
PreferTls(TlsConnector::new(ctx.configure().unwrap(), "localhost")),
56+
TlsConnector::new(ctx.configure().unwrap(), "localhost"),
5757
);
5858
}
5959

@@ -63,8 +63,8 @@ fn scram_user() {
6363
builder.set_ca_file("../test/server.crt").unwrap();
6464
let ctx = builder.build();
6565
smoke_test(
66-
"user=scram_user password=password dbname=postgres",
67-
RequireTls(TlsConnector::new(ctx.configure().unwrap(), "localhost")),
66+
"user=scram_user password=password dbname=postgres sslmode=require",
67+
TlsConnector::new(ctx.configure().unwrap(), "localhost"),
6868
);
6969
}
7070

@@ -78,8 +78,8 @@ fn runtime() {
7878
let connector = MakeTlsConnector::new(builder.build());
7979

8080
let connect = tokio_postgres::connect(
81-
"host=localhost port=5433 user=postgres",
82-
RequireTls(connector),
81+
"host=localhost port=5433 user=postgres sslmode=require",
82+
connector,
8383
);
8484
let (mut client, connection) = runtime.block_on(connect).unwrap();
8585
let connection = connection.map_err(|e| panic!("{}", e));

tokio-postgres/Cargo.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ postgres-protocol = { version = "0.3.0", path = "../postgres-protocol" }
4949
state_machine_future = "0.1.7"
5050
tokio-codec = "0.1"
5151
tokio-io = "0.1"
52-
void = "1.0"
5352

5453
tokio-tcp = { version = "0.1", optional = true }
5554
futures-cpupool = { version = "0.1", optional = true }

tokio-postgres/src/config.rs

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ use tokio_io::{AsyncRead, AsyncWrite};
1919
use crate::proto::ConnectFuture;
2020
use crate::proto::ConnectRawFuture;
2121
#[cfg(feature = "runtime")]
22-
use crate::{Connect, MakeTlsMode, Socket};
23-
use crate::{ConnectRaw, Error, TlsMode};
22+
use crate::{Connect, MakeTlsConnect, Socket};
23+
use crate::{ConnectRaw, Error, TlsConnect};
2424

2525
/// Properties required of a session.
2626
#[cfg(feature = "runtime")]
@@ -34,6 +34,17 @@ pub enum TargetSessionAttrs {
3434
__NonExhaustive,
3535
}
3636

37+
/// TLS configuration.
38+
#[derive(Debug, Copy, Clone, PartialEq)]
39+
pub enum SslMode {
40+
/// Do not use TLS.
41+
Disable,
42+
/// Attempt to connect with TLS but allow sessions without.
43+
Prefer,
44+
/// Require the use of TLS.
45+
Require,
46+
}
47+
3748
#[cfg(feature = "runtime")]
3849
#[derive(Debug, Clone, PartialEq)]
3950
pub(crate) enum Host {
@@ -49,6 +60,7 @@ pub(crate) struct Inner {
4960
pub(crate) dbname: Option<String>,
5061
pub(crate) options: Option<String>,
5162
pub(crate) application_name: Option<String>,
63+
pub(crate) ssl_mode: SslMode,
5264
#[cfg(feature = "runtime")]
5365
pub(crate) host: Vec<Host>,
5466
#[cfg(feature = "runtime")]
@@ -79,6 +91,8 @@ pub(crate) struct Inner {
7991
/// * `dbname` - The name of the database to connect to. Defaults to the username.
8092
/// * `options` - Command line options used to configure the server.
8193
/// * `application_name` - Sets the `application_name` parameter on the server.
94+
/// * `sslmode` - Controls usage of TLS. If set to `disable`, TLS will not be used. If set to `prefer`, TLS will be used
95+
/// if available, but not used otherwise. If set to `require`, TLS will be forced to be used. Defaults to `prefer`.
8296
/// * `host` - The host to connect to. On Unix platforms, if the host starts with a `/` character it is treated as the
8397
/// path to the directory containing Unix domain sockets. Otherwise, it is treated as a hostname. Multiple hosts
8498
/// can be specified, separated by commas. Each host will be tried in turn when connecting. Required if connecting
@@ -152,6 +166,7 @@ impl Config {
152166
dbname: None,
153167
options: None,
154168
application_name: None,
169+
ssl_mode: SslMode::Prefer,
155170
#[cfg(feature = "runtime")]
156171
host: vec![],
157172
#[cfg(feature = "runtime")]
@@ -204,6 +219,14 @@ impl Config {
204219
self
205220
}
206221

222+
/// Sets the SSL configuration.
223+
///
224+
/// Defaults to `prefer`.
225+
pub fn ssl_mode(&mut self, ssl_mode: SslMode) -> &mut Config {
226+
Arc::make_mut(&mut self.0).ssl_mode = ssl_mode;
227+
self
228+
}
229+
207230
/// Adds a host to the configuration.
208231
///
209232
/// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order. On Unix
@@ -320,6 +343,15 @@ impl Config {
320343
"application_name" => {
321344
self.application_name(&value);
322345
}
346+
"sslmode" => {
347+
let mode = match value {
348+
"disable" => SslMode::Disable,
349+
"prefer" => SslMode::Prefer,
350+
"require" => SslMode::Require,
351+
_ => return Err(Error::config_parse(Box::new(InvalidValue("sslmode")))),
352+
};
353+
self.ssl_mode(mode);
354+
}
323355
#[cfg(feature = "runtime")]
324356
"host" => {
325357
for host in value.split(',') {
@@ -390,22 +422,22 @@ impl Config {
390422
///
391423
/// Requires the `runtime` Cargo feature (enabled by default).
392424
#[cfg(feature = "runtime")]
393-
pub fn connect<T>(&self, make_tls_mode: T) -> Connect<T>
425+
pub fn connect<T>(&self, tls: T) -> Connect<T>
394426
where
395-
T: MakeTlsMode<Socket>,
427+
T: MakeTlsConnect<Socket>,
396428
{
397-
Connect(ConnectFuture::new(make_tls_mode, Ok(self.clone())))
429+
Connect(ConnectFuture::new(tls, Ok(self.clone())))
398430
}
399431

400432
/// Connects to a PostgreSQL database over an arbitrary stream.
401433
///
402434
/// All of the settings other than `user`, `password`, `dbname`, `options`, and `application` name are ignored.
403-
pub fn connect_raw<S, T>(&self, stream: S, tls_mode: T) -> ConnectRaw<S, T>
435+
pub fn connect_raw<S, T>(&self, stream: S, tls: T) -> ConnectRaw<S, T>
404436
where
405437
S: AsyncRead + AsyncWrite,
406-
T: TlsMode<S>,
438+
T: TlsConnect<S>,
407439
{
408-
ConnectRaw(ConnectRawFuture::new(stream, tls_mode, self.clone(), None))
440+
ConnectRaw(ConnectRawFuture::new(stream, tls, self.clone(), None))
409441
}
410442
}
411443

0 commit comments

Comments
 (0)