Skip to content

Commit 9545f8d

Browse files
Add low-level Python-C support for arbirary ancestral state
1 parent 57594f4 commit 9545f8d

File tree

3 files changed

+49
-9
lines changed

3 files changed

+49
-9
lines changed

_tsinfermodule.c

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,17 @@ uint64_PyArray_converter(PyObject *in, PyObject **out)
6060
return NPY_SUCCEED;
6161
}
6262

63+
static int
64+
int8_PyArray_converter(PyObject *in, PyObject **out)
65+
{
66+
PyObject *ret = PyArray_FROMANY(in, NPY_INT8, 1, 1, NPY_ARRAY_IN_ARRAY);
67+
if (ret == NULL) {
68+
return NPY_FAIL;
69+
}
70+
*out = ret;
71+
return NPY_SUCCEED;
72+
}
73+
6374
/*===================================================================
6475
* AncestorBuilder
6576
*===================================================================
@@ -429,30 +440,43 @@ TreeSequenceBuilder_init(TreeSequenceBuilder *self, PyObject *args, PyObject *kw
429440
{
430441
int ret = -1;
431442
int err;
432-
static char *kwlist[] = {"num_alleles", "max_nodes", "max_edges", NULL};
443+
static char *kwlist[] = {"num_alleles", "max_nodes", "max_edges", "ancestral_state",
444+
NULL};
433445
PyArrayObject *num_alleles = NULL;
446+
PyArrayObject *ancestral_state = NULL;
447+
int8_t *ancestral_state_data = NULL;
434448
unsigned long max_nodes = 1024;
435449
unsigned long max_edges = 1024;
436450
unsigned long num_sites;
437451
npy_intp *shape;
438452
int flags = 0;
439453

440454
self->tree_sequence_builder = NULL;
441-
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&|kk", kwlist,
455+
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&|kkO&", kwlist,
442456
uint64_PyArray_converter, &num_alleles,
443-
&max_nodes, &max_edges)) {
457+
&max_nodes, &max_edges,
458+
int8_PyArray_converter, &ancestral_state)) {
444459
goto out;
445460
}
446461
shape = PyArray_DIMS(num_alleles);
447462
num_sites = shape[0];
448-
463+
if (ancestral_state != NULL) {
464+
shape = PyArray_DIMS(ancestral_state);
465+
if (shape[0] != (npy_intp) num_sites) {
466+
PyErr_SetString(PyExc_ValueError, "ancestral state array wrong size");
467+
goto out;
468+
}
469+
ancestral_state_data = PyArray_DATA(ancestral_state);
470+
}
449471
self->tree_sequence_builder = PyMem_Malloc(sizeof(tree_sequence_builder_t));
450472
if (self->tree_sequence_builder == NULL) {
451473
PyErr_NoMemory();
452474
goto out;
453475
}
454476
err = tree_sequence_builder_alloc(self->tree_sequence_builder,
455-
num_sites, PyArray_DATA(num_alleles),
477+
num_sites,
478+
PyArray_DATA(num_alleles),
479+
ancestral_state_data,
456480
max_nodes, max_edges, flags);
457481
if (err != 0) {
458482
handle_library_error(err);
@@ -461,6 +485,7 @@ TreeSequenceBuilder_init(TreeSequenceBuilder *self, PyObject *args, PyObject *kw
461485
ret = 0;
462486
out:
463487
Py_XDECREF(num_alleles);
488+
Py_XDECREF(ancestral_state);
464489
return ret;
465490
}
466491

lib/err.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ tsi_strerror(int err)
9494
ret = "Bad mutation information: mutation already exists for this node.";
9595
break;
9696
case TSI_ERR_BAD_ANCESTRAL_STATE:
97-
ret = "Bad derived state for site";
97+
ret = "Bad ancestral state for site";
9898
break;
9999
case TSI_ERR_BAD_NUM_SAMPLES:
100100
ret = "Must have at least 2 samples.";

tests/test_low_level.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,16 +88,31 @@ class TestTreeSequenceBuilder:
8888
def test_init(self):
8989
with pytest.raises(TypeError):
9090
_tsinfer.TreeSequenceBuilder()
91-
for bad_array in [None, "serf", [[], []], ["asdf"], {}]:
92-
with pytest.raises(ValueError):
93-
_tsinfer.TreeSequenceBuilder(bad_array)
9491

9592
for bad_type in [None, "sdf", {}]:
9693
with pytest.raises(TypeError):
9794
_tsinfer.TreeSequenceBuilder([2], max_nodes=bad_type)
9895
with pytest.raises(TypeError):
9996
_tsinfer.TreeSequenceBuilder([2], max_edges=bad_type)
10097

98+
def test_bad_num_alleles(self):
99+
for bad_array in [None, "serf", [[], []], ["asdf"], {}]:
100+
with pytest.raises(ValueError):
101+
_tsinfer.TreeSequenceBuilder(bad_array)
102+
with pytest.raises(_tsinfer.LibraryError, match="number of alleles"):
103+
_tsinfer.TreeSequenceBuilder([1000])
104+
105+
def test_bad_ancestral_state(self):
106+
for bad_array in [None, "serf", [[], []], ["asdf"], {}]:
107+
with pytest.raises(ValueError):
108+
_tsinfer.TreeSequenceBuilder([2], ancestral_state=bad_array)
109+
with pytest.raises(_tsinfer.LibraryError, match="Bad ancestral state"):
110+
for bad_ancestral_state in [-1, 2, 100]:
111+
_tsinfer.TreeSequenceBuilder([2], ancestral_state=[bad_ancestral_state])
112+
113+
with pytest.raises(ValueError, match="ancestral state array wrong size"):
114+
_tsinfer.TreeSequenceBuilder([2, 2, 2], ancestral_state=[])
115+
101116

102117
class TestAncestorBuilder:
103118
"""

0 commit comments

Comments
 (0)