@@ -31,9 +31,7 @@ use futures_util::{FutureExt, StreamExt};
31
31
use std:: collections:: { HashMap , HashSet } ;
32
32
use std:: future:: Future ;
33
33
use std:: sync:: Arc ;
34
- use tokio:: sync:: broadcast;
35
- use tokio:: sync:: oneshot;
36
- use tokio:: sync:: Mutex ;
34
+ use tokio:: sync:: { broadcast, mpsc, oneshot, Mutex } ;
37
35
use tokio_postgres:: types:: ToSql ;
38
36
use tokio_postgres:: types:: Type ;
39
37
use tokio_postgres:: Statement ;
@@ -172,6 +170,7 @@ pub struct AsyncPgConnection {
172
170
transaction_state : Arc < Mutex < AnsiTransactionManager > > ,
173
171
metadata_cache : Arc < Mutex < PgMetadataCache > > ,
174
172
connection_future : Option < broadcast:: Receiver < Arc < tokio_postgres:: Error > > > ,
173
+ notification_rx : Option < mpsc:: UnboundedReceiver < QueryResult < diesel:: pg:: PgNotification > > > ,
175
174
shutdown_channel : Option < oneshot:: Sender < ( ) > > ,
176
175
// a sync mutex is fine here as we only hold it for a really short time
177
176
instrumentation : Arc < std:: sync:: Mutex < DynInstrumentation > > ,
@@ -283,11 +282,12 @@ impl AsyncConnection for AsyncPgConnection {
283
282
. await
284
283
. map_err ( ErrorHelper ) ?;
285
284
286
- let ( error_rx, shutdown_tx) = drive_connection ( connection) ;
285
+ let ( error_rx, notification_rx , shutdown_tx) = drive_connection ( connection) ;
287
286
288
287
let r = Self :: setup (
289
288
client,
290
289
Some ( error_rx) ,
290
+ Some ( notification_rx) ,
291
291
Some ( shutdown_tx) ,
292
292
Arc :: clone ( & instrumentation) ,
293
293
)
@@ -477,6 +477,7 @@ impl AsyncPgConnection {
477
477
conn,
478
478
None ,
479
479
None ,
480
+ None ,
480
481
Arc :: new ( std:: sync:: Mutex :: new (
481
482
DynInstrumentation :: default_instrumentation ( ) ,
482
483
) ) ,
@@ -493,11 +494,12 @@ impl AsyncPgConnection {
493
494
where
494
495
S : tokio_postgres:: tls:: TlsStream + Unpin + Send + ' static ,
495
496
{
496
- let ( error_rx, shutdown_tx) = drive_connection ( conn) ;
497
+ let ( error_rx, notification_rx , shutdown_tx) = drive_connection ( conn) ;
497
498
498
499
Self :: setup (
499
500
client,
500
501
Some ( error_rx) ,
502
+ Some ( notification_rx) ,
501
503
Some ( shutdown_tx) ,
502
504
Arc :: new ( std:: sync:: Mutex :: new ( DynInstrumentation :: none ( ) ) ) ,
503
505
)
@@ -507,6 +509,7 @@ impl AsyncPgConnection {
507
509
async fn setup (
508
510
conn : tokio_postgres:: Client ,
509
511
connection_future : Option < broadcast:: Receiver < Arc < tokio_postgres:: Error > > > ,
512
+ notification_rx : Option < mpsc:: UnboundedReceiver < QueryResult < diesel:: pg:: PgNotification > > > ,
510
513
shutdown_channel : Option < oneshot:: Sender < ( ) > > ,
511
514
instrumentation : Arc < std:: sync:: Mutex < DynInstrumentation > > ,
512
515
) -> ConnectionResult < Self > {
@@ -516,6 +519,7 @@ impl AsyncPgConnection {
516
519
transaction_state : Arc :: new ( Mutex :: new ( AnsiTransactionManager :: default ( ) ) ) ,
517
520
metadata_cache : Arc :: new ( Mutex :: new ( PgMetadataCache :: new ( ) ) ) ,
518
521
connection_future,
522
+ notification_rx,
519
523
shutdown_channel,
520
524
instrumentation,
521
525
} ;
@@ -724,6 +728,58 @@ impl AsyncPgConnection {
724
728
. unwrap_or_else ( |p| p. into_inner ( ) )
725
729
. on_connection_event ( event) ;
726
730
}
731
+
732
+ /// See Postgres documentation for SQL commands [NOTIFY][] and [LISTEN][]
733
+ ///
734
+ /// The returned stream yields all notifications received by the connection, not only notifications received
735
+ /// after calling the function. The returned stream will never close, so no notifications will just result
736
+ /// in a pending state.
737
+ ///
738
+ /// If there's no connection available to poll, the stream will yield no notifications and be pending forever.
739
+ /// This can happen if you created the [`AsyncPgConnection`] by the [`try_from`] constructor.
740
+ ///
741
+ /// [NOTIFY]: https://www.postgresql.org/docs/current/sql-notify.html
742
+ /// [LISTEN]: https://www.postgresql.org/docs/current/sql-listen.html
743
+ /// [`AsyncPgConnection`]: crate::pg::AsyncPgConnection
744
+ /// [`try_from`]: crate::pg::AsyncPgConnection::try_from
745
+ ///
746
+ /// ```rust
747
+ /// # include!("../doctest_setup.rs");
748
+ /// # use scoped_futures::ScopedFutureExt;
749
+ /// #
750
+ /// # #[tokio::main(flavor = "current_thread")]
751
+ /// # async fn main() {
752
+ /// # run_test().await.unwrap();
753
+ /// # }
754
+ /// #
755
+ /// # async fn run_test() -> QueryResult<()> {
756
+ /// # use diesel_async::RunQueryDsl;
757
+ /// # use futures_util::StreamExt;
758
+ /// # let conn = &mut connection_no_transaction().await;
759
+ /// // register the notifications channel we want to receive notifications for
760
+ /// diesel::sql_query("LISTEN example_channel").execute(conn).await?;
761
+ /// // send some notification (usually done from a different connection/thread/application)
762
+ /// diesel::sql_query("NOTIFY example_channel, 'additional data'").execute(conn).await?;
763
+ ///
764
+ /// let mut notifications = std::pin::pin!(conn.notifications_stream());
765
+ /// let mut notification = notifications.next().await.unwrap().unwrap();
766
+ ///
767
+ /// assert_eq!(notification.channel, "example_channel");
768
+ /// assert_eq!(notification.payload, "additional data");
769
+ /// println!("Notification received from process with id {}", notification.process_id);
770
+ /// # Ok(())
771
+ /// # }
772
+ /// ```
773
+ pub fn notifications_stream (
774
+ & mut self ,
775
+ ) -> impl futures_core:: Stream < Item = QueryResult < diesel:: pg:: PgNotification > > + ' _ {
776
+ match & mut self . notification_rx {
777
+ None => Either :: Left ( futures_util:: stream:: pending ( ) ) ,
778
+ Some ( rx) => Either :: Right ( futures_util:: stream:: unfold ( rx, |rx| async {
779
+ rx. recv ( ) . await . map ( move |item| ( item, rx) )
780
+ } ) ) ,
781
+ }
782
+ }
727
783
}
728
784
729
785
struct BindData {
@@ -969,27 +1025,44 @@ async fn drive_future<R>(
969
1025
}
970
1026
971
1027
fn drive_connection < S > (
972
- conn : tokio_postgres:: Connection < tokio_postgres:: Socket , S > ,
1028
+ mut conn : tokio_postgres:: Connection < tokio_postgres:: Socket , S > ,
973
1029
) -> (
974
1030
broadcast:: Receiver < Arc < tokio_postgres:: Error > > ,
1031
+ mpsc:: UnboundedReceiver < QueryResult < diesel:: pg:: PgNotification > > ,
975
1032
oneshot:: Sender < ( ) > ,
976
1033
)
977
1034
where
978
1035
S : tokio_postgres:: tls:: TlsStream + Unpin + Send + ' static ,
979
1036
{
980
1037
let ( error_tx, error_rx) = tokio:: sync:: broadcast:: channel ( 1 ) ;
981
- let ( shutdown_tx, shutdown_rx) = tokio:: sync:: oneshot:: channel ( ) ;
1038
+ let ( notification_tx, notification_rx) = tokio:: sync:: mpsc:: unbounded_channel ( ) ;
1039
+ let ( shutdown_tx, mut shutdown_rx) = tokio:: sync:: oneshot:: channel ( ) ;
1040
+ let mut conn = futures_util:: stream:: poll_fn ( move |cx| conn. poll_message ( cx) ) ;
982
1041
983
1042
tokio:: spawn ( async move {
984
- match futures_util:: future:: select ( shutdown_rx, conn) . await {
985
- Either :: Left ( _) | Either :: Right ( ( Ok ( _) , _) ) => { }
986
- Either :: Right ( ( Err ( e) , _) ) => {
987
- let _ = error_tx. send ( Arc :: new ( e) ) ;
1043
+ loop {
1044
+ match futures_util:: future:: select ( & mut shutdown_rx, conn. next ( ) ) . await {
1045
+ Either :: Left ( _) | Either :: Right ( ( None , _) ) => break ,
1046
+ Either :: Right ( ( Some ( Ok ( tokio_postgres:: AsyncMessage :: Notification ( notif) ) ) , _) ) => {
1047
+ let _: Result < _ , _ > = notification_tx. send ( Ok ( diesel:: pg:: PgNotification {
1048
+ process_id : notif. process_id ( ) ,
1049
+ channel : notif. channel ( ) . to_owned ( ) ,
1050
+ payload : notif. payload ( ) . to_owned ( ) ,
1051
+ } ) ) ;
1052
+ }
1053
+ Either :: Right ( ( Some ( Ok ( _) ) , _) ) => { }
1054
+ Either :: Right ( ( Some ( Err ( e) ) , _) ) => {
1055
+ let e = Arc :: new ( e) ;
1056
+ let _: Result < _ , _ > = error_tx. send ( e. clone ( ) ) ;
1057
+ let _: Result < _ , _ > =
1058
+ notification_tx. send ( Err ( error_helper:: from_tokio_postgres_error ( e) ) ) ;
1059
+ break ;
1060
+ }
988
1061
}
989
1062
}
990
1063
} ) ;
991
1064
992
- ( error_rx, shutdown_tx)
1065
+ ( error_rx, notification_rx , shutdown_tx)
993
1066
}
994
1067
995
1068
#[ cfg( any(
0 commit comments