Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/python-api.md
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ which perform the same actions but modify the {class}`TableCollection` in place.
.. autosummary::
TreeSequence.simplify
TreeSequence.subset
TreeSequence.merge
TreeSequence.union
TreeSequence.concatenate
TreeSequence.keep_intervals
Expand Down Expand Up @@ -753,6 +754,7 @@ a functional way, returning a new tree sequence while leaving the original uncha
TableCollection.delete_sites
TableCollection.trim
TableCollection.shift
TableCollection.merge
TableCollection.union
TableCollection.delete_older
```
Expand Down
7 changes: 7 additions & 0 deletions python/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,19 @@
- ``TreeSequence.map_to_vcf_model`` now also returns the transformed positions and
contig length. (:user:`benjeffery`, :pr:`XXXX`, :issue:`3173`)

- New ``merge`` functions for tree sequences and table collections, to merge another
into the current one (:user:`hyanwong`, :pr:`3183`, :issue:`3181`)

**Bugfixes**

- Fix bug in ``TreeSequence.pair_coalescence_counts`` when ``span_normalise=True``
and a window breakpoint falls within an internal missing interval.
(:user:`nspope`, :pr:`3176`, :issue:`3175`)

- Change ``TreeSequence.concatenate`` to use ``merge``, as ``union`` does not
port edges, sites, or mutations from the added tree sequences if they are associated
with shared nodes (:user:`hyanwong`, :pr:`3183`, :issue:`3168`, :issue:`3182`)

--------------------
[0.6.4] - 2025-05-21
--------------------
Expand Down
281 changes: 268 additions & 13 deletions python/tests/test_topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import io
import itertools
import json
import platform
import random
import sys
import unittest
Expand All @@ -43,6 +44,9 @@
import tskit.provenance as provenance


IS_WINDOWS = platform.system() == "Windows"


def simple_keep_intervals(tables, intervals, simplify=True, record_provenance=True):
"""
Simple Python implementation of keep_intervals.
Expand Down Expand Up @@ -7141,18 +7145,223 @@ def test_bad_seq_len(self):
ts.shift(1, sequence_length=1)


class TestMerge:
def test_empty(self):
ts = tskit.TableCollection(2).tree_sequence()
merged_ts = ts.merge(ts, node_mapping=[])
assert merged_ts.num_nodes == 0
assert merged_ts.num_edges == 0
assert merged_ts.sequence_length == 2

def test_overlay(self):
ts1 = tskit.Tree.generate_balanced(4, span=2).tree_sequence
tables = tskit.Tree.generate_comb(4, span=2).tree_sequence.dump_tables()
tables.populations.add_row()
tables.nodes[5] = tables.nodes[5].replace(
flags=tskit.NODE_IS_SAMPLE, population=0
)
ts2 = tables.tree_sequence()
ts_merge = ts1.merge(ts2, node_mapping=np.full(ts1.num_nodes, tskit.NULL))
assert ts_merge.sequence_length == ts1.sequence_length
assert ts_merge.num_samples == ts1.num_samples + ts2.num_samples
assert ts_merge.num_nodes == ts1.num_nodes + ts2.num_nodes
assert ts_merge.num_edges == ts1.num_edges + ts2.num_edges
assert ts_merge.num_trees == 1
assert ts_merge.num_populations == 1
assert ts_merge.first().num_roots == 2

def test_split_and_merge(self):
# Cut up a single tree into alternating edges and mutations, then merge
ts = tskit.Tree.generate_comb(4, span=10).tree_sequence
ts = msprime.sim_mutations(ts, rate=0.1, random_seed=1)
mut_counts = np.bincount(ts.mutations_site, minlength=ts.num_sites)
assert min(mut_counts) == 1
assert max(mut_counts) > 1
tables1 = ts.dump_tables()
tables1.mutations.clear()
tables2 = tables1.copy()
i = 0
for s in ts.sites():
for m in s.mutations:
i += 1
if i % 2:
tables1.mutations.append(m.replace(parent=tskit.NULL))
else:
tables2.mutations.append(m.replace(parent=tskit.NULL))
tables1.simplify()
tables2.simplify()
assert tables1.sites.num_rows != ts.num_sites
tables1.edges.clear()
tables2.edges.clear()
for e in ts.edges():
if e.id % 2:
tables1.edges.append(e)
else:
tables2.edges.append(e)
ts1 = tables1.tree_sequence()
ts2 = tables2.tree_sequence()
new_ts = ts1.merge(ts2, node_mapping=np.arange(ts.num_nodes)).simplify()
assert new_ts.equals(ts, ignore_provenance=True)

def test_multi_tree(self):
ts = msprime.sim_ancestry(
2, sequence_length=4, recombination_rate=1, random_seed=1
)
ts = msprime.sim_mutations(ts, rate=1, random_seed=1)
assert ts.num_trees > 3
assert ts.num_mutations > 4
ts1 = ts.keep_intervals([[0, 1.5]], simplify=False)
ts2 = ts.keep_intervals([[1.5, 4]], simplify=False)
new_ts = ts1.merge(
ts2, node_mapping=np.arange(ts.num_nodes), add_populations=False
)
assert new_ts.num_trees == ts.num_trees + 1
new_ts = new_ts.simplify()
new_ts.equals(ts, ignore_provenance=True)

def test_new_individuals(self):
ts1 = msprime.sim_ancestry(2, sequence_length=1, random_seed=1)
ts2 = msprime.sim_ancestry(2, sequence_length=1, random_seed=2)
tables = ts2.dump_tables()
tables.edges.clear()
ts2 = tables.tree_sequence()
node_map = np.full(ts2.num_nodes, tskit.NULL)
node_map[0:2] = [0, 1] # map first two nodes to themselves
ts_merged = ts1.merge(ts2, node_mapping=node_map)
assert ts_merged.num_nodes == ts1.num_nodes + ts2.num_nodes - 2
assert ts1.num_individuals == 2
assert ts_merged.num_individuals == 3

def test_popcheck(self):
tables = tskit.TableCollection(1)
p1 = tables.populations.add_row(b"foo")
p2 = tables.populations.add_row(b"bar")
tables.nodes.add_row(time=0, flags=tskit.NODE_IS_SAMPLE, population=p1)
tables.nodes.add_row(time=0, flags=tskit.NODE_IS_SAMPLE, population=p2)
ts1 = tables.tree_sequence()
tables.populations[0] = tables.populations[0].replace(metadata=b"baz")
ts2 = tables.tree_sequence()
with pytest.raises(ValueError, match="Non-matching populations"):
ts1.merge(ts2, node_mapping=[0, 1])
ts1.merge(ts2, node_mapping=[0, 1], check_populations=False)
# Check with add_populations=False
ts1.merge(ts2, node_mapping=[-1, 1]) # only merge the last one
with pytest.raises(ValueError, match="Non-matching populations"):
ts1.merge(ts2, node_mapping=[-1, 1], add_populations=False)

with pytest.raises(ValueError, match="Non-matching populations"):
ts1.simplify([0]).merge(ts2, node_mapping=[-1, 1])

def test_isolated_mutations(self):
tables = tskit.TableCollection(1)
u = tables.nodes.add_row(time=0, flags=tskit.NODE_IS_SAMPLE)
s = tables.sites.add_row(0.5, "A")
tables.mutations.add_row(s, u, derived_state="T", time=1, metadata=b"xxx")
ts1 = tables.tree_sequence()
tables.mutations[0] = tables.mutations[0].replace(time=0.5, metadata=b"yyy")
ts2 = tables.tree_sequence()
ts_merge = ts1.merge(ts2, node_mapping=[0])
assert ts_merge.num_sites == 1
assert ts_merge.num_mutations == 2
assert ts_merge.mutation(0).time == 1
assert ts_merge.mutation(0).parent == tskit.NULL
assert ts_merge.mutation(0).metadata == b"xxx"
assert ts_merge.mutation(1).time == 0.5
assert ts_merge.mutation(1).parent == 0
assert ts_merge.mutation(1).metadata == b"yyy"

def test_identity(self):
tables = tskit.TableCollection(1)
tables.nodes.add_row(time=0, flags=tskit.NODE_IS_SAMPLE)
ts = tables.tree_sequence()
ts_merge = ts.merge(ts, node_mapping=[0])
assert ts.equals(ts_merge, ignore_provenance=True)

@pytest.mark.skipif(IS_WINDOWS, reason="Msprime gives different result on Windows")
def test_migrations(self):
pop_configs = [msprime.PopulationConfiguration(3) for _ in range(2)]
migration_matrix = [[0, 0.001], [0.001, 0]]
ts = msprime.simulate(
population_configurations=pop_configs,
migration_matrix=migration_matrix,
record_migrations=True,
recombination_rate=2,
random_seed=42, # pick a seed that gives min(migrations.left) > 0
end_time=100,
)
# No migration_table.squash() function exists, so we just try to cut on the
# LHS of all the migrations
assert ts.num_migrations > 0
assert ts.migrations_left.min() > 0
cutpoint = ts.migrations_left.min()
ts1 = ts.keep_intervals([[0, cutpoint]], simplify=False)
ts2 = ts.keep_intervals([[cutpoint, ts.sequence_length]], simplify=False)
ts_new = ts1.merge(ts2, node_mapping=np.arange(ts.num_nodes))
tables = ts_new.dump_tables()
tables.edges.squash()
tables.sort()
ts_new = tables.tree_sequence()
ts.tables.assert_equals(ts_new.tables, ignore_provenance=True)

def test_provenance(self):
tables = tskit.TableCollection(1)
tables.nodes.add_row(time=0, flags=tskit.NODE_IS_SAMPLE)
ts = tables.tree_sequence()
ts_merge = ts.merge(ts, node_mapping=[0], record_provenance=False)
assert ts_merge.num_provenances == ts.num_provenances
ts_merge = ts.merge(ts, node_mapping=[0])
assert ts_merge.num_provenances == ts.num_provenances + 1
prov = json.loads(ts_merge.provenance(-1).record)
assert prov["parameters"]["command"] == "merge"
assert prov["parameters"]["node_mapping"] == [0]
assert prov["parameters"]["add_populations"] is True
assert prov["parameters"]["check_populations"] is True

def test_bad_sequence_length(self):
ts1 = tskit.TableCollection(1).tree_sequence()
ts2 = tskit.TableCollection(2).tree_sequence()
with pytest.raises(ValueError, match="sequence length"):
ts1.merge(ts2, node_mapping=[])

def test_bad_node_mapping(self):
ts = tskit.Tree.generate_comb(5).tree_sequence
with pytest.raises(ValueError, match="node_mapping"):
ts.merge(ts, node_mapping=[0, 1, 2])

def test_bad_populations(self):
tables = tskit.TableCollection(1)
tables = tskit.TableCollection(1)
p1 = tables.populations.add_row()
p2 = tables.populations.add_row()
tables.nodes.add_row(time=0, flags=tskit.NODE_IS_SAMPLE, population=p1)
tables.nodes.add_row(time=0, flags=tskit.NODE_IS_SAMPLE, population=p1)
tables.nodes.add_row(time=0, flags=tskit.NODE_IS_SAMPLE, population=p2)
ts2 = tables.tree_sequence()
ts1 = ts2.simplify([0, 1])
assert ts1.num_populations == 1
assert ts2.num_populations == 2
ts2.merge(ts1, [0, -1], check_populations=False, add_populations=False)
with pytest.raises(ValueError, match="population not present"):
ts1.merge(ts2, [0, -1, -1], check_populations=False, add_populations=False)


class TestConcatenate:
def test_simple(self):
ts1 = tskit.Tree.generate_comb(5, span=2).tree_sequence
ts1 = msprime.sim_mutations(ts1, rate=1, random_seed=1)
ts2 = tskit.Tree.generate_balanced(5, arity=3, span=3).tree_sequence
ts2 = msprime.sim_mutations(ts2, rate=1, random_seed=1)
assert ts1.num_samples == ts2.num_samples
assert ts1.num_nodes != ts2.num_nodes
joint_ts = ts1.concatenate(ts2)
assert joint_ts.num_nodes == ts1.num_nodes + ts2.num_nodes - 5
assert joint_ts.sequence_length == ts1.sequence_length + ts2.sequence_length
assert joint_ts.num_samples == ts1.num_samples
assert joint_ts.num_sites == ts1.num_sites + ts2.num_sites
assert joint_ts.num_mutations == ts1.num_mutations + ts2.num_mutations
ts3 = joint_ts.delete_intervals([[2, 5]]).rtrim()
# Have to simplify here, to remove the redundant nodes
ts3.tables.assert_equals(ts1.tables, ignore_provenance=True)
assert ts3.equals(ts1.simplify(), ignore_provenance=True)
ts4 = joint_ts.delete_intervals([[0, 2]]).ltrim()
assert ts4.equals(ts2.simplify(), ignore_provenance=True)
Expand Down Expand Up @@ -7183,6 +7392,13 @@ def test_empty(self):
assert ts.num_nodes == 0
assert ts.sequence_length == 40

def test_check_populations(self):
ts = msprime.sim_ancestry(2)
ts1 = ts.concatenate(ts, ts, check_populations=True)
assert ts1.num_populations == 1
ts2 = ts.concatenate(ts, ts, add_populations=True, check_populations=True)
assert ts2.num_populations == 3

def test_samples_at_end(self):
ts1 = tskit.Tree.generate_comb(5, span=2).tree_sequence
ts2 = tskit.Tree.generate_balanced(5, arity=3, span=3).tree_sequence
Expand All @@ -7200,22 +7416,58 @@ def test_internal_samples(self):
nodes_flags[:] = tskit.NODE_IS_SAMPLE
nodes_flags[-1] = 0 # Only root is not a sample
tables.nodes.flags = nodes_flags
ts = tables.tree_sequence()
ts = msprime.sim_mutations(tables.tree_sequence(), rate=0.5, random_seed=1)
assert ts.num_mutations > 0
assert ts.num_mutations > ts.num_sites
joint_ts = ts.concatenate(ts)
assert joint_ts.num_samples == ts.num_samples
assert joint_ts.num_nodes == ts.num_nodes + 1
assert joint_ts.num_mutations == ts.num_mutations * 2
assert joint_ts.num_sites == ts.num_sites * 2
assert joint_ts.sequence_length == ts.sequence_length * 2

def test_some_shared_samples(self):
ts1 = tskit.Tree.generate_comb(4, span=2).tree_sequence
ts2 = tskit.Tree.generate_balanced(8, arity=3, span=3).tree_sequence
shared = np.full(ts2.num_nodes, tskit.NULL)
shared[0] = 1
shared[1] = 0
joint_ts = ts1.concatenate(ts2, node_mappings=[shared])
assert joint_ts.sequence_length == ts1.sequence_length + ts2.sequence_length
assert joint_ts.num_samples == ts1.num_samples + ts2.num_samples - 2
assert joint_ts.num_nodes == ts1.num_nodes + ts2.num_nodes - 2
tables = tskit.Tree.generate_comb(5).tree_sequence.dump_tables()
tables.nodes[5] = tables.nodes[5].replace(flags=tskit.NODE_IS_SAMPLE)
ts1 = tables.tree_sequence()
tables = tskit.Tree.generate_balanced(5).tree_sequence.dump_tables()
tables.nodes[5] = tables.nodes[5].replace(flags=tskit.NODE_IS_SAMPLE)
ts2 = tables.tree_sequence()
assert ts1.num_samples == ts2.num_samples
joint_ts = ts1.concatenate(ts2)
assert joint_ts.num_samples == ts1.num_samples
assert joint_ts.num_edges == ts1.num_edges + ts2.num_edges
for tree in joint_ts.trees():
assert tree.num_roots == 1

@pytest.mark.parametrize("simplify", [True, False])
def test_wf_sim(self, simplify):
# Test that we can split & concat a wf_sim ts, which has internal samples
tables = wf.wf_sim(
6,
5,
seed=3,
deep_history=True,
initial_generation_samples=True,
num_loci=10,
)
tables.sort()
tables.simplify()
ts = msprime.mutate(tables.tree_sequence(), rate=0.05, random_seed=234)
assert ts.num_trees > 2
assert len(np.unique(ts.nodes_time[ts.samples()])) > 1
ts1 = ts.keep_intervals([[0, 4.5]], simplify=False).trim()
ts2 = ts.keep_intervals([[4.5, ts.sequence_length]], simplify=False).trim()
if simplify:
ts1 = ts1.simplify(filter_nodes=False)
ts2, node_map = ts2.simplify(map_nodes=True)
node_mapping = np.zeros_like(node_map, shape=ts2.num_nodes)
kept = node_map != tskit.NULL
node_mapping[node_map[kept]] = np.arange(len(node_map))[kept]
else:
node_mapping = np.arange(ts.num_nodes)
ts_new = ts1.concatenate(ts2, node_mappings=[node_mapping]).simplify()
ts_new.tables.assert_equals(ts.tables, ignore_provenance=True)

def test_provenance(self):
ts = tskit.Tree.generate_comb(2).tree_sequence
Expand All @@ -7233,9 +7485,12 @@ def test_unequal_samples(self):
with pytest.raises(ValueError, match="must have the same number of samples"):
ts1.concatenate(ts2)

@pytest.mark.skip(
reason="union bug: https://github.com/tskit-dev/tskit/issues/3168"
)
def test_different_sample_numbers(self):
ts1 = tskit.Tree.generate_comb(5, span=2).tree_sequence
ts2 = tskit.Tree.generate_balanced(4, arity=3, span=3).tree_sequence
with pytest.raises(ValueError, match="must have the same number of samples"):
ts1.concatenate(ts2)

def test_duplicate_ts(self):
ts1 = tskit.Tree.generate_comb(3, span=4).tree_sequence
ts = ts1.keep_intervals([[0, 1]]).trim() # a quarter of the original
Expand Down
Loading
Loading