Skip to content

Commit 967f98e

Browse files
hugocasamossbanayjxs
authored
Add support for TLS in postgres/tokio-postgres using native-tls (#353)
* Add support for using TLS with PostgreSQL (#260) * Make use_tls optional in config * Include openssl crates in refinery module * tls support for postgres using native-tls * use tls for tokio_postgres as well * nit * nit * fix issues * avoid touching main Cargo.toml * add new tests * fix serde bug --------- Co-authored-by: moss <[email protected]> Co-authored-by: João Oliveira <[email protected]>
1 parent 057ef74 commit 967f98e

File tree

5 files changed

+193
-17
lines changed

5 files changed

+193
-17
lines changed

refinery/tests/postgres.rs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use barrel::backend::Pg as Sql;
44
mod postgres {
55
use assert_cmd::prelude::*;
66
use predicates::str::contains;
7+
use refinery::config::ConfigDbType;
78
use refinery::{
89
config::Config, embed_migrations, error::Kind, Migrate, Migration, Runner, Target,
910
};
@@ -728,4 +729,38 @@ mod postgres {
728729
.stdout(contains("applying migration: V3__add_brand_to_cars_table"));
729730
})
730731
}
732+
733+
#[test]
734+
fn migrates_with_tls_enabled() {
735+
run_test(|| {
736+
let mut config = Config::new(ConfigDbType::Postgres)
737+
.set_db_name("postgres")
738+
.set_db_user("postgres")
739+
.set_db_host("localhost")
740+
.set_db_port("5432")
741+
.set_use_tls(true);
742+
743+
let migrations = get_migrations();
744+
let runner = Runner::new(&migrations)
745+
.set_grouped(false)
746+
.set_abort_divergent(true)
747+
.set_abort_missing(true);
748+
749+
let report = runner.run(&mut config).unwrap();
750+
751+
let applied_migrations = report.applied_migrations();
752+
assert_eq!(5, applied_migrations.len());
753+
754+
let last_migration = runner
755+
.get_last_applied_migration(&mut config)
756+
.unwrap()
757+
.unwrap();
758+
759+
assert_eq!(5, last_migration.version());
760+
assert_eq!(migrations[4].name(), last_migration.name());
761+
assert_eq!(migrations[4].checksum(), last_migration.checksum());
762+
763+
assert!(config.use_tls());
764+
});
765+
}
731766
}

refinery/tests/tokio_postgres.rs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ mod tokio_postgres {
1212
use refinery_core::tokio_postgres;
1313
use refinery_core::tokio_postgres::NoTls;
1414
use std::panic::AssertUnwindSafe;
15+
use std::str::FromStr;
1516
use time::OffsetDateTime;
1617

1718
const DEFAULT_TABLE_NAME: &str = "refinery_schema_history";
@@ -953,4 +954,37 @@ mod tokio_postgres {
953954
})
954955
.await;
955956
}
957+
958+
#[tokio::test]
959+
async fn migrates_with_tls_enabled() {
960+
run_test(async {
961+
let mut config =
962+
Config::from_str("postgres://postgres@localhost:5432/postgres?sslmode=require")
963+
.unwrap();
964+
965+
let migrations = get_migrations();
966+
let runner = Runner::new(&migrations)
967+
.set_grouped(false)
968+
.set_abort_divergent(true)
969+
.set_abort_missing(true);
970+
971+
let report = runner.run_async(&mut config).await.unwrap();
972+
973+
let applied_migrations = report.applied_migrations();
974+
assert_eq!(5, applied_migrations.len());
975+
976+
let last_migration = runner
977+
.get_last_applied_migration_async(&mut config)
978+
.await
979+
.unwrap()
980+
.unwrap();
981+
982+
assert_eq!(5, last_migration.version());
983+
assert_eq!(migrations[4].name(), last_migration.name());
984+
assert_eq!(migrations[4].checksum(), last_migration.checksum());
985+
986+
assert!(config.use_tls());
987+
})
988+
.await;
989+
}
956990
}

refinery_core/Cargo.toml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,13 @@ edition = "2021"
1010

1111
[features]
1212
default = []
13+
mysql_async = ["dep:mysql_async"]
14+
postgres = ["dep:postgres", "dep:postgres-native-tls", "dep:native-tls"]
1315
rusqlite-bundled = ["rusqlite", "rusqlite/bundled"]
16+
serde = ["dep:serde"]
1417
tiberius = ["dep:tiberius", "futures", "tokio", "tokio/net"]
1518
tiberius-config = ["tiberius", "tokio", "tokio-util", "serde"]
16-
tokio-postgres = ["dep:tokio-postgres", "tokio", "tokio/rt"]
17-
mysql_async = ["dep:mysql_async"]
18-
serde = ["dep:serde"]
19+
tokio-postgres = ["dep:postgres-native-tls", "dep:native-tls", "dep:tokio-postgres", "tokio", "tokio/rt"]
1920
toml = ["serde", "dep:toml"]
2021

