Skip to content

Commit 756a885

Browse files
committed
Improve the AnsiTransactionManager implementation to behave like that
one from diesel
1 parent 7410cea commit 756a885

File tree

5 files changed

+443
-181
lines changed

5 files changed

+443
-181
lines changed

src/lib.rs

Lines changed: 27 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -140,31 +140,27 @@ where
140140
///
141141
/// This function executes the provided closure `f` inside a database
142142
/// transaction. If there is already an open transaction for the current
143-
/// connection savepoints will be used instead. The connection is commited if
143+
/// connection savepoints will be used instead. The connection is committed if
144144
/// the closure returns `Ok(_)`, it will be rolled back if it returns `Err(_)`.
145145
/// For both cases the original result value will be returned from this function.
146146
///
147147
/// If the transaction fails to commit due to a `SerializationFailure` or a
148-
/// `ReadOnlyTransaction` a rollback will be attempted. In this case a
149-
/// [`Error::CommitTransactionFailed`](crate::result::Error::CommitTransactionFailed)
150-
/// error is returned, which contains details about the original error and
151-
/// the success of the rollback attempt.
152-
/// If the rollback failed the connection should be considered broken
148+
/// `ReadOnlyTransaction` a rollback will be attempted.
149+
/// If the rollback fails, the error will be returned in a
150+
/// [`Error::RollbackErrorOnCommit`](crate::result::Error::RollbackErrorOnCommit),
151+
/// from which you will be able to extract both the original commit error and
152+
/// the rollback error.
153+
/// In addition, the connection will be considered broken
153154
/// as it contains a uncommitted unabortable open transaction. Any further
154155
/// interaction with the transaction system will result in an returned error
155-
/// in this cases.
156+
/// in this case.
156157
///
157158
/// If the closure returns an `Err(_)` and the rollback fails the function
158-
/// will return a [`Error::RollbackError`](crate::result::Error::RollbackError)
159-
/// wrapping the error generated by the rollback operation instead.
160-
/// In this case the connection should be considered broken as it contains
161-
/// an unabortable open transaction.
159+
/// will return that rollback error directly, and the transaction manager will
160+
/// be marked as broken as it contains a uncommitted unabortable open transaction.
162161
///
163162
/// If a nested transaction fails to release the corresponding savepoint
164-
/// a rollback will be attempted. In this case a
165-
/// [`Error::CommitTransactionFailed`](crate::result::Error::CommitTransactionFailed)
166-
/// error is returned, which contains the original error and
167-
/// details about the success of the rollback attempt.
163+
/// the error will be returned directly.
168164
///
169165
/// # Example
170166
///
@@ -218,32 +214,28 @@ where
218214
E: From<diesel::result::Error> + Send,
219215
R: Send,
220216
{
221-
Self::TransactionManager::begin_transaction(self).await?;
222-
match callback(&mut *self).await {
223-
Ok(value) => {
224-
Self::TransactionManager::commit_transaction(self).await?;
225-
Ok(value)
226-
}
227-
Err(user_error) => {
228-
match Self::TransactionManager::rollback_transaction(self).await {
229-
Ok(()) => Err(user_error),
230-
Err(diesel::result::Error::BrokenTransactionManager) => {
231-
// In this case we are probably more interested by the
232-
// original error, which likely caused this
233-
Err(user_error)
234-
}
235-
Err(rollback_error) => Err(rollback_error.into()),
236-
}
237-
}
238-
}
217+
Self::TransactionManager::transaction(self, callback).await
239218
}
240219

241220
/// Creates a transaction that will never be committed. This is useful for
242221
/// tests. Panics if called while inside of a transaction or
243222
/// if called with a connection containing a broken transaction
244223
async fn begin_test_transaction(&mut self) -> QueryResult<()> {
245-
assert_eq!(Self::TransactionManager::get_transaction_depth(self), 0);
246-
Self::TransactionManager::begin_transaction(self).await
224+
use crate::transaction_manager::TransactionManagerStatus;
225+
226+
match Self::TransactionManager::transaction_manager_status_mut(self) {
227+
TransactionManagerStatus::Valid(valid_status) => {
228+
assert_eq!(None, valid_status.transaction_depth())
229+
}
230+
TransactionManagerStatus::InError => panic!("Transaction manager in error"),
231+
};
232+
Self::TransactionManager::begin_transaction(self).await?;
233+
// set the test transaction flag
234+
// to pervent that this connection gets droped in connection pools
235+
// Tests commonly set the poolsize to 1 and use `begin_test_transaction`
236+
// to prevent modifications to the schema
237+
Self::TransactionManager::transaction_manager_status_mut(self).set_test_transaction_flag();
238+
Ok(())
247239
}
248240

