Skip to content

Commit 48874dc

Browse files
committed
IpAddr + try hostaddr first
1 parent 3697f6b commit 48874dc

File tree

3 files changed

+110
-39
lines changed

3 files changed

+110
-39
lines changed

tokio-postgres/src/config.rs

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use crate::{Client, Connection, Error};
1313
use std::borrow::Cow;
1414
#[cfg(unix)]
1515
use std::ffi::OsStr;
16+
use std::net::IpAddr;
1617
use std::ops::Deref;
1718
#[cfg(unix)]
1819
use std::os::unix::ffi::OsStrExt;
@@ -98,7 +99,9 @@ pub enum Host {
9899
/// - or if host specifies an IP address, that value will be used directly.
99100
/// Using `hostaddr` allows the application to avoid a host name look-up, which might be important in applications
100101
/// with time constraints. However, a host name is required for verify-full SSL certificate verification.
101-
/// Note that `host` is always required regardless of whether `hostaddr` is present.
102+
/// Specifically:
103+
/// * If `hostaddr` is specified without `host`, the value for `hostaddr` gives the server network address.
104+
/// The connection attempt will fail if the authentication method requires a host name;
102105
/// * If `host` is specified without `hostaddr`, a host name lookup occurs;
103106
/// * If both `host` and `hostaddr` are specified, the value for `hostaddr` gives the server network address.
104107
/// The value for `host` is ignored unless the authentication method requires it,
@@ -174,7 +177,7 @@ pub struct Config {
174177
pub(crate) application_name: Option<String>,
175178
pub(crate) ssl_mode: SslMode,
176179
pub(crate) host: Vec<Host>,
177-
pub(crate) hostaddr: Vec<String>,
180+
pub(crate) hostaddr: Vec<IpAddr>,
178181
pub(crate) port: Vec<u16>,
179182
pub(crate) connect_timeout: Option<Duration>,
180183
pub(crate) keepalives: bool,
@@ -317,7 +320,7 @@ impl Config {
317320
}
318321

319322
/// Gets the hostaddrs that have been added to the configuration with `hostaddr`.
320-
pub fn get_hostaddrs(&self) -> &[String] {
323+
pub fn get_hostaddrs(&self) -> &[IpAddr] {
321324
self.hostaddr.deref()
322325
}
323326

@@ -337,8 +340,8 @@ impl Config {
337340
///
338341
/// Multiple hostaddrs can be specified by calling this method multiple times, and each will be tried in order.
339342
/// There must be either no hostaddrs, or the same number of hostaddrs as hosts.
340-
pub fn hostaddr(&mut self, hostaddr: &str) -> &mut Config {
341-
self.hostaddr.push(hostaddr.to_string());
343+
pub fn hostaddr(&mut self, hostaddr: IpAddr) -> &mut Config {
344+
self.hostaddr.push(hostaddr);
342345
self
343346
}
344347

@@ -489,7 +492,10 @@ impl Config {
489492
}
490493
"hostaddr" => {
491494
for hostaddr in value.split(',') {
492-
self.hostaddr(hostaddr);
495+
let addr = hostaddr
496+
.parse()
497+
.map_err(|_| Error::config_parse(Box::new(InvalidValue("hostaddr"))))?;
498+
self.hostaddr(addr);
493499
}
494500
}
495501
"port" => {
@@ -1016,6 +1022,8 @@ impl<'a> UrlParser<'a> {
10161022

10171023
#[cfg(test)]
10181024
mod tests {
1025+
use std::net::IpAddr;
1026+
10191027
use crate::{config::Host, Config};
10201028

10211029
#[test]
@@ -1032,16 +1040,14 @@ mod tests {
10321040
config.get_hosts(),
10331041
);
10341042

1035-
assert_eq!(["127.0.0.1", "127.0.0.2"], config.get_hostaddrs(),);
1043+
assert_eq!(
1044+
[
1045+
"127.0.0.1".parse::<IpAddr>().unwrap(),
1046+
"127.0.0.2".parse::<IpAddr>().unwrap()
1047+
],
1048+
config.get_hostaddrs(),
1049+
);
10361050

10371051
assert_eq!(1, 1);
10381052
}
1039-
1040-
#[test]
1041-
fn test_empty_hostaddrs() {
1042-
let s =
1043-
"user=pass_user dbname=postgres host=host1,host2,host3 hostaddr=127.0.0.1,,127.0.0.2";
1044-
let config = s.parse::<Config>().unwrap();
1045-
assert_eq!(["127.0.0.1", "", "127.0.0.2"], config.get_hostaddrs(),);
1046-
}
10471053
}

tokio-postgres/src/connect.rs

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ use crate::connect_socket::connect_socket;
55
use crate::tls::{MakeTlsConnect, TlsConnect};
66
use crate::{Client, Config, Connection, Error, SimpleQueryMessage, Socket};
77
use futures_util::{future, pin_mut, Future, FutureExt, Stream};
8-
use std::io;
98
use std::task::Poll;
9+
use std::{cmp, io};
1010

1111
pub async fn connect<T>(
1212
mut tls: T,
@@ -15,25 +15,35 @@ pub async fn connect<T>(
1515
where
1616
T: MakeTlsConnect<Socket>,
1717
{
18-
if config.host.is_empty() {
19-
return Err(Error::config("host missing".into()));
18+
if config.host.is_empty() && config.hostaddr.is_empty() {
19+
return Err(Error::config("both host and hostaddr are missing".into()));
2020
}
2121

22-
if config.port.len() > 1 && config.port.len() != config.host.len() {
23-
return Err(Error::config("invalid number of ports".into()));
24-
}
25-
26-
if !config.hostaddr.is_empty() && config.hostaddr.len() != config.host.len() {
22+
if !config.host.is_empty()
23+
&& !config.hostaddr.is_empty()
24+
&& config.host.len() != config.hostaddr.len()
25+
{
2726
let msg = format!(
28-
"invalid number of hostaddrs ({}). Possible values: 0 or number of hosts ({})",
29-
config.hostaddr.len(),
27+
"number of hosts ({}) is different from number of hostaddrs ({})",
3028
config.host.len(),
29+
config.hostaddr.len(),
3130
);
3231
return Err(Error::config(msg.into()));
3332
}
3433

34+
// At this point, either one of the following two scenarios could happen:
35+
// (1) either config.host or config.hostaddr must be empty;
36+
// (2) if both config.host and config.hostaddr are NOT empty; their lengths must be equal.
37+
let num_hosts = cmp::max(config.host.len(), config.hostaddr.len());
38+
39+
if config.port.len() > 1 && config.port.len() != num_hosts {
40+
return Err(Error::config("invalid number of ports".into()));
41+
}
42+
3543
let mut error = None;
36-
for (i, host) in config.host.iter().enumerate() {
44+
for i in 0..num_hosts {
45+
let host = config.host.get(i);
46+
let hostaddr = config.hostaddr.get(i);
3747
let port = config
3848
.port
3949
.get(i)
@@ -42,27 +52,30 @@ where
4252
.unwrap_or(5432);
4353

4454
// The value of host is always used as the hostname for TLS validation.
55+
// postgres doesn't support TLS over unix sockets, so the choice for Host::Unix variant here doesn't matter
4556
let hostname = match host {
46-
Host::Tcp(host) => host.as_str(),
47-
// postgres doesn't support TLS over unix sockets, so the choice here doesn't matter
48-
#[cfg(unix)]
49-
Host::Unix(_) => "",
57+
Some(Host::Tcp(host)) => host.as_str(),
58+
_ => "",
5059
};
5160
let tls = tls
5261
.make_tls_connect(hostname)
5362
.map_err(|e| Error::tls(e.into()))?;
5463

55-
// If both host and hostaddr are specified, the value of hostaddr is used to to establish the TCP connection.
56-
let hostaddr = match host {
57-
Host::Tcp(_hostname) => match config.hostaddr.get(i) {
58-
Some(hostaddr) if hostaddr.is_empty() => Host::Tcp(hostaddr.clone()),
59-
_ => host.clone(),
60-
},
61-
#[cfg(unix)]
62-
Host::Unix(_v) => host.clone(),
64+
// Try to use the value of hostaddr to establish the TCP connection,
65+
// fallback to host if hostaddr is not present.
66+
let addr = match hostaddr {
67+
Some(ipaddr) => Host::Tcp(ipaddr.to_string()),
68+
None => {
69+
if let Some(host) = host {
70+
host.clone()
71+
} else {
72+
// This is unreachable.
73+
return Err(Error::config("both host and hostaddr are empty".into()));
74+
}
75+
}
6376
};
6477

65-
match connect_once(&hostaddr, port, tls, config).await {
78+
match connect_once(&addr, port, tls, config).await {
6679
Ok((client, connection)) => return Ok((client, connection)),
6780
Err(e) => error = Some(e),
6881
}

tokio-postgres/tests/test/main.rs

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,58 @@ async fn scram_password_ok() {
147147
connect("user=scram_user password=password dbname=postgres").await;
148148
}
149149

150+
#[tokio::test]
151+
async fn host_only_ok() {
152+
let _ = tokio_postgres::connect(
153+
"host=localhost port=5433 user=pass_user dbname=postgres password=password",
154+
NoTls,
155+
)
156+
.await
157+
.unwrap();
158+
}
159+
160+
#[tokio::test]
161+
async fn hostaddr_only_ok() {
162+
let _ = tokio_postgres::connect(
163+
"hostaddr=127.0.0.1 port=5433 user=pass_user dbname=postgres password=password",
164+
NoTls,
165+
)
166+
.await
167+
.unwrap();
168+
}
169+
170+
#[tokio::test]
171+
async fn hostaddr_and_host_ok() {
172+
let _ = tokio_postgres::connect(
173+
"hostaddr=127.0.0.1 host=localhost port=5433 user=pass_user dbname=postgres password=password",
174+
NoTls,
175+
)
176+
.await
177+
.unwrap();
178+
}
179+
180+
#[tokio::test]
181+
async fn hostaddr_host_mismatch() {
182+
let _ = tokio_postgres::connect(
183+
"hostaddr=127.0.0.1,127.0.0.2 host=localhost port=5433 user=pass_user dbname=postgres password=password",
184+
NoTls,
185+
)
186+
.await
187+
.err()
188+
.unwrap();
189+
}
190+
191+
#[tokio::test]
192+
async fn hostaddr_host_both_missing() {
193+
let _ = tokio_postgres::connect(
194+
"port=5433 user=pass_user dbname=postgres password=password",
195+
NoTls,
196+
)
197+
.await
198+
.err()
199+
.unwrap();
200+
}
201+
150202
#[tokio::test]
151203
async fn pipelined_prepare() {
152204
let client = connect("user=postgres").await;

0 commit comments

Comments
 (0)