Skip to content

Commit 2e42bfc

Browse files
committed
Move inherited_state array to CPython
1 parent e80c00a commit 2e42bfc

File tree

10 files changed

+138
-14
lines changed

10 files changed

+138
-14
lines changed

c/CHANGELOG.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,15 @@
1414
compatible with the correct mutation parent.
1515
(:user:`benjeffery`, :issue:`2729`, :issue:`2732`, :pr:`3212`).
1616

17+
- Mutations returned by ``tsk_treeseq_get_mutation`` now include pre-computed
18+
``inherited_state`` and ``inherited_state_length`` fields. The inherited state
19+
is computed during tree sequence initialization and represents the state that
20+
existed at the site before each mutation occurred (either the ancestral state
21+
if the mutation is the root mutation or the derived state of the parent mutation).
22+
Note that this breaks ABI compatibility due to the addition of these fields
23+
to the ``tsk_mutation_t`` struct.
24+
(:user:`benjeffery`, :pr:`3277`, :issue:`2631`).
25+
1726
**Features**
1827

1928
- Add ``TSK_TS_INIT_COMPUTE_MUTATION_PARENTS`` to ``tsk_treeseq_init``

c/tskit/trees.c

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,7 @@ tsk_treeseq_init_trees(tsk_treeseq_t *self)
269269
bool discrete_breakpoints = true;
270270
tsk_id_t *node_edge_map = tsk_malloc(num_nodes * sizeof(*node_edge_map));
271271
tsk_mutation_t *mutation;
272+
tsk_id_t parent_id;
272273

273274
self->tree_sites_length
274275
= tsk_malloc(num_trees_alloc * sizeof(*self->tree_sites_length));
@@ -329,7 +330,7 @@ tsk_treeseq_init_trees(tsk_treeseq_t *self)
329330
- sites_ancestral_state_offset[site_id];
330331
} else {
331332
/* Has parent: inherited state is parent's derived state */
332-
tsk_id_t parent_id = mutation_parent[mutation_id];
333+
parent_id = mutation_parent[mutation_id];
333334
mutation->inherited_state
334335
= mutations_derived_state
335336
+ mutations_derived_state_offset[parent_id];

docs/python-api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ for more information.
9292
TreeSequence.edges_parent
9393
TreeSequence.edges_child
9494
TreeSequence.sites_position
95+
TreeSequence.sites_ancestral_state
9596
TreeSequence.mutations_site
9697
TreeSequence.mutations_node
9798
TreeSequence.mutations_parent

python/CHANGELOG.rst

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,6 @@
3030
- Add ``TreeSequence.mutations_edge`` which returns the edge ID for each mutation's
3131
edge. (:user:`benjeffery`, :pr:`3226`, :issue:`3189`)
3232

33-
- Add which returns the inherited state
34-
for each mutation. (:user:`benjeffery`, :pr:`3276`, :issue:`2631`)
35-
3633
- Add ``TreeSequence.sites_ancestral_state``, ``TreeSequence.mutations_derived_state`` and
3734
``TreeSequence.mutations_inherited_state`` properties to return the ancestral state of sites,
3835
derived state of mutations and inherited state of mutations as NumPy arrays of

python/_tskitmodule.c

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11399,6 +11399,59 @@ TreeSequence_get_mutations_derived_state(TreeSequence *self, void *closure)
1139911399
out:
1140011400
return ret;
1140111401
}
11402+
static PyObject *
11403+
TreeSequence_get_mutations_inherited_state(TreeSequence *self, void *closure)
11404+
{
11405+
PyObject *ret = NULL;
11406+
tsk_treeseq_t *ts;
11407+
tsk_size_t num_mutations;
11408+
char *inherited_state_data = NULL;
11409+
tsk_size_t *inherited_state_offsets = NULL;
11410+
tsk_size_t total_length = 0;
11411+
tsk_size_t j, offset;
11412+
11413+
if (TreeSequence_check_state(self) != 0) {
11414+
goto out;
11415+
}
11416+
11417+
ts = self->tree_sequence;
11418+
num_mutations = ts->tables->mutations.num_rows;
11419+
11420+
/* Calculate total length needed for inherited state data */
11421+
for (j = 0; j < num_mutations; j++) {
11422+
total_length += ts->site_mutations_mem[j].inherited_state_length;
11423+
}
11424+
11425+
/* Allocate memory for the ragged array */
11426+
inherited_state_data = PyMem_Malloc(total_length * sizeof(char));
11427+
inherited_state_offsets = PyMem_Malloc((num_mutations + 1) * sizeof(tsk_size_t));
11428+
if (inherited_state_data == NULL || inherited_state_offsets == NULL) {
11429+
PyErr_NoMemory();
11430+
goto out;
11431+
}
11432+
11433+
/* Populate the ragged array data */
11434+
offset = 0;
11435+
for (j = 0; j < num_mutations; j++) {
11436+
inherited_state_offsets[j] = offset;
11437+
memcpy(inherited_state_data + offset, ts->site_mutations_mem[j].inherited_state,
11438+
ts->site_mutations_mem[j].inherited_state_length);
11439+
offset += ts->site_mutations_mem[j].inherited_state_length;
11440+
}
11441+
inherited_state_offsets[num_mutations] = offset;
11442+
11443+
ret = TreeSequence_decode_ragged_string_column(
11444+
self, num_mutations, inherited_state_data, inherited_state_offsets);
11445+
11446+
out:
11447+
if (inherited_state_data != NULL) {
11448+
PyMem_Free(inherited_state_data);
11449+
}
11450+
if (inherited_state_offsets != NULL) {
11451+
PyMem_Free(inherited_state_offsets);
11452+
}
11453+
return ret;
11454+
}
1140211455
#endif
1140311456

