@@ -172,6 +172,7 @@ pub struct AsyncPgConnection {
172
172
transaction_state : Arc < Mutex < AnsiTransactionManager > > ,
173
173
metadata_cache : Arc < Mutex < PgMetadataCache > > ,
174
174
connection_future : Option < broadcast:: Receiver < Arc < tokio_postgres:: Error > > > ,
175
+ notification_rx : Option < broadcast:: Receiver < diesel:: pg:: PgNotification > > ,
175
176
shutdown_channel : Option < oneshot:: Sender < ( ) > > ,
176
177
// a sync mutex is fine here as we only hold it for a really short time
177
178
instrumentation : Arc < std:: sync:: Mutex < DynInstrumentation > > ,
@@ -283,11 +284,12 @@ impl AsyncConnection for AsyncPgConnection {
283
284
. await
284
285
. map_err ( ErrorHelper ) ?;
285
286
286
- let ( error_rx, shutdown_tx) = drive_connection ( connection) ;
287
+ let ( error_rx, notification_rx , shutdown_tx) = drive_connection ( connection) ;
287
288
288
289
let r = Self :: setup (
289
290
client,
290
291
Some ( error_rx) ,
292
+ Some ( notification_rx) ,
291
293
Some ( shutdown_tx) ,
292
294
Arc :: clone ( & instrumentation) ,
293
295
)
@@ -477,6 +479,7 @@ impl AsyncPgConnection {
477
479
conn,
478
480
None ,
479
481
None ,
482
+ None ,
480
483
Arc :: new ( std:: sync:: Mutex :: new (
481
484
DynInstrumentation :: default_instrumentation ( ) ,
482
485
) ) ,
@@ -493,11 +496,12 @@ impl AsyncPgConnection {
493
496
where
494
497
S : tokio_postgres:: tls:: TlsStream + Unpin + Send + ' static ,
495
498
{
496
- let ( error_rx, shutdown_tx) = drive_connection ( conn) ;
499
+ let ( error_rx, notification_rx , shutdown_tx) = drive_connection ( conn) ;
497
500
498
501
Self :: setup (
499
502
client,
500
503
Some ( error_rx) ,
504
+ Some ( notification_rx) ,
501
505
Some ( shutdown_tx) ,
502
506
Arc :: new ( std:: sync:: Mutex :: new ( DynInstrumentation :: none ( ) ) ) ,
503
507
)
@@ -507,6 +511,7 @@ impl AsyncPgConnection {
507
511
async fn setup (
508
512
conn : tokio_postgres:: Client ,
509
513
connection_future : Option < broadcast:: Receiver < Arc < tokio_postgres:: Error > > > ,
514
+ notification_rx : Option < broadcast:: Receiver < diesel:: pg:: PgNotification > > ,
510
515
shutdown_channel : Option < oneshot:: Sender < ( ) > > ,
511
516
instrumentation : Arc < std:: sync:: Mutex < DynInstrumentation > > ,
512
517
) -> ConnectionResult < Self > {
@@ -516,6 +521,7 @@ impl AsyncPgConnection {
516
521
transaction_state : Arc :: new ( Mutex :: new ( AnsiTransactionManager :: default ( ) ) ) ,
517
522
metadata_cache : Arc :: new ( Mutex :: new ( PgMetadataCache :: new ( ) ) ) ,
518
523
connection_future,
524
+ notification_rx,
519
525
shutdown_channel,
520
526
instrumentation,
521
527
} ;
@@ -724,6 +730,21 @@ impl AsyncPgConnection {
724
730
. unwrap_or_else ( |p| p. into_inner ( ) )
725
731
. on_connection_event ( event) ;
726
732
}
733
+
734
+ pub fn notification_stream (
735
+ & self ,
736
+ ) -> impl futures_core:: Stream < Item = QueryResult < diesel:: pg:: PgNotification > > {
737
+ futures_util:: stream:: unfold (
738
+ self . notification_rx . as_ref ( ) . map ( |rx| rx. resubscribe ( ) ) ,
739
+ |rx| async {
740
+ let mut rx = rx?;
741
+ match rx. recv ( ) . await {
742
+ Ok ( notification) => Some ( ( Ok ( notification) , Some ( rx) ) ) ,
743
+ Err ( _) => todo ! ( ) ,
744
+ }
745
+ } ,
746
+ )
747
+ }
727
748
}
728
749
729
750
struct BindData {
@@ -969,27 +990,42 @@ async fn drive_future<R>(
969
990
}
970
991
971
992
fn drive_connection < S > (
972
- conn : tokio_postgres:: Connection < tokio_postgres:: Socket , S > ,
993
+ mut conn : tokio_postgres:: Connection < tokio_postgres:: Socket , S > ,
973
994
) -> (
974
995
broadcast:: Receiver < Arc < tokio_postgres:: Error > > ,
996
+ broadcast:: Receiver < diesel:: pg:: PgNotification > ,
975
997
oneshot:: Sender < ( ) > ,
976
998
)
977
999
where
978
1000
S : tokio_postgres:: tls:: TlsStream + Unpin + Send + ' static ,
979
1001
{
980
1002
let ( error_tx, error_rx) = tokio:: sync:: broadcast:: channel ( 1 ) ;
981
- let ( shutdown_tx, shutdown_rx) = tokio:: sync:: oneshot:: channel ( ) ;
1003
+ let ( notification_tx, notification_rx) = tokio:: sync:: broadcast:: channel ( 1 ) ;
1004
+ let ( shutdown_tx, mut shutdown_rx) = tokio:: sync:: oneshot:: channel ( ) ;
982
1005
983
1006
tokio:: spawn ( async move {
984
- match futures_util:: future:: select ( shutdown_rx, conn) . await {
985
- Either :: Left ( _) | Either :: Right ( ( Ok ( _) , _) ) => { }
986
- Either :: Right ( ( Err ( e) , _) ) => {
987
- let _ = error_tx. send ( Arc :: new ( e) ) ;
1007
+ let mut conn = futures_util:: stream:: poll_fn ( |cx| conn. poll_message ( cx) ) ;
1008
+
1009
+ loop {
1010
+ match futures_util:: future:: select ( & mut shutdown_rx, conn. next ( ) ) . await {
1011
+ Either :: Left ( _) | Either :: Right ( ( None , _) ) => break ,
1012
+ Either :: Right ( ( Some ( Ok ( tokio_postgres:: AsyncMessage :: Notification ( notif) ) ) , _) ) => {
1013
+ let _ = notification_tx. send ( diesel:: pg:: PgNotification {
1014
+ process_id : notif. process_id ( ) ,
1015
+ channel : notif. channel ( ) . to_owned ( ) ,
1016
+ payload : notif. payload ( ) . to_owned ( ) ,
1017
+ } ) ;
1018
+ }
1019
+ Either :: Right ( ( Some ( Ok ( _) ) , _) ) => { }
1020
+ Either :: Right ( ( Some ( Err ( e) ) , _) ) => {
1021
+ let _ = error_tx. send ( Arc :: new ( e) ) ;
1022
+ break ;
1023
+ }
988
1024
}
989
1025
}
990
1026
} ) ;
991
1027
992
- ( error_rx, shutdown_tx)
1028
+ ( error_rx, notification_rx , shutdown_tx)
993
1029
}
994
1030
995
1031
#[ cfg( any(
0 commit comments