Skip to content

Commit d1a01a5

Browse files
committed
apollo_staking: align the mock staking contract interface with the real contract
1 parent 4f29510 commit d1a01a5

File tree

9 files changed

+3704
-2747
lines changed

9 files changed

+3704
-2747
lines changed

crates/apollo_staking/src/committee_provider.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use crate::contract_types::RetdataDeserializationError;
1717
pub type Committee = Vec<Staker>;
1818

1919
#[cfg_attr(test, derive(Clone))]
20-
#[derive(Debug, PartialEq, Eq)]
20+
#[derive(Debug, PartialEq, Eq, Hash)]
2121
pub struct Staker {
2222
// A contract address of the staker, to which rewards are sent.
2323
pub address: ContractAddress,

crates/apollo_staking/src/contract_types.rs

Lines changed: 127 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -4,99 +4,171 @@ use starknet_api::staking::StakingWeight;
44
use starknet_types_core::felt::Felt;
55
use thiserror::Error;
66

7+
#[cfg(test)]
78
use crate::committee_provider::Staker;
89

910
pub(crate) const GET_STAKERS_ENTRY_POINT: &str = "get_stakers";
1011
pub(crate) const EPOCH_LENGTH: u64 = 100; // Number of heights in an epoch.
1112

13+
/// Conversion from an [`Iterator`].
14+
///
15+
/// By implementing `TryFromIterator` for a type, you define how it will be
16+
/// created from an iterator.
17+
///
18+
/// Used in this context to parse Cairo1 types returned by a contract as a vector of Felts.
19+
pub trait TryFromIterator<Felt>: Sized {
20+
type Error;
21+
22+
fn try_from_iter<T: Iterator<Item = Felt>>(iter: &mut T) -> Result<Self, Self::Error>;
23+
}
24+
1225
// Represents a Cairo1 `Array` containing elements that can be deserialized to `T`.
1326
// `T` must implement `TryFrom<[Felt; N]>`, where `N` is the size of `T`'s Cairo equivalent.
1427
#[derive(Debug, PartialEq, Eq)]
15-
struct ArrayRetdata<const N: usize, T>(Vec<T>);
28+
struct ArrayRetdata<T>(Vec<T>);
29+
30+
#[derive(Debug, PartialEq, Eq)]
31+
pub(crate) struct ContractStaker {
32+
pub(crate) contract_address: ContractAddress,
33+
pub(crate) staking_power: StakingWeight,
34+
pub(crate) public_key: Option<Felt>,
35+
}
1636

1737
#[derive(Debug, Error)]
1838
pub enum RetdataDeserializationError {
1939
#[error("Failed to convert Felt to ContractAddress: {address}")]
2040
ContractAddressConversionError { address: Felt },
2141
#[error("Failed to convert Felt to u128: {felt}")]
2242
U128ConversionError { felt: Felt },
23-
#[error(
24-
"Invalid retdata length: expected 1 Felt followed by {num_structs} (number of structs) *
25-
{struct_size} (number of Felts per struct), but received {length} Felts."
26-
)]
27-
InvalidArrayLength { length: usize, num_structs: usize, struct_size: usize },
43+
#[error("Failed to convert Felt to usize: {felt}")]
44+
USizeConversionError { felt: Felt },
45+
#[error("Invalid object length: {message}.")]
46+
InvalidObjectLength { message: String },
47+
#[error("Unexpected enum variant: {variant}")]
48+
UnexpectedEnumVariant { variant: usize },
2849
}
2950

