@@ -172,6 +172,7 @@ pub struct AsyncPgConnection {
172172 transaction_state : Arc < Mutex < AnsiTransactionManager > > ,
173173 metadata_cache : Arc < Mutex < PgMetadataCache > > ,
174174 connection_future : Option < broadcast:: Receiver < Arc < tokio_postgres:: Error > > > ,
175+ notification_rx : Option < broadcast:: Receiver < diesel:: pg:: PgNotification > > ,
175176 shutdown_channel : Option < oneshot:: Sender < ( ) > > ,
176177 // a sync mutex is fine here as we only hold it for a really short time
177178 instrumentation : Arc < std:: sync:: Mutex < DynInstrumentation > > ,
@@ -283,11 +284,12 @@ impl AsyncConnection for AsyncPgConnection {
283284 . await
284285 . map_err ( ErrorHelper ) ?;
285286
286- let ( error_rx, shutdown_tx) = drive_connection ( connection) ;
287+ let ( error_rx, notification_rx , shutdown_tx) = drive_connection ( connection) ;
287288
288289 let r = Self :: setup (
289290 client,
290291 Some ( error_rx) ,
292+ Some ( notification_rx) ,
291293 Some ( shutdown_tx) ,
292294 Arc :: clone ( & instrumentation) ,
293295 )
@@ -477,6 +479,7 @@ impl AsyncPgConnection {
477479 conn,
478480 None ,
479481 None ,
482+ None ,
480483 Arc :: new ( std:: sync:: Mutex :: new (
481484 DynInstrumentation :: default_instrumentation ( ) ,
482485 ) ) ,
@@ -493,11 +496,12 @@ impl AsyncPgConnection {
493496 where
494497 S : tokio_postgres:: tls:: TlsStream + Unpin + Send + ' static ,
495498 {
496- let ( error_rx, shutdown_tx) = drive_connection ( conn) ;
499+ let ( error_rx, notification_rx , shutdown_tx) = drive_connection ( conn) ;
497500
498501 Self :: setup (
499502 client,
500503 Some ( error_rx) ,
504+ Some ( notification_rx) ,
501505 Some ( shutdown_tx) ,
502506 Arc :: new ( std:: sync:: Mutex :: new ( DynInstrumentation :: none ( ) ) ) ,
503507 )
@@ -507,6 +511,7 @@ impl AsyncPgConnection {
507511 async fn setup (
508512 conn : tokio_postgres:: Client ,
509513 connection_future : Option < broadcast:: Receiver < Arc < tokio_postgres:: Error > > > ,
514+ notification_rx : Option < broadcast:: Receiver < diesel:: pg:: PgNotification > > ,
510515 shutdown_channel : Option < oneshot:: Sender < ( ) > > ,
511516 instrumentation : Arc < std:: sync:: Mutex < DynInstrumentation > > ,
512517 ) -> ConnectionResult < Self > {
@@ -516,6 +521,7 @@ impl AsyncPgConnection {
516521 transaction_state : Arc :: new ( Mutex :: new ( AnsiTransactionManager :: default ( ) ) ) ,
517522 metadata_cache : Arc :: new ( Mutex :: new ( PgMetadataCache :: new ( ) ) ) ,
518523 connection_future,
524+ notification_rx,
519525 shutdown_channel,
520526 instrumentation,
521527 } ;
@@ -724,6 +730,21 @@ impl AsyncPgConnection {
724730 . unwrap_or_else ( |p| p. into_inner ( ) )
725731 . on_connection_event ( event) ;
726732 }
733+
734+ pub fn notification_stream (
735+ & self ,
736+ ) -> impl futures_core:: Stream < Item = QueryResult < diesel:: pg:: PgNotification > > {
737+ futures_util:: stream:: unfold (
738+ self . notification_rx . as_ref ( ) . map ( |rx| rx. resubscribe ( ) ) ,
739+ |rx| async {
740+ let mut rx = rx?;
741+ match rx. recv ( ) . await {
742+ Ok ( notification) => Some ( ( Ok ( notification) , Some ( rx) ) ) ,
743+ Err ( _) => todo ! ( ) ,
744+ }
745+ } ,
746+ )
747+ }
727748}
728749
729750struct BindData {
@@ -969,27 +990,42 @@ async fn drive_future<R>(
969990}
970991
971992fn drive_connection < S > (
972- conn : tokio_postgres:: Connection < tokio_postgres:: Socket , S > ,
993+ mut conn : tokio_postgres:: Connection < tokio_postgres:: Socket , S > ,
973994) -> (
974995 broadcast:: Receiver < Arc < tokio_postgres:: Error > > ,
996+ broadcast:: Receiver < diesel:: pg:: PgNotification > ,
975997 oneshot:: Sender < ( ) > ,
976998)
977999where
9781000 S : tokio_postgres:: tls:: TlsStream + Unpin + Send + ' static ,
9791001{
9801002 let ( error_tx, error_rx) = tokio:: sync:: broadcast:: channel ( 1 ) ;
981- let ( shutdown_tx, shutdown_rx) = tokio:: sync:: oneshot:: channel ( ) ;
1003+ let ( notification_tx, notification_rx) = tokio:: sync:: broadcast:: channel ( 1 ) ;
1004+ let ( shutdown_tx, mut shutdown_rx) = tokio:: sync:: oneshot:: channel ( ) ;
9821005
9831006 tokio:: spawn ( async move {
984- match futures_util:: future:: select ( shutdown_rx, conn) . await {
985- Either :: Left ( _) | Either :: Right ( ( Ok ( _) , _) ) => { }
986- Either :: Right ( ( Err ( e) , _) ) => {
987- let _ = error_tx. send ( Arc :: new ( e) ) ;
1007+ let mut conn = futures_util:: stream:: poll_fn ( |cx| conn. poll_message ( cx) ) ;
1008+
1009+ loop {
1010+ match futures_util:: future:: select ( & mut shutdown_rx, conn. next ( ) ) . await {
1011+ Either :: Left ( _) | Either :: Right ( ( None , _) ) => break ,
1012+ Either :: Right ( ( Some ( Ok ( tokio_postgres:: AsyncMessage :: Notification ( notif) ) ) , _) ) => {
1013+ let _ = notification_tx. send ( diesel:: pg:: PgNotification {
1014+ process_id : notif. process_id ( ) ,
1015+ channel : notif. channel ( ) . to_owned ( ) ,
1016+ payload : notif. payload ( ) . to_owned ( ) ,
1017+ } ) ;
1018+ }
1019+ Either :: Right ( ( Some ( Ok ( _) ) , _) ) => { }
1020+ Either :: Right ( ( Some ( Err ( e) ) , _) ) => {
1021+ let _ = error_tx. send ( Arc :: new ( e) ) ;
1022+ break ;
1023+ }
9881024 }
9891025 }
9901026 } ) ;
9911027
992- ( error_rx, shutdown_tx)
1028+ ( error_rx, notification_rx , shutdown_tx)
9931029}
9941030
9951031#[ cfg( any(
0 commit comments