@@ -6597,43 +6597,62 @@ tsk_treeseq_divergence_matrix_branch(const tsk_treeseq_t *self, tsk_size_t num_s
6597
6597
return ret ;
6598
6598
}
6599
6599
6600
- static tsk_size_t
6601
- count_mutations_on_path (tsk_id_t u , tsk_id_t v , const tsk_id_t * restrict parent ,
6602
- const double * restrict time , const tsk_size_t * restrict mutations_per_node )
6600
+ // FIXME see #2817
6601
+ // Just including this here for now as it's the simplest option. Everything
6602
+ // will probably move to stats.[c,h] in the near future though, and it
6603
+ // can pull in ``genotypes.h`` without issues.
6604
+ #include <tskit/genotypes.h>
6605
+
6606
+ static void
6607
+ update_site_divergence (const tsk_variant_t * var , const tsk_id_t * restrict A ,
6608
+ const tsk_size_t * restrict offsets , double * D )
6609
+
6603
6610
{
6604
- double tu , tv ;
6605
- tsk_size_t count = 0 ;
6611
+ const tsk_size_t num_alleles = var -> num_alleles ;
6612
+ const tsk_id_t n = ( tsk_id_t ) var -> num_samples ;
6606
6613
6607
- tu = time [u ];
6608
- tv = time [v ];
6609
- while (u != v ) {
6610
- if (tu < tv ) {
6611
- count += mutations_per_node [u ];
6612
- u = parent [u ];
6613
- if (u == TSK_NULL ) {
6614
- break ;
6615
- }
6616
- tu = time [u ];
6617
- } else {
6618
- count += mutations_per_node [v ];
6619
- v = parent [v ];
6620
- if (v == TSK_NULL ) {
6621
- break ;
6614
+ tsk_size_t a , b , j , k ;
6615
+ tsk_id_t u , v ;
6616
+
6617
+ for (a = 0 ; a < num_alleles ; a ++ ) {
6618
+ for (b = a + 1 ; b < num_alleles ; b ++ ) {
6619
+ for (j = offsets [a ]; j < offsets [a + 1 ]; j ++ ) {
6620
+ for (k = offsets [b ]; k < offsets [b + 1 ]; k ++ ) {
6621
+ u = A [j ];
6622
+ v = A [k ];
6623
+ /* Only increment the upper triangle to (hopefully) improve memory
6624
+ * access patterns */
6625
+ if (u > v ) {
6626
+ v = A [j ];
6627
+ u = A [k ];
6628
+ }
6629
+ D [u * n + v ]++ ;
6630
+ }
6622
6631
}
6623
- tv = time [v ];
6624
6632
}
6625
6633
}
6626
- if (u != v ) {
6627
- while (u != TSK_NULL ) {
6628
- count += mutations_per_node [u ];
6629
- u = parent [u ];
6630
- }
6631
- while (v != TSK_NULL ) {
6632
- count += mutations_per_node [v ];
6633
- v = parent [v ];
6634
+ }
6635
+
6636
+ static void
6637
+ group_alleles (const tsk_variant_t * var , tsk_id_t * restrict A , tsk_size_t * offsets )
6638
+ {
6639
+ const tsk_size_t n = var -> num_samples ;
6640
+ const int32_t * restrict genotypes = var -> genotypes ;
6641
+ tsk_id_t a ;
6642
+ tsk_size_t j , k ;
6643
+
6644
+ k = 0 ;
6645
+ offsets [0 ] = 0 ;
6646
+ for (a = 0 ; a < (tsk_id_t ) var -> num_alleles ; a ++ ) {
6647
+ offsets [a + 1 ] = offsets [a ];
6648
+ for (j = 0 ; j < n ; j ++ ) {
6649
+ if (genotypes [j ] == a ) {
6650
+ offsets [a + 1 ]++ ;
6651
+ A [k ] = (tsk_id_t ) j ;
6652
+ k ++ ;
6653
+ }
6634
6654
}
6635
6655
}
6636
- return count ;
6637
6656
}
6638
6657
6639
6658
static int
@@ -6643,72 +6662,99 @@ tsk_treeseq_divergence_matrix_site(const tsk_treeseq_t *self, tsk_size_t num_sam
6643
6662
double * restrict result )
6644
6663
{
6645
6664
int ret = 0 ;
6646
- tsk_tree_t tree ;
6647
- const tsk_size_t n = num_samples ;
6648
- const tsk_size_t num_nodes = self -> tables -> nodes .num_rows ;
6649
- const double * restrict nodes_time = self -> tables -> nodes .time ;
6650
- tsk_size_t i , j , k , tree_site , tree_mut ;
6651
- tsk_site_t site ;
6652
- tsk_mutation_t mut ;
6653
- tsk_id_t u , v ;
6654
- double left , right , span_left , span_right ;
6665
+ tsk_size_t i ;
6666
+ tsk_id_t site_id ;
6667
+ double left , right ;
6655
6668
double * restrict D ;
6656
- tsk_size_t * mutations_per_node = tsk_malloc (num_nodes * sizeof (* mutations_per_node ));
6657
-
6658
- ret = tsk_tree_init (& tree , self , 0 );
6669
+ const tsk_id_t num_sites = (tsk_id_t ) self -> tables -> sites .num_rows ;
6670
+ const double * restrict sites_position = self -> tables -> sites .position ;
6671
+ tsk_id_t * A = tsk_malloc (num_samples * sizeof (* A ));
6672
+ /* Allocate the allele offsets at the first variant */
6673
+ tsk_size_t max_alleles = 0 ;
6674
+ tsk_size_t * allele_offsets = NULL ;
6675
+ tsk_variant_t variant ;
6676
+
6677
+ ret = tsk_variant_init (
6678
+ & variant , self , samples , num_samples , NULL , TSK_ISOLATED_NOT_MISSING );
6659
6679
if (ret != 0 ) {
6660
6680
goto out ;
6661
6681
}
6662
- if (mutations_per_node == NULL ) {
6682
+ if (A == NULL ) {
6663
6683
ret = TSK_ERR_NO_MEMORY ;
6664
6684
goto out ;
6665
6685
}
6666
6686
6687
+ site_id = 0 ;
6688
+ while (site_id < num_sites && sites_position [site_id ] < windows [0 ]) {
6689
+ site_id ++ ;
6690
+ }
6691
+
6667
6692
for (i = 0 ; i < num_windows ; i ++ ) {
6668
6693
left = windows [i ];
6669
6694
right = windows [i + 1 ];
6670
- D = result + i * n * n ;
6671
- ret = tsk_tree_seek (& tree , left , 0 );
6672
- if (ret != 0 ) {
6673
- goto out ;
6674
- }
6675
- while (tree .interval .left < right && tree .index != -1 ) {
6676
- span_left = TSK_MAX (tree .interval .left , left );
6677
- span_right = TSK_MIN (tree .interval .right , right );
6695
+ D = result + i * num_samples * num_samples ;
6678
6696
6679
- /* NOTE: we could avoid this full memset across all nodes by doing
6680
- * the same loops again and decrementing at the end of the main
6681
- * tree-loop. It's probably not worth it though, because of the
6682
- * overwhelming O(n^2) below */
6683
- tsk_memset (mutations_per_node , 0 , num_nodes * sizeof (* mutations_per_node ));
6684
- for (tree_site = 0 ; tree_site < tree .sites_length ; tree_site ++ ) {
6685
- site = tree .sites [tree_site ];
6686
- if (span_left <= site .position && site .position < span_right ) {
6687
- for (tree_mut = 0 ; tree_mut < site .mutations_length ; tree_mut ++ ) {
6688
- mut = site .mutations [tree_mut ];
6689
- mutations_per_node [mut .node ]++ ;
6690
- }
6691
- }
6697
+ if (site_id < num_sites ) {
6698
+ tsk_bug_assert (sites_position [site_id ] >= left );
6699
+ }
6700
+ while (site_id < num_sites && sites_position [site_id ] < right ) {
6701
+ ret = tsk_variant_decode (& variant , site_id , 0 );
6702
+ if (ret != 0 ) {
6703
+ goto out ;
6692
6704
}
6693
-
6694
- for (j = 0 ; j < n ; j ++ ) {
6695
- u = samples [j ];
6696
- for (k = j + 1 ; k < n ; k ++ ) {
6697
- v = samples [k ];
6698
- D [j * n + k ] += (double ) count_mutations_on_path (
6699
- u , v , tree .parent , nodes_time , mutations_per_node );
6705
+ if (variant .num_alleles > max_alleles ) {
6706
+ /* could do some kind of doubling here, but there's no
6707
+ * point - just keep it simple for testing. */
6708
+ max_alleles = variant .num_alleles ;
6709
+ tsk_safe_free (allele_offsets );
6710
+ allele_offsets = tsk_malloc ((max_alleles + 1 ) * sizeof (* allele_offsets ));
6711
+ if (allele_offsets == NULL ) {
6712
+ ret = TSK_ERR_NO_MEMORY ;
6713
+ goto out ;
6700
6714
}
6701
6715
}
6702
- ret = tsk_tree_next (& tree );
6703
- if (ret < 0 ) {
6704
- goto out ;
6705
- }
6716
+ group_alleles (& variant , A , allele_offsets );
6717
+ update_site_divergence (& variant , A , allele_offsets , D );
6718
+ site_id ++ ;
6706
6719
}
6707
6720
}
6708
6721
ret = 0 ;
6709
6722
out :
6710
- tsk_tree_free (& tree );
6711
- tsk_safe_free (mutations_per_node );
6723
+ tsk_variant_free (& variant );
6724
+ tsk_safe_free (A );
6725
+ tsk_safe_free (allele_offsets );
6726
+ return ret ;
6727
+ }
6728
+
6729
+ static int
6730
+ get_sample_index_map (const tsk_size_t num_nodes , const tsk_size_t num_samples ,
6731
+ const tsk_id_t * restrict samples , tsk_id_t * * ret_sample_index_map )
6732
+ {
6733
+ int ret = 0 ;
6734
+ tsk_size_t j ;
6735
+ tsk_id_t u ;
6736
+ tsk_id_t * sample_index_map = tsk_malloc (num_nodes * sizeof (* sample_index_map ));
6737
+
6738
+ if (sample_index_map == NULL ) {
6739
+ ret = TSK_ERR_NO_MEMORY ;
6740
+ goto out ;
6741
+ }
6742
+ /* Assign the output pointer here so that it will be freed in the case
6743
+ * of an error raised in the input checking */
6744
+ * ret_sample_index_map = sample_index_map ;
6745
+
6746
+ for (j = 0 ; j < num_nodes ; j ++ ) {
6747
+ sample_index_map [j ] = TSK_NULL ;
6748
+ }
6749
+ for (j = 0 ; j < num_samples ; j ++ ) {
6750
+ u = samples [j ];
6751
+ if (sample_index_map [u ] != TSK_NULL ) {
6752
+ ret = TSK_ERR_DUPLICATE_SAMPLE ;
6753
+ goto out ;
6754
+ }
6755
+ sample_index_map [u ] = (tsk_id_t ) j ;
6756
+ }
6757
+ out :
6712
6758
return ret ;
6713
6759
}
6714
6760
@@ -6739,9 +6785,11 @@ tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_samples,
6739
6785
const tsk_id_t * samples = self -> samples ;
6740
6786
tsk_size_t n = self -> num_samples ;
6741
6787
const double default_windows [] = { 0 , self -> tables -> sequence_length };
6788
+ const tsk_size_t num_nodes = self -> tables -> nodes .num_rows ;
6742
6789
bool stat_site = !!(options & TSK_STAT_SITE );
6743
6790
bool stat_branch = !!(options & TSK_STAT_BRANCH );
6744
6791
bool stat_node = !!(options & TSK_STAT_NODE );
6792
+ tsk_id_t * sample_index_map = NULL ;
6745
6793
6746
6794
if (stat_node ) {
6747
6795
ret = TSK_ERR_UNSUPPORTED_STAT_MODE ;
@@ -6785,6 +6833,13 @@ tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_samples,
6785
6833
}
6786
6834
}
6787
6835
6836
+ /* NOTE: we're just using this here to check the input for duplicates.
6837
+ */
6838
+ ret = get_sample_index_map (num_nodes , n , samples , & sample_index_map );
6839
+ if (ret != 0 ) {
6840
+ goto out ;
6841
+ }
6842
+
6788
6843
tsk_memset (result , 0 , num_windows * n * n * sizeof (* result ));
6789
6844
6790
6845
if (stat_branch ) {
@@ -6801,5 +6856,6 @@ tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_samples,
6801
6856
fill_lower_triangle (result , n , num_windows );
6802
6857
6803
6858
out :
6859
+ tsk_safe_free (sample_index_map );
6804
6860
return ret ;
6805
6861
}
0 commit comments