Skip to content

Commit 234906d

Browse files
committed
Use new algo for truncating priors
1 parent 34db782 commit 234906d

File tree

2 files changed

+31
-29
lines changed

2 files changed

+31
-29
lines changed

tsdate/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,9 @@ def normalize(self):
159159
else:
160160
raise RuntimeError("Probability space is not", LIN, "or", LOG)
161161

162+
def is_fixed(self, node_id):
163+
return self.row_lookup[node_id] < 0
164+
162165
def __getitem__(self, node_id):
163166
index = self.row_lookup[node_id]
164167
if index < 0:

tsdate/prior.py

Lines changed: 28 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1032,14 +1032,10 @@ def shape_scale_from_mean_var(mean, var):
10321032

10331033
def _truncate_priors(ts, priors, progress=False):
10341034
"""
1035-
Truncate priors for the nodes listed in truncate_nodes (or all nonfixed nodes
1036-
if truncate_nodes is None) so they conform to the age of fixed nodes in the tree
1037-
sequence
1035+
Truncate priors for all nonfixed nodes
1036+
so they conform to the age of fixed nodes in the tree sequence
10381037
"""
10391038
tables = ts.tables
1040-
truncate_nodes = priors.nonfixed_node_ids()
1041-
# ensure truncate_nodes is ordered by node time
1042-
truncate_nodes = truncate_nodes[np.argsort(tables.nodes.time[truncate_nodes])]
10431039

10441040
fixed_nodes = priors.fixed_node_ids()
10451041
fixed_times = tables.nodes.time[fixed_nodes]
@@ -1055,29 +1051,32 @@ def _truncate_priors(ts, priors, progress=False):
10551051
constrained_min_times = np.zeros_like(tables.nodes.time)
10561052
# Set the min times of fixed nodes to those in the tree sequence
10571053
constrained_min_times[fixed_nodes] = fixed_times
1058-
constrained_max_times = np.full_like(constrained_min_times, np.inf)
1059-
1060-
parents = tables.edges.parent
1061-
nd_children = tables.edges.child[np.argsort(parents)]
1062-
parents = sorted(parents)
1063-
parents_unique = np.unique(parents, return_index=True)
1064-
parent_indices = parents_unique[1][np.isin(parents_unique[0], truncate_nodes)]
1065-
for index, nd in tqdm(
1066-
enumerate(truncate_nodes), desc="Constrain Ages", disable=not progress
1054+
1055+
# Traverse through the ARG, ensuring children come before parents.
1056+
# This can be done by iterating over groups of edges with the same parent
1057+
new_parent_edge_idx = np.concatenate(
1058+
(
1059+
[0],
1060+
np.where(np.diff(tables.edges.parent) != 0)[0] + 1,
1061+
[tables.edges.num_rows],
1062+
)
1063+
)
1064+
for edges_start, edges_end in zip(
1065+
new_parent_edge_idx[:-1], new_parent_edge_idx[1:]
10671066
):
1068-
if index + 1 != len(truncate_nodes):
1069-
children_index = np.arange(parent_indices[index], parent_indices[index + 1])
1070-
else:
1071-
children_index = np.arange(parent_indices[index], ts.num_edges)
1072-
children = nd_children[children_index]
1073-
time = np.max(constrained_min_times[children])
1074-
# The constrained time of the node should be the age of the oldest child
1075-
if constrained_min_times[nd] <= time:
1076-
constrained_min_times[nd] = time
1077-
nearest_time = np.argmin(np.abs(timepoints - time))
1078-
lookup_index = priors.row_lookup[int(nd)]
1079-
grid_data[lookup_index][:nearest_time] = zero_value
1080-
assert np.all(constrained_min_times < constrained_max_times)
1067+
parent = tables.edges.parent[edges_start]
1068+
child_ids = tables.edges.child[edges_start:edges_end] # May contain dups
1069+
oldest_child_time = np.max(constrained_min_times[child_ids])
1070+
if oldest_child_time > constrained_min_times[parent]:
1071+
if priors.is_fixed(parent):
1072+
raise ValueError(
1073+
"Invalid fixed times: time for"
1074+
+ f"fixed node {parent} is younger than some of its descendants"
1075+
)
1076+
constrained_min_times[parent] = oldest_child_time
1077+
if constrained_min_times[parent] > 0 and not priors.is_fixed(parent):
1078+
nearest_time = np.argmin(np.abs(timepoints - constrained_min_times[parent]))
1079+
grid_data[priors.row_lookup[parent]][:nearest_time] = zero_value
10811080

10821081
rowmax = grid_data[:, 1:].max(axis=1)
10831082
if priors.probability_space == "linear":
@@ -1138,7 +1137,7 @@ def build_grid(
11381137
:param dict node_var_override: is a dict mapping node IDs to a variance value.
11391138
Any nodes listed here will be treated as non-fixed nodes whose prior is not
11401139
calculated from the conditional coalescent but instead are allocated a prior
1141-
whose mean is thenode time in the tree sequence and whose variance is the
1140+
whose mean is the node time in the tree sequence and whose variance is the
11421141
value in this dictionary. This allows sample nodes to be treated as nonfixed
11431142
nodes, and therefore dated. If ``None`` (default) then all sample nodes are
11441143
treated as occurring ata fixed time (as if this were an empty dict).

0 commit comments

Comments
 (0)