Skip to content

Commit 8e28db2

Browse files
committed
Rework build-prior and inside / outside logic to allow historical samples
And speed up time constraint algorithms while also allowing nodes to be out of time order
1 parent b877807 commit 8e28db2

File tree

6 files changed

+218
-116
lines changed

6 files changed

+218
-116
lines changed

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
tskit>=0.4.0
1+
tskit>=0.5.2
22
tsinfer>=0.2.0
33
flake8
44
numpy

tests/test_functions.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -742,14 +742,16 @@ def test_logsumexp(self):
742742
assert np.allclose(LogLikelihoods.logsumexp(log_lls), np.log(ll_sum))
743743

744744
def test_zeros_logsumexp(self):
745-
lls = np.log(np.concatenate([np.zeros(100), np.random.rand(1000)]))
746-
assert np.allclose(LogLikelihoods.logsumexp(lls), self.naive_logsumexp(lls))
745+
with np.errstate(divide="ignore"):
746+
lls = np.log(np.concatenate([np.zeros(100), np.random.rand(1000)]))
747+
assert np.allclose(LogLikelihoods.logsumexp(lls), self.naive_logsumexp(lls))
747748

748749
def test_logsumexp_underflow(self):
749750
# underflow in the naive case, but not in the LogLikelihoods implementation
750-
lls = np.array([-1000, -1001])
751-
assert self.naive_logsumexp(lls) == -np.inf
752-
assert LogLikelihoods.logsumexp(lls) != -np.inf
751+
with np.errstate(divide="ignore"):
752+
lls = np.array([-1000, -1001])
753+
assert self.naive_logsumexp(lls) == -np.inf
754+
assert LogLikelihoods.logsumexp(lls) != -np.inf
753755

754756
def test_log_tri_functions(self):
755757
ts = utility_functions.two_tree_mutation_ts()
@@ -1047,7 +1049,7 @@ def test_dangling_fails(self):
10471049
print(ts.draw_text())
10481050
print("Samples:", ts.samples())
10491051
Ne = 0.5
1050-
with pytest.raises(ValueError, match="simplified"):
1052+
with pytest.raises(ValueError, match="simplify"):
10511053
tsdate.build_prior_grid(ts, Ne, timepoints=np.array([0, 1.2, 2]))
10521054
# mut_rate = 1
10531055
# eps = 1e-6
@@ -1420,7 +1422,7 @@ def test_date_input(self):
14201422

14211423
def test_sample_as_parent_fails(self):
14221424
ts = utility_functions.single_tree_ts_n3_sample_as_parent()
1423-
with pytest.raises(NotImplementedError):
1425+
with pytest.raises(ValueError, match="samples at non-zero times"):
14241426
tsdate.date(ts, mutation_rate=None, Ne=1)
14251427

14261428
def test_recombination_not_implemented(self):
@@ -1531,18 +1533,7 @@ def test_constrain_ages_topo(self):
15311533
ts = utility_functions.two_tree_ts()
15321534
post_mn = np.array([0.0, 0.0, 0.0, 2.0, 1.0, 3.0])
15331535
eps = 1e-6
1534-
nodes_to_date = np.array([3, 4, 5])
1535-
constrained_ages = constrain_ages_topo(ts, post_mn, eps, nodes_to_date)
1536-
assert np.array_equal(
1537-
np.array([0.0, 0.0, 0.0, 2.0, 2.000001, 3.0]), constrained_ages
1538-
)
1539-
1540-
def test_constrain_ages_topo_no_nodes_to_date(self):
1541-
ts = utility_functions.two_tree_ts()
1542-
post_mn = np.array([0.0, 0.0, 0.0, 2.0, 1.0, 3.0])
1543-
eps = 1e-6
1544-
nodes_to_date = None
1545-
constrained_ages = constrain_ages_topo(ts, post_mn, eps, nodes_to_date)
1536+
constrained_ages = constrain_ages_topo(ts, post_mn, eps)
15461537
assert np.array_equal(
15471538
np.array([0.0, 0.0, 0.0, 2.0, 2.000001, 3.0]), constrained_ages
15481539
)

tests/test_inference.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def test_bad_Ne(self):
6161

6262
def test_dangling_failure(self):
6363
ts = utility_functions.single_tree_ts_n2_dangling()
64-
with pytest.raises(ValueError, match="simplified"):
64+
with pytest.raises(ValueError, match="simplify"):
6565
tsdate.date(ts, mutation_rate=None, Ne=1)
6666

6767
def test_unary_failure(self):
@@ -271,16 +271,29 @@ def test_fails_multi_root(self):
271271
with pytest.raises(ValueError):
272272
tsdate.date(multiroot_ts, Ne=1, mutation_rate=2, priors=good_priors)
273273

