Skip to content

Commit ead316c

Browse files
authored
refactor: (correctly) mark an internal fn as unsafe (#708)
1 parent cc97621 commit ead316c

File tree

4 files changed

+81
-44
lines changed

4 files changed

+81
-44
lines changed

src/_macros.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ macro_rules! build_table_column_slice_getter {
118118
($(#[$attr:meta])* => $column: ident, $name: ident, $cast: ty) => {
119119
$(#[$attr])*
120120
pub fn $name(&self) -> &[$cast] {
121-
$crate::sys::generate_slice(self.as_ref().$column, self.num_rows())
121+
// SAFETY: all array lengths are the number of rows in the table
122+
unsafe{$crate::sys::generate_slice(self.as_ref().$column, self.num_rows())}
122123
}
123124
};
124125
}
@@ -127,7 +128,8 @@ macro_rules! build_table_column_slice_mut_getter {
127128
($(#[$attr:meta])* => $column: ident, $name: ident, $cast: ty) => {
128129
$(#[$attr])*
129130
pub fn $name(&mut self) -> &mut [$cast] {
130-
$crate::sys::generate_slice_mut(self.as_ref().$column, self.num_rows())
131+
// SAFETY: all array lengths are the number of rows in the table
132+
unsafe{$crate::sys::generate_slice_mut(self.as_ref().$column, self.num_rows())}
131133
}
132134
};
133135
}

src/sys/mod.rs

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -189,22 +189,28 @@ pub fn tsk_ragged_column_access<
189189
.map(|(p, n)| unsafe { std::slice::from_raw_parts(p.cast::<O>(), n) })
190190
}
191191

192-
pub fn generate_slice<'a, L: Into<bindings::tsk_size_t>, I, O>(
192+
/// # SAFETY
193+
///
194+
/// * data must not be NULL
195+
/// * length must be a valid offset from data
196+
/// (ideally it comes from the tskit-c API)
197+
pub unsafe fn generate_slice<'a, L: Into<bindings::tsk_size_t>, I, O>(
193198
data: *const I,
194199
length: L,
195200
) -> &'a [O] {
196-
assert!(!data.is_null());
197-
// SAFETY: pointer is not null, length comes from C API
198-
unsafe { std::slice::from_raw_parts(data.cast::<O>(), length.into() as usize) }
201+
std::slice::from_raw_parts(data.cast::<O>(), length.into() as usize)
199202
}
200203

201-
pub fn generate_slice_mut<'a, L: Into<bindings::tsk_size_t>, I, O>(
204+
/// # SAFETY
205+
///
206+
/// * data must not be NULL
207+
/// * length must be a valid offset from data
208+
/// (ideally it comes from the tskit-c API)
209+
pub unsafe fn generate_slice_mut<'a, L: Into<bindings::tsk_size_t>, I, O>(
202210
data: *mut I,
203211
length: L,
204212
) -> &'a mut [O] {
205-
assert!(!data.is_null());
206-
// SAFETY: pointer is not null, length comes from C API
207-
unsafe { std::slice::from_raw_parts_mut(data.cast::<O>(), length.into() as usize) }
213+
std::slice::from_raw_parts_mut(data.cast::<O>(), length.into() as usize)
208214
}
209215

