Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 86 additions & 21 deletions crates/starknet_os_flow_tests/src/test_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use blockifier::context::{BlockContext, ChainInfo, FeeTokenAddresses};
use blockifier::state::cached_state::{CommitmentStateDiff, StateMaps};
use blockifier::state::stateful_compression_test_utils::decompress;
use blockifier::test_utils::ALIAS_CONTRACT_ADDRESS;
use blockifier::transaction::objects::TransactionExecutionInfo;
use blockifier::transaction::transaction_execution::Transaction as BlockifierTransaction;
use blockifier_test_utils::contracts::FeatureContract;
use itertools::Itertools;
Expand Down Expand Up @@ -90,13 +91,18 @@ pub(crate) struct TestParameters {
pub(crate) private_keys: Option<Vec<Felt>>,
}

pub(crate) struct FlowTestTx {
tx: BlockifierTransaction,
expected_revert_reason: Option<String>,
}

/// Manages the execution of flow tests by maintaining the initial state and transactions.
pub(crate) struct TestManager<S: FlowTestState> {
pub(crate) initial_state: InitialState<S>,
pub(crate) nonce_manager: NonceManager,
pub(crate) execution_contracts: OsExecutionContracts,

per_block_transactions: Vec<Vec<BlockifierTransaction>>,
per_block_transactions: Vec<Vec<FlowTestTx>>,
}