274-
def test_non_contemporaneous(self):
274+
def test_non_contemporaneous_warn(self):
275275
samples = [
276276
msprime.Sample(population=0, time=0),
277277
msprime.Sample(population=0, time=0),
278278
msprime.Sample(population=0, time=0),
279279
msprime.Sample(population=0, time=1.0),
280280
]
281281
ts = msprime.simulate(samples=samples, Ne=1, mutation_rate=2, random_seed=12)
282-
with pytest.raises(NotImplementedError):
282+
with pytest.raises(ValueError, match="samples at non-zero times"):
283283
tsdate.date(ts, Ne=1, mutation_rate=2)
284+
with pytest.raises(ValueError, match="samples at non-zero times"):
285+
tsdate.build_prior_grid(ts, Ne=1)
286+
287+
def test_non_contemporaneous(self):
288+
samples = [
289+
msprime.Sample(population=0, time=0),
290+
msprime.Sample(population=0, time=0),
291+
msprime.Sample(population=0, time=0),
292+
msprime.Sample(population=0, time=1.0),
293+
]
294+
ts = msprime.simulate(samples=samples, Ne=1, mutation_rate=2, random_seed=12)
295+
priors = tsdate.build_prior_grid(ts, Ne=1, allow_historical_samples=True)
296+
tsdate.date(ts, priors=priors, mutation_rate=2)
284297

285298
def test_no_mutation_times(self):
286299
ts = msprime.simulate(20, Ne=1, mutation_rate=1, random_seed=12)

tsdate/base.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,12 @@ def __init__(
9595
] = (-np.arange(num_nodes - self.num_nonfixed) - 1)
9696
self.probability_space = LIN
9797

