Skip to content

Commit 3569190

Browse files
authored
Merge pull request #262 from stormshield-kg/fix-hrtb-errors
Use custom `Future` combinators to avoid GAT errors
2 parents c339949 + 661de8c commit 3569190

File tree

3 files changed

+154
-105
lines changed

3 files changed

+154
-105
lines changed

src/pg/mod.rs

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1166,46 +1166,42 @@ mod tests {
11661166
.await
11671167
.unwrap();
11681168

1169-
fn erase<'a, T: Future + Send + 'a>(t: T) -> impl Future<Output = T::Output> + Send + 'a {
1170-
t
1171-
}
1172-
11731169
async fn fn12(mut conn: &AsyncPgConnection) -> QueryResult<(i32, i32)> {
11741170
let f1 = diesel::select(1_i32.into_sql::<Integer>()).get_result::<i32>(&mut conn);
11751171
let f2 = diesel::select(2_i32.into_sql::<Integer>()).get_result::<i32>(&mut conn);
11761172

1177-
erase(try_join(f1, f2)).await
1178-
}
1179-
1180-
async fn fn34(mut conn: &AsyncPgConnection) -> QueryResult<(i32, i32)> {
1181-
let f3 = diesel::select(3_i32.into_sql::<Integer>()).get_result::<i32>(&mut conn);
1182-
let f4 = diesel::select(4_i32.into_sql::<Integer>()).get_result::<i32>(&mut conn);
1183-
1184-
try_join(f3, f4).boxed().await
1173+
try_join(f1, f2).await
11851174
}
11861175

