Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 53 additions & 12 deletions src/pg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@ use futures_util::{FutureExt, StreamExt};
use std::collections::{HashMap, HashSet};
use std::future::Future;
use std::sync::Arc;
use tokio::sync::broadcast;
use tokio::sync::oneshot;
use tokio::sync::Mutex;
use tokio::sync::{broadcast, mpsc, oneshot, Mutex};
use tokio_postgres::types::ToSql;
use tokio_postgres::types::Type;
use tokio_postgres::Statement;
Expand Down Expand Up @@ -172,6 +170,7 @@ pub struct AsyncPgConnection {
transaction_state: Arc<Mutex<AnsiTransactionManager>>,
metadata_cache: Arc<Mutex<PgMetadataCache>>,
connection_future: Option<broadcast::Receiver<Arc<tokio_postgres::Error>>>,
notification_rx: Option<mpsc::UnboundedReceiver<diesel::pg::PgNotification>>,
shutdown_channel: Option<oneshot::Sender<()>>,
// a sync mutex is fine here as we only hold it for a really short time
instrumentation: Arc<std::sync::Mutex<DynInstrumentation>>,
Expand Down Expand Up @@ -283,11 +282,12 @@ impl AsyncConnection for AsyncPgConnection {
.await
.map_err(ErrorHelper)?;

let (error_rx, shutdown_tx) = drive_connection(connection);
let (error_rx, notification_rx, shutdown_tx) = drive_connection(connection);

let r = Self::setup(
client,
Some(error_rx),
Some(notification_rx),
Some(shutdown_tx),
Arc::clone(&instrumentation),
)
Expand Down Expand Up @@ -477,6 +477,7 @@ impl AsyncPgConnection {
conn,
None,
None,
None,
Arc::new(std::sync::Mutex::new(
DynInstrumentation::default_instrumentation(),
)),
Expand All @@ -493,11 +494,12 @@ impl AsyncPgConnection {
where
S: tokio_postgres::tls::TlsStream + Unpin + Send + 'static,
{
let (error_rx, shutdown_tx) = drive_connection(conn);
let (error_rx, notification_rx, shutdown_tx) = drive_connection(conn);

Self::setup(
client,
Some(error_rx),
Some(notification_rx),
Some(shutdown_tx),
Arc::new(std::sync::Mutex::new(DynInstrumentation::none())),
)
Expand All @@ -507,6 +509,7 @@ impl AsyncPgConnection {
async fn setup(
conn: tokio_postgres::Client,
connection_future: Option<broadcast::Receiver<Arc<tokio_postgres::Error>>>,
notification_rx: Option<mpsc::UnboundedReceiver<diesel::pg::PgNotification>>,
shutdown_channel: Option<oneshot::Sender<()>>,
instrumentation: Arc<std::sync::Mutex<DynInstrumentation>>,
) -> ConnectionResult<Self> {
Expand All @@ -516,6 +519,7 @@ impl AsyncPgConnection {
transaction_state: Arc::new(Mutex::new(AnsiTransactionManager::default())),
metadata_cache: Arc::new(Mutex::new(PgMetadataCache::new())),
connection_future,
notification_rx,
shutdown_channel,
instrumentation,
};
Expand Down Expand Up @@ -724,6 +728,28 @@ impl AsyncPgConnection {
.unwrap_or_else(|p| p.into_inner())
.on_connection_event(event);
}

pub fn notification_stream(
&mut self,
) -> impl futures_core::Stream<Item = diesel::pg::PgNotification> + '_ {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's about this signature. I changed it to Stream to match the channel, but now I think maybe it's wrong because you said it should be Stream<QueryResult>.

That said, the poll_messages function exposes Ok(Notification | Notice) | Err, so if I get an error I can't be sure it's from a notification, a notice, or anything. It's basically just an error.

So it feels to me that I should either output the 3 possible cases, or just pick the one I want (notification). But I'm not sure of this. Anyway, I think diesel sync might have gone through the same question?

Copy link
Owner

@weiznich weiznich Jul 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it needs to be Stream<QueryResult> as we currently miss one important point: We still need to poll the drive_connection future as well as part of the stream, as otherwise we won't get any new notifications while the user polls on the channel we returned. I think that connection future might return an error (via the error channel), that then must be returned from the notification stream.

(Hopefully that's "clear" and not expressed in a too confusing way)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that actually makes sense. I implemented it on 297c383.
One leftover question, does it make sense to have the error go on both error_tx and notification_tx? That's what I ended up leaving with but I was not really sure.

NotificationStream(self.notification_rx.as_mut())
}
}

struct NotificationStream<'a>(Option<&'a mut mpsc::UnboundedReceiver<diesel::pg::PgNotification>>);

impl futures_core::Stream for NotificationStream<'_> {
type Item = diesel::pg::PgNotification;

fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
match &mut self.0 {
Some(rx) => rx.poll_recv(cx),
None => std::task::Poll::Pending,
}
}
}

struct BindData {
Expand Down Expand Up @@ -969,27 +995,42 @@ async fn drive_future<R>(
}

fn drive_connection<S>(
conn: tokio_postgres::Connection<tokio_postgres::Socket, S>,
mut conn: tokio_postgres::Connection<tokio_postgres::Socket, S>,
) -> (
broadcast::Receiver<Arc<tokio_postgres::Error>>,
mpsc::UnboundedReceiver<diesel::pg::PgNotification>,
oneshot::Sender<()>,
)
where
S: tokio_postgres::tls::TlsStream + Unpin + Send + 'static,
{
let (error_tx, error_rx) = tokio::sync::broadcast::channel(1);
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
let (notification_tx, notification_rx) = tokio::sync::mpsc::unbounded_channel();
let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel();

tokio::spawn(async move {
match futures_util::future::select(shutdown_rx, conn).await {
Either::Left(_) | Either::Right((Ok(_), _)) => {}
Either::Right((Err(e), _)) => {
let _ = error_tx.send(Arc::new(e));
let mut conn = futures_util::stream::poll_fn(|cx| conn.poll_message(cx));

loop {
match futures_util::future::select(&mut shutdown_rx, conn.next()).await {
Either::Left(_) | Either::Right((None, _)) => break,
Either::Right((Some(Ok(tokio_postgres::AsyncMessage::Notification(notif))), _)) => {
let _: Result<_, _> = notification_tx.send(diesel::pg::PgNotification {
process_id: notif.process_id(),
channel: notif.channel().to_owned(),
payload: notif.payload().to_owned(),
});
}
Either::Right((Some(Ok(_)), _)) => {}
Either::Right((Some(Err(e)), _)) => {
let _ = error_tx.send(Arc::new(e));
break;
}
}
}
});

(error_rx, shutdown_tx)
(error_rx, notification_rx, shutdown_tx)
}

#[cfg(any(
Expand Down
Loading