Skip to content

Commit 58a6a52

Browse files
committed
feat!: add context to checks
BREAKING CHANGE Signed-off-by: Gustavo Inacio <[email protected]>
1 parent f49c21d commit 58a6a52

File tree

11 files changed

+79
-49
lines changed

11 files changed

+79
-49
lines changed

tap_core/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ anyhow.workspace = true
1515
rand.workspace = true
1616
thiserror = "1.0.38"
1717
async-trait = "0.1.72"
18+
anymap3 = "1.0.0"
1819

1920
[dev-dependencies]
2021
criterion = { version = "0.5", features = ["async_std"] }

tap_core/src/manager/context/memory.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ pub mod checks {
263263
receipt::{
264264
checks::{Check, CheckError, CheckResult, ReceiptCheck},
265265
state::Checking,
266-
ReceiptError, ReceiptWithState,
266+
Context, ReceiptError, ReceiptWithState,
267267
},
268268
signed_message::MessageId,
269269
};
@@ -296,7 +296,7 @@ pub mod checks {
296296

297297
#[async_trait::async_trait]
298298
impl Check for AllocationIdCheck {
299-
async fn check(&self, receipt: &ReceiptWithState<Checking>) -> CheckResult {
299+
async fn check(&self, _: &Context, receipt: &ReceiptWithState<Checking>) -> CheckResult {
300300
let received_allocation_id = receipt.signed_receipt().message.allocation_id;
301301
if self
302302
.allocation_ids
@@ -323,7 +323,7 @@ pub mod checks {
323323

324324
#[async_trait::async_trait]
325325
impl Check for SignatureCheck {
326-
async fn check(&self, receipt: &ReceiptWithState<Checking>) -> CheckResult {
326+
async fn check(&self, _: &Context, receipt: &ReceiptWithState<Checking>) -> CheckResult {
327327
let recovered_address = receipt
328328
.signed_receipt()
329329
.recover_signer(&self.domain_separator)

tap_core/src/manager/mod.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@
3939
//! ReceiptWithState,
4040
//! state::Checking,
4141
//! checks::CheckList,
42-
//! ReceiptError
42+
//! ReceiptError,
43+
//! Context
4344
//! },
4445
//! manager::{
4546
//! Manager,
@@ -70,7 +71,7 @@
7071
//! let receipt = EIP712SignedMessage::new(&domain_separator, message, &wallet).unwrap();
7172
//!
7273
//! let manager = Manager::new(domain_separator, MyContext, CheckList::empty());
73-
//! manager.verify_and_store_receipt(receipt).await.unwrap()
74+
//! manager.verify_and_store_receipt(&Context::new(), receipt).await.unwrap()
7475
//! # }
7576
//! ```
7677
//!

tap_core/src/manager/tap_manager.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use crate::{
99
receipt::{
1010
checks::{CheckBatch, CheckList, TimestampCheck, UniqueCheck},
1111
state::{Failed, Reserved},
12-
ReceiptError, ReceiptWithState, SignedReceipt,
12+
Context, ReceiptError, ReceiptWithState, SignedReceipt,
1313
},
1414
Error,
1515
};
@@ -99,6 +99,7 @@ where
9999
{
100100
async fn collect_receipts(
101101
&self,
102+
ctx: &Context,
102103
timestamp_buffer_ns: u64,
103104
min_timestamp_ns: u64,
104105
limit: Option<u64>,
@@ -140,7 +141,7 @@ where
140141

141142
for receipt in checking_receipts.into_iter() {
142143
let receipt = receipt
143-
.finalize_receipt_checks(&self.checks)
144+
.finalize_receipt_checks(ctx, &self.checks)
144145
.await
145146
.map_err(|e| Error::ReceiptError(ReceiptError::RetryableCheck(e)))?;
146147

@@ -184,6 +185,7 @@ where
184185
///
185186
pub async fn create_rav_request(
186187
&self,
188+
ctx: &Context,
187189
timestamp_buffer_ns: u64,
188190
receipts_limit: Option<u64>,
189191
) -> Result<RAVRequest, Error> {
@@ -194,7 +196,7 @@ where
194196
.unwrap_or(0);
195197

196198
let (valid_receipts, invalid_receipts) = self
197-
.collect_receipts(timestamp_buffer_ns, min_timestamp_ns, receipts_limit)
199+
.collect_receipts(ctx, timestamp_buffer_ns, min_timestamp_ns, receipts_limit)
198200
.await?;
199201

200202
let expected_rav = Self::generate_expected_rav(&valid_receipts, previous_rav.clone());
@@ -271,12 +273,13 @@ where
271273
///
272274
pub async fn verify_and_store_receipt(
273275
&self,
276+
ctx: &Context,
274277
signed_receipt: SignedReceipt,
275278
) -> std::result::Result<(), Error> {
276279
let mut received_receipt = ReceiptWithState::new(signed_receipt);
277280

278281
// perform checks
279-
received_receipt.perform_checks(&self.checks).await?;
282+
received_receipt.perform_checks(ctx, &self.checks).await?;
280283

281284
// store the receipt
282285
self.context

tap_core/src/rav.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
//! 1. Create a [`RAVRequest`] with the valid receipts and the previous RAV.
2727
//! 2. Send the request to the aggregator.
2828
//! 3. The aggregator will verify the request and increment the total amount that
29-
//! has been aggregated.
29+
//! has been aggregated.
3030
//! 4. The aggregator will return a [`SignedRAV`].
3131
//! 5. Store the [`SignedRAV`].
3232
//! 6. Repeat the process until the allocation is closed.

tap_core/src/receipt/checks.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@
1212
//! # use std::sync::Arc;
1313
//! use tap_core::{
1414
//! receipt::checks::{Check, CheckResult, ReceiptCheck},
15-
//! receipt::{ReceiptWithState, state::Checking}
15+
//! receipt::{Context, ReceiptWithState, state::Checking}
1616
//! };
1717
//! # use async_trait::async_trait;
1818
//!
1919
//! struct MyCheck;
2020
//!
2121
//! #[async_trait]
2222
//! impl Check for MyCheck {
23-
//! async fn check(&self, receipt: &ReceiptWithState<Checking>) -> CheckResult {
23+
//! async fn check(&self, ctx: &Context, receipt: &ReceiptWithState<Checking>) -> CheckResult {
2424
//! // Implement your check here
2525
//! Ok(())
2626
//! }
@@ -33,7 +33,7 @@ use crate::signed_message::{SignatureBytes, SignatureBytesExt};
3333

3434
use super::{
3535
state::{Checking, Failed},
36-
ReceiptError, ReceiptWithState,
36+
Context, ReceiptError, ReceiptWithState,
3737
};
3838
use std::{
3939
collections::HashSet,
@@ -80,7 +80,7 @@ impl Deref for CheckList {
8080
/// Check trait is implemented by the lib user to validate receipts before they are stored.
8181
#[async_trait::async_trait]
8282
pub trait Check {
83-
async fn check(&self, receipt: &ReceiptWithState<Checking>) -> CheckResult;
83+
async fn check(&self, ctx: &Context, receipt: &ReceiptWithState<Checking>) -> CheckResult;
8484
}
8585

8686
/// CheckBatch is mostly used by the lib to implement checks
@@ -119,7 +119,7 @@ impl StatefulTimestampCheck {
119119

120120
#[async_trait::async_trait]
121121
impl Check for StatefulTimestampCheck {
122-
async fn check(&self, receipt: &ReceiptWithState<Checking>) -> CheckResult {
122+
async fn check(&self, _: &Context, receipt: &ReceiptWithState<Checking>) -> CheckResult {
123123
let min_timestamp_ns = *self.min_timestamp_ns.read().unwrap();
124124
let signed_receipt = receipt.signed_receipt();
125125
if signed_receipt.message.timestamp_ns <= min_timestamp_ns {

tap_core/src/receipt/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,5 @@ pub type SignedReceipt = EIP712SignedMessage<Receipt>;
3636

3737
/// Result type for receipt
3838
pub type ReceiptResult<T> = Result<T, ReceiptError>;
39+
40+
pub type Context = anymap3::Map<dyn std::any::Any + Send + Sync>;

tap_core/src/receipt/received_receipt.rs

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
use alloy::dyn_abi::Eip712Domain;
1717

1818
use super::checks::CheckError;
19-
use super::{Receipt, ReceiptError, ReceiptResult, SignedReceipt};
19+
use super::{Context, Receipt, ReceiptError, ReceiptResult, SignedReceipt};
2020
use crate::receipt::state::{AwaitingReserve, Checking, Failed, ReceiptState, Reserved};
2121
use crate::{
2222
manager::adapters::EscrowHandler, receipt::checks::ReceiptCheck,
@@ -28,16 +28,15 @@ pub type ResultReceipt<S> = std::result::Result<ReceiptWithState<S>, ReceiptWith
2828
/// Typestate pattern for tracking the state of a receipt
2929
///
3030
/// - The [ `ReceiptState` ] trait represents the different states a receipt
31-
/// can be in.
31+
/// can be in.
3232
/// - The [ `Checking` ] state is used to represent a receipt that is currently
33-
/// being checked.
33+
/// being checked.
3434
/// - The [ `Failed` ] state is used to represent a receipt that has failed a
35-
/// check or validation.
35+
/// check or validation.
3636
/// - The [ `AwaitingReserve` ] state is used to represent a receipt that has
37-
/// passed all checks and is
38-
/// awaiting escrow reservation.
37+
/// passed all checks and is awaiting escrow reservation.
3938
/// - The [ `Reserved` ] state is used to represent a receipt that has
40-
/// successfully reserved escrow.
39+
/// successfully reserved escrow.
4140
#[derive(Debug, Clone)]
4241
pub struct ReceiptWithState<S>
4342
where
@@ -90,10 +89,14 @@ impl ReceiptWithState<Checking> {
9089
/// cannot be comleted in the receipts current internal state.
9190
/// All other checks must be complete before `CheckAndReserveEscrow`.
9291
///
93-
pub async fn perform_checks(&mut self, checks: &[ReceiptCheck]) -> ReceiptResult<()> {
92+
pub async fn perform_checks(
93+
&mut self,
94+
ctx: &Context,
95+
checks: &[ReceiptCheck],
96+
) -> ReceiptResult<()> {
9497
for check in checks {
9598
// return early on an error
96-
check.check(self).await.map_err(|e| match e {
99+
check.check(ctx, self).await.map_err(|e| match e {
97100
CheckError::Retryable(e) => ReceiptError::RetryableCheck(e.to_string()),
98101
CheckError::Failed(e) => ReceiptError::CheckFailure(e.to_string()),
99102
})?;
@@ -108,9 +111,10 @@ impl ReceiptWithState<Checking> {
108111
///
109112
pub async fn finalize_receipt_checks(
110113
mut self,
114+
ctx: &Context,
111115
checks: &[ReceiptCheck],
112116
) -> Result<ResultReceipt<AwaitingReserve>, String> {
113-
let all_checks_passed = self.perform_checks(checks).await;
117+
let all_checks_passed = self.perform_checks(ctx, checks).await;
114118
if let Err(ReceiptError::RetryableCheck(e)) = all_checks_passed {
115119
Err(e.to_string())
116120
} else if let Err(e) = all_checks_passed {

tap_core/tests/manager_test.rs

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ use tap_core::{
2727
receipt::{
2828
checks::{Check, CheckError, CheckList, StatefulTimestampCheck},
2929
state::Checking,
30-
Receipt, ReceiptWithState,
30+
Context, Receipt, ReceiptWithState,
3131
},
3232
signed_message::EIP712SignedMessage,
3333
tap_eip712_domain,
@@ -145,7 +145,7 @@ async fn manager_verify_and_store_varying_initial_checks(
145145
.insert(signer.address(), 999999);
146146

147147
assert!(manager
148-
.verify_and_store_receipt(signed_receipt)
148+
.verify_and_store_receipt(&Context::new(), signed_receipt)
149149
.await
150150
.is_ok());
151151
}
@@ -184,11 +184,11 @@ async fn manager_create_rav_request_all_valid_receipts(
184184
stored_signed_receipts.push(signed_receipt.clone());
185185
query_appraisals.write().unwrap().insert(query_id, value);
186186
assert!(manager
187-
.verify_and_store_receipt(signed_receipt)
187+
.verify_and_store_receipt(&Context::new(), signed_receipt)
188188
.await
189189
.is_ok());
190190
}
191-
let rav_request_result = manager.create_rav_request(0, None).await;
191+
let rav_request_result = manager.create_rav_request(&Context::new(), 0, None).await;
192192
assert!(rav_request_result.is_ok());
193193

194194
let rav_request = rav_request_result.unwrap();
@@ -279,12 +279,12 @@ async fn manager_create_multiple_rav_requests_all_valid_receipts(
279279
stored_signed_receipts.push(signed_receipt.clone());
280280
query_appraisals.write().unwrap().insert(query_id, value);
281281
assert!(manager
282-
.verify_and_store_receipt(signed_receipt)
282+
.verify_and_store_receipt(&Context::new(), signed_receipt)
283283
.await
284284
.is_ok());
285285
expected_accumulated_value += value;
286286
}
287-
let rav_request_result = manager.create_rav_request(0, None).await;
287+
let rav_request_result = manager.create_rav_request(&Context::new(), 0, None).await;
288288
assert!(rav_request_result.is_ok());
289289

290290
let rav_request = rav_request_result.unwrap();
@@ -323,12 +323,12 @@ async fn manager_create_multiple_rav_requests_all_valid_receipts(
323323
stored_signed_receipts.push(signed_receipt.clone());
324324
query_appraisals.write().unwrap().insert(query_id, value);
325325
assert!(manager
326-
.verify_and_store_receipt(signed_receipt)
326+
.verify_and_store_receipt(&Context::new(), signed_receipt)
327327
.await
328328
.is_ok());
329329
expected_accumulated_value += value;
330330
}
331-
let rav_request_result = manager.create_rav_request(0, None).await;
331+
let rav_request_result = manager.create_rav_request(&Context::new(), 0, None).await;
332332
assert!(rav_request_result.is_ok());
333333

334334
let rav_request = rav_request_result.unwrap();
@@ -391,7 +391,7 @@ async fn manager_create_multiple_rav_requests_all_valid_receipts_consecutive_tim
391391
stored_signed_receipts.push(signed_receipt.clone());
392392
query_appraisals.write().unwrap().insert(query_id, value);
393393
assert!(manager
394-
.verify_and_store_receipt(signed_receipt)
394+
.verify_and_store_receipt(&Context::new(), signed_receipt)
395395
.await
396396
.is_ok());
397397
expected_accumulated_value += value;
@@ -403,7 +403,7 @@ async fn manager_create_multiple_rav_requests_all_valid_receipts_consecutive_tim
403403
manager.remove_obsolete_receipts().await.unwrap();
404404
}
405405

406-
let rav_request_1_result = manager.create_rav_request(0, None).await;
406+
let rav_request_1_result = manager.create_rav_request(&Context::new(), 0, None).await;
407407
assert!(rav_request_1_result.is_ok());
408408

409409
let rav_request_1 = rav_request_1_result.unwrap();
@@ -438,7 +438,7 @@ async fn manager_create_multiple_rav_requests_all_valid_receipts_consecutive_tim
438438
stored_signed_receipts.push(signed_receipt.clone());
439439
query_appraisals.write().unwrap().insert(query_id, value);
440440
assert!(manager
441-
.verify_and_store_receipt(signed_receipt)
441+
.verify_and_store_receipt(&Context::new(), signed_receipt)
442442
.await
443443
.is_ok());
444444
expected_accumulated_value += value;
@@ -458,7 +458,7 @@ async fn manager_create_multiple_rav_requests_all_valid_receipts_consecutive_tim
458458
);
459459
}
460460

461-
let rav_request_2_result = manager.create_rav_request(0, None).await;
461+
let rav_request_2_result = manager.create_rav_request(&Context::new(), 0, None).await;
462462
assert!(rav_request_2_result.is_ok());
463463

464464
let rav_request_2 = rav_request_2_result.unwrap();
@@ -518,12 +518,15 @@ async fn manager_create_rav_and_ignore_invalid_receipts(
518518
let signed_receipt = EIP712SignedMessage::new(&domain_separator, receipt, &signer).unwrap();
519519
stored_signed_receipts.push(signed_receipt.clone());
520520
manager
521-
.verify_and_store_receipt(signed_receipt)
521+
.verify_and_store_receipt(&Context::new(), signed_receipt)
522522
.await
523523
.unwrap();
524524
}
525525

526-
let rav_request = manager.create_rav_request(0, None).await.unwrap();
526+
let rav_request = manager
527+
.create_rav_request(&Context::new(), 0, None)
528+
.await
529+
.unwrap();
527530
let expected_rav = rav_request.expected_rav.unwrap();
528531

529532
assert_eq!(rav_request.valid_receipts.len(), 1);
@@ -544,7 +547,11 @@ async fn test_retryable_checks(
544547

545548
#[async_trait::async_trait]
546549
impl Check for RetryableCheck {
547-
async fn check(&self, receipt: &ReceiptWithState<Checking>) -> Result<(), CheckError> {
550+
async fn check(
551+
&self,
552+
_: &Context,
553+
receipt: &ReceiptWithState<Checking>,
554+
) -> Result<(), CheckError> {
548555
// we want to fail only if nonce is 5 and if is create rav step
549556
if self.0.load(std::sync::atomic::Ordering::SeqCst)
550557
&& receipt.signed_receipt().message.nonce == 5
@@ -591,14 +598,14 @@ async fn test_retryable_checks(
591598
let signed_receipt = EIP712SignedMessage::new(&domain_separator, receipt, &signer).unwrap();
592599
stored_signed_receipts.push(signed_receipt.clone());
593600
manager
594-
.verify_and_store_receipt(signed_receipt)
601+
.verify_and_store_receipt(&Context::new(), signed_receipt)
595602
.await
596603
.unwrap();
597604
}
598605

599606
is_create_rav.store(true, std::sync::atomic::Ordering::SeqCst);
600607

601-
let rav_request = manager.create_rav_request(0, None).await;
608+
let rav_request = manager.create_rav_request(&Context::new(), 0, None).await;
602609

603610
assert_eq!(
604611
rav_request.expect_err("Didn't fail").to_string(),

0 commit comments

Comments
 (0)