-
-
Notifications
You must be signed in to change notification settings - Fork 96
Implement notification_stream for AsyncPgConnection #251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
6d9d4bc
9107381
e8bb91c
297c383
459f5aa
685f0ba
848c241
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,9 +31,7 @@ use futures_util::{FutureExt, StreamExt}; | |
use std::collections::{HashMap, HashSet}; | ||
use std::future::Future; | ||
use std::sync::Arc; | ||
use tokio::sync::broadcast; | ||
use tokio::sync::oneshot; | ||
use tokio::sync::Mutex; | ||
use tokio::sync::{broadcast, mpsc, oneshot, Mutex}; | ||
use tokio_postgres::types::ToSql; | ||
use tokio_postgres::types::Type; | ||
use tokio_postgres::Statement; | ||
|
@@ -172,6 +170,7 @@ pub struct AsyncPgConnection { | |
transaction_state: Arc<Mutex<AnsiTransactionManager>>, | ||
metadata_cache: Arc<Mutex<PgMetadataCache>>, | ||
connection_future: Option<broadcast::Receiver<Arc<tokio_postgres::Error>>>, | ||
notification_rx: Option<mpsc::UnboundedReceiver<diesel::pg::PgNotification>>, | ||
shutdown_channel: Option<oneshot::Sender<()>>, | ||
// a sync mutex is fine here as we only hold it for a really short time | ||
instrumentation: Arc<std::sync::Mutex<DynInstrumentation>>, | ||
|
@@ -283,11 +282,12 @@ impl AsyncConnection for AsyncPgConnection { | |
.await | ||
.map_err(ErrorHelper)?; | ||
|
||
let (error_rx, shutdown_tx) = drive_connection(connection); | ||
let (error_rx, notification_rx, shutdown_tx) = drive_connection(connection); | ||
|
||
let r = Self::setup( | ||
client, | ||
Some(error_rx), | ||
Some(notification_rx), | ||
Some(shutdown_tx), | ||
Arc::clone(&instrumentation), | ||
) | ||
|
@@ -477,6 +477,7 @@ impl AsyncPgConnection { | |
conn, | ||
None, | ||
None, | ||
None, | ||
Arc::new(std::sync::Mutex::new( | ||
DynInstrumentation::default_instrumentation(), | ||
)), | ||
|
@@ -493,11 +494,12 @@ impl AsyncPgConnection { | |
where | ||
S: tokio_postgres::tls::TlsStream + Unpin + Send + 'static, | ||
{ | ||
let (error_rx, shutdown_tx) = drive_connection(conn); | ||
let (error_rx, notification_rx, shutdown_tx) = drive_connection(conn); | ||
|
||
Self::setup( | ||
client, | ||
Some(error_rx), | ||
Some(notification_rx), | ||
Some(shutdown_tx), | ||
Arc::new(std::sync::Mutex::new(DynInstrumentation::none())), | ||
) | ||
|
@@ -507,6 +509,7 @@ impl AsyncPgConnection { | |
async fn setup( | ||
conn: tokio_postgres::Client, | ||
connection_future: Option<broadcast::Receiver<Arc<tokio_postgres::Error>>>, | ||
notification_rx: Option<mpsc::UnboundedReceiver<diesel::pg::PgNotification>>, | ||
shutdown_channel: Option<oneshot::Sender<()>>, | ||
instrumentation: Arc<std::sync::Mutex<DynInstrumentation>>, | ||
) -> ConnectionResult<Self> { | ||
|
@@ -516,6 +519,7 @@ impl AsyncPgConnection { | |
transaction_state: Arc::new(Mutex::new(AnsiTransactionManager::default())), | ||
metadata_cache: Arc::new(Mutex::new(PgMetadataCache::new())), | ||
connection_future, | ||
notification_rx, | ||
shutdown_channel, | ||
instrumentation, | ||
}; | ||
|
@@ -724,6 +728,28 @@ impl AsyncPgConnection { | |
.unwrap_or_else(|p| p.into_inner()) | ||
.on_connection_event(event); | ||
} | ||
|
||
pub fn notification_stream( | ||
lsunsi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
&mut self, | ||
) -> impl futures_core::Stream<Item = diesel::pg::PgNotification> + '_ { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's about this signature. I changed it to Stream to match the channel, but now I think maybe it's wrong because you said it should be Stream<QueryResult>. That said, the poll_messages function exposes Ok(Notification | Notice) | Err, so if I get an error I can't be sure it's from a notification, a notice, or anything. It's basically just an error. So it feels to me that I should either output the 3 possible cases, or just pick the one I want (notification). But I'm not sure of this. Anyway, I think diesel sync might have gone through the same question? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it needs to be (Hopefully that's "clear" and not expressed in a too confusing way) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, that actually makes sense. I implemented it on 297c383. |
||
NotificationStream(self.notification_rx.as_mut()) | ||
} | ||
} | ||
|
||
struct NotificationStream<'a>(Option<&'a mut mpsc::UnboundedReceiver<diesel::pg::PgNotification>>); | ||
|
||
impl futures_core::Stream for NotificationStream<'_> { | ||
type Item = diesel::pg::PgNotification; | ||
|
||
fn poll_next( | ||
mut self: std::pin::Pin<&mut Self>, | ||
cx: &mut std::task::Context<'_>, | ||
) -> std::task::Poll<Option<Self::Item>> { | ||
match &mut self.0 { | ||
Some(rx) => rx.poll_recv(cx), | ||
None => std::task::Poll::Pending, | ||
} | ||
} | ||
} | ||
|
||
struct BindData { | ||
|
@@ -969,27 +995,42 @@ async fn drive_future<R>( | |
} | ||
|
||
fn drive_connection<S>( | ||
conn: tokio_postgres::Connection<tokio_postgres::Socket, S>, | ||
mut conn: tokio_postgres::Connection<tokio_postgres::Socket, S>, | ||
) -> ( | ||
broadcast::Receiver<Arc<tokio_postgres::Error>>, | ||
mpsc::UnboundedReceiver<diesel::pg::PgNotification>, | ||
oneshot::Sender<()>, | ||
) | ||
where | ||
S: tokio_postgres::tls::TlsStream + Unpin + Send + 'static, | ||
{ | ||
let (error_tx, error_rx) = tokio::sync::broadcast::channel(1); | ||
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel(); | ||
let (notification_tx, notification_rx) = tokio::sync::mpsc::unbounded_channel(); | ||
let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel(); | ||
|
||
tokio::spawn(async move { | ||
match futures_util::future::select(shutdown_rx, conn).await { | ||
Either::Left(_) | Either::Right((Ok(_), _)) => {} | ||
Either::Right((Err(e), _)) => { | ||
let _ = error_tx.send(Arc::new(e)); | ||
let mut conn = futures_util::stream::poll_fn(|cx| conn.poll_message(cx)); | ||
lsunsi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
loop { | ||
match futures_util::future::select(&mut shutdown_rx, conn.next()).await { | ||
Either::Left(_) | Either::Right((None, _)) => break, | ||
Either::Right((Some(Ok(tokio_postgres::AsyncMessage::Notification(notif))), _)) => { | ||
let _: Result<_, _> = notification_tx.send(diesel::pg::PgNotification { | ||
process_id: notif.process_id(), | ||
channel: notif.channel().to_owned(), | ||
payload: notif.payload().to_owned(), | ||
}); | ||
} | ||
Either::Right((Some(Ok(_)), _)) => {} | ||
Either::Right((Some(Err(e)), _)) => { | ||
let _ = error_tx.send(Arc::new(e)); | ||
break; | ||
} | ||
} | ||
} | ||
}); | ||
|
||
(error_rx, shutdown_tx) | ||
(error_rx, notification_rx, shutdown_tx) | ||
} | ||
|
||
#[cfg(any( | ||
|
Uh oh!
There was an error while loading. Please reload this page.