Skip to content

Commit 4ec0035

Browse files
committed
apollo_staking: align the mock staking contract interface with the real contract
1 parent 803470d commit 4ec0035

File tree

4 files changed

+138
-51
lines changed

4 files changed

+138
-51
lines changed

crates/apollo_staking/src/contract_types.rs

Lines changed: 58 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use starknet_api::staking::StakingWeight;
44
use starknet_types_core::felt::Felt;
55
use thiserror::Error;
66

7+
#[cfg(test)]
78
use crate::committee_provider::Staker;
89

910
pub(crate) const GET_STAKERS_ENTRY_POINT: &str = "get_stakers";
@@ -14,46 +15,89 @@ pub(crate) const EPOCH_LENGTH: u64 = 100; // Number of heights in an epoch.
1415
#[derive(Debug, PartialEq, Eq)]
1516
struct ArrayRetdata<const N: usize, T>(Vec<T>);
1617

18+
#[derive(Debug, PartialEq, Eq)]
19+
pub(crate) struct ContractStaker {
20+
pub(crate) contract_address: ContractAddress,
21+
pub(crate) staking_power: StakingWeight,
22+
pub(crate) public_key: Option<Felt>,
23+
}
24+
1725
#[derive(Debug, Error)]
1826
pub enum RetdataDeserializationError {
1927
#[error("Failed to convert Felt to ContractAddress: {address}")]
2028
ContractAddressConversionError { address: Felt },
2129
#[error("Failed to convert Felt to u128: {felt}")]
2230
U128ConversionError { felt: Felt },
31+
#[error("Failed to convert Felt to usize: {felt}")]
32+
USizeConversionError { felt: Felt },
2333
#[error(
2434
"Invalid retdata length: expected 1 Felt followed by {num_structs} (number of structs) *
2535
{struct_size} (number of Felts per struct), but received {length} Felts."
2636
)]
2737
InvalidArrayLength { length: usize, num_structs: usize, struct_size: usize },
38+
#[error("Invalid retdata length: expected {expected} Felts, but received {received} Felts.")]
39+
InvalidObjectLength { expected: usize, received: usize },
40+
#[error("Unexpected enum variant: {variant}")]
41+
UnexpectedEnumVariant { variant: usize },
2842
}
2943

