Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 85 additions & 12 deletions src/pg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@ use futures_util::{FutureExt, StreamExt};
use std::collections::{HashMap, HashSet};
use std::future::Future;
use std::sync::Arc;
use tokio::sync::broadcast;
use tokio::sync::oneshot;
use tokio::sync::Mutex;
use tokio::sync::{broadcast, mpsc, oneshot, Mutex};
use tokio_postgres::types::ToSql;
use tokio_postgres::types::Type;
use tokio_postgres::Statement;
Expand Down Expand Up @@ -172,6 +170,7 @@ pub struct AsyncPgConnection {
transaction_state: Arc<Mutex<AnsiTransactionManager>>,
metadata_cache: Arc<Mutex<PgMetadataCache>>,
connection_future: Option<broadcast::Receiver<Arc<tokio_postgres::Error>>>,
notification_rx: Option<mpsc::UnboundedReceiver<QueryResult<diesel::pg::PgNotification>>>,
shutdown_channel: Option<oneshot::Sender<()>>,
// a sync mutex is fine here as we only hold it for a really short time
instrumentation: Arc<std::sync::Mutex<DynInstrumentation>>,
Expand Down Expand Up @@ -283,11 +282,12 @@ impl AsyncConnection for AsyncPgConnection {
.await
.map_err(ErrorHelper)?;

let (error_rx, shutdown_tx) = drive_connection(connection);
let (error_rx, notification_rx, shutdown_tx) = drive_connection(connection);

let r = Self::setup(
client,
Some(error_rx),
Some(notification_rx),
Some(shutdown_tx),
Arc::clone(&instrumentation),
)
Expand Down Expand Up @@ -477,6 +477,7 @@ impl AsyncPgConnection {
conn,
None,
None,
None,
Arc::new(std::sync::Mutex::new(
DynInstrumentation::default_instrumentation(),
)),
Expand All @@ -493,11 +494,12 @@ impl AsyncPgConnection {
where
S: tokio_postgres::tls::TlsStream + Unpin + Send + 'static,
{
let (error_rx, shutdown_tx) = drive_connection(conn);
let (error_rx, notification_rx, shutdown_tx) = drive_connection(conn);

Self::setup(
client,
Some(error_rx),
Some(notification_rx),
Some(shutdown_tx),
Arc::new(std::sync::Mutex::new(DynInstrumentation::none())),
)
Expand All @@ -507,6 +509,7 @@ impl AsyncPgConnection {
async fn setup(
conn: tokio_postgres::Client,
connection_future: Option<broadcast::Receiver<Arc<tokio_postgres::Error>>>,
notification_rx: Option<mpsc::UnboundedReceiver<QueryResult<diesel::pg::PgNotification>>>,
shutdown_channel: Option<oneshot::Sender<()>>,
instrumentation: Arc<std::sync::Mutex<DynInstrumentation>>,
) -> ConnectionResult<Self> {
Expand All @@ -516,6 +519,7 @@ impl AsyncPgConnection {
transaction_state: Arc::new(Mutex::new(AnsiTransactionManager::default())),
metadata_cache: Arc::new(Mutex::new(PgMetadataCache::new())),
connection_future,
notification_rx,
shutdown_channel,
instrumentation,
};
Expand Down Expand Up @@ -724,6 +728,58 @@ impl AsyncPgConnection {
.unwrap_or_else(|p| p.into_inner())
.on_connection_event(event);
}

/// See Postgres documentation for SQL commands [NOTIFY][] and [LISTEN][]
///
/// The returned stream yields all notifications received by the connection, not only notifications received
/// after calling the function. The returned stream will never close, so no notifications will just result
/// in a pending state.
///
/// If there's no connection available to poll, the stream will yield no notifications and be pending forever.
/// This can happen if you created the [`AsyncPgConnection`] by the [`try_from`] constructor.
///
/// [NOTIFY]: https://www.postgresql.org/docs/current/sql-notify.html
/// [LISTEN]: https://www.postgresql.org/docs/current/sql-listen.html
/// [`AsyncPgConnection`]: crate::pg::AsyncPgConnection
/// [`try_from`]: crate::pg::AsyncPgConnection::try_from
///
/// ```rust
/// # include!("../doctest_setup.rs");
/// # use scoped_futures::ScopedFutureExt;
/// #
/// # #[tokio::main(flavor = "current_thread")]
/// # async fn main() {
/// # run_test().await.unwrap();
/// # }
/// #
/// # async fn run_test() -> QueryResult<()> {
/// # use diesel_async::RunQueryDsl;
/// # use futures_util::StreamExt;
/// # let conn = &mut connection_no_transaction().await;
/// // register the notifications channel we want to receive notifications for
/// diesel::sql_query("LISTEN example_channel").execute(conn).await?;
/// // send some notification (usually done from a different connection/thread/application)
/// diesel::sql_query("NOTIFY example_channel, 'additional data'").execute(conn).await?;
///
/// let mut notifications = std::pin::pin!(conn.notifications_stream());
/// let mut notification = notifications.next().await.unwrap().unwrap();
///
/// assert_eq!(notification.channel, "example_channel");
/// assert_eq!(notification.payload, "additional data");
/// println!("Notification received from process with id {}", notification.process_id);
/// # Ok(())
/// # }
/// ```
pub fn notifications_stream(
&mut self,
) -> impl futures_core::Stream<Item = QueryResult<diesel::pg::PgNotification>> + '_ {
match &mut self.notification_rx {
None => Either::Left(futures_util::stream::pending()),
Some(rx) => Either::Right(futures_util::stream::unfold(rx, |rx| async {
rx.recv().await.map(move |item| (item, rx))
})),
}
}
}

struct BindData {
Expand Down Expand Up @@ -969,27 +1025,44 @@ async fn drive_future<R>(
}

fn drive_connection<S>(
conn: tokio_postgres::Connection<tokio_postgres::Socket, S>,
mut conn: tokio_postgres::Connection<tokio_postgres::Socket, S>,
) -> (
broadcast::Receiver<Arc<tokio_postgres::Error>>,
mpsc::UnboundedReceiver<QueryResult<diesel::pg::PgNotification>>,
oneshot::Sender<()>,
)
where
S: tokio_postgres::tls::TlsStream + Unpin + Send + 'static,
{
let (error_tx, error_rx) = tokio::sync::broadcast::channel(1);
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
let (notification_tx, notification_rx) = tokio::sync::mpsc::unbounded_channel();
let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel();
let mut conn = futures_util::stream::poll_fn(move |cx| conn.poll_message(cx));

tokio::spawn(async move {
match futures_util::future::select(shutdown_rx, conn).await {
Either::Left(_) | Either::Right((Ok(_), _)) => {}
Either::Right((Err(e), _)) => {
let _ = error_tx.send(Arc::new(e));
loop {
match futures_util::future::select(&mut shutdown_rx, conn.next()).await {
Either::Left(_) | Either::Right((None, _)) => break,
Either::Right((Some(Ok(tokio_postgres::AsyncMessage::Notification(notif))), _)) => {
let _: Result<_, _> = notification_tx.send(Ok(diesel::pg::PgNotification {
process_id: notif.process_id(),
channel: notif.channel().to_owned(),
payload: notif.payload().to_owned(),
}));
}
Either::Right((Some(Ok(_)), _)) => {}
Either::Right((Some(Err(e)), _)) => {
let e = Arc::new(e);
let _: Result<_, _> = error_tx.send(e.clone());
let _: Result<_, _> =
notification_tx.send(Err(error_helper::from_tokio_postgres_error(e)));
break;
}
}
}
});

(error_rx, shutdown_tx)
(error_rx, notification_rx, shutdown_tx)
}

#[cfg(any(
Expand Down
1 change: 1 addition & 0 deletions tests/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use std::fmt::Debug;
#[cfg(feature = "postgres")]
mod custom_types;
mod instrumentation;
mod notifications;
#[cfg(any(feature = "bb8", feature = "deadpool", feature = "mobc"))]
mod pooling;
#[cfg(feature = "async-connection-wrapper")]
Expand Down
55 changes: 55 additions & 0 deletions tests/notifications.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#[cfg(feature = "postgres")]
#[tokio::test]
async fn notifications_arrive() {
use diesel_async::RunQueryDsl;
use futures_util::{StreamExt, TryStreamExt};

let conn = &mut super::connection_without_transaction().await;

diesel::sql_query("LISTEN test_notifications")
.execute(conn)
.await
.unwrap();

diesel::sql_query("NOTIFY test_notifications, 'first'")
.execute(conn)
.await
.unwrap();

diesel::sql_query("NOTIFY test_notifications, 'second'")
.execute(conn)
.await
.unwrap();

let notifications = conn
.notifications_stream()
.take(2)
.try_collect::<Vec<_>>()
.await
.unwrap();

assert_eq!(2, notifications.len());
assert_eq!(notifications[0].channel, "test_notifications");
assert_eq!(notifications[1].channel, "test_notifications");
assert_eq!(notifications[0].payload, "first");
assert_eq!(notifications[1].payload, "second");

let next_notification = tokio::time::timeout(
std::time::Duration::from_secs(1),
std::pin::pin!(conn.notifications_stream()).next(),
)
.await;

assert!(
next_notification.is_err(),
"Got a next notification, while not expecting one: {next_notification:?}"
);

diesel::sql_query("NOTIFY test_notifications")
.execute(conn)
.await
.unwrap();

let next_notification = std::pin::pin!(conn.notifications_stream()).next().await;
assert_eq!(next_notification.unwrap().unwrap().payload, "");
}