249241
#[doc(hidden)]

src/mysql/mod.rs

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,23 @@ impl AsyncConnection for AsyncMysqlConnection {
126126
}
127127
}
128128

129+
#[inline(always)]
130+
fn update_transaction_manager_status<T>(
131+
query_result: QueryResult<T>,
132+
transaction_manager: &mut AnsiTransactionManager,
133+
) -> QueryResult<T> {
134+
if let Err(diesel::result::Error::DatabaseError(
135+
diesel::result::DatabaseErrorKind::SerializationFailure,
136+
_,
137+
)) = query_result
138+
{
139+
transaction_manager
140+
.status
141+
.set_top_level_transaction_requires_rollback()
142+
}
143+
query_result
144+
}
145+
129146
#[async_trait::async_trait]
130147
impl PrepareCallback<Statement, MysqlType> for &'_ mut mysql_async::Conn {
131148
async fn prepare(
@@ -194,6 +211,7 @@ impl AsyncMysqlConnection {
194211
ref mut conn,
195212
ref mut stmt_cache,
196213
ref mut last_stmt,
214+
ref mut transaction_manager,
197215
..
198216
} = self;
199217

@@ -209,7 +227,7 @@ impl AsyncMysqlConnection {
209227
MaybeCached::Cached(s) => s,
210228
_ => unreachable!("We've opted into breaking diesel changes and want to know if things break because someone added a new variant here")
211229
};
212-
callback(conn, stmt, ToSqlHelper{metadata, binds}).await
230+
update_transaction_manager_status(callback(conn, stmt, ToSqlHelper{metadata, binds}).await, transaction_manager)
213231
}).boxed()
214232
}
215233
}

src/pg/mod.rs

