@@ -89,6 +89,7 @@ use core::future::Future;
8989use core:: mem;
9090use core:: pin:: Pin ;
9191use core:: sync:: atomic:: { AtomicBool , AtomicUsize , Ordering } ;
92+ use core:: task:: { Context , Poll , Waker } ;
9293use core:: time:: Duration ;
9394
9495use bitcoin:: psbt:: Psbt ;
@@ -856,15 +857,93 @@ impl<Signer: sign::ecdsa::EcdsaChannelSigner> Persist<Signer> for TestPersister
856857 }
857858}
858859
860+ // A simple multi-producer-single-consumer one-shot channel
861+ type OneShotChannelState = Arc < Mutex < ( Option < Result < ( ) , io:: Error > > , Option < Waker > ) > > ;
862+ struct OneShotChannel ( OneShotChannelState ) ;
863+ impl Future for OneShotChannel {
864+ type Output = Result < ( ) , io:: Error > ;
865+ fn poll ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Result < ( ) , io:: Error > > {
866+ let mut state = self . 0 . lock ( ) . unwrap ( ) ;
867+ // If the future is complete, take() the result and return it,
868+ state. 0 . take ( ) . map ( |res| Poll :: Ready ( res) ) . unwrap_or_else ( || {
869+ // otherwise, store the waker so that the future will be poll()ed again when the result
870+ // is ready.
871+ state. 1 = Some ( cx. waker ( ) . clone ( ) ) ;
872+ Poll :: Pending
873+ } )
874+ }
875+ }
876+
877+ /// An in-memory KVStore for testing.
878+ ///
879+ /// Sync writes always complete immediately while async writes always block until manually
880+ /// completed with [`Self::complete_async_writes_through`] or [`Self::complete_all_async_writes`].
881+ ///
882+ /// Removes always complete immediately.
859883pub struct TestStore {
884+ pending_async_writes : Mutex < HashMap < String , Vec < ( usize , OneShotChannelState , Vec < u8 > ) > > > ,
860885 persisted_bytes : Mutex < HashMap < String , HashMap < String , Vec < u8 > > > > ,
861886 read_only : bool ,
862887}
863888
864889impl TestStore {
865890 pub fn new ( read_only : bool ) -> Self {
891+ let pending_async_writes = Mutex :: new ( new_hash_map ( ) ) ;
866892 let persisted_bytes = Mutex :: new ( new_hash_map ( ) ) ;
867- Self { persisted_bytes, read_only }
893+ Self { pending_async_writes, persisted_bytes, read_only }
894+ }
895+
896+ pub fn list_pending_async_writes (
897+ & self , primary_namespace : & str , secondary_namespace : & str , key : & str ,
898+ ) -> Vec < usize > {
899+ let key = format ! ( "{primary_namespace}/{secondary_namespace}/{key}" ) ;
900+ let writes_lock = self . pending_async_writes . lock ( ) . unwrap ( ) ;
901+ writes_lock
902+ . get ( & key)
903+ . map ( |v| v. iter ( ) . map ( |( id, _, _) | * id) . collect ( ) )
904+ . unwrap_or ( Vec :: new ( ) )
905+ }
906+
907+ /// Completes all pending async writes for the given namespace and key, up to and through the
908+ /// given `write_id` (which can be fetched from [`Self::list_pending_async_writes`]).
909+ pub fn complete_async_writes_through (
910+ & self , primary_namespace : & str , secondary_namespace : & str , key : & str , write_id : usize ,
911+ ) {
912+ let prefix = format ! ( "{primary_namespace}/{secondary_namespace}" ) ;
913+ let key = format ! ( "{primary_namespace}/{secondary_namespace}/{key}" ) ;
914+
915+ let mut persisted_lock = self . persisted_bytes . lock ( ) . unwrap ( ) ;
916+ let mut writes_lock = self . pending_async_writes . lock ( ) . unwrap ( ) ;
917+
918+ let pending_writes = writes_lock. get_mut ( & key) . expect ( "No pending writes for given key" ) ;
919+ pending_writes. retain ( |( id, res, data) | {
920+ if * id <= write_id {
921+ let namespace = persisted_lock. entry ( prefix. clone ( ) ) . or_insert ( new_hash_map ( ) ) ;
922+ * namespace. entry ( key. to_string ( ) ) . or_default ( ) = data. clone ( ) ;
923+ let mut future_state = res. lock ( ) . unwrap ( ) ;
924+ future_state. 0 = Some ( Ok ( ( ) ) ) ;
925+ if let Some ( waker) = future_state. 1 . take ( ) {
926+ waker. wake ( ) ;
927+ }
928+ false
929+ } else {
930+ true
931+ }
932+ } ) ;
933+ }
934+
935+ /// Completes all pending async writes on all namespaces and keys.
936+ pub fn complete_all_async_writes ( & self ) {
937+ let pending_writes: Vec < String > =
938+ self . pending_async_writes . lock ( ) . unwrap ( ) . keys ( ) . cloned ( ) . collect ( ) ;
939+ for key in pending_writes {
940+ let mut levels = key. split ( "/" ) ;
941+ let primary = levels. next ( ) . unwrap ( ) ;
942+ let secondary = levels. next ( ) . unwrap ( ) ;
943+ let key = levels. next ( ) . unwrap ( ) ;
944+ assert ! ( levels. next( ) . is_none( ) ) ;
945+ self . complete_async_writes_through ( primary, secondary, key, usize:: MAX ) ;
946+ }
868947 }
869948
870949 fn read_internal (
@@ -885,23 +964,6 @@ impl TestStore {
885964 }
886965 }
887966
888- fn write_internal (
889- & self , primary_namespace : & str , secondary_namespace : & str , key : & str , buf : Vec < u8 > ,
890- ) -> io:: Result < ( ) > {
891- if self . read_only {
892- return Err ( io:: Error :: new (
893- io:: ErrorKind :: PermissionDenied ,
894- "Cannot modify read-only store" ,
895- ) ) ;
896- }
897- let mut persisted_lock = self . persisted_bytes . lock ( ) . unwrap ( ) ;
898-
899- let prefixed = format ! ( "{primary_namespace}/{secondary_namespace}" ) ;
900- let outer_e = persisted_lock. entry ( prefixed) . or_insert ( new_hash_map ( ) ) ;
901- outer_e. insert ( key. to_string ( ) , buf) ;
902- Ok ( ( ) )
903- }
904-
905967 fn remove_internal (
906968 & self , primary_namespace : & str , secondary_namespace : & str , key : & str , _lazy : bool ,
907969 ) -> io:: Result < ( ) > {
@@ -913,12 +975,23 @@ impl TestStore {
913975 }
914976
915977 let mut persisted_lock = self . persisted_bytes . lock ( ) . unwrap ( ) ;
978+ let mut async_writes_lock = self . pending_async_writes . lock ( ) . unwrap ( ) ;
916979
917980 let prefixed = format ! ( "{primary_namespace}/{secondary_namespace}" ) ;
918981 if let Some ( outer_ref) = persisted_lock. get_mut ( & prefixed) {
919982 outer_ref. remove ( & key. to_string ( ) ) ;
920983 }
921984
985+ if let Some ( pending_writes) = async_writes_lock. remove ( & format ! ( "{prefixed}/{key}" ) ) {
986+ for ( _, future, _) in pending_writes {
987+ let mut future_lock = future. lock ( ) . unwrap ( ) ;
988+ future_lock. 0 = Some ( Ok ( ( ) ) ) ;
989+ if let Some ( waker) = future_lock. 1 . take ( ) {
990+ waker. wake ( ) ;
991+ }
992+ }
993+ }
994+
922995 Ok ( ( ) )
923996 }
924997
@@ -945,8 +1018,15 @@ impl KVStore for TestStore {
9451018 fn write (
9461019 & self , primary_namespace : & str , secondary_namespace : & str , key : & str , buf : Vec < u8 > ,
9471020 ) -> Pin < Box < dyn Future < Output = Result < ( ) , io:: Error > > + ' static + Send > > {
948- let res = self . write_internal ( & primary_namespace, & secondary_namespace, & key, buf) ;
949- Box :: pin ( async move { res } )
1021+ let path = format ! ( "{primary_namespace}/{secondary_namespace}/{key}" ) ;
1022+ let future = Arc :: new ( Mutex :: new ( ( None , None ) ) ) ;
1023+
1024+ let mut async_writes_lock = self . pending_async_writes . lock ( ) . unwrap ( ) ;
1025+ let pending_writes = async_writes_lock. entry ( path) . or_insert ( Vec :: new ( ) ) ;
1026+ let new_id = pending_writes. last ( ) . map ( |( id, _, _) | id + 1 ) . unwrap_or ( 0 ) ;
1027+ pending_writes. push ( ( new_id, Arc :: clone ( & future) , buf) ) ;
1028+
1029+ Box :: pin ( OneShotChannel ( future) )
9501030 }
9511031 fn remove (
9521032 & self , primary_namespace : & str , secondary_namespace : & str , key : & str , lazy : bool ,
@@ -972,7 +1052,30 @@ impl KVStoreSync for TestStore {
9721052 fn write (
9731053 & self , primary_namespace : & str , secondary_namespace : & str , key : & str , buf : Vec < u8 > ,
9741054 ) -> io:: Result < ( ) > {
975- self . write_internal ( primary_namespace, secondary_namespace, key, buf)
1055+ if self . read_only {
1056+ return Err ( io:: Error :: new (
1057+ io:: ErrorKind :: PermissionDenied ,
1058+ "Cannot modify read-only store" ,
1059+ ) ) ;
1060+ }
1061+ let mut persisted_lock = self . persisted_bytes . lock ( ) . unwrap ( ) ;
1062+ let mut async_writes_lock = self . pending_async_writes . lock ( ) . unwrap ( ) ;
1063+
1064+ let prefixed = format ! ( "{primary_namespace}/{secondary_namespace}" ) ;
1065+ let async_writes_pending = async_writes_lock. remove ( & format ! ( "{prefixed}/{key}" ) ) ;
1066+ let outer_e = persisted_lock. entry ( prefixed) . or_insert ( new_hash_map ( ) ) ;
1067+ outer_e. insert ( key. to_string ( ) , buf) ;
1068+
1069+ if let Some ( pending_writes) = async_writes_pending {
1070+ for ( _, future, _) in pending_writes {
1071+ let mut future_lock = future. lock ( ) . unwrap ( ) ;
1072+ future_lock. 0 = Some ( Ok ( ( ) ) ) ;
1073+ if let Some ( waker) = future_lock. 1 . take ( ) {
1074+ waker. wake ( ) ;
1075+ }
1076+ }
1077+ }
1078+ Ok ( ( ) )
9761079 }
9771080
9781081 fn remove (
0 commit comments