Skip to content

Commit 03cc55a

Browse files
committed
Allow truncate_prior to actually function!
1 parent b5e9bba commit 03cc55a

File tree

1 file changed

+18
-18
lines changed

1 file changed

+18
-18
lines changed

tsdate/prior.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,7 @@ def __init__(self, tree_sequence, progress=False):
419419

420420
self.ts = tree_sequence
421421
self.sample_node_set = set(self.ts.samples())
422-
#if np.any(self.ts.tables.nodes.time[self.ts.samples()] != 0):
422+
# if np.any(self.ts.tables.nodes.time[self.ts.samples()] != 0):
423423
# raise ValueError(
424424
# "The SpansBySamples class needs a tree seq with all samples at time 0"
425425
# )
@@ -991,34 +991,36 @@ def fill_priors(node_parameters, timepoints, ts, Ne, *, prior_distr, progress=Fa
991991
return prior_times
992992

993993

994-
def truncate_priors(ts, sample_times, priors, nodes_to_date=None, progress=False):
994+
def truncate_priors(ts, priors, nodes_to_date=None, progress=False):
995995
"""
996996
Truncate priors so they conform to the age of nodes in the tree sequence
997997
"""
998+
tables = ts.tables
999+
if nodes_to_date is None:
1000+
nodes_to_date = np.arange(ts.num_nodes, dtype=np.uint64)
1001+
nodes_to_date = nodes_to_date[~np.isin(nodes_to_date, ts.samples())]
1002+
# ensure nodes_to_date is ordered by node time
1003+
nodes_to_date = nodes_to_date[np.argsort(tables.nodes.time[nodes_to_date])]
1004+
9981005
grid_data = np.copy(priors.grid_data[:])
9991006
timepoints = priors.timepoints
1000-
if np.max(sample_times) >= np.max(timepoints):
1007+
if np.max(tables.nodes.time[ts.samples()]) >= np.max(timepoints):
10011008
raise ValueError("Sample times cannot be larger than the oldest timepoint")
10021009
if priors.probability_space == "linear":
10031010
zero_value = 0
1004-
one_value = 1
10051011
elif priors.probability_space == "logarithmic":
10061012
zero_value = -np.inf
1007-
one_value = 0
1008-
constrained_min_times = np.copy(sample_times)
1009-
constrained_max_times = np.full(sample_times.shape[0], np.inf)
1010-
if nodes_to_date is None:
1011-
nodes_to_date = np.arange(ts.num_nodes, dtype=np.uint64)
1012-
nodes_to_date = nodes_to_date[~np.isin(nodes_to_date, ts.samples())]
1013+
constrained_min_times = np.copy(tables.nodes.time)
1014+
constrained_min_times[ts.samples()] = tables.nodes.time[ts.samples()]
1015+
constrained_max_times = np.full_like(constrained_min_times, np.inf)
10131016

1014-
tables = ts.tables
10151017
parents = tables.edges.parent
10161018
nd_children = tables.edges.child[np.argsort(parents)]
10171019
parents = sorted(parents)
10181020
parents_unique = np.unique(parents, return_index=True)
10191021
parent_indices = parents_unique[1][np.isin(parents_unique[0], nodes_to_date)]
10201022
for index, nd in tqdm(
1021-
enumerate(sorted(nodes_to_date)), desc="Constrain Ages", disable=not progress
1023+
enumerate(nodes_to_date), desc="Constrain Ages", disable=not progress
10221024
):
10231025
if index + 1 != len(nodes_to_date):
10241026
children_index = np.arange(parent_indices[index], parent_indices[index + 1])
@@ -1033,17 +1035,17 @@ def truncate_priors(ts, sample_times, priors, nodes_to_date=None, progress=False
10331035
lookup_index = priors.row_lookup[int(nd)]
10341036
grid_data[lookup_index][:nearest_time] = zero_value
10351037
assert np.all(constrained_min_times < constrained_max_times)
1036-
all_zeros = np.where(np.all(grid_data == zero_value, axis=1))[0]
10371038

10381039
rowmax = grid_data[:, 1:].max(axis=1)
10391040
if priors.probability_space == "linear":
10401041
grid_data = grid_data / rowmax[:, np.newaxis]
10411042
elif priors.probability_space == "logarithmic":
10421043
grid_data = grid_data - rowmax[:, np.newaxis]
1043-
1044+
10441045
priors.grid_data[:] = grid_data
10451046
return constrained_min_times, constrained_max_times, priors
10461047

1048+
10471049
def build_grid(
10481050
tree_sequence,
10491051
Ne,
@@ -1054,7 +1056,6 @@ def build_grid(
10541056
prior_distribution="lognorm",
10551057
eps=1e-6,
10561058
progress=False,
1057-
sample_times=None
10581059
):
10591060
"""
10601061
Using the conditional coalescent, calculate the prior distribution for the age of
@@ -1085,7 +1086,7 @@ def build_grid(
10851086
inference and a discretised time grid
10861087
:rtype: base.NodeGridValues Object
10871088
"""
1088-
#tree_sequence = tree_sequence.simplify(tree_sequence.samples())
1089+
# tree_sequence = tree_sequence.simplify(tree_sequence.samples())
10891090

10901091
if Ne <= 0:
10911092
raise ValueError("Parameter 'Ne' must be greater than 0")
@@ -1139,6 +1140,5 @@ def build_grid(
11391140
progress=progress,
11401141
)
11411142
if np.any(tree_sequence.tables.nodes.time[tree_sequence.samples()] != 0):
1142-
if False:
1143-
priors = truncate_priors(tree_sequence, sample_times, priors, eps, progress=progress)
1143+
_, _, priors = truncate_priors(tree_sequence, priors, progress=progress)
11441144
return priors

0 commit comments

Comments
 (0)