210216
pub fn get_tskit_error_message(code: i32) -> String {

src/sys/tree.rs

Lines changed: 59 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ impl<'treeseq> LLTree<'treeseq> {
5959
pub fn samples_array(&self) -> Result<&[super::newtypes::NodeId], TskitError> {
6060
err_if_not_tracking_samples!(
6161
self.flags,
62-
super::generate_slice(self.as_ll_ref().samples, self.num_samples())
62+
// SAFETY: num_samples is the correct value
63+
unsafe { super::generate_slice(self.as_ll_ref().samples, self.num_samples()) }
6364
)
6465
}
6566

@@ -182,43 +183,63 @@ impl<'treeseq> LLTree<'treeseq> {
182183

183184
pub fn sample_nodes(&self) -> &[NodeId] {
184185
assert!(!self.as_ptr().is_null());
185-
// SAFETY: self ptr is not null and the tree is initialized
186-
let num_samples =
187-
unsafe { bindings::tsk_treeseq_get_num_samples(self.as_ll_ref().tree_sequence) };
188-
super::generate_slice(self.as_ll_ref().samples, num_samples)
186+
unsafe {
187+
// SAFETY: self ptr is not null and the tree is initialized
188+
// num_samples is the correct array length
189+
let num_samples = bindings::tsk_treeseq_get_num_samples(self.as_ll_ref().tree_sequence);
190+
super::generate_slice(self.as_ll_ref().samples, num_samples)
191+
}
189192
}
190193

191194
pub fn parent_array(&self) -> &[NodeId] {
192-
super::generate_slice(self.as_ll_ref().parent, self.treeseq.num_nodes_raw() + 1)
195+
// SAFETY: the array length is the number of nodes + 1 for the "virtual root"
196+
unsafe { super::generate_slice(self.as_ll_ref().parent, self.treeseq.num_nodes_raw() + 1) }
193197
}
194198

195199
pub fn left_sib_array(&self) -> &[NodeId] {
196-
super::generate_slice(self.as_ll_ref().left_sib, self.treeseq.num_nodes_raw() + 1)
200+
// SAFETY: the array length is the number of nodes + 1 for the "virtual root"
201+
unsafe {
202+
super::generate_slice(self.as_ll_ref().left_sib, self.treeseq.num_nodes_raw() + 1)
203+
}
197204
}
198205

199206
pub fn right_sib_array(&self) -> &[NodeId] {
200-
super::generate_slice(self.as_ll_ref().right_sib, self.treeseq.num_nodes_raw() + 1)
207+
// SAFETY: the array length is the number of nodes + 1 for the "virtual root"
208+
unsafe {
209+
super::generate_slice(self.as_ll_ref().right_sib, self.treeseq.num_nodes_raw() + 1)
210+
}
201211
}
202212

203213
pub fn left_child_array(&self) -> &[NodeId] {
204-
super::generate_slice(
205-
self.as_ll_ref().left_child,
206-
self.treeseq.num_nodes_raw() + 1,
207-
)
214+
// SAFETY: the array length is the number of nodes + 1 for the "virtual root"
215+
unsafe {
216+
super::generate_slice(
217+
self.as_ll_ref().left_child,
218+
self.treeseq.num_nodes_raw() + 1,
219+
)
220+
}
208221
}
209222

210223
pub fn right_child_array(&self) -> &[NodeId] {
211-
super::generate_slice(
212-
self.as_ll_ref().right_child,
213-
self.treeseq.num_nodes_raw() + 1,
214-
)
224+
// SAFETY: the array length is the number of nodes + 1 for the "virtual root"
225+
unsafe {
226+
super::generate_slice(
227+
self.as_ll_ref().right_child,
228+
self.treeseq.num_nodes_raw() + 1,
229+
)
230+
}
215231
}
216232

217233
pub fn total_branch_length(&self, by_span: bool) -> Result<Time, TskitError> {
218-
let time: &[Time] = super::generate_slice(
219-
unsafe { (*(*(*self.as_ptr()).tree_sequence).tables).nodes.time },
220-
self.treeseq.num_nodes_raw() + 1,
221-
);
234+
assert!(!self.treeseq.as_ref().tables.is_null());
235+
// SAFETY: array len is number of nodes + 1 for the "virtual root"
236+
// tables ptr is not NULL
237+
let time: &[Time] = unsafe {
238+
super::generate_slice(
239+
(*(self.treeseq.as_ref()).tables).nodes.time,
240+
self.treeseq.num_nodes_raw() + 1,
241+
)
242+
};
222243
let mut b = Time::from(0.);
223244
for n in self.traverse_nodes(NodeTraversalOrder::Preorder) {
224245
let p = self.parent(n).ok_or(TskitError::IndexError {})?;
@@ -246,32 +267,38 @@ impl<'treeseq> LLTree<'treeseq> {
246267
}
247268

248269
pub fn left_sample_array(&self) -> Result<&[NodeId], TskitError> {
249-
err_if_not_tracking_samples!(
250-
self.flags,
270+
err_if_not_tracking_samples!(self.flags, unsafe {
271+
// SAFETY: array length is number of nodes + 1 for the "virtual root"
251272
super::generate_slice(
252273
self.as_ll_ref().left_sample,
253-
self.treeseq.num_nodes_raw() + 1
274+
self.treeseq.num_nodes_raw() + 1,
254275
)
255-
)
276+
})
256277
}
257278

258279
pub fn right_sample_array(&self) -> Result<&[NodeId], TskitError> {
259280
err_if_not_tracking_samples!(
260281
self.flags,
261-
super::generate_slice(
262-
self.as_ll_ref().right_sample,
263-
self.treeseq.num_nodes_raw() + 1
264-
)
282+
// SAFETY: array length is number of nodes + 1 for the "virtual root"
283+
unsafe {
284+
super::generate_slice(
285+
self.as_ll_ref().right_sample,
286+
self.treeseq.num_nodes_raw() + 1,
287+
)
288+
}
265289
)
266290
}
267291

268292
pub fn next_sample_array(&self) -> Result<&[NodeId], TskitError> {
269293
err_if_not_tracking_samples!(
270294
self.flags,
271-
super::generate_slice(
272-
self.as_ll_ref().next_sample,
273-
self.treeseq.num_nodes_raw() + 1
274-
)
295+
// SAFETY: array length is number of nodes + 1 for the "virtual root"
296+
unsafe {
297+
super::generate_slice(
298+
self.as_ll_ref().next_sample,
299+
self.treeseq.num_nodes_raw() + 1,
300+
)
301+
}
275302
)
276303
}
277304

src/trees/treeseq.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -283,8 +283,10 @@ impl TreeSequence {
283283

284284
/// Get the list of sample nodes as a slice.
285285
pub fn sample_nodes(&self) -> &[NodeId] {
286-
let num_samples = unsafe { ll_bindings::tsk_treeseq_get_num_samples(self.as_ptr()) };
287-
sys::generate_slice(self.as_ref().samples, num_samples)
286+
unsafe {
287+
let num_samples = ll_bindings::tsk_treeseq_get_num_samples(self.as_ref());
288+
sys::generate_slice(self.as_ref().samples, num_samples)
289+
}
288290
}
289291

290292
/// Get the number of trees.

0 commit comments

Comments
 (0)