1+ use std:: fmt:: Debug ;
2+ use std:: hash:: Hash ;
13use std:: sync:: Arc ;
24
35use apollo_batcher_config:: config:: { BatcherConfig , BlockBuilderConfig } ;
@@ -32,7 +34,7 @@ use apollo_storage::test_utils::get_test_storage;
3234use apollo_storage:: { StorageError , StorageReader , StorageWriter } ;
3335use assert_matches:: assert_matches;
3436use blockifier:: abi:: constants;
35- use indexmap:: IndexSet ;
37+ use indexmap:: { indexmap , IndexMap , IndexSet } ;
3638use metrics_exporter_prometheus:: PrometheusBuilder ;
3739use mockall:: predicate:: eq;
3840use rstest:: rstest;
@@ -46,6 +48,7 @@ use starknet_api::block::{
4648use starknet_api:: block_hash:: block_hash_calculator:: PartialBlockHashComponents ;
4749use starknet_api:: block_hash:: state_diff_hash:: calculate_state_diff_hash;
4850use starknet_api:: consensus_transaction:: InternalConsensusTransaction ;
51+ use starknet_api:: core:: { ClassHash , CompiledClassHash , Nonce } ;
4952use starknet_api:: state:: ThinStateDiff ;
5053use starknet_api:: test_utils:: CHAIN_ID_FOR_TESTS ;
5154use starknet_api:: transaction:: TransactionHash ;
@@ -108,6 +111,59 @@ use crate::test_utils::{
108111 STREAMING_CHUNK_SIZE ,
109112} ;
110113
114+ fn get_test_state_diff (
115+ mut keys_stream : impl Iterator < Item = u64 > ,
116+ mut values_stream : impl Iterator < Item = u64 > ,
117+ ) -> ThinStateDiff {
118+ ThinStateDiff {
119+ deployed_contracts : indexmap ! {
120+ ( keys_stream. next( ) . unwrap( ) ) . into( ) => ClassHash ( values_stream. next( ) . unwrap( ) . into( ) ) ,
121+ ( keys_stream. next( ) . unwrap( ) ) . into( ) => ClassHash ( values_stream. next( ) . unwrap( ) . into( ) ) ,
122+ } ,
123+ storage_diffs : indexmap ! {
124+ ( keys_stream. next( ) . unwrap( ) ) . into( ) => indexmap! {
125+ ( keys_stream. next( ) . unwrap( ) ) . into( ) => ( values_stream. next( ) . unwrap( ) ) . into( ) ,
126+ ( keys_stream. next( ) . unwrap( ) ) . into( ) => values_stream. next( ) . unwrap( ) . into( ) ,
127+ } ,
128+ } ,
129+ class_hash_to_compiled_class_hash : indexmap ! {
130+ ClassHash ( keys_stream. next( ) . unwrap( ) . into( ) ) =>
131+ CompiledClassHash ( values_stream. next( ) . unwrap( ) . into( ) ) ,
132+ ClassHash ( keys_stream. next( ) . unwrap( ) . into( ) ) =>
133+ CompiledClassHash ( values_stream. next( ) . unwrap( ) . into( ) ) ,
134+ } ,
135+ nonces : indexmap ! {
136+ ( keys_stream. next( ) . unwrap( ) ) . into( ) => Nonce ( values_stream. next( ) . unwrap( ) . into( ) ) ,
137+ ( keys_stream. next( ) . unwrap( ) ) . into( ) => Nonce ( values_stream. next( ) . unwrap( ) . into( ) ) ,
138+ } ,
139+ deprecated_declared_classes : vec ! [
140+ ClassHash ( keys_stream. next( ) . unwrap( ) . into( ) ) ,
141+ ClassHash ( keys_stream. next( ) . unwrap( ) . into( ) ) ,
142+ ] ,
143+ }
144+ }
145+
146+ /// The keys in each consecutive state diff are overlapping, for each map in the state diff.
147+ /// If in block A the keys are x, x+1, then in block A+1 the keys are x+1, x+2.
148+ fn get_overlapping_state_diffs ( n_state_diffs : u64 ) -> Vec < ThinStateDiff > {
149+ let mut state_diffs = Vec :: new ( ) ;
150+ for i in 0 ..n_state_diffs {
151+ state_diffs. push ( get_test_state_diff ( i.., ( i * 100 ) ..) ) ;
152+ }
153+ state_diffs
154+ }
155+
156+ fn write_state_diff ( batcher : & mut Batcher , height : BlockNumber , state_diff : & ThinStateDiff ) {
157+ batcher
158+ . storage_writer
159+ . commit_proposal (
160+ height,
161+ state_diff. clone ( ) ,
162+ StorageCommitmentBlockHash :: Partial ( PartialBlockHashComponents :: default ( ) ) ,
163+ )
164+ . expect ( "set_state_diff failed" ) ;
165+ }
166+
111167async fn proposal_commitment ( ) -> ProposalCommitment {
112168 BlockExecutionArtifacts :: create_for_testing ( ) . await . commitment ( )
113169}
@@ -1558,3 +1614,61 @@ async fn get_block_hash_error() {
15581614 let result = batcher. get_block_hash ( INITIAL_HEIGHT ) ;
15591615 assert_eq ! ( result, Err ( BatcherError :: InternalError ) ) ;
15601616}
1617+
1618+ /// For every key in the original map, validates that the reversed map values are identical to the
1619+ /// base map, or zero if the key is missing in the base map.
1620+ fn validate_is_reversed < K : Eq + Hash + Debug , V : Debug + Default + Eq + Hash > (
1621+ base : IndexMap < K , V > ,
1622+ original : IndexMap < K , V > ,
1623+ reversed : IndexMap < K , V > ,
1624+ ) {
1625+ assert_eq ! ( original. len( ) , reversed. len( ) ) ;
1626+ for key in original. keys ( ) {
1627+ assert_eq ! ( reversed. get( key) . unwrap( ) , base. get( key) . unwrap_or( & V :: default ( ) ) ) ;
1628+ }
1629+ }
1630+
1631+ #[ tokio:: test]
1632+ async fn test_reversed_state_diff ( ) {
1633+ let mut batcher =
1634+ create_batcher_with_real_storage ( MockDependenciesWithRealStorage :: default ( ) ) . await ;
1635+
1636+ let state_diffs = get_overlapping_state_diffs ( 2 ) ;
1637+
1638+ let mut height = BlockNumber ( 0 ) ;
1639+ let base_state_diff = state_diffs[ 0 ] . clone ( ) ;
1640+ write_state_diff ( & mut batcher, height, & base_state_diff) ;
1641+
1642+ height = height. unchecked_next ( ) ;
1643+ let original_state_diff = state_diffs[ 1 ] . clone ( ) ;
1644+ write_state_diff ( & mut batcher, height, & original_state_diff) ;
1645+
1646+ let reversed_state_diff = batcher. storage_reader . reversed_state_diff ( height) . unwrap ( ) ;
1647+
1648+ validate_is_reversed (
1649+ base_state_diff. deployed_contracts ,
1650+ original_state_diff. deployed_contracts ,
1651+ reversed_state_diff. deployed_contracts ,
1652+ ) ;
1653+ for ( contract_address, storage_diffs) in original_state_diff. storage_diffs {
1654+ validate_is_reversed (
1655+ base_state_diff
1656+ . storage_diffs
1657+ . get ( & contract_address)
1658+ . unwrap_or ( & IndexMap :: new ( ) )
1659+ . clone ( ) ,
1660+ storage_diffs,
1661+ reversed_state_diff. storage_diffs . get ( & contract_address) . unwrap ( ) . clone ( ) ,
1662+ ) ;
1663+ }
1664+ validate_is_reversed (
1665+ base_state_diff. class_hash_to_compiled_class_hash ,
1666+ original_state_diff. class_hash_to_compiled_class_hash . clone ( ) ,
1667+ reversed_state_diff. class_hash_to_compiled_class_hash ,
1668+ ) ;
1669+ validate_is_reversed (
1670+ base_state_diff. nonces ,
1671+ original_state_diff. nonces . clone ( ) ,
1672+ reversed_state_diff. nonces ,
1673+ ) ;
1674+ }
0 commit comments