Skip to content

Commit fec74d7

Browse files
authored
Replace Persister with SessionPersister for v2 Sender (payjoin#789)
The following replaces the existing implemention of the `Persister` trait with the new `SessionPersister`.
2 parents 9f403f1 + 3b157b0 commit fec74d7

File tree

18 files changed

+1214
-693
lines changed

18 files changed

+1214
-693
lines changed

payjoin-cli/src/app/v2/mod.rs

Lines changed: 108 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,17 @@ use std::sync::{Arc, Mutex};
22

33
use anyhow::{anyhow, Context, Result};
44
use payjoin::bitcoin::consensus::encode::serialize_hex;
5-
use payjoin::bitcoin::psbt::Psbt;
65
use payjoin::bitcoin::{Amount, FeeRate};
76
use payjoin::persist::OptionalTransitionOutcome;
87
use payjoin::receive::v2::{
98
process_err_res, replay_event_log as replay_receiver_event_log, MaybeInputsOwned,
109
MaybeInputsSeen, OutputsUnknown, PayjoinProposal, ProvisionalProposal, Receiver,
1110
ReceiverTypeState, SessionHistory, UncheckedProposal, WantsInputs, WantsOutputs, WithContext,
1211
};
13-
use payjoin::send::v2::{Sender, SenderBuilder, WithReplyKey};
12+
use payjoin::send::v2::{
13+
replay_event_log as replay_sender_event_log, Sender, SenderBuilder, SenderTypeState,
14+
V2GetContext, WithReplyKey,
15+
};
1416
use payjoin::Uri;
1517
use tokio::sync::watch;
1618

@@ -50,30 +52,46 @@ impl AppTrait for App {
5052

5153
fn wallet(&self) -> BitcoindWallet { self.wallet.clone() }
5254

55+
#[allow(clippy::incompatible_msrv)]
5356
async fn send_payjoin(&self, bip21: &str, fee_rate: FeeRate) -> Result<()> {
5457
use payjoin::UriExt;
5558
let uri =
5659
Uri::try_from(bip21).map_err(|e| anyhow!("Failed to create URI from BIP21: {}", e))?;
5760
let uri = uri.assume_checked();
5861
let uri = uri.check_pj_supported().map_err(|_| anyhow!("URI does not support Payjoin"))?;
5962
let url = uri.extras.endpoint();
60-
// match bip21 to send_session public_key
61-
let req_ctx = match self.db.get_send_session(url)? {
62-
Some(send_session) => send_session,
63+
// TODO: perhaps we should store pj uri in the session wrapper as to not replay the event log for each session
64+
let sender_state = self.db.get_send_session_ids()?.into_iter().find_map(|session_id| {
65+
let sender_persister = SenderPersister::from_id(self.db.clone(), session_id).ok()?;
66+
let replay_results = replay_sender_event_log(&sender_persister)
67+
.map_err(|e| anyhow!("Failed to replay sender event log: {:?}", e))
68+
.ok()?;
69+
70+
let pj_uri = replay_results.1.endpoint();
71+
let sender_state = pj_uri.filter(|uri| uri == &url).map(|_| replay_results.0);
72+
sender_state.map(|sender_state| (sender_state, sender_persister))
73+
});
74+
75+
let (sender_state, persister) = match sender_state {
76+
Some((sender_state, persister)) => (sender_state, persister),
6377
None => {
78+
let persister = SenderPersister::new(self.db.clone())?;
6479
let psbt = self.create_original_psbt(&uri, fee_rate)?;
65-
let mut persister = SenderPersister::new(self.db.clone());
66-
let new_sender = SenderBuilder::new(psbt, uri.clone())
80+
let sender = SenderBuilder::new(psbt, uri.clone())
6781
.build_recommended(fee_rate)
68-
.with_context(|| "Failed to build payjoin request")?;
69-
let storage_token = new_sender
70-
.persist(&mut persister)
71-
.map_err(|e| anyhow!("Failed to persist sender: {}", e))?;
72-
Sender::load(storage_token, &persister)
73-
.map_err(|e| anyhow!("Failed to load sender: {}", e))?
82+
.save(&persister)?;
83+
84+
(SenderTypeState::WithReplyKey(sender), persister)
7485
}
7586
};
76-
self.spawn_payjoin_sender(req_ctx).await
87+
let mut interrupt = self.interrupt.clone();
88+
tokio::select! {
89+
_ = self.process_sender_session(sender_state, &persister) => return Ok(()),
90+
_ = interrupt.changed() => {
91+
println!("Interrupted. Call `send` with the same arguments to resume this session or `resume` to resume all sessions.");
92+
return Err(anyhow!("Interrupted"))
93+
}
94+
}
7795
}
7896

7997
async fn receive_payjoin(&self, amount: Amount) -> Result<()> {
@@ -104,9 +122,9 @@ impl AppTrait for App {
104122
#[allow(clippy::incompatible_msrv)]
105123
async fn resume_payjoins(&self) -> Result<()> {
106124
let recv_session_ids = self.db.get_recv_session_ids()?;
107-
let send_sessions = self.db.get_send_sessions()?;
125+
let send_session_ids = self.db.get_send_session_ids()?;
108126

109-
if recv_session_ids.is_empty() && send_sessions.is_empty() {
127+
if recv_session_ids.is_empty() && send_session_ids.is_empty() {
110128
println!("No sessions to resume.");
111129
return Ok(());
112130
}
@@ -124,9 +142,15 @@ impl AppTrait for App {
124142
}));
125143
}
126144

127-
for session in send_sessions {
145+
for session_id in send_session_ids {
146+
let sender_persiter = SenderPersister::from_id(self.db.clone(), session_id)?;
147+
let sender_state = replay_sender_event_log(&sender_persiter)
148+
.map_err(|e| anyhow!("Failed to replay sender event log: {:?}", e))?
149+
.0;
128150
let self_clone = self.clone();
129-
tasks.push(tokio::spawn(async move { self_clone.spawn_payjoin_sender(session).await }));
151+
tasks.push(tokio::spawn(async move {
152+
self_clone.process_sender_session(sender_state, &sender_persiter).await
153+
}));
130154
}
131155

132156
let mut interrupt = self.interrupt.clone();
@@ -147,58 +171,79 @@ impl AppTrait for App {
147171
}
148172

149173
impl App {
150-
#[allow(clippy::incompatible_msrv)]
151-
async fn spawn_payjoin_sender(&self, mut req_ctx: Sender<WithReplyKey>) -> Result<()> {
152-
let mut interrupt = self.interrupt.clone();
153-
tokio::select! {
154-
res = self.long_poll_post(&mut req_ctx) => {
155-
self.process_pj_response(res?)?;
156-
self.db.clear_send_session(req_ctx.endpoint())?;
174+
async fn process_sender_session(
175+
&self,
176+
session: SenderTypeState,
177+
persister: &SenderPersister,
178+
) -> Result<()> {
179+
match session {
180+
SenderTypeState::WithReplyKey(context) => {
181+
// TODO: can we handle the fall back case in `post_original_proposal`. That way we don't have to clone here
182+
match self.post_original_proposal(context.clone(), persister).await {
183+
Ok(()) => (),
184+
Err(_) => {
185+
let (req, v1_ctx) = context.extract_v1();
186+
let response = post_request(req).await?;
187+
let psbt = Arc::new(
188+
v1_ctx.process_response(response.bytes().await?.to_vec().as_slice())?,
189+
);
190+
self.process_pj_response((*psbt).clone())?;
191+
}
192+
}
193+
return Ok(());
157194
}
158-
_ = interrupt.changed() => {
159-
println!("Interrupted. Call `send` with the same arguments to resume this session or `resume` to resume all sessions.");
195+
SenderTypeState::V2GetContext(context) =>
196+
self.get_proposed_payjoin_psbt(context, persister).await?,
197+
SenderTypeState::ProposalReceived(proposal) => {
198+
self.process_pj_response(proposal.clone())?;
199+
return Ok(());
160200
}
201+
_ => return Err(anyhow!("Unexpected sender state")),
161202
}
162203
Ok(())
163204
}
164205

165-
async fn long_poll_post(&self, req_ctx: &mut Sender<WithReplyKey>) -> Result<Psbt> {
166-
let ohttp_relay = self.unwrap_relay_or_else_fetch(Some(req_ctx.endpoint().clone())).await?;
167-
168-
match req_ctx.extract_v2(ohttp_relay.clone()) {
169-
Ok((req, ctx)) => {
170-
println!("Posting Original PSBT Payload request...");
171-
let response = post_request(req).await?;
172-
println!("Sent fallback transaction");
173-
let v2_ctx = Arc::new(ctx.process_response(&response.bytes().await?)?);
174-
loop {
175-
let (req, ohttp_ctx) = v2_ctx.extract_req(&ohttp_relay)?;
176-
let response = post_request(req).await?;
177-
match v2_ctx.process_response(&response.bytes().await?, ohttp_ctx) {
178-
Ok(Some(psbt)) => return Ok(psbt),
179-
Ok(None) => {
180-
println!("No response yet.");
181-
}
182-
Err(re) => {
183-
println!("{re}");
184-
log::debug!("{re:?}");
185-
return Err(anyhow!("Response error").context(re));
186-
}
187-
}
206+
async fn post_original_proposal(
207+
&self,
208+
sender: Sender<WithReplyKey>,
209+
persister: &SenderPersister,
210+
) -> Result<()> {
211+
let (req, ctx) = sender
212+
.extract_v2(self.unwrap_relay_or_else_fetch(Some(sender.endpoint().clone())).await?)?;
213+
let response = post_request(req).await?;
214+
println!("Posted original proposal...");
215+
let sender = sender.process_response(&response.bytes().await?, ctx).save(persister)?;
216+
self.get_proposed_payjoin_psbt(sender, persister).await
217+
}
218+
219+
async fn get_proposed_payjoin_psbt(
220+
&self,
221+
sender: Sender<V2GetContext>,
222+
persister: &SenderPersister,
223+
) -> Result<()> {
224+
let mut session = sender.clone();
225+
// Long poll until we get a response
226+
loop {
227+
let (req, ctx) = session.extract_req(
228+
self.unwrap_relay_or_else_fetch(Some(session.endpoint().clone())).await?,
229+
)?;
230+
let response = post_request(req).await?;
231+
let res = session.process_response(&response.bytes().await?, ctx).save(persister);
232+
match res {
233+
Ok(OptionalTransitionOutcome::Progress(psbt)) => {
234+
println!("Proposal received. Processing...");
235+
self.process_pj_response(psbt.clone())?;
236+
return Ok(());
188237
}
189-
}
190-
Err(_) => {
191-
let (req, v1_ctx) = req_ctx.extract_v1();
192-
println!("Posting Original PSBT Payload request...");
193-
let response = post_request(req).await?;
194-
println!("Sent fallback transaction");
195-
match v1_ctx.process_response(&response.bytes().await?) {
196-
Ok(psbt) => Ok(psbt),
197-
Err(re) => {
198-
println!("{re}");
199-
log::debug!("{re:?}");
200-
Err(anyhow!("Response error").context(re))
201-
}
238+
Ok(OptionalTransitionOutcome::Stasis(current_state)) => {
239+
println!("No response yet.");
240+
session = current_state;
241+
continue;
242+
}
243+
Err(re) => {
244+
println!("{re}");
245+
log::debug!("{re:?}");
246+
return Err(anyhow!("Response error").context(re));
202247
}
203248
}
204249
}

0 commit comments

Comments
 (0)