Lines changed: 35 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ use self::serialize::ToSqlHelper;
44
use crate::stmt_cache::{PrepareCallback, StmtCache};
55
use crate::{
66
AnsiTransactionManager, AsyncConnection, AsyncConnectionGatWorkaround, SimpleAsyncConnection,
7-
TransactionManager,
87
};
98
use diesel::connection::statement_cache::PrepareForCache;
109
use diesel::pg::{
@@ -94,7 +93,7 @@ mod transaction_builder;
9493
pub struct AsyncPgConnection {
9594
conn: Arc<tokio_postgres::Client>,
9695
stmt_cache: Arc<Mutex<StmtCache<diesel::pg::Pg, Statement>>>,
97-
transaction_state: AnsiTransactionManager,
96+
transaction_state: Arc<Mutex<AnsiTransactionManager>>,
9897
metadata_cache: Arc<Mutex<Option<PgMetadataCache>>>,
9998
}
10099

@@ -152,11 +151,13 @@ impl AsyncConnection for AsyncPgConnection {
152151
let conn = self.conn.clone();
153152
let stmt_cache = self.stmt_cache.clone();
154153
let metadata_cache = self.metadata_cache.clone();
154+
let tm = self.transaction_state.clone();
155155
let query = source.as_query();
156156
Self::with_prepared_statement(
157157
conn,
158158
stmt_cache,
159159
metadata_cache,
160+
tm,
160161
query,
161162
|conn, stmt, binds| async move {
162163
let res = conn.query_raw(&stmt, binds).await.map_err(ErrorHelper)?;
@@ -181,6 +182,7 @@ impl AsyncConnection for AsyncPgConnection {
181182
self.conn.clone(),
182183
self.stmt_cache.clone(),
183184
self.metadata_cache.clone(),
185+
self.transaction_state.clone(),
184186
source,
185187
|conn, stmt, binds| async move {
186188
let binds = binds
@@ -197,42 +199,33 @@ impl AsyncConnection for AsyncPgConnection {
197199
.boxed()
198200
}
199201

200-
fn transaction_state(
201-
&mut self,
202-
) -> &mut <Self::TransactionManager as TransactionManager<Self>>::TransactionStateData {
203-
&mut self.transaction_state
204-
}
205-
206-
async fn transaction<R, E, F>(&mut self, callback: F) -> Result<R, E>
207-
where
208-
F: FnOnce(&mut Self) -> futures::future::BoxFuture<Result<R, E>> + Send,
209-
E: From<diesel::result::Error> + Send,
210-
R: Send,
211-
{
212-
Self::TransactionManager::begin_transaction(self).await?;
213-
match callback(&mut *self).await {
214-
Ok(value) => {
215-
Self::TransactionManager::commit_transaction(self).await?;
216-
Ok(value)
217-
}
218-
Err(user_error) => {
219-
match Self::TransactionManager::rollback_transaction(self).await {
220-
Ok(()) => Err(user_error),
221-
Err(diesel::result::Error::BrokenTransactionManager) => {
222-
// In this case we are probably more interested by the
223-
// original error, which likely caused this
224-
Err(user_error)
225-
}
226-
Err(rollback_error) => Err(rollback_error.into()),
227-
}
228-
}
202+
fn transaction_state(&mut self) -> &mut AnsiTransactionManager {
203+
// there should be no other pending future when this is called
204+
// that means there is only one instance of this arc and
205+
// we can simply access the inner data
206+
if let Some(tm) = Arc::get_mut(&mut self.transaction_state) {
207+
tm.get_mut()
208+
} else {
209+
panic!("Cannot access shared transaction state")
229210
}
230211
}
212+
}
231213

232-
async fn begin_test_transaction(&mut self) -> QueryResult<()> {
233-
assert_eq!(Self::TransactionManager::get_transaction_depth(self), 0);
234-
Self::TransactionManager::begin_transaction(self).await
214+
#[inline(always)]
215+
fn update_transaction_manager_status<T>(
216+
query_result: QueryResult<T>,
217+
transaction_manager: &mut AnsiTransactionManager,
218+
) -> QueryResult<T> {
219+
if let Err(diesel::result::Error::DatabaseError(
220+
diesel::result::DatabaseErrorKind::SerializationFailure,
221+
_,
222+
)) = query_result
223+
{
224+
transaction_manager
225+
.status
226+
.set_top_level_transaction_requires_rollback()
235227
}
228+
query_result
236229
}
237230

238231
#[async_trait::async_trait]
@@ -308,7 +301,7 @@ impl AsyncPgConnection {
308301
let mut conn = Self {
309302
conn: Arc::new(conn),
310303
stmt_cache: Arc::new(Mutex::new(StmtCache::new())),
311-
transaction_state: AnsiTransactionManager::default(),
304+
transaction_state: Arc::new(Mutex::new(AnsiTransactionManager::default())),
312305
metadata_cache: Arc::new(Mutex::new(Some(PgMetadataCache::new()))),
313306
};
314307
conn.set_config_options()
@@ -333,6 +326,7 @@ impl AsyncPgConnection {
333326
raw_connection: Arc<tokio_postgres::Client>,
334327
stmt_cache: Arc<Mutex<StmtCache<diesel::pg::Pg, Statement>>>,
335328
metadata_cache: Arc<Mutex<Option<PgMetadataCache>>>,
329+
tm: Arc<Mutex<AnsiTransactionManager>>,
336330
query: T,
337331
callback: impl FnOnce(Arc<tokio_postgres::Client>, Statement, Vec<ToSqlHelper>) -> F,
338332
) -> QueryResult<R>
@@ -396,15 +390,15 @@ impl AsyncPgConnection {
396390
}
397391
} else {
398392
// bubble up any error as soon as we have done all lookups
399-
let _ = res?;
393+
res?;
400394
break;
401395
}
402396
}
403397
}
404398

405399
let stmt = {
406400
let mut stmt_cache = stmt_cache.lock().await;
407-
let stmt = stmt_cache
401+
stmt_cache
408402
.cached_prepared_statement(
409403
query,
410404
&bind_collector.metadata,
@@ -413,8 +407,7 @@ impl AsyncPgConnection {
413407
)
414408
.await?
415409
.0
416-
.clone();
417-
stmt
410+
.clone()
418411
};
419412

420413
let binds = bind_collector
@@ -423,7 +416,9 @@ impl AsyncPgConnection {
423416
.zip(bind_collector.binds)
424417
.map(|(meta, bind)| ToSqlHelper(meta, bind))
425418
.collect::<Vec<_>>();
426-
callback(raw_connection, stmt.clone(), binds).await
419+
let res = callback(raw_connection, stmt.clone(), binds).await;
420+
let mut tm = tm.lock().await;
421+
update_transaction_manager_status(res, &mut *tm)
427422
}
428423
}
429424

src/pooled_connection/mod.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,10 @@ where
173173
TM::commit_transaction(&mut **conn).await
174174
}
175175

176-
fn get_transaction_depth(conn: &mut C) -> u32 {
177-
TM::get_transaction_depth(&mut **conn)
176+
fn transaction_manager_status_mut(
177+
conn: &mut C,
178+
) -> &mut crate::transaction_manager::TransactionManagerStatus {
179+
TM::transaction_manager_status_mut(&mut **conn)
178180
}
179181
}
180182

0 commit comments

Comments
 (0)