30-
impl Staker {
31-
pub const CAIRO_OBJECT_NUM_FELTS: usize = 3;
44+
impl ContractStaker {
45+
pub const CAIRO_OBJECT_NUM_FELTS: usize = 4;
3246

3347
pub fn from_retdata_many(retdata: Retdata) -> Result<Vec<Self>, RetdataDeserializationError> {
34-
Ok(ArrayRetdata::<{ Self::CAIRO_OBJECT_NUM_FELTS }, Staker>::try_from(retdata)?.0)
48+
Ok(ArrayRetdata::<{ Self::CAIRO_OBJECT_NUM_FELTS }, ContractStaker>::try_from(retdata)?.0)
3549
}
3650
}
3751

38-
impl TryFrom<[Felt; Self::CAIRO_OBJECT_NUM_FELTS]> for Staker {
52+
impl TryFrom<[Felt; Self::CAIRO_OBJECT_NUM_FELTS]> for ContractStaker {
3953
type Error = RetdataDeserializationError;
4054

4155
fn try_from(felts: [Felt; Self::CAIRO_OBJECT_NUM_FELTS]) -> Result<Self, Self::Error> {
42-
let [address, weight, public_key] = felts;
43-
let address = ContractAddress::try_from(address)
44-
.map_err(|_| RetdataDeserializationError::ContractAddressConversionError { address })?;
45-
let weight = StakingWeight(
46-
u128::try_from(weight)
47-
.map_err(|_| RetdataDeserializationError::U128ConversionError { felt: weight })?,
48-
);
49-
Ok(Self { address, weight, public_key })
56+
let [contract_address, staking_power, option_variant, public_key] = felts;
57+
let contract_address = ContractAddress::try_from(contract_address).map_err(|_| {
58+
RetdataDeserializationError::ContractAddressConversionError {
59+
address: contract_address,
60+
}
61+
})?;
62+
let staking_power = StakingWeight(u128::try_from(staking_power).map_err(|_| {
63+
RetdataDeserializationError::U128ConversionError { felt: staking_power }
64+
})?);
65+
let option_variant = usize::try_from(option_variant).map_err(|_| {
66+
RetdataDeserializationError::USizeConversionError { felt: option_variant }
67+
})?;
68+
let public_key = match option_variant {
69+
0 => Some(public_key),
70+
1 => None,
71+
_ => {
72+
return Err(RetdataDeserializationError::UnexpectedEnumVariant {
73+
variant: option_variant,
74+
});
75+
}
76+
};
77+
Ok(Self { contract_address, staking_power, public_key })
5078
}
5179
}
5280

5381
#[cfg(test)]
54-
impl From<&Staker> for Vec<Felt> {
82+
impl From<&ContractStaker> for Vec<Felt> {
83+
fn from(staker: &ContractStaker) -> Self {
84+
vec![
85+
Felt::from(staker.contract_address),
86+
Felt::from(staker.staking_power.0),
87+
Felt::from(if staker.public_key.is_some() { 0 } else { 1 }),
88+
staker.public_key.unwrap_or_default(),
89+
]
90+
}
91+
}
92+
93+
#[cfg(test)]
94+
impl From<&Staker> for ContractStaker {
5595
fn from(staker: &Staker) -> Self {
56-
vec![Felt::from(staker.address), Felt::from(staker.weight.0), staker.public_key]
96+
Self {
97+
contract_address: staker.address,
98+
staking_power: staker.weight,
99+
public_key: Some(staker.public_key),
100+
}
57101
}
58102
}
59103

crates/apollo_staking/src/staking_manager.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ use crate::committee_provider::{
1919
ExecutionContext,
2020
Staker,
2121
};
22-
use crate::contract_types::{EPOCH_LENGTH, GET_STAKERS_ENTRY_POINT};
22+
use crate::contract_types::{ContractStaker, EPOCH_LENGTH, GET_STAKERS_ENTRY_POINT};
2323
use crate::utils::BlockRandomGenerator;
2424

2525
pub type StakerSet = Vec<Staker>;
@@ -132,7 +132,17 @@ impl StakingManager {
132132
Calldata(vec![Felt::from(epoch)].into()),
133133
)?;
134134

135-
let stakers = Staker::from_retdata_many(call_info.execution.retdata)?;
135+
let stakers: Vec<Staker> = ContractStaker::from_retdata_many(call_info.execution.retdata)?
136+
.into_iter()
137+
.filter_map(|staker| {
138+
// Filter out stakers that don't have a public key.
139+
staker.public_key.map(|public_key| Staker {
140+
address: staker.contract_address,
141+
weight: staker.staking_power,
142+
public_key,
143+
})
144+
})
145+
.collect();
136146
let committee_members = self.select_committee(stakers);
137147

138148
// Prepare the data needed for proposer selection.

crates/apollo_staking/src/staking_manager_test.rs

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ use crate::committee_provider::{
3232
ExecutionContext,
3333
Staker,
3434
};
35-
use crate::contract_types::RetdataDeserializationError;
36-
use crate::staking_manager::{StakerSet, StakingManager, StakingManagerConfig};
35+
use crate::contract_types::{ContractStaker, RetdataDeserializationError};
36+
use crate::staking_manager::{StakingManager, StakingManagerConfig};
3737
use crate::utils::MockBlockRandomGenerator;
3838

3939
const STAKING_CONTRACT: FeatureContract =
@@ -93,7 +93,10 @@ fn default_config() -> StakingManagerConfig {
9393
}
9494

9595
fn set_stakers(state: &mut State, block_context: &Context, stakers: &[Staker]) {
96-
let mut stakers_as_felts: Vec<Felt> = stakers.iter().flat_map(<Vec<Felt>>::from).collect();
96+
let mut stakers_as_felts: Vec<Felt> = stakers
97+
.iter()
98+
.flat_map(|staker| <Vec<Felt>>::from(&ContractStaker::from(staker)))
99+
.collect();
97100
stakers_as_felts.insert(0, Felt::from(stakers.len()));
98101

99102
// Invoke the set_stakers function on the mock staking contract.
@@ -123,7 +126,7 @@ fn get_committee_success(
123126
default_config: StakingManagerConfig,
124127
mut state: State,
125128
block_context: Context,
126-
#[case] stakers: StakerSet,
129+
#[case] stakers: Vec<Staker>,
127130
#[case] expected_committee: Committee,
128131
) {
129132
set_stakers(&mut state, &block_context, &stakers);
@@ -276,25 +279,39 @@ async fn get_proposer_random_value_exceeds_total_weight(
276279
let _ = committee_manager.get_proposer(BlockNumber(1), 0, context).await;
277280
}
278281

279-
// --- TryFrom tests for Staker and ArrayRetdata ---
282+
// --- TryFrom tests for ContractStaker and ArrayRetdata ---
280283

281284
#[rstest]
282285
fn staker_try_from_valid() {
283-
let staker = Staker::try_from([Felt::ONE, Felt::TWO, Felt::THREE]).unwrap();
284-
assert_eq!(staker.address, contract_address!("0x1"));
285-
assert_eq!(staker.weight, StakingWeight(2));
286-
assert_eq!(staker.public_key, Felt::THREE);
286+
let staker = ContractStaker::try_from([Felt::ONE, Felt::TWO, Felt::ZERO, Felt::THREE]).unwrap();
287+
assert_eq!(staker.contract_address, contract_address!("0x1"));
288+
assert_eq!(staker.staking_power, StakingWeight(2));
289+
assert_eq!(staker.public_key, Some(Felt::THREE));
290+
291+
// A valid staker with no public key.
292+
let staker = ContractStaker::try_from([Felt::ONE, Felt::TWO, Felt::ONE, Felt::THREE]).unwrap();
293+
assert_eq!(staker.contract_address, contract_address!("0x1"));
294+
assert_eq!(staker.staking_power, StakingWeight(2));
295+
assert_eq!(staker.public_key, None);
287296
}
288297

289298
#[rstest]
290299
fn staker_try_from_invalid_address() {
291-
let err = Staker::try_from([CONTRACT_ADDRESS_DOMAIN_SIZE, Felt::ONE, Felt::ONE]).unwrap_err();
300+
let err =
301+
ContractStaker::try_from([CONTRACT_ADDRESS_DOMAIN_SIZE, Felt::ONE, Felt::ZERO, Felt::ONE])
302+
.unwrap_err();
292303
assert_matches!(err, RetdataDeserializationError::ContractAddressConversionError { .. });
293304
}
294305

306+
#[rstest]
307+
fn staker_try_from_invalid_public_key() {
308+
let err = ContractStaker::try_from([Felt::ONE, Felt::TWO, Felt::TWO, Felt::THREE]).unwrap_err();
309+
assert_matches!(err, RetdataDeserializationError::UnexpectedEnumVariant { .. });
310+
}
311+
295312
#[rstest]
296313
fn staker_try_from_invalid_staked_amount() {
297-
let err = Staker::try_from([Felt::ONE, Felt::MAX, Felt::ONE]).unwrap_err(); // Felt::MAX is too big for u128
314+
let err = ContractStaker::try_from([Felt::ONE, Felt::MAX, Felt::ZERO, Felt::ONE]).unwrap_err(); // Felt::MAX is too big for u128
298315
assert_matches!(err, RetdataDeserializationError::U128ConversionError { .. });
299316
}
300317

@@ -304,19 +321,19 @@ fn staker_try_from_invalid_staked_amount() {
304321
fn staker_array_retdata_try_from_valid(#[case] num_structs: usize) {
305322
let valid_retdata = [
306323
[Felt::from(num_structs)].as_slice(),
307-
vec![Felt::ONE; Staker::CAIRO_OBJECT_NUM_FELTS * num_structs].as_slice(),
324+
vec![Felt::ONE; ContractStaker::CAIRO_OBJECT_NUM_FELTS * num_structs].as_slice(),
308325
]
309326
.concat();
310327

311-
let result = Staker::from_retdata_many(Retdata(valid_retdata)).unwrap();
328+
let result = ContractStaker::from_retdata_many(Retdata(valid_retdata)).unwrap();
312329
assert_eq!(result.len(), num_structs);
313330
}
314331

315332
#[rstest]
316333
#[case::empty_retdata(vec![])]
317-
#[case::missing_num_structs(vec![Felt::ONE; Staker::CAIRO_OBJECT_NUM_FELTS * 2])]
318-
#[case::invalid_staker_length(vec![Felt::ONE; Staker::CAIRO_OBJECT_NUM_FELTS - 1])]
334+
#[case::missing_num_structs(vec![Felt::ONE; ContractStaker::CAIRO_OBJECT_NUM_FELTS * 2])]
335+
#[case::invalid_staker_length(vec![Felt::ONE; ContractStaker::CAIRO_OBJECT_NUM_FELTS - 1])]
319336
fn staker_array_retdata_try_from_invalid_length(#[case] retdata: Vec<Felt>) {
320-
let err = Staker::from_retdata_many(Retdata(retdata)).unwrap_err();
337+
let err = ContractStaker::from_retdata_many(Retdata(retdata)).unwrap_err();
321338
assert_matches!(err, RetdataDeserializationError::InvalidArrayLength { .. });
322339
}
Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,44 @@
1+
use starknet::ContractAddress;
2+
3+
pub type BlockNumber = u64;
4+
pub type Epoch = u64;
5+
pub type PublicKey = felt252;
6+
pub type StakingPower = u128;
7+
18
#[derive(Drop, Serde, starknet::Store)]
29
pub struct Staker {
3-
pub address: felt252,
4-
pub staked_amount: u128,
5-
pub pubkey: felt252,
10+
pub contract_address: ContractAddress,
11+
pub staking_power: StakingPower,
12+
pub pub_key: Option<PublicKey>,
613
}
714

815
#[derive(Drop, Serde, starknet::Store)]
916
pub struct EpochInfo {
10-
pub epoch: u64,
11-
pub start_block: u64,
12-
pub end_block: u64,
17+
pub epoch_id: Epoch,
18+
pub start_block: BlockNumber,
19+
pub epoch_length: u32,
1320
}
1421

1522
#[starknet::interface]
1623
pub trait IStaking<TContractState> {
1724
fn add_staker(ref self: TContractState, staker: Staker);
1825
fn set_stakers(ref self: TContractState, stakers: Array<Staker>);
19-
fn get_stakers(self: @TContractState, epoch: u64) -> Array<Staker>;
2026
fn set_current_epoch(ref self: TContractState, epoch: EpochInfo);
21-
fn get_current_epoch(self: @TContractState) -> EpochInfo;
27+
28+
// The following functions have exactly the same interface as the real Staking contract.
29+
fn get_stakers(
30+
self: @TContractState, epoch_id: Epoch,
31+
) -> Span<(ContractAddress, StakingPower, Option<PublicKey>)>;
32+
fn get_current_epoch_data(self: @TContractState) -> (Epoch, BlockNumber, u32);
2233
}
2334

2435
#[starknet::contract]
2536
mod Staking {
37+
use starknet::ContractAddress;
2638
use starknet::storage::{
2739
MutableVecTrait, StoragePointerReadAccess, StoragePointerWriteAccess, Vec, VecTrait,
2840
};
29-
use super::{EpochInfo, Staker};
41+
use super::{BlockNumber, Epoch, EpochInfo, PublicKey, Staker, StakingPower};
3042

3143
#[storage]
3244
struct Storage {
@@ -50,21 +62,25 @@ mod Staking {
5062
}
5163
}
5264

53-
// epoch is not used in this mock, but should be part of the interface.
54-
fn get_stakers(self: @ContractState, epoch: u64) -> Array<Staker> {
65+
fn set_current_epoch(ref self: ContractState, epoch: EpochInfo) {
66+
self.current_epoch.write(epoch);
67+
}
68+
69+
// epoch_id is not used in this mock, but should be part of the interface.
70+
fn get_stakers(
71+
self: @ContractState, epoch_id: Epoch,
72+
) -> Span<(ContractAddress, StakingPower, Option<PublicKey>)> {
5573
let mut stakers = array![];
5674
for i in 0..self.stakers.len() {
57-
stakers.append(self.stakers.at(i).read());
75+
let staker = self.stakers.at(i).read();
76+
stakers.append((staker.contract_address, staker.staking_power, staker.pub_key));
5877
}
59-
stakers
60-
}
61-
62-
fn set_current_epoch(ref self: ContractState, epoch: EpochInfo) {
63-
self.current_epoch.write(epoch);
78+
stakers.span()
6479
}
6580

66-
fn get_current_epoch(self: @ContractState) -> EpochInfo {
67-
self.current_epoch.read()
81+
fn get_current_epoch_data(self: @ContractState) -> (Epoch, BlockNumber, u32) {
82+
let epoch_info = self.current_epoch.read();
83+
(epoch_info.epoch_id, epoch_info.start_block, epoch_info.epoch_length)
6884
}
6985
}
7086
}

0 commit comments

Comments
 (0)