Skip to content

Commit 48a41a1

Browse files
committed
define SpawnBlocking trait to customize runtime used for spawning blocking tasks.
Previously, `SyncConnectionWrapper` was using tokio as spawning and running blocking tasks. This had prevented using Sqlite backend on wasm32-unknown-unknown target since futures generally run on top of JavaScript promises with the help of wasm_bindgen_futures crate. It is now possible for users to provide their own runtime to spawn blocking tasks inside the `SyncConnectionWrapper`.
1 parent c8a752f commit 48a41a1

File tree

1 file changed

+109
-18
lines changed
  • src/sync_connection_wrapper

1 file changed

+109
-18
lines changed

src/sync_connection_wrapper/mod.rs

Lines changed: 109 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,38 @@
66
//!
77
//! * using a sync Connection implementation in async context
88
//! * using the same code base for async crates needing multiple backends
9+
use std::error::Error;
10+
use futures_util::future::BoxFuture;
911

1012
#[cfg(feature = "sqlite")]
1113
mod sqlite;
1214

15+
/// This is a helper trait that allows to customize the
16+
/// spawning blocking tasks as part of the
17+
/// [`SyncConnectionWrapper`] type. By default a
18+
/// tokio runtime and its spawn_blocking function is used.
19+
pub trait SpawnBlocking {
20+
/// This function should allow to execute a
21+
/// given blocking task without blocking the caller
22+
/// to get the result
23+
fn spawn_blocking<'a, R>(
24+
&mut self,
25+
task: impl FnOnce() -> R + Send + 'static,
26+
) -> BoxFuture<'a, Result<R, Box<dyn Error + Send + Sync + 'static>>>
27+
where
28+
R: Send + 'static;
29+
30+
/// This function should be used to construct
31+
/// a new runtime instance
32+
fn get_runtime() -> Self;
33+
}
34+
35+
#[cfg(feature = "tokio")]
36+
pub type SyncConnectionWrapper<C, B = self::implementation::Tokio> = self::implementation::SyncConnectionWrapper<C, B>;
37+
38+
#[cfg(not(feature = "tokio"))]
1339
pub use self::implementation::SyncConnectionWrapper;
40+
1441
pub use self::implementation::SyncTransactionManagerWrapper;
1542

