Skip to content

Commit e0d69d9

Browse files
committed
Fix #26
`mysql_async` requires us to explicitly close prepared statements if their statement cache is disabled. This commit introduces the necessary code to close any non-cached prepared statements. Any cached prepared statement will live as long as the connection itself. They will be automatically deallocated on connection close on server side (as the the corresponding connection is gone at this point).
1 parent 9221da1 commit e0d69d9

File tree

1 file changed

+136
-29
lines changed

1 file changed

+136
-29
lines changed

src/mysql/mod.rs

Lines changed: 136 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ pub struct AsyncMysqlConnection {
2727
conn: mysql_async::Conn,
2828
stmt_cache: StmtCache<Mysql, Statement>,
2929
transaction_manager: AnsiTransactionManager,
30-
last_stmt: Option<Statement>,
3130
}
3231

3332
#[async_trait::async_trait]
@@ -72,7 +71,6 @@ impl AsyncConnection for AsyncMysqlConnection {
7271
conn,
7372
stmt_cache: StmtCache::new(),
7473
transaction_manager: AnsiTransactionManager::default(),
75-
last_stmt: None,
7674
})
7775
}
7876

@@ -88,19 +86,44 @@ impl AsyncConnection for AsyncMysqlConnection {
8886
+ 'query,
8987
{
9088
self.with_prepared_statement(source.as_query(), |conn, stmt, binds| async move {
91-
let res = conn.exec_iter(stmt, binds).await.map_err(ErrorHelper)?;
89+
let stmt_for_exec = match stmt {
90+
MaybeCached::Cached(ref s) => (*s).clone(),
91+
MaybeCached::CannotCache(ref s) => s.clone(),
92+
_ => todo!(),
93+
};
9294

93-
let stream = res
94-
.stream_and_drop::<MysqlRow>()
95-
.await
96-
.map_err(ErrorHelper)?
97-
.ok_or_else(|| {
98-
diesel::result::Error::DeserializationError(Box::new(
99-
diesel::result::UnexpectedEndOfRow,
100-
))
101-
})?
102-
.map_err(|e| diesel::result::Error::from(ErrorHelper(e)))
103-
.boxed();
95+
let (tx, rx) = futures::channel::mpsc::channel(0);
96+
97+
let yielder = async move {
98+
let r = Self::poll_result_stream(conn, stmt_for_exec, binds, tx).await;
99+
// We need to close any non-cached statement explicitly here as otherwise
100+
// we might error out on too many open statements. See https://github.com/weiznich/diesel_async/issues/26
101+
// for details
102+
//
103+
// This might be problematic for cases where the stream is droped before the end is reached
104+
//
105+
// Such behaviour might happen if users:
106+
// * Just drop the future/stream after polling at least once (timeouts!!)
107+
// * Users only fetch a fixed number of elements from the stream
108+
//
109+
// For now there is not really a good solution to this problem as this would require something like async drop
110+
// (and even with async drop that would be really hard to solve due to the involved lifetimes)
111+
if let MaybeCached::CannotCache(stmt) = stmt {
112+
conn.close(stmt).await.map_err(ErrorHelper)?;
113+
}
114+
r
115+
};
116+
117+
let fake_stream =
118+
futures::stream::once(yielder).filter_map(|e: QueryResult<()>| async move {
119+
if let Err(e) = e {
120+
Some(Err(e))
121+
} else {
122+
None
123+
}
124+
});
125+
126+
let stream = futures::stream::select(fake_stream, rx).boxed();
104127

105128
Ok(stream)
106129
})
@@ -118,7 +141,21 @@ impl AsyncConnection for AsyncMysqlConnection {
118141
+ 'query,
119142
{
120143
self.with_prepared_statement(source, |conn, stmt, binds| async move {
121-
conn.exec_drop(stmt, binds).await.map_err(ErrorHelper)?;
144+
conn.exec_drop(&*stmt, binds).await.map_err(ErrorHelper)?;
145+
// We need to close any non-cached statement explicitly here as otherwise
146+
// we might error out on too many open statements. See https://github.com/weiznich/diesel_async/issues/26
147+
// for details
148+
//
149+
// This might be problematic for cases where the stream is droped before the end is reached
150+
//
151+
// Such behaviour might happen if users:
152+
// * Just drop the future after polling at least once (timeouts!!)
153+
//
154+
// For now there is not really a good solution to this problem as this would require something like async drop
155+
// (and even with async drop that would be really hard to solve due to the involved lifetimes)
156+
if let MaybeCached::CannotCache(stmt) = stmt {
157+
conn.close(stmt).await.map_err(ErrorHelper)?;
158+
}
122159
Ok(conn.affected_rows() as usize)
123160
})
124161
}
@@ -169,7 +206,6 @@ impl AsyncMysqlConnection {
169206
conn,
170207
stmt_cache: StmtCache::new(),
171208
transaction_manager: AnsiTransactionManager::default(),
172-
last_stmt: None,
173209
};
174210

175211
for stmt in CONNECTION_SETUP_QUERIES {
@@ -185,7 +221,7 @@ impl AsyncMysqlConnection {
185221
fn with_prepared_statement<'conn, T, F, R>(
186222
&'conn mut self,
187223
query: T,
188-
callback: impl (FnOnce(&'conn mut mysql_async::Conn, &'conn Statement, ToSqlHelper) -> F)
224+
callback: impl (FnOnce(&'conn mut mysql_async::Conn, MaybeCached<'conn, Statement>, ToSqlHelper) -> F)
189225
+ Send
190226
+ 'conn,
191227
) -> BoxFuture<'conn, QueryResult<R>>
@@ -205,27 +241,98 @@ impl AsyncMysqlConnection {
205241
let AsyncMysqlConnection {
206242
ref mut conn,
207243
ref mut stmt_cache,
208-
ref mut last_stmt,
209244
ref mut transaction_manager,
210245
..
211246
} = self;
212247

213248
let stmt = stmt_cache.cached_prepared_statement(query, &metadata, conn, &Mysql);
214249

215-
stmt.and_then(|(stmt, conn)|async move {
250+
stmt.and_then(|(stmt, conn)| async move {
251+
let res = update_transaction_manager_status(
252+
callback(conn, stmt, ToSqlHelper { metadata, binds }).await,
253+
transaction_manager,
254+
);
255+
res
256+
})
257+
.boxed()
258+
}
259+
260+
async fn poll_result_stream(
261+
conn: &mut mysql_async::Conn,
262+
stmt_for_exec: mysql_async::Statement,
263+
binds: ToSqlHelper,
264+
mut tx: futures::channel::mpsc::Sender<QueryResult<MysqlRow>>,
265+
) -> QueryResult<()> {
266+
use futures::SinkExt;
267+
let res = conn
268+
.exec_iter(stmt_for_exec, binds)
269+
.await
270+
.map_err(ErrorHelper)?;
216271

217-
let stmt = match stmt {
218-
MaybeCached::CannotCache(stmt) => {
219-
*last_stmt = Some(stmt);
220-
last_stmt.as_ref().unwrap()
221-
}
222-
MaybeCached::Cached(s) => s,
223-
_ => unreachable!("We've opted into breaking diesel changes and want to know if things break because someone added a new variant here")
224-
};
225-
update_transaction_manager_status(callback(conn, stmt, ToSqlHelper{metadata, binds}).await, transaction_manager)
226-
}).boxed()
272+
let mut stream = res
273+
.stream_and_drop::<MysqlRow>()
274+
.await
275+
.map_err(ErrorHelper)?
276+
.ok_or_else(|| {
277+
diesel::result::Error::DeserializationError(Box::new(
278+
diesel::result::UnexpectedEndOfRow,
279+
))
280+
})?
281+
.map_err(|e| diesel::result::Error::from(ErrorHelper(e)));
282+
283+
while let Some(row) = stream.next().await {
284+
let row = row?;
285+
tx.send(Ok(row))
286+
.await
287+
.map_err(|e| diesel::result::Error::DeserializationError(Box::new(e)))?;
288+
}
289+
290+
Ok(())
227291
}
228292
}
229293

230294
#[cfg(any(feature = "deadpool", feature = "bb8", feature = "mobc"))]
231295
impl crate::pooled_connection::PoolableConnection for AsyncMysqlConnection {}
296+
297+
#[cfg(test)]
298+
mod tests {
299+
use crate::RunQueryDsl;
300+
mod diesel_async {
301+
pub use crate::*;
302+
}
303+
include!("../doctest_setup.rs");
304+
305+
#[tokio::test]
306+
async fn check_statements_are_dropped() {
307+
use self::schema::users;
308+
309+
let mut conn = establish_connection().await;
310+
// we cannot set a lower limit here without admin privileges
311+
// which makes this test really slow
312+
let stmt_count = 16382 + 10;
313+
314+
for i in 0..stmt_count {
315+
diesel::insert_into(users::table)
316+
.values(Some(users::name.eq(format!("User{i}"))))
317+
.execute(&mut conn)
318+
.await
319+
.unwrap();
320+
}
321+
322+
#[derive(QueryableByName)]
323+
#[diesel(table_name = users)]
324+
#[allow(dead_code)]
325+
struct User {
326+
id: i32,
327+
name: String,
328+
}
329+
330+
for i in 0..stmt_count {
331+
diesel::sql_query("SELECT id, name FROM users WHERE name = ?")
332+
.bind::<diesel::sql_types::Text, _>(format!("User{i}"))
333+
.load::<User>(&mut conn)
334+
.await
335+
.unwrap();
336+
}
337+
}
338+
}

0 commit comments

Comments
 (0)