Skip to content

Commit fb85ba9

Browse files
committed
Add Tree::num_tracked_samples and Tree::kc_distance.
1 parent e6495a9 commit fb85ba9

File tree

1 file changed

+146
-1
lines changed

1 file changed

+146
-1
lines changed

src/trees.rs

Lines changed: 146 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,22 @@ impl Tree {
5151

5252
fn new(ts: &TreeSequence, flags: TreeFlags) -> Result<Self, TskitError> {
5353
let mut tree = Self::wrap(ts.consumed.nodes().num_rows(), flags);
54-
let rv = unsafe { ll_bindings::tsk_tree_init(tree.as_mut_ptr(), ts.as_ptr(), flags.bits) };
54+
let mut rv =
55+
unsafe { ll_bindings::tsk_tree_init(tree.as_mut_ptr(), ts.as_ptr(), flags.bits) };
56+
if rv < 0 {
57+
return Err(TskitError::ErrorCode { code: rv });
58+
}
59+
// Gotta ask Jerome about this one--why isn't this handled in tsk_tree_init??
60+
if !flags.contains(TreeFlags::NO_SAMPLE_COUNTS) {
61+
rv = unsafe {
62+
ll_bindings::tsk_tree_set_tracked_samples(
63+
tree.as_mut_ptr(),
64+
ts.num_samples() as u64,
65+
tree.inner.samples,
66+
)
67+
};
68+
}
69+
5570
handle_tsk_return_value!(rv, tree)
5671
}
5772

@@ -355,6 +370,39 @@ impl Tree {
355370
false => Ok(b),
356371
}
357372
}
373+
374+
/// Get the number of samples below node `u`.
375+
///
376+
/// # Errors
377+
///
378+
/// * [`TskitError`] if [`TreeFlags::NO_SAMPLE_COUNTS`].
379+
pub fn num_tracked_samples(&self, u: tsk_id_t) -> Result<u64, TskitError> {
380+
let mut n = u64::MAX;
381+
let np: *mut u64 = &mut n;
382+
let code = unsafe { ll_bindings::tsk_tree_get_num_tracked_samples(self.as_ptr(), u, np) };
383+
handle_tsk_return_value!(code, n)
384+
}
385+
386+
/// Calculate the average Kendall-Colijn (`K-C`) distance between
387+
/// pairs of trees whose intervals overlap.
388+
///
389+
/// # Note
390+
///
391+
/// * [Citation](https://doi.org/10.1093/molbev/msw124)
392+
///
393+
/// # Parameters
394+
///
395+
/// * `lambda` specifies the relative weight of topology and branch length.
396+
/// If `lambda` is 0, we only consider topology.
397+
/// If `lambda` is 1, we only consider branch lengths.
398+
pub fn kc_distance(&self, other: &Tree, lambda: f64) -> Result<f64, TskitError> {
399+
let mut kc = f64::NAN;
400+
let kcp: *mut f64 = &mut kc;
401+
let code = unsafe {
402+
ll_bindings::tsk_tree_kc_distance(self.as_ptr(), other.as_ptr(), lambda, kcp)
403+
};
404+
handle_tsk_return_value!(code, kc)
405+
}
358406
}
359407

