Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions crates/starknet_os_flow_tests/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2083,3 +2083,141 @@ async fn test_initial_sierra_gas() {

test_output.perform_default_validations();
}

#[rstest]
#[tokio::test]
async fn test_reverted_call() {
let test_contract = FeatureContract::TestContract(CairoVersion::Cairo1(RunnableCairo1::Casm));
let test_class_hash = get_class_hash_of_feature_contract(test_contract);
let test_contract2 = FeatureContract::TestContract2;
let empty_contract = FeatureContract::Empty(CairoVersion::Cairo0);
let (mut test_manager, [main_contract_address, test_contract2_address, empty_contract_address]) =
TestManager::<DictStateReader>::new_with_default_initial_state([
// The `my_storage_var` cell is initialized as the sum of the ctor args in the
// constructor. This cell is also used in revert tests, so it must be
// initialized to zero.
(test_contract, calldata![Felt::ZERO, Felt::ZERO]),
(test_contract2, calldata![]),
(empty_contract, calldata![]),
])
.await;

// Tests 1+2.

// Tell inner contract to panic.
let to_panic = true;
for (inner_selector, is_meta_tx) in
[("test_revert_helper", false), ("bad_selector", false), ("__execute__", true)]
{
// Test contract call to test_revert_helper.
let calldata = create_calldata(
main_contract_address,
"test_call_contract_revert",
&[
**main_contract_address,
selector_from_name(inner_selector).0,
Felt::TWO,
test_class_hash.0,
to_panic.into(),
is_meta_tx.into(),
],
);
test_manager.add_funded_account_invoke(invoke_tx_args! { calldata });
}

// Test call to cairo0 contracts with 0 and more then 0 entry points.
for (contract, address) in
[(empty_contract, empty_contract_address), (test_contract2, test_contract2_address)]
{
let class_hash = get_class_hash_of_feature_contract(contract);
let calldata = create_calldata(
main_contract_address,
"test_call_contract_revert",
&[
**address,
selector_from_name("bad_selector").0,
Felt::TWO,
class_hash.0,
to_panic.into(),
false.into(),
],
);
test_manager.add_funded_account_invoke(invoke_tx_args! { calldata });

// Test 3:
// - Contract A calls Contract B.
// - Contract B changes the storage value from 0 to 10.
// - Contract A calls Contract C.
// - Contract C changes the storage value from 10 to 17 and raises an exception.
// - Contract A checks that the storage value == 10.
let calldata = create_calldata(
main_contract_address,
"test_revert_with_inner_call_and_reverted_storage",
&[**main_contract_address, test_class_hash.0],
);
test_manager.add_funded_account_invoke(invoke_tx_args! { calldata });
}

// Test 4:
// - Contract A calls Contract B and asserts that the state remains unchanged.
// - Contract B calls Contract C and panics.
// - Contract C modifies the state but does not panic.

// Tell contract C not to panic:
let to_panic = false;
let contract_c_calldata = [test_class_hash.0, to_panic.into()];

// Create calldata recursively.
let contract_b_calldata = [
vec![
**main_contract_address,
selector_from_name("test_revert_helper").0,
contract_c_calldata.len().into(),
],
contract_c_calldata.to_vec(),
]
.concat();
let contract_a_calldata = [
vec![
**main_contract_address,
selector_from_name("middle_revert_contract").0,
contract_b_calldata.len().into(),
],
contract_b_calldata,
vec![false.into()], // is_meta_tx.
]
.concat();

// Call contract A.
let calldata =
create_calldata(main_contract_address, "test_call_contract_revert", &contract_a_calldata);
test_manager.add_funded_account_invoke(invoke_tx_args! { calldata });

// Run the test and assert only the fee token contract and the OS contracts have storage
// updates.
let test_output = test_manager
.execute_test_with_default_block_contexts(&TestParameters {
use_kzg_da: true,
..Default::default()
})
.await;

test_output.perform_default_validations();

let block_hash_contract_address = ContractAddress(
Const::BlockHashContractAddress.fetch_from_os_program().unwrap().try_into().unwrap(),
);
let expected_changed_addresses: HashSet<&ContractAddress> = HashSet::from_iter([
&*STRK_FEE_TOKEN_ADDRESS,
&*ALIAS_CONTRACT_ADDRESS,
&block_hash_contract_address,
]);
let actual_changed_addresses =
test_output.decompressed_state_diff.storage_updates.keys().collect::<HashSet<_>>();
assert!(
actual_changed_addresses.is_subset(&expected_changed_addresses),
"Expected changed addresses are not subset of actual changed addresses: actual changed \
addresses are {actual_changed_addresses:#?}, expected changed addresses are \
{expected_changed_addresses:#?}"
);
}