Skip to content

Commit 4deea3b

Browse files
authored
Merge pull request #488 from sfackler/channel-binding-require
Add channel_binding=disable/prefer/require to config
2 parents a8d945c + c9469ea commit 4deea3b

File tree

7 files changed

+129
-110
lines changed

7 files changed

+129
-110
lines changed

postgres-native-tls/src/test.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@ where
1212
T: TlsConnect<TcpStream>,
1313
T::Stream: 'static + Send,
1414
{
15-
let stream = TcpStream::connect("127.0.0.1:5433")
16-
.await
17-
.unwrap();
15+
let stream = TcpStream::connect("127.0.0.1:5433").await.unwrap();
1816

1917
let builder = s.parse::<tokio_postgres::Config>().unwrap();
2018
let (mut client, connection) = builder.connect_raw(stream, tls).await.unwrap();

postgres-openssl/src/test.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,32 @@ async fn scram_user() {
6565
.await;
6666
}
6767

68+
#[tokio::test]
69+
async fn require_channel_binding_err() {
70+
let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
71+
builder.set_ca_file("../test/server.crt").unwrap();
72+
let ctx = builder.build();
73+
let connector = TlsConnector::new(ctx.configure().unwrap(), "localhost");
74+
75+
let stream = TcpStream::connect("127.0.0.1:5433").await.unwrap();
76+
let builder = "user=pass_user password=password dbname=postgres channel_binding=require"
77+
.parse::<tokio_postgres::Config>()
78+
.unwrap();
79+
builder.connect_raw(stream, connector).await.err().unwrap();
80+
}
81+
82+
#[tokio::test]
83+
async fn require_channel_binding_ok() {
84+
let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
85+
builder.set_ca_file("../test/server.crt").unwrap();
86+
let ctx = builder.build();
87+
smoke_test(
88+
"user=scram_user password=password dbname=postgres channel_binding=require",
89+
TlsConnector::new(ctx.configure().unwrap(), "localhost"),
90+
)
91+
.await;
92+
}
93+
6894
#[tokio::test]
6995
#[cfg(feature = "runtime")]
7096
async fn runtime() {

postgres/src/config.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use tokio_postgres::tls::{MakeTlsConnect, TlsConnect};
1414
use tokio_postgres::{Error, Socket};
1515

1616
#[doc(inline)]
17-
pub use tokio_postgres::config::{SslMode, TargetSessionAttrs};
17+
pub use tokio_postgres::config::{SslMode, TargetSessionAttrs, ChannelBinding};
1818

1919
use crate::{Client, RUNTIME};
2020

@@ -234,6 +234,14 @@ impl Config {
234234
self
235235
}
236236

237+
/// Sets the channel binding behavior.
238+
///
239+
/// Defaults to `prefer`.
240+
pub fn channel_binding(&mut self, channel_binding: ChannelBinding) -> &mut Config {
241+
self.config.channel_binding(channel_binding);
242+
self
243+
}
244+
237245
/// Sets the executor used to run the connection futures.
238246
///
239247
/// Defaults to a postgres-specific tokio `Runtime`.

tokio-postgres/src/config.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,19 @@ pub enum SslMode {
4646
__NonExhaustive,
4747
}
4848

49+
/// Channel binding configuration.
50+
#[derive(Debug, Copy, Clone, PartialEq)]
51+
pub enum ChannelBinding {
52+
/// Do not use channel binding.
53+
Disable,
54+
/// Attempt to use channel binding but allow sessions without.
55+
Prefer,
56+
/// Require the use of channel binding.
57+
Require,
58+
#[doc(hidden)]
59+
__NonExhaustive,
60+
}
61+
4962
#[derive(Debug, Clone, PartialEq)]
5063
pub(crate) enum Host {
5164
Tcp(String),
@@ -87,6 +100,9 @@ pub(crate) enum Host {
87100
/// * `target_session_attrs` - Specifies requirements of the session. If set to `read-write`, the client will check that
88101
/// the `transaction_read_write` session parameter is set to `on`. This can be used to connect to the primary server
89102
/// in a database cluster as opposed to the secondary read-only mirrors. Defaults to `all`.
103+
/// * `channel_binding` - Controls usage of channel binding in the authentication process. If set to `disable`, channel
104+
/// binding will not be used. If set to `prefer`, channel binding will be used if available, but not used otherwise.
105+
/// If set to `require`, the authentication process will fail if channel binding is not used. Defaults to `prefer`.
90106
///
91107
/// ## Examples
92108
///
@@ -140,6 +156,7 @@ pub struct Config {
140156
pub(crate) keepalives: bool,
141157
pub(crate) keepalives_idle: Duration,
142158
pub(crate) target_session_attrs: TargetSessionAttrs,
159+
pub(crate) channel_binding: ChannelBinding,
143160
}
144161

145162
impl Default for Config {
@@ -164,6 +181,7 @@ impl Config {
164181
keepalives: true,
165182
keepalives_idle: Duration::from_secs(2 * 60 * 60),
166183
target_session_attrs: TargetSessionAttrs::Any,
184+
channel_binding: ChannelBinding::Prefer,
167185
}
168186
}
169187

@@ -287,6 +305,14 @@ impl Config {
287305
self
288306
}
289307

308+
/// Sets the channel binding behavior.
309+
///
310+
/// Defaults to `prefer`.
311+
pub fn channel_binding(&mut self, channel_binding: ChannelBinding) -> &mut Config {
312+
self.channel_binding = channel_binding;
313+
self
314+
}
315+
290316
fn param(&mut self, key: &str, value: &str) -> Result<(), Error> {
291317
match key {
292318
"user" => {
@@ -363,6 +389,19 @@ impl Config {
363389
};
364390
self.target_session_attrs(target_session_attrs);
365391
}
392+
"channel_binding" => {
393+
let channel_binding = match value {
394+
"disable" => ChannelBinding::Disable,
395+
"prefer" => ChannelBinding::Prefer,
396+
"require" => ChannelBinding::Require,
397+
_ => {
398+
return Err(Error::config_parse(Box::new(InvalidValue(
399+
"channel_binding",
400+
))))
401+
}
402+
};
403+
self.channel_binding(channel_binding);
404+
}
366405
key => {
367406
return Err(Error::config_parse(Box::new(UnknownOption(
368407
key.to_string(),
@@ -434,6 +473,7 @@ impl fmt::Debug for Config {
434473
.field("keepalives", &self.keepalives)
435474
.field("keepalives_idle", &self.keepalives_idle)
436475
.field("target_session_attrs", &self.target_session_attrs)
476+
.field("channel_binding", &self.channel_binding)
437477
.finish()
438478
}
439479
}

tokio-postgres/src/connect_raw.rs

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
2-
use crate::config::Config;
2+
use crate::config::{self, Config};
33
use crate::connect_tls::connect_tls;
44
use crate::maybe_tls_stream::MaybeTlsStream;
55
use crate::tls::{ChannelBinding, TlsConnect};
@@ -141,8 +141,13 @@ where
141141
T: AsyncRead + AsyncWrite + Unpin,
142142
{
143143
match stream.try_next().await.map_err(Error::io)? {
144-
Some(Message::AuthenticationOk) => return Ok(()),
144+
Some(Message::AuthenticationOk) => {
145+
can_skip_channel_binding(config)?;
146+
return Ok(());
147+
}
145148
Some(Message::AuthenticationCleartextPassword) => {
149+
can_skip_channel_binding(config)?;
150+
146151
let pass = config
147152
.password
148153
.as_ref()
@@ -151,6 +156,8 @@ where
151156
authenticate_password(stream, pass).await?;
152157
}
153158
Some(Message::AuthenticationMd5Password(body)) => {
159+
can_skip_channel_binding(config)?;
160+
154161
let user = config
155162
.user
156163
.as_ref()
@@ -164,12 +171,7 @@ where
164171
authenticate_password(stream, output.as_bytes()).await?;
165172
}
166173
Some(Message::AuthenticationSasl(body)) => {
167-
let pass = config
168-
.password
169-
.as_ref()
170-
.ok_or_else(|| Error::config("password missing".into()))?;
171-
172-
authenticate_sasl(stream, body, channel_binding, pass).await?;
174+
authenticate_sasl(stream, body, channel_binding, config).await?;
173175
}
174176
Some(Message::AuthenticationKerberosV5)
175177
| Some(Message::AuthenticationScmCredential)
@@ -192,6 +194,16 @@ where
192194
}
193195
}
194196

197+
fn can_skip_channel_binding(config: &Config) -> Result<(), Error> {
198+
match config.channel_binding {
199+
config::ChannelBinding::Disable | config::ChannelBinding::Prefer => Ok(()),
200+
config::ChannelBinding::Require => Err(Error::authentication(
201+
"server did not use channel binding".into(),
202+
)),
203+
config::ChannelBinding::__NonExhaustive => unreachable!(),
204+
}
205+
}
206+
195207
async fn authenticate_password<S, T>(
196208
stream: &mut StartupStream<S, T>,
197209
password: &[u8],
@@ -213,12 +225,17 @@ async fn authenticate_sasl<S, T>(
213225
stream: &mut StartupStream<S, T>,
214226
body: AuthenticationSaslBody,
215227
channel_binding: ChannelBinding,
216-
password: &[u8],
228+
config: &Config,
217229
) -> Result<(), Error>
218230
where
219231
S: AsyncRead + AsyncWrite + Unpin,
220232
T: AsyncRead + AsyncWrite + Unpin,
221233
{
234+
let password = config
235+
.password
236+
.as_ref()
237+
.ok_or_else(|| Error::config("password missing".into()))?;
238+
222239
let mut has_scram = false;
223240
let mut has_scram_plus = false;
224241
let mut mechanisms = body.mechanisms();
@@ -232,12 +249,15 @@ where
232249

233250
let channel_binding = channel_binding
234251
.tls_server_end_point
252+
.filter(|_| config.channel_binding != config::ChannelBinding::Disable)
235253
.map(sasl::ChannelBinding::tls_server_end_point);
236254

237255
let (channel_binding, mechanism) = if has_scram_plus {
238256
match channel_binding {
239257
Some(channel_binding) => (channel_binding, sasl::SCRAM_SHA_256_PLUS),
240-
None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256),
258+
None => {
259+
(sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256)
260+
},
241261
}
242262
} else if has_scram {
243263
match channel_binding {
@@ -248,6 +268,10 @@ where
248268
return Err(Error::authentication("unsupported SASL mechanism".into()));
249269
};
250270

271+
if mechanism != sasl::SCRAM_SHA_256_PLUS {
272+
can_skip_channel_binding(config)?;
273+
}
274+
251275
let mut scram = ScramSha256::new(password, channel_binding);
252276

253277
let mut buf = vec![];

tokio-postgres/src/prepare.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,11 @@ pub fn prepare(
106106
}
107107
}
108108

109-
fn prepare_rec(client: Arc<InnerClient>, query: &str, types: &[Type]) -> Pin<Box<dyn Future<Output = Result<Statement, Error>> + 'static + Send>> {
109+
fn prepare_rec(
110+
client: Arc<InnerClient>,
111+
query: &str,
112+
types: &[Type],
113+
) -> Pin<Box<dyn Future<Output = Result<Statement, Error>> + 'static + Send>> {
110114
Box::pin(prepare(client, query, types))
111115
}
112116

0 commit comments

Comments
 (0)