From 749eb2680924b4a6cf290b87dfdd36c89332cda3 Mon Sep 17 00:00:00 2001 From: Tobias Bieniek Date: Tue, 19 Nov 2024 18:04:42 +0100 Subject: [PATCH] test_db: Extract `async_connect()` fn --- Cargo.lock | 1 + crates/crates_io_test_db/Cargo.toml | 1 + crates/crates_io_test_db/src/lib.rs | 8 ++++++ src/controllers/user/session.rs | 3 +-- src/index.rs | 3 +-- src/models/category.rs | 13 +++++---- src/rate_limiter.rs | 27 +++++++++---------- src/tests/categories.rs | 10 +++---- src/tests/util/test_app.rs | 5 ++-- src/typosquat/database.rs | 3 +-- src/worker/jobs/archive_version_downloads.rs | 5 ++-- .../downloads/clean_processed_log_files.rs | 4 +-- src/worker/jobs/expiry_notification.rs | 3 +-- src/worker/jobs/rss/sync_crate_feed.rs | 3 +-- src/worker/jobs/rss/sync_crates_feed.rs | 4 +-- src/worker/jobs/rss/sync_updates_feed.rs | 3 +-- src/worker/jobs/typosquat.rs | 3 +-- 17 files changed, 49 insertions(+), 50 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c19b6002e0e..5c97b471045 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1288,6 +1288,7 @@ version = "0.0.0" dependencies = [ "crates_io_env_vars", "diesel", + "diesel-async", "diesel_migrations", "rand", "tracing", diff --git a/crates/crates_io_test_db/Cargo.toml b/crates/crates_io_test_db/Cargo.toml index 8b67c768767..40e75f0a003 100644 --- a/crates/crates_io_test_db/Cargo.toml +++ b/crates/crates_io_test_db/Cargo.toml @@ -10,6 +10,7 @@ workspace = true [dependencies] crates_io_env_vars = { path = "../crates_io_env_vars" } diesel = { version = "=2.2.4", features = ["postgres", "r2d2"] } +diesel-async = { version = "=0.5.1", features = ["postgres"] } diesel_migrations = { version = "=2.2.0", features = ["postgres"] } rand = "=0.8.5" tracing = "=0.1.40" diff --git a/crates/crates_io_test_db/src/lib.rs b/crates/crates_io_test_db/src/lib.rs index 9f493b912fa..eb513307fec 100644 --- a/crates/crates_io_test_db/src/lib.rs +++ b/crates/crates_io_test_db/src/lib.rs @@ -2,6 +2,7 @@ use crates_io_env_vars::required_var_parsed; use diesel::prelude::*; use diesel::r2d2::{ConnectionManager, Pool, PooledConnection}; use diesel::sql_query; +use diesel_async::{AsyncConnection, AsyncPgConnection}; use diesel_migrations::{FileBasedMigrations, MigrationHarness}; use rand::Rng; use std::sync::LazyLock; @@ -128,6 +129,13 @@ impl TestDatabase { .get() .expect("Failed to get database connection") } + + #[instrument(skip(self))] + pub async fn async_connect(&self) -> AsyncPgConnection { + AsyncPgConnection::establish(self.url()) + .await + .expect("Failed to connect to database") + } } impl Drop for TestDatabase { diff --git a/src/controllers/user/session.rs b/src/controllers/user/session.rs index 4c694ec4b86..91a369e111f 100644 --- a/src/controllers/user/session.rs +++ b/src/controllers/user/session.rs @@ -168,14 +168,13 @@ pub async fn logout(session: SessionExtension) -> Json { mod tests { use super::*; use crates_io_test_db::TestDatabase; - use diesel_async::AsyncConnection; #[tokio::test] async fn gh_user_with_invalid_email_doesnt_fail() { let emails = Emails::new_in_memory(); let test_db = TestDatabase::new(); - let mut conn = AsyncPgConnection::establish(test_db.url()).await.unwrap(); + let mut conn = test_db.async_connect().await; let gh_user = GithubUser { email: Some("String.Format(\"{0}.{1}@live.com\", FirstName, LastName)".into()), diff --git a/src/index.rs b/src/index.rs index 07dc8c84ae3..3e6f8eaa844 100644 --- a/src/index.rs +++ b/src/index.rs @@ -140,14 +140,13 @@ mod tests { use crate::tests::builders::{CrateBuilder, VersionBuilder}; use chrono::{Days, Utc}; use crates_io_test_db::TestDatabase; - use diesel_async::AsyncConnection; use insta::assert_json_snapshot; #[tokio::test] async fn test_index_metadata() { let test_db = TestDatabase::new(); let mut conn = test_db.connect(); - let mut async_conn = AsyncPgConnection::establish(test_db.url()).await.unwrap(); + let mut async_conn = test_db.async_connect().await; let user_id = diesel::insert_into(users::table) .values(( diff --git a/src/models/category.rs b/src/models/category.rs index a23ddfd4c2c..320d696b296 100644 --- a/src/models/category.rs +++ b/src/models/category.rs @@ -155,7 +155,6 @@ pub struct NewCategory<'a> { mod tests { use super::*; use crates_io_test_db::TestDatabase; - use diesel_async::AsyncConnection; use diesel_async::RunQueryDsl; #[tokio::test] @@ -163,7 +162,7 @@ mod tests { use self::categories; let test_db = TestDatabase::new(); - let mut conn = AsyncPgConnection::establish(test_db.url()).await.unwrap(); + let mut conn = test_db.async_connect().await; insert_into(categories::table) .values(&vec![ @@ -207,7 +206,7 @@ mod tests { }; let test_db = TestDatabase::new(); - let mut conn = AsyncPgConnection::establish(test_db.url()).await.unwrap(); + let mut conn = test_db.async_connect().await; insert_into(categories::table) .values(&vec![ @@ -238,7 +237,7 @@ mod tests { use self::categories; let test_db = TestDatabase::new(); - let mut conn = AsyncPgConnection::establish(test_db.url()).await.unwrap(); + let mut conn = test_db.async_connect().await; insert_into(categories::table) .values(&vec![ @@ -287,7 +286,7 @@ mod tests { }; let test_db = TestDatabase::new(); - let mut conn = AsyncPgConnection::establish(test_db.url()).await.unwrap(); + let mut conn = test_db.async_connect().await; insert_into(categories::table) .values(&vec![ @@ -329,7 +328,7 @@ mod tests { }; let test_db = TestDatabase::new(); - let mut conn = AsyncPgConnection::establish(test_db.url()).await.unwrap(); + let mut conn = test_db.async_connect().await; insert_into(categories::table) .values(&vec![ @@ -376,7 +375,7 @@ mod tests { }; let test_db = TestDatabase::new(); - let mut conn = AsyncPgConnection::establish(test_db.url()).await.unwrap(); + let mut conn = test_db.async_connect().await; insert_into(categories::table) .values(&vec![ diff --git a/src/rate_limiter.rs b/src/rate_limiter.rs index 1d0ded07b23..01cbc1c6748 100644 --- a/src/rate_limiter.rs +++ b/src/rate_limiter.rs @@ -187,12 +187,11 @@ mod tests { use super::*; use crate::schema::users; use crates_io_test_db::TestDatabase; - use diesel_async::AsyncConnection; #[tokio::test] async fn default_rate_limits() -> anyhow::Result<()> { let test_db = TestDatabase::new(); - let mut conn = AsyncPgConnection::establish(test_db.url()).await?; + let mut conn = test_db.async_connect().await; let now = now(); // Set the defaults as if no env vars have been set in production @@ -267,7 +266,7 @@ mod tests { #[tokio::test] async fn take_token_with_no_bucket_creates_new_one() -> anyhow::Result<()> { let test_db = TestDatabase::new(); - let mut conn = AsyncPgConnection::establish(test_db.url()).await?; + let mut conn = test_db.async_connect().await; let now = now(); let rate = SampleRateLimiter { @@ -319,7 +318,7 @@ mod tests { #[tokio::test] async fn take_token_with_existing_bucket_modifies_existing_bucket() -> anyhow::Result<()> { let test_db = TestDatabase::new(); - let mut conn = AsyncPgConnection::establish(test_db.url()).await?; + let mut conn = test_db.async_connect().await; let now = now(); let rate = SampleRateLimiter { @@ -345,7 +344,7 @@ mod tests { #[tokio::test] async fn take_token_after_delay_refills() -> anyhow::Result<()> { let test_db = TestDatabase::new(); - let mut conn = AsyncPgConnection::establish(test_db.url()).await?; + let mut conn = test_db.async_connect().await; let now = now(); let rate = SampleRateLimiter { @@ -372,7 +371,7 @@ mod tests { #[tokio::test] async fn refill_subsecond_rate() -> anyhow::Result<()> { let test_db = TestDatabase::new(); - let mut conn = AsyncPgConnection::establish(test_db.url()).await?; + let mut conn = test_db.async_connect().await; // Subsecond rates have floating point rounding issues, so use a known // timestamp that rounds fine let now = @@ -403,7 +402,7 @@ mod tests { #[tokio::test] async fn last_refill_always_advanced_by_multiple_of_rate() -> anyhow::Result<()> { let test_db = TestDatabase::new(); - let mut conn = AsyncPgConnection::establish(test_db.url()).await?; + let mut conn = test_db.async_connect().await; let now = now(); let rate = SampleRateLimiter { @@ -435,7 +434,7 @@ mod tests { #[tokio::test] async fn zero_tokens_returned_when_user_has_no_tokens_left() -> anyhow::Result<()> { let test_db = TestDatabase::new(); - let mut conn = AsyncPgConnection::establish(test_db.url()).await?; + let mut conn = test_db.async_connect().await; let now = now(); let rate = SampleRateLimiter { @@ -466,7 +465,7 @@ mod tests { #[tokio::test] async fn a_user_with_no_tokens_gets_a_token_after_exactly_rate() -> anyhow::Result<()> { let test_db = TestDatabase::new(); - let mut conn = AsyncPgConnection::establish(test_db.url()).await?; + let mut conn = test_db.async_connect().await; let now = now(); let rate = SampleRateLimiter { @@ -494,7 +493,7 @@ mod tests { #[tokio::test] async fn tokens_never_refill_past_burst() -> anyhow::Result<()> { let test_db = TestDatabase::new(); - let mut conn = AsyncPgConnection::establish(test_db.url()).await?; + let mut conn = test_db.async_connect().await; let now = now(); let rate = SampleRateLimiter { @@ -522,7 +521,7 @@ mod tests { #[tokio::test] async fn two_actions_dont_interfere_with_each_other() -> anyhow::Result<()> { let test_db = TestDatabase::new(); - let mut conn = AsyncPgConnection::establish(test_db.url()).await?; + let mut conn = test_db.async_connect().await; let now = now(); let mut config = HashMap::new(); @@ -571,7 +570,7 @@ mod tests { use diesel_async::RunQueryDsl; let test_db = TestDatabase::new(); - let mut conn = AsyncPgConnection::establish(test_db.url()).await?; + let mut conn = test_db.async_connect().await; let now = now(); let rate = SampleRateLimiter { @@ -609,7 +608,7 @@ mod tests { use diesel_async::RunQueryDsl; let test_db = TestDatabase::new(); - let mut conn = AsyncPgConnection::establish(test_db.url()).await?; + let mut conn = test_db.async_connect().await; let now = now(); let rate = SampleRateLimiter { @@ -669,7 +668,7 @@ mod tests { use diesel_async::RunQueryDsl; let test_db = TestDatabase::new(); - let mut conn = AsyncPgConnection::establish(test_db.url()).await?; + let mut conn = test_db.async_connect().await; let now = now(); let user_id = new_user(&mut conn, "user").await?; diff --git a/src/tests/categories.rs b/src/tests/categories.rs index b03befccdfa..79976676a1f 100644 --- a/src/tests/categories.rs +++ b/src/tests/categories.rs @@ -1,7 +1,7 @@ use crate::schema::categories; use crates_io_test_db::TestDatabase; use diesel::*; -use diesel_async::{AsyncConnection, AsyncPgConnection, RunQueryDsl}; +use diesel_async::{AsyncPgConnection, RunQueryDsl}; const ALGORITHMS: &str = r#" [algorithms] @@ -50,7 +50,7 @@ async fn select_slugs(conn: &mut AsyncPgConnection) -> Vec { #[tokio::test] async fn sync_adds_new_categories() { let test_db = TestDatabase::new(); - let mut conn = AsyncPgConnection::establish(test_db.url()).await.unwrap(); + let mut conn = test_db.async_connect().await; crate::boot::categories::sync_with_connection(ALGORITHMS_AND_SUCH, &mut conn) .await @@ -63,7 +63,7 @@ async fn sync_adds_new_categories() { #[tokio::test] async fn sync_removes_missing_categories() { let test_db = TestDatabase::new(); - let mut conn = AsyncPgConnection::establish(test_db.url()).await.unwrap(); + let mut conn = test_db.async_connect().await; crate::boot::categories::sync_with_connection(ALGORITHMS_AND_SUCH, &mut conn) .await @@ -79,7 +79,7 @@ async fn sync_removes_missing_categories() { #[tokio::test] async fn sync_adds_and_removes() { let test_db = TestDatabase::new(); - let mut conn = AsyncPgConnection::establish(test_db.url()).await.unwrap(); + let mut conn = test_db.async_connect().await; crate::boot::categories::sync_with_connection(ALGORITHMS_AND_SUCH, &mut conn) .await @@ -95,7 +95,7 @@ async fn sync_adds_and_removes() { #[tokio::test] async fn test_real_categories() { let test_db = TestDatabase::new(); - let mut conn = AsyncPgConnection::establish(test_db.url()).await.unwrap(); + let mut conn = test_db.async_connect().await; const TOML: &str = include_str!("../boot/categories.toml"); assert_ok!(crate::boot::categories::sync_with_connection(TOML, &mut conn).await); diff --git a/src/tests/util/test_app.rs b/src/tests/util/test_app.rs index 645b2d8d069..ada6c1aba56 100644 --- a/src/tests/util/test_app.rs +++ b/src/tests/util/test_app.rs @@ -20,7 +20,7 @@ use crates_io_test_db::TestDatabase; use crates_io_worker::Runner; use diesel::r2d2::{ConnectionManager, PooledConnection}; use diesel::PgConnection; -use diesel_async::{AsyncConnection, AsyncPgConnection}; +use diesel_async::AsyncPgConnection; use futures_util::TryStreamExt; use oauth2::{ClientId, ClientSecret}; use regex::Regex; @@ -119,8 +119,7 @@ impl TestApp { /// Obtain an async database connection from the primary database pool. pub async fn async_db_conn(&self) -> AsyncPgConnection { - let result = AsyncPgConnection::establish(self.0.test_database.url()).await; - result.expect("Failed to get database connection") + self.0.test_database.async_connect().await } /// Create a new user with a verified email address in the database diff --git a/src/typosquat/database.rs b/src/typosquat/database.rs index f3bd71e5822..1f3e38215f1 100644 --- a/src/typosquat/database.rs +++ b/src/typosquat/database.rs @@ -176,7 +176,6 @@ mod tests { use super::*; use crate::typosquat::test_util::faker; use crates_io_test_db::TestDatabase; - use diesel_async::AsyncConnection; use thiserror::Error; #[tokio::test] @@ -198,7 +197,7 @@ mod tests { faker::add_crate_to_team(&mut conn, &user_b, &top_b, ¬_the_a_team)?; faker::add_crate_to_team(&mut conn, &user_b, ¬_top_c, ¬_the_a_team)?; - let mut async_conn = AsyncPgConnection::establish(test_db.url()).await?; + let mut async_conn = test_db.async_connect().await; let top_crates = TopCrates::new(&mut async_conn, 2).await?; // Let's ensure the top crates include what we expect (which is a and b, since we asked for diff --git a/src/worker/jobs/archive_version_downloads.rs b/src/worker/jobs/archive_version_downloads.rs index e44ef601bc8..44e57ab55bb 100644 --- a/src/worker/jobs/archive_version_downloads.rs +++ b/src/worker/jobs/archive_version_downloads.rs @@ -249,13 +249,12 @@ mod tests { use super::*; use crate::schema::{crates, version_downloads, versions}; use crates_io_test_db::TestDatabase; - use diesel_async::AsyncConnection; use insta::assert_snapshot; #[tokio::test] async fn test_export() { let test_db = TestDatabase::new(); - let mut conn = AsyncPgConnection::establish(test_db.url()).await.unwrap(); + let mut conn = test_db.async_connect().await; prepare_database(&mut conn).await; let tempdir = tempdir().unwrap(); @@ -357,7 +356,7 @@ mod tests { #[tokio::test] async fn test_delete() { let test_db = TestDatabase::new(); - let mut conn = AsyncPgConnection::establish(test_db.url()).await.unwrap(); + let mut conn = test_db.async_connect().await; prepare_database(&mut conn).await; let dates = vec![NaiveDate::from_ymd_opt(2021, 1, 1).unwrap()]; diff --git a/src/worker/jobs/downloads/clean_processed_log_files.rs b/src/worker/jobs/downloads/clean_processed_log_files.rs index dee4fbdb87c..9850a7b5842 100644 --- a/src/worker/jobs/downloads/clean_processed_log_files.rs +++ b/src/worker/jobs/downloads/clean_processed_log_files.rs @@ -44,13 +44,13 @@ mod tests { use super::*; use chrono::{DateTime, Utc}; use crates_io_test_db::TestDatabase; - use diesel_async::{AsyncConnection, AsyncPgConnection}; + use diesel_async::AsyncPgConnection; use insta::assert_debug_snapshot; #[tokio::test] async fn test_cleanup() { let test_db = TestDatabase::new(); - let mut conn = AsyncPgConnection::establish(test_db.url()).await.unwrap(); + let mut conn = test_db.async_connect().await; let now = chrono::Utc::now(); let cut_off_date = cut_off_date(); diff --git a/src/worker/jobs/expiry_notification.rs b/src/worker/jobs/expiry_notification.rs index 5a07de5a144..8b6f3ddf0c6 100644 --- a/src/worker/jobs/expiry_notification.rs +++ b/src/worker/jobs/expiry_notification.rs @@ -170,13 +170,12 @@ mod tests { use crate::{models::token::ApiToken, schema::api_tokens, util::token::PlainToken}; use crates_io_test_db::TestDatabase; use diesel::dsl::IntervalDsl; - use diesel_async::AsyncConnection; use lettre::Address; #[tokio::test] async fn test_expiry_notification() -> anyhow::Result<()> { let test_db = TestDatabase::new(); - let mut conn = AsyncPgConnection::establish(test_db.url()).await?; + let mut conn = test_db.async_connect().await; // Set up a user and a token that is about to expire. let user = NewUser::new(0, "a", None, None, "token"); diff --git a/src/worker/jobs/rss/sync_crate_feed.rs b/src/worker/jobs/rss/sync_crate_feed.rs index 50e67f136e6..8fe2a4eb91a 100644 --- a/src/worker/jobs/rss/sync_crate_feed.rs +++ b/src/worker/jobs/rss/sync_crate_feed.rs @@ -179,7 +179,6 @@ mod tests { use super::*; use chrono::NaiveDateTime; use crates_io_test_db::TestDatabase; - use diesel_async::AsyncConnection; use futures_util::future::join_all; use insta::assert_debug_snapshot; use std::borrow::Cow; @@ -190,7 +189,7 @@ mod tests { crate::util::tracing::init_for_test(); let db = TestDatabase::new(); - let mut conn = AsyncPgConnection::establish(db.url()).await.unwrap(); + let mut conn = db.async_connect().await; let now = chrono::Utc::now().naive_utc(); diff --git a/src/worker/jobs/rss/sync_crates_feed.rs b/src/worker/jobs/rss/sync_crates_feed.rs index de28909f53a..f0f9b4d6633 100644 --- a/src/worker/jobs/rss/sync_crates_feed.rs +++ b/src/worker/jobs/rss/sync_crates_feed.rs @@ -157,7 +157,7 @@ mod tests { use super::*; use chrono::NaiveDateTime; use crates_io_test_db::TestDatabase; - use diesel_async::{AsyncConnection, AsyncPgConnection}; + use diesel_async::AsyncPgConnection; use futures_util::future::join_all; use insta::assert_debug_snapshot; use std::borrow::Cow; @@ -168,7 +168,7 @@ mod tests { crate::util::tracing::init_for_test(); let db = TestDatabase::new(); - let mut conn = AsyncPgConnection::establish(db.url()).await.unwrap(); + let mut conn = db.async_connect().await; let now = chrono::Utc::now().naive_utc(); diff --git a/src/worker/jobs/rss/sync_updates_feed.rs b/src/worker/jobs/rss/sync_updates_feed.rs index 3a817f732f8..95db5b90687 100644 --- a/src/worker/jobs/rss/sync_updates_feed.rs +++ b/src/worker/jobs/rss/sync_updates_feed.rs @@ -173,7 +173,6 @@ mod tests { use super::*; use chrono::NaiveDateTime; use crates_io_test_db::TestDatabase; - use diesel_async::AsyncConnection; use futures_util::future::join_all; use insta::assert_debug_snapshot; use std::borrow::Cow; @@ -184,7 +183,7 @@ mod tests { crate::util::tracing::init_for_test(); let db = TestDatabase::new(); - let mut conn = AsyncPgConnection::establish(db.url()).await.unwrap(); + let mut conn = db.async_connect().await; let now = chrono::Utc::now().naive_utc(); diff --git a/src/worker/jobs/typosquat.rs b/src/worker/jobs/typosquat.rs index 7423fcf7737..ced952eae7d 100644 --- a/src/worker/jobs/typosquat.rs +++ b/src/worker/jobs/typosquat.rs @@ -124,7 +124,6 @@ mod tests { use super::*; use crate::typosquat::test_util::faker; use crates_io_test_db::TestDatabase; - use diesel_async::AsyncConnection; use lettre::Address; #[tokio::test] @@ -138,7 +137,7 @@ mod tests { faker::crate_and_version(&mut conn, "my-crate", "It's awesome", &user, 100)?; // Prime the cache so it only includes the crate we just created. - let mut async_conn = AsyncPgConnection::establish(test_db.url()).await?; + let mut async_conn = test_db.async_connect().await; let cache = Cache::new(vec!["admin@example.com".to_string()], &mut async_conn).await?; let cache = Arc::new(cache);