Skip to content

Commit 2e4075a

Browse files
committed
Fix a bug in how we handle serialization errors
This commit fixes a bug in how we handle serialization errors with the postgres backend and transactions. It turns out that we never update the transaction manager state for `batch_execute` calls, which in turn are used for executing the transaction SQL itself. That could lead to situations in which we don't roll back the transaction, but in which we should have done that. Fixes #241
1 parent 7d45634 commit 2e4075a

File tree

8 files changed

+119
-11
lines changed

8 files changed

+119
-11
lines changed

examples/postgres/pooled-with-rustls/src/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
4141
Ok(())
4242
}
4343

44-
fn establish_connection(config: &str) -> BoxFuture<ConnectionResult<AsyncPgConnection>> {
44+
fn establish_connection(config: &str) -> BoxFuture<'_, ConnectionResult<AsyncPgConnection>> {
4545
let fut = async {
4646
// We first set up the way we want rustls to work.
4747
let rustls_config = ClientConfig::with_platform_verifier();

examples/postgres/run-pending-migrations-with-rustls/src/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
2727
Ok(())
2828
}
2929

30-
fn establish_connection(config: &str) -> BoxFuture<ConnectionResult<AsyncPgConnection>> {
30+
fn establish_connection(config: &str) -> BoxFuture<'_, ConnectionResult<AsyncPgConnection>> {
3131
let fut = async {
3232
// We first set up the way we want rustls to work.
3333
let rustls_config = ClientConfig::with_platform_verifier();

src/mysql/row.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ impl<'a> diesel::row::Row<'a, Mysql> for MysqlRow {
121121
Some(field)
122122
}
123123

124-
fn partial_row(&self, range: std::ops::Range<usize>) -> PartialRow<Self::InnerPartialRow> {
124+
fn partial_row(&self, range: std::ops::Range<usize>) -> PartialRow<'_, Self::InnerPartialRow> {
125125
PartialRow::new(self, range)
126126
}
127127
}

src/pg/mod.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,12 @@ impl SimpleAsyncConnection for AsyncPgConnection {
146146
.batch_execute(query)
147147
.map_err(ErrorHelper)
148148
.map_err(Into::into);
149+
149150
let r = drive_future(connection_future, batch_execute).await;
151+
let r = {
152+
let mut transaction_manager = self.transaction_state.lock().await;
153+
update_transaction_manager_status(r, &mut transaction_manager)
154+
};
150155
self.record_instrumentation(InstrumentationEvent::finish_query(
151156
&StrQueryHelper::new(query),
152157
r.as_ref().err(),
@@ -379,7 +384,7 @@ impl AsyncPgConnection {
379384
/// .await
380385
/// # }
381386
/// ```
382-
pub fn build_transaction(&mut self) -> TransactionBuilder<Self> {
387+
pub fn build_transaction(&mut self) -> TransactionBuilder<'_, Self> {
383388
TransactionBuilder::new(self)
384389
}
385390

src/pg/row.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ impl<'a> diesel::row::Row<'a, diesel::pg::Pg> for PgRow {
4141
fn partial_row(
4242
&self,
4343
range: std::ops::Range<usize>,
44-
) -> diesel::row::PartialRow<Self::InnerPartialRow> {
44+
) -> diesel::row::PartialRow<'_, Self::InnerPartialRow> {
4545
PartialRow::new(self, range)
4646
}
4747
}

src/transaction_manager.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -381,12 +381,8 @@ where
381381
..
382382
}) = conn.transaction_state().status
383383
{
384-
match Self::critical_transaction_block(
385-
&is_broken,
386-
Self::rollback_transaction(conn),
387-
)
388-
.await
389-
{
384+
// rollback_transaction handles the critical block internally on its own
385+
match Self::rollback_transaction(conn).await {
390386
Ok(()) => {}
391387
Err(rollback_error) => {
392388
conn.transaction_state().status.set_in_error();

tests/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ mod instrumentation;
1111
mod pooling;
1212
#[cfg(feature = "async-connection-wrapper")]
1313
mod sync_wrapper;
14+
mod transactions;
1415
mod type_check;
1516

1617
async fn transaction_test<C: AsyncConnection<Backend = TestBackend>>(

tests/transactions.rs

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
#[cfg(feature = "postgres")]
2+
#[tokio::test]
3+
async fn concurrent_serializable_transactions_behave_correctly() {
4+
use diesel::prelude::*;
5+
use diesel_async::RunQueryDsl;
6+
use std::sync::Arc;
7+
use tokio::sync::Barrier;
8+
9+
table! {
10+
users3 {
11+
id -> Integer,
12+
}
13+
}
14+
15+
// create an async connection
16+
let mut conn = super::connection_without_transaction().await;
17+
18+
let mut conn1 = super::connection_without_transaction().await;
19+
20+
diesel::sql_query("CREATE TABLE IF NOT EXISTS users3 (id int);")
21+
.execute(&mut conn)
22+
.await
23+
.unwrap();
24+
25+
let barrier_1 = Arc::new(Barrier::new(2));
26+
let barrier_2 = Arc::new(Barrier::new(2));
27+
let barrier_1_for_tx1 = barrier_1.clone();
28+
let barrier_1_for_tx2 = barrier_1.clone();
29+
let barrier_2_for_tx1 = barrier_2.clone();
30+
let barrier_2_for_tx2 = barrier_2.clone();
31+
32+
let mut tx = conn.build_transaction().serializable().read_write();
33+
34+
let res = tx.run(|conn| {
35+
Box::pin(async {
36+
users3::table.select(users3::id).load::<i32>(conn).await?;
37+
38+
barrier_1_for_tx1.wait().await;
39+
diesel::insert_into(users3::table)
40+
.values(users3::id.eq(1))
41+
.execute(conn)
42+
.await?;
43+
barrier_2_for_tx1.wait().await;
44+
45+
Ok::<_, diesel::result::Error>(())
46+
})
47+
});
48+
49+
let mut tx1 = conn1.build_transaction().serializable().read_write();
50+
51+
let res1 = async {
52+
let res = tx1
53+
.run(|conn| {
54+
Box::pin(async {
55+
users3::table.select(users3::id).load::<i32>(conn).await?;
56+
57+
barrier_1_for_tx2.wait().await;
58+
diesel::insert_into(users3::table)
59+
.values(users3::id.eq(1))
60+
.execute(conn)
61+
.await?;
62+
63+
Ok::<_, diesel::result::Error>(())
64+
})
65+
})
66+
.await;
67+
barrier_2_for_tx2.wait().await;
68+
res
69+
};
70+
71+
let (res, res1) = tokio::join!(res, res1);
72+
let _ = diesel::sql_query("DROP TABLE users3")
73+
.execute(&mut conn1)
74+
.await;
75+
76+
assert!(
77+
res1.is_ok(),
78+
"Expected the second transaction to be succussfull, but got an error: {:?}",
79+
res1.unwrap_err()
80+
);
81+
82+
assert!(res.is_err(), "Expected the first transaction to fail");
83+
let err = res.unwrap_err();
84+
assert!(
85+
matches!(
86+
&err,
87+
diesel::result::Error::DatabaseError(
88+
diesel::result::DatabaseErrorKind::SerializationFailure,
89+
_
90+
)
91+
),
92+
"Expected an serialization failure but got another error: {err:?}"
93+
);
94+
95+
let mut tx = conn.build_transaction();
96+
97+
let res = tx
98+
.run(|_| Box::pin(async { Ok::<_, diesel::result::Error>(()) }))
99+
.await;
100+
101+
assert!(
102+
res.is_ok(),
103+
"Expect transaction to run fine but got an error: {:?}",
104+
res.unwrap_err()
105+
);
106+
}

0 commit comments

Comments
 (0)