Skip to content

Commit 64ae454

Browse files
apollo_gateway: return non-boxed state reader_with_compiled_classes from state reader factory (#11231)
1 parent 749abb6 commit 64ae454

File tree

3 files changed

+74
-103
lines changed

3 files changed

+74
-103
lines changed
Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,6 @@
11
use apollo_state_sync_types::communication::StateSyncClientResult;
22
use async_trait::async_trait;
3-
use blockifier::execution::contract_class::RunnableCompiledClass;
4-
use blockifier::state::global_cache::CompiledClasses;
5-
use blockifier::state::state_api::{StateReader as BlockifierStateReader, StateResult};
63
use blockifier::state::state_reader_and_contract_manager::FetchCompiledClasses;
7-
use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, Nonce};
8-
use starknet_api::state::StorageKey;
9-
use starknet_types_core::felt::Felt;
104

115
use crate::gateway_fixed_block_state_reader::GatewayFixedBlockStateReader;
126
#[async_trait]
@@ -25,39 +19,3 @@ pub trait StateReaderFactory: Send + Sync {
2519
// TODO(Arni): Delete this trait, once we replace `dyn GatewayStateReaderWithCompiledClasses` with
2620
// generics.
2721
pub trait GatewayStateReaderWithCompiledClasses: FetchCompiledClasses + Send + Sync {}
28-
29-
impl BlockifierStateReader for Box<dyn GatewayStateReaderWithCompiledClasses> {
30-
fn get_storage_at(
31-
&self,
32-
contract_address: ContractAddress,
33-
key: StorageKey,
34-
) -> StateResult<Felt> {
35-
self.as_ref().get_storage_at(contract_address, key)
36-
}
37-
38-
fn get_nonce_at(&self, contract_address: ContractAddress) -> StateResult<Nonce> {
39-
self.as_ref().get_nonce_at(contract_address)
40-
}
41-
42-
fn get_class_hash_at(&self, contract_address: ContractAddress) -> StateResult<ClassHash> {
43-
self.as_ref().get_class_hash_at(contract_address)
44-
}
45-
46-
fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult<RunnableCompiledClass> {
47-
self.as_ref().get_compiled_class(class_hash)
48-
}
49-
50-
fn get_compiled_class_hash(&self, class_hash: ClassHash) -> StateResult<CompiledClassHash> {
51-
self.as_ref().get_compiled_class_hash(class_hash)
52-
}
53-
}
54-
55-
impl FetchCompiledClasses for Box<dyn GatewayStateReaderWithCompiledClasses> {
56-
fn get_compiled_classes(&self, class_hash: ClassHash) -> StateResult<CompiledClasses> {
57-
self.as_ref().get_compiled_classes(class_hash)
58-
}
59-
60-
fn is_declared(&self, class_hash: ClassHash) -> StateResult<bool> {
61-
self.as_ref().is_declared(class_hash)
62-
}
63-
}

crates/apollo_gateway/src/stateful_transaction_validator.rs

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,6 @@ use crate::state_reader::{GatewayStateReaderWithCompiledClasses, StateReaderFact
4040
#[path = "stateful_transaction_validator_test.rs"]
4141
mod stateful_transaction_validator_test;
4242

43-
type BlockifierStatefulValidator = StatefulValidator<
44-
StateReaderAndContractManager<Box<dyn GatewayStateReaderWithCompiledClasses>>,
45-
>;
46-
4743
#[cfg_attr(test, mockall::automock)]
4844
#[async_trait]
4945
pub trait StatefulTransactionValidatorFactoryTrait: Send + Sync {
@@ -81,11 +77,8 @@ impl<TStateReaderFactory: StateReaderFactory> StatefulTransactionValidatorFactor
8177
e,
8278
)
8379
})?;
84-
// Convert concrete type to trait object.
85-
let boxed_state_reader: Box<dyn GatewayStateReaderWithCompiledClasses> =
86-
Box::new(blockifier_state_reader);
8780
let state_reader_and_contract_manager = StateReaderAndContractManager::new(
88-
boxed_state_reader,
81+
blockifier_state_reader,
8982
self.contract_class_manager.clone(),
9083
Some(GATEWAY_CLASS_CACHE_METRICS),
9184
);
@@ -109,22 +102,31 @@ pub trait StatefulTransactionValidatorTrait: Send {
109102
) -> StatefulTransactionValidatorResult<Nonce>;
110103
}
111104

