@@ -4,7 +4,6 @@ use self::serialize::ToSqlHelper;
4
4
use crate :: stmt_cache:: { PrepareCallback , StmtCache } ;
5
5
use crate :: {
6
6
AnsiTransactionManager , AsyncConnection , AsyncConnectionGatWorkaround , SimpleAsyncConnection ,
7
- TransactionManager ,
8
7
} ;
9
8
use diesel:: connection:: statement_cache:: PrepareForCache ;
10
9
use diesel:: pg:: {
@@ -94,7 +93,7 @@ mod transaction_builder;
94
93
pub struct AsyncPgConnection {
95
94
conn : Arc < tokio_postgres:: Client > ,
96
95
stmt_cache : Arc < Mutex < StmtCache < diesel:: pg:: Pg , Statement > > > ,
97
- transaction_state : AnsiTransactionManager ,
96
+ transaction_state : Arc < Mutex < AnsiTransactionManager > > ,
98
97
metadata_cache : Arc < Mutex < Option < PgMetadataCache > > > ,
99
98
}
100
99
@@ -152,11 +151,13 @@ impl AsyncConnection for AsyncPgConnection {
152
151
let conn = self . conn . clone ( ) ;
153
152
let stmt_cache = self . stmt_cache . clone ( ) ;
154
153
let metadata_cache = self . metadata_cache . clone ( ) ;
154
+ let tm = self . transaction_state . clone ( ) ;
155
155
let query = source. as_query ( ) ;
156
156
Self :: with_prepared_statement (
157
157
conn,
158
158
stmt_cache,
159
159
metadata_cache,
160
+ tm,
160
161
query,
161
162
|conn, stmt, binds| async move {
162
163
let res = conn. query_raw ( & stmt, binds) . await . map_err ( ErrorHelper ) ?;
@@ -181,6 +182,7 @@ impl AsyncConnection for AsyncPgConnection {
181
182
self . conn . clone ( ) ,
182
183
self . stmt_cache . clone ( ) ,
183
184
self . metadata_cache . clone ( ) ,
185
+ self . transaction_state . clone ( ) ,
184
186
source,
185
187
|conn, stmt, binds| async move {
186
188
let binds = binds
@@ -197,42 +199,33 @@ impl AsyncConnection for AsyncPgConnection {
197
199
. boxed ( )
198
200
}
199
201
200
- fn transaction_state (
201
- & mut self ,
202
- ) -> & mut <Self :: TransactionManager as TransactionManager < Self > >:: TransactionStateData {
203
- & mut self . transaction_state
204
- }
205
-
206
- async fn transaction < R , E , F > ( & mut self , callback : F ) -> Result < R , E >
207
- where
208
- F : FnOnce ( & mut Self ) -> futures:: future:: BoxFuture < Result < R , E > > + Send ,
209
- E : From < diesel:: result:: Error > + Send ,
210
- R : Send ,
211
- {
212
- Self :: TransactionManager :: begin_transaction ( self ) . await ?;
213
- match callback ( & mut * self ) . await {
214
- Ok ( value) => {
215
- Self :: TransactionManager :: commit_transaction ( self ) . await ?;
216
- Ok ( value)
217
- }
218
- Err ( user_error) => {
219
- match Self :: TransactionManager :: rollback_transaction ( self ) . await {
220
- Ok ( ( ) ) => Err ( user_error) ,
221
- Err ( diesel:: result:: Error :: BrokenTransactionManager ) => {
222
- // In this case we are probably more interested by the
223
- // original error, which likely caused this
224
- Err ( user_error)
225
- }
226
- Err ( rollback_error) => Err ( rollback_error. into ( ) ) ,
227
- }
228
- }
202
+ fn transaction_state ( & mut self ) -> & mut AnsiTransactionManager {
203
+ // there should be no other pending future when this is called
204
+ // that means there is only one instance of this arc and
205
+ // we can simply access the inner data
206
+ if let Some ( tm) = Arc :: get_mut ( & mut self . transaction_state ) {
207
+ tm. get_mut ( )
208
+ } else {
209
+ panic ! ( "Cannot access shared transaction state" )
229
210
}
230
211
}
212
+ }
231
213
232
- async fn begin_test_transaction ( & mut self ) -> QueryResult < ( ) > {
233
- assert_eq ! ( Self :: TransactionManager :: get_transaction_depth( self ) , 0 ) ;
234
- Self :: TransactionManager :: begin_transaction ( self ) . await
214
+ #[ inline( always) ]
215
+ fn update_transaction_manager_status < T > (
216
+ query_result : QueryResult < T > ,
217
+ transaction_manager : & mut AnsiTransactionManager ,
218
+ ) -> QueryResult < T > {
219
+ if let Err ( diesel:: result:: Error :: DatabaseError (
220
+ diesel:: result:: DatabaseErrorKind :: SerializationFailure ,
221
+ _,
222
+ ) ) = query_result
223
+ {
224
+ transaction_manager
225
+ . status
226
+ . set_top_level_transaction_requires_rollback ( )
235
227
}
228
+ query_result
236
229
}
237
230
238
231
#[ async_trait:: async_trait]
@@ -308,7 +301,7 @@ impl AsyncPgConnection {
308
301
let mut conn = Self {
309
302
conn : Arc :: new ( conn) ,
310
303
stmt_cache : Arc :: new ( Mutex :: new ( StmtCache :: new ( ) ) ) ,
311
- transaction_state : AnsiTransactionManager :: default ( ) ,
304
+ transaction_state : Arc :: new ( Mutex :: new ( AnsiTransactionManager :: default ( ) ) ) ,
312
305
metadata_cache : Arc :: new ( Mutex :: new ( Some ( PgMetadataCache :: new ( ) ) ) ) ,
313
306
} ;
314
307
conn. set_config_options ( )
@@ -333,6 +326,7 @@ impl AsyncPgConnection {
333
326
raw_connection : Arc < tokio_postgres:: Client > ,
334
327
stmt_cache : Arc < Mutex < StmtCache < diesel:: pg:: Pg , Statement > > > ,
335
328
metadata_cache : Arc < Mutex < Option < PgMetadataCache > > > ,
329
+ tm : Arc < Mutex < AnsiTransactionManager > > ,
336
330
query : T ,
337
331
callback : impl FnOnce ( Arc < tokio_postgres:: Client > , Statement , Vec < ToSqlHelper > ) -> F ,
338
332
) -> QueryResult < R >
@@ -396,15 +390,15 @@ impl AsyncPgConnection {
396
390
}
397
391
} else {
398
392
// bubble up any error as soon as we have done all lookups
399
- let _ = res?;
393
+ res?;
400
394
break ;
401
395
}
402
396
}
403
397
}
404
398
405
399
let stmt = {
406
400
let mut stmt_cache = stmt_cache. lock ( ) . await ;
407
- let stmt = stmt_cache
401
+ stmt_cache
408
402
. cached_prepared_statement (
409
403
query,
410
404
& bind_collector. metadata ,
@@ -413,8 +407,7 @@ impl AsyncPgConnection {
413
407
)
414
408
. await ?
415
409
. 0
416
- . clone ( ) ;
417
- stmt
410
+ . clone ( )
418
411
} ;
419
412
420
413
let binds = bind_collector
@@ -423,7 +416,9 @@ impl AsyncPgConnection {
423
416
. zip ( bind_collector. binds )
424
417
. map ( |( meta, bind) | ToSqlHelper ( meta, bind) )
425
418
. collect :: < Vec < _ > > ( ) ;
426
- callback ( raw_connection, stmt. clone ( ) , binds) . await
419
+ let res = callback ( raw_connection, stmt. clone ( ) , binds) . await ;
420
+ let mut tm = tm. lock ( ) . await ;
421
+ update_transaction_manager_status ( res, & mut * tm)
427
422
}
428
423
}
429
424
0 commit comments