6
6
//!
7
7
//! * using a sync Connection implementation in async context
8
8
//! * using the same code base for async crates needing multiple backends
9
+ use std:: error:: Error ;
10
+ use futures_util:: future:: BoxFuture ;
9
11
10
12
#[ cfg( feature = "sqlite" ) ]
11
13
mod sqlite;
12
14
15
+ /// This is a helper trait that allows to customize the
16
+ /// spawning blocking tasks as part of the
17
+ /// [`SyncConnectionWrapper`] type. By default a
18
+ /// tokio runtime and its spawn_blocking function is used.
19
+ pub trait SpawnBlocking {
20
+ /// This function should allow to execute a
21
+ /// given blocking task without blocking the caller
22
+ /// to get the result
23
+ fn spawn_blocking < ' a , R > (
24
+ & mut self ,
25
+ task : impl FnOnce ( ) -> R + Send + ' static ,
26
+ ) -> BoxFuture < ' a , Result < R , Box < dyn Error + Send + Sync + ' static > > >
27
+ where
28
+ R : Send + ' static ;
29
+
30
+ /// This function should be used to construct
31
+ /// a new runtime instance
32
+ fn get_runtime ( ) -> Self ;
33
+ }
34
+
35
+ #[ cfg( feature = "tokio" ) ]
36
+ pub type SyncConnectionWrapper < C , B = self :: implementation:: Tokio > = self :: implementation:: SyncConnectionWrapper < C , B > ;
37
+
38
+ #[ cfg( not( feature = "tokio" ) ) ]
13
39
pub use self :: implementation:: SyncConnectionWrapper ;
40
+
14
41
pub use self :: implementation:: SyncTransactionManagerWrapper ;
15
42
16
43
mod implementation {
@@ -25,17 +52,17 @@ mod implementation {
25
52
} ;
26
53
use diesel:: row:: IntoOwnedRow ;
27
54
use diesel:: { ConnectionResult , QueryResult } ;
28
- use futures_util:: future:: BoxFuture ;
29
55
use futures_util:: stream:: BoxStream ;
30
56
use futures_util:: { FutureExt , StreamExt , TryFutureExt } ;
31
57
use std:: marker:: PhantomData ;
32
58
use std:: sync:: { Arc , Mutex } ;
33
- use tokio:: task:: JoinError ;
34
59
35
- fn from_tokio_join_error ( join_error : JoinError ) -> diesel:: result:: Error {
60
+ use super :: * ;
61
+
62
+ fn from_spawn_blocking_error ( error : Box < dyn Error + Send + Sync + ' static > ) -> diesel:: result:: Error {
36
63
diesel:: result:: Error :: DatabaseError (
37
64
diesel:: result:: DatabaseErrorKind :: UnableToSendCommand ,
38
- Box :: new ( join_error . to_string ( ) ) ,
65
+ Box :: new ( error . to_string ( ) ) ,
39
66
)
40
67
}
41
68
@@ -77,13 +104,15 @@ mod implementation {
77
104
/// # some_async_fn().await;
78
105
/// # }
79
106
/// ```
80
- pub struct SyncConnectionWrapper < C > {
107
+ pub struct SyncConnectionWrapper < C , S > {
81
108
inner : Arc < Mutex < C > > ,
109
+ runtime : S ,
82
110
}
83
111
84
- impl < C > SimpleAsyncConnection for SyncConnectionWrapper < C >
112
+ impl < C , S > SimpleAsyncConnection for SyncConnectionWrapper < C , S >
85
113
where
86
114
C : diesel:: connection:: Connection + ' static ,
115
+ S : SpawnBlocking + Send ,
87
116
{
88
117
async fn batch_execute ( & mut self , query : & str ) -> QueryResult < ( ) > {
89
118
let query = query. to_string ( ) ;
@@ -92,7 +121,7 @@ mod implementation {
92
121
}
93
122
}
94
123
95
- impl < C , MD , O > AsyncConnection for SyncConnectionWrapper < C >
124
+ impl < C , S , MD , O > AsyncConnection for SyncConnectionWrapper < C , S >
96
125
where
97
126
// Backend bounds
98
127
<C as Connection >:: Backend : std:: default:: Default + DieselReserveSpecialization ,
@@ -108,6 +137,8 @@ mod implementation {
108
137
O : ' static + Send + for < ' conn > diesel:: row:: Row < ' conn , C :: Backend > ,
109
138
for < ' conn , ' query > <C as LoadConnection >:: Row < ' conn , ' query > :
110
139
IntoOwnedRow < ' conn , <C as Connection >:: Backend , OwnedRow = O > ,
140
+ // SpawnBlocking bounds
141
+ S : SpawnBlocking + Send ,
111
142
{
112
143
type LoadFuture < ' conn , ' query > = BoxFuture < ' query , QueryResult < Self :: Stream < ' conn , ' query > > > ;
113
144
type ExecuteFuture < ' conn , ' query > = BoxFuture < ' query , QueryResult < usize > > ;
@@ -118,10 +149,12 @@ mod implementation {
118
149
119
150
async fn establish ( database_url : & str ) -> ConnectionResult < Self > {
120
151
let database_url = database_url. to_string ( ) ;
121
- tokio:: task:: spawn_blocking ( move || C :: establish ( & database_url) )
152
+ let mut runtime = S :: get_runtime ( ) ;
153
+
154
+ runtime. spawn_blocking ( move || C :: establish ( & database_url) )
122
155
. await
123
156
. unwrap_or_else ( |e| Err ( diesel:: ConnectionError :: BadConnection ( e. to_string ( ) ) ) )
124
- . map ( |c| SyncConnectionWrapper :: new ( c ) )
157
+ . map ( move |c| SyncConnectionWrapper :: with_runtime ( c , runtime ) )
125
158
}
126
159
127
160
fn load < ' conn , ' query , T > ( & ' conn mut self , source : T ) -> Self :: LoadFuture < ' conn , ' query >
@@ -209,44 +242,60 @@ mod implementation {
209
242
/// A wrapper of a diesel transaction manager usable in async context.
210
243
pub struct SyncTransactionManagerWrapper < T > ( PhantomData < T > ) ;
211
244
212
- impl < T , C > TransactionManager < SyncConnectionWrapper < C > > for SyncTransactionManagerWrapper < T >
245
+ impl < T , C , S > TransactionManager < SyncConnectionWrapper < C , S > > for SyncTransactionManagerWrapper < T >
213
246
where
214
- SyncConnectionWrapper < C > : AsyncConnection ,
247
+ SyncConnectionWrapper < C , S > : AsyncConnection ,
215
248
C : Connection + ' static ,
249
+ S : SpawnBlocking ,
216
250
T : diesel:: connection:: TransactionManager < C > + Send ,
217
251
{
218
252
type TransactionStateData = T :: TransactionStateData ;
219
253
220
- async fn begin_transaction ( conn : & mut SyncConnectionWrapper < C > ) -> QueryResult < ( ) > {
254
+ async fn begin_transaction ( conn : & mut SyncConnectionWrapper < C , S > ) -> QueryResult < ( ) > {
221
255
conn. spawn_blocking ( move |inner| T :: begin_transaction ( inner) )
222
256
. await
223
257
}
224
258
225
- async fn commit_transaction ( conn : & mut SyncConnectionWrapper < C > ) -> QueryResult < ( ) > {
259
+ async fn commit_transaction ( conn : & mut SyncConnectionWrapper < C , S > ) -> QueryResult < ( ) > {
226
260
conn. spawn_blocking ( move |inner| T :: commit_transaction ( inner) )
227
261
. await
228
262
}
229
263
230
- async fn rollback_transaction ( conn : & mut SyncConnectionWrapper < C > ) -> QueryResult < ( ) > {
264
+ async fn rollback_transaction ( conn : & mut SyncConnectionWrapper < C , S > ) -> QueryResult < ( ) > {
231
265
conn. spawn_blocking ( move |inner| T :: rollback_transaction ( inner) )
232
266
. await
233
267
}
234
268
235
269
fn transaction_manager_status_mut (
236
- conn : & mut SyncConnectionWrapper < C > ,
270
+ conn : & mut SyncConnectionWrapper < C , S > ,
237
271
) -> & mut TransactionManagerStatus {
238
272
T :: transaction_manager_status_mut ( conn. exclusive_connection ( ) )
239
273
}
240
274
}
241
275
242
- impl < C > SyncConnectionWrapper < C > {
276
+ impl < C , S > SyncConnectionWrapper < C , S > {
243
277
/// Builds a wrapper with this underlying sync connection
244
278
pub fn new ( connection : C ) -> Self
245
279
where
246
280
C : Connection ,
281
+ S : SpawnBlocking ,
282
+ {
283
+ SyncConnectionWrapper {
284
+ inner : Arc :: new ( Mutex :: new ( connection) ) ,
285
+ runtime : S :: get_runtime ( ) ,
286
+ }
287
+ }
288
+
289
+ /// Builds a wrapper with this underlying sync connection
290
+ /// and runtime for spawning blocking tasks
291
+ pub fn with_runtime ( connection : C , runtime : S ) -> Self
292
+ where
293
+ C : Connection ,
294
+ S : SpawnBlocking ,
247
295
{
248
296
SyncConnectionWrapper {
249
297
inner : Arc :: new ( Mutex :: new ( connection) ) ,
298
+ runtime,
250
299
}
251
300
}
252
301
@@ -283,17 +332,18 @@ mod implementation {
283
332
where
284
333
C : Connection + ' static ,
285
334
R : Send + ' static ,
335
+ S : SpawnBlocking ,
286
336
{
287
337
let inner = self . inner . clone ( ) ;
288
- tokio :: task :: spawn_blocking ( move || {
338
+ self . runtime . spawn_blocking ( move || {
289
339
let mut inner = inner. lock ( ) . unwrap_or_else ( |poison| {
290
340
// try to be resilient by providing the guard
291
341
inner. clear_poison ( ) ;
292
342
poison. into_inner ( )
293
343
} ) ;
294
344
task ( & mut inner)
295
345
} )
296
- . unwrap_or_else ( |err| QueryResult :: Err ( from_tokio_join_error ( err) ) )
346
+ . unwrap_or_else ( |err| QueryResult :: Err ( from_spawn_blocking_error ( err) ) )
297
347
. boxed ( )
298
348
}
299
349
@@ -316,6 +366,8 @@ mod implementation {
316
366
// Arguments/Return bounds
317
367
Q : QueryFragment < C :: Backend > + QueryId ,
318
368
R : Send + ' static ,
369
+ // SpawnBlocking bounds
370
+ S : SpawnBlocking ,
319
371
{
320
372
let backend = C :: Backend :: default ( ) ;
321
373
@@ -383,4 +435,43 @@ mod implementation {
383
435
Self :: TransactionManager :: is_broken_transaction_manager ( self )
384
436
}
385
437
}
438
+
439
+ #[ cfg( feature = "tokio" ) ]
440
+ pub enum Tokio {
441
+ Handle ( tokio:: runtime:: Handle ) ,
442
+ Runtime ( tokio:: runtime:: Runtime )
443
+ }
444
+
445
+ #[ cfg( feature = "tokio" ) ]
446
+ impl SpawnBlocking for Tokio {
447
+ fn spawn_blocking < ' a , R > (
448
+ & mut self ,
449
+ task : impl FnOnce ( ) -> R + Send + ' static ,
450
+ ) -> BoxFuture < ' a , Result < R , Box < dyn Error + Send + Sync + ' static > > >
451
+ where
452
+ R : Send + ' static ,
453
+ {
454
+ let fut = match self {
455
+ Tokio :: Handle ( handle) => handle. spawn_blocking ( task) ,
456
+ Tokio :: Runtime ( runtime) => runtime. spawn_blocking ( task)
457
+ } ;
458
+
459
+ fut
460
+ . map_err ( |err| Box :: from ( err) )
461
+ . boxed ( )
462
+ }
463
+
464
+ fn get_runtime ( ) -> Self {
465
+ if let Ok ( handle) = tokio:: runtime:: Handle :: try_current ( ) {
466
+ Tokio :: Handle ( handle)
467
+ } else {
468
+ let runtime = tokio:: runtime:: Builder :: new_current_thread ( )
469
+ . enable_io ( )
470
+ . build ( )
471
+ . unwrap ( ) ;
472
+
473
+ Tokio :: Runtime ( runtime)
474
+ }
475
+ }
476
+ }
386
477
}
0 commit comments