Skip to content
This repository was archived by the owner on Mar 11, 2025. It is now read-only.

Commit 38212ea

Browse files
authored
stake-pool: Remove unsafe pointer casts via Pod types (#5185)
* stake-pool: Force BigVec to work with Pod types * Remove all unsafe through an enum wrapper struct * Also fix the CLI
1 parent 908ea3f commit 38212ea

File tree

12 files changed

+209
-204
lines changed

12 files changed

+209
-204
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

stake-pool/cli/src/output.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@ use {
33
solana_cli_output::{QuietDisplay, VerboseDisplay},
44
solana_sdk::native_token::Sol,
55
solana_sdk::{pubkey::Pubkey, stake::state::Lockup},
6-
spl_stake_pool::state::{Fee, StakePool, StakeStatus, ValidatorList, ValidatorStakeInfo},
6+
spl_stake_pool::state::{
7+
Fee, PodStakeStatus, StakePool, StakeStatus, ValidatorList, ValidatorStakeInfo,
8+
},
79
std::fmt::{Display, Formatter, Result, Write},
810
};
911

@@ -384,8 +386,9 @@ impl From<ValidatorStakeInfo> for CliStakePoolValidator {
384386
}
385387
}
386388

387-
impl From<StakeStatus> for CliStakePoolValidatorStakeStatus {
388-
fn from(s: StakeStatus) -> CliStakePoolValidatorStakeStatus {
389+
impl From<PodStakeStatus> for CliStakePoolValidatorStakeStatus {
390+
fn from(s: PodStakeStatus) -> CliStakePoolValidatorStakeStatus {
391+
let s = StakeStatus::try_from(s).unwrap();
389392
match s {
390393
StakeStatus::Active => CliStakePoolValidatorStakeStatus::Active,
391394
StakeStatus::DeactivatingTransient => {

stake-pool/program/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ test-sbf = []
1414
[dependencies]
1515
arrayref = "0.3.7"
1616
borsh = "0.10"
17+
bytemuck = "1.13"
1718
num-derive = "0.4"
1819
num-traits = "0.2"
1920
num_enum = "0.7.0"

stake-pool/program/src/big_vec.rs

Lines changed: 52 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@
44
use {
55
arrayref::array_ref,
66
borsh::BorshDeserialize,
7-
solana_program::{
8-
program_error::ProgramError, program_memory::sol_memmove, program_pack::Pack,
9-
},
10-
std::marker::PhantomData,
7+
bytemuck::Pod,
8+
solana_program::{program_error::ProgramError, program_memory::sol_memmove},
9+
std::mem,
1110
};
1211

1312
/// Contains easy to use utilities for a big vector of Borsh-compatible types,
@@ -32,7 +31,7 @@ impl<'data> BigVec<'data> {
3231
}
3332

3433
/// Retain all elements that match the provided function, discard all others
35-
pub fn retain<T: Pack, F: Fn(&[u8]) -> bool>(
34+
pub fn retain<T: Pod, F: Fn(&[u8]) -> bool>(
3635
&mut self,
3736
predicate: F,
3837
) -> Result<(), ProgramError> {
@@ -42,12 +41,12 @@ impl<'data> BigVec<'data> {
4241

4342
let data_start_index = VEC_SIZE_BYTES;
4443
let data_end_index =
45-
data_start_index.saturating_add((vec_len as usize).saturating_mul(T::LEN));
46-
for start_index in (data_start_index..data_end_index).step_by(T::LEN) {
47-
let end_index = start_index + T::LEN;
44+
data_start_index.saturating_add((vec_len as usize).saturating_mul(mem::size_of::<T>()));
45+
for start_index in (data_start_index..data_end_index).step_by(mem::size_of::<T>()) {
46+
let end_index = start_index + mem::size_of::<T>();
4847
let slice = &self.data[start_index..end_index];
4948
if !predicate(slice) {
50-
let gap = removals_found * T::LEN;
49+
let gap = removals_found * mem::size_of::<T>();
5150
if removals_found > 0 {
5251
// In case the compute budget is ever bumped up, allowing us
5352
// to use this safe code instead:
@@ -68,7 +67,7 @@ impl<'data> BigVec<'data> {
6867

6968
// final memmove
7069
if removals_found > 0 {
71-
let gap = removals_found * T::LEN;
70+
let gap = removals_found * mem::size_of::<T>();
7271
// In case the compute budget is ever bumped up, allowing us
7372
// to use this safe code instead:
7473
//self.data.copy_within(dst_start_index + gap..data_end_index, dst_start_index);
@@ -88,11 +87,11 @@ impl<'data> BigVec<'data> {
8887
}
8988

9089
/// Extracts a slice of the data types
91-
pub fn deserialize_mut_slice<T: Pack>(
90+
pub fn deserialize_mut_slice<T: Pod>(
9291
&mut self,
9392
skip: usize,
9493
len: usize,
95-
) -> Result<Vec<&'data mut T>, ProgramError> {
94+
) -> Result<&mut [T], ProgramError> {
9695
let vec_len = self.len();
9796
let last_item_index = skip
9897
.checked_add(len)
@@ -101,66 +100,60 @@ impl<'data> BigVec<'data> {
101100
return Err(ProgramError::AccountDataTooSmall);
102101
}
103102

104-
let start_index = VEC_SIZE_BYTES.saturating_add(skip.saturating_mul(T::LEN));
105-
let end_index = start_index.saturating_add(len.saturating_mul(T::LEN));
106-
let mut deserialized = vec![];
107-
for slice in self.data[start_index..end_index].chunks_exact_mut(T::LEN) {
108-
deserialized.push(unsafe { &mut *(slice.as_ptr() as *mut T) });
103+
let start_index = VEC_SIZE_BYTES.saturating_add(skip.saturating_mul(mem::size_of::<T>()));
104+
let end_index = start_index.saturating_add(len.saturating_mul(mem::size_of::<T>()));
105+
bytemuck::try_cast_slice_mut(&mut self.data[start_index..end_index])
106+
.map_err(|_| ProgramError::InvalidAccountData)
107+
}
108+
109+
/// Extracts a slice of the data types
110+
pub fn deserialize_slice<T: Pod>(&self, skip: usize, len: usize) -> Result<&[T], ProgramError> {
111+
let vec_len = self.len();
112+
let last_item_index = skip
113+
.checked_add(len)
114+
.ok_or(ProgramError::AccountDataTooSmall)?;
115+
if last_item_index > vec_len as usize {
116+
return Err(ProgramError::AccountDataTooSmall);
109117
}
110-
Ok(deserialized)
118+
119+
let start_index = VEC_SIZE_BYTES.saturating_add(skip.saturating_mul(mem::size_of::<T>()));
120+
let end_index = start_index.saturating_add(len.saturating_mul(mem::size_of::<T>()));
121+
bytemuck::try_cast_slice(&self.data[start_index..end_index])
122+
.map_err(|_| ProgramError::InvalidAccountData)
111123
}
112124

113125
/// Add new element to the end
114-
pub fn push<T: Pack>(&mut self, element: T) -> Result<(), ProgramError> {
126+
pub fn push<T: Pod>(&mut self, element: T) -> Result<(), ProgramError> {
115127
let mut vec_len_ref = &mut self.data[0..VEC_SIZE_BYTES];
116128
let mut vec_len = u32::try_from_slice(vec_len_ref)?;
117129

118-
let start_index = VEC_SIZE_BYTES + vec_len as usize * T::LEN;
119-
let end_index = start_index + T::LEN;
130+
let start_index = VEC_SIZE_BYTES + vec_len as usize * mem::size_of::<T>();
131+
let end_index = start_index + mem::size_of::<T>();
120132

121133
vec_len += 1;
122134
borsh::to_writer(&mut vec_len_ref, &vec_len)?;
123135

124136
if self.data.len() < end_index {
125137
return Err(ProgramError::AccountDataTooSmall);
126138
}
127-
let element_ref = &mut self.data[start_index..start_index + T::LEN];
128-
element.pack_into_slice(element_ref);
139+
let element_ref = bytemuck::try_from_bytes_mut(
140+
&mut self.data[start_index..start_index + mem::size_of::<T>()],
141+
)
142+
.map_err(|_| ProgramError::InvalidAccountData)?;
143+
*element_ref = element;
129144
Ok(())
130145
}
131146

132-
/// Get an iterator for the type provided
133-
pub fn iter<'vec, T: Pack>(&'vec self) -> Iter<'data, 'vec, T> {
134-
Iter {
135-
len: self.len() as usize,
136-
current: 0,
137-
current_index: VEC_SIZE_BYTES,
138-
inner: self,
139-
phantom: PhantomData,
140-
}
141-
}
142-
143-
/// Get a mutable iterator for the type provided
144-
pub fn iter_mut<'vec, T: Pack>(&'vec mut self) -> IterMut<'data, 'vec, T> {
145-
IterMut {
146-
len: self.len() as usize,
147-
current: 0,
148-
current_index: VEC_SIZE_BYTES,
149-
inner: self,
150-
phantom: PhantomData,
151-
}
152-
}
153-
154147
/// Find matching data in the array
155-
pub fn find<T: Pack, F: Fn(&[u8]) -> bool>(&self, predicate: F) -> Option<&T> {
148+
pub fn find<T: Pod, F: Fn(&[u8]) -> bool>(&self, predicate: F) -> Option<&T> {
156149
let len = self.len() as usize;
157150
let mut current = 0;
158151
let mut current_index = VEC_SIZE_BYTES;
159152
while current != len {
160-
let end_index = current_index + T::LEN;
153+
let end_index = current_index + mem::size_of::<T>();
161154
let current_slice = &self.data[current_index..end_index];
162155
if predicate(current_slice) {
163-
return Some(unsafe { &*(current_slice.as_ptr() as *const T) });
156+
return Some(bytemuck::from_bytes(current_slice));
164157
}
165158
current_index = end_index;
166159
current += 1;
@@ -169,15 +162,17 @@ impl<'data> BigVec<'data> {
169162
}
170163

171164
/// Find matching data in the array
172-
pub fn find_mut<T: Pack, F: Fn(&[u8]) -> bool>(&mut self, predicate: F) -> Option<&mut T> {
165+
pub fn find_mut<T: Pod, F: Fn(&[u8]) -> bool>(&mut self, predicate: F) -> Option<&mut T> {
173166
let len = self.len() as usize;
174167
let mut current = 0;
175168
let mut current_index = VEC_SIZE_BYTES;
176169
while current != len {
177-
let end_index = current_index + T::LEN;
170+
let end_index = current_index + mem::size_of::<T>();
178171
let current_slice = &self.data[current_index..end_index];
179172
if predicate(current_slice) {
180-
return Some(unsafe { &mut *(current_slice.as_ptr() as *mut T) });
173+
return Some(bytemuck::from_bytes_mut(
174+
&mut self.data[current_index..end_index],
175+
));
181176
}
182177
current_index = end_index;
183178
current += 1;
@@ -186,84 +181,16 @@ impl<'data> BigVec<'data> {
186181
}
187182
}
188183

189-
/// Iterator wrapper over a BigVec
190-
pub struct Iter<'data, 'vec, T> {
191-
len: usize,
192-
current: usize,
193-
current_index: usize,
194-
inner: &'vec BigVec<'data>,
195-
phantom: PhantomData<T>,
196-
}
197-
198-
impl<'data, 'vec, T: Pack + 'data> Iterator for Iter<'data, 'vec, T> {
199-
type Item = &'data T;
200-
201-
fn next(&mut self) -> Option<Self::Item> {
202-
if self.current == self.len {
203-
None
204-
} else {
205-
let end_index = self.current_index + T::LEN;
206-
let value = Some(unsafe {
207-
&*(self.inner.data[self.current_index..end_index].as_ptr() as *const T)
208-
});
209-
self.current += 1;
210-
self.current_index = end_index;
211-
value
212-
}
213-
}
214-
}
215-
216-
/// Iterator wrapper over a BigVec
217-
pub struct IterMut<'data, 'vec, T> {
218-
len: usize,
219-
current: usize,
220-
current_index: usize,
221-
inner: &'vec mut BigVec<'data>,
222-
phantom: PhantomData<T>,
223-
}
224-
225-
impl<'data, 'vec, T: Pack + 'data> Iterator for IterMut<'data, 'vec, T> {
226-
type Item = &'data mut T;
227-
228-
fn next(&mut self) -> Option<Self::Item> {
229-
if self.current == self.len {
230-
None
231-
} else {
232-
let end_index = self.current_index + T::LEN;
233-
let value = Some(unsafe {
234-
&mut *(self.inner.data[self.current_index..end_index].as_ptr() as *mut T)
235-
});
236-
self.current += 1;
237-
self.current_index = end_index;
238-
value
239-
}
240-
}
241-
}
242-
243184
#[cfg(test)]
244185
mod tests {
245-
use {super::*, solana_program::program_pack::Sealed};
186+
use {super::*, bytemuck::Zeroable};
246187

247-
#[derive(Debug, PartialEq)]
188+
#[repr(C)]
189+
#[derive(Debug, Copy, Clone, PartialEq, Pod, Zeroable)]
248190
struct TestStruct {
249191
value: [u8; 8],
250192
}
251193

252-
impl Sealed for TestStruct {}
253-
254-
impl Pack for TestStruct {
255-
const LEN: usize = 8;
256-
fn pack_into_slice(&self, data: &mut [u8]) {
257-
let mut data = data;
258-
borsh::to_writer(&mut data, &self.value).unwrap();
259-
}
260-
fn unpack_from_slice(src: &[u8]) -> Result<Self, ProgramError> {
261-
Ok(TestStruct {
262-
value: src.try_into().unwrap(),
263-
})
264-
}
265-
}
266-
267194
impl TestStruct {
268195
fn new(value: u8) -> Self {
269196
let value = [value, 0, 0, 0, 0, 0, 0, 0];
@@ -281,7 +208,9 @@ mod tests {
281208

282209
fn check_big_vec_eq(big_vec: &BigVec, slice: &[u8]) {
283210
assert!(big_vec
284-
.iter::<TestStruct>()
211+
.deserialize_slice::<TestStruct>(0, big_vec.len() as usize)
212+
.unwrap()
213+
.iter()
285214
.map(|x| &x.value[0])
286215
.zip(slice.iter())
287216
.all(|(a, b)| a == b));

0 commit comments

Comments
 (0)