diff --git a/Cargo.lock b/Cargo.lock index 3f29c500bf1..05a927d5224 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1728,6 +1728,7 @@ dependencies = [ "itertools 0.12.1", "keccak", "log", + "mockall", "num-bigint 0.4.6", "num-integer", "num-rational 0.4.2", @@ -10698,6 +10699,7 @@ dependencies = [ "assert_matches", "async-trait", "blockifier", + "cairo-lang-starknet-classes", "chrono", "futures", "indexmap 2.7.0", diff --git a/crates/blockifier/Cargo.toml b/crates/blockifier/Cargo.toml index 673b0cb06db..298c562648f 100644 --- a/crates/blockifier/Cargo.toml +++ b/crates/blockifier/Cargo.toml @@ -33,6 +33,7 @@ cairo-native = { workspace = true, optional = true } cairo-vm.workspace = true derive_more.workspace = true indexmap.workspace = true +mockall.workspace = true itertools.workspace = true keccak.workspace = true log.workspace = true diff --git a/crates/blockifier/src/state/state_api.rs b/crates/blockifier/src/state/state_api.rs index 074d7835c56..32544db04d9 100644 --- a/crates/blockifier/src/state/state_api.rs +++ b/crates/blockifier/src/state/state_api.rs @@ -19,6 +19,7 @@ pub enum DataAvailabilityMode { /// /// The `self` argument is mutable for flexibility during reads (for example, caching reads), /// and to allow for the `State` trait below to also be considered a `StateReader`. +#[cfg_attr(any(test, feature = "testing"), mockall::automock)] pub trait StateReader { /// Returns the storage value under the given key in the given contract instance (represented by /// its address). diff --git a/crates/starknet_batcher/Cargo.toml b/crates/starknet_batcher/Cargo.toml index 6e03503a41d..afe2695f792 100644 --- a/crates/starknet_batcher/Cargo.toml +++ b/crates/starknet_batcher/Cargo.toml @@ -18,6 +18,7 @@ papyrus_config.workspace = true papyrus_state_reader.workspace = true papyrus_storage.workspace = true serde.workspace = true +starknet-types-core.workspace = true starknet_api.workspace = true starknet_batcher_types.workspace = true starknet_class_manager_types.workspace = true @@ -33,15 +34,18 @@ validator.workspace = true [dev-dependencies] assert_matches.workspace = true +cairo-lang-starknet-classes.workspace = true chrono = { workspace = true } futures.workspace = true mempool_test_utils.path = "../mempool_test_utils" metrics.workspace = true metrics-exporter-prometheus.workspace = true mockall.workspace = true +papyrus_storage = { path = "../papyrus_storage", features = ["testing"] } rstest.workspace = true starknet-types-core.workspace = true starknet_api = { path = "../starknet_api", features = ["testing"] } +starknet_class_manager_types = { path = "../starknet_class_manager_types", features = ["testing"] } starknet_infra_utils.path = "../starknet_infra_utils" starknet_l1_provider_types = { path = "../starknet_l1_provider_types", features = ["testing"] } starknet_mempool_types = { path = "../starknet_mempool_types", features = ["testing"] } diff --git a/crates/starknet_batcher/src/batcher.rs b/crates/starknet_batcher/src/batcher.rs index eba110a13ad..26a34ec8c91 100644 --- a/crates/starknet_batcher/src/batcher.rs +++ b/crates/starknet_batcher/src/batcher.rs @@ -646,6 +646,7 @@ pub fn create_batcher( contract_class_manager: ContractClassManager::start( config.contract_class_manager_config.clone(), ), + class_manager_client: class_manager_client.clone(), }); let storage_reader = Arc::new(storage_reader); let storage_writer = Box::new(storage_writer); diff --git a/crates/starknet_batcher/src/block_builder.rs b/crates/starknet_batcher/src/block_builder.rs index 89b89aadea1..7982d27f6f8 100644 --- a/crates/starknet_batcher/src/block_builder.rs +++ b/crates/starknet_batcher/src/block_builder.rs @@ -32,9 +32,11 @@ use starknet_api::execution_resources::GasAmount; use starknet_api::state::ThinStateDiff; use starknet_api::transaction::TransactionHash; use starknet_batcher_types::batcher_types::ProposalCommitment; +use starknet_class_manager_types::SharedClassManagerClient; use thiserror::Error; use tracing::{debug, error, info, trace}; +use crate::reader_with_class_manager::ReaderWithClassManager; use crate::transaction_executor::TransactionExecutorTrait; use crate::transaction_provider::{NextTxs, TransactionProvider, TransactionProviderError}; @@ -344,13 +346,14 @@ pub struct BlockBuilderFactory { pub block_builder_config: BlockBuilderConfig, pub storage_reader: StorageReader, pub contract_class_manager: ContractClassManager, + pub class_manager_client: SharedClassManagerClient, } impl BlockBuilderFactory { fn preprocess_and_create_transaction_executor( &self, block_metadata: BlockMetadata, - ) -> BlockBuilderResult> { + ) -> BlockBuilderResult>> { let height = block_metadata.block_info.block_number; let block_builder_config = self.block_builder_config.clone(); let versioned_constants = VersionedConstants::get_versioned_constants( @@ -363,11 +366,13 @@ impl BlockBuilderFactory { block_builder_config.bouncer_config, ); - let state_reader = PapyrusReader::new( + let papyrus_state_reader = PapyrusReader::new( self.storage_reader.clone(), height, self.contract_class_manager.clone(), ); + let state_reader = + ReaderWithClassManager::new(papyrus_state_reader, self.class_manager_client.clone()); let executor = TransactionExecutor::pre_process_and_create( state_reader, diff --git a/crates/starknet_batcher/src/lib.rs b/crates/starknet_batcher/src/lib.rs index 2d08e46e30c..ae7e7e99e4d 100644 --- a/crates/starknet_batcher/src/lib.rs +++ b/crates/starknet_batcher/src/lib.rs @@ -7,6 +7,9 @@ mod block_builder_test; pub mod communication; pub mod config; mod metrics; +mod reader_with_class_manager; +#[cfg(test)] +mod reader_with_class_manager_test; #[cfg(test)] mod test_utils; mod transaction_executor; diff --git a/crates/starknet_batcher/src/reader_with_class_manager.rs b/crates/starknet_batcher/src/reader_with_class_manager.rs new file mode 100644 index 00000000000..9cbe838376c --- /dev/null +++ b/crates/starknet_batcher/src/reader_with_class_manager.rs @@ -0,0 +1,60 @@ +use blockifier::execution::contract_class::RunnableCompiledClass; +use blockifier::state::errors::StateError; +use blockifier::state::state_api::{StateReader, StateResult}; +use futures::executor::block_on; +use starknet_api::contract_class::ContractClass; +use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, Nonce}; +use starknet_api::state::StorageKey; +use starknet_class_manager_types::SharedClassManagerClient; +use starknet_types_core::felt::Felt; + +pub struct ReaderWithClassManager { + state_reader: S, + class_manager_client: SharedClassManagerClient, +} + +impl ReaderWithClassManager { + pub fn new(state_reader: S, class_manager_client: SharedClassManagerClient) -> Self { + Self { state_reader, class_manager_client } + } +} + +impl StateReader for ReaderWithClassManager { + fn get_storage_at( + &self, + contract_address: ContractAddress, + key: StorageKey, + ) -> StateResult { + self.state_reader.get_storage_at(contract_address, key) + } + + fn get_nonce_at(&self, contract_address: ContractAddress) -> StateResult { + self.state_reader.get_nonce_at(contract_address) + } + + fn get_class_hash_at(&self, contract_address: ContractAddress) -> StateResult { + self.state_reader.get_class_hash_at(contract_address) + } + + fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult { + let contract_class = block_on(self.class_manager_client.get_executable(class_hash)) + .map_err(|e| StateError::StateReadError(e.to_string()))?; + + match contract_class { + // TODO(noamsp): Remove this once class manager component is implemented. + ContractClass::V0(ref inner) if inner == &Default::default() => { + self.state_reader.get_compiled_class(class_hash) + } + ContractClass::V1(casm_contract_class) => { + Ok(RunnableCompiledClass::V1(casm_contract_class.try_into()?)) + } + ContractClass::V0(deprecated_contract_class) => { + Ok(RunnableCompiledClass::V0(deprecated_contract_class.try_into()?)) + } + } + } + + fn get_compiled_class_hash(&self, class_hash: ClassHash) -> StateResult { + self.state_reader.get_compiled_class_hash(class_hash) + } +} diff --git a/crates/starknet_batcher/src/reader_with_class_manager_test.rs b/crates/starknet_batcher/src/reader_with_class_manager_test.rs new file mode 100644 index 00000000000..eeb73e39a2d --- /dev/null +++ b/crates/starknet_batcher/src/reader_with_class_manager_test.rs @@ -0,0 +1,48 @@ +use std::sync::Arc; + +use blockifier::execution::contract_class::RunnableCompiledClass; +use blockifier::state::state_api::{MockStateReader, StateReader}; +use cairo_lang_starknet_classes::casm_contract_class::CasmContractClass; +use mockall::predicate; +use starknet_api::class_hash; +use starknet_api::contract_class::{ContractClass, SierraVersion}; +use starknet_class_manager_types::MockClassManagerClient; + +use crate::reader_with_class_manager::ReaderWithClassManager; + +#[tokio::test] +async fn test_get_compiled_class() { + let mock_state_state_reader = MockStateReader::new(); + let mut mock_class_manager_client = MockClassManagerClient::new(); + let class_hash = class_hash!("0x2"); + let casm_contract_class = CasmContractClass { + compiler_version: "0.0.0".to_string(), + prime: Default::default(), + bytecode: Default::default(), + bytecode_segment_lengths: Default::default(), + hints: Default::default(), + pythonic_hints: Default::default(), + entry_points_by_type: Default::default(), + }; + let expected_result = casm_contract_class.clone(); + + mock_class_manager_client + .expect_get_executable() + .times(1) + .with(predicate::eq(class_hash)) + .returning(move |_| { + Ok(ContractClass::V1((casm_contract_class.clone(), SierraVersion::default()))) + }); + + let state_reader = + ReaderWithClassManager::new(mock_state_state_reader, Arc::new(mock_class_manager_client)); + + let result = state_reader.get_compiled_class(class_hash).unwrap(); + assert_eq!( + result, + RunnableCompiledClass::V1((expected_result, SierraVersion::default()).try_into().unwrap()) + ); +} + +// TODO(noamsp): Add tests for get_storage_at, get_nonce_at, get_class_hash_at, +// get_compiled_class_hash