112-
pub struct StatefulTransactionValidator<TGatewayFixedBlockStateReader: GatewayFixedBlockStateReader>
113-
{
105+
pub struct StatefulTransactionValidator<
106+
TGatewayStateReaderWithCompiledClasses: GatewayStateReaderWithCompiledClasses,
107+
TGatewayFixedBlockStateReader: GatewayFixedBlockStateReader,
108+
> {
114109
config: StatefulTransactionValidatorConfig,
115110
chain_info: ChainInfo,
116111
// Consumed when running the CPU-heavy blockifier validation.
117112
// TODO(Itamar): The whole `StatefulTransactionValidator` is never used after
118113
// `state_reader_and_contract_manager` is taken. Make it non-optional and discard the
119114
// instance after use.
120115
state_reader_and_contract_manager:
121-
Option<StateReaderAndContractManager<Box<dyn GatewayStateReaderWithCompiledClasses>>>,
116+
Option<StateReaderAndContractManager<TGatewayStateReaderWithCompiledClasses>>,
122117
gateway_fixed_block_state_reader: TGatewayFixedBlockStateReader,
123118
}
124119

125120
#[async_trait]
126-
impl<TGatewayFixedBlockStateReader: GatewayFixedBlockStateReader> StatefulTransactionValidatorTrait
127-
for StatefulTransactionValidator<TGatewayFixedBlockStateReader>
121+
impl<TGatewayStateReaderWithCompiledClasses, TGatewayFixedBlockStateReader>
122+
StatefulTransactionValidatorTrait
123+
for StatefulTransactionValidator<
124+
TGatewayStateReaderWithCompiledClasses,
125+
TGatewayFixedBlockStateReader,
126+
>
127+
where
128+
TGatewayStateReaderWithCompiledClasses: GatewayStateReaderWithCompiledClasses + 'static,
129+
TGatewayFixedBlockStateReader: GatewayFixedBlockStateReader,
128130
{
129131
async fn extract_state_nonce_and_run_validations(
130132
&mut self,
@@ -150,14 +152,20 @@ impl<TGatewayFixedBlockStateReader: GatewayFixedBlockStateReader> StatefulTransa
150152
}
151153
}
152154

153-
impl<TGatewayFixedBlockStateReader: GatewayFixedBlockStateReader>
154-
StatefulTransactionValidator<TGatewayFixedBlockStateReader>
155+
impl<TGatewayStateReaderWithCompiledClasses, TGatewayFixedBlockStateReader>
156+
StatefulTransactionValidator<
157+
TGatewayStateReaderWithCompiledClasses,
158+
TGatewayFixedBlockStateReader,
159+
>
160+
where
161+
TGatewayStateReaderWithCompiledClasses: GatewayStateReaderWithCompiledClasses + 'static,
162+
TGatewayFixedBlockStateReader: GatewayFixedBlockStateReader,
155163
{
156164
fn new(
157165
config: StatefulTransactionValidatorConfig,
158166
chain_info: ChainInfo,
159167
state_reader_and_contract_manager: StateReaderAndContractManager<
160-
Box<dyn GatewayStateReaderWithCompiledClasses>,
168+
TGatewayStateReaderWithCompiledClasses,
161169
>,
162170
gateway_fixed_block_state_reader: TGatewayFixedBlockStateReader,
163171
) -> Self {
@@ -171,7 +179,7 @@ impl<TGatewayFixedBlockStateReader: GatewayFixedBlockStateReader>
171179

172180
fn take_state_reader_and_contract_manager(
173181
&mut self,
174-
) -> StateReaderAndContractManager<Box<dyn GatewayStateReaderWithCompiledClasses>> {
182+
) -> StateReaderAndContractManager<TGatewayStateReaderWithCompiledClasses> {
175183
self.state_reader_and_contract_manager.take().expect("Validator was already consumed")
176184
}
177185

@@ -301,8 +309,8 @@ impl<TGatewayFixedBlockStateReader: GatewayFixedBlockStateReader>
301309
tokio::task::spawn_blocking(move || {
302310
cur_span.in_scope(|| {
303311
let state = CachedState::new(state_reader_and_contract_manager);
304-
let mut blockifier = BlockifierStatefulValidator::create(state, block_context);
305-
blockifier.validate(account_tx)
312+
let mut blockifier_validator = StatefulValidator::create(state, block_context);
313+
blockifier_validator.validate(account_tx)
306314
})
307315
})
308316
.await

crates/apollo_gateway/src/stateful_transaction_validator_test.rs

Lines changed: 47 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ use starknet_api::transaction::fields::{AllResourceBounds, ResourceBounds, Valid
2727
use starknet_api::{declare_tx_args, deploy_account_tx_args, invoke_tx_args, nonce};
2828

2929
use crate::gateway_fixed_block_state_reader::MockGatewayFixedBlockStateReader;
30-
use crate::state_reader_test_utils::local_test_state_reader_factory;
30+
use crate::state_reader_test_utils::{local_test_state_reader_factory, TestStateReader};
3131
use crate::stateful_transaction_validator::{
3232
StatefulTransactionValidator,
3333
StatefulTransactionValidatorFactory,
@@ -52,12 +52,13 @@ async fn test_get_nonce_fail_on_extract_state_nonce_and_run_validations() {
5252
});
5353

5454
let mempool_client = Arc::new(MockMempoolClient::new());
55-
let mut stateful_validator = StatefulTransactionValidator {
56-
config: StatefulTransactionValidatorConfig::default(),
57-
chain_info: ChainInfo::create_for_testing(),
58-
state_reader_and_contract_manager: None,
59-
gateway_fixed_block_state_reader: mock_gateway_fixed_block,
60-
};
55+
let mut stateful_validator: StatefulTransactionValidator<TestStateReader, _> =
56+
StatefulTransactionValidator {
57+
config: StatefulTransactionValidatorConfig::default(),
58+
chain_info: ChainInfo::create_for_testing(),
59+
state_reader_and_contract_manager: None,
60+
gateway_fixed_block_state_reader: mock_gateway_fixed_block,
61+
};
6162

6263
let result = stateful_validator
6364
.extract_state_nonce_and_run_validations(&executable_tx, mempool_client)
@@ -103,12 +104,13 @@ async fn test_run_pre_validation_checks(
103104
let mut mock_gateway_fixed_block = MockGatewayFixedBlockStateReader::new();
104105
mock_gateway_fixed_block.expect_get_block_info().returning(|| Ok(BlockInfo::default()));
105106

106-
let stateful_validator = StatefulTransactionValidator {
107-
config: StatefulTransactionValidatorConfig::default(),
108-
chain_info: ChainInfo::create_for_testing(),
109-
state_reader_and_contract_manager: None,
110-
gateway_fixed_block_state_reader: mock_gateway_fixed_block,
111-
};
107+
let stateful_validator: StatefulTransactionValidator<TestStateReader, _> =
108+
StatefulTransactionValidator {
109+
config: StatefulTransactionValidatorConfig::default(),
110+
chain_info: ChainInfo::create_for_testing(),
111+
state_reader_and_contract_manager: None,
112+
gateway_fixed_block_state_reader: mock_gateway_fixed_block,
113+
};
112114

113115
let resource_bounds = if zero_gas_fee {
114116
ValidResourceBounds::AllResources(AllResourceBounds {
@@ -203,15 +205,16 @@ async fn test_skip_validate(
203205
.expect_get_nonce()
204206
.with(eq(executable_tx.sender_address()))
205207
.return_once(move |_| Ok(sender_nonce));
206-
let stateful_validator = StatefulTransactionValidator {
207-
config: StatefulTransactionValidatorConfig {
208-
validate_resource_bounds: false,
209-
..Default::default()
210-
},
211-
chain_info: ChainInfo::create_for_testing(),
212-
state_reader_and_contract_manager: None,
213-
gateway_fixed_block_state_reader: mock_gateway_fixed_block,
214-
};
208+
let stateful_validator: StatefulTransactionValidator<TestStateReader, _> =
209+
StatefulTransactionValidator {
210+
config: StatefulTransactionValidatorConfig {
211+
validate_resource_bounds: false,
212+
..Default::default()
213+
},
214+
chain_info: ChainInfo::create_for_testing(),
215+
state_reader_and_contract_manager: None,
216+
gateway_fixed_block_state_reader: mock_gateway_fixed_block,
217+
};
215218

216219
let skip_validate = stateful_validator
217220
.run_pre_validation_checks(&executable_tx, sender_nonce, mempool_client)
@@ -305,16 +308,17 @@ async fn validate_resource_bounds(
305308
})
306309
});
307310

308-
let stateful_validator = StatefulTransactionValidator {
309-
config: StatefulTransactionValidatorConfig {
310-
validate_resource_bounds: true,
311-
min_gas_price_percentage,
312-
..Default::default()
313-
},
314-
chain_info: ChainInfo::create_for_testing(),
315-
state_reader_and_contract_manager: None,
316-
gateway_fixed_block_state_reader: mock_gateway_fixed_block,
317-
};
311+
let stateful_validator: StatefulTransactionValidator<TestStateReader, _> =
312+
StatefulTransactionValidator {
313+
config: StatefulTransactionValidatorConfig {
314+
validate_resource_bounds: true,
315+
min_gas_price_percentage,
316+
..Default::default()
317+
},
318+
chain_info: ChainInfo::create_for_testing(),
319+
state_reader_and_contract_manager: None,
320+
gateway_fixed_block_state_reader: mock_gateway_fixed_block,
321+
};
318322

319323
let result = stateful_validator.validate_resource_bounds(&executable_tx).await;
320324
assert_eq!(result, expected_result);
@@ -415,16 +419,17 @@ async fn run_pre_validation_checks_test(
415419
.expect_get_nonce()
416420
.with(eq(executable_tx.sender_address()))
417421
.return_once(move |_| Ok(account_nonce));
418-
let stateful_validator = StatefulTransactionValidator {
419-
config: StatefulTransactionValidatorConfig {
420-
max_allowed_nonce_gap,
421-
validate_resource_bounds: false,
422-
..Default::default()
423-
},
424-
chain_info: ChainInfo::create_for_testing(),
425-
state_reader_and_contract_manager: None,
426-
gateway_fixed_block_state_reader: mock_gateway_fixed_block,
427-
};
422+
let stateful_validator: StatefulTransactionValidator<TestStateReader, _> =
423+
StatefulTransactionValidator {
424+
config: StatefulTransactionValidatorConfig {
425+
max_allowed_nonce_gap,
426+
validate_resource_bounds: false,
427+
..Default::default()
428+
},
429+
chain_info: ChainInfo::create_for_testing(),
430+
state_reader_and_contract_manager: None,
431+
gateway_fixed_block_state_reader: mock_gateway_fixed_block,
432+
};
428433

429434
let mut mempool_client = MockMempoolClient::new();
430435
mempool_client.expect_validate_tx().returning(|_| Ok(()));

0 commit comments

Comments
 (0)