360408
impl streaming_iterator::StreamingIterator for Tree {
@@ -763,6 +811,7 @@ impl TreeSequence {
763811
/// # Parameters
764812
///
765813
/// * `lambda` specifies the relative weight of topology and branch length.
814+
/// See [`Tree::kc_distance`] for more details.
766815
pub fn kc_distance(&self, other: &TreeSequence, lambda: f64) -> Result<f64, TskitError> {
767816
let mut kc: f64 = f64::NAN;
768817
let kcp: *mut f64 = &mut kc;
@@ -771,6 +820,11 @@ impl TreeSequence {
771820
};
772821
handle_tsk_return_value!(code, kc)
773822
}
823+
824+
// FIXME: document
825+
pub fn num_samples(&self) -> tsk_size_t {
826+
unsafe { ll_bindings::tsk_treeseq_get_num_samples(self.as_ptr()) }
827+
}
774828
}
775829

776830
#[cfg(test)]
@@ -799,6 +853,51 @@ mod test_trees {
799853
tables.tree_sequence().unwrap()
800854
}
801855

856+
fn make_small_table_collection_two_trees() -> TableCollection {
857+
// The two trees are:
858+
// 0
859+
// +++
860+
// | | 1
861+
// | | +++
862+
// 2 3 4 5
863+
864+
// 0
865+
// +-+-+
866+
// 1 |
867+
// +-+-+ |
868+
// 2 4 5 3
869+
870+
let mut tables = TableCollection::new(1000.).unwrap();
871+
tables.add_node(0, 2.0, TSK_NULL, TSK_NULL).unwrap();
872+
tables.add_node(0, 1.0, TSK_NULL, TSK_NULL).unwrap();
873+
tables
874+
.add_node(TSK_NODE_IS_SAMPLE, 0.0, TSK_NULL, TSK_NULL)
875+
.unwrap();
876+
tables
877+
.add_node(TSK_NODE_IS_SAMPLE, 0.0, TSK_NULL, TSK_NULL)
878+
.unwrap();
879+
tables
880+
.add_node(TSK_NODE_IS_SAMPLE, 0.0, TSK_NULL, TSK_NULL)
881+
.unwrap();
882+
tables
883+
.add_node(TSK_NODE_IS_SAMPLE, 0.0, TSK_NULL, TSK_NULL)
884+
.unwrap();
885+
tables.add_edge(500., 1000., 0, 1).unwrap();
886+
tables.add_edge(0., 500., 0, 2).unwrap();
887+
tables.add_edge(0., 1000., 0, 3).unwrap();
888+
tables.add_edge(500., 1000., 1, 2).unwrap();
889+
tables.add_edge(0., 1000., 1, 4).unwrap();
890+
tables.add_edge(0., 1000., 1, 5).unwrap();
891+
tables.full_sort().unwrap();
892+
tables.build_index(0).unwrap();
893+
tables
894+
}
895+
896+
fn treeseq_from_small_table_collection_two_trees() -> TreeSequence {
897+
let tables = make_small_table_collection_two_trees();
898+
tables.tree_sequence().unwrap()
899+
}
900+
802901
#[test]
803902
fn test_create_treeseq_new_from_tables() {
804903
let tables = make_small_table_collection();
@@ -877,18 +976,46 @@ mod test_trees {
877976
}
878977
}
879978

979+
#[test]
980+
fn test_num_tracked_samples() {
981+
let treeseq = treeseq_from_small_table_collection();
982+
assert_eq!(treeseq.inner.num_samples, 2);
983+
let mut tree_iter = treeseq.tree_iterator(TreeFlags::default()).unwrap();
984+
if let Some(tree) = tree_iter.next() {
985+
assert_eq!(tree.num_tracked_samples(2).unwrap(), 1);
986+
assert_eq!(tree.num_tracked_samples(1).unwrap(), 1);
987+
assert_eq!(tree.num_tracked_samples(0).unwrap(), 2);
988+
}
989+
}
990+
991+
#[should_panic]
992+
#[test]
993+
fn test_num_tracked_samples_not_tracking_samples() {
994+
let treeseq = treeseq_from_small_table_collection();
995+
assert_eq!(treeseq.inner.num_samples, 2);
996+
let mut tree_iter = treeseq.tree_iterator(TreeFlags::NO_SAMPLE_COUNTS).unwrap();
997+
if let Some(tree) = tree_iter.next() {
998+
assert_eq!(tree.num_tracked_samples(2).unwrap(), 0);
999+
assert_eq!(tree.num_tracked_samples(1).unwrap(), 0);
1000+
assert_eq!(tree.num_tracked_samples(0).unwrap(), 0);
1001+
}
1002+
}
1003+
8801004
#[test]
8811005
fn test_iterate_samples() {
8821006
let tables = make_small_table_collection();
8831007
let treeseq = tables.tree_sequence().unwrap();
8841008

8851009
let mut tree_iter = treeseq.tree_iterator(TreeFlags::SAMPLE_LISTS).unwrap();
8861010
if let Some(tree) = tree_iter.next() {
1011+
assert!(!tree.flags.contains(TreeFlags::NO_SAMPLE_COUNTS));
1012+
assert!(tree.flags.contains(TreeFlags::SAMPLE_LISTS));
8871013
let mut s = vec![];
8881014
for i in tree.samples(0).unwrap() {
8891015
s.push(i);
8901016
}
8911017
assert_eq!(s.len(), 2);
1018+
assert_eq!(s.len(), tree.num_tracked_samples(0).unwrap() as usize);
8921019
assert_eq!(s[0], 1);
8931020
assert_eq!(s[1], 2);
8941021

@@ -899,12 +1026,30 @@ mod test_trees {
8991026
}
9001027
assert_eq!(s.len(), 1);
9011028
assert_eq!(s[0], u);
1029+
assert_eq!(s.len(), tree.num_tracked_samples(u).unwrap() as usize);
9021030
}
9031031
} else {
9041032
panic!("Expected a tree");
9051033
}
9061034
}
9071035

1036+
#[test]
1037+
fn test_iterate_samples_two_trees() {
1038+
let treeseq = treeseq_from_small_table_collection_two_trees();
1039+
assert_eq!(treeseq.inner.num_trees, 2);
1040+
let mut tree_iter = treeseq.tree_iterator(TreeFlags::SAMPLE_LISTS).unwrap();
1041+
while let Some(tree) = tree_iter.next() {
1042+
for n in tree.nodes(NodeTraversalOrder::Preorder) {
1043+
let mut nsamples = 0;
1044+
for _ in tree.samples(n).unwrap() {
1045+
nsamples += 1;
1046+
}
1047+
assert!(nsamples > 0);
1048+
assert_eq!(nsamples, tree.num_tracked_samples(n).unwrap());
1049+
}
1050+
}
1051+
}
1052+
9081053
#[test]
9091054
fn test_kc_distance_naive_test() {
9101055
let ts1 = treeseq_from_small_table_collection();

0 commit comments

Comments
 (0)