Skip to content

Commit 34c383c

Browse files
committed
WIP
1 parent eef71d9 commit 34c383c

File tree

1 file changed

+57
-0
lines changed

1 file changed

+57
-0
lines changed

tsdate/prior.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -991,6 +991,59 @@ def fill_priors(node_parameters, timepoints, ts, Ne, *, prior_distr, progress=Fa
991991
return prior_times
992992

993993

994+
def truncate_priors(ts, sample_times, priors, nodes_to_date=None, progress=False):
995+
"""
996+
Truncate priors so they conform to the age of nodes in the tree sequence
997+
"""
998+
grid_data = np.copy(priors.grid_data[:])
999+
timepoints = priors.timepoints
1000+
if np.max(sample_times) >= np.max(timepoints):
1001+
raise ValueError("Sample times cannot be larger than the oldest timepoint")
1002+
if priors.probability_space == "linear":
1003+
zero_value = 0
1004+
one_value = 1
1005+
elif priors.probability_space == "logarithmic":
1006+
zero_value = -np.inf
1007+
one_value = 0
1008+
constrained_min_times = np.copy(sample_times)
1009+
constrained_max_times = np.full(sample_times.shape[0], np.inf)
1010+
if nodes_to_date is None:
1011+
nodes_to_date = np.arange(ts.num_nodes, dtype=np.uint64)
1012+
nodes_to_date = nodes_to_date[~np.isin(nodes_to_date, ts.samples())]
1013+
1014+
tables = ts.tables
1015+
parents = tables.edges.parent
1016+
nd_children = tables.edges.child[np.argsort(parents)]
1017+
parents = sorted(parents)
1018+
parents_unique = np.unique(parents, return_index=True)
1019+
parent_indices = parents_unique[1][np.isin(parents_unique[0], nodes_to_date)]
1020+
for index, nd in tqdm(
1021+
enumerate(sorted(nodes_to_date)), desc="Constrain Ages", disable=not progress
1022+
):
1023+
if index + 1 != len(nodes_to_date):
1024+
children_index = np.arange(parent_indices[index], parent_indices[index + 1])
1025+
else:
1026+
children_index = np.arange(parent_indices[index], ts.num_edges)
1027+
children = nd_children[children_index]
1028+
time = np.max(constrained_min_times[children])
1029+
# The constrained time of the node should be the age of the oldest child
1030+
if constrained_min_times[nd] <= time:
1031+
constrained_min_times[nd] = time
1032+
nearest_time = np.argmin(np.abs(timepoints - time))
1033+
lookup_index = priors.row_lookup[int(nd)]
1034+
grid_data[lookup_index][:nearest_time] = zero_value
1035+
assert np.all(constrained_min_times < constrained_max_times)
1036+
all_zeros = np.where(np.all(grid_data == zero_value, axis=1))[0]
1037+
1038+
rowmax = grid_data[:, 1:].max(axis=1)
1039+
if priors.probability_space == "linear":
1040+
grid_data = grid_data / rowmax[:, np.newaxis]
1041+
elif priors.probability_space == "logarithmic":
1042+
grid_data = grid_data - rowmax[:, np.newaxis]
1043+
1044+
priors.grid_data[:] = grid_data
1045+
return constrained_min_times, constrained_max_times, priors
1046+
9941047
def build_grid(
9951048
tree_sequence,
9961049
Ne,
@@ -1001,6 +1054,7 @@ def build_grid(
10011054
prior_distribution="lognorm",
10021055
eps=1e-6,
10031056
progress=False,
1057+
sample_times=None
10041058
):
10051059
"""
10061060
Using the conditional coalescent, calculate the prior distribution for the age of
@@ -1084,4 +1138,7 @@ def build_grid(
10841138
prior_distr=prior_distribution,
10851139
progress=progress,
10861140
)
1141+
if np.any(tree_sequence.tables.nodes.time[tree_sequence.samples()] != 0):
1142+
if False:
1143+
priors = truncate_priors(tree_sequence, sample_times, priors, eps, progress=progress)
10871144
return priors

0 commit comments

Comments
 (0)