1643
mod implementation {
@@ -25,17 +52,17 @@ mod implementation {
2552
};
2653
use diesel::row::IntoOwnedRow;
2754
use diesel::{ConnectionResult, QueryResult};
28-
use futures_util::future::BoxFuture;
2955
use futures_util::stream::BoxStream;
3056
use futures_util::{FutureExt, StreamExt, TryFutureExt};
3157
use std::marker::PhantomData;
3258
use std::sync::{Arc, Mutex};
33-
use tokio::task::JoinError;
3459

35-
fn from_tokio_join_error(join_error: JoinError) -> diesel::result::Error {
60+
use super::*;
61+
62+
fn from_spawn_blocking_error(error: Box<dyn Error + Send + Sync + 'static>) -> diesel::result::Error {
3663
diesel::result::Error::DatabaseError(
3764
diesel::result::DatabaseErrorKind::UnableToSendCommand,
38-
Box::new(join_error.to_string()),
65+
Box::new(error.to_string()),
3966
)
4067
}
4168

@@ -77,13 +104,15 @@ mod implementation {
77104
/// # some_async_fn().await;
78105
/// # }
79106
/// ```
80-
pub struct SyncConnectionWrapper<C> {
107+
pub struct SyncConnectionWrapper<C, S> {
81108
inner: Arc<Mutex<C>>,
109+
runtime: S,
82110
}
83111

84-
impl<C> SimpleAsyncConnection for SyncConnectionWrapper<C>
112+
impl<C, S> SimpleAsyncConnection for SyncConnectionWrapper<C, S>
85113
where
86114
C: diesel::connection::Connection + 'static,
115+
S: SpawnBlocking + Send,
87116
{
88117
async fn batch_execute(&mut self, query: &str) -> QueryResult<()> {
89118
let query = query.to_string();
@@ -92,7 +121,7 @@ mod implementation {
92121
}
93122
}
94123

95-
impl<C, MD, O> AsyncConnection for SyncConnectionWrapper<C>
124+
impl<C, S, MD, O> AsyncConnection for SyncConnectionWrapper<C, S>
96125
where
97126
// Backend bounds
98127
<C as Connection>::Backend: std::default::Default + DieselReserveSpecialization,
@@ -108,6 +137,8 @@ mod implementation {
108137
O: 'static + Send + for<'conn> diesel::row::Row<'conn, C::Backend>,
109138
for<'conn, 'query> <C as LoadConnection>::Row<'conn, 'query>:
110139
IntoOwnedRow<'conn, <C as Connection>::Backend, OwnedRow = O>,
140+
// SpawnBlocking bounds
141+
S: SpawnBlocking + Send,
111142
{
112143
type LoadFuture<'conn, 'query> = BoxFuture<'query, QueryResult<Self::Stream<'conn, 'query>>>;
113144
type ExecuteFuture<'conn, 'query> = BoxFuture<'query, QueryResult<usize>>;
@@ -118,10 +149,12 @@ mod implementation {
118149

119150
async fn establish(database_url: &str) -> ConnectionResult<Self> {
120151
let database_url = database_url.to_string();
121-
tokio::task::spawn_blocking(move || C::establish(&database_url))
152+
let mut runtime = S::get_runtime();
153+
154+
runtime.spawn_blocking(move || C::establish(&database_url))
122155
.await
123156
.unwrap_or_else(|e| Err(diesel::ConnectionError::BadConnection(e.to_string())))
124-
.map(|c| SyncConnectionWrapper::new(c))
157+
.map(move |c| SyncConnectionWrapper::with_runtime(c, runtime))
125158
}
126159

127160
fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query>
@@ -209,44 +242,60 @@ mod implementation {
209242
/// A wrapper of a diesel transaction manager usable in async context.
210243
pub struct SyncTransactionManagerWrapper<T>(PhantomData<T>);
211244

212-
impl<T, C> TransactionManager<SyncConnectionWrapper<C>> for SyncTransactionManagerWrapper<T>
245+
impl<T, C, S> TransactionManager<SyncConnectionWrapper<C, S>> for SyncTransactionManagerWrapper<T>
213246
where
214-
SyncConnectionWrapper<C>: AsyncConnection,
247+
SyncConnectionWrapper<C, S>: AsyncConnection,
215248
C: Connection + 'static,
249+
S: SpawnBlocking,
216250
T: diesel::connection::TransactionManager<C> + Send,
217251
{
218252
type TransactionStateData = T::TransactionStateData;
219253

220-
async fn begin_transaction(conn: &mut SyncConnectionWrapper<C>) -> QueryResult<()> {
254+
async fn begin_transaction(conn: &mut SyncConnectionWrapper<C, S>) -> QueryResult<()> {
221255
conn.spawn_blocking(move |inner| T::begin_transaction(inner))
222256
.await
223257
}
224258

225-
async fn commit_transaction(conn: &mut SyncConnectionWrapper<C>) -> QueryResult<()> {
259+
async fn commit_transaction(conn: &mut SyncConnectionWrapper<C, S>) -> QueryResult<()> {
226260
conn.spawn_blocking(move |inner| T::commit_transaction(inner))
227261
.await
228262
}
229263

230-
async fn rollback_transaction(conn: &mut SyncConnectionWrapper<C>) -> QueryResult<()> {
264+
async fn rollback_transaction(conn: &mut SyncConnectionWrapper<C, S>) -> QueryResult<()> {
231265
conn.spawn_blocking(move |inner| T::rollback_transaction(inner))
232266
.await
233267
}
234268

235269
fn transaction_manager_status_mut(
236-
conn: &mut SyncConnectionWrapper<C>,
270+
conn: &mut SyncConnectionWrapper<C, S>,
237271
) -> &mut TransactionManagerStatus {
238272
T::transaction_manager_status_mut(conn.exclusive_connection())
239273
}
240274
}
241275

242-
impl<C> SyncConnectionWrapper<C> {
276+
impl<C, S> SyncConnectionWrapper<C, S> {
243277
/// Builds a wrapper with this underlying sync connection
244278
pub fn new(connection: C) -> Self
245279
where
246280
C: Connection,
281+
S: SpawnBlocking,
282+
{
283+
SyncConnectionWrapper {
284+
inner: Arc::new(Mutex::new(connection)),
285+
runtime: S::get_runtime(),
286+
}
287+
}
288+
289+
/// Builds a wrapper with this underlying sync connection
290+
/// and runtime for spawning blocking tasks
291+
pub fn with_runtime(connection: C, runtime: S) -> Self
292+
where
293+
C: Connection,
294+
S: SpawnBlocking,
247295
{
248296
SyncConnectionWrapper {
249297
inner: Arc::new(Mutex::new(connection)),
298+
runtime,
250299
}
251300
}
252301

@@ -283,17 +332,18 @@ mod implementation {
283332
where
284333
C: Connection + 'static,
285334
R: Send + 'static,
335+
S: SpawnBlocking,
286336
{
287337
let inner = self.inner.clone();
288-
tokio::task::spawn_blocking(move || {
338+
self.runtime.spawn_blocking(move || {
289339
let mut inner = inner.lock().unwrap_or_else(|poison| {
290340
// try to be resilient by providing the guard
291341
inner.clear_poison();
292342
poison.into_inner()
293343
});
294344
task(&mut inner)
295345
})
296-
.unwrap_or_else(|err| QueryResult::Err(from_tokio_join_error(err)))
346+
.unwrap_or_else(|err| QueryResult::Err(from_spawn_blocking_error(err)))
297347
.boxed()
298348
}
299349

@@ -316,6 +366,8 @@ mod implementation {
316366
// Arguments/Return bounds
317367
Q: QueryFragment<C::Backend> + QueryId,
318368
R: Send + 'static,
369+
// SpawnBlocking bounds
370+
S: SpawnBlocking,
319371
{
320372
let backend = C::Backend::default();
321373

@@ -383,4 +435,43 @@ mod implementation {
383435
Self::TransactionManager::is_broken_transaction_manager(self)
384436
}
385437
}
438+
439+
#[cfg(feature = "tokio")]
440+
pub enum Tokio {
441+
Handle(tokio::runtime::Handle),
442+
Runtime(tokio::runtime::Runtime)
443+
}
444+
445+
#[cfg(feature = "tokio")]
446+
impl SpawnBlocking for Tokio {
447+
fn spawn_blocking<'a, R>(
448+
&mut self,
449+
task: impl FnOnce() -> R + Send + 'static,
450+
) -> BoxFuture<'a, Result<R, Box<dyn Error + Send + Sync + 'static>>>
451+
where
452+
R: Send + 'static,
453+
{
454+
let fut = match self {
455+
Tokio::Handle(handle) => handle.spawn_blocking(task),
456+
Tokio::Runtime(runtime) => runtime.spawn_blocking(task)
457+
};
458+
459+
fut
460+
.map_err(|err| Box::from(err))
461+
.boxed()
462+
}
463+
464+
fn get_runtime() -> Self {
465+
if let Ok(handle) = tokio::runtime::Handle::try_current() {
466+
Tokio::Handle(handle)
467+
} else {
468+
let runtime = tokio::runtime::Builder::new_current_thread()
469+
.enable_io()
470+
.build()
471+
.unwrap();
472+
473+
Tokio::Runtime(runtime)
474+
}
475+
}
476+
}
386477
}

0 commit comments

Comments
 (0)