33
44use std:: collections:: hash_set;
55
6+ use alloy_primitives:: Address ;
7+ use alloy_sol_types:: Eip712Domain ;
68use anyhow:: { Ok , Result } ;
7- use ethers_core:: types:: { Address , Signature } ;
9+ use ethers_core:: types:: Signature ;
810use ethers_signers:: { LocalWallet , Signer } ;
911
1012use tap_core:: {
@@ -13,21 +15,26 @@ use tap_core::{
1315} ;
1416
1517pub async fn check_and_aggregate_receipts (
18+ domain_separator : & Eip712Domain ,
1619 receipts : & [ EIP712SignedMessage < Receipt > ] ,
1720 previous_rav : Option < EIP712SignedMessage < ReceiptAggregateVoucher > > ,
1821 wallet : & LocalWallet ,
1922) -> Result < EIP712SignedMessage < ReceiptAggregateVoucher > > {
23+ // Get the address of the wallet
24+ let address: [ u8 ; 20 ] = wallet. address ( ) . into ( ) ;
25+ let address: Address = address. into ( ) ;
26+
2027 // Check that the receipts are unique
2128 check_signatures_unique ( receipts) ?;
2229
2330 // Check that the receipts are signed by ourselves
2431 receipts
2532 . iter ( )
26- . try_for_each ( |receipt| receipt. verify ( wallet . address ( ) ) ) ?;
33+ . try_for_each ( |receipt| receipt. verify ( domain_separator , address) ) ?;
2734
2835 // Check that the previous rav is signed by ourselves
2936 if let Some ( previous_rav) = & previous_rav {
30- previous_rav. verify ( wallet . address ( ) ) ?;
37+ previous_rav. verify ( domain_separator , address) ?;
3138 }
3239
3340 // Check that the receipts timestamp is greater than the previous rav
@@ -58,7 +65,7 @@ pub async fn check_and_aggregate_receipts(
5865 let rav = ReceiptAggregateVoucher :: aggregate_receipts ( allocation_id, receipts, previous_rav) ?;
5966
6067 // Sign the rav and return
61- Ok ( EIP712SignedMessage :: new ( rav, wallet) . await ?)
68+ Ok ( EIP712SignedMessage :: new ( domain_separator , rav, wallet) . await ?)
6269}
6370
6471fn check_allocation_id (
@@ -109,7 +116,8 @@ fn check_receipt_timestamps(
109116mod tests {
110117 use std:: str:: FromStr ;
111118
112- use ethers_core:: types:: Address ;
119+ use alloy_primitives:: Address ;
120+ use alloy_sol_types:: { eip712_domain, Eip712Domain } ;
113121 use ethers_signers:: { coins_bip39:: English , LocalWallet , MnemonicBuilder , Signer } ;
114122 use rstest:: * ;
115123
@@ -122,8 +130,8 @@ mod tests {
122130 . phrase ( "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about" )
123131 . build ( )
124132 . unwrap ( ) ;
125- let address = wallet. address ( ) ;
126- ( wallet, address)
133+ let address: [ u8 ; 20 ] = wallet. address ( ) . into ( ) ;
134+ ( wallet, address. into ( ) )
127135 }
128136
129137 #[ fixture]
@@ -136,18 +144,32 @@ mod tests {
136144 ]
137145 }
138146
147+ #[ fixture]
148+ fn domain_separator ( ) -> Eip712Domain {
149+ eip712_domain ! {
150+ name: "TAP" ,
151+ version: "1" ,
152+ chain_id: 1 ,
153+ verifying_contract: Address :: from( [ 0x11u8 ; 20 ] ) ,
154+ }
155+ }
156+
139157 #[ rstest]
140158 #[ tokio:: test]
141159 async fn check_signatures_unique_fail (
142160 keys : ( LocalWallet , Address ) ,
143161 allocation_ids : Vec < Address > ,
162+ domain_separator : Eip712Domain ,
144163 ) {
145164 // Create the same receipt twice (replay attack)
146165 let mut receipts = Vec :: new ( ) ;
147- let receipt =
148- EIP712SignedMessage :: new ( Receipt :: new ( allocation_ids[ 0 ] , 42 ) . unwrap ( ) , & keys. 0 )
149- . await
150- . unwrap ( ) ;
166+ let receipt = EIP712SignedMessage :: new (
167+ & domain_separator,
168+ Receipt :: new ( allocation_ids[ 0 ] , 42 ) . unwrap ( ) ,
169+ & keys. 0 ,
170+ )
171+ . await
172+ . unwrap ( ) ;
151173 receipts. push ( receipt. clone ( ) ) ;
152174 receipts. push ( receipt) ;
153175
@@ -160,18 +182,27 @@ mod tests {
160182 async fn check_signatures_unique_ok (
161183 keys : ( LocalWallet , Address ) ,
162184 allocation_ids : Vec < Address > ,
185+ domain_separator : Eip712Domain ,
163186 ) {
164187 // Create 2 different receipts
165188 let mut receipts = Vec :: new ( ) ;
166189 receipts. push (
167- EIP712SignedMessage :: new ( Receipt :: new ( allocation_ids[ 0 ] , 42 ) . unwrap ( ) , & keys. 0 )
168- . await
169- . unwrap ( ) ,
190+ EIP712SignedMessage :: new (
191+ & domain_separator,
192+ Receipt :: new ( allocation_ids[ 0 ] , 42 ) . unwrap ( ) ,
193+ & keys. 0 ,
194+ )
195+ . await
196+ . unwrap ( ) ,
170197 ) ;
171198 receipts. push (
172- EIP712SignedMessage :: new ( Receipt :: new ( allocation_ids[ 0 ] , 43 ) . unwrap ( ) , & keys. 0 )
173- . await
174- . unwrap ( ) ,
199+ EIP712SignedMessage :: new (
200+ & domain_separator,
201+ Receipt :: new ( allocation_ids[ 0 ] , 43 ) . unwrap ( ) ,
202+ & keys. 0 ,
203+ )
204+ . await
205+ . unwrap ( ) ,
175206 ) ;
176207
177208 let res = aggregator:: check_signatures_unique ( & receipts) ;
@@ -181,13 +212,18 @@ mod tests {
181212 #[ rstest]
182213 #[ tokio:: test]
183214 /// Test that a receipt with a timestamp greater then the rav timestamp passes
184- async fn check_receipt_timestamps ( keys : ( LocalWallet , Address ) , allocation_ids : Vec < Address > ) {
215+ async fn check_receipt_timestamps (
216+ keys : ( LocalWallet , Address ) ,
217+ allocation_ids : Vec < Address > ,
218+ domain_separator : Eip712Domain ,
219+ ) {
185220 // Create receipts with consecutive timestamps
186221 let receipt_timestamp_range = 10 ..20 ;
187222 let mut receipts = Vec :: new ( ) ;
188223 for i in receipt_timestamp_range. clone ( ) {
189224 receipts. push (
190225 EIP712SignedMessage :: new (
226+ & domain_separator,
191227 Receipt {
192228 allocation_id : allocation_ids[ 0 ] ,
193229 timestamp_ns : i,
@@ -203,6 +239,7 @@ mod tests {
203239
204240 // Create rav with max_timestamp below the receipts timestamps
205241 let rav = EIP712SignedMessage :: new (
242+ & domain_separator,
206243 tap_core:: receipt_aggregate_voucher:: ReceiptAggregateVoucher {
207244 allocation_id : allocation_ids[ 0 ] ,
208245 timestamp_ns : receipt_timestamp_range. clone ( ) . min ( ) . unwrap ( ) - 1 ,
@@ -217,6 +254,7 @@ mod tests {
217254 // Create rav with max_timestamp equal to the lowest receipt timestamp
218255 // Aggregation should fail
219256 let rav = EIP712SignedMessage :: new (
257+ & domain_separator,
220258 tap_core:: receipt_aggregate_voucher:: ReceiptAggregateVoucher {
221259 allocation_id : allocation_ids[ 0 ] ,
222260 timestamp_ns : receipt_timestamp_range. clone ( ) . min ( ) . unwrap ( ) ,
@@ -231,6 +269,7 @@ mod tests {
231269 // Create rav with max_timestamp above highest receipt timestamp
232270 // Aggregation should fail
233271 let rav = EIP712SignedMessage :: new (
272+ & domain_separator,
234273 tap_core:: receipt_aggregate_voucher:: ReceiptAggregateVoucher {
235274 allocation_id : allocation_ids[ 0 ] ,
236275 timestamp_ns : receipt_timestamp_range. clone ( ) . max ( ) . unwrap ( ) + 1 ,
@@ -247,22 +286,38 @@ mod tests {
247286 #[ tokio:: test]
248287 /// Test check_allocation_id with 2 receipts that have the correct allocation id
249288 /// and 1 receipt that has the wrong allocation id
250- async fn check_allocation_id_fail ( keys : ( LocalWallet , Address ) , allocation_ids : Vec < Address > ) {
289+ async fn check_allocation_id_fail (
290+ keys : ( LocalWallet , Address ) ,
291+ allocation_ids : Vec < Address > ,
292+ domain_separator : Eip712Domain ,
293+ ) {
251294 let mut receipts = Vec :: new ( ) ;
252295 receipts. push (
253- EIP712SignedMessage :: new ( Receipt :: new ( allocation_ids[ 0 ] , 42 ) . unwrap ( ) , & keys. 0 )
254- . await
255- . unwrap ( ) ,
296+ EIP712SignedMessage :: new (
297+ & domain_separator,
298+ Receipt :: new ( allocation_ids[ 0 ] , 42 ) . unwrap ( ) ,
299+ & keys. 0 ,
300+ )
301+ . await
302+ . unwrap ( ) ,
256303 ) ;
257304 receipts. push (
258- EIP712SignedMessage :: new ( Receipt :: new ( allocation_ids[ 0 ] , 43 ) . unwrap ( ) , & keys. 0 )
259- . await
260- . unwrap ( ) ,
305+ EIP712SignedMessage :: new (
306+ & domain_separator,
307+ Receipt :: new ( allocation_ids[ 0 ] , 43 ) . unwrap ( ) ,
308+ & keys. 0 ,
309+ )
310+ . await
311+ . unwrap ( ) ,
261312 ) ;
262313 receipts. push (
263- EIP712SignedMessage :: new ( Receipt :: new ( allocation_ids[ 1 ] , 44 ) . unwrap ( ) , & keys. 0 )
264- . await
265- . unwrap ( ) ,
314+ EIP712SignedMessage :: new (
315+ & domain_separator,
316+ Receipt :: new ( allocation_ids[ 1 ] , 44 ) . unwrap ( ) ,
317+ & keys. 0 ,
318+ )
319+ . await
320+ . unwrap ( ) ,
266321 ) ;
267322
268323 let res = aggregator:: check_allocation_id ( & receipts, allocation_ids[ 0 ] ) ;
@@ -273,22 +328,38 @@ mod tests {
273328 #[ rstest]
274329 #[ tokio:: test]
275330 /// Test check_allocation_id with 3 receipts that have the correct allocation id
276- async fn check_allocation_id_ok ( keys : ( LocalWallet , Address ) , allocation_ids : Vec < Address > ) {
331+ async fn check_allocation_id_ok (
332+ keys : ( LocalWallet , Address ) ,
333+ allocation_ids : Vec < Address > ,
334+ domain_separator : Eip712Domain ,
335+ ) {
277336 let mut receipts = Vec :: new ( ) ;
278337 receipts. push (
279- EIP712SignedMessage :: new ( Receipt :: new ( allocation_ids[ 0 ] , 42 ) . unwrap ( ) , & keys. 0 )
280- . await
281- . unwrap ( ) ,
338+ EIP712SignedMessage :: new (
339+ & domain_separator,
340+ Receipt :: new ( allocation_ids[ 0 ] , 42 ) . unwrap ( ) ,
341+ & keys. 0 ,
342+ )
343+ . await
344+ . unwrap ( ) ,
282345 ) ;
283346 receipts. push (
284- EIP712SignedMessage :: new ( Receipt :: new ( allocation_ids[ 0 ] , 43 ) . unwrap ( ) , & keys. 0 )
285- . await
286- . unwrap ( ) ,
347+ EIP712SignedMessage :: new (
348+ & domain_separator,
349+ Receipt :: new ( allocation_ids[ 0 ] , 43 ) . unwrap ( ) ,
350+ & keys. 0 ,
351+ )
352+ . await
353+ . unwrap ( ) ,
287354 ) ;
288355 receipts. push (
289- EIP712SignedMessage :: new ( Receipt :: new ( allocation_ids[ 0 ] , 44 ) . unwrap ( ) , & keys. 0 )
290- . await
291- . unwrap ( ) ,
356+ EIP712SignedMessage :: new (
357+ & domain_separator,
358+ Receipt :: new ( allocation_ids[ 0 ] , 44 ) . unwrap ( ) ,
359+ & keys. 0 ,
360+ )
361+ . await
362+ . unwrap ( ) ,
292363 ) ;
293364
294365 let res = aggregator:: check_allocation_id ( & receipts, allocation_ids[ 0 ] ) ;
0 commit comments