@@ -2,15 +2,17 @@ use std::sync::{Arc, Mutex};
22
33use anyhow:: { anyhow, Context , Result } ;
44use payjoin:: bitcoin:: consensus:: encode:: serialize_hex;
5- use payjoin:: bitcoin:: psbt:: Psbt ;
65use payjoin:: bitcoin:: { Amount , FeeRate } ;
76use payjoin:: persist:: OptionalTransitionOutcome ;
87use payjoin:: receive:: v2:: {
98 process_err_res, replay_event_log as replay_receiver_event_log, MaybeInputsOwned ,
109 MaybeInputsSeen , OutputsUnknown , PayjoinProposal , ProvisionalProposal , Receiver ,
1110 ReceiverTypeState , SessionHistory , UncheckedProposal , WantsInputs , WantsOutputs , WithContext ,
1211} ;
13- use payjoin:: send:: v2:: { Sender , SenderBuilder , WithReplyKey } ;
12+ use payjoin:: send:: v2:: {
13+ replay_event_log as replay_sender_event_log, Sender , SenderBuilder , SenderTypeState ,
14+ V2GetContext , WithReplyKey ,
15+ } ;
1416use payjoin:: Uri ;
1517use tokio:: sync:: watch;
1618
@@ -50,30 +52,46 @@ impl AppTrait for App {
5052
5153 fn wallet ( & self ) -> BitcoindWallet { self . wallet . clone ( ) }
5254
55+ #[ allow( clippy:: incompatible_msrv) ]
5356 async fn send_payjoin ( & self , bip21 : & str , fee_rate : FeeRate ) -> Result < ( ) > {
5457 use payjoin:: UriExt ;
5558 let uri =
5659 Uri :: try_from ( bip21) . map_err ( |e| anyhow ! ( "Failed to create URI from BIP21: {}" , e) ) ?;
5760 let uri = uri. assume_checked ( ) ;
5861 let uri = uri. check_pj_supported ( ) . map_err ( |_| anyhow ! ( "URI does not support Payjoin" ) ) ?;
5962 let url = uri. extras . endpoint ( ) ;
60- // match bip21 to send_session public_key
61- let req_ctx = match self . db . get_send_session ( url) ? {
62- Some ( send_session) => send_session,
63+ // TODO: perhaps we should store pj uri in the session wrapper as to not replay the event log for each session
64+ let sender_state = self . db . get_send_session_ids ( ) ?. into_iter ( ) . find_map ( |session_id| {
65+ let sender_persister = SenderPersister :: from_id ( self . db . clone ( ) , session_id) . ok ( ) ?;
66+ let replay_results = replay_sender_event_log ( & sender_persister)
67+ . map_err ( |e| anyhow ! ( "Failed to replay sender event log: {:?}" , e) )
68+ . ok ( ) ?;
69+
70+ let pj_uri = replay_results. 1 . endpoint ( ) ;
71+ let sender_state = pj_uri. filter ( |uri| uri == & url) . map ( |_| replay_results. 0 ) ;
72+ sender_state. map ( |sender_state| ( sender_state, sender_persister) )
73+ } ) ;
74+
75+ let ( sender_state, persister) = match sender_state {
76+ Some ( ( sender_state, persister) ) => ( sender_state, persister) ,
6377 None => {
78+ let persister = SenderPersister :: new ( self . db . clone ( ) ) ?;
6479 let psbt = self . create_original_psbt ( & uri, fee_rate) ?;
65- let mut persister = SenderPersister :: new ( self . db . clone ( ) ) ;
66- let new_sender = SenderBuilder :: new ( psbt, uri. clone ( ) )
80+ let sender = SenderBuilder :: new ( psbt, uri. clone ( ) )
6781 . build_recommended ( fee_rate)
68- . with_context ( || "Failed to build payjoin request" ) ?;
69- let storage_token = new_sender
70- . persist ( & mut persister)
71- . map_err ( |e| anyhow ! ( "Failed to persist sender: {}" , e) ) ?;
72- Sender :: load ( storage_token, & persister)
73- . map_err ( |e| anyhow ! ( "Failed to load sender: {}" , e) ) ?
82+ . save ( & persister) ?;
83+
84+ ( SenderTypeState :: WithReplyKey ( sender) , persister)
7485 }
7586 } ;
76- self . spawn_payjoin_sender ( req_ctx) . await
87+ let mut interrupt = self . interrupt . clone ( ) ;
88+ tokio:: select! {
89+ _ = self . process_sender_session( sender_state, & persister) => return Ok ( ( ) ) ,
90+ _ = interrupt. changed( ) => {
91+ println!( "Interrupted. Call `send` with the same arguments to resume this session or `resume` to resume all sessions." ) ;
92+ return Err ( anyhow!( "Interrupted" ) )
93+ }
94+ }
7795 }
7896
7997 async fn receive_payjoin ( & self , amount : Amount ) -> Result < ( ) > {
@@ -104,9 +122,9 @@ impl AppTrait for App {
104122 #[ allow( clippy:: incompatible_msrv) ]
105123 async fn resume_payjoins ( & self ) -> Result < ( ) > {
106124 let recv_session_ids = self . db . get_recv_session_ids ( ) ?;
107- let send_sessions = self . db . get_send_sessions ( ) ?;
125+ let send_session_ids = self . db . get_send_session_ids ( ) ?;
108126
109- if recv_session_ids. is_empty ( ) && send_sessions . is_empty ( ) {
127+ if recv_session_ids. is_empty ( ) && send_session_ids . is_empty ( ) {
110128 println ! ( "No sessions to resume." ) ;
111129 return Ok ( ( ) ) ;
112130 }
@@ -124,9 +142,15 @@ impl AppTrait for App {
124142 } ) ) ;
125143 }
126144
127- for session in send_sessions {
145+ for session_id in send_session_ids {
146+ let sender_persiter = SenderPersister :: from_id ( self . db . clone ( ) , session_id) ?;
147+ let sender_state = replay_sender_event_log ( & sender_persiter)
148+ . map_err ( |e| anyhow ! ( "Failed to replay sender event log: {:?}" , e) ) ?
149+ . 0 ;
128150 let self_clone = self . clone ( ) ;
129- tasks. push ( tokio:: spawn ( async move { self_clone. spawn_payjoin_sender ( session) . await } ) ) ;
151+ tasks. push ( tokio:: spawn ( async move {
152+ self_clone. process_sender_session ( sender_state, & sender_persiter) . await
153+ } ) ) ;
130154 }
131155
132156 let mut interrupt = self . interrupt . clone ( ) ;
@@ -147,58 +171,79 @@ impl AppTrait for App {
147171}
148172
149173impl App {
150- #[ allow( clippy:: incompatible_msrv) ]
151- async fn spawn_payjoin_sender ( & self , mut req_ctx : Sender < WithReplyKey > ) -> Result < ( ) > {
152- let mut interrupt = self . interrupt . clone ( ) ;
153- tokio:: select! {
154- res = self . long_poll_post( & mut req_ctx) => {
155- self . process_pj_response( res?) ?;
156- self . db. clear_send_session( req_ctx. endpoint( ) ) ?;
174+ async fn process_sender_session (
175+ & self ,
176+ session : SenderTypeState ,
177+ persister : & SenderPersister ,
178+ ) -> Result < ( ) > {
179+ match session {
180+ SenderTypeState :: WithReplyKey ( context) => {
181+ // TODO: can we handle the fall back case in `post_original_proposal`. That way we don't have to clone here
182+ match self . post_original_proposal ( context. clone ( ) , persister) . await {
183+ Ok ( ( ) ) => ( ) ,
184+ Err ( _) => {
185+ let ( req, v1_ctx) = context. extract_v1 ( ) ;
186+ let response = post_request ( req) . await ?;
187+ let psbt = Arc :: new (
188+ v1_ctx. process_response ( response. bytes ( ) . await ?. to_vec ( ) . as_slice ( ) ) ?,
189+ ) ;
190+ self . process_pj_response ( ( * psbt) . clone ( ) ) ?;
191+ }
192+ }
193+ return Ok ( ( ) ) ;
157194 }
158- _ = interrupt. changed( ) => {
159- println!( "Interrupted. Call `send` with the same arguments to resume this session or `resume` to resume all sessions." ) ;
195+ SenderTypeState :: V2GetContext ( context) =>
196+ self . get_proposed_payjoin_psbt ( context, persister) . await ?,
197+ SenderTypeState :: ProposalReceived ( proposal) => {
198+ self . process_pj_response ( proposal. clone ( ) ) ?;
199+ return Ok ( ( ) ) ;
160200 }
201+ _ => return Err ( anyhow ! ( "Unexpected sender state" ) ) ,
161202 }
162203 Ok ( ( ) )
163204 }
164205
165- async fn long_poll_post ( & self , req_ctx : & mut Sender < WithReplyKey > ) -> Result < Psbt > {
166- let ohttp_relay = self . unwrap_relay_or_else_fetch ( Some ( req_ctx. endpoint ( ) . clone ( ) ) ) . await ?;
167-
168- match req_ctx. extract_v2 ( ohttp_relay. clone ( ) ) {
169- Ok ( ( req, ctx) ) => {
170- println ! ( "Posting Original PSBT Payload request..." ) ;
171- let response = post_request ( req) . await ?;
172- println ! ( "Sent fallback transaction" ) ;
173- let v2_ctx = Arc :: new ( ctx. process_response ( & response. bytes ( ) . await ?) ?) ;
174- loop {
175- let ( req, ohttp_ctx) = v2_ctx. extract_req ( & ohttp_relay) ?;
176- let response = post_request ( req) . await ?;
177- match v2_ctx. process_response ( & response. bytes ( ) . await ?, ohttp_ctx) {
178- Ok ( Some ( psbt) ) => return Ok ( psbt) ,
179- Ok ( None ) => {
180- println ! ( "No response yet." ) ;
181- }
182- Err ( re) => {
183- println ! ( "{re}" ) ;
184- log:: debug!( "{re:?}" ) ;
185- return Err ( anyhow ! ( "Response error" ) . context ( re) ) ;
186- }
187- }
206+ async fn post_original_proposal (
207+ & self ,
208+ sender : Sender < WithReplyKey > ,
209+ persister : & SenderPersister ,
210+ ) -> Result < ( ) > {
211+ let ( req, ctx) = sender
212+ . extract_v2 ( self . unwrap_relay_or_else_fetch ( Some ( sender. endpoint ( ) . clone ( ) ) ) . await ?) ?;
213+ let response = post_request ( req) . await ?;
214+ println ! ( "Posted original proposal..." ) ;
215+ let sender = sender. process_response ( & response. bytes ( ) . await ?, ctx) . save ( persister) ?;
216+ self . get_proposed_payjoin_psbt ( sender, persister) . await
217+ }
218+
219+ async fn get_proposed_payjoin_psbt (
220+ & self ,
221+ sender : Sender < V2GetContext > ,
222+ persister : & SenderPersister ,
223+ ) -> Result < ( ) > {
224+ let mut session = sender. clone ( ) ;
225+ // Long poll until we get a response
226+ loop {
227+ let ( req, ctx) = session. extract_req (
228+ self . unwrap_relay_or_else_fetch ( Some ( session. endpoint ( ) . clone ( ) ) ) . await ?,
229+ ) ?;
230+ let response = post_request ( req) . await ?;
231+ let res = session. process_response ( & response. bytes ( ) . await ?, ctx) . save ( persister) ;
232+ match res {
233+ Ok ( OptionalTransitionOutcome :: Progress ( psbt) ) => {
234+ println ! ( "Proposal received. Processing..." ) ;
235+ self . process_pj_response ( psbt. clone ( ) ) ?;
236+ return Ok ( ( ) ) ;
188237 }
189- }
190- Err ( _) => {
191- let ( req, v1_ctx) = req_ctx. extract_v1 ( ) ;
192- println ! ( "Posting Original PSBT Payload request..." ) ;
193- let response = post_request ( req) . await ?;
194- println ! ( "Sent fallback transaction" ) ;
195- match v1_ctx. process_response ( & response. bytes ( ) . await ?) {
196- Ok ( psbt) => Ok ( psbt) ,
197- Err ( re) => {
198- println ! ( "{re}" ) ;
199- log:: debug!( "{re:?}" ) ;
200- Err ( anyhow ! ( "Response error" ) . context ( re) )
201- }
238+ Ok ( OptionalTransitionOutcome :: Stasis ( current_state) ) => {
239+ println ! ( "No response yet." ) ;
240+ session = current_state;
241+ continue ;
242+ }
243+ Err ( re) => {
244+ println ! ( "{re}" ) ;
245+ log:: debug!( "{re:?}" ) ;
246+ return Err ( anyhow ! ( "Response error" ) . context ( re) ) ;
202247 }
203248 }
204249 }
0 commit comments