Skip to content

Commit b03606e

Browse files
authored
Merge pull request #60 from KatsuoRyuu/feature/test_transaction
Adding test transaction to the trait AsyncConnection
2 parents 91ce1d0 + a58433f commit b03606e

File tree

5 files changed

+80
-10
lines changed

5 files changed

+80
-10
lines changed

src/lib.rs

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,17 @@
6666
//! ```
6767
6868
#![warn(missing_docs)]
69+
6970
use diesel::backend::Backend;
7071
use diesel::query_builder::{AsQuery, QueryFragment, QueryId};
72+
use diesel::result::Error;
7173
use diesel::row::Row;
7274
use diesel::{ConnectionResult, QueryResult};
7375
use futures_util::{Future, Stream};
76+
use std::fmt::Debug;
7477

7578
pub use scoped_futures;
76-
use scoped_futures::ScopedBoxFuture;
79+
use scoped_futures::{ScopedBoxFuture, ScopedFutureExt};
7780

7881
#[cfg(feature = "mysql")]
7982
mod mysql;
@@ -254,6 +257,65 @@ pub trait AsyncConnection: SimpleAsyncConnection + Sized + Send {
254257
Ok(())
255258
}
256259

260+
/// Executes the given function inside a transaction, but does not commit
261+
/// it. Panics if the given function returns an error.
262+
///
263+
/// # Example
264+
///
265+
/// ```rust
266+
/// # include!("doctest_setup.rs");
267+
/// use diesel::result::Error;
268+
/// use scoped_futures::ScopedFutureExt;
269+
///
270+
/// # #[tokio::main(flavor = "current_thread")]
271+
/// # async fn main() {
272+
/// # run_test().await.unwrap();
273+
/// # }
274+
/// #
275+
/// # async fn run_test() -> QueryResult<()> {
276+
/// # use schema::users::dsl::*;
277+
/// # let conn = &mut establish_connection().await;
278+
/// conn.test_transaction::<_, Error, _>(|conn| async move {
279+
/// diesel::insert_into(users)
280+
/// .values(name.eq("Ruby"))
281+
/// .execute(conn)
282+
/// .await?;
283+
///
284+
/// let all_names = users.select(name).load::<String>(conn).await?;
285+
/// assert_eq!(vec!["Sean", "Tess", "Ruby"], all_names);
286+
///
287+
/// Ok(())
288+
/// }.scope_boxed()).await;
289+
///
290+
/// // Even though we returned `Ok`, the transaction wasn't committed.
291+
/// let all_names = users.select(name).load::<String>(conn).await?;
292+
/// assert_eq!(vec!["Sean", "Tess"], all_names);
293+
/// # Ok(())
294+
/// # }
295+
/// ```
296+
async fn test_transaction<'a, R, E, F>(&'a mut self, f: F) -> R
297+
where
298+
F: for<'r> FnOnce(&'r mut Self) -> ScopedBoxFuture<'a, 'r, Result<R, E>> + Send + 'a,
299+
E: Debug + Send + 'a,
300+
R: Send + 'a,
301+
Self: 'a,
302+
{
303+
use futures_util::TryFutureExt;
304+
305+
let mut user_result = None;
306+
let _ = self
307+
.transaction::<R, _, _>(|c| {
308+
f(c).map_err(|_| Error::RollbackTransaction)
309+
.and_then(|r| {
310+
user_result = Some(r);
311+
futures_util::future::ready(Err(Error::RollbackTransaction))
312+
})
313+
.scope_boxed()
314+
})
315+
.await;
316+
user_result.expect("Transaction did not succeed")
317+
}
318+
257319
#[doc(hidden)]
258320
fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query>
259321
where

src/pg/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ impl AsyncConnection for AsyncPgConnection {
123123
.map_err(ErrorHelper)?;
124124
tokio::spawn(async move {
125125
if let Err(e) = connection.await {
126-
eprintln!("connection error: {}", e);
126+
eprintln!("connection error: {e}");
127127
}
128128
});
129129
Self::try_from(client).await

src/pooled_connection/mod.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,17 @@ impl fmt::Display for PoolError {
3939

4040
impl std::error::Error for PoolError {}
4141

42+
type SetupCallback<C> =
43+
Box<dyn Fn(&str) -> future::BoxFuture<diesel::ConnectionResult<C>> + Send + Sync>;
44+
4245
/// An connection manager for use with diesel-async.
4346
///
4447
/// See the concrete pool implementations for examples:
4548
/// * [deadpool](self::deadpool)
4649
/// * [bb8](self::bb8)
4750
/// * [mobc](self::mobc)
4851
pub struct AsyncDieselConnectionManager<C> {
49-
setup: Box<dyn Fn(&str) -> future::BoxFuture<diesel::ConnectionResult<C>> + Send + Sync>,
52+
setup: SetupCallback<C>,
5053
connection_url: String,
5154
}
5255

src/transaction_manager.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ where
292292
let start_transaction_sql = match transaction_state.transaction_depth() {
293293
None => Cow::from("BEGIN"),
294294
Some(transaction_depth) => {
295-
Cow::from(format!("SAVEPOINT diesel_savepoint_{}", transaction_depth))
295+
Cow::from(format!("SAVEPOINT diesel_savepoint_{transaction_depth}"))
296296
}
297297
};
298298
conn.batch_execute(&start_transaction_sql).await?;

tests/type_check.rs

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,8 @@ async fn test_timestamp() {
161161
type_check::<_, sql_types::Timestamp>(
162162
conn,
163163
chrono::NaiveDateTime::new(
164-
chrono::NaiveDate::from_ymd(2021, 09, 27),
165-
chrono::NaiveTime::from_hms_milli(17, 44, 23, 0),
164+
chrono::NaiveDate::from_ymd_opt(2021, 09, 27).unwrap(),
165+
chrono::NaiveTime::from_hms_milli_opt(17, 44, 23, 0).unwrap(),
166166
),
167167
)
168168
.await;
@@ -171,13 +171,18 @@ async fn test_timestamp() {
171171
#[tokio::test]
172172
async fn test_date() {
173173
let conn = &mut connection().await;
174-
type_check::<_, sql_types::Date>(conn, chrono::NaiveDate::from_ymd(2021, 09, 27)).await;
174+
type_check::<_, sql_types::Date>(conn, chrono::NaiveDate::from_ymd_opt(2021, 09, 27).unwrap())
175+
.await;
175176
}
176177

177178
#[tokio::test]
178179
async fn test_time() {
179180
let conn = &mut connection().await;
180-
type_check::<_, sql_types::Time>(conn, chrono::NaiveTime::from_hms_milli(17, 44, 23, 0)).await;
181+
type_check::<_, sql_types::Time>(
182+
conn,
183+
chrono::NaiveTime::from_hms_milli_opt(17, 44, 23, 0).unwrap(),
184+
)
185+
.await;
181186
}
182187

183188
#[cfg(feature = "mysql")]
@@ -187,8 +192,8 @@ async fn test_datetime() {
187192
type_check::<_, sql_types::Datetime>(
188193
conn,
189194
chrono::NaiveDateTime::new(
190-
chrono::NaiveDate::from_ymd(2021, 09, 30),
191-
chrono::NaiveTime::from_hms_milli(12, 06, 42, 0),
195+
chrono::NaiveDate::from_ymd_opt(2021, 09, 30).unwrap(),
196+
chrono::NaiveTime::from_hms_milli_opt(12, 06, 42, 0).unwrap(),
192197
),
193198
)
194199
.await;

0 commit comments

Comments
 (0)