Skip to content

Commit 4af6fcd

Browse files
committed
Permit configuring the notice callback
Right now the behavior is hardcoded to log any received notices at the info level. Add a `notice_callback` configuration option that permits installing an arbitrary callback to handle any received notices. As discussed in #588.
1 parent 4237843 commit 4af6fcd

File tree

3 files changed

+60
-8
lines changed

3 files changed

+60
-8
lines changed

postgres/src/config.rs

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,16 @@
44
55
use crate::connection::Connection;
66
use crate::Client;
7+
use log::info;
78
use std::fmt;
89
use std::path::Path;
910
use std::str::FromStr;
11+
use std::sync::Arc;
1012
use std::time::Duration;
1113
use tokio::runtime;
1214
#[doc(inline)]
1315
pub use tokio_postgres::config::{ChannelBinding, Host, SslMode, TargetSessionAttrs};
16+
use tokio_postgres::error::DbError;
1417
use tokio_postgres::tls::{MakeTlsConnect, TlsConnect};
1518
use tokio_postgres::{Error, Socket};
1619

@@ -90,6 +93,7 @@ use tokio_postgres::{Error, Socket};
9093
#[derive(Clone)]
9194
pub struct Config {
9295
config: tokio_postgres::Config,
96+
notice_callback: Arc<dyn Fn(DbError) + Send + Sync>,
9397
}
9498

9599
impl fmt::Debug for Config {
@@ -109,9 +113,7 @@ impl Default for Config {
109113
impl Config {
110114
/// Creates a new configuration.
111115
pub fn new() -> Config {
112-
Config {
113-
config: tokio_postgres::Config::new(),
114-
}
116+
tokio_postgres::Config::new().into()
115117
}
116118

117119
/// Sets the user to authenticate with.
@@ -307,6 +309,25 @@ impl Config {
307309
self.config.get_channel_binding()
308310
}
309311

312+
/// Sets the notice callback.
313+
///
314+
/// This callback will be invoked with the contents of every
315+
/// [`AsyncMessage::Notice`] that is received by the connection. Notices use
316+
/// the same structure as errors, but they are not "errors" per-se.
317+
///
318+
/// Notices are distinct from notifications, which are instead accessible
319+
/// via the [`Notifications`] API.
320+
///
321+
/// [`AsyncMessage::Notice`]: tokio_postgres::AsyncMessage::Notice
322+
/// [`Notifications`]: crate::Notifications
323+
pub fn notice_callback<F>(&mut self, f: F) -> &mut Config
324+
where
325+
F: Fn(DbError) + Send + Sync + 'static,
326+
{
327+
self.notice_callback = Arc::new(f);
328+
self
329+
}
330+
310331
/// Opens a connection to a PostgreSQL database.
311332
pub fn connect<T>(&self, tls: T) -> Result<Client, Error>
312333
where
@@ -323,7 +344,7 @@ impl Config {
323344

324345
let (client, connection) = runtime.block_on(self.config.connect(tls))?;
325346

326-
let connection = Connection::new(runtime, connection);
347+
let connection = Connection::new(runtime, connection, self.notice_callback.clone());
327348
Ok(Client::new(connection, client))
328349
}
329350
}
@@ -338,6 +359,11 @@ impl FromStr for Config {
338359

339360
impl From<tokio_postgres::Config> for Config {
340361
fn from(config: tokio_postgres::Config) -> Config {
341-
Config { config }
362+
Config {
363+
config,
364+
notice_callback: Arc::new(|notice| {
365+
info!("{}: {}", notice.severity(), notice.message())
366+
}),
367+
}
342368
}
343369
}

postgres/src/connection.rs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,30 @@
11
use crate::{Error, Notification};
22
use futures::future;
33
use futures::{pin_mut, Stream};
4-
use log::info;
54
use std::collections::VecDeque;
65
use std::future::Future;
76
use std::ops::{Deref, DerefMut};
87
use std::pin::Pin;
8+
use std::sync::Arc;
99
use std::task::{Context, Poll};
1010
use tokio::io::{AsyncRead, AsyncWrite};
1111
use tokio::runtime::Runtime;
12+
use tokio_postgres::error::DbError;
1213
use tokio_postgres::AsyncMessage;
1314

1415
pub struct Connection {
1516
runtime: Runtime,
1617
connection: Pin<Box<dyn Stream<Item = Result<AsyncMessage, Error>> + Send>>,
1718
notifications: VecDeque<Notification>,
19+
notice_callback: Arc<dyn Fn(DbError)>,
1820
}
1921

2022
impl Connection {
21-
pub fn new<S, T>(runtime: Runtime, connection: tokio_postgres::Connection<S, T>) -> Connection
23+
pub fn new<S, T>(
24+
runtime: Runtime,
25+
connection: tokio_postgres::Connection<S, T>,
26+
notice_callback: Arc<dyn Fn(DbError)>,
27+
) -> Connection
2228
where
2329
S: AsyncRead + AsyncWrite + Unpin + 'static + Send,
2430
T: AsyncRead + AsyncWrite + Unpin + 'static + Send,
@@ -27,6 +33,7 @@ impl Connection {
2733
runtime,
2834
connection: Box::pin(ConnectionStream { connection }),
2935
notifications: VecDeque::new(),
36+
notice_callback,
3037
}
3138
}
3239

@@ -55,6 +62,7 @@ impl Connection {
5562
{
5663
let connection = &mut self.connection;
5764
let notifications = &mut self.notifications;
65+
let notice_callback = &mut self.notice_callback;
5866
self.runtime.block_on({
5967
future::poll_fn(|cx| {
6068
let done = loop {
@@ -63,7 +71,7 @@ impl Connection {
6371
notifications.push_back(notification);
6472
}
6573
Poll::Ready(Some(Ok(AsyncMessage::Notice(notice)))) => {
66-
info!("{}: {}", notice.severity(), notice.message());
74+
notice_callback(notice)
6775
}
6876
Poll::Ready(Some(Ok(_))) => {}
6977
Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e)),

postgres/src/test.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
use std::io::{Read, Write};
2+
use std::str::FromStr;
3+
use std::sync::mpsc;
24
use std::thread;
35
use std::time::Duration;
46
use tokio_postgres::error::SqlState;
@@ -476,6 +478,22 @@ fn notifications_timeout_iter() {
476478
assert_eq!(notifications[1].payload(), "world");
477479
}
478480

481+
#[test]
482+
fn notice_callback() {
483+
let (notice_tx, notice_rx) = mpsc::sync_channel(64);
484+
let mut client = Config::from_str("host=localhost port=5433 user=postgres")
485+
.unwrap()
486+
.notice_callback(move |n| notice_tx.send(n).unwrap())
487+
.connect(NoTls)
488+
.unwrap();
489+
490+
client
491+
.batch_execute("DO $$BEGIN RAISE NOTICE 'custom'; END$$")
492+
.unwrap();
493+
494+
assert_eq!(notice_rx.recv().unwrap().message(), "custom");
495+
}
496+
479497
#[test]
480498
fn explicit_close() {
481499
let client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap();

0 commit comments

Comments
 (0)