4
4
use {
5
5
arrayref:: array_ref,
6
6
borsh:: BorshDeserialize ,
7
- solana_program:: {
8
- program_error:: ProgramError , program_memory:: sol_memmove, program_pack:: Pack ,
9
- } ,
10
- std:: marker:: PhantomData ,
7
+ bytemuck:: Pod ,
8
+ solana_program:: { program_error:: ProgramError , program_memory:: sol_memmove} ,
9
+ std:: mem,
11
10
} ;
12
11
13
12
/// Contains easy to use utilities for a big vector of Borsh-compatible types,
@@ -32,7 +31,7 @@ impl<'data> BigVec<'data> {
32
31
}
33
32
34
33
/// Retain all elements that match the provided function, discard all others
35
- pub fn retain < T : Pack , F : Fn ( & [ u8 ] ) -> bool > (
34
+ pub fn retain < T : Pod , F : Fn ( & [ u8 ] ) -> bool > (
36
35
& mut self ,
37
36
predicate : F ,
38
37
) -> Result < ( ) , ProgramError > {
@@ -42,12 +41,12 @@ impl<'data> BigVec<'data> {
42
41
43
42
let data_start_index = VEC_SIZE_BYTES ;
44
43
let data_end_index =
45
- data_start_index. saturating_add ( ( vec_len as usize ) . saturating_mul ( T :: LEN ) ) ;
46
- for start_index in ( data_start_index..data_end_index) . step_by ( T :: LEN ) {
47
- let end_index = start_index + T :: LEN ;
44
+ data_start_index. saturating_add ( ( vec_len as usize ) . saturating_mul ( mem :: size_of :: < T > ( ) ) ) ;
45
+ for start_index in ( data_start_index..data_end_index) . step_by ( mem :: size_of :: < T > ( ) ) {
46
+ let end_index = start_index + mem :: size_of :: < T > ( ) ;
48
47
let slice = & self . data [ start_index..end_index] ;
49
48
if !predicate ( slice) {
50
- let gap = removals_found * T :: LEN ;
49
+ let gap = removals_found * mem :: size_of :: < T > ( ) ;
51
50
if removals_found > 0 {
52
51
// In case the compute budget is ever bumped up, allowing us
53
52
// to use this safe code instead:
@@ -68,7 +67,7 @@ impl<'data> BigVec<'data> {
68
67
69
68
// final memmove
70
69
if removals_found > 0 {
71
- let gap = removals_found * T :: LEN ;
70
+ let gap = removals_found * mem :: size_of :: < T > ( ) ;
72
71
// In case the compute budget is ever bumped up, allowing us
73
72
// to use this safe code instead:
74
73
//self.data.copy_within(dst_start_index + gap..data_end_index, dst_start_index);
@@ -88,11 +87,11 @@ impl<'data> BigVec<'data> {
88
87
}
89
88
90
89
/// Extracts a slice of the data types
91
- pub fn deserialize_mut_slice < T : Pack > (
90
+ pub fn deserialize_mut_slice < T : Pod > (
92
91
& mut self ,
93
92
skip : usize ,
94
93
len : usize ,
95
- ) -> Result < Vec < & ' data mut T > , ProgramError > {
94
+ ) -> Result < & mut [ T ] , ProgramError > {
96
95
let vec_len = self . len ( ) ;
97
96
let last_item_index = skip
98
97
. checked_add ( len)
@@ -101,66 +100,60 @@ impl<'data> BigVec<'data> {
101
100
return Err ( ProgramError :: AccountDataTooSmall ) ;
102
101
}
103
102
104
- let start_index = VEC_SIZE_BYTES . saturating_add ( skip. saturating_mul ( T :: LEN ) ) ;
105
- let end_index = start_index. saturating_add ( len. saturating_mul ( T :: LEN ) ) ;
106
- let mut deserialized = vec ! [ ] ;
107
- for slice in self . data [ start_index..end_index] . chunks_exact_mut ( T :: LEN ) {
108
- deserialized. push ( unsafe { & mut * ( slice. as_ptr ( ) as * mut T ) } ) ;
103
+ let start_index = VEC_SIZE_BYTES . saturating_add ( skip. saturating_mul ( mem:: size_of :: < T > ( ) ) ) ;
104
+ let end_index = start_index. saturating_add ( len. saturating_mul ( mem:: size_of :: < T > ( ) ) ) ;
105
+ bytemuck:: try_cast_slice_mut ( & mut self . data [ start_index..end_index] )
106
+ . map_err ( |_| ProgramError :: InvalidAccountData )
107
+ }
108
+
109
+ /// Extracts a slice of the data types
110
+ pub fn deserialize_slice < T : Pod > ( & self , skip : usize , len : usize ) -> Result < & [ T ] , ProgramError > {
111
+ let vec_len = self . len ( ) ;
112
+ let last_item_index = skip
113
+ . checked_add ( len)
114
+ . ok_or ( ProgramError :: AccountDataTooSmall ) ?;
115
+ if last_item_index > vec_len as usize {
116
+ return Err ( ProgramError :: AccountDataTooSmall ) ;
109
117
}
110
- Ok ( deserialized)
118
+
119
+ let start_index = VEC_SIZE_BYTES . saturating_add ( skip. saturating_mul ( mem:: size_of :: < T > ( ) ) ) ;
120
+ let end_index = start_index. saturating_add ( len. saturating_mul ( mem:: size_of :: < T > ( ) ) ) ;
121
+ bytemuck:: try_cast_slice ( & self . data [ start_index..end_index] )
122
+ . map_err ( |_| ProgramError :: InvalidAccountData )
111
123
}
112
124
113
125
/// Add new element to the end
114
- pub fn push < T : Pack > ( & mut self , element : T ) -> Result < ( ) , ProgramError > {
126
+ pub fn push < T : Pod > ( & mut self , element : T ) -> Result < ( ) , ProgramError > {
115
127
let mut vec_len_ref = & mut self . data [ 0 ..VEC_SIZE_BYTES ] ;
116
128
let mut vec_len = u32:: try_from_slice ( vec_len_ref) ?;
117
129
118
- let start_index = VEC_SIZE_BYTES + vec_len as usize * T :: LEN ;
119
- let end_index = start_index + T :: LEN ;
130
+ let start_index = VEC_SIZE_BYTES + vec_len as usize * mem :: size_of :: < T > ( ) ;
131
+ let end_index = start_index + mem :: size_of :: < T > ( ) ;
120
132
121
133
vec_len += 1 ;
122
134
borsh:: to_writer ( & mut vec_len_ref, & vec_len) ?;
123
135
124
136
if self . data . len ( ) < end_index {
125
137
return Err ( ProgramError :: AccountDataTooSmall ) ;
126
138
}
127
- let element_ref = & mut self . data [ start_index..start_index + T :: LEN ] ;
128
- element. pack_into_slice ( element_ref) ;
139
+ let element_ref = bytemuck:: try_from_bytes_mut (
140
+ & mut self . data [ start_index..start_index + mem:: size_of :: < T > ( ) ] ,
141
+ )
142
+ . map_err ( |_| ProgramError :: InvalidAccountData ) ?;
143
+ * element_ref = element;
129
144
Ok ( ( ) )
130
145
}
131
146
132
- /// Get an iterator for the type provided
133
- pub fn iter < ' vec , T : Pack > ( & ' vec self ) -> Iter < ' data , ' vec , T > {
134
- Iter {
135
- len : self . len ( ) as usize ,
136
- current : 0 ,
137
- current_index : VEC_SIZE_BYTES ,
138
- inner : self ,
139
- phantom : PhantomData ,
140
- }
141
- }
142
-
143
- /// Get a mutable iterator for the type provided
144
- pub fn iter_mut < ' vec , T : Pack > ( & ' vec mut self ) -> IterMut < ' data , ' vec , T > {
145
- IterMut {
146
- len : self . len ( ) as usize ,
147
- current : 0 ,
148
- current_index : VEC_SIZE_BYTES ,
149
- inner : self ,
150
- phantom : PhantomData ,
151
- }
152
- }
153
-
154
147
/// Find matching data in the array
155
- pub fn find < T : Pack , F : Fn ( & [ u8 ] ) -> bool > ( & self , predicate : F ) -> Option < & T > {
148
+ pub fn find < T : Pod , F : Fn ( & [ u8 ] ) -> bool > ( & self , predicate : F ) -> Option < & T > {
156
149
let len = self . len ( ) as usize ;
157
150
let mut current = 0 ;
158
151
let mut current_index = VEC_SIZE_BYTES ;
159
152
while current != len {
160
- let end_index = current_index + T :: LEN ;
153
+ let end_index = current_index + mem :: size_of :: < T > ( ) ;
161
154
let current_slice = & self . data [ current_index..end_index] ;
162
155
if predicate ( current_slice) {
163
- return Some ( unsafe { & * ( current_slice. as_ptr ( ) as * const T ) } ) ;
156
+ return Some ( bytemuck :: from_bytes ( current_slice) ) ;
164
157
}
165
158
current_index = end_index;
166
159
current += 1 ;
@@ -169,15 +162,17 @@ impl<'data> BigVec<'data> {
169
162
}
170
163
171
164
/// Find matching data in the array
172
- pub fn find_mut < T : Pack , F : Fn ( & [ u8 ] ) -> bool > ( & mut self , predicate : F ) -> Option < & mut T > {
165
+ pub fn find_mut < T : Pod , F : Fn ( & [ u8 ] ) -> bool > ( & mut self , predicate : F ) -> Option < & mut T > {
173
166
let len = self . len ( ) as usize ;
174
167
let mut current = 0 ;
175
168
let mut current_index = VEC_SIZE_BYTES ;
176
169
while current != len {
177
- let end_index = current_index + T :: LEN ;
170
+ let end_index = current_index + mem :: size_of :: < T > ( ) ;
178
171
let current_slice = & self . data [ current_index..end_index] ;
179
172
if predicate ( current_slice) {
180
- return Some ( unsafe { & mut * ( current_slice. as_ptr ( ) as * mut T ) } ) ;
173
+ return Some ( bytemuck:: from_bytes_mut (
174
+ & mut self . data [ current_index..end_index] ,
175
+ ) ) ;
181
176
}
182
177
current_index = end_index;
183
178
current += 1 ;
@@ -186,84 +181,16 @@ impl<'data> BigVec<'data> {
186
181
}
187
182
}
188
183
189
- /// Iterator wrapper over a BigVec
190
- pub struct Iter < ' data , ' vec , T > {
191
- len : usize ,
192
- current : usize ,
193
- current_index : usize ,
194
- inner : & ' vec BigVec < ' data > ,
195
- phantom : PhantomData < T > ,
196
- }
197
-
198
- impl < ' data , ' vec , T : Pack + ' data > Iterator for Iter < ' data , ' vec , T > {
199
- type Item = & ' data T ;
200
-
201
- fn next ( & mut self ) -> Option < Self :: Item > {
202
- if self . current == self . len {
203
- None
204
- } else {
205
- let end_index = self . current_index + T :: LEN ;
206
- let value = Some ( unsafe {
207
- & * ( self . inner . data [ self . current_index ..end_index] . as_ptr ( ) as * const T )
208
- } ) ;
209
- self . current += 1 ;
210
- self . current_index = end_index;
211
- value
212
- }
213
- }
214
- }
215
-
216
- /// Iterator wrapper over a BigVec
217
- pub struct IterMut < ' data , ' vec , T > {
218
- len : usize ,
219
- current : usize ,
220
- current_index : usize ,
221
- inner : & ' vec mut BigVec < ' data > ,
222
- phantom : PhantomData < T > ,
223
- }
224
-
225
- impl < ' data , ' vec , T : Pack + ' data > Iterator for IterMut < ' data , ' vec , T > {
226
- type Item = & ' data mut T ;
227
-
228
- fn next ( & mut self ) -> Option < Self :: Item > {
229
- if self . current == self . len {
230
- None
231
- } else {
232
- let end_index = self . current_index + T :: LEN ;
233
- let value = Some ( unsafe {
234
- & mut * ( self . inner . data [ self . current_index ..end_index] . as_ptr ( ) as * mut T )
235
- } ) ;
236
- self . current += 1 ;
237
- self . current_index = end_index;
238
- value
239
- }
240
- }
241
- }
242
-
243
184
#[ cfg( test) ]
244
185
mod tests {
245
- use { super :: * , solana_program :: program_pack :: Sealed } ;
186
+ use { super :: * , bytemuck :: Zeroable } ;
246
187
247
- #[ derive( Debug , PartialEq ) ]
188
+ #[ repr( C ) ]
189
+ #[ derive( Debug , Copy , Clone , PartialEq , Pod , Zeroable ) ]
248
190
struct TestStruct {
249
191
value : [ u8 ; 8 ] ,
250
192
}
251
193
252
- impl Sealed for TestStruct { }
253
-
254
- impl Pack for TestStruct {
255
- const LEN : usize = 8 ;
256
- fn pack_into_slice ( & self , data : & mut [ u8 ] ) {
257
- let mut data = data;
258
- borsh:: to_writer ( & mut data, & self . value ) . unwrap ( ) ;
259
- }
260
- fn unpack_from_slice ( src : & [ u8 ] ) -> Result < Self , ProgramError > {
261
- Ok ( TestStruct {
262
- value : src. try_into ( ) . unwrap ( ) ,
263
- } )
264
- }
265
- }
266
-
267
194
impl TestStruct {
268
195
fn new ( value : u8 ) -> Self {
269
196
let value = [ value, 0 , 0 , 0 , 0 , 0 , 0 , 0 ] ;
@@ -281,7 +208,9 @@ mod tests {
281
208
282
209
fn check_big_vec_eq ( big_vec : & BigVec , slice : & [ u8 ] ) {
283
210
assert ! ( big_vec
284
- . iter:: <TestStruct >( )
211
+ . deserialize_slice:: <TestStruct >( 0 , big_vec. len( ) as usize )
212
+ . unwrap( )
213
+ . iter( )
285
214
. map( |x| & x. value[ 0 ] )
286
215
. zip( slice. iter( ) )
287
216
. all( |( a, b) | a == b) ) ;
0 commit comments