Skip to content

Commit efc3313

Browse files
Site divmat based on genotype decoding
Closes #2779
1 parent b6f9872 commit efc3313

File tree

3 files changed

+254
-160
lines changed

3 files changed

+254
-160
lines changed

c/tests/test_stats.c

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,7 +1132,6 @@ test_single_tree_divergence_matrix_multi_root(void)
11321132
int ret;
11331133
double result[16];
11341134
double D_branch[16] = { 0, 2, 3, 3, 2, 0, 3, 3, 3, 3, 0, 4, 3, 3, 4, 0 };
1135-
double D_site[16] = { 0, 4, 6, 6, 4, 0, 6, 6, 6, 6, 0, 8, 6, 6, 8, 0 };
11361135

11371136
const char *nodes = "1 0 -1 -1\n"
11381137
"1 0 -1 -1\n" /* 2.00┊ 5 ┊ */
@@ -1142,7 +1141,7 @@ test_single_tree_divergence_matrix_multi_root(void)
11421141
"0 2 -1 -1\n"; /* 0 * * * * 1 */
11431142
const char *edges = "0 1 4 0,1\n"
11441143
"0 1 5 2,3\n";
1145-
/* Two mutations per branch unit so we get twice branch length value */
1144+
/* Two mutations per branch */
11461145
const char *sites = "0.1 A\n"
11471146
"0.2 A\n"
11481147
"0.3 A\n"
@@ -1166,9 +1165,8 @@ test_single_tree_divergence_matrix_multi_root(void)
11661165
CU_ASSERT_EQUAL_FATAL(ret, 0);
11671166
assert_arrays_almost_equal(16, result, D_branch);
11681167

1169-
ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_SITE, result);
1170-
CU_ASSERT_EQUAL_FATAL(ret, 0);
1171-
assert_arrays_almost_equal(16, result, D_site);
1168+
verify_divergence_matrix(&ts, TSK_STAT_SITE);
1169+
verify_divergence_matrix(&ts, TSK_STAT_BRANCH);
11721170

11731171
tsk_treeseq_free(&ts);
11741172
}
@@ -2041,6 +2039,13 @@ test_simplest_divergence_matrix(void)
20412039
ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 0, NULL, 0, result);
20422040
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS);
20432041

2042+
sample_ids[0] = 1;
2043+
ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 0, NULL, 0, result);
2044+
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_DUPLICATE_SAMPLE);
2045+
ret = tsk_treeseq_divergence_matrix(
2046+
&ts, 2, sample_ids, 0, NULL, TSK_STAT_BRANCH, result);
2047+
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_DUPLICATE_SAMPLE);
2048+
20442049
tsk_treeseq_free(&ts);
20452050
}
20462051

@@ -2051,15 +2056,19 @@ test_simplest_divergence_matrix_windows(void)
20512056
"1 0 0\n"
20522057
"0 1 0\n";
20532058
const char *edges = "0 1 2 0,1\n";
2059+
const char *sites = "0.1 A\n"
2060+
"0.6 A\n";
2061+
const char *mutations = "0 0 B -1\n"
2062+
"1 0 B -1\n";
20542063
tsk_treeseq_t ts;
20552064
tsk_id_t sample_ids[] = { 0, 1 };
20562065
double D_branch[8] = { 0, 1, 1, 0, 0, 1, 1, 0 };
2057-
double D_site[8] = { 0, 0, 0, 0, 0, 0, 0, 0 };
2066+
double D_site[8] = { 0, 1, 1, 0, 0, 1, 1, 0 };
20582067
double result[8];
20592068
double windows[] = { 0, 0.5, 1 };
20602069
int ret;
20612070

2062-
tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL, 0);
2071+
tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, sites, mutations, NULL, NULL, 0);
20632072

20642073
ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 2, windows, 0, result);
20652074
CU_ASSERT_EQUAL_FATAL(ret, 0);
@@ -2069,6 +2078,16 @@ test_simplest_divergence_matrix_windows(void)
20692078
CU_ASSERT_EQUAL_FATAL(ret, 0);
20702079
assert_arrays_almost_equal(8, D_branch, result);
20712080

