11use std:: collections:: HashMap ;
22
33use futures:: FutureExt ;
4+ use mockall:: predicate:: eq;
45use papyrus_common:: pending_classes:: ApiContractClass ;
56use papyrus_protobuf:: sync:: {
67 BlockHashOrNumber ,
@@ -16,14 +17,14 @@ use papyrus_test_utils::{get_rng, GetTestInstance};
1617use rand:: { Rng , RngCore } ;
1718use rand_chacha:: ChaCha8Rng ;
1819use starknet_api:: block:: BlockNumber ;
19- use starknet_api:: contract_class:: ContractClass ;
2020use starknet_api:: core:: { ClassHash , CompiledClassHash , EntryPointSelector } ;
2121use starknet_api:: deprecated_contract_class:: {
2222 ContractClass as DeprecatedContractClass ,
2323 EntryPointOffset ,
2424 EntryPointV0 ,
2525} ;
2626use starknet_api:: state:: SierraContractClass ;
27+ use starknet_class_manager_types:: MockClassManagerClient ;
2728
2829use super :: test_utils:: {
2930 random_header,
@@ -39,6 +40,8 @@ use super::test_utils::{
3940async fn class_basic_flow ( ) {
4041 let mut rng = get_rng ( ) ;
4142
43+ let mut class_manager_client = MockClassManagerClient :: new ( ) ;
44+
4245 let state_diffs_and_classes_of_blocks = [
4346 vec ! [
4447 create_random_state_diff_chunk_with_class( & mut rng) ,
@@ -51,6 +54,30 @@ async fn class_basic_flow() {
5154 ] ,
5255 ] ;
5356
57+ // Fill class manager client with expectations.
58+ for state_diffs_and_classes in & state_diffs_and_classes_of_blocks {
59+ for ( state_diff, class) in state_diffs_and_classes {
60+ let class_hash = state_diff. get_class_hash ( ) ;
61+ match class {
62+ ApiContractClass :: ContractClass ( class) => {
63+ let compiled_class_hash = state_diff. get_compiled_class_hash ( ) ;
64+ class_manager_client
65+ . expect_add_class ( )
66+ . times ( 1 )
67+ . with ( eq ( class_hash) , eq ( class. clone ( ) ) )
68+ . return_once ( move |_, _| Ok ( compiled_class_hash) ) ;
69+ }
70+ ApiContractClass :: DeprecatedContractClass ( class) => {
71+ class_manager_client
72+ . expect_add_deprecated_class ( )
73+ . times ( 1 )
74+ . with ( eq ( class_hash) , eq ( class. clone ( ) ) )
75+ . return_once ( |_, _| Ok ( ( ) ) ) ;
76+ }
77+ }
78+ }
79+ }
80+
5481 let mut actions = vec ! [
5582 Action :: RunP2pSync ,
5683 // We already validate the header query content in other tests.
@@ -99,7 +126,7 @@ async fn class_basic_flow() {
99126 let class_hash = state_diff. get_class_hash ( ) ;
100127
101128 // Check that before the last class was sent, the classes aren't written.
102- actions. push ( Action :: CheckStorage ( Box :: new ( move |( reader, _ ) | {
129+ actions. push ( Action :: CheckStorage ( Box :: new ( move |reader| {
103130 async move {
104131 assert_eq ! (
105132 u64 :: try_from( i) . unwrap( ) ,
@@ -111,7 +138,7 @@ async fn class_basic_flow() {
111138 actions. push ( Action :: SendClass ( DataOrFin ( Some ( ( class. clone ( ) , class_hash) ) ) ) ) ;
112139 }
113140 // Check that a block's classes are written before the entire query finished.
114- actions. push ( Action :: CheckStorage ( Box :: new ( move |( reader, class_manager_client ) | {
141+ actions. push ( Action :: CheckStorage ( Box :: new ( move |reader| {
115142 async move {
116143 let block_number = BlockNumber ( i. try_into ( ) . unwrap ( ) ) ;
117144 wait_for_marker (
@@ -122,22 +149,6 @@ async fn class_basic_flow() {
122149 TIMEOUT_FOR_TEST ,
123150 )
124151 . await ;
125-
126- for ( state_diff, expected_class) in state_diffs_and_classes {
127- let class_hash = state_diff. get_class_hash ( ) ;
128- match expected_class {
129- ApiContractClass :: ContractClass ( expected_class) => {
130- let actual_class =
131- class_manager_client. get_sierra ( class_hash) . await . unwrap ( ) ;
132- assert_eq ! ( actual_class, expected_class. clone( ) ) ;
133- }
134- ApiContractClass :: DeprecatedContractClass ( expected_class) => {
135- let actual_class =
136- class_manager_client. get_executable ( class_hash) . await . unwrap ( ) ;
137- assert_eq ! ( actual_class, ContractClass :: V0 ( expected_class. clone( ) ) ) ;
138- }
139- }
140- }
141152 }
142153 . boxed ( )
143154 } ) ) ) ;
@@ -149,6 +160,7 @@ async fn class_basic_flow() {
149160 ( DataType :: StateDiff , len. try_into ( ) . unwrap ( ) ) ,
150161 ( DataType :: Class , len. try_into ( ) . unwrap ( ) ) ,
151162 ] ) ,
163+ Some ( class_manager_client) ,
152164 actions,
153165 )
154166 . await ;
@@ -158,6 +170,7 @@ async fn class_basic_flow() {
158170// we need to define this trait because StateDiffChunk is defined in an other crate.
159171trait GetClassHash {
160172 fn get_class_hash ( & self ) -> ClassHash ;
173+ fn get_compiled_class_hash ( & self ) -> CompiledClassHash ;
161174}
162175
163176impl GetClassHash for StateDiffChunk {
@@ -170,6 +183,13 @@ impl GetClassHash for StateDiffChunk {
170183 _ => unreachable ! ( ) ,
171184 }
172185 }
186+
187+ fn get_compiled_class_hash ( & self ) -> CompiledClassHash {
188+ match self {
189+ StateDiffChunk :: DeclaredClass ( declared_class) => declared_class. compiled_class_hash ,
190+ _ => unreachable ! ( ) ,
191+ }
192+ }
173193}
174194
175195fn create_random_state_diff_chunk_with_class (
@@ -325,6 +345,7 @@ async fn validate_class_sync_fails(
325345 ( DataType :: StateDiff , header_state_diff_lengths. len ( ) . try_into ( ) . unwrap ( ) ) ,
326346 ( DataType :: Class , header_state_diff_lengths. len ( ) . try_into ( ) . unwrap ( ) ) ,
327347 ] ) ,
348+ None ,
328349 actions,
329350 )
330351 . await ;
0 commit comments