Skip to content

Commit dfe1b3c

Browse files
authored
refactor: internal fn for ragged access is now safe (#721)
1 parent dfb487a commit dfe1b3c

File tree

7 files changed

+170
-144
lines changed

7 files changed

+170
-144
lines changed

src/metadata.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,13 @@
9191
//! let id = tables.add_individual_with_metadata(0, &[] as &[tskit::Location], &[tskit::IndividualId::NULL], &individual).unwrap();
9292
//! let decoded = tables.individuals().metadata::<IndividualMetadata>(id).unwrap().unwrap();
9393
//! assert_eq!(decoded.genetic_value.partial_cmp(&individual.genetic_value).unwrap(), std::cmp::Ordering::Equal);
94+
//! let _ = tables.add_individual(0, &[] as &[tskit::Location], &[tskit::IndividualId::NULL]).unwrap();
95+
//! let individual2 = IndividualMetadata {
96+
//! genetic_value: GeneticValue(1.0),
97+
//! };
98+
//! let id2 = tables.add_individual_with_metadata(0, &[] as &[tskit::Location], &[tskit::IndividualId::NULL], &individual2).unwrap();
99+
//! let decoded2 = tables.individuals().metadata::<IndividualMetadata>(id2).unwrap().unwrap();
100+
//! assert_eq!(decoded2.genetic_value.partial_cmp(&individual2.genetic_value).unwrap(), std::cmp::Ordering::Equal);
94101
//! # }
95102
//! ```
96103
//!

src/sys/individual_table.rs

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use std::ptr::NonNull;
22

3+
use super::bindings::tsk_size_t;
34
use super::flags::IndividualFlags;
45
use super::newtypes::IndividualId;
56

@@ -76,18 +77,46 @@ impl IndividualTable {
7677

7778
raw_metadata_getter_for_tables!(IndividualId);
7879

79-
pub fn location(&self, row: IndividualId) -> Option<&[super::newtypes::Location]> {
80-
assert!(
81-
(self.as_ref().num_rows == 0 && self.as_ref().location_length == 0)
82-
|| (!self.as_ref().location.is_null() && !self.as_ref().location_offset.is_null())
83-
);
80+
pub fn location_column(&self) -> &[super::newtypes::Location] {
8481
unsafe {
85-
super::tsk_ragged_column_access(
86-
row,
87-
self.as_ref().location,
88-
self.as_ref().num_rows,
82+
std::slice::from_raw_parts(
83+
self.as_ref().location.cast::<super::newtypes::Location>(),
84+
self.as_ref().location_length as usize,
85+
)
86+
}
87+
}
88+
89+
fn location_offset_column_raw(&self) -> &[tsk_size_t] {
90+
unsafe {
91+
std::slice::from_raw_parts(
8992
self.as_ref().location_offset,
90-
self.as_ref().location_length,
93+
self.as_ref().num_rows as usize,
94+
)
95+
}
96+
}
97+
98+
pub fn location(&self, row: IndividualId) -> Option<&[super::newtypes::Location]> {
99+
super::tsk_ragged_column_access(
100+
row,
101+
self.location_column(),
102+
self.location_offset_column_raw(),
103+
)
104+
}
105+
106+
fn parents_column(&self) -> &[IndividualId] {
107+
unsafe {
108+
std::slice::from_raw_parts(
109+
self.as_ref().parents.cast::<IndividualId>(),
110+
self.as_ref().parents_length as usize,
111+
)
112+
}
113+
}
114+
115+
fn parents_offset_column_raw(&self) -> &[tsk_size_t] {
116+
unsafe {
117+
std::slice::from_raw_parts(
118+
self.as_ref().parents_offset,
119+
self.as_ref().num_rows as usize,
91120
)
92121
}
93122
}
@@ -97,15 +126,11 @@ impl IndividualTable {
97126
(self.as_ref().num_rows == 0 && self.as_ref().parents_length == 0)
98127
|| (!self.as_ref().parents.is_null() && !self.as_ref().location_offset.is_null())
99128
);
100-
unsafe {
101-
super::tsk_ragged_column_access(
102-
row,
103-
self.as_ref().parents,
104-
self.as_ref().num_rows,
105-
self.as_ref().parents_offset,
106-
self.as_ref().parents_length,
107-
)
108-
}
129+
super::tsk_ragged_column_access(
130+
row,
131+
self.parents_column(),
132+
self.parents_offset_column_raw(),
133+
)
109134
}
110135
}
111136

src/sys/macros.rs

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -295,21 +295,30 @@ macro_rules! safe_tsk_column_access {
295295

296296
macro_rules! raw_metadata_getter_for_tables {
297297
($idtype: ty) => {
298-
pub fn raw_metadata<I: Into<$idtype>>(&self, row: I) -> Option<&[u8]> {
299-
assert!(
300-
(self.as_ref().num_rows == 0 && self.as_ref().metadata_length == 0)
301-
|| (!self.as_ref().metadata.is_null()
302-
&& !self.as_ref().metadata_offset.is_null())
303-
);
298+
fn metadata_column(&self) -> &[u8] {
299+
unsafe {
300+
std::slice::from_raw_parts(
301+
self.as_ref().metadata.cast::<u8>(),
302+
self.as_ref().metadata_length as usize,
303+
)
304+
}
305+
}
306+
307+
fn metadata_offset_raw(&self) -> &[super::bindings::tsk_size_t] {
304308
unsafe {
305-
$crate::sys::tsk_ragged_column_access::<'_, u8, $idtype, _, _>(
306-
row.into(),
307-
self.as_ref().metadata,
308-
self.as_ref().num_rows,
309+
std::slice::from_raw_parts(
309310
self.as_ref().metadata_offset,
310-
self.as_ref().metadata_length,
311+
self.as_ref().num_rows as usize,
311312
)
312313
}
313314
}
315+
316+
pub fn raw_metadata<I: Into<$idtype>>(&self, row: I) -> Option<&[u8]> {
317+
$crate::sys::tsk_ragged_column_access(
318+
row.into(),
319+
self.metadata_column(),
320+
self.metadata_offset_raw(),
321+
)
322+
}
314323
};
315324
}

src/sys/mod.rs

Lines changed: 17 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -141,79 +141,36 @@ unsafe fn tsk_column_access<
141141
tsk_column_access_detail(row, column, column_length).map(|v| v.into())
142142
}
143143

144-
/// # SAFETY
145-
///
146-
/// The safety requirements here are a bit fiddly.
147-
///
148-
/// The hard case is when the columns contain data:
149-
///
150-
/// * column and offset must both not be NULL
151-
/// * column_length and offset_length must both be
152-
/// the correct lengths for the input pointers
153-
/// * we return None if row < 0 or row > array length.
154-
/// * Thus, the requirement is that the two _lengths
155-
/// == 0 or (pointer both not NULL and the lengths are correct)
156-
///
157-
/// When the lengths of each column are 0, we
158-
/// don't worry about anything else
159-
unsafe fn tsk_ragged_column_access_detail<
160-
R: Into<bindings::tsk_id_t>,
161-
L: Into<bindings::tsk_size_t>,
162-
T: Copy,
163-
>(
164-
row: R,
165-
column: *const T,
166-
column_length: L,
167-
offset: *const bindings::tsk_size_t,
168-
offset_length: bindings::tsk_size_t,
169-
) -> Option<(*const T, usize)> {
170-
let row = row.into();
171-
let column_length = column_length.into();
172-
if row < 0 || row as bindings::tsk_size_t > column_length || offset_length == 0 {
144+
fn tsk_ragged_column_access_detail<'a, T>(
145+
row: usize,
146+
column: &'a [T],
147+
raw_offset: &'a [bindings::tsk_size_t],
148+
) -> Option<&'a [T]> {
149+
if row >= raw_offset.len() || raw_offset.is_empty() {
173150
None
174151
} else {
175-
// SAFETY: pointers are not null
176-
// and *_length are given by tskit-c
177-
let index = row as isize;
178-
let start = *offset.offset(index);
179-
let stop = if (row as bindings::tsk_size_t) < column_length {
180-
*offset.offset(index + 1)
152+
let start = usize::try_from(raw_offset[row]).ok()?;
153+
let stop = if row < raw_offset.len() - 1 {
154+
usize::try_from(raw_offset[row + 1]).ok()?
181155
} else {
182-
offset_length
156+
column.len()
183157
};
184158
if start == stop {
185159
None
186160
} else {
187-
Some((
188-
column.offset(start as isize),
189-
stop as usize - start as usize,
190-
))
161+
Some(&column[start..stop])
191162
}
192163
}
193164
}
194165

195-
// SAFETY: see tsk_ragged_column_access_detail
196-
// We further erquire that a pointer to a T can
197-
// be safely cast to a pointer to an O.
198-
unsafe fn tsk_ragged_column_access<
199-
'a,
200-
O,
201-
R: Into<bindings::tsk_id_t>,
202-
L: Into<bindings::tsk_size_t>,
203-
T: Copy,
204-
>(
166+
fn tsk_ragged_column_access<'a, O, R: Into<bindings::tsk_id_t>>(
205167
row: R,
206-
column: *const T,
207-
column_length: L,
208-
offset: *const bindings::tsk_size_t,
209-
offset_length: bindings::tsk_size_t,
168+
column: &'a [O],
169+
raw_offset: &'a [bindings::tsk_size_t],
210170
) -> Option<&'a [O]> {
211-
unsafe {
212-
tsk_ragged_column_access_detail(row, column, column_length, offset, offset_length)
213-
// If the safety requirements of tsk_ragged_column_access_detail are upheld,
214-
// then we have received a valid pointer + length from which to make a slice
215-
.map(|(p, n)| std::slice::from_raw_parts(p.cast::<O>(), n))
216-
}
171+
let row = row.into();
172+
let row = usize::try_from(row).ok()?;
173+
tsk_ragged_column_access_detail(row, column, raw_offset)
217174
}
218175

219176
/// # SAFETY

src/sys/mutation_table.rs

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use super::bindings::tsk_mutation_table_add_row;
1010
use super::bindings::tsk_mutation_table_clear;
1111
use super::bindings::tsk_mutation_table_init;
1212
use super::bindings::tsk_mutation_table_t;
13+
use super::bindings::tsk_size_t;
1314
use super::tskbox::TskBox;
1415
use super::TskitError;
1516

@@ -100,24 +101,31 @@ impl MutationTable {
100101

101102
raw_metadata_getter_for_tables!(MutationId);
102103

103-
pub fn derived_state(&self, row: MutationId) -> Option<&[u8]> {
104-
assert!(
105-
(self.as_ref().num_rows == 0 && self.as_ref().derived_state_length == 0)
106-
|| (!self.as_ref().derived_state.is_null()
107-
&& !self.as_ref().derived_state_offset.is_null())
108-
);
109-
// SAFETY: either both columns are empty or both pointers at not NULL,
110-
// in which case the correct lengths are from the low-level objects
104+
fn derived_state_column(&self) -> &[u8] {
111105
unsafe {
112-
super::tsk_ragged_column_access(
113-
row,
114-
self.as_ref().derived_state,
115-
self.as_ref().num_rows,
106+
std::slice::from_raw_parts(
107+
self.as_ref().derived_state.cast::<u8>(),
108+
self.as_ref().derived_state_length as usize,
109+
)
110+
}
111+
}
112+
113+
fn derived_state_offset_raw(&self) -> &[tsk_size_t] {
114+
unsafe {
115+
std::slice::from_raw_parts(
116116
self.as_ref().derived_state_offset,
117-
self.as_ref().derived_state_length,
117+
self.as_ref().num_rows as usize,
118118
)
119119
}
120120
}
121+
122+
pub fn derived_state(&self, row: MutationId) -> Option<&[u8]> {
123+
super::tsk_ragged_column_access(
124+
row,
125+
self.derived_state_column(),
126+
self.derived_state_offset_raw(),
127+
)
128+
}
121129
}
122130

123131
impl Default for MutationTable {

src/sys/provenance_table.rs

Lines changed: 40 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -58,47 +58,57 @@ impl ProvenanceTable {
5858
Ok(rv)
5959
}
6060

61-
pub fn timestamp(&self, row: ProvenanceId) -> Option<&str> {
62-
assert!(
63-
(self.as_ref().num_rows != 0 && self.as_ref().timestamp_length != 0)
64-
|| (!self.as_ref().timestamp.is_null()
65-
&& !self.as_ref().timestamp_offset.is_null())
66-
);
61+
fn timestamp_column(&self) -> &[u8] {
62+
unsafe {
63+
std::slice::from_raw_parts(
64+
self.as_ref().timestamp.cast::<u8>(),
65+
self.as_ref().timestamp_length as usize,
66+
)
67+
}
68+
}
6769

68-
// SAFETY: the previous assert checks the safety
69-
// requirements
70-
let timestamp_slice = unsafe {
71-
super::tsk_ragged_column_access(
72-
row,
73-
self.as_ref().timestamp,
74-
self.as_ref().num_rows,
70+
fn timestamp_offset_column_raw(&self) -> &[tsk_size_t] {
71+
unsafe {
72+
std::slice::from_raw_parts(
7573
self.as_ref().timestamp_offset,
76-
self.as_ref().timestamp_length,
74+
self.as_ref().num_rows as usize,
7775
)
78-
};
76+
}
77+
}
78+
79+
pub fn timestamp(&self, row: ProvenanceId) -> Option<&str> {
80+
let timestamp_slice = super::tsk_ragged_column_access(
81+
row,
82+
self.timestamp_column(),
83+
self.timestamp_offset_column_raw(),
84+
);
7985
match timestamp_slice {
8086
Some(tstamp) => std::str::from_utf8(tstamp).ok(),
8187
None => None,
8288
}
8389
}
8490

91+
fn record_column(&self) -> &[u8] {
92+
unsafe {
93+
std::slice::from_raw_parts(
94+
self.as_ref().record.cast::<u8>(),
95+
self.as_ref().record_length as usize,
96+
)
97+
}
98+
}
99+
100+
fn record_offset_column_raw(&self) -> &[tsk_size_t] {
101+
unsafe {
102+
std::slice::from_raw_parts(self.as_ref().record_offset, self.as_ref().num_rows as usize)
103+
}
104+
}
105+
85106
pub fn record(&self, row: ProvenanceId) -> Option<&str> {
86-
assert!(
87-
(self.as_ref().num_rows != 0 && self.as_ref().record_length != 0)
88-
|| (!self.as_ref().record.is_null() && !self.as_ref().record_offset.is_null())
107+
let record_slice = super::tsk_ragged_column_access(
108+
row,
109+
self.record_column(),
110+
self.record_offset_column_raw(),
89111
);
90-
91-
// SAFETY: the previous assert checks the safety
92-
// requirements
93-
let record_slice = unsafe {
94-
super::tsk_ragged_column_access(
95-
row,
96-
self.as_ref().record,
97-
self.as_ref().num_rows,
98-
self.as_ref().record_offset,
99-
self.as_ref().record_length,
100-
)
101-
};
102112
match record_slice {
103113
Some(rec) => std::str::from_utf8(rec).ok(),
104114
None => None,

0 commit comments

Comments
 (0)