2081+
/* Windows for the second half */
2082+
ret = tsk_treeseq_divergence_matrix(
2083+
&ts, 2, sample_ids, 1, windows + 1, TSK_STAT_SITE, result);
2084+
CU_ASSERT_EQUAL_FATAL(ret, 0);
2085+
assert_arrays_almost_equal(4, D_site, result);
2086+
ret = tsk_treeseq_divergence_matrix(
2087+
&ts, 2, sample_ids, 1, windows + 1, TSK_STAT_BRANCH, result);
2088+
CU_ASSERT_EQUAL_FATAL(ret, 0);
2089+
assert_arrays_almost_equal(4, D_branch, result);
2090+
20722091
ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 0, windows, 0, result);
20732092
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_NUM_WINDOWS);
20742093

c/tskit/trees.c

Lines changed: 133 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -6597,43 +6597,62 @@ tsk_treeseq_divergence_matrix_branch(const tsk_treeseq_t *self, tsk_size_t num_s
65976597
return ret;
65986598
}
65996599

6600-
static tsk_size_t
6601-
count_mutations_on_path(tsk_id_t u, tsk_id_t v, const tsk_id_t *restrict parent,
6602-
const double *restrict time, const tsk_size_t *restrict mutations_per_node)
6600+
// FIXME see #2817
6601+
// Just including this here for now as it's the simplest option. Everything
6602+
// will probably move to stats.[c,h] in the near future though, and it
6603+
// can pull in ``genotypes.h`` without issues.
6604+
#include <tskit/genotypes.h>
6605+
6606+
static void
6607+
update_site_divergence(const tsk_variant_t *var, const tsk_id_t *restrict A,
6608+
const tsk_size_t *restrict offsets, double *D)
6609+
66036610
{
6604-
double tu, tv;
6605-
tsk_size_t count = 0;
6611+
const tsk_size_t num_alleles = var->num_alleles;
6612+
const tsk_id_t n = (tsk_id_t) var->num_samples;
66066613

6607-
tu = time[u];
6608-
tv = time[v];
6609-
while (u != v) {
6610-
if (tu < tv) {
6611-
count += mutations_per_node[u];
6612-
u = parent[u];
6613-
if (u == TSK_NULL) {
6614-
break;
6615-
}
6616-
tu = time[u];
6617-
} else {
6618-
count += mutations_per_node[v];
6619-
v = parent[v];
6620-
if (v == TSK_NULL) {
6621-
break;
6614+
tsk_size_t a, b, j, k;
6615+
tsk_id_t u, v;
6616+
6617+
for (a = 0; a < num_alleles; a++) {
6618+
for (b = a + 1; b < num_alleles; b++) {
6619+
for (j = offsets[a]; j < offsets[a + 1]; j++) {
6620+
for (k = offsets[b]; k < offsets[b + 1]; k++) {
6621+
u = A[j];
6622+
v = A[k];
6623+
/* Only increment the upper triangle to (hopefully) improve memory
6624+
* access patterns */
6625+
if (u > v) {
6626+
v = A[j];
6627+
u = A[k];
6628+
}
6629+
D[u * n + v]++;
6630+
}
66226631
}
6623-
tv = time[v];
66246632
}
66256633
}
6626-
if (u != v) {
6627-
while (u != TSK_NULL) {
6628-
count += mutations_per_node[u];
6629-
u = parent[u];
6630-
}
6631-
while (v != TSK_NULL) {
6632-
count += mutations_per_node[v];
6633-
v = parent[v];
6634+
}
6635+
6636+
static void
6637+
group_alleles(const tsk_variant_t *var, tsk_id_t *restrict A, tsk_size_t *offsets)
6638+
{
6639+
const tsk_size_t n = var->num_samples;
6640+
const int32_t *restrict genotypes = var->genotypes;
6641+
tsk_id_t a;
6642+
tsk_size_t j, k;
6643+
6644+
k = 0;
6645+
offsets[0] = 0;
6646+
for (a = 0; a < (tsk_id_t) var->num_alleles; a++) {
6647+
offsets[a + 1] = offsets[a];
6648+
for (j = 0; j < n; j++) {
6649+
if (genotypes[j] == a) {
6650+
offsets[a + 1]++;
6651+
A[k] = (tsk_id_t) j;
6652+
k++;
6653+
}
66346654
}
66356655
}
6636-
return count;
66376656
}
66386657

66396658
static int
@@ -6643,72 +6662,99 @@ tsk_treeseq_divergence_matrix_site(const tsk_treeseq_t *self, tsk_size_t num_sam
66436662
double *restrict result)
66446663
{
66456664
int ret = 0;
6646-
tsk_tree_t tree;
6647-
const tsk_size_t n = num_samples;
6648-
const tsk_size_t num_nodes = self->tables->nodes.num_rows;
6649-
const double *restrict nodes_time = self->tables->nodes.time;
6650-
tsk_size_t i, j, k, tree_site, tree_mut;
6651-
tsk_site_t site;
6652-
tsk_mutation_t mut;
6653-
tsk_id_t u, v;
6654-
double left, right, span_left, span_right;
6665+
tsk_size_t i;
6666+
tsk_id_t site_id;
6667+
double left, right;
66556668
double *restrict D;
6656-
tsk_size_t *mutations_per_node = tsk_malloc(num_nodes * sizeof(*mutations_per_node));
6657-
6658-
ret = tsk_tree_init(&tree, self, 0);
6669+
const tsk_id_t num_sites = (tsk_id_t) self->tables->sites.num_rows;
6670+
const double *restrict sites_position = self->tables->sites.position;
6671+
tsk_id_t *A = tsk_malloc(num_samples * sizeof(*A));
6672+
/* Allocate the allele offsets at the first variant */
6673+
tsk_size_t max_alleles = 0;
6674+
tsk_size_t *allele_offsets = NULL;
6675+
tsk_variant_t variant;
6676+
6677+
ret = tsk_variant_init(
6678+
&variant, self, samples, num_samples, NULL, TSK_ISOLATED_NOT_MISSING);
66596679
if (ret != 0) {
66606680
goto out;
66616681
}
6662-
if (mutations_per_node == NULL) {
6682+
if (A == NULL) {
66636683
ret = TSK_ERR_NO_MEMORY;
66646684
goto out;
66656685
}
66666686

6687+
site_id = 0;
6688+
while (site_id < num_sites && sites_position[site_id] < windows[0]) {
6689+
site_id++;
6690+
}
6691+
66676692
for (i = 0; i < num_windows; i++) {
66686693
left = windows[i];
66696694
right = windows[i + 1];
6670-
D = result + i * n * n;
6671-
ret = tsk_tree_seek(&tree, left, 0);
6672-
if (ret != 0) {
6673-
goto out;
6674-
}
6675-
while (tree.interval.left < right && tree.index != -1) {
6676-
span_left = TSK_MAX(tree.interval.left, left);
6677-
span_right = TSK_MIN(tree.interval.right, right);
6695+
D = result + i * num_samples * num_samples;
66786696

6679-
/* NOTE: we could avoid this full memset across all nodes by doing
6680-
* the same loops again and decrementing at the end of the main
6681-
* tree-loop. It's probably not worth it though, because of the
6682-
* overwhelming O(n^2) below */
6683-
tsk_memset(mutations_per_node, 0, num_nodes * sizeof(*mutations_per_node));
6684-
for (tree_site = 0; tree_site < tree.sites_length; tree_site++) {
6685-
site = tree.sites[tree_site];
6686-
if (span_left <= site.position && site.position < span_right) {
6687-
for (tree_mut = 0; tree_mut < site.mutations_length; tree_mut++) {
6688-
mut = site.mutations[tree_mut];
6689-
mutations_per_node[mut.node]++;
6690-
}
6691-
}
6697+
if (site_id < num_sites) {
6698+
tsk_bug_assert(sites_position[site_id] >= left);
6699+
}
6700+
while (site_id < num_sites && sites_position[site_id] < right) {
6701+
ret = tsk_variant_decode(&variant, site_id, 0);
6702+
if (ret != 0) {
6703+
goto out;
66926704
}
6693-
6694-
for (j = 0; j < n; j++) {
6695-
u = samples[j];
6696-
for (k = j + 1; k < n; k++) {
6697-
v = samples[k];
6698-
D[j * n + k] += (double) count_mutations_on_path(
6699-
u, v, tree.parent, nodes_time, mutations_per_node);
6705+
if (variant.num_alleles > max_alleles) {
6706+
/* could do some kind of doubling here, but there's no
6707+
* point - just keep it simple for testing. */
6708+
max_alleles = variant.num_alleles;
6709+
tsk_safe_free(allele_offsets);
6710+
allele_offsets = tsk_malloc((max_alleles + 1) * sizeof(*allele_offsets));
6711+
if (allele_offsets == NULL) {
6712+
ret = TSK_ERR_NO_MEMORY;
6713+
goto out;
67006714
}
67016715
}
6702-
ret = tsk_tree_next(&tree);
6703-
if (ret < 0) {
6704-
goto out;
6705-
}
6716+
group_alleles(&variant, A, allele_offsets);
6717+
update_site_divergence(&variant, A, allele_offsets, D);
6718+
site_id++;
67066719
}
67076720
}
67086721
ret = 0;
67096722
out:
6710-
tsk_tree_free(&tree);
6711-
tsk_safe_free(mutations_per_node);
6723+
tsk_variant_free(&variant);
6724+
tsk_safe_free(A);
6725+
tsk_safe_free(allele_offsets);
6726+
return ret;
6727+
}
6728+
6729+
static int
6730+
get_sample_index_map(const tsk_size_t num_nodes, const tsk_size_t num_samples,
6731+
const tsk_id_t *restrict samples, tsk_id_t **ret_sample_index_map)
6732+
{
6733+
int ret = 0;
6734+
tsk_size_t j;
6735+
tsk_id_t u;
6736+
tsk_id_t *sample_index_map = tsk_malloc(num_nodes * sizeof(*sample_index_map));
6737+
6738+
if (sample_index_map == NULL) {
6739+
ret = TSK_ERR_NO_MEMORY;
6740+
goto out;
6741+
}
6742+
/* Assign the output pointer here so that it will be freed in the case
6743+
* of an error raised in the input checking */
6744+
*ret_sample_index_map = sample_index_map;
6745+
6746+
for (j = 0; j < num_nodes; j++) {
6747+
sample_index_map[j] = TSK_NULL;
6748+
}
6749+
for (j = 0; j < num_samples; j++) {
6750+
u = samples[j];
6751+
if (sample_index_map[u] != TSK_NULL) {
6752+
ret = TSK_ERR_DUPLICATE_SAMPLE;
6753+
goto out;
6754+
}
6755+
sample_index_map[u] = (tsk_id_t) j;
6756+
}
6757+
out:
67126758
return ret;
67136759
}
67146760

@@ -6739,9 +6785,11 @@ tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_samples,
67396785
const tsk_id_t *samples = self->samples;
67406786
tsk_size_t n = self->num_samples;
67416787
const double default_windows[] = { 0, self->tables->sequence_length };
6788+
const tsk_size_t num_nodes = self->tables->nodes.num_rows;
67426789
bool stat_site = !!(options & TSK_STAT_SITE);
67436790
bool stat_branch = !!(options & TSK_STAT_BRANCH);
67446791
bool stat_node = !!(options & TSK_STAT_NODE);
6792+
tsk_id_t *sample_index_map = NULL;
67456793

67466794
if (stat_node) {
67476795
ret = TSK_ERR_UNSUPPORTED_STAT_MODE;
@@ -6785,6 +6833,13 @@ tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_samples,
67856833
}
67866834
}
67876835

6836+
/* NOTE: we're just using this here to check the input for duplicates.
6837+
*/
6838+
ret = get_sample_index_map(num_nodes, n, samples, &sample_index_map);
6839+
if (ret != 0) {
6840+
goto out;
6841+
}
6842+
67886843
tsk_memset(result, 0, num_windows * n * n * sizeof(*result));
67896844

67906845
if (stat_branch) {
@@ -6801,5 +6856,6 @@ tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_samples,
68016856
fill_lower_triangle(result, n, num_windows);
68026857

68036858
out:
6859+
tsk_safe_free(sample_index_map);
68046860
return ret;
68056861
}

0 commit comments

Comments
 (0)