Skip to content

Commit 04e04aa

Browse files
lkirkmergify[bot]
authored andcommitted
Creates the CPython layer for LD matricies. In this layer, we provide a
single matrix method for each statistic. These will get dispatched in the python layer. Adds some low level tests.
1 parent 1601a5e commit 04e04aa

File tree

4 files changed

+260
-6
lines changed

4 files changed

+260
-6
lines changed

c/tests/test_stats.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2541,7 +2541,7 @@ test_two_locus_stat_input_errors(void)
25412541

25422542
ret = tsk_treeseq_r2(&ts, 0, sample_set_sizes, sample_sets, num_sites, row_sites,
25432543
num_sites, col_sites, 0, result);
2544-
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_STATE_DIMS);
2544+
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_INSUFFICIENT_SAMPLE_SETS);
25452545

25462546
sample_set_sizes[0] = 0;
25472547
ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites,

c/tskit/trees.c

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2668,21 +2668,17 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl
26682668
ret = TSK_ERR_MULTIPLE_STAT_MODES;
26692669
goto out;
26702670
}
2671-
if (state_dim < 1) {
2672-
ret = TSK_ERR_BAD_STATE_DIMS;
2673-
goto out;
2674-
}
26752671
// TODO: impossible until we implement branch/windows
26762672
// if (result_dim < 1) {
26772673
// ret = TSK_ERR_BAD_RESULT_DIMS;
26782674
// goto out;
26792675
// }
2680-
26812676
ret = tsk_treeseq_check_sample_sets(
26822677
self, num_sample_sets, sample_set_sizes, sample_sets);
26832678
if (ret != 0) {
26842679
goto out;
26852680
}
2681+
tsk_bug_assert(state_dim > 0);
26862682
ret = sample_sets_to_bit_array(
26872683
self, sample_set_sizes, sample_sets, num_sample_sets, &sample_sets_bits);
26882684
if (ret != 0) {

python/_tskitmodule.c

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9860,6 +9860,130 @@ TreeSequence_divergence_matrix(TreeSequence *self, PyObject *args, PyObject *kwd
98609860
return ret;
98619861
}
98629862

9863+
static PyObject *
9864+
TreeSequence_ld_matrix(TreeSequence *self, PyObject *args, PyObject *kwds,
9865+
two_locus_count_stat_method *method)
9866+
{
9867+
PyObject *ret = NULL;
9868+
static char *kwlist[]
9869+
= { "sample_set_sizes", "sample_sets", "row_sites", "col_sites", "mode", NULL };
9870+
9871+
PyObject *row_sites = NULL;
9872+
PyObject *col_sites = NULL;
9873+
PyObject *sample_set_sizes = NULL;
9874+
PyObject *sample_sets = NULL;
9875+
PyArrayObject *sample_set_sizes_array = NULL;
9876+
PyArrayObject *sample_sets_array = NULL;
9877+
PyArrayObject *row_sites_array = NULL;
9878+
PyArrayObject *col_sites_array = NULL;
9879+
PyArrayObject *result_matrix = NULL;
9880+
npy_intp result_shape[3];
9881+
char *mode = NULL;
9882+
tsk_size_t num_sample_sets;
9883+
tsk_flags_t options = 0;
9884+
int err;
9885+
9886+
if (TreeSequence_check_state(self) != 0) {
9887+
goto out;
9888+
}
9889+
if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOOOs", kwlist, &sample_set_sizes,
9890+
&sample_sets, &row_sites, &col_sites, &mode)) {
9891+
goto out;
9892+
}
9893+
if (parse_stats_mode(mode, &options) != 0) {
9894+
goto out;
9895+
}
9896+
if (parse_sample_sets(sample_set_sizes, &sample_set_sizes_array, sample_sets,
9897+
&sample_sets_array, &num_sample_sets)
9898+
!= 0) {
9899+
goto out;
9900+
}
9901+
row_sites_array = (PyArrayObject *) PyArray_FROMANY(
9902+
row_sites, NPY_INT32, 1, 1, NPY_ARRAY_IN_ARRAY);
9903+
if (row_sites_array == NULL) {
9904+
goto out;
9905+
}
9906+
col_sites_array = (PyArrayObject *) PyArray_FROMANY(
9907+
col_sites, NPY_INT32, 1, 1, NPY_ARRAY_IN_ARRAY);
9908+
if (col_sites_array == NULL) {
9909+
goto out;
9910+
}
9911+
9912+
result_shape[0] = PyArray_DIM(row_sites_array, 0);
9913+
result_shape[1] = PyArray_DIM(col_sites_array, 0);
9914+
result_shape[2] = num_sample_sets;
9915+
result_matrix = (PyArrayObject *) PyArray_ZEROS(3, result_shape, NPY_FLOAT64, 0);
9916+
if (result_matrix == NULL) {
9917+
goto out;
9918+
}
9919+
9920+
// clang-format off
9921+
Py_BEGIN_ALLOW_THREADS
9922+
err = method(self->tree_sequence, num_sample_sets,
9923+
PyArray_DATA(sample_set_sizes_array), PyArray_DATA(sample_sets_array),
9924+
result_shape[0], PyArray_DATA(row_sites_array), result_shape[1],
9925+
PyArray_DATA(col_sites_array), options, PyArray_DATA(result_matrix));
9926+
Py_END_ALLOW_THREADS
9927+
// clang-format on
9928+
9929+
if (err != 0)
9930+
{
9931+
handle_library_error(err);
9932+
goto out;
9933+
}
9934+
ret = (PyObject *) result_matrix;
9935+
result_matrix = NULL;
9936+
out:
9937+
Py_XDECREF(row_sites_array);
9938+
Py_XDECREF(col_sites_array);
9939+
Py_XDECREF(sample_set_sizes_array);
9940+
Py_XDECREF(sample_sets_array);
9941+
Py_XDECREF(result_matrix);
9942+
return ret;
9943+
}
9944+
9945+
static PyObject *
9946+
TreeSequence_D_matrix(TreeSequence *self, PyObject *args, PyObject *kwds)
9947+
{
9948+
return TreeSequence_ld_matrix(self, args, kwds, tsk_treeseq_D);
9949+
}
9950+
9951+
static PyObject *
9952+
TreeSequence_D2_matrix(TreeSequence *self, PyObject *args, PyObject *kwds)
9953+
{
9954+
return TreeSequence_ld_matrix(self, args, kwds, tsk_treeseq_D2);
9955+
}
9956+
9957+
static PyObject *
9958+
TreeSequence_r2_matrix(TreeSequence *self, PyObject *args, PyObject *kwds)
9959+
{
9960+
return TreeSequence_ld_matrix(self, args, kwds, tsk_treeseq_r2);
9961+
}
9962+
9963+
static PyObject *
9964+
TreeSequence_D_prime_matrix(TreeSequence *self, PyObject *args, PyObject *kwds)
9965+
{
9966+
return TreeSequence_ld_matrix(self, args, kwds, tsk_treeseq_D_prime);
9967+
}
9968+
9969+
static PyObject *
9970+
TreeSequence_r_matrix(TreeSequence *self, PyObject *args, PyObject *kwds)
9971+
{
9972+
return TreeSequence_ld_matrix(self, args, kwds, tsk_treeseq_r);
9973+
}
9974+
9975+
static PyObject *
9976+
TreeSequence_Dz_matrix(TreeSequence *self, PyObject *args, PyObject *kwds)
9977+
{
9978+
return TreeSequence_ld_matrix(self, args, kwds, tsk_treeseq_Dz);
9979+
}
9980+
9981+
static PyObject *
9982+
TreeSequence_pi2_matrix(TreeSequence *self, PyObject *args, PyObject *kwds)
9983+
{
9984+
return TreeSequence_ld_matrix(self, args, kwds, tsk_treeseq_pi2);
9985+
}
9986+
98639987
static PyObject *
98649988
TreeSequence_get_num_mutations(TreeSequence *self)
98659989
{
@@ -10588,6 +10712,34 @@ static PyMethodDef TreeSequence_methods[] = {
1058810712
.ml_meth = (PyCFunction) TreeSequence_has_reference_sequence,
1058910713
.ml_flags = METH_NOARGS,
1059010714
.ml_doc = "Returns True if the TreeSequence has a reference sequence." },
10715+
{ .ml_name = "D_matrix",
10716+
.ml_meth = (PyCFunction) TreeSequence_D_matrix,
10717+
.ml_flags = METH_VARARGS | METH_KEYWORDS,
10718+
.ml_doc = "Computes the D matrix." },
10719+
{ .ml_name = "D2_matrix",
10720+
.ml_meth = (PyCFunction) TreeSequence_D2_matrix,
10721+
.ml_flags = METH_VARARGS | METH_KEYWORDS,
10722+
.ml_doc = "Computes the D2 matrix." },
10723+
{ .ml_name = "r2_matrix",
10724+
.ml_meth = (PyCFunction) TreeSequence_r2_matrix,
10725+
.ml_flags = METH_VARARGS | METH_KEYWORDS,
10726+
.ml_doc = "Computes the r2 matrix." },
10727+
{ .ml_name = "D_prime_matrix",
10728+
.ml_meth = (PyCFunction) TreeSequence_D_prime_matrix,
10729+
.ml_flags = METH_VARARGS | METH_KEYWORDS,
10730+
.ml_doc = "Computes the D_prime matrix." },
10731+
{ .ml_name = "r_matrix",
10732+
.ml_meth = (PyCFunction) TreeSequence_r_matrix,
10733+
.ml_flags = METH_VARARGS | METH_KEYWORDS,
10734+
.ml_doc = "Computes the r matrix." },
10735+
{ .ml_name = "Dz_matrix",
10736+
.ml_meth = (PyCFunction) TreeSequence_Dz_matrix,
10737+
.ml_flags = METH_VARARGS | METH_KEYWORDS,
10738+
.ml_doc = "Computes the Dz matrix." },
10739+
{ .ml_name = "pi2_matrix",
10740+
.ml_meth = (PyCFunction) TreeSequence_pi2_matrix,
10741+
.ml_flags = METH_VARARGS | METH_KEYWORDS,
10742+
.ml_doc = "Computes the pi2 matrix." },
1059110743
{ NULL } /* Sentinel */
1059210744
};
1059310745