30-
impl Staker {
31-
pub const CAIRO_OBJECT_NUM_FELTS: usize = 3;
32-
51+
impl ContractStaker {
3352
pub fn from_retdata_many(retdata: Retdata) -> Result<Vec<Self>, RetdataDeserializationError> {
34-
Ok(ArrayRetdata::<{ Self::CAIRO_OBJECT_NUM_FELTS }, Staker>::try_from(retdata)?.0)
53+
Ok(ArrayRetdata::try_from(retdata)?.0)
3554
}
3655
}
3756

38-
impl TryFrom<[Felt; Self::CAIRO_OBJECT_NUM_FELTS]> for Staker {
57+
impl TryFromIterator<Felt> for ContractStaker {
3958
type Error = RetdataDeserializationError;
4059

41-
fn try_from(felts: [Felt; Self::CAIRO_OBJECT_NUM_FELTS]) -> Result<Self, Self::Error> {
42-
let [address, weight, public_key] = felts;
43-
let address = ContractAddress::try_from(address)
44-
.map_err(|_| RetdataDeserializationError::ContractAddressConversionError { address })?;
45-
let weight = StakingWeight(
46-
u128::try_from(weight)
47-
.map_err(|_| RetdataDeserializationError::U128ConversionError { felt: weight })?,
48-
);
49-
Ok(Self { address, weight, public_key })
50-
}
51-
}
52-
53-
#[cfg(test)]
54-
impl From<&Staker> for Vec<Felt> {
55-
fn from(staker: &Staker) -> Self {
56-
vec![Felt::from(staker.address), Felt::from(staker.weight.0), staker.public_key]
60+
// Parses a single `ContractStaker` from a stream of Felts.
61+
//
62+
// The iterator is expected to yield the following values, in order:
63+
// 1. Contract Address (1 Felt)
64+
// 2. Staking Power (1 Felt)
65+
// 3. Public Key option variant (1 Felt):
66+
// - 0 => Some
67+
// - 1 => None
68+
// 4. Public Key (1 Felt), only if the option variant is `Some`
69+
fn try_from_iter<T: Iterator<Item = Felt>>(iter: &mut T) -> Result<Self, Self::Error> {
70+
// Parse contract address.
71+
let raw_address = iter.next().ok_or(RetdataDeserializationError::InvalidObjectLength {
72+
message: "missing contract address.".to_string(),
73+
})?;
74+
let contract_address = ContractAddress::try_from(raw_address).map_err(|_| {
75+
RetdataDeserializationError::ContractAddressConversionError { address: raw_address }
76+
})?;
77+
78+
// Parse staking power.
79+
let raw_staking_power =
80+
iter.next().ok_or(RetdataDeserializationError::InvalidObjectLength {
81+
message: "missing staking power.".to_string(),
82+
})?;
83+
let staking_power = StakingWeight(u128::try_from(raw_staking_power).map_err(|_| {
84+
RetdataDeserializationError::U128ConversionError { felt: raw_staking_power }
85+
})?);
86+
87+
// Parse public key.
88+
let raw_option_variant =
89+
iter.next().ok_or(RetdataDeserializationError::InvalidObjectLength {
90+
message: "missing public key option variant.".to_string(),
91+
})?;
92+
let option_variant = usize::try_from(raw_option_variant).map_err(|_| {
93+
RetdataDeserializationError::USizeConversionError { felt: raw_option_variant }
94+
})?;
95+
let public_key = match option_variant {
96+
1 => None,
97+
0 => {
98+
let public_key =
99+
iter.next().ok_or(RetdataDeserializationError::InvalidObjectLength {
100+
message: "missing public key.".to_string(),
101+
})?;
102+
Some(public_key)
103+
}
104+
_ => {
105+
return Err(RetdataDeserializationError::UnexpectedEnumVariant {
106+
variant: option_variant,
107+
});
108+
}
109+
};
110+
111+
Ok(Self { contract_address, staking_power, public_key })
57112
}
58113
}
59114

60-
impl<const N: usize, T> TryFrom<Retdata> for ArrayRetdata<N, T>
115+
impl<T> TryFrom<Retdata> for ArrayRetdata<T>
61116
where
62-
T: TryFrom<[Felt; N], Error = RetdataDeserializationError>,
117+
T: TryFromIterator<Felt, Error = RetdataDeserializationError>,
63118
{
64119
type Error = RetdataDeserializationError;
65120

66121
fn try_from(retdata: Retdata) -> Result<Self, Self::Error> {
67-
let data = retdata.0;
122+
let mut iter = retdata.0.into_iter();
68123

69124
// The first Felt in the Retdata must be the number of structs in the array.
70-
if data.is_empty() {
71-
return Err(RetdataDeserializationError::InvalidArrayLength {
72-
length: data.len(),
73-
num_structs: 0,
74-
struct_size: N,
75-
});
125+
let raw_num_items =
126+
iter.next().ok_or(RetdataDeserializationError::InvalidObjectLength {
127+
message: "missing number of items in an array.".to_string(),
128+
})?;
129+
130+
let num_items = usize::try_from(raw_num_items).map_err(|_| {
131+
RetdataDeserializationError::USizeConversionError { felt: raw_num_items }
132+
})?;
133+
134+
let mut result = Vec::with_capacity(num_items);
135+
for _ in 0..num_items {
136+
let item = T::try_from_iter(&mut iter)?;
137+
result.push(item);
76138
}
77139

78-
// Split the remaining Felts into chunks of N Felts, each is a struct in the array.
79-
let data_chunks = data[1..].chunks_exact(N);
80-
81-
// Verify that the number of structs in the array matches the number of chunks.
82-
let num_structs = usize::try_from(data[0]).expect("num_structs should fit in usize.");
83-
if data_chunks.len() != num_structs || !data_chunks.remainder().is_empty() {
84-
return Err(RetdataDeserializationError::InvalidArrayLength {
85-
length: data.len(),
86-
num_structs,
87-
struct_size: N,
140+
if iter.next().is_some() {
141+
return Err(RetdataDeserializationError::InvalidObjectLength {
142+
message: "Unconsumed elements found in retdata.".to_string(),
88143
});
89144
}
90145

91-
// Convert each chunk to T.
92-
let result = data_chunks
93-
.map(|chunk| {
94-
T::try_from(
95-
chunk.try_into().unwrap_or_else(|_| panic!("chunk size must be N: {N}.")),
96-
)
97-
})
98-
.collect::<Result<Vec<_>, _>>()?;
99-
100146
Ok(ArrayRetdata(result))
101147
}
102148
}
149+
150+
#[cfg(test)]
151+
impl From<&ContractStaker> for Vec<Felt> {
152+
fn from(staker: &ContractStaker) -> Self {
153+
let public_key = match staker.public_key {
154+
Some(public_key) => vec![Felt::ZERO, public_key],
155+
None => vec![Felt::ONE],
156+
};
157+
[
158+
[Felt::from(staker.contract_address), Felt::from(staker.staking_power.0)].as_slice(),
159+
public_key.as_slice(),
160+
]
161+
.concat()
162+
}
163+
}
164+
165+
#[cfg(test)]
166+
impl From<&Staker> for ContractStaker {
167+
fn from(staker: &Staker) -> Self {
168+
Self {
169+
contract_address: staker.address,
170+
staking_power: staker.weight,
171+
public_key: Some(staker.public_key),
172+
}
173+
}
174+
}

crates/apollo_staking/src/staking_manager.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ use crate::committee_provider::{
1919
ExecutionContext,
2020
Staker,
2121
};
22-
use crate::contract_types::{EPOCH_LENGTH, GET_STAKERS_ENTRY_POINT};
22+
use crate::contract_types::{ContractStaker, EPOCH_LENGTH, GET_STAKERS_ENTRY_POINT};
2323
use crate::utils::BlockRandomGenerator;
2424

2525
pub type StakerSet = Vec<Staker>;
@@ -132,7 +132,17 @@ impl StakingManager {
132132
Calldata(vec![Felt::from(epoch)].into()),
133133
)?;
134134

135-
let stakers = Staker::from_retdata_many(call_info.execution.retdata)?;
135+
let stakers: Vec<Staker> = ContractStaker::from_retdata_many(call_info.execution.retdata)?
136+
.into_iter()
137+
.filter_map(|staker| {
138+
// Filter out stakers that don't have a public key.
139+
staker.public_key.map(|public_key| Staker {
140+
address: staker.contract_address,
141+
weight: staker.staking_power,
142+
public_key,
143+
})
144+
})
145+
.collect();
136146
let committee_members = self.select_committee(stakers);
137147

138148
// Prepare the data needed for proposer selection.

0 commit comments

Comments
 (0)