Skip to content

Commit 80d48d6

Browse files
committed
apollo_staking: align the mock staking contract interface with the real contract
1 parent 4f29510 commit 80d48d6

File tree

9 files changed

+3695
-2748
lines changed

9 files changed

+3695
-2748
lines changed

crates/apollo_staking/src/committee_provider.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use crate::contract_types::RetdataDeserializationError;
1717
pub type Committee = Vec<Staker>;
1818

1919
#[cfg_attr(test, derive(Clone))]
20-
#[derive(Debug, PartialEq, Eq)]
20+
#[derive(Debug, PartialEq, Eq, Hash)]
2121
pub struct Staker {
2222
// A contract address of the staker, to which rewards are sent.
2323
pub address: ContractAddress,

crates/apollo_staking/src/contract_types.rs

Lines changed: 118 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -4,99 +4,161 @@ use starknet_api::staking::StakingWeight;
44
use starknet_types_core::felt::Felt;
55
use thiserror::Error;
66

7+
#[cfg(test)]
78
use crate::committee_provider::Staker;
89

910
pub(crate) const GET_STAKERS_ENTRY_POINT: &str = "get_stakers";
1011
pub(crate) const EPOCH_LENGTH: u64 = 100; // Number of heights in an epoch.
1112

13+
/// Conversion from an [`Iterator`].
14+
///
15+
/// By implementing `TryFromIterator` for a type, you define how it will be
16+
/// created from an iterator.
17+
///
18+
/// Used in this context to parse Cairo1 types returned by a contract as a vector of Felts.
19+
pub trait TryFromIterator<Felt>: Sized {
20+
type Error;
21+
22+
fn try_from_iter<T: Iterator<Item = Felt>>(iter: &mut T) -> Result<Self, Self::Error>;
23+
}
24+
1225
// Represents a Cairo1 `Array` containing elements that can be deserialized to `T`.
1326
// `T` must implement `TryFrom<[Felt; N]>`, where `N` is the size of `T`'s Cairo equivalent.
1427
#[derive(Debug, PartialEq, Eq)]
15-
struct ArrayRetdata<const N: usize, T>(Vec<T>);
28+
struct ArrayRetdata<T>(Vec<T>);
29+
30+
#[derive(Debug, PartialEq, Eq)]
31+
pub(crate) struct ContractStaker {
32+
pub(crate) contract_address: ContractAddress,
33+
pub(crate) staking_power: StakingWeight,
34+
pub(crate) public_key: Option<Felt>,
35+
}
1636

1737
#[derive(Debug, Error)]
1838
pub enum RetdataDeserializationError {
1939
#[error("Failed to convert Felt to ContractAddress: {address}")]
2040
ContractAddressConversionError { address: Felt },
2141
#[error("Failed to convert Felt to u128: {felt}")]
2242
U128ConversionError { felt: Felt },
23-
#[error(
24-
"Invalid retdata length: expected 1 Felt followed by {num_structs} (number of structs) *
25-
{struct_size} (number of Felts per struct), but received {length} Felts."
26-
)]
27-
InvalidArrayLength { length: usize, num_structs: usize, struct_size: usize },
43+
#[error("Failed to convert Felt to usize: {felt}")]
44+
USizeConversionError { felt: Felt },
45+
#[error("Invalid object length: {message}.")]
46+
InvalidObjectLength { message: String },
47+
#[error("Unexpected enum variant: {variant}")]
48+
UnexpectedEnumVariant { variant: usize },
2849
}
2950

