Skip to content

Commit 5196120

Browse files
committed
apollo_staking: align the mock staking contract interface with the real contract
1 parent c04b632 commit 5196120

File tree

9 files changed

+3711
-2746
lines changed

9 files changed

+3711
-2746
lines changed

crates/apollo_staking/src/committee_provider.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use crate::contract_types::RetdataDeserializationError;
1717
pub type Committee = Vec<Staker>;
1818

1919
#[cfg_attr(test, derive(Clone))]
20-
#[derive(Debug, PartialEq, Eq)]
20+
#[derive(Debug, PartialEq, Eq, Hash)]
2121
pub struct Staker {
2222
// A contract address of the staker, to which rewards are sent.
2323
pub address: ContractAddress,

crates/apollo_staking/src/contract_types.rs

Lines changed: 138 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -9,94 +9,178 @@ use crate::committee_provider::Staker;
99
pub(crate) const GET_STAKERS_ENTRY_POINT: &str = "get_stakers";
1010
pub(crate) const EPOCH_LENGTH: u64 = 100; // Number of heights in an epoch.
1111

12+
/// Conversion from an [`Iterator`].
13+
///
14+
/// By implementing `TryFromIterator` for a type, you define how it will be
15+
/// created from an iterator.
16+
///
17+
/// Used in this context to parse Cairo1 types returned by a contract as a vector of Felts.
18+
pub trait TryFromIterator<Felt>: Sized {
19+
type Error;
20+
21+
fn try_from_iter<T: Iterator<Item = Felt>>(iter: &mut T) -> Result<Self, Self::Error>;
22+
}
23+
1224
// Represents a Cairo1 `Array` containing elements that can be deserialized to `T`.
1325
// `T` must implement `TryFrom<[Felt; N]>`, where `N` is the size of `T`'s Cairo equivalent.
1426
#[derive(Debug, PartialEq, Eq)]
15-
struct ArrayRetdata<const N: usize, T>(Vec<T>);
27+
struct ArrayRetdata<T>(Vec<T>);
28+
29+
#[derive(Debug, PartialEq, Eq)]
30+
pub(crate) struct ContractStaker {
31+
pub(crate) contract_address: ContractAddress,
32+
pub(crate) staking_power: StakingWeight,
33+
pub(crate) public_key: Option<Felt>,
34+
}
1635

1736
#[derive(Debug, Error)]
1837
pub enum RetdataDeserializationError {
1938
#[error("Failed to convert Felt to ContractAddress: {address}")]
2039
ContractAddressConversionError { address: Felt },
2140
#[error("Failed to convert Felt to u128: {felt}")]
2241
U128ConversionError { felt: Felt },
23-
#[error(
24-
"Invalid retdata length: expected 1 Felt followed by {num_structs} (number of structs) *
25-
{struct_size} (number of Felts per struct), but received {length} Felts."
26-
)]
27-
InvalidArrayLength { length: usize, num_structs: usize, struct_size: usize },
42+
#[error("Failed to convert Felt to usize: {felt}")]
43+
USizeConversionError { felt: Felt },
44+
#[error("Invalid object length: {message}.")]
45+
InvalidObjectLength { message: String },
46+
#[error("Unexpected enum variant: {variant}")]
47+
UnexpectedEnumVariant { variant: usize },
2848
}
2949

