1
1
use std:: collections:: HashMap ;
2
2
use std:: hash:: { DefaultHasher , Hash , Hasher } ;
3
- use std:: sync:: { Arc , Mutex } ;
3
+ use std:: sync:: Arc ;
4
4
5
5
use anyhow:: * ;
6
6
use async_trait:: async_trait;
7
7
use base64:: Engine ;
8
8
use base64:: engine:: general_purpose:: STANDARD_NO_PAD as BASE64 ;
9
9
use deadpool_postgres:: { Config , ManagerConfig , Pool , PoolConfig , RecyclingMethod , Runtime } ;
10
10
use futures_util:: future:: poll_fn;
11
+ use rivet_util:: backoff:: Backoff ;
12
+ use tokio:: sync:: { Mutex , broadcast} ;
11
13
use tokio_postgres:: { AsyncMessage , NoTls } ;
12
14
use tracing:: Instrument ;
13
15
@@ -17,13 +19,13 @@ use crate::pubsub::DriverOutput;
17
19
#[ derive( Clone ) ]
18
20
struct Subscription {
19
21
// Channel to send messages to this subscription
20
- tx : tokio :: sync :: broadcast:: Sender < Vec < u8 > > ,
22
+ tx : broadcast:: Sender < Vec < u8 > > ,
21
23
// Cancellation token shared by all subscribers of this subject
22
24
token : tokio_util:: sync:: CancellationToken ,
23
25
}
24
26
25
27
impl Subscription {
26
- fn new ( tx : tokio :: sync :: broadcast:: Sender < Vec < u8 > > ) -> Self {
28
+ fn new ( tx : broadcast:: Sender < Vec < u8 > > ) -> Self {
27
29
let token = tokio_util:: sync:: CancellationToken :: new ( ) ;
28
30
Self { tx, token }
29
31
}
@@ -48,8 +50,9 @@ pub const POSTGRES_MAX_MESSAGE_SIZE: usize =
48
50
#[ derive( Clone ) ]
49
51
pub struct PostgresDriver {
50
52
pool : Arc < Pool > ,
51
- client : Arc < tokio_postgres:: Client > ,
53
+ client : Arc < Mutex < Option < Arc < tokio_postgres:: Client > > > > ,
52
54
subscriptions : Arc < Mutex < HashMap < String , Subscription > > > ,
55
+ client_ready : tokio:: sync:: watch:: Receiver < bool > ,
53
56
}
54
57
55
58
impl PostgresDriver {
@@ -76,48 +79,168 @@ impl PostgresDriver {
76
79
77
80
let subscriptions: Arc < Mutex < HashMap < String , Subscription > > > =
78
81
Arc :: new ( Mutex :: new ( HashMap :: new ( ) ) ) ;
79
- let subscriptions2 = subscriptions . clone ( ) ;
82
+ let client : Arc < Mutex < Option < Arc < tokio_postgres :: Client > > > > = Arc :: new ( Mutex :: new ( None ) ) ;
80
83
81
- let ( client, mut conn) = tokio_postgres:: connect ( & conn_str, tokio_postgres:: NoTls ) . await ?;
82
- tokio:: spawn ( async move {
83
- // NOTE: This loop will stop automatically when client is dropped
84
- loop {
85
- match poll_fn ( |cx| conn. poll_message ( cx) ) . await {
86
- Some ( std:: result:: Result :: Ok ( AsyncMessage :: Notification ( note) ) ) => {
87
- if let Some ( sub) =
88
- subscriptions2. lock ( ) . unwrap ( ) . get ( note. channel ( ) ) . cloned ( )
89
- {
90
- let bytes = match BASE64 . decode ( note. payload ( ) ) {
91
- std:: result:: Result :: Ok ( b) => b,
92
- std:: result:: Result :: Err ( err) => {
93
- tracing:: error!( ?err, "failed decoding base64" ) ;
94
- break ;
95
- }
96
- } ;
97
- let _ = sub. tx . send ( bytes) ;
98
- }
99
- }
100
- Some ( std:: result:: Result :: Ok ( _) ) => {
101
- // Ignore other async messages
84
+ // Create channel for client ready notifications
85
+ let ( ready_tx, client_ready) = tokio:: sync:: watch:: channel ( false ) ;
86
+
87
+ // Spawn connection lifecycle task
88
+ tokio:: spawn ( Self :: spawn_connection_lifecycle (
89
+ conn_str. clone ( ) ,
90
+ subscriptions. clone ( ) ,
91
+ client. clone ( ) ,
92
+ ready_tx,
93
+ ) ) ;
94
+
95
+ let driver = Self {
96
+ pool : Arc :: new ( pool) ,
97
+ client,
98
+ subscriptions,
99
+ client_ready,
100
+ } ;
101
+
102
+ // Wait for initial connection to be established
103
+ driver. wait_for_client ( ) . await ?;
104
+
105
+ Ok ( driver)
106
+ }
107
+
108
+ /// Manages the connection lifecycle with automatic reconnection
109
+ async fn spawn_connection_lifecycle (
110
+ conn_str : String ,
111
+ subscriptions : Arc < Mutex < HashMap < String , Subscription > > > ,
112
+ client : Arc < Mutex < Option < Arc < tokio_postgres:: Client > > > > ,
113
+ ready_tx : tokio:: sync:: watch:: Sender < bool > ,
114
+ ) {
115
+ let mut backoff = Backoff :: new ( 8 , None , 1_000 , 1_000 ) ;
116
+
117
+ loop {
118
+ match tokio_postgres:: connect ( & conn_str, tokio_postgres:: NoTls ) . await {
119
+ Result :: Ok ( ( new_client, conn) ) => {
120
+ tracing:: info!( "postgres listen connection established" ) ;
121
+ // Reset backoff on successful connection
122
+ backoff = Backoff :: new ( 8 , None , 1_000 , 1_000 ) ;
123
+
124
+ let new_client = Arc :: new ( new_client) ;
125
+
126
+ // Update the client reference immediately
127
+ * client. lock ( ) . await = Some ( new_client. clone ( ) ) ;
128
+ // Notify that client is ready
129
+ let _ = ready_tx. send ( true ) ;
130
+
131
+ // Get channels to re-subscribe to
132
+ let channels: Vec < String > =
133
+ subscriptions. lock ( ) . await . keys ( ) . cloned ( ) . collect ( ) ;
134
+ let needs_resubscribe = !channels. is_empty ( ) ;
135
+
136
+ if needs_resubscribe {
137
+ tracing:: debug!(
138
+ ?channels,
139
+ "will re-subscribe to channels after connection starts"
140
+ ) ;
102
141
}
103
- Some ( std:: result:: Result :: Err ( err) ) => {
104
- tracing:: error!( ?err, "async postgres error" ) ;
105
- break ;
142
+
143
+ // Spawn a task to re-subscribe after a short delay
144
+ if needs_resubscribe {
145
+ let client_for_resub = new_client. clone ( ) ;
146
+ let channels_clone = channels. clone ( ) ;
147
+ tokio:: spawn ( async move {
148
+ tracing:: debug!(
149
+ ?channels_clone,
150
+ "re-subscribing to channels after reconnection"
151
+ ) ;
152
+ for channel in & channels_clone {
153
+ if let Result :: Err ( e) = client_for_resub
154
+ . execute ( & format ! ( "LISTEN \" {}\" " , channel) , & [ ] )
155
+ . await
156
+ {
157
+ tracing:: error!( ?e, %channel, "failed to re-subscribe to channel" ) ;
158
+ } else {
159
+ tracing:: debug!( %channel, "successfully re-subscribed to channel" ) ;
160
+ }
161
+ }
162
+ } ) ;
106
163
}
107
- None => {
108
- tracing:: debug!( "async postgres connection closed" ) ;
109
- break ;
164
+
165
+ // Poll the connection until it closes
166
+ Self :: poll_connection ( conn, subscriptions. clone ( ) ) . await ;
167
+
168
+ // Clear the client reference on disconnect
169
+ * client. lock ( ) . await = None ;
170
+ // Notify that client is disconnected
171
+ let _ = ready_tx. send ( false ) ;
172
+ }
173
+ Result :: Err ( e) => {
174
+ tracing:: error!( ?e, "failed to connect to postgres, retrying" ) ;
175
+ backoff. tick ( ) . await ;
176
+ }
177
+ }
178
+ }
179
+ }
180
+
181
+ /// Polls the connection for notifications until it closes or errors
182
+ async fn poll_connection (
183
+ mut conn : tokio_postgres:: Connection <
184
+ tokio_postgres:: Socket ,
185
+ tokio_postgres:: tls:: NoTlsStream ,
186
+ > ,
187
+ subscriptions : Arc < Mutex < HashMap < String , Subscription > > > ,
188
+ ) {
189
+ loop {
190
+ match poll_fn ( |cx| conn. poll_message ( cx) ) . await {
191
+ Some ( std:: result:: Result :: Ok ( AsyncMessage :: Notification ( note) ) ) => {
192
+ tracing:: trace!( channel = %note. channel( ) , "received notification" ) ;
193
+ if let Some ( sub) = subscriptions. lock ( ) . await . get ( note. channel ( ) ) . cloned ( ) {
194
+ let bytes = match BASE64 . decode ( note. payload ( ) ) {
195
+ std:: result:: Result :: Ok ( b) => b,
196
+ std:: result:: Result :: Err ( err) => {
197
+ tracing:: error!( ?err, "failed decoding base64" ) ;
198
+ continue ;
199
+ }
200
+ } ;
201
+ tracing:: trace!( channel = %note. channel( ) , bytes_len = bytes. len( ) , "sending to broadcast channel" ) ;
202
+ let _ = sub. tx . send ( bytes) ;
203
+ } else {
204
+ tracing:: warn!( channel = %note. channel( ) , "received notification for unknown channel" ) ;
110
205
}
111
206
}
207
+ Some ( std:: result:: Result :: Ok ( _) ) => {
208
+ // Ignore other async messages
209
+ }
210
+ Some ( std:: result:: Result :: Err ( err) ) => {
211
+ tracing:: error!( ?err, "postgres connection error, reconnecting" ) ;
212
+ break ; // Exit loop to reconnect
213
+ }
214
+ None => {
215
+ tracing:: warn!( "postgres connection closed, reconnecting" ) ;
216
+ break ; // Exit loop to reconnect
217
+ }
112
218
}
113
- tracing :: debug! ( "listen connection closed" ) ;
114
- } ) ;
219
+ }
220
+ }
115
221
116
- Ok ( Self {
117
- pool : Arc :: new ( pool) ,
118
- client : Arc :: new ( client) ,
119
- subscriptions,
222
+ /// Wait for the client to be connected
223
+ async fn wait_for_client ( & self ) -> Result < Arc < tokio_postgres:: Client > > {
224
+ let mut ready_rx = self . client_ready . clone ( ) ;
225
+ tokio:: time:: timeout ( tokio:: time:: Duration :: from_secs ( 5 ) , async {
226
+ loop {
227
+ // Subscribe to changed before attempting to access client
228
+ let changed_fut = ready_rx. changed ( ) ;
229
+
230
+ // Check if client is already available
231
+ if let Some ( client) = self . client . lock ( ) . await . clone ( ) {
232
+ return Ok ( client) ;
233
+ }
234
+
235
+ // Wait for change, will return client if exists on next iteration
236
+ changed_fut
237
+ . await
238
+ . map_err ( |_| anyhow ! ( "connection lifecycle task ended" ) ) ?;
239
+ tracing:: debug!( "client does not exist immediately after receive ready" ) ;
240
+ }
120
241
} )
242
+ . await
243
+ . map_err ( |_| anyhow ! ( "timeout waiting for postgres client connection" ) ) ?
121
244
}
122
245
123
246
fn hash_subject ( & self , subject : & str ) -> String {
@@ -147,7 +270,7 @@ impl PubSubDriver for PostgresDriver {
147
270
148
271
// Check if we already have a subscription for this channel
149
272
let ( rx, drop_guard) =
150
- if let Some ( existing_sub) = self . subscriptions . lock ( ) . unwrap ( ) . get ( & hashed) . cloned ( ) {
273
+ if let Some ( existing_sub) = self . subscriptions . lock ( ) . await . get ( & hashed) . cloned ( ) {
151
274
// Reuse the existing broadcast channel
152
275
let rx = existing_sub. tx . subscribe ( ) ;
153
276
let drop_guard = existing_sub. token . clone ( ) . drop_guard ( ) ;
@@ -160,13 +283,15 @@ impl PubSubDriver for PostgresDriver {
160
283
// Register subscription
161
284
self . subscriptions
162
285
. lock ( )
163
- . unwrap ( )
286
+ . await
164
287
. insert ( hashed. clone ( ) , subscription. clone ( ) ) ;
165
288
166
289
// Execute LISTEN command on the async client (for receiving notifications)
167
290
// This only needs to be done once per channel
291
+ // Wait for client to be connected with retry logic
292
+ let client = self . wait_for_client ( ) . await ?;
168
293
let span = tracing:: trace_span!( "pg_listen" ) ;
169
- self . client
294
+ client
170
295
. execute ( & format ! ( "LISTEN \" {hashed}\" " ) , & [ ] )
171
296
. instrument ( span)
172
297
. await ?;
@@ -179,13 +304,16 @@ impl PubSubDriver for PostgresDriver {
179
304
tokio:: spawn ( async move {
180
305
token_clone. cancelled ( ) . await ;
181
306
if tx_clone. receiver_count ( ) == 0 {
182
- let sql = format ! ( "UNLISTEN \" {}\" " , hashed_clone) ;
183
- if let Err ( err) = driver. client . execute ( sql. as_str ( ) , & [ ] ) . await {
184
- tracing:: warn!( ?err, %hashed_clone, "failed to UNLISTEN channel" ) ;
185
- } else {
186
- tracing:: trace!( %hashed_clone, "unlistened channel" ) ;
307
+ let client = driver. client . lock ( ) . await . clone ( ) ;
308
+ if let Some ( client) = client {
309
+ let sql = format ! ( "UNLISTEN \" {}\" " , hashed_clone) ;
310
+ if let Err ( err) = client. execute ( sql. as_str ( ) , & [ ] ) . await {
311
+ tracing:: warn!( ?err, %hashed_clone, "failed to UNLISTEN channel" ) ;
312
+ } else {
313
+ tracing:: trace!( %hashed_clone, "unlistened channel" ) ;
314
+ }
187
315
}
188
- driver. subscriptions . lock ( ) . unwrap ( ) . remove ( & hashed_clone) ;
316
+ driver. subscriptions . lock ( ) . await . remove ( & hashed_clone) ;
189
317
}
190
318
} ) ;
191
319
0 commit comments