diff --git a/c/tests/test_core.c b/c/tests/test_core.c index 6a14ecc655..1b97e4485f 100644 --- a/c/tests/test_core.c +++ b/c/tests/test_core.c @@ -531,22 +531,24 @@ test_bit_arrays(void) { // NB: This test is only valid for the 32 bit implementation of bit arrays. If we // were to change the chunk size of a bit array, we'd need to update these tests - tsk_bit_array_t arr; + tsk_bitset_t arr; tsk_id_t items_truth[64] = { 0 }, items[64] = { 0 }; tsk_size_t n_items = 0, n_items_truth = 0; // test item retrieval - tsk_bit_array_init(&arr, 90, 1); - tsk_bit_array_get_items(&arr, items, &n_items); + tsk_bitset_init(&arr, 90, 1); + CU_ASSERT_EQUAL_FATAL(arr.len, 1); + CU_ASSERT_EQUAL_FATAL(arr.row_len, 3); + tsk_bitset_get_items(&arr, 0, items, &n_items); assert_arrays_equal(n_items_truth, items, items_truth); - for (tsk_bit_array_value_t i = 0; i < 20; i++) { - tsk_bit_array_add_bit(&arr, i); + for (tsk_bitset_val_t i = 0; i < 20; i++) { + tsk_bitset_set_bit(&arr, 0, i); items_truth[n_items_truth] = (tsk_id_t) i; n_items_truth++; } - tsk_bit_array_add_bit(&arr, 63); - tsk_bit_array_add_bit(&arr, 65); + tsk_bitset_set_bit(&arr, 0, 63); + tsk_bitset_set_bit(&arr, 0, 65); // these assertions are only valid for 32-bit values CU_ASSERT_EQUAL_FATAL(arr.data[0], 1048575); @@ -554,32 +556,29 @@ test_bit_arrays(void) CU_ASSERT_EQUAL_FATAL(arr.data[2], 2); // verify our assumptions about bit array counting - CU_ASSERT_EQUAL_FATAL(tsk_bit_array_count(&arr), 22); + CU_ASSERT_EQUAL_FATAL(tsk_bitset_count(&arr, 0), 22); - tsk_bit_array_get_items(&arr, items, &n_items); + tsk_bitset_get_items(&arr, 0, items, &n_items); assert_arrays_equal(n_items_truth, items, items_truth); tsk_memset(items, 0, 64); tsk_memset(items_truth, 0, 64); n_items = n_items_truth = 0; - tsk_bit_array_free(&arr); + tsk_bitset_free(&arr); - // create a length-2 array with 64 bit capacity - tsk_bit_array_init(&arr, 64, 2); - tsk_bit_array_t arr_row1, arr_row2; - - // select the first and second row - tsk_bit_array_get_row(&arr, 0, &arr_row1); - tsk_bit_array_get_row(&arr, 1, &arr_row2); + // create a length-2 array with 64 bit capacity (two chunks per row) + tsk_bitset_init(&arr, 64, 2); + CU_ASSERT_EQUAL_FATAL(arr.len, 2); + CU_ASSERT_EQUAL_FATAL(arr.row_len, 2); // fill the first 50 bits of the first row - for (tsk_bit_array_value_t i = 0; i < 50; i++) { - tsk_bit_array_add_bit(&arr_row1, i); + for (tsk_bitset_val_t i = 0; i < 50; i++) { + tsk_bitset_set_bit(&arr, 0, i); items_truth[n_items_truth] = (tsk_id_t) i; n_items_truth++; } - tsk_bit_array_get_items(&arr_row1, items, &n_items); + tsk_bitset_get_items(&arr, 0, items, &n_items); assert_arrays_equal(n_items_truth, items, items_truth); tsk_memset(items, 0, 64); @@ -587,13 +586,13 @@ test_bit_arrays(void) n_items = n_items_truth = 0; // fill bits 20-40 of the second row - for (tsk_bit_array_value_t i = 20; i < 40; i++) { - tsk_bit_array_add_bit(&arr_row2, i); + for (tsk_bitset_val_t i = 20; i < 40; i++) { + tsk_bitset_set_bit(&arr, 1, i); items_truth[n_items_truth] = (tsk_id_t) i; n_items_truth++; } - tsk_bit_array_get_items(&arr_row2, items, &n_items); + tsk_bitset_get_items(&arr, 1, items, &n_items); assert_arrays_equal(n_items_truth, items, items_truth); tsk_memset(items, 0, 64); @@ -601,41 +600,38 @@ test_bit_arrays(void) n_items = n_items_truth = 0; // verify our assumptions about row selection - CU_ASSERT_EQUAL_FATAL(arr.data[0], 4294967295); - CU_ASSERT_EQUAL_FATAL(arr.data[1], 262143); - CU_ASSERT_EQUAL_FATAL(arr_row1.data[0], 4294967295); - CU_ASSERT_EQUAL_FATAL(arr_row1.data[1], 262143); - - CU_ASSERT_EQUAL_FATAL(arr.data[2], 4293918720); - CU_ASSERT_EQUAL_FATAL(arr.data[3], 255); - CU_ASSERT_EQUAL_FATAL(arr_row2.data[0], 4293918720); - CU_ASSERT_EQUAL_FATAL(arr_row2.data[1], 255); + CU_ASSERT_EQUAL_FATAL(arr.data[0], 4294967295); // row1 elem1 + CU_ASSERT_EQUAL_FATAL(arr.data[1], 262143); // row1 elem2 + CU_ASSERT_EQUAL_FATAL(arr.data[2], 4293918720); // row2 elem1 + CU_ASSERT_EQUAL_FATAL(arr.data[3], 255); // row2 elem2 // subtract the second from the first row, store in first - tsk_bit_array_subtract(&arr_row1, &arr_row2); + tsk_bitset_subtract(&arr, 0, &arr, 1); // verify our assumptions about subtraction - CU_ASSERT_EQUAL_FATAL(arr_row1.data[0], 1048575); - CU_ASSERT_EQUAL_FATAL(arr_row1.data[1], 261888); + CU_ASSERT_EQUAL_FATAL(arr.data[0], 1048575); + CU_ASSERT_EQUAL_FATAL(arr.data[1], 261888); - tsk_bit_array_t int_result; - tsk_bit_array_init(&int_result, 64, 1); + tsk_bitset_t int_result; + tsk_bitset_init(&int_result, 64, 1); + CU_ASSERT_EQUAL_FATAL(int_result.len, 1); + CU_ASSERT_EQUAL_FATAL(int_result.row_len, 2); // their intersection should be zero - tsk_bit_array_intersect(&arr_row1, &arr_row2, &int_result); + tsk_bitset_intersect(&arr, 0, &arr, 1, &int_result); CU_ASSERT_EQUAL_FATAL(int_result.data[0], 0); CU_ASSERT_EQUAL_FATAL(int_result.data[1], 0); // now, add them back together, storing back in a - tsk_bit_array_add(&arr_row1, &arr_row2); + tsk_bitset_union(&arr, 0, &arr, 1); // now, their intersection should be the subtracted chunk (20-40) - tsk_bit_array_intersect(&arr_row1, &arr_row2, &int_result); + tsk_bitset_intersect(&arr, 0, &arr, 1, &int_result); CU_ASSERT_EQUAL_FATAL(int_result.data[0], 4293918720); CU_ASSERT_EQUAL_FATAL(int_result.data[1], 255); - tsk_bit_array_free(&int_result); - tsk_bit_array_free(&arr); + tsk_bitset_free(&int_result); + tsk_bitset_free(&arr); } static void diff --git a/c/tskit/core.c b/c/tskit/core.c index 53cc0ce679..44ab0a18bd 100644 --- a/c/tskit/core.c +++ b/c/tskit/core.c @@ -1260,16 +1260,16 @@ tsk_avl_tree_int_ordered_nodes(const tsk_avl_tree_int_t *self, tsk_avl_node_int_ } // Bit Array implementation. Allows us to store unsigned integers in a compact manner. -// Currently implemented as an array of 32-bit unsigned integers for ease of counting. +// Currently implemented as an array of 32-bit unsigned integers. int -tsk_bit_array_init(tsk_bit_array_t *self, tsk_size_t num_bits, tsk_size_t length) +tsk_bitset_init(tsk_bitset_t *self, tsk_size_t num_bits, tsk_size_t length) { int ret = 0; - self->size = (num_bits >> TSK_BIT_ARRAY_CHUNK) - + (num_bits % TSK_BIT_ARRAY_NUM_BITS ? 1 : 0); - self->data = tsk_calloc(self->size * length, sizeof(*self->data)); + self->row_len = (num_bits / TSK_BITSET_BITS) + (num_bits % TSK_BITSET_BITS ? 1 : 0); + self->len = length; + self->data = tsk_calloc(self->row_len * length, sizeof(*self->data)); if (self->data == NULL) { ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; @@ -1278,96 +1278,111 @@ tsk_bit_array_init(tsk_bit_array_t *self, tsk_size_t num_bits, tsk_size_t length return ret; } -void -tsk_bit_array_get_row(const tsk_bit_array_t *self, tsk_size_t row, tsk_bit_array_t *out) -{ - out->size = self->size; - out->data = self->data + (row * self->size); -} +#define BITSET_DATA_ROW(bs, row) (bs)->data + (row) * (bs)->row_len void -tsk_bit_array_intersect( - const tsk_bit_array_t *self, const tsk_bit_array_t *other, tsk_bit_array_t *out) +tsk_bitset_intersect(const tsk_bitset_t *self, tsk_size_t self_row, + const tsk_bitset_t *other, tsk_size_t other_row, tsk_bitset_t *out) { - for (tsk_size_t i = 0; i < self->size; i++) { - out->data[i] = self->data[i] & other->data[i]; + const tsk_bitset_val_t *restrict self_d = BITSET_DATA_ROW(self, self_row); + const tsk_bitset_val_t *restrict other_d = BITSET_DATA_ROW(other, other_row); + tsk_bitset_val_t *restrict out_d = out->data; + for (tsk_size_t i = 0; i < self->row_len; i++) { + out_d[i] = self_d[i] & other_d[i]; } } void -tsk_bit_array_subtract(tsk_bit_array_t *self, const tsk_bit_array_t *other) +tsk_bitset_subtract(tsk_bitset_t *self, tsk_size_t self_row, const tsk_bitset_t *other, + tsk_size_t other_row) { - for (tsk_size_t i = 0; i < self->size; i++) { - self->data[i] &= ~(other->data[i]); + tsk_bitset_val_t *restrict self_d = BITSET_DATA_ROW(self, self_row); + const tsk_bitset_val_t *restrict other_d = BITSET_DATA_ROW(other, other_row); + for (tsk_size_t i = 0; i < self->row_len; i++) { + self_d[i] &= ~(other_d[i]); } } void -tsk_bit_array_add(tsk_bit_array_t *self, const tsk_bit_array_t *other) +tsk_bitset_union(tsk_bitset_t *self, tsk_size_t self_row, const tsk_bitset_t *other, + tsk_size_t other_row) { - for (tsk_size_t i = 0; i < self->size; i++) { - self->data[i] |= other->data[i]; + tsk_bitset_val_t *restrict self_d = BITSET_DATA_ROW(self, self_row); + const tsk_bitset_val_t *restrict other_d = BITSET_DATA_ROW(other, other_row); + for (tsk_size_t i = 0; i < self->row_len; i++) { + self_d[i] |= other_d[i]; } } void -tsk_bit_array_add_bit(tsk_bit_array_t *self, const tsk_bit_array_value_t bit) +tsk_bitset_set_bit(tsk_bitset_t *self, tsk_size_t row, const tsk_bitset_val_t bit) { - tsk_bit_array_value_t i = bit >> TSK_BIT_ARRAY_CHUNK; - self->data[i] |= (tsk_bit_array_value_t) 1 << (bit - (TSK_BIT_ARRAY_NUM_BITS * i)); + tsk_bitset_val_t i = (bit / TSK_BITSET_BITS); + *(BITSET_DATA_ROW(self, row) + i) |= (tsk_bitset_val_t) 1 + << (bit - (TSK_BITSET_BITS * i)); } bool -tsk_bit_array_contains(const tsk_bit_array_t *self, const tsk_bit_array_value_t bit) +tsk_bitset_contains(const tsk_bitset_t *self, tsk_size_t row, const tsk_bitset_val_t bit) { - tsk_bit_array_value_t i = bit >> TSK_BIT_ARRAY_CHUNK; - return self->data[i] - & ((tsk_bit_array_value_t) 1 << (bit - (TSK_BIT_ARRAY_NUM_BITS * i))); + tsk_bitset_val_t i = (bit / TSK_BITSET_BITS); + return *(BITSET_DATA_ROW(self, row) + i) + & ((tsk_bitset_val_t) 1 << (bit - (TSK_BITSET_BITS * i))); } -tsk_size_t -tsk_bit_array_count(const tsk_bit_array_t *self) +static inline uint32_t +popcount(tsk_bitset_val_t v) { - // Utilizes 12 operations per bit array. NB this only works on 32 bit integers. + // Utilizes 12 operations per chunk. NB this only works on 32 bit integers. // Taken from: // https://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel // There's a nice breakdown of this algorithm here: // https://stackoverflow.com/a/109025 - // Could probably do better with explicit SIMD (instead of SWAR), but not as - // portable: https://arxiv.org/pdf/1611.07612.pdf // - // There is one solution to explore further, which uses __builtin_popcountll. - // This option is relatively simple, but requires a 64 bit bit array and also - // involves some compiler flag plumbing (-mpopcnt) + // The gcc/clang compiler flag will -mpopcnt will convert this code to a + // popcnt instruction (most if not all modern CPUs will support this). The + // popcnt instruction will yield some speed improvements, which depend on + // the tree sequence. + // + // NB: 32bit counting is typically faster than 64bit counting for this task. + // (at least on x86-64) - tsk_bit_array_value_t tmp; - tsk_size_t i, count = 0; + v = v - ((v >> 1) & 0x55555555); + v = (v & 0x33333333) + ((v >> 2) & 0x33333333); + return (((v + (v >> 4)) & 0xF0F0F0F) * 0x1010101) >> 24; +} + +tsk_size_t +tsk_bitset_count(const tsk_bitset_t *self, tsk_size_t row) +{ + tsk_size_t i = 0; + uint32_t count = 0; + const tsk_bitset_val_t *restrict self_d = BITSET_DATA_ROW(self, row); - for (i = 0; i < self->size; i++) { - tmp = self->data[i] - ((self->data[i] >> 1) & 0x55555555); - tmp = (tmp & 0x33333333) + ((tmp >> 2) & 0x33333333); - count += (((tmp + (tmp >> 4)) & 0xF0F0F0F) * 0x1010101) >> 24; + for (i = 0; i < self->row_len; i++) { + count += popcount(self_d[i]); } - return count; + return (tsk_size_t) count; } void -tsk_bit_array_get_items( - const tsk_bit_array_t *self, tsk_id_t *items, tsk_size_t *n_items) +tsk_bitset_get_items( + const tsk_bitset_t *self, tsk_size_t row, tsk_id_t *items, tsk_size_t *n_items) { // Get the items stored in the row of a bitset. - // Uses a de Bruijn sequence lookup table to determine the lowest bit set. See the - // wikipedia article for more info: https://w.wiki/BYiF + // Uses a de Bruijn sequence lookup table to determine the lowest bit set. + // See the wikipedia article for more info: https://w.wiki/BYiF tsk_size_t i, n, off; - tsk_bit_array_value_t v, lsb; // least significant bit + tsk_bitset_val_t v, lsb; // least significant bit static const tsk_id_t lookup[32] = { 0, 1, 28, 2, 29, 14, 24, 3, 30, 22, 20, 15, 25, 17, 4, 8, 31, 27, 13, 23, 21, 19, 16, 7, 26, 12, 18, 6, 11, 5, 10, 9 }; + const tsk_bitset_val_t *restrict self_d = BITSET_DATA_ROW(self, row); n = 0; - for (i = 0; i < self->size; i++) { - v = self->data[i]; - off = i * ((tsk_size_t) TSK_BIT_ARRAY_NUM_BITS); + for (i = 0; i < self->row_len; i++) { + v = self_d[i]; + off = i * TSK_BITSET_BITS; if (v == 0) { continue; } @@ -1381,7 +1396,7 @@ tsk_bit_array_get_items( } void -tsk_bit_array_free(tsk_bit_array_t *self) +tsk_bitset_free(tsk_bitset_t *self) { tsk_safe_free(self->data); } diff --git a/c/tskit/core.h b/c/tskit/core.h index 7dd24eba56..dede20d7d5 100644 --- a/c/tskit/core.h +++ b/c/tskit/core.h @@ -1104,29 +1104,31 @@ FILE *tsk_get_debug_stream(void); /* Bit Array functionality */ -typedef uint32_t tsk_bit_array_value_t; +// define a 32-bit chunk size for our bitsets. +// this means we'll be able to hold 32 distinct items in each 32 bit uint +#define TSK_BITSET_BITS (tsk_size_t) 32 +typedef uint32_t tsk_bitset_val_t; + typedef struct { - tsk_size_t size; // Number of chunks per row - tsk_bit_array_value_t *data; // Array data -} tsk_bit_array_t; - -#define TSK_BIT_ARRAY_CHUNK 5U -#define TSK_BIT_ARRAY_NUM_BITS (1U << TSK_BIT_ARRAY_CHUNK) - -int tsk_bit_array_init(tsk_bit_array_t *self, tsk_size_t num_bits, tsk_size_t length); -void tsk_bit_array_free(tsk_bit_array_t *self); -void tsk_bit_array_get_row( - const tsk_bit_array_t *self, tsk_size_t row, tsk_bit_array_t *out); -void tsk_bit_array_intersect( - const tsk_bit_array_t *self, const tsk_bit_array_t *other, tsk_bit_array_t *out); -void tsk_bit_array_subtract(tsk_bit_array_t *self, const tsk_bit_array_t *other); -void tsk_bit_array_add(tsk_bit_array_t *self, const tsk_bit_array_t *other); -void tsk_bit_array_add_bit(tsk_bit_array_t *self, const tsk_bit_array_value_t bit); -bool tsk_bit_array_contains( - const tsk_bit_array_t *self, const tsk_bit_array_value_t bit); -tsk_size_t tsk_bit_array_count(const tsk_bit_array_t *self); -void tsk_bit_array_get_items( - const tsk_bit_array_t *self, tsk_id_t *items, tsk_size_t *n_items); + tsk_size_t row_len; // Number of size TSK_BITSET_BITS chunks per row + tsk_size_t len; // Number of rows + tsk_bitset_val_t *data; +} tsk_bitset_t; + +int tsk_bitset_init(tsk_bitset_t *self, tsk_size_t num_bits, tsk_size_t length); +void tsk_bitset_free(tsk_bitset_t *self); +void tsk_bitset_intersect(const tsk_bitset_t *self, tsk_size_t self_row, + const tsk_bitset_t *other, tsk_size_t other_row, tsk_bitset_t *out); +void tsk_bitset_subtract(tsk_bitset_t *self, tsk_size_t self_row, + const tsk_bitset_t *other, tsk_size_t other_row); +void tsk_bitset_union(tsk_bitset_t *self, tsk_size_t self_row, const tsk_bitset_t *other, + tsk_size_t other_row); +void tsk_bitset_set_bit(tsk_bitset_t *self, tsk_size_t row, const tsk_bitset_val_t bit); +bool tsk_bitset_contains( + const tsk_bitset_t *self, tsk_size_t row, const tsk_bitset_val_t bit); +tsk_size_t tsk_bitset_count(const tsk_bitset_t *self, tsk_size_t row); +void tsk_bitset_get_items( + const tsk_bitset_t *self, tsk_size_t row, tsk_id_t *items, tsk_size_t *n_items); #ifdef __cplusplus } diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 7a159a7fe4..16d72ef4b5 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -2223,20 +2223,18 @@ tsk_treeseq_sample_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_s ***********************************/ static int -get_allele_samples(const tsk_site_t *site, const tsk_bit_array_t *state, - tsk_bit_array_t *out_allele_samples, tsk_size_t *out_num_alleles) +get_allele_samples(const tsk_site_t *site, tsk_size_t site_offset, + const tsk_bitset_t *state, tsk_bitset_t *out_allele_samples, + tsk_size_t *out_num_alleles) { int ret = 0; tsk_mutation_t mutation, parent_mut; - tsk_size_t mutation_index, allele, alt_allele_length; + tsk_size_t mutation_index, allele, alt_allele, alt_allele_length; /* The allele table */ tsk_size_t max_alleles = site->mutations_length + 1; const char **alleles = tsk_malloc(max_alleles * sizeof(*alleles)); tsk_size_t *allele_lengths = tsk_calloc(max_alleles, sizeof(*allele_lengths)); - const char *alt_allele; - tsk_bit_array_t state_row; - tsk_bit_array_t allele_samples_row; - tsk_bit_array_t alt_allele_samples_row; + const char *alt_allele_state; tsk_size_t num_alleles = 1; if (alleles == NULL || allele_lengths == NULL) { @@ -2267,29 +2265,29 @@ get_allele_samples(const tsk_site_t *site, const tsk_bit_array_t *state, } /* Add the mutation's samples to this allele */ - tsk_bit_array_get_row(out_allele_samples, allele, &allele_samples_row); - tsk_bit_array_get_row(state, mutation_index, &state_row); - tsk_bit_array_add(&allele_samples_row, &state_row); + tsk_bitset_union( + out_allele_samples, allele + site_offset, state, mutation_index); /* Get the index for the alternate allele that we must subtract from */ - alt_allele = site->ancestral_state; + alt_allele_state = site->ancestral_state; alt_allele_length = site->ancestral_state_length; if (mutation.parent != TSK_NULL) { parent_mut = site->mutations[mutation.parent - site->mutations[0].id]; - alt_allele = parent_mut.derived_state; + alt_allele_state = parent_mut.derived_state; alt_allele_length = parent_mut.derived_state_length; } - for (allele = 0; allele < num_alleles; allele++) { - if (alt_allele_length == allele_lengths[allele] - && tsk_memcmp(alt_allele, alleles[allele], allele_lengths[allele]) + for (alt_allele = 0; alt_allele < num_alleles; alt_allele++) { + if (alt_allele_length == allele_lengths[alt_allele] + && tsk_memcmp( + alt_allele_state, alleles[alt_allele], allele_lengths[alt_allele]) == 0) { break; } } tsk_bug_assert(allele < num_alleles); - tsk_bit_array_get_row(out_allele_samples, allele, &alt_allele_samples_row); - tsk_bit_array_subtract(&alt_allele_samples_row, &allele_samples_row); + tsk_bitset_subtract(out_allele_samples, alt_allele + site_offset, + out_allele_samples, allele + site_offset); } *out_num_alleles = num_alleles; out: @@ -2310,7 +2308,6 @@ norm_hap_weighted(tsk_size_t result_dim, const double *hap_weights, for (k = 0; k < result_dim; k++) { weight_row = GET_2D_ROW(hap_weights, 3, k); n = (double) args.sample_set_sizes[k]; - // TODO: what to do when n = 0 result[k] = weight_row[0] / n; } return 0; @@ -2347,111 +2344,107 @@ norm_total_weighted(tsk_size_t result_dim, const double *TSK_UNUSED(hap_weights) tsk_size_t n_a, tsk_size_t n_b, double *result, void *TSK_UNUSED(params)) { tsk_size_t k; + double norm = 1 / (double) (n_a * n_b); for (k = 0; k < result_dim; k++) { - result[k] = 1 / (double) (n_a * n_b); + result[k] = norm; } return 0; } static void -get_all_samples_bits(tsk_bit_array_t *all_samples, tsk_size_t n) +get_all_samples_bits(tsk_bitset_t *all_samples, tsk_size_t n) { tsk_size_t i; - const tsk_bit_array_value_t all = ~((tsk_bit_array_value_t) 0); - const tsk_bit_array_value_t remainder_samples = n % TSK_BIT_ARRAY_NUM_BITS; + const tsk_bitset_val_t all = ~((tsk_bitset_val_t) 0); + const tsk_bitset_val_t remainder_samples = n % TSK_BITSET_BITS; - all_samples->data[all_samples->size - 1] + all_samples->data[all_samples->row_len - 1] = remainder_samples ? ~(all << remainder_samples) : all; - for (i = 0; i < all_samples->size - 1; i++) { + for (i = 0; i < all_samples->row_len - 1; i++) { all_samples->data[i] = all; } } +typedef struct { + double *weights; + double *norm; + double *result_tmp; + tsk_bitset_t AB_samples; +} two_locus_work_t; + static int -compute_general_two_site_stat_result(const tsk_bit_array_t *site_a_state, - const tsk_bit_array_t *site_b_state, tsk_size_t num_a_alleles, - tsk_size_t num_b_alleles, tsk_size_t num_samples, tsk_size_t state_dim, - const tsk_bit_array_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f, - sample_count_stat_params_t *f_params, norm_func_t *norm_f, bool polarised, - double *result) +two_locus_work_init(tsk_size_t max_alleles, tsk_size_t num_samples, + tsk_size_t result_dim, tsk_size_t state_dim, two_locus_work_t *out) { int ret = 0; - tsk_bit_array_t A_samples, B_samples; - // ss_ prefix refers to a sample set - tsk_bit_array_t ss_row; - tsk_bit_array_t ss_A_samples, ss_B_samples, ss_AB_samples, AB_samples; - // Sample sets and b sites are rows, a sites are columns - // b1 b2 b3 - // a1 [s1, s2, s3] [s1, s2, s3] [s1, s2, s3] - // a2 [s1, s2, s3] [s1, s2, s3] [s1, s2, s3] - // a3 [s1, s2, s3] [s1, s2, s3] [s1, s2, s3] - tsk_size_t k, mut_a, mut_b; - tsk_size_t result_row_len = num_b_alleles * result_dim; - tsk_size_t w_A = 0, w_B = 0, w_AB = 0; - uint8_t polarised_val = polarised ? 1 : 0; - double *hap_weight_row; - double *result_tmp_row; - double *weights = tsk_malloc(3 * state_dim * sizeof(*weights)); - double *norm = tsk_malloc(result_dim * sizeof(*norm)); - double *result_tmp - = tsk_malloc(result_row_len * num_a_alleles * sizeof(*result_tmp)); - - tsk_memset(&ss_A_samples, 0, sizeof(ss_A_samples)); - tsk_memset(&ss_B_samples, 0, sizeof(ss_B_samples)); - tsk_memset(&ss_AB_samples, 0, sizeof(ss_AB_samples)); - tsk_memset(&AB_samples, 0, sizeof(AB_samples)); - - if (weights == NULL || norm == NULL || result_tmp == NULL) { - ret = tsk_trace_error(TSK_ERR_NO_MEMORY); - goto out; - } - ret = tsk_bit_array_init(&ss_A_samples, num_samples, 1); - if (ret != 0) { - goto out; - } - ret = tsk_bit_array_init(&ss_B_samples, num_samples, 1); - if (ret != 0) { - goto out; - } - ret = tsk_bit_array_init(&ss_AB_samples, num_samples, 1); - if (ret != 0) { + out->weights = tsk_malloc(3 * state_dim * sizeof(*out->weights)); + out->norm = tsk_malloc(result_dim * sizeof(*out->norm)); + out->result_tmp + = tsk_malloc(result_dim * max_alleles * max_alleles * sizeof(*out->result_tmp)); + tsk_memset(&out->AB_samples, 0, sizeof(out->AB_samples)); + if (out->weights == NULL || out->norm == NULL || out->result_tmp == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } - ret = tsk_bit_array_init(&AB_samples, num_samples, 1); + ret = tsk_bitset_init(&out->AB_samples, num_samples, 1); if (ret != 0) { goto out; } +out: + return ret; +} + +static void +two_locus_work_free(two_locus_work_t *work) +{ + tsk_safe_free(work->weights); + tsk_safe_free(work->norm); + tsk_safe_free(work->result_tmp); + tsk_bitset_free(&work->AB_samples); +} - for (mut_a = polarised_val; mut_a < num_a_alleles; mut_a++) { +static int +compute_general_normed_two_site_stat_result(const tsk_bitset_t *state, + const tsk_size_t *allele_counts, tsk_size_t a_off, tsk_size_t b_off, + tsk_size_t num_a_alleles, tsk_size_t num_b_alleles, tsk_size_t state_dim, + tsk_size_t result_dim, general_stat_func_t *f, sample_count_stat_params_t *f_params, + norm_func_t *norm_f, bool polarised, two_locus_work_t *restrict work, double *result) +{ + int ret = 0; + // Sample sets and b sites are rows, a sites are columns + // b1 b2 b3 + // a1 [s1, s2, s3] [s1, s2, s3] [s1, s2, s3] + // a2 [s1, s2, s3] [s1, s2, s3] [s1, s2, s3] + // a3 [s1, s2, s3] [s1, s2, s3] [s1, s2, s3] + tsk_size_t k, mut_a, mut_b, result_row_len = num_b_alleles * result_dim; + uint8_t is_polarised = polarised ? 1 : 0; + double *restrict hap_row, *restrict result_tmp_row; + double *restrict norm = work->norm; + double *restrict weights = work->weights; + double *restrict result_tmp = work->result_tmp; + tsk_bitset_t AB_samples = work->AB_samples; + + for (mut_a = is_polarised; mut_a < num_a_alleles; mut_a++) { result_tmp_row = GET_2D_ROW(result_tmp, result_row_len, mut_a); - for (mut_b = polarised_val; mut_b < num_b_alleles; mut_b++) { - tsk_bit_array_get_row(site_a_state, mut_a, &A_samples); - tsk_bit_array_get_row(site_b_state, mut_b, &B_samples); - tsk_bit_array_intersect(&A_samples, &B_samples, &AB_samples); + for (mut_b = is_polarised; mut_b < num_b_alleles; mut_b++) { for (k = 0; k < state_dim; k++) { - tsk_bit_array_get_row(sample_sets, k, &ss_row); - hap_weight_row = GET_2D_ROW(weights, 3, k); - - tsk_bit_array_intersect(&A_samples, &ss_row, &ss_A_samples); - tsk_bit_array_intersect(&B_samples, &ss_row, &ss_B_samples); - tsk_bit_array_intersect(&AB_samples, &ss_row, &ss_AB_samples); - - w_AB = tsk_bit_array_count(&ss_AB_samples); - w_A = tsk_bit_array_count(&ss_A_samples); - w_B = tsk_bit_array_count(&ss_B_samples); - - hap_weight_row[0] = (double) w_AB; - hap_weight_row[1] = (double) (w_A - w_AB); // w_Ab - hap_weight_row[2] = (double) (w_B - w_AB); // w_aB + tsk_bitset_intersect(state, a_off + (mut_a * state_dim) + k, state, + b_off + (mut_b * state_dim) + k, &AB_samples); + hap_row = GET_2D_ROW(weights, 3, k); + hap_row[0] = (double) tsk_bitset_count(&AB_samples, 0); + hap_row[1] = (double) allele_counts[a_off + (mut_a * state_dim) + k] + - hap_row[0]; + hap_row[2] = (double) allele_counts[b_off + (mut_b * state_dim) + k] + - hap_row[0]; } ret = f(state_dim, weights, result_dim, result_tmp_row, f_params); if (ret != 0) { goto out; } - ret = norm_f(result_dim, weights, num_a_alleles - polarised_val, - num_b_alleles - polarised_val, norm, f_params); + ret = norm_f(result_dim, weights, num_a_alleles - is_polarised, + num_b_alleles - is_polarised, norm, f_params); if (ret != 0) { goto out; } @@ -2461,15 +2454,38 @@ compute_general_two_site_stat_result(const tsk_bit_array_t *site_a_state, result_tmp_row += result_dim; // Advance to the next column } } +out: + return ret; +} + +static int +compute_general_two_site_stat_result(const tsk_bitset_t *state, + const tsk_size_t *allele_counts, tsk_size_t a_off, tsk_size_t b_off, + tsk_size_t state_dim, tsk_size_t result_dim, general_stat_func_t *f, + sample_count_stat_params_t *f_params, two_locus_work_t *restrict work, + double *result) +{ + int ret = 0; + tsk_size_t k; + tsk_bitset_t AB_samples = work->AB_samples; + tsk_size_t mut_a = 1, mut_b = 1; + double *restrict hap_row, *restrict weights = work->weights; + for (k = 0; k < state_dim; k++) { + tsk_bitset_intersect(state, a_off + (mut_a * state_dim) + k, state, + b_off + (mut_b * state_dim) + k, &AB_samples); + hap_row = GET_2D_ROW(weights, 3, k); + hap_row[0] = (double) tsk_bitset_count(&AB_samples, 0); + hap_row[1] + = (double) allele_counts[a_off + (mut_a * state_dim) + k] - hap_row[0]; + hap_row[2] + = (double) allele_counts[b_off + (mut_b * state_dim) + k] - hap_row[0]; + } + ret = f(state_dim, weights, result_dim, result, f_params); + if (ret != 0) { + goto out; + } out: - tsk_safe_free(weights); - tsk_safe_free(norm); - tsk_safe_free(result_tmp); - tsk_bit_array_free(&ss_A_samples); - tsk_bit_array_free(&ss_B_samples); - tsk_bit_array_free(&ss_AB_samples); - tsk_bit_array_free(&AB_samples); return ret; } @@ -2520,7 +2536,7 @@ get_site_row_col_indices(tsk_size_t n_rows, const tsk_id_t *row_sites, tsk_size_ static int get_mutation_samples(const tsk_treeseq_t *ts, const tsk_id_t *sites, tsk_size_t n_sites, - tsk_size_t *num_alleles, tsk_bit_array_t *allele_samples) + tsk_size_t *num_alleles, tsk_bitset_t *allele_samples) { int ret = 0; const tsk_flags_t *restrict flags = ts->tables->nodes.flags; @@ -2528,7 +2544,7 @@ get_mutation_samples(const tsk_treeseq_t *ts, const tsk_id_t *sites, tsk_size_t const tsk_size_t *restrict site_muts_len = ts->site_mutations_length; tsk_site_t site; tsk_tree_t tree; - tsk_bit_array_t all_samples_bits, mut_samples, mut_samples_row, out_row; + tsk_bitset_t all_samples_bits, mut_samples; tsk_size_t max_muts_len, site_offset, num_nodes, site_idx, s, m, n; tsk_id_t node, *nodes = NULL; void *tmp_nodes; @@ -2543,11 +2559,11 @@ get_mutation_samples(const tsk_treeseq_t *ts, const tsk_id_t *sites, tsk_size_t } } // Allocate a bit array of size max alleles for all sites - ret = tsk_bit_array_init(&mut_samples, num_samples, max_muts_len); + ret = tsk_bitset_init(&mut_samples, num_samples, max_muts_len); if (ret != 0) { goto out; } - ret = tsk_bit_array_init(&all_samples_bits, num_samples, 1); + ret = tsk_bitset_init(&all_samples_bits, num_samples, 1); if (ret != 0) { goto out; } @@ -2572,15 +2588,11 @@ get_mutation_samples(const tsk_treeseq_t *ts, const tsk_id_t *sites, tsk_size_t goto out; } nodes = tmp_nodes; - - tsk_bit_array_get_row(allele_samples, site_offset, &out_row); - tsk_bit_array_add(&out_row, &all_samples_bits); - + tsk_bitset_union(allele_samples, site_offset, &all_samples_bits, 0); // Zero out results before the start of each iteration tsk_memset(mut_samples.data, 0, - mut_samples.size * max_muts_len * sizeof(tsk_bit_array_value_t)); + mut_samples.row_len * max_muts_len * sizeof(tsk_bitset_val_t)); for (m = 0; m < site.mutations_length; m++) { - tsk_bit_array_get_row(&mut_samples, m, &mut_samples_row); node = site.mutations[m].node; ret = tsk_tree_preorder_from(&tree, node, nodes, &num_nodes); if (ret != 0) { @@ -2589,43 +2601,92 @@ get_mutation_samples(const tsk_treeseq_t *ts, const tsk_id_t *sites, tsk_size_t for (n = 0; n < num_nodes; n++) { node = nodes[n]; if (flags[node] & TSK_NODE_IS_SAMPLE) { - tsk_bit_array_add_bit(&mut_samples_row, - (tsk_bit_array_value_t) ts->sample_index_map[node]); + tsk_bitset_set_bit( + &mut_samples, m, (tsk_bitset_val_t) ts->sample_index_map[node]); } } } + get_allele_samples( + &site, site_offset, &mut_samples, allele_samples, &(num_alleles[site_idx])); site_offset += site.mutations_length + 1; - get_allele_samples(&site, &mut_samples, &out_row, &(num_alleles[site_idx])); } // if adding code below, check ret before continuing out: tsk_safe_free(nodes); tsk_tree_free(&tree); - tsk_bit_array_free(&mut_samples); - tsk_bit_array_free(&all_samples_bits); + tsk_bitset_free(&mut_samples); + tsk_bitset_free(&all_samples_bits); return ret == TSK_TREE_OK ? 0 : ret; } +static int +get_mutation_sample_sets(const tsk_bitset_t *allele_samples, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + const tsk_id_t *sample_index_map, tsk_size_t *max_ss_size, + tsk_bitset_t *allele_sample_sets, tsk_size_t **allele_sample_set_counts) +{ + int ret = 0; + tsk_bitset_val_t k, sample; + tsk_size_t i, j, ss_off; + + *max_ss_size = 0; + for (i = 0; i < num_sample_sets; i++) { + if (sample_set_sizes[i] > *max_ss_size) { + *max_ss_size = sample_set_sizes[i]; + } + } + + *allele_sample_set_counts = tsk_calloc( + allele_samples->len * num_sample_sets, sizeof(**allele_sample_set_counts)); + if (*allele_sample_set_counts == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + ret = tsk_bitset_init( + allele_sample_sets, *max_ss_size, allele_samples->len * num_sample_sets); + if (ret != 0) { + goto out; + } + + for (i = 0; i < allele_samples->len; i++) { + ss_off = 0; + for (j = 0; j < num_sample_sets; j++) { + for (k = 0; k < sample_set_sizes[j]; k++) { + sample = (tsk_bitset_val_t) sample_index_map[sample_sets[k + ss_off]]; + if (tsk_bitset_contains(allele_samples, i, sample)) { + tsk_bitset_set_bit(allele_sample_sets, j + i * num_sample_sets, k); + (*allele_sample_set_counts)[j + i * num_sample_sets]++; + } + } + ss_off += sample_set_sizes[j]; + } + } +out: + return ret; +} + static int tsk_treeseq_two_site_count_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, - const tsk_bit_array_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f, + tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, + const tsk_id_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f, sample_count_stat_params_t *f_params, norm_func_t *norm_f, tsk_size_t n_rows, const tsk_id_t *row_sites, tsk_size_t n_cols, const tsk_id_t *col_sites, tsk_flags_t options, double *result) { - int ret = 0; - tsk_bit_array_t allele_samples, c_state, r_state; + tsk_bitset_t allele_samples, allele_sample_sets; bool polarised = false; tsk_id_t *sites; - tsk_size_t r, c, s, n_alleles, n_sites, *row_idx, *col_idx; + tsk_size_t i, j, max_ss_size, max_alleles, n_alleles, n_sites, *row_idx, *col_idx; double *result_row; const tsk_size_t num_samples = self->num_samples; - tsk_size_t *num_alleles = NULL, *site_offsets = NULL; + tsk_size_t *num_alleles = NULL, *site_offsets = NULL, *allele_counts = NULL; tsk_size_t result_row_len = n_cols * result_dim; + two_locus_work_t work; + tsk_memset(&work, 0, sizeof(work)); tsk_memset(&allele_samples, 0, sizeof(allele_samples)); - + tsk_memset(&allele_sample_sets, 0, sizeof(allele_sample_sets)); sites = tsk_malloc(self->tables->sites.num_rows * sizeof(*sites)); row_idx = tsk_malloc(self->tables->sites.num_rows * sizeof(*row_idx)); col_idx = tsk_malloc(self->tables->sites.num_rows * sizeof(*col_idx)); @@ -2644,35 +2705,57 @@ tsk_treeseq_two_site_count_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } - n_alleles = 0; - for (s = 0; s < n_sites; s++) { - site_offsets[s] = n_alleles; - n_alleles += self->site_mutations_length[sites[s]] + 1; + max_alleles = 0; + for (i = 0; i < n_sites; i++) { + site_offsets[i] = n_alleles * num_sample_sets; + n_alleles += self->site_mutations_length[sites[i]] + 1; + if (self->site_mutations_length[sites[i]] > max_alleles) { + max_alleles = self->site_mutations_length[sites[i]]; + } } - ret = tsk_bit_array_init(&allele_samples, num_samples, n_alleles); + max_alleles++; // add 1 for the ancestral allele + + ret = tsk_bitset_init(&allele_samples, num_samples, n_alleles); if (ret != 0) { goto out; } + // we track the number of alleles to account for backmutations ret = get_mutation_samples(self, sites, n_sites, num_alleles, &allele_samples); if (ret != 0) { goto out; } + ret = get_mutation_sample_sets(&allele_samples, num_sample_sets, sample_set_sizes, + sample_sets, self->sample_index_map, &max_ss_size, &allele_sample_sets, + &allele_counts); + if (ret != 0) { + goto out; + } + ret = two_locus_work_init(max_alleles, max_ss_size, result_dim, state_dim, &work); + if (ret != 0) { + goto out; + } if (options & TSK_STAT_POLARISED) { polarised = true; } // For each row/column pair, fill in the sample set in the result matrix. - for (r = 0; r < n_rows; r++) { - result_row = GET_2D_ROW(result, result_row_len, r); - for (c = 0; c < n_cols; c++) { - tsk_bit_array_get_row(&allele_samples, site_offsets[row_idx[r]], &r_state); - tsk_bit_array_get_row(&allele_samples, site_offsets[col_idx[c]], &c_state); - ret = compute_general_two_site_stat_result(&r_state, &c_state, - num_alleles[row_idx[r]], num_alleles[col_idx[c]], num_samples, state_dim, - sample_sets, result_dim, f, f_params, norm_f, polarised, - &(result_row[c * result_dim])); + for (i = 0; i < n_rows; i++) { + result_row = GET_2D_ROW(result, result_row_len, i); + for (j = 0; j < n_cols; j++) { + if (num_alleles[row_idx[i]] == 2 && num_alleles[col_idx[j]] == 2) { + ret = compute_general_two_site_stat_result(&allele_sample_sets, + allele_counts, site_offsets[row_idx[i]], site_offsets[col_idx[j]], + state_dim, result_dim, f, f_params, &work, + &(result_row[j * result_dim])); + } else { + ret = compute_general_normed_two_site_stat_result(&allele_sample_sets, + allele_counts, site_offsets[row_idx[i]], site_offsets[col_idx[j]], + num_alleles[row_idx[i]], num_alleles[col_idx[j]], state_dim, + result_dim, f, f_params, norm_f, polarised, &work, + &(result_row[j * result_dim])); + } if (ret != 0) { goto out; } @@ -2685,37 +2768,37 @@ tsk_treeseq_two_site_count_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, tsk_safe_free(col_idx); tsk_safe_free(num_alleles); tsk_safe_free(site_offsets); - tsk_bit_array_free(&allele_samples); + tsk_safe_free(allele_counts); + two_locus_work_free(&work); + tsk_bitset_free(&allele_samples); + tsk_bitset_free(&allele_sample_sets); return ret; } static int -sample_sets_to_bit_array(const tsk_treeseq_t *self, const tsk_size_t *sample_set_sizes, +sample_sets_to_bitset(const tsk_treeseq_t *self, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_sample_sets, - tsk_bit_array_t *sample_sets_bits) + tsk_bitset_t *sample_sets_bits) { int ret; - tsk_bit_array_t bits_row; tsk_size_t j, k, l; tsk_id_t u, sample_index; - ret = tsk_bit_array_init(sample_sets_bits, self->num_samples, num_sample_sets); + ret = tsk_bitset_init(sample_sets_bits, self->num_samples, num_sample_sets); if (ret != 0) { return ret; } - j = 0; for (k = 0; k < num_sample_sets; k++) { - tsk_bit_array_get_row(sample_sets_bits, k, &bits_row); for (l = 0; l < sample_set_sizes[k]; l++) { u = sample_sets[j]; sample_index = self->sample_index_map[u]; - if (tsk_bit_array_contains( - &bits_row, (tsk_bit_array_value_t) sample_index)) { + if (tsk_bitset_contains( + sample_sets_bits, k, (tsk_bitset_val_t) sample_index)) { ret = tsk_trace_error(TSK_ERR_DUPLICATE_SAMPLE); goto out; } - tsk_bit_array_add_bit(&bits_row, (tsk_bit_array_value_t) sample_index); + tsk_bitset_set_bit(sample_sets_bits, k, (tsk_bitset_val_t) sample_index); j++; } } @@ -2852,7 +2935,7 @@ get_index_counts( typedef struct { tsk_tree_t tree; - tsk_bit_array_t *node_samples; + tsk_bitset_t *node_samples; tsk_id_t *parent; tsk_id_t *edges_out; tsk_id_t *edges_in; @@ -2876,7 +2959,7 @@ iter_state_init(iter_state *self, const tsk_treeseq_t *ts, tsk_size_t state_dim) ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } - ret = tsk_bit_array_init(self->node_samples, ts->num_samples, state_dim * num_nodes); + ret = tsk_bitset_init(self->node_samples, ts->num_samples, state_dim * num_nodes); if (ret != 0) { goto out; } @@ -2895,29 +2978,25 @@ iter_state_init(iter_state *self, const tsk_treeseq_t *ts, tsk_size_t state_dim) static int get_node_samples(const tsk_treeseq_t *ts, tsk_size_t state_dim, - const tsk_bit_array_t *sample_sets, tsk_bit_array_t *node_samples) + const tsk_bitset_t *sample_sets, tsk_bitset_t *node_samples) { int ret = 0; tsk_size_t n, k; - tsk_bit_array_t sample_set_row, node_samples_row; tsk_size_t num_nodes = ts->tables->nodes.num_rows; - tsk_bit_array_value_t sample; + tsk_bitset_val_t sample; const tsk_id_t *restrict sample_index_map = ts->sample_index_map; const tsk_flags_t *restrict flags = ts->tables->nodes.flags; - ret = tsk_bit_array_init(node_samples, ts->num_samples, num_nodes * state_dim); + ret = tsk_bitset_init(node_samples, ts->num_samples, num_nodes * state_dim); if (ret != 0) { goto out; } for (k = 0; k < state_dim; k++) { - tsk_bit_array_get_row(sample_sets, k, &sample_set_row); for (n = 0; n < num_nodes; n++) { if (flags[n] & TSK_NODE_IS_SAMPLE) { - sample = (tsk_bit_array_value_t) sample_index_map[n]; - if (tsk_bit_array_contains(&sample_set_row, sample)) { - tsk_bit_array_get_row( - node_samples, (state_dim * n) + k, &node_samples_row); - tsk_bit_array_add_bit(&node_samples_row, sample); + sample = (tsk_bitset_val_t) sample_index_map[n]; + if (tsk_bitset_contains(sample_sets, k, sample)) { + tsk_bitset_set_bit(node_samples, (state_dim * n) + k, sample); } } } @@ -2928,7 +3007,7 @@ get_node_samples(const tsk_treeseq_t *ts, tsk_size_t state_dim, static void iter_state_clear(iter_state *self, tsk_size_t state_dim, tsk_size_t num_nodes, - const tsk_bit_array_t *node_samples) + const tsk_bitset_t *node_samples) { self->n_edges_out = 0; self->n_edges_in = 0; @@ -2938,14 +3017,14 @@ iter_state_clear(iter_state *self, tsk_size_t state_dim, tsk_size_t num_nodes, tsk_memset(self->edges_in, TSK_NULL, num_nodes * sizeof(*self->edges_in)); tsk_memset(self->branch_len, 0, num_nodes * sizeof(*self->branch_len)); tsk_memcpy(self->node_samples->data, node_samples->data, - node_samples->size * state_dim * num_nodes * sizeof(*node_samples->data)); + node_samples->row_len * state_dim * num_nodes * sizeof(*node_samples->data)); } static void iter_state_free(iter_state *self) { tsk_tree_free(&self->tree); - tsk_bit_array_free(self->node_samples); + tsk_bitset_free(self->node_samples); tsk_safe_free(self->node_samples); tsk_safe_free(self->parent); tsk_safe_free(self->edges_out); @@ -3025,41 +3104,26 @@ static int compute_two_tree_branch_state_update(const tsk_treeseq_t *ts, tsk_id_t c, const iter_state *A_state, const iter_state *B_state, tsk_size_t state_dim, tsk_size_t result_dim, int sign, general_stat_func_t *f, - sample_count_stat_params_t *f_params, double *result) + sample_count_stat_params_t *f_params, two_locus_work_t *restrict work, + double *result) { int ret = 0; double a_len, b_len; double *restrict B_branch_len = B_state->branch_len; - double *weights = NULL, *weights_row, *result_tmp = NULL; + double *weights_row; tsk_size_t n, k, a_row, b_row; - tsk_bit_array_t A_samples, B_samples, AB_samples, B_samples_tmp; const double *restrict A_branch_len = A_state->branch_len; - const tsk_bit_array_t *restrict A_state_samples = A_state->node_samples; - const tsk_bit_array_t *restrict B_state_samples = B_state->node_samples; - tsk_size_t num_samples = ts->num_samples; + const tsk_bitset_t *restrict A_state_samples = A_state->node_samples; + const tsk_bitset_t *restrict B_state_samples = B_state->node_samples; tsk_size_t num_nodes = ts->tables->nodes.num_rows; + double *weights = work->weights; + double *result_tmp = work->result_tmp; + tsk_bitset_t AB_samples = work->AB_samples; + b_len = B_branch_len[c] * sign; if (b_len == 0) { return ret; } - - tsk_memset(&AB_samples, 0, sizeof(AB_samples)); - tsk_memset(&B_samples_tmp, 0, sizeof(B_samples_tmp)); - - weights = tsk_calloc(3 * state_dim, sizeof(*weights)); - result_tmp = tsk_calloc(result_dim, sizeof(*result_tmp)); - if (weights == NULL || result_tmp == NULL) { - ret = tsk_trace_error(TSK_ERR_NO_MEMORY); - goto out; - } - ret = tsk_bit_array_init(&AB_samples, num_samples, 1); - if (ret != 0) { - goto out; - } - ret = tsk_bit_array_init(&B_samples_tmp, num_samples, 1); - if (ret != 0) { - goto out; - } for (n = 0; n < num_nodes; n++) { a_len = A_branch_len[n]; if (a_len == 0) { @@ -3068,15 +3132,14 @@ compute_two_tree_branch_state_update(const tsk_treeseq_t *ts, tsk_id_t c, for (k = 0; k < state_dim; k++) { a_row = (state_dim * n) + k; b_row = (state_dim * (tsk_size_t) c) + k; - tsk_bit_array_get_row(A_state_samples, a_row, &A_samples); - tsk_bit_array_get_row(B_state_samples, b_row, &B_samples); - tsk_bit_array_intersect(&A_samples, &B_samples, &AB_samples); weights_row = GET_2D_ROW(weights, 3, k); - weights_row[0] = (double) tsk_bit_array_count(&AB_samples); // w_AB + tsk_bitset_intersect( + A_state_samples, a_row, B_state_samples, b_row, &AB_samples); + weights_row[0] = (double) tsk_bitset_count(&AB_samples, 0); weights_row[1] - = (double) tsk_bit_array_count(&A_samples) - weights_row[0]; // w_Ab + = (double) tsk_bitset_count(A_state_samples, a_row) - weights_row[0]; weights_row[2] - = (double) tsk_bit_array_count(&B_samples) - weights_row[0]; // w_aB + = (double) tsk_bitset_count(B_state_samples, b_row) - weights_row[0]; } ret = f(state_dim, weights, result_dim, result_tmp, f_params); if (ret != 0) { @@ -3087,10 +3150,6 @@ compute_two_tree_branch_state_update(const tsk_treeseq_t *ts, tsk_id_t c, } } out: - tsk_safe_free(weights); - tsk_safe_free(result_tmp); - tsk_bit_array_free(&AB_samples); - tsk_bit_array_free(&B_samples_tmp); return ret; } @@ -3106,10 +3165,17 @@ compute_two_tree_branch_stat(const tsk_treeseq_t *ts, const iter_state *l_state, const tsk_id_t *restrict edges_child = ts->tables->edges.child; const tsk_id_t *restrict edges_parent = ts->tables->edges.parent; const tsk_size_t num_nodes = ts->tables->nodes.num_rows; - tsk_bit_array_t updates, row, ec_row, *r_samples = r_state->node_samples; + tsk_bitset_t updates, *r_samples = r_state->node_samples; + two_locus_work_t work; + tsk_memset(&work, 0, sizeof(work)); tsk_memset(&updates, 0, sizeof(updates)); - ret = tsk_bit_array_init(&updates, num_nodes, 1); + // only two alleles are possible for branch stats + ret = two_locus_work_init(2, ts->num_samples, result_dim, state_dim, &work); + if (ret != 0) { + goto out; + } + ret = tsk_bitset_init(&updates, num_nodes, 1); if (ret != 0) { goto out; } @@ -3126,18 +3192,18 @@ compute_two_tree_branch_stat(const tsk_treeseq_t *ts, const iter_state *l_state, c = edges_child[e]; // Identify affected nodes above child while (p != TSK_NULL) { - tsk_bit_array_add_bit(&updates, (tsk_bit_array_value_t) c); + tsk_bitset_set_bit(&updates, 0, (tsk_bitset_val_t) c); c = p; p = r_state->parent[p]; } } // Subtract the whole contribution from the child node - tsk_bit_array_get_items(&updates, updated_nodes, &n_updates); + tsk_bitset_get_items(&updates, 0, updated_nodes, &n_updates); while (n_updates != 0) { n_updates--; c = updated_nodes[n_updates]; - compute_two_tree_branch_state_update( - ts, c, l_state, r_state, state_dim, result_dim, -1, f, f_params, result); + compute_two_tree_branch_state_update(ts, c, l_state, r_state, state_dim, + result_dim, -1, f, f_params, &work, result); } // Remove samples under nodes from removed edges to parent nodes for (j = 0; j < r_state->n_edges_out; j++) { @@ -3146,10 +3212,8 @@ compute_two_tree_branch_stat(const tsk_treeseq_t *ts, const iter_state *l_state, ec = edges_child[e]; // edge child while (p != TSK_NULL) { for (k = 0; k < state_dim; k++) { - tsk_bit_array_get_row( - r_samples, (state_dim * (tsk_size_t) ec) + k, &ec_row); - tsk_bit_array_get_row(r_samples, (state_dim * (tsk_size_t) p) + k, &row); - tsk_bit_array_subtract(&row, &ec_row); + tsk_bitset_subtract(r_samples, (state_dim * (tsk_size_t) p) + k, + r_samples, (state_dim * (tsk_size_t) ec) + k); } p = r_state->parent[p]; } @@ -3164,12 +3228,10 @@ compute_two_tree_branch_stat(const tsk_treeseq_t *ts, const iter_state *l_state, r_state->branch_len[c] = time[p] - time[c]; r_state->parent[c] = p; while (p != TSK_NULL) { - tsk_bit_array_add_bit(&updates, (tsk_bit_array_value_t) c); + tsk_bitset_set_bit(&updates, 0, (tsk_bitset_val_t) c); for (k = 0; k < state_dim; k++) { - tsk_bit_array_get_row( - r_samples, (state_dim * (tsk_size_t) ec) + k, &ec_row); - tsk_bit_array_get_row(r_samples, (state_dim * (tsk_size_t) p) + k, &row); - tsk_bit_array_add(&row, &ec_row); + tsk_bitset_union(r_samples, (state_dim * (tsk_size_t) p) + k, r_samples, + (state_dim * (tsk_size_t) ec) + k); } c = p; p = r_state->parent[p]; @@ -3177,22 +3239,24 @@ compute_two_tree_branch_stat(const tsk_treeseq_t *ts, const iter_state *l_state, } // Update all affected child nodes (fully subtracted, deferred from addition) n_updates = 0; - tsk_bit_array_get_items(&updates, updated_nodes, &n_updates); + tsk_bitset_get_items(&updates, 0, updated_nodes, &n_updates); while (n_updates != 0) { n_updates--; c = updated_nodes[n_updates]; - compute_two_tree_branch_state_update( - ts, c, l_state, r_state, state_dim, result_dim, +1, f, f_params, result); + compute_two_tree_branch_state_update(ts, c, l_state, r_state, state_dim, + result_dim, +1, f, f_params, &work, result); } out: tsk_safe_free(updated_nodes); - tsk_bit_array_free(&updates); + two_locus_work_free(&work); + tsk_bitset_free(&updates); return ret; } static int tsk_treeseq_two_branch_count_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, - const tsk_bit_array_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f, + tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, + const tsk_id_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f, sample_count_stat_params_t *f_params, norm_func_t *TSK_UNUSED(norm_f), tsk_size_t n_rows, const double *row_positions, tsk_size_t n_cols, const double *col_positions, tsk_flags_t TSK_UNUSED(options), double *result) @@ -3201,11 +3265,12 @@ tsk_treeseq_two_branch_count_stat(const tsk_treeseq_t *self, tsk_size_t state_di int r, c; tsk_id_t *row_indexes = NULL, *col_indexes = NULL; tsk_size_t i, j, k, row, col, *row_repeats = NULL, *col_repeats = NULL; - tsk_bit_array_t node_samples; + tsk_bitset_t node_samples, sample_sets_bits; iter_state l_state, r_state; double *result_tmp = NULL, *result_row; const tsk_size_t num_nodes = self->tables->nodes.num_rows; + tsk_memset(&sample_sets_bits, 0, sizeof(sample_sets_bits)); tsk_memset(&node_samples, 0, sizeof(node_samples)); tsk_memset(&l_state, 0, sizeof(l_state)); tsk_memset(&r_state, 0, sizeof(r_state)); @@ -3222,6 +3287,11 @@ tsk_treeseq_two_branch_count_stat(const tsk_treeseq_t *self, tsk_size_t state_di if (ret != 0) { goto out; } + ret = sample_sets_to_bitset( + self, sample_set_sizes, sample_sets, num_sample_sets, &sample_sets_bits); + if (ret != 0) { + goto out; + } ret = positions_to_tree_indexes(self, row_positions, n_rows, &row_indexes); if (ret != 0) { goto out; @@ -3238,7 +3308,7 @@ tsk_treeseq_two_branch_count_stat(const tsk_treeseq_t *self, tsk_size_t state_di if (ret != 0) { goto out; } - ret = get_node_samples(self, state_dim, sample_sets, &node_samples); + ret = get_node_samples(self, state_dim, &sample_sets_bits, &node_samples); if (ret != 0) { goto out; } @@ -3289,7 +3359,42 @@ tsk_treeseq_two_branch_count_stat(const tsk_treeseq_t *self, tsk_size_t state_di tsk_safe_free(col_repeats); iter_state_free(&l_state); iter_state_free(&r_state); - tsk_bit_array_free(&node_samples); + tsk_bitset_free(&node_samples); + tsk_bitset_free(&sample_sets_bits); + return ret; +} + +static int +check_sample_set_dups(tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, + const tsk_id_t *sample_sets, const tsk_id_t *restrict sample_index_map, + tsk_size_t num_samples) +{ + int ret; + tsk_size_t j, k, l; + tsk_id_t u, sample_index; + tsk_bitset_t tmp; + + tsk_memset(&tmp, 0, sizeof(tmp)); + ret = tsk_bitset_init(&tmp, num_samples, 1); + if (ret != 0) { + return ret; + } + j = 0; + for (k = 0; k < num_sample_sets; k++) { + tsk_memset(tmp.data, 0, sizeof(*tmp.data) * tmp.row_len); + for (l = 0; l < sample_set_sizes[k]; l++) { + u = sample_sets[j]; + sample_index = sample_index_map[u]; + if (tsk_bitset_contains(&tmp, 0, (tsk_bitset_val_t) sample_index)) { + ret = tsk_trace_error(TSK_ERR_DUPLICATE_SAMPLE); + goto out; + } + tsk_bitset_set_bit(&tmp, 0, (tsk_bitset_val_t) sample_index); + j++; + } + } +out: + tsk_bitset_free(&tmp); return ret; } @@ -3304,7 +3409,6 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl // TODO: generalize this function if we ever decide to do weighted two_locus stats. // We only implement count stats and therefore we don't handle weights. int ret = 0; - tsk_bit_array_t sample_sets_bits; bool stat_site = !!(options & TSK_STAT_SITE); bool stat_branch = !!(options & TSK_STAT_BRANCH); tsk_size_t state_dim = num_sample_sets; @@ -3313,8 +3417,6 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl .sample_set_sizes = sample_set_sizes, .set_indexes = set_indexes }; - tsk_memset(&sample_sets_bits, 0, sizeof(sample_sets_bits)); - // We do not support two-locus node stats if (!!(options & TSK_STAT_NODE)) { ret = tsk_trace_error(TSK_ERR_UNSUPPORTED_STAT_MODE); @@ -3338,12 +3440,6 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl ret = tsk_trace_error(TSK_ERR_BAD_RESULT_DIMS); goto out; } - ret = sample_sets_to_bit_array( - self, sample_set_sizes, sample_sets, num_sample_sets, &sample_sets_bits); - if (ret != 0) { - goto out; - } - if (stat_site) { ret = check_sites(row_sites, out_rows, self->tables->sites.num_rows); if (ret != 0) { @@ -3353,9 +3449,15 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl if (ret != 0) { goto out; } - ret = tsk_treeseq_two_site_count_stat(self, state_dim, &sample_sets_bits, - result_dim, f, &f_params, norm_f, out_rows, row_sites, out_cols, col_sites, - options, result); + ret = check_sample_set_dups(num_sample_sets, sample_set_sizes, sample_sets, + self->sample_index_map, self->num_samples); + if (ret != 0) { + goto out; + } + // TODO: result dim/state dim can be set internally now. + ret = tsk_treeseq_two_site_count_stat(self, state_dim, num_sample_sets, + sample_set_sizes, sample_sets, result_dim, f, &f_params, norm_f, out_rows, + row_sites, out_cols, col_sites, options, result); } else if (stat_branch) { ret = check_positions( row_positions, out_rows, tsk_treeseq_get_sequence_length(self)); @@ -3367,13 +3469,11 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl if (ret != 0) { goto out; } - ret = tsk_treeseq_two_branch_count_stat(self, state_dim, &sample_sets_bits, - result_dim, f, &f_params, norm_f, out_rows, row_positions, out_cols, - col_positions, options, result); + ret = tsk_treeseq_two_branch_count_stat(self, state_dim, num_sample_sets, + sample_set_sizes, sample_sets, result_dim, f, &f_params, norm_f, out_rows, + row_positions, out_cols, col_positions, options, result); } - out: - tsk_bit_array_free(&sample_sets_bits); return ret; }