98+
def fixed_node_ids(self):
99+
return np.where(self.row_lookup < 0)[0]
100+
101+
def nonfixed_node_ids(self):
102+
return np.where(self.row_lookup >= 0)[0]
103+
98104
def force_probability_space(self, probability_space):
99105
"""
100106
probability_space can be "logarithmic" or "linear": this function will force
@@ -119,7 +125,7 @@ def force_probability_space(self, probability_space):
119125
if self.probability_space == LOG:
120126
pass
121127
elif self.probability_space == LIN:
122-
with np.errstate(divide="ignore"):
128+
with np.errstate(divide="ignore", invalid="ignore"):
123129
self.grid_data = np.log(self.grid_data)
124130
self.fixed_data = np.log(self.fixed_data)
125131
self.probability_space = LOG
@@ -140,6 +146,9 @@ def normalize(self):
140146
else:
141147
raise RuntimeError("Probability space is not", LIN, "or", LOG)
142148

149+
def is_fixed(self, node_id):
150+
return self.row_lookup[node_id] < 0
151+
143152
def __getitem__(self, node_id):
144153
index = self.row_lookup[node_id]
145154
if index < 0:
@@ -207,8 +216,7 @@ def fill_fixed(orig, fixed_data):
207216
new_obj.fixed_data = fill_fixed(
208217
self, grid_data if fixed_data is None else fixed_data
209218
)
210-
if probability_space is None:
211-
new_obj.probability_space = self.probability_space
212-
else:
213-
new_obj.probability_space = probability_space
219+
new_obj.probability_space = self.probability_space
220+
if probability_space is not None:
221+
new_obj.force_probability_space(probability_space)
214222
return new_obj

tsdate/core.py

Lines changed: 46 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ def get_fixed(self, arr, edge):
402402
return arr * liks
403403

404404
def scale_geometric(self, fraction, value):
405-
return value**fraction
405+
return value ** fraction
406406

407407

408408
class LogLikelihoods(Likelihoods):
@@ -647,11 +647,22 @@ def inside_pass(self, *, normalize=True, cache_inside=False, progress=None):
647647
inside = self.priors.clone_with_new_data( # store inside matrix values
648648
grid_data=np.nan, fixed_data=self.lik.identity_constant
649649
)
650+
# It is possible that a simple node is non-fixed, in which case we want to
651+
# provide an inside array that reflects the prior distribution
652+
nonfixed_samples = np.intersect1d(inside.nonfixed_node_ids(), self.ts.samples())
653+
for u in nonfixed_samples:
654+
# this is in the same probability space as the prior, so we should be
655+
# OK just to copy the prior values straight in. It's unclear to me (Yan)
656+
# how/if they should be normalised, however
657+
inside[u][:] = self.priors[u]
658+
650659
if cache_inside:
651660
g_i = np.full(
652661
(self.ts.num_edges, self.lik.grid_size), self.lik.identity_constant
653662
)
654663
norm = np.full(self.ts.num_nodes, np.nan)
664+
to_visit = np.zeros(self.ts.num_nodes, dtype=bool)
665+
to_visit[inside.nonfixed_node_ids()] = True
655666
# Iterate through the nodes via groupby on parent node
656667
for parent, edges in tqdm(
657668
self.edges_by_parent_asc(),
@@ -686,16 +697,22 @@ def inside_pass(self, *, normalize=True, cache_inside=False, progress=None):
686697
"dangling nodes: please simplify it"
687698
)
688699
daughter_val = self.lik.scale_geometric(
689-
spanfrac, self.lik.make_lower_tri(inside[edge.child])
700+
spanfrac, self.lik.make_lower_tri(inside_values)
690701
)
691702
edge_lik = self.lik.get_inside(daughter_val, edge)
692703
val = self.lik.combine(val, edge_lik)
693704
if np.all(val == 0):
694705
raise ValueError
695706
if cache_inside:
696707
g_i[edge.id] = edge_lik
697-
norm[parent] = np.max(val) if normalize else 1
708+
norm[parent] = np.max(val) if normalize else self.lik.identity_constant
698709
inside[parent] = self.lik.reduce(val, norm[parent])
710+
to_visit[parent] = False
711+
712+
# There may be nodes that are not parents but are also not fixed (e.g.
713+
# undated sample nodes). These need an identity normalization constant
714+
for unfixed_unvisited in np.where(to_visit)[0]:
715+
norm[unfixed_unvisited] = self.lik.identity_constant
699716

700717
if cache_inside:
701718
self.g_i = self.lik.reduce(g_i, norm[self.ts.tables.edges.child, None])
@@ -913,34 +930,32 @@ def posterior_mean_var(ts, posterior, *, fixed_node_set=None):
913930
return ts, mn_post, vr_post
914931

915932

916-
def constrain_ages_topo(ts, post_mn, eps, nodes_to_date=None, progress=False):
933+
def constrain_ages_topo(ts, node_times, eps, progress=False):
917934
"""
918-
If predicted node times violate topology, restrict node ages so that they
919-
must be older than all their children.
935+
If node_times violate topology, return increased node_times so that each node is
936+
guaranteed to be older than any of its their children.
920937
"""
921-
new_mn_post = np.copy(post_mn)
922-
if nodes_to_date is None:
923-
nodes_to_date = np.arange(ts.num_nodes, dtype=np.uint64)
924-
nodes_to_date = nodes_to_date[~np.isin(nodes_to_date, ts.samples())]
925-
926-
tables = ts.tables
927-
parents = tables.edges.parent
928-
nd_children = tables.edges.child[np.argsort(parents)]
929-
parents = sorted(parents)
930-
parents_unique = np.unique(parents, return_index=True)
931-
parent_indices = parents_unique[1][np.isin(parents_unique[0], nodes_to_date)]
932-
for index, nd in tqdm(
933-
enumerate(sorted(nodes_to_date)), desc="Constrain Ages", disable=not progress
938+
edges_parent = ts.edges_parent
939+
edges_child = ts.edges_child
940+
941+
new_node_times = np.copy(node_times)
942+
# Traverse through the ARG, ensuring children come before parents.
943+
# This can be done by iterating over groups of edges with the same parent
944+
new_parent_edge_idx = np.where(np.diff(edges_parent) != 0)[0] + 1
945+
for edges_start, edges_end in tqdm(
946+
zip(
947+
itertools.chain([0], new_parent_edge_idx),
948+
itertools.chain(new_parent_edge_idx, [len(edges_parent)]),
949+
),
950+
desc="Constrain Ages",
951+
disable=not progress,
934952
):
935-
if index + 1 != len(nodes_to_date):
936-
children_index = np.arange(parent_indices[index], parent_indices[index + 1])
937-
else:
938-
children_index = np.arange(parent_indices[index], ts.num_edges)
939-
children = nd_children[children_index]
940-
time = np.max(new_mn_post[children])
941-
if new_mn_post[nd] <= time:
942-
new_mn_post[nd] = time + eps
943-
return new_mn_post
953+
parent = edges_parent[edges_start]
954+
child_ids = edges_child[edges_start:edges_end] # May contain dups
955+
oldest_child_time = np.max(new_node_times[child_ids])
956+
if oldest_child_time >= new_node_times[parent]:
957+
new_node_times[parent] = oldest_child_time + eps
958+
return new_node_times
944959

945960

946961
def date(
@@ -1031,7 +1046,7 @@ def date(
10311046
progress=progress,
10321047
**kwargs
10331048
)
1034-
constrained = constrain_ages_topo(tree_sequence, dates, eps, nds, progress)
1049+
constrained = constrain_ages_topo(tree_sequence, dates, eps, progress)
10351050
tables = tree_sequence.dump_tables()
10361051
tables.time_units = time_units
10371052
tables.nodes.time = constrained
@@ -1080,8 +1095,6 @@ def get_dates(
10801095
10811096
:return: tuple(mn_post, posterior, timepoints, eps, nodes_to_date)
10821097
"""
1083-
fixed_nodes = set(tree_sequence.samples())
1084-
10851098
# Default to not creating approximate priors unless ts has > 1000 samples
10861099
approx_priors = False
10871100
if tree_sequence.num_samples > 1000:
@@ -1109,6 +1122,8 @@ def get_dates(
11091122
)
11101123
priors = priors
11111124

1125+
fixed_nodes = set(priors.fixed_node_ids())
1126+
11121127
if probability_space != base.LOG:
11131128
liklhd = Likelihoods(
11141129
tree_sequence,

0 commit comments

Comments
 (0)