@@ -253,7 +253,49 @@ impl RelayPool {
253
253
U : TryIntoUrl ,
254
254
Error : From < <U as TryIntoUrl >:: Err > ,
255
255
{
256
- self . inner . add_relay ( url, opts) . await
256
+ // Convert into url
257
+ let url: RelayUrl = url. try_into_url ( ) ?;
258
+
259
+ // Check if the pool has been shutdown
260
+ if self . inner . is_shutdown ( ) {
261
+ return Err ( Error :: Shutdown ) ;
262
+ }
263
+
264
+ // Get relays
265
+ let mut relays = self . inner . atomic . relays . write ( ) . await ;
266
+
267
+ // Check if map already contains url
268
+ if relays. contains_key ( & url) {
269
+ return Ok ( false ) ;
270
+ }
271
+
272
+ // Check number fo relays and limit
273
+ if let Some ( max) = self . inner . opts . max_relays {
274
+ if relays. len ( ) >= max {
275
+ return Err ( Error :: TooManyRelays { limit : max } ) ;
276
+ }
277
+ }
278
+
279
+ // Compose new relay
280
+ let mut relay: Relay = Relay :: new ( url, self . inner . state . clone ( ) , opts) ;
281
+
282
+ // Set notification sender
283
+ relay
284
+ . inner
285
+ . set_notification_sender ( self . inner . notification_sender . clone ( ) ) ;
286
+
287
+ // If relay has `READ` flag, inherit pool subscriptions
288
+ if relay. flags ( ) . has_read ( ) {
289
+ let subscriptions = self . subscriptions ( ) . await ;
290
+ for ( id, filters) in subscriptions. into_iter ( ) {
291
+ relay. inner . update_subscription ( id, filters, false ) . await ;
292
+ }
293
+ }
294
+
295
+ // Insert relay into map
296
+ relays. insert ( relay. url ( ) . clone ( ) , relay) ;
297
+
298
+ Ok ( true )
257
299
}
258
300
259
301
// Private API
@@ -270,12 +312,41 @@ impl RelayPool {
270
312
match self . relay ( & url) . await {
271
313
Ok ( relay) => Ok ( Some ( relay) ) ,
272
314
Err ( ..) => {
273
- self . inner . add_relay ( url, opts) . await ?;
315
+ self . add_relay ( url, opts) . await ?;
274
316
Ok ( None )
275
317
}
276
318
}
277
319
}
278
320
321
+ async fn _remove_relay < U > ( & self , url : U , force : bool ) -> Result < ( ) , Error >
322
+ where
323
+ U : TryIntoUrl ,
324
+ Error : From < <U as TryIntoUrl >:: Err > ,
325
+ {
326
+ // Convert into url
327
+ let url: RelayUrl = url. try_into_url ( ) ?;
328
+
329
+ // Acquire write lock
330
+ let mut relays = self . inner . atomic . relays . write ( ) . await ;
331
+
332
+ // Remove relay
333
+ let relay: Relay = relays. remove ( & url) . ok_or ( Error :: RelayNotFound ) ?;
334
+
335
+ // If NOT force, check if it has `GOSSIP` flag
336
+ if !force {
337
+ // If can't be removed, re-insert it.
338
+ if !can_remove_relay ( & relay) {
339
+ relays. insert ( url, relay) ;
340
+ return Ok ( ( ) ) ;
341
+ }
342
+ }
343
+
344
+ // Disconnect
345
+ relay. disconnect ( ) ;
346
+
347
+ Ok ( ( ) )
348
+ }
349
+
279
350
/// Remove and disconnect relay
280
351
///
281
352
/// If the relay has [`RelayServiceFlags::GOSSIP`], it will not be removed from the pool and its
@@ -289,7 +360,7 @@ impl RelayPool {
289
360
U : TryIntoUrl ,
290
361
Error : From < <U as TryIntoUrl >:: Err > ,
291
362
{
292
- self . inner . remove_relay ( url, false ) . await
363
+ self . _remove_relay ( url, false ) . await
293
364
}
294
365
295
366
/// Force remove and disconnect relay
@@ -301,7 +372,7 @@ impl RelayPool {
301
372
U : TryIntoUrl ,
302
373
Error : From < <U as TryIntoUrl >:: Err > ,
303
374
{
304
- self . inner . remove_relay ( url, true ) . await
375
+ self . _remove_relay ( url, true ) . await
305
376
}
306
377
307
378
/// Disconnect and remove all relays
@@ -310,7 +381,11 @@ impl RelayPool {
310
381
/// Use [`RelayPool::force_remove_all_relays`] to remove every relay.
311
382
#[ inline]
312
383
pub async fn remove_all_relays ( & self ) {
313
- self . inner . remove_all_relays ( ) . await
384
+ // Acquire write lock
385
+ let mut relays = self . inner . atomic . relays . write ( ) . await ;
386
+
387
+ // Retains all relays that can't be removed
388
+ relays. retain ( |_, r| !can_remove_relay ( r) ) ;
314
389
}
315
390
316
391
/// Disconnect and force remove all relays
@@ -486,21 +561,34 @@ impl RelayPool {
486
561
/// Get subscriptions
487
562
#[ inline]
488
563
pub async fn subscriptions ( & self ) -> HashMap < SubscriptionId , Filter > {
489
- self . inner . subscriptions ( ) . await
564
+ self . inner . atomic . subscriptions . read ( ) . await . clone ( )
490
565
}
491
566
492
- /// Get subscription
567
+ /// Get a subscription
493
568
#[ inline]
494
569
pub async fn subscription ( & self , id : & SubscriptionId ) -> Option < Filter > {
495
- self . inner . subscription ( id) . await
570
+ let subscriptions = self . inner . atomic . subscriptions . read ( ) . await ;
571
+ subscriptions. get ( id) . cloned ( )
496
572
}
497
573
498
574
/// Register subscription in the [RelayPool]
499
575
///
500
576
/// When a new relay will be added, saved subscriptions will be automatically used for it.
501
577
#[ inline]
502
578
pub async fn save_subscription ( & self , id : SubscriptionId , filter : Filter ) {
503
- self . inner . save_subscription ( id, filter) . await
579
+ let mut subscriptions = self . inner . atomic . subscriptions . write ( ) . await ;
580
+ let current: & mut Filter = subscriptions. entry ( id) . or_default ( ) ;
581
+ * current = filter;
582
+ }
583
+
584
+ async fn remove_subscription ( & self , id : & SubscriptionId ) {
585
+ let mut subscriptions = self . inner . atomic . subscriptions . write ( ) . await ;
586
+ subscriptions. remove ( id) ;
587
+ }
588
+
589
+ async fn remove_all_subscriptions ( & self ) {
590
+ let mut subscriptions = self . inner . atomic . subscriptions . write ( ) . await ;
591
+ subscriptions. clear ( ) ;
504
592
}
505
593
506
594
/// Send a client message to specific relays
@@ -833,7 +921,7 @@ impl RelayPool {
833
921
/// Unsubscribe from subscription
834
922
pub async fn unsubscribe ( & self , id : & SubscriptionId ) {
835
923
// Remove subscription from pool
836
- self . inner . remove_subscription ( id) . await ;
924
+ self . remove_subscription ( id) . await ;
837
925
838
926
// Lock with read shared access
839
927
let relays = self . inner . atomic . relays . read ( ) . await ;
@@ -851,7 +939,7 @@ impl RelayPool {
851
939
/// Unsubscribe from all subscriptions
852
940
pub async fn unsubscribe_all ( & self ) {
853
941
// Remove subscriptions from pool
854
- self . inner . remove_all_subscriptions ( ) . await ;
942
+ self . remove_all_subscriptions ( ) . await ;
855
943
856
944
// Lock with read shared access
857
945
let relays = self . inner . atomic . relays . read ( ) . await ;
@@ -1165,6 +1253,26 @@ impl RelayPool {
1165
1253
}
1166
1254
}
1167
1255
1256
+ /// Return `true` if the relay can be removed
1257
+ ///
1258
+ /// If it CAN'T be removed,
1259
+ /// the flags are automatically updated (remove `READ`, `WRITE` and `DISCOVERY` flags).
1260
+ fn can_remove_relay ( relay : & Relay ) -> bool {
1261
+ let flags = relay. flags ( ) ;
1262
+ if flags. has_any ( RelayServiceFlags :: GOSSIP ) {
1263
+ // Remove READ, WRITE and DISCOVERY flags
1264
+ flags. remove (
1265
+ RelayServiceFlags :: READ | RelayServiceFlags :: WRITE | RelayServiceFlags :: DISCOVERY ,
1266
+ ) ;
1267
+
1268
+ // Relay has `GOSSIP` flag so it can't be removed.
1269
+ return false ;
1270
+ }
1271
+
1272
+ // Relay can be removed
1273
+ true
1274
+ }
1275
+
1168
1276
#[ cfg( test) ]
1169
1277
mod tests {
1170
1278
use nostr_relay_builder:: MockRelay ;
0 commit comments