Skip to content

Commit d89dcd1

Browse files
authored
Merge pull request #18 from molpopgen/tskit_wrapper_trait
Closes #17
2 parents 64ac8a0 + 3445af3 commit d89dcd1

File tree

4 files changed

+105
-46
lines changed

4 files changed

+105
-46
lines changed

src/_macros.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,34 @@ macro_rules! unsafe_tsk_column_access {
3030
}};
3131
}
3232

33+
macro_rules! build_tskit_type {
34+
($name: ident, $ll_name: ty, $drop: ident) => {
35+
impl Drop for $name {
36+
fn drop(&mut self) {
37+
let rv = unsafe { $drop(&mut *self.inner) };
38+
panic_on_tskit_error!(rv);
39+
}
40+
}
41+
42+
impl crate::ffi::TskitType<$ll_name> for $name {
43+
fn wrap() -> Self {
44+
let temp: std::mem::MaybeUninit<$ll_name> = std::mem::MaybeUninit::uninit();
45+
$name {
46+
inner: unsafe { Box::<$ll_name>::new(temp.assume_init()) },
47+
}
48+
}
49+
50+
fn as_ptr(&self) -> *const $ll_name {
51+
&*self.inner
52+
}
53+
54+
fn as_mut_ptr(&mut self) -> *mut $ll_name {
55+
&mut *self.inner
56+
}
57+
}
58+
};
59+
}
60+
3361
#[cfg(test)]
3462
mod test {
3563
use crate::error::TskitRustError;

src/ffi.rs

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
//! Define traits related to wrapping tskit stuff
2+
3+
/// Define what it means to wrap a tskit struct.
4+
/// The implementation of Drop should call the
5+
/// tsk_foo_free() function corresponding
6+
/// to tsk_foo_t.
7+
pub trait TskitType<T>: Drop {
8+
/// Encapsulate tsk_foo_t and return rust
9+
/// object. Best practices seem to
10+
/// suggest using Box for this.
11+
fn wrap() -> Self;
12+
/// Return const pointer
13+
fn as_ptr(&self) -> *const T;
14+
/// Return mutable pointer
15+
fn as_mut_ptr(&mut self) -> *mut T;
16+
}
17+
18+
#[cfg(test)]
19+
mod tests {
20+
use super::*;
21+
use crate::bindings as ll_bindings;
22+
use ll_bindings::tsk_table_collection_free;
23+
24+
pub struct TableCollectionMock {
25+
inner: Box<ll_bindings::tsk_table_collection_t>,
26+
}
27+
28+
build_tskit_type!(
29+
TableCollectionMock,
30+
ll_bindings::tsk_table_collection_t,
31+
tsk_table_collection_free
32+
);
33+
34+
impl TableCollectionMock {
35+
fn new(len: f64) -> Self {
36+
let mut s = Self::wrap();
37+
38+
let rv = unsafe { ll_bindings::tsk_table_collection_init(s.as_mut_ptr(), 0) };
39+
assert_eq!(rv, 0);
40+
41+
s.inner.sequence_length = len;
42+
43+
s
44+
}
45+
46+
fn sequence_length(&self) -> f64 {
47+
unsafe { (*self.as_ptr()).sequence_length }
48+
}
49+
}
50+
51+
#[test]
52+
fn test_create_mock_type() {
53+
let t = TableCollectionMock::new(10.);
54+
assert_eq!(t.sequence_length() as i64, 10);
55+
}
56+
}

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ pub mod bindings;
99
mod _macros; // Starts w/_ to be sorted at front by rustfmt!
1010
mod edge_table;
1111
pub mod error;
12+
pub mod ffi;
1213
mod mutation_table;
1314
mod node_table;
1415
mod population_table;
@@ -53,4 +54,3 @@ pub fn version() -> &'static str {
5354

5455
// Testing modules
5556
mod test_tsk_variables;
56-

src/table_collection.rs

Lines changed: 20 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use crate::bindings as ll_bindings;
22
use crate::error::TskitRustError;
3+
use crate::ffi::TskitType;
34
use crate::types::Bookmark;
45
use crate::EdgeTable;
56
use crate::MutationTable;
@@ -8,19 +9,7 @@ use crate::PopulationTable;
89
use crate::SiteTable;
910
use crate::TskReturnValue;
1011
use crate::{tsk_flags_t, tsk_id_t, tsk_size_t};
11-
12-
/// Handle allocation details.
13-
fn new_tsk_table_collection_t() -> Result<Box<ll_bindings::tsk_table_collection_t>, TskitRustError>
14-
{
15-
let mut tsk_tables: std::mem::MaybeUninit<ll_bindings::tsk_table_collection_t> =
16-
std::mem::MaybeUninit::uninit();
17-
let rv = unsafe { ll_bindings::tsk_table_collection_init(tsk_tables.as_mut_ptr(), 0) };
18-
if rv < 0 {
19-
return Err(TskitRustError::ErrorCode { code: rv });
20-
}
21-
let rv = unsafe { Box::<ll_bindings::tsk_table_collection_t>::new(tsk_tables.assume_init()) };
22-
Ok(rv)
23-
}
12+
use ll_bindings::tsk_table_collection_free;
2413

2514
/// A table collection.
2615
///
@@ -69,9 +58,15 @@ fn new_tsk_table_collection_t() -> Result<Box<ll_bindings::tsk_table_collection_
6958
///
7059
/// Addressing point 3 may require API breakage.
7160
pub struct TableCollection {
72-
tables: Box<ll_bindings::tsk_table_collection_t>,
61+
inner: Box<ll_bindings::tsk_table_collection_t>,
7362
}
7463

64+
build_tskit_type!(
65+
TableCollection,
66+
ll_bindings::tsk_table_collection_t,
67+
tsk_table_collection_free
68+
);
69+
7570
impl TableCollection {
7671
/// Create a new table collection with a sequence length.
7772
pub fn new(sequence_length: f64) -> Result<Self, TskitRustError> {
@@ -81,16 +76,13 @@ impl TableCollection {
8176
expected: "sequence_length >= 0.0".to_string(),
8277
});
8378
}
84-
let tables = new_tsk_table_collection_t();
85-
match tables {
86-
Ok(_) => (),
87-
Err(e) => return Err(e),
79+
let mut tables = Self::wrap();
80+
let rv = unsafe { ll_bindings::tsk_table_collection_init(tables.as_mut_ptr(), 0) };
81+
if rv < 0 {
82+
return Err(crate::error::TskitRustError::ErrorCode { code: rv });
8883
}
89-
let mut rv = TableCollection {
90-
tables: tables.unwrap(),
91-
};
92-
rv.tables.sequence_length = sequence_length;
93-
Ok(rv)
84+
tables.inner.sequence_length = sequence_length;
85+
Ok(tables)
9486
}
9587

9688
/// Load a table collection from a file.
@@ -119,16 +111,6 @@ impl TableCollection {
119111
}
120112
}
121113

