Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/starknet_gateway/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ papyrus_network_types = { path = "../papyrus_network_types", features = ["testin
papyrus_test_utils.path = "../papyrus_test_utils"
pretty_assertions.workspace = true
rstest.workspace = true
starknet_class_manager_types = { path = "../starknet_class_manager_types", features = ["testing"] }
starknet_mempool.path = "../starknet_mempool"
starknet_mempool_types = { path = "../starknet_mempool_types", features = ["testing"] }
starknet_state_sync_types = { path = "../starknet_state_sync_types", features = ["testing"] }
Expand Down
5 changes: 4 additions & 1 deletion crates/starknet_gateway/src/gateway.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,10 @@ pub fn create_gateway(
mempool_client: SharedMempoolClient,
class_manager_client: SharedClassManagerClient,
) -> Gateway {
let state_reader_factory = Arc::new(SyncStateReaderFactory { shared_state_sync_client });
let state_reader_factory = Arc::new(SyncStateReaderFactory {
shared_state_sync_client,
class_manager_client: class_manager_client.clone(),
});
let transaction_converter =
TransactionConverter::new(class_manager_client, config.chain_info.chain_id.clone());

Expand Down
28 changes: 22 additions & 6 deletions crates/starknet_gateway/src/sync_state_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use starknet_api::contract_class::ContractClass;
use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, Nonce};
use starknet_api::data_availability::L1DataAvailabilityMode;
use starknet_api::state::StorageKey;
use starknet_class_manager_types::SharedClassManagerClient;
use starknet_state_sync_types::communication::{
SharedStateSyncClient,
StateSyncClientError,
Expand All @@ -20,14 +21,16 @@ use crate::state_reader::{MempoolStateReader, StateReaderFactory};
pub(crate) struct SyncStateReader {
block_number: BlockNumber,
state_sync_client: SharedStateSyncClient,
class_manager_client: SharedClassManagerClient,
}

impl SyncStateReader {
pub fn from_number(
state_sync_client: SharedStateSyncClient,
class_manager_client: SharedClassManagerClient,
block_number: BlockNumber,
) -> Self {
Self { block_number, state_sync_client }
Self { block_number, state_sync_client, class_manager_client }
}
}

Expand Down Expand Up @@ -99,10 +102,17 @@ impl BlockifierStateReader for SyncStateReader {
}

fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult<RunnableCompiledClass> {
let contract_class = block_on(
self.state_sync_client.get_compiled_class_deprecated(self.block_number, class_hash),
)
.map_err(|e| StateError::StateReadError(e.to_string()))?;
let contract_class = block_on(self.class_manager_client.get_executable(class_hash))
.map_err(|e| StateError::StateReadError(e.to_string()))?;

// TODO(noamsp): Remove this once class manager component is implemented.
let contract_class = match contract_class {
ContractClass::V0(ref inner) if inner == &Default::default() => block_on(
self.state_sync_client.get_compiled_class_deprecated(self.block_number, class_hash),
)
.map_err(|e| StateError::StateReadError(e.to_string()))?,
_ => contract_class,
};

match contract_class {
ContractClass::V1(casm_contract_class) => {
Expand Down Expand Up @@ -134,6 +144,7 @@ impl BlockifierStateReader for SyncStateReader {

pub struct SyncStateReaderFactory {
pub shared_state_sync_client: SharedStateSyncClient,
pub class_manager_client: SharedClassManagerClient,
}

impl StateReaderFactory for SyncStateReaderFactory {
Expand All @@ -146,11 +157,16 @@ impl StateReaderFactory for SyncStateReaderFactory {

Ok(Box::new(SyncStateReader::from_number(
self.shared_state_sync_client.clone(),
self.class_manager_client.clone(),
latest_block_number,
)))
}

fn get_state_reader(&self, block_number: BlockNumber) -> Box<dyn MempoolStateReader> {
Box::new(SyncStateReader::from_number(self.shared_state_sync_client.clone(), block_number))
Box::new(SyncStateReader::from_number(
self.shared_state_sync_client.clone(),
self.class_manager_client.clone(),
block_number,
))
}
}
51 changes: 36 additions & 15 deletions crates/starknet_gateway/src/sync_state_reader_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use starknet_api::contract_class::{ContractClass, SierraVersion};
use starknet_api::core::SequencerContractAddress;
use starknet_api::data_availability::L1DataAvailabilityMode;
use starknet_api::{class_hash, contract_address, felt, nonce, storage_key};
use starknet_class_manager_types::MockClassManagerClient;
use starknet_state_sync_types::communication::MockStateSyncClient;
use starknet_state_sync_types::state_sync_types::SyncBlock;

Expand All @@ -27,6 +28,7 @@ use crate::sync_state_reader::SyncStateReader;
#[tokio::test]
async fn test_get_block_info() {
let mut mock_state_sync_client = MockStateSyncClient::new();
let mock_class_manager_client = MockClassManagerClient::new();
let block_number = BlockNumber(1);
let block_timestamp = BlockTimestamp(2);
let sequencer_address = contract_address!("0x3");
Expand Down Expand Up @@ -55,8 +57,11 @@ async fn test_get_block_info() {
},
);

let state_sync_reader =
SyncStateReader::from_number(Arc::new(mock_state_sync_client), block_number);
let state_sync_reader = SyncStateReader::from_number(
Arc::new(mock_state_sync_client),
Arc::new(mock_class_manager_client),
block_number,
);
let result = state_sync_reader.get_block_info().unwrap();

assert_eq!(
Expand Down Expand Up @@ -92,6 +97,7 @@ async fn test_get_block_info() {
#[tokio::test]
async fn test_get_storage_at() {
let mut mock_state_sync_client = MockStateSyncClient::new();
let mock_class_manager_client = MockClassManagerClient::new();
let block_number = BlockNumber(1);
let contract_address = contract_address!("0x2");
let storage_key = storage_key!("0x3");
Expand All @@ -106,8 +112,11 @@ async fn test_get_storage_at() {
)
.returning(move |_, _, _| Ok(value));

let state_sync_reader =
SyncStateReader::from_number(Arc::new(mock_state_sync_client), block_number);
let state_sync_reader = SyncStateReader::from_number(
Arc::new(mock_state_sync_client),
Arc::new(mock_class_manager_client),
block_number,
);

let result = state_sync_reader.get_storage_at(contract_address, storage_key).unwrap();
assert_eq!(result, value);
Expand All @@ -116,6 +125,7 @@ async fn test_get_storage_at() {
#[tokio::test]
async fn test_get_nonce_at() {
let mut mock_state_sync_client = MockStateSyncClient::new();
let mock_class_manager_client = MockClassManagerClient::new();
let block_number = BlockNumber(1);
let contract_address = contract_address!("0x2");
let expected_result = nonce!(0x3);
Expand All @@ -126,8 +136,11 @@ async fn test_get_nonce_at() {
.with(predicate::eq(block_number), predicate::eq(contract_address))
.returning(move |_, _| Ok(expected_result));

let state_sync_reader =
SyncStateReader::from_number(Arc::new(mock_state_sync_client), block_number);
let state_sync_reader = SyncStateReader::from_number(
Arc::new(mock_state_sync_client),
Arc::new(mock_class_manager_client),
block_number,
);

let result = state_sync_reader.get_nonce_at(contract_address).unwrap();
assert_eq!(result, expected_result);
Expand All @@ -136,6 +149,7 @@ async fn test_get_nonce_at() {
#[tokio::test]
async fn test_get_class_hash_at() {
let mut mock_state_sync_client = MockStateSyncClient::new();
let mock_class_manager_client = MockClassManagerClient::new();
let block_number = BlockNumber(1);
let contract_address = contract_address!("0x2");
let expected_result = class_hash!("0x3");
Expand All @@ -146,16 +160,20 @@ async fn test_get_class_hash_at() {
.with(predicate::eq(block_number), predicate::eq(contract_address))
.returning(move |_, _| Ok(expected_result));

let state_sync_reader =
SyncStateReader::from_number(Arc::new(mock_state_sync_client), block_number);
let state_sync_reader = SyncStateReader::from_number(
Arc::new(mock_state_sync_client),
Arc::new(mock_class_manager_client),
block_number,
);

let result = state_sync_reader.get_class_hash_at(contract_address).unwrap();
assert_eq!(result, expected_result);
}

#[tokio::test]
async fn test_get_compiled_class() {
let mut mock_state_sync_client = MockStateSyncClient::new();
let mock_state_sync_client = MockStateSyncClient::new();
let mut mock_class_manager_client = MockClassManagerClient::new();
let block_number = BlockNumber(1);
let class_hash = class_hash!("0x2");
let casm_contract_class = CasmContractClass {
Expand All @@ -169,16 +187,19 @@ async fn test_get_compiled_class() {
};
let expected_result = casm_contract_class.clone();

mock_state_sync_client
.expect_get_compiled_class_deprecated()
mock_class_manager_client
.expect_get_executable()
.times(1)
.with(predicate::eq(block_number), predicate::eq(class_hash))
.returning(move |_, _| {
.with(predicate::eq(class_hash))
.returning(move |_| {
Ok(ContractClass::V1((casm_contract_class.clone(), SierraVersion::default())))
});

let state_sync_reader =
SyncStateReader::from_number(Arc::new(mock_state_sync_client), block_number);
let state_sync_reader = SyncStateReader::from_number(
Arc::new(mock_state_sync_client),
Arc::new(mock_class_manager_client),
block_number,
);

let result = state_sync_reader.get_compiled_class(class_hash).unwrap();
assert_eq!(
Expand Down