Skip to content

Commit c59799e

Browse files
committed
Refactor connectparams
1 parent 4c91a68 commit c59799e

File tree

7 files changed

+179
-104
lines changed

7 files changed

+179
-104
lines changed

postgres-shared/src/params/mod.rs

Lines changed: 134 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,142 @@
11
//! Connection parameters
22
use std::error::Error;
33
use std::path::PathBuf;
4+
use std::mem;
45

56
use params::url::Url;
67

78
mod url;
89

9-
/// Specifies the target server to connect to.
10+
/// The host.
1011
#[derive(Clone, Debug)]
11-
pub enum ConnectTarget {
12-
/// Connect via TCP to the specified host.
12+
pub enum Host {
13+
/// A TCP hostname.
1314
Tcp(String),
14-
/// Connect via a Unix domain socket in the specified directory.
15-
///
16-
/// Unix sockets are only supported on Unixy platforms (i.e. not Windows).
15+
/// The path to a directory containing the server's Unix socket.
1716
Unix(PathBuf),
1817
}
1918

2019
/// Authentication information.
2120
#[derive(Clone, Debug)]
22-
pub struct UserInfo {
21+
pub struct User {
22+
name: String,
23+
password: Option<String>,
24+
}
25+
26+
impl User {
2327
/// The username.
24-
pub user: String,
28+
pub fn name(&self) -> &str {
29+
&self.name
30+
}
31+
2532
/// An optional password.
26-
pub password: Option<String>,
33+
pub fn password(&self) -> Option<&str> {
34+
self.password.as_ref().map(|p| &**p)
35+
}
2736
}
2837

2938
/// Information necessary to open a new connection to a Postgres server.
3039
#[derive(Clone, Debug)]
3140
pub struct ConnectParams {
32-
/// The target server.
33-
pub target: ConnectTarget,
41+
host: Host,
42+
port: u16,
43+
user: Option<User>,
44+
database: Option<String>,
45+
options: Vec<(String, String)>,
46+
}
47+
48+
impl ConnectParams {
49+
/// Returns a new builder.
50+
pub fn builder() -> Builder {
51+
Builder::new()
52+
}
53+
54+
/// The target host.
55+
pub fn host(&self) -> &Host {
56+
&self.host
57+
}
58+
3459
/// The target port.
3560
///
36-
/// Defaults to 5432 if not specified.
37-
pub port: Option<u16>,
38-
/// The user to login as.
61+
/// Defaults to 5432.
62+
pub fn port(&self) -> u16 {
63+
self.port
64+
}
65+
66+
/// The user to log in as.
3967
///
40-
/// `Connection::connect` requires a user but `cancel_query` does not.
41-
pub user: Option<UserInfo>,
68+
/// A user is required to open a new connection but not to cancel a query.
69+
pub fn user(&self) -> Option<&User> {
70+
self.user.as_ref()
71+
}
72+
4273
/// The database to connect to.
43-
///
44-
/// Defaults the value of `user`.
45-
pub database: Option<String>,
74+
pub fn database(&self) -> Option<&str> {
75+
self.database.as_ref().map(|d| &**d)
76+
}
77+
4678
/// Runtime parameters to be passed to the Postgres backend.
47-
pub options: Vec<(String, String)>,
79+
pub fn options(&self) -> &[(String, String)] {
80+
&self.options
81+
}
82+
}
83+
84+
/// A builder for `ConnectParams`.
85+
pub struct Builder {
86+
port: u16,
87+
user: Option<User>,
88+
database: Option<String>,
89+
options: Vec<(String, String)>,
90+
}
91+
92+
impl Builder {
93+
/// Creates a new builder.
94+
pub fn new() -> Builder {
95+
Builder {
96+
port: 5432,
97+
user: None,
98+
database: None,
99+
options: vec![],
100+
}
101+
}
102+
103+
/// Sets the port.
104+
pub fn port(&mut self, port: u16) -> &mut Builder {
105+
self.port = port;
106+
self
107+
}
108+
109+
/// Sets the user.
110+
pub fn user(&mut self, name: &str, password: Option<&str>) -> &mut Builder {
111+
self.user = Some(User {
112+
name: name.to_string(),
113+
password: password.map(ToString::to_string),
114+
});
115+
self
116+
}
117+
118+
/// Sets the database.
119+
pub fn database(&mut self, database: &str) -> &mut Builder {
120+
self.database = Some(database.to_string());
121+
self
122+
}
123+
124+
/// Adds a runtime parameter.
125+
pub fn option(&mut self, name: &str, value: &str) -> &mut Builder {
126+
self.options.push((name.to_string(), value.to_string()));
127+
self
128+
}
129+
130+
/// Constructs a `ConnectParams` from the builder.
131+
pub fn build(&mut self, host: Host) -> ConnectParams {
132+
ConnectParams {
133+
host: host,
134+
port: self.port,
135+
user: self.user.take(),
136+
database: self.database.take(),
137+
options: mem::replace(&mut self.options, vec![]),
138+
}
139+
}
48140
}
49141

50142
/// A trait implemented by types that can be converted into a `ConnectParams`.
@@ -78,35 +170,33 @@ impl IntoConnectParams for Url {
78170
fn into_connect_params(self) -> Result<ConnectParams, Box<Error + Sync + Send>> {
79171
let Url { host, port, user, path: url::Path { mut path, query: options, .. }, .. } = self;
80172

81-
let maybe_path = url::decode_component(&host)?;
82-
let target = if maybe_path.starts_with('/') {
83-
ConnectTarget::Unix(PathBuf::from(maybe_path))
84-
} else {
85-
ConnectTarget::Tcp(host)
86-
};
173+
let mut builder = ConnectParams::builder();
87174

88-
let user = user.map(|url::UserInfo { user, pass }| {
89-
UserInfo {
90-
user: user,
91-
password: pass,
92-
}
93-
});
175+
if let Some(port) = port {
176+
builder.port(port);
177+
}
94178

95-
let database = if path.is_empty() {
96-
None
97-
} else {
179+
if let Some(info) = user {
180+
builder.user(&info.user, info.pass.as_ref().map(|p| &**p));
181+
}
182+
183+
if !path.is_empty() {
98184
// path contains the leading /
99-
path.remove(0);
100-
Some(path)
185+
builder.database(&path[1..]);
186+
}
187+
188+
for (name, value) in options {
189+
builder.option(&name, &value);
190+
}
191+
192+
let maybe_path = url::decode_component(&host)?;
193+
let host = if maybe_path.starts_with('/') {
194+
Host::Unix(maybe_path.into())
195+
} else {
196+
Host::Tcp(maybe_path)
101197
};
102198

103-
Ok(ConnectParams {
104-
target: target,
105-
port: port,
106-
user: user,
107-
database: database,
108-
options: options,
109-
})
199+
Ok(builder.build(host))
110200
}
111201
}
112202

postgres/src/lib.rs

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ use postgres_shared::rows::RowData;
9696
use error::{Error, ConnectError, SqlState, DbError};
9797
use tls::TlsHandshake;
9898
use notification::{Notifications, Notification};
99-
use params::{ConnectParams, IntoConnectParams, UserInfo};
99+
use params::{ConnectParams, IntoConnectParams, User};
100100
use priv_io::MessageStream;
101101
use rows::{Rows, LazyRows};
102102
use stmt::{Statement, Column};
@@ -255,9 +255,7 @@ impl InnerConnection {
255255
let params = params.into_connect_params().map_err(ConnectError::ConnectParams)?;
256256
let stream = priv_io::initialize_stream(&params, tls)?;
257257

258-
let ConnectParams { user, database, mut options, .. } = params;
259-
260-
let user = match user {
258+
let user = match params.user() {
261259
Some(user) => user,
262260
None => {
263261
return Err(ConnectError::ConnectParams("User missing from connection parameters"
@@ -285,14 +283,15 @@ impl InnerConnection {
285283
has_typeinfo_composite_query: false,
286284
};
287285

286+
let mut options = params.options().to_owned();
288287
options.push(("client_encoding".to_owned(), "UTF8".to_owned()));
289288
// Postgres uses the value of TimeZone as the time zone for TIMESTAMP
290289
// WITH TIME ZONE values. Timespec converts to GMT internally.
291290
options.push(("timezone".to_owned(), "GMT".to_owned()));
292291
// We have to clone here since we need the user again for auth
293-
options.push(("user".to_owned(), user.user.clone()));
294-
if let Some(database) = database {
295-
options.push(("database".to_owned(), database));
292+
options.push(("user".to_owned(), user.name().to_owned()));
293+
if let Some(database) = params.database() {
294+
options.push(("database".to_owned(), database.to_owned()));
296295
}
297296

298297
let options = options.iter().map(|&(ref a, ref b)| (&**a, &**b));
@@ -390,21 +389,21 @@ impl InnerConnection {
390389
}
391390
}
392391

393-
fn handle_auth(&mut self, user: UserInfo) -> result::Result<(), ConnectError> {
392+
fn handle_auth(&mut self, user: &User) -> result::Result<(), ConnectError> {
394393
match self.read_message()? {
395394
backend::Message::AuthenticationOk => return Ok(()),
396395
backend::Message::AuthenticationCleartextPassword => {
397-
let pass = user.password.ok_or_else(|| {
396+
let pass = user.password().ok_or_else(|| {
398397
ConnectError::ConnectParams("a password was requested but not provided".into())
399398
})?;
400-
self.stream.write_message(|buf| frontend::password_message(&pass, buf))?;
399+
self.stream.write_message(|buf| frontend::password_message(pass, buf))?;
401400
self.stream.flush()?;
402401
}
403402
backend::Message::AuthenticationMd5Password(body) => {
404-
let pass = user.password.ok_or_else(|| {
403+
let pass = user.password().ok_or_else(|| {
405404
ConnectError::ConnectParams("a password was requested but not provided".into())
406405
})?;
407-
let output = authentication::md5_hash(user.user.as_bytes(),
406+
let output = authentication::md5_hash(user.name().as_bytes(),
408407
pass.as_bytes(),
409408
body.salt());
410409
self.stream.write_message(|buf| frontend::password_message(&output, buf))?;
@@ -932,22 +931,15 @@ impl Connection {
932931
///
933932
/// ```rust,no_run
934933
/// use postgres::{Connection, TlsMode};
935-
/// use postgres::params::{UserInfo, ConnectParams, ConnectTarget};
934+
/// use postgres::params::{ConnectParams, Host};
936935
/// # use std::path::PathBuf;
937936
///
938937
/// # #[cfg(unix)]
939938
/// # fn f() {
940939
/// # let some_crazy_path = PathBuf::new();
941-
/// let params = ConnectParams {
942-
/// target: ConnectTarget::Unix(some_crazy_path),
943-
/// port: None,
944-
/// user: Some(UserInfo {
945-
/// user: "postgres".to_owned(),
946-
/// password: None
947-
/// }),
948-
/// database: None,
949-
/// options: vec![],
950-
/// };
940+
/// let params = ConnectParams::builder()
941+
/// .user("postgres", None)
942+
/// .build(Host::Unix(some_crazy_path));
951943
/// let conn = Connection::connect(params, TlsMode::None).unwrap();
952944
/// # }
953945
/// ```

postgres/src/priv_io.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ use postgres_protocol::message::backend::{self, ParseResult};
1616
use TlsMode;
1717
use error::ConnectError;
1818
use tls::TlsStream;
19-
use params::{ConnectParams, ConnectTarget};
19+
use params::{ConnectParams, Host};
2020

2121
const DEFAULT_PORT: u16 = 5432;
2222
const MESSAGE_HEADER_SIZE: usize = 5;
@@ -221,18 +221,18 @@ impl Write for InternalStream {
221221
}
222222

223223
fn open_socket(params: &ConnectParams) -> Result<InternalStream, ConnectError> {
224-
let port = params.port.unwrap_or(DEFAULT_PORT);
225-
match params.target {
226-
ConnectTarget::Tcp(ref host) => {
224+
let port = params.port();
225+
match *params.host() {
226+
Host::Tcp(ref host) => {
227227
Ok(TcpStream::connect(&(&**host, port)).map(InternalStream::Tcp)?)
228228
}
229229
#[cfg(unix)]
230-
ConnectTarget::Unix(ref path) => {
230+
Host::Unix(ref path) => {
231231
let path = path.join(&format!(".s.PGSQL.{}", port));
232232
Ok(UnixStream::connect(&path).map(InternalStream::Unix)?)
233233
}
234234
#[cfg(not(unix))]
235-
ConnectTarget::Unix(..) => {
235+
Host::Unix(..) => {
236236
Err(ConnectError::Io(io::Error::new(io::ErrorKind::InvalidInput,
237237
"unix sockets are not supported on this system")))
238238
}
@@ -265,10 +265,10 @@ pub fn initialize_stream(params: &ConnectParams,
265265
}
266266
}
267267

268-
let host = match params.target {
269-
ConnectTarget::Tcp(ref host) => host,
268+
let host = match *params.host() {
269+
Host::Tcp(ref host) => host,
270270
// Postgres doesn't support TLS over unix sockets
271-
ConnectTarget::Unix(_) => return Err(ConnectError::Io(::bad_response())),
271+
Host::Unix(_) => return Err(ConnectError::Io(::bad_response())),
272272
};
273273

274274
handshaker.tls_handshake(host, socket).map_err(ConnectError::Tls)

postgres/tests/test.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -978,8 +978,8 @@ fn url_unencoded_password() {
978978
#[test]
979979
fn url_encoded_password() {
980980
let params = "postgresql://username%7b%7c:password%7b%7c@localhost".into_connect_params().unwrap();
981-
assert_eq!("username{|", &params.user.as_ref().unwrap().user[..]);
982-
assert_eq!("password{|", &params.user.as_ref().unwrap().password.as_ref().unwrap()[..]);
981+
assert_eq!("username{|", params.user().unwrap().name());
982+
assert_eq!("password{|", params.user().unwrap().password().unwrap());
983983
}
984984

985985
#[test]

0 commit comments

Comments
 (0)