Skip to content

Commit a807e68

Browse files
committed
Implement support for pipelining for the Postgres connection
1 parent 800da86 commit a807e68

File tree

6 files changed

+566
-320
lines changed

6 files changed

+566
-320
lines changed

src/lib.rs

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use diesel::query_builder::{AsQuery, QueryFragment, QueryId};
33
use diesel::row::Row;
44
use diesel::{ConnectionResult, QueryResult};
55
use futures::future::BoxFuture;
6-
use futures::Stream;
6+
use futures::{Future, Stream};
77
#[cfg(feature = "mysql")]
88
mod mysql;
99
#[cfg(feature = "postgres")]
@@ -25,32 +25,37 @@ pub trait SimpleAsyncConnection {
2525
async fn batch_execute(&mut self, query: &str) -> QueryResult<()>;
2626
}
2727

28-
pub trait AsyncConnectionGatWorkaround<'a, DB: Backend> {
29-
type Stream: Stream<Item = QueryResult<Self::Row>> + Send + 'a;
30-
type Row: Row<'a, DB> + 'a;
28+
pub trait AsyncConnectionGatWorkaround<'conn, 'query, DB: Backend> {
29+
type ExecuteFuture: Future<Output = QueryResult<usize>> + Send;
30+
type LoadFuture: Future<Output = QueryResult<Self::Stream>> + Send;
31+
type Stream: Stream<Item = QueryResult<Self::Row>> + Send;
32+
type Row: Row<'conn, DB>;
3133
}
3234

3335
#[async_trait::async_trait]
3436
pub trait AsyncConnection: SimpleAsyncConnection + Sized + Send
3537
where
36-
for<'a> Self: AsyncConnectionGatWorkaround<'a, Self::Backend>,
38+
for<'a, 'b> Self: AsyncConnectionGatWorkaround<'a, 'b, Self::Backend>,
3739
{
3840
type Backend: Backend;
3941
type TransactionManager: TransactionManager<Self>;
4042

4143
async fn establish(database_url: &str) -> ConnectionResult<Self>;
4244

43-
async fn load<'a, T>(
44-
&'a mut self,
45+
fn load<'conn, 'query, T>(
46+
&'conn mut self,
4547
source: T,
46-
) -> QueryResult<<Self as AsyncConnectionGatWorkaround<'a, Self::Backend>>::Stream>
48+
) -> <Self as AsyncConnectionGatWorkaround<'conn, 'query, Self::Backend>>::LoadFuture
4749
where
48-
T: AsQuery + Send,
49-
T::Query: QueryFragment<Self::Backend> + QueryId + Send;
50+
T: AsQuery + Send + 'query,
51+
T::Query: QueryFragment<Self::Backend> + QueryId + Send + 'query;
5052

51-
async fn execute_returning_count<T>(&mut self, source: T) -> QueryResult<usize>
53+
fn execute_returning_count<'conn, 'query, T>(
54+
&'conn mut self,
55+
source: T,
56+
) -> <Self as AsyncConnectionGatWorkaround<'conn, 'query, Self::Backend>>::ExecuteFuture
5257
where
53-
T: QueryFragment<Self::Backend> + QueryId + Send;
58+
T: QueryFragment<Self::Backend> + QueryId + Send + 'query;
5459

5560
async fn transaction<F, R, E>(&mut self, callback: F) -> Result<R, E>
5661
where

src/mysql/mod.rs

Lines changed: 56 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
use std::pin::Pin;
2-
31
use crate::stmt_cache::{PrepareCallback, StmtCache};
42
use crate::{
53
AnsiTransactionManager, AsyncConnection, AsyncConnectionGatWorkaround, SimpleAsyncConnection,
@@ -9,7 +7,9 @@ use diesel::mysql::{Mysql, MysqlType};
97
use diesel::query_builder::{bind_collector::RawBytesBindCollector, QueryFragment, QueryId};
108
use diesel::result::{ConnectionError, ConnectionResult};
119
use diesel::QueryResult;
12-
use futures::{Future, Stream, StreamExt, TryStreamExt};
10+
use futures::future::BoxFuture;
11+
use futures::stream::BoxStream;
12+
use futures::{Future, FutureExt, StreamExt, TryFutureExt, TryStreamExt};
1313
use mysql_async::prelude::Queryable;
1414
use mysql_async::{Opts, OptsBuilder, Statement};
1515

@@ -35,8 +35,10 @@ impl SimpleAsyncConnection for AsyncMysqlConnection {
3535
}
3636
}
3737

