Skip to content

Commit 9107381

Browse files
committed
address some PR comments
1 parent 6d9d4bc commit 9107381

File tree

1 file changed

+25
-20
lines changed

1 file changed

+25
-20
lines changed

src/pg/mod.rs

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,7 @@ use futures_util::{FutureExt, StreamExt};
3131
use std::collections::{HashMap, HashSet};
3232
use std::future::Future;
3333
use std::sync::Arc;
34-
use tokio::sync::broadcast;
35-
use tokio::sync::oneshot;
36-
use tokio::sync::Mutex;
34+
use tokio::sync::{broadcast, mpsc, oneshot, Mutex};
3735
use tokio_postgres::types::ToSql;
3836
use tokio_postgres::types::Type;
3937
use tokio_postgres::Statement;
@@ -172,7 +170,7 @@ pub struct AsyncPgConnection {
172170
transaction_state: Arc<Mutex<AnsiTransactionManager>>,
173171
metadata_cache: Arc<Mutex<PgMetadataCache>>,
174172
connection_future: Option<broadcast::Receiver<Arc<tokio_postgres::Error>>>,
175-
notification_rx: Option<broadcast::Receiver<diesel::pg::PgNotification>>,
173+
notification_rx: Option<mpsc::UnboundedReceiver<diesel::pg::PgNotification>>,
176174
shutdown_channel: Option<oneshot::Sender<()>>,
177175
// a sync mutex is fine here as we only hold it for a really short time
178176
instrumentation: Arc<std::sync::Mutex<DynInstrumentation>>,
@@ -511,7 +509,7 @@ impl AsyncPgConnection {
511509
async fn setup(
512510
conn: tokio_postgres::Client,
513511
connection_future: Option<broadcast::Receiver<Arc<tokio_postgres::Error>>>,
514-
notification_rx: Option<broadcast::Receiver<diesel::pg::PgNotification>>,
512+
notification_rx: Option<mpsc::UnboundedReceiver<diesel::pg::PgNotification>>,
515513
shutdown_channel: Option<oneshot::Sender<()>>,
516514
instrumentation: Arc<std::sync::Mutex<DynInstrumentation>>,
517515
) -> ConnectionResult<Self> {
@@ -732,18 +730,25 @@ impl AsyncPgConnection {
732730
}
733731

734732
pub fn notification_stream(
735-
&self,
736-
) -> impl futures_core::Stream<Item = QueryResult<diesel::pg::PgNotification>> {
737-
futures_util::stream::unfold(
738-
self.notification_rx.as_ref().map(|rx| rx.resubscribe()),
739-
|rx| async {
740-
let mut rx = rx?;
741-
match rx.recv().await {
742-
Ok(notification) => Some((Ok(notification), Some(rx))),
743-
Err(_) => todo!(),
744-
}
745-
},
746-
)
733+
&mut self,
734+
) -> impl futures_core::Stream<Item = diesel::pg::PgNotification> + '_ {
735+
NotificationStream(self.notification_rx.as_mut())
736+
}
737+
}
738+
739+
struct NotificationStream<'a>(Option<&'a mut mpsc::UnboundedReceiver<diesel::pg::PgNotification>>);
740+
741+
impl futures_core::Stream for NotificationStream<'_> {
742+
type Item = diesel::pg::PgNotification;
743+
744+
fn poll_next(
745+
mut self: std::pin::Pin<&mut Self>,
746+
cx: &mut std::task::Context<'_>,
747+
) -> std::task::Poll<Option<Self::Item>> {
748+
match &mut self.0 {
749+
Some(rx) => rx.poll_recv(cx),
750+
None => std::task::Poll::Pending,
751+
}
747752
}
748753
}
749754

@@ -993,14 +998,14 @@ fn drive_connection<S>(
993998
mut conn: tokio_postgres::Connection<tokio_postgres::Socket, S>,
994999
) -> (
9951000
broadcast::Receiver<Arc<tokio_postgres::Error>>,
996-
broadcast::Receiver<diesel::pg::PgNotification>,
1001+
mpsc::UnboundedReceiver<diesel::pg::PgNotification>,
9971002
oneshot::Sender<()>,
9981003
)
9991004
where
10001005
S: tokio_postgres::tls::TlsStream + Unpin + Send + 'static,
10011006
{
10021007
let (error_tx, error_rx) = tokio::sync::broadcast::channel(1);
1003-
let (notification_tx, notification_rx) = tokio::sync::broadcast::channel(1);
1008+
let (notification_tx, notification_rx) = tokio::sync::mpsc::unbounded_channel();
10041009
let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel();
10051010

10061011
tokio::spawn(async move {
@@ -1010,7 +1015,7 @@ where
10101015
match futures_util::future::select(&mut shutdown_rx, conn.next()).await {
10111016
Either::Left(_) | Either::Right((None, _)) => break,
10121017
Either::Right((Some(Ok(tokio_postgres::AsyncMessage::Notification(notif))), _)) => {
1013-
let _ = notification_tx.send(diesel::pg::PgNotification {
1018+
let _: Result<_, _> = notification_tx.send(diesel::pg::PgNotification {
10141019
process_id: notif.process_id(),
10151020
channel: notif.channel().to_owned(),
10161021
payload: notif.payload().to_owned(),

0 commit comments

Comments
 (0)