From 0c1eeb47e5508783fb958f969da83ae9d5fc6fd2 Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Wed, 15 Oct 2025 19:52:55 +0100 Subject: [PATCH 01/11] Numba only --- python/tests/test_jit.py | 160 +++++++++++++++++++++ python/tskit/jit/numba.py | 294 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 454 insertions(+) diff --git a/python/tests/test_jit.py b/python/tests/test_jit.py index 38aede1894..4e629afa7f 100644 --- a/python/tests/test_jit.py +++ b/python/tests/test_jit.py @@ -671,3 +671,163 @@ def ancestral_edges_tskit(ts, start_node): a1 = ancestral_edges(numba_ts, u) a2 = ancestral_edges_tskit(ts, u) nt.assert_array_equal(a1, a2) + + +def build_alignment_example(): + tables = tskit.TableCollection(sequence_length=3) + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) # 0 + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) # 1 + tables.nodes.add_row(flags=0, time=1) # 2 + tables.edges.add_row(0, 3, parent=2, child=0) + tables.edges.add_row(0, 3, parent=2, child=1) + tables.sites.add_row(0, "A") + tables.mutations.add_row(site=0, node=2, derived_state="G") + tables.sites.add_row(1, "A") + tables.mutations.add_row(site=1, node=0, derived_state="C") + tables.sites.add_row(2, "A") + tables.mutations.add_row(site=2, node=1, derived_state="T") + tables.sort() + return tables.tree_sequence() + + +def build_missing_alignment_example(): + tables = tskit.TableCollection(sequence_length=3) + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) # 0 isolated + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) # 1 + tables.nodes.add_row(flags=0, time=1) # ancestor for sample 1 + tables.edges.add_row(0, 3, parent=2, child=1) + tables.sites.add_row(0, "A") + tables.sites.add_row(1, "A") + tables.mutations.add_row(site=1, node=2, derived_state="T") + tables.sites.add_row(2, "A") + tables.sort() + return tables.tree_sequence() + + +def build_internal_sample_example(): + tables = tskit.TableCollection(sequence_length=3) + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) # 0 + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=1) # 1 internal sample + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) # 2 + tables.nodes.add_row(flags=0, time=2) # 3 root + tables.edges.add_row(0, 3, parent=1, child=0) + tables.edges.add_row(0, 3, parent=3, child=1) + tables.edges.add_row(0, 3, parent=3, child=2) + tables.sites.add_row(0, "A") + tables.mutations.add_row(site=0, node=3, derived_state="G") + tables.sites.add_row(1, "A") + tables.mutations.add_row(site=1, node=1, derived_state="C") + tables.sites.add_row(2, "A") + tables.mutations.add_row(site=2, node=0, derived_state="T") + tables.sort() + return tables.tree_sequence() + + +def build_overlapping_edges_example(): + tables = tskit.TableCollection(sequence_length=4) + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) # 0 + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) # 1 + tables.nodes.add_row(flags=0, time=1) # 2 + tables.nodes.add_row(flags=0, time=1) # 3 + tables.nodes.add_row(flags=0, time=2) # 4 root + tables.edges.add_row(0, 2, parent=2, child=0) + tables.edges.add_row(2, 4, parent=3, child=0) + tables.edges.add_row(0, 4, parent=3, child=1) + tables.edges.add_row(0, 4, parent=4, child=2) + tables.edges.add_row(0, 4, parent=4, child=3) + tables.sites.add_row(1, "A") + tables.mutations.add_row(site=0, node=2, derived_state="G") + tables.sites.add_row(3, "A") + tables.mutations.add_row(site=1, node=3, derived_state="T") + tables.sort() + return tables.tree_sequence() + + +def build_deep_mutation_example(): + tables = tskit.TableCollection(sequence_length=2) + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) # 0 + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) # 1 + tables.nodes.add_row(flags=0, time=1) # 2 + tables.nodes.add_row(flags=0, time=2) # 3 + tables.nodes.add_row(flags=0, time=3) # 4 root + tables.edges.add_row(0, 2, parent=2, child=0) + tables.edges.add_row(0, 2, parent=4, child=1) + tables.edges.add_row(0, 2, parent=3, child=2) + tables.edges.add_row(0, 2, parent=4, child=3) + tables.sites.add_row(0, "A") + m0 = tables.mutations.add_row(site=0, node=4, derived_state="C") + m1 = tables.mutations.add_row(site=0, node=3, derived_state="G", parent=m0) + tables.mutations.add_row(site=0, node=2, derived_state="T", parent=m1) + tables.sort() + return tables.tree_sequence() + + +def build_multiple_roots_example(): + tables = tskit.TableCollection(sequence_length=3) + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) # 0 + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) # 1 + tables.nodes.add_row(flags=0, time=1) # 2 root A + tables.nodes.add_row(flags=0, time=1) # 3 root B + tables.edges.add_row(0, 3, parent=2, child=0) + tables.edges.add_row(0, 3, parent=3, child=1) + tables.sites.add_row(0, "A") + tables.mutations.add_row(site=0, node=2, derived_state="G") + tables.sites.add_row(2, "A") + tables.mutations.add_row(site=1, node=3, derived_state="T") + tables.sort() + return tables.tree_sequence() + + +def _check_alignments(ts): + expected = list(ts.haplotypes()) + sites = list(ts.sites()) + adjusted = [] + for hap in expected: + chars = list(hap) + for j, c in enumerate(chars): + if c == "N": + chars[j] = sites[j].ancestral_state + adjusted.append("".join(chars)) + numba_ts = jit_numba.jitwrap(ts) + observed = list(jit_numba.alignments(numba_ts)) + samples = [node for node, _ in observed] + haplotypes = [hap for _, hap in observed] + assert samples == list(ts.samples()) + assert haplotypes == adjusted + + +def test_jit_alignments_basic(): + ts = build_alignment_example() + _check_alignments(ts) + + +def test_jit_alignments_missing_data(): + ts = build_missing_alignment_example() + _check_alignments(ts) + + +def test_jit_alignments_internal_sample(): + ts = build_internal_sample_example() + _check_alignments(ts) + + +def test_jit_alignments_overlapping_edges(): + ts = build_overlapping_edges_example() + _check_alignments(ts) + + +def test_jit_alignments_deep_mutations(): + ts = build_deep_mutation_example() + _check_alignments(ts) + + +def test_jit_alignments_multiple_roots(): + ts = build_multiple_roots_example() + _check_alignments(ts) + + +def test_jit_alignments_msprime_example(): + ts = msprime.sim_ancestry(5, sequence_length=8, ploidy=1, random_seed=5) + ts = msprime.sim_mutations(ts, rate=0.5, random_seed=13) + assert ts.discrete_genome + _check_alignments(ts) diff --git a/python/tskit/jit/numba.py b/python/tskit/jit/numba.py index 40534431d1..698c2cdee4 100644 --- a/python/tskit/jit/numba.py +++ b/python/tskit/jit/numba.py @@ -672,3 +672,297 @@ def jitwrap(ts): ) return numba_ts + + +@numba.njit +def _bitset_init(bitset, num_sites): + # Initialise all bits to 1 (meaning "unresolved") and mask any unused bits + # in the final word when the number of sites is not a multiple of 64. + n_words = bitset.shape[0] + for w in range(n_words): + bitset[w] = np.uint64(-1) + if n_words > 0: + excess = n_words * 64 - num_sites + if excess > 0: + mask = np.uint64(-1) >> excess + bitset[n_words - 1] = mask + + +@numba.njit +def _bitset_clear(bitset, idx): + word = idx >> 6 + bit = np.uint64(1) << (idx & 63) + bitset[word] &= ~bit + + +@numba.njit +def _bitset_is_set(bitset, idx): + word = idx >> 6 + bit = np.uint64(1) << (idx & 63) + return (bitset[word] & bit) != 0 + + +@numba.njit +def _ctz64(x): + count = 0 + while (x & 1) == 0: + x >>= 1 + count += 1 + return count + + +@numba.njit +def _bitset_next(bitset, start, num_sites): + # Return the index of the first set bit >= start, or num_sites if none. + if start >= num_sites: + return num_sites + n_words = bitset.shape[0] + word = start >> 6 + offset = start & 63 + if word >= n_words: + return num_sites + mask = np.uint64(-1) << offset + value = bitset[word] & mask + while value == 0: + word += 1 + if word >= n_words: + return num_sites + value = bitset[word] + return (word << 6) + _ctz64(value) + + +def _build_node_mutation_index(numba_ts): + num_nodes = numba_ts.num_nodes + num_mutations = numba_ts.num_mutations + + counts = np.zeros(num_nodes, dtype=np.int32) + mutations_node = numba_ts.mutations_node + mutations_site = numba_ts.mutations_site + mutations_derived_state = numba_ts.mutations_derived_state + + for mut_id in range(num_mutations - 1, -1, -1): + node = mutations_node[mut_id] + if 0 <= node < num_nodes: + counts[node] += 1 + + offsets = np.zeros(num_nodes + 1, dtype=np.int32) + total = 0 + for u in range(num_nodes): + offsets[u] = total + total += counts[u] + offsets[num_nodes] = total + + node_sites = np.empty(total, dtype=np.int32) + node_alleles = np.empty(total, dtype=np.uint8) + insert_pos = offsets.copy() + + for mut_id in range(num_mutations - 1, -1, -1): + node = mutations_node[mut_id] + if 0 <= node < num_nodes: + site = mutations_site[mut_id] + allele = mutations_derived_state[mut_id] + if len(allele) != 1: + raise ValueError("Expected single-character derived alleles") + pos = insert_pos[node] + node_sites[pos] = site + node_alleles[pos] = ord(allele[0]) + insert_pos[node] += 1 + + return node_sites, node_alleles, offsets + + +def _compute_next_site_index(numba_ts, sequence_length): + num_sites = numba_ts.num_sites + next_site = np.empty(sequence_length + 1, dtype=np.int32) + site_positions = numba_ts.sites_position.astype(np.int64) + j = 0 + for pos in range(sequence_length + 1): + while j < num_sites and site_positions[j] < pos: + j += 1 + next_site[pos] = j + return next_site + + +@numba.njit +def _node_haplotype( + numba_ts, + parent_index, + edge_start_index, + edge_end_index, + next_site_index, + node_mut_sites, + node_mut_alleles, + node_mut_offsets, + ancestral_codes, + node, + hap, + unresolved_bits, + stack_edges, + stack_start, + stack_end, + parent_interval_start, + parent_interval_end, +): + num_sites = ancestral_codes.shape[0] + if num_sites == 0: + return + for j in range(num_sites): + hap[j] = ancestral_codes[j] + _bitset_init(unresolved_bits, num_sites) + + edges_parent = numba_ts.edges_parent + edge_index = parent_index.edge_index + index_range = parent_index.index_range + + stack_top = 0 + + mut_start = node_mut_offsets[node] + mut_stop = node_mut_offsets[node + 1] + for m in range(mut_start, mut_stop): + site_idx = node_mut_sites[m] + if site_idx >= num_sites: + continue + if _bitset_is_set(unresolved_bits, site_idx): + hap[site_idx] = node_mut_alleles[m] + _bitset_clear(unresolved_bits, site_idx) + + start_edge, stop_edge = index_range[node, 0], index_range[node, 1] + for i in range(start_edge, stop_edge): + edge = edge_index[i] + start_idx = edge_start_index[edge] + end_idx = edge_end_index[edge] + if start_idx >= end_idx: + continue + unresolved = _bitset_next(unresolved_bits, start_idx, num_sites) + if unresolved < end_idx: + stack_edges[stack_top] = edge + stack_start[stack_top] = start_idx + stack_end[stack_top] = end_idx + stack_top += 1 + + while stack_top > 0: + stack_top -= 1 + edge = stack_edges[stack_top] + interval_start = stack_start[stack_top] + interval_end = stack_end[stack_top] + ancestor = edges_parent[edge] + + mut_start = node_mut_offsets[ancestor] + mut_stop = node_mut_offsets[ancestor + 1] + for m in range(mut_start, mut_stop): + site_idx = node_mut_sites[m] + if interval_start <= site_idx < interval_end: + if _bitset_is_set(unresolved_bits, site_idx): + hap[site_idx] = node_mut_alleles[m] + _bitset_clear(unresolved_bits, site_idx) + + parent_count = 0 + start_edge, stop_edge = index_range[ancestor, 0], index_range[ancestor, 1] + for idx in range(start_edge, stop_edge): + parent_edge = edge_index[idx] + parent_start = edge_start_index[parent_edge] + parent_end = edge_end_index[parent_edge] + if parent_start < interval_start: + parent_start = interval_start + if parent_end > interval_end: + parent_end = interval_end + if parent_start >= parent_end: + continue + unresolved = _bitset_next(unresolved_bits, parent_start, num_sites) + if unresolved < parent_end: + # Push this parent edge because it still covers unresolved sites. + stack_edges[stack_top] = parent_edge + stack_start[stack_top] = parent_start + stack_end[stack_top] = parent_end + stack_top += 1 + parent_interval_start[parent_count] = parent_start + parent_interval_end[parent_count] = parent_end + parent_count += 1 + + idx = _bitset_next(unresolved_bits, interval_start, num_sites) + while idx < interval_end: + needs_parent = False + for j in range(parent_count): + if parent_interval_start[j] <= idx < parent_interval_end[j]: + needs_parent = True + break + if needs_parent: + # This site is still covered by a parent edge that hasn't been + # processed yet, so leave it for that ancestor. + idx = _bitset_next(unresolved_bits, idx + 1, num_sites) + else: + # No higher ancestor will supply a mutation, so the ancestral + # allele stands and we can mark the site resolved. + _bitset_clear(unresolved_bits, idx) + idx = _bitset_next(unresolved_bits, idx, num_sites) + + +def alignments(numba_ts): + num_sites = numba_ts.num_sites + sequence_length = numba_ts.sequence_length + if not float(sequence_length).is_integer(): + raise ValueError("This prototype requires discrete genomic coordinates") + sequence_length = int(sequence_length) + + ancestral_codes = np.empty(num_sites, dtype=np.uint8) + for site_id in range(num_sites): + allele = numba_ts.sites_ancestral_state[site_id] + if len(allele) != 1: + raise ValueError("Expected single-character ancestral alleles") + ancestral_codes[site_id] = ord(allele[0]) + + node_mut_sites, node_mut_alleles, node_mut_offsets = _build_node_mutation_index( + numba_ts + ) + next_site_index = _compute_next_site_index(numba_ts, sequence_length) + parent_index = numba_ts.parent_index() + edge_start_index = np.empty(numba_ts.num_edges, dtype=np.int32) + edge_end_index = np.empty(numba_ts.num_edges, dtype=np.int32) + seq_len = int(sequence_length) + for e in range(numba_ts.num_edges): + left = int(numba_ts.edges_left[e]) + if left < 0: + left = 0 + if left > seq_len: + left = seq_len + right = int(numba_ts.edges_right[e]) + if right < 0: + right = 0 + if right > seq_len: + right = seq_len + edge_start_index[e] = next_site_index[left] + edge_end_index[e] = next_site_index[right] + + hap = np.empty(num_sites, dtype=np.uint8) + bitset_size = (num_sites + 63) // 64 + unresolved_bits = np.empty(bitset_size, dtype=np.uint64) + stack_edges = np.empty(numba_ts.num_edges, dtype=np.int32) + stack_start = np.empty(numba_ts.num_edges, dtype=np.int32) + stack_end = np.empty(numba_ts.num_edges, dtype=np.int32) + parent_interval_start = np.empty(numba_ts.num_edges, dtype=np.int32) + parent_interval_end = np.empty(numba_ts.num_edges, dtype=np.int32) + + nodes_flags = numba_ts.nodes_flags + for node in range(numba_ts.num_nodes): + if nodes_flags[node] & NODE_IS_SAMPLE: + # Decode the haplotype for this sample node using the ancestor walk. + _node_haplotype( + numba_ts, + parent_index, + edge_start_index, + edge_end_index, + next_site_index, + node_mut_sites, + node_mut_alleles, + node_mut_offsets, + ancestral_codes, + node, + hap, + unresolved_bits, + stack_edges, + stack_start, + stack_end, + parent_interval_start, + parent_interval_end, + ) + yield int(node), hap.tobytes().decode("ascii") From 5ec04efc07fffd42f193981b7cc0cda28d48b5c1 Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Thu, 16 Oct 2025 00:13:10 +0100 Subject: [PATCH 02/11] First rough --- c/tskit/genotypes.c | 660 ++++++++++++++++++++++++++++++++++++++++++ c/tskit/genotypes.h | 34 +++ python/_tskitmodule.c | 173 +++++++++++ python/tskit/trees.py | 115 +++++--- 4 files changed, 949 insertions(+), 33 deletions(-) diff --git a/c/tskit/genotypes.c b/c/tskit/genotypes.c index c2385281bd..9bec16e57e 100644 --- a/c/tskit/genotypes.c +++ b/c/tskit/genotypes.c @@ -23,13 +23,673 @@ * SOFTWARE. */ +#include #include #include #include #include +#include +#include + +#if defined(_MSC_VER) +#include +#endif #include +typedef struct { + tsk_id_t edge_id; + tsk_id_t child; + double left; +} tsk_haplotype_edge_sort_t; + +static int +tsk_haplotype_edge_sort_cmp(const void *aa, const void *bb) +{ + const tsk_haplotype_edge_sort_t *a = (const tsk_haplotype_edge_sort_t *) aa; + const tsk_haplotype_edge_sort_t *b = (const tsk_haplotype_edge_sort_t *) bb; + + if (a->child == b->child) { + if (a->left < b->left) { + return -1; + } else if (a->left > b->left) { + return 1; + } + if (a->edge_id < b->edge_id) { + return -1; + } else if (a->edge_id > b->edge_id) { + return 1; + } + return 0; + } + return a->child < b->child ? -1 : 1; +} + +static inline uint32_t +tsk_haplotype_ctz64(uint64_t x) +{ +#if defined(_MSC_VER) + unsigned long index; + _BitScanForward64(&index, x); + return (uint32_t) index; +#else + return (uint32_t) __builtin_ctzll(x); +#endif +} + +static inline void +tsk_haplotype_bitset_clear(uint64_t *bits, tsk_size_t idx) +{ + tsk_size_t word = idx >> 6; + uint64_t mask = UINT64_C(1) << (idx & 63); + bits[word] &= ~mask; +} + +static inline tsk_size_t +tsk_haplotype_bitset_next( + const uint64_t *bits, tsk_size_t num_words, tsk_size_t start, tsk_size_t limit) +{ + tsk_size_t word = start >> 6; + uint64_t mask, value; + + if (start >= limit || word >= num_words) { + return limit; + } + mask = UINT64_MAX << (start & 63); + value = bits[word] & mask; + while (value == 0) { + word++; + if (word >= num_words) { + return limit; + } + value = bits[word]; + } + start = (word << 6) + tsk_haplotype_ctz64(value); + return start < limit ? start : limit; +} + +static void +tsk_haplotype_reset_bitset(const tsk_haplotype_t *self) +{ + if (self->num_bit_words > 0) { + tsk_memcpy(self->unresolved_bits, self->initial_bits, + self->num_bit_words * sizeof(*self->unresolved_bits)); + } +} + +static int +tsk_haplotype_build_child_index(tsk_haplotype_t *self) +{ + int ret = 0; + tsk_size_t j; + const tsk_table_collection_t *tables = self->tree_sequence->tables; + const tsk_edge_table_t *edges = &tables->edges; + tsk_size_t num_edges = edges->num_rows; + tsk_haplotype_edge_sort_t *sorted = NULL; + + if (num_edges == 0) { + self->child_order = NULL; + self->child_offsets + = tsk_calloc(self->num_nodes + 1, sizeof(*self->child_offsets)); + if (self->child_offsets == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + goto out; + } + + sorted = tsk_malloc(num_edges * sizeof(*sorted)); + self->child_order = tsk_malloc(num_edges * sizeof(*self->child_order)); + self->child_offsets = tsk_calloc(self->num_nodes + 1, sizeof(*self->child_offsets)); + if (sorted == NULL || self->child_order == NULL || self->child_offsets == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + + for (j = 0; j < num_edges; j++) { + sorted[j].edge_id = (tsk_id_t) j; + sorted[j].child = edges->child[j]; + sorted[j].left = edges->left[j]; + } + qsort(sorted, num_edges, sizeof(*sorted), tsk_haplotype_edge_sort_cmp); + + for (j = 0; j < num_edges; j++) { + tsk_id_t child = sorted[j].child; + if (child >= 0 && child < (tsk_id_t) self->num_nodes) { + self->child_offsets[child + 1]++; + } + self->child_order[j] = sorted[j].edge_id; + } + for (j = 0; j < self->num_nodes; j++) { + self->child_offsets[j + 1] += self->child_offsets[j]; + } + +out: + tsk_safe_free(sorted); + return ret; +} + +static int +tsk_haplotype_build_mutation_index(tsk_haplotype_t *self) +{ + int ret = 0; + tsk_size_t j; + const tsk_table_collection_t *tables = self->tree_sequence->tables; + const tsk_mutation_table_t *mutations = &tables->mutations; + int32_t *counts = NULL; + tsk_size_t total_mutations = 0; + tsk_id_t site_start = self->site_start; + tsk_id_t site_stop = self->site_stop; + + counts = tsk_calloc(self->num_nodes, sizeof(*counts)); + if (self->num_nodes > 0 && counts == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + + for (j = 0; j < mutations->num_rows; j++) { + tsk_id_t node = mutations->node[j]; + tsk_id_t site = mutations->site[j]; + if (site < site_start || site >= site_stop) { + continue; + } + if (node >= 0 && node < (tsk_id_t) self->num_nodes) { + if (counts[node] == INT32_MAX) { + ret = tsk_trace_error(TSK_ERR_UNSUPPORTED_OPERATION); + goto out; + } + counts[node]++; + } + } + + self->node_mutation_offsets + = tsk_malloc((self->num_nodes + 1) * sizeof(*self->node_mutation_offsets)); + if (self->node_mutation_offsets == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + self->node_mutation_offsets[0] = 0; + for (j = 0; j < self->num_nodes; j++) { + total_mutations += counts[j]; + if (total_mutations > INT32_MAX) { + ret = tsk_trace_error(TSK_ERR_UNSUPPORTED_OPERATION); + goto out; + } + self->node_mutation_offsets[j + 1] = (int32_t) total_mutations; + } + + self->node_mutation_sites + = tsk_malloc(total_mutations * sizeof(*self->node_mutation_sites)); + self->node_mutation_states + = tsk_malloc(total_mutations * sizeof(*self->node_mutation_states)); + if ((total_mutations > 0) + && (self->node_mutation_sites == NULL || self->node_mutation_states == NULL)) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + + for (j = 0; j < self->num_nodes; j++) { + counts[j] = self->node_mutation_offsets[j]; + } + for (j = 0; j < mutations->num_rows; j++) { + tsk_id_t node = mutations->node[j]; + tsk_id_t site = mutations->site[j]; + if (site < site_start || site >= site_stop) { + continue; + } + if (node >= 0 && node < (tsk_id_t) self->num_nodes) { + tsk_size_t start = mutations->derived_state_offset[j]; + tsk_size_t stop = mutations->derived_state_offset[j + 1]; + tsk_size_t length = stop - start; + uint8_t allele; + + if (length != 1) { + ret = tsk_trace_error(TSK_ERR_UNSUPPORTED_OPERATION); + goto out; + } + allele = (uint8_t) mutations->derived_state[start]; + if (allele > 0x7F) { + ret = tsk_trace_error(TSK_ERR_UNSUPPORTED_OPERATION); + goto out; + } + if (allele == (uint8_t) self->missing_char) { + ret = tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); + goto out; + } + self->node_mutation_sites[counts[node]] = (int32_t)(site - site_start); + self->node_mutation_states[counts[node]] = allele; + counts[node]++; + } + } + +out: + tsk_safe_free(counts); + return ret; +} + +static int +tsk_haplotype_build_ancestral_states(tsk_haplotype_t *self) +{ + int ret = 0; + const tsk_table_collection_t *tables = self->tree_sequence->tables; + const tsk_site_table_t *sites = &tables->sites; + tsk_id_t site_start = self->site_start; + tsk_id_t site_stop = self->site_stop; + tsk_size_t j; + + if (self->num_sites == 0) { + self->ancestral_states = NULL; + return 0; + } + + self->ancestral_states + = tsk_malloc(self->num_sites * sizeof(*self->ancestral_states)); + if (self->ancestral_states == NULL) { + return tsk_trace_error(TSK_ERR_NO_MEMORY); + } + + for (j = 0; j < (tsk_size_t) self->num_sites; j++) { + tsk_id_t site = site_start + (tsk_id_t) j; + tsk_size_t start = sites->ancestral_state_offset[site]; + tsk_size_t stop = sites->ancestral_state_offset[site + 1]; + tsk_size_t length = stop - start; + uint8_t allele; + if (length != 1) { + ret = tsk_trace_error(TSK_ERR_UNSUPPORTED_OPERATION); + goto out; + } + allele = (uint8_t) sites->ancestral_state[start]; + if (allele > 0x7F) { + ret = tsk_trace_error(TSK_ERR_UNSUPPORTED_OPERATION); + goto out; + } + if (allele == (uint8_t) self->missing_char) { + ret = tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); + goto out; + } + self->ancestral_states[j] = allele; + } + +out: + if (ret != 0) { + tsk_safe_free(self->ancestral_states); + self->ancestral_states = NULL; + } + return ret; +} + +static int +tsk_haplotype_build_edge_intervals(tsk_haplotype_t *self) +{ + int ret = 0; + const tsk_table_collection_t *tables = self->tree_sequence->tables; + const tsk_edge_table_t *edges = &tables->edges; + const double *positions = tables->sites.position + self->site_start; + tsk_size_t num_edges = edges->num_rows; + tsk_size_t j; + + if (num_edges == 0) { + self->edge_start_index = NULL; + self->edge_end_index = NULL; + return 0; + } + + self->edge_start_index = tsk_malloc(num_edges * sizeof(*self->edge_start_index)); + self->edge_end_index = tsk_malloc(num_edges * sizeof(*self->edge_end_index)); + if (self->edge_start_index == NULL || self->edge_end_index == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + + if (self->num_sites == 0) { + for (j = 0; j < num_edges; j++) { + self->edge_start_index[j] = 0; + self->edge_end_index[j] = 0; + } + goto out; + } + + for (j = 0; j < num_edges; j++) { + double left = edges->left[j]; + double right = edges->right[j]; + tsk_size_t start + = tsk_search_sorted(positions, (tsk_size_t) self->num_sites, left); + tsk_size_t end + = tsk_search_sorted(positions, (tsk_size_t) self->num_sites, right); + if (start > (tsk_size_t) self->num_sites) { + start = (tsk_size_t) self->num_sites; + } + if (end > (tsk_size_t) self->num_sites) { + end = (tsk_size_t) self->num_sites; + } + self->edge_start_index[j] = (int32_t) start; + self->edge_end_index[j] = (int32_t) end; + } + +out: + if (ret != 0) { + tsk_safe_free(self->edge_start_index); + tsk_safe_free(self->edge_end_index); + self->edge_start_index = NULL; + self->edge_end_index = NULL; + } + return ret; +} + +static int +tsk_haplotype_alloc_bitset(tsk_haplotype_t *self) +{ + tsk_size_t j; + + self->num_bit_words = ((tsk_size_t) self->num_sites + 63) >> 6; + if (self->num_bit_words == 0) { + self->unresolved_bits = NULL; + self->initial_bits = NULL; + return 0; + } + self->unresolved_bits + = tsk_malloc(self->num_bit_words * sizeof(*self->unresolved_bits)); + self->initial_bits = tsk_malloc(self->num_bit_words * sizeof(*self->initial_bits)); + if (self->unresolved_bits == NULL || self->initial_bits == NULL) { + return tsk_trace_error(TSK_ERR_NO_MEMORY); + } + for (j = 0; j < self->num_bit_words; j++) { + self->initial_bits[j] = UINT64_MAX; + } + if ((tsk_size_t) self->num_sites % 64 != 0) { + uint32_t bits = (uint32_t)((tsk_size_t) self->num_sites & 63); + self->initial_bits[self->num_bit_words - 1] = (UINT64_C(1) << bits) - 1; + } + return 0; +} + +int +tsk_haplotype_init(tsk_haplotype_t *self, const tsk_treeseq_t *tree_sequence, + tsk_id_t site_start, tsk_id_t site_stop, int8_t missing_char, tsk_flags_t options) +{ + int ret = 0; + const tsk_table_collection_t *tables; + const tsk_site_table_t *sites; + tsk_size_t total_sites; + + if (tree_sequence == NULL) { + return tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); + } + + tsk_memset(self, 0, sizeof(*self)); + self->tree_sequence = tree_sequence; + self->missing_char = missing_char; + self->isolated_as_missing = !(options & TSK_ISOLATED_NOT_MISSING); + + tables = tree_sequence->tables; + sites = &tables->sites; + total_sites = sites->num_rows; + + if (site_start < 0 || site_stop < site_start || site_stop > (tsk_id_t) total_sites) { + ret = tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); + goto out; + } + if ((unsigned char) missing_char > 0x7F) { + ret = tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); + goto out; + } + + self->site_start = (int32_t) site_start; + self->site_stop = (int32_t) site_stop; + self->num_sites = (int32_t)(site_stop - site_start); + self->num_nodes = tables->nodes.num_rows; + self->num_edges = tables->edges.num_rows; + self->node_flags = tables->nodes.flags; + self->site_positions = sites->position + site_start; + + if (!tsk_treeseq_get_discrete_genome(tree_sequence)) { + ret = tsk_trace_error(TSK_ERR_UNSUPPORTED_OPERATION); + goto out; + } + + ret = tsk_haplotype_build_child_index(self); + if (ret != 0) { + goto out; + } + ret = tsk_haplotype_build_mutation_index(self); + if (ret != 0) { + goto out; + } + ret = tsk_haplotype_build_ancestral_states(self); + if (ret != 0) { + goto out; + } + ret = tsk_haplotype_build_edge_intervals(self); + if (ret != 0) { + goto out; + } + ret = tsk_haplotype_alloc_bitset(self); + if (ret != 0) { + goto out; + } + if (self->num_edges > 0) { + self->edge_stack = tsk_malloc(self->num_edges * sizeof(*self->edge_stack)); + self->stack_interval_start + = tsk_malloc(self->num_edges * sizeof(*self->stack_interval_start)); + self->stack_interval_end + = tsk_malloc(self->num_edges * sizeof(*self->stack_interval_end)); + self->parent_interval_start + = tsk_malloc(self->num_edges * sizeof(*self->parent_interval_start)); + self->parent_interval_end + = tsk_malloc(self->num_edges * sizeof(*self->parent_interval_end)); + if (self->edge_stack == NULL || self->stack_interval_start == NULL + || self->stack_interval_end == NULL || self->parent_interval_start == NULL + || self->parent_interval_end == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + } + + self->initialised = true; + +out: + if (ret != 0) { + tsk_haplotype_free(self); + } + return ret; +} + +int +tsk_haplotype_decode(tsk_haplotype_t *self, tsk_id_t node, int8_t *haplotype) +{ + tsk_size_t stack_top = 0; + const tsk_table_collection_t *tables; + const tsk_edge_table_t *edges; + const tsk_id_t *edge_parent; + int32_t interval_start, interval_end; + int32_t mut_start, mut_end; + tsk_size_t idx; + tsk_size_t parent_count; + uint64_t *bits; + + if (self == NULL || haplotype == NULL) { + return tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); + } + if (!self->initialised) { + return tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); + } + if (node < 0 || node >= (tsk_id_t) self->num_nodes) { + return tsk_trace_error(TSK_ERR_NODE_OUT_OF_BOUNDS); + } + if (self->isolated_as_missing && !(self->node_flags[node] & TSK_NODE_IS_SAMPLE)) { + return tsk_trace_error(TSK_ERR_MUST_IMPUTE_NON_SAMPLES); + } + if (self->num_sites == 0) { + return 0; + } + + tables = self->tree_sequence->tables; + edges = &tables->edges; + edge_parent = edges->parent; + bits = self->unresolved_bits; + + for (idx = 0; idx < (tsk_size_t) self->num_sites; idx++) { + haplotype[idx] = (int8_t) self->ancestral_states[idx]; + } + tsk_haplotype_reset_bitset(self); + + mut_start = self->node_mutation_offsets[node]; + mut_end = self->node_mutation_offsets[node + 1]; + for (int32_t m = mut_start; m < mut_end; m++) { + int32_t site = self->node_mutation_sites[m]; + if (site >= 0 && site < self->num_sites) { + haplotype[site] = (int8_t) self->node_mutation_states[m]; + tsk_haplotype_bitset_clear(bits, (tsk_size_t) site); + } + } + + int32_t child_start = self->child_offsets[node]; + int32_t child_stop = self->child_offsets[node + 1]; + for (int32_t i = child_start; i < child_stop; i++) { + tsk_id_t edge = self->child_order[i]; + int32_t start = self->edge_start_index[edge]; + int32_t end = self->edge_end_index[edge]; + if (start >= end) { + continue; + } + if (tsk_haplotype_bitset_next( + bits, self->num_bit_words, (tsk_size_t) start, (tsk_size_t) end) + < (tsk_size_t) end) { + self->edge_stack[stack_top] = edge; + self->stack_interval_start[stack_top] = start; + self->stack_interval_end[stack_top] = end; + stack_top++; + } + } + + while (stack_top > 0) { + stack_top--; + tsk_id_t edge = self->edge_stack[stack_top]; + tsk_id_t ancestor = edge_parent[edge]; + interval_start = self->stack_interval_start[stack_top]; + interval_end = self->stack_interval_end[stack_top]; + + if (ancestor >= 0) { + mut_start = self->node_mutation_offsets[ancestor]; + mut_end = self->node_mutation_offsets[ancestor + 1]; + for (int32_t m = mut_start; m < mut_end; m++) { + int32_t site = self->node_mutation_sites[m]; + if (site >= interval_start && site < interval_end + && tsk_haplotype_bitset_next(bits, self->num_bit_words, + (tsk_size_t) site, (tsk_size_t) site + 1) + == (tsk_size_t) site) { + haplotype[site] = (int8_t) self->node_mutation_states[m]; + tsk_haplotype_bitset_clear(bits, (tsk_size_t) site); + } + } + } + + parent_count = 0; + if (ancestor >= 0) { + child_start = self->child_offsets[ancestor]; + child_stop = self->child_offsets[ancestor + 1]; + for (int32_t i = child_start; i < child_stop; i++) { + tsk_id_t parent_edge = self->child_order[i]; + int32_t parent_start = self->edge_start_index[parent_edge]; + int32_t parent_end = self->edge_end_index[parent_edge]; + if (parent_start < interval_start) { + parent_start = interval_start; + } + if (parent_end > interval_end) { + parent_end = interval_end; + } + if (parent_start >= parent_end) { + continue; + } + if (tsk_haplotype_bitset_next(bits, self->num_bit_words, + (tsk_size_t) parent_start, (tsk_size_t) parent_end) + < (tsk_size_t) parent_end) { + self->edge_stack[stack_top] = parent_edge; + self->stack_interval_start[stack_top] = parent_start; + self->stack_interval_end[stack_top] = parent_end; + stack_top++; + self->parent_interval_start[parent_count] = parent_start; + self->parent_interval_end[parent_count] = parent_end; + parent_count++; + } + } + } + + idx = tsk_haplotype_bitset_next(bits, self->num_bit_words, + (tsk_size_t) interval_start, (tsk_size_t) interval_end); + while ((int32_t) idx < interval_end) { + bool covered = false; + for (tsk_size_t p = 0; p < parent_count; p++) { + if (self->parent_interval_start[p] <= (int32_t) idx + && (int32_t) idx < self->parent_interval_end[p]) { + covered = true; + break; + } + } + if (covered) { + idx = tsk_haplotype_bitset_next( + bits, self->num_bit_words, idx + 1, (tsk_size_t) interval_end); + } else { + if (self->isolated_as_missing) { + haplotype[idx] = self->missing_char; + } + tsk_haplotype_bitset_clear(bits, idx); + idx = tsk_haplotype_bitset_next( + bits, self->num_bit_words, idx, (tsk_size_t) interval_end); + } + } + } + + if (self->isolated_as_missing) { + idx = tsk_haplotype_bitset_next( + bits, self->num_bit_words, 0, (tsk_size_t) self->num_sites); + while (idx < (tsk_size_t) self->num_sites) { + haplotype[idx] = self->missing_char; + tsk_haplotype_bitset_clear(bits, idx); + idx = tsk_haplotype_bitset_next( + bits, self->num_bit_words, idx, (tsk_size_t) self->num_sites); + } + } else { + idx = tsk_haplotype_bitset_next( + bits, self->num_bit_words, 0, (tsk_size_t) self->num_sites); + while (idx < (tsk_size_t) self->num_sites) { + tsk_haplotype_bitset_clear(bits, idx); + idx = tsk_haplotype_bitset_next( + bits, self->num_bit_words, idx, (tsk_size_t) self->num_sites); + } + } + + return 0; +} + +int +tsk_haplotype_free(tsk_haplotype_t *self) +{ + if (self == NULL) { + return 0; + } + tsk_safe_free(self->ancestral_states); + tsk_safe_free(self->node_mutation_offsets); + tsk_safe_free(self->node_mutation_sites); + tsk_safe_free(self->node_mutation_states); + tsk_safe_free(self->child_order); + tsk_safe_free(self->child_offsets); + tsk_safe_free(self->edge_start_index); + tsk_safe_free(self->edge_end_index); + tsk_safe_free(self->edge_stack); + tsk_safe_free(self->stack_interval_start); + tsk_safe_free(self->stack_interval_end); + tsk_safe_free(self->parent_interval_start); + tsk_safe_free(self->parent_interval_end); + tsk_safe_free(self->unresolved_bits); + tsk_safe_free(self->initial_bits); + self->tree_sequence = NULL; + self->node_flags = NULL; + self->site_positions = NULL; + self->initialised = false; + return 0; +} + /* ======================================================== * * Variant generator * ======================================================== */ diff --git a/c/tskit/genotypes.h b/c/tskit/genotypes.h index 8c3b769e5a..7cf1f31772 100644 --- a/c/tskit/genotypes.h +++ b/c/tskit/genotypes.h @@ -86,6 +86,36 @@ typedef struct { tsk_variant_t variant; } tsk_vargen_t; +typedef struct { + const tsk_treeseq_t *tree_sequence; + tsk_size_t num_nodes; + tsk_size_t num_edges; + int32_t site_start; + int32_t site_stop; + int32_t num_sites; + int8_t missing_char; + bool isolated_as_missing; + const tsk_flags_t *node_flags; + const double *site_positions; + uint8_t *ancestral_states; + int32_t *node_mutation_offsets; + int32_t *node_mutation_sites; + uint8_t *node_mutation_states; + tsk_id_t *child_order; + int32_t *child_offsets; + int32_t *edge_start_index; + int32_t *edge_end_index; + tsk_id_t *edge_stack; + int32_t *stack_interval_start; + int32_t *stack_interval_end; + int32_t *parent_interval_start; + int32_t *parent_interval_end; + uint64_t *unresolved_bits; + uint64_t *initial_bits; + tsk_size_t num_bit_words; + bool initialised; +} tsk_haplotype_t; + /** @defgroup VARIANT_API_GROUP Variant API for obtaining genotypes. @{ @@ -179,6 +209,10 @@ void tsk_variant_print_state(const tsk_variant_t *self, FILE *out); /** @} */ /* Deprecated vargen methods (since C API v1.0) */ +int tsk_haplotype_init(tsk_haplotype_t *self, const tsk_treeseq_t *tree_sequence, + tsk_id_t site_start, tsk_id_t site_stop, int8_t missing_char, tsk_flags_t options); +int tsk_haplotype_decode(tsk_haplotype_t *self, tsk_id_t node, int8_t *haplotype); +int tsk_haplotype_free(tsk_haplotype_t *self); int tsk_vargen_init(tsk_vargen_t *self, const tsk_treeseq_t *tree_sequence, const tsk_id_t *samples, tsk_size_t num_samples, const char **alleles, tsk_flags_t options); diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 23ab663538..3a6b53d6d8 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -154,6 +154,12 @@ typedef struct { tsk_tree_t *tree; } Tree; +typedef struct { + PyObject_HEAD + TreeSequence *tree_sequence; + tsk_haplotype_t *haplotype; +} Haplotype; + typedef struct { PyObject_HEAD TreeSequence *tree_sequence; @@ -10594,6 +10600,166 @@ static PyTypeObject TreeType = { // clang-format on }; +/*=================================================================== + * Haplotype + *=================================================================== + */ + +/* Forward declaration */ +static PyTypeObject HaplotypeType; + +static int +Haplotype_check_state(Haplotype *self) +{ + int ret = 0; + if (self->haplotype == NULL) { + PyErr_SetString(PyExc_SystemError, "haplotype not initialised"); + ret = -1; + } + return ret; +} + +static void +Haplotype_dealloc(Haplotype *self) +{ + if (self->haplotype != NULL) { + tsk_haplotype_free(self->haplotype); + PyMem_Free(self->haplotype); + self->haplotype = NULL; + } + Py_XDECREF(self->tree_sequence); + Py_TYPE(self)->tp_free((PyObject *) self); +} + +static int +Haplotype_init(Haplotype *self, PyObject *args, PyObject *kwds) +{ + int ret = -1; + int err; + static char *kwlist[] = { "tree_sequence", "site_start", "site_stop", + "isolated_as_missing", "missing_data_character", NULL }; + TreeSequence *tree_sequence = NULL; + Py_ssize_t site_start; + Py_ssize_t site_stop; + int isolated_as_missing = 1; + PyObject *missing_obj = NULL; + PyObject *missing_bytes = NULL; + const char *missing_ptr = NULL; + Py_ssize_t missing_length = 0; + char missing_char = 'N'; + tsk_flags_t options = 0; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!nn|pO", kwlist, &TreeSequenceType, + &tree_sequence, &site_start, &site_stop, &isolated_as_missing, + &missing_obj)) { + goto out; + } + + if (missing_obj != NULL && missing_obj != Py_None) { + if (PyBytes_Check(missing_obj)) { + missing_ptr = PyBytes_AS_STRING(missing_obj); + missing_length = PyBytes_GET_SIZE(missing_obj); + } else { + missing_bytes = PyUnicode_AsASCIIString(missing_obj); + if (missing_bytes == NULL) { + goto out; + } + missing_ptr = PyBytes_AS_STRING(missing_bytes); + missing_length = PyBytes_GET_SIZE(missing_bytes); + } + if (missing_length != 1) { + PyErr_SetString(PyExc_ValueError, + "missing_data_character must be a single ASCII character"); + goto out; + } + missing_char = missing_ptr[0]; + } + + self->haplotype = PyMem_Malloc(sizeof(*self->haplotype)); + if (self->haplotype == NULL) { + PyErr_NoMemory(); + goto out; + } + + self->tree_sequence = tree_sequence; + Py_INCREF(tree_sequence); + + if (!isolated_as_missing) { + options |= TSK_ISOLATED_NOT_MISSING; + } + + err = tsk_haplotype_init(self->haplotype, tree_sequence->tree_sequence, + (tsk_id_t) site_start, (tsk_id_t) site_stop, (int8_t) missing_char, options); + if (err != 0) { + handle_library_error(err); + goto out; + } + + ret = 0; +out: + if (ret != 0) { + if (self->haplotype != NULL) { + tsk_haplotype_free(self->haplotype); + PyMem_Free(self->haplotype); + self->haplotype = NULL; + } + Py_XDECREF(self->tree_sequence); + self->tree_sequence = NULL; + } + Py_XDECREF(missing_bytes); + return ret; +} + +static PyObject * +Haplotype_decode(Haplotype *self, PyObject *args) +{ + int err; + PyObject *ret = NULL; + tsk_id_t node; + + if (Haplotype_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTuple(args, "O&", &tsk_id_converter, &node)) { + goto out; + } + + ret = PyBytes_FromStringAndSize(NULL, (Py_ssize_t) self->haplotype->num_sites); + if (ret == NULL) { + goto out; + } + err = tsk_haplotype_decode(self->haplotype, node, (int8_t *) PyBytes_AS_STRING(ret)); + if (err != 0) { + handle_library_error(err); + Py_CLEAR(ret); + goto out; + } +out: + return ret; +} + +static PyMethodDef Haplotype_methods[] = { + { .ml_name = "decode", + .ml_meth = (PyCFunction) Haplotype_decode, + .ml_flags = METH_VARARGS, + .ml_doc = "Decode the haplotype for the specified node." }, + { NULL }, +}; + +static PyTypeObject HaplotypeType = { + // clang-format off + PyVarObject_HEAD_INIT(NULL, 0) + .tp_name = "_tskit.Haplotype", + .tp_basicsize = sizeof(Haplotype), + .tp_dealloc = (destructor) Haplotype_dealloc, + .tp_flags = Py_TPFLAGS_DEFAULT, + .tp_doc = "Low-level haplotype decoder", + .tp_methods = Haplotype_methods, + .tp_init = (initproc) Haplotype_init, + .tp_new = PyType_GenericNew, + // clang-format on +}; + /*=================================================================== * Variant *=================================================================== @@ -11924,6 +12090,13 @@ PyInit__tskit(void) Py_INCREF(&TreeType); PyModule_AddObject(module, "Tree", (PyObject *) &TreeType); + /* Haplotype type */ + if (PyType_Ready(&HaplotypeType) < 0) { + return NULL; + } + Py_INCREF(&HaplotypeType); + PyModule_AddObject(module, "Haplotype", (PyObject *) &HaplotypeType); + /* Variant type */ if (PyType_Ready(&VariantType) < 0) { return NULL; diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 177d9187aa..740279e6c4 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -5262,31 +5262,54 @@ def _haplotypes_array( # return an array of haplotypes and the first and last site positions if missing_data_character is None: missing_data_character = "N" + if len(missing_data_character) != 1: + raise ValueError("missing_data_character must be a single character") + try: + missing_data_character.encode("ascii") + except UnicodeEncodeError: + raise TypeError("missing_data_character must be ASCII") start_site, stop_site = np.searchsorted(self.sites_position, interval) - H = np.empty( - ( - self.num_samples if samples is None else len(samples), - stop_site - start_site, - ), - dtype=np.int8, - ) - missing_int8 = ord(missing_data_character.encode("ascii")) + num_sites = stop_site - start_site + + if samples is None: + sample_nodes = self.samples() + else: + sample_nodes = np.array(samples, dtype=np.int64) + num_samples = len(sample_nodes) + + H = np.empty((num_samples, num_sites), dtype=np.int8) + if num_samples == 0 or num_sites == 0: + return H, (start_site, stop_site - 1) + + missing_int8 = ord(missing_data_character) + for var in self.variants( samples=samples, isolated_as_missing=isolated_as_missing, left=interval.left, right=interval.right, ): - alleles = np.full(len(var.alleles), missing_int8, dtype=np.int8) - for i, allele in enumerate(var.alleles): - if allele is not None: - if len(allele) != 1: - raise TypeError( - "Multi-letter allele or deletion detected at site {}".format( - var.site.id - ) + ancestral = var.site.ancestral_state + if ancestral is None or len(ancestral) != 1: + raise TypeError( + "Multi-letter allele or deletion detected at site {}".format( + var.site.id + ) + ) + else: + raise TypeError(f"Non-ascii character in allele at site {var.site.id}") + alleles = var.alleles + for allele in alleles: + if allele is None: + continue + if len(allele) != 1: + raise TypeError( + "Multi-letter allele or deletion detected at site {}".format( + var.site.id ) + ) + if isinstance(allele, str): try: ascii_allele = allele.encode("ascii") except UnicodeEncodeError: @@ -5295,16 +5318,35 @@ def _haplotypes_array( var.site.id ) ) - allele_int8 = ord(ascii_allele) - if allele_int8 == missing_int8: - raise ValueError( - "The missing data character '{}' clashes with an " - "existing allele at site {}".format( - missing_data_character, var.site.id - ) + elif isinstance(allele, bytes): + ascii_allele = allele + else: + raise TypeError( + f"Non-ascii character in allele at site {var.site.id}" + ) + allele_int8 = ascii_allele[0] + if allele_int8 == missing_int8: + raise ValueError( + "The missing data character '{}' clashes with an " + "existing allele at site {}".format( + missing_data_character, var.site.id ) - alleles[i] = allele_int8 - H[:, var.site.id - start_site] = alleles[var.genotypes] + ) + + isolated = isolated_as_missing if isolated_as_missing is not None else True + + hap = _tskit.Haplotype( + self._ll_tree_sequence, + int(start_site), + int(stop_site), + isolated_as_missing=bool(isolated), + missing_data_character=missing_data_character, + ) + + for row, node in enumerate(sample_nodes): + data = hap.decode(int(node)) + H[row, :] = np.frombuffer(data, dtype=np.int8, count=num_sites) + return H, (start_site, stop_site - 1) def haplotypes( @@ -5702,9 +5744,10 @@ def alignments( missing_data_character = ( "N" if missing_data_character is None else missing_data_character ) + if len(missing_data_character) != 1: + raise ValueError("missing_data_character must be length 1") L = interval.span - a = np.empty(L, dtype=np.int8) if reference_sequence is None: if self.has_reference_sequence(): # This may be inefficient - see #1989. However, since we're @@ -5728,7 +5771,7 @@ def alignments( "The reference sequence ends before the requested stop position" ) ref_bytes = reference_sequence.encode("ascii") - a[:] = np.frombuffer(ref_bytes, dtype=np.int8) + a = np.frombuffer(ref_bytes, dtype=np.int8).copy() # To do this properly we'll have to detect the missing data as # part of a full implementation of alignments in C. The current @@ -5748,15 +5791,21 @@ def alignments( ) H, (first_site_id, last_site_id) = self._haplotypes_array( interval=interval, + isolated_as_missing=False, missing_data_character=missing_data_character, samples=samples, ) - site_pos = self.sites_position.astype(np.int64)[ - first_site_id : last_site_id + 1 - ] - for h in H: - a[site_pos - interval.left] = h - yield a.tobytes().decode("ascii") + if first_site_id <= last_site_id: + site_pos = self.sites_position.astype(np.int64)[ + first_site_id : last_site_id + 1 + ] + else: + site_pos = np.array([], dtype=np.int64) + for hap in H: + a_copy = a.copy() + if site_pos.size > 0: + a_copy[site_pos - interval.left] = hap + yield a_copy.tobytes().decode("ascii") @property def individuals_population(self): From 1330b43aa5a417957ce8dbadf26062d3b167b887 Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Thu, 16 Oct 2025 02:14:58 +0100 Subject: [PATCH 03/11] Tests pass --- c/tskit/genotypes.c | 64 ++++++++++++++++------------- c/tskit/genotypes.h | 1 + python/tskit/trees.py | 96 +++++++++++++++++++++---------------------- 3 files changed, 83 insertions(+), 78 deletions(-) diff --git a/c/tskit/genotypes.c b/c/tskit/genotypes.c index 9bec16e57e..df452a9546 100644 --- a/c/tskit/genotypes.c +++ b/c/tskit/genotypes.c @@ -127,6 +127,13 @@ tsk_haplotype_build_child_index(tsk_haplotype_t *self) tsk_size_t num_edges = edges->num_rows; tsk_haplotype_edge_sort_t *sorted = NULL; + self->parent_edge_counts + = tsk_calloc(self->num_nodes, sizeof(*self->parent_edge_counts)); + if (self->num_nodes > 0 && self->parent_edge_counts == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + if (num_edges == 0) { self->child_order = NULL; self->child_offsets @@ -150,6 +157,14 @@ tsk_haplotype_build_child_index(tsk_haplotype_t *self) sorted[j].edge_id = (tsk_id_t) j; sorted[j].child = edges->child[j]; sorted[j].left = edges->left[j]; + tsk_id_t parent = edges->parent[j]; + if (parent >= 0 && parent < (tsk_id_t) self->num_nodes) { + if (self->parent_edge_counts[parent] == INT32_MAX) { + ret = tsk_trace_error(TSK_ERR_UNSUPPORTED_OPERATION); + goto out; + } + self->parent_edge_counts[parent]++; + } } qsort(sorted, num_edges, sizeof(*sorted), tsk_haplotype_edge_sort_cmp); @@ -231,15 +246,16 @@ tsk_haplotype_build_mutation_index(tsk_haplotype_t *self) for (j = 0; j < self->num_nodes; j++) { counts[j] = self->node_mutation_offsets[j]; } - for (j = 0; j < mutations->num_rows; j++) { - tsk_id_t node = mutations->node[j]; - tsk_id_t site = mutations->site[j]; + for (j = mutations->num_rows; j > 0; j--) { + tsk_size_t mut_index = j - 1; + tsk_id_t node = mutations->node[mut_index]; + tsk_id_t site = mutations->site[mut_index]; if (site < site_start || site >= site_stop) { continue; } if (node >= 0 && node < (tsk_id_t) self->num_nodes) { - tsk_size_t start = mutations->derived_state_offset[j]; - tsk_size_t stop = mutations->derived_state_offset[j + 1]; + tsk_size_t start = mutations->derived_state_offset[mut_index]; + tsk_size_t stop = mutations->derived_state_offset[mut_index + 1]; tsk_size_t length = stop - start; uint8_t allele; @@ -442,11 +458,6 @@ tsk_haplotype_init(tsk_haplotype_t *self, const tsk_treeseq_t *tree_sequence, self->node_flags = tables->nodes.flags; self->site_positions = sites->position + site_start; - if (!tsk_treeseq_get_discrete_genome(tree_sequence)) { - ret = tsk_trace_error(TSK_ERR_UNSUPPORTED_OPERATION); - goto out; - } - ret = tsk_haplotype_build_child_index(self); if (ret != 0) { goto out; @@ -537,7 +548,10 @@ tsk_haplotype_decode(tsk_haplotype_t *self, tsk_id_t node, int8_t *haplotype) mut_end = self->node_mutation_offsets[node + 1]; for (int32_t m = mut_start; m < mut_end; m++) { int32_t site = self->node_mutation_sites[m]; - if (site >= 0 && site < self->num_sites) { + if (site >= 0 && site < self->num_sites + && tsk_haplotype_bitset_next( + bits, self->num_bit_words, (tsk_size_t) site, (tsk_size_t) site + 1) + == (tsk_size_t) site) { haplotype[site] = (int8_t) self->node_mutation_states[m]; tsk_haplotype_bitset_clear(bits, (tsk_size_t) site); } @@ -630,9 +644,6 @@ tsk_haplotype_decode(tsk_haplotype_t *self, tsk_id_t node, int8_t *haplotype) idx = tsk_haplotype_bitset_next( bits, self->num_bit_words, idx + 1, (tsk_size_t) interval_end); } else { - if (self->isolated_as_missing) { - haplotype[idx] = self->missing_char; - } tsk_haplotype_bitset_clear(bits, idx); idx = tsk_haplotype_bitset_next( bits, self->num_bit_words, idx, (tsk_size_t) interval_end); @@ -640,23 +651,19 @@ tsk_haplotype_decode(tsk_haplotype_t *self, tsk_id_t node, int8_t *haplotype) } } - if (self->isolated_as_missing) { - idx = tsk_haplotype_bitset_next( - bits, self->num_bit_words, 0, (tsk_size_t) self->num_sites); - while (idx < (tsk_size_t) self->num_sites) { + bool has_incoming = self->child_offsets[node + 1] != self->child_offsets[node]; + bool has_outgoing + = self->parent_edge_counts != NULL && self->parent_edge_counts[node] > 0; + bool mark_missing = self->isolated_as_missing && !(has_incoming || has_outgoing); + idx = tsk_haplotype_bitset_next( + bits, self->num_bit_words, 0, (tsk_size_t) self->num_sites); + while (idx < (tsk_size_t) self->num_sites) { + if (mark_missing) { haplotype[idx] = self->missing_char; - tsk_haplotype_bitset_clear(bits, idx); - idx = tsk_haplotype_bitset_next( - bits, self->num_bit_words, idx, (tsk_size_t) self->num_sites); } - } else { + tsk_haplotype_bitset_clear(bits, idx); idx = tsk_haplotype_bitset_next( - bits, self->num_bit_words, 0, (tsk_size_t) self->num_sites); - while (idx < (tsk_size_t) self->num_sites) { - tsk_haplotype_bitset_clear(bits, idx); - idx = tsk_haplotype_bitset_next( - bits, self->num_bit_words, idx, (tsk_size_t) self->num_sites); - } + bits, self->num_bit_words, idx, (tsk_size_t) self->num_sites); } return 0; @@ -674,6 +681,7 @@ tsk_haplotype_free(tsk_haplotype_t *self) tsk_safe_free(self->node_mutation_states); tsk_safe_free(self->child_order); tsk_safe_free(self->child_offsets); + tsk_safe_free(self->parent_edge_counts); tsk_safe_free(self->edge_start_index); tsk_safe_free(self->edge_end_index); tsk_safe_free(self->edge_stack); diff --git a/c/tskit/genotypes.h b/c/tskit/genotypes.h index 7cf1f31772..97e7b3ddc8 100644 --- a/c/tskit/genotypes.h +++ b/c/tskit/genotypes.h @@ -103,6 +103,7 @@ typedef struct { uint8_t *node_mutation_states; tsk_id_t *child_order; int32_t *child_offsets; + int32_t *parent_edge_counts; int32_t *edge_start_index; int32_t *edge_end_index; tsk_id_t *edge_stack; diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 740279e6c4..92b1d676ad 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -46,6 +46,7 @@ import tskit import tskit.combinatorics as combinatorics import tskit.drawing as drawing +import tskit.exceptions as exceptions import tskit.metadata as metadata_module import tskit.provenance as provenance import tskit.tables as tables @@ -5278,75 +5279,61 @@ def _haplotypes_array( sample_nodes = np.array(samples, dtype=np.int64) num_samples = len(sample_nodes) + want_missing = ( + True if isolated_as_missing is None else bool(isolated_as_missing) + ) + if want_missing and samples is not None and num_samples > 0: + flags = self.nodes_flags[sample_nodes] + if np.any((flags & NODE_IS_SAMPLE) == 0): + raise exceptions.LibraryError( + "Cannot generate genotypes for non-samples when isolated nodes " + "are considered as missing. (TSK_ERR_MUST_IMPUTE_NON_SAMPLES)" + ) + H = np.empty((num_samples, num_sites), dtype=np.int8) if num_samples == 0 or num_sites == 0: return H, (start_site, stop_site - 1) missing_int8 = ord(missing_data_character) + missing_mask = None + if want_missing: + missing_mask = np.zeros((num_samples, num_sites), dtype=bool) for var in self.variants( samples=samples, isolated_as_missing=isolated_as_missing, left=interval.left, right=interval.right, + copy=False, ): - ancestral = var.site.ancestral_state - if ancestral is None or len(ancestral) != 1: - raise TypeError( - "Multi-letter allele or deletion detected at site {}".format( - var.site.id - ) + if want_missing and missing_mask is not None: + genotypes = np.asarray(var.genotypes, dtype=np.int32) + missing_mask[:, var.site.id - start_site] = ( + genotypes == tskit.MISSING_DATA ) - else: - raise TypeError(f"Non-ascii character in allele at site {var.site.id}") - alleles = var.alleles - for allele in alleles: - if allele is None: - continue - if len(allele) != 1: - raise TypeError( - "Multi-letter allele or deletion detected at site {}".format( - var.site.id - ) - ) - if isinstance(allele, str): - try: - ascii_allele = allele.encode("ascii") - except UnicodeEncodeError: - raise TypeError( - "Non-ascii character in allele at site {}".format( - var.site.id - ) - ) - elif isinstance(allele, bytes): - ascii_allele = allele - else: - raise TypeError( - f"Non-ascii character in allele at site {var.site.id}" - ) - allele_int8 = ascii_allele[0] - if allele_int8 == missing_int8: - raise ValueError( - "The missing data character '{}' clashes with an " - "existing allele at site {}".format( - missing_data_character, var.site.id - ) - ) - isolated = isolated_as_missing if isolated_as_missing is not None else True - - hap = _tskit.Haplotype( - self._ll_tree_sequence, - int(start_site), - int(stop_site), - isolated_as_missing=bool(isolated), - missing_data_character=missing_data_character, - ) + try: + hap = _tskit.Haplotype( + self._ll_tree_sequence, + int(start_site), + int(stop_site), + isolated_as_missing=False, + missing_data_character=missing_data_character, + ) + except exceptions.LibraryError as err: + if "TSK_ERR_UNSUPPORTED_OPERATION" in str(err): + raise TypeError(str(err)) from err + if "TSK_ERR_BAD_PARAM_VALUE" in str(err): + raise ValueError(str(err)) from err + raise for row, node in enumerate(sample_nodes): data = hap.decode(int(node)) H[row, :] = np.frombuffer(data, dtype=np.int8, count=num_sites) + if want_missing and missing_mask is not None: + H[missing_mask] = missing_int8 + return H, (start_site, stop_site - 1) def haplotypes( @@ -5789,6 +5776,15 @@ def alignments( "The current implementation may also incorrectly identify an " "input tree sequence has having missing data." ) + if samples is not None: + samples = np.array(samples, dtype=np.int64) + if samples.size > 0: + flags = self.nodes_flags[samples] + if np.any((flags & NODE_IS_SAMPLE) == 0): + raise exceptions.LibraryError( + "Cannot generate genotypes for non-samples when isolated nodes " + "are considered as missing. (TSK_ERR_MUST_IMPUTE_NON_SAMPLES)" + ) H, (first_site_id, last_site_id) = self._haplotypes_array( interval=interval, isolated_as_missing=False, From dab573285d2aea293076b223acd9510bf9248b62 Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Thu, 16 Oct 2025 03:01:48 +0100 Subject: [PATCH 04/11] Remove dead code --- c/tskit/genotypes.c | 45 +------------------------------ c/tskit/genotypes.h | 6 +---- python/_tskitmodule.c | 42 +++-------------------------- python/tskit/trees.py | 62 +++++++++++++++++++++++++++++++------------ 4 files changed, 51 insertions(+), 104 deletions(-) diff --git a/c/tskit/genotypes.c b/c/tskit/genotypes.c index df452a9546..b0cf909727 100644 --- a/c/tskit/genotypes.c +++ b/c/tskit/genotypes.c @@ -127,13 +127,6 @@ tsk_haplotype_build_child_index(tsk_haplotype_t *self) tsk_size_t num_edges = edges->num_rows; tsk_haplotype_edge_sort_t *sorted = NULL; - self->parent_edge_counts - = tsk_calloc(self->num_nodes, sizeof(*self->parent_edge_counts)); - if (self->num_nodes > 0 && self->parent_edge_counts == NULL) { - ret = tsk_trace_error(TSK_ERR_NO_MEMORY); - goto out; - } - if (num_edges == 0) { self->child_order = NULL; self->child_offsets @@ -157,14 +150,6 @@ tsk_haplotype_build_child_index(tsk_haplotype_t *self) sorted[j].edge_id = (tsk_id_t) j; sorted[j].child = edges->child[j]; sorted[j].left = edges->left[j]; - tsk_id_t parent = edges->parent[j]; - if (parent >= 0 && parent < (tsk_id_t) self->num_nodes) { - if (self->parent_edge_counts[parent] == INT32_MAX) { - ret = tsk_trace_error(TSK_ERR_UNSUPPORTED_OPERATION); - goto out; - } - self->parent_edge_counts[parent]++; - } } qsort(sorted, num_edges, sizeof(*sorted), tsk_haplotype_edge_sort_cmp); @@ -268,10 +253,6 @@ tsk_haplotype_build_mutation_index(tsk_haplotype_t *self) ret = tsk_trace_error(TSK_ERR_UNSUPPORTED_OPERATION); goto out; } - if (allele == (uint8_t) self->missing_char) { - ret = tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); - goto out; - } self->node_mutation_sites[counts[node]] = (int32_t)(site - site_start); self->node_mutation_states[counts[node]] = allele; counts[node]++; @@ -290,7 +271,6 @@ tsk_haplotype_build_ancestral_states(tsk_haplotype_t *self) const tsk_table_collection_t *tables = self->tree_sequence->tables; const tsk_site_table_t *sites = &tables->sites; tsk_id_t site_start = self->site_start; - tsk_id_t site_stop = self->site_stop; tsk_size_t j; if (self->num_sites == 0) { @@ -319,10 +299,6 @@ tsk_haplotype_build_ancestral_states(tsk_haplotype_t *self) ret = tsk_trace_error(TSK_ERR_UNSUPPORTED_OPERATION); goto out; } - if (allele == (uint8_t) self->missing_char) { - ret = tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); - goto out; - } self->ancestral_states[j] = allele; } @@ -421,7 +397,7 @@ tsk_haplotype_alloc_bitset(tsk_haplotype_t *self) int tsk_haplotype_init(tsk_haplotype_t *self, const tsk_treeseq_t *tree_sequence, - tsk_id_t site_start, tsk_id_t site_stop, int8_t missing_char, tsk_flags_t options) + tsk_id_t site_start, tsk_id_t site_stop) { int ret = 0; const tsk_table_collection_t *tables; @@ -434,8 +410,6 @@ tsk_haplotype_init(tsk_haplotype_t *self, const tsk_treeseq_t *tree_sequence, tsk_memset(self, 0, sizeof(*self)); self->tree_sequence = tree_sequence; - self->missing_char = missing_char; - self->isolated_as_missing = !(options & TSK_ISOLATED_NOT_MISSING); tables = tree_sequence->tables; sites = &tables->sites; @@ -445,17 +419,12 @@ tsk_haplotype_init(tsk_haplotype_t *self, const tsk_treeseq_t *tree_sequence, ret = tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); goto out; } - if ((unsigned char) missing_char > 0x7F) { - ret = tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); - goto out; - } self->site_start = (int32_t) site_start; self->site_stop = (int32_t) site_stop; self->num_sites = (int32_t)(site_stop - site_start); self->num_nodes = tables->nodes.num_rows; self->num_edges = tables->edges.num_rows; - self->node_flags = tables->nodes.flags; self->site_positions = sites->position + site_start; ret = tsk_haplotype_build_child_index(self); @@ -527,9 +496,6 @@ tsk_haplotype_decode(tsk_haplotype_t *self, tsk_id_t node, int8_t *haplotype) if (node < 0 || node >= (tsk_id_t) self->num_nodes) { return tsk_trace_error(TSK_ERR_NODE_OUT_OF_BOUNDS); } - if (self->isolated_as_missing && !(self->node_flags[node] & TSK_NODE_IS_SAMPLE)) { - return tsk_trace_error(TSK_ERR_MUST_IMPUTE_NON_SAMPLES); - } if (self->num_sites == 0) { return 0; } @@ -651,16 +617,9 @@ tsk_haplotype_decode(tsk_haplotype_t *self, tsk_id_t node, int8_t *haplotype) } } - bool has_incoming = self->child_offsets[node + 1] != self->child_offsets[node]; - bool has_outgoing - = self->parent_edge_counts != NULL && self->parent_edge_counts[node] > 0; - bool mark_missing = self->isolated_as_missing && !(has_incoming || has_outgoing); idx = tsk_haplotype_bitset_next( bits, self->num_bit_words, 0, (tsk_size_t) self->num_sites); while (idx < (tsk_size_t) self->num_sites) { - if (mark_missing) { - haplotype[idx] = self->missing_char; - } tsk_haplotype_bitset_clear(bits, idx); idx = tsk_haplotype_bitset_next( bits, self->num_bit_words, idx, (tsk_size_t) self->num_sites); @@ -681,7 +640,6 @@ tsk_haplotype_free(tsk_haplotype_t *self) tsk_safe_free(self->node_mutation_states); tsk_safe_free(self->child_order); tsk_safe_free(self->child_offsets); - tsk_safe_free(self->parent_edge_counts); tsk_safe_free(self->edge_start_index); tsk_safe_free(self->edge_end_index); tsk_safe_free(self->edge_stack); @@ -692,7 +650,6 @@ tsk_haplotype_free(tsk_haplotype_t *self) tsk_safe_free(self->unresolved_bits); tsk_safe_free(self->initial_bits); self->tree_sequence = NULL; - self->node_flags = NULL; self->site_positions = NULL; self->initialised = false; return 0; diff --git a/c/tskit/genotypes.h b/c/tskit/genotypes.h index 97e7b3ddc8..937cf86acb 100644 --- a/c/tskit/genotypes.h +++ b/c/tskit/genotypes.h @@ -93,9 +93,6 @@ typedef struct { int32_t site_start; int32_t site_stop; int32_t num_sites; - int8_t missing_char; - bool isolated_as_missing; - const tsk_flags_t *node_flags; const double *site_positions; uint8_t *ancestral_states; int32_t *node_mutation_offsets; @@ -103,7 +100,6 @@ typedef struct { uint8_t *node_mutation_states; tsk_id_t *child_order; int32_t *child_offsets; - int32_t *parent_edge_counts; int32_t *edge_start_index; int32_t *edge_end_index; tsk_id_t *edge_stack; @@ -211,7 +207,7 @@ void tsk_variant_print_state(const tsk_variant_t *self, FILE *out); /* Deprecated vargen methods (since C API v1.0) */ int tsk_haplotype_init(tsk_haplotype_t *self, const tsk_treeseq_t *tree_sequence, - tsk_id_t site_start, tsk_id_t site_stop, int8_t missing_char, tsk_flags_t options); + tsk_id_t site_start, tsk_id_t site_stop); int tsk_haplotype_decode(tsk_haplotype_t *self, tsk_id_t node, int8_t *haplotype); int tsk_haplotype_free(tsk_haplotype_t *self); int tsk_vargen_init(tsk_vargen_t *self, const tsk_treeseq_t *tree_sequence, diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 3a6b53d6d8..10983613fb 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -10636,45 +10636,16 @@ Haplotype_init(Haplotype *self, PyObject *args, PyObject *kwds) { int ret = -1; int err; - static char *kwlist[] = { "tree_sequence", "site_start", "site_stop", - "isolated_as_missing", "missing_data_character", NULL }; + static char *kwlist[] = { "tree_sequence", "site_start", "site_stop", NULL }; TreeSequence *tree_sequence = NULL; Py_ssize_t site_start; Py_ssize_t site_stop; - int isolated_as_missing = 1; - PyObject *missing_obj = NULL; - PyObject *missing_bytes = NULL; - const char *missing_ptr = NULL; - Py_ssize_t missing_length = 0; - char missing_char = 'N'; - tsk_flags_t options = 0; - if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!nn|pO", kwlist, &TreeSequenceType, - &tree_sequence, &site_start, &site_stop, &isolated_as_missing, - &missing_obj)) { + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!nn", kwlist, &TreeSequenceType, + &tree_sequence, &site_start, &site_stop)) { goto out; } - if (missing_obj != NULL && missing_obj != Py_None) { - if (PyBytes_Check(missing_obj)) { - missing_ptr = PyBytes_AS_STRING(missing_obj); - missing_length = PyBytes_GET_SIZE(missing_obj); - } else { - missing_bytes = PyUnicode_AsASCIIString(missing_obj); - if (missing_bytes == NULL) { - goto out; - } - missing_ptr = PyBytes_AS_STRING(missing_bytes); - missing_length = PyBytes_GET_SIZE(missing_bytes); - } - if (missing_length != 1) { - PyErr_SetString(PyExc_ValueError, - "missing_data_character must be a single ASCII character"); - goto out; - } - missing_char = missing_ptr[0]; - } - self->haplotype = PyMem_Malloc(sizeof(*self->haplotype)); if (self->haplotype == NULL) { PyErr_NoMemory(); @@ -10684,12 +10655,8 @@ Haplotype_init(Haplotype *self, PyObject *args, PyObject *kwds) self->tree_sequence = tree_sequence; Py_INCREF(tree_sequence); - if (!isolated_as_missing) { - options |= TSK_ISOLATED_NOT_MISSING; - } - err = tsk_haplotype_init(self->haplotype, tree_sequence->tree_sequence, - (tsk_id_t) site_start, (tsk_id_t) site_stop, (int8_t) missing_char, options); + (tsk_id_t) site_start, (tsk_id_t) site_stop); if (err != 0) { handle_library_error(err); goto out; @@ -10706,7 +10673,6 @@ Haplotype_init(Haplotype *self, PyObject *args, PyObject *kwds) Py_XDECREF(self->tree_sequence); self->tree_sequence = NULL; } - Py_XDECREF(missing_bytes); return ret; } diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 92b1d676ad..b08548f7df 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -5272,6 +5272,38 @@ def _haplotypes_array( start_site, stop_site = np.searchsorted(self.sites_position, interval) num_sites = stop_site - start_site + missing_int8 = ord(missing_data_character) + + want_missing = ( + True if isolated_as_missing is None else bool(isolated_as_missing) + ) + + if want_missing and num_sites > 0: + ll_ts = self._ll_tree_sequence + anc_offsets = ll_ts.sites_ancestral_state_offset + anc_data = ll_ts.sites_ancestral_state + anc_slice = anc_offsets[start_site : stop_site + 1] + anc_lengths = np.diff(anc_slice) + if np.any(anc_lengths > 0): + anc_index = anc_slice[:-1][anc_lengths > 0] + if np.any(anc_data[anc_index] == missing_int8): + raise ValueError( + "missing_data_character must differ from existing allele states" + ) + mut_sites = ll_ts.mutations_site + if mut_sites.size > 0: + mut_offsets = ll_ts.mutations_derived_state_offset + mut_lengths = np.diff(mut_offsets) + mask = (mut_sites >= start_site) & (mut_sites < stop_site) + valid = mask & (mut_lengths > 0) + if np.any(valid): + mut_start = mut_offsets[:-1][valid] + derived_chars = ll_ts.mutations_derived_state[mut_start] + if np.any(derived_chars == missing_int8): + raise ValueError( + "missing_data_character must differ from existing allele " + "states" + ) if samples is None: sample_nodes = self.samples() @@ -5279,9 +5311,6 @@ def _haplotypes_array( sample_nodes = np.array(samples, dtype=np.int64) num_samples = len(sample_nodes) - want_missing = ( - True if isolated_as_missing is None else bool(isolated_as_missing) - ) if want_missing and samples is not None and num_samples > 0: flags = self.nodes_flags[sample_nodes] if np.any((flags & NODE_IS_SAMPLE) == 0): @@ -5294,19 +5323,20 @@ def _haplotypes_array( if num_samples == 0 or num_sites == 0: return H, (start_site, stop_site - 1) - missing_int8 = ord(missing_data_character) + # For now deal with missing data using the variants iterator missing_mask = None if want_missing: - missing_mask = np.zeros((num_samples, num_sites), dtype=bool) - - for var in self.variants( - samples=samples, - isolated_as_missing=isolated_as_missing, - left=interval.left, - right=interval.right, - copy=False, - ): - if want_missing and missing_mask is not None: + for var in self.variants( + samples=samples, + isolated_as_missing=isolated_as_missing, + left=interval.left, + right=interval.right, + copy=False, + ): + if not var.has_missing_data: + continue + if missing_mask is None: + missing_mask = np.zeros((num_samples, num_sites), dtype=bool) genotypes = np.asarray(var.genotypes, dtype=np.int32) missing_mask[:, var.site.id - start_site] = ( genotypes == tskit.MISSING_DATA @@ -5317,8 +5347,6 @@ def _haplotypes_array( self._ll_tree_sequence, int(start_site), int(stop_site), - isolated_as_missing=False, - missing_data_character=missing_data_character, ) except exceptions.LibraryError as err: if "TSK_ERR_UNSUPPORTED_OPERATION" in str(err): @@ -5331,7 +5359,7 @@ def _haplotypes_array( data = hap.decode(int(node)) H[row, :] = np.frombuffer(data, dtype=np.int8, count=num_sites) - if want_missing and missing_mask is not None: + if missing_mask is not None: H[missing_mask] = missing_int8 return H, (start_site, stop_site - 1) From 6f1fea4f90f2001c2010f2ae72374df245acc918 Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Thu, 16 Oct 2025 03:14:17 +0100 Subject: [PATCH 05/11] Rename to parent --- c/tskit/genotypes.c | 76 +++++++++++++++++++++++++++++++-------------- c/tskit/genotypes.h | 4 +-- 2 files changed, 54 insertions(+), 26 deletions(-) diff --git a/c/tskit/genotypes.c b/c/tskit/genotypes.c index b0cf909727..c8b0096fc3 100644 --- a/c/tskit/genotypes.c +++ b/c/tskit/genotypes.c @@ -118,7 +118,7 @@ tsk_haplotype_reset_bitset(const tsk_haplotype_t *self) } static int -tsk_haplotype_build_child_index(tsk_haplotype_t *self) +tsk_haplotype_build_parent_index(tsk_haplotype_t *self) { int ret = 0; tsk_size_t j; @@ -126,22 +126,31 @@ tsk_haplotype_build_child_index(tsk_haplotype_t *self) const tsk_edge_table_t *edges = &tables->edges; tsk_size_t num_edges = edges->num_rows; tsk_haplotype_edge_sort_t *sorted = NULL; + int32_t *offsets = NULL; if (num_edges == 0) { - self->child_order = NULL; - self->child_offsets - = tsk_calloc(self->num_nodes + 1, sizeof(*self->child_offsets)); - if (self->child_offsets == NULL) { - ret = tsk_trace_error(TSK_ERR_NO_MEMORY); - goto out; + self->parent_edge_index = NULL; + if (self->num_nodes > 0) { + self->parent_index_range + = tsk_calloc(self->num_nodes * 2, sizeof(*self->parent_index_range)); + if (self->parent_index_range == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + } else { + self->parent_index_range = NULL; } goto out; } sorted = tsk_malloc(num_edges * sizeof(*sorted)); - self->child_order = tsk_malloc(num_edges * sizeof(*self->child_order)); - self->child_offsets = tsk_calloc(self->num_nodes + 1, sizeof(*self->child_offsets)); - if (sorted == NULL || self->child_order == NULL || self->child_offsets == NULL) { + self->parent_edge_index = tsk_malloc(num_edges * sizeof(*self->parent_edge_index)); + self->parent_index_range + = tsk_malloc(self->num_nodes * 2 * sizeof(*self->parent_index_range)); + offsets = tsk_calloc(self->num_nodes + 1, sizeof(*offsets)); + if (sorted == NULL || self->parent_edge_index == NULL + || (self->num_nodes > 0 && self->parent_index_range == NULL) + || offsets == NULL) { ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } @@ -156,16 +165,26 @@ tsk_haplotype_build_child_index(tsk_haplotype_t *self) for (j = 0; j < num_edges; j++) { tsk_id_t child = sorted[j].child; if (child >= 0 && child < (tsk_id_t) self->num_nodes) { - self->child_offsets[child + 1]++; + offsets[child + 1]++; } - self->child_order[j] = sorted[j].edge_id; } - for (j = 0; j < self->num_nodes; j++) { - self->child_offsets[j + 1] += self->child_offsets[j]; + for (j = 0; j < (tsk_size_t) self->num_nodes; j++) { + offsets[j + 1] += offsets[j]; + self->parent_index_range[2 * j] = offsets[j]; + self->parent_index_range[2 * j + 1] = offsets[j + 1]; + } + + for (j = 0; j < num_edges; j++) { + tsk_id_t child = sorted[j].child; + if (child >= 0 && child < (tsk_id_t) self->num_nodes) { + int32_t pos = offsets[child]++; + self->parent_edge_index[pos] = sorted[j].edge_id; + } } out: tsk_safe_free(sorted); + tsk_safe_free(offsets); return ret; } @@ -427,7 +446,7 @@ tsk_haplotype_init(tsk_haplotype_t *self, const tsk_treeseq_t *tree_sequence, self->num_edges = tables->edges.num_rows; self->site_positions = sites->position + site_start; - ret = tsk_haplotype_build_child_index(self); + ret = tsk_haplotype_build_parent_index(self); if (ret != 0) { goto out; } @@ -523,10 +542,15 @@ tsk_haplotype_decode(tsk_haplotype_t *self, tsk_id_t node, int8_t *haplotype) } } - int32_t child_start = self->child_offsets[node]; - int32_t child_stop = self->child_offsets[node + 1]; + int32_t child_start = 0; + int32_t child_stop = 0; + if (self->parent_index_range != NULL) { + int32_t range_offset = node * 2; + child_start = self->parent_index_range[range_offset]; + child_stop = self->parent_index_range[range_offset + 1]; + } for (int32_t i = child_start; i < child_stop; i++) { - tsk_id_t edge = self->child_order[i]; + tsk_id_t edge = self->parent_edge_index[i]; int32_t start = self->edge_start_index[edge]; int32_t end = self->edge_end_index[edge]; if (start >= end) { @@ -565,11 +589,12 @@ tsk_haplotype_decode(tsk_haplotype_t *self, tsk_id_t node, int8_t *haplotype) } parent_count = 0; - if (ancestor >= 0) { - child_start = self->child_offsets[ancestor]; - child_stop = self->child_offsets[ancestor + 1]; + if (ancestor >= 0 && self->parent_index_range != NULL) { + int32_t range_offset = ancestor * 2; + child_start = self->parent_index_range[range_offset]; + child_stop = self->parent_index_range[range_offset + 1]; for (int32_t i = child_start; i < child_stop; i++) { - tsk_id_t parent_edge = self->child_order[i]; + tsk_id_t parent_edge = self->parent_edge_index[i]; int32_t parent_start = self->edge_start_index[parent_edge]; int32_t parent_end = self->edge_end_index[parent_edge]; if (parent_start < interval_start) { @@ -593,6 +618,9 @@ tsk_haplotype_decode(tsk_haplotype_t *self, tsk_id_t node, int8_t *haplotype) parent_count++; } } + } else { + child_start = 0; + child_stop = 0; } idx = tsk_haplotype_bitset_next(bits, self->num_bit_words, @@ -638,8 +666,8 @@ tsk_haplotype_free(tsk_haplotype_t *self) tsk_safe_free(self->node_mutation_offsets); tsk_safe_free(self->node_mutation_sites); tsk_safe_free(self->node_mutation_states); - tsk_safe_free(self->child_order); - tsk_safe_free(self->child_offsets); + tsk_safe_free(self->parent_edge_index); + tsk_safe_free(self->parent_index_range); tsk_safe_free(self->edge_start_index); tsk_safe_free(self->edge_end_index); tsk_safe_free(self->edge_stack); diff --git a/c/tskit/genotypes.h b/c/tskit/genotypes.h index 937cf86acb..62a923d109 100644 --- a/c/tskit/genotypes.h +++ b/c/tskit/genotypes.h @@ -98,8 +98,8 @@ typedef struct { int32_t *node_mutation_offsets; int32_t *node_mutation_sites; uint8_t *node_mutation_states; - tsk_id_t *child_order; - int32_t *child_offsets; + tsk_id_t *parent_edge_index; + int32_t *parent_index_range; int32_t *edge_start_index; int32_t *edge_end_index; tsk_id_t *edge_stack; From cbc6512739c19ae774ce982b44bdf3789c2b1681 Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Thu, 16 Oct 2025 03:29:07 +0100 Subject: [PATCH 06/11] Remove qsort --- c/tskit/genotypes.c | 93 ++++++++++++++++++--------------------------- 1 file changed, 38 insertions(+), 55 deletions(-) diff --git a/c/tskit/genotypes.c b/c/tskit/genotypes.c index c8b0096fc3..77d02b3243 100644 --- a/c/tskit/genotypes.c +++ b/c/tskit/genotypes.c @@ -37,34 +37,6 @@ #include -typedef struct { - tsk_id_t edge_id; - tsk_id_t child; - double left; -} tsk_haplotype_edge_sort_t; - -static int -tsk_haplotype_edge_sort_cmp(const void *aa, const void *bb) -{ - const tsk_haplotype_edge_sort_t *a = (const tsk_haplotype_edge_sort_t *) aa; - const tsk_haplotype_edge_sort_t *b = (const tsk_haplotype_edge_sort_t *) bb; - - if (a->child == b->child) { - if (a->left < b->left) { - return -1; - } else if (a->left > b->left) { - return 1; - } - if (a->edge_id < b->edge_id) { - return -1; - } else if (a->edge_id > b->edge_id) { - return 1; - } - return 0; - } - return a->child < b->child ? -1 : 1; -} - static inline uint32_t tsk_haplotype_ctz64(uint64_t x) { @@ -121,12 +93,11 @@ static int tsk_haplotype_build_parent_index(tsk_haplotype_t *self) { int ret = 0; - tsk_size_t j; const tsk_table_collection_t *tables = self->tree_sequence->tables; const tsk_edge_table_t *edges = &tables->edges; + const tsk_id_t *edges_child = edges->child; tsk_size_t num_edges = edges->num_rows; - tsk_haplotype_edge_sort_t *sorted = NULL; - int32_t *offsets = NULL; + int32_t *child_counts = NULL; if (num_edges == 0) { self->parent_edge_index = NULL; @@ -143,48 +114,60 @@ tsk_haplotype_build_parent_index(tsk_haplotype_t *self) goto out; } - sorted = tsk_malloc(num_edges * sizeof(*sorted)); self->parent_edge_index = tsk_malloc(num_edges * sizeof(*self->parent_edge_index)); self->parent_index_range = tsk_malloc(self->num_nodes * 2 * sizeof(*self->parent_index_range)); - offsets = tsk_calloc(self->num_nodes + 1, sizeof(*offsets)); - if (sorted == NULL || self->parent_edge_index == NULL + child_counts = tsk_calloc(self->num_nodes, sizeof(*child_counts)); + if (self->parent_edge_index == NULL || (self->num_nodes > 0 && self->parent_index_range == NULL) - || offsets == NULL) { + || child_counts == NULL) { ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } - for (j = 0; j < num_edges; j++) { - sorted[j].edge_id = (tsk_id_t) j; - sorted[j].child = edges->child[j]; - sorted[j].left = edges->left[j]; - } - qsort(sorted, num_edges, sizeof(*sorted), tsk_haplotype_edge_sort_cmp); - - for (j = 0; j < num_edges; j++) { - tsk_id_t child = sorted[j].child; + for (tsk_size_t j = 0; j < num_edges; j++) { + tsk_id_t child = edges_child[j]; if (child >= 0 && child < (tsk_id_t) self->num_nodes) { - offsets[child + 1]++; + if (child_counts[child] == INT32_MAX) { + ret = tsk_trace_error(TSK_ERR_UNSUPPORTED_OPERATION); + goto out; + } + child_counts[child]++; } } - for (j = 0; j < (tsk_size_t) self->num_nodes; j++) { - offsets[j + 1] += offsets[j]; - self->parent_index_range[2 * j] = offsets[j]; - self->parent_index_range[2 * j + 1] = offsets[j + 1]; + + int32_t current_start = 0; + for (tsk_size_t u = 0; u < (tsk_size_t) self->num_nodes; u++) { + int32_t offset = (int32_t)(u * 2); + self->parent_index_range[offset] = current_start; + self->parent_index_range[offset + 1] = current_start; + current_start += child_counts[u]; } - for (j = 0; j < num_edges; j++) { - tsk_id_t child = sorted[j].child; + for (tsk_size_t j = 0; j < num_edges; j++) { + tsk_id_t child = edges_child[j]; if (child >= 0 && child < (tsk_id_t) self->num_nodes) { - int32_t pos = offsets[child]++; - self->parent_edge_index[pos] = sorted[j].edge_id; + int32_t end_offset = (int32_t)(child * 2 + 1); + int32_t pos = self->parent_index_range[end_offset]; + self->parent_edge_index[pos] = (tsk_id_t) j; + self->parent_index_range[end_offset] = pos + 1; } } + for (tsk_size_t u = 0; u < (tsk_size_t) self->num_nodes; u++) { + int32_t offset = (int32_t)(u * 2); + int32_t end = self->parent_index_range[offset + 1]; + self->parent_index_range[offset] = end - child_counts[u]; + } + out: - tsk_safe_free(sorted); - tsk_safe_free(offsets); + if (ret != 0) { + tsk_safe_free(self->parent_edge_index); + self->parent_edge_index = NULL; + tsk_safe_free(self->parent_index_range); + self->parent_index_range = NULL; + } + tsk_safe_free(child_counts); return ret; } From 24cf27a8bc612465dc06ded745d183511a8a56b2 Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Thu, 16 Oct 2025 04:34:21 +0100 Subject: [PATCH 07/11] Fix warnings --- c/tskit/genotypes.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/c/tskit/genotypes.c b/c/tskit/genotypes.c index 77d02b3243..7703a19045 100644 --- a/c/tskit/genotypes.c +++ b/c/tskit/genotypes.c @@ -212,7 +212,7 @@ tsk_haplotype_build_mutation_index(tsk_haplotype_t *self) } self->node_mutation_offsets[0] = 0; for (j = 0; j < self->num_nodes; j++) { - total_mutations += counts[j]; + total_mutations += (tsk_size_t) counts[j]; if (total_mutations > INT32_MAX) { ret = tsk_trace_error(TSK_ERR_UNSUPPORTED_OPERATION); goto out; @@ -281,7 +281,7 @@ tsk_haplotype_build_ancestral_states(tsk_haplotype_t *self) } self->ancestral_states - = tsk_malloc(self->num_sites * sizeof(*self->ancestral_states)); + = tsk_malloc((tsk_size_t) self->num_sites * sizeof(*self->ancestral_states)); if (self->ancestral_states == NULL) { return tsk_trace_error(TSK_ERR_NO_MEMORY); } From d3029a60c03527c3e9e1340ecb92bc6c4036d9a6 Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Thu, 16 Oct 2025 07:35:29 +0100 Subject: [PATCH 08/11] Perf --- c/tskit/genotypes.c | 198 ++++++++++++++++++++++++++++++++------------ c/tskit/genotypes.h | 2 + 2 files changed, 147 insertions(+), 53 deletions(-) diff --git a/c/tskit/genotypes.c b/c/tskit/genotypes.c index 7703a19045..2ff7fd2582 100644 --- a/c/tskit/genotypes.c +++ b/c/tskit/genotypes.c @@ -49,43 +49,140 @@ tsk_haplotype_ctz64(uint64_t x) #endif } +static inline uint32_t +tsk_haplotype_popcount64(uint64_t value) +{ +#if defined(_MSC_VER) + return (uint32_t) __popcnt64(value); +#else + return (uint32_t) __builtin_popcountll(value); +#endif +} + +static inline void +tsk_haplotype_bitset_clear(tsk_haplotype_t *self, tsk_size_t idx) +{ + tsk_size_t word = idx >> 6; + uint64_t mask = UINT64_C(1) << (idx & 63); + if ((self->unresolved_bits[word] & mask) == 0) { + return; + } + self->unresolved_bits[word] &= ~mask; + if (self->unresolved_counts[word] > 0) { + self->unresolved_counts[word]--; + } +} + static inline void -tsk_haplotype_bitset_clear(uint64_t *bits, tsk_size_t idx) +tsk_haplotype_clear_word_bit(tsk_haplotype_t *self, tsk_size_t word, uint64_t mask) +{ + if ((self->unresolved_bits[word] & mask) != 0) { + self->unresolved_bits[word] &= ~mask; + if (self->unresolved_counts[word] > 0) { + self->unresolved_counts[word]--; + } + } +} + +static inline bool +tsk_haplotype_bitset_test(const uint64_t *bits, tsk_size_t idx) { tsk_size_t word = idx >> 6; uint64_t mask = UINT64_C(1) << (idx & 63); - bits[word] &= ~mask; + return (bits[word] & mask) != 0; } static inline tsk_size_t tsk_haplotype_bitset_next( - const uint64_t *bits, tsk_size_t num_words, tsk_size_t start, tsk_size_t limit) + const tsk_haplotype_t *self, tsk_size_t start, tsk_size_t limit) { tsk_size_t word = start >> 6; + tsk_size_t word_limit = (limit + 63) >> 6; uint64_t mask, value; - if (start >= limit || word >= num_words) { + if (start >= limit || word >= self->num_bit_words) { return limit; } mask = UINT64_MAX << (start & 63); - value = bits[word] & mask; + value = self->unresolved_bits[word] & mask; while (value == 0) { word++; - if (word >= num_words) { + if (word >= word_limit || word >= self->num_bit_words) { return limit; } - value = bits[word]; + while (word < word_limit && word < self->num_bit_words + && self->unresolved_counts[word] == 0) { + word++; + } + if (word >= word_limit || word >= self->num_bit_words) { + return limit; + } + value = self->unresolved_bits[word]; } start = (word << 6) + tsk_haplotype_ctz64(value); return start < limit ? start : limit; } +static bool +tsk_haplotype_find_next_uncovered(tsk_haplotype_t *self, tsk_size_t start, + tsk_size_t end, const int32_t *interval_start, const int32_t *interval_end, + tsk_size_t interval_count, tsk_size_t *out_index) +{ + if (start >= end) { + return false; + } + tsk_size_t word = start >> 6; + tsk_size_t last_word = (end - 1) >> 6; + if (word >= self->num_bit_words) { + return false; + } + uint64_t start_mask = UINT64_MAX << (start & 63); + for (; word <= last_word && word < self->num_bit_words; word++) { + if (self->unresolved_counts[word] == 0) { + continue; + } + uint64_t word_bits = self->unresolved_bits[word]; + if (word == (start >> 6)) { + word_bits &= start_mask; + } + if (word == last_word) { + uint64_t end_mask = UINT64_MAX >> (63 - ((end - 1) & 63)); + word_bits &= end_mask; + } + while (word_bits != 0) { + uint64_t lowest_bit = word_bits & (~word_bits + 1); + tsk_size_t bit = tsk_haplotype_ctz64(word_bits); + word_bits ^= lowest_bit; + tsk_size_t bit_index = (word << 6) + bit; + if (bit_index >= end) { + break; + } + bool covered = false; + for (tsk_size_t p = 0; p < interval_count; p++) { + if (interval_start[p] <= (int32_t) bit_index + && (int32_t) bit_index < interval_end[p]) { + covered = true; + break; + } + } + if (!covered) { + *out_index = bit_index; + return true; + } + } + start_mask = UINT64_MAX; + } + return false; +} + static void -tsk_haplotype_reset_bitset(const tsk_haplotype_t *self) +tsk_haplotype_reset_bitset(tsk_haplotype_t *self) { if (self->num_bit_words > 0) { tsk_memcpy(self->unresolved_bits, self->initial_bits, self->num_bit_words * sizeof(*self->unresolved_bits)); + tsk_memcpy(self->unresolved_counts, self->initial_counts, + self->num_bit_words * sizeof(*self->unresolved_counts)); } } @@ -379,21 +476,31 @@ tsk_haplotype_alloc_bitset(tsk_haplotype_t *self) if (self->num_bit_words == 0) { self->unresolved_bits = NULL; self->initial_bits = NULL; + self->unresolved_counts = NULL; + self->initial_counts = NULL; return 0; } self->unresolved_bits = tsk_malloc(self->num_bit_words * sizeof(*self->unresolved_bits)); self->initial_bits = tsk_malloc(self->num_bit_words * sizeof(*self->initial_bits)); - if (self->unresolved_bits == NULL || self->initial_bits == NULL) { + self->unresolved_counts + = tsk_malloc(self->num_bit_words * sizeof(*self->unresolved_counts)); + self->initial_counts + = tsk_malloc(self->num_bit_words * sizeof(*self->initial_counts)); + if (self->unresolved_bits == NULL || self->initial_bits == NULL + || self->unresolved_counts == NULL || self->initial_counts == NULL) { return tsk_trace_error(TSK_ERR_NO_MEMORY); } for (j = 0; j < self->num_bit_words; j++) { - self->initial_bits[j] = UINT64_MAX; - } - if ((tsk_size_t) self->num_sites % 64 != 0) { - uint32_t bits = (uint32_t)((tsk_size_t) self->num_sites & 63); - self->initial_bits[self->num_bit_words - 1] = (UINT64_C(1) << bits) - 1; + uint64_t word = UINT64_MAX; + if (j == self->num_bit_words - 1 && (tsk_size_t) self->num_sites % 64 != 0) { + uint32_t bits = (uint32_t)((tsk_size_t) self->num_sites & 63); + word = (UINT64_C(1) << bits) - 1; + } + self->initial_bits[j] = word; + self->initial_counts[j] = tsk_haplotype_popcount64(word); } + tsk_haplotype_reset_bitset(self); return 0; } @@ -517,11 +624,9 @@ tsk_haplotype_decode(tsk_haplotype_t *self, tsk_id_t node, int8_t *haplotype) for (int32_t m = mut_start; m < mut_end; m++) { int32_t site = self->node_mutation_sites[m]; if (site >= 0 && site < self->num_sites - && tsk_haplotype_bitset_next( - bits, self->num_bit_words, (tsk_size_t) site, (tsk_size_t) site + 1) - == (tsk_size_t) site) { + && tsk_haplotype_bitset_test(bits, (tsk_size_t) site)) { haplotype[site] = (int8_t) self->node_mutation_states[m]; - tsk_haplotype_bitset_clear(bits, (tsk_size_t) site); + tsk_haplotype_bitset_clear(self, (tsk_size_t) site); } } @@ -539,9 +644,10 @@ tsk_haplotype_decode(tsk_haplotype_t *self, tsk_id_t node, int8_t *haplotype) if (start >= end) { continue; } - if (tsk_haplotype_bitset_next( - bits, self->num_bit_words, (tsk_size_t) start, (tsk_size_t) end) - < (tsk_size_t) end) { + tsk_size_t uncovered_idx; + if (tsk_haplotype_find_next_uncovered(self, (tsk_size_t) start, (tsk_size_t) end, + self->parent_interval_start, self->parent_interval_end, 0, + &uncovered_idx)) { self->edge_stack[stack_top] = edge; self->stack_interval_start[stack_top] = start; self->stack_interval_end[stack_top] = end; @@ -562,11 +668,9 @@ tsk_haplotype_decode(tsk_haplotype_t *self, tsk_id_t node, int8_t *haplotype) for (int32_t m = mut_start; m < mut_end; m++) { int32_t site = self->node_mutation_sites[m]; if (site >= interval_start && site < interval_end - && tsk_haplotype_bitset_next(bits, self->num_bit_words, - (tsk_size_t) site, (tsk_size_t) site + 1) - == (tsk_size_t) site) { + && tsk_haplotype_bitset_test(bits, (tsk_size_t) site)) { haplotype[site] = (int8_t) self->node_mutation_states[m]; - tsk_haplotype_bitset_clear(bits, (tsk_size_t) site); + tsk_haplotype_bitset_clear(self, (tsk_size_t) site); } } } @@ -589,9 +693,10 @@ tsk_haplotype_decode(tsk_haplotype_t *self, tsk_id_t node, int8_t *haplotype) if (parent_start >= parent_end) { continue; } - if (tsk_haplotype_bitset_next(bits, self->num_bit_words, - (tsk_size_t) parent_start, (tsk_size_t) parent_end) - < (tsk_size_t) parent_end) { + tsk_size_t uncovered_idx; + if (tsk_haplotype_find_next_uncovered(self, (tsk_size_t) parent_start, + (tsk_size_t) parent_end, self->parent_interval_start, + self->parent_interval_end, parent_count, &uncovered_idx)) { self->edge_stack[stack_top] = parent_edge; self->stack_interval_start[stack_top] = parent_start; self->stack_interval_end[stack_top] = parent_end; @@ -606,34 +711,19 @@ tsk_haplotype_decode(tsk_haplotype_t *self, tsk_id_t node, int8_t *haplotype) child_stop = 0; } - idx = tsk_haplotype_bitset_next(bits, self->num_bit_words, - (tsk_size_t) interval_start, (tsk_size_t) interval_end); - while ((int32_t) idx < interval_end) { - bool covered = false; - for (tsk_size_t p = 0; p < parent_count; p++) { - if (self->parent_interval_start[p] <= (int32_t) idx - && (int32_t) idx < self->parent_interval_end[p]) { - covered = true; - break; - } - } - if (covered) { - idx = tsk_haplotype_bitset_next( - bits, self->num_bit_words, idx + 1, (tsk_size_t) interval_end); - } else { - tsk_haplotype_bitset_clear(bits, idx); - idx = tsk_haplotype_bitset_next( - bits, self->num_bit_words, idx, (tsk_size_t) interval_end); - } + tsk_size_t uncovered_idx; + while (tsk_haplotype_find_next_uncovered(self, (tsk_size_t) interval_start, + (tsk_size_t) interval_end, self->parent_interval_start, + self->parent_interval_end, parent_count, &uncovered_idx)) { + tsk_size_t word_index = uncovered_idx >> 6; + uint64_t mask = UINT64_C(1) << (uncovered_idx & 63); + tsk_haplotype_clear_word_bit(self, word_index, mask); } } - idx = tsk_haplotype_bitset_next( - bits, self->num_bit_words, 0, (tsk_size_t) self->num_sites); - while (idx < (tsk_size_t) self->num_sites) { - tsk_haplotype_bitset_clear(bits, idx); - idx = tsk_haplotype_bitset_next( - bits, self->num_bit_words, idx, (tsk_size_t) self->num_sites); + for (tsk_size_t w = 0; w < self->num_bit_words; w++) { + self->unresolved_bits[w] = 0; + self->unresolved_counts[w] = 0; } return 0; @@ -660,6 +750,8 @@ tsk_haplotype_free(tsk_haplotype_t *self) tsk_safe_free(self->parent_interval_end); tsk_safe_free(self->unresolved_bits); tsk_safe_free(self->initial_bits); + tsk_safe_free(self->unresolved_counts); + tsk_safe_free(self->initial_counts); self->tree_sequence = NULL; self->site_positions = NULL; self->initialised = false; diff --git a/c/tskit/genotypes.h b/c/tskit/genotypes.h index 62a923d109..993dbe2f08 100644 --- a/c/tskit/genotypes.h +++ b/c/tskit/genotypes.h @@ -110,6 +110,8 @@ typedef struct { uint64_t *unresolved_bits; uint64_t *initial_bits; tsk_size_t num_bit_words; + uint32_t *unresolved_counts; + uint32_t *initial_counts; bool initialised; } tsk_haplotype_t; From 22ae99abbb55769eb5966bc43cc2753668682b24 Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Thu, 16 Oct 2025 10:58:16 +0100 Subject: [PATCH 09/11] Perf - Shift coverage handling to whole-word bit math. --- c/examples/Makefile | 3 +- c/examples/haplotype_benchmark.c | 90 ++++++++++++++++++++++++++++++++ c/meson.build | 3 ++ c/tskit/genotypes.c | 73 +++++++++++++++++++++----- 4 files changed, 153 insertions(+), 16 deletions(-) create mode 100644 c/examples/haplotype_benchmark.c diff --git a/c/examples/Makefile b/c/examples/Makefile index b289c2c253..03639ad165 100644 --- a/c/examples/Makefile +++ b/c/examples/Makefile @@ -23,7 +23,7 @@ TSKIT_SOURCE=../tskit/*.c ../subprojects/kastore/kastore.c targets = api_structure error_handling \ haploid_wright_fisher streaming \ tree_iteration tree_traversal \ - take_ownership + take_ownership haplotype_benchmark all: $(targets) @@ -32,4 +32,3 @@ $(targets): %: %.c clean: rm -f $(targets) - diff --git a/c/examples/haplotype_benchmark.c b/c/examples/haplotype_benchmark.c new file mode 100644 index 0000000000..6f78456480 --- /dev/null +++ b/c/examples/haplotype_benchmark.c @@ -0,0 +1,90 @@ +#include +#include +#include +#include + +#include +#include +#include + +#define CHECK_TSK(err) \ + do { \ + if ((err) < 0) { \ + fprintf(stderr, "Error: line %d: %s\n", __LINE__, tsk_strerror(err)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +#define NUM_ITERATIONS 1 +#define MAX_BENCHMARK_NODES 500 + +int +main(int argc, char **argv) +{ + int ret; + tsk_table_collection_t tables; + tsk_treeseq_t treeseq; + tsk_haplotype_t haplotype_decoder; + int8_t *haplotype = NULL; + double elapsed_seconds; + clock_t start_clock, end_clock; + uint64_t checksum = 0; + + const char *filename = "../../simulated_chrom_21_100k.ts"; + if (argc > 1) { + filename = argv[1]; + } + + ret = tsk_table_collection_init(&tables, 0); + CHECK_TSK(ret); + + ret = tsk_table_collection_load(&tables, filename, 0); + CHECK_TSK(ret); + + ret = tsk_treeseq_init(&treeseq, &tables, 0); + CHECK_TSK(ret); + + tsk_size_t num_nodes = tsk_treeseq_get_num_nodes(&treeseq); + tsk_size_t num_sites = tsk_treeseq_get_num_sites(&treeseq); + if (num_sites == 0) { + fprintf(stderr, "Tree sequence has no sites\n"); + exit(EXIT_FAILURE); + } + + tsk_id_t node_limit + = (tsk_id_t) (num_nodes < MAX_BENCHMARK_NODES ? num_nodes : MAX_BENCHMARK_NODES); + + ret = tsk_haplotype_init(&haplotype_decoder, &treeseq, 0, (tsk_id_t) num_sites); + CHECK_TSK(ret); + + haplotype = malloc(num_sites * sizeof(*haplotype)); + if (haplotype == NULL) { + fprintf(stderr, "Failed to allocate haplotype buffer\n"); + exit(EXIT_FAILURE); + } + + start_clock = clock(); + for (int iter = 0; iter < NUM_ITERATIONS; iter++) { + for (tsk_id_t node = 0; node < node_limit; node++) { + ret = tsk_haplotype_decode(&haplotype_decoder, node, haplotype); + CHECK_TSK(ret); + for (tsk_id_t site = 0; site < (tsk_id_t) num_sites; site++) { + checksum += (uint64_t) haplotype[site]; + } + } + } + end_clock = clock(); + + elapsed_seconds = (double) (end_clock - start_clock) / CLOCKS_PER_SEC; + + printf("Loaded tree sequence from %s\n", filename); + printf("Decoded %d iterations over %lld nodes × %lld sites in %.3f seconds\n", + NUM_ITERATIONS, (long long) node_limit, (long long) num_sites, elapsed_seconds); + printf("Checksummed haplotypes: %llu\n", (unsigned long long) checksum); + + free(haplotype); + tsk_haplotype_free(&haplotype_decoder); + tsk_treeseq_free(&treeseq); + tsk_table_collection_free(&tables); + return EXIT_SUCCESS; +} diff --git a/c/meson.build b/c/meson.build index f5c1a0f585..6d0e8c66b5 100644 --- a/c/meson.build +++ b/c/meson.build @@ -113,6 +113,9 @@ if not meson.is_subproject() executable('tree_traversal', sources: ['examples/tree_traversal.c'], link_with: [tskit_lib], dependencies: lib_deps) + executable('haplotype_benchmark', + sources: ['examples/haplotype_benchmark.c'], + link_with: [tskit_lib], dependencies: lib_deps) executable('streaming', sources: ['examples/streaming.c'], link_with: [tskit_lib], dependencies: lib_deps) diff --git a/c/tskit/genotypes.c b/c/tskit/genotypes.c index 2ff7fd2582..a1a6016d87 100644 --- a/c/tskit/genotypes.c +++ b/c/tskit/genotypes.c @@ -123,6 +123,21 @@ tsk_haplotype_bitset_next( return start < limit ? start : limit; } +static inline uint64_t +tsk_haplotype_mask_from_offsets(uint32_t start_offset, uint32_t end_offset) +{ + if (start_offset >= end_offset) { + return 0; + } + if (start_offset == 0 && end_offset >= 64) { + return UINT64_MAX; + } + uint64_t high_mask + = end_offset >= 64 ? UINT64_MAX : ((UINT64_C(1) << end_offset) - 1); + uint64_t low_mask = start_offset == 0 ? 0 : ((UINT64_C(1) << start_offset) - 1); + return high_mask & ~low_mask; +} + static bool tsk_haplotype_find_next_uncovered(tsk_haplotype_t *self, tsk_size_t start, tsk_size_t end, const int32_t *interval_start, const int32_t *interval_end, @@ -139,6 +154,7 @@ tsk_haplotype_find_next_uncovered(tsk_haplotype_t *self, tsk_size_t start, uint64_t start_mask = UINT64_MAX << (start & 63); for (; word <= last_word && word < self->num_bit_words; word++) { if (self->unresolved_counts[word] == 0) { + start_mask = UINT64_MAX; continue; } uint64_t word_bits = self->unresolved_bits[word]; @@ -149,26 +165,55 @@ tsk_haplotype_find_next_uncovered(tsk_haplotype_t *self, tsk_size_t start, uint64_t end_mask = UINT64_MAX >> (63 - ((end - 1) & 63)); word_bits &= end_mask; } + if (word_bits == 0) { + start_mask = UINT64_MAX; + continue; + } + if (interval_count > 0) { + int32_t word_left = (int32_t)(word << 6); + int32_t word_right = word_left + 64; + uint64_t coverage_mask = 0; + for (tsk_size_t p = 0; p < interval_count; p++) { + int32_t interval_left = interval_start[p]; + int32_t interval_right = interval_end[p]; + if (interval_left >= interval_right) { + continue; + } + if (interval_right <= word_left || interval_left >= word_right) { + continue; + } + int32_t clipped_left + = interval_left > word_left ? interval_left : word_left; + int32_t clipped_right + = interval_right < word_right ? interval_right : word_right; + if ((int32_t) start > clipped_left) { + clipped_left = (int32_t) start; + } + if ((int32_t) end < clipped_right) { + clipped_right = (int32_t) end; + } + if (clipped_left >= clipped_right) { + continue; + } + uint32_t start_offset = (uint32_t)(clipped_left - word_left); + uint32_t end_offset = (uint32_t)(clipped_right - word_left); + coverage_mask + |= tsk_haplotype_mask_from_offsets(start_offset, end_offset); + if (coverage_mask == UINT64_MAX) { + break; + } + } + word_bits &= ~coverage_mask; + } while (word_bits != 0) { - uint64_t lowest_bit = word_bits & (~word_bits + 1); tsk_size_t bit = tsk_haplotype_ctz64(word_bits); - word_bits ^= lowest_bit; + word_bits &= word_bits - 1; tsk_size_t bit_index = (word << 6) + bit; if (bit_index >= end) { break; } - bool covered = false; - for (tsk_size_t p = 0; p < interval_count; p++) { - if (interval_start[p] <= (int32_t) bit_index - && (int32_t) bit_index < interval_end[p]) { - covered = true; - break; - } - } - if (!covered) { - *out_index = bit_index; - return true; - } + *out_index = bit_index; + return true; } start_mask = UINT64_MAX; } From af7f033dfb50b86f9d31e16f0730bd4dac0e58eb Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Thu, 16 Oct 2025 13:08:56 +0100 Subject: [PATCH 10/11] Comments --- c/tskit/genotypes.c | 18 ++++++++++++++++++ python/tskit/trees.py | 6 +++++- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/c/tskit/genotypes.c b/c/tskit/genotypes.c index a1a6016d87..dd85e17bf2 100644 --- a/c/tskit/genotypes.c +++ b/c/tskit/genotypes.c @@ -37,6 +37,8 @@ #include +// FIXME Tskit already has a bitset implementation that maybe we could use + static inline uint32_t tsk_haplotype_ctz64(uint64_t x) { @@ -151,6 +153,7 @@ tsk_haplotype_find_next_uncovered(tsk_haplotype_t *self, tsk_size_t start, if (word >= self->num_bit_words) { return false; } + // FIXME Horrendous logic here, needs jeromeifying. uint64_t start_mask = UINT64_MAX << (start & 63); for (; word <= last_word && word < self->num_bit_words; word++) { if (self->unresolved_counts[word] == 0) { @@ -231,6 +234,8 @@ tsk_haplotype_reset_bitset(tsk_haplotype_t *self) } } +// FIXME We're building the whole index here, which is a bit sad when we're clipping a +// region. static int tsk_haplotype_build_parent_index(tsk_haplotype_t *self) { @@ -313,6 +318,7 @@ tsk_haplotype_build_parent_index(tsk_haplotype_t *self) return ret; } +// FIXME No point adding mutations who are above nodes we have no interest in. static int tsk_haplotype_build_mutation_index(tsk_haplotype_t *self) { @@ -408,6 +414,7 @@ tsk_haplotype_build_mutation_index(tsk_haplotype_t *self) return ret; } +// FIXME Not sure this is even needed static int tsk_haplotype_build_ancestral_states(tsk_haplotype_t *self) { @@ -659,6 +666,7 @@ tsk_haplotype_decode(tsk_haplotype_t *self, tsk_id_t node, int8_t *haplotype) edge_parent = edges->parent; bits = self->unresolved_bits; + // Create a bitset that tracks which sites are still unresolved for (idx = 0; idx < (tsk_size_t) self->num_sites; idx++) { haplotype[idx] = (int8_t) self->ancestral_states[idx]; } @@ -666,6 +674,7 @@ tsk_haplotype_decode(tsk_haplotype_t *self, tsk_id_t node, int8_t *haplotype) mut_start = self->node_mutation_offsets[node]; mut_end = self->node_mutation_offsets[node + 1]; + // Apply mutations above this node for (int32_t m = mut_start; m < mut_end; m++) { int32_t site = self->node_mutation_sites[m]; if (site >= 0 && site < self->num_sites @@ -682,6 +691,8 @@ tsk_haplotype_decode(tsk_haplotype_t *self, tsk_id_t node, int8_t *haplotype) child_start = self->parent_index_range[range_offset]; child_stop = self->parent_index_range[range_offset + 1]; } + // Push all edges from this node (that are still relavent to resolving sites) onto + // the stack for (int32_t i = child_start; i < child_stop; i++) { tsk_id_t edge = self->parent_edge_index[i]; int32_t start = self->edge_start_index[edge]; @@ -700,6 +711,7 @@ tsk_haplotype_decode(tsk_haplotype_t *self, tsk_id_t node, int8_t *haplotype) } } + // Now process the stack until we run out of edges or have resolved all sites while (stack_top > 0) { stack_top--; tsk_id_t edge = self->edge_stack[stack_top]; @@ -707,6 +719,7 @@ tsk_haplotype_decode(tsk_haplotype_t *self, tsk_id_t node, int8_t *haplotype) interval_start = self->stack_interval_start[stack_top]; interval_end = self->stack_interval_end[stack_top]; + // Apply mutations above this ancestor if (ancestor >= 0) { mut_start = self->node_mutation_offsets[ancestor]; mut_end = self->node_mutation_offsets[ancestor + 1]; @@ -720,6 +733,8 @@ tsk_haplotype_decode(tsk_haplotype_t *self, tsk_id_t node, int8_t *haplotype) } } + // Going up the tree push all edges from this ancestor (that are still relavent + // to resolving sites) parent_count = 0; if (ancestor >= 0 && self->parent_index_range != NULL) { int32_t range_offset = ancestor * 2; @@ -756,6 +771,8 @@ tsk_haplotype_decode(tsk_haplotype_t *self, tsk_id_t node, int8_t *haplotype) child_stop = 0; } + // Clear out any sites that are still unresolved in this interval but not covered + // by any parent edges tsk_size_t uncovered_idx; while (tsk_haplotype_find_next_uncovered(self, (tsk_size_t) interval_start, (tsk_size_t) interval_end, self->parent_interval_start, @@ -766,6 +783,7 @@ tsk_haplotype_decode(tsk_haplotype_t *self, tsk_id_t node, int8_t *haplotype) } } + // Reset the bitset for next time for (tsk_size_t w = 0; w < self->num_bit_words; w++) { self->unresolved_bits[w] = 0; self->unresolved_counts[w] = 0; diff --git a/python/tskit/trees.py b/python/tskit/trees.py index b08548f7df..89d3436cb1 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -5274,6 +5274,8 @@ def _haplotypes_array( num_sites = stop_site - start_site missing_int8 = ord(missing_data_character) + # FIXME! The low-level code doesn't support isolated_as_missing + # yet so we do this ugly check here want_missing = ( True if isolated_as_missing is None else bool(isolated_as_missing) ) @@ -5323,7 +5325,9 @@ def _haplotypes_array( if num_samples == 0 or num_sites == 0: return H, (start_site, stop_site - 1) - # For now deal with missing data using the variants iterator + # FIXME! The low-level code doesn't support isolated_as_missing + # yet so we do this ugly thing of using the variants code to find + # sites with missing data missing_mask = None if want_missing: for var in self.variants( From dfd3055eeff99df172cf0efe77c82f84348d1d35 Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Thu, 16 Oct 2025 13:53:23 +0100 Subject: [PATCH 11/11] Fix jit --- python/tskit/jit/numba.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/tskit/jit/numba.py b/python/tskit/jit/numba.py index 698c2cdee4..c5c0059183 100644 --- a/python/tskit/jit/numba.py +++ b/python/tskit/jit/numba.py @@ -679,12 +679,13 @@ def _bitset_init(bitset, num_sites): # Initialise all bits to 1 (meaning "unresolved") and mask any unused bits # in the final word when the number of sites is not a multiple of 64. n_words = bitset.shape[0] + all_bits = np.uint64((1 << 64) - 1) for w in range(n_words): - bitset[w] = np.uint64(-1) + bitset[w] = all_bits if n_words > 0: excess = n_words * 64 - num_sites if excess > 0: - mask = np.uint64(-1) >> excess + mask = all_bits >> excess bitset[n_words - 1] = mask