1
+ use std:: collections:: HashMap ;
1
2
use std:: hash:: { DefaultHasher , Hash , Hasher } ;
2
- use std:: sync:: Arc ;
3
+ use std:: sync:: { Arc , Mutex } ;
3
4
4
5
use anyhow:: * ;
5
6
use async_trait:: async_trait;
6
7
use base64:: Engine ;
7
8
use base64:: engine:: general_purpose:: STANDARD_NO_PAD as BASE64 ;
8
9
use deadpool_postgres:: { Config , ManagerConfig , Pool , PoolConfig , RecyclingMethod , Runtime } ;
9
10
use futures_util:: future:: poll_fn;
10
- use moka:: future:: Cache ;
11
11
use tokio_postgres:: { AsyncMessage , NoTls } ;
12
12
use tracing:: Instrument ;
13
13
@@ -18,6 +18,15 @@ use crate::pubsub::DriverOutput;
18
18
struct Subscription {
19
19
// Channel to send messages to this subscription
20
20
tx : tokio:: sync:: broadcast:: Sender < Vec < u8 > > ,
21
+ // Cancellation token shared by all subscribers of this subject
22
+ token : tokio_util:: sync:: CancellationToken ,
23
+ }
24
+
25
+ impl Subscription {
26
+ fn new ( tx : tokio:: sync:: broadcast:: Sender < Vec < u8 > > ) -> Self {
27
+ let token = tokio_util:: sync:: CancellationToken :: new ( ) ;
28
+ Self { tx, token }
29
+ }
21
30
}
22
31
23
32
/// > In the default configuration it must be shorter than 8000 bytes
@@ -40,7 +49,7 @@ pub const POSTGRES_MAX_MESSAGE_SIZE: usize =
40
49
pub struct PostgresDriver {
41
50
pool : Arc < Pool > ,
42
51
client : Arc < tokio_postgres:: Client > ,
43
- subscriptions : Cache < String , Subscription > ,
52
+ subscriptions : Arc < Mutex < HashMap < String , Subscription > > > ,
44
53
}
45
54
46
55
impl PostgresDriver {
@@ -65,8 +74,8 @@ impl PostgresDriver {
65
74
. context ( "failed to create postgres pool" ) ?;
66
75
tracing:: debug!( "postgres pool created successfully" ) ;
67
76
68
- let subscriptions: Cache < String , Subscription > =
69
- Cache :: builder ( ) . initial_capacity ( 5 ) . build ( ) ;
77
+ let subscriptions: Arc < Mutex < HashMap < String , Subscription > > > =
78
+ Arc :: new ( Mutex :: new ( HashMap :: new ( ) ) ) ;
70
79
let subscriptions2 = subscriptions. clone ( ) ;
71
80
72
81
let ( client, mut conn) = tokio_postgres:: connect ( & conn_str, tokio_postgres:: NoTls ) . await ?;
@@ -75,7 +84,9 @@ impl PostgresDriver {
75
84
loop {
76
85
match poll_fn ( |cx| conn. poll_message ( cx) ) . await {
77
86
Some ( std:: result:: Result :: Ok ( AsyncMessage :: Notification ( note) ) ) => {
78
- if let Some ( sub) = subscriptions2. get ( note. channel ( ) ) . await {
87
+ if let Some ( sub) =
88
+ subscriptions2. lock ( ) . unwrap ( ) . get ( note. channel ( ) ) . cloned ( )
89
+ {
79
90
let bytes = match BASE64 . decode ( note. payload ( ) ) {
80
91
std:: result:: Result :: Ok ( b) => b,
81
92
std:: result:: Result :: Err ( err) => {
@@ -121,7 +132,7 @@ impl PostgresDriver {
121
132
#[ async_trait]
122
133
impl PubSubDriver for PostgresDriver {
123
134
async fn subscribe ( & self , subject : & str ) -> Result < SubscriberDriverHandle > {
124
- // TODO: To match NATS implementation, LIST must be pipelined (i.e. wait for the command
135
+ // TODO: To match NATS implementation, LISTEN must be pipelined (i.e. wait for the command
125
136
// to reach the server, but not wait for it to respond). However, this has to ensure that
126
137
// NOTIFY & LISTEN are called on the same connection (not diff connections in a pool) or
127
138
// else there will be race conditions where messages might be published before
@@ -135,33 +146,57 @@ impl PubSubDriver for PostgresDriver {
135
146
let hashed = self . hash_subject ( subject) ;
136
147
137
148
// Check if we already have a subscription for this channel
138
- let rx = if let Some ( existing_sub) = self . subscriptions . get ( & hashed) . await {
139
- // Reuse the existing broadcast channel
140
- existing_sub. tx . subscribe ( )
141
- } else {
142
- // Create a new broadcast channel for this subject
143
- let ( tx, rx) = tokio:: sync:: broadcast:: channel ( 1024 ) ;
144
- let subscription = Subscription { tx : tx. clone ( ) } ;
145
-
146
- // Register subscription
147
- self . subscriptions
148
- . insert ( hashed. clone ( ) , subscription)
149
- . await ;
150
-
151
- // Execute LISTEN command on the async client (for receiving notifications)
152
- // This only needs to be done once per channel
153
- let span = tracing:: trace_span!( "pg_listen" ) ;
154
- self . client
155
- . execute ( & format ! ( "LISTEN \" {hashed}\" " ) , & [ ] )
156
- . instrument ( span)
157
- . await ?;
158
-
159
- rx
160
- } ;
149
+ let ( rx, drop_guard) =
150
+ if let Some ( existing_sub) = self . subscriptions . lock ( ) . unwrap ( ) . get ( & hashed) . cloned ( ) {
151
+ // Reuse the existing broadcast channel
152
+ let rx = existing_sub. tx . subscribe ( ) ;
153
+ let drop_guard = existing_sub. token . clone ( ) . drop_guard ( ) ;
154
+ ( rx, drop_guard)
155
+ } else {
156
+ // Create a new broadcast channel for this subject
157
+ let ( tx, rx) = tokio:: sync:: broadcast:: channel ( 1024 ) ;
158
+ let subscription = Subscription :: new ( tx. clone ( ) ) ;
159
+
160
+ // Register subscription
161
+ self . subscriptions
162
+ . lock ( )
163
+ . unwrap ( )
164
+ . insert ( hashed. clone ( ) , subscription. clone ( ) ) ;
165
+
166
+ // Execute LISTEN command on the async client (for receiving notifications)
167
+ // This only needs to be done once per channel
168
+ let span = tracing:: trace_span!( "pg_listen" ) ;
169
+ self . client
170
+ . execute ( & format ! ( "LISTEN \" {hashed}\" " ) , & [ ] )
171
+ . instrument ( span)
172
+ . await ?;
173
+
174
+ // Spawn a single cleanup task for this subscription waiting on its token
175
+ let driver = self . clone ( ) ;
176
+ let hashed_clone = hashed. clone ( ) ;
177
+ let tx_clone = tx. clone ( ) ;
178
+ let token_clone = subscription. token . clone ( ) ;
179
+ tokio:: spawn ( async move {
180
+ token_clone. cancelled ( ) . await ;
181
+ if tx_clone. receiver_count ( ) == 0 {
182
+ let sql = format ! ( "UNLISTEN \" {}\" " , hashed_clone) ;
183
+ if let Err ( err) = driver. client . execute ( sql. as_str ( ) , & [ ] ) . await {
184
+ tracing:: warn!( ?err, %hashed_clone, "failed to UNLISTEN channel" ) ;
185
+ } else {
186
+ tracing:: trace!( %hashed_clone, "unlistened channel" ) ;
187
+ }
188
+ driver. subscriptions . lock ( ) . unwrap ( ) . remove ( & hashed_clone) ;
189
+ }
190
+ } ) ;
191
+
192
+ let drop_guard = subscription. token . clone ( ) . drop_guard ( ) ;
193
+ ( rx, drop_guard)
194
+ } ;
161
195
162
196
Ok ( Box :: new ( PostgresSubscriber {
163
197
subject : subject. to_string ( ) ,
164
- rx,
198
+ rx : Some ( rx) ,
199
+ _drop_guard : drop_guard,
165
200
} ) )
166
201
}
167
202
@@ -191,13 +226,18 @@ impl PubSubDriver for PostgresDriver {
191
226
192
227
pub struct PostgresSubscriber {
193
228
subject : String ,
194
- rx : tokio:: sync:: broadcast:: Receiver < Vec < u8 > > ,
229
+ rx : Option < tokio:: sync:: broadcast:: Receiver < Vec < u8 > > > ,
230
+ _drop_guard : tokio_util:: sync:: DropGuard ,
195
231
}
196
232
197
233
#[ async_trait]
198
234
impl SubscriberDriver for PostgresSubscriber {
199
235
async fn next ( & mut self ) -> Result < DriverOutput > {
200
- match self . rx . recv ( ) . await {
236
+ let rx = match self . rx . as_mut ( ) {
237
+ Some ( rx) => rx,
238
+ None => return Ok ( DriverOutput :: Unsubscribed ) ,
239
+ } ;
240
+ match rx. recv ( ) . await {
201
241
std:: result:: Result :: Ok ( payload) => Ok ( DriverOutput :: Message {
202
242
subject : self . subject . clone ( ) ,
203
243
payload,
0 commit comments