Skip to content

Commit 6b523bd

Browse files
committed
Cache references to edges_parent & edges_child
And remove the need to specify which nodes to constrain, which was never used anyway.
1 parent 974038d commit 6b523bd

File tree

4 files changed

+35
-43
lines changed

4 files changed

+35
-43
lines changed

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
tskit>=0.4.0
1+
tskit>=0.5.2
22
tsinfer>=0.2.0
33
flake8
44
numpy

tests/test_functions.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1530,18 +1530,7 @@ def test_constrain_ages_topo(self):
15301530
ts = utility_functions.two_tree_ts()
15311531
post_mn = np.array([0.0, 0.0, 0.0, 2.0, 1.0, 3.0])
15321532
eps = 1e-6
1533-
nodes_to_date = np.array([3, 4, 5])
1534-
constrained_ages = constrain_ages_topo(ts, post_mn, eps, nodes_to_date)
1535-
assert np.array_equal(
1536-
np.array([0.0, 0.0, 0.0, 2.0, 2.000001, 3.0]), constrained_ages
1537-
)
1538-
1539-
def test_constrain_ages_topo_no_nodes_to_date(self):
1540-
ts = utility_functions.two_tree_ts()
1541-
post_mn = np.array([0.0, 0.0, 0.0, 2.0, 1.0, 3.0])
1542-
eps = 1e-6
1543-
nodes_to_date = None
1544-
constrained_ages = constrain_ages_topo(ts, post_mn, eps, nodes_to_date)
1533+
constrained_ages = constrain_ages_topo(ts, post_mn, eps)
15451534
assert np.array_equal(
15461535
np.array([0.0, 0.0, 0.0, 2.0, 2.000001, 3.0]), constrained_ages
15471536
)

tsdate/core.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -936,22 +936,23 @@ def constrain_ages_topo(ts, node_times, eps, progress=False):
936936
If node_times violate topology, return increased node_times so that each node is
937937
guaranteed to be older than any of its their children.
938938
"""
939-
tables = ts.tables
939+
edges_parent = ts.edges_parent
940+
edges_child = ts.edges_child
941+
940942
new_node_times = np.copy(node_times)
941943
# Traverse through the ARG, ensuring children come before parents.
942944
# This can be done by iterating over groups of edges with the same parent
943-
new_parent_edge_idx = np.concatenate(
944-
(
945-
[0],
946-
np.where(np.diff(tables.edges.parent) != 0)[0] + 1,
947-
[tables.edges.num_rows],
948-
)
949-
)
950-
for edges_start, edges_end in zip(
951-
new_parent_edge_idx[:-1], new_parent_edge_idx[1:]
945+
new_parent_edge_idx = np.where(np.diff(edges_parent) != 0)[0] + 1
946+
for edges_start, edges_end in tqdm(
947+
zip(
948+
itertools.chain([0], new_parent_edge_idx),
949+
itertools.chain(new_parent_edge_idx, [len(edges_parent)]),
950+
),
951+
desc="Constrain Ages",
952+
disable=not progress,
952953
):
953-
parent = tables.edges.parent[edges_start]
954-
child_ids = tables.edges.child[edges_start:edges_end] # May contain dups
954+
parent = edges_parent[edges_start]
955+
child_ids = edges_child[edges_start:edges_end] # May contain dups
955956
oldest_child_time = np.max(new_node_times[child_ids])
956957
if oldest_child_time >= new_node_times[parent]:
957958
new_node_times[parent] = oldest_child_time + eps

tsdate/prior.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
"""
2323
Routines and classes for creating priors and timeslices for use in tsdate
2424
"""
25+
import itertools
2526
import logging
2627
import os
2728
from collections import defaultdict
@@ -1030,10 +1031,8 @@ def _truncate_priors(ts, priors, progress=False):
10301031
Truncate priors for all nonfixed nodes
10311032
so they conform to the age of fixed nodes in the tree sequence
10321033
"""
1033-
tables = ts.tables
1034-
10351034
fixed_nodes = priors.fixed_node_ids()
1036-
fixed_times = tables.nodes.time[fixed_nodes]
1035+
fixed_times = ts.nodes_time[fixed_nodes]
10371036

10381037
grid_data = np.copy(priors.grid_data[:])
10391038
timepoints = priors.timepoints
@@ -1043,24 +1042,25 @@ def _truncate_priors(ts, priors, progress=False):
10431042
zero_value = 0
10441043
elif priors.probability_space == "logarithmic":
10451044
zero_value = -np.inf
1046-
constrained_min_times = np.zeros_like(tables.nodes.time)
1045+
constrained_min_times = np.zeros_like(ts.nodes_time)
10471046
# Set the min times of fixed nodes to those in the tree sequence
10481047
constrained_min_times[fixed_nodes] = fixed_times
10491048

10501049
# Traverse through the ARG, ensuring children come before parents.
10511050
# This can be done by iterating over groups of edges with the same parent
1052-
new_parent_edge_idx = np.concatenate(
1053-
(
1054-
[0],
1055-
np.where(np.diff(tables.edges.parent) != 0)[0] + 1,
1056-
[tables.edges.num_rows],
1057-
)
1058-
)
1059-
for edges_start, edges_end in zip(
1060-
new_parent_edge_idx[:-1], new_parent_edge_idx[1:]
1051+
edges_parent = ts.edges_parent
1052+
edges_child = ts.edges_child
1053+
new_parent_edge_idx = np.where(np.diff(edges_parent) != 0)[0] + 1
1054+
for edges_start, edges_end in tqdm(
1055+
zip(
1056+
itertools.chain([0], new_parent_edge_idx),
1057+
itertools.chain(new_parent_edge_idx, [len(edges_parent)]),
1058+
),
1059+
desc="Trunc priors",
1060+
disable=not progress,
10611061
):
1062-
parent = tables.edges.parent[edges_start]
1063-
child_ids = tables.edges.child[edges_start:edges_end] # May contain dups
1062+
parent = edges_parent[edges_start]
1063+
child_ids = edges_child[edges_start:edges_end] # May contain dups
10641064
oldest_child_time = np.max(constrained_min_times[child_ids])
10651065
if oldest_child_time > constrained_min_times[parent]:
10661066
if priors.is_fixed(parent):
@@ -1198,15 +1198,17 @@ def build_grid(
11981198
node_var_override=node_var_override,
11991199
progress=progress,
12001200
)
1201-
tables = tree_sequence.tables
1202-
if np.any(tables.nodes.time[tree_sequence.samples()] > 0):
1201+
if np.any(tree_sequence.nodes_time[tree_sequence.samples()] > 0):
12031202
if not allow_historical_samples:
12041203
raise ValueError(
12051204
"There are samples at non-zero times, invalidating the conditional "
12061205
"coalescent prior. You can set allow_historical_samples=True to carry "
12071206
"on regardless, calculating a prior as if all samples were "
12081207
"contemporaneous (reasonable if you only have a few ancient samples)"
12091208
)
1210-
if np.any(tables.nodes.time[priors.fixed_node_ids()] > 0) and truncate_priors:
1209+
if (
1210+
np.any(tree_sequence.nodes_time[priors.fixed_node_ids()] > 0)
1211+
and truncate_priors
1212+
):
12111213
priors = _truncate_priors(tree_sequence, priors, progress=progress)
12121214
return priors

0 commit comments

Comments
 (0)