38-
impl<'a> AsyncConnectionGatWorkaround<'a, Mysql> for AsyncMysqlConnection {
39-
type Stream = Pin<Box<dyn Stream<Item = QueryResult<Self::Row>> + Send + 'a>>;
38+
impl<'conn, 'query> AsyncConnectionGatWorkaround<'conn, 'query, Mysql> for AsyncMysqlConnection {
39+
type ExecuteFuture = BoxFuture<'conn, QueryResult<usize>>;
40+
type LoadFuture = BoxFuture<'conn, QueryResult<Self::Stream>>;
41+
type Stream = BoxStream<'conn, QueryResult<Self::Row>>;
4042

4143
type Row = MysqlRow;
4244
}
@@ -50,13 +52,15 @@ impl AsyncConnection for AsyncMysqlConnection {
5052
async fn establish(database_url: &str) -> diesel::ConnectionResult<Self> {
5153
let opts = Opts::from_url(database_url)
5254
.map_err(|e| diesel::result::ConnectionError::InvalidConnectionUrl(e.to_string()))?;
53-
let builder = OptsBuilder::from_opts(opts).init(vec![
54-
"SET sql_mode=(SELECT CONCAT(@@sql_mode, ',PIPES_AS_CONCAT'))",
55-
"SET time_zone = '+00:00';",
56-
"SET character_set_client = 'utf8mb4'",
57-
"SET character_set_connection = 'utf8mb4'",
58-
"SET character_set_results = 'utf8mb4'",
59-
]);
55+
let builder = OptsBuilder::from_opts(opts)
56+
.init(vec![
57+
"SET sql_mode=(SELECT CONCAT(@@sql_mode, ',PIPES_AS_CONCAT'))",
58+
"SET time_zone = '+00:00';",
59+
"SET character_set_client = 'utf8mb4'",
60+
"SET character_set_connection = 'utf8mb4'",
61+
"SET character_set_results = 'utf8mb4'",
62+
])
63+
.stmt_cache_size(0); // We have our own cache
6064

6165
let conn = mysql_async::Conn::new(builder).await.map_err(ErrorHelper)?;
6266

@@ -68,15 +72,16 @@ impl AsyncConnection for AsyncMysqlConnection {
6872
})
6973
}
7074

71-
async fn load<'a, T>(
72-
&'a mut self,
75+
fn load<'conn, 'query, T>(
76+
&'conn mut self,
7377
source: T,
74-
) -> diesel::QueryResult<<Self as AsyncConnectionGatWorkaround<'a, Self::Backend>>::Stream>
78+
) -> <Self as AsyncConnectionGatWorkaround<'conn, 'query, Self::Backend>>::LoadFuture
7579
where
7680
T: diesel::query_builder::AsQuery + Send,
7781
T::Query: diesel::query_builder::QueryFragment<Self::Backend>
7882
+ diesel::query_builder::QueryId
79-
+ Send,
83+
+ Send
84+
+ 'query,
8085
{
8186
self.with_prepared_statement(source.as_query(), |conn, stmt, binds| async move {
8287
let res = conn.exec_iter(&*stmt, binds).await.map_err(ErrorHelper)?;
@@ -95,20 +100,23 @@ impl AsyncConnection for AsyncMysqlConnection {
95100

96101
Ok(stream)
97102
})
98-
.await
103+
.boxed()
99104
}
100105

101-
async fn execute_returning_count<T>(&mut self, source: T) -> diesel::QueryResult<usize>
106+
fn execute_returning_count<'conn, 'query, T>(
107+
&'conn mut self,
108+
source: T,
109+
) -> <Self as AsyncConnectionGatWorkaround<'conn, 'query, Self::Backend>>::ExecuteFuture
102110
where
103111
T: diesel::query_builder::QueryFragment<Self::Backend>
104112
+ diesel::query_builder::QueryId
105-
+ Send,
113+
+ Send
114+
+ 'query,
106115
{
107116
self.with_prepared_statement(source, |conn, stmt, binds| async move {
108117
conn.exec_drop(&*stmt, binds).await.map_err(ErrorHelper)?;
109118
Ok(conn.affected_rows() as usize)
110119
})
111-
.await
112120
}
113121

114122
fn transaction_state(&mut self) -> &mut AnsiTransactionManager {
@@ -117,14 +125,15 @@ impl AsyncConnection for AsyncMysqlConnection {
117125
}
118126

119127
#[async_trait::async_trait]
120-
impl PrepareCallback<Statement, MysqlType> for mysql_async::Conn {
128+
impl PrepareCallback<Statement, MysqlType> for &'_ mut mysql_async::Conn {
121129
async fn prepare(
122-
&mut self,
130+
self,
123131
sql: &str,
124132
_metadata: &[MysqlType],
125133
_is_for_cache: diesel::connection::statement_cache::PrepareForCache,
126-
) -> QueryResult<Statement> {
127-
Ok(self.prep(sql).await.map_err(ErrorHelper)?)
134+
) -> QueryResult<(Statement, Self)> {
135+
let s = self.prep(sql).await.map_err(ErrorHelper)?;
136+
Ok((s, self))
128137
}
129138
}
130139

@@ -155,17 +164,22 @@ impl AsyncMysqlConnection {
155164
Ok(conn)
156165
}
157166

158-
async fn with_prepared_statement<'a, T, F, R>(
159-
&'a mut self,
167+
fn with_prepared_statement<'conn, T, F, R>(
168+
&'conn mut self,
160169
query: T,
161-
callback: impl FnOnce(&'a mut mysql_async::Conn, &'a Statement, ToSqlHelper) -> F,
162-
) -> QueryResult<R>
170+
callback: impl (FnOnce(&'conn mut mysql_async::Conn, &'conn Statement, ToSqlHelper) -> F)
171+
+ Send
172+
+ 'conn,
173+
) -> BoxFuture<'conn, QueryResult<R>>
163174
where
175+
R: Send + 'conn,
164176
T: QueryFragment<Mysql> + QueryId + Send,
165-
F: Future<Output = QueryResult<R>>,
177+
F: Future<Output = QueryResult<R>> + Send,
166178
{
167179
let mut bind_collector = RawBytesBindCollector::<Mysql>::new();
168-
query.collect_binds(&mut bind_collector, &mut (), &Mysql)?;
180+
if let Err(e) = query.collect_binds(&mut bind_collector, &mut (), &Mysql) {
181+
return futures::future::ready(Err(e)).boxed();
182+
}
169183

170184
let binds = bind_collector.binds;
171185
let metadata = bind_collector.metadata;
@@ -177,24 +191,19 @@ impl AsyncMysqlConnection {
177191
..
178192
} = self;
179193

180-
let conn = &mut *conn;
194+
let stmt = stmt_cache.cached_prepared_statement(query, &metadata, conn, &Mysql);
181195

182-
let stmt = {
183-
let stmt = stmt_cache
184-
.cached_prepared_statement(query, &metadata, conn, &Mysql)
185-
.await?;
186-
stmt
187-
};
188-
189-
let stmt = match stmt {
190-
MaybeCached::CannotCache(stmt) => {
191-
*last_stmt = Some(stmt);
192-
last_stmt.as_ref().unwrap()
193-
}
194-
MaybeCached::Cached(s) => s,
195-
_ => unreachable!("We've opted into breaking diesel changes and want to know if things break because someone added a new variant here")
196-
};
196+
stmt.and_then(|(stmt, conn)|async move {
197197

198-
callback(&mut self.conn, stmt, ToSqlHelper { metadata, binds }).await
198+
let stmt = match stmt {
199+
MaybeCached::CannotCache(stmt) => {
200+
*last_stmt = Some(stmt);
201+
last_stmt.as_ref().unwrap()
202+
}
203+
MaybeCached::Cached(s) => s,
204+
_ => unreachable!("We've opted into breaking diesel changes and want to know if things break because someone added a new variant here")
205+
};
206+
callback(conn, stmt, ToSqlHelper{metadata, binds}).await
207+
}).boxed()
199208
}
200209
}

src/mysql/row.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ impl<'a> diesel::row::Row<'a, Mysql> for MysqlRow {
6161
let buffer = match dbg!(value) {
6262
Value::NULL => None,
6363
Value::Bytes(b) => {
64+
dbg!(&b);
65+
dbg!(b.len());
66+
dbg!(std::mem::size_of::<diesel::mysql::data_types::MysqlTime>());
6467
// deserialize gets the length prepended, so we just use that buffer
6568
// directly
6669
Some(Cow::Borrowed(b as &[_]))

0 commit comments

Comments
 (0)