From bac74442705c1342d25c2981700c26be3a085819 Mon Sep 17 00:00:00 2001 From: lkirk Date: Wed, 18 Dec 2024 14:34:46 -0600 Subject: [PATCH] Optimize two-locus site operations This PR is a combination of three separate modifications. They are described below and in #3290. Fixes (#3290). * Two-locus malloc optimizations This revision moves all malloc operations out of the hot loop in two-locus statistics, instead providing pre-allocated regions of memory that the two-locus framework will use to perform work. Instead of simply passing each pre-allocated array into each function call, we introduce a simple structure called `two_locus_work_t`, which stores the statistical results, and provides temporary arrays for storing the normalisation constants. Setup and teardown methods for this work structure are provided. Python and C tests are passing and valgrind reports no errors. * Refactor bit array api, rename to bitset. As discussed in #2834, this patch renames tsk_bit_array_t to tsk_bitset_t. Philosophically, we treat these as sets and not arrays, performing intersections, unions, and membership tests. Therefore, it makes sense to alter the API to use set theoretic vocabulary, describing the intent more precisely. Fundamentally, the bitset structure is a list of N independent bitsets. Each operation on two sets must select the row on which to operate. The tsk_bitset_t originally tracked `len` only, which was N, the number of sets. For convenience, we also track the `row_len`, which is the number of unsigned integers per row. If we multiply `row_len` by `TSK_BITSET_BITS`, we get the number of bits that each set (or row) in the list of bitsets will hold. We had also discussed each set theoretic operation accepting a row index instead of a pointer to a row within the bitset object. Now, each operation accepts a row index for each bitset structure passed into the function. This simplifies the consumption of this API considerably, removing the need of storing and tracking many intermediate temporary array pointers. We also see some performance improvements from this cleanup. For DRY purposes, I've created a private macro, `BITSET_DATA_ROW`, which abstracts away the pointer arithmetic for selecting a row out of the list of sets. Because of these changes, `tsk_bit_array_get_row` is no longer needed and has been removed from the API. This change does not change the size of the "chunk", which is the unsigned integer storing bits. It remains a 32 bit unsigned integer, which is most performant for bit counting (popcount). I've streamlined the macros used to determine which integer in the row will be used to store a particular bit. Everything now revolves around the TSK_BITSET_BITS macro, which is simply 32 and bitshift operations have been converted to unsigned integer division. Testing has been refactored to reflect these changes, removing tests that operate on a specific rows. Tests in c and python are passing and valgrind shows no errors. Fixes (#2834). * Precompute A/B Counts and Biallelic Summary Func Precompute A/B counts for each sample set. We were previously computing them redundantly each for each site pair in our results matrix. The precomputation happens in a function called `get_mutation_sample_sets`, which takes our list of sets (`tsk_bitset_t`) for each mutation and intersects the samples with a particular mutation with the sample sets passed in by the user. The result is an expanded list of sets with one set per mutation per sample set. During this operation, we compute the number of samples containing the given allele for each mutation, avoiding the need to perform redundant count operations on the data. In addition to precomputation, we add a non-normalized version of `compute_general_two_site_stat_result` for situations where we're computing stats from biallelic loci. We dispatch the computation of the result based on the number of alleles in the two loci we're comparing. If the number of alleles in both loci is 2, then we simply perform an LD computation on the derived alleles for the two loci. As a result, we remove the need to compute a matrix of LD values, then take a weighted sum. This is much more efficient and means that we only run the full multiallelic LD routine on sites that are multiallelic. --- c/tests/test_core.c | 80 +++---- c/tskit/core.c | 119 ++++++---- c/tskit/core.h | 46 ++-- c/tskit/trees.c | 564 ++++++++++++++++++++++++++------------------ 4 files changed, 461 insertions(+), 348 deletions(-) diff --git a/c/tests/test_core.c b/c/tests/test_core.c index 6a14ecc655..1b97e4485f 100644 --- a/c/tests/test_core.c +++ b/c/tests/test_core.c @@ -531,22 +531,24 @@ test_bit_arrays(void) { // NB: This test is only valid for the 32 bit implementation of bit arrays. If we // were to change the chunk size of a bit array, we'd need to update these tests - tsk_bit_array_t arr; + tsk_bitset_t arr; tsk_id_t items_truth[64] = { 0 }, items[64] = { 0 }; tsk_size_t n_items = 0, n_items_truth = 0; // test item retrieval - tsk_bit_array_init(&arr, 90, 1); - tsk_bit_array_get_items(&arr, items, &n_items); + tsk_bitset_init(&arr, 90, 1); + CU_ASSERT_EQUAL_FATAL(arr.len, 1); + CU_ASSERT_EQUAL_FATAL(arr.row_len, 3); + tsk_bitset_get_items(&arr, 0, items, &n_items); assert_arrays_equal(n_items_truth, items, items_truth); - for (tsk_bit_array_value_t i = 0; i < 20; i++) { - tsk_bit_array_add_bit(&arr, i); + for (tsk_bitset_val_t i = 0; i < 20; i++) { + tsk_bitset_set_bit(&arr, 0, i); items_truth[n_items_truth] = (tsk_id_t) i; n_items_truth++; } - tsk_bit_array_add_bit(&arr, 63); - tsk_bit_array_add_bit(&arr, 65); + tsk_bitset_set_bit(&arr, 0, 63); + tsk_bitset_set_bit(&arr, 0, 65); // these assertions are only valid for 32-bit values CU_ASSERT_EQUAL_FATAL(arr.data[0], 1048575); @@ -554,32 +556,29 @@ test_bit_arrays(void) CU_ASSERT_EQUAL_FATAL(arr.data[2], 2); // verify our assumptions about bit array counting - CU_ASSERT_EQUAL_FATAL(tsk_bit_array_count(&arr), 22); + CU_ASSERT_EQUAL_FATAL(tsk_bitset_count(&arr, 0), 22); - tsk_bit_array_get_items(&arr, items, &n_items); + tsk_bitset_get_items(&arr, 0, items, &n_items); assert_arrays_equal(n_items_truth, items, items_truth); tsk_memset(items, 0, 64); tsk_memset(items_truth, 0, 64); n_items = n_items_truth = 0; - tsk_bit_array_free(&arr); + tsk_bitset_free(&arr); - // create a length-2 array with 64 bit capacity - tsk_bit_array_init(&arr, 64, 2); - tsk_bit_array_t arr_row1, arr_row2; - - // select the first and second row - tsk_bit_array_get_row(&arr, 0, &arr_row1); - tsk_bit_array_get_row(&arr, 1, &arr_row2); + // create a length-2 array with 64 bit capacity (two chunks per row) + tsk_bitset_init(&arr, 64, 2); + CU_ASSERT_EQUAL_FATAL(arr.len, 2); + CU_ASSERT_EQUAL_FATAL(arr.row_len, 2); // fill the first 50 bits of the first row - for (tsk_bit_array_value_t i = 0; i < 50; i++) { - tsk_bit_array_add_bit(&arr_row1, i); + for (tsk_bitset_val_t i = 0; i < 50; i++) { + tsk_bitset_set_bit(&arr, 0, i); items_truth[n_items_truth] = (tsk_id_t) i; n_items_truth++; } - tsk_bit_array_get_items(&arr_row1, items, &n_items); + tsk_bitset_get_items(&arr, 0, items, &n_items); assert_arrays_equal(n_items_truth, items, items_truth); tsk_memset(items, 0, 64); @@ -587,13 +586,13 @@ test_bit_arrays(void) n_items = n_items_truth = 0; // fill bits 20-40 of the second row - for (tsk_bit_array_value_t i = 20; i < 40; i++) { - tsk_bit_array_add_bit(&arr_row2, i); + for (tsk_bitset_val_t i = 20; i < 40; i++) { + tsk_bitset_set_bit(&arr, 1, i); items_truth[n_items_truth] = (tsk_id_t) i; n_items_truth++; } - tsk_bit_array_get_items(&arr_row2, items, &n_items); + tsk_bitset_get_items(&arr, 1, items, &n_items); assert_arrays_equal(n_items_truth, items, items_truth); tsk_memset(items, 0, 64); @@ -601,41 +600,38 @@ test_bit_arrays(void) n_items = n_items_truth = 0; // verify our assumptions about row selection - CU_ASSERT_EQUAL_FATAL(arr.data[0], 4294967295); - CU_ASSERT_EQUAL_FATAL(arr.data[1], 262143); - CU_ASSERT_EQUAL_FATAL(arr_row1.data[0], 4294967295); - CU_ASSERT_EQUAL_FATAL(arr_row1.data[1], 262143); - - CU_ASSERT_EQUAL_FATAL(arr.data[2], 4293918720); - CU_ASSERT_EQUAL_FATAL(arr.data[3], 255); - CU_ASSERT_EQUAL_FATAL(arr_row2.data[0], 4293918720); - CU_ASSERT_EQUAL_FATAL(arr_row2.data[1], 255); + CU_ASSERT_EQUAL_FATAL(arr.data[0], 4294967295); // row1 elem1 + CU_ASSERT_EQUAL_FATAL(arr.data[1], 262143); // row1 elem2 + CU_ASSERT_EQUAL_FATAL(arr.data[2], 4293918720); // row2 elem1 + CU_ASSERT_EQUAL_FATAL(arr.data[3], 255); // row2 elem2 // subtract the second from the first row, store in first - tsk_bit_array_subtract(&arr_row1, &arr_row2); + tsk_bitset_subtract(&arr, 0, &arr, 1); // verify our assumptions about subtraction - CU_ASSERT_EQUAL_FATAL(arr_row1.data[0], 1048575); - CU_ASSERT_EQUAL_FATAL(arr_row1.data[1], 261888); + CU_ASSERT_EQUAL_FATAL(arr.data[0], 1048575); + CU_ASSERT_EQUAL_FATAL(arr.data[1], 261888); - tsk_bit_array_t int_result; - tsk_bit_array_init(&int_result, 64, 1); + tsk_bitset_t int_result; + tsk_bitset_init(&int_result, 64, 1); + CU_ASSERT_EQUAL_FATAL(int_result.len, 1); + CU_ASSERT_EQUAL_FATAL(int_result.row_len, 2); // their intersection should be zero - tsk_bit_array_intersect(&arr_row1, &arr_row2, &int_result); + tsk_bitset_intersect(&arr, 0, &arr, 1, &int_result); CU_ASSERT_EQUAL_FATAL(int_result.data[0], 0); CU_ASSERT_EQUAL_FATAL(int_result.data[1], 0); // now, add them back together, storing back in a - tsk_bit_array_add(&arr_row1, &arr_row2); + tsk_bitset_union(&arr, 0, &arr, 1); // now, their intersection should be the subtracted chunk (20-40) - tsk_bit_array_intersect(&arr_row1, &arr_row2, &int_result); + tsk_bitset_intersect(&arr, 0, &arr, 1, &int_result); CU_ASSERT_EQUAL_FATAL(int_result.data[0], 4293918720); CU_ASSERT_EQUAL_FATAL(int_result.data[1], 255); - tsk_bit_array_free(&int_result); - tsk_bit_array_free(&arr); + tsk_bitset_free(&int_result); + tsk_bitset_free(&arr); } static void diff --git a/c/tskit/core.c b/c/tskit/core.c index 53cc0ce679..44ab0a18bd 100644 --- a/c/tskit/core.c +++ b/c/tskit/core.c @@ -1260,16 +1260,16 @@ tsk_avl_tree_int_ordered_nodes(const tsk_avl_tree_int_t *self, tsk_avl_node_int_ } // Bit Array implementation. Allows us to store unsigned integers in a compact manner. -// Currently implemented as an array of 32-bit unsigned integers for ease of counting. +// Currently implemented as an array of 32-bit unsigned integers. int -tsk_bit_array_init(tsk_bit_array_t *self, tsk_size_t num_bits, tsk_size_t length) +tsk_bitset_init(tsk_bitset_t *self, tsk_size_t num_bits, tsk_size_t length) { int ret = 0; - self->size = (num_bits >> TSK_BIT_ARRAY_CHUNK) - + (num_bits % TSK_BIT_ARRAY_NUM_BITS ? 1 : 0); - self->data = tsk_calloc(self->size * length, sizeof(*self->data)); + self->row_len = (num_bits / TSK_BITSET_BITS) + (num_bits % TSK_BITSET_BITS ? 1 : 0); + self->len = length; + self->data = tsk_calloc(self->row_len * length, sizeof(*self->data)); if (self->data == NULL) { ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; @@ -1278,96 +1278,111 @@ tsk_bit_array_init(tsk_bit_array_t *self, tsk_size_t num_bits, tsk_size_t length return ret; } -void -tsk_bit_array_get_row(const tsk_bit_array_t *self, tsk_size_t row, tsk_bit_array_t *out) -{ - out->size = self->size; - out->data = self->data + (row * self->size); -} +#define BITSET_DATA_ROW(bs, row) (bs)->data + (row) * (bs)->row_len void -tsk_bit_array_intersect( - const tsk_bit_array_t *self, const tsk_bit_array_t *other, tsk_bit_array_t *out) +tsk_bitset_intersect(const tsk_bitset_t *self, tsk_size_t self_row, + const tsk_bitset_t *other, tsk_size_t other_row, tsk_bitset_t *out) { - for (tsk_size_t i = 0; i < self->size; i++) { - out->data[i] = self->data[i] & other->data[i]; + const tsk_bitset_val_t *restrict self_d = BITSET_DATA_ROW(self, self_row); + const tsk_bitset_val_t *restrict other_d = BITSET_DATA_ROW(other, other_row); + tsk_bitset_val_t *restrict out_d = out->data; + for (tsk_size_t i = 0; i < self->row_len; i++) { + out_d[i] = self_d[i] & other_d[i]; } } void -tsk_bit_array_subtract(tsk_bit_array_t *self, const tsk_bit_array_t *other) +tsk_bitset_subtract(tsk_bitset_t *self, tsk_size_t self_row, const tsk_bitset_t *other, + tsk_size_t other_row) { - for (tsk_size_t i = 0; i < self->size; i++) { - self->data[i] &= ~(other->data[i]); + tsk_bitset_val_t *restrict self_d = BITSET_DATA_ROW(self, self_row); + const tsk_bitset_val_t *restrict other_d = BITSET_DATA_ROW(other, other_row); + for (tsk_size_t i = 0; i < self->row_len; i++) { + self_d[i] &= ~(other_d[i]); } } void -tsk_bit_array_add(tsk_bit_array_t *self, const tsk_bit_array_t *other) +tsk_bitset_union(tsk_bitset_t *self, tsk_size_t self_row, const tsk_bitset_t *other, + tsk_size_t other_row) { - for (tsk_size_t i = 0; i < self->size; i++) { - self->data[i] |= other->data[i]; + tsk_bitset_val_t *restrict self_d = BITSET_DATA_ROW(self, self_row); + const tsk_bitset_val_t *restrict other_d = BITSET_DATA_ROW(other, other_row); + for (tsk_size_t i = 0; i < self->row_len; i++) { + self_d[i] |= other_d[i]; } } void -tsk_bit_array_add_bit(tsk_bit_array_t *self, const tsk_bit_array_value_t bit) +tsk_bitset_set_bit(tsk_bitset_t *self, tsk_size_t row, const tsk_bitset_val_t bit) { - tsk_bit_array_value_t i = bit >> TSK_BIT_ARRAY_CHUNK; - self->data[i] |= (tsk_bit_array_value_t) 1 << (bit - (TSK_BIT_ARRAY_NUM_BITS * i)); + tsk_bitset_val_t i = (bit / TSK_BITSET_BITS); + *(BITSET_DATA_ROW(self, row) + i) |= (tsk_bitset_val_t) 1 + << (bit - (TSK_BITSET_BITS * i)); } bool -tsk_bit_array_contains(const tsk_bit_array_t *self, const tsk_bit_array_value_t bit) +tsk_bitset_contains(const tsk_bitset_t *self, tsk_size_t row, const tsk_bitset_val_t bit) { - tsk_bit_array_value_t i = bit >> TSK_BIT_ARRAY_CHUNK; - return self->data[i] - & ((tsk_bit_array_value_t) 1 << (bit - (TSK_BIT_ARRAY_NUM_BITS * i))); + tsk_bitset_val_t i = (bit / TSK_BITSET_BITS); + return *(BITSET_DATA_ROW(self, row) + i) + & ((tsk_bitset_val_t) 1 << (bit - (TSK_BITSET_BITS * i))); } -tsk_size_t -tsk_bit_array_count(const tsk_bit_array_t *self) +static inline uint32_t +popcount(tsk_bitset_val_t v) { - // Utilizes 12 operations per bit array. NB this only works on 32 bit integers. + // Utilizes 12 operations per chunk. NB this only works on 32 bit integers. // Taken from: // https://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel // There's a nice breakdown of this algorithm here: // https://stackoverflow.com/a/109025 - // Could probably do better with explicit SIMD (instead of SWAR), but not as - // portable: https://arxiv.org/pdf/1611.07612.pdf // - // There is one solution to explore further, which uses __builtin_popcountll. - // This option is relatively simple, but requires a 64 bit bit array and also - // involves some compiler flag plumbing (-mpopcnt) + // The gcc/clang compiler flag will -mpopcnt will convert this code to a + // popcnt instruction (most if not all modern CPUs will support this). The + // popcnt instruction will yield some speed improvements, which depend on + // the tree sequence. + // + // NB: 32bit counting is typically faster than 64bit counting for this task. + // (at least on x86-64) - tsk_bit_array_value_t tmp; - tsk_size_t i, count = 0; + v = v - ((v >> 1) & 0x55555555); + v = (v & 0x33333333) + ((v >> 2) & 0x33333333); + return (((v + (v >> 4)) & 0xF0F0F0F) * 0x1010101) >> 24; +} + +tsk_size_t +tsk_bitset_count(const tsk_bitset_t *self, tsk_size_t row) +{ + tsk_size_t i = 0; + uint32_t count = 0; + const tsk_bitset_val_t *restrict self_d = BITSET_DATA_ROW(self, row); - for (i = 0; i < self->size; i++) { - tmp = self->data[i] - ((self->data[i] >> 1) & 0x55555555); - tmp = (tmp & 0x33333333) + ((tmp >> 2) & 0x33333333); - count += (((tmp + (tmp >> 4)) & 0xF0F0F0F) * 0x1010101) >> 24; + for (i = 0; i < self->row_len; i++) { + count += popcount(self_d[i]); } - return count; + return (tsk_size_t) count; } void -tsk_bit_array_get_items( - const tsk_bit_array_t *self, tsk_id_t *items, tsk_size_t *n_items) +tsk_bitset_get_items( + const tsk_bitset_t *self, tsk_size_t row, tsk_id_t *items, tsk_size_t *n_items) { // Get the items stored in the row of a bitset. - // Uses a de Bruijn sequence lookup table to determine the lowest bit set. See the - // wikipedia article for more info: https://w.wiki/BYiF + // Uses a de Bruijn sequence lookup table to determine the lowest bit set. + // See the wikipedia article for more info: https://w.wiki/BYiF tsk_size_t i, n, off; - tsk_bit_array_value_t v, lsb; // least significant bit + tsk_bitset_val_t v, lsb; // least significant bit static const tsk_id_t lookup[32] = { 0, 1, 28, 2, 29, 14, 24, 3, 30, 22, 20, 15, 25, 17, 4, 8, 31, 27, 13, 23, 21, 19, 16, 7, 26, 12, 18, 6, 11, 5, 10, 9 }; + const tsk_bitset_val_t *restrict self_d = BITSET_DATA_ROW(self, row); n = 0; - for (i = 0; i < self->size; i++) { - v = self->data[i]; - off = i * ((tsk_size_t) TSK_BIT_ARRAY_NUM_BITS); + for (i = 0; i < self->row_len; i++) { + v = self_d[i]; + off = i * TSK_BITSET_BITS; if (v == 0) { continue; } @@ -1381,7 +1396,7 @@ tsk_bit_array_get_items( } void -tsk_bit_array_free(tsk_bit_array_t *self) +tsk_bitset_free(tsk_bitset_t *self) { tsk_safe_free(self->data); } diff --git a/c/tskit/core.h b/c/tskit/core.h index 7dd24eba56..dede20d7d5 100644 --- a/c/tskit/core.h +++ b/c/tskit/core.h @@ -1104,29 +1104,31 @@ FILE *tsk_get_debug_stream(void); /* Bit Array functionality */ -typedef uint32_t tsk_bit_array_value_t; +// define a 32-bit chunk size for our bitsets. +// this means we'll be able to hold 32 distinct items in each 32 bit uint +#define TSK_BITSET_BITS (tsk_size_t) 32 +typedef uint32_t tsk_bitset_val_t; + typedef struct { - tsk_size_t size; // Number of chunks per row - tsk_bit_array_value_t *data; // Array data -} tsk_bit_array_t; - -#define TSK_BIT_ARRAY_CHUNK 5U -#define TSK_BIT_ARRAY_NUM_BITS (1U << TSK_BIT_ARRAY_CHUNK) - -int tsk_bit_array_init(tsk_bit_array_t *self, tsk_size_t num_bits, tsk_size_t length); -void tsk_bit_array_free(tsk_bit_array_t *self); -void tsk_bit_array_get_row( - const tsk_bit_array_t *self, tsk_size_t row, tsk_bit_array_t *out); -void tsk_bit_array_intersect( - const tsk_bit_array_t *self, const tsk_bit_array_t *other, tsk_bit_array_t *out); -void tsk_bit_array_subtract(tsk_bit_array_t *self, const tsk_bit_array_t *other); -void tsk_bit_array_add(tsk_bit_array_t *self, const tsk_bit_array_t *other); -void tsk_bit_array_add_bit(tsk_bit_array_t *self, const tsk_bit_array_value_t bit); -bool tsk_bit_array_contains( - const tsk_bit_array_t *self, const tsk_bit_array_value_t bit); -tsk_size_t tsk_bit_array_count(const tsk_bit_array_t *self); -void tsk_bit_array_get_items( - const tsk_bit_array_t *self, tsk_id_t *items, tsk_size_t *n_items); + tsk_size_t row_len; // Number of size TSK_BITSET_BITS chunks per row + tsk_size_t len; // Number of rows + tsk_bitset_val_t *data; +} tsk_bitset_t; + +int tsk_bitset_init(tsk_bitset_t *self, tsk_size_t num_bits, tsk_size_t length); +void tsk_bitset_free(tsk_bitset_t *self); +void tsk_bitset_intersect(const tsk_bitset_t *self, tsk_size_t self_row, + const tsk_bitset_t *other, tsk_size_t other_row, tsk_bitset_t *out); +void tsk_bitset_subtract(tsk_bitset_t *self, tsk_size_t self_row, + const tsk_bitset_t *other, tsk_size_t other_row); +void tsk_bitset_union(tsk_bitset_t *self, tsk_size_t self_row, const tsk_bitset_t *other, + tsk_size_t other_row); +void tsk_bitset_set_bit(tsk_bitset_t *self, tsk_size_t row, const tsk_bitset_val_t bit); +bool tsk_bitset_contains( + const tsk_bitset_t *self, tsk_size_t row, const tsk_bitset_val_t bit); +tsk_size_t tsk_bitset_count(const tsk_bitset_t *self, tsk_size_t row); +void tsk_bitset_get_items( + const tsk_bitset_t *self, tsk_size_t row, tsk_id_t *items, tsk_size_t *n_items); #ifdef __cplusplus } diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 7a159a7fe4..16d72ef4b5 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -2223,20 +2223,18 @@ tsk_treeseq_sample_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_s ***********************************/ static int -get_allele_samples(const tsk_site_t *site, const tsk_bit_array_t *state, - tsk_bit_array_t *out_allele_samples, tsk_size_t *out_num_alleles) +get_allele_samples(const tsk_site_t *site, tsk_size_t site_offset, + const tsk_bitset_t *state, tsk_bitset_t *out_allele_samples, + tsk_size_t *out_num_alleles) { int ret = 0; tsk_mutation_t mutation, parent_mut; - tsk_size_t mutation_index, allele, alt_allele_length; + tsk_size_t mutation_index, allele, alt_allele, alt_allele_length; /* The allele table */ tsk_size_t max_alleles = site->mutations_length + 1; const char **alleles = tsk_malloc(max_alleles * sizeof(*alleles)); tsk_size_t *allele_lengths = tsk_calloc(max_alleles, sizeof(*allele_lengths)); - const char *alt_allele; - tsk_bit_array_t state_row; - tsk_bit_array_t allele_samples_row; - tsk_bit_array_t alt_allele_samples_row; + const char *alt_allele_state; tsk_size_t num_alleles = 1; if (alleles == NULL || allele_lengths == NULL) { @@ -2267,29 +2265,29 @@ get_allele_samples(const tsk_site_t *site, const tsk_bit_array_t *state, } /* Add the mutation's samples to this allele */ - tsk_bit_array_get_row(out_allele_samples, allele, &allele_samples_row); - tsk_bit_array_get_row(state, mutation_index, &state_row); - tsk_bit_array_add(&allele_samples_row, &state_row); + tsk_bitset_union( + out_allele_samples, allele + site_offset, state, mutation_index); /* Get the index for the alternate allele that we must subtract from */ - alt_allele = site->ancestral_state; + alt_allele_state = site->ancestral_state; alt_allele_length = site->ancestral_state_length; if (mutation.parent != TSK_NULL) { parent_mut = site->mutations[mutation.parent - site->mutations[0].id]; - alt_allele = parent_mut.derived_state; + alt_allele_state = parent_mut.derived_state; alt_allele_length = parent_mut.derived_state_length; } - for (allele = 0; allele < num_alleles; allele++) { - if (alt_allele_length == allele_lengths[allele] - && tsk_memcmp(alt_allele, alleles[allele], allele_lengths[allele]) + for (alt_allele = 0; alt_allele < num_alleles; alt_allele++) { + if (alt_allele_length == allele_lengths[alt_allele] + && tsk_memcmp( + alt_allele_state, alleles[alt_allele], allele_lengths[alt_allele]) == 0) { break; } } tsk_bug_assert(allele < num_alleles); - tsk_bit_array_get_row(out_allele_samples, allele, &alt_allele_samples_row); - tsk_bit_array_subtract(&alt_allele_samples_row, &allele_samples_row); + tsk_bitset_subtract(out_allele_samples, alt_allele + site_offset, + out_allele_samples, allele + site_offset); } *out_num_alleles = num_alleles; out: @@ -2310,7 +2308,6 @@ norm_hap_weighted(tsk_size_t result_dim, const double *hap_weights, for (k = 0; k < result_dim; k++) { weight_row = GET_2D_ROW(hap_weights, 3, k); n = (double) args.sample_set_sizes[k]; - // TODO: what to do when n = 0 result[k] = weight_row[0] / n; } return 0; @@ -2347,111 +2344,107 @@ norm_total_weighted(tsk_size_t result_dim, const double *TSK_UNUSED(hap_weights) tsk_size_t n_a, tsk_size_t n_b, double *result, void *TSK_UNUSED(params)) { tsk_size_t k; + double norm = 1 / (double) (n_a * n_b); for (k = 0; k < result_dim; k++) { - result[k] = 1 / (double) (n_a * n_b); + result[k] = norm; } return 0; } static void -get_all_samples_bits(tsk_bit_array_t *all_samples, tsk_size_t n) +get_all_samples_bits(tsk_bitset_t *all_samples, tsk_size_t n) { tsk_size_t i; - const tsk_bit_array_value_t all = ~((tsk_bit_array_value_t) 0); - const tsk_bit_array_value_t remainder_samples = n % TSK_BIT_ARRAY_NUM_BITS; + const tsk_bitset_val_t all = ~((tsk_bitset_val_t) 0); + const tsk_bitset_val_t remainder_samples = n % TSK_BITSET_BITS; - all_samples->data[all_samples->size - 1] + all_samples->data[all_samples->row_len - 1] = remainder_samples ? ~(all << remainder_samples) : all; - for (i = 0; i < all_samples->size - 1; i++) { + for (i = 0; i < all_samples->row_len - 1; i++) { all_samples->data[i] = all; } } +typedef struct { + double *weights; + double *norm; + double *result_tmp; + tsk_bitset_t AB_samples; +} two_locus_work_t; + static int -compute_general_two_site_stat_result(const tsk_bit_array_t *site_a_state, - const tsk_bit_array_t *site_b_state, tsk_size_t num_a_alleles, - tsk_size_t num_b_alleles, tsk_size_t num_samples, tsk_size_t state_dim, - const tsk_bit_array_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f, - sample_count_stat_params_t *f_params, norm_func_t *norm_f, bool polarised, - double *result) +two_locus_work_init(tsk_size_t max_alleles, tsk_size_t num_samples, + tsk_size_t result_dim, tsk_size_t state_dim, two_locus_work_t *out) { int ret = 0; - tsk_bit_array_t A_samples, B_samples; - // ss_ prefix refers to a sample set - tsk_bit_array_t ss_row; - tsk_bit_array_t ss_A_samples, ss_B_samples, ss_AB_samples, AB_samples; - // Sample sets and b sites are rows, a sites are columns - // b1 b2 b3 - // a1 [s1, s2, s3] [s1, s2, s3] [s1, s2, s3] - // a2 [s1, s2, s3] [s1, s2, s3] [s1, s2, s3] - // a3 [s1, s2, s3] [s1, s2, s3] [s1, s2, s3] - tsk_size_t k, mut_a, mut_b; - tsk_size_t result_row_len = num_b_alleles * result_dim; - tsk_size_t w_A = 0, w_B = 0, w_AB = 0; - uint8_t polarised_val = polarised ? 1 : 0; - double *hap_weight_row; - double *result_tmp_row; - double *weights = tsk_malloc(3 * state_dim * sizeof(*weights)); - double *norm = tsk_malloc(result_dim * sizeof(*norm)); - double *result_tmp - = tsk_malloc(result_row_len * num_a_alleles * sizeof(*result_tmp)); - - tsk_memset(&ss_A_samples, 0, sizeof(ss_A_samples)); - tsk_memset(&ss_B_samples, 0, sizeof(ss_B_samples)); - tsk_memset(&ss_AB_samples, 0, sizeof(ss_AB_samples)); - tsk_memset(&AB_samples, 0, sizeof(AB_samples)); - - if (weights == NULL || norm == NULL || result_tmp == NULL) { - ret = tsk_trace_error(TSK_ERR_NO_MEMORY); - goto out; - } - ret = tsk_bit_array_init(&ss_A_samples, num_samples, 1); - if (ret != 0) { - goto out; - } - ret = tsk_bit_array_init(&ss_B_samples, num_samples, 1); - if (ret != 0) { - goto out; - } - ret = tsk_bit_array_init(&ss_AB_samples, num_samples, 1); - if (ret != 0) { + out->weights = tsk_malloc(3 * state_dim * sizeof(*out->weights)); + out->norm = tsk_malloc(result_dim * sizeof(*out->norm)); + out->result_tmp + = tsk_malloc(result_dim * max_alleles * max_alleles * sizeof(*out->result_tmp)); + tsk_memset(&out->AB_samples, 0, sizeof(out->AB_samples)); + if (out->weights == NULL || out->norm == NULL || out->result_tmp == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } - ret = tsk_bit_array_init(&AB_samples, num_samples, 1); + ret = tsk_bitset_init(&out->AB_samples, num_samples, 1); if (ret != 0) { goto out; } +out: + return ret; +} + +static void +two_locus_work_free(two_locus_work_t *work) +{ + tsk_safe_free(work->weights); + tsk_safe_free(work->norm); + tsk_safe_free(work->result_tmp); + tsk_bitset_free(&work->AB_samples); +} - for (mut_a = polarised_val; mut_a < num_a_alleles; mut_a++) { +static int +compute_general_normed_two_site_stat_result(const tsk_bitset_t *state, + const tsk_size_t *allele_counts, tsk_size_t a_off, tsk_size_t b_off, + tsk_size_t num_a_alleles, tsk_size_t num_b_alleles, tsk_size_t state_dim, + tsk_size_t result_dim, general_stat_func_t *f, sample_count_stat_params_t *f_params, + norm_func_t *norm_f, bool polarised, two_locus_work_t *restrict work, double *result) +{ + int ret = 0; + // Sample sets and b sites are rows, a sites are columns + // b1 b2 b3 + // a1 [s1, s2, s3] [s1, s2, s3] [s1, s2, s3] + // a2 [s1, s2, s3] [s1, s2, s3] [s1, s2, s3] + // a3 [s1, s2, s3] [s1, s2, s3] [s1, s2, s3] + tsk_size_t k, mut_a, mut_b, result_row_len = num_b_alleles * result_dim; + uint8_t is_polarised = polarised ? 1 : 0; + double *restrict hap_row, *restrict result_tmp_row; + double *restrict norm = work->norm; + double *restrict weights = work->weights; + double *restrict result_tmp = work->result_tmp; + tsk_bitset_t AB_samples = work->AB_samples; + + for (mut_a = is_polarised; mut_a < num_a_alleles; mut_a++) { result_tmp_row = GET_2D_ROW(result_tmp, result_row_len, mut_a); - for (mut_b = polarised_val; mut_b < num_b_alleles; mut_b++) { - tsk_bit_array_get_row(site_a_state, mut_a, &A_samples); - tsk_bit_array_get_row(site_b_state, mut_b, &B_samples); - tsk_bit_array_intersect(&A_samples, &B_samples, &AB_samples); + for (mut_b = is_polarised; mut_b < num_b_alleles; mut_b++) { for (k = 0; k < state_dim; k++) { - tsk_bit_array_get_row(sample_sets, k, &ss_row); - hap_weight_row = GET_2D_ROW(weights, 3, k); - - tsk_bit_array_intersect(&A_samples, &ss_row, &ss_A_samples); - tsk_bit_array_intersect(&B_samples, &ss_row, &ss_B_samples); - tsk_bit_array_intersect(&AB_samples, &ss_row, &ss_AB_samples); - - w_AB = tsk_bit_array_count(&ss_AB_samples); - w_A = tsk_bit_array_count(&ss_A_samples); - w_B = tsk_bit_array_count(&ss_B_samples); - - hap_weight_row[0] = (double) w_AB; - hap_weight_row[1] = (double) (w_A - w_AB); // w_Ab - hap_weight_row[2] = (double) (w_B - w_AB); // w_aB + tsk_bitset_intersect(state, a_off + (mut_a * state_dim) + k, state, + b_off + (mut_b * state_dim) + k, &AB_samples); + hap_row = GET_2D_ROW(weights, 3, k); + hap_row[0] = (double) tsk_bitset_count(&AB_samples, 0); + hap_row[1] = (double) allele_counts[a_off + (mut_a * state_dim) + k] + - hap_row[0]; + hap_row[2] = (double) allele_counts[b_off + (mut_b * state_dim) + k] + - hap_row[0]; } ret = f(state_dim, weights, result_dim, result_tmp_row, f_params); if (ret != 0) { goto out; } - ret = norm_f(result_dim, weights, num_a_alleles - polarised_val, - num_b_alleles - polarised_val, norm, f_params); + ret = norm_f(result_dim, weights, num_a_alleles - is_polarised, + num_b_alleles - is_polarised, norm, f_params); if (ret != 0) { goto out; } @@ -2461,15 +2454,38 @@ compute_general_two_site_stat_result(const tsk_bit_array_t *site_a_state, result_tmp_row += result_dim; // Advance to the next column } } +out: + return ret; +} + +static int +compute_general_two_site_stat_result(const tsk_bitset_t *state, + const tsk_size_t *allele_counts, tsk_size_t a_off, tsk_size_t b_off, + tsk_size_t state_dim, tsk_size_t result_dim, general_stat_func_t *f, + sample_count_stat_params_t *f_params, two_locus_work_t *restrict work, + double *result) +{ + int ret = 0; + tsk_size_t k; + tsk_bitset_t AB_samples = work->AB_samples; + tsk_size_t mut_a = 1, mut_b = 1; + double *restrict hap_row, *restrict weights = work->weights; + for (k = 0; k < state_dim; k++) { + tsk_bitset_intersect(state, a_off + (mut_a * state_dim) + k, state, + b_off + (mut_b * state_dim) + k, &AB_samples); + hap_row = GET_2D_ROW(weights, 3, k); + hap_row[0] = (double) tsk_bitset_count(&AB_samples, 0); + hap_row[1] + = (double) allele_counts[a_off + (mut_a * state_dim) + k] - hap_row[0]; + hap_row[2] + = (double) allele_counts[b_off + (mut_b * state_dim) + k] - hap_row[0]; + } + ret = f(state_dim, weights, result_dim, result, f_params); + if (ret != 0) { + goto out; + } out: - tsk_safe_free(weights); - tsk_safe_free(norm); - tsk_safe_free(result_tmp); - tsk_bit_array_free(&ss_A_samples); - tsk_bit_array_free(&ss_B_samples); - tsk_bit_array_free(&ss_AB_samples); - tsk_bit_array_free(&AB_samples); return ret; } @@ -2520,7 +2536,7 @@ get_site_row_col_indices(tsk_size_t n_rows, const tsk_id_t *row_sites, tsk_size_ static int get_mutation_samples(const tsk_treeseq_t *ts, const tsk_id_t *sites, tsk_size_t n_sites, - tsk_size_t *num_alleles, tsk_bit_array_t *allele_samples) + tsk_size_t *num_alleles, tsk_bitset_t *allele_samples) { int ret = 0; const tsk_flags_t *restrict flags = ts->tables->nodes.flags; @@ -2528,7 +2544,7 @@ get_mutation_samples(const tsk_treeseq_t *ts, const tsk_id_t *sites, tsk_size_t const tsk_size_t *restrict site_muts_len = ts->site_mutations_length; tsk_site_t site; tsk_tree_t tree; - tsk_bit_array_t all_samples_bits, mut_samples, mut_samples_row, out_row; + tsk_bitset_t all_samples_bits, mut_samples; tsk_size_t max_muts_len, site_offset, num_nodes, site_idx, s, m, n; tsk_id_t node, *nodes = NULL; void *tmp_nodes; @@ -2543,11 +2559,11 @@ get_mutation_samples(const tsk_treeseq_t *ts, const tsk_id_t *sites, tsk_size_t } } // Allocate a bit array of size max alleles for all sites - ret = tsk_bit_array_init(&mut_samples, num_samples, max_muts_len); + ret = tsk_bitset_init(&mut_samples, num_samples, max_muts_len); if (ret != 0) { goto out; } - ret = tsk_bit_array_init(&all_samples_bits, num_samples, 1); + ret = tsk_bitset_init(&all_samples_bits, num_samples, 1); if (ret != 0) { goto out; } @@ -2572,15 +2588,11 @@ get_mutation_samples(const tsk_treeseq_t *ts, const tsk_id_t *sites, tsk_size_t goto out; } nodes = tmp_nodes; - - tsk_bit_array_get_row(allele_samples, site_offset, &out_row); - tsk_bit_array_add(&out_row, &all_samples_bits); - + tsk_bitset_union(allele_samples, site_offset, &all_samples_bits, 0); // Zero out results before the start of each iteration tsk_memset(mut_samples.data, 0, - mut_samples.size * max_muts_len * sizeof(tsk_bit_array_value_t)); + mut_samples.row_len * max_muts_len * sizeof(tsk_bitset_val_t)); for (m = 0; m < site.mutations_length; m++) { - tsk_bit_array_get_row(&mut_samples, m, &mut_samples_row); node = site.mutations[m].node; ret = tsk_tree_preorder_from(&tree, node, nodes, &num_nodes); if (ret != 0) { @@ -2589,43 +2601,92 @@ get_mutation_samples(const tsk_treeseq_t *ts, const tsk_id_t *sites, tsk_size_t for (n = 0; n < num_nodes; n++) { node = nodes[n]; if (flags[node] & TSK_NODE_IS_SAMPLE) { - tsk_bit_array_add_bit(&mut_samples_row, - (tsk_bit_array_value_t) ts->sample_index_map[node]); + tsk_bitset_set_bit( + &mut_samples, m, (tsk_bitset_val_t) ts->sample_index_map[node]); } } } + get_allele_samples( + &site, site_offset, &mut_samples, allele_samples, &(num_alleles[site_idx])); site_offset += site.mutations_length + 1; - get_allele_samples(&site, &mut_samples, &out_row, &(num_alleles[site_idx])); } // if adding code below, check ret before continuing out: tsk_safe_free(nodes); tsk_tree_free(&tree); - tsk_bit_array_free(&mut_samples); - tsk_bit_array_free(&all_samples_bits); + tsk_bitset_free(&mut_samples); + tsk_bitset_free(&all_samples_bits); return ret == TSK_TREE_OK ? 0 : ret; } +static int +get_mutation_sample_sets(const tsk_bitset_t *allele_samples, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + const tsk_id_t *sample_index_map, tsk_size_t *max_ss_size, + tsk_bitset_t *allele_sample_sets, tsk_size_t **allele_sample_set_counts) +{ + int ret = 0; + tsk_bitset_val_t k, sample; + tsk_size_t i, j, ss_off; + + *max_ss_size = 0; + for (i = 0; i < num_sample_sets; i++) { + if (sample_set_sizes[i] > *max_ss_size) { + *max_ss_size = sample_set_sizes[i]; + } + } + + *allele_sample_set_counts = tsk_calloc( + allele_samples->len * num_sample_sets, sizeof(**allele_sample_set_counts)); + if (*allele_sample_set_counts == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + ret = tsk_bitset_init( + allele_sample_sets, *max_ss_size, allele_samples->len * num_sample_sets); + if (ret != 0) { + goto out; + } + + for (i = 0; i < allele_samples->len; i++) { + ss_off = 0; + for (j = 0; j < num_sample_sets; j++) { + for (k = 0; k < sample_set_sizes[j]; k++) { + sample = (tsk_bitset_val_t) sample_index_map[sample_sets[k + ss_off]]; + if (tsk_bitset_contains(allele_samples, i, sample)) { + tsk_bitset_set_bit(allele_sample_sets, j + i * num_sample_sets, k); + (*allele_sample_set_counts)[j + i * num_sample_sets]++; + } + } + ss_off += sample_set_sizes[j]; + } + } +out: + return ret; +} + static int tsk_treeseq_two_site_count_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, - const tsk_bit_array_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f, + tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, + const tsk_id_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f, sample_count_stat_params_t *f_params, norm_func_t *norm_f, tsk_size_t n_rows, const tsk_id_t *row_sites, tsk_size_t n_cols, const tsk_id_t *col_sites, tsk_flags_t options, double *result) { - int ret = 0; - tsk_bit_array_t allele_samples, c_state, r_state; + tsk_bitset_t allele_samples, allele_sample_sets; bool polarised = false; tsk_id_t *sites; - tsk_size_t r, c, s, n_alleles, n_sites, *row_idx, *col_idx; + tsk_size_t i, j, max_ss_size, max_alleles, n_alleles, n_sites, *row_idx, *col_idx; double *result_row; const tsk_size_t num_samples = self->num_samples; - tsk_size_t *num_alleles = NULL, *site_offsets = NULL; + tsk_size_t *num_alleles = NULL, *site_offsets = NULL, *allele_counts = NULL; tsk_size_t result_row_len = n_cols * result_dim; + two_locus_work_t work; + tsk_memset(&work, 0, sizeof(work)); tsk_memset(&allele_samples, 0, sizeof(allele_samples)); - + tsk_memset(&allele_sample_sets, 0, sizeof(allele_sample_sets)); sites = tsk_malloc(self->tables->sites.num_rows * sizeof(*sites)); row_idx = tsk_malloc(self->tables->sites.num_rows * sizeof(*row_idx)); col_idx = tsk_malloc(self->tables->sites.num_rows * sizeof(*col_idx)); @@ -2644,35 +2705,57 @@ tsk_treeseq_two_site_count_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } - n_alleles = 0; - for (s = 0; s < n_sites; s++) { - site_offsets[s] = n_alleles; - n_alleles += self->site_mutations_length[sites[s]] + 1; + max_alleles = 0; + for (i = 0; i < n_sites; i++) { + site_offsets[i] = n_alleles * num_sample_sets; + n_alleles += self->site_mutations_length[sites[i]] + 1; + if (self->site_mutations_length[sites[i]] > max_alleles) { + max_alleles = self->site_mutations_length[sites[i]]; + } } - ret = tsk_bit_array_init(&allele_samples, num_samples, n_alleles); + max_alleles++; // add 1 for the ancestral allele + + ret = tsk_bitset_init(&allele_samples, num_samples, n_alleles); if (ret != 0) { goto out; } + // we track the number of alleles to account for backmutations ret = get_mutation_samples(self, sites, n_sites, num_alleles, &allele_samples); if (ret != 0) { goto out; } + ret = get_mutation_sample_sets(&allele_samples, num_sample_sets, sample_set_sizes, + sample_sets, self->sample_index_map, &max_ss_size, &allele_sample_sets, + &allele_counts); + if (ret != 0) { + goto out; + } + ret = two_locus_work_init(max_alleles, max_ss_size, result_dim, state_dim, &work); + if (ret != 0) { + goto out; + } if (options & TSK_STAT_POLARISED) { polarised = true; } // For each row/column pair, fill in the sample set in the result matrix. - for (r = 0; r < n_rows; r++) { - result_row = GET_2D_ROW(result, result_row_len, r); - for (c = 0; c < n_cols; c++) { - tsk_bit_array_get_row(&allele_samples, site_offsets[row_idx[r]], &r_state); - tsk_bit_array_get_row(&allele_samples, site_offsets[col_idx[c]], &c_state); - ret = compute_general_two_site_stat_result(&r_state, &c_state, - num_alleles[row_idx[r]], num_alleles[col_idx[c]], num_samples, state_dim, - sample_sets, result_dim, f, f_params, norm_f, polarised, - &(result_row[c * result_dim])); + for (i = 0; i < n_rows; i++) { + result_row = GET_2D_ROW(result, result_row_len, i); + for (j = 0; j < n_cols; j++) { + if (num_alleles[row_idx[i]] == 2 && num_alleles[col_idx[j]] == 2) { + ret = compute_general_two_site_stat_result(&allele_sample_sets, + allele_counts, site_offsets[row_idx[i]], site_offsets[col_idx[j]], + state_dim, result_dim, f, f_params, &work, + &(result_row[j * result_dim])); + } else { + ret = compute_general_normed_two_site_stat_result(&allele_sample_sets, + allele_counts, site_offsets[row_idx[i]], site_offsets[col_idx[j]], + num_alleles[row_idx[i]], num_alleles[col_idx[j]], state_dim, + result_dim, f, f_params, norm_f, polarised, &work, + &(result_row[j * result_dim])); + } if (ret != 0) { goto out; } @@ -2685,37 +2768,37 @@ tsk_treeseq_two_site_count_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, tsk_safe_free(col_idx); tsk_safe_free(num_alleles); tsk_safe_free(site_offsets); - tsk_bit_array_free(&allele_samples); + tsk_safe_free(allele_counts); + two_locus_work_free(&work); + tsk_bitset_free(&allele_samples); + tsk_bitset_free(&allele_sample_sets); return ret; } static int -sample_sets_to_bit_array(const tsk_treeseq_t *self, const tsk_size_t *sample_set_sizes, +sample_sets_to_bitset(const tsk_treeseq_t *self, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_sample_sets, - tsk_bit_array_t *sample_sets_bits) + tsk_bitset_t *sample_sets_bits) { int ret; - tsk_bit_array_t bits_row; tsk_size_t j, k, l; tsk_id_t u, sample_index; - ret = tsk_bit_array_init(sample_sets_bits, self->num_samples, num_sample_sets); + ret = tsk_bitset_init(sample_sets_bits, self->num_samples, num_sample_sets); if (ret != 0) { return ret; } - j = 0; for (k = 0; k < num_sample_sets; k++) { - tsk_bit_array_get_row(sample_sets_bits, k, &bits_row); for (l = 0; l < sample_set_sizes[k]; l++) { u = sample_sets[j]; sample_index = self->sample_index_map[u]; - if (tsk_bit_array_contains( - &bits_row, (tsk_bit_array_value_t) sample_index)) { + if (tsk_bitset_contains( + sample_sets_bits, k, (tsk_bitset_val_t) sample_index)) { ret = tsk_trace_error(TSK_ERR_DUPLICATE_SAMPLE); goto out; } - tsk_bit_array_add_bit(&bits_row, (tsk_bit_array_value_t) sample_index); + tsk_bitset_set_bit(sample_sets_bits, k, (tsk_bitset_val_t) sample_index); j++; } } @@ -2852,7 +2935,7 @@ get_index_counts( typedef struct { tsk_tree_t tree; - tsk_bit_array_t *node_samples; + tsk_bitset_t *node_samples; tsk_id_t *parent; tsk_id_t *edges_out; tsk_id_t *edges_in; @@ -2876,7 +2959,7 @@ iter_state_init(iter_state *self, const tsk_treeseq_t *ts, tsk_size_t state_dim) ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } - ret = tsk_bit_array_init(self->node_samples, ts->num_samples, state_dim * num_nodes); + ret = tsk_bitset_init(self->node_samples, ts->num_samples, state_dim * num_nodes); if (ret != 0) { goto out; } @@ -2895,29 +2978,25 @@ iter_state_init(iter_state *self, const tsk_treeseq_t *ts, tsk_size_t state_dim) static int get_node_samples(const tsk_treeseq_t *ts, tsk_size_t state_dim, - const tsk_bit_array_t *sample_sets, tsk_bit_array_t *node_samples) + const tsk_bitset_t *sample_sets, tsk_bitset_t *node_samples) { int ret = 0; tsk_size_t n, k; - tsk_bit_array_t sample_set_row, node_samples_row; tsk_size_t num_nodes = ts->tables->nodes.num_rows; - tsk_bit_array_value_t sample; + tsk_bitset_val_t sample; const tsk_id_t *restrict sample_index_map = ts->sample_index_map; const tsk_flags_t *restrict flags = ts->tables->nodes.flags; - ret = tsk_bit_array_init(node_samples, ts->num_samples, num_nodes * state_dim); + ret = tsk_bitset_init(node_samples, ts->num_samples, num_nodes * state_dim); if (ret != 0) { goto out; } for (k = 0; k < state_dim; k++) { - tsk_bit_array_get_row(sample_sets, k, &sample_set_row); for (n = 0; n < num_nodes; n++) { if (flags[n] & TSK_NODE_IS_SAMPLE) { - sample = (tsk_bit_array_value_t) sample_index_map[n]; - if (tsk_bit_array_contains(&sample_set_row, sample)) { - tsk_bit_array_get_row( - node_samples, (state_dim * n) + k, &node_samples_row); - tsk_bit_array_add_bit(&node_samples_row, sample); + sample = (tsk_bitset_val_t) sample_index_map[n]; + if (tsk_bitset_contains(sample_sets, k, sample)) { + tsk_bitset_set_bit(node_samples, (state_dim * n) + k, sample); } } } @@ -2928,7 +3007,7 @@ get_node_samples(const tsk_treeseq_t *ts, tsk_size_t state_dim, static void iter_state_clear(iter_state *self, tsk_size_t state_dim, tsk_size_t num_nodes, - const tsk_bit_array_t *node_samples) + const tsk_bitset_t *node_samples) { self->n_edges_out = 0; self->n_edges_in = 0; @@ -2938,14 +3017,14 @@ iter_state_clear(iter_state *self, tsk_size_t state_dim, tsk_size_t num_nodes, tsk_memset(self->edges_in, TSK_NULL, num_nodes * sizeof(*self->edges_in)); tsk_memset(self->branch_len, 0, num_nodes * sizeof(*self->branch_len)); tsk_memcpy(self->node_samples->data, node_samples->data, - node_samples->size * state_dim * num_nodes * sizeof(*node_samples->data)); + node_samples->row_len * state_dim * num_nodes * sizeof(*node_samples->data)); } static void iter_state_free(iter_state *self) { tsk_tree_free(&self->tree); - tsk_bit_array_free(self->node_samples); + tsk_bitset_free(self->node_samples); tsk_safe_free(self->node_samples); tsk_safe_free(self->parent); tsk_safe_free(self->edges_out); @@ -3025,41 +3104,26 @@ static int compute_two_tree_branch_state_update(const tsk_treeseq_t *ts, tsk_id_t c, const iter_state *A_state, const iter_state *B_state, tsk_size_t state_dim, tsk_size_t result_dim, int sign, general_stat_func_t *f, - sample_count_stat_params_t *f_params, double *result) + sample_count_stat_params_t *f_params, two_locus_work_t *restrict work, + double *result) { int ret = 0; double a_len, b_len; double *restrict B_branch_len = B_state->branch_len; - double *weights = NULL, *weights_row, *result_tmp = NULL; + double *weights_row; tsk_size_t n, k, a_row, b_row; - tsk_bit_array_t A_samples, B_samples, AB_samples, B_samples_tmp; const double *restrict A_branch_len = A_state->branch_len; - const tsk_bit_array_t *restrict A_state_samples = A_state->node_samples; - const tsk_bit_array_t *restrict B_state_samples = B_state->node_samples; - tsk_size_t num_samples = ts->num_samples; + const tsk_bitset_t *restrict A_state_samples = A_state->node_samples; + const tsk_bitset_t *restrict B_state_samples = B_state->node_samples; tsk_size_t num_nodes = ts->tables->nodes.num_rows; + double *weights = work->weights; + double *result_tmp = work->result_tmp; + tsk_bitset_t AB_samples = work->AB_samples; + b_len = B_branch_len[c] * sign; if (b_len == 0) { return ret; } - - tsk_memset(&AB_samples, 0, sizeof(AB_samples)); - tsk_memset(&B_samples_tmp, 0, sizeof(B_samples_tmp)); - - weights = tsk_calloc(3 * state_dim, sizeof(*weights)); - result_tmp = tsk_calloc(result_dim, sizeof(*result_tmp)); - if (weights == NULL || result_tmp == NULL) { - ret = tsk_trace_error(TSK_ERR_NO_MEMORY); - goto out; - } - ret = tsk_bit_array_init(&AB_samples, num_samples, 1); - if (ret != 0) { - goto out; - } - ret = tsk_bit_array_init(&B_samples_tmp, num_samples, 1); - if (ret != 0) { - goto out; - } for (n = 0; n < num_nodes; n++) { a_len = A_branch_len[n]; if (a_len == 0) { @@ -3068,15 +3132,14 @@ compute_two_tree_branch_state_update(const tsk_treeseq_t *ts, tsk_id_t c, for (k = 0; k < state_dim; k++) { a_row = (state_dim * n) + k; b_row = (state_dim * (tsk_size_t) c) + k; - tsk_bit_array_get_row(A_state_samples, a_row, &A_samples); - tsk_bit_array_get_row(B_state_samples, b_row, &B_samples); - tsk_bit_array_intersect(&A_samples, &B_samples, &AB_samples); weights_row = GET_2D_ROW(weights, 3, k); - weights_row[0] = (double) tsk_bit_array_count(&AB_samples); // w_AB + tsk_bitset_intersect( + A_state_samples, a_row, B_state_samples, b_row, &AB_samples); + weights_row[0] = (double) tsk_bitset_count(&AB_samples, 0); weights_row[1] - = (double) tsk_bit_array_count(&A_samples) - weights_row[0]; // w_Ab + = (double) tsk_bitset_count(A_state_samples, a_row) - weights_row[0]; weights_row[2] - = (double) tsk_bit_array_count(&B_samples) - weights_row[0]; // w_aB + = (double) tsk_bitset_count(B_state_samples, b_row) - weights_row[0]; } ret = f(state_dim, weights, result_dim, result_tmp, f_params); if (ret != 0) { @@ -3087,10 +3150,6 @@ compute_two_tree_branch_state_update(const tsk_treeseq_t *ts, tsk_id_t c, } } out: - tsk_safe_free(weights); - tsk_safe_free(result_tmp); - tsk_bit_array_free(&AB_samples); - tsk_bit_array_free(&B_samples_tmp); return ret; } @@ -3106,10 +3165,17 @@ compute_two_tree_branch_stat(const tsk_treeseq_t *ts, const iter_state *l_state, const tsk_id_t *restrict edges_child = ts->tables->edges.child; const tsk_id_t *restrict edges_parent = ts->tables->edges.parent; const tsk_size_t num_nodes = ts->tables->nodes.num_rows; - tsk_bit_array_t updates, row, ec_row, *r_samples = r_state->node_samples; + tsk_bitset_t updates, *r_samples = r_state->node_samples; + two_locus_work_t work; + tsk_memset(&work, 0, sizeof(work)); tsk_memset(&updates, 0, sizeof(updates)); - ret = tsk_bit_array_init(&updates, num_nodes, 1); + // only two alleles are possible for branch stats + ret = two_locus_work_init(2, ts->num_samples, result_dim, state_dim, &work); + if (ret != 0) { + goto out; + } + ret = tsk_bitset_init(&updates, num_nodes, 1); if (ret != 0) { goto out; } @@ -3126,18 +3192,18 @@ compute_two_tree_branch_stat(const tsk_treeseq_t *ts, const iter_state *l_state, c = edges_child[e]; // Identify affected nodes above child while (p != TSK_NULL) { - tsk_bit_array_add_bit(&updates, (tsk_bit_array_value_t) c); + tsk_bitset_set_bit(&updates, 0, (tsk_bitset_val_t) c); c = p; p = r_state->parent[p]; } } // Subtract the whole contribution from the child node - tsk_bit_array_get_items(&updates, updated_nodes, &n_updates); + tsk_bitset_get_items(&updates, 0, updated_nodes, &n_updates); while (n_updates != 0) { n_updates--; c = updated_nodes[n_updates]; - compute_two_tree_branch_state_update( - ts, c, l_state, r_state, state_dim, result_dim, -1, f, f_params, result); + compute_two_tree_branch_state_update(ts, c, l_state, r_state, state_dim, + result_dim, -1, f, f_params, &work, result); } // Remove samples under nodes from removed edges to parent nodes for (j = 0; j < r_state->n_edges_out; j++) { @@ -3146,10 +3212,8 @@ compute_two_tree_branch_stat(const tsk_treeseq_t *ts, const iter_state *l_state, ec = edges_child[e]; // edge child while (p != TSK_NULL) { for (k = 0; k < state_dim; k++) { - tsk_bit_array_get_row( - r_samples, (state_dim * (tsk_size_t) ec) + k, &ec_row); - tsk_bit_array_get_row(r_samples, (state_dim * (tsk_size_t) p) + k, &row); - tsk_bit_array_subtract(&row, &ec_row); + tsk_bitset_subtract(r_samples, (state_dim * (tsk_size_t) p) + k, + r_samples, (state_dim * (tsk_size_t) ec) + k); } p = r_state->parent[p]; } @@ -3164,12 +3228,10 @@ compute_two_tree_branch_stat(const tsk_treeseq_t *ts, const iter_state *l_state, r_state->branch_len[c] = time[p] - time[c]; r_state->parent[c] = p; while (p != TSK_NULL) { - tsk_bit_array_add_bit(&updates, (tsk_bit_array_value_t) c); + tsk_bitset_set_bit(&updates, 0, (tsk_bitset_val_t) c); for (k = 0; k < state_dim; k++) { - tsk_bit_array_get_row( - r_samples, (state_dim * (tsk_size_t) ec) + k, &ec_row); - tsk_bit_array_get_row(r_samples, (state_dim * (tsk_size_t) p) + k, &row); - tsk_bit_array_add(&row, &ec_row); + tsk_bitset_union(r_samples, (state_dim * (tsk_size_t) p) + k, r_samples, + (state_dim * (tsk_size_t) ec) + k); } c = p; p = r_state->parent[p]; @@ -3177,22 +3239,24 @@ compute_two_tree_branch_stat(const tsk_treeseq_t *ts, const iter_state *l_state, } // Update all affected child nodes (fully subtracted, deferred from addition) n_updates = 0; - tsk_bit_array_get_items(&updates, updated_nodes, &n_updates); + tsk_bitset_get_items(&updates, 0, updated_nodes, &n_updates); while (n_updates != 0) { n_updates--; c = updated_nodes[n_updates]; - compute_two_tree_branch_state_update( - ts, c, l_state, r_state, state_dim, result_dim, +1, f, f_params, result); + compute_two_tree_branch_state_update(ts, c, l_state, r_state, state_dim, + result_dim, +1, f, f_params, &work, result); } out: tsk_safe_free(updated_nodes); - tsk_bit_array_free(&updates); + two_locus_work_free(&work); + tsk_bitset_free(&updates); return ret; } static int tsk_treeseq_two_branch_count_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, - const tsk_bit_array_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f, + tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, + const tsk_id_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f, sample_count_stat_params_t *f_params, norm_func_t *TSK_UNUSED(norm_f), tsk_size_t n_rows, const double *row_positions, tsk_size_t n_cols, const double *col_positions, tsk_flags_t TSK_UNUSED(options), double *result) @@ -3201,11 +3265,12 @@ tsk_treeseq_two_branch_count_stat(const tsk_treeseq_t *self, tsk_size_t state_di int r, c; tsk_id_t *row_indexes = NULL, *col_indexes = NULL; tsk_size_t i, j, k, row, col, *row_repeats = NULL, *col_repeats = NULL; - tsk_bit_array_t node_samples; + tsk_bitset_t node_samples, sample_sets_bits; iter_state l_state, r_state; double *result_tmp = NULL, *result_row; const tsk_size_t num_nodes = self->tables->nodes.num_rows; + tsk_memset(&sample_sets_bits, 0, sizeof(sample_sets_bits)); tsk_memset(&node_samples, 0, sizeof(node_samples)); tsk_memset(&l_state, 0, sizeof(l_state)); tsk_memset(&r_state, 0, sizeof(r_state)); @@ -3222,6 +3287,11 @@ tsk_treeseq_two_branch_count_stat(const tsk_treeseq_t *self, tsk_size_t state_di if (ret != 0) { goto out; } + ret = sample_sets_to_bitset( + self, sample_set_sizes, sample_sets, num_sample_sets, &sample_sets_bits); + if (ret != 0) { + goto out; + } ret = positions_to_tree_indexes(self, row_positions, n_rows, &row_indexes); if (ret != 0) { goto out; @@ -3238,7 +3308,7 @@ tsk_treeseq_two_branch_count_stat(const tsk_treeseq_t *self, tsk_size_t state_di if (ret != 0) { goto out; } - ret = get_node_samples(self, state_dim, sample_sets, &node_samples); + ret = get_node_samples(self, state_dim, &sample_sets_bits, &node_samples); if (ret != 0) { goto out; } @@ -3289,7 +3359,42 @@ tsk_treeseq_two_branch_count_stat(const tsk_treeseq_t *self, tsk_size_t state_di tsk_safe_free(col_repeats); iter_state_free(&l_state); iter_state_free(&r_state); - tsk_bit_array_free(&node_samples); + tsk_bitset_free(&node_samples); + tsk_bitset_free(&sample_sets_bits); + return ret; +} + +static int +check_sample_set_dups(tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, + const tsk_id_t *sample_sets, const tsk_id_t *restrict sample_index_map, + tsk_size_t num_samples) +{ + int ret; + tsk_size_t j, k, l; + tsk_id_t u, sample_index; + tsk_bitset_t tmp; + + tsk_memset(&tmp, 0, sizeof(tmp)); + ret = tsk_bitset_init(&tmp, num_samples, 1); + if (ret != 0) { + return ret; + } + j = 0; + for (k = 0; k < num_sample_sets; k++) { + tsk_memset(tmp.data, 0, sizeof(*tmp.data) * tmp.row_len); + for (l = 0; l < sample_set_sizes[k]; l++) { + u = sample_sets[j]; + sample_index = sample_index_map[u]; + if (tsk_bitset_contains(&tmp, 0, (tsk_bitset_val_t) sample_index)) { + ret = tsk_trace_error(TSK_ERR_DUPLICATE_SAMPLE); + goto out; + } + tsk_bitset_set_bit(&tmp, 0, (tsk_bitset_val_t) sample_index); + j++; + } + } +out: + tsk_bitset_free(&tmp); return ret; } @@ -3304,7 +3409,6 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl // TODO: generalize this function if we ever decide to do weighted two_locus stats. // We only implement count stats and therefore we don't handle weights. int ret = 0; - tsk_bit_array_t sample_sets_bits; bool stat_site = !!(options & TSK_STAT_SITE); bool stat_branch = !!(options & TSK_STAT_BRANCH); tsk_size_t state_dim = num_sample_sets; @@ -3313,8 +3417,6 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl .sample_set_sizes = sample_set_sizes, .set_indexes = set_indexes }; - tsk_memset(&sample_sets_bits, 0, sizeof(sample_sets_bits)); - // We do not support two-locus node stats if (!!(options & TSK_STAT_NODE)) { ret = tsk_trace_error(TSK_ERR_UNSUPPORTED_STAT_MODE); @@ -3338,12 +3440,6 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl ret = tsk_trace_error(TSK_ERR_BAD_RESULT_DIMS); goto out; } - ret = sample_sets_to_bit_array( - self, sample_set_sizes, sample_sets, num_sample_sets, &sample_sets_bits); - if (ret != 0) { - goto out; - } - if (stat_site) { ret = check_sites(row_sites, out_rows, self->tables->sites.num_rows); if (ret != 0) { @@ -3353,9 +3449,15 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl if (ret != 0) { goto out; } - ret = tsk_treeseq_two_site_count_stat(self, state_dim, &sample_sets_bits, - result_dim, f, &f_params, norm_f, out_rows, row_sites, out_cols, col_sites, - options, result); + ret = check_sample_set_dups(num_sample_sets, sample_set_sizes, sample_sets, + self->sample_index_map, self->num_samples); + if (ret != 0) { + goto out; + } + // TODO: result dim/state dim can be set internally now. + ret = tsk_treeseq_two_site_count_stat(self, state_dim, num_sample_sets, + sample_set_sizes, sample_sets, result_dim, f, &f_params, norm_f, out_rows, + row_sites, out_cols, col_sites, options, result); } else if (stat_branch) { ret = check_positions( row_positions, out_rows, tsk_treeseq_get_sequence_length(self)); @@ -3367,13 +3469,11 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl if (ret != 0) { goto out; } - ret = tsk_treeseq_two_branch_count_stat(self, state_dim, &sample_sets_bits, - result_dim, f, &f_params, norm_f, out_rows, row_positions, out_cols, - col_positions, options, result); + ret = tsk_treeseq_two_branch_count_stat(self, state_dim, num_sample_sets, + sample_set_sizes, sample_sets, result_dim, f, &f_params, norm_f, out_rows, + row_positions, out_cols, col_positions, options, result); } - out: - tsk_bit_array_free(&sample_sets_bits); return ret; }