Skip to content

Commit 6d9d4bc

Browse files
committed
implement notification_stream for AsyncPgConnection
1 parent 0e703ba commit 6d9d4bc

File tree

1 file changed

+45
-9
lines changed

1 file changed

+45
-9
lines changed

src/pg/mod.rs

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ pub struct AsyncPgConnection {
172172
transaction_state: Arc<Mutex<AnsiTransactionManager>>,
173173
metadata_cache: Arc<Mutex<PgMetadataCache>>,
174174
connection_future: Option<broadcast::Receiver<Arc<tokio_postgres::Error>>>,
175+
notification_rx: Option<broadcast::Receiver<diesel::pg::PgNotification>>,
175176
shutdown_channel: Option<oneshot::Sender<()>>,
176177
// a sync mutex is fine here as we only hold it for a really short time
177178
instrumentation: Arc<std::sync::Mutex<DynInstrumentation>>,
@@ -283,11 +284,12 @@ impl AsyncConnection for AsyncPgConnection {
283284
.await
284285
.map_err(ErrorHelper)?;
285286

286-
let (error_rx, shutdown_tx) = drive_connection(connection);
287+
let (error_rx, notification_rx, shutdown_tx) = drive_connection(connection);
287288

288289
let r = Self::setup(
289290
client,
290291
Some(error_rx),
292+
Some(notification_rx),
291293
Some(shutdown_tx),
292294
Arc::clone(&instrumentation),
293295
)
@@ -477,6 +479,7 @@ impl AsyncPgConnection {
477479
conn,
478480
None,
479481
None,
482+
None,
480483
Arc::new(std::sync::Mutex::new(
481484
DynInstrumentation::default_instrumentation(),
482485
)),
@@ -493,11 +496,12 @@ impl AsyncPgConnection {
493496
where
494497
S: tokio_postgres::tls::TlsStream + Unpin + Send + 'static,
495498
{
496-
let (error_rx, shutdown_tx) = drive_connection(conn);
499+
let (error_rx, notification_rx, shutdown_tx) = drive_connection(conn);
497500

498501
Self::setup(
499502
client,
500503
Some(error_rx),
504+
Some(notification_rx),
501505
Some(shutdown_tx),
502506
Arc::new(std::sync::Mutex::new(DynInstrumentation::none())),
503507
)
@@ -507,6 +511,7 @@ impl AsyncPgConnection {
507511
async fn setup(
508512
conn: tokio_postgres::Client,
509513
connection_future: Option<broadcast::Receiver<Arc<tokio_postgres::Error>>>,
514+
notification_rx: Option<broadcast::Receiver<diesel::pg::PgNotification>>,
510515
shutdown_channel: Option<oneshot::Sender<()>>,
511516
instrumentation: Arc<std::sync::Mutex<DynInstrumentation>>,
512517
) -> ConnectionResult<Self> {
@@ -516,6 +521,7 @@ impl AsyncPgConnection {
516521
transaction_state: Arc::new(Mutex::new(AnsiTransactionManager::default())),
517522
metadata_cache: Arc::new(Mutex::new(PgMetadataCache::new())),
518523
connection_future,
524+
notification_rx,
519525
shutdown_channel,
520526
instrumentation,
521527
};
@@ -724,6 +730,21 @@ impl AsyncPgConnection {
724730
.unwrap_or_else(|p| p.into_inner())
725731
.on_connection_event(event);
726732
}
733+
734+
pub fn notification_stream(
735+
&self,
736+
) -> impl futures_core::Stream<Item = QueryResult<diesel::pg::PgNotification>> {
737+
futures_util::stream::unfold(
738+
self.notification_rx.as_ref().map(|rx| rx.resubscribe()),
739+
|rx| async {
740+
let mut rx = rx?;
741+
match rx.recv().await {
742+
Ok(notification) => Some((Ok(notification), Some(rx))),
743+
Err(_) => todo!(),
744+
}
745+
},
746+
)
747+
}
727748
}
728749

729750
struct BindData {
@@ -969,27 +990,42 @@ async fn drive_future<R>(
969990
}
970991

971992
fn drive_connection<S>(
972-
conn: tokio_postgres::Connection<tokio_postgres::Socket, S>,
993+
mut conn: tokio_postgres::Connection<tokio_postgres::Socket, S>,
973994
) -> (
974995
broadcast::Receiver<Arc<tokio_postgres::Error>>,
996+
broadcast::Receiver<diesel::pg::PgNotification>,
975997
oneshot::Sender<()>,
976998
)
977999
where
9781000
S: tokio_postgres::tls::TlsStream + Unpin + Send + 'static,
9791001
{
9801002
let (error_tx, error_rx) = tokio::sync::broadcast::channel(1);
981-
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
1003+
let (notification_tx, notification_rx) = tokio::sync::broadcast::channel(1);
1004+
let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel();
9821005

9831006
tokio::spawn(async move {
984-
match futures_util::future::select(shutdown_rx, conn).await {
985-
Either::Left(_) | Either::Right((Ok(_), _)) => {}
986-
Either::Right((Err(e), _)) => {
987-
let _ = error_tx.send(Arc::new(e));
1007+
let mut conn = futures_util::stream::poll_fn(|cx| conn.poll_message(cx));
1008+
1009+
loop {
1010+
match futures_util::future::select(&mut shutdown_rx, conn.next()).await {
1011+
Either::Left(_) | Either::Right((None, _)) => break,
1012+
Either::Right((Some(Ok(tokio_postgres::AsyncMessage::Notification(notif))), _)) => {
1013+
let _ = notification_tx.send(diesel::pg::PgNotification {
1014+
process_id: notif.process_id(),
1015+
channel: notif.channel().to_owned(),
1016+
payload: notif.payload().to_owned(),
1017+
});
1018+
}
1019+
Either::Right((Some(Ok(_)), _)) => {}
1020+
Either::Right((Some(Err(e)), _)) => {
1021+
let _ = error_tx.send(Arc::new(e));
1022+
break;
1023+
}
9881024
}
9891025
}
9901026
});
9911027

992-
(error_rx, shutdown_tx)
1028+
(error_rx, notification_rx, shutdown_tx)
9931029
}
9941030

9951031
#[cfg(any(

0 commit comments

Comments
 (0)