@@ -113,8 +113,7 @@ impl PubSubDriver for PostgresDriver {
113
113
tracing:: debug!( %subject, ?lock_id, "calculated advisory lock id" ) ;
114
114
115
115
// Create a single connection for both subscription and lock holding
116
- let ( client, mut connection) =
117
- tokio_postgres:: connect ( & self . conn_str , tokio_postgres:: NoTls ) . await ?;
116
+ let ( client, mut connection) = pg_connect ( & self . conn_str ) . await ?;
118
117
119
118
// Set up message forwarding
120
119
let ( tx, rx) = tokio:: sync:: mpsc:: unbounded_channel :: < String > ( ) ;
@@ -155,7 +154,7 @@ impl PubSubDriver for PostgresDriver {
155
154
let listen_subject = subject_owned. clone ( ) ;
156
155
157
156
// Spawn task to handle connection, lock acquisition, and LISTEN
158
- tokio:: spawn ( async move {
157
+ let poll_handle = tokio:: spawn ( async move {
159
158
// First acquire the lock while polling the connection
160
159
let lock_sql = format ! ( "SELECT pg_try_advisory_lock_shared({})" , lock_id) ;
161
160
let lock_future = client_clone. query_one ( & lock_sql, & [ ] ) ;
@@ -265,6 +264,7 @@ impl PubSubDriver for PostgresDriver {
265
264
lock_id,
266
265
client,
267
266
subject : subject. to_string ( ) ,
267
+ poll_handle,
268
268
} ) )
269
269
}
270
270
@@ -419,16 +419,15 @@ impl PubSubDriver for PostgresDriver {
419
419
// Create a temporary reply subject and a dedicated listener connection
420
420
let reply_subject = format ! ( "_INBOX.{}" , uuid:: Uuid :: new_v4( ) ) ;
421
421
422
- let ( client, mut connection) =
423
- tokio_postgres:: connect ( & self . conn_str , tokio_postgres:: NoTls ) . await ?;
422
+ let ( client, mut connection) = pg_connect ( & self . conn_str ) . await ?;
424
423
425
424
// Setup connection and LISTEN in a task
426
425
let ( listen_done_tx, listen_done_rx) = tokio:: sync:: oneshot:: channel ( ) ;
427
426
let reply_subject_clone = reply_subject. clone ( ) ;
428
427
429
428
// Spawn task to handle connection and LISTEN
430
429
let ( response_tx, mut response_rx) = tokio:: sync:: mpsc:: unbounded_channel ( ) ;
431
- tokio:: spawn ( async move {
430
+ let poll_handle = tokio:: spawn ( async move {
432
431
// Convert subject to base64 hash string because Postgres identifiers can only be 63 bytes
433
432
let mut hasher = DefaultHasher :: new ( ) ;
434
433
reply_subject_clone. hash ( & mut hasher) ;
@@ -513,14 +512,19 @@ impl PubSubDriver for PostgresDriver {
513
512
} ;
514
513
515
514
// Apply timeout if specified
516
- if let Some ( dur) = timeout {
515
+ let res = if let Some ( dur) = timeout {
517
516
match tokio:: time:: timeout ( dur, response_future) . await {
518
517
std:: result:: Result :: Ok ( resp) => resp,
519
518
std:: result:: Result :: Err ( _) => Err ( errors:: Ups :: RequestTimeout . build ( ) . into ( ) ) ,
520
519
}
521
520
} else {
522
521
response_future. await
523
- }
522
+ } ;
523
+
524
+ // Stop poll loop
525
+ poll_handle. abort ( ) ;
526
+
527
+ res
524
528
}
525
529
526
530
async fn send_request_reply ( & self , reply : & str , payload : & [ u8 ] ) -> Result < ( ) > {
@@ -599,6 +603,7 @@ pub struct PostgresSubscriber {
599
603
lock_id : i64 ,
600
604
client : Arc < tokio_postgres:: Client > ,
601
605
subject : String ,
606
+ poll_handle : tokio:: task:: JoinHandle < ( ) > ,
602
607
}
603
608
604
609
#[ async_trait]
@@ -716,5 +721,22 @@ impl Drop for PostgresSubscriber {
716
721
. execute ( "SELECT pg_advisory_unlock_shared($1)" , & [ & lock_id] )
717
722
. await ;
718
723
} ) ;
724
+
725
+ // Stop polling task
726
+ self . poll_handle . abort ( ) ;
719
727
}
720
728
}
729
+
730
+ async fn pg_connect (
731
+ conn_str : & str ,
732
+ ) -> Result < (
733
+ tokio_postgres:: Client ,
734
+ tokio_postgres:: Connection <
735
+ tokio_postgres:: Socket ,
736
+ <tokio_postgres:: NoTls as tokio_postgres:: tls:: TlsConnect < tokio_postgres:: Socket > >:: Stream ,
737
+ > ,
738
+ ) > {
739
+ let ( client, conn) = tokio_postgres:: connect ( conn_str, tokio_postgres:: NoTls ) . await ?;
740
+
741
+ Ok ( ( client, conn) )
742
+ }
0 commit comments