@@ -31,9 +31,7 @@ use futures_util::{FutureExt, StreamExt};
3131use std:: collections:: { HashMap , HashSet } ;
3232use std:: future:: Future ;
3333use std:: sync:: Arc ;
34- use tokio:: sync:: broadcast;
35- use tokio:: sync:: oneshot;
36- use tokio:: sync:: Mutex ;
34+ use tokio:: sync:: { broadcast, mpsc, oneshot, Mutex } ;
3735use tokio_postgres:: types:: ToSql ;
3836use tokio_postgres:: types:: Type ;
3937use tokio_postgres:: Statement ;
@@ -172,6 +170,7 @@ pub struct AsyncPgConnection {
172170 transaction_state : Arc < Mutex < AnsiTransactionManager > > ,
173171 metadata_cache : Arc < Mutex < PgMetadataCache > > ,
174172 connection_future : Option < broadcast:: Receiver < Arc < tokio_postgres:: Error > > > ,
173+ notification_rx : Option < mpsc:: UnboundedReceiver < QueryResult < diesel:: pg:: PgNotification > > > ,
175174 shutdown_channel : Option < oneshot:: Sender < ( ) > > ,
176175 // a sync mutex is fine here as we only hold it for a really short time
177176 instrumentation : Arc < std:: sync:: Mutex < DynInstrumentation > > ,
@@ -283,11 +282,12 @@ impl AsyncConnection for AsyncPgConnection {
283282 . await
284283 . map_err ( ErrorHelper ) ?;
285284
286- let ( error_rx, shutdown_tx) = drive_connection ( connection) ;
285+ let ( error_rx, notification_rx , shutdown_tx) = drive_connection ( connection) ;
287286
288287 let r = Self :: setup (
289288 client,
290289 Some ( error_rx) ,
290+ Some ( notification_rx) ,
291291 Some ( shutdown_tx) ,
292292 Arc :: clone ( & instrumentation) ,
293293 )
@@ -477,6 +477,7 @@ impl AsyncPgConnection {
477477 conn,
478478 None ,
479479 None ,
480+ None ,
480481 Arc :: new ( std:: sync:: Mutex :: new (
481482 DynInstrumentation :: default_instrumentation ( ) ,
482483 ) ) ,
@@ -493,11 +494,12 @@ impl AsyncPgConnection {
493494 where
494495 S : tokio_postgres:: tls:: TlsStream + Unpin + Send + ' static ,
495496 {
496- let ( error_rx, shutdown_tx) = drive_connection ( conn) ;
497+ let ( error_rx, notification_rx , shutdown_tx) = drive_connection ( conn) ;
497498
498499 Self :: setup (
499500 client,
500501 Some ( error_rx) ,
502+ Some ( notification_rx) ,
501503 Some ( shutdown_tx) ,
502504 Arc :: new ( std:: sync:: Mutex :: new ( DynInstrumentation :: none ( ) ) ) ,
503505 )
@@ -507,6 +509,7 @@ impl AsyncPgConnection {
507509 async fn setup (
508510 conn : tokio_postgres:: Client ,
509511 connection_future : Option < broadcast:: Receiver < Arc < tokio_postgres:: Error > > > ,
512+ notification_rx : Option < mpsc:: UnboundedReceiver < QueryResult < diesel:: pg:: PgNotification > > > ,
510513 shutdown_channel : Option < oneshot:: Sender < ( ) > > ,
511514 instrumentation : Arc < std:: sync:: Mutex < DynInstrumentation > > ,
512515 ) -> ConnectionResult < Self > {
@@ -516,6 +519,7 @@ impl AsyncPgConnection {
516519 transaction_state : Arc :: new ( Mutex :: new ( AnsiTransactionManager :: default ( ) ) ) ,
517520 metadata_cache : Arc :: new ( Mutex :: new ( PgMetadataCache :: new ( ) ) ) ,
518521 connection_future,
522+ notification_rx,
519523 shutdown_channel,
520524 instrumentation,
521525 } ;
@@ -724,6 +728,58 @@ impl AsyncPgConnection {
724728 . unwrap_or_else ( |p| p. into_inner ( ) )
725729 . on_connection_event ( event) ;
726730 }
731+
732+ /// See Postgres documentation for SQL commands [NOTIFY][] and [LISTEN][]
733+ ///
734+ /// The returned stream yields all notifications received by the connection, not only notifications received
735+ /// after calling the function. The returned stream will never close, so no notifications will just result
736+ /// in a pending state.
737+ ///
738+ /// If there's no connection available to poll, the stream will yield no notifications and be pending forever.
739+ /// This can happen if you created the [`AsyncPgConnection`] by the [`try_from`] constructor.
740+ ///
741+ /// [NOTIFY]: https://www.postgresql.org/docs/current/sql-notify.html
742+ /// [LISTEN]: https://www.postgresql.org/docs/current/sql-listen.html
743+ /// [`AsyncPgConnection`]: crate::pg::AsyncPgConnection
744+ /// [`try_from`]: crate::pg::AsyncPgConnection::try_from
745+ ///
746+ /// ```rust
747+ /// # include!("../doctest_setup.rs");
748+ /// # use scoped_futures::ScopedFutureExt;
749+ /// #
750+ /// # #[tokio::main(flavor = "current_thread")]
751+ /// # async fn main() {
752+ /// # run_test().await.unwrap();
753+ /// # }
754+ /// #
755+ /// # async fn run_test() -> QueryResult<()> {
756+ /// # use diesel_async::RunQueryDsl;
757+ /// # use futures_util::StreamExt;
758+ /// # let conn = &mut connection_no_transaction().await;
759+ /// // register the notifications channel we want to receive notifications for
760+ /// diesel::sql_query("LISTEN example_channel").execute(conn).await?;
761+ /// // send some notification (usually done from a different connection/thread/application)
762+ /// diesel::sql_query("NOTIFY example_channel, 'additional data'").execute(conn).await?;
763+ ///
764+ /// let mut notifications = std::pin::pin!(conn.notifications_stream());
765+ /// let mut notification = notifications.next().await.unwrap().unwrap();
766+ ///
767+ /// assert_eq!(notification.channel, "example_channel");
768+ /// assert_eq!(notification.payload, "additional data");
769+ /// println!("Notification received from process with id {}", notification.process_id);
770+ /// # Ok(())
771+ /// # }
772+ /// ```
773+ pub fn notifications_stream (
774+ & mut self ,
775+ ) -> impl futures_core:: Stream < Item = QueryResult < diesel:: pg:: PgNotification > > + ' _ {
776+ match & mut self . notification_rx {
777+ None => Either :: Left ( futures_util:: stream:: pending ( ) ) ,
778+ Some ( rx) => Either :: Right ( futures_util:: stream:: unfold ( rx, |rx| async {
779+ rx. recv ( ) . await . map ( move |item| ( item, rx) )
780+ } ) ) ,
781+ }
782+ }
727783}
728784
729785struct BindData {
@@ -969,27 +1025,44 @@ async fn drive_future<R>(
9691025}
9701026
9711027fn drive_connection < S > (
972- conn : tokio_postgres:: Connection < tokio_postgres:: Socket , S > ,
1028+ mut conn : tokio_postgres:: Connection < tokio_postgres:: Socket , S > ,
9731029) -> (
9741030 broadcast:: Receiver < Arc < tokio_postgres:: Error > > ,
1031+ mpsc:: UnboundedReceiver < QueryResult < diesel:: pg:: PgNotification > > ,
9751032 oneshot:: Sender < ( ) > ,
9761033)
9771034where
9781035 S : tokio_postgres:: tls:: TlsStream + Unpin + Send + ' static ,
9791036{
9801037 let ( error_tx, error_rx) = tokio:: sync:: broadcast:: channel ( 1 ) ;
981- let ( shutdown_tx, shutdown_rx) = tokio:: sync:: oneshot:: channel ( ) ;
1038+ let ( notification_tx, notification_rx) = tokio:: sync:: mpsc:: unbounded_channel ( ) ;
1039+ let ( shutdown_tx, mut shutdown_rx) = tokio:: sync:: oneshot:: channel ( ) ;
1040+ let mut conn = futures_util:: stream:: poll_fn ( move |cx| conn. poll_message ( cx) ) ;
9821041
9831042 tokio:: spawn ( async move {
984- match futures_util:: future:: select ( shutdown_rx, conn) . await {
985- Either :: Left ( _) | Either :: Right ( ( Ok ( _) , _) ) => { }
986- Either :: Right ( ( Err ( e) , _) ) => {
987- let _ = error_tx. send ( Arc :: new ( e) ) ;
1043+ loop {
1044+ match futures_util:: future:: select ( & mut shutdown_rx, conn. next ( ) ) . await {
1045+ Either :: Left ( _) | Either :: Right ( ( None , _) ) => break ,
1046+ Either :: Right ( ( Some ( Ok ( tokio_postgres:: AsyncMessage :: Notification ( notif) ) ) , _) ) => {
1047+ let _: Result < _ , _ > = notification_tx. send ( Ok ( diesel:: pg:: PgNotification {
1048+ process_id : notif. process_id ( ) ,
1049+ channel : notif. channel ( ) . to_owned ( ) ,
1050+ payload : notif. payload ( ) . to_owned ( ) ,
1051+ } ) ) ;
1052+ }
1053+ Either :: Right ( ( Some ( Ok ( _) ) , _) ) => { }
1054+ Either :: Right ( ( Some ( Err ( e) ) , _) ) => {
1055+ let e = Arc :: new ( e) ;
1056+ let _: Result < _ , _ > = error_tx. send ( e. clone ( ) ) ;
1057+ let _: Result < _ , _ > =
1058+ notification_tx. send ( Err ( error_helper:: from_tokio_postgres_error ( e) ) ) ;
1059+ break ;
1060+ }
9881061 }
9891062 }
9901063 } ) ;
9911064
992- ( error_rx, shutdown_tx)
1065+ ( error_rx, notification_rx , shutdown_tx)
9931066}
9941067
9951068#[ cfg( any(
0 commit comments