diff --git a/src/pg/mod.rs b/src/pg/mod.rs index 39811b3..5a26692 100644 --- a/src/pg/mod.rs +++ b/src/pg/mod.rs @@ -31,9 +31,7 @@ use futures_util::{FutureExt, StreamExt}; use std::collections::{HashMap, HashSet}; use std::future::Future; use std::sync::Arc; -use tokio::sync::broadcast; -use tokio::sync::oneshot; -use tokio::sync::Mutex; +use tokio::sync::{broadcast, mpsc, oneshot, Mutex}; use tokio_postgres::types::ToSql; use tokio_postgres::types::Type; use tokio_postgres::Statement; @@ -172,6 +170,7 @@ pub struct AsyncPgConnection { transaction_state: Arc>, metadata_cache: Arc>, connection_future: Option>>, + notification_rx: Option>>, shutdown_channel: Option>, // a sync mutex is fine here as we only hold it for a really short time instrumentation: Arc>, @@ -283,11 +282,12 @@ impl AsyncConnection for AsyncPgConnection { .await .map_err(ErrorHelper)?; - let (error_rx, shutdown_tx) = drive_connection(connection); + let (error_rx, notification_rx, shutdown_tx) = drive_connection(connection); let r = Self::setup( client, Some(error_rx), + Some(notification_rx), Some(shutdown_tx), Arc::clone(&instrumentation), ) @@ -477,6 +477,7 @@ impl AsyncPgConnection { conn, None, None, + None, Arc::new(std::sync::Mutex::new( DynInstrumentation::default_instrumentation(), )), @@ -493,11 +494,12 @@ impl AsyncPgConnection { where S: tokio_postgres::tls::TlsStream + Unpin + Send + 'static, { - let (error_rx, shutdown_tx) = drive_connection(conn); + let (error_rx, notification_rx, shutdown_tx) = drive_connection(conn); Self::setup( client, Some(error_rx), + Some(notification_rx), Some(shutdown_tx), Arc::new(std::sync::Mutex::new(DynInstrumentation::none())), ) @@ -507,6 +509,7 @@ impl AsyncPgConnection { async fn setup( conn: tokio_postgres::Client, connection_future: Option>>, + notification_rx: Option>>, shutdown_channel: Option>, instrumentation: Arc>, ) -> ConnectionResult { @@ -516,6 +519,7 @@ impl AsyncPgConnection { transaction_state: Arc::new(Mutex::new(AnsiTransactionManager::default())), metadata_cache: Arc::new(Mutex::new(PgMetadataCache::new())), connection_future, + notification_rx, shutdown_channel, instrumentation, }; @@ -724,6 +728,58 @@ impl AsyncPgConnection { .unwrap_or_else(|p| p.into_inner()) .on_connection_event(event); } + + /// See Postgres documentation for SQL commands [NOTIFY][] and [LISTEN][] + /// + /// The returned stream yields all notifications received by the connection, not only notifications received + /// after calling the function. The returned stream will never close, so no notifications will just result + /// in a pending state. + /// + /// If there's no connection available to poll, the stream will yield no notifications and be pending forever. + /// This can happen if you created the [`AsyncPgConnection`] by the [`try_from`] constructor. + /// + /// [NOTIFY]: https://www.postgresql.org/docs/current/sql-notify.html + /// [LISTEN]: https://www.postgresql.org/docs/current/sql-listen.html + /// [`AsyncPgConnection`]: crate::pg::AsyncPgConnection + /// [`try_from`]: crate::pg::AsyncPgConnection::try_from + /// + /// ```rust + /// # include!("../doctest_setup.rs"); + /// # use scoped_futures::ScopedFutureExt; + /// # + /// # #[tokio::main(flavor = "current_thread")] + /// # async fn main() { + /// # run_test().await.unwrap(); + /// # } + /// # + /// # async fn run_test() -> QueryResult<()> { + /// # use diesel_async::RunQueryDsl; + /// # use futures_util::StreamExt; + /// # let conn = &mut connection_no_transaction().await; + /// // register the notifications channel we want to receive notifications for + /// diesel::sql_query("LISTEN example_channel").execute(conn).await?; + /// // send some notification (usually done from a different connection/thread/application) + /// diesel::sql_query("NOTIFY example_channel, 'additional data'").execute(conn).await?; + /// + /// let mut notifications = std::pin::pin!(conn.notifications_stream()); + /// let mut notification = notifications.next().await.unwrap().unwrap(); + /// + /// assert_eq!(notification.channel, "example_channel"); + /// assert_eq!(notification.payload, "additional data"); + /// println!("Notification received from process with id {}", notification.process_id); + /// # Ok(()) + /// # } + /// ``` + pub fn notifications_stream( + &mut self, + ) -> impl futures_core::Stream> + '_ { + match &mut self.notification_rx { + None => Either::Left(futures_util::stream::pending()), + Some(rx) => Either::Right(futures_util::stream::unfold(rx, |rx| async { + rx.recv().await.map(move |item| (item, rx)) + })), + } + } } struct BindData { @@ -969,27 +1025,44 @@ async fn drive_future( } fn drive_connection( - conn: tokio_postgres::Connection, + mut conn: tokio_postgres::Connection, ) -> ( broadcast::Receiver>, + mpsc::UnboundedReceiver>, oneshot::Sender<()>, ) where S: tokio_postgres::tls::TlsStream + Unpin + Send + 'static, { let (error_tx, error_rx) = tokio::sync::broadcast::channel(1); - let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel(); + let (notification_tx, notification_rx) = tokio::sync::mpsc::unbounded_channel(); + let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel(); + let mut conn = futures_util::stream::poll_fn(move |cx| conn.poll_message(cx)); tokio::spawn(async move { - match futures_util::future::select(shutdown_rx, conn).await { - Either::Left(_) | Either::Right((Ok(_), _)) => {} - Either::Right((Err(e), _)) => { - let _ = error_tx.send(Arc::new(e)); + loop { + match futures_util::future::select(&mut shutdown_rx, conn.next()).await { + Either::Left(_) | Either::Right((None, _)) => break, + Either::Right((Some(Ok(tokio_postgres::AsyncMessage::Notification(notif))), _)) => { + let _: Result<_, _> = notification_tx.send(Ok(diesel::pg::PgNotification { + process_id: notif.process_id(), + channel: notif.channel().to_owned(), + payload: notif.payload().to_owned(), + })); + } + Either::Right((Some(Ok(_)), _)) => {} + Either::Right((Some(Err(e)), _)) => { + let e = Arc::new(e); + let _: Result<_, _> = error_tx.send(e.clone()); + let _: Result<_, _> = + notification_tx.send(Err(error_helper::from_tokio_postgres_error(e))); + break; + } } } }); - (error_rx, shutdown_tx) + (error_rx, notification_rx, shutdown_tx) } #[cfg(any( diff --git a/tests/lib.rs b/tests/lib.rs index 24cd2a6..5125e28 100644 --- a/tests/lib.rs +++ b/tests/lib.rs @@ -7,6 +7,7 @@ use std::fmt::Debug; #[cfg(feature = "postgres")] mod custom_types; mod instrumentation; +mod notifications; #[cfg(any(feature = "bb8", feature = "deadpool", feature = "mobc"))] mod pooling; #[cfg(feature = "async-connection-wrapper")] diff --git a/tests/notifications.rs b/tests/notifications.rs new file mode 100644 index 0000000..17b790b --- /dev/null +++ b/tests/notifications.rs @@ -0,0 +1,55 @@ +#[cfg(feature = "postgres")] +#[tokio::test] +async fn notifications_arrive() { + use diesel_async::RunQueryDsl; + use futures_util::{StreamExt, TryStreamExt}; + + let conn = &mut super::connection_without_transaction().await; + + diesel::sql_query("LISTEN test_notifications") + .execute(conn) + .await + .unwrap(); + + diesel::sql_query("NOTIFY test_notifications, 'first'") + .execute(conn) + .await + .unwrap(); + + diesel::sql_query("NOTIFY test_notifications, 'second'") + .execute(conn) + .await + .unwrap(); + + let notifications = conn + .notifications_stream() + .take(2) + .try_collect::>() + .await + .unwrap(); + + assert_eq!(2, notifications.len()); + assert_eq!(notifications[0].channel, "test_notifications"); + assert_eq!(notifications[1].channel, "test_notifications"); + assert_eq!(notifications[0].payload, "first"); + assert_eq!(notifications[1].payload, "second"); + + let next_notification = tokio::time::timeout( + std::time::Duration::from_secs(1), + std::pin::pin!(conn.notifications_stream()).next(), + ) + .await; + + assert!( + next_notification.is_err(), + "Got a next notification, while not expecting one: {next_notification:?}" + ); + + diesel::sql_query("NOTIFY test_notifications") + .execute(conn) + .await + .unwrap(); + + let next_notification = std::pin::pin!(conn.notifications_stream()).next().await; + assert_eq!(next_notification.unwrap().unwrap().payload, ""); +}