Skip to content

Commit b877807

Browse files
awohnshyanwong
authored andcommitted
Allow ancient samples
1 parent 7737a55 commit b877807

File tree

2 files changed

+93
-31
lines changed

2 files changed

+93
-31
lines changed

tsdate/core.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def _lik(muts, span, dt, mutation_rate, normalize=True):
151151
"""
152152
ll = scipy.stats.poisson.pmf(muts, dt * mutation_rate * span)
153153
if normalize:
154-
return ll / np.max(ll)
154+
return ll / np.nanmax(ll)
155155
else:
156156
return ll
157157

@@ -258,15 +258,28 @@ def get_mut_lik_fixed_node(self, edge):
258258

259259
mutations_on_edge = self.mut_edges[edge.id]
260260
child_time = self.ts.node(edge.child).time
261-
assert child_time == 0
262-
# Temporary hack - we should really take a more precise likelihood
263-
return self._lik(
264-
mutations_on_edge,
265-
edge.span,
266-
self.timediff,
267-
self.mut_rate,
268-
normalize=self.normalize,
269-
)
261+
if child_time == 0:
262+
return self._lik(
263+
mutations_on_edge,
264+
edge.span,
265+
self.timediff,
266+
self.mut_rate,
267+
normalize=self.normalize,
268+
)
269+
else:
270+
timediff = self.timepoints - child_time + 1e-8
271+
# Temporary hack - we should really take a more precise likelihood
272+
likelihood = self._lik(
273+
mutations_on_edge,
274+
edge.span,
275+
timediff,
276+
self.mut_rate,
277+
normalize=self.normalize,
278+
)
279+
# Prevent child from being older than parent
280+
likelihood[timediff < 0] = 0
281+
282+
return likelihood
270283

271284
def get_mut_lik_lower_tri(self, edge):
272285
"""
@@ -429,7 +442,7 @@ def _lik(muts, span, dt, mutation_rate, normalize=True):
429442
"""
430443
ll = scipy.stats.poisson.logpmf(muts, dt * mutation_rate * span)
431444
if normalize:
432-
return ll - np.max(ll)
445+
return ll - np.nanmax(ll)
433446
else:
434447
return ll
435448

@@ -677,10 +690,13 @@ def inside_pass(self, *, normalize=True, cache_inside=False, progress=None):
677690
)
678691
edge_lik = self.lik.get_inside(daughter_val, edge)
679692
val = self.lik.combine(val, edge_lik)
693+
if np.all(val == 0):
694+
raise ValueError
680695
if cache_inside:
681696
g_i[edge.id] = edge_lik
682697
norm[parent] = np.max(val) if normalize else 1
683698
inside[parent] = self.lik.reduce(val, norm[parent])
699+
684700
if cache_inside:
685701
self.g_i = self.lik.reduce(g_i, norm[self.ts.tables.edges.child, None])
686702
# Keep the results in this object
@@ -1064,10 +1080,6 @@ def get_dates(
10641080
10651081
:return: tuple(mn_post, posterior, timepoints, eps, nodes_to_date)
10661082
"""
1067-
# Stuff yet to be implemented. These can be deleted once fixed
1068-
for sample in tree_sequence.samples():
1069-
if tree_sequence.node(sample).time != 0:
1070-
raise NotImplementedError("Samples must all be at time 0")
10711083
fixed_nodes = set(tree_sequence.samples())
10721084

10731085
# Default to not creating approximate priors unless ts has > 1000 samples

tsdate/prior.py

Lines changed: 66 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -419,10 +419,10 @@ def __init__(self, tree_sequence, *, progress=False, allow_unary=False):
419419

420420
self.ts = tree_sequence
421421
self.sample_node_set = set(self.ts.samples())
422-
if np.any(self.ts.tables.nodes.time[self.ts.samples()] != 0):
423-
raise ValueError(
424-
"The SpansBySamples class needs a tree seq with all samples at time 0"
425-
)
422+
#if np.any(self.ts.tables.nodes.time[self.ts.samples()] != 0):
423+
# raise ValueError(
424+
# "The SpansBySamples class needs a tree seq with all samples at time 0"
425+
# )
426426
self.progress = progress
427427

