Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions c/tests/test_tables.c
Original file line number Diff line number Diff line change
Expand Up @@ -11240,6 +11240,107 @@ test_table_collection_union(void)
tsk_table_collection_free(&tables);
}

static void
test_table_collection_disjoint_union(void)
{
int ret;
tsk_id_t ret_id;
tsk_table_collection_t tables;
tsk_table_collection_t tables1;
tsk_table_collection_t tables2;
tsk_table_collection_t tables12;
tsk_id_t node_mapping[4];

tsk_memset(node_mapping, 0xff, sizeof(node_mapping));

ret = tsk_table_collection_init(&tables1, 0);
CU_ASSERT_EQUAL_FATAL(ret, 0);
tables1.sequence_length = 2;

// set up nodes, which will be shared
// flags, time, pop, ind, metadata, metadata_length
ret_id = tsk_node_table_add_row(
&tables1.nodes, TSK_NODE_IS_SAMPLE, 0.0, TSK_NULL, TSK_NULL, NULL, 0);
CU_ASSERT_FATAL(ret_id >= 0);
ret_id = tsk_node_table_add_row(
&tables1.nodes, TSK_NODE_IS_SAMPLE, 0.0, TSK_NULL, TSK_NULL, NULL, 0);
CU_ASSERT_FATAL(ret_id >= 0);
ret_id = tsk_node_table_add_row(&tables1.nodes, 0, 0.5, TSK_NULL, TSK_NULL, NULL, 0);
CU_ASSERT_FATAL(ret_id >= 0);
ret_id = tsk_node_table_add_row(&tables1.nodes, 0, 1.5, TSK_NULL, TSK_NULL, NULL, 0);
CU_ASSERT_FATAL(ret_id >= 0);
ret = tsk_table_collection_copy(&tables1, &tables2, 0);
CU_ASSERT_EQUAL_FATAL(ret, 0);

// for tables1:
// on [0, 1] we have 0, 1 inherit from 2
// left, right, parent, child, metadata, metadata_length
ret_id = tsk_edge_table_add_row(&tables1.edges, 0.0, 1.0, 2, 0, NULL, 0);
CU_ASSERT_FATAL(ret_id >= 0);
ret_id = tsk_edge_table_add_row(&tables1.edges, 0.0, 1.0, 2, 1, NULL, 0);
CU_ASSERT_FATAL(ret_id >= 0);
ret_id = tsk_site_table_add_row(&tables1.sites, 0.4, "T", 1, NULL, 0);
CU_ASSERT_FATAL(ret_id >= 0);
ret_id = tsk_mutation_table_add_row(
&tables1.mutations, ret_id, 0, TSK_NULL, TSK_UNKNOWN_TIME, NULL, 0, NULL, 0);
CU_ASSERT_FATAL(ret_id >= 0);
ret = tsk_table_collection_build_index(&tables1, 0);
CU_ASSERT_EQUAL_FATAL(ret, 0);
ret = tsk_table_collection_sort(&tables1, NULL, 0);
CU_ASSERT_EQUAL_FATAL(ret, 0);

// all this goes in tables12 so far
ret = tsk_table_collection_copy(&tables1, &tables12, 0);
CU_ASSERT_EQUAL_FATAL(ret, 0);

// for tables2; and need to add it to tables12 also:
// on [1, 2] we have 0, 1 inherit from 3
// left, right, parent, child, metadata, metadata_length
ret_id = tsk_edge_table_add_row(&tables2.edges, 1.0, 2.0, 3, 0, NULL, 0);
CU_ASSERT_FATAL(ret_id >= 0);
ret_id = tsk_edge_table_add_row(&tables2.edges, 1.0, 2.0, 3, 1, NULL, 0);
CU_ASSERT_FATAL(ret_id >= 0);
ret_id = tsk_site_table_add_row(&tables2.sites, 1.4, "A", 1, NULL, 0);
CU_ASSERT_FATAL(ret_id >= 0);
ret_id = tsk_mutation_table_add_row(
&tables2.mutations, ret_id, 0, TSK_NULL, TSK_UNKNOWN_TIME, "T", 1, NULL, 0);
CU_ASSERT_FATAL(ret_id >= 0);
ret = tsk_table_collection_build_index(&tables2, 0);
CU_ASSERT_EQUAL_FATAL(ret, 0);
ret = tsk_table_collection_sort(&tables2, NULL, 0);
CU_ASSERT_EQUAL_FATAL(ret, 0);
// also tables12
ret_id = tsk_edge_table_add_row(&tables12.edges, 1.0, 2.0, 3, 0, NULL, 0);
CU_ASSERT_FATAL(ret_id >= 0);
ret_id = tsk_edge_table_add_row(&tables12.edges, 1.0, 2.0, 3, 1, NULL, 0);
CU_ASSERT_FATAL(ret_id >= 0);
ret_id = tsk_site_table_add_row(&tables12.sites, 1.4, "A", 1, NULL, 0);
CU_ASSERT_FATAL(ret_id >= 0);
ret_id = tsk_mutation_table_add_row(
&tables12.mutations, ret_id, 1, TSK_NULL, TSK_UNKNOWN_TIME, "T", 1, NULL, 0);
CU_ASSERT_FATAL(ret_id >= 0);
ret = tsk_table_collection_build_index(&tables12, 0);
CU_ASSERT_EQUAL_FATAL(ret, 0);
ret = tsk_table_collection_sort(&tables12, NULL, 0);
CU_ASSERT_EQUAL_FATAL(ret, 0);

// now disjoint union-ing tables1 and tables2 should get tables12
ret = tsk_table_collection_copy(&tables1, &tables, 0);
CU_ASSERT_EQUAL_FATAL(ret, 0);
node_mapping[0] = 0;
node_mapping[1] = 1;
node_mapping[2] = 2;
node_mapping[3] = 3;
ret = tsk_table_collection_union(&tables, &tables2, node_mapping,
TSK_UNION_NO_CHECK_SHARED | TSK_UNION_ALL_EDGES | TSK_UNION_ALL_MUTATIONS);
CU_ASSERT_EQUAL_FATAL(ret, 0);

tsk_table_collection_free(&tables12);
tsk_table_collection_free(&tables2);
tsk_table_collection_free(&tables1);
tsk_table_collection_free(&tables);
}

