Skip to content

Commit 7422acf

Browse files
committed
Cache references to edges_parent & edges_child
And remove the need to specify which nodes to constrain, which was never used anyway.
1 parent 5c97a60 commit 7422acf

File tree

4 files changed

+35
-43
lines changed

4 files changed

+35
-43
lines changed

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
tskit>=0.4.0
1+
tskit>=0.5.2
22
tsinfer>=0.2.0
33
flake8
44
numpy

tests/test_functions.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1533,18 +1533,7 @@ def test_constrain_ages_topo(self):
15331533
ts = utility_functions.two_tree_ts()
15341534
post_mn = np.array([0.0, 0.0, 0.0, 2.0, 1.0, 3.0])
15351535
eps = 1e-6
1536-
nodes_to_date = np.array([3, 4, 5])
1537-
constrained_ages = constrain_ages_topo(ts, post_mn, eps, nodes_to_date)
1538-
assert np.array_equal(
1539-
np.array([0.0, 0.0, 0.0, 2.0, 2.000001, 3.0]), constrained_ages
1540-
)
1541-
1542-
def test_constrain_ages_topo_no_nodes_to_date(self):
1543-
ts = utility_functions.two_tree_ts()
1544-
post_mn = np.array([0.0, 0.0, 0.0, 2.0, 1.0, 3.0])
1545-
eps = 1e-6
1546-
nodes_to_date = None
1547-
constrained_ages = constrain_ages_topo(ts, post_mn, eps, nodes_to_date)
1536+
constrained_ages = constrain_ages_topo(ts, post_mn, eps)
15481537
assert np.array_equal(
15491538
np.array([0.0, 0.0, 0.0, 2.0, 2.000001, 3.0]), constrained_ages
15501539
)

tsdate/core.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -935,22 +935,23 @@ def constrain_ages_topo(ts, node_times, eps, progress=False):
935935
If node_times violate topology, return increased node_times so that each node is
936936
guaranteed to be older than any of its their children.
937937
"""
938-
tables = ts.tables
938+
edges_parent = ts.edges_parent
939+
edges_child = ts.edges_child
940+
939941
new_node_times = np.copy(node_times)
940942
# Traverse through the ARG, ensuring children come before parents.
941943
# This can be done by iterating over groups of edges with the same parent
942-
new_parent_edge_idx = np.concatenate(
943-
(
944-
[0],
945-
np.where(np.diff(tables.edges.parent) != 0)[0] + 1,
946-
[tables.edges.num_rows],
947-
)
948-
)
949-
for edges_start, edges_end in zip(
950-
new_parent_edge_idx[:-1], new_parent_edge_idx[1:]
944+
new_parent_edge_idx = np.where(np.diff(edges_parent) != 0)[0] + 1
945+
for edges_start, edges_end in tqdm(
946+
zip(
947+
itertools.chain([0], new_parent_edge_idx),
948+
itertools.chain(new_parent_edge_idx, [len(edges_parent)]),
949+
),
950+
desc="Constrain Ages",
951+
disable=not progress,
951952
):
952-
parent = tables.edges.parent[edges_start]
953-
child_ids = tables.edges.child[edges_start:edges_end] # May contain dups
953+
parent = edges_parent[edges_start]
954+
child_ids = edges_child[edges_start:edges_end] # May contain dups
954955
oldest_child_time = np.max(new_node_times[child_ids])
955956
if oldest_child_time >= new_node_times[parent]:
956957
new_node_times[parent] = oldest_child_time + eps

tsdate/prior.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
"""
2323
Routines and classes for creating priors and timeslices for use in tsdate
2424
"""
25+
import itertools
2526
import logging
2627
import os
2728
from collections import defaultdict
@@ -1035,10 +1036,8 @@ def _truncate_priors(ts, priors, progress=False):
10351036
Truncate priors for all nonfixed nodes
10361037
so they conform to the age of fixed nodes in the tree sequence
10371038
"""
1038-
tables = ts.tables
1039-
10401039
fixed_nodes = priors.fixed_node_ids()
1041-
fixed_times = tables.nodes.time[fixed_nodes]
1040+
fixed_times = ts.nodes_time[fixed_nodes]
10421041

10431042
grid_data = np.copy(priors.grid_data[:])
10441043
timepoints = priors.timepoints
@@ -1048,24 +1047,25 @@ def _truncate_priors(ts, priors, progress=False):
10481047
zero_value = 0
10491048
elif priors.probability_space == "logarithmic":
10501049
zero_value = -np.inf
1051-
constrained_min_times = np.zeros_like(tables.nodes.time)
1050+
constrained_min_times = np.zeros_like(ts.nodes_time)
10521051
# Set the min times of fixed nodes to those in the tree sequence
10531052
constrained_min_times[fixed_nodes] = fixed_times
10541053

10551054
# Traverse through the ARG, ensuring children come before parents.
10561055
# This can be done by iterating over groups of edges with the same parent
1057-
new_parent_edge_idx = np.concatenate(
1058-
(
1059-
[0],
1060-
np.where(np.diff(tables.edges.parent) != 0)[0] + 1,
1061-
[tables.edges.num_rows],
1062-
)
1063-
)
1064-
for edges_start, edges_end in zip(
1065-
new_parent_edge_idx[:-1], new_parent_edge_idx[1:]
1056+
edges_parent = ts.edges_parent
1057+
edges_child = ts.edges_child
1058+
new_parent_edge_idx = np.where(np.diff(edges_parent) != 0)[0] + 1
1059+
for edges_start, edges_end in tqdm(
1060+
zip(
1061+
itertools.chain([0], new_parent_edge_idx),
1062+
itertools.chain(new_parent_edge_idx, [len(edges_parent)]),
1063+
),
1064+
desc="Trunc priors",
1065+
disable=not progress,
10661066
):
1067-
parent = tables.edges.parent[edges_start]
1068-
child_ids = tables.edges.child[edges_start:edges_end] # May contain dups
1067+
parent = edges_parent[edges_start]
1068+
child_ids = edges_child[edges_start:edges_end] # May contain dups
10691069
oldest_child_time = np.max(constrained_min_times[child_ids])
10701070
if oldest_child_time > constrained_min_times[parent]:
10711071
if priors.is_fixed(parent):
@@ -1204,15 +1204,17 @@ def build_grid(
12041204
node_var_override=node_var_override,
12051205
progress=progress,
12061206
)
1207-
tables = tree_sequence.tables
1208-
if np.any(tables.nodes.time[tree_sequence.samples()] > 0):
1207+
if np.any(tree_sequence.nodes_time[tree_sequence.samples()] > 0):
12091208
if not allow_historical_samples:
12101209
raise ValueError(
12111210
"There are samples at non-zero times, invalidating the conditional "
12121211
"coalescent prior. You can set allow_historical_samples=True to carry "
12131212
"on regardless, calculating a prior as if all samples were "
12141213
"contemporaneous (reasonable if you only have a few ancient samples)"
12151214
)
1216-
if np.any(tables.nodes.time[priors.fixed_node_ids()] > 0) and truncate_priors:
1215+
if (
1216+
np.any(tree_sequence.nodes_time[priors.fixed_node_ids()] > 0)
1217+
and truncate_priors
1218+
):
12171219
priors = _truncate_priors(tree_sequence, priors, progress=progress)
12181220
return priors

0 commit comments

Comments
 (0)