Skip to content

Commit b313ecd

Browse files
committed
Truncate on the basis of whether the nodes are fixed
1 parent 15bc4e8 commit b313ecd

File tree

2 files changed

+23
-13
lines changed

2 files changed

+23
-13
lines changed

tsdate/base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,12 @@ def __init__(
9595
] = (-np.arange(num_nodes - self.num_nonfixed) - 1)
9696
self.probability_space = LIN
9797

98+
def fixed_node_ids(self):
99+
return np.where(self.row_lookup < 0)[0]
100+
101+
def nonfixed_node_ids(self):
102+
return np.where(self.row_lookup >= 0)[0]
103+
98104
def force_probability_space(self, probability_space):
99105
"""
100106
probability_space can be "logarithmic" or "linear": this function will force

tsdate/prior.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -991,38 +991,42 @@ def fill_priors(node_parameters, timepoints, ts, Ne, *, prior_distr, progress=Fa
991991
return prior_times
992992

993993

994-
def _truncate_priors(ts, priors, nodes_to_date=None, progress=False):
994+
def _truncate_priors(ts, priors, progress=False):
995995
"""
996-
Truncate priors so they conform to the age of nodes in the tree sequence
996+
Truncate priors for the nodes listed in truncate_nodes (or all nonfixed nodes
997+
if truncate_nodes in None) so they conform to the age of fixed nodes in the tree
998+
sequence
997999
"""
9981000
tables = ts.tables
999-
if nodes_to_date is None:
1000-
nodes_to_date = np.arange(ts.num_nodes, dtype=np.uint64)
1001-
nodes_to_date = nodes_to_date[~np.isin(nodes_to_date, ts.samples())]
1002-
# ensure nodes_to_date is ordered by node time
1003-
nodes_to_date = nodes_to_date[np.argsort(tables.nodes.time[nodes_to_date])]
1001+
truncate_nodes = priors.nonfixed_node_ids()
1002+
# ensure truncate_nodes is ordered by node time
1003+
truncate_nodes = truncate_nodes[np.argsort(tables.nodes.time[truncate_nodes])]
1004+
1005+
fixed_nodes = priors.fixed_node_ids()
1006+
fixed_times = tables.nodes.time[fixed_nodes]
10041007

10051008
grid_data = np.copy(priors.grid_data[:])
10061009
timepoints = priors.timepoints
1007-
if np.max(tables.nodes.time[ts.samples()]) >= np.max(timepoints):
1008-
raise ValueError("Sample times cannot be larger than the oldest timepoint")
1010+
if np.max(fixed_times) >= np.max(timepoints):
1011+
raise ValueError("Fixed node times cannot be older than the oldest timepoint")
10091012
if priors.probability_space == "linear":
10101013
zero_value = 0
10111014
elif priors.probability_space == "logarithmic":
10121015
zero_value = -np.inf
10131016
constrained_min_times = np.zeros_like(tables.nodes.time)
1014-
constrained_min_times[ts.samples()] = tables.nodes.time[ts.samples()]
1017+
# Set the min times of fixed nodes to those in the tree sequence
1018+
constrained_min_times[fixed_nodes] = fixed_times
10151019
constrained_max_times = np.full_like(constrained_min_times, np.inf)
10161020

10171021
parents = tables.edges.parent
10181022
nd_children = tables.edges.child[np.argsort(parents)]
10191023
parents = sorted(parents)
10201024
parents_unique = np.unique(parents, return_index=True)
1021-
parent_indices = parents_unique[1][np.isin(parents_unique[0], nodes_to_date)]
1025+
parent_indices = parents_unique[1][np.isin(parents_unique[0], truncate_nodes)]
10221026
for index, nd in tqdm(
1023-
enumerate(nodes_to_date), desc="Constrain Ages", disable=not progress
1027+
enumerate(truncate_nodes), desc="Constrain Ages", disable=not progress
10241028
):
1025-
if index + 1 != len(nodes_to_date):
1029+
if index + 1 != len(truncate_nodes):
10261030
children_index = np.arange(parent_indices[index], parent_indices[index + 1])
10271031
else:
10281032
children_index = np.arange(parent_indices[index], ts.num_edges)

0 commit comments

Comments
 (0)