pub(crate) struct OsTestExpectedValues {
Expand Down Expand Up @@ -273,7 +279,7 @@ impl<S: FlowTestState> TestManager<S> {
self.per_block_transactions.push(vec![]);
}

fn last_block_txs_mut(&mut self) -> &mut Vec<BlockifierTransaction> {
fn last_block_txs_mut(&mut self) -> &mut Vec<FlowTestTx> {
self.per_block_transactions
.last_mut()
.expect("Always initialized with at least one tx list (at least one block).")
Expand All @@ -289,9 +295,12 @@ impl<S: FlowTestState> TestManager<S> {
else {
panic!("Expected a V1 contract class");
};
self.last_block_txs_mut().push(BlockifierTransaction::new_for_sequencing(
StarknetApiTransaction::Account(AccountTransaction::Declare(tx)),
));
self.last_block_txs_mut().push(FlowTestTx {
tx: BlockifierTransaction::new_for_sequencing(StarknetApiTransaction::Account(
AccountTransaction::Declare(tx),
)),
expected_revert_reason: None,
});

self.execution_contracts
.declared_class_hash_to_component_hashes
Expand All @@ -303,14 +312,29 @@ impl<S: FlowTestState> TestManager<S> {
.insert(compiled_class_hash, casm.clone());
}

pub(crate) fn add_invoke_tx(&mut self, tx: InvokeTransaction) {
self.last_block_txs_mut().push(BlockifierTransaction::new_for_sequencing(
StarknetApiTransaction::Account(AccountTransaction::Invoke(tx)),
));
pub(crate) fn add_invoke_tx(
&mut self,
tx: InvokeTransaction,
expected_revert_reason: Option<String>,
) {
self.last_block_txs_mut().push(FlowTestTx {
tx: BlockifierTransaction::new_for_sequencing(StarknetApiTransaction::Account(
AccountTransaction::Invoke(tx),
)),
expected_revert_reason,
});
}

pub(crate) fn add_invoke_tx_from_args(&mut self, args: InvokeTxArgs, chain_id: &ChainId) {
self.add_invoke_tx(InvokeTransaction::create(invoke_tx(args), chain_id).unwrap());
pub(crate) fn add_invoke_tx_from_args(
&mut self,
args: InvokeTxArgs,
chain_id: &ChainId,
revert_reason: Option<String>,
) {
self.add_invoke_tx(
InvokeTransaction::create(invoke_tx(args), chain_id).unwrap(),
revert_reason,
);
}

/// Similar to `add_invoke_tx_from_args`, but with the sender address set to the funded account,
Expand All @@ -326,28 +350,41 @@ impl<S: FlowTestState> TestManager<S> {
..additional_args
},
&CHAIN_ID_FOR_TESTS,
None,
);
}

pub(crate) fn add_cairo0_declare_tx(&mut self, tx: DeclareTransaction, class_hash: ClassHash) {
let ContractClass::V0(class) = tx.class_info.contract_class.clone() else {
panic!("Expected a V0 contract class");
};
self.last_block_txs_mut().push(BlockifierTransaction::new_for_sequencing(
StarknetApiTransaction::Account(AccountTransaction::Declare(tx)),
));
self.last_block_txs_mut().push(FlowTestTx {
tx: BlockifierTransaction::new_for_sequencing(StarknetApiTransaction::Account(
AccountTransaction::Declare(tx),
)),
expected_revert_reason: None,
});
self.execution_contracts.executed_contracts.deprecated_contracts.insert(class_hash, class);
}

pub(crate) fn add_deploy_account_tx(&mut self, tx: DeployAccountTransaction) {
self.last_block_txs_mut().push(BlockifierTransaction::new_for_sequencing(
StarknetApiTransaction::Account(AccountTransaction::DeployAccount(tx)),
));
self.last_block_txs_mut().push(FlowTestTx {
tx: BlockifierTransaction::new_for_sequencing(StarknetApiTransaction::Account(
AccountTransaction::DeployAccount(tx),
)),
expected_revert_reason: None,
});
}

pub(crate) fn add_l1_handler_tx(&mut self, tx: L1HandlerTransaction) {
self.last_block_txs_mut()
.push(BlockifierTransaction::new_for_sequencing(StarknetApiTransaction::L1Handler(tx)));
pub(crate) fn add_l1_handler_tx(
&mut self,
tx: L1HandlerTransaction,
expected_revert_reason: Option<String>,
) {
self.last_block_txs_mut().push(FlowTestTx {
tx: BlockifierTransaction::new_for_sequencing(StarknetApiTransaction::L1Handler(tx)),
expected_revert_reason,
});
}

/// Executes the test using default block contexts, starting from the given block number.
Expand Down Expand Up @@ -422,6 +459,28 @@ impl<S: FlowTestState> TestManager<S> {
first_use_kzg_da
}

/// Verifies all the execution outputs are as expected w.r.t. revert reasons.
fn verify_execution_outputs(
revert_reasons: &[Option<String>],
execution_outputs: &[(TransactionExecutionInfo, StateMaps)],
) {
assert_eq!(revert_reasons.len(), execution_outputs.len());
for (revert_reason, (execution_info, _)) in
revert_reasons.iter().zip(execution_outputs.iter())
{
if let Some(revert_reason) = revert_reason {
let actual_revert_reason =
execution_info.revert_error.as_ref().unwrap().to_string();
assert!(
actual_revert_reason.contains(revert_reason),
"Expected '{revert_reason}' to be in revert string:\n'{actual_revert_reason}'"
);
} else {
assert!(execution_info.revert_error.is_none());
}
}
}

/// Decompresses the state diff from the OS output using the given OS output, state and alias
/// keys.
fn get_decompressed_state_diff(
Expand Down Expand Up @@ -485,13 +544,19 @@ impl<S: FlowTestState> TestManager<S> {
"use_kzg_da flag in block contexts must match the test parameter."
);
let mut alias_keys = HashSet::new();
for (block_txs, block_context) in per_block_txs.into_iter().zip(block_contexts.into_iter())
for (block_txs_with_reason, block_context) in
per_block_txs.into_iter().zip(block_contexts.into_iter())
{
let (block_txs, revert_reasons): (Vec<_>, Vec<_>) = block_txs_with_reason
.into_iter()
.map(|flow_test_tx| (flow_test_tx.tx, flow_test_tx.expected_revert_reason))
.unzip();
// Clone the block info for later use.
let block_info = block_context.block_info().clone();
// Execute the transactions.
let ExecutionOutput { execution_outputs, block_summary, mut final_state } =
execute_transactions(state, &block_txs, block_context);
Self::verify_execution_outputs(&revert_reasons, &execution_outputs);
let extended_state_diff = final_state.cache.borrow().extended_state_diff();
// Update the wrapped state.
let state_diff = final_state.to_state_diff().unwrap();
Expand Down
40 changes: 27 additions & 13 deletions crates/starknet_os_flow_tests/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,13 +226,18 @@ async fn trivial_diff_scenario(
/// 1. All storage changes made before the revert are properly rolled back.
/// 2. The transaction fee is still deducted from the caller's account.
#[rstest]
#[case::cairo0(
FeatureContract::TestContract(CairoVersion::Cairo0),
"ASSERT_EQ instruction failed: 1 != 0"
)]
#[case::cairo1(
FeatureContract::TestContract(CairoVersion::Cairo1(RunnableCairo1::Casm)),
"Panic for revert"
)]
#[tokio::test]
async fn test_reverted_invoke_tx(
#[values(
FeatureContract::TestContract(CairoVersion::Cairo0),
FeatureContract::TestContract(CairoVersion::Cairo1(RunnableCairo1::Casm))
)]
test_contract: FeatureContract,
#[case] test_contract: FeatureContract,
#[case] revert_reason: &str,
) {
let (use_kzg_da, full_output) = (true, false);

Expand All @@ -247,10 +252,14 @@ async fn test_reverted_invoke_tx(
let invoke_tx_args = invoke_tx_args! {
sender_address: *FUNDED_ACCOUNT_ADDRESS,
nonce: test_manager.next_nonce(*FUNDED_ACCOUNT_ADDRESS),
calldata: create_calldata(test_contract_address, "write_and_revert", &[]),
calldata: create_calldata(test_contract_address, "write_and_revert", &[Felt::ONE, Felt::TWO]),
resource_bounds: *NON_TRIVIAL_RESOURCE_BOUNDS,
};
test_manager.add_invoke_tx_from_args(invoke_tx_args, &CHAIN_ID_FOR_TESTS);
test_manager.add_invoke_tx_from_args(
invoke_tx_args,
&CHAIN_ID_FOR_TESTS,
Some(revert_reason.to_string()),
);

// Execute the test.
let test_output = test_manager
Expand Down Expand Up @@ -318,13 +327,18 @@ async fn test_encrypted_state_diff(
/// Verifies that when an L1 handler modifies storage and then reverts, all storage changes made
/// before the revert are properly rolled back.
#[rstest]
#[case::cairo0(
FeatureContract::TestContract(CairoVersion::Cairo0),
"ASSERT_EQ instruction failed: 1 != 0."
)]
#[case::cairo1(
FeatureContract::TestContract(CairoVersion::Cairo1(RunnableCairo1::Casm)),
"revert in l1 handler"
)]
#[tokio::test]
async fn test_reverted_l1_handler_tx(
#[values(
FeatureContract::TestContract(CairoVersion::Cairo0),
FeatureContract::TestContract(CairoVersion::Cairo1(RunnableCairo1::Casm))
)]
test_contract: FeatureContract,
#[case] test_contract: FeatureContract,
#[case] revert_reason: &str,
) {
let (mut test_manager, [test_contract_address]) =
TestManager::<DictStateReader>::new_with_default_initial_state([(
Expand All @@ -347,7 +361,7 @@ async fn test_reverted_l1_handler_tx(
Fee(1_000_000),
)
.unwrap();
test_manager.add_l1_handler_tx(tx);
test_manager.add_l1_handler_tx(tx, Some(revert_reason.to_string()));

let test_output =
test_manager.execute_test_with_default_block_contexts(&TestParameters::default()).await;
Expand Down