diff --git a/refinery_cli/src/migrate.rs b/refinery_cli/src/migrate.rs index 77deae02..cd9557a4 100644 --- a/refinery_cli/src/migrate.rs +++ b/refinery_cli/src/migrate.rs @@ -3,7 +3,7 @@ use std::path::Path; use anyhow::Context; use refinery_core::{ config::{Config, ConfigDbType}, - find_migration_files, Migration, MigrationType, Runner, SchemaVersion, Target, + find_migration_files, Migration, MigrationType, SchemaVersion, Target, }; use crate::cli::MigrateArgs; @@ -19,8 +19,7 @@ pub fn handle_migration_command(args: MigrateArgs) -> anyhow::Result<()> { args.env_var.as_deref(), &args.path, &args.table_name, - )?; - Ok(()) + ) } #[allow(clippy::too_many_arguments)] @@ -73,7 +72,7 @@ fn run_migrations( .context("Can't start tokio runtime")?; runtime.block_on(async { - Runner::new(&migrations) + refinery_core::Runner::new(&migrations) .set_grouped(grouped) .set_target(target) .set_abort_divergent(divergent) @@ -90,7 +89,7 @@ fn run_migrations( _db_type @ (ConfigDbType::Mysql | ConfigDbType::Postgres | ConfigDbType::Sqlite) => { cfg_if::cfg_if! { if #[cfg(any(feature = "mysql", feature = "postgresql", feature = "sqlite"))] { - Runner::new(&migrations) + refinery_core::Runner::new(&migrations) .set_grouped(grouped) .set_abort_divergent(divergent) .set_abort_missing(missing) @@ -103,7 +102,6 @@ fn run_migrations( } } }; - Ok(()) } diff --git a/refinery_cli/src/setup.rs b/refinery_cli/src/setup.rs index e6e06f90..6429df71 100644 --- a/refinery_cli/src/setup.rs +++ b/refinery_cli/src/setup.rs @@ -59,35 +59,45 @@ fn get_config_from_input() -> Result { } } - print!("Enter database host: "); - io::stdout().flush()?; - let mut db_host = String::new(); - io::stdin().read_line(&mut db_host)?; - config = config.set_db_host(db_host.trim()); + cfg_if::cfg_if! { + if #[cfg(any( + feature = "mysql", + feature = "mssql", + feature = "postgresql", + ))]{ + print!("Enter database host: "); + io::stdout().flush()?; + let mut db_host = String::new(); + io::stdin().read_line(&mut db_host)?; + config = config.set_db_host(db_host.trim()); - print!("Enter database port: "); - io::stdout().flush()?; - let mut db_port = String::new(); - io::stdin().read_line(&mut db_port)?; - config = config.set_db_port(db_port.trim()); + print!("Enter database port: "); + io::stdout().flush()?; + let mut db_port = String::new(); + io::stdin().read_line(&mut db_port)?; + config = config.set_db_port(db_port.trim()); - print!("Enter database username: "); - io::stdout().flush()?; - let mut db_user = String::new(); - io::stdin().read_line(&mut db_user)?; - config = config.set_db_user(db_user.trim()); + print!("Enter database username: "); + io::stdout().flush()?; + let mut db_user = String::new(); + io::stdin().read_line(&mut db_user)?; + config = config.set_db_user(db_user.trim()); - print!("Enter database password: "); - io::stdout().flush()?; - let mut db_pass = String::new(); - io::stdin().read_line(&mut db_pass)?; - config = config.set_db_pass(db_pass.trim()); + print!("Enter database password: "); + io::stdout().flush()?; + let mut db_pass = String::new(); + io::stdin().read_line(&mut db_pass)?; + config = config.set_db_pass(db_pass.trim()); - print!("Enter database name: "); - io::stdout().flush()?; - let mut db_name = String::new(); - io::stdin().read_line(&mut db_name)?; - config = config.set_db_name(db_name.trim()); + print!("Enter database name: "); + io::stdout().flush()?; + let mut db_name = String::new(); + io::stdin().read_line(&mut db_name)?; + config = config.set_db_name(db_name.trim()); - Ok(config) + Ok(config) + } else { + Ok(config) + } + } } diff --git a/refinery_core/src/config.rs b/refinery_core/src/config.rs index bbb3b1da..ca1b0101 100644 --- a/refinery_core/src/config.rs +++ b/refinery_core/src/config.rs @@ -1,14 +1,13 @@ use crate::error::Kind; use crate::Error; -use std::convert::TryFrom; -use std::path::PathBuf; -use std::str::FromStr; #[cfg(any( feature = "postgres", feature = "tokio-postgres", feature = "tiberius-config" ))] -use std::{borrow::Cow, collections::HashMap}; +use std::borrow::Cow; +use std::convert::TryFrom; +use std::str::FromStr; use url::Url; // refinery config file used by migrate_from_config if migration from a Config struct is preferred instead of using the macros @@ -32,19 +31,7 @@ impl Config { /// create a new config instance pub fn new(db_type: ConfigDbType) -> Config { Config { - main: Main { - db_type, - db_path: None, - db_host: None, - db_port: None, - db_user: None, - db_pass: None, - db_name: None, - #[cfg(any(feature = "postgres", feature = "tokio-postgres"))] - use_tls: false, - #[cfg(feature = "tiberius-config")] - trust_cert: false, - }, + main: Main::new(db_type), } } @@ -59,6 +46,10 @@ impl Config { Config::from_str(&value) } + pub fn db_type(&self) -> ConfigDbType { + self.main.db_type + } + /// create a new Config instance from a config file located on the file system #[cfg(feature = "toml")] pub fn from_file_location>(location: T) -> Result { @@ -69,7 +60,7 @@ impl Config { ) })?; - let mut config: Config = toml::from_str(&file).map_err(|err| { + let config: Config = toml::from_str(&file).map_err(|err| { Error::new( Kind::ConfigError(format!("could not parse config file, {err}")), None, @@ -77,7 +68,9 @@ impl Config { })?; //replace relative path with canonical path in case of Sqlite db + #[cfg(feature = "rusqlite")] if config.main.db_type == ConfigDbType::Sqlite { + let mut config = config; let mut config_db_path = config.main.db_path.ok_or_else(|| { Error::new( Kind::ConfigError("field path must be present for Sqlite database type".into()), @@ -103,42 +96,28 @@ impl Config { None, ) })?; - config.main.db_path = Some(config_db_path); - } - - Ok(config) - } - cfg_if::cfg_if! { - if #[cfg(feature = "rusqlite")] { - pub(crate) fn db_path(&self) -> Option<&std::path::Path> { - self.main.db_path.as_deref() - } - - pub fn set_db_path(self, db_path: &str) -> Config { - Config { - main: Main { - db_path: Some(db_path.into()), - ..self.main - }, - } - } + return Ok(config); } - } - cfg_if::cfg_if! { - if #[cfg(feature = "tiberius-config")] { - pub fn set_trust_cert(&mut self) { - self.main.trust_cert = true; - } - } + Ok(config) } - pub fn db_type(&self) -> ConfigDbType { - self.main.db_type + #[cfg(feature = "tiberius-config")] + pub fn set_trust_cert(&mut self) { + self.main.trust_cert = true; } +} +#[cfg(any( + feature = "mysql", + feature = "postgres", + feature = "tokio-postgres", + feature = "mysql_async", + feature = "tiberius-config" +))] +impl Config { pub fn db_host(&self) -> Option<&str> { self.main.db_host.as_deref() } @@ -147,11 +126,6 @@ impl Config { self.main.db_port.as_deref() } - #[cfg(any(feature = "postgres", feature = "tokio-postgres"))] - pub fn use_tls(&self) -> bool { - self.main.use_tls - } - pub fn set_db_user(self, db_user: &str) -> Config { Config { main: Main { @@ -196,8 +170,30 @@ impl Config { }, } } +} + +#[cfg(feature = "rusqlite")] +impl Config { + pub(crate) fn db_path(&self) -> Option<&std::path::Path> { + self.main.db_path.as_deref() + } + + pub fn set_db_path(self, db_path: &str) -> Config { + Config { + main: Main { + db_path: Some(db_path.into()), + ..self.main + }, + } + } +} + +#[cfg(any(feature = "postgres", feature = "tokio-postgres"))] +impl Config { + pub fn use_tls(&self) -> bool { + self.main.use_tls + } - #[cfg(any(feature = "postgres", feature = "tokio-postgres"))] pub fn set_use_tls(self, use_tls: bool) -> Config { Config { main: Main { @@ -226,45 +222,10 @@ impl TryFrom for Config { } }; - #[cfg(any( - feature = "postgres", - feature = "tokio-postgres", - feature = "tiberius-config" - ))] - let query_params = url - .query_pairs() - .collect::, Cow<'_, str>>>(); - - cfg_if::cfg_if! { - if #[cfg(feature = "tiberius-config")] { - let trust_cert = query_params. - get("trust_cert") - .unwrap_or(&Cow::Borrowed("false")) - .parse::() - .map_err(|_| { - Error::new( - Kind::ConfigError("Invalid trust_cert value, please use true/false".into()), - None, - ) - })?; - } - } - - #[cfg(any(feature = "postgres", feature = "tokio-postgres"))] - let use_tls = match query_params.get("sslmode") { - Some(Cow::Borrowed("require")) => true, - Some(Cow::Borrowed("disable")) | None => false, - _ => { - return Err(Error::new( - Kind::ConfigError("Invalid sslmode value, please use disable/require".into()), - None, - )) - } - }; - Ok(Self { main: Main { db_type, + #[cfg(feature = "rusqlite")] db_path: Some( url.as_str()[url.scheme().len()..] .trim_start_matches(':') @@ -272,15 +233,78 @@ impl TryFrom for Config { .to_string() .into(), ), + #[cfg(any( + feature = "mysql", + feature = "postgres", + feature = "tokio-postgres", + feature = "mysql_async", + feature = "tiberius-config" + ))] db_host: url.host_str().map(|r| r.to_string()), + #[cfg(any( + feature = "mysql", + feature = "postgres", + feature = "tokio-postgres", + feature = "mysql_async", + feature = "tiberius-config" + ))] db_port: url.port().map(|r| r.to_string()), + #[cfg(any( + feature = "mysql", + feature = "postgres", + feature = "tokio-postgres", + feature = "mysql_async", + feature = "tiberius-config" + ))] db_user: Some(url.username().to_string()), + #[cfg(any( + feature = "mysql", + feature = "postgres", + feature = "tokio-postgres", + feature = "mysql_async", + feature = "tiberius-config" + ))] db_pass: url.password().map(|r| r.to_string()), + #[cfg(any( + feature = "mysql", + feature = "postgres", + feature = "tokio-postgres", + feature = "mysql_async", + feature = "tiberius-config" + ))] db_name: Some(url.path().trim_start_matches('/').to_string()), #[cfg(any(feature = "postgres", feature = "tokio-postgres"))] - use_tls, + use_tls: match url + .query_pairs() + .collect::, Cow<'_, str>>>() + .get("sslmode") + { + Some(Cow::Borrowed("require")) => true, + Some(Cow::Borrowed("disable")) | None => false, + _ => { + return Err(Error::new( + Kind::ConfigError( + "Invalid sslmode value, please use disable/require".into(), + ), + None, + )) + } + }, #[cfg(feature = "tiberius-config")] - trust_cert, + trust_cert: url + .query_pairs() + .collect::, Cow<'_, str>>>() + .get("trust_cert") + .unwrap_or(&Cow::Borrowed("false")) + .parse::() + .map_err(|_| { + Error::new( + Kind::ConfigError( + "Invalid trust_cert value, please use true/false".into(), + ), + None, + ) + })?, }, }) } @@ -305,11 +329,47 @@ impl FromStr for Config { #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] struct Main { db_type: ConfigDbType, - db_path: Option, + #[cfg(feature = "rusqlite")] + db_path: Option, + #[cfg(any( + feature = "mysql", + feature = "postgres", + feature = "tokio-postgres", + feature = "mysql_async", + feature = "tiberius-config" + ))] db_host: Option, + #[cfg(any( + feature = "mysql", + feature = "postgres", + feature = "tokio-postgres", + feature = "mysql_async", + feature = "tiberius-config" + ))] db_port: Option, + #[cfg(any( + feature = "mysql", + feature = "postgres", + feature = "tokio-postgres", + feature = "mysql_async", + feature = "tiberius-config" + ))] db_user: Option, + #[cfg(any( + feature = "mysql", + feature = "postgres", + feature = "tokio-postgres", + feature = "mysql_async", + feature = "tiberius-config" + ))] db_pass: Option, + #[cfg(any( + feature = "mysql", + feature = "postgres", + feature = "tokio-postgres", + feature = "mysql_async", + feature = "tiberius-config" + ))] db_name: Option, #[cfg(any(feature = "postgres", feature = "tokio-postgres"))] #[cfg_attr(feature = "serde", serde(default))] @@ -319,11 +379,65 @@ struct Main { trust_cert: bool, } +impl Main { + fn new(db_type: ConfigDbType) -> Self { + Main { + db_type, + #[cfg(feature = "rusqlite")] + db_path: None, + #[cfg(any( + feature = "mysql", + feature = "postgres", + feature = "tokio-postgres", + feature = "mysql_async", + feature = "tiberius-config" + ))] + db_host: None, + #[cfg(any( + feature = "mysql", + feature = "postgres", + feature = "tokio-postgres", + feature = "mysql_async", + feature = "tiberius-config" + ))] + db_port: None, + #[cfg(any( + feature = "mysql", + feature = "postgres", + feature = "tokio-postgres", + feature = "mysql_async", + feature = "tiberius-config" + ))] + db_user: None, + #[cfg(any( + feature = "mysql", + feature = "postgres", + feature = "tokio-postgres", + feature = "mysql_async", + feature = "tiberius-config" + ))] + db_pass: None, + #[cfg(any( + feature = "mysql", + feature = "postgres", + feature = "tokio-postgres", + feature = "mysql_async", + feature = "tiberius-config" + ))] + db_name: None, + #[cfg(any(feature = "postgres", feature = "tokio-postgres"))] + use_tls: false, + #[cfg(feature = "tiberius-config")] + trust_cert: false, + } + } +} + #[cfg(any( feature = "mysql", feature = "postgres", feature = "tokio-postgres", - feature = "mysql_async" + feature = "mysql_async", ))] pub(crate) fn build_db_url(name: &str, config: &Config) -> String { let mut url: String = name.to_string() + "://"; @@ -350,52 +464,58 @@ pub(crate) fn build_db_url(name: &str, config: &Config) -> String { url } -cfg_if::cfg_if! { - if #[cfg(feature = "tiberius-config")] { - use tiberius::{AuthMethod, Config as TConfig}; - - impl TryFrom<&Config> for TConfig { - type Error=Error; - - fn try_from(config: &Config) -> Result { - let mut tconfig = TConfig::new(); - if let Some(host) = &config.main.db_host { - tconfig.host(host); - } +#[cfg(feature = "tiberius-config")] +impl TryFrom<&Config> for tiberius::Config { + type Error = Error; - if let Some(port) = &config.main.db_port { - let port = port.parse().map_err(|_| Error::new( - Kind::ConfigError(format!("Couldn't parse value {port} as mssql port")), - None, - ))?; - tconfig.port(port); - } + fn try_from(config: &Config) -> Result { + let mut tconfig = tiberius::Config::new(); + if let Some(host) = &config.main.db_host { + tconfig.host(host); + } - if let Some(db) = &config.main.db_name { - tconfig.database(db); - } + if let Some(port) = &config.main.db_port { + let port = port.parse().map_err(|_| { + Error::new( + Kind::ConfigError(format!("Couldn't parse value {port} as mssql port")), + None, + ) + })?; + tconfig.port(port); + } - let user = config.main.db_user.as_deref().unwrap_or(""); - let pass = config.main.db_pass.as_deref().unwrap_or(""); + if let Some(db) = &config.main.db_name { + tconfig.database(db); + } - if config.main.trust_cert { - tconfig.trust_cert(); - } - tconfig.authentication(AuthMethod::sql_server(user, pass)); + let user = config.main.db_user.as_deref().unwrap_or(""); + let pass = config.main.db_pass.as_deref().unwrap_or(""); - Ok(tconfig) - } + if config.main.trust_cert { + tconfig.trust_cert(); } + tconfig.authentication(tiberius::AuthMethod::sql_server(user, pass)); + + Ok(tconfig) } } #[cfg(test)] mod tests { - use super::{build_db_url, Config, Kind}; + use super::{Config, Kind}; use std::io::Write; use std::str::FromStr; + #[cfg(any( + feature = "mysql", + feature = "postgres", + feature = "tokio-postgres", + feature = "mysql_async" + ))] + use super::build_db_url; + #[test] + #[cfg(feature = "toml")] fn returns_config_error_from_invalid_config_location() { let config = Config::from_file_location("invalid_path").unwrap_err(); match config.kind() { @@ -405,6 +525,7 @@ mod tests { } #[test] + #[cfg(feature = "toml")] fn returns_config_error_from_invalid_toml_file() { let config = "[<$% db_type = \"Sqlite\" \n"; @@ -419,6 +540,7 @@ mod tests { } #[test] + #[cfg(all(feature = "toml", feature = "rusqlite"))] fn returns_config_error_from_sqlite_with_missing_path() { let config = "[main] \n db_type = \"Sqlite\" \n"; @@ -435,6 +557,7 @@ mod tests { } #[test] + #[cfg(all(feature = "toml", feature = "rusqlite"))] fn builds_sqlite_path_from_relative_path() { let db_file = tempfile::NamedTempFile::new_in(".").unwrap(); @@ -458,6 +581,15 @@ mod tests { } #[test] + #[cfg(all( + feature = "toml", + any( + feature = "mysql", + feature = "postgres", + feature = "tokio-postgres", + feature = "mysql_async" + ) + ))] fn builds_db_url() { let config = "[main] \n db_type = \"Postgres\" \n @@ -476,12 +608,18 @@ mod tests { } #[test] + #[cfg(any( + feature = "mysql", + feature = "postgres", + feature = "tokio-postgres", + feature = "mysql_async" + ))] fn builds_db_env_var() { std::env::set_var( - "DATABASE_URL", + "TEST_DATABASE_URL", "postgres://root:1234@localhost:5432/refinery", ); - let config = Config::from_env_var("DATABASE_URL").unwrap(); + let config = Config::from_env_var("TEST_DATABASE_URL").unwrap(); assert_eq!( "postgres://root:1234@localhost:5432/refinery", build_db_url("postgres", &config) @@ -489,6 +627,12 @@ mod tests { } #[test] + #[cfg(any( + feature = "mysql", + feature = "postgres", + feature = "tokio-postgres", + feature = "mysql_async" + ))] fn builds_from_str() { let config = Config::from_str("postgres://root:1234@localhost:5432/refinery").unwrap(); assert_eq!( @@ -538,8 +682,8 @@ mod tests { #[test] fn builds_db_env_var_failure() { - std::env::set_var("DATABASE_URL", "this_is_not_a_url"); - let config = Config::from_env_var("DATABASE_URL"); + std::env::set_var("TEST_DATABASE_URL_INVALID", "this_is_not_a_url"); + let config = Config::from_env_var("TEST_DATABASE_URL_INVALID"); assert!(config.is_err()); } } diff --git a/refinery_core/src/drivers/config.rs b/refinery_core/src/drivers/config.rs index 9a9067ba..3f1cfd74 100644 --- a/refinery_core/src/drivers/config.rs +++ b/refinery_core/src/drivers/config.rs @@ -1,16 +1,21 @@ +use crate::config::Config; +use crate::traits::r#async::{AsyncQuery, AsyncTransaction}; +use crate::traits::sync::{Query, Transaction}; +use crate::Migration; #[cfg(any( feature = "mysql", feature = "postgres", + feature = "rusqlite", feature = "tokio-postgres", - feature = "mysql_async" + feature = "mysql_async", + feature = "tiberius-config" ))] -use crate::config::build_db_url; -use crate::config::{Config, ConfigDbType}; -use crate::error::WrapMigrationError; -use crate::traits::r#async::{AsyncQuery, AsyncTransaction}; -use crate::traits::sync::{Query, Transaction}; -use crate::traits::{GET_APPLIED_MIGRATIONS_QUERY, GET_LAST_APPLIED_MIGRATION_QUERY}; -use crate::{Error, Migration, Report, Target}; +use crate::{ + config::ConfigDbType, + error::WrapMigrationError, + traits::{GET_APPLIED_MIGRATIONS_QUERY, GET_LAST_APPLIED_MIGRATION_QUERY}, + Error, Report, Target, +}; use async_trait::async_trait; use std::convert::Infallible; @@ -63,7 +68,7 @@ macro_rules! with_connection { ConfigDbType::Mysql => { cfg_if::cfg_if! { if #[cfg(feature = "mysql")] { - let url = build_db_url("mysql", &$config); + let url = crate::config::build_db_url("mysql", &$config); let opts = mysql::Opts::from_url(&url).migration_err("could not parse url", None)?; let conn = mysql::Conn::new(opts).migration_err("could not connect to database", None)?; $op(conn) @@ -87,7 +92,7 @@ macro_rules! with_connection { ConfigDbType::Postgres => { cfg_if::cfg_if! { if #[cfg(feature = "postgres")] { - let path = build_db_url("postgresql", &$config); + let path = crate::config::build_db_url("postgresql", &$config); let conn; if $config.use_tls() { @@ -123,7 +128,7 @@ macro_rules! with_connection_async { ConfigDbType::Mysql => { cfg_if::cfg_if! { if #[cfg(feature = "mysql_async")] { - let url = build_db_url("mysql", $config); + let url = crate::config::build_db_url("mysql", $config); let pool = mysql_async::Pool::from_url(&url).migration_err("could not connect to the database", None)?; $op(pool).await } else { @@ -137,7 +142,7 @@ macro_rules! with_connection_async { ConfigDbType::Postgres => { cfg_if::cfg_if! { if #[cfg(feature = "tokio-postgres")] { - let path = build_db_url("postgresql", $config); + let path = crate::config::build_db_url("postgresql", $config); if $config.use_tls() { let connector = native_tls::TlsConnector::new().unwrap(); let connector = postgres_native_tls::MakeTlsConnector::new(connector);