diff --git a/crates/starknet_os_flow_tests/src/test_manager.rs b/crates/starknet_os_flow_tests/src/test_manager.rs index 029faf4c4a4..d0c1b40cac9 100644 --- a/crates/starknet_os_flow_tests/src/test_manager.rs +++ b/crates/starknet_os_flow_tests/src/test_manager.rs @@ -36,7 +36,12 @@ use starknet_api::test_utils::invoke::{invoke_tx, InvokeTxArgs}; use starknet_api::test_utils::{NonceManager, CHAIN_ID_FOR_TESTS}; use starknet_api::transaction::fields::{Calldata, Tip}; use starknet_api::transaction::MessageToL1; -use starknet_committer::block_committer::input::{IsSubset, StarknetStorageKey, StateDiff}; +use starknet_committer::block_committer::input::{ + IsSubset, + StarknetStorageKey, + StarknetStorageValue, + StateDiff, +}; use starknet_committer::db::facts_db::FactsDb; use starknet_os::hints::hint_implementation::state_diff_encryption::utils::compute_public_keys; use starknet_os::io::os_input::{ @@ -181,6 +186,27 @@ impl OsTestOutput { ); } + #[track_caller] + pub(crate) fn assert_storage_diff_eq( + &self, + contract_address: ContractAddress, + storage_updates: HashMap, + ) { + assert_eq!( + self.decompressed_state_diff + .storage_updates + .get(&contract_address) + .unwrap_or(&HashMap::default()), + &storage_updates + .into_iter() + .map(|(key, value)| ( + StarknetStorageKey(key.try_into().unwrap()), + StarknetStorageValue(value) + )) + .collect::>() + ); + } + fn perform_global_validations(&self) { // TODO(Dori): Implement global validations for the OS test output. diff --git a/crates/starknet_os_flow_tests/src/tests.rs b/crates/starknet_os_flow_tests/src/tests.rs index 1b5df598c2a..a8b64e85322 100644 --- a/crates/starknet_os_flow_tests/src/tests.rs +++ b/crates/starknet_os_flow_tests/src/tests.rs @@ -3,6 +3,7 @@ use std::sync::{Arc, LazyLock}; use blockifier::abi::constants::STORED_BLOCK_HASH_BUFFER; use blockifier::blockifier_versioned_constants::VersionedConstants; +use blockifier::test_utils::contracts::FeatureContractTrait; use blockifier::test_utils::dict_state_reader::DictStateReader; use blockifier::test_utils::ALIAS_CONTRACT_ADDRESS; use blockifier::transaction::test_utils::ExpectedExecutionInfo; @@ -15,7 +16,7 @@ use rstest::rstest; use starknet_api::abi::abi_utils::{get_storage_var_address, selector_from_name}; use starknet_api::block::{BlockInfo, BlockNumber, BlockTimestamp, GasPrice}; use starknet_api::contract_class::compiled_class_hash::{HashVersion, HashableCompiledClass}; -use starknet_api::contract_class::{ClassInfo, ContractClass}; +use starknet_api::contract_class::{ClassInfo, ContractClass, SierraVersion}; use starknet_api::core::{ calculate_contract_address, ClassHash, @@ -2221,3 +2222,106 @@ async fn test_reverted_call() { {expected_changed_addresses:#?}" ); } + +/// Tests that the OS correctly handles calls between Cairo 1.0 contracts that count resources by +/// cairo steps and sierra gas. +#[rstest] +#[tokio::test] +async fn test_resources_type() { + let test_contract = FeatureContract::TestContract(CairoVersion::Cairo1(RunnableCairo1::Casm)); + let (mut test_manager, [sierra_gas_contract_address]) = + TestManager::::new_with_default_initial_state([( + test_contract, + calldata![Felt::ZERO, Felt::ZERO], + )]) + .await; + + // Define an updated Cairo 1.0 contract by overriding the encoded sierra version. + let mut cairo_steps_contract_sierra = test_contract.get_sierra(); + let min_sierra_version = + VersionedConstants::latest_constants().min_sierra_version_for_sierra_gas.clone(); + let (old_major, old_minor, old_patch) = (1, 5, 0); + let old_sierra_version = SierraVersion::new(old_major, old_minor, old_patch); + assert!(old_sierra_version < min_sierra_version); + + // Override version. + assert!(cairo_steps_contract_sierra.get_sierra_version().unwrap() >= min_sierra_version); + cairo_steps_contract_sierra.sierra_program[0] = Felt::from(old_major); + cairo_steps_contract_sierra.sierra_program[1] = Felt::from(old_minor); + cairo_steps_contract_sierra.sierra_program[2] = Felt::from(old_patch); + assert_eq!(cairo_steps_contract_sierra.get_sierra_version().unwrap(), old_sierra_version); + + // Declare it. + let cairo_steps_class_hash = cairo_steps_contract_sierra.calculate_class_hash(); + let compiled_class_hash = test_contract.get_compiled_class_hash(&HashVersion::V2); + let declare_tx_args = declare_tx_args! { + sender_address: *FUNDED_ACCOUNT_ADDRESS, + class_hash: cairo_steps_class_hash, + compiled_class_hash, + resource_bounds: *NON_TRIVIAL_RESOURCE_BOUNDS, + nonce: test_manager.next_nonce(*FUNDED_ACCOUNT_ADDRESS), + }; + let account_declare_tx = declare_tx(declare_tx_args); + let class_info = match test_contract.get_class() { + ContractClass::V0(_) => panic!("Expected Cairo 1.0 contract"), + ContractClass::V1((contract_class, _sierra_version)) => ClassInfo { + contract_class: ContractClass::V1((contract_class, old_sierra_version.clone())), + sierra_program_length: cairo_steps_contract_sierra.sierra_program.len(), + abi_length: cairo_steps_contract_sierra.abi.len(), + sierra_version: old_sierra_version, + }, + }; + let tx = + DeclareTransaction::create(account_declare_tx, class_info, &CHAIN_ID_FOR_TESTS).unwrap(); + test_manager.add_cairo1_declare_tx(tx, &cairo_steps_contract_sierra); + + // Deploy it. + let (deploy_tx, cairo_steps_contract_address) = get_deploy_contract_tx_and_address_with_salt( + cairo_steps_class_hash, + calldata![Felt::ZERO, Felt::ZERO], + test_manager.next_nonce(*FUNDED_ACCOUNT_ADDRESS), + *NON_TRIVIAL_RESOURCE_BOUNDS, + ContractAddressSalt(Felt::ZERO), + ); + test_manager.add_invoke_tx(deploy_tx, None); + + // Test recursive calling. + let (key, value) = (Felt::from(123), Felt::from(45)); + let calldata_0 = vec![ + **cairo_steps_contract_address, + selector_from_name("test_call_contract").0, + Felt::from(5), // Outer calldata length. + **sierra_gas_contract_address, + selector_from_name("test_storage_write").0, + Felt::TWO, // Inner calldata length. + key, + value, + ]; + let calldata_1 = vec![ + **sierra_gas_contract_address, + selector_from_name("test_storage_write").0, + Felt::TWO, + key + Felt::ONE, + value + Felt::ONE, + ]; + let calldata = create_calldata( + sierra_gas_contract_address, + "test_call_two_contracts", + &[calldata_0, calldata_1].concat(), + ); + test_manager.add_funded_account_invoke(invoke_tx_args! { calldata }); + + // Run test and check storage updates. + let test_output = test_manager + .execute_test_with_default_block_contexts(&TestParameters { + use_kzg_da: true, + ..Default::default() + }) + .await; + + let expected_storage_updates = + HashMap::from([(key, value), (key + Felt::ONE, value + Felt::ONE)]); + test_output.perform_default_validations(); + test_output.assert_storage_diff_eq(cairo_steps_contract_address, HashMap::default()); + test_output.assert_storage_diff_eq(sierra_gas_contract_address, expected_storage_updates); +}