30-
impl Staker {
31-
pub const CAIRO_OBJECT_NUM_FELTS: usize = 3;
32-
51+
impl ContractStaker {
3352
pub fn from_retdata_many(retdata: Retdata) -> Result<Vec<Self>, RetdataDeserializationError> {
34-
Ok(ArrayRetdata::<{ Self::CAIRO_OBJECT_NUM_FELTS }, Staker>::try_from(retdata)?.0)
53+
Ok(ArrayRetdata::try_from(retdata)?.0)
3554
}
3655
}
3756

38-
impl TryFrom<[Felt; Self::CAIRO_OBJECT_NUM_FELTS]> for Staker {
57+
impl TryFromIterator<Felt> for ContractStaker {
3958
type Error = RetdataDeserializationError;
40-
41-
fn try_from(felts: [Felt; Self::CAIRO_OBJECT_NUM_FELTS]) -> Result<Self, Self::Error> {
42-
let [address, weight, public_key] = felts;
43-
let address = ContractAddress::try_from(address)
44-
.map_err(|_| RetdataDeserializationError::ContractAddressConversionError { address })?;
45-
let weight = StakingWeight(
46-
u128::try_from(weight)
47-
.map_err(|_| RetdataDeserializationError::U128ConversionError { felt: weight })?,
48-
);
49-
Ok(Self { address, weight, public_key })
50-
}
51-
}
52-
53-
#[cfg(test)]
54-
impl From<&Staker> for Vec<Felt> {
55-
fn from(staker: &Staker) -> Self {
56-
vec![Felt::from(staker.address), Felt::from(staker.weight.0), staker.public_key]
59+
fn try_from_iter<T: Iterator<Item = Felt>>(iter: &mut T) -> Result<Self, Self::Error> {
60+
// Parse contract address.
61+
let raw_address = iter.next().ok_or(RetdataDeserializationError::InvalidObjectLength {
62+
message: "missing contract address.".to_string(),
63+
})?;
64+
let contract_address = ContractAddress::try_from(raw_address).map_err(|_| {
65+
RetdataDeserializationError::ContractAddressConversionError { address: raw_address }
66+
})?;
67+
68+
// Parse staking power.
69+
let raw_staking_power =
70+
iter.next().ok_or(RetdataDeserializationError::InvalidObjectLength {
71+
message: "missing staking power.".to_string(),
72+
})?;
73+
let staking_power = StakingWeight(u128::try_from(raw_staking_power).map_err(|_| {
74+
RetdataDeserializationError::U128ConversionError { felt: raw_staking_power }
75+
})?);
76+
77+
// Parse public key.
78+
let raw_option_variant =
79+
iter.next().ok_or(RetdataDeserializationError::InvalidObjectLength {
80+
message: "missing public key option variant.".to_string(),
81+
})?;
82+
let option_variant = usize::try_from(raw_option_variant).map_err(|_| {
83+
RetdataDeserializationError::USizeConversionError { felt: raw_option_variant }
84+
})?;
85+
let public_key = match option_variant {
86+
1 => None,
87+
0 => {
88+
let public_key =
89+
iter.next().ok_or(RetdataDeserializationError::InvalidObjectLength {
90+
message: "missing public key.".to_string(),
91+
})?;
92+
Some(public_key)
93+
}
94+
_ => {
95+
return Err(RetdataDeserializationError::UnexpectedEnumVariant {
96+
variant: option_variant,
97+
});
98+
}
99+
};
100+
101+
Ok(Self { contract_address, staking_power, public_key })
57102
}
58103
}
59104

60-
impl<const N: usize, T> TryFrom<Retdata> for ArrayRetdata<N, T>
105+
impl<T> TryFrom<Retdata> for ArrayRetdata<T>
61106
where
62-
T: TryFrom<[Felt; N], Error = RetdataDeserializationError>,
107+
T: TryFromIterator<Felt, Error = RetdataDeserializationError>,
63108
{
64109
type Error = RetdataDeserializationError;
65110

66111
fn try_from(retdata: Retdata) -> Result<Self, Self::Error> {
67-
let data = retdata.0;
112+
let mut iter = retdata.0.into_iter();
68113

69114
// The first Felt in the Retdata must be the number of structs in the array.
70-
if data.is_empty() {
71-
return Err(RetdataDeserializationError::InvalidArrayLength {
72-
length: data.len(),
73-
num_structs: 0,
74-
struct_size: N,
75-
});
115+
let raw_num_items =
116+
iter.next().ok_or(RetdataDeserializationError::InvalidObjectLength {
117+
message: "missing number of items in an array.".to_string(),
118+
})?;
119+
120+
let num_items = usize::try_from(raw_num_items).map_err(|_| {
121+
RetdataDeserializationError::USizeConversionError { felt: raw_num_items }
122+
})?;
123+
124+
let mut result = Vec::new();
125+
for _ in 0..num_items {
126+
let item = T::try_from_iter(&mut iter)?;
127+
result.push(item);
76128
}
77129

78-
// Split the remaining Felts into chunks of N Felts, each is a struct in the array.
79-
let data_chunks = data[1..].chunks_exact(N);
80-
81-
// Verify that the number of structs in the array matches the number of chunks.
82-
let num_structs = usize::try_from(data[0]).expect("num_structs should fit in usize.");
83-
if data_chunks.len() != num_structs || !data_chunks.remainder().is_empty() {
84-
return Err(RetdataDeserializationError::InvalidArrayLength {
85-
length: data.len(),
86-
num_structs,
87-
struct_size: N,
130+
if iter.next().is_some() {
131+
return Err(RetdataDeserializationError::InvalidObjectLength {
132+
message: "Unconsumed elements found in retdata.".to_string(),
88133
});
89134
}
90135

91-
// Convert each chunk to T.
92-
let result = data_chunks
93-
.map(|chunk| {
94-
T::try_from(
95-
chunk.try_into().unwrap_or_else(|_| panic!("chunk size must be N: {N}.")),
96-
)
97-
})
98-
.collect::<Result<Vec<_>, _>>()?;
99-
100136
Ok(ArrayRetdata(result))
101137
}
102138
}
139+
140+
#[cfg(test)]
141+
impl From<&ContractStaker> for Vec<Felt> {
142+
fn from(staker: &ContractStaker) -> Self {
143+
let public_key = match staker.public_key {
144+
Some(public_key) => vec![Felt::ZERO, public_key],
145+
None => vec![Felt::ONE],
146+
};
147+
[
148+
[Felt::from(staker.contract_address), Felt::from(staker.staking_power.0)].as_slice(),
149+
public_key.as_slice(),
150+
]
151+
.concat()
152+
}
153+
}
154+
155+
#[cfg(test)]
156+
impl From<&Staker> for ContractStaker {
157+
fn from(staker: &Staker) -> Self {
158+
Self {
159+
contract_address: staker.address,
160+
staking_power: staker.weight,
161+
public_key: Some(staker.public_key),
162+
}
163+
}
164+
}

crates/apollo_staking/src/staking_manager.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ use crate::committee_provider::{
1919
ExecutionContext,
2020
Staker,
2121
};
22-
use crate::contract_types::{EPOCH_LENGTH, GET_STAKERS_ENTRY_POINT};
22+
use crate::contract_types::{ContractStaker, EPOCH_LENGTH, GET_STAKERS_ENTRY_POINT};
2323
use crate::utils::BlockRandomGenerator;
2424

2525
pub type StakerSet = Vec<Staker>;
@@ -132,7 +132,17 @@ impl StakingManager {
132132
Calldata(vec![Felt::from(epoch)].into()),
133133
)?;
134134

135-
let stakers = Staker::from_retdata_many(call_info.execution.retdata)?;
135+
let stakers: Vec<Staker> = ContractStaker::from_retdata_many(call_info.execution.retdata)?
136+
.into_iter()
137+
.filter_map(|staker| {
138+
// Filter out stakers that don't have a public key.
139+
staker.public_key.map(|public_key| Staker {
140+
address: staker.contract_address,
141+
weight: staker.staking_power,
142+
public_key,
143+
})
144+
})
145+
.collect();
136146
let committee_members = self.select_committee(stakers);
137147

138148
// Prepare the data needed for proposer selection.

0 commit comments

Comments
 (0)