python/tests/test_lowlevel.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1512,6 +1512,112 @@ def test_extend_edges_bad_args(self):
15121512
):
15131513
tsm.extend_edges(1)
15141514

1515+
@pytest.mark.parametrize(
1516+
"stat_method_name",
1517+
[
1518+
"D_matrix",
1519+
"D2_matrix",
1520+
"r2_matrix",
1521+
"D_prime_matrix",
1522+
"r_matrix",
1523+
"Dz_matrix",
1524+
"pi2_matrix",
1525+
],
1526+
)
1527+
def test_ld_matrix(self, stat_method_name):
1528+
ts = self.get_example_tree_sequence(10)
1529+
stat_method = getattr(ts, stat_method_name)
1530+
1531+
mode = "site"
1532+
sample_sets = ts.get_samples()
1533+
sample_set_sizes = np.array([len(sample_sets)], dtype=np.uint32)
1534+
row_sites = np.arange(ts.get_num_sites(), dtype=np.int32)
1535+
col_sites = row_sites
1536+
row_sites_list = list(range(ts.get_num_sites()))
1537+
col_sites_list = row_sites_list
1538+
1539+
# happy path
1540+
a = stat_method(sample_set_sizes, sample_sets, row_sites, col_sites, mode)
1541+
assert a.shape == (10, 10, 1)
1542+
1543+
a = stat_method(
1544+
sample_set_sizes, sample_sets, row_sites_list, col_sites_list, mode
1545+
)
1546+
assert a.shape == (10, 10, 1)
1547+
1548+
# CPython API errors
1549+
with pytest.raises(ValueError, match="Sum of sample_set_sizes"):
1550+
bad_sample_sets = np.array([], dtype=np.int32)
1551+
stat_method(sample_set_sizes, bad_sample_sets, row_sites, col_sites, mode)
1552+
with pytest.raises(TypeError, match="cast array data"):
1553+
bad_sample_sets = np.array(ts.get_samples(), dtype=np.uint32)
1554+
stat_method(sample_set_sizes, bad_sample_sets, row_sites, col_sites, mode)
1555+
with pytest.raises(ValueError, match="Unrecognised stats mode"):
1556+
stat_method(sample_set_sizes, sample_sets, row_sites, col_sites, "bla")
1557+
with pytest.raises(TypeError, match="at most"):
1558+
stat_method(
1559+
sample_set_sizes, sample_sets, row_sites, col_sites, mode, "abc"
1560+
)
1561+
with pytest.raises(ValueError, match="invalid literal"):
1562+
bad_sites = ["abadsite", 0, 3, 2]
1563+
stat_method(sample_set_sizes, sample_sets, bad_sites, col_sites, mode)
1564+
with pytest.raises(TypeError):
1565+
bad_sites = [None, 0, 3, 2]
1566+
stat_method(sample_set_sizes, sample_sets, bad_sites, col_sites, mode)
1567+
with pytest.raises(TypeError):
1568+
bad_sites = [{}, 0, 3, 2]
1569+
stat_method(sample_set_sizes, sample_sets, bad_sites, col_sites, mode)
1570+
with pytest.raises(TypeError, match="Cannot cast array data"):
1571+
bad_sites = np.array([0, 1, 2], dtype=np.uint32)
1572+
stat_method(sample_set_sizes, sample_sets, bad_sites, col_sites, mode)
1573+
with pytest.raises(ValueError, match="invalid literal"):
1574+
bad_sites = ["abadsite", 0, 3, 2]
1575+
stat_method(sample_set_sizes, sample_sets, row_sites, bad_sites, mode)
1576+
with pytest.raises(TypeError):
1577+
bad_sites = [None, 0, 3, 2]
1578+
stat_method(sample_set_sizes, sample_sets, row_sites, bad_sites, mode)
1579+
with pytest.raises(TypeError):
1580+
bad_sites = [{}, 0, 3, 2]
1581+
stat_method(sample_set_sizes, sample_sets, row_sites, bad_sites, mode)
1582+
with pytest.raises(TypeError, match="Cannot cast array data"):
1583+
bad_sites = np.array([0, 1, 2], dtype=np.uint32)
1584+
stat_method(sample_set_sizes, sample_sets, row_sites, bad_sites, mode)
1585+
# C API errors
1586+
with pytest.raises(tskit.LibraryError, match="TSK_ERR_UNSORTED_SITES"):
1587+
bad_sites = np.array([1, 0, 2], dtype=np.int32)
1588+
stat_method(sample_set_sizes, sample_sets, bad_sites, col_sites, mode)
1589+
with pytest.raises(tskit.LibraryError, match="TSK_ERR_UNSORTED_SITES"):
1590+
bad_sites = np.array([1, 0, 2], dtype=np.int32)
1591+
stat_method(sample_set_sizes, sample_sets, row_sites, bad_sites, mode)
1592+
with pytest.raises(
1593+
_tskit.LibraryError, match="TSK_ERR_INSUFFICIENT_SAMPLE_SETS"
1594+
):
1595+
bad_sample_sets = np.array([], dtype=np.int32)
1596+
bad_sample_set_sizes = np.array([], dtype=np.uint32)
1597+
stat_method(
1598+
bad_sample_set_sizes, bad_sample_sets, row_sites, col_sites, mode
1599+
)
1600+
with pytest.raises(_tskit.LibraryError, match="TSK_ERR_EMPTY_SAMPLE_SET"):
1601+
bad_sample_sets = np.array([], dtype=np.int32)
1602+
bad_sample_set_sizes = np.array([0], dtype=np.uint32)
1603+
stat_method(
1604+
bad_sample_set_sizes, bad_sample_sets, row_sites, col_sites, mode
1605+
)
1606+
with pytest.raises(_tskit.LibraryError, match="TSK_ERR_NODE_OUT_OF_BOUNDS"):
1607+
bad_sample_sets = np.array([1000], dtype=np.int32)
1608+
bad_sample_set_sizes = np.array([1], dtype=np.uint32)
1609+
stat_method(
1610+
bad_sample_set_sizes, bad_sample_sets, row_sites, col_sites, mode
1611+
)
1612+
with pytest.raises(_tskit.LibraryError, match="TSK_ERR_DUPLICATE_SAMPLE"):
1613+
bad_sample_sets = np.array([2, 2], dtype=np.int32)
1614+
bad_sample_set_sizes = np.array([2], dtype=np.uint32)
1615+
stat_method(
1616+
bad_sample_set_sizes, bad_sample_sets, row_sites, col_sites, mode
1617+
)
1618+
with pytest.raises(_tskit.LibraryError, match="TSK_ERR_UNSUPPORTED_STAT_MODE"):
1619+
stat_method(sample_set_sizes, sample_sets, row_sites, col_sites, "branch")
1620+
15151621
def test_kc_distance_errors(self):
15161622
ts1 = self.get_example_tree_sequence(10)
15171623
with pytest.raises(TypeError):

0 commit comments

Comments
 (0)