428428
# We will store the spans in here, and normalize them at the end
@@ -996,6 +996,59 @@ def fill_priors(node_parameters, timepoints, ts, Ne, *, prior_distr, progress=Fa
996996
return prior_times
997997

998998

999+
def truncate_priors(ts, sample_times, priors, nodes_to_date=None, progress=False):
1000+
"""
1001+
Truncate priors so they conform to the age of nodes in the tree sequence
1002+
"""
1003+
grid_data = np.copy(priors.grid_data[:])
1004+
timepoints = priors.timepoints
1005+
if np.max(sample_times) >= np.max(timepoints):
1006+
raise ValueError("Sample times cannot be larger than the oldest timepoint")
1007+
if priors.probability_space == "linear":
1008+
zero_value = 0
1009+
one_value = 1
1010+
elif priors.probability_space == "logarithmic":
1011+
zero_value = -np.inf
1012+
one_value = 0
1013+
constrained_min_times = np.copy(sample_times)
1014+
constrained_max_times = np.full(sample_times.shape[0], np.inf)
1015+
if nodes_to_date is None:
1016+
nodes_to_date = np.arange(ts.num_nodes, dtype=np.uint64)
1017+
nodes_to_date = nodes_to_date[~np.isin(nodes_to_date, ts.samples())]
1018+
1019+
tables = ts.tables
1020+
parents = tables.edges.parent
1021+
nd_children = tables.edges.child[np.argsort(parents)]
1022+
parents = sorted(parents)
1023+
parents_unique = np.unique(parents, return_index=True)
1024+
parent_indices = parents_unique[1][np.isin(parents_unique[0], nodes_to_date)]
1025+
for index, nd in tqdm(
1026+
enumerate(sorted(nodes_to_date)), desc="Constrain Ages", disable=not progress
1027+
):
1028+
if index + 1 != len(nodes_to_date):
1029+
children_index = np.arange(parent_indices[index], parent_indices[index + 1])
1030+
else:
1031+
children_index = np.arange(parent_indices[index], ts.num_edges)
1032+
children = nd_children[children_index]
1033+
time = np.max(constrained_min_times[children])
1034+
# The constrained time of the node should be the age of the oldest child
1035+
if constrained_min_times[nd] <= time:
1036+
constrained_min_times[nd] = time
1037+
nearest_time = np.argmin(np.abs(timepoints - time))
1038+
lookup_index = priors.row_lookup[int(nd)]
1039+
grid_data[lookup_index][:nearest_time] = zero_value
1040+
assert np.all(constrained_min_times < constrained_max_times)
1041+
all_zeros = np.where(np.all(grid_data == zero_value, axis=1))[0]
1042+
1043+
rowmax = grid_data[:, 1:].max(axis=1)
1044+
if priors.probability_space == "linear":
1045+
grid_data = grid_data / rowmax[:, np.newaxis]
1046+
elif priors.probability_space == "logarithmic":
1047+
grid_data = grid_data - rowmax[:, np.newaxis]
1048+
1049+
priors.grid_data[:] = grid_data
1050+
return constrained_min_times, constrained_max_times, priors
1051+
9991052
def build_grid(
10001053
tree_sequence,
10011054
Ne,
@@ -1007,7 +1060,7 @@ def build_grid(
10071060
eps=1e-6,
10081061
# Parameters below undocumented
10091062
progress=False,
1010-
allow_unary=False,
1063+
sample_times=None
10111064
):
10121065
"""
10131066
Using the conditional coalescent, calculate the prior distribution for the age of
@@ -1038,6 +1091,8 @@ def build_grid(
10381091
inference and a discretised time grid
10391092
:rtype: base.NodeGridValues Object
10401093
"""
1094+
#tree_sequence = tree_sequence.simplify(tree_sequence.samples())
1095+
10411096
if Ne <= 0:
10421097
raise ValueError("Parameter 'Ne' must be greater than 0")
10431098
if approximate_priors:
@@ -1049,19 +1104,13 @@ def build_grid(
10491104
"Can't set approx_prior_size if approximate_prior is False"
10501105
)
10511106

1052-
contmpr_ts, node_map = util.reduce_to_contemporaneous(tree_sequence)
1053-
if contmpr_ts.num_nodes != tree_sequence.num_nodes:
1054-
raise ValueError(
1055-
"Passed tree sequence is not simplified and/or contains "
1056-
"noncontemporaneous samples"
1057-
)
1058-
span_data = SpansBySamples(contmpr_ts, progress=progress, allow_unary=allow_unary)
1107+
span_data = SpansBySamples(tree_sequence, progress=progress)
10591108

10601109
base_priors = ConditionalCoalescentTimes(
10611110
approx_prior_size, Ne, prior_distribution, progress=progress
10621111
)
10631112

1064-
base_priors.add(contmpr_ts.num_samples, approximate_priors)
1113+
base_priors.add(tree_sequence.num_samples, approximate_priors)
10651114
for total_fixed in span_data.total_fixed_at_0_counts:
10661115
# For missing data: trees vary in total fixed node count => have different priors
10671116
if total_fixed > 0:
@@ -1085,9 +1134,7 @@ def build_grid(
10851134
else:
10861135
raise ValueError("time_slices must be an integer or a numpy array of floats")
10871136

1088-
prior_params_contmpr = base_priors.get_mixture_prior_params(span_data)
1089-
# Map the nodes in the prior params back to the node ids in the original ts
1090-
prior_params = prior_params_contmpr[node_map, :]
1137+
prior_params = base_priors.get_mixture_prior_params(span_data)
10911138
# Set all fixed nodes (i.e. samples) to have 0 variance
10921139
priors = fill_priors(
10931140
prior_params,
@@ -1097,4 +1144,7 @@ def build_grid(
10971144
prior_distr=prior_distribution,
10981145
progress=progress,
10991146
)
1147+
if np.any(tree_sequence.tables.nodes.time[tree_sequence.samples()] != 0):
1148+
if False:
1149+
priors = truncate_priors(tree_sequence, sample_times, priors, eps, progress=progress)
11001150
return priors

0 commit comments

Comments
 (0)