Skip to content

Commit 51be191

Browse files
committed
Truncate on the basis of whether the nodes are fixed
1 parent a417394 commit 51be191

File tree

2 files changed

+23
-13
lines changed

2 files changed

+23
-13
lines changed

tsdate/base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,12 @@ def __init__(
9595
] = (-np.arange(num_nodes - self.num_nonfixed) - 1)
9696
self.probability_space = LIN
9797

98+
def fixed_node_ids(self):
99+
return np.where(self.row_lookup < 0)[0]
100+
101+
def nonfixed_node_ids(self):
102+
return np.where(self.row_lookup >= 0)[0]
103+
98104
def force_probability_space(self, probability_space):
99105
"""
100106
probability_space can be "logarithmic" or "linear": this function will force

tsdate/prior.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -996,38 +996,42 @@ def fill_priors(node_parameters, timepoints, ts, Ne, *, prior_distr, progress=Fa
996996
return prior_times
997997

998998

999-
def _truncate_priors(ts, priors, nodes_to_date=None, progress=False):
999+
def _truncate_priors(ts, priors, progress=False):
10001000
"""
1001-
Truncate priors so they conform to the age of nodes in the tree sequence
1001+
Truncate priors for the nodes listed in truncate_nodes (or all nonfixed nodes
1002+
if truncate_nodes in None) so they conform to the age of fixed nodes in the tree
1003+
sequence
10021004
"""
10031005
tables = ts.tables
1004-
if nodes_to_date is None:
1005-
nodes_to_date = np.arange(ts.num_nodes, dtype=np.uint64)
1006-
nodes_to_date = nodes_to_date[~np.isin(nodes_to_date, ts.samples())]
1007-
# ensure nodes_to_date is ordered by node time
1008-
nodes_to_date = nodes_to_date[np.argsort(tables.nodes.time[nodes_to_date])]
1006+
truncate_nodes = priors.nonfixed_node_ids()
1007+
# ensure truncate_nodes is ordered by node time
1008+
truncate_nodes = truncate_nodes[np.argsort(tables.nodes.time[truncate_nodes])]
1009+
1010+
fixed_nodes = priors.fixed_node_ids()
1011+
fixed_times = tables.nodes.time[fixed_nodes]
10091012

10101013
grid_data = np.copy(priors.grid_data[:])
10111014
timepoints = priors.timepoints
1012-
if np.max(tables.nodes.time[ts.samples()]) >= np.max(timepoints):
1013-
raise ValueError("Sample times cannot be larger than the oldest timepoint")
1015+
if np.max(fixed_times) >= np.max(timepoints):
1016+
raise ValueError("Fixed node times cannot be older than the oldest timepoint")
10141017
if priors.probability_space == "linear":
10151018
zero_value = 0
10161019
elif priors.probability_space == "logarithmic":
10171020
zero_value = -np.inf
10181021
constrained_min_times = np.zeros_like(tables.nodes.time)
1019-
constrained_min_times[ts.samples()] = tables.nodes.time[ts.samples()]
1022+
# Set the min times of fixed nodes to those in the tree sequence
1023+
constrained_min_times[fixed_nodes] = fixed_times
10201024
constrained_max_times = np.full_like(constrained_min_times, np.inf)
10211025

10221026
parents = tables.edges.parent
10231027
nd_children = tables.edges.child[np.argsort(parents)]
10241028
parents = sorted(parents)
10251029
parents_unique = np.unique(parents, return_index=True)
1026-
parent_indices = parents_unique[1][np.isin(parents_unique[0], nodes_to_date)]
1030+
parent_indices = parents_unique[1][np.isin(parents_unique[0], truncate_nodes)]
10271031
for index, nd in tqdm(
1028-
enumerate(nodes_to_date), desc="Constrain Ages", disable=not progress
1032+
enumerate(truncate_nodes), desc="Constrain Ages", disable=not progress
10291033
):
1030-
if index + 1 != len(nodes_to_date):
1034+
if index + 1 != len(truncate_nodes):
10311035
children_index = np.arange(parent_indices[index], parent_indices[index + 1])
10321036
else:
10331037
children_index = np.arange(parent_indices[index], ts.num_edges)

0 commit comments

Comments
 (0)