Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 45 additions & 9 deletions src/pg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ pub struct AsyncPgConnection {
transaction_state: Arc<Mutex<AnsiTransactionManager>>,
metadata_cache: Arc<Mutex<PgMetadataCache>>,
connection_future: Option<broadcast::Receiver<Arc<tokio_postgres::Error>>>,
notification_rx: Option<broadcast::Receiver<diesel::pg::PgNotification>>,
shutdown_channel: Option<oneshot::Sender<()>>,
// a sync mutex is fine here as we only hold it for a really short time
instrumentation: Arc<std::sync::Mutex<DynInstrumentation>>,
Expand Down Expand Up @@ -283,11 +284,12 @@ impl AsyncConnection for AsyncPgConnection {
.await
.map_err(ErrorHelper)?;

let (error_rx, shutdown_tx) = drive_connection(connection);
let (error_rx, notification_rx, shutdown_tx) = drive_connection(connection);

let r = Self::setup(
client,
Some(error_rx),
Some(notification_rx),
Some(shutdown_tx),
Arc::clone(&instrumentation),
)
Expand Down Expand Up @@ -477,6 +479,7 @@ impl AsyncPgConnection {
conn,
None,
None,
None,
Arc::new(std::sync::Mutex::new(
DynInstrumentation::default_instrumentation(),
)),
Expand All @@ -493,11 +496,12 @@ impl AsyncPgConnection {
where
S: tokio_postgres::tls::TlsStream + Unpin + Send + 'static,
{
let (error_rx, shutdown_tx) = drive_connection(conn);
let (error_rx, notification_rx, shutdown_tx) = drive_connection(conn);

Self::setup(
client,
Some(error_rx),
Some(notification_rx),
Some(shutdown_tx),
Arc::new(std::sync::Mutex::new(DynInstrumentation::none())),
)
Expand All @@ -507,6 +511,7 @@ impl AsyncPgConnection {
async fn setup(
conn: tokio_postgres::Client,
connection_future: Option<broadcast::Receiver<Arc<tokio_postgres::Error>>>,
notification_rx: Option<broadcast::Receiver<diesel::pg::PgNotification>>,
shutdown_channel: Option<oneshot::Sender<()>>,
instrumentation: Arc<std::sync::Mutex<DynInstrumentation>>,
) -> ConnectionResult<Self> {
Expand All @@ -516,6 +521,7 @@ impl AsyncPgConnection {
transaction_state: Arc::new(Mutex::new(AnsiTransactionManager::default())),
metadata_cache: Arc::new(Mutex::new(PgMetadataCache::new())),
connection_future,
notification_rx,
shutdown_channel,
instrumentation,
};
Expand Down Expand Up @@ -724,6 +730,21 @@ impl AsyncPgConnection {
.unwrap_or_else(|p| p.into_inner())
.on_connection_event(event);
}

pub fn notification_stream(
&self,
) -> impl futures_core::Stream<Item = QueryResult<diesel::pg::PgNotification>> {
futures_util::stream::unfold(
self.notification_rx.as_ref().map(|rx| rx.resubscribe()),
|rx| async {
let mut rx = rx?;
match rx.recv().await {
Ok(notification) => Some((Ok(notification), Some(rx))),
Err(_) => todo!(),
}
},
)
}
}

struct BindData {
Expand Down Expand Up @@ -969,27 +990,42 @@ async fn drive_future<R>(
}

fn drive_connection<S>(
conn: tokio_postgres::Connection<tokio_postgres::Socket, S>,
mut conn: tokio_postgres::Connection<tokio_postgres::Socket, S>,
) -> (
broadcast::Receiver<Arc<tokio_postgres::Error>>,
broadcast::Receiver<diesel::pg::PgNotification>,
oneshot::Sender<()>,
)
where
S: tokio_postgres::tls::TlsStream + Unpin + Send + 'static,
{
let (error_tx, error_rx) = tokio::sync::broadcast::channel(1);
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
let (notification_tx, notification_rx) = tokio::sync::broadcast::channel(1);
let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel();

tokio::spawn(async move {
match futures_util::future::select(shutdown_rx, conn).await {
Either::Left(_) | Either::Right((Ok(_), _)) => {}
Either::Right((Err(e), _)) => {
let _ = error_tx.send(Arc::new(e));
let mut conn = futures_util::stream::poll_fn(|cx| conn.poll_message(cx));

loop {
match futures_util::future::select(&mut shutdown_rx, conn.next()).await {
Either::Left(_) | Either::Right((None, _)) => break,
Either::Right((Some(Ok(tokio_postgres::AsyncMessage::Notification(notif))), _)) => {
let _ = notification_tx.send(diesel::pg::PgNotification {
process_id: notif.process_id(),
channel: notif.channel().to_owned(),
payload: notif.payload().to_owned(),
});
}
Either::Right((Some(Ok(_)), _)) => {}
Either::Right((Some(Err(e)), _)) => {
let _ = error_tx.send(Arc::new(e));
break;
}
}
}
});

(error_rx, shutdown_tx)
(error_rx, notification_rx, shutdown_tx)
}

#[cfg(any(
Expand Down
Loading