122-
/// Access to raw C pointer as const tsk_table_collection_t *.
123-
pub fn as_ptr(&self) -> *const ll_bindings::tsk_table_collection_t {
124-
&*self.tables
125-
}
126-
127-
/// Access to raw C pointer as tsk_table_collection_t *.
128-
pub fn as_mut_ptr(&mut self) -> *mut ll_bindings::tsk_table_collection_t {
129-
&mut *self.tables
130-
}
131-
132114
/// Length of the sequence/"genome".
133115
pub fn sequence_length(&self) -> f64 {
134116
unsafe { (*self.as_ptr()).sequence_length }
@@ -138,35 +120,35 @@ impl TableCollection {
138120
/// Lifetime of return value is tied to (this)
139121
/// parent object.
140122
pub fn edges<'a>(&'a self) -> EdgeTable<'a> {
141-
EdgeTable::<'a>::new_from_table(&self.tables.edges)
123+
EdgeTable::<'a>::new_from_table(&self.inner.edges)
142124
}
143125

144126
/// Get reference to the [``NodeTable``](crate::NodeTable).
145127
/// Lifetime of return value is tied to (this)
146128
/// parent object.
147129
pub fn nodes<'a>(&'a self) -> NodeTable<'a> {
148-
NodeTable::<'a>::new_from_table(&self.tables.nodes)
130+
NodeTable::<'a>::new_from_table(&self.inner.nodes)
149131
}
150132

151133
/// Get reference to the [``SiteTable``](crate::SiteTable).
152134
/// Lifetime of return value is tied to (this)
153135
/// parent object.
154136
pub fn sites<'a>(&'a self) -> SiteTable<'a> {
155-
SiteTable::<'a>::new_from_table(&self.tables.sites)
137+
SiteTable::<'a>::new_from_table(&self.inner.sites)
156138
}
157139

158140
/// Get reference to the [``MutationTable``](crate::MutationTable).
159141
/// Lifetime of return value is tied to (this)
160142
/// parent object.
161143
pub fn mutations<'a>(&'a self) -> MutationTable<'a> {
162-
MutationTable::<'a>::new_from_table(&self.tables.mutations)
144+
MutationTable::<'a>::new_from_table(&self.inner.mutations)
163145
}
164146

165147
/// Get reference to the [``PopulationTable``](crate::PopulationTable).
166148
/// Lifetime of return value is tied to (this)
167149
/// parent object.
168150
pub fn populations<'a>(&'a self) -> PopulationTable<'a> {
169-
PopulationTable::<'a>::new_from_table(&self.tables.populations)
151+
PopulationTable::<'a>::new_from_table(&self.inner.populations)
170152
}
171153

172154
/// Add a row to the edge table
@@ -347,13 +329,6 @@ impl TableCollection {
347329
}
348330
}
349331

350-
impl Drop for TableCollection {
351-
fn drop(&mut self) {
352-
let rv = unsafe { ll_bindings::tsk_table_collection_free(&mut *self.tables) };
353-
panic_on_tskit_error!(rv);
354-
}
355-
}
356-
357332
#[cfg(test)]
358333
mod test {
359334
use super::*;

0 commit comments

Comments
 (0)