1140411457
static PyObject *
@@ -12018,6 +12071,9 @@ static PyGetSetDef TreeSequence_getsetters[] = {
1201812071
{ .name = "mutations_derived_state",
1201912072
.get = (getter) TreeSequence_get_mutations_derived_state,
1202012073
.doc = "The mutation derived state array" },
12074+
{ .name = "mutations_inherited_state",
12075+
.get = (getter) TreeSequence_get_mutations_inherited_state,
12076+
.doc = "The mutation inherited state array" },
1202112077
#endif
1202212078
{ .name = "mutations_metadata",
1202312079
.get = (getter) TreeSequence_get_mutations_metadata,

python/tests/test_highlevel.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5560,6 +5560,14 @@ def test_equality_mutations_derived_state(self, ts):
55605560
[mutation.derived_state for mutation in ts.mutations()],
55615561
)
55625562

5563+
@pytest.mark.skipif(not _tskit.HAS_NUMPY_2, reason="Requires NumPy 2.0 or higher")
5564+
@pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences())
5565+
def test_equality_mutations_inherited_state(self, ts):
5566+
assert_array_equal(
5567+
ts.mutations_inherited_state,
5568+
[mutation.inherited_state for mutation in ts.mutations()],
5569+
)
5570+
55635571
@pytest.mark.skipif(not _tskit.HAS_NUMPY_2, reason="Requires NumPy 2.0 or higher")
55645572
@pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences())
55655573
def test_mutations_inherited_state(self, ts):

python/tests/test_jit.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,10 @@ def test_jitwrap_properties(ts):
420420
assert numba_ts.mutations_time.dtype == np.float64
421421
nt.assert_array_equal(numba_ts.mutations_derived_state, ts.mutations_derived_state)
422422
assert numba_ts.mutations_derived_state.dtype.kind == "U" # Unicode string
423+
nt.assert_array_equal(
424+
numba_ts.mutations_inherited_state, ts.mutations_inherited_state
425+
)
426+
assert numba_ts.mutations_inherited_state.dtype.kind == "U" # Unicode string
423427
nt.assert_array_equal(
424428
numba_ts.indexes_edge_insertion_order, ts.indexes_edge_insertion_order
425429
)

python/tests/test_lowlevel.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2192,7 +2192,12 @@ def test_generated_columns(self, ts_fixture, name):
21922192

21932193
@pytest.mark.skipif(not _tskit.HAS_NUMPY_2, reason="Requires NumPy 2.0+")
21942194
@pytest.mark.parametrize(
2195-
"string_array", ["sites_ancestral_state", "mutations_derived_state"]
2195+
"string_array",
2196+
[
2197+
"sites_ancestral_state",
2198+
"mutations_derived_state",
2199+
"mutations_inherited_state",
2200+
],
21962201
)
21972202
@pytest.mark.parametrize(
21982203
"str_lengths",
@@ -2210,6 +2215,9 @@ def test_string_arrays(self, ts_fixture, str_lengths, string_array):
22102215
elif string_array == "mutations_derived_state":
22112216
assert ts.num_mutations > 0
22122217
assert {len(mut.derived_state) for mut in ts.mutations()} == {1}
2218+
elif string_array == "mutations_inherited_state":
2219+
assert ts.num_mutations > 0
2220+
assert {len(mut.inherited_state) for mut in ts.mutations()} == {1}
22132221
else:
22142222
tables = ts_fixture.dump_tables()
22152223

@@ -2239,6 +2247,25 @@ def test_string_arrays(self, ts_fixture, str_lengths, string_array):
22392247
derived_state=get_derived_state(i, mutation)
22402248
)
22412249
)
2250+
elif string_array == "mutations_inherited_state":
2251+
# For inherited state, we modify sites and mutations to create
2252+
# varied lengths
2253+
sites = tables.sites.copy()
2254+
tables.sites.clear()
2255+
get_ancestral_state = str_map[str_lengths]
2256+
for i, site in enumerate(sites):
2257+
tables.sites.append(
2258+
site.replace(ancestral_state=get_ancestral_state(i, site))
2259+
)
2260+
mutations = tables.mutations.copy()
2261+
tables.mutations.clear()
2262+
get_derived_state = str_map[str_lengths]
2263+
for i, mutation in enumerate(mutations):
2264+
tables.mutations.append(
2265+
mutation.replace(
2266+
derived_state=get_derived_state(i, mutation)
2267+
)
2268+
)
22422269