2122
[dependencies]
@@ -31,6 +32,8 @@ walkdir = "2.3.1"
3132
# allow multiple versions of the same dependency if API is similar
3233
rusqlite = { version = ">= 0.23, <= 0.37", optional = true }
3334
postgres = { version = ">=0.17, <= 0.19", optional = true }
35+
native-tls = { version = "0.2", optional = true }
36+
postgres-native-tls = { version = "0.5", optional = true}
3437
tokio-postgres = { version = ">= 0.5, <= 0.7", optional = true }
3538
mysql = { version = ">= 21.0.0, <= 26", optional = true, default-features = false, features = ["minimal"] }
3639
mysql_async = { version = ">= 0.28, <= 0.35", optional = true, default-features = false, features = ["minimal"] }

refinery_core/src/config.rs

Lines changed: 89 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,12 @@ use crate::Error;
33
use std::convert::TryFrom;
44
use std::path::PathBuf;
55
use std::str::FromStr;
6+
#[cfg(any(
7+
feature = "postgres",
8+
feature = "tokio-postgres",
9+
feature = "tiberius-config"
10+
))]
11+
use std::{borrow::Cow, collections::HashMap};
612
use url::Url;
713

814
// refinery config file used by migrate_from_config if migration from a Config struct is preferred instead of using the macros
@@ -34,6 +40,8 @@ impl Config {
3440
db_user: None,
3541
db_pass: None,
3642
db_name: None,
43+
#[cfg(any(feature = "postgres", feature = "tokio-postgres"))]
44+
use_tls: false,
3745
#[cfg(feature = "tiberius-config")]
3846
trust_cert: false,
3947
},
@@ -139,6 +147,11 @@ impl Config {
139147
self.main.db_port.as_deref()
140148
}
141149

150+
#[cfg(any(feature = "postgres", feature = "tokio-postgres"))]
151+
pub fn use_tls(&self) -> bool {
152+
self.main.use_tls
153+
}
154+
142155
pub fn set_db_user(self, db_user: &str) -> Config {
143156
Config {
144157
main: Main {
@@ -183,6 +196,16 @@ impl Config {
183196
},
184197
}
185198
}
199+
200+
#[cfg(any(feature = "postgres", feature = "tokio-postgres"))]
201+
pub fn set_use_tls(self, use_tls: bool) -> Config {
202+
Config {
203+
main: Main {
204+
use_tls,
205+
..self.main
206+
},
207+
}
208+
}
186209
}
187210

188211
impl TryFrom<Url> for Config {
@@ -203,13 +226,17 @@ impl TryFrom<Url> for Config {
203226
}
204227
};
205228

229+
#[cfg(any(
230+
feature = "postgres",
231+
feature = "tokio-postgres",
232+
feature = "tiberius-config"
233+
))]
234+
let query_params = url
235+
.query_pairs()
236+
.collect::<HashMap<Cow<'_, str>, Cow<'_, str>>>();
237+
206238
cfg_if::cfg_if! {
207239
if #[cfg(feature = "tiberius-config")] {
208-
use std::{borrow::Cow, collections::HashMap};
209-
let query_params = url
210-
.query_pairs()
211-
.collect::<HashMap< Cow<'_, str>, Cow<'_, str>>>();
212-
213240
let trust_cert = query_params.
214241
get("trust_cert")
215242
.unwrap_or(&Cow::Borrowed("false"))
@@ -223,6 +250,18 @@ impl TryFrom<Url> for Config {
223250
}
224251
}
225252

