Skip to content

Commit 7dff9d9

Browse files
authored
Add link-ancestors to TreeSequence and ImmutableTableCollection (#3312)
1 parent 17889ef commit 7dff9d9

File tree

7 files changed

+209
-2
lines changed

7 files changed

+209
-2
lines changed

python/_tskitmodule.c

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5344,6 +5344,69 @@ TreeSequence_dump_tables(TreeSequence *self, PyObject *args, PyObject *kwds)
53445344
return ret;
53455345
}
53465346

5347+
static PyObject *
5348+
TreeSequence_link_ancestors(TreeSequence *self, PyObject *args, PyObject *kwds)
5349+
{
5350+
int err;
5351+
PyObject *ret = NULL;
5352+
PyObject *samples = NULL;
5353+
PyObject *ancestors = NULL;
5354+
PyArrayObject *samples_array = NULL;
5355+
PyArrayObject *ancestors_array = NULL;
5356+
npy_intp *shape;
5357+
tsk_size_t num_samples, num_ancestors;
5358+
EdgeTable *result = NULL;
5359+
PyObject *result_args = NULL;
5360+
static char *kwlist[] = { "samples", "ancestors", NULL };
5361+
5362+
if (TreeSequence_check_state(self) != 0) {
5363+
goto out;
5364+
}
5365+
if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO", kwlist, &samples, &ancestors)) {
5366+
goto out;
5367+
}
5368+
5369+
samples_array = (PyArrayObject *) PyArray_FROMANY(
5370+
samples, NPY_INT32, 1, 1, NPY_ARRAY_IN_ARRAY);
5371+
if (samples_array == NULL) {
5372+
goto out;
5373+
}
5374+
shape = PyArray_DIMS(samples_array);
5375+
num_samples = (tsk_size_t) shape[0];
5376+
5377+
ancestors_array = (PyArrayObject *) PyArray_FROMANY(
5378+
ancestors, NPY_INT32, 1, 1, NPY_ARRAY_IN_ARRAY);
5379+
if (ancestors_array == NULL) {
5380+
goto out;
5381+
}
5382+
shape = PyArray_DIMS(ancestors_array);
5383+
num_ancestors = (tsk_size_t) shape[0];
5384+
5385+
result_args = PyTuple_New(0);
5386+
if (result_args == NULL) {
5387+
goto out;
5388+
}
5389+
result = (EdgeTable *) PyObject_CallObject((PyObject *) &EdgeTableType, result_args);
5390+
if (result == NULL) {
5391+
goto out;
5392+
}
5393+
err = tsk_table_collection_link_ancestors(self->tree_sequence->tables,
5394+
PyArray_DATA(samples_array), num_samples, PyArray_DATA(ancestors_array),
5395+
num_ancestors, 0, result->table);
5396+
if (err != 0) {
5397+
handle_library_error(err);
5398+
goto out;
5399+
}
5400+
ret = (PyObject *) result;
5401+
result = NULL;
5402+
out:
5403+
Py_XDECREF(samples_array);
5404+
Py_XDECREF(ancestors_array);
5405+
Py_XDECREF(result);
5406+
Py_XDECREF(result_args);
5407+
return ret;
5408+
}
5409+
53475410
static PyObject *
53485411
TreeSequence_load(TreeSequence *self, PyObject *args, PyObject *kwds)
53495412
{
@@ -8624,6 +8687,10 @@ static PyMethodDef TreeSequence_methods[] = {
86248687
.ml_meth = (PyCFunction) TreeSequence_dump_tables,
86258688
.ml_flags = METH_VARARGS | METH_KEYWORDS,
86268689
.ml_doc = "Dumps the tree sequence to the specified set of tables" },
8690+
{ .ml_name = "link_ancestors",
8691+
.ml_meth = (PyCFunction) TreeSequence_link_ancestors,
8692+
.ml_flags = METH_VARARGS | METH_KEYWORDS,
8693+
.ml_doc = "Returns an EdgeTable linking the specified samples and ancestors." },
86278694
{ .ml_name = "get_node",
86288695
.ml_meth = (PyCFunction) TreeSequence_get_node,
86298696
.ml_flags = METH_VARARGS,

python/tests/test_highlevel.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,47 @@ def simplify_tree_sequence(ts, samples, filter_sites=True):
163163
return s.simplify()
164164

165165

166+
@pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences())
167+
class TestLinkAncestorsExamples:
168+
def test_link_ancestors_runs_and_is_sane(self, ts):
169+
# Can't link ancestors when edges have metadata.
170+
if ts.tables.edges.metadata_schema != tskit.MetadataSchema(schema=None):
171+
pytest.skip("link_ancestors does not support edges with metadata")
172+
173+
samples = ts.samples()
174+
if len(samples) == 0:
175+
pytest.skip("Tree sequence has no samples")
176+
177+
# Prefer internal nodes as ancestors; fall back to samples if none.
178+
ancestor_nodes = [u.id for u in ts.nodes() if not u.is_sample()]
179+
if len(ancestor_nodes) == 0:
180+
ancestor_nodes = list(samples)
181+
182+
# Keep argument sizes modest for large examples.
183+
samples = samples[: min(len(samples), 10)]
184+
ancestors = ancestor_nodes[: min(len(ancestor_nodes), 10)]
185+
186+
result = ts.link_ancestors(samples, ancestors)
187+
assert isinstance(result, tskit.EdgeTable)
188+
189+
# Basic invariants on the returned table.
190+
assert np.all(result.left >= 0)
191+
assert np.all(result.right <= ts.sequence_length)
192+
if result.num_rows > 0:
193+
assert np.all(result.left < result.right)
194+
assert set(result.parent).issubset(set(range(ts.num_nodes)))
195+
assert set(result.child).issubset(set(range(ts.num_nodes)))
196+
197+
# Parity with mutable TableCollection implementation.
198+
mutable_result = ts.dump_tables().link_ancestors(samples, ancestors)
199+
assert result == mutable_result
200+
201+
# Parity with immutable TableCollection, when available.
202+
if getattr(_tskit, "HAS_NUMPY_2", False):
203+
immutable_result = ts.tables.link_ancestors(samples, ancestors)
204+
assert result == immutable_result
205+
206+
166207
def oriented_forests(n):
167208
"""
168209
Implementation of Algorithm O from TAOCP section 7.2.1.6.

python/tests/test_immutable_table_collection.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,27 @@ def test_str_contains_identifier(self, ts):
128128
s = str(immutable)
129129
assert "ImmutableTableCollection" in s
130130

131+
def test_link_ancestors_parity(self, ts):
132+
# Can't link ancestors when edges have metadata.
133+
if ts.tables.edges.metadata_schema != tskit.MetadataSchema(schema=None):
134+
pytest.skip("link_ancestors does not support edges with metadata")
135+
136+
mutable, immutable = get_mutable_and_immutable(ts)
137+
samples = ts.samples()
138+
if len(samples) == 0:
139+
pytest.skip("Tree sequence has no samples")
140+
141+
ancestor_nodes = [u.id for u in ts.nodes() if not u.is_sample()]
142+
if len(ancestor_nodes) == 0:
143+
ancestor_nodes = list(samples)
144+
145+
samples = samples[: min(len(samples), 10)]
146+
ancestors = ancestor_nodes[: min(len(ancestor_nodes), 10)]
147+
148+
mutable_result = mutable.link_ancestors(samples, ancestors)
149+
immutable_result = immutable.link_ancestors(samples, ancestors)
150+
assert mutable_result == immutable_result
151+
131152

132153
@pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences())
133154
class TestTablesParity:

python/tests/test_lowlevel.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1495,6 +1495,47 @@ def test_mean_descendants(self):
14951495
A = ts.mean_descendants([focal[2:], focal[:2]])
14961496
assert A.shape == (ts.get_num_nodes(), 2)
14971497

1498+
def test_link_ancestors_bad_args(self):
1499+
ts = self.get_example_tree_sequence()
1500+
with pytest.raises(TypeError):
1501+
ts.link_ancestors()
1502+
with pytest.raises(TypeError):
1503+
ts.link_ancestors([0, 1])
1504+
with pytest.raises(ValueError):
1505+
ts.link_ancestors(samples=[0, 1], ancestors="sdf")
1506+
with pytest.raises(ValueError):
1507+
ts.link_ancestors(samples="sdf", ancestors=[0, 1])
1508+
with pytest.raises(_tskit.LibraryError):
1509+
ts.link_ancestors(samples=[0, 1], ancestors=[ts.get_num_nodes(), -1])
1510+
with pytest.raises(_tskit.LibraryError):
1511+
ts.link_ancestors(samples=[0, -1], ancestors=[0])
1512+
1513+
def test_link_ancestors(self):
1514+
# Check that the low-level method runs and does not mutate the tree sequence
1515+
# and that it matches the TableCollection implementation.
1516+
high_ts = msprime.simulate(4, random_seed=1)
1517+
ts = high_ts.ll_tree_sequence
1518+
samples = list(range(ts.get_num_samples()))
1519+
ancestors = list(range(ts.get_num_nodes()))
1520+
num_edges_before = ts.get_num_edges()
1521+
edges = ts.link_ancestors(samples, ancestors)
1522+
assert isinstance(edges, _tskit.EdgeTable)
1523+
assert edges.num_rows >= 0
1524+
if edges.num_rows > 0:
1525+
assert np.all(edges.left >= 0)
1526+
assert np.all(edges.right <= ts.get_sequence_length())
1527+
assert np.all(edges.left < edges.right)
1528+
assert np.all(edges.parent >= 0)
1529+
assert np.all(edges.parent < ts.get_num_nodes())
1530+
assert np.all(edges.child >= 0)
1531+
assert np.all(edges.child < ts.get_num_nodes())
1532+
assert ts.get_num_edges() == num_edges_before
1533+
1534+
# Parity with low-level TableCollection.link_ancestors
1535+
tc = high_ts.dump_tables()._ll_tables
1536+
edges_from_tables = tc.link_ancestors(samples, ancestors)
1537+
assert edges.equals(edges_from_tables)
1538+
14981539
def test_metadata_schemas(self):
14991540
tables = _tskit.TableCollection(1.0)
15001541
# Set the schema

python/tests/test_topology.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4901,6 +4901,10 @@ def do_map(self, ts, ancestors, samples=None, compare_lib=True):
49014901
if compare_lib:
49024902
lib_result = ts.dump_tables().link_ancestors(samples, ancestors)
49034903
assert ancestor_table == lib_result
4904+
ts_result = ts.link_ancestors(samples, ancestors)
4905+
assert ancestor_table == ts_result
4906+
tables_result = ts.tables.link_ancestors(samples, ancestors)
4907+
assert ancestor_table == tables_result
49044908
return ancestor_table
49054909

49064910
def test_deprecated_name(self):
@@ -4914,6 +4918,10 @@ def test_deprecated_name(self):
49144918
tss = s.link_ancestors()
49154919
lib_result = ts.dump_tables().map_ancestors(samples, ancestors)
49164920
assert tss == lib_result
4921+
ts_result = ts.link_ancestors(samples, ancestors)
4922+
assert tss == ts_result
4923+
immutable_result = ts.tables.map_ancestors(samples, ancestors)
4924+
assert tss == immutable_result
49174925
assert list(tss.parent) == [8, 8, 8, 8, 8]
49184926
assert list(tss.child) == [0, 1, 2, 3, 4]
49194927
assert all(tss.left) == 0

python/tskit/tables.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4780,6 +4780,21 @@ def __str__(self):
47804780
]
47814781
)
47824782

4783+
def link_ancestors(self, samples, ancestors):
4784+
"""
4785+
See :meth:`TableCollection.link_ancestors`.
4786+
"""
4787+
samples = util.safe_np_int_cast(samples, np.int32)
4788+
ancestors = util.safe_np_int_cast(ancestors, np.int32)
4789+
ll_edge_table = self._llts.link_ancestors(samples, ancestors)
4790+
return EdgeTable(ll_table=ll_edge_table)
4791+
4792+
def map_ancestors(self, *args, **kwargs):
4793+
"""
4794+
Deprecated alias for :meth:`link_ancestors`.
4795+
"""
4796+
return self.link_ancestors(*args, **kwargs)
4797+
47834798
_MUTATOR_METHODS = {
47844799
"clear",
47854800
"sort",
@@ -4803,8 +4818,6 @@ def __str__(self):
48034818
"ibd_segments",
48044819
"fromdict",
48054820
"simplify",
4806-
"link_ancestors",
4807-
"map_ancestors",
48084821
}
48094822

48104823
def copy(self):

python/tskit/trees.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4393,6 +4393,22 @@ def dump_tables(self):
43934393
self._ll_tree_sequence.dump_tables(ll_tables)
43944394
return tables.TableCollection(ll_tables=ll_tables)
43954395

4396+
def link_ancestors(self, samples, ancestors):
4397+
"""
4398+
Equivalent to :meth:`TableCollection.link_ancestors`; see that method for full
4399+
documentation and parameter semantics.
4400+
4401+
:param list[int] samples: Node IDs to retain as samples.
4402+
:param list[int] ancestors: Node IDs to treat as ancestors.
4403+
:return: An :class:`tables.EdgeTable` containing the genealogical links between
4404+
the supplied ``samples`` and ``ancestors``.
4405+
:rtype: tables.EdgeTable
4406+
"""
4407+
samples = util.safe_np_int_cast(samples, np.int32)
4408+
ancestors = util.safe_np_int_cast(ancestors, np.int32)
4409+
ll_edge_table = self._ll_tree_sequence.link_ancestors(samples, ancestors)
4410+
return tables.EdgeTable(ll_table=ll_edge_table)
4411+
43964412
def dump_text(
43974413
self,
43984414
nodes=None,

0 commit comments

Comments
 (0)