Skip to content

Commit 53c52a4

Browse files
authored
Merge pull request #251 from lsunsi/notification-stream
Implement notification_stream for AsyncPgConnection
2 parents 0e703ba + 848c241 commit 53c52a4

File tree

3 files changed

+141
-12
lines changed

3 files changed

+141
-12
lines changed

src/pg/mod.rs

Lines changed: 85 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,7 @@ use futures_util::{FutureExt, StreamExt};
3131
use std::collections::{HashMap, HashSet};
3232
use std::future::Future;
3333
use std::sync::Arc;
34-
use tokio::sync::broadcast;
35-
use tokio::sync::oneshot;
36-
use tokio::sync::Mutex;
34+
use tokio::sync::{broadcast, mpsc, oneshot, Mutex};
3735
use tokio_postgres::types::ToSql;
3836
use tokio_postgres::types::Type;
3937
use tokio_postgres::Statement;
@@ -172,6 +170,7 @@ pub struct AsyncPgConnection {
172170
transaction_state: Arc<Mutex<AnsiTransactionManager>>,
173171
metadata_cache: Arc<Mutex<PgMetadataCache>>,
174172
connection_future: Option<broadcast::Receiver<Arc<tokio_postgres::Error>>>,
173+
notification_rx: Option<mpsc::UnboundedReceiver<QueryResult<diesel::pg::PgNotification>>>,
175174
shutdown_channel: Option<oneshot::Sender<()>>,
176175
// a sync mutex is fine here as we only hold it for a really short time
177176
instrumentation: Arc<std::sync::Mutex<DynInstrumentation>>,
@@ -283,11 +282,12 @@ impl AsyncConnection for AsyncPgConnection {
283282
.await
284283
.map_err(ErrorHelper)?;
285284

286-
let (error_rx, shutdown_tx) = drive_connection(connection);
285+
let (error_rx, notification_rx, shutdown_tx) = drive_connection(connection);
287286

288287
let r = Self::setup(
289288
client,
290289
Some(error_rx),
290+
Some(notification_rx),
291291
Some(shutdown_tx),
292292
Arc::clone(&instrumentation),
293293
)
@@ -477,6 +477,7 @@ impl AsyncPgConnection {
477477
conn,
478478
None,
479479
None,
480+
None,
480481
Arc::new(std::sync::Mutex::new(
481482
DynInstrumentation::default_instrumentation(),
482483
)),
@@ -493,11 +494,12 @@ impl AsyncPgConnection {
493494
where
494495
S: tokio_postgres::tls::TlsStream + Unpin + Send + 'static,
495496
{
496-
let (error_rx, shutdown_tx) = drive_connection(conn);
497+
let (error_rx, notification_rx, shutdown_tx) = drive_connection(conn);
497498

498499
Self::setup(
499500
client,
500501
Some(error_rx),
502+
Some(notification_rx),
501503
Some(shutdown_tx),
502504
Arc::new(std::sync::Mutex::new(DynInstrumentation::none())),
503505
)
@@ -507,6 +509,7 @@ impl AsyncPgConnection {
507509
async fn setup(
508510
conn: tokio_postgres::Client,
509511
connection_future: Option<broadcast::Receiver<Arc<tokio_postgres::Error>>>,
512+
notification_rx: Option<mpsc::UnboundedReceiver<QueryResult<diesel::pg::PgNotification>>>,
510513
shutdown_channel: Option<oneshot::Sender<()>>,
511514
instrumentation: Arc<std::sync::Mutex<DynInstrumentation>>,
512515
) -> ConnectionResult<Self> {
@@ -516,6 +519,7 @@ impl AsyncPgConnection {
516519
transaction_state: Arc::new(Mutex::new(AnsiTransactionManager::default())),
517520
metadata_cache: Arc::new(Mutex::new(PgMetadataCache::new())),
518521
connection_future,
522+
notification_rx,
519523
shutdown_channel,
520524
instrumentation,
521525
};
@@ -724,6 +728,58 @@ impl AsyncPgConnection {
724728
.unwrap_or_else(|p| p.into_inner())
725729
.on_connection_event(event);
726730
}
731+
732+
/// See Postgres documentation for SQL commands [NOTIFY][] and [LISTEN][]
733+
///
734+
/// The returned stream yields all notifications received by the connection, not only notifications received
735+
/// after calling the function. The returned stream will never close, so no notifications will just result
736+
/// in a pending state.
737+
///
738+
/// If there's no connection available to poll, the stream will yield no notifications and be pending forever.
739+
/// This can happen if you created the [`AsyncPgConnection`] by the [`try_from`] constructor.
740+
///
741+
/// [NOTIFY]: https://www.postgresql.org/docs/current/sql-notify.html
742+
/// [LISTEN]: https://www.postgresql.org/docs/current/sql-listen.html
743+
/// [`AsyncPgConnection`]: crate::pg::AsyncPgConnection
744+
/// [`try_from`]: crate::pg::AsyncPgConnection::try_from
745+
///
746+
/// ```rust
747+
/// # include!("../doctest_setup.rs");
748+
/// # use scoped_futures::ScopedFutureExt;
749+
/// #
750+
/// # #[tokio::main(flavor = "current_thread")]
751+
/// # async fn main() {
752+
/// # run_test().await.unwrap();
753+
/// # }
754+
/// #
755+
/// # async fn run_test() -> QueryResult<()> {
756+
/// # use diesel_async::RunQueryDsl;
757+
/// # use futures_util::StreamExt;
758+
/// # let conn = &mut connection_no_transaction().await;
759+
/// // register the notifications channel we want to receive notifications for
760+
/// diesel::sql_query("LISTEN example_channel").execute(conn).await?;
761+
/// // send some notification (usually done from a different connection/thread/application)
762+
/// diesel::sql_query("NOTIFY example_channel, 'additional data'").execute(conn).await?;
763+
///
764+
/// let mut notifications = std::pin::pin!(conn.notifications_stream());
765+
/// let mut notification = notifications.next().await.unwrap().unwrap();
766+
///
767+
/// assert_eq!(notification.channel, "example_channel");
768+
/// assert_eq!(notification.payload, "additional data");
769+
/// println!("Notification received from process with id {}", notification.process_id);
770+
/// # Ok(())
771+
/// # }
772+
/// ```
773+
pub fn notifications_stream(
774+
&mut self,
775+
) -> impl futures_core::Stream<Item = QueryResult<diesel::pg::PgNotification>> + '_ {
776+
match &mut self.notification_rx {
777+
None => Either::Left(futures_util::stream::pending()),
778+
Some(rx) => Either::Right(futures_util::stream::unfold(rx, |rx| async {
779+
rx.recv().await.map(move |item| (item, rx))
780+
})),
781+
}
782+
}
727783
}
728784

729785
struct BindData {
@@ -969,27 +1025,44 @@ async fn drive_future<R>(
9691025
}
9701026

9711027
fn drive_connection<S>(
972-
conn: tokio_postgres::Connection<tokio_postgres::Socket, S>,
1028+
mut conn: tokio_postgres::Connection<tokio_postgres::Socket, S>,
9731029
) -> (
9741030
broadcast::Receiver<Arc<tokio_postgres::Error>>,
1031+
mpsc::UnboundedReceiver<QueryResult<diesel::pg::PgNotification>>,
9751032
oneshot::Sender<()>,
9761033
)
9771034
where
9781035
S: tokio_postgres::tls::TlsStream + Unpin + Send + 'static,
9791036
{
9801037
let (error_tx, error_rx) = tokio::sync::broadcast::channel(1);
981-
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
1038+
let (notification_tx, notification_rx) = tokio::sync::mpsc::unbounded_channel();
1039+
let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel();
1040+
let mut conn = futures_util::stream::poll_fn(move |cx| conn.poll_message(cx));
9821041

9831042
tokio::spawn(async move {
984-
match futures_util::future::select(shutdown_rx, conn).await {
985-
Either::Left(_) | Either::Right((Ok(_), _)) => {}
986-
Either::Right((Err(e), _)) => {
987-
let _ = error_tx.send(Arc::new(e));
1043+
loop {
1044+
match futures_util::future::select(&mut shutdown_rx, conn.next()).await {
1045+
Either::Left(_) | Either::Right((None, _)) => break,
1046+
Either::Right((Some(Ok(tokio_postgres::AsyncMessage::Notification(notif))), _)) => {
1047+
let _: Result<_, _> = notification_tx.send(Ok(diesel::pg::PgNotification {
1048+
process_id: notif.process_id(),
1049+
channel: notif.channel().to_owned(),
1050+
payload: notif.payload().to_owned(),
1051+
}));
1052+
}
1053+
Either::Right((Some(Ok(_)), _)) => {}
1054+
Either::Right((Some(Err(e)), _)) => {
1055+
let e = Arc::new(e);
1056+
let _: Result<_, _> = error_tx.send(e.clone());
1057+
let _: Result<_, _> =
1058+
notification_tx.send(Err(error_helper::from_tokio_postgres_error(e)));
1059+
break;
1060+
}
9881061
}
9891062
}
9901063
});
9911064

992-
(error_rx, shutdown_tx)
1065+
(error_rx, notification_rx, shutdown_tx)
9931066
}
9941067

9951068
#[cfg(any(

tests/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use std::fmt::Debug;
77
#[cfg(feature = "postgres")]
88
mod custom_types;
99
mod instrumentation;
10+
mod notifications;
1011
#[cfg(any(feature = "bb8", feature = "deadpool", feature = "mobc"))]
1112
mod pooling;
1213
#[cfg(feature = "async-connection-wrapper")]

tests/notifications.rs

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#[cfg(feature = "postgres")]
2+
#[tokio::test]
3+
async fn notifications_arrive() {
4+
use diesel_async::RunQueryDsl;
5+
use futures_util::{StreamExt, TryStreamExt};
6+
7+
let conn = &mut super::connection_without_transaction().await;
8+
9+
diesel::sql_query("LISTEN test_notifications")
10+
.execute(conn)
11+
.await
12+
.unwrap();
13+
14+
diesel::sql_query("NOTIFY test_notifications, 'first'")
15+
.execute(conn)
16+
.await
17+
.unwrap();
18+
19+
diesel::sql_query("NOTIFY test_notifications, 'second'")
20+
.execute(conn)
21+
.await
22+
.unwrap();
23+
24+
let notifications = conn
25+
.notifications_stream()
26+
.take(2)
27+
.try_collect::<Vec<_>>()
28+
.await
29+
.unwrap();
30+
31+
assert_eq!(2, notifications.len());
32+
assert_eq!(notifications[0].channel, "test_notifications");
33+
assert_eq!(notifications[1].channel, "test_notifications");
34+
assert_eq!(notifications[0].payload, "first");
35+
assert_eq!(notifications[1].payload, "second");
36+
37+
let next_notification = tokio::time::timeout(
38+
std::time::Duration::from_secs(1),
39+
std::pin::pin!(conn.notifications_stream()).next(),
40+
)
41+
.await;
42+
43+
assert!(
44+
next_notification.is_err(),
45+
"Got a next notification, while not expecting one: {next_notification:?}"
46+
);
47+
48+
diesel::sql_query("NOTIFY test_notifications")
49+
.execute(conn)
50+
.await
51+
.unwrap();
52+
53+
let next_notification = std::pin::pin!(conn.notifications_stream()).next().await;
54+
assert_eq!(next_notification.unwrap().unwrap().payload, "");
55+
}

0 commit comments

Comments
 (0)