Skip to content

Commit 6c3a4ab

Browse files
committed
Add channel_binding=disable/prefer/require to config
Closes #487
1 parent a8d945c commit 6c3a4ab

File tree

7 files changed

+124
-109
lines changed

7 files changed

+124
-109
lines changed

postgres-native-tls/src/test.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@ where
1212
T: TlsConnect<TcpStream>,
1313
T::Stream: 'static + Send,
1414
{
15-
let stream = TcpStream::connect("127.0.0.1:5433")
16-
.await
17-
.unwrap();
15+
let stream = TcpStream::connect("127.0.0.1:5433").await.unwrap();
1816

1917
let builder = s.parse::<tokio_postgres::Config>().unwrap();
2018
let (mut client, connection) = builder.connect_raw(stream, tls).await.unwrap();

postgres-openssl/src/test.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,32 @@ async fn scram_user() {
6565
.await;
6666
}
6767

68+
#[tokio::test]
69+
async fn require_channel_binding_err() {
70+
let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
71+
builder.set_ca_file("../test/server.crt").unwrap();
72+
let ctx = builder.build();
73+
let connector = TlsConnector::new(ctx.configure().unwrap(), "localhost");
74+
75+
let stream = TcpStream::connect("127.0.0.1:5433").await.unwrap();
76+
let builder = "user=pass_user password=password dbname=postgres channel_binding=require"
77+
.parse::<tokio_postgres::Config>()
78+
.unwrap();
79+
builder.connect_raw(stream, connector).await.err().unwrap();
80+
}
81+
82+
#[tokio::test]
83+
async fn require_channel_binding_ok() {
84+
let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
85+
builder.set_ca_file("../test/server.crt").unwrap();
86+
let ctx = builder.build();
87+
smoke_test(
88+
"user=scram_user password=password dbname=postgres channel_binding=require",
89+
TlsConnector::new(ctx.configure().unwrap(), "localhost"),
90+
)
91+
.await;
92+
}
93+
6894
#[tokio::test]
6995
#[cfg(feature = "runtime")]
7096
async fn runtime() {

postgres/src/config.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use tokio_postgres::tls::{MakeTlsConnect, TlsConnect};
1414
use tokio_postgres::{Error, Socket};
1515

1616
#[doc(inline)]
17-
pub use tokio_postgres::config::{SslMode, TargetSessionAttrs};
17+
pub use tokio_postgres::config::{SslMode, TargetSessionAttrs, ChannelBinding};
1818

1919
use crate::{Client, RUNTIME};
2020

@@ -234,6 +234,14 @@ impl Config {
234234
self
235235
}
236236

237+
/// Sets the channel binding behavior.
238+
///
239+
/// Defaults to `prefer`.
240+
pub fn channel_binding(&mut self, channel_binding: ChannelBinding) -> &mut Config {
241+
self.config.channel_binding(channel_binding);
242+
self
243+
}
244+
237245
/// Sets the executor used to run the connection futures.
238246
///
239247
/// Defaults to a postgres-specific tokio `Runtime`.

tokio-postgres/src/config.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,19 @@ pub enum SslMode {
4646
__NonExhaustive,
4747
}
4848

49+
/// Channel binding configuration.
50+
#[derive(Debug, Copy, Clone, PartialEq)]
51+
pub enum ChannelBinding {
52+
/// Do not use channel binding.
53+
Disable,
54+
/// Attempt to use channel binding but allow sessions without.
55+
Prefer,
56+
/// Require the use of channel binding.
57+
Require,
58+
#[doc(hidden)]
59+
__NonExhaustive,
60+
}
61+
4962
#[derive(Debug, Clone, PartialEq)]
5063
pub(crate) enum Host {
5164
Tcp(String),
@@ -87,6 +100,9 @@ pub(crate) enum Host {
87100
/// * `target_session_attrs` - Specifies requirements of the session. If set to `read-write`, the client will check that
88101
/// the `transaction_read_write` session parameter is set to `on`. This can be used to connect to the primary server
89102
/// in a database cluster as opposed to the secondary read-only mirrors. Defaults to `all`.
103+
/// * `channel_binding` - Controls usage of channel binding in the authentication process. If set to `disable`, channel
104+
/// binding will not be used. If set to `prefer`, channel binding will be used if available, but not used otherwise.
105+
/// If set to `require`, the authentication process will fail if channel binding is not used. Defaults to `prefer`.
90106
///
91107
/// ## Examples
92108
///
@@ -140,6 +156,7 @@ pub struct Config {
140156
pub(crate) keepalives: bool,
141157
pub(crate) keepalives_idle: Duration,
142158
pub(crate) target_session_attrs: TargetSessionAttrs,
159+
pub(crate) channel_binding: ChannelBinding,
143160
}
144161

145162
impl Default for Config {
@@ -164,6 +181,7 @@ impl Config {
164181
keepalives: true,
165182
keepalives_idle: Duration::from_secs(2 * 60 * 60),
166183
target_session_attrs: TargetSessionAttrs::Any,
184+
channel_binding: ChannelBinding::Prefer,
167185
}
168186
}
169187

@@ -287,6 +305,14 @@ impl Config {
287305
self
288306
}
289307

308+
/// Sets the channel binding behavior.
309+
///
310+
/// Defaults to `prefer`.
311+
pub fn channel_binding(&mut self, channel_binding: ChannelBinding) -> &mut Config {
312+
self.channel_binding = channel_binding;
313+
self
314+
}
315+
290316
fn param(&mut self, key: &str, value: &str) -> Result<(), Error> {
291317
match key {
292318
"user" => {
@@ -363,6 +389,19 @@ impl Config {
363389
};
364390
self.target_session_attrs(target_session_attrs);
365391
}
392+
"channel_binding" => {
393+
let channel_binding = match value {
394+
"disable" => ChannelBinding::Disable,
395+
"prefer" => ChannelBinding::Prefer,
396+
"require" => ChannelBinding::Require,
397+
_ => {
398+
return Err(Error::config_parse(Box::new(InvalidValue(
399+
"channel_binding",
400+
))))
401+
}
402+
};
403+
self.channel_binding(channel_binding);
404+
}
366405
key => {
367406
return Err(Error::config_parse(Box::new(UnknownOption(
368407
key.to_string(),
@@ -434,6 +473,7 @@ impl fmt::Debug for Config {
434473
.field("keepalives", &self.keepalives)
435474
.field("keepalives_idle", &self.keepalives_idle)
436475
.field("target_session_attrs", &self.target_session_attrs)
476+
.field("channel_binding", &self.channel_binding)
437477
.finish()
438478
}
439479
}

tokio-postgres/src/connect_raw.rs

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
2-
use crate::config::Config;
2+
use crate::config::{self, Config};
33
use crate::connect_tls::connect_tls;
44
use crate::maybe_tls_stream::MaybeTlsStream;
55
use crate::tls::{ChannelBinding, TlsConnect};
@@ -141,8 +141,13 @@ where
141141
T: AsyncRead + AsyncWrite + Unpin,
142142
{
143143
match stream.try_next().await.map_err(Error::io)? {
144-
Some(Message::AuthenticationOk) => return Ok(()),
144+
Some(Message::AuthenticationOk) => {
145+
no_channel_binding(config)?;
146+
return Ok(());
147+
}
145148
Some(Message::AuthenticationCleartextPassword) => {
149+
no_channel_binding(config)?;
150+
146151
let pass = config
147152
.password
148153
.as_ref()
@@ -151,6 +156,8 @@ where
151156
authenticate_password(stream, pass).await?;
152157
}
153158
Some(Message::AuthenticationMd5Password(body)) => {
159+
no_channel_binding(config)?;
160+
154161
let user = config
155162
.user
156163
.as_ref()
@@ -164,12 +171,7 @@ where
164171
authenticate_password(stream, output.as_bytes()).await?;
165172
}
166173
Some(Message::AuthenticationSasl(body)) => {
167-
let pass = config
168-
.password
169-
.as_ref()
170-
.ok_or_else(|| Error::config("password missing".into()))?;
171-
172-
authenticate_sasl(stream, body, channel_binding, pass).await?;
174+
authenticate_sasl(stream, body, channel_binding, config).await?;
173175
}
174176
Some(Message::AuthenticationKerberosV5)
175177
| Some(Message::AuthenticationScmCredential)
@@ -192,6 +194,16 @@ where
192194
}
193195
}
194196

197+
fn no_channel_binding(config: &Config) -> Result<(), Error> {
198+
match config.channel_binding {
199+
config::ChannelBinding::Disable | config::ChannelBinding::Prefer => Ok(()),
200+
config::ChannelBinding::Require => Err(Error::authentication(
201+
"server did not use channel binding".into(),
202+
)),
203+
config::ChannelBinding::__NonExhaustive => unreachable!(),
204+
}
205+
}
206+
195207
async fn authenticate_password<S, T>(
196208
stream: &mut StartupStream<S, T>,
197209
password: &[u8],
@@ -213,12 +225,17 @@ async fn authenticate_sasl<S, T>(
213225
stream: &mut StartupStream<S, T>,
214226
body: AuthenticationSaslBody,
215227
channel_binding: ChannelBinding,
216-
password: &[u8],
228+
config: &Config,
217229
) -> Result<(), Error>
218230
where
219231
S: AsyncRead + AsyncWrite + Unpin,
220232
T: AsyncRead + AsyncWrite + Unpin,
221233
{
234+
let password = config
235+
.password
236+
.as_ref()
237+
.ok_or_else(|| Error::config("password missing".into()))?;
238+
222239
let mut has_scram = false;
223240
let mut has_scram_plus = false;
224241
let mut mechanisms = body.mechanisms();
@@ -232,6 +249,7 @@ where
232249

233250
let channel_binding = channel_binding
234251
.tls_server_end_point
252+
.filter(|_| config.channel_binding != config::ChannelBinding::Disable)
235253
.map(sasl::ChannelBinding::tls_server_end_point);
236254

237255
let (channel_binding, mechanism) = if has_scram_plus {
@@ -240,6 +258,8 @@ where
240258
None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256),
241259
}
242260
} else if has_scram {
261+
no_channel_binding(config)?;
262+
243263
match channel_binding {
244264
Some(_) => (sasl::ChannelBinding::unrequested(), sasl::SCRAM_SHA_256),
245265
None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256),

tokio-postgres/src/prepare.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,11 @@ pub fn prepare(
106106
}
107107
}
108108

109-
fn prepare_rec(client: Arc<InnerClient>, query: &str, types: &[Type]) -> Pin<Box<dyn Future<Output = Result<Statement, Error>> + 'static + Send>> {
109+
fn prepare_rec(
110+
client: Arc<InnerClient>,
111+
query: &str,
112+
types: &[Type],
113+
) -> Pin<Box<dyn Future<Output = Result<Statement, Error>> + 'static + Send>> {
110114
Box::pin(prepare(client, query, types))
111115
}
112116

tokio-postgres/tests/test/main.rs

Lines changed: 14 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ mod types;
2020
async fn connect_raw(s: &str) -> Result<(Client, Connection<TcpStream, NoTlsStream>), Error> {
2121
let socket = TcpStream::connect("127.0.0.1:5433").await.unwrap();
2222
let config = s.parse::<Config>().unwrap();
23-
// FIXME https://github.com/rust-lang/rust/issues/64391
24-
async move { config.connect_raw(socket, NoTls).await }.await
23+
config.connect_raw(socket, NoTls).await
2524
}
2625

2726
async fn connect(s: &str) -> Client {
@@ -608,100 +607,20 @@ async fn query_portal() {
608607
assert_eq!(r3.len(), 0);
609608
}
610609

611-
/*
612-
#[test]
613-
fn poll_idle_running() {
614-
struct DelayStream(Delay);
615-
616-
impl Stream for DelayStream {
617-
type Item = Vec<u8>;
618-
type Error = tokio_postgres::Error;
619-
620-
fn poll(&mut self) -> Poll<Option<Vec<u8>>, tokio_postgres::Error> {
621-
try_ready!(self.0.poll().map_err(|e| panic!("{}", e)));
622-
QUERY_DONE.store(true, Ordering::SeqCst);
623-
Ok(Async::Ready(None))
624-
}
625-
}
626-
627-
struct IdleFuture(tokio_postgres::Client);
628-
629-
impl Future for IdleFuture {
630-
type Item = ();
631-
type Error = tokio_postgres::Error;
632-
633-
fn poll(&mut self) -> Poll<(), tokio_postgres::Error> {
634-
try_ready!(self.0.poll_idle());
635-
assert!(QUERY_DONE.load(Ordering::SeqCst));
636-
Ok(Async::Ready(()))
637-
}
638-
}
639-
640-
static QUERY_DONE: AtomicBool = AtomicBool::new(false);
641-
642-
let _ = env_logger::try_init();
643-
let mut runtime = Runtime::new().unwrap();
644-
645-
let (mut client, connection) = runtime.block_on(connect("user=postgres")).unwrap();
646-
let connection = connection.map_err(|e| panic!("{}", e));
647-
runtime.handle().spawn(connection).unwrap();
648-
649-
let execute = client
650-
.simple_query("CREATE TEMPORARY TABLE foo (id INT)")
651-
.for_each(|_| Ok(()));
652-
runtime.block_on(execute).unwrap();
653-
654-
let prepare = client.prepare("COPY foo FROM STDIN");
655-
let stmt = runtime.block_on(prepare).unwrap();
656-
let copy_in = client.copy_in(
657-
&stmt,
658-
&[],
659-
DelayStream(Delay::new(Instant::now() + Duration::from_millis(10))),
660-
);
661-
let copy_in = copy_in.map(|_| ()).map_err(|e| panic!("{}", e));
662-
runtime.spawn(copy_in);
663-
664-
let future = IdleFuture(client);
665-
runtime.block_on(future).unwrap();
610+
#[tokio::test]
611+
async fn require_channel_binding() {
612+
connect_raw("user=postgres channel_binding=require")
613+
.await
614+
.err()
615+
.unwrap();
666616
}
667617

668-
#[test]
669-
fn poll_idle_new() {
670-
struct IdleFuture {
671-
client: tokio_postgres::Client,
672-
prepare: Option<impls::Prepare>,
673-
}
674-
675-
impl Future for IdleFuture {
676-
type Item = ();
677-
type Error = tokio_postgres::Error;
678-
679-
fn poll(&mut self) -> Poll<(), tokio_postgres::Error> {
680-
match self.prepare.take() {
681-
Some(_future) => {
682-
assert!(!self.client.poll_idle().unwrap().is_ready());
683-
Ok(Async::NotReady)
684-
}
685-
None => {
686-
assert!(self.client.poll_idle().unwrap().is_ready());
687-
Ok(Async::Ready(()))
688-
}
689-
}
690-
}
691-
}
692-
693-
let _ = env_logger::try_init();
694-
let mut runtime = Runtime::new().unwrap();
695-
696-
let (mut client, connection) = runtime.block_on(connect("user=postgres")).unwrap();
697-
let connection = connection.map_err(|e| panic!("{}", e));
698-
runtime.handle().spawn(connection).unwrap();
618+
#[tokio::test]
619+
async fn prefer_channel_binding() {
620+
connect("user=postgres channel_binding=prefer").await;
621+
}
699622

700-
let prepare = client.prepare("");
701-
let future = IdleFuture {
702-
client,
703-
prepare: Some(prepare),
704-
};
705-
runtime.block_on(future).unwrap();
623+
#[tokio::test]
624+
async fn disable_channel_binding() {
625+
connect("user=postgres channel_binding=disable").await;
706626
}
707-
*/

0 commit comments

Comments
 (0)