1187-
async fn fn56(mut conn: &AsyncPgConnection) -> QueryResult<(i32, i32)> {
1176+
async fn fn37(
1177+
mut conn: &AsyncPgConnection,
1178+
) -> QueryResult<(usize, (Vec<i32>, (i32, (Vec<i32>, i32))))> {
1179+
let f3 = diesel::select(0_i32.into_sql::<Integer>()).execute(&mut conn);
1180+
let f4 = diesel::select(4_i32.into_sql::<Integer>()).load::<i32>(&mut conn);
11881181
let f5 = diesel::select(5_i32.into_sql::<Integer>()).get_result::<i32>(&mut conn);
1189-
let f6 = diesel::select(6_i32.into_sql::<Integer>()).get_result::<i32>(&mut conn);
1182+
let f6 = diesel::select(6_i32.into_sql::<Integer>()).get_results::<i32>(&mut conn);
1183+
let f7 = diesel::select(7_i32.into_sql::<Integer>()).first::<i32>(&mut conn);
11901184

1191-
try_join(f5.boxed(), f6.boxed()).await
1185+
try_join(f3, try_join(f4, try_join(f5, try_join(f6, f7)))).await
11921186
}
11931187

11941188
conn.transaction(|conn| {
11951189
async move {
11961190
let f12 = fn12(conn);
1197-
let f34 = fn34(conn);
1198-
let f56 = fn56(conn);
1191+
let f37 = fn37(conn);
11991192

1200-
let ((r1, r2), ((r3, r4), (r5, r6))) =
1201-
try_join(f12, try_join(f34, f56)).await.unwrap();
1193+
let ((r1, r2), (r3, (r4, (r5, (r6, r7))))) = try_join(f12, f37).await.unwrap();
12021194

12031195
assert_eq!(r1, 1);
12041196
assert_eq!(r2, 2);
1205-
assert_eq!(r3, 3);
1206-
assert_eq!(r4, 4);
1197+
assert_eq!(r3, 1);
1198+
assert_eq!(r4, vec![4]);
12071199
assert_eq!(r5, 5);
1208-
assert_eq!(r6, 6);
1200+
assert_eq!(r6, vec![6]);
1201+
assert_eq!(r7, 7);
1202+
1203+
fn12(conn).await?;
1204+
fn37(conn).await?;
12091205

12101206
QueryResult::<_>::Ok(())
12111207
}

src/run_query_dsl/mod.rs

Lines changed: 24 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
1+
mod utils;
2+
13
use crate::AsyncConnectionCore;
24
use diesel::associations::HasTable;
35
use diesel::query_builder::IntoUpdateTarget;
46
use diesel::result::QueryResult;
57
use diesel::AsChangeset;
68
use futures_core::future::BoxFuture;
7-
use futures_core::Stream;
8-
use futures_util::{future, stream, FutureExt, StreamExt, TryFutureExt, TryStreamExt};
9+
#[cfg(any(feature = "mysql", feature = "postgres"))]
10+
use futures_util::FutureExt;
11+
use futures_util::{stream, StreamExt, TryStreamExt};
912
use std::future::Future;
10-
use std::pin::Pin;
1113

1214
/// The traits used by `QueryDsl`.
1315
///
@@ -22,7 +24,7 @@ pub mod methods {
2224
use diesel::expression::QueryMetadata;
2325
use diesel::query_builder::{AsQuery, QueryFragment, QueryId};
2426
use diesel::query_dsl::CompatibleType;
25-
use futures_util::{Future, Stream, TryFutureExt};
27+
use futures_util::{Future, Stream};
2628

2729
/// The `execute` method
2830
///
@@ -74,6 +76,7 @@ pub mod methods {
7476
type LoadFuture<'conn>: Future<Output = QueryResult<Self::Stream<'conn>>> + Send
7577
where
7678
Conn: 'conn;
79+
7780
/// The inner stream returned by [`LoadQuery::internal_load`]
7881
type Stream<'conn>: Stream<Item = QueryResult<U>> + Send
7982
where
@@ -96,10 +99,7 @@ pub mod methods {
9699
ST: 'static,
97100
{
98101
type LoadFuture<'conn>
99-
= future::MapOk<
100-
Conn::LoadFuture<'conn, 'query>,
101-
fn(Conn::Stream<'conn, 'query>) -> Self::Stream<'conn>,
102-
>
102+
= utils::MapOk<Conn::LoadFuture<'conn, 'query>, Self::Stream<'conn>>
103103
where
104104
Conn: 'conn;
105105

@@ -112,33 +112,13 @@ pub mod methods {
112112
Conn: 'conn;
113113

114114
fn internal_load(self, conn: &mut Conn) -> Self::LoadFuture<'_> {
115-
conn.load(self)
116-
.map_ok(map_result_stream_future::<U, _, _, DB, ST>)
115+
utils::MapOk::new(conn.load(self), |stream| {
116+
stream.map(|row| {
117+
U::build_from_row(&row?).map_err(diesel::result::Error::DeserializationError)
118+
})
119+
})
117120
}
118121
}
119-
120-
#[allow(clippy::type_complexity)]
121-
fn map_result_stream_future<'s, 'a, U, S, R, DB, ST>(
122-
stream: S,
123-
) -> stream::Map<S, fn(QueryResult<R>) -> QueryResult<U>>
124-
where
125-
S: Stream<Item = QueryResult<R>> + Send + 's,
126-
R: diesel::row::Row<'a, DB> + 's,
127-
DB: Backend + 'static,
128-
U: FromSqlRow<ST, DB> + 'static,
129-
ST: 'static,
130-
{
131-
stream.map(map_row_helper::<_, DB, U, ST>)
132-
}
133-
134-
fn map_row_helper<'a, R, DB, U, ST>(row: QueryResult<R>) -> QueryResult<U>
135-
where
136-
U: FromSqlRow<ST, DB>,
137-
R: diesel::row::Row<'a, DB>,
138-
DB: Backend,
139-
{
140-
U::build_from_row(&row?).map_err(diesel::result::Error::DeserializationError)
141-
}
142122
}
143123

144124
/// The return types produced by the various [`RunQueryDsl`] methods
@@ -149,37 +129,24 @@ pub mod methods {
149129
// the same connection
150130
#[allow(type_alias_bounds)] // we need these bounds otherwise we cannot use GAT's
151131
pub mod return_futures {
132+
use crate::run_query_dsl::utils;
133+
152134
use super::methods::LoadQuery;
153-
use diesel::QueryResult;
154-
use futures_util::{future, stream};
135+
use futures_util::stream;
155136
use std::pin::Pin;
156137

157138
/// The future returned by [`RunQueryDsl::load`](super::RunQueryDsl::load)
158139
/// and [`RunQueryDsl::get_results`](super::RunQueryDsl::get_results)
159140
///
160141
/// This is essentially `impl Future<Output = QueryResult<Vec<U>>>`
161-
pub type LoadFuture<'conn, 'query, Q: LoadQuery<'query, Conn, U>, Conn, U> = future::AndThen<
162-
Q::LoadFuture<'conn>,
163-
stream::TryCollect<Q::Stream<'conn>, Vec<U>>,
164-
fn(Q::Stream<'conn>) -> stream::TryCollect<Q::Stream<'conn>, Vec<U>>,
165-
>;
142+
pub type LoadFuture<'conn, 'query, Q: LoadQuery<'query, Conn, U>, Conn, U> =
143+
utils::AndThen<Q::LoadFuture<'conn>, stream::TryCollect<Q::Stream<'conn>, Vec<U>>>;
166144

167145
/// The future returned by [`RunQueryDsl::get_result`](super::RunQueryDsl::get_result)
168146
///
169147
/// This is essentially `impl Future<Output = QueryResult<U>>`
170-
pub type GetResult<'conn, 'query, Q: LoadQuery<'query, Conn, U>, Conn, U> = future::AndThen<
171-
Q::LoadFuture<'conn>,
172-
future::Map<
173-
stream::StreamFuture<Pin<Box<Q::Stream<'conn>>>>,
174-
fn((Option<QueryResult<U>>, Pin<Box<Q::Stream<'conn>>>)) -> QueryResult<U>,
175-
>,
176-
fn(
177-
Q::Stream<'conn>,
178-
) -> future::Map<
179-
stream::StreamFuture<Pin<Box<Q::Stream<'conn>>>>,
180-
fn((Option<QueryResult<U>>, Pin<Box<Q::Stream<'conn>>>)) -> QueryResult<U>,
181-
>,
182-
>;
148+
pub type GetResult<'conn, 'query, Q: LoadQuery<'query, Conn, U>, Conn, U> =
149+
utils::AndThen<Q::LoadFuture<'conn>, utils::LoadNext<Pin<Box<Q::Stream<'conn>>>>>;
183150
}
184151

185152
/// Methods used to execute queries.
@@ -346,13 +313,7 @@ pub trait RunQueryDsl<Conn>: Sized {
346313
Conn: AsyncConnectionCore,
347314
Self: methods::LoadQuery<'query, Conn, U> + 'query,
348315
{
349-
fn collect_result<U, S>(stream: S) -> stream::TryCollect<S, Vec<U>>
350-
where
351-
S: Stream<Item = QueryResult<U>>,
352-
{
353-
stream.try_collect()
354-
}
355-
self.internal_load(conn).and_then(collect_result::<U, _>)
316+
utils::AndThen::new(self.internal_load(conn), |stream| stream.try_collect())
356317
}
357318

358319
/// Executes the given query, returning a [`Stream`] with the returned rows.
@@ -547,29 +508,9 @@ pub trait RunQueryDsl<Conn>: Sized {
547508
Conn: AsyncConnectionCore,
548509
Self: methods::LoadQuery<'query, Conn, U> + 'query,
549510
{
550-
#[allow(clippy::type_complexity)]
551-
fn get_next_stream_element<S, U>(
552-
stream: S,
553-
) -> future::Map<
554-
stream::StreamFuture<Pin<Box<S>>>,
555-
fn((Option<QueryResult<U>>, Pin<Box<S>>)) -> QueryResult<U>,
556-
>
557-
where
558-
S: Stream<Item = QueryResult<U>>,
559-
{
560-
fn map_option_to_result<U, S>(
561-
(o, _): (Option<QueryResult<U>>, Pin<Box<S>>),
562-
) -> QueryResult<U> {
563-
match o {
564-
Some(s) => s,
565-
None => Err(diesel::result::Error::NotFound),
566-
}
567-
}
568-
569-
Box::pin(stream).into_future().map(map_option_to_result)
570-
}
571-
572-
self.load_stream(conn).and_then(get_next_stream_element)
511+
utils::AndThen::new(self.internal_load(conn), |stream| {
512+
utils::LoadNext::new(Box::pin(stream))
513+
})
573514
}
574515

575516
/// Runs the command, returning an `Vec` with the affected rows.

src/run_query_dsl/utils.rs

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
use std::future::Future;
2+
use std::pin::Pin;
3+
use std::task::{Context, Poll};
4+
5+
use diesel::QueryResult;
6+
use futures_core::{ready, TryFuture, TryStream};
7+
use futures_util::{TryFutureExt, TryStreamExt};
8+
9+
// We use a custom future implementation here to erase some lifetimes
10+
// that otherwise need to be specified explicitly
11+
//
12+
// Specifying these lifetimes results in the compiler not beeing
13+
// able to look through the generic code and emit
14+
// lifetime erros for pipelined queries. See
15+
// https://github.com/weiznich/diesel_async/issues/249 for more context
16+
#[repr(transparent)]
17+
pub struct MapOk<F: TryFutureExt, T> {
18+
future: futures_util::future::MapOk<F, fn(F::Ok) -> T>,
19+
}
20+
21+
impl<F, T> Future for MapOk<F, T>
22+
where
23+
F: TryFuture,
24+
futures_util::future::MapOk<F, fn(F::Ok) -> T>: Future<Output = Result<T, F::Error>>,
25+
{
26+
type Output = Result<T, F::Error>;
27+
28+
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
29+
unsafe {
30+
// SAFETY: This projects pinning to the only inner field, so it
31+
// should be safe
32+
self.map_unchecked_mut(|s| &mut s.future)
33+
}
34+
.poll(cx)
35+
}
36+
}
37+
38+
impl<Fut: TryFutureExt, T> MapOk<Fut, T> {
39+
pub(crate) fn new(future: Fut, f: fn(Fut::Ok) -> T) -> Self {
40+
Self {
41+
future: future.map_ok(f),
42+
}
43+
}
44+
}
45+
46+
// similar to `MapOk` above this mainly exists to hide the lifetime
47+
#[repr(transparent)]
48+
pub struct AndThen<F1: TryFuture, F2> {
49+
future: futures_util::future::AndThen<F1, F2, fn(F1::Ok) -> F2>,
50+
}
51+
52+
impl<Fut1, Fut2> AndThen<Fut1, Fut2>
53+
where
54+
Fut1: TryFuture,
55+
Fut2: TryFuture<Error = Fut1::Error>,
56+
{
57+
pub(crate) fn new(fut1: Fut1, f: fn(Fut1::Ok) -> Fut2) -> AndThen<Fut1, Fut2> {
58+
Self {
59+
future: fut1.and_then(f),
60+
}
61+
}
62+
}
63+
64+
impl<F1, F2> Future for AndThen<F1, F2>
65+
where
66+
F1: TryFuture,
67+
F2: TryFuture<Error = F1::Error>,
68+
{
69+
type Output = Result<F2::Ok, F2::Error>;
70+
71+
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
72+
unsafe {
73+
// SAFETY: This projects pinning to the only inner field, so it
74+
// should be safe
75+
self.map_unchecked_mut(|s| &mut s.future)
76+
}
77+
.poll(cx)
78+
}
79+
}
80+
81+
/// Converts a stream into a future, only yielding the first element.
82+
/// Based on [`futures_util::stream::StreamFuture`].
83+
pub struct LoadNext<St> {
84+
stream: Option<St>,
85+
}
86+
87+
impl<St> LoadNext<St> {
88+
pub(crate) fn new(stream: St) -> Self {
89+
Self {
90+
stream: Some(stream),
91+
}
92+
}
93+
}
94+
95+
impl<St> Future for LoadNext<St>
96+
where
97+
St: TryStream<Error = diesel::result::Error> + Unpin,
98+
{
99+
type Output = QueryResult<St::Ok>;
100+
101+
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
102+
let first = {
103+
let s = self.stream.as_mut().expect("polling LoadNext twice");
104+
ready!(s.try_poll_next_unpin(cx))
105+
};
106+
self.stream = None;
107+
match first {
108+
Some(first) => Poll::Ready(first),
109+
None => Poll::Ready(Err(diesel::result::Error::NotFound)),
110+
}
111+
}
112+
}

0 commit comments

Comments
 (0)