Skip to content

Commit 3b7e59c

Browse files
committed
more
1 parent 79e2268 commit 3b7e59c

File tree

2 files changed

+63
-34
lines changed

2 files changed

+63
-34
lines changed

src/sys/tree.rs

Lines changed: 59 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ impl<'treeseq> LLTree<'treeseq> {
5959
pub fn samples_array(&self) -> Result<&[super::newtypes::NodeId], TskitError> {
6060
err_if_not_tracking_samples!(
6161
self.flags,
62-
super::generate_slice(self.as_ll_ref().samples, self.num_samples())
62+
// SAFETY: num_samples is the correct value
63+
unsafe { super::generate_slice(self.as_ll_ref().samples, self.num_samples()) }
6364
)
6465
}
6566

@@ -182,43 +183,63 @@ impl<'treeseq> LLTree<'treeseq> {
182183

183184
pub fn sample_nodes(&self) -> &[NodeId] {
184185
assert!(!self.as_ptr().is_null());
185-
// SAFETY: self ptr is not null and the tree is initialized
186-
let num_samples =
187-
unsafe { bindings::tsk_treeseq_get_num_samples(self.as_ll_ref().tree_sequence) };
188-
super::generate_slice(self.as_ll_ref().samples, num_samples)
186+
unsafe {
187+
// SAFETY: self ptr is not null and the tree is initialized
188+
// num_samples is the correct array length
189+
let num_samples = bindings::tsk_treeseq_get_num_samples(self.as_ll_ref().tree_sequence);
190+
super::generate_slice(self.as_ll_ref().samples, num_samples)
191+
}
189192
}
190193

191194
pub fn parent_array(&self) -> &[NodeId] {
192-
super::generate_slice(self.as_ll_ref().parent, self.treeseq.num_nodes_raw() + 1)
195+
// SAFETY: the array length is the number of nodes + 1 for the "virtual root"
196+
unsafe { super::generate_slice(self.as_ll_ref().parent, self.treeseq.num_nodes_raw() + 1) }
193197
}
194198

195199
pub fn left_sib_array(&self) -> &[NodeId] {
196-
super::generate_slice(self.as_ll_ref().left_sib, self.treeseq.num_nodes_raw() + 1)
200+
// SAFETY: the array length is the number of nodes + 1 for the "virtual root"
201+
unsafe {
202+
super::generate_slice(self.as_ll_ref().left_sib, self.treeseq.num_nodes_raw() + 1)
203+
}
197204
}
198205

199206
pub fn right_sib_array(&self) -> &[NodeId] {
200-
super::generate_slice(self.as_ll_ref().right_sib, self.treeseq.num_nodes_raw() + 1)
207+
// SAFETY: the array length is the number of nodes + 1 for the "virtual root"
208+
unsafe {
209+
super::generate_slice(self.as_ll_ref().right_sib, self.treeseq.num_nodes_raw() + 1)
210+
}
201211
}
202212

203213
pub fn left_child_array(&self) -> &[NodeId] {
204-
super::generate_slice(
205-
self.as_ll_ref().left_child,
206-
self.treeseq.num_nodes_raw() + 1,
207-
)
214+
// SAFETY: the array length is the number of nodes + 1 for the "virtual root"
215+
unsafe {
216+
super::generate_slice(
217+
self.as_ll_ref().left_child,
218+
self.treeseq.num_nodes_raw() + 1,
219+
)
220+
}
208221
}
209222

210223
pub fn right_child_array(&self) -> &[NodeId] {
211-
super::generate_slice(
212-
self.as_ll_ref().right_child,
213-
self.treeseq.num_nodes_raw() + 1,
214-
)
224+
// SAFETY: the array length is the number of nodes + 1 for the "virtual root"
225+
unsafe {
226+
super::generate_slice(
227+
self.as_ll_ref().right_child,
228+
self.treeseq.num_nodes_raw() + 1,
229+
)
230+
}
215231
}
216232

217233
pub fn total_branch_length(&self, by_span: bool) -> Result<Time, TskitError> {
218-
let time: &[Time] = super::generate_slice(
219-
unsafe { (*(*(*self.as_ptr()).tree_sequence).tables).nodes.time },
220-
self.treeseq.num_nodes_raw() + 1,
221-
);
234+
assert!(!self.treeseq.as_ref().tables.is_null());
235+
// SAFETY: array len is number of nodes + 1 for the "virtual root"
236+
// tables ptr is not NULL
237+
let time: &[Time] = unsafe {
238+
super::generate_slice(
239+
(*(self.treeseq.as_ref()).tables).nodes.time,
240+
self.treeseq.num_nodes_raw() + 1,
241+
)
242+
};
222243
let mut b = Time::from(0.);
223244
for n in self.traverse_nodes(NodeTraversalOrder::Preorder) {
224245
let p = self.parent(n).ok_or(TskitError::IndexError {})?;
@@ -246,32 +267,38 @@ impl<'treeseq> LLTree<'treeseq> {
246267
}
247268

248269
pub fn left_sample_array(&self) -> Result<&[NodeId], TskitError> {
249-
err_if_not_tracking_samples!(
250-
self.flags,
270+
err_if_not_tracking_samples!(self.flags, unsafe {
271+
// SAFETY: array length is number of nodes + 1 for the "virtual root"
251272
super::generate_slice(
252273
self.as_ll_ref().left_sample,
253-
self.treeseq.num_nodes_raw() + 1
274+
self.treeseq.num_nodes_raw() + 1,
254275
)
255-
)
276+
})
256277
}
257278

258279
pub fn right_sample_array(&self) -> Result<&[NodeId], TskitError> {
259280
err_if_not_tracking_samples!(
260281
self.flags,
261-
super::generate_slice(
262-
self.as_ll_ref().right_sample,
263-
self.treeseq.num_nodes_raw() + 1
264-
)
282+
// SAFETY: array length is number of nodes + 1 for the "virtual root"
283+
unsafe {
284+
super::generate_slice(
285+
self.as_ll_ref().right_sample,
286+
self.treeseq.num_nodes_raw() + 1,
287+
)
288+
}
265289
)
266290
}
267291

268292
pub fn next_sample_array(&self) -> Result<&[NodeId], TskitError> {
269293
err_if_not_tracking_samples!(
270294
self.flags,
271-
super::generate_slice(
272-
self.as_ll_ref().next_sample,
273-
self.treeseq.num_nodes_raw() + 1
274-
)
295+
// SAFETY: array length is number of nodes + 1 for the "virtual root"
296+
unsafe {
297+
super::generate_slice(
298+
self.as_ll_ref().next_sample,
299+
self.treeseq.num_nodes_raw() + 1,
300+
)
301+
}
275302
)
276303
}
277304

src/trees/treeseq.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -283,8 +283,10 @@ impl TreeSequence {
283283

284284
/// Get the list of sample nodes as a slice.
285285
pub fn sample_nodes(&self) -> &[NodeId] {
286-
let num_samples = unsafe { ll_bindings::tsk_treeseq_get_num_samples(self.as_ptr()) };
287-
sys::generate_slice(self.as_ref().samples, num_samples)
286+
unsafe {
287+
let num_samples = ll_bindings::tsk_treeseq_get_num_samples(self.as_ref());
288+
sys::generate_slice(self.as_ref().samples, num_samples)
289+
}
288290
}
289291

290292
/// Get the number of trees.

0 commit comments

Comments
 (0)