22432270
ts = tables.tree_sequence()
22442271
ll_ts = ts.ll_tree_sequence
@@ -2255,6 +2282,9 @@ def test_string_arrays(self, ts_fixture, str_lengths, string_array):
22552282
elif string_array == "mutations_derived_state":
22562283
for mutation in ts.mutations():
22572284
assert a[mutation.id] == mutation.derived_state
2285+
elif string_array == "mutations_inherited_state":
2286+
for mutation in ts.mutations():
2287+
assert a[mutation.id] == mutation.inherited_state
22582288

22592289
# Read only
22602290
with pytest.raises(AttributeError, match="not writable"):

python/tskit/jit/numba.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -402,9 +402,11 @@ def __init__(
402402
mutations_parent,
403403
mutations_time,
404404
mutations_derived_state,
405+
mutations_inherited_state,
405406
breakpoints,
406407
max_ancestral_length,
407408
max_derived_length,
409+
max_inherited_length,
408410
):
409411
self.num_trees = num_trees
410412
self.num_nodes = num_nodes
@@ -431,9 +433,11 @@ def __init__(
431433
self.mutations_parent = mutations_parent
432434
self.mutations_time = mutations_time
433435
self.mutations_derived_state = mutations_derived_state
436+
self.mutations_inherited_state = mutations_inherited_state
434437
self.breakpoints = breakpoints
435438
self.max_ancestral_length = max_ancestral_length
436439
self.max_derived_length = max_derived_length
440+
self.max_inherited_length = max_inherited_length
437441

438442
def tree_index(self):
439443
"""
@@ -526,7 +530,7 @@ def parent_index(self):
526530

527531
# We cache these classes to avoid repeated JIT compilation
528532
@functools.lru_cache(None)
529-
def _jitwrap(max_ancestral_length, max_derived_length):
533+
def _jitwrap(max_ancestral_length, max_derived_length, max_inherited_length):
530534
# We have a circular dependency in JIT compilation between NumbaTreeSequence
531535
# and NumbaTreeIndex so we used a deferred type to break it
532536
tree_sequence_type = numba.deferred_type()
@@ -576,9 +580,14 @@ def _jitwrap(max_ancestral_length, max_derived_length):
576580
("mutations_parent", numba.int32[:]),
577581
("mutations_time", numba.float64[:]),
578582
("mutations_derived_state", numba.types.UnicodeCharSeq(max_derived_length)[:]),
583+
(
584+
"mutations_inherited_state",
585+
numba.types.UnicodeCharSeq(max_inherited_length)[:],
586+
),
579587
("breakpoints", numba.float64[:]),
580588
("max_ancestral_length", numba.int32),
581589
("max_derived_length", numba.int32),
590+
("max_inherited_length", numba.int32),
582591
]
583592

584593
# The `tree_index` method on NumbaTreeSequence uses NumbaTreeIndex
@@ -614,8 +623,13 @@ def jitwrap(ts):
614623
"""
615624
max_ancestral_length = max(1, max(map(len, ts.sites_ancestral_state), default=1))
616625
max_derived_length = max(1, max(map(len, ts.mutations_derived_state), default=1))
626+
max_inherited_length = max(
627+
1, max(map(len, ts.mutations_inherited_state), default=1)
628+
)
617629

618-
JittedTreeSequence = _jitwrap(max_ancestral_length, max_derived_length)
630+
JittedTreeSequence = _jitwrap(
631+
max_ancestral_length, max_derived_length, max_inherited_length
632+
)
619633

620634
# Create the tree sequence instance
621635
numba_ts = JittedTreeSequence(
@@ -648,9 +662,13 @@ def jitwrap(ts):
648662
mutations_derived_state=ts.mutations_derived_state.astype(
649663
f"U{max_derived_length}"
650664
),
665+
mutations_inherited_state=ts.mutations_inherited_state.astype(
666+
f"U{max_inherited_length}"
667+
),
651668
breakpoints=ts.breakpoints(as_array=True),
652669
max_ancestral_length=max_ancestral_length,
653670
max_derived_length=max_derived_length,
671+
max_inherited_length=max_inherited_length,
654672
)
655673

656674
return numba_ts

python/tskit/trees.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6092,14 +6092,14 @@ def mutations_inherited_state(self):
60926092
:return: Array of shape (num_mutations,) containing inherited states.
60936093
:rtype: numpy.ndarray
60946094
"""
6095+
if not _tskit.HAS_NUMPY_2:
6096+
raise RuntimeError(
6097+
"The mutations_inherited_state property requires numpy 2.0 or later."
6098+
)
60956099
if self._mutations_inherited_state is None:
6096-
inherited_state = self.sites_ancestral_state[self.mutations_site]
6097-
mutations_with_parent = self.mutations_parent != -1
6098-
parent = self.mutations_parent[mutations_with_parent]
6099-
inherited_state[mutations_with_parent] = self.mutations_derived_state[
6100-
parent
6101-
]
6102-
self._mutations_inherited_state = inherited_state
6100+
self._mutations_inherited_state = (
6101+
self._ll_tree_sequence.mutations_inherited_state
6102+
)
61036103
return self._mutations_inherited_state
61046104

61056105
@property

0 commit comments

Comments
 (0)