static void
test_table_collection_union_middle_merge(void)
{
Expand Down Expand Up @@ -11836,6 +11937,7 @@ main(int argc, char **argv)
test_table_collection_subset_unsorted },
{ "test_table_collection_subset_errors", test_table_collection_subset_errors },
{ "test_table_collection_union", test_table_collection_union },
{ "test_table_collection_disjoint_union", test_table_collection_disjoint_union },
{ "test_table_collection_union_middle_merge",
test_table_collection_union_middle_merge },
{ "test_table_collection_union_errors", test_table_collection_union_errors },
Expand Down
25 changes: 22 additions & 3 deletions c/tskit/tables.c
Original file line number Diff line number Diff line change
Expand Up @@ -13202,6 +13202,8 @@ tsk_table_collection_union(tsk_table_collection_t *self,
tsk_id_t *site_map = NULL;
bool add_populations = !(options & TSK_UNION_NO_ADD_POP);
bool check_shared_portion = !(options & TSK_UNION_NO_CHECK_SHARED);
bool all_edges = !!(options & TSK_UNION_ALL_EDGES);
bool all_mutations = !!(options & TSK_UNION_ALL_MUTATIONS);

/* Not calling TSK_CHECK_TREES so casting to int is safe */
ret = (int) tsk_table_collection_check_integrity(self, 0);
Expand Down Expand Up @@ -13285,7 +13287,7 @@ tsk_table_collection_union(tsk_table_collection_t *self,
// edges
for (k = 0; k < (tsk_id_t) other->edges.num_rows; k++) {
tsk_edge_table_get_row_unsafe(&other->edges, k, &edge);
if ((other_node_mapping[edge.parent] == TSK_NULL)
if (all_edges || (other_node_mapping[edge.parent] == TSK_NULL)
|| (other_node_mapping[edge.child] == TSK_NULL)) {
new_parent = node_map[edge.parent];
new_child = node_map[edge.child];
Expand All @@ -13298,14 +13300,31 @@ tsk_table_collection_union(tsk_table_collection_t *self,
}
}

// mutations and sites
// sites
// first do the "disjoint" (all_mutations) case, where we just add all sites;
// otherwise we want to just add sites for new mutations
if (all_mutations) {
for (k = 0; k < (tsk_id_t) other->sites.num_rows; k++) {
tsk_site_table_get_row_unsafe(&other->sites, k, &site);
ret_id = tsk_site_table_add_row(&self->sites, site.position,
site.ancestral_state, site.ancestral_state_length, site.metadata,
site.metadata_length);
if (ret_id < 0) {
ret = (int) ret_id;
goto out;
}
site_map[site.id] = ret_id;
}
}

// mutations (and maybe sites)
i = 0;
for (k = 0; k < (tsk_id_t) other->sites.num_rows; k++) {
tsk_site_table_get_row_unsafe(&other->sites, k, &site);
while ((i < (tsk_id_t) other->mutations.num_rows)
&& (other->mutations.site[i] == site.id)) {
tsk_mutation_table_get_row_unsafe(&other->mutations, i, &mut);
if (other_node_mapping[mut.node] == TSK_NULL) {
if (all_mutations || (other_node_mapping[mut.node] == TSK_NULL)) {
if (site_map[site.id] == TSK_NULL) {
ret_id = tsk_site_table_add_row(&self->sites, site.position,
site.ancestral_state, site.ancestral_state_length, site.metadata,
Expand Down
16 changes: 15 additions & 1 deletion c/tskit/tables.h
Original file line number Diff line number Diff line change
Expand Up @@ -858,11 +858,21 @@ equality of the subsets.
*/
#define TSK_UNION_NO_CHECK_SHARED (1 << 0)
/**
By default, all nodes new to ``self`` are assigned new populations. If this
By default, all nodes new to ``self`` are assigned new populations. If this
option is specified, nodes that are added to ``self`` will retain the
population IDs they have in ``other``.
*/
#define TSK_UNION_NO_ADD_POP (1 << 1)
/**
By default, union only adds only edges adjacent to a newly added node;
this option adds all edges.
*/
#define TSK_UNION_ALL_EDGES (1 << 2)
/**
By default, union only adds only mutations on newly added edges;
this option adds all mutations.
*/
#define TSK_UNION_ALL_MUTATIONS (1 << 3)
/** @} */

/**
Expand Down Expand Up @@ -4414,6 +4424,10 @@ that are exclusive ``other`` are added to ``self``, along with:
By default, populations of newly added nodes are assumed to be new populations,
and added to the population table as well.

The behavior can be changed by the flags ``TSK_UNION_ALL_EDGES`` and
``TSK_UNION_ALL_MUTATIONS``, which will (respectively) add *all* edges
or *all* sites and mutations instead.

This operation will also sort the resulting tables, so the tables may change
even if nothing new is added, if the original tables were not sorted.

Expand Down
17 changes: 13 additions & 4 deletions python/_tskitmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -7069,15 +7069,18 @@ TableCollection_union(TableCollection *self, PyObject *args, PyObject *kwds)
npy_intp *shape;
tsk_flags_t options = 0;
int check_shared = true;
int all_edges = false;
int all_mutations = false;
int add_populations = true;
static char *kwlist[] = { "other", "other_node_mapping", "check_shared_equality",
"add_populations", NULL };
static char *kwlist[] = { "other", "other_node_mapping", "all_edges",
"all_mutations", "check_shared_equality", "add_populations", NULL };

if (TableCollection_check_state(self) != 0) {
goto out;
}
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!O|ii", kwlist, &TableCollectionType,
&other, &other_node_mapping, &check_shared, &add_populations)) {
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!O|iiii", kwlist,
&TableCollectionType, &other, &other_node_mapping, &all_edges,
&all_mutations, &check_shared, &add_populations)) {
goto out;
}
nmap_array = (PyArrayObject *) PyArray_FROMANY(
Expand All @@ -7092,6 +7095,12 @@ TableCollection_union(TableCollection *self, PyObject *args, PyObject *kwds)
" number of nodes in the other tree sequence.");
goto out;
}
if (all_edges) {
options |= TSK_UNION_ALL_EDGES;
}
if (all_mutations) {
options |= TSK_UNION_ALL_MUTATIONS;
}
if (!check_shared) {
options |= TSK_UNION_NO_CHECK_SHARED;
}
Expand Down
19 changes: 19 additions & 0 deletions python/tests/test_lowlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,25 @@ def test_union_bad_args(self):
with pytest.raises(ValueError):
tc.union(tc2, np.array([[1], [2]], dtype="int32"))

@pytest.mark.parametrize("value", [True, False])
@pytest.mark.parametrize(
"flag",
[
"all_edges",
"all_mutations",
"check_shared_equality",
"add_populations",
],
)
def test_union_options(self, flag, value):
ts = msprime.simulate(10, random_seed=1)
tc = ts.tables._ll_tables
empty_tables = ts.tables.copy()
for table in empty_tables.table_name_map.keys():
getattr(empty_tables, table).clear()
tc2 = empty_tables._ll_tables
tc.union(tc2, np.arange(0, dtype="int32"), **{flag: value})

def test_equals_bad_args(self):
ts = msprime.simulate(10, random_seed=1242)
tc = ts.tables._ll_tables
Expand Down
45 changes: 45 additions & 0 deletions python/tests/test_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -5172,6 +5172,51 @@ def test_examples(self):
ts = tables.tree_sequence()
self.verify_union(*self.split_example(ts, T))

def test_split_and_rejoin(self):
ts = self.get_msprime_example(5, T=2, seed=928)
cutpoints = np.array([0, 0.25, 0.5, 0.75, 1]) * ts.sequence_length
tables1 = ts.dump_tables()
tables1.delete_intervals([cutpoints[0:2], cutpoints[2:4]], simplify=False)
tables2 = ts.dump_tables()
tables2.delete_intervals([cutpoints[1:3], cutpoints[3:]], simplify=False)
tables1.union(
tables2,
all_edges=True,
all_mutations=True,
node_mapping=np.arange(ts.num_nodes),
check_shared_equality=False,
)
tables1.edges.squash()
tables1.sort()
tables1.assert_equals(ts.tables, ignore_provenance=True)

def test_add_from_empty(self):
# reciprocally add mutations from one table and edges from the other
edges_table = tskit.Tree.generate_comb(6, span=6).tree_sequence.dump_tables()
muts_table = tskit.TableCollection(sequence_length=6)
muts_table.nodes.replace_with(edges_table.nodes) # same nodes, no edges
for j in range(0, 6):
site_id = muts_table.sites.add_row(position=j, ancestral_state="0")
if j % 2 == 0:
# Some sites empty
muts_table.mutations.add_row(site=site_id, node=j, derived_state="1")
identity_map = np.arange(len(muts_table.nodes), dtype="int32")
params = {"node_mapping": identity_map, "check_shared_equality": False}

edges_table.union(muts_table, **params, all_edges=True) # null op
assert len(edges_table.sites) == 0
assert len(edges_table.mutations) == 0
edges_table.union(muts_table, **params, all_mutations=True)
assert len(edges_table.sites) == 6
assert len(edges_table.mutations) == 3

muts_table.union(edges_table, **params, all_mutations=True) # null op
assert len(muts_table.edges) == 0
muts_table.union(edges_table, **params, all_edges=True) # null op
assert len(muts_table.edges) != 0

edges_table.assert_equals(muts_table, ignore_provenance=True)


class TestTableSetitemMetadata:
@pytest.mark.parametrize("table_name", tskit.TABLE_NAMES)
Expand Down
Loading
Loading