253+
#[cfg(any(feature = "postgres", feature = "tokio-postgres"))]
254+
let use_tls = match query_params.get("sslmode") {
255+
Some(Cow::Borrowed("require")) => true,
256+
Some(Cow::Borrowed("disable")) | None => false,
257+
_ => {
258+
return Err(Error::new(
259+
Kind::ConfigError("Invalid sslmode value, please use disable/require".into()),
260+
None,
261+
))
262+
}
263+
};
264+
226265
Ok(Self {
227266
main: Main {
228267
db_type,
@@ -238,6 +277,8 @@ impl TryFrom<Url> for Config {
238277
db_user: Some(url.username().to_string()),
239278
db_pass: url.password().map(|r| r.to_string()),
240279
db_name: Some(url.path().trim_start_matches('/').to_string()),
280+
#[cfg(any(feature = "postgres", feature = "tokio-postgres"))]
281+
use_tls,
241282
#[cfg(feature = "tiberius-config")]
242283
trust_cert,
243284
},
@@ -270,8 +311,11 @@ struct Main {
270311
db_user: Option<String>,
271312
db_pass: Option<String>,
272313
db_name: Option<String>,
314+
#[cfg(any(feature = "postgres", feature = "tokio-postgres"))]
315+
#[cfg_attr(feature = "serde", serde(default))]
316+
use_tls: bool,
273317
#[cfg(feature = "tiberius-config")]
274-
#[serde(default)]
318+
#[cfg_attr(feature = "serde", serde(default))]
275319
trust_cert: bool,
276320
}
277321

@@ -453,6 +497,45 @@ mod tests {
453497
);
454498
}
455499

500+
#[cfg(any(feature = "postgres", feature = "tokio-postgres"))]
501+
#[test]
502+
fn builds_from_sslmode_str() {
503+
use crate::config::ConfigDbType;
504+
505+
let config_disable =
506+
Config::from_str("postgres://root:1234@localhost:5432/refinery?sslmode=disable")
507+
.unwrap();
508+
assert!(!config_disable.use_tls());
509+
510+
let config_require =
511+
Config::from_str("postgres://root:1234@localhost:5432/refinery?sslmode=require")
512+
.unwrap();
513+
assert!(config_require.use_tls());
514+
515+
// Verify that manually created config matches parsed URL config
516+
let manual_config_disable = Config::new(ConfigDbType::Postgres)
517+
.set_db_user("root")
518+
.set_db_pass("1234")
519+
.set_db_host("localhost")
520+
.set_db_port("5432")
521+
.set_db_name("refinery")
522+
.set_use_tls(false);
523+
assert_eq!(config_disable.use_tls(), manual_config_disable.use_tls());
524+
525+
let manual_config_require = Config::new(ConfigDbType::Postgres)
526+
.set_db_user("root")
527+
.set_db_pass("1234")
528+
.set_db_host("localhost")
529+
.set_db_port("5432")
530+
.set_db_name("refinery")
531+
.set_use_tls(true);
532+
assert_eq!(config_require.use_tls(), manual_config_require.use_tls());
533+
534+
let config =
535+
Config::from_str("postgres://root:1234@localhost:5432/refinery?sslmode=invalidvalue");
536+
assert!(config.is_err());
537+
}
538+
456539
#[test]
457540
fn builds_db_env_var_failure() {
458541
std::env::set_var("DATABASE_URL", "this_is_not_a_url");

refinery_core/src/drivers/config.rs

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,16 @@ macro_rules! with_connection {
8282
cfg_if::cfg_if! {
8383
if #[cfg(feature = "postgres")] {
8484
let path = build_db_url("postgresql", &$config);
85-
let conn = postgres::Client::connect(path.as_str(), postgres::NoTls).migration_err("could not connect to database", None)?;
85+
86+
let conn;
87+
if $config.use_tls() {
88+
let connector = native_tls::TlsConnector::new().unwrap();
89+
let connector = postgres_native_tls::MakeTlsConnector::new(connector);
90+
conn = postgres::Client::connect(path.as_str(), connector).migration_err("could not connect to database", None)?;
91+
} else {
92+
conn = postgres::Client::connect(path.as_str(), postgres::NoTls).migration_err("could not connect to database", None)?;
93+
}
94+
8695
$op(conn)
8796
} else {
8897
panic!("tried to migrate from config for a postgresql database, but feature postgres not enabled!");
@@ -123,13 +132,25 @@ macro_rules! with_connection_async {
123132
cfg_if::cfg_if! {
124133
if #[cfg(feature = "tokio-postgres")] {
125134
let path = build_db_url("postgresql", $config);
126-
let (client, connection ) = tokio_postgres::connect(path.as_str(), tokio_postgres::NoTls).await.migration_err("could not connect to database", None)?;
127-
tokio::spawn(async move {
128-
if let Err(e) = connection.await {
129-
eprintln!("connection error: {}", e);
130-
}
131-
});
132-
$op(client).await
135+
if $config.use_tls() {
136+
let connector = native_tls::TlsConnector::new().unwrap();
137+
let connector = postgres_native_tls::MakeTlsConnector::new(connector);
138+
let (client, connection) = tokio_postgres::connect(path.as_str(), connector).await.migration_err("could not connect to database", None)?;
139+
tokio::spawn(async move {
140+
if let Err(e) = connection.await {
141+
eprintln!("connection error: {}", e);
142+
}
143+
});
144+
$op(client).await
145+
} else {
146+
let (client, connection) = tokio_postgres::connect(path.as_str(), tokio_postgres::NoTls).await.migration_err("could not connect to database", None)?;
147+
tokio::spawn(async move {
148+
if let Err(e) = connection.await {
149+
eprintln!("connection error: {}", e);
150+
}
151+
});
152+
$op(client).await
153+
}
133154
} else {
134155
panic!("tried to migrate async from config for a postgresql database, but tokio-postgres was not enabled!");
135156
}

0 commit comments

Comments
 (0)