diff --git a/src/lib.rs b/src/lib.rs index b1be8dc..8102312 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -125,12 +125,8 @@ pub trait SimpleAsyncConnection { fn batch_execute(&mut self, query: &str) -> impl Future> + Send; } -/// An async connection to a database -/// -/// This trait represents a n async database connection. It can be used to query the database through -/// the query dsl provided by diesel, custom extensions or raw sql queries. It essentially mirrors -/// the sync diesel [`Connection`](diesel::connection::Connection) implementation -pub trait AsyncConnection: SimpleAsyncConnection + Sized + Send { +/// Core trait for an async database connection +pub trait AsyncConnectionCore: SimpleAsyncConnection + Send { /// The future returned by `AsyncConnection::execute` type ExecuteFuture<'conn, 'query>: Future> + Send; /// The future returned by `AsyncConnection::load` @@ -143,6 +139,37 @@ pub trait AsyncConnection: SimpleAsyncConnection + Sized + Send { /// The backend this type connects to type Backend: Backend; + #[doc(hidden)] + fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query> + where + T: AsQuery + 'query, + T::Query: QueryFragment + QueryId + 'query; + + #[doc(hidden)] + fn execute_returning_count<'conn, 'query, T>( + &'conn mut self, + source: T, + ) -> Self::ExecuteFuture<'conn, 'query> + where + T: QueryFragment + QueryId + 'query; + + // These functions allow the associated types (`ExecuteFuture`, `LoadFuture`, etc.) to + // compile without a `where Self: '_` clause. This is needed the because bound causes + // lifetime issues when using `transaction()` with generic `AsyncConnection`s. + // + // See: https://github.com/rust-lang/rust/issues/87479 + #[doc(hidden)] + fn _silence_lint_on_execute_future(_: Self::ExecuteFuture<'_, '_>) {} + #[doc(hidden)] + fn _silence_lint_on_load_future(_: Self::LoadFuture<'_, '_>) {} +} + +/// An async connection to a database +/// +/// This trait represents an async database connection. It can be used to query the database through +/// the query dsl provided by diesel, custom extensions or raw sql queries. It essentially mirrors +/// the sync diesel [`Connection`](diesel::connection::Connection) implementation +pub trait AsyncConnection: AsyncConnectionCore + Sized { #[doc(hidden)] type TransactionManager: TransactionManager; @@ -336,35 +363,11 @@ pub trait AsyncConnection: SimpleAsyncConnection + Sized + Send { }) } - #[doc(hidden)] - fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query> - where - T: AsQuery + 'query, - T::Query: QueryFragment + QueryId + 'query; - - #[doc(hidden)] - fn execute_returning_count<'conn, 'query, T>( - &'conn mut self, - source: T, - ) -> Self::ExecuteFuture<'conn, 'query> - where - T: QueryFragment + QueryId + 'query; - #[doc(hidden)] fn transaction_state( &mut self, ) -> &mut >::TransactionStateData; - // These functions allow the associated types (`ExecuteFuture`, `LoadFuture`, etc.) to - // compile without a `where Self: '_` clause. This is needed the because bound causes - // lifetime issues when using `transaction()` with generic `AsyncConnection`s. - // - // See: https://github.com/rust-lang/rust/issues/87479 - #[doc(hidden)] - fn _silence_lint_on_execute_future(_: Self::ExecuteFuture<'_, '_>) {} - #[doc(hidden)] - fn _silence_lint_on_load_future(_: Self::LoadFuture<'_, '_>) {} - #[doc(hidden)] fn instrumentation(&mut self) -> &mut dyn Instrumentation; diff --git a/src/mysql/mod.rs b/src/mysql/mod.rs index b25e5e0..1d44650 100644 --- a/src/mysql/mod.rs +++ b/src/mysql/mod.rs @@ -1,5 +1,5 @@ use crate::stmt_cache::{CallbackHelper, QueryFragmentHelper}; -use crate::{AnsiTransactionManager, AsyncConnection, SimpleAsyncConnection}; +use crate::{AnsiTransactionManager, AsyncConnection, AsyncConnectionCore, SimpleAsyncConnection}; use diesel::connection::statement_cache::{ MaybeCached, QueryFragmentForCachedStatement, StatementCache, }; @@ -64,30 +64,13 @@ const CONNECTION_SETUP_QUERIES: &[&str] = &[ "SET character_set_results = 'utf8mb4'", ]; -impl AsyncConnection for AsyncMysqlConnection { +impl AsyncConnectionCore for AsyncMysqlConnection { type ExecuteFuture<'conn, 'query> = BoxFuture<'conn, QueryResult>; type LoadFuture<'conn, 'query> = BoxFuture<'conn, QueryResult>>; type Stream<'conn, 'query> = BoxStream<'conn, QueryResult>>; type Row<'conn, 'query> = MysqlRow; type Backend = Mysql; - type TransactionManager = AnsiTransactionManager; - - async fn establish(database_url: &str) -> diesel::ConnectionResult { - let mut instrumentation = DynInstrumentation::default_instrumentation(); - instrumentation.on_connection_event(InstrumentationEvent::start_establish_connection( - database_url, - )); - let r = Self::establish_connection_inner(database_url).await; - instrumentation.on_connection_event(InstrumentationEvent::finish_establish_connection( - database_url, - r.as_ref().err(), - )); - let mut conn = r?; - conn.instrumentation = instrumentation; - Ok(conn) - } - fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query> where T: diesel::query_builder::AsQuery, @@ -173,6 +156,25 @@ impl AsyncConnection for AsyncMysqlConnection { .map_err(|e| diesel::result::Error::DeserializationError(Box::new(e))) }) } +} + +impl AsyncConnection for AsyncMysqlConnection { + type TransactionManager = AnsiTransactionManager; + + async fn establish(database_url: &str) -> diesel::ConnectionResult { + let mut instrumentation = DynInstrumentation::default_instrumentation(); + instrumentation.on_connection_event(InstrumentationEvent::start_establish_connection( + database_url, + )); + let r = Self::establish_connection_inner(database_url).await; + instrumentation.on_connection_event(InstrumentationEvent::finish_establish_connection( + database_url, + r.as_ref().err(), + )); + let mut conn = r?; + conn.instrumentation = instrumentation; + Ok(conn) + } fn transaction_state(&mut self) -> &mut AnsiTransactionManager { &mut self.transaction_manager diff --git a/src/pg/mod.rs b/src/pg/mod.rs index 62a51b7..39811b3 100644 --- a/src/pg/mod.rs +++ b/src/pg/mod.rs @@ -8,7 +8,7 @@ use self::error_helper::ErrorHelper; use self::row::PgRow; use self::serialize::ToSqlHelper; use crate::stmt_cache::{CallbackHelper, QueryFragmentHelper}; -use crate::{AnsiTransactionManager, AsyncConnection, SimpleAsyncConnection}; +use crate::{AnsiTransactionManager, AsyncConnection, AsyncConnectionCore, SimpleAsyncConnection}; use diesel::connection::statement_cache::{ PrepareForCache, QueryFragmentForCachedStatement, StatementCache, }; @@ -114,6 +114,48 @@ const FAKE_OID: u32 = 0; /// # } /// ``` /// +/// For more complex cases, an immutable reference to the connection need to be used: +/// ```rust +/// # include!("../doctest_setup.rs"); +/// use diesel_async::RunQueryDsl; +/// +/// # +/// # #[tokio::main(flavor = "current_thread")] +/// # async fn main() { +/// # run_test().await.unwrap(); +/// # } +/// # +/// # async fn run_test() -> QueryResult<()> { +/// # use diesel::sql_types::{Text, Integer}; +/// # let conn = &mut establish_connection().await; +/// # +/// async fn fn12(mut conn: &AsyncPgConnection) -> QueryResult<(i32, i32)> { +/// let f1 = diesel::select(1_i32.into_sql::()).get_result::(&mut conn); +/// let f2 = diesel::select(2_i32.into_sql::()).get_result::(&mut conn); +/// +/// futures_util::try_join!(f1, f2) +/// } +/// +/// async fn fn34(mut conn: &AsyncPgConnection) -> QueryResult<(i32, i32)> { +/// let f3 = diesel::select(3_i32.into_sql::()).get_result::(&mut conn); +/// let f4 = diesel::select(4_i32.into_sql::()).get_result::(&mut conn); +/// +/// futures_util::try_join!(f3, f4) +/// } +/// +/// let f12 = fn12(&conn); +/// let f34 = fn34(&conn); +/// +/// let ((r1, r2), (r3, r4)) = futures_util::try_join!(f12, f34).unwrap(); +/// +/// assert_eq!(r1, 1); +/// assert_eq!(r2, 2); +/// assert_eq!(r3, 3); +/// assert_eq!(r4, 4); +/// # Ok(()) +/// # } +/// ``` +/// /// ## TLS /// /// Connections created by [`AsyncPgConnection::establish`] do not support TLS. @@ -136,6 +178,12 @@ pub struct AsyncPgConnection { } impl SimpleAsyncConnection for AsyncPgConnection { + async fn batch_execute(&mut self, query: &str) -> QueryResult<()> { + SimpleAsyncConnection::batch_execute(&mut &*self, query).await + } +} + +impl SimpleAsyncConnection for &AsyncPgConnection { async fn batch_execute(&mut self, query: &str) -> QueryResult<()> { self.record_instrumentation(InstrumentationEvent::start_query(&StrQueryHelper::new( query, @@ -160,12 +208,69 @@ impl SimpleAsyncConnection for AsyncPgConnection { } } -impl AsyncConnection for AsyncPgConnection { +impl AsyncConnectionCore for AsyncPgConnection { type LoadFuture<'conn, 'query> = BoxFuture<'query, QueryResult>>; type ExecuteFuture<'conn, 'query> = BoxFuture<'query, QueryResult>; type Stream<'conn, 'query> = BoxStream<'static, QueryResult>; type Row<'conn, 'query> = PgRow; type Backend = diesel::pg::Pg; + + fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query> + where + T: AsQuery + 'query, + T::Query: QueryFragment + QueryId + 'query, + { + AsyncConnectionCore::load(&mut &*self, source) + } + + fn execute_returning_count<'conn, 'query, T>( + &'conn mut self, + source: T, + ) -> Self::ExecuteFuture<'conn, 'query> + where + T: QueryFragment + QueryId + 'query, + { + AsyncConnectionCore::execute_returning_count(&mut &*self, source) + } +} + +impl AsyncConnectionCore for &AsyncPgConnection { + type LoadFuture<'conn, 'query> = + ::LoadFuture<'conn, 'query>; + + type ExecuteFuture<'conn, 'query> = + ::ExecuteFuture<'conn, 'query>; + + type Stream<'conn, 'query> = ::Stream<'conn, 'query>; + + type Row<'conn, 'query> = ::Row<'conn, 'query>; + + type Backend = ::Backend; + + fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query> + where + T: AsQuery + 'query, + T::Query: QueryFragment + QueryId + 'query, + { + let query = source.as_query(); + let load_future = self.with_prepared_statement(query, load_prepared); + + self.run_with_connection_future(load_future) + } + + fn execute_returning_count<'conn, 'query, T>( + &'conn mut self, + source: T, + ) -> Self::ExecuteFuture<'conn, 'query> + where + T: QueryFragment + QueryId + 'query, + { + let execute = self.with_prepared_statement(source, execute_prepared); + self.run_with_connection_future(execute) + } +} + +impl AsyncConnection for AsyncPgConnection { type TransactionManager = AnsiTransactionManager; async fn establish(database_url: &str) -> ConnectionResult { @@ -198,28 +303,6 @@ impl AsyncConnection for AsyncPgConnection { r } - fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query> - where - T: AsQuery + 'query, - T::Query: QueryFragment + QueryId + 'query, - { - let query = source.as_query(); - let load_future = self.with_prepared_statement(query, load_prepared); - - self.run_with_connection_future(load_future) - } - - fn execute_returning_count<'conn, 'query, T>( - &'conn mut self, - source: T, - ) -> Self::ExecuteFuture<'conn, 'query> - where - T: QueryFragment + QueryId + 'query, - { - let execute = self.with_prepared_statement(source, execute_prepared); - self.run_with_connection_future(execute) - } - fn transaction_state(&mut self) -> &mut AnsiTransactionManager { // there should be no other pending future when this is called // that means there is only one instance of this arc and @@ -467,7 +550,7 @@ impl AsyncPgConnection { } fn with_prepared_statement<'a, T, F, R>( - &mut self, + &self, query: T, callback: fn(Arc, Statement, Vec) -> F, ) -> BoxFuture<'a, QueryResult> @@ -502,7 +585,7 @@ impl AsyncPgConnection { } fn with_prepared_statement_after_sql_built<'a, F, R>( - &mut self, + &self, callback: fn(Arc, Statement, Vec) -> F, is_safe_to_cache_prepared: QueryResult, query_id: Option, @@ -939,11 +1022,15 @@ mod tests { use crate::run_query_dsl::RunQueryDsl; use diesel::sql_types::Integer; use diesel::IntoSql; + use futures_util::future::try_join; + use futures_util::try_join; + use scoped_futures::ScopedFutureExt; #[tokio::test] async fn pipelining() { let database_url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set in order to run tests"); + let mut conn = crate::AsyncPgConnection::establish(&database_url) .await .unwrap(); @@ -954,9 +1041,100 @@ mod tests { let f1 = q1.get_result::(&mut conn); let f2 = q2.get_result::(&mut conn); - let (r1, r2) = futures_util::try_join!(f1, f2).unwrap(); + let (r1, r2) = try_join!(f1, f2).unwrap(); assert_eq!(r1, 1); assert_eq!(r2, 2); } + + #[tokio::test] + async fn pipelining_with_composed_futures() { + let database_url = + std::env::var("DATABASE_URL").expect("DATABASE_URL must be set in order to run tests"); + + let conn = crate::AsyncPgConnection::establish(&database_url) + .await + .unwrap(); + + async fn fn12(mut conn: &AsyncPgConnection) -> QueryResult<(i32, i32)> { + let f1 = diesel::select(1_i32.into_sql::()).get_result::(&mut conn); + let f2 = diesel::select(2_i32.into_sql::()).get_result::(&mut conn); + + try_join!(f1, f2) + } + + async fn fn34(mut conn: &AsyncPgConnection) -> QueryResult<(i32, i32)> { + let f3 = diesel::select(3_i32.into_sql::()).get_result::(&mut conn); + let f4 = diesel::select(4_i32.into_sql::()).get_result::(&mut conn); + + try_join!(f3, f4) + } + + let f12 = fn12(&conn); + let f34 = fn34(&conn); + + let ((r1, r2), (r3, r4)) = try_join!(f12, f34).unwrap(); + + assert_eq!(r1, 1); + assert_eq!(r2, 2); + assert_eq!(r3, 3); + assert_eq!(r4, 4); + } + + #[tokio::test] + async fn pipelining_with_composed_futures_and_transaction() { + let database_url = + std::env::var("DATABASE_URL").expect("DATABASE_URL must be set in order to run tests"); + + let mut conn = crate::AsyncPgConnection::establish(&database_url) + .await + .unwrap(); + + fn erase<'a, T: Future + Send + 'a>(t: T) -> impl Future + Send + 'a { + t + } + + async fn fn12(mut conn: &AsyncPgConnection) -> QueryResult<(i32, i32)> { + let f1 = diesel::select(1_i32.into_sql::()).get_result::(&mut conn); + let f2 = diesel::select(2_i32.into_sql::()).get_result::(&mut conn); + + erase(try_join(f1, f2)).await + } + + async fn fn34(mut conn: &AsyncPgConnection) -> QueryResult<(i32, i32)> { + let f3 = diesel::select(3_i32.into_sql::()).get_result::(&mut conn); + let f4 = diesel::select(4_i32.into_sql::()).get_result::(&mut conn); + + try_join(f3, f4).boxed().await + } + + async fn fn56(mut conn: &AsyncPgConnection) -> QueryResult<(i32, i32)> { + let f5 = diesel::select(5_i32.into_sql::()).get_result::(&mut conn); + let f6 = diesel::select(6_i32.into_sql::()).get_result::(&mut conn); + + try_join!(f5.boxed(), f6.boxed()) + } + + conn.transaction(|conn| { + async move { + let f12 = fn12(conn); + let f34 = fn34(conn); + let f56 = fn56(conn); + + let ((r1, r2), (r3, r4), (r5, r6)) = try_join!(f12, f34, f56).unwrap(); + + assert_eq!(r1, 1); + assert_eq!(r2, 2); + assert_eq!(r3, 3); + assert_eq!(r4, 4); + assert_eq!(r5, 5); + assert_eq!(r6, 6); + + QueryResult::<_>::Ok(()) + } + .scope_boxed() + }) + .await + .unwrap(); + } } diff --git a/src/pooled_connection/mod.rs b/src/pooled_connection/mod.rs index f155d3c..cbe9f60 100644 --- a/src/pooled_connection/mod.rs +++ b/src/pooled_connection/mod.rs @@ -5,7 +5,7 @@ //! * [deadpool](self::deadpool) //! * [bb8](self::bb8) //! * [mobc](self::mobc) -use crate::{AsyncConnection, SimpleAsyncConnection}; +use crate::{AsyncConnection, AsyncConnectionCore, SimpleAsyncConnection}; use crate::{TransactionManager, UpdateAndFetchResults}; use diesel::associations::HasTable; use diesel::connection::{CacheSize, Instrumentation}; @@ -176,27 +176,18 @@ where } } -impl AsyncConnection for C +impl AsyncConnectionCore for C where C: DerefMut + Send, - C::Target: AsyncConnection, + C::Target: AsyncConnectionCore, { type ExecuteFuture<'conn, 'query> = - ::ExecuteFuture<'conn, 'query>; - type LoadFuture<'conn, 'query> = ::LoadFuture<'conn, 'query>; - type Stream<'conn, 'query> = ::Stream<'conn, 'query>; - type Row<'conn, 'query> = ::Row<'conn, 'query>; + ::ExecuteFuture<'conn, 'query>; + type LoadFuture<'conn, 'query> = ::LoadFuture<'conn, 'query>; + type Stream<'conn, 'query> = ::Stream<'conn, 'query>; + type Row<'conn, 'query> = ::Row<'conn, 'query>; - type Backend = ::Backend; - - type TransactionManager = - PoolTransactionManager<::TransactionManager>; - - async fn establish(_database_url: &str) -> diesel::ConnectionResult { - Err(diesel::result::ConnectionError::BadConnection( - String::from("Cannot directly establish a pooled connection"), - )) - } + type Backend = ::Backend; fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query> where @@ -221,6 +212,21 @@ where let conn = self.deref_mut(); conn.execute_returning_count(source) } +} + +impl AsyncConnection for C +where + C: DerefMut + Send, + C::Target: AsyncConnection, +{ + type TransactionManager = + PoolTransactionManager<::TransactionManager>; + + async fn establish(_database_url: &str) -> diesel::ConnectionResult { + Err(diesel::result::ConnectionError::BadConnection( + String::from("Cannot directly establish a pooled connection"), + )) + } fn transaction_state( &mut self, diff --git a/src/run_query_dsl/mod.rs b/src/run_query_dsl/mod.rs index b1ed693..437d2a2 100644 --- a/src/run_query_dsl/mod.rs +++ b/src/run_query_dsl/mod.rs @@ -1,4 +1,4 @@ -use crate::AsyncConnection; +use crate::AsyncConnectionCore; use diesel::associations::HasTable; use diesel::query_builder::IntoUpdateTarget; use diesel::result::QueryResult; @@ -31,9 +31,9 @@ pub mod methods { /// to call `execute` from generic code. /// /// [`RunQueryDsl`]: super::RunQueryDsl - pub trait ExecuteDsl::Backend> + pub trait ExecuteDsl::Backend> where - Conn: AsyncConnection, + Conn: AsyncConnectionCore, DB: Backend, { /// Execute this command @@ -47,7 +47,7 @@ pub mod methods { impl ExecuteDsl for T where - Conn: AsyncConnection, + Conn: AsyncConnectionCore, DB: Backend, T: QueryFragment + QueryId + Send, { @@ -69,7 +69,7 @@ pub mod methods { /// to call `load` from generic code. /// /// [`RunQueryDsl`]: super::RunQueryDsl - pub trait LoadQuery<'query, Conn: AsyncConnection, U> { + pub trait LoadQuery<'query, Conn: AsyncConnectionCore, U> { /// The future returned by [`LoadQuery::internal_load`] type LoadFuture<'conn>: Future>> + Send where @@ -85,7 +85,7 @@ pub mod methods { impl<'query, Conn, DB, T, U, ST> LoadQuery<'query, Conn, U> for T where - Conn: AsyncConnection, + Conn: AsyncConnectionCore, U: Send, DB: Backend + 'static, T: AsQuery + Send + 'query, @@ -227,7 +227,7 @@ pub trait RunQueryDsl: Sized { /// ``` fn execute<'conn, 'query>(self, conn: &'conn mut Conn) -> Conn::ExecuteFuture<'conn, 'query> where - Conn: AsyncConnection + Send, + Conn: AsyncConnectionCore + Send, Self: methods::ExecuteDsl + 'query, { methods::ExecuteDsl::execute(self, conn) @@ -343,7 +343,7 @@ pub trait RunQueryDsl: Sized { ) -> return_futures::LoadFuture<'conn, 'query, Self, Conn, U> where U: Send, - Conn: AsyncConnection, + Conn: AsyncConnectionCore, Self: methods::LoadQuery<'query, Conn, U> + 'query, { fn collect_result(stream: S) -> stream::TryCollect> @@ -481,7 +481,7 @@ pub trait RunQueryDsl: Sized { /// ``` fn load_stream<'conn, 'query, U>(self, conn: &'conn mut Conn) -> Self::LoadFuture<'conn> where - Conn: AsyncConnection, + Conn: AsyncConnectionCore, U: 'conn, Self: methods::LoadQuery<'query, Conn, U> + 'query, { @@ -544,7 +544,7 @@ pub trait RunQueryDsl: Sized { ) -> return_futures::GetResult<'conn, 'query, Self, Conn, U> where U: Send + 'conn, - Conn: AsyncConnection, + Conn: AsyncConnectionCore, Self: methods::LoadQuery<'query, Conn, U> + 'query, { #[allow(clippy::type_complexity)] @@ -584,7 +584,7 @@ pub trait RunQueryDsl: Sized { ) -> return_futures::LoadFuture<'conn, 'query, Self, Conn, U> where U: Send, - Conn: AsyncConnection, + Conn: AsyncConnectionCore, Self: methods::LoadQuery<'query, Conn, U> + 'query, { self.load(conn) @@ -640,7 +640,7 @@ pub trait RunQueryDsl: Sized { ) -> return_futures::GetResult<'conn, 'query, diesel::dsl::Limit, Conn, U> where U: Send + 'conn, - Conn: AsyncConnection, + Conn: AsyncConnectionCore, Self: diesel::query_dsl::methods::LimitDsl, diesel::dsl::Limit: methods::LoadQuery<'query, Conn, U> + Send + 'query, { @@ -734,7 +734,7 @@ impl SaveChangesDsl for T where /// For implementing this trait for a custom backend: /// * The `Changes` generic parameter represents the changeset that should be stored /// * The `Output` generic parameter represents the type of the response. -pub trait UpdateAndFetchResults: AsyncConnection +pub trait UpdateAndFetchResults: AsyncConnectionCore where Changes: diesel::prelude::Identifiable + HasTable, { diff --git a/src/sync_connection_wrapper/mod.rs b/src/sync_connection_wrapper/mod.rs index 2bbf570..cbb8436 100644 --- a/src/sync_connection_wrapper/mod.rs +++ b/src/sync_connection_wrapper/mod.rs @@ -89,7 +89,7 @@ pub use self::implementation::SyncConnectionWrapper; pub use self::implementation::SyncTransactionManagerWrapper; mod implementation { - use crate::{AsyncConnection, SimpleAsyncConnection, TransactionManager}; + use crate::{AsyncConnection, AsyncConnectionCore, SimpleAsyncConnection, TransactionManager}; use diesel::backend::{Backend, DieselReserveSpecialization}; use diesel::connection::{CacheSize, Instrumentation}; use diesel::connection::{ @@ -133,7 +133,7 @@ mod implementation { } } - impl AsyncConnection for SyncConnectionWrapper + impl AsyncConnectionCore for SyncConnectionWrapper where // Backend bounds ::Backend: std::default::Default + DieselReserveSpecialization, @@ -158,19 +158,6 @@ mod implementation { type Stream<'conn, 'query> = BoxStream<'static, QueryResult>>; type Row<'conn, 'query> = O; type Backend = ::Backend; - type TransactionManager = - SyncTransactionManagerWrapper<::TransactionManager>; - - async fn establish(database_url: &str) -> ConnectionResult { - let database_url = database_url.to_string(); - let mut runtime = S::get_runtime(); - - runtime - .spawn_blocking(move || C::establish(&database_url)) - .await - .unwrap_or_else(|e| Err(diesel::ConnectionError::BadConnection(e.to_string()))) - .map(move |c| SyncConnectionWrapper::with_runtime(c, runtime)) - } fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query> where @@ -209,6 +196,40 @@ mod implementation { conn.execute_returning_count(&query) }) } + } + + impl AsyncConnection for SyncConnectionWrapper + where + // Backend bounds + ::Backend: std::default::Default + DieselReserveSpecialization, + ::QueryBuilder: std::default::Default, + // Connection bounds + C: Connection + LoadConnection + WithMetadataLookup + 'static, + ::TransactionManager: Send, + // BindCollector bounds + MD: Send + 'static, + for<'a> ::BindCollector<'a>: + MoveableBindCollector + std::default::Default, + // Row bounds + O: 'static + Send + for<'conn> diesel::row::Row<'conn, C::Backend>, + for<'conn, 'query> ::Row<'conn, 'query>: + IntoOwnedRow<'conn, ::Backend, OwnedRow = O>, + // SpawnBlocking bounds + S: SpawnBlocking + Send, + { + type TransactionManager = + SyncTransactionManagerWrapper<::TransactionManager>; + + async fn establish(database_url: &str) -> ConnectionResult { + let database_url = database_url.to_string(); + let mut runtime = S::get_runtime(); + + runtime + .spawn_blocking(move || C::establish(&database_url)) + .await + .unwrap_or_else(|e| Err(diesel::ConnectionError::BadConnection(e.to_string()))) + .map(move |c| SyncConnectionWrapper::with_runtime(c, runtime)) + } fn transaction_state( &mut self, diff --git a/tests/instrumentation.rs b/tests/instrumentation.rs index e14a0c3..899189d 100644 --- a/tests/instrumentation.rs +++ b/tests/instrumentation.rs @@ -5,6 +5,7 @@ use diesel::connection::InstrumentationEvent; use diesel::query_builder::AsQuery; use diesel::QueryResult; use diesel_async::AsyncConnection; +use diesel_async::AsyncConnectionCore; use diesel_async::SimpleAsyncConnection; use std::num::NonZeroU32; use std::sync::Arc; @@ -107,7 +108,7 @@ async fn check_events_are_emitted_for_execute_returning_count() { #[tokio::test] async fn check_events_are_emitted_for_load() { let (events_to_check, mut conn) = setup_test_case().await; - let _ = AsyncConnection::load(&mut conn, users::table.as_query()) + let _ = AsyncConnectionCore::load(&mut conn, users::table.as_query()) .await .unwrap(); let events = events_to_check.lock().unwrap(); @@ -133,7 +134,7 @@ async fn check_events_are_emitted_for_execute_returning_count_does_not_contain_c #[tokio::test] async fn check_events_are_emitted_for_load_does_not_contain_cache_for_uncached_queries() { let (events_to_check, mut conn) = setup_test_case().await; - let _ = AsyncConnection::load(&mut conn, diesel::sql_query("select 1")) + let _ = AsyncConnectionCore::load(&mut conn, diesel::sql_query("select 1")) .await .unwrap(); let events = events_to_check.lock().unwrap(); @@ -157,7 +158,7 @@ async fn check_events_are_emitted_for_execute_returning_count_does_contain_error #[tokio::test] async fn check_events_are_emitted_for_load_does_contain_error_for_failures() { let (events_to_check, mut conn) = setup_test_case().await; - let _ = AsyncConnection::load(&mut conn, diesel::sql_query("invalid")).await; + let _ = AsyncConnectionCore::load(&mut conn, diesel::sql_query("invalid")).await; let events = events_to_check.lock().unwrap(); assert_eq!(events.len(), 2, "{:?}", events); assert_matches!(events[0], Event::StartQuery { .. }); @@ -185,10 +186,10 @@ async fn check_events_are_emitted_for_execute_returning_count_repeat_does_not_re #[tokio::test] async fn check_events_are_emitted_for_load_repeat_does_not_repeat_cache() { let (events_to_check, mut conn) = setup_test_case().await; - let _ = AsyncConnection::load(&mut conn, users::table.as_query()) + let _ = AsyncConnectionCore::load(&mut conn, users::table.as_query()) .await .unwrap(); - let _ = AsyncConnection::load(&mut conn, users::table.as_query()) + let _ = AsyncConnectionCore::load(&mut conn, users::table.as_query()) .await .unwrap(); let events = events_to_check.lock().unwrap(); diff --git a/tests/lib.rs b/tests/lib.rs index c3fa5e4..24cd2a6 100644 --- a/tests/lib.rs +++ b/tests/lib.rs @@ -100,7 +100,7 @@ type TestConnection = sync_connection_wrapper::SyncConnectionWrapper; #[allow(dead_code)] -type TestBackend = ::Backend; +type TestBackend = ::Backend; #[tokio::test] async fn test_basic_insert_and_load() -> QueryResult<()> { diff --git a/tests/type_check.rs b/tests/type_check.rs index 52ff8c3..a074796 100644 --- a/tests/type_check.rs +++ b/tests/type_check.rs @@ -4,14 +4,14 @@ use diesel::expression::{AsExpression, ValidGrouping}; use diesel::prelude::*; use diesel::query_builder::{NoFromClause, QueryFragment, QueryId}; use diesel::sql_types::{self, HasSqlType, SingleValue}; -use diesel_async::{AsyncConnection, RunQueryDsl}; +use diesel_async::{AsyncConnectionCore, RunQueryDsl}; use std::fmt::Debug; async fn type_check(conn: &mut TestConnection, value: T) where T: Clone + AsExpression - + FromSqlRow::Backend> + + FromSqlRow::Backend> + Send + PartialEq + Debug @@ -19,10 +19,10 @@ where + 'static, T::Expression: ValidGrouping<()> + SelectableExpression - + QueryFragment<::Backend> + + QueryFragment<::Backend> + QueryId + Send, - ::Backend: HasSqlType, + ::Backend: HasSqlType, ST: SingleValue, { let res = diesel::select(value.clone().into_sql())