@@ -51,7 +51,22 @@ impl Tree {
5151
5252 fn new ( ts : & TreeSequence , flags : TreeFlags ) -> Result < Self , TskitError > {
5353 let mut tree = Self :: wrap ( ts. consumed . nodes ( ) . num_rows ( ) , flags) ;
54- let rv = unsafe { ll_bindings:: tsk_tree_init ( tree. as_mut_ptr ( ) , ts. as_ptr ( ) , flags. bits ) } ;
54+ let mut rv =
55+ unsafe { ll_bindings:: tsk_tree_init ( tree. as_mut_ptr ( ) , ts. as_ptr ( ) , flags. bits ) } ;
56+ if rv < 0 {
57+ return Err ( TskitError :: ErrorCode { code : rv } ) ;
58+ }
59+ // Gotta ask Jerome about this one--why isn't this handled in tsk_tree_init??
60+ if !flags. contains ( TreeFlags :: NO_SAMPLE_COUNTS ) {
61+ rv = unsafe {
62+ ll_bindings:: tsk_tree_set_tracked_samples (
63+ tree. as_mut_ptr ( ) ,
64+ ts. num_samples ( ) as u64 ,
65+ tree. inner . samples ,
66+ )
67+ } ;
68+ }
69+
5570 handle_tsk_return_value ! ( rv, tree)
5671 }
5772
@@ -355,6 +370,39 @@ impl Tree {
355370 false => Ok ( b) ,
356371 }
357372 }
373+
374+ /// Get the number of samples below node `u`.
375+ ///
376+ /// # Errors
377+ ///
378+ /// * [`TskitError`] if [`TreeFlags::NO_SAMPLE_COUNTS`].
379+ pub fn num_tracked_samples ( & self , u : tsk_id_t ) -> Result < u64 , TskitError > {
380+ let mut n = u64:: MAX ;
381+ let np: * mut u64 = & mut n;
382+ let code = unsafe { ll_bindings:: tsk_tree_get_num_tracked_samples ( self . as_ptr ( ) , u, np) } ;
383+ handle_tsk_return_value ! ( code, n)
384+ }
385+
386+ /// Calculate the average Kendall-Colijn (`K-C`) distance between
387+ /// pairs of trees whose intervals overlap.
388+ ///
389+ /// # Note
390+ ///
391+ /// * [Citation](https://doi.org/10.1093/molbev/msw124)
392+ ///
393+ /// # Parameters
394+ ///
395+ /// * `lambda` specifies the relative weight of topology and branch length.
396+ /// If `lambda` is 0, we only consider topology.
397+ /// If `lambda` is 1, we only consider branch lengths.
398+ pub fn kc_distance ( & self , other : & Tree , lambda : f64 ) -> Result < f64 , TskitError > {
399+ let mut kc = f64:: NAN ;
400+ let kcp: * mut f64 = & mut kc;
401+ let code = unsafe {
402+ ll_bindings:: tsk_tree_kc_distance ( self . as_ptr ( ) , other. as_ptr ( ) , lambda, kcp)
403+ } ;
404+ handle_tsk_return_value ! ( code, kc)
405+ }
358406}
359407
360408impl streaming_iterator:: StreamingIterator for Tree {
@@ -763,6 +811,7 @@ impl TreeSequence {
763811 /// # Parameters
764812 ///
765813 /// * `lambda` specifies the relative weight of topology and branch length.
814+ /// See [`Tree::kc_distance`] for more details.
766815 pub fn kc_distance ( & self , other : & TreeSequence , lambda : f64 ) -> Result < f64 , TskitError > {
767816 let mut kc: f64 = f64:: NAN ;
768817 let kcp: * mut f64 = & mut kc;
@@ -771,6 +820,11 @@ impl TreeSequence {
771820 } ;
772821 handle_tsk_return_value ! ( code, kc)
773822 }
823+
824+ // FIXME: document
825+ pub fn num_samples ( & self ) -> tsk_size_t {
826+ unsafe { ll_bindings:: tsk_treeseq_get_num_samples ( self . as_ptr ( ) ) }
827+ }
774828}
775829
776830#[ cfg( test) ]
@@ -799,6 +853,51 @@ mod test_trees {
799853 tables. tree_sequence ( ) . unwrap ( )
800854 }
801855
856+ fn make_small_table_collection_two_trees ( ) -> TableCollection {
857+ // The two trees are:
858+ // 0
859+ // +++
860+ // | | 1
861+ // | | +++
862+ // 2 3 4 5
863+
864+ // 0
865+ // +-+-+
866+ // 1 |
867+ // +-+-+ |
868+ // 2 4 5 3
869+
870+ let mut tables = TableCollection :: new ( 1000. ) . unwrap ( ) ;
871+ tables. add_node ( 0 , 2.0 , TSK_NULL , TSK_NULL ) . unwrap ( ) ;
872+ tables. add_node ( 0 , 1.0 , TSK_NULL , TSK_NULL ) . unwrap ( ) ;
873+ tables
874+ . add_node ( TSK_NODE_IS_SAMPLE , 0.0 , TSK_NULL , TSK_NULL )
875+ . unwrap ( ) ;
876+ tables
877+ . add_node ( TSK_NODE_IS_SAMPLE , 0.0 , TSK_NULL , TSK_NULL )
878+ . unwrap ( ) ;
879+ tables
880+ . add_node ( TSK_NODE_IS_SAMPLE , 0.0 , TSK_NULL , TSK_NULL )
881+ . unwrap ( ) ;
882+ tables
883+ . add_node ( TSK_NODE_IS_SAMPLE , 0.0 , TSK_NULL , TSK_NULL )
884+ . unwrap ( ) ;
885+ tables. add_edge ( 500. , 1000. , 0 , 1 ) . unwrap ( ) ;
886+ tables. add_edge ( 0. , 500. , 0 , 2 ) . unwrap ( ) ;
887+ tables. add_edge ( 0. , 1000. , 0 , 3 ) . unwrap ( ) ;
888+ tables. add_edge ( 500. , 1000. , 1 , 2 ) . unwrap ( ) ;
889+ tables. add_edge ( 0. , 1000. , 1 , 4 ) . unwrap ( ) ;
890+ tables. add_edge ( 0. , 1000. , 1 , 5 ) . unwrap ( ) ;
891+ tables. full_sort ( ) . unwrap ( ) ;
892+ tables. build_index ( 0 ) . unwrap ( ) ;
893+ tables
894+ }
895+
896+ fn treeseq_from_small_table_collection_two_trees ( ) -> TreeSequence {
897+ let tables = make_small_table_collection_two_trees ( ) ;
898+ tables. tree_sequence ( ) . unwrap ( )
899+ }
900+
802901 #[ test]
803902 fn test_create_treeseq_new_from_tables ( ) {
804903 let tables = make_small_table_collection ( ) ;
@@ -877,18 +976,46 @@ mod test_trees {
877976 }
878977 }
879978
979+ #[ test]
980+ fn test_num_tracked_samples ( ) {
981+ let treeseq = treeseq_from_small_table_collection ( ) ;
982+ assert_eq ! ( treeseq. inner. num_samples, 2 ) ;
983+ let mut tree_iter = treeseq. tree_iterator ( TreeFlags :: default ( ) ) . unwrap ( ) ;
984+ if let Some ( tree) = tree_iter. next ( ) {
985+ assert_eq ! ( tree. num_tracked_samples( 2 ) . unwrap( ) , 1 ) ;
986+ assert_eq ! ( tree. num_tracked_samples( 1 ) . unwrap( ) , 1 ) ;
987+ assert_eq ! ( tree. num_tracked_samples( 0 ) . unwrap( ) , 2 ) ;
988+ }
989+ }
990+
991+ #[ should_panic]
992+ #[ test]
993+ fn test_num_tracked_samples_not_tracking_samples ( ) {
994+ let treeseq = treeseq_from_small_table_collection ( ) ;
995+ assert_eq ! ( treeseq. inner. num_samples, 2 ) ;
996+ let mut tree_iter = treeseq. tree_iterator ( TreeFlags :: NO_SAMPLE_COUNTS ) . unwrap ( ) ;
997+ if let Some ( tree) = tree_iter. next ( ) {
998+ assert_eq ! ( tree. num_tracked_samples( 2 ) . unwrap( ) , 0 ) ;
999+ assert_eq ! ( tree. num_tracked_samples( 1 ) . unwrap( ) , 0 ) ;
1000+ assert_eq ! ( tree. num_tracked_samples( 0 ) . unwrap( ) , 0 ) ;
1001+ }
1002+ }
1003+
8801004 #[ test]
8811005 fn test_iterate_samples ( ) {
8821006 let tables = make_small_table_collection ( ) ;
8831007 let treeseq = tables. tree_sequence ( ) . unwrap ( ) ;
8841008
8851009 let mut tree_iter = treeseq. tree_iterator ( TreeFlags :: SAMPLE_LISTS ) . unwrap ( ) ;
8861010 if let Some ( tree) = tree_iter. next ( ) {
1011+ assert ! ( !tree. flags. contains( TreeFlags :: NO_SAMPLE_COUNTS ) ) ;
1012+ assert ! ( tree. flags. contains( TreeFlags :: SAMPLE_LISTS ) ) ;
8871013 let mut s = vec ! [ ] ;
8881014 for i in tree. samples ( 0 ) . unwrap ( ) {
8891015 s. push ( i) ;
8901016 }
8911017 assert_eq ! ( s. len( ) , 2 ) ;
1018+ assert_eq ! ( s. len( ) , tree. num_tracked_samples( 0 ) . unwrap( ) as usize ) ;
8921019 assert_eq ! ( s[ 0 ] , 1 ) ;
8931020 assert_eq ! ( s[ 1 ] , 2 ) ;
8941021
@@ -899,12 +1026,30 @@ mod test_trees {
8991026 }
9001027 assert_eq ! ( s. len( ) , 1 ) ;
9011028 assert_eq ! ( s[ 0 ] , u) ;
1029+ assert_eq ! ( s. len( ) , tree. num_tracked_samples( u) . unwrap( ) as usize ) ;
9021030 }
9031031 } else {
9041032 panic ! ( "Expected a tree" ) ;
9051033 }
9061034 }
9071035
1036+ #[ test]
1037+ fn test_iterate_samples_two_trees ( ) {
1038+ let treeseq = treeseq_from_small_table_collection_two_trees ( ) ;
1039+ assert_eq ! ( treeseq. inner. num_trees, 2 ) ;
1040+ let mut tree_iter = treeseq. tree_iterator ( TreeFlags :: SAMPLE_LISTS ) . unwrap ( ) ;
1041+ while let Some ( tree) = tree_iter. next ( ) {
1042+ for n in tree. nodes ( NodeTraversalOrder :: Preorder ) {
1043+ let mut nsamples = 0 ;
1044+ for _ in tree. samples ( n) . unwrap ( ) {
1045+ nsamples += 1 ;
1046+ }
1047+ assert ! ( nsamples > 0 ) ;
1048+ assert_eq ! ( nsamples, tree. num_tracked_samples( n) . unwrap( ) ) ;
1049+ }
1050+ }
1051+ }
1052+
9081053 #[ test]
9091054 fn test_kc_distance_naive_test ( ) {
9101055 let ts1 = treeseq_from_small_table_collection ( ) ;
0 commit comments