30-
impl Staker {
31-
pub const CAIRO_OBJECT_NUM_FELTS: usize = 3;
32-
50+
impl ContractStaker {
3351
pub fn from_retdata_many(retdata: Retdata) -> Result<Vec<Self>, RetdataDeserializationError> {
34-
Ok(ArrayRetdata::<{ Self::CAIRO_OBJECT_NUM_FELTS }, Staker>::try_from(retdata)?.0)
52+
Ok(ArrayRetdata::try_from(retdata)?.0)
3553
}
3654
}
3755

38-
impl TryFrom<[Felt; Self::CAIRO_OBJECT_NUM_FELTS]> for Staker {
56+
impl TryFromIterator<Felt> for ContractStaker {
3957
type Error = RetdataDeserializationError;
4058

41-
fn try_from(felts: [Felt; Self::CAIRO_OBJECT_NUM_FELTS]) -> Result<Self, Self::Error> {
42-
let [address, weight, public_key] = felts;
43-
let address = ContractAddress::try_from(address)
44-
.map_err(|_| RetdataDeserializationError::ContractAddressConversionError { address })?;
45-
let weight = StakingWeight(
46-
u128::try_from(weight)
47-
.map_err(|_| RetdataDeserializationError::U128ConversionError { felt: weight })?,
48-
);
49-
Ok(Self { address, weight, public_key })
59+
// Parses a single `ContractStaker` from a stream of Felts.
60+
//
61+
// The iterator is expected to yield the following values, in order:
62+
// 1. Contract Address (1 Felt)
63+
// 2. Staking Power (1 Felt)
64+
// 3. Public Key option variant (1 Felt):
65+
// - 0 => Some
66+
// - 1 => None
67+
// 4. Public Key (1 Felt), only if the option variant is `Some`
68+
fn try_from_iter<T: Iterator<Item = Felt>>(iter: &mut T) -> Result<Self, Self::Error> {
69+
// Parse contract address.
70+
let raw_address = iter.next().ok_or(RetdataDeserializationError::InvalidObjectLength {
71+
message: "missing contract address.".to_string(),
72+
})?;
73+
let contract_address = ContractAddress::try_from(raw_address).map_err(|_| {
74+
RetdataDeserializationError::ContractAddressConversionError { address: raw_address }
75+
})?;
76+
77+
// Parse staking power.
78+
let raw_staking_power =
79+
iter.next().ok_or(RetdataDeserializationError::InvalidObjectLength {
80+
message: "missing staking power.".to_string(),
81+
})?;
82+
let staking_power = StakingWeight(u128::try_from(raw_staking_power).map_err(|_| {
83+
RetdataDeserializationError::U128ConversionError { felt: raw_staking_power }
84+
})?);
85+
86+
// Parse public key.
87+
let raw_option_variant =
88+
iter.next().ok_or(RetdataDeserializationError::InvalidObjectLength {
89+
message: "missing public key option variant.".to_string(),
90+
})?;
91+
let option_variant = usize::try_from(raw_option_variant).map_err(|_| {
92+
RetdataDeserializationError::USizeConversionError { felt: raw_option_variant }
93+
})?;
94+
let public_key = match option_variant {
95+
1 => None,
96+
0 => {
97+
let public_key =
98+
iter.next().ok_or(RetdataDeserializationError::InvalidObjectLength {
99+
message: "missing public key.".to_string(),
100+
})?;
101+
Some(public_key)
102+
}
103+
_ => {
104+
return Err(RetdataDeserializationError::UnexpectedEnumVariant {
105+
variant: option_variant,
106+
});
107+
}
108+
};
109+
110+
Ok(Self { contract_address, staking_power, public_key })
50111
}
51112
}
52113

53-
#[cfg(test)]
54-
impl From<&Staker> for Vec<Felt> {
55-
fn from(staker: &Staker) -> Self {
56-
vec![Felt::from(staker.address), Felt::from(staker.weight.0), staker.public_key]
57-
}
58-
}
59-
60-
impl<const N: usize, T> TryFrom<Retdata> for ArrayRetdata<N, T>
114+
impl<T> TryFrom<Retdata> for ArrayRetdata<T>
61115
where
62-
T: TryFrom<[Felt; N], Error = RetdataDeserializationError>,
116+
T: TryFromIterator<Felt, Error = RetdataDeserializationError>,
63117
{
64118
type Error = RetdataDeserializationError;
65119

66120
fn try_from(retdata: Retdata) -> Result<Self, Self::Error> {
67-
let data = retdata.0;
121+
let mut iter = retdata.0.into_iter();
68122

69123
// The first Felt in the Retdata must be the number of structs in the array.
70-
if data.is_empty() {
71-
return Err(RetdataDeserializationError::InvalidArrayLength {
72-
length: data.len(),
73-
num_structs: 0,
74-
struct_size: N,
124+
let raw_num_items =
125+
iter.next().ok_or(RetdataDeserializationError::InvalidObjectLength {
126+
message: "missing number of items in an array.".to_string(),
127+
})?;
128+
129+
let num_items = usize::try_from(raw_num_items).map_err(|_| {
130+
RetdataDeserializationError::USizeConversionError { felt: raw_num_items }
131+
})?;
132+
133+
let mut result = Vec::with_capacity(num_items);
134+
for _ in 0..num_items {
135+
let item = T::try_from_iter(&mut iter)?;
136+
result.push(item);
137+
}
138+
139+
if iter.next().is_some() {
140+
return Err(RetdataDeserializationError::InvalidObjectLength {
141+
message: "Unconsumed elements found in retdata.".to_string(),
75142
});
76143
}
77144

78-
// Split the remaining Felts into chunks of N Felts, each is a struct in the array.
79-
let data_chunks = data[1..].chunks_exact(N);
145+
Ok(ArrayRetdata(result))
146+
}
147+
}
80148

81-
// Verify that the number of structs in the array matches the number of chunks.
82-
let num_structs = usize::try_from(data[0]).expect("num_structs should fit in usize.");
83-
if data_chunks.len() != num_structs || !data_chunks.remainder().is_empty() {
84-
return Err(RetdataDeserializationError::InvalidArrayLength {
85-
length: data.len(),
86-
num_structs,
87-
struct_size: N,
88-
});
149+
impl From<&ContractStaker> for Staker {
150+
/// # Panics
151+
///
152+
/// Panics if `public_key` is `None`.
153+
fn from(contract_staker: &ContractStaker) -> Self {
154+
Self {
155+
address: contract_staker.contract_address,
156+
weight: contract_staker.staking_power,
157+
public_key: contract_staker.public_key.expect("public key is required."),
89158
}
159+
}
160+
}
90161

91-
// Convert each chunk to T.
92-
let result = data_chunks
93-
.map(|chunk| {
94-
T::try_from(
95-
chunk.try_into().unwrap_or_else(|_| panic!("chunk size must be N: {N}.")),
96-
)
97-
})
98-
.collect::<Result<Vec<_>, _>>()?;
162+
#[cfg(test)]
163+
impl From<&ContractStaker> for Vec<Felt> {
164+
fn from(staker: &ContractStaker) -> Self {
165+
let public_key = match staker.public_key {
166+
Some(public_key) => vec![Felt::ZERO, public_key],
167+
None => vec![Felt::ONE],
168+
};
169+
[
170+
[Felt::from(staker.contract_address), Felt::from(staker.staking_power.0)].as_slice(),
171+
public_key.as_slice(),
172+
]
173+
.concat()
174+
}
175+
}
99176

100-
Ok(ArrayRetdata(result))
177+
#[cfg(test)]
178+
impl From<&Staker> for ContractStaker {
179+
fn from(staker: &Staker) -> Self {
180+
Self {
181+
contract_address: staker.address,
182+
staking_power: staker.weight,
183+
public_key: Some(staker.public_key),
184+
}
101185
}
102186
}

crates/apollo_staking/src/staking_manager.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ use crate::committee_provider::{
1919
ExecutionContext,
2020
Staker,
2121
};
22-
use crate::contract_types::{EPOCH_LENGTH, GET_STAKERS_ENTRY_POINT};
22+
use crate::contract_types::{ContractStaker, EPOCH_LENGTH, GET_STAKERS_ENTRY_POINT};
2323
use crate::utils::BlockRandomGenerator;
2424

2525
pub type StakerSet = Vec<Staker>;
@@ -132,7 +132,13 @@ impl StakingManager {
132132
Calldata(vec![Felt::from(epoch)].into()),
133133
)?;
134134

135-
let stakers = Staker::from_retdata_many(call_info.execution.retdata)?;
135+
let stakers: Vec<Staker> = ContractStaker::from_retdata_many(call_info.execution.retdata)?
136+
.into_iter()
137+
.filter_map(|contract_staker| {
138+
// Filter out stakers that don't have a public key.
139+
contract_staker.public_key.map(|_| Staker::from(&contract_staker))
140+
})
141+
.collect();
136142
let committee_members = self.select_committee(stakers);
137143

138144
// Prepare the data needed for proposer selection.

0 commit comments

Comments
 (0)