Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/pg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -387,9 +387,11 @@ fn update_transaction_manager_status<T>(
if let Err(diesel::result::Error::DatabaseError(DatabaseErrorKind::SerializationFailure, _)) =
query_result
{
transaction_manager
.status
.set_requires_rollback_maybe_up_to_top_level(true)
if !transaction_manager.is_commit {
transaction_manager
.status
.set_requires_rollback_maybe_up_to_top_level(true);
}
}
query_result
}
Expand Down
21 changes: 19 additions & 2 deletions src/transaction_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,11 @@ pub struct AnsiTransactionManager {
// See https://github.com/weiznich/diesel_async/issues/198 for
// details
pub(crate) is_broken: Arc<AtomicBool>,
// this boolean flag tracks whether we are currently in this process
// of trying to commit the transaction. this is useful because if we
// are and we get a serialization failure, we might not want to attempt
// a rollback up the chain.
pub(crate) is_commit: bool,
}

impl AnsiTransactionManager {
Expand Down Expand Up @@ -355,9 +360,18 @@ where
conn.instrumentation()
.on_connection_event(InstrumentationEvent::commit_transaction(depth));

let is_broken = conn.transaction_state().is_broken.clone();
let is_broken = {
let transaction_state = conn.transaction_state();
transaction_state.is_commit = true;
transaction_state.is_broken.clone()
};

let res =
Self::critical_transaction_block(&is_broken, conn.batch_execute(&commit_sql)).await;

conn.transaction_state().is_commit = false;

match Self::critical_transaction_block(&is_broken, conn.batch_execute(&commit_sql)).await {
match res {
Ok(()) => {
match Self::get_transaction_state(conn)?
.change_transaction_depth(TransactionDepthChange::DecreaseDepth)
Expand Down Expand Up @@ -392,6 +406,9 @@ where
});
}
}
} else {
Self::get_transaction_state(conn)?
.change_transaction_depth(TransactionDepthChange::DecreaseDepth)?;
}
Err(commit_error)
}
Expand Down
122 changes: 122 additions & 0 deletions tests/transactions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,125 @@ async fn concurrent_serializable_transactions_behave_correctly() {
res.unwrap_err()
);
}

#[cfg(feature = "postgres")]
#[tokio::test]
async fn commit_with_serialization_failure_already_ends_transaction() {
use diesel::prelude::*;
use diesel_async::{AsyncConnection, RunQueryDsl};
use std::sync::Arc;
use tokio::sync::Barrier;

table! {
users4 {
id -> Integer,
}
}

// create an async connection
let mut conn = super::connection_without_transaction().await;

struct A(Vec<&'static str>);
impl diesel::connection::Instrumentation for A {
fn on_connection_event(&mut self, event: diesel::connection::InstrumentationEvent<'_>) {
if let diesel::connection::InstrumentationEvent::StartQuery { query, .. } = event {
let q = query.to_string();
let q = q.split_once(' ').map(|(a, _)| a).unwrap_or(&q);

if matches!(q, "BEGIN" | "COMMIT" | "ROLLBACK") {
assert_eq!(q, self.0.pop().unwrap());
}
}
}
}
conn.set_instrumentation(A(vec!["COMMIT", "BEGIN", "COMMIT", "BEGIN"]));

let mut conn1 = super::connection_without_transaction().await;

diesel::sql_query("CREATE TABLE IF NOT EXISTS users4 (id int);")
.execute(&mut conn)
.await
.unwrap();

let barrier_1 = Arc::new(Barrier::new(2));
let barrier_2 = Arc::new(Barrier::new(2));
let barrier_1_for_tx1 = barrier_1.clone();
let barrier_1_for_tx2 = barrier_1.clone();
let barrier_2_for_tx1 = barrier_2.clone();
let barrier_2_for_tx2 = barrier_2.clone();

let mut tx = conn.build_transaction().serializable().read_write();

let res = tx.run(|conn| {
Box::pin(async {
users4::table.select(users4::id).load::<i32>(conn).await?;

barrier_1_for_tx1.wait().await;
diesel::insert_into(users4::table)
.values(users4::id.eq(1))
.execute(conn)
.await?;
barrier_2_for_tx1.wait().await;

Ok::<_, diesel::result::Error>(())
})
});

let mut tx1 = conn1.build_transaction().serializable().read_write();

let res1 = async {
let res = tx1
.run(|conn| {
Box::pin(async {
users4::table.select(users4::id).load::<i32>(conn).await?;

barrier_1_for_tx2.wait().await;
diesel::insert_into(users4::table)
.values(users4::id.eq(1))
.execute(conn)
.await?;

Ok::<_, diesel::result::Error>(())
})
})
.await;
barrier_2_for_tx2.wait().await;
res
};

let (res, res1) = tokio::join!(res, res1);
let _ = diesel::sql_query("DROP TABLE users3")
.execute(&mut conn1)
.await;

assert!(
res1.is_ok(),
"Expected the second transaction to be succussfull, but got an error: {:?}",
res1.unwrap_err()
);

assert!(res.is_err(), "Expected the first transaction to fail");
let err = res.unwrap_err();
assert!(
matches!(
&err,
diesel::result::Error::DatabaseError(
diesel::result::DatabaseErrorKind::SerializationFailure,
_
)
),
"Expected an serialization failure but got another error: {err:?}"
);

let mut tx = conn.build_transaction();

let res = tx
.run(|_| Box::pin(async { Ok::<_, diesel::result::Error>(()) }))
.await;

assert!(
res.is_ok(),
"Expect transaction to run fine but got an error: {:?}",
res.unwrap_err()
);
}
Loading