diff --git a/c/tests/test_tables.c b/c/tests/test_tables.c index 9def0a29d6..6593b39ef5 100644 --- a/c/tests/test_tables.c +++ b/c/tests/test_tables.c @@ -11240,6 +11240,107 @@ test_table_collection_union(void) tsk_table_collection_free(&tables); } +static void +test_table_collection_disjoint_union(void) +{ + int ret; + tsk_id_t ret_id; + tsk_table_collection_t tables; + tsk_table_collection_t tables1; + tsk_table_collection_t tables2; + tsk_table_collection_t tables12; + tsk_id_t node_mapping[4]; + + tsk_memset(node_mapping, 0xff, sizeof(node_mapping)); + + ret = tsk_table_collection_init(&tables1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tables1.sequence_length = 2; + + // set up nodes, which will be shared + // flags, time, pop, ind, metadata, metadata_length + ret_id = tsk_node_table_add_row( + &tables1.nodes, TSK_NODE_IS_SAMPLE, 0.0, TSK_NULL, TSK_NULL, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_node_table_add_row( + &tables1.nodes, TSK_NODE_IS_SAMPLE, 0.0, TSK_NULL, TSK_NULL, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_node_table_add_row(&tables1.nodes, 0, 0.5, TSK_NULL, TSK_NULL, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_node_table_add_row(&tables1.nodes, 0, 1.5, TSK_NULL, TSK_NULL, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret = tsk_table_collection_copy(&tables1, &tables2, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + // for tables1: + // on [0, 1] we have 0, 1 inherit from 2 + // left, right, parent, child, metadata, metadata_length + ret_id = tsk_edge_table_add_row(&tables1.edges, 0.0, 1.0, 2, 0, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_edge_table_add_row(&tables1.edges, 0.0, 1.0, 2, 1, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_site_table_add_row(&tables1.sites, 0.4, "T", 1, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_mutation_table_add_row( + &tables1.mutations, ret_id, 0, TSK_NULL, TSK_UNKNOWN_TIME, NULL, 0, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret = tsk_table_collection_build_index(&tables1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_table_collection_sort(&tables1, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + // all this goes in tables12 so far + ret = tsk_table_collection_copy(&tables1, &tables12, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + // for tables2; and need to add it to tables12 also: + // on [1, 2] we have 0, 1 inherit from 3 + // left, right, parent, child, metadata, metadata_length + ret_id = tsk_edge_table_add_row(&tables2.edges, 1.0, 2.0, 3, 0, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_edge_table_add_row(&tables2.edges, 1.0, 2.0, 3, 1, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_site_table_add_row(&tables2.sites, 1.4, "A", 1, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_mutation_table_add_row( + &tables2.mutations, ret_id, 0, TSK_NULL, TSK_UNKNOWN_TIME, "T", 1, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret = tsk_table_collection_build_index(&tables2, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_table_collection_sort(&tables2, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + // also tables12 + ret_id = tsk_edge_table_add_row(&tables12.edges, 1.0, 2.0, 3, 0, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_edge_table_add_row(&tables12.edges, 1.0, 2.0, 3, 1, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_site_table_add_row(&tables12.sites, 1.4, "A", 1, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret_id = tsk_mutation_table_add_row( + &tables12.mutations, ret_id, 1, TSK_NULL, TSK_UNKNOWN_TIME, "T", 1, NULL, 0); + CU_ASSERT_FATAL(ret_id >= 0); + ret = tsk_table_collection_build_index(&tables12, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_table_collection_sort(&tables12, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + // now disjoint union-ing tables1 and tables2 should get tables12 + ret = tsk_table_collection_copy(&tables1, &tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + node_mapping[0] = 0; + node_mapping[1] = 1; + node_mapping[2] = 2; + node_mapping[3] = 3; + ret = tsk_table_collection_union(&tables, &tables2, node_mapping, + TSK_UNION_NO_CHECK_SHARED | TSK_UNION_ALL_EDGES | TSK_UNION_ALL_MUTATIONS); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + tsk_table_collection_free(&tables12); + tsk_table_collection_free(&tables2); + tsk_table_collection_free(&tables1); + tsk_table_collection_free(&tables); +} + static void test_table_collection_union_middle_merge(void) { @@ -11836,6 +11937,7 @@ main(int argc, char **argv) test_table_collection_subset_unsorted }, { "test_table_collection_subset_errors", test_table_collection_subset_errors }, { "test_table_collection_union", test_table_collection_union }, + { "test_table_collection_disjoint_union", test_table_collection_disjoint_union }, { "test_table_collection_union_middle_merge", test_table_collection_union_middle_merge }, { "test_table_collection_union_errors", test_table_collection_union_errors }, diff --git a/c/tskit/tables.c b/c/tskit/tables.c index 9805d669a5..7106300f3d 100644 --- a/c/tskit/tables.c +++ b/c/tskit/tables.c @@ -13202,6 +13202,8 @@ tsk_table_collection_union(tsk_table_collection_t *self, tsk_id_t *site_map = NULL; bool add_populations = !(options & TSK_UNION_NO_ADD_POP); bool check_shared_portion = !(options & TSK_UNION_NO_CHECK_SHARED); + bool all_edges = !!(options & TSK_UNION_ALL_EDGES); + bool all_mutations = !!(options & TSK_UNION_ALL_MUTATIONS); /* Not calling TSK_CHECK_TREES so casting to int is safe */ ret = (int) tsk_table_collection_check_integrity(self, 0); @@ -13285,7 +13287,7 @@ tsk_table_collection_union(tsk_table_collection_t *self, // edges for (k = 0; k < (tsk_id_t) other->edges.num_rows; k++) { tsk_edge_table_get_row_unsafe(&other->edges, k, &edge); - if ((other_node_mapping[edge.parent] == TSK_NULL) + if (all_edges || (other_node_mapping[edge.parent] == TSK_NULL) || (other_node_mapping[edge.child] == TSK_NULL)) { new_parent = node_map[edge.parent]; new_child = node_map[edge.child]; @@ -13298,14 +13300,31 @@ tsk_table_collection_union(tsk_table_collection_t *self, } } - // mutations and sites + // sites + // first do the "disjoint" (all_mutations) case, where we just add all sites; + // otherwise we want to just add sites for new mutations + if (all_mutations) { + for (k = 0; k < (tsk_id_t) other->sites.num_rows; k++) { + tsk_site_table_get_row_unsafe(&other->sites, k, &site); + ret_id = tsk_site_table_add_row(&self->sites, site.position, + site.ancestral_state, site.ancestral_state_length, site.metadata, + site.metadata_length); + if (ret_id < 0) { + ret = (int) ret_id; + goto out; + } + site_map[site.id] = ret_id; + } + } + + // mutations (and maybe sites) i = 0; for (k = 0; k < (tsk_id_t) other->sites.num_rows; k++) { tsk_site_table_get_row_unsafe(&other->sites, k, &site); while ((i < (tsk_id_t) other->mutations.num_rows) && (other->mutations.site[i] == site.id)) { tsk_mutation_table_get_row_unsafe(&other->mutations, i, &mut); - if (other_node_mapping[mut.node] == TSK_NULL) { + if (all_mutations || (other_node_mapping[mut.node] == TSK_NULL)) { if (site_map[site.id] == TSK_NULL) { ret_id = tsk_site_table_add_row(&self->sites, site.position, site.ancestral_state, site.ancestral_state_length, site.metadata, diff --git a/c/tskit/tables.h b/c/tskit/tables.h index 85ed29d58c..b9a4c28415 100644 --- a/c/tskit/tables.h +++ b/c/tskit/tables.h @@ -858,11 +858,21 @@ equality of the subsets. */ #define TSK_UNION_NO_CHECK_SHARED (1 << 0) /** - By default, all nodes new to ``self`` are assigned new populations. If this +By default, all nodes new to ``self`` are assigned new populations. If this option is specified, nodes that are added to ``self`` will retain the population IDs they have in ``other``. */ #define TSK_UNION_NO_ADD_POP (1 << 1) +/** +By default, union only adds only edges adjacent to a newly added node; +this option adds all edges. + */ +#define TSK_UNION_ALL_EDGES (1 << 2) +/** +By default, union only adds only mutations on newly added edges; +this option adds all mutations. + */ +#define TSK_UNION_ALL_MUTATIONS (1 << 3) /** @} */ /** @@ -4414,6 +4424,10 @@ that are exclusive ``other`` are added to ``self``, along with: By default, populations of newly added nodes are assumed to be new populations, and added to the population table as well. +The behavior can be changed by the flags ``TSK_UNION_ALL_EDGES`` and +``TSK_UNION_ALL_MUTATIONS``, which will (respectively) add *all* edges +or *all* sites and mutations instead. + This operation will also sort the resulting tables, so the tables may change even if nothing new is added, if the original tables were not sorted. diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index f595cf5150..c538d1de51 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -7069,15 +7069,18 @@ TableCollection_union(TableCollection *self, PyObject *args, PyObject *kwds) npy_intp *shape; tsk_flags_t options = 0; int check_shared = true; + int all_edges = false; + int all_mutations = false; int add_populations = true; - static char *kwlist[] = { "other", "other_node_mapping", "check_shared_equality", - "add_populations", NULL }; + static char *kwlist[] = { "other", "other_node_mapping", "all_edges", + "all_mutations", "check_shared_equality", "add_populations", NULL }; if (TableCollection_check_state(self) != 0) { goto out; } - if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!O|ii", kwlist, &TableCollectionType, - &other, &other_node_mapping, &check_shared, &add_populations)) { + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!O|iiii", kwlist, + &TableCollectionType, &other, &other_node_mapping, &all_edges, + &all_mutations, &check_shared, &add_populations)) { goto out; } nmap_array = (PyArrayObject *) PyArray_FROMANY( @@ -7092,6 +7095,12 @@ TableCollection_union(TableCollection *self, PyObject *args, PyObject *kwds) " number of nodes in the other tree sequence."); goto out; } + if (all_edges) { + options |= TSK_UNION_ALL_EDGES; + } + if (all_mutations) { + options |= TSK_UNION_ALL_MUTATIONS; + } if (!check_shared) { options |= TSK_UNION_NO_CHECK_SHARED; } diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index ccf1022ed0..aa1b442384 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -433,6 +433,25 @@ def test_union_bad_args(self): with pytest.raises(ValueError): tc.union(tc2, np.array([[1], [2]], dtype="int32")) + @pytest.mark.parametrize("value", [True, False]) + @pytest.mark.parametrize( + "flag", + [ + "all_edges", + "all_mutations", + "check_shared_equality", + "add_populations", + ], + ) + def test_union_options(self, flag, value): + ts = msprime.simulate(10, random_seed=1) + tc = ts.tables._ll_tables + empty_tables = ts.tables.copy() + for table in empty_tables.table_name_map.keys(): + getattr(empty_tables, table).clear() + tc2 = empty_tables._ll_tables + tc.union(tc2, np.arange(0, dtype="int32"), **{flag: value}) + def test_equals_bad_args(self): ts = msprime.simulate(10, random_seed=1242) tc = ts.tables._ll_tables diff --git a/python/tests/test_tables.py b/python/tests/test_tables.py index 5c59653c7b..6b797e38fd 100644 --- a/python/tests/test_tables.py +++ b/python/tests/test_tables.py @@ -5172,6 +5172,51 @@ def test_examples(self): ts = tables.tree_sequence() self.verify_union(*self.split_example(ts, T)) + def test_split_and_rejoin(self): + ts = self.get_msprime_example(5, T=2, seed=928) + cutpoints = np.array([0, 0.25, 0.5, 0.75, 1]) * ts.sequence_length + tables1 = ts.dump_tables() + tables1.delete_intervals([cutpoints[0:2], cutpoints[2:4]], simplify=False) + tables2 = ts.dump_tables() + tables2.delete_intervals([cutpoints[1:3], cutpoints[3:]], simplify=False) + tables1.union( + tables2, + all_edges=True, + all_mutations=True, + node_mapping=np.arange(ts.num_nodes), + check_shared_equality=False, + ) + tables1.edges.squash() + tables1.sort() + tables1.assert_equals(ts.tables, ignore_provenance=True) + + def test_add_from_empty(self): + # reciprocally add mutations from one table and edges from the other + edges_table = tskit.Tree.generate_comb(6, span=6).tree_sequence.dump_tables() + muts_table = tskit.TableCollection(sequence_length=6) + muts_table.nodes.replace_with(edges_table.nodes) # same nodes, no edges + for j in range(0, 6): + site_id = muts_table.sites.add_row(position=j, ancestral_state="0") + if j % 2 == 0: + # Some sites empty + muts_table.mutations.add_row(site=site_id, node=j, derived_state="1") + identity_map = np.arange(len(muts_table.nodes), dtype="int32") + params = {"node_mapping": identity_map, "check_shared_equality": False} + + edges_table.union(muts_table, **params, all_edges=True) # null op + assert len(edges_table.sites) == 0 + assert len(edges_table.mutations) == 0 + edges_table.union(muts_table, **params, all_mutations=True) + assert len(edges_table.sites) == 6 + assert len(edges_table.mutations) == 3 + + muts_table.union(edges_table, **params, all_mutations=True) # null op + assert len(muts_table.edges) == 0 + muts_table.union(edges_table, **params, all_edges=True) # null op + assert len(muts_table.edges) != 0 + + edges_table.assert_equals(muts_table, ignore_provenance=True) + class TestTableSetitemMetadata: @pytest.mark.parametrize("table_name", tskit.TABLE_NAMES) diff --git a/python/tests/test_topology.py b/python/tests/test_topology.py index dd1eb9f23e..24b39f3a3c 100644 --- a/python/tests/test_topology.py +++ b/python/tests/test_topology.py @@ -7215,18 +7215,22 @@ def test_reference_sequence(self): class TestConcatenate: def test_simple(self): ts1 = tskit.Tree.generate_comb(5, span=2).tree_sequence + ts1 = msprime.sim_mutations(ts1, rate=1, random_seed=1) ts2 = tskit.Tree.generate_balanced(5, arity=3, span=3).tree_sequence + ts2 = msprime.sim_mutations(ts2, rate=1, random_seed=2) assert ts1.num_samples == ts2.num_samples assert ts1.num_nodes != ts2.num_nodes joint_ts = ts1.concatenate(ts2) assert joint_ts.num_nodes == ts1.num_nodes + ts2.num_nodes - 5 assert joint_ts.sequence_length == ts1.sequence_length + ts2.sequence_length assert joint_ts.num_samples == ts1.num_samples + assert joint_ts.num_sites == ts1.num_sites + ts2.num_sites + assert joint_ts.num_mutations == ts1.num_mutations + ts2.num_mutations ts3 = joint_ts.delete_intervals([[2, 5]]).rtrim() # Have to simplify here, to remove the redundant nodes - assert ts3.equals(ts1.simplify(), ignore_provenance=True) + ts3.tables.assert_equals(ts1.tables, ignore_provenance=True) ts4 = joint_ts.delete_intervals([[0, 2]]).ltrim() - assert ts4.equals(ts2.simplify(), ignore_provenance=True) + ts4.tables.assert_equals(ts2.tables, ignore_provenance=True) def test_multiple(self): np.random.seed(42) @@ -7278,15 +7282,47 @@ def test_internal_samples(self): assert joint_ts.sequence_length == ts.sequence_length * 2 def test_some_shared_samples(self): - ts1 = tskit.Tree.generate_comb(4, span=2).tree_sequence - ts2 = tskit.Tree.generate_balanced(8, arity=3, span=3).tree_sequence - shared = np.full(ts2.num_nodes, tskit.NULL) - shared[0] = 1 - shared[1] = 0 - joint_ts = ts1.concatenate(ts2, node_mappings=[shared]) - assert joint_ts.sequence_length == ts1.sequence_length + ts2.sequence_length - assert joint_ts.num_samples == ts1.num_samples + ts2.num_samples - 2 - assert joint_ts.num_nodes == ts1.num_nodes + ts2.num_nodes - 2 + tables = tskit.Tree.generate_comb(5).tree_sequence.dump_tables() + tables.nodes[5] = tables.nodes[5].replace(flags=tskit.NODE_IS_SAMPLE) + ts1 = tables.tree_sequence() + tables = tskit.Tree.generate_balanced(5).tree_sequence.dump_tables() + tables.nodes[5] = tables.nodes[5].replace(flags=tskit.NODE_IS_SAMPLE) + ts2 = tables.tree_sequence() + assert ts1.num_samples == ts2.num_samples + joint_ts = ts1.concatenate(ts2) + assert joint_ts.num_samples == ts1.num_samples + assert joint_ts.num_edges == ts1.num_edges + ts2.num_edges + for tree in joint_ts.trees(): + assert tree.num_roots == 1 + + @pytest.mark.parametrize("simplify", [True, False]) + def test_wf_sim(self, simplify): + # Test that we can split & concat a wf_sim ts, which has internal samples + tables = wf.wf_sim( + 6, + 5, + seed=3, + deep_history=True, + initial_generation_samples=True, + num_loci=10, + ) + tables.sort() + tables.simplify() + ts = msprime.mutate(tables.tree_sequence(), rate=0.05, random_seed=234) + assert ts.num_trees > 2 + assert len(np.unique(ts.nodes_time[ts.samples()])) > 1 + ts1 = ts.keep_intervals([[0, 4.5]], simplify=False).trim() + ts2 = ts.keep_intervals([[4.5, ts.sequence_length]], simplify=False).trim() + if simplify: + ts1 = ts1.simplify(filter_nodes=False) + ts2, node_map = ts2.simplify(map_nodes=True) + node_mapping = np.zeros_like(node_map, shape=ts2.num_nodes) + kept = node_map != tskit.NULL + node_mapping[node_map[kept]] = np.arange(len(node_map))[kept] + else: + node_mapping = np.arange(ts.num_nodes) + ts_new = ts1.concatenate(ts2, node_mappings=[node_mapping]).simplify() + ts_new.tables.assert_equals(ts.tables, ignore_provenance=True) def test_provenance(self): ts = tskit.Tree.generate_comb(2).tree_sequence @@ -7304,9 +7340,6 @@ def test_unequal_samples(self): with pytest.raises(ValueError, match="must have the same number of samples"): ts1.concatenate(ts2) - @pytest.mark.skip( - reason="union bug: https://github.com/tskit-dev/tskit/issues/3168" - ) def test_duplicate_ts(self): ts1 = tskit.Tree.generate_comb(3, span=4).tree_sequence ts = ts1.keep_intervals([[0, 1]]).trim() # a quarter of the original diff --git a/python/tskit/tables.py b/python/tskit/tables.py index 09e22e4443..68fb78c0c3 100644 --- a/python/tskit/tables.py +++ b/python/tskit/tables.py @@ -4181,6 +4181,8 @@ def union( self, other, node_mapping, + all_edges=False, + all_mutations=False, check_shared_equality=True, add_populations=True, record_provenance=True, @@ -4199,6 +4201,10 @@ def union( should be the index of the equivalent node in ``self``, or ``tskit.NULL`` if the node is not present in ``self`` (in which case it will be added to self). + :param bool all_edges: If True, then all edges in ``other`` are added + to ``self``. Must have ``check_shared_equality=False``. + :param bool all_mutations: If True, then all mutations in ``other`` are added + to ``self``. Must have ``check_shared_equality=False``. :param bool check_shared_equality: If True, the shared portions of the table collections will be checked for equality. :param bool add_populations: If True, nodes new to ``self`` will be @@ -4210,6 +4216,8 @@ def union( self._ll_tables.union( other._ll_tables, node_mapping, + all_edges=all_edges, + all_mutations=all_mutations, check_shared_equality=check_shared_equality, add_populations=add_populations, ) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 81c49c3224..f154361abe 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -7241,6 +7241,8 @@ def concatenate( other_tables, node_mapping=node_mapping, check_shared_equality=False, # Else checks fail with internal samples + all_mutations=True, + all_edges=True, record_provenance=False, add_populations=add_populations, ) @@ -7466,6 +7468,8 @@ def union( self, other, node_mapping, + all_edges=False, + all_mutations=False, check_shared_equality=True, add_populations=True, record_provenance=True, @@ -7513,6 +7517,10 @@ def union( :param TableCollection other: Another table collection. :param list node_mapping: An array of node IDs that relate nodes in ``other`` to nodes in ``self``. + :param bool all_edges: If True, then all edges in ``other`` are added + to ``self``. + :param bool all_mutations: If True, then all mutations and sites in + ``other`` are added to ``self``. :param bool check_shared_equality: If True, the shared portions of the tree sequences will be checked for equality. It does so by running :meth:`TreeSequence.subset` on both ``self`` and ``other`` @@ -7528,6 +7536,8 @@ def union( tables.union( other_tables, node_mapping, + all_edges=all_edges, + all_mutations=all_mutations, check_shared_equality=check_shared_equality, add_populations=add_populations, record_provenance=record_provenance,