Skip to content

Commit 790ef0d

Browse files
starknet_os_runner: fetch classes concurrent
1 parent 250272f commit 790ef0d

File tree

6 files changed

+94
-58
lines changed

6 files changed

+94
-58
lines changed

Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/starknet_os_runner/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@ description = "Runs transactions through the Starknet OS and returns Cairo PIE a
1010
cairo_native = ["blockifier/cairo_native"]
1111

1212
[dependencies]
13+
async-trait.workspace = true
1314
blockifier.workspace = true
1415
blockifier_reexecution.workspace = true
1516
cairo-lang-starknet-classes.workspace = true
1617
cairo-lang-utils.workspace = true
1718
cairo-vm.workspace = true
19+
futures.workspace = true
1820
indexmap.workspace = true
1921
shared_execution_objects.workspace = true
2022
starknet-rust.workspace = true
Lines changed: 67 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
use std::collections::{BTreeMap, HashSet};
2+
use std::sync::Arc;
23

4+
use async_trait::async_trait;
35
use blockifier::execution::contract_class::{CompiledClassV1, RunnableCompiledClass};
46
use blockifier::state::state_api::StateReader;
57
use blockifier::state::state_reader_and_contract_manager::{
@@ -9,6 +11,7 @@ use blockifier::state::state_reader_and_contract_manager::{
911
use cairo_lang_starknet_classes::casm_contract_class::CasmContractClass;
1012
use cairo_lang_utils::bigint::BigUintAsHex;
1113
use cairo_vm::types::relocatable::MaybeRelocatable;
14+
use futures::future::try_join_all;
1215
use starknet_api::core::{ClassHash, CompiledClassHash};
1316
use starknet_types_core::felt::Felt;
1417

@@ -43,6 +46,37 @@ fn compiled_class_v1_to_casm(class: &CompiledClassV1) -> CasmContractClass {
4346
}
4447
}
4548

49+
/// Fetch class from the state reader and contract manager.
50+
/// Returns error if the class is deprecated.
51+
fn fetch_class<S>(
52+
state_reader_and_contract_manager: Arc<StateReaderAndContractManager<S>>,
53+
class_hash: ClassHash,
54+
) -> Result<(CompiledClassHash, CasmContractClass), ClassesProviderError>
55+
where
56+
S: FetchCompiledClasses + Send + Sync + 'static,
57+
{
58+
let compiled_class = state_reader_and_contract_manager.get_compiled_class(class_hash)?;
59+
60+
let compiled_class_hash = state_reader_and_contract_manager
61+
.get_compiled_class_hash_v2(class_hash, &compiled_class)?;
62+
63+
match compiled_class {
64+
RunnableCompiledClass::V0(_v0) => {
65+
Err(ClassesProviderError::DeprecatedContractError(class_hash))
66+
}
67+
RunnableCompiledClass::V1(compiled_class_v1) => {
68+
let casm = compiled_class_v1_to_casm(&compiled_class_v1);
69+
Ok((compiled_class_hash, casm))
70+
}
71+
#[cfg(feature = "cairo_native")]
72+
RunnableCompiledClass::V1Native(compiled_class_v1_native) => {
73+
let compiled_class_v1 = compiled_class_v1_native.casm();
74+
let casm = compiled_class_v1_to_casm(&compiled_class_v1);
75+
Ok((compiled_class_hash, casm))
76+
}
77+
}
78+
}
79+
4680
/// The classes required for a Starknet OS run.
4781
/// Matches the fields in `StarknetOsInput`.
4882
pub struct ClassesInput {
@@ -51,53 +85,46 @@ pub struct ClassesInput {
5185
pub compiled_classes: BTreeMap<CompiledClassHash, CasmContractClass>,
5286
}
5387

88+
#[async_trait]
5489
pub trait ClassesProvider {
5590
/// Fetches all classes required for the OS run based on the executed class hashes.
56-
fn get_classes(
91+
/// This default implementation parallelizes fetching by spawning blocking tasks.
92+
async fn get_classes(
93+
&self,
94+
executed_class_hashes: &HashSet<ClassHash>,
95+
) -> Result<ClassesInput, ClassesProviderError>;
96+
}
97+
98+
#[async_trait]
99+
impl<S> ClassesProvider for Arc<StateReaderAndContractManager<S>>
100+
where
101+
S: FetchCompiledClasses + Send + Sync + 'static,
102+
{
103+
async fn get_classes(
57104
&self,
58105
executed_class_hashes: &HashSet<ClassHash>,
59106
) -> Result<ClassesInput, ClassesProviderError> {
60-
let mut compiled_classes = BTreeMap::new();
107+
// clonning the arc to create new refference with static lifetime.
108+
let shared_contract_class_manager = self.clone();
61109

62-
// TODO(Aviv): Parallelize the fetching of classes.
63-
for &class_hash in executed_class_hashes {
64-
let (compiled_class_hash, casm) = self.fetch_class(class_hash)?;
65-
compiled_classes.insert(compiled_class_hash, casm);
66-
}
67-
Ok(ClassesInput { compiled_classes })
68-
}
110+
// Creating tasks to fetch classes in parallel.
111+
let tasks = executed_class_hashes.iter().map(|&class_hash| {
112+
let manager = shared_contract_class_manager.clone();
69113

70-
/// Fetches class by class hash.
71-
fn fetch_class(
72-
&self,
73-
class_hash: ClassHash,
74-
) -> Result<(CompiledClassHash, CasmContractClass), ClassesProviderError>;
75-
}
114+
tokio::task::spawn_blocking(move || fetch_class(manager, class_hash))
115+
});
76116

77-
impl<S: FetchCompiledClasses> ClassesProvider for StateReaderAndContractManager<S> {
78-
/// Fetch class from the state reader and contract manager.
79-
/// Returns error if the class is deprecated.
80-
fn fetch_class(
81-
&self,
82-
class_hash: ClassHash,
83-
) -> Result<(CompiledClassHash, CasmContractClass), ClassesProviderError> {
84-
let compiled_class = self.get_compiled_class(class_hash)?;
85-
// TODO(Aviv): Make sure that the state reader is not returning dummy compiled class hash.
86-
let compiled_class_hash = self.get_compiled_class_hash_v2(class_hash, &compiled_class)?;
87-
match compiled_class {
88-
RunnableCompiledClass::V0(_v0) => {
89-
Err(ClassesProviderError::DeprecatedContractError(class_hash))
90-
}
91-
RunnableCompiledClass::V1(compiled_class_v1) => {
92-
let casm = compiled_class_v1_to_casm(&compiled_class_v1);
93-
Ok((compiled_class_hash, casm))
94-
}
95-
#[cfg(feature = "cairo_native")]
96-
RunnableCompiledClass::V1Native(compiled_class_v1_native) => {
97-
let compiled_class_v1 = compiled_class_v1_native.casm();
98-
let casm = compiled_class_v1_to_casm(&compiled_class_v1);
99-
Ok((compiled_class_hash, casm))
100-
}
101-
}
117+
// Fetching classes in parallel.
118+
let results = try_join_all(tasks)
119+
.await
120+
.map_err(|e| ClassesProviderError::GetClassesError(format!("Task join error: {e}")))?;
121+
122+
// Collecting results into a BTreeMap.
123+
// results is Vec<Result<(CompiledClassHash, CasmContractClass), ClassesProviderError>>
124+
let compiled_classes = results
125+
.into_iter()
126+
.collect::<Result<BTreeMap<CompiledClassHash, CasmContractClass>, ClassesProviderError>>()?;
127+
128+
Ok(ClassesInput { compiled_classes })
102129
}
103130
}

crates/starknet_os_runner/src/runner.rs

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ impl From<VirtualOsBlockInput> for StarknetOsInput {
7979

8080
pub struct Runner<C, S, V>
8181
where
82-
C: ClassesProvider,
83-
S: StorageProofProvider,
82+
C: ClassesProvider + Sync,
83+
S: StorageProofProvider + Sync,
8484
V: VirtualBlockExecutor,
8585
{
8686
pub classes_provider: C,
@@ -90,8 +90,8 @@ where
9090

9191
impl<C, S, V> Runner<C, S, V>
9292
where
93-
C: ClassesProvider,
94-
S: StorageProofProvider,
93+
C: ClassesProvider + Sync,
94+
S: StorageProofProvider + Sync,
9595
V: VirtualBlockExecutor,
9696
{
9797
pub fn new(classes_provider: C, storage_proofs_provider: S, virtual_block_executor: V) -> Self {
@@ -100,7 +100,7 @@ where
100100

101101
/// Creates the OS hints required to run the given transactions virtually
102102
/// on top of the given block number.
103-
pub fn create_os_hints(
103+
pub async fn create_os_hints(
104104
&self,
105105
block_number: BlockNumber,
106106
contract_class_manager: ContractClassManager,
@@ -120,12 +120,13 @@ where
120120
strk_fee_token_address: chain_info.fee_token_addresses.strk_fee_token_address,
121121
};
122122

123-
// Fetch classes.
124-
let classes = self.classes_provider.get_classes(&execution_data.executed_class_hashes)?;
125-
126-
// Fetch storage proofs.
127-
let storage_proofs =
128-
self.storage_proofs_provider.get_storage_proofs(block_number, &execution_data)?;
123+
// Fetch classes and storage proofs in parallel.
124+
let (classes, storage_proofs) = tokio::join!(
125+
self.classes_provider.get_classes(&execution_data.executed_class_hashes),
126+
self.storage_proofs_provider.get_storage_proofs(block_number, &execution_data)
127+
);
128+
let classes = classes?;
129+
let storage_proofs = storage_proofs?;
129130

130131
// Convert execution outputs to CentralTransactionExecutionInfo.
131132
let tx_execution_infos =
@@ -177,13 +178,13 @@ where
177178
/// 4. Runs the OS in stateless mode (all state pre-loaded in input)
178179
///
179180
/// Returns the OS output containing the Cairo PIE and execution metrics.
180-
pub fn run_os(
181+
pub async fn run_os(
181182
&self,
182183
block_number: BlockNumber,
183184
contract_class_manager: ContractClassManager,
184185
txs: Vec<(InvokeTransaction, TransactionHash)>,
185186
) -> Result<StarknetOsRunnerOutput, RunnerError> {
186-
let os_hints = self.create_os_hints(block_number, contract_class_manager, txs)?;
187+
let os_hints = self.create_os_hints(block_number, contract_class_manager, txs).await?;
187188
let output = run_os_stateless(DEFAULT_OS_LAYOUT, os_hints)?;
188189
Ok(output)
189190
}

crates/starknet_os_runner/src/storage_proofs.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use std::collections::HashMap;
22

3+
use async_trait::async_trait;
34
use blockifier::state::cached_state::StateMaps;
45
use starknet_api::block::BlockNumber;
56
use starknet_api::core::{ClassHash, ContractAddress, Nonce};
@@ -32,8 +33,9 @@ use crate::virtual_block_executor::VirtualBlockExecutionData;
3233
/// The returned `StorageProofs` contains:
3334
/// - `proof_state`: The ambient state values (nonces, class hashes) discovered in the proof.
3435
/// - `commitment_infos`: The Patricia Merkle proof nodes for contracts, classes, and storage tries.
36+
#[async_trait]
3537
pub trait StorageProofProvider {
36-
fn get_storage_proofs(
38+
async fn get_storage_proofs(
3739
&self,
3840
block_number: BlockNumber,
3941
execution_data: &VirtualBlockExecutionData,
@@ -288,16 +290,16 @@ impl RpcStorageProofsProvider {
288290
}
289291
}
290292

293+
#[async_trait]
291294
impl StorageProofProvider for RpcStorageProofsProvider {
292-
fn get_storage_proofs(
295+
async fn get_storage_proofs(
293296
&self,
294297
block_number: BlockNumber,
295298
execution_data: &VirtualBlockExecutionData,
296299
) -> Result<StorageProofs, ProofProviderError> {
297300
let query = Self::prepare_query(execution_data);
298301

299-
let runtime = tokio::runtime::Runtime::new().expect("Failed to create tokio runtime");
300-
let rpc_proof = runtime.block_on(self.fetch_proofs(block_number, &query))?;
302+
let rpc_proof = self.fetch_proofs(block_number, &query).await?;
301303

302304
Self::to_storage_proofs(&rpc_proof, &query)
303305
}

crates/starknet_os_runner/src/storage_proofs_test.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@ fn test_get_storage_proofs_from_rpc(
5050
prev_base_block_hash: BlockHash::default(),
5151
};
5252

53-
let result = rpc_provider.get_storage_proofs(BlockNumber(block_number), &execution_data);
53+
let result = runtime.block_on(async {
54+
rpc_provider.get_storage_proofs(BlockNumber(block_number), &execution_data).await
55+
});
5456
assert!(result.is_ok(), "Failed to get storage proofs: {:?}", result.err());
5557

5658
let storage_proofs = result.unwrap();

0 commit comments

Comments
 (0)