11use std:: collections:: { BTreeMap , HashSet } ;
2+ use std:: sync:: Arc ;
23
4+ use async_trait:: async_trait;
35use blockifier:: execution:: contract_class:: { CompiledClassV1 , RunnableCompiledClass } ;
46use blockifier:: state:: state_api:: StateReader ;
57use blockifier:: state:: state_reader_and_contract_manager:: {
@@ -9,6 +11,7 @@ use blockifier::state::state_reader_and_contract_manager::{
911use cairo_lang_starknet_classes:: casm_contract_class:: CasmContractClass ;
1012use cairo_lang_utils:: bigint:: BigUintAsHex ;
1113use cairo_vm:: types:: relocatable:: MaybeRelocatable ;
14+ use futures:: future:: try_join_all;
1215use starknet_api:: core:: { ClassHash , CompiledClassHash } ;
1316use starknet_types_core:: felt:: Felt ;
1417
@@ -43,6 +46,37 @@ fn compiled_class_v1_to_casm(class: &CompiledClassV1) -> CasmContractClass {
4346 }
4447}
4548
49+ /// Fetch class from the state reader and contract manager.
50+ /// Returns error if the class is deprecated.
51+ fn fetch_class < S > (
52+ state_reader_and_contract_manager : Arc < StateReaderAndContractManager < S > > ,
53+ class_hash : ClassHash ,
54+ ) -> Result < ( CompiledClassHash , CasmContractClass ) , ClassesProviderError >
55+ where
56+ S : FetchCompiledClasses + Send + Sync + ' static ,
57+ {
58+ let compiled_class = state_reader_and_contract_manager. get_compiled_class ( class_hash) ?;
59+
60+ let compiled_class_hash = state_reader_and_contract_manager
61+ . get_compiled_class_hash_v2 ( class_hash, & compiled_class) ?;
62+
63+ match compiled_class {
64+ RunnableCompiledClass :: V0 ( _v0) => {
65+ Err ( ClassesProviderError :: DeprecatedContractError ( class_hash) )
66+ }
67+ RunnableCompiledClass :: V1 ( compiled_class_v1) => {
68+ let casm = compiled_class_v1_to_casm ( & compiled_class_v1) ;
69+ Ok ( ( compiled_class_hash, casm) )
70+ }
71+ #[ cfg( feature = "cairo_native" ) ]
72+ RunnableCompiledClass :: V1Native ( compiled_class_v1_native) => {
73+ let compiled_class_v1 = compiled_class_v1_native. casm ( ) ;
74+ let casm = compiled_class_v1_to_casm ( & compiled_class_v1) ;
75+ Ok ( ( compiled_class_hash, casm) )
76+ }
77+ }
78+ }
79+
4680/// The classes required for a Starknet OS run.
4781/// Matches the fields in `StarknetOsInput`.
4882pub struct ClassesInput {
@@ -51,53 +85,46 @@ pub struct ClassesInput {
5185 pub compiled_classes : BTreeMap < CompiledClassHash , CasmContractClass > ,
5286}
5387
88+ #[ async_trait]
5489pub trait ClassesProvider {
5590 /// Fetches all classes required for the OS run based on the executed class hashes.
56- fn get_classes (
91+ /// This default implementation parallelizes fetching by spawning blocking tasks.
92+ async fn get_classes (
93+ & self ,
94+ executed_class_hashes : & HashSet < ClassHash > ,
95+ ) -> Result < ClassesInput , ClassesProviderError > ;
96+ }
97+
98+ #[ async_trait]
99+ impl < S > ClassesProvider for Arc < StateReaderAndContractManager < S > >
100+ where
101+ S : FetchCompiledClasses + Send + Sync + ' static ,
102+ {
103+ async fn get_classes (
57104 & self ,
58105 executed_class_hashes : & HashSet < ClassHash > ,
59106 ) -> Result < ClassesInput , ClassesProviderError > {
60- let mut compiled_classes = BTreeMap :: new ( ) ;
107+ // clonning the arc to create new refference with static lifetime.
108+ let shared_contract_class_manager = self . clone ( ) ;
61109
62- // TODO(Aviv): Parallelize the fetching of classes.
63- for & class_hash in executed_class_hashes {
64- let ( compiled_class_hash, casm) = self . fetch_class ( class_hash) ?;
65- compiled_classes. insert ( compiled_class_hash, casm) ;
66- }
67- Ok ( ClassesInput { compiled_classes } )
68- }
110+ // Creating tasks to fetch classes in parallel.
111+ let tasks = executed_class_hashes. iter ( ) . map ( |& class_hash| {
112+ let manager = shared_contract_class_manager. clone ( ) ;
69113
70- /// Fetches class by class hash.
71- fn fetch_class (
72- & self ,
73- class_hash : ClassHash ,
74- ) -> Result < ( CompiledClassHash , CasmContractClass ) , ClassesProviderError > ;
75- }
114+ tokio:: task:: spawn_blocking ( move || fetch_class ( manager, class_hash) )
115+ } ) ;
76116
77- impl < S : FetchCompiledClasses > ClassesProvider for StateReaderAndContractManager < S > {
78- /// Fetch class from the state reader and contract manager.
79- /// Returns error if the class is deprecated.
80- fn fetch_class (
81- & self ,
82- class_hash : ClassHash ,
83- ) -> Result < ( CompiledClassHash , CasmContractClass ) , ClassesProviderError > {
84- let compiled_class = self . get_compiled_class ( class_hash) ?;
85- // TODO(Aviv): Make sure that the state reader is not returning dummy compiled class hash.
86- let compiled_class_hash = self . get_compiled_class_hash_v2 ( class_hash, & compiled_class) ?;
87- match compiled_class {
88- RunnableCompiledClass :: V0 ( _v0) => {
89- Err ( ClassesProviderError :: DeprecatedContractError ( class_hash) )
90- }
91- RunnableCompiledClass :: V1 ( compiled_class_v1) => {
92- let casm = compiled_class_v1_to_casm ( & compiled_class_v1) ;
93- Ok ( ( compiled_class_hash, casm) )
94- }
95- #[ cfg( feature = "cairo_native" ) ]
96- RunnableCompiledClass :: V1Native ( compiled_class_v1_native) => {
97- let compiled_class_v1 = compiled_class_v1_native. casm ( ) ;
98- let casm = compiled_class_v1_to_casm ( & compiled_class_v1) ;
99- Ok ( ( compiled_class_hash, casm) )
100- }
101- }
117+ // Fetching classes in parallel.
118+ let results = try_join_all ( tasks)
119+ . await
120+ . map_err ( |e| ClassesProviderError :: GetClassesError ( format ! ( "Task join error: {e}" ) ) ) ?;
121+
122+ // Collecting results into a BTreeMap.
123+ // results is Vec<Result<(CompiledClassHash, CasmContractClass), ClassesProviderError>>
124+ let compiled_classes = results
125+ . into_iter ( )
126+ . collect :: < Result < BTreeMap < CompiledClassHash , CasmContractClass > , ClassesProviderError > > ( ) ?;
127+
128+ Ok ( ClassesInput { compiled